Phase 1-3 of decoupling plan: - Shared DB, models, infrastructure, browser, config, utils - Event infrastructure (domain_events outbox, bus, processor) - Structured logging - Generic container concept (container_type/container_id) - Alembic migrations for all schema changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
100 lines
2.4 KiB
Python
100 lines
2.4 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
from typing import Callable, Awaitable, Optional
|
|
|
|
from quart import (
|
|
abort,
|
|
current_app,
|
|
request,
|
|
session as qsession,
|
|
)
|
|
|
|
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
|
|
|
|
|
def generate_csrf_token() -> str:
|
|
"""
|
|
Per-session CSRF token.
|
|
|
|
In Jinja:
|
|
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}">
|
|
"""
|
|
token = qsession.get("csrf_token")
|
|
if not token:
|
|
token = secrets.token_urlsafe(32)
|
|
qsession["csrf_token"] = token
|
|
return token
|
|
|
|
|
|
def _is_exempt_endpoint() -> bool:
|
|
endpoint = request.endpoint
|
|
if not endpoint:
|
|
return False
|
|
view = current_app.view_functions.get(endpoint)
|
|
|
|
# Walk decorator stack (__wrapped__) to find csrf_exempt
|
|
while view is not None:
|
|
if getattr(view, "_csrf_exempt", False):
|
|
return True
|
|
view = getattr(view, "__wrapped__", None)
|
|
|
|
return False
|
|
|
|
|
|
async def protect() -> None:
|
|
"""
|
|
Enforce CSRF on unsafe methods.
|
|
|
|
Supports:
|
|
* Forms: hidden input "csrf_token"
|
|
* JSON: "csrf_token" or "csrfToken" field
|
|
* HTMX/AJAX: "X-CSRFToken" or "X-CSRF-Token" header
|
|
"""
|
|
if request.method in SAFE_METHODS:
|
|
return
|
|
|
|
if _is_exempt_endpoint():
|
|
return
|
|
|
|
session_token = qsession.get("csrf_token")
|
|
if not session_token:
|
|
abort(400, "Missing CSRF session token")
|
|
|
|
supplied_token: Optional[str] = None
|
|
|
|
# JSON body
|
|
if request.mimetype == "application/json":
|
|
data = await request.get_json(silent=True) or {}
|
|
supplied_token = data.get("csrf_token") or data.get("csrfToken")
|
|
|
|
# Form body
|
|
if not supplied_token and request.mimetype != "application/json":
|
|
form = await request.form
|
|
supplied_token = form.get("csrf_token")
|
|
|
|
# Headers (HTMX / fetch)
|
|
if not supplied_token:
|
|
supplied_token = (
|
|
request.headers.get("X-CSRFToken")
|
|
or request.headers.get("X-CSRF-Token")
|
|
)
|
|
|
|
if not supplied_token or supplied_token != session_token:
|
|
abort(400, "Invalid CSRF token")
|
|
|
|
|
|
def csrf_exempt(view: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
|
|
"""
|
|
Mark a view as CSRF-exempt.
|
|
|
|
from suma_browser.app.csrf import csrf_exempt
|
|
|
|
@csrf_exempt
|
|
@blueprint.post("/hook")
|
|
async def webhook():
|
|
...
|
|
"""
|
|
setattr(view, "_csrf_exempt", True)
|
|
return view
|