Security audit: fix IDOR, add rate limiting, HMAC auth, token hashing, XSS sanitization
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 3m22s

Critical: Add ownership checks to all order routes (IDOR fix).
High: Redis rate limiting on auth endpoints, HMAC-signed internal
service calls replacing header-presence-only checks, nh3 HTML
sanitization on ghost_sync and product import, internal auth on
market API endpoints, SHA-256 hashed OAuth grant/code tokens.
Medium: SECRET_KEY production guard, AP signature enforcement,
is_admin param removal, cart_sid validation, SSRF protection on
remote actor fetch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-26 13:30:27 +00:00
parent 404449fcab
commit c015f3f02f
27 changed files with 607 additions and 33 deletions

View File

@@ -0,0 +1,86 @@
"""Add token_hash columns to oauth_grants and oauth_codes
Revision ID: acct_0002
Revises: acct_0001
Create Date: 2026-02-26
"""
import hashlib
import sqlalchemy as sa
from alembic import op
revision = "acct_0002"
down_revision = "acct_0001"
branch_labels = None
depends_on = None
def _hash(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
def upgrade():
# Add new hash columns
op.add_column("oauth_grants", sa.Column("token_hash", sa.String(64), nullable=True))
op.add_column("oauth_codes", sa.Column("code_hash", sa.String(64), nullable=True))
op.add_column("oauth_codes", sa.Column("grant_token_hash", sa.String(64), nullable=True))
# Backfill hashes from existing plaintext tokens
conn = op.get_bind()
grants = conn.execute(sa.text("SELECT id, token FROM oauth_grants WHERE token IS NOT NULL"))
for row in grants:
conn.execute(
sa.text("UPDATE oauth_grants SET token_hash = :h WHERE id = :id"),
{"h": _hash(row.token), "id": row.id},
)
codes = conn.execute(sa.text("SELECT id, code, grant_token FROM oauth_codes WHERE code IS NOT NULL"))
for row in codes:
params = {"id": row.id, "ch": _hash(row.code)}
params["gh"] = _hash(row.grant_token) if row.grant_token else None
conn.execute(
sa.text("UPDATE oauth_codes SET code_hash = :ch, grant_token_hash = :gh WHERE id = :id"),
params,
)
# Create unique indexes on hash columns
op.create_index("ix_oauth_grant_token_hash", "oauth_grants", ["token_hash"], unique=True)
op.create_index("ix_oauth_code_code_hash", "oauth_codes", ["code_hash"], unique=True)
# Make original token columns nullable (keep for rollback safety)
op.alter_column("oauth_grants", "token", nullable=True)
op.alter_column("oauth_codes", "code", nullable=True)
# Drop old unique indexes on plaintext columns
try:
op.drop_index("ix_oauth_grant_token", "oauth_grants")
except Exception:
pass
try:
op.drop_index("ix_oauth_code_code", "oauth_codes")
except Exception:
pass
def downgrade():
# Restore original NOT NULL constraints
op.alter_column("oauth_grants", "token", nullable=False)
op.alter_column("oauth_codes", "code", nullable=False)
# Drop hash columns and indexes
try:
op.drop_index("ix_oauth_grant_token_hash", "oauth_grants")
except Exception:
pass
try:
op.drop_index("ix_oauth_code_code_hash", "oauth_codes")
except Exception:
pass
op.drop_column("oauth_grants", "token_hash")
op.drop_column("oauth_codes", "code_hash")
op.drop_column("oauth_codes", "grant_token_hash")
# Restore original unique indexes
op.create_index("ix_oauth_grant_token", "oauth_grants", ["token"], unique=True)
op.create_index("ix_oauth_code_code", "oauth_codes", ["code"], unique=True)

View File

@@ -17,6 +17,9 @@ def register() -> Blueprint:
async def _require_action_header():
if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -26,9 +26,10 @@ from sqlalchemy.exc import SQLAlchemyError
from shared.db.session import get_session
from shared.models import User
from shared.models.oauth_code import OAuthCode
from shared.models.oauth_grant import OAuthGrant
from shared.models.oauth_grant import OAuthGrant, hash_token
from shared.infrastructure.urls import account_url, app_url
from shared.infrastructure.cart_identity import current_cart_identity
from shared.infrastructure.rate_limit import rate_limit, check_poll_backoff
from shared.events import emit_activity
from .services import (
@@ -98,7 +99,8 @@ def register(url_prefix="/auth"):
async with get_session() as s:
async with s.begin():
grant = OAuthGrant(
token=grant_token,
token=None,
token_hash=hash_token(grant_token),
user_id=g.user.id,
client_id=client_id,
issuer_session=account_sid,
@@ -107,12 +109,14 @@ def register(url_prefix="/auth"):
s.add(grant)
oauth_code = OAuthCode(
code=code,
code=None,
code_hash=hash_token(code),
user_id=g.user.id,
client_id=client_id,
redirect_uri=redirect_uri,
expires_at=expires,
grant_token=grant_token,
grant_token=None,
grant_token_hash=hash_token(grant_token),
)
s.add(oauth_code)
@@ -149,11 +153,15 @@ def register(url_prefix="/auth"):
now = datetime.now(timezone.utc)
code_h = hash_token(code)
async with get_session() as s:
async with s.begin():
# Look up by hash first (new grants), fall back to plaintext (migration)
result = await s.execute(
select(OAuthCode)
.where(OAuthCode.code == code)
.where(
(OAuthCode.code_hash == code_h) | (OAuthCode.code == code)
)
.with_for_update()
)
oauth_code = result.scalar_one_or_none()
@@ -197,9 +205,12 @@ def register(url_prefix="/auth"):
if not token:
return jsonify({"valid": False}), 200
token_h = hash_token(token)
async with get_session() as s:
grant = await s.scalar(
select(OAuthGrant).where(OAuthGrant.token == token)
select(OAuthGrant).where(
(OAuthGrant.token_hash == token_h) | (OAuthGrant.token == token)
)
)
if not grant or grant.revoked_at is not None:
return jsonify({"valid": False}), 200
@@ -257,12 +268,19 @@ def register(url_prefix="/auth"):
store_login_redirect_target()
cross_cart_sid = request.args.get("cart_sid")
if cross_cart_sid:
import re
# Validate cart_sid is a hex token (32 chars from token_hex(16))
if re.fullmatch(r"[0-9a-f]{32}", cross_cart_sid):
qsession["cart_sid"] = cross_cart_sid
if g.get("user"):
redirect_url = pop_login_redirect_target()
return redirect(redirect_url)
return await render_template("auth/login.html")
@rate_limit(
key_func=lambda: request.headers.get("X-Forwarded-For", request.remote_addr),
max_requests=10, window_seconds=900, scope="magic_ip",
)
@auth_bp.post("/start/")
async def start_login():
form = await request.form
@@ -279,6 +297,22 @@ def register(url_prefix="/auth"):
400,
)
# Per-email rate limit: 5 magic links per 15 minutes
from shared.infrastructure.rate_limit import _check_rate_limit
try:
allowed, _ = await _check_rate_limit(f"magic_email:{email}", 5, 900)
if not allowed:
return (
await render_template(
"auth/check_email.html",
email=email,
email_error=None,
),
200,
)
except Exception:
pass # Redis down — allow the request
user = await find_or_create_user(g.s, email)
token, expires = await create_magic_link(g.s, user.id)
@@ -521,7 +555,8 @@ def register(url_prefix="/auth"):
async with get_session() as s:
async with s.begin():
grant = OAuthGrant(
token=grant_token,
token=None,
token_hash=hash_token(grant_token),
user_id=user.id,
client_id=blob["client_id"],
issuer_session=account_sid,
@@ -546,6 +581,10 @@ def register(url_prefix="/auth"):
return True
@rate_limit(
key_func=lambda: request.headers.get("X-Forwarded-For", request.remote_addr),
max_requests=10, window_seconds=3600, scope="dev_auth",
)
@csrf_exempt
@auth_bp.post("/device/authorize")
@auth_bp.post("/device/authorize/")
@@ -600,6 +639,14 @@ def register(url_prefix="/auth"):
if not device_code or client_id not in ALLOWED_CLIENTS:
return jsonify({"error": "invalid_request"}), 400
# Enforce polling backoff per RFC 8628
try:
poll_ok, interval = await check_poll_backoff(device_code)
if not poll_ok:
return jsonify({"error": "slow_down", "interval": interval}), 400
except Exception:
pass # Redis down — allow the request
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header():
if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -17,6 +17,9 @@ def register() -> Blueprint:
async def _require_action_header():
if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -15,6 +15,7 @@ from html import escape as html_escape
from typing import Dict, Any, Optional
import httpx
import nh3
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
@@ -29,6 +30,35 @@ GHOST_ADMIN_API_URL = os.environ["GHOST_ADMIN_API_URL"]
from shared.browser.app.utils import utcnow
def _sanitize_html(html: str | None) -> str | None:
"""Sanitize HTML content using nh3, allowing safe formatting tags."""
if not html:
return html
return nh3.clean(
html,
tags={
"a", "abbr", "acronym", "b", "blockquote", "br", "code",
"div", "em", "figcaption", "figure", "h1", "h2", "h3",
"h4", "h5", "h6", "hr", "i", "img", "li", "ol", "p",
"pre", "span", "strong", "sub", "sup", "table", "tbody",
"td", "th", "thead", "tr", "ul", "video", "source",
"picture", "iframe", "audio",
},
attributes={
"*": {"class", "id", "style"},
"a": {"href", "title", "target", "rel"},
"img": {"src", "alt", "title", "width", "height", "loading"},
"video": {"src", "controls", "width", "height", "poster"},
"audio": {"src", "controls"},
"source": {"src", "type"},
"iframe": {"src", "width", "height", "frameborder", "allowfullscreen"},
"td": {"colspan", "rowspan"},
"th": {"colspan", "rowspan"},
},
url_schemes={"http", "https", "mailto"},
)
def _auth_header() -> dict[str, str]:
return {"Authorization": f"Ghost {make_ghost_admin_jwt()}"}
@@ -99,13 +129,13 @@ def _apply_ghost_fields(obj: Post, gp: Dict[str, Any], author_map: Dict[str, Aut
obj.uuid = gp.get("uuid") or obj.uuid
obj.slug = gp.get("slug") or obj.slug
obj.title = gp.get("title") or obj.title
obj.html = gp.get("html")
obj.html = _sanitize_html(gp.get("html"))
obj.plaintext = gp.get("plaintext")
obj.mobiledoc = gp.get("mobiledoc")
obj.lexical = gp.get("lexical")
obj.feature_image = gp.get("feature_image")
obj.feature_image_alt = gp.get("feature_image_alt")
obj.feature_image_caption = gp.get("feature_image_caption")
obj.feature_image_caption = _sanitize_html(gp.get("feature_image_caption"))
obj.excerpt = gp.get("excerpt")
obj.custom_excerpt = gp.get("custom_excerpt")
obj.visibility = gp.get("visibility") or obj.visibility

View File

@@ -35,6 +35,9 @@ def register() -> Blueprint:
async def _require_data_header():
if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header():
if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header():
if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -11,12 +11,23 @@ from shared.browser.app.payments.sumup import create_checkout as sumup_create_ch
from shared.config import config
from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page
from shared.infrastructure.cart_identity import current_cart_identity
from bp.cart.services import check_sumup_status
from shared.browser.app.utils.htmx import is_htmx_request
from .filters.qs import makeqs_factory, decode
def _owner_filter():
"""Return SQLAlchemy clause restricting orders to current user/session."""
ident = current_cart_identity()
if ident["user_id"]:
return Order.user_id == ident["user_id"]
if ident["session_id"]:
return Order.session_id == ident["session_id"]
return None
def register() -> Blueprint:
bp = Blueprint("order", __name__, url_prefix='/<int:order_id>')
@@ -32,12 +43,15 @@ def register() -> Blueprint:
"""
Show a single order + items.
"""
owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute(
select(Order)
.options(
selectinload(Order.items).selectinload(OrderItem.product)
)
.where(Order.id == order_id)
.where(Order.id == order_id, owner)
)
order = result.scalar_one_or_none()
if not order:
@@ -58,7 +72,10 @@ def register() -> Blueprint:
If already paid, just go back to the order detail.
If not, (re)create a SumUp checkout and redirect.
"""
result = await g.s.execute(select(Order).where(Order.id == order_id))
owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute(select(Order).where(Order.id == order_id, owner))
order = result.scalar_one_or_none()
if not order:
return await make_response("Order not found", 404)
@@ -115,7 +132,10 @@ def register() -> Blueprint:
Manually re-check this order's status with SumUp.
Useful if the webhook hasn't fired or the user didn't return correctly.
"""
result = await g.s.execute(select(Order).where(Order.id == order_id))
owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute(select(Order).where(Order.id == order_id, owner))
order = result.scalar_one_or_none()
if not order:
return await make_response("Order not found", 404)

View File

@@ -11,6 +11,7 @@ from shared.browser.app.payments.sumup import create_checkout as sumup_create_ch
from shared.config import config
from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page
from shared.infrastructure.cart_identity import current_cart_identity
from bp.cart.services import check_sumup_status
from shared.browser.app.utils.htmx import is_htmx_request
from bp import register_order
@@ -42,9 +43,25 @@ def register(url_prefix: str) -> Blueprint:
# this is the crucial bit for the |qs filter
g.makeqs_factory = makeqs_factory
@bp.before_request
async def _require_identity():
"""Orders require a logged-in user or at least a cart session."""
ident = current_cart_identity()
if not ident["user_id"] and not ident["session_id"]:
return redirect(url_for("auth.login_form"))
@bp.get("/")
async def list_orders():
# --- ownership: only show orders belonging to current user/session ---
ident = current_cart_identity()
if ident["user_id"]:
owner_clause = Order.user_id == ident["user_id"]
elif ident["session_id"]:
owner_clause = Order.session_id == ident["session_id"]
else:
return redirect(url_for("auth.login_form"))
# --- decode filters from query string (page + search) ---
q = decode()
page, search = q.page, q.search
@@ -97,8 +114,8 @@ def register(url_prefix: str) -> Blueprint:
where_clause = or_(*conditions)
# --- total count & total pages (respecting search) ---
count_stmt = select(func.count()).select_from(Order)
# --- total count & total pages (respecting search + ownership) ---
count_stmt = select(func.count()).select_from(Order).where(owner_clause)
if where_clause is not None:
count_stmt = count_stmt.where(where_clause)
@@ -110,10 +127,11 @@ def register(url_prefix: str) -> Blueprint:
if page > total_pages:
page = total_pages
# --- paginated orders (respecting search) ---
# --- paginated orders (respecting search + ownership) ---
offset = (page - 1) * ORDERS_PER_PAGE
stmt = (
select(Order)
.where(owner_clause)
.order_by(Order.created_at.desc())
.offset(offset)
.limit(ORDERS_PER_PAGE)

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header():
if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header():
if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}
@@ -131,8 +134,9 @@ def register() -> Blueprint:
period_start = datetime.fromisoformat(request.args.get("period_start", ""))
period_end = datetime.fromisoformat(request.args.get("period_end", ""))
user_id = request.args.get("user_id", type=int)
is_admin = request.args.get("is_admin", "false").lower() == "true"
session_id = request.args.get("session_id")
# is_admin determined server-side, never from client params
is_admin = False
entries = await services.calendar.visible_entries_for_period(
g.s, calendar_id, period_start, period_end,
user_id=user_id, is_admin=is_admin, session_id=session_id,

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header():
if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Dict, List, Tuple, Iterable, Optional
import nh3
from quart import Blueprint, request, jsonify, g
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -29,10 +30,18 @@ from models.market import (
from shared.browser.app.redis_cacher import clear_cache
from shared.browser.app.csrf import csrf_exempt
from shared.infrastructure.internal_auth import validate_internal_request
products_api = Blueprint("products_api", __name__, url_prefix="/api/products")
@products_api.before_request
async def _require_internal_auth():
"""All product API endpoints require HMAC-signed internal requests."""
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
# ---- Comparison config (matches your schema) --------------------------------
PRODUCT_FIELDS: List[str] = [
@@ -219,9 +228,35 @@ def _deep_equal(a: Dict[str, Any], b: Dict[str, Any]) -> bool:
# ---- Mutation helpers -------------------------------------------------------
_PRODUCT_HTML_FIELDS = {"description_html"}
_SANITIZE_TAGS = {
"a", "b", "blockquote", "br", "code", "div", "em", "h1", "h2", "h3",
"h4", "h5", "h6", "hr", "i", "img", "li", "ol", "p", "pre", "span",
"strong", "sub", "sup", "table", "tbody", "td", "th", "thead", "tr",
"ul", "figure", "figcaption",
}
_SANITIZE_ATTRS = {
"*": {"class", "id"},
"a": {"href", "title", "target", "rel"},
"img": {"src", "alt", "title", "width", "height", "loading"},
"td": {"colspan", "rowspan"},
"th": {"colspan", "rowspan"},
}
def _sanitize_product_html(value: Any) -> Any:
if isinstance(value, str) and value:
return nh3.clean(value, tags=_SANITIZE_TAGS, attributes=_SANITIZE_ATTRS)
return value
def _apply_product_fields(p: Product, payload: Dict[str, Any]) -> None:
for f in PRODUCT_FIELDS:
setattr(p, f, payload.get(f))
val = payload.get(f)
if f in _PRODUCT_HTML_FIELDS:
val = _sanitize_product_html(val)
setattr(p, f, val)
p.updated_at = _now_utc()
def _replace_children(p: Product, payload: Dict[str, Any]) -> None:
@@ -239,7 +274,7 @@ def _replace_children(p: Product, payload: Dict[str, Any]) -> None:
for row in payload.get("sections") or []:
p.sections.append(ProductSection(
title=row.get("title") or "",
html=row.get("html") or "",
html=_sanitize_product_html(row.get("html") or ""),
created_at=_now_utc(), updated_at=_now_utc(),
))

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header():
if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {}

View File

@@ -57,10 +57,13 @@ async def protect() -> None:
if _is_exempt_endpoint():
return
# Internal service-to-service calls are already gated by header checks
# and only reachable on the Docker overlay network.
# Internal service-to-service calls — validate HMAC signature
if request.headers.get("X-Internal-Action") or request.headers.get("X-Internal-Data"):
from shared.infrastructure.internal_auth import validate_internal_request
if validate_internal_request():
return
# Reject unsigned internal requests
abort(403, "Invalid internal request signature")
session_token = qsession.get("csrf_token")
if not session_token:

View File

@@ -13,6 +13,8 @@ import os
import httpx
from shared.infrastructure.internal_auth import sign_internal_headers
log = logging.getLogger(__name__)
# Re-usable async client (created lazily, one per process)
@@ -65,10 +67,11 @@ async def call_action(
base = _internal_url(app_name)
url = f"{base}/internal/actions/{action_name}"
try:
headers = {ACTION_HEADER: "1", **sign_internal_headers(app_name)}
resp = await _get_client().post(
url,
json=payload or {},
headers={ACTION_HEADER: "1"},
headers=headers,
timeout=timeout,
)
if 200 <= resp.status_code < 300:

View File

@@ -328,9 +328,10 @@ def create_activitypub_blueprint(app_name: str) -> Blueprint:
if not sig_valid:
log.warning(
"Unverified inbox POST from %s (%s) on %saccepting anyway for now",
"Unverified inbox POST from %s (%s) on %srejecting",
from_actor_url, activity_type, domain,
)
abort(401, "Invalid or missing HTTP signature")
# Load actor row for DB operations
actor_row = (

View File

@@ -29,8 +29,43 @@ AP_CONTENT_TYPE = "application/activity+json"
# Helpers
# ---------------------------------------------------------------------------
def _is_safe_url(url: str) -> bool:
"""Reject URLs pointing to private/internal IPs to prevent SSRF."""
from urllib.parse import urlparse
import ipaddress
parsed = urlparse(url)
# Require HTTPS
if parsed.scheme != "https":
return False
hostname = parsed.hostname
if not hostname:
return False
# Block obvious internal hostnames
if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"):
return False
try:
addr = ipaddress.ip_address(hostname)
if addr.is_private or addr.is_loopback or addr.is_reserved or addr.is_link_local:
return False
except ValueError:
# Not an IP literal — hostname is fine (DNS resolution handled by httpx)
# Block common internal DNS patterns
if hostname.endswith(".internal") or hostname.endswith(".local"):
return False
return True
async def fetch_remote_actor(actor_url: str) -> dict | None:
"""Fetch a remote actor's JSON-LD profile."""
if not _is_safe_url(actor_url):
log.warning("Blocked SSRF attempt: %s", actor_url)
return None
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(

View File

@@ -13,6 +13,8 @@ import os
import httpx
from shared.infrastructure.internal_auth import sign_internal_headers
log = logging.getLogger(__name__)
# Re-usable async client (created lazily, one per process)
@@ -66,10 +68,11 @@ async def fetch_data(
base = _internal_url(app_name)
url = f"{base}/internal/data/{query_name}"
try:
headers = {DATA_HEADER: "1", **sign_internal_headers(app_name)}
resp = await _get_client().get(
url,
params=params,
headers={DATA_HEADER: "1"},
headers=headers,
timeout=timeout,
)
if resp.status_code == 200:

View File

@@ -77,7 +77,13 @@ def create_base_app(
configure_logging(name)
app.secret_key = os.getenv("SECRET_KEY", "dev-secret-key-change-me-777")
secret_key = os.getenv("SECRET_KEY")
if not secret_key:
env = os.getenv("ENVIRONMENT", "development")
if env in ("production", "staging"):
raise RuntimeError("SECRET_KEY environment variable must be set in production")
secret_key = "dev-secret-key-change-me-777"
app.secret_key = secret_key
# Per-app first-party session cookie (no shared domain — avoids Safari ITP)
app.config["SESSION_COOKIE_NAME"] = f"{name}_session"
@@ -192,11 +198,14 @@ def create_base_app(
from sqlalchemy import select
from shared.db.session import get_account_session
from shared.models.oauth_grant import OAuthGrant
from shared.models.oauth_grant import OAuthGrant, hash_token
try:
token_h = hash_token(grant_token)
async with get_account_session() as s:
grant = await s.scalar(
select(OAuthGrant).where(OAuthGrant.token == grant_token)
select(OAuthGrant).where(
(OAuthGrant.token_hash == token_h) | (OAuthGrant.token == grant_token)
)
)
valid = grant is not None and grant.revoked_at is None
except Exception:

View File

@@ -0,0 +1,92 @@
"""HMAC-based authentication for internal service-to-service calls.
Replaces the previous header-presence-only check with a signed token
that includes a timestamp to prevent replay attacks.
Signing side (data_client.py / actions.py)::
from shared.infrastructure.internal_auth import sign_internal_headers
headers = sign_internal_headers("cart")
Validation side (before_request guards, csrf.py)::
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
abort(403)
"""
from __future__ import annotations
import hashlib
import hmac
import os
import time
from quart import request
# Shared secret — MUST be set in production
_SECRET = os.getenv("INTERNAL_HMAC_SECRET", "").encode() or os.getenv("SECRET_KEY", "").encode()
# Maximum age of a signed request (seconds)
_MAX_AGE = 300 # 5 minutes
def _get_secret() -> bytes:
return _SECRET or os.getenv("SECRET_KEY", "dev-secret-key-change-me-777").encode()
def sign_internal_headers(app_name: str) -> dict[str, str]:
"""Generate signed headers for an internal request.
Returns a dict of headers to include in the request.
"""
ts = str(int(time.time()))
payload = f"{ts}:{app_name}".encode()
sig = hmac.new(_get_secret(), payload, hashlib.sha256).hexdigest()
return {
"X-Internal-Timestamp": ts,
"X-Internal-App": app_name,
"X-Internal-Signature": sig,
}
def validate_internal_request() -> bool:
"""Validate that an incoming request has a valid HMAC signature.
Checks X-Internal-Timestamp, X-Internal-App, and X-Internal-Signature
headers. Returns True if valid, False otherwise.
"""
ts = request.headers.get("X-Internal-Timestamp", "")
app_name = request.headers.get("X-Internal-App", "")
sig = request.headers.get("X-Internal-Signature", "")
if not ts or not app_name or not sig:
return False
# Check timestamp freshness
try:
req_time = int(ts)
except (ValueError, TypeError):
return False
now = int(time.time())
if abs(now - req_time) > _MAX_AGE:
return False
# Verify signature
payload = f"{ts}:{app_name}".encode()
expected = hmac.new(_get_secret(), payload, hashlib.sha256).hexdigest()
return hmac.compare_digest(sig, expected)
def is_internal_request() -> bool:
"""Check if the current request is a signed internal request.
This is a convenience that checks for any of the internal headers
(legacy or new HMAC-signed).
"""
# New HMAC-signed headers
if request.headers.get("X-Internal-Signature"):
return validate_internal_request()
# Legacy: presence-only headers (still accepted during migration,
# but callers should be updated to use signed headers)
return False

View File

@@ -0,0 +1,142 @@
"""Redis-based rate limiter for auth endpoints.
Provides a decorator that enforces per-key rate limits using a sliding
window counter stored in Redis (auth DB 15).
Usage::
from shared.infrastructure.rate_limit import rate_limit
@rate_limit(key_func=lambda: request.form.get("email", "").lower(),
max_requests=5, window_seconds=900, scope="magic_link")
@bp.post("/start/")
async def start_login():
...
"""
from __future__ import annotations
import functools
import time
from quart import request, jsonify, make_response
async def _check_rate_limit(
key: str,
max_requests: int,
window_seconds: int,
) -> tuple[bool, int]:
"""Check and increment rate limit counter.
Returns (allowed, remaining).
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
now = time.time()
window_start = now - window_seconds
redis_key = f"rl:{key}"
pipe = r.pipeline()
# Remove expired entries
pipe.zremrangebyscore(redis_key, 0, window_start)
# Add current request
pipe.zadd(redis_key, {str(now).encode(): now})
# Count entries in window
pipe.zcard(redis_key)
# Set TTL so key auto-expires
pipe.expire(redis_key, window_seconds)
results = await pipe.execute()
count = results[2]
allowed = count <= max_requests
remaining = max(0, max_requests - count)
return allowed, remaining
def rate_limit(
*,
key_func,
max_requests: int,
window_seconds: int,
scope: str,
):
"""Decorator that rate-limits a Quart route.
Parameters
----------
key_func:
Callable returning the rate-limit key (e.g. email, IP).
Called inside request context.
max_requests:
Maximum number of requests allowed in the window.
window_seconds:
Sliding window duration in seconds.
scope:
Namespace prefix for the Redis key (e.g. "magic_link").
"""
def decorator(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
raw_key = key_func()
if not raw_key:
return await fn(*args, **kwargs)
full_key = f"{scope}:{raw_key}"
try:
allowed, remaining = await _check_rate_limit(
full_key, max_requests, window_seconds,
)
except Exception:
# If Redis is down, allow the request
return await fn(*args, **kwargs)
if not allowed:
resp = await make_response(
jsonify({"error": "rate_limited", "retry_after": window_seconds}),
429,
)
resp.headers["Retry-After"] = str(window_seconds)
return resp
return await fn(*args, **kwargs)
return wrapper
return decorator
async def check_poll_backoff(device_code: str) -> tuple[bool, int]:
"""Enforce exponential backoff on device token polling.
Returns (allowed, interval) where interval is the recommended
poll interval in seconds. If not allowed, caller should return
a 'slow_down' error per RFC 8628.
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
key = f"rl:devpoll:{device_code}"
now = time.time()
raw = await r.get(key)
if raw:
data = raw.decode() if isinstance(raw, bytes) else raw
parts = data.split(":")
last_poll = float(parts[0])
interval = int(parts[1])
elapsed = now - last_poll
if elapsed < interval:
# Too fast — increase interval
new_interval = min(interval + 5, 60)
await r.set(key, f"{now}:{new_interval}".encode(), ex=900)
return False, new_interval
# Acceptable pace — keep current interval
await r.set(key, f"{now}:{interval}".encode(), ex=900)
return True, interval
# First poll
initial_interval = 5
await r.set(key, f"{now}:{initial_interval}".encode(), ex=900)
return True, initial_interval

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import hashlib
from datetime import datetime
from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -6,21 +7,28 @@ from shared.db.base import Base
class OAuthCode(Base):
"""Short-lived authorization code issued during OAuth flow.
The ``code`` column is retained during migration but new codes store
only ``code_hash``. Lookups should use ``code_hash``.
"""
__tablename__ = "oauth_codes"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
code: Mapped[str] = mapped_column(String(128), unique=True, index=True, nullable=False)
code: Mapped[str | None] = mapped_column(String(128), nullable=True)
code_hash: Mapped[str | None] = mapped_column(String(64), unique=True, nullable=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
client_id: Mapped[str] = mapped_column(String(64), nullable=False)
redirect_uri: Mapped[str] = mapped_column(String(512), nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
grant_token: Mapped[str | None] = mapped_column(String(128), nullable=True)
grant_token_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
user = relationship("User", backref="oauth_codes")
__table_args__ = (
Index("ix_oauth_code_code", "code", unique=True),
Index("ix_oauth_code_code_hash", "code_hash", unique=True),
Index("ix_oauth_code_user", "user_id"),
)

View File

@@ -1,21 +1,31 @@
from __future__ import annotations
import hashlib
from datetime import datetime
from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.db.base import Base
def hash_token(token: str) -> str:
"""SHA-256 hash a token for secure DB storage."""
return hashlib.sha256(token.encode()).hexdigest()
class OAuthGrant(Base):
"""Long-lived grant tracking each client-app session authorization.
Created when the OAuth authorize endpoint issues a code. Tied to the
account session that issued it (``issuer_session``) so that logging out
on one device revokes only that device's grants.
The ``token`` column is retained during migration but new grants store
only ``token_hash``. Lookups should use ``token_hash``.
"""
__tablename__ = "oauth_grants"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
token: Mapped[str] = mapped_column(String(128), unique=True, nullable=False)
token: Mapped[str | None] = mapped_column(String(128), nullable=True)
token_hash: Mapped[str | None] = mapped_column(String(64), unique=True, nullable=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
client_id: Mapped[str] = mapped_column(String(64), nullable=False)
issuer_session: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
@@ -26,7 +36,7 @@ class OAuthGrant(Base):
user = relationship("User", backref="oauth_grants")
__table_args__ = (
Index("ix_oauth_grant_token", "token", unique=True),
Index("ix_oauth_grant_token_hash", "token_hash", unique=True),
Index("ix_oauth_grant_issuer", "issuer_session"),
Index("ix_oauth_grant_device", "device_id", "client_id"),
)

View File

@@ -44,6 +44,7 @@ Werkzeug==3.1.3
wsproto==1.2.0
zstandard==0.25.0
redis>=5.0
nh3>=0.2.14
mistune>=3.0
pytest>=8.0
pytest-asyncio>=0.23