feat: initial shared library extraction

Contains shared infrastructure for all coop services:
- shared/ (factory, urls, user_loader, context, internal_api, jinja_setup)
- models/ (User, Order, Calendar, Ticket, Product, Ghost CMS)
- db/ (SQLAlchemy async session, base)
- suma_browser/app/ (csrf, middleware, errors, authz, redis_cacher, payments, filters, utils)
- suma_browser/templates/ (shared base layouts, macros, error pages)
- static/ (CSS, JS, fonts, images)
- alembic/ (database migrations)
- config/ (app-config.yaml)
- editor/ (Lexical editor Node.js build)
- requirements.txt

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
giles
2026-02-09 23:11:36 +00:00
commit 668d9c7df8
446 changed files with 22741 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# The monolith has been split into three apps (apps/coop, apps/market, apps/cart).
# This package remains for shared infrastructure modules (middleware, redis_cacher,
# csrf, errors, authz, filters, utils, bp/*).
#
# To run individual apps:
# hypercorn apps.coop.app:app --bind 0.0.0.0:8000
# hypercorn apps.market.app:app --bind 0.0.0.0:8001
# hypercorn apps.cart.app:app --bind 0.0.0.0:8002
#
# Legacy single-process:
# hypercorn suma_browser.app.app:app --bind 0.0.0.0:8000
# (runs the old monolith from app.py, which still works)

152
suma_browser/app/authz.py Normal file
View File

@@ -0,0 +1,152 @@
from __future__ import annotations
from functools import wraps
from typing import Any, Dict, Iterable, Optional
import inspect
from quart import g, abort, redirect, request, current_app
from shared.urls import login_url
def require_rights(*rights: str, any_of: bool = True):
"""
Decorator for routes that require certain user rights.
"""
if not rights:
raise ValueError("require_rights needs at least one right name")
required_set = frozenset(rights)
def decorator(view_func):
@wraps(view_func)
async def wrapper(*args: Any, **kwargs: Any):
# Not logged in → go to login, with ?next=<current path>
user = g.get("user")
if not user:
return redirect(login_url(request.url))
rights_dict = g.get("rights") or {}
if any_of:
allowed = any(rights_dict.get(name) for name in required_set)
else:
allowed = all(rights_dict.get(name) for name in required_set)
if not allowed:
abort(403)
result = view_func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
# ---- expose access requirements on the wrapper ----
wrapper.__access_requires__ = {
"rights": required_set,
"any_of": any_of,
}
return wrapper
return decorator
def require_login(view_func):
"""
Decorator for routes that require any logged-in user.
"""
@wraps(view_func)
async def wrapper(*args: Any, **kwargs: Any):
user = g.get("user")
if not user:
return redirect(login_url(request.url))
result = view_func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
return wrapper
def require_admin(view_func=None):
"""
Shortcut for routes that require the 'admin' right.
"""
if view_func is None:
return require_rights("admin")
return require_rights("admin")(view_func)
def require_post_author(view_func):
"""Allow admin or post owner."""
@wraps(view_func)
async def wrapper(*args, **kwargs):
user = g.get("user")
if not user:
return redirect(login_url(request.url))
is_admin = bool((g.get("rights") or {}).get("admin"))
if is_admin:
result = view_func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
post = getattr(g, "post_data", {}).get("original_post")
if post and post.user_id == user.id:
result = view_func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
abort(403)
return wrapper
def _get_access_meta(view_func) -> Optional[Dict[str, Any]]:
"""
Walk the wrapper chain looking for __access_requires__ metadata.
"""
func = view_func
seen: set[int] = set()
while func is not None and id(func) not in seen:
seen.add(id(func))
meta = getattr(func, "__access_requires__", None)
if meta is not None:
return meta
func = getattr(func, "__wrapped__", None)
return None
def has_access(endpoint: str) -> bool:
"""
Return True if the current user has access to the given endpoint.
Example:
has_access("settings.home")
has_access("settings.clear_cache_view")
"""
view = current_app.view_functions.get(endpoint)
if view is None:
# Unknown endpoint: be conservative
return False
meta = _get_access_meta(view)
# If the route has no rights metadata, treat it as public:
if meta is None:
return True
required: Iterable[str] = meta["rights"]
any_of: bool = meta["any_of"]
# Must be in a request context; if no user, they don't have access
user = g.get("user")
if not user:
return False
rights_dict = g.get("rights") or {}
if any_of:
return any(rights_dict.get(name) for name in required)
else:
return all(rights_dict.get(name) for name in required)

