Security audit: fix IDOR, add rate limiting, HMAC auth, token hashing, XSS sanitization

Critical: Add ownership checks to all order routes (IDOR fix).
High: Redis rate limiting on auth endpoints, HMAC-signed internal
service calls replacing header-presence-only checks, nh3 HTML
sanitization on ghost_sync and product import, internal auth on
market API endpoints, SHA-256 hashed OAuth grant/code tokens.
Medium: SECRET_KEY production guard, AP signature enforcement,
is_admin param removal, cart_sid validation, SSRF protection on
remote actor fetch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-26 13:30:27 +00:00
parent 404449fcab
commit c015f3f02f
27 changed files with 607 additions and 33 deletions

View File

@@ -0,0 +1,86 @@
"""Add token_hash columns to oauth_grants and oauth_codes
Revision ID: acct_0002
Revises: acct_0001
Create Date: 2026-02-26
"""
import hashlib
import sqlalchemy as sa
from alembic import op
revision = "acct_0002"
down_revision = "acct_0001"
branch_labels = None
depends_on = None
def _hash(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
def upgrade():
# Add new hash columns
op.add_column("oauth_grants", sa.Column("token_hash", sa.String(64), nullable=True))
op.add_column("oauth_codes", sa.Column("code_hash", sa.String(64), nullable=True))
op.add_column("oauth_codes", sa.Column("grant_token_hash", sa.String(64), nullable=True))
# Backfill hashes from existing plaintext tokens
conn = op.get_bind()
grants = conn.execute(sa.text("SELECT id, token FROM oauth_grants WHERE token IS NOT NULL"))
for row in grants:
conn.execute(
sa.text("UPDATE oauth_grants SET token_hash = :h WHERE id = :id"),
{"h": _hash(row.token), "id": row.id},
)
codes = conn.execute(sa.text("SELECT id, code, grant_token FROM oauth_codes WHERE code IS NOT NULL"))
for row in codes:
params = {"id": row.id, "ch": _hash(row.code)}
params["gh"] = _hash(row.grant_token) if row.grant_token else None
conn.execute(
sa.text("UPDATE oauth_codes SET code_hash = :ch, grant_token_hash = :gh WHERE id = :id"),
params,
)
# Create unique indexes on hash columns
op.create_index("ix_oauth_grant_token_hash", "oauth_grants", ["token_hash"], unique=True)
op.create_index("ix_oauth_code_code_hash", "oauth_codes", ["code_hash"], unique=True)
# Make original token columns nullable (keep for rollback safety)
op.alter_column("oauth_grants", "token", nullable=True)
op.alter_column("oauth_codes", "code", nullable=True)
# Drop old unique indexes on plaintext columns
try:
op.drop_index("ix_oauth_grant_token", "oauth_grants")
except Exception:
pass
try:
op.drop_index("ix_oauth_code_code", "oauth_codes")
except Exception:
pass
def downgrade():
# Restore original NOT NULL constraints
op.alter_column("oauth_grants", "token", nullable=False)
op.alter_column("oauth_codes", "code", nullable=False)
# Drop hash columns and indexes
try:
op.drop_index("ix_oauth_grant_token_hash", "oauth_grants")
except Exception:
pass
try:
op.drop_index("ix_oauth_code_code_hash", "oauth_codes")
except Exception:
pass
op.drop_column("oauth_grants", "token_hash")
op.drop_column("oauth_codes", "code_hash")
op.drop_column("oauth_codes", "grant_token_hash")
# Restore original unique indexes
op.create_index("ix_oauth_grant_token", "oauth_grants", ["token"], unique=True)
op.create_index("ix_oauth_code_code", "oauth_codes", ["code"], unique=True)

View File

@@ -17,6 +17,9 @@ def register() -> Blueprint:
async def _require_action_header(): async def _require_action_header():
if not request.headers.get(ACTION_HEADER): if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -26,9 +26,10 @@ from sqlalchemy.exc import SQLAlchemyError
from shared.db.session import get_session from shared.db.session import get_session
from shared.models import User from shared.models import User
from shared.models.oauth_code import OAuthCode from shared.models.oauth_code import OAuthCode
from shared.models.oauth_grant import OAuthGrant from shared.models.oauth_grant import OAuthGrant, hash_token
from shared.infrastructure.urls import account_url, app_url from shared.infrastructure.urls import account_url, app_url
from shared.infrastructure.cart_identity import current_cart_identity from shared.infrastructure.cart_identity import current_cart_identity
from shared.infrastructure.rate_limit import rate_limit, check_poll_backoff
from shared.events import emit_activity from shared.events import emit_activity
from .services import ( from .services import (
@@ -98,7 +99,8 @@ def register(url_prefix="/auth"):
async with get_session() as s: async with get_session() as s:
async with s.begin(): async with s.begin():
grant = OAuthGrant( grant = OAuthGrant(
token=grant_token, token=None,
token_hash=hash_token(grant_token),
user_id=g.user.id, user_id=g.user.id,
client_id=client_id, client_id=client_id,
issuer_session=account_sid, issuer_session=account_sid,
@@ -107,12 +109,14 @@ def register(url_prefix="/auth"):
s.add(grant) s.add(grant)
oauth_code = OAuthCode( oauth_code = OAuthCode(
code=code, code=None,
code_hash=hash_token(code),
user_id=g.user.id, user_id=g.user.id,
client_id=client_id, client_id=client_id,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
expires_at=expires, expires_at=expires,
grant_token=grant_token, grant_token=None,
grant_token_hash=hash_token(grant_token),
) )
s.add(oauth_code) s.add(oauth_code)
@@ -149,11 +153,15 @@ def register(url_prefix="/auth"):
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
code_h = hash_token(code)
async with get_session() as s: async with get_session() as s:
async with s.begin(): async with s.begin():
# Look up by hash first (new grants), fall back to plaintext (migration)
result = await s.execute( result = await s.execute(
select(OAuthCode) select(OAuthCode)
.where(OAuthCode.code == code) .where(
(OAuthCode.code_hash == code_h) | (OAuthCode.code == code)
)
.with_for_update() .with_for_update()
) )
oauth_code = result.scalar_one_or_none() oauth_code = result.scalar_one_or_none()
@@ -197,9 +205,12 @@ def register(url_prefix="/auth"):
if not token: if not token:
return jsonify({"valid": False}), 200 return jsonify({"valid": False}), 200
token_h = hash_token(token)
async with get_session() as s: async with get_session() as s:
grant = await s.scalar( grant = await s.scalar(
select(OAuthGrant).where(OAuthGrant.token == token) select(OAuthGrant).where(
(OAuthGrant.token_hash == token_h) | (OAuthGrant.token == token)
)
) )
if not grant or grant.revoked_at is not None: if not grant or grant.revoked_at is not None:
return jsonify({"valid": False}), 200 return jsonify({"valid": False}), 200
@@ -257,12 +268,19 @@ def register(url_prefix="/auth"):
store_login_redirect_target() store_login_redirect_target()
cross_cart_sid = request.args.get("cart_sid") cross_cart_sid = request.args.get("cart_sid")
if cross_cart_sid: if cross_cart_sid:
qsession["cart_sid"] = cross_cart_sid import re
# Validate cart_sid is a hex token (32 chars from token_hex(16))
if re.fullmatch(r"[0-9a-f]{32}", cross_cart_sid):
qsession["cart_sid"] = cross_cart_sid
if g.get("user"): if g.get("user"):
redirect_url = pop_login_redirect_target() redirect_url = pop_login_redirect_target()
return redirect(redirect_url) return redirect(redirect_url)
return await render_template("auth/login.html") return await render_template("auth/login.html")
@rate_limit(
key_func=lambda: request.headers.get("X-Forwarded-For", request.remote_addr),
max_requests=10, window_seconds=900, scope="magic_ip",
)
@auth_bp.post("/start/") @auth_bp.post("/start/")
async def start_login(): async def start_login():
form = await request.form form = await request.form
@@ -279,6 +297,22 @@ def register(url_prefix="/auth"):
400, 400,
) )
# Per-email rate limit: 5 magic links per 15 minutes
from shared.infrastructure.rate_limit import _check_rate_limit
try:
allowed, _ = await _check_rate_limit(f"magic_email:{email}", 5, 900)
if not allowed:
return (
await render_template(
"auth/check_email.html",
email=email,
email_error=None,
),
200,
)
except Exception:
pass # Redis down — allow the request
user = await find_or_create_user(g.s, email) user = await find_or_create_user(g.s, email)
token, expires = await create_magic_link(g.s, user.id) token, expires = await create_magic_link(g.s, user.id)
@@ -521,7 +555,8 @@ def register(url_prefix="/auth"):
async with get_session() as s: async with get_session() as s:
async with s.begin(): async with s.begin():
grant = OAuthGrant( grant = OAuthGrant(
token=grant_token, token=None,
token_hash=hash_token(grant_token),
user_id=user.id, user_id=user.id,
client_id=blob["client_id"], client_id=blob["client_id"],
issuer_session=account_sid, issuer_session=account_sid,
@@ -546,6 +581,10 @@ def register(url_prefix="/auth"):
return True return True
@rate_limit(
key_func=lambda: request.headers.get("X-Forwarded-For", request.remote_addr),
max_requests=10, window_seconds=3600, scope="dev_auth",
)
@csrf_exempt @csrf_exempt
@auth_bp.post("/device/authorize") @auth_bp.post("/device/authorize")
@auth_bp.post("/device/authorize/") @auth_bp.post("/device/authorize/")
@@ -600,6 +639,14 @@ def register(url_prefix="/auth"):
if not device_code or client_id not in ALLOWED_CLIENTS: if not device_code or client_id not in ALLOWED_CLIENTS:
return jsonify({"error": "invalid_request"}), 400 return jsonify({"error": "invalid_request"}), 400
# Enforce polling backoff per RFC 8628
try:
poll_ok, interval = await check_poll_backoff(device_code)
if not poll_ok:
return jsonify({"error": "slow_down", "interval": interval}), 400
except Exception:
pass # Redis down — allow the request
from shared.infrastructure.auth_redis import get_auth_redis from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis() r = await get_auth_redis()

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header(): async def _require_data_header():
if not request.headers.get(DATA_HEADER): if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -17,6 +17,9 @@ def register() -> Blueprint:
async def _require_action_header(): async def _require_action_header():
if not request.headers.get(ACTION_HEADER): if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -15,6 +15,7 @@ from html import escape as html_escape
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import httpx import httpx
import nh3
from sqlalchemy import select, delete from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -29,6 +30,35 @@ GHOST_ADMIN_API_URL = os.environ["GHOST_ADMIN_API_URL"]
from shared.browser.app.utils import utcnow from shared.browser.app.utils import utcnow
def _sanitize_html(html: str | None) -> str | None:
"""Sanitize HTML content using nh3, allowing safe formatting tags."""
if not html:
return html
return nh3.clean(
html,
tags={
"a", "abbr", "acronym", "b", "blockquote", "br", "code",
"div", "em", "figcaption", "figure", "h1", "h2", "h3",
"h4", "h5", "h6", "hr", "i", "img", "li", "ol", "p",
"pre", "span", "strong", "sub", "sup", "table", "tbody",
"td", "th", "thead", "tr", "ul", "video", "source",
"picture", "iframe", "audio",
},
attributes={
"*": {"class", "id", "style"},
"a": {"href", "title", "target", "rel"},
"img": {"src", "alt", "title", "width", "height", "loading"},
"video": {"src", "controls", "width", "height", "poster"},
"audio": {"src", "controls"},
"source": {"src", "type"},
"iframe": {"src", "width", "height", "frameborder", "allowfullscreen"},
"td": {"colspan", "rowspan"},
"th": {"colspan", "rowspan"},
},
url_schemes={"http", "https", "mailto"},
)
def _auth_header() -> dict[str, str]: def _auth_header() -> dict[str, str]:
return {"Authorization": f"Ghost {make_ghost_admin_jwt()}"} return {"Authorization": f"Ghost {make_ghost_admin_jwt()}"}
@@ -99,13 +129,13 @@ def _apply_ghost_fields(obj: Post, gp: Dict[str, Any], author_map: Dict[str, Aut
obj.uuid = gp.get("uuid") or obj.uuid obj.uuid = gp.get("uuid") or obj.uuid
obj.slug = gp.get("slug") or obj.slug obj.slug = gp.get("slug") or obj.slug
obj.title = gp.get("title") or obj.title obj.title = gp.get("title") or obj.title
obj.html = gp.get("html") obj.html = _sanitize_html(gp.get("html"))
obj.plaintext = gp.get("plaintext") obj.plaintext = gp.get("plaintext")
obj.mobiledoc = gp.get("mobiledoc") obj.mobiledoc = gp.get("mobiledoc")
obj.lexical = gp.get("lexical") obj.lexical = gp.get("lexical")
obj.feature_image = gp.get("feature_image") obj.feature_image = gp.get("feature_image")
obj.feature_image_alt = gp.get("feature_image_alt") obj.feature_image_alt = gp.get("feature_image_alt")
obj.feature_image_caption = gp.get("feature_image_caption") obj.feature_image_caption = _sanitize_html(gp.get("feature_image_caption"))
obj.excerpt = gp.get("excerpt") obj.excerpt = gp.get("excerpt")
obj.custom_excerpt = gp.get("custom_excerpt") obj.custom_excerpt = gp.get("custom_excerpt")
obj.visibility = gp.get("visibility") or obj.visibility obj.visibility = gp.get("visibility") or obj.visibility

View File

@@ -35,6 +35,9 @@ def register() -> Blueprint:
async def _require_data_header(): async def _require_data_header():
if not request.headers.get(DATA_HEADER): if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header(): async def _require_action_header():
if not request.headers.get(ACTION_HEADER): if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header(): async def _require_data_header():
if not request.headers.get(DATA_HEADER): if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -11,12 +11,23 @@ from shared.browser.app.payments.sumup import create_checkout as sumup_create_ch
from shared.config import config from shared.config import config
from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page
from shared.infrastructure.cart_identity import current_cart_identity
from bp.cart.services import check_sumup_status from bp.cart.services import check_sumup_status
from shared.browser.app.utils.htmx import is_htmx_request from shared.browser.app.utils.htmx import is_htmx_request
from .filters.qs import makeqs_factory, decode from .filters.qs import makeqs_factory, decode
def _owner_filter():
"""Return SQLAlchemy clause restricting orders to current user/session."""
ident = current_cart_identity()
if ident["user_id"]:
return Order.user_id == ident["user_id"]
if ident["session_id"]:
return Order.session_id == ident["session_id"]
return None
def register() -> Blueprint: def register() -> Blueprint:
bp = Blueprint("order", __name__, url_prefix='/<int:order_id>') bp = Blueprint("order", __name__, url_prefix='/<int:order_id>')
@@ -32,12 +43,15 @@ def register() -> Blueprint:
""" """
Show a single order + items. Show a single order + items.
""" """
owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute( result = await g.s.execute(
select(Order) select(Order)
.options( .options(
selectinload(Order.items).selectinload(OrderItem.product) selectinload(Order.items).selectinload(OrderItem.product)
) )
.where(Order.id == order_id) .where(Order.id == order_id, owner)
) )
order = result.scalar_one_or_none() order = result.scalar_one_or_none()
if not order: if not order:
@@ -58,7 +72,10 @@ def register() -> Blueprint:
If already paid, just go back to the order detail. If already paid, just go back to the order detail.
If not, (re)create a SumUp checkout and redirect. If not, (re)create a SumUp checkout and redirect.
""" """
result = await g.s.execute(select(Order).where(Order.id == order_id)) owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute(select(Order).where(Order.id == order_id, owner))
order = result.scalar_one_or_none() order = result.scalar_one_or_none()
if not order: if not order:
return await make_response("Order not found", 404) return await make_response("Order not found", 404)
@@ -115,7 +132,10 @@ def register() -> Blueprint:
Manually re-check this order's status with SumUp. Manually re-check this order's status with SumUp.
Useful if the webhook hasn't fired or the user didn't return correctly. Useful if the webhook hasn't fired or the user didn't return correctly.
""" """
result = await g.s.execute(select(Order).where(Order.id == order_id)) owner = _owner_filter()
if owner is None:
return await make_response("Order not found", 404)
result = await g.s.execute(select(Order).where(Order.id == order_id, owner))
order = result.scalar_one_or_none() order = result.scalar_one_or_none()
if not order: if not order:
return await make_response("Order not found", 404) return await make_response("Order not found", 404)

View File

@@ -11,6 +11,7 @@ from shared.browser.app.payments.sumup import create_checkout as sumup_create_ch
from shared.config import config from shared.config import config
from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page from shared.infrastructure.http_utils import vary as _vary, current_url_without_page as _current_url_without_page
from shared.infrastructure.cart_identity import current_cart_identity
from bp.cart.services import check_sumup_status from bp.cart.services import check_sumup_status
from shared.browser.app.utils.htmx import is_htmx_request from shared.browser.app.utils.htmx import is_htmx_request
from bp import register_order from bp import register_order
@@ -42,9 +43,25 @@ def register(url_prefix: str) -> Blueprint:
# this is the crucial bit for the |qs filter # this is the crucial bit for the |qs filter
g.makeqs_factory = makeqs_factory g.makeqs_factory = makeqs_factory
@bp.before_request
async def _require_identity():
"""Orders require a logged-in user or at least a cart session."""
ident = current_cart_identity()
if not ident["user_id"] and not ident["session_id"]:
return redirect(url_for("auth.login_form"))
@bp.get("/") @bp.get("/")
async def list_orders(): async def list_orders():
# --- ownership: only show orders belonging to current user/session ---
ident = current_cart_identity()
if ident["user_id"]:
owner_clause = Order.user_id == ident["user_id"]
elif ident["session_id"]:
owner_clause = Order.session_id == ident["session_id"]
else:
return redirect(url_for("auth.login_form"))
# --- decode filters from query string (page + search) --- # --- decode filters from query string (page + search) ---
q = decode() q = decode()
page, search = q.page, q.search page, search = q.page, q.search
@@ -97,8 +114,8 @@ def register(url_prefix: str) -> Blueprint:
where_clause = or_(*conditions) where_clause = or_(*conditions)
# --- total count & total pages (respecting search) --- # --- total count & total pages (respecting search + ownership) ---
count_stmt = select(func.count()).select_from(Order) count_stmt = select(func.count()).select_from(Order).where(owner_clause)
if where_clause is not None: if where_clause is not None:
count_stmt = count_stmt.where(where_clause) count_stmt = count_stmt.where(where_clause)
@@ -110,10 +127,11 @@ def register(url_prefix: str) -> Blueprint:
if page > total_pages: if page > total_pages:
page = total_pages page = total_pages
# --- paginated orders (respecting search) --- # --- paginated orders (respecting search + ownership) ---
offset = (page - 1) * ORDERS_PER_PAGE offset = (page - 1) * ORDERS_PER_PAGE
stmt = ( stmt = (
select(Order) select(Order)
.where(owner_clause)
.order_by(Order.created_at.desc()) .order_by(Order.created_at.desc())
.offset(offset) .offset(offset)
.limit(ORDERS_PER_PAGE) .limit(ORDERS_PER_PAGE)

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header(): async def _require_action_header():
if not request.headers.get(ACTION_HEADER): if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header(): async def _require_data_header():
if not request.headers.get(DATA_HEADER): if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}
@@ -131,8 +134,9 @@ def register() -> Blueprint:
period_start = datetime.fromisoformat(request.args.get("period_start", "")) period_start = datetime.fromisoformat(request.args.get("period_start", ""))
period_end = datetime.fromisoformat(request.args.get("period_end", "")) period_end = datetime.fromisoformat(request.args.get("period_end", ""))
user_id = request.args.get("user_id", type=int) user_id = request.args.get("user_id", type=int)
is_admin = request.args.get("is_admin", "false").lower() == "true"
session_id = request.args.get("session_id") session_id = request.args.get("session_id")
# is_admin determined server-side, never from client params
is_admin = False
entries = await services.calendar.visible_entries_for_period( entries = await services.calendar.visible_entries_for_period(
g.s, calendar_id, period_start, period_end, g.s, calendar_id, period_start, period_end,
user_id=user_id, is_admin=is_admin, session_id=session_id, user_id=user_id, is_admin=is_admin, session_id=session_id,

View File

@@ -18,6 +18,9 @@ def register() -> Blueprint:
async def _require_action_header(): async def _require_action_header():
if not request.headers.get(ACTION_HEADER): if not request.headers.get(ACTION_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
from decimal import Decimal from decimal import Decimal
from typing import Any, Dict, List, Tuple, Iterable, Optional from typing import Any, Dict, List, Tuple, Iterable, Optional
import nh3
from quart import Blueprint, request, jsonify, g from quart import Blueprint, request, jsonify, g
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -29,10 +30,18 @@ from models.market import (
from shared.browser.app.redis_cacher import clear_cache from shared.browser.app.redis_cacher import clear_cache
from shared.browser.app.csrf import csrf_exempt from shared.browser.app.csrf import csrf_exempt
from shared.infrastructure.internal_auth import validate_internal_request
products_api = Blueprint("products_api", __name__, url_prefix="/api/products") products_api = Blueprint("products_api", __name__, url_prefix="/api/products")
@products_api.before_request
async def _require_internal_auth():
"""All product API endpoints require HMAC-signed internal requests."""
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
# ---- Comparison config (matches your schema) -------------------------------- # ---- Comparison config (matches your schema) --------------------------------
PRODUCT_FIELDS: List[str] = [ PRODUCT_FIELDS: List[str] = [
@@ -219,9 +228,35 @@ def _deep_equal(a: Dict[str, Any], b: Dict[str, Any]) -> bool:
# ---- Mutation helpers ------------------------------------------------------- # ---- Mutation helpers -------------------------------------------------------
_PRODUCT_HTML_FIELDS = {"description_html"}
_SANITIZE_TAGS = {
"a", "b", "blockquote", "br", "code", "div", "em", "h1", "h2", "h3",
"h4", "h5", "h6", "hr", "i", "img", "li", "ol", "p", "pre", "span",
"strong", "sub", "sup", "table", "tbody", "td", "th", "thead", "tr",
"ul", "figure", "figcaption",
}
_SANITIZE_ATTRS = {
"*": {"class", "id"},
"a": {"href", "title", "target", "rel"},
"img": {"src", "alt", "title", "width", "height", "loading"},
"td": {"colspan", "rowspan"},
"th": {"colspan", "rowspan"},
}
def _sanitize_product_html(value: Any) -> Any:
if isinstance(value, str) and value:
return nh3.clean(value, tags=_SANITIZE_TAGS, attributes=_SANITIZE_ATTRS)
return value
def _apply_product_fields(p: Product, payload: Dict[str, Any]) -> None: def _apply_product_fields(p: Product, payload: Dict[str, Any]) -> None:
for f in PRODUCT_FIELDS: for f in PRODUCT_FIELDS:
setattr(p, f, payload.get(f)) val = payload.get(f)
if f in _PRODUCT_HTML_FIELDS:
val = _sanitize_product_html(val)
setattr(p, f, val)
p.updated_at = _now_utc() p.updated_at = _now_utc()
def _replace_children(p: Product, payload: Dict[str, Any]) -> None: def _replace_children(p: Product, payload: Dict[str, Any]) -> None:
@@ -239,7 +274,7 @@ def _replace_children(p: Product, payload: Dict[str, Any]) -> None:
for row in payload.get("sections") or []: for row in payload.get("sections") or []:
p.sections.append(ProductSection( p.sections.append(ProductSection(
title=row.get("title") or "", title=row.get("title") or "",
html=row.get("html") or "", html=_sanitize_product_html(row.get("html") or ""),
created_at=_now_utc(), updated_at=_now_utc(), created_at=_now_utc(), updated_at=_now_utc(),
)) ))

View File

@@ -19,6 +19,9 @@ def register() -> Blueprint:
async def _require_data_header(): async def _require_data_header():
if not request.headers.get(DATA_HEADER): if not request.headers.get(DATA_HEADER):
return jsonify({"error": "forbidden"}), 403 return jsonify({"error": "forbidden"}), 403
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
return jsonify({"error": "forbidden"}), 403
_handlers: dict[str, object] = {} _handlers: dict[str, object] = {}

View File

@@ -57,10 +57,13 @@ async def protect() -> None:
if _is_exempt_endpoint(): if _is_exempt_endpoint():
return return
# Internal service-to-service calls are already gated by header checks # Internal service-to-service calls — validate HMAC signature
# and only reachable on the Docker overlay network.
if request.headers.get("X-Internal-Action") or request.headers.get("X-Internal-Data"): if request.headers.get("X-Internal-Action") or request.headers.get("X-Internal-Data"):
return from shared.infrastructure.internal_auth import validate_internal_request
if validate_internal_request():
return
# Reject unsigned internal requests
abort(403, "Invalid internal request signature")
session_token = qsession.get("csrf_token") session_token = qsession.get("csrf_token")
if not session_token: if not session_token:

View File

@@ -13,6 +13,8 @@ import os
import httpx import httpx
from shared.infrastructure.internal_auth import sign_internal_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Re-usable async client (created lazily, one per process) # Re-usable async client (created lazily, one per process)
@@ -65,10 +67,11 @@ async def call_action(
base = _internal_url(app_name) base = _internal_url(app_name)
url = f"{base}/internal/actions/{action_name}" url = f"{base}/internal/actions/{action_name}"
try: try:
headers = {ACTION_HEADER: "1", **sign_internal_headers(app_name)}
resp = await _get_client().post( resp = await _get_client().post(
url, url,
json=payload or {}, json=payload or {},
headers={ACTION_HEADER: "1"}, headers=headers,
timeout=timeout, timeout=timeout,
) )
if 200 <= resp.status_code < 300: if 200 <= resp.status_code < 300:

View File

@@ -328,9 +328,10 @@ def create_activitypub_blueprint(app_name: str) -> Blueprint:
if not sig_valid: if not sig_valid:
log.warning( log.warning(
"Unverified inbox POST from %s (%s) on %saccepting anyway for now", "Unverified inbox POST from %s (%s) on %srejecting",
from_actor_url, activity_type, domain, from_actor_url, activity_type, domain,
) )
abort(401, "Invalid or missing HTTP signature")
# Load actor row for DB operations # Load actor row for DB operations
actor_row = ( actor_row = (

View File

@@ -29,8 +29,43 @@ AP_CONTENT_TYPE = "application/activity+json"
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _is_safe_url(url: str) -> bool:
"""Reject URLs pointing to private/internal IPs to prevent SSRF."""
from urllib.parse import urlparse
import ipaddress
parsed = urlparse(url)
# Require HTTPS
if parsed.scheme != "https":
return False
hostname = parsed.hostname
if not hostname:
return False
# Block obvious internal hostnames
if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"):
return False
try:
addr = ipaddress.ip_address(hostname)
if addr.is_private or addr.is_loopback or addr.is_reserved or addr.is_link_local:
return False
except ValueError:
# Not an IP literal — hostname is fine (DNS resolution handled by httpx)
# Block common internal DNS patterns
if hostname.endswith(".internal") or hostname.endswith(".local"):
return False
return True
async def fetch_remote_actor(actor_url: str) -> dict | None: async def fetch_remote_actor(actor_url: str) -> dict | None:
"""Fetch a remote actor's JSON-LD profile.""" """Fetch a remote actor's JSON-LD profile."""
if not _is_safe_url(actor_url):
log.warning("Blocked SSRF attempt: %s", actor_url)
return None
try: try:
async with httpx.AsyncClient(timeout=10) as client: async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get( resp = await client.get(

View File

@@ -13,6 +13,8 @@ import os
import httpx import httpx
from shared.infrastructure.internal_auth import sign_internal_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Re-usable async client (created lazily, one per process) # Re-usable async client (created lazily, one per process)
@@ -66,10 +68,11 @@ async def fetch_data(
base = _internal_url(app_name) base = _internal_url(app_name)
url = f"{base}/internal/data/{query_name}" url = f"{base}/internal/data/{query_name}"
try: try:
headers = {DATA_HEADER: "1", **sign_internal_headers(app_name)}
resp = await _get_client().get( resp = await _get_client().get(
url, url,
params=params, params=params,
headers={DATA_HEADER: "1"}, headers=headers,
timeout=timeout, timeout=timeout,
) )
if resp.status_code == 200: if resp.status_code == 200:

View File

@@ -77,7 +77,13 @@ def create_base_app(
configure_logging(name) configure_logging(name)
app.secret_key = os.getenv("SECRET_KEY", "dev-secret-key-change-me-777") secret_key = os.getenv("SECRET_KEY")
if not secret_key:
env = os.getenv("ENVIRONMENT", "development")
if env in ("production", "staging"):
raise RuntimeError("SECRET_KEY environment variable must be set in production")
secret_key = "dev-secret-key-change-me-777"
app.secret_key = secret_key
# Per-app first-party session cookie (no shared domain — avoids Safari ITP) # Per-app first-party session cookie (no shared domain — avoids Safari ITP)
app.config["SESSION_COOKIE_NAME"] = f"{name}_session" app.config["SESSION_COOKIE_NAME"] = f"{name}_session"
@@ -192,11 +198,14 @@ def create_base_app(
from sqlalchemy import select from sqlalchemy import select
from shared.db.session import get_account_session from shared.db.session import get_account_session
from shared.models.oauth_grant import OAuthGrant from shared.models.oauth_grant import OAuthGrant, hash_token
try: try:
token_h = hash_token(grant_token)
async with get_account_session() as s: async with get_account_session() as s:
grant = await s.scalar( grant = await s.scalar(
select(OAuthGrant).where(OAuthGrant.token == grant_token) select(OAuthGrant).where(
(OAuthGrant.token_hash == token_h) | (OAuthGrant.token == grant_token)
)
) )
valid = grant is not None and grant.revoked_at is None valid = grant is not None and grant.revoked_at is None
except Exception: except Exception:

View File

@@ -0,0 +1,92 @@
"""HMAC-based authentication for internal service-to-service calls.
Replaces the previous header-presence-only check with a signed token
that includes a timestamp to prevent replay attacks.
Signing side (data_client.py / actions.py)::
from shared.infrastructure.internal_auth import sign_internal_headers
headers = sign_internal_headers("cart")
Validation side (before_request guards, csrf.py)::
from shared.infrastructure.internal_auth import validate_internal_request
if not validate_internal_request():
abort(403)
"""
from __future__ import annotations
import hashlib
import hmac
import os
import time
from quart import request
# Shared secret — MUST be set in production
_SECRET = os.getenv("INTERNAL_HMAC_SECRET", "").encode() or os.getenv("SECRET_KEY", "").encode()
# Maximum age of a signed request (seconds)
_MAX_AGE = 300 # 5 minutes
def _get_secret() -> bytes:
return _SECRET or os.getenv("SECRET_KEY", "dev-secret-key-change-me-777").encode()
def sign_internal_headers(app_name: str) -> dict[str, str]:
"""Generate signed headers for an internal request.
Returns a dict of headers to include in the request.
"""
ts = str(int(time.time()))
payload = f"{ts}:{app_name}".encode()
sig = hmac.new(_get_secret(), payload, hashlib.sha256).hexdigest()
return {
"X-Internal-Timestamp": ts,
"X-Internal-App": app_name,
"X-Internal-Signature": sig,
}
def validate_internal_request() -> bool:
"""Validate that an incoming request has a valid HMAC signature.
Checks X-Internal-Timestamp, X-Internal-App, and X-Internal-Signature
headers. Returns True if valid, False otherwise.
"""
ts = request.headers.get("X-Internal-Timestamp", "")
app_name = request.headers.get("X-Internal-App", "")
sig = request.headers.get("X-Internal-Signature", "")
if not ts or not app_name or not sig:
return False
# Check timestamp freshness
try:
req_time = int(ts)
except (ValueError, TypeError):
return False
now = int(time.time())
if abs(now - req_time) > _MAX_AGE:
return False
# Verify signature
payload = f"{ts}:{app_name}".encode()
expected = hmac.new(_get_secret(), payload, hashlib.sha256).hexdigest()
return hmac.compare_digest(sig, expected)
def is_internal_request() -> bool:
"""Check if the current request is a signed internal request.
This is a convenience that checks for any of the internal headers
(legacy or new HMAC-signed).
"""
# New HMAC-signed headers
if request.headers.get("X-Internal-Signature"):
return validate_internal_request()
# Legacy: presence-only headers (still accepted during migration,
# but callers should be updated to use signed headers)
return False

View File

@@ -0,0 +1,142 @@
"""Redis-based rate limiter for auth endpoints.
Provides a decorator that enforces per-key rate limits using a sliding
window counter stored in Redis (auth DB 15).
Usage::
from shared.infrastructure.rate_limit import rate_limit
@rate_limit(key_func=lambda: request.form.get("email", "").lower(),
max_requests=5, window_seconds=900, scope="magic_link")
@bp.post("/start/")
async def start_login():
...
"""
from __future__ import annotations
import functools
import time
from quart import request, jsonify, make_response
async def _check_rate_limit(
key: str,
max_requests: int,
window_seconds: int,
) -> tuple[bool, int]:
"""Check and increment rate limit counter.
Returns (allowed, remaining).
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
now = time.time()
window_start = now - window_seconds
redis_key = f"rl:{key}"
pipe = r.pipeline()
# Remove expired entries
pipe.zremrangebyscore(redis_key, 0, window_start)
# Add current request
pipe.zadd(redis_key, {str(now).encode(): now})
# Count entries in window
pipe.zcard(redis_key)
# Set TTL so key auto-expires
pipe.expire(redis_key, window_seconds)
results = await pipe.execute()
count = results[2]
allowed = count <= max_requests
remaining = max(0, max_requests - count)
return allowed, remaining
def rate_limit(
*,
key_func,
max_requests: int,
window_seconds: int,
scope: str,
):
"""Decorator that rate-limits a Quart route.
Parameters
----------
key_func:
Callable returning the rate-limit key (e.g. email, IP).
Called inside request context.
max_requests:
Maximum number of requests allowed in the window.
window_seconds:
Sliding window duration in seconds.
scope:
Namespace prefix for the Redis key (e.g. "magic_link").
"""
def decorator(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
raw_key = key_func()
if not raw_key:
return await fn(*args, **kwargs)
full_key = f"{scope}:{raw_key}"
try:
allowed, remaining = await _check_rate_limit(
full_key, max_requests, window_seconds,
)
except Exception:
# If Redis is down, allow the request
return await fn(*args, **kwargs)
if not allowed:
resp = await make_response(
jsonify({"error": "rate_limited", "retry_after": window_seconds}),
429,
)
resp.headers["Retry-After"] = str(window_seconds)
return resp
return await fn(*args, **kwargs)
return wrapper
return decorator
async def check_poll_backoff(device_code: str) -> tuple[bool, int]:
"""Enforce exponential backoff on device token polling.
Returns (allowed, interval) where interval is the recommended
poll interval in seconds. If not allowed, caller should return
a 'slow_down' error per RFC 8628.
"""
from shared.infrastructure.auth_redis import get_auth_redis
r = await get_auth_redis()
key = f"rl:devpoll:{device_code}"
now = time.time()
raw = await r.get(key)
if raw:
data = raw.decode() if isinstance(raw, bytes) else raw
parts = data.split(":")
last_poll = float(parts[0])
interval = int(parts[1])
elapsed = now - last_poll
if elapsed < interval:
# Too fast — increase interval
new_interval = min(interval + 5, 60)
await r.set(key, f"{now}:{new_interval}".encode(), ex=900)
return False, new_interval
# Acceptable pace — keep current interval
await r.set(key, f"{now}:{interval}".encode(), ex=900)
return True, interval
# First poll
initial_interval = 5
await r.set(key, f"{now}:{initial_interval}".encode(), ex=900)
return True, initial_interval

View File

@@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import hashlib
from datetime import datetime from datetime import datetime
from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -6,21 +7,28 @@ from shared.db.base import Base
class OAuthCode(Base): class OAuthCode(Base):
"""Short-lived authorization code issued during OAuth flow.
The ``code`` column is retained during migration but new codes store
only ``code_hash``. Lookups should use ``code_hash``.
"""
__tablename__ = "oauth_codes" __tablename__ = "oauth_codes"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
code: Mapped[str] = mapped_column(String(128), unique=True, index=True, nullable=False) code: Mapped[str | None] = mapped_column(String(128), nullable=True)
code_hash: Mapped[str | None] = mapped_column(String(64), unique=True, nullable=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
client_id: Mapped[str] = mapped_column(String(64), nullable=False) client_id: Mapped[str] = mapped_column(String(64), nullable=False)
redirect_uri: Mapped[str] = mapped_column(String(512), nullable=False) redirect_uri: Mapped[str] = mapped_column(String(512), nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
grant_token: Mapped[str | None] = mapped_column(String(128), nullable=True) grant_token: Mapped[str | None] = mapped_column(String(128), nullable=True)
grant_token_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
user = relationship("User", backref="oauth_codes") user = relationship("User", backref="oauth_codes")
__table_args__ = ( __table_args__ = (
Index("ix_oauth_code_code", "code", unique=True), Index("ix_oauth_code_code_hash", "code_hash", unique=True),
Index("ix_oauth_code_user", "user_id"), Index("ix_oauth_code_user", "user_id"),
) )

View File

@@ -1,21 +1,31 @@
from __future__ import annotations from __future__ import annotations
import hashlib
from datetime import datetime from datetime import datetime
from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index from sqlalchemy import String, Integer, DateTime, ForeignKey, func, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.db.base import Base from shared.db.base import Base
def hash_token(token: str) -> str:
"""SHA-256 hash a token for secure DB storage."""
return hashlib.sha256(token.encode()).hexdigest()
class OAuthGrant(Base): class OAuthGrant(Base):
"""Long-lived grant tracking each client-app session authorization. """Long-lived grant tracking each client-app session authorization.
Created when the OAuth authorize endpoint issues a code. Tied to the Created when the OAuth authorize endpoint issues a code. Tied to the
account session that issued it (``issuer_session``) so that logging out account session that issued it (``issuer_session``) so that logging out
on one device revokes only that device's grants. on one device revokes only that device's grants.
The ``token`` column is retained during migration but new grants store
only ``token_hash``. Lookups should use ``token_hash``.
""" """
__tablename__ = "oauth_grants" __tablename__ = "oauth_grants"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
token: Mapped[str] = mapped_column(String(128), unique=True, nullable=False) token: Mapped[str | None] = mapped_column(String(128), nullable=True)
token_hash: Mapped[str | None] = mapped_column(String(64), unique=True, nullable=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
client_id: Mapped[str] = mapped_column(String(64), nullable=False) client_id: Mapped[str] = mapped_column(String(64), nullable=False)
issuer_session: Mapped[str] = mapped_column(String(128), nullable=False, index=True) issuer_session: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
@@ -26,7 +36,7 @@ class OAuthGrant(Base):
user = relationship("User", backref="oauth_grants") user = relationship("User", backref="oauth_grants")
__table_args__ = ( __table_args__ = (
Index("ix_oauth_grant_token", "token", unique=True), Index("ix_oauth_grant_token_hash", "token_hash", unique=True),
Index("ix_oauth_grant_issuer", "issuer_session"), Index("ix_oauth_grant_issuer", "issuer_session"),
Index("ix_oauth_grant_device", "device_id", "client_id"), Index("ix_oauth_grant_device", "device_id", "client_id"),
) )

View File

@@ -44,6 +44,7 @@ Werkzeug==3.1.3
wsproto==1.2.0 wsproto==1.2.0
zstandard==0.25.0 zstandard==0.25.0
redis>=5.0 redis>=5.0
nh3>=0.2.14
mistune>=3.0 mistune>=3.0
pytest>=8.0 pytest>=8.0
pytest-asyncio>=0.23 pytest-asyncio>=0.23