from __future__ import annotations from functools import wraps from typing import Optional, Literal import asyncio from quart import ( Quart, request, Response, g, current_app, ) from redis import asyncio as aioredis Scope = Literal["user", "global", "anon"] TagScope = Literal["all", "user"] # for clear_cache # --------------------------------------------------------------------------- # Redis setup # --------------------------------------------------------------------------- def register(app: Quart) -> None: @app.before_serving async def setup_redis() -> None: if app.config["REDIS_URL"] and app.config["REDIS_URL"] != 'no': app.redis = aioredis.Redis.from_url( app.config["REDIS_URL"], encoding="utf-8", decode_responses=False, # store bytes ) else: app.redis = False @app.after_serving async def close_redis() -> None: if app.redis: await app.redis.close() # optional: await app.redis.connection_pool.disconnect() def get_redis(): return current_app.redis # --------------------------------------------------------------------------- # Key helpers # --------------------------------------------------------------------------- def get_user_id() -> str: """ Returns a string id or 'anon'. Adjust based on your auth system. """ user = getattr(g, "user", None) if user: return str(user.id) return "anon" def make_cache_key(cache_user_id: str) -> str: """ Build a cache key for this (user/global/anon) + path + query + HTMX status. HTMX requests and normal requests get different cache keys because they return different content (partials vs full pages). Keys are namespaced by app name (from CACHE_APP_PREFIX) to avoid collisions between apps that may share the same paths. """ app_prefix = current_app.config.get("CACHE_APP_PREFIX", "app") path = request.path qs = request.query_string.decode() if request.query_string else "" # Check if this is an HTMX request is_htmx = request.headers.get("HX-Request", "").lower() == "true" htmx_suffix = ":htmx" if is_htmx else "" if qs: return f"cache:{app_prefix}:page:{cache_user_id}:{path}?{qs}{htmx_suffix}" else: return f"cache:{app_prefix}:page:{cache_user_id}:{path}{htmx_suffix}" def user_set_key(user_id: str) -> str: """ Redis set that tracks all cache keys for a given user id. Only used when scope='user'. """ return f"cache:user:{user_id}" def tag_set_key(tag: str) -> str: """ Redis set that tracks all cache keys associated with a tag (across all scopes/users). """ return f"cache:tag:{tag}" # --------------------------------------------------------------------------- # Invalidation helpers # --------------------------------------------------------------------------- async def invalidate_user_cache(user_id: str) -> None: """ Delete all cached pages for a specific user (scope='user' caches). """ r = get_redis() if r: s_key = user_set_key(user_id) keys = await r.smembers(s_key) # set of bytes if keys: await r.delete(*keys) await r.delete(s_key) async def invalidate_tag_cache(tag: str) -> None: """ Delete all cached pages associated with this tag, for all users/scopes. """ r = get_redis() if r: t_key = tag_set_key(tag) keys = await r.smembers(t_key) # set of bytes if keys: await r.delete(*keys) await r.delete(t_key) async def invalidate_tag_cache_for_user(tag: str, cache_uid: str) -> None: r = get_redis() if not r: return t_key = tag_set_key(tag) keys = await r.smembers(t_key) # set of bytes if not keys: return prefix = f"cache:page:{cache_uid}:".encode("utf-8") # Filter keys belonging to this cache_uid only to_delete = [k for k in keys if k.startswith(prefix)] if not to_delete: return # Delete those page entries await r.delete(*to_delete) # Remove them from the tag set (leave other users' keys intact) await r.srem(t_key, *to_delete) async def invalidate_tag_cache_for_current_user(tag: str) -> None: """ Convenience helper: delete tag cache for the current user_id (scope='user'). """ uid = get_user_id() await invalidate_tag_cache_for_user(tag, uid) # --------------------------------------------------------------------------- # Cache decorator for GET # --------------------------------------------------------------------------- def cache_page( ttl: int = 0, tag: Optional[str] = None, scope: Scope = "user", ): """ Cache GET responses in Redis. ttl: Seconds to keep the cache. 0 = no expiry. tag: Optional tag name used for bulk invalidation via invalidate_tag_cache(). scope: "user" → cache per-user (includes 'anon'), tracked in cache:user:{id} "global" → single cache shared by everyone (no per-user tracking) "anon" → cache only for anonymous users; logged-in users bypass cache """ def decorator(view): @wraps(view) async def wrapper(*args, **kwargs): r = get_redis() if not r or request.method != "GET": return await view(*args, **kwargs) uid = get_user_id() # Decide who the cache key is keyed on if scope == "global": cache_uid = "global" elif scope == "anon": # Only cache for anonymous users if uid != "anon": return await view(*args, **kwargs) cache_uid = "anon" else: # scope == "user" cache_uid = uid key = make_cache_key(cache_uid) cached = await r.hgetall(key) if cached: body = cached[b"body"] status = int(cached[b"status"].decode()) content_type = cached.get(b"content_type", b"text/html").decode() return Response(body, status=status, content_type=content_type) # Not cached, call the view resp = await view(*args, **kwargs) # Normalise: if the view returned a string/bytes, wrap it if not isinstance(resp, Response): resp = Response(resp, content_type="text/html") # Only cache successful responses if resp.status_code == 200: body = await resp.get_data() # bytes pipe = r.pipeline() pipe.hset( key, mapping={ "body": body, "status": str(resp.status_code), "content_type": resp.content_type or "text/html", }, ) if ttl: pipe.expire(key, ttl) # Track per-user keys only when scope='user' if scope == "user": pipe.sadd(user_set_key(cache_uid), key) # Track per-tag keys (all scopes) if tag: pipe.sadd(tag_set_key(tag), key) await pipe.execute() resp.set_data(body) return resp return wrapper return decorator # --------------------------------------------------------------------------- # Clear cache decorator for POST (or any method) # --------------------------------------------------------------------------- def clear_cache( *, tag: Optional[str] = None, tag_scope: TagScope = "all", clear_user: bool = False, ): """ Decorator for routes that should clear cache after they run. Use on POST/PUT/PATCH/DELETE handlers. Params: tag: If set, will clear caches for this tag. tag_scope: "all" → invalidate_tag_cache(tag) (all users/scopes) "user" → invalidate_tag_cache_for_current_user(tag) clear_user: If True, also run invalidate_user_cache(current_user_id). Typical usage: @bp.post("/posts//edit") @clear_cache(tag="post.post_detail", tag_scope="all") async def edit_post(slug): ... @bp.post("/prefs") @clear_cache(tag="dashboard", tag_scope="user", clear_user=True) async def update_prefs(): ... """ def decorator(view): @wraps(view) async def wrapper(*args, **kwargs): # Run the view first resp = await view(*args, **kwargs) if get_redis(): # Only clear cache if the view succeeded (2xx) status = getattr(resp, "status_code", None) if status is None: # Non-Response return (string, dict) -> treat as success success = True else: success = 200 <= status < 300 if not success: return resp # Perform invalidations tasks = [] if clear_user: uid = get_user_id() tasks.append(invalidate_user_cache(uid)) if tag: if tag_scope == "all": tasks.append(invalidate_tag_cache(tag)) else: # tag_scope == "user" tasks.append(invalidate_tag_cache_for_current_user(tag)) if tasks: # Run them concurrently await asyncio.gather(*tasks) return resp return wrapper return decorator async def clear_all_cache(prefix: str = "cache:") -> None: r = get_redis() if not r: return cursor = 0 pattern = f"{prefix}*" while True: cursor, keys = await r.scan(cursor=cursor, match=pattern, count=500) if keys: await r.delete(*keys) if cursor == 0: break