"""HTTP Signature verification for incoming AP-style inbox requests. Implements the same RSA-SHA256 / PKCS1v15 scheme used by the coop's shared/utils/http_signatures.py, but only the verification side. """ from __future__ import annotations import base64 import re from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding def verify_request_signature( public_key_pem: str, signature_header: str, method: str, path: str, headers: dict[str, str], ) -> bool: """Verify an incoming HTTP Signature. Args: public_key_pem: PEM-encoded public key of the sender. signature_header: Value of the ``Signature`` header. method: HTTP method (GET, POST, etc.). path: Request path (e.g. ``/inbox``). headers: All request headers (case-insensitive keys). Returns: True if the signature is valid. """ parts = _parse_signature_header(signature_header) signed_headers = parts.get("headers", "date").split() signature_b64 = parts.get("signature", "") # Reconstruct the signed string lc_headers = {k.lower(): v for k, v in headers.items()} lines: list[str] = [] for h in signed_headers: if h == "(request-target)": lines.append(f"(request-target): {method.lower()} {path}") else: lines.append(f"{h}: {lc_headers.get(h, '')}") signed_string = "\n".join(lines) public_key = serialization.load_pem_public_key(public_key_pem.encode()) try: public_key.verify( base64.b64decode(signature_b64), signed_string.encode(), padding.PKCS1v15(), hashes.SHA256(), ) return True except Exception: return False def parse_key_id(signature_header: str) -> str: """Extract the keyId from a Signature header. keyId is typically ``https://domain/users/username#main-key``. Returns the actor URL (strips ``#main-key``). """ parts = _parse_signature_header(signature_header) key_id = parts.get("keyId", "") return re.sub(r"#.*$", "", key_id) def _parse_signature_header(header: str) -> dict[str, str]: """Parse a Signature header into its component parts.""" parts: dict[str, str] = {} for part in header.split(","): part = part.strip() eq = part.find("=") if eq < 0: continue key = part[:eq] val = part[eq + 1:].strip('"') parts[key] = val return parts