"""Redis-based rate limiter for auth endpoints. Provides a decorator that enforces per-key rate limits using a sliding window counter stored in Redis (auth DB 15). Usage:: from shared.infrastructure.rate_limit import rate_limit @rate_limit(key_func=lambda: request.form.get("email", "").lower(), max_requests=5, window_seconds=900, scope="magic_link") @bp.post("/start/") async def start_login(): ... """ from __future__ import annotations import functools import time from quart import request, jsonify, make_response async def _check_rate_limit( key: str, max_requests: int, window_seconds: int, ) -> tuple[bool, int]: """Check and increment rate limit counter. Returns (allowed, remaining). """ from shared.infrastructure.auth_redis import get_auth_redis r = await get_auth_redis() now = time.time() window_start = now - window_seconds redis_key = f"rl:{key}" pipe = r.pipeline() # Remove expired entries pipe.zremrangebyscore(redis_key, 0, window_start) # Add current request pipe.zadd(redis_key, {str(now).encode(): now}) # Count entries in window pipe.zcard(redis_key) # Set TTL so key auto-expires pipe.expire(redis_key, window_seconds) results = await pipe.execute() count = results[2] allowed = count <= max_requests remaining = max(0, max_requests - count) return allowed, remaining def rate_limit( *, key_func, max_requests: int, window_seconds: int, scope: str, ): """Decorator that rate-limits a Quart route. Parameters ---------- key_func: Callable returning the rate-limit key (e.g. email, IP). Called inside request context. max_requests: Maximum number of requests allowed in the window. window_seconds: Sliding window duration in seconds. scope: Namespace prefix for the Redis key (e.g. "magic_link"). """ def decorator(fn): @functools.wraps(fn) async def wrapper(*args, **kwargs): raw_key = key_func() if not raw_key: return await fn(*args, **kwargs) full_key = f"{scope}:{raw_key}" try: allowed, remaining = await _check_rate_limit( full_key, max_requests, window_seconds, ) except Exception: # If Redis is down, allow the request return await fn(*args, **kwargs) if not allowed: resp = await make_response( jsonify({"error": "rate_limited", "retry_after": window_seconds}), 429, ) resp.headers["Retry-After"] = str(window_seconds) return resp return await fn(*args, **kwargs) return wrapper return decorator async def check_poll_backoff(device_code: str) -> tuple[bool, int]: """Enforce exponential backoff on device token polling. Returns (allowed, interval) where interval is the recommended poll interval in seconds. If not allowed, caller should return a 'slow_down' error per RFC 8628. """ from shared.infrastructure.auth_redis import get_auth_redis r = await get_auth_redis() key = f"rl:devpoll:{device_code}" now = time.time() raw = await r.get(key) if raw: data = raw.decode() if isinstance(raw, bytes) else raw parts = data.split(":") last_poll = float(parts[0]) interval = int(parts[1]) elapsed = now - last_poll if elapsed < interval: # Too fast — increase interval new_interval = min(interval + 5, 60) await r.set(key, f"{now}:{new_interval}".encode(), ex=900) return False, new_interval # Acceptable pace — keep current interval await r.set(key, f"{now}:{interval}".encode(), ex=900) return True, interval # First poll initial_interval = 5 await r.set(key, f"{now}:{initial_interval}".encode(), ex=900) return True, initial_interval