from __future__ import annotations import secrets from typing import Callable, Awaitable, Optional from quart import ( abort, current_app, request, session as qsession, ) SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} def generate_csrf_token() -> str: """ Per-session CSRF token. In Jinja: """ token = qsession.get("csrf_token") if not token: token = secrets.token_urlsafe(32) qsession["csrf_token"] = token return token def _is_exempt_endpoint() -> bool: endpoint = request.endpoint if not endpoint: return False view = current_app.view_functions.get(endpoint) # Walk decorator stack (__wrapped__) to find csrf_exempt while view is not None: if getattr(view, "_csrf_exempt", False): return True view = getattr(view, "__wrapped__", None) return False async def protect() -> None: """ Enforce CSRF on unsafe methods. Supports: * Forms: hidden input "csrf_token" * JSON: "csrf_token" or "csrfToken" field * HTMX/AJAX: "X-CSRFToken" or "X-CSRF-Token" header """ if request.method in SAFE_METHODS: return if _is_exempt_endpoint(): return # Internal service-to-service calls are already gated by header checks # and only reachable on the Docker overlay network. if request.headers.get("X-Internal-Action") or request.headers.get("X-Internal-Data"): return session_token = qsession.get("csrf_token") if not session_token: abort(400, "Missing CSRF session token") supplied_token: Optional[str] = None # JSON body if request.mimetype == "application/json": data = await request.get_json(silent=True) or {} supplied_token = data.get("csrf_token") or data.get("csrfToken") # Form body if not supplied_token and request.mimetype != "application/json": form = await request.form supplied_token = form.get("csrf_token") # Headers (HTMX / fetch) if not supplied_token: supplied_token = ( request.headers.get("X-CSRFToken") or request.headers.get("X-CSRF-Token") ) if not supplied_token or supplied_token != session_token: abort(400, "Invalid CSRF token") def csrf_exempt(view: Callable[..., Awaitable]) -> Callable[..., Awaitable]: """ Mark a view as CSRF-exempt. from suma_browser.app.csrf import csrf_exempt @csrf_exempt @blueprint.post("/hook") async def webhook(): ... """ setattr(view, "_csrf_exempt", True) return view