add get_usage_since
This commit is contained in:
parent
944fa00de7
commit
b82c0af81e
38
rrl.py
38
rrl.py
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from redis import Redis
|
|
||||||
import datetime
|
import datetime
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -32,12 +32,18 @@ class RateLimiter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, tiers: typing.List[Tier], *, prefix: str = "", use_redis_time: bool = True
|
self,
|
||||||
|
tiers: typing.List[Tier],
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
use_redis_time: bool = True,
|
||||||
|
track_daily_usage: bool = True,
|
||||||
):
|
):
|
||||||
self.redis = _get_redis_connection()
|
self.redis = _get_redis_connection()
|
||||||
self.tiers = {tier.name: tier for tier in tiers}
|
self.tiers = {tier.name: tier for tier in tiers}
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.use_redis_time = use_redis_time
|
self.use_redis_time = use_redis_time
|
||||||
|
self.track_daily_usage = track_daily_usage
|
||||||
|
|
||||||
def check_limit(self, zone: str, key: str, tier_name: str) -> bool:
|
def check_limit(self, zone: str, key: str, tier_name: str) -> bool:
|
||||||
if self.use_redis_time:
|
if self.use_redis_time:
|
||||||
@ -56,11 +62,13 @@ class RateLimiter:
|
|||||||
hour_key = f"{self.prefix}:{zone}:{key}:h{now.hour}"
|
hour_key = f"{self.prefix}:{zone}:{key}:h{now.hour}"
|
||||||
pipe.incr(hour_key)
|
pipe.incr(hour_key)
|
||||||
pipe.expire(hour_key, 3600)
|
pipe.expire(hour_key, 3600)
|
||||||
if tier.per_day:
|
if tier.per_day or self.track_daily_usage:
|
||||||
day = now.strftime("%Y%m%d")
|
day = now.strftime("%Y%m%d")
|
||||||
day_key = f"{self.prefix}:{zone}:{key}:d{day}"
|
day_key = f"{self.prefix}:{zone}:{key}:d{day}"
|
||||||
pipe.incr(day_key)
|
pipe.incr(day_key)
|
||||||
# do not expire day keys for now, useful for metrics
|
# keep data around for usage tracking
|
||||||
|
if not self.track_daily_usage:
|
||||||
|
pipe.expire(day_key, 86400)
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
# the result is pairs of results of incr and expire calls, so if all 3 limits are set
|
# the result is pairs of results of incr and expire calls, so if all 3 limits are set
|
||||||
@ -86,3 +94,25 @@ class RateLimiter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_usage_since(
|
||||||
|
self,
|
||||||
|
zone: str,
|
||||||
|
key: str,
|
||||||
|
start: datetime.date,
|
||||||
|
end: typing.Optional[datetime.date] = None,
|
||||||
|
) -> typing.List[typing.Dict[str, typing.Union[datetime.date, int]]]:
|
||||||
|
if not end:
|
||||||
|
end = datetime.date.today()
|
||||||
|
days = []
|
||||||
|
day = start
|
||||||
|
while day <= end:
|
||||||
|
days.append(day)
|
||||||
|
day += datetime.timedelta(days=1)
|
||||||
|
day_keys = [
|
||||||
|
f"{self.prefix}:{zone}:{key}:d{day.strftime('%Y%m%d')}" for day in days
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{"date": d, "calls": int(calls.decode()) if calls else 0}
|
||||||
|
for d, calls in zip(days, self.redis.mget(day_keys))
|
||||||
|
]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from rrl import Tier, RateLimiter, RateLimitExceeded, _get_redis_connection
|
from rrl import Tier, RateLimiter, RateLimitExceeded, _get_redis_connection
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
@ -8,6 +9,7 @@ simple_hour_tier = Tier("10/hour", 0, 10, 0)
|
|||||||
simple_daily_tier = Tier("10/day", 0, 0, 10)
|
simple_daily_tier = Tier("10/day", 0, 0, 10)
|
||||||
long_minute_short_hour_tier = Tier("long_min_short_hour", 100, 10, 0)
|
long_minute_short_hour_tier = Tier("long_min_short_hour", 100, 10, 0)
|
||||||
everything_set_short_day_tier = Tier("everything_set", 100, 100, 10)
|
everything_set_short_day_tier = Tier("everything_set", 100, 100, 10)
|
||||||
|
unlimited_tier = Tier("unlimited", 0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -51,7 +53,7 @@ def test_using_redis_time():
|
|||||||
try:
|
try:
|
||||||
rl.check_limit("test-zone", "test-key", simple_daily_tier.name)
|
rl.check_limit("test-zone", "test-key", simple_daily_tier.name)
|
||||||
count += 1
|
count += 1
|
||||||
except RateLimitExceeded as e:
|
except RateLimitExceeded:
|
||||||
break
|
break
|
||||||
assert count == 10
|
assert count == 10
|
||||||
|
|
||||||
@ -86,3 +88,25 @@ def test_multiple_keys():
|
|||||||
except RateLimitExceeded:
|
except RateLimitExceeded:
|
||||||
break
|
break
|
||||||
assert count == 10
|
assert count == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_daily_usage():
|
||||||
|
redis.flushall()
|
||||||
|
rl = RateLimiter(
|
||||||
|
tiers=[unlimited_tier], use_redis_time=False, track_daily_usage=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# make Nth day have N calls
|
||||||
|
for n in range(1, 10):
|
||||||
|
with freeze_time(f"2020-01-0{n}"):
|
||||||
|
for _ in range(n):
|
||||||
|
rl.check_limit("zone", "test-key", unlimited_tier.name)
|
||||||
|
|
||||||
|
with freeze_time("2020-01-15"):
|
||||||
|
usage = rl.get_usage_since("zone", "test-key", datetime.date(2020, 1, 1))
|
||||||
|
assert usage[0] == {"date": datetime.date(2020, 1, 1), "calls": 1}
|
||||||
|
assert usage[3] == {"date": datetime.date(2020, 1, 4), "calls": 4}
|
||||||
|
assert usage[8] == {"date": datetime.date(2020, 1, 9), "calls": 9}
|
||||||
|
assert usage[9] == {"date": datetime.date(2020, 1, 10), "calls": 0}
|
||||||
|
assert usage[14] == {"date": datetime.date(2020, 1, 15), "calls": 0}
|
||||||
|
assert len(usage) == 15
|
||||||
|
Loading…
Reference in New Issue
Block a user