diff --git a/rrl.py b/rrl.py index 6603833..2d67d99 100644 --- a/rrl.py +++ b/rrl.py @@ -17,10 +17,10 @@ class RateLimitExceeded(Exception): pass -def _get_redis_connection(): +def _get_redis_connection() -> Redis: host = os.environ.get("RRL_REDIS_HOST", "localhost") - port = os.environ.get("RRL_REDIS_PORT", 6379) - db = os.environ.get("RRL_REDIS_DB", 0) + port = int(os.environ.get("RRL_REDIS_PORT", 6379)) + db = int(os.environ.get("RRL_REDIS_DB", 0)) return Redis(host=host, port=port, db=db) @@ -31,13 +31,15 @@ class RateLimiter: :: never expires """ - def __init__(self, tiers: typing.List[Tier], *, prefix="", use_redis_time=True): + def __init__( + self, tiers: typing.List[Tier], *, prefix: str = "", use_redis_time: bool = True + ): self.redis = _get_redis_connection() self.tiers = {tier.name: tier for tier in tiers} self.prefix = prefix self.use_redis_time = use_redis_time - def check_limit(self, zone: str, key: str, tier_name: str): + def check_limit(self, zone: str, key: str, tier_name: str) -> bool: if self.use_redis_time: timestamp = self.redis.time()[0] now = datetime.datetime.fromtimestamp(timestamp)