diff --git a/ratelimit.py b/ratelimit.py index 4bd3f75..6603833 100644 --- a/ratelimit.py +++ b/ratelimit.py @@ -1,3 +1,4 @@ +import os from redis import Redis import datetime import typing @@ -16,6 +17,13 @@ class RateLimitExceeded(Exception): pass +def _get_redis_connection(): + host = os.environ.get("RRL_REDIS_HOST", "localhost") + port = os.environ.get("RRL_REDIS_PORT", 6379) + db = os.environ.get("RRL_REDIS_DB", 0) + return Redis(host=host, port=port, db=db) + + class RateLimiter: """ :: expires in 2 minutes @@ -24,7 +32,7 @@ class RateLimiter: """ def __init__(self, tiers: typing.List[Tier], *, prefix="", use_redis_time=True): - self.redis = Redis() + self.redis = _get_redis_connection() self.tiers = {tier.name: tier for tier in tiers} self.prefix = prefix self.use_redis_time = use_redis_time diff --git a/test_ratelimit.py b/test_ratelimit.py index 9aaa4d5..2b5b83e 100644 --- a/test_ratelimit.py +++ b/test_ratelimit.py @@ -1,9 +1,8 @@ import pytest -from ratelimit import Tier, RateLimiter, RateLimitExceeded +from ratelimit import Tier, RateLimiter, RateLimitExceeded, _get_redis_connection from freezegun import freeze_time -from redis import Redis -redis = Redis() +redis = _get_redis_connection() simple_minute_tier = Tier("10/minute", 10, 0, 0) simple_hour_tier = Tier("10/hour", 0, 10, 0) simple_daily_tier = Tier("10/day", 0, 0, 10)