rrl/rrl.py
2020-11-19 15:29:56 -05:00

127 lines
3.9 KiB
Python

import os
import datetime
import typing
from dataclasses import dataclass
from redis import Redis
@dataclass
class Tier:
name: str
per_minute: int
per_hour: int
per_day: int
@dataclass
class DailyUsage:
date: datetime.date
calls: int
class RateLimitExceeded(Exception):
pass
def _get_redis_connection() -> Redis:
host = os.environ.get("RRL_REDIS_HOST", "localhost")
port = int(os.environ.get("RRL_REDIS_PORT", 6379))
db = int(os.environ.get("RRL_REDIS_DB", 0))
return Redis(host=host, port=port, db=db)
class RateLimiter:
"""
<prefix>:<key>:<hour><minute> expires in 2 minutes
<prefix>:<key>:<hour> expires in 2 hours
<prefix>:<key>:<day> never expires
"""
def __init__(
self,
tiers: typing.List[Tier],
*,
prefix: str = "",
use_redis_time: bool = True,
track_daily_usage: bool = True,
):
self.redis = _get_redis_connection()
self.tiers = {tier.name: tier for tier in tiers}
self.prefix = prefix
self.use_redis_time = use_redis_time
self.track_daily_usage = track_daily_usage
def check_limit(self, key: str, tier_name: str) -> bool:
try:
tier = self.tiers[tier_name]
except KeyError:
raise ValueError(f"unknown tier: {tier_name}")
if self.use_redis_time:
timestamp = self.redis.time()[0]
now = datetime.datetime.fromtimestamp(timestamp)
else:
now = datetime.datetime.utcnow()
pipe = self.redis.pipeline()
if tier.per_minute:
minute_key = f"{self.prefix}:{key}:m{now.minute}"
pipe.incr(minute_key)
pipe.expire(minute_key, 60)
if tier.per_hour:
hour_key = f"{self.prefix}:{key}:h{now.hour}"
pipe.incr(hour_key)
pipe.expire(hour_key, 3600)
if tier.per_day or self.track_daily_usage:
day = now.strftime("%Y%m%d")
day_key = f"{self.prefix}:{key}:d{day}"
pipe.incr(day_key)
# keep data around for usage tracking
if not self.track_daily_usage:
pipe.expire(day_key, 86400)
result = pipe.execute()
# the result is pairs of results of incr and expire calls, so if all 3 limits are set
# it looks like [per_minute_calls, True, per_hour_calls, True, per_day_calls]
# we increment value_pos as we consume values so we know which location we're looking at
value_pos = 0
if tier.per_minute:
if result[value_pos] > tier.per_minute:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_minute}/min: {result[value_pos]}"
)
value_pos += 2
if tier.per_hour:
if result[value_pos] > tier.per_hour:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_hour}/hour: {result[value_pos]}"
)
value_pos += 2
if tier.per_day:
if result[value_pos] > tier.per_day:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_day}/day: {result[value_pos]}"
)
return True
def get_usage_since(
self,
key: str,
start: datetime.date,
end: typing.Optional[datetime.date] = None,
) -> typing.List[DailyUsage]:
if not self.track_daily_usage:
raise RuntimeError("track_daily_usage is not enabled")
if not end:
end = datetime.date.today()
days = []
day = start
while day <= end:
days.append(day)
day += datetime.timedelta(days=1)
day_keys = [f"{self.prefix}:{key}:d{day.strftime('%Y%m%d')}" for day in days]
return [
DailyUsage(d, int(calls.decode()) if calls else 0)
for d, calls in zip(days, self.redis.mget(day_keys))
]