Files
mono/shared/infrastructure/rate_limit.py
giles c015f3f02f Security audit: fix IDOR, add rate limiting, HMAC auth, token hashing, XSS sanitization
Critical: Add ownership checks to all order routes (IDOR fix).
High: Redis rate limiting on auth endpoints, HMAC-signed internal
service calls replacing header-presence-only checks, nh3 HTML
sanitization on ghost_sync and product import, internal auth on
market API endpoints, SHA-256 hashed OAuth grant/code tokens.
Medium: SECRET_KEY production guard, AP signature enforcement,
is_admin param removal, cart_sid validation, SSRF protection on
remote actor fetch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 13:30:27 +00:00

143 lines
4.0 KiB
Python

"""Redis-based rate limiter for auth endpoints.
Provides a decorator that enforces per-key rate limits using a sliding
window counter stored in Redis (auth DB 15).
Usage::
from shared.infrastructure.rate_limit import rate_limit
@rate_limit(key_func=lambda: request.form.get("email", "").lower(),
max_requests=5, window_seconds=900, scope="magic_link")
@bp.post("/start/")
async def start_login():
...
"""
from __future__ import annotations
import functools
import time
from quart import request, jsonify, make_response
async def _check_rate_limit(
key: str,
max_requests: int,
window_seconds: int,
) -> tuple[bool, int]:
"""Check and increment rate limit counter.
Returns (allowed, remaining).
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
now = time.time()
window_start = now - window_seconds
redis_key = f"rl:{key}"
pipe = r.pipeline()
# Remove expired entries
pipe.zremrangebyscore(redis_key, 0, window_start)
# Add current request
pipe.zadd(redis_key, {str(now).encode(): now})
# Count entries in window
pipe.zcard(redis_key)
# Set TTL so key auto-expires
pipe.expire(redis_key, window_seconds)
results = await pipe.execute()
count = results[2]
allowed = count <= max_requests
remaining = max(0, max_requests - count)
return allowed, remaining
def rate_limit(
*,
key_func,
max_requests: int,
window_seconds: int,
scope: str,
):
"""Decorator that rate-limits a Quart route.
Parameters
----------
key_func:
Callable returning the rate-limit key (e.g. email, IP).
Called inside request context.
max_requests:
Maximum number of requests allowed in the window.
window_seconds:
Sliding window duration in seconds.
scope:
Namespace prefix for the Redis key (e.g. "magic_link").
"""
def decorator(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
raw_key = key_func()
if not raw_key:
return await fn(*args, **kwargs)
full_key = f"{scope}:{raw_key}"
try:
allowed, remaining = await _check_rate_limit(
full_key, max_requests, window_seconds,
)
except Exception:
# If Redis is down, allow the request
return await fn(*args, **kwargs)
if not allowed:
resp = await make_response(
jsonify({"error": "rate_limited", "retry_after": window_seconds}),
429,
)
resp.headers["Retry-After"] = str(window_seconds)
return resp
return await fn(*args, **kwargs)
return wrapper
return decorator
async def check_poll_backoff(device_code: str) -> tuple[bool, int]:
"""Enforce exponential backoff on device token polling.
Returns (allowed, interval) where interval is the recommended
poll interval in seconds. If not allowed, caller should return
a 'slow_down' error per RFC 8628.
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
key = f"rl:devpoll:{device_code}"
now = time.time()
raw = await r.get(key)
if raw:
data = raw.decode() if isinstance(raw, bytes) else raw
parts = data.split(":")
last_poll = float(parts[0])
interval = int(parts[1])
elapsed = now - last_poll
if elapsed < interval:
# Too fast — increase interval
new_interval = min(interval + 5, 60)
await r.set(key, f"{now}:{new_interval}".encode(), ex=900)
return False, new_interval
# Acceptable pace — keep current interval
await r.set(key, f"{now}:{interval}".encode(), ex=900)
return True, interval
# First poll
initial_interval = 5
await r.set(key, f"{now}:{initial_interval}".encode(), ex=900)
return True, initial_interval