diff --git a/.gitignore b/.gitignore index ba9db1a..34abdb8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.pyc dist/ +*.egg-info/ diff --git a/rrl.py b/rrl.py index 7a64629..ae70243 100644 --- a/rrl.py +++ b/rrl.py @@ -52,12 +52,15 @@ class RateLimiter: self.track_daily_usage = track_daily_usage def check_limit(self, zone: str, key: str, tier_name: str) -> bool: + try: + tier = self.tiers[tier_name] + except KeyError: + raise ValueError(f"unknown tier: {tier_name}") if self.use_redis_time: timestamp = self.redis.time()[0] now = datetime.datetime.fromtimestamp(timestamp) else: now = datetime.datetime.utcnow() - tier = self.tiers[tier_name] pipe = self.redis.pipeline() if tier.per_minute: diff --git a/test_ratelimit.py b/test_ratelimit.py index 8faf29a..84de977 100644 --- a/test_ratelimit.py +++ b/test_ratelimit.py @@ -58,6 +58,14 @@ def test_using_redis_time(): assert count == 10 +def test_invalid_tier(): + redis.flushall() + rl = RateLimiter(tiers=[simple_daily_tier], use_redis_time=True) + + with pytest.raises(ValueError): + rl.check_limit("test-zone", "test-key", "non-existent-tier") + + def test_multiple_zones(): redis.flushall() rl = RateLimiter(tiers=[simple_daily_tier], use_redis_time=True)