99
suma_browser/app/csrf.py Normal file
View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import secrets
from typing import Callable, Awaitable, Optional
from quart import (
abort,
current_app,
request,
session as qsession,
)
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
def generate_csrf_token() -> str:
"""
Per-session CSRF token.
In Jinja:
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}">
"""
token = qsession.get("csrf_token")
if not token:
token = secrets.token_urlsafe(32)
qsession["csrf_token"] = token
return token
def _is_exempt_endpoint() -> bool:
endpoint = request.endpoint
if not endpoint:
return False
view = current_app.view_functions.get(endpoint)
# Walk decorator stack (__wrapped__) to find csrf_exempt
while view is not None:
if getattr(view, "_csrf_exempt", False):
return True
view = getattr(view, "__wrapped__", None)
return False
async def protect() -> None:
"""
Enforce CSRF on unsafe methods.
Supports:
* Forms: hidden input "csrf_token"
* JSON: "csrf_token" or "csrfToken" field
* HTMX/AJAX: "X-CSRFToken" or "X-CSRF-Token" header
"""
if request.method in SAFE_METHODS:
return
if _is_exempt_endpoint():
return
session_token = qsession.get("csrf_token")
if not session_token:
abort(400, "Missing CSRF session token")
supplied_token: Optional[str] = None
# JSON body
if request.mimetype == "application/json":
data = await request.get_json(silent=True) or {}
supplied_token = data.get("csrf_token") or data.get("csrfToken")
# Form body
if not supplied_token and request.mimetype != "application/json":
form = await request.form
supplied_token = form.get("csrf_token")
# Headers (HTMX / fetch)
if not supplied_token:
supplied_token = (
request.headers.get("X-CSRFToken")
or request.headers.get("X-CSRF-Token")
)
if not supplied_token or supplied_token != session_token:
abort(400, "Invalid CSRF token")
def csrf_exempt(view: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
"""
Mark a view as CSRF-exempt.
from suma_browser.app.csrf import csrf_exempt
@csrf_exempt
@blueprint.post("/hook")
async def webhook():
...
"""
setattr(view, "_csrf_exempt", True)
return view

126
suma_browser/app/errors.py Normal file
View File

@@ -0,0 +1,126 @@
from werkzeug.exceptions import HTTPException
from utils import hx_fragment_request
from quart import (
request,
render_template,
make_response,
current_app
)
from markupsafe import escape
class AppError(ValueError):
"""
Base class for app-level, client-safe errors.
Behaves like ValueError so existing except ValueError: still works.
"""
status_code: int = 400
def __init__(self, message, *, status_code: int | None = None):
# Support a single message or a list/tuple of messages
if isinstance(message, (list, tuple, set)):
self.messages = [str(m) for m in message]
msg = self.messages[0] if self.messages else ""
else:
self.messages = [str(message)]
msg = str(message)
super().__init__(msg)
if status_code is not None:
self.status_code = status_code
def errors(app):
def _info(e):
return {
"exception": e,
"method": request.method,
"url": str(request.url),
"base_url": str(request.base_url),
"root_path": request.root_path,
"path": request.path,
"full_path": request.full_path,
"endpoint": request.endpoint,
"url_rule": str(request.url_rule) if request.url_rule else None,
"headers": {k: v for k, v in request.headers.items()
if k.lower().startswith("x-forwarded") or k in ("Host",)},
}
@app.errorhandler(404)
async def not_found(e):
current_app.logger.warning("404 %s", _info(e))
if hx_fragment_request():
html = await render_template(
"_types/root/exceptions/hx/_.html",
errnum='404'
)
else:
html = await render_template(
"_types/root/exceptions/_.html",
errnum='404',
)
return await make_response(html, 404)
@app.errorhandler(403)
async def not_allowed(e):
current_app.logger.warning("403 %s", _info(e))
if hx_fragment_request():
html = await render_template(
"_types/root/exceptions/hx/_.html",
errnum='403'
)
else:
html = await render_template(
"_types/root/exceptions/_.html",
errnum='403',
)
return await make_response(html, 403)
@app.errorhandler(AppError)
async def app_error(e: AppError):
# App-level, client-safe errors
current_app.logger.info("AppError %s", _info(e))
status = getattr(e, "status_code", 400)
messages = getattr(e, "messages", [str(e)])
if request.headers.get("HX-Request") == "true":
# Build a little styled <ul><li>...</li></ul> snippet
lis = "".join(
f"<li>{escape(m)}</li>"
for m in messages if m
)
html = (
"<ul class='list-disc pl-5 space-y-1 text-sm text-red-600'>"
f"{lis}"
"</ul>"
)
return await make_response(html, status)
# Non-HTMX: show a nicer page with error messages
html = await render_template(
"_types/root/exceptions/app_error.html",
messages=messages,
)
return await make_response(html, status)
@app.errorhandler(Exception)
async def error(e):
current_app.logger.exception("Exception %s", _info(e))
status = 500
if isinstance(e, HTTPException):
status = e.code or 500
if request.headers.get("HX-Request") == "true":
# Generic message for unexpected/untrusted errors
return await make_response(
"Something went wrong. Please try again.",
status,
)
html = await render_template("_types/root/exceptions/error.html")
return await make_response(html, status)

View File

@@ -0,0 +1,17 @@
def register(app):
from .highlight import highlight
app.jinja_env.filters["highlight"] = highlight
from .qs import register as qs
from .url_join import register as url_join
from .combine import register as combine
from .currency import register as currency
from .truncate import register as truncate
from .getattr import register as getattr
qs(app)
url_join(app)
combine(app)
currency(app)
getattr(app)
# truncate(app)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Any, Mapping
def _deep_merge(dst: dict, src: Mapping) -> dict:
out = dict(dst)
for k, v in src.items():
if isinstance(v, Mapping) and isinstance(out.get(k), Mapping):
out[k] = _deep_merge(out[k], v) # type: ignore[arg-type]
else:
out[k] = v
return out
def register(app):
@app.template_filter("combine")
def combine_filter(a: Any, b: Any, deep: bool = False, drop_none: bool = False) -> Any:
"""
Jinja filter: merge two dict-like objects.
- Non-dict inputs: returns `a` unchanged.
- If drop_none=True, keys in `b` with value None are ignored.
- If deep=True, nested dicts are merged recursively.
"""
if not isinstance(a, Mapping) or not isinstance(b, Mapping):
return a
b2 = {k: v for k, v in b.items() if not (drop_none and v is None)}
return _deep_merge(a, b2) if deep else {**a, **b2}

View File

@@ -0,0 +1,12 @@
from decimal import Decimal
def register(app):
@app.template_filter("currency")
def currency_filter(value, code="GBP"):
if value is None:
return ""
# ensure decimal-ish
if isinstance(value, float):
value = Decimal(str(value))
symbol = "£" if code == "GBP" else code
return f"{symbol}{value:.2f}"

View File

@@ -0,0 +1,6 @@
def register(app):
@app.template_filter("getattr")
def jinja_getattr(obj, name, default=None):
# Safe getattr: returns default if the attribute is missing
return getattr(obj, name, default)

View File

@@ -0,0 +1,21 @@
# ---------- misc helpers / filters ----------
from markupsafe import Markup, escape
def highlight(text: str, needle: str, cls: str = "bg-yellow-200 rounded") -> Markup:
"""
Wraps case-insensitive matches of `needle` inside <mark class="...">.
Escapes everything safely.
"""
import re
if not text or not needle:
return Markup(escape(text or ""))
pattern = re.compile(re.escape(needle), re.IGNORECASE)
def repl(m: re.Match) -> str:
return f'<mark class="{escape(cls)}">{escape(m.group(0))}</mark>'
esc = escape(text)
result = pattern.sub(lambda m: Markup(repl(m)), esc)
return Markup(result)

View File

@@ -0,0 +1,13 @@
from typing import Dict
from quart import g
def register(app):
@app.template_filter("qs")
def qs_filter(dict: Dict):
if getattr(g, "makeqs_factory", False):
q= g.makeqs_factory()(
**dict,
)
return q
else:
return ""

View File

@@ -0,0 +1,78 @@
"""
Shared query-string primitives used by blog, market, and order qs modules.
"""
from __future__ import annotations
from urllib.parse import urlencode
# Sentinel meaning "leave value as-is" (used as default arg in makeqs)
KEEP = object()
def _iterify(x):
"""Normalize *x* to a list: None → [], scalar → [scalar], iterable → as-is."""
if x is None:
return []
if isinstance(x, (list, tuple, set)):
return x
return [x]
def _norm(s: str) -> str:
"""Strip + lowercase — used for case-insensitive filter dedup."""
return s.strip().lower()
def make_filter_set(
base: list[str],
add,
remove,
clear_filters: bool,
*,
single_select: bool = False,
) -> list[str]:
"""
Build a deduplicated, sorted filter list.
Parameters
----------
base : list[str]
Current filter values.
add : str | list | None
Value(s) to add.
remove : str | list | None
Value(s) to remove.
clear_filters : bool
If True, start from empty instead of *base*.
single_select : bool
If True, *add* **replaces** the list (blog tags/authors).
If False, *add* is **appended** (market brands/stickers/labels).
"""
add_list = [s for s in _iterify(add) if s is not None]
if single_select:
# Blog-style: adding replaces the entire set
if add_list:
table = {_norm(s): s for s in add_list}
else:
table = {_norm(s): s for s in base if not clear_filters}
else:
# Market-style: adding appends to the existing set
table = {_norm(s): s for s in base if not clear_filters}
for s in add_list:
k = _norm(s)
if k not in table:
table[k] = s
for s in _iterify(remove):
if s is None:
continue
table.pop(_norm(s), None)
return [table[k] for k in sorted(table)]
def build_qs(params: list[tuple[str, str]], *, leading_q: bool = True) -> str:
"""URL-encode *params* and optionally prepend ``?``."""
qs = urlencode(params, doseq=True)
return ("?" + qs) if (qs and leading_q) else qs

View File

@@ -0,0 +1,33 @@
"""
NamedTuple types returned by each blueprint's ``decode()`` function.
"""
from __future__ import annotations
from typing import NamedTuple
class BlogQuery(NamedTuple):
page: int
search: str | None
sort: str | None
selected_tags: tuple[str, ...]
selected_authors: tuple[str, ...]
liked: str | None
view: str | None
drafts: str | None
selected_groups: tuple[str, ...]
class MarketQuery(NamedTuple):
page: int
search: str | None
sort: str | None
selected_brands: tuple[str, ...]
selected_stickers: tuple[str, ...]
selected_labels: tuple[str, ...]
liked: str | None
class OrderQuery(NamedTuple):
page: int
search: str | None

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
def register(app):
@app.template_filter("truncate")
def truncate(text, max_length=100):
"""
Truncate text to max_length characters and add an ellipsis character (…)
if it was longer.
"""
if text is None:
return ""
text = str(text)
if len(text) <= max_length:
return text
# Leave space for the ellipsis itself
if max_length <= 1:
return ""
return text[:max_length - 1] + ""

View File

@@ -0,0 +1,20 @@
from typing import Iterable, Union
from utils import join_url, host_url, _join_url_parts, route_prefix
# --- Register as a Jinja filter (Quart / Flask) ---
def register(app):
@app.template_filter("urljoin")
def urljoin_filter(value: Union[str, Iterable[str]]):
return join_url(value)
@app.template_filter("urlhost")
def urlhost_filter(value: Union[str, Iterable[str]]):
return host_url(value)
@app.template_filter("urlhost_no_slash")
def urlhost_no_slash_filter(value: Union[str, Iterable[str]]):
return host_url(value, True)
@app.template_filter("host")
def host_filter(value: str):
return _join_url_parts([route_prefix(), value])

View File

@@ -0,0 +1,58 @@
def register(app):
import json
from typing import Any
def _decode_headers(scope) -> dict[str, str]:
out = {}
for k, v in scope.get("headers", []):
try:
out[k.decode("latin1")] = v.decode("latin1")
except Exception:
out[repr(k)] = repr(v)
return out
def _safe(obj: Any):
# make scope json-serialisable; fall back to repr()
try:
json.dumps(obj)
return obj
except Exception:
return repr(obj)
class ScopeDumpMiddleware:
def __init__(self, app, *, log_bodies: bool = False):
self.app = app
self.log_bodies = log_bodies # keep False; bodies aren't needed for routing
async def __call__(self, scope, receive, send):
if scope["type"] in ("http", "websocket"):
# Build a compact view of keys relevant to routing
scope_view = {
"type": scope.get("type"),
"asgi": scope.get("asgi"),
"http_version": scope.get("http_version"),
"scheme": scope.get("scheme"),
"method": scope.get("method"),
"server": scope.get("server"),
"client": scope.get("client"),
"root_path": scope.get("root_path"),
"path": scope.get("path"),
"raw_path": scope.get("raw_path").decode("latin1") if scope.get("raw_path") else None,
"query_string": scope.get("query_string", b"").decode("latin1"),
"headers": _decode_headers(scope),
}
print("\n=== ASGI SCOPE (routing) ===")
print(json.dumps({_safe(k): _safe(v) for k, v in scope_view.items()}, indent=2))
print("=== END SCOPE ===\n", flush=True)
return await self.app(scope, receive, send)
# wrap LAST so you see what hits Quart
#app.asgi_app = ScopeDumpMiddleware(app.asgi_app)
from hypercorn.middleware import ProxyFixMiddleware
# trust a single proxy hop; use legacy X-Forwarded-* headers
app.asgi_app = ProxyFixMiddleware(app.asgi_app, mode="legacy", trusted_hops=1)

View File

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import os
from typing import Any, Dict
import httpx
from quart import current_app
from config import config
from models.order import Order
SUMUP_BASE_URL = "https://api.sumup.com/v0.1"
def _sumup_settings() -> Dict[str, str]:
cfg = config()
sumup_cfg = cfg.get("sumup", {}) or {}
api_key_env = sumup_cfg.get("api_key_env", "SUMUP_API_KEY")
api_key = os.getenv(api_key_env)
if not api_key:
raise RuntimeError(f"Missing SumUp API key in environment variable {api_key_env}")
merchant_code = sumup_cfg.get("merchant_code")
prefix = sumup_cfg.get("checkout_prefix", "")
if not merchant_code:
raise RuntimeError("Missing 'sumup.merchant_code' in app-config.yaml")
currency = sumup_cfg.get("currency", "GBP")
return {
"api_key": api_key,
"merchant_code": merchant_code,
"currency": currency,
"checkout_reference_prefix": prefix,
}
async def create_checkout(
order: Order,
redirect_url: str,
webhook_url: str | None = None,
description: str | None = None,
) -> Dict[str, Any]:
settings = _sumup_settings()
# Use stored reference if present, otherwise build it
checkout_reference = order.sumup_reference or f"{settings['checkout_reference_prefix']}{order.id}"
payload: Dict[str, Any] = {
"checkout_reference": checkout_reference,
"amount": float(order.total_amount),
"currency": settings["currency"],
"merchant_code": settings["merchant_code"],
"description": description or f"Order {order.id} at {current_app.config.get('APP_TITLE', 'Rose Ash')}",
"return_url": webhook_url or redirect_url,
"redirect_url": redirect_url,
"hosted_checkout": {"enabled": True},
}
headers = {
"Authorization": f"Bearer {settings['api_key']}",
"Content-Type": "application/json",
}
# Optional: log for debugging
current_app.logger.info(
"Creating SumUp checkout %s for Order %s amount %.2f",
checkout_reference,
order.id,
float(order.total_amount),
)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(f"{SUMUP_BASE_URL}/checkouts", json=payload, headers=headers)
if resp.status_code == 409:
# Duplicate checkout — retrieve the existing one by reference
current_app.logger.warning(
"SumUp duplicate checkout for ref %s order %s, fetching existing",
checkout_reference,
order.id,
)
list_resp = await client.get(
f"{SUMUP_BASE_URL}/checkouts",
params={"checkout_reference": checkout_reference},
headers=headers,
)
list_resp.raise_for_status()
items = list_resp.json()
if isinstance(items, list) and items:
return items[0]
if isinstance(items, dict) and items.get("items"):
return items["items"][0]
# Fallback: re-raise original error
resp.raise_for_status()
if resp.status_code >= 400:
current_app.logger.error(
"SumUp checkout error for ref %s order %s: %s",
checkout_reference,
order.id,
resp.text,
)
resp.raise_for_status()
data = resp.json()
return data
async def get_checkout(checkout_id: str) -> Dict[str, Any]:
"""Fetch checkout status/details from SumUp."""
settings = _sumup_settings()
headers = {
"Authorization": f"Bearer {settings['api_key']}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{SUMUP_BASE_URL}/checkouts/{checkout_id}", headers=headers)
resp.raise_for_status()
return resp.json()

View File

@@ -0,0 +1,346 @@
from __future__ import annotations
from functools import wraps
from typing import Optional, Literal
import asyncio
from quart import (
Quart,
request,
Response,
g,
current_app,
)
from redis import asyncio as aioredis
Scope = Literal["user", "global", "anon"]
TagScope = Literal["all", "user"] # for clear_cache
# ---------------------------------------------------------------------------
# Redis setup
# ---------------------------------------------------------------------------
def register(app: Quart) -> None:
@app.before_serving
async def setup_redis() -> None:
if app.config["REDIS_URL"] and app.config["REDIS_URL"] != 'no':
app.redis = aioredis.Redis.from_url(
app.config["REDIS_URL"],
encoding="utf-8",
decode_responses=False, # store bytes
)
else:
app.redis = False
@app.after_serving
async def close_redis() -> None:
if app.redis:
await app.redis.close()
# optional: await app.redis.connection_pool.disconnect()
def get_redis():
return current_app.redis
# ---------------------------------------------------------------------------
# Key helpers
# ---------------------------------------------------------------------------
def get_user_id() -> str:
"""
Returns a string id or 'anon'.
Adjust based on your auth system.
"""
user = getattr(g, "user", None)
if user:
return str(user.id)
return "anon"
def make_cache_key(cache_user_id: str) -> str:
"""
Build a cache key for this (user/global/anon) + path + query + HTMX status.
HTMX requests and normal requests get different cache keys because they
return different content (partials vs full pages).
Keys are namespaced by app name (from CACHE_APP_PREFIX) to avoid
collisions between apps that may share the same paths.
"""
app_prefix = current_app.config.get("CACHE_APP_PREFIX", "app")
path = request.path
qs = request.query_string.decode() if request.query_string else ""
# Check if this is an HTMX request
is_htmx = request.headers.get("HX-Request", "").lower() == "true"
htmx_suffix = ":htmx" if is_htmx else ""
if qs:
return f"cache:{app_prefix}:page:{cache_user_id}:{path}?{qs}{htmx_suffix}"
else:
return f"cache:{app_prefix}:page:{cache_user_id}:{path}{htmx_suffix}"
def user_set_key(user_id: str) -> str:
"""
Redis set that tracks all cache keys for a given user id.
Only used when scope='user'.
"""
return f"cache:user:{user_id}"
def tag_set_key(tag: str) -> str:
"""
Redis set that tracks all cache keys associated with a tag
(across all scopes/users).
"""
return f"cache:tag:{tag}"
# ---------------------------------------------------------------------------
# Invalidation helpers
# ---------------------------------------------------------------------------
async def invalidate_user_cache(user_id: str) -> None:
"""
Delete all cached pages for a specific user (scope='user' caches).
"""
r = get_redis()
if r:
s_key = user_set_key(user_id)
keys = await r.smembers(s_key) # set of bytes
if keys:
await r.delete(*keys)
await r.delete(s_key)
async def invalidate_tag_cache(tag: str) -> None:
"""
Delete all cached pages associated with this tag, for all users/scopes.
"""
r = get_redis()
if r:
t_key = tag_set_key(tag)
keys = await r.smembers(t_key) # set of bytes
if keys:
await r.delete(*keys)
await r.delete(t_key)
async def invalidate_tag_cache_for_user(tag: str, cache_uid: str) -> None:
r = get_redis()
if not r:
return
t_key = tag_set_key(tag)
keys = await r.smembers(t_key) # set of bytes
if not keys:
return
prefix = f"cache:page:{cache_uid}:".encode("utf-8")
# Filter keys belonging to this cache_uid only
to_delete = [k for k in keys if k.startswith(prefix)]
if not to_delete:
return
# Delete those page entries
await r.delete(*to_delete)
# Remove them from the tag set (leave other users' keys intact)
await r.srem(t_key, *to_delete)
async def invalidate_tag_cache_for_current_user(tag: str) -> None:
"""
Convenience helper: delete tag cache for the current user_id (scope='user').
"""
uid = get_user_id()
await invalidate_tag_cache_for_user(tag, uid)
# ---------------------------------------------------------------------------
# Cache decorator for GET
# ---------------------------------------------------------------------------
def cache_page(
ttl: int = 0,
tag: Optional[str] = None,
scope: Scope = "user",
):
"""
Cache GET responses in Redis.
ttl:
Seconds to keep the cache. 0 = no expiry.
tag:
Optional tag name used for bulk invalidation via invalidate_tag_cache().
scope:
"user" → cache per-user (includes 'anon'), tracked in cache:user:{id}
"global" → single cache shared by everyone (no per-user tracking)
"anon" → cache only for anonymous users; logged-in users bypass cache
"""
def decorator(view):
@wraps(view)
async def wrapper(*args, **kwargs):
r = get_redis()
if not r or request.method != "GET":
return await view(*args, **kwargs)
uid = get_user_id()
# Decide who the cache key is keyed on
if scope == "global":
cache_uid = "global"
elif scope == "anon":
# Only cache for anonymous users
if uid != "anon":
return await view(*args, **kwargs)
cache_uid = "anon"
else: # scope == "user"
cache_uid = uid
key = make_cache_key(cache_uid)
cached = await r.hgetall(key)
if cached:
body = cached[b"body"]
status = int(cached[b"status"].decode())
content_type = cached.get(b"content_type", b"text/html").decode()
return Response(body, status=status, content_type=content_type)
# Not cached, call the view
resp = await view(*args, **kwargs)
# Normalise: if the view returned a string/bytes, wrap it
if not isinstance(resp, Response):
resp = Response(resp, content_type="text/html")
# Only cache successful responses
if resp.status_code == 200:
body = await resp.get_data() # bytes
pipe = r.pipeline()
pipe.hset(
key,
mapping={
"body": body,
"status": str(resp.status_code),
"content_type": resp.content_type or "text/html",
},
)
if ttl:
pipe.expire(key, ttl)
# Track per-user keys only when scope='user'
if scope == "user":
pipe.sadd(user_set_key(cache_uid), key)
# Track per-tag keys (all scopes)
if tag:
pipe.sadd(tag_set_key(tag), key)
await pipe.execute()
resp.set_data(body)
return resp
return wrapper
return decorator
# ---------------------------------------------------------------------------
# Clear cache decorator for POST (or any method)
# ---------------------------------------------------------------------------
def clear_cache(
*,
tag: Optional[str] = None,
tag_scope: TagScope = "all",
clear_user: bool = False,
):
"""
Decorator for routes that should clear cache after they run.
Use on POST/PUT/PATCH/DELETE handlers.
Params:
tag:
If set, will clear caches for this tag.
tag_scope:
"all" → invalidate_tag_cache(tag) (all users/scopes)
"user" → invalidate_tag_cache_for_current_user(tag)
clear_user:
If True, also run invalidate_user_cache(current_user_id).
Typical usage:
@bp.post("/posts/<slug>/edit")
@clear_cache(tag="post.post_detail", tag_scope="all")
async def edit_post(slug):
...
@bp.post("/prefs")
@clear_cache(tag="dashboard", tag_scope="user", clear_user=True)
async def update_prefs():
...
"""
def decorator(view):
@wraps(view)
async def wrapper(*args, **kwargs):
# Run the view first
resp = await view(*args, **kwargs)
if get_redis():
# Only clear cache if the view succeeded (2xx)
status = getattr(resp, "status_code", None)
if status is None:
# Non-Response return (string, dict) -> treat as success
success = True
else:
success = 200 <= status < 300
if not success:
return resp
# Perform invalidations
tasks = []
if clear_user:
uid = get_user_id()
tasks.append(invalidate_user_cache(uid))
if tag:
if tag_scope == "all":
tasks.append(invalidate_tag_cache(tag))
else: # tag_scope == "user"
tasks.append(invalidate_tag_cache_for_current_user(tag))
if tasks:
# Run them concurrently
await asyncio.gather(*tasks)
return resp
return wrapper
return decorator
async def clear_all_cache(prefix: str = "cache:") -> None:
r = get_redis()
if not r:
return
cursor = 0
pattern = f"{prefix}*"
while True:
cursor, keys = await r.scan(cursor=cursor, match=pattern, count=500)
if keys:
await r.delete(*keys)
if cursor == 0:
break

View File

@@ -0,0 +1,12 @@
from .parse import (
parse_time,
parse_cost,
parse_dt
)
from .utils import (
current_route_relative_path,
current_url_without_page,
vary,
)
from .utc import utcnow

View File

@@ -0,0 +1,46 @@
"""HTMX utilities for detecting and handling HTMX requests."""
from quart import request
def is_htmx_request() -> bool:
"""
Check if the current request is an HTMX request.
Returns:
bool: True if HX-Request header is present and true
"""
return request.headers.get("HX-Request", "").lower() == "true"
def get_htmx_target() -> str | None:
"""
Get the target element ID from HTMX request headers.
Returns:
str | None: Target element ID or None
"""
return request.headers.get("HX-Target")
def get_htmx_trigger() -> str | None:
"""
Get the trigger element ID from HTMX request headers.
Returns:
str | None: Trigger element ID or None
"""
return request.headers.get("HX-Trigger")
def should_return_fragment() -> bool:
"""
Determine if we should return a fragment vs full page.
For HTMX requests, return fragment.
For normal requests, return full page.
Returns:
bool: True if fragment should be returned
"""
return is_htmx_request()

View File

@@ -0,0 +1,36 @@
from datetime import datetime, timezone
def parse_time(val: str | None):
if not val:
return None
try:
h,m = val.split(':', 1)
from datetime import time
return time(int(h), int(m))
except Exception:
return None
def parse_cost(val: str | None):
if not val:
return None
try:
return float(val)
except Exception:
return None
if not val:
return None
dt = datetime.fromisoformat(val)
# make TZ-aware (assume local if naive; convert to UTC)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
def parse_dt(val: str | None) -> datetime | None:
if not val:
return None
dt = datetime.fromisoformat(val)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt

View File

@@ -0,0 +1,6 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
return datetime.now(timezone.utc)

View File

@@ -0,0 +1,51 @@
from quart import (
Response,
request,
g,
)
from utils import host_url
from urllib.parse import urlencode
def current_route_relative_path() -> str:
"""
Returns the current request path relative to the app's mount point (script_root).
"""
(request.script_root or "").rstrip("/")
path = request.path # excludes query string
if g.root and path.startswith(f"/{g.root}"):
rel = path[len(g.root+1):]
return rel if rel.startswith("/") else "/" + rel
return path # app at /
def current_url_without_page() -> str:
"""
Build current URL (host+path+qs) but with ?page= removed.
Used for Hx-Push-Url.
"""
base = host_url(current_route_relative_path())
params = request.args.to_dict(flat=False) # keep multivals
params.pop("page", None)
qs = urlencode(params, doseq=True)
return f"{base}?{qs}" if qs else base
def vary(resp: Response) -> Response:
"""
Ensure caches/CDNs vary on HX headers so htmx/non-htmx versions don't get mixed.
"""
v = resp.headers.get("Vary", "")
parts = [p.strip() for p in v.split(",") if p.strip()]
for h in ("HX-Request", "X-Origin"):
if h not in parts:
parts.append(h)
if parts:
resp.headers["Vary"] = ", ".join(parts)
return resp