""" Auth Service - token management and user verification. """ import hashlib import base64 import json from typing import Optional, Dict, Any, TYPE_CHECKING import httpx from artdag_common.middleware.auth import UserContext from ..config import settings if TYPE_CHECKING: import redis from starlette.requests import Request # Token expiry (30 days to match token lifetime) TOKEN_EXPIRY_SECONDS = 60 * 60 * 24 * 30 # Redis key prefixes REVOKED_KEY_PREFIX = "artdag:revoked:" USER_TOKENS_PREFIX = "artdag:user_tokens:" class AuthService: """Service for authentication and token management.""" def __init__(self, redis_client: "redis.Redis[bytes]") -> None: self.redis = redis_client def register_user_token(self, username: str, token: str) -> None: """Track a token for a user (for later revocation by username).""" token_hash = hashlib.sha256(token.encode()).hexdigest() key = f"{USER_TOKENS_PREFIX}{username}" self.redis.sadd(key, token_hash) self.redis.expire(key, TOKEN_EXPIRY_SECONDS) def revoke_token(self, token: str) -> bool: """Add token to revocation set. Returns True if newly revoked.""" token_hash = hashlib.sha256(token.encode()).hexdigest() key = f"{REVOKED_KEY_PREFIX}{token_hash}" result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) return result is not None def revoke_token_hash(self, token_hash: str) -> bool: """Add token hash to revocation set. Returns True if newly revoked.""" key = f"{REVOKED_KEY_PREFIX}{token_hash}" result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) return result is not None def revoke_all_user_tokens(self, username: str) -> int: """Revoke all tokens for a user. Returns count revoked.""" key = f"{USER_TOKENS_PREFIX}{username}" token_hashes = self.redis.smembers(key) count = 0 for token_hash in token_hashes: if self.revoke_token_hash( token_hash.decode() if isinstance(token_hash, bytes) else token_hash ): count += 1 self.redis.delete(key) return count def is_token_revoked(self, token: str) -> bool: """Check if token has been revoked.""" token_hash = hashlib.sha256(token.encode()).hexdigest() key = f"{REVOKED_KEY_PREFIX}{token_hash}" return self.redis.exists(key) > 0 def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]: """Decode JWT claims without verification.""" try: parts = token.split(".") if len(parts) != 3: return None payload = parts[1] # Add padding padding = 4 - len(payload) % 4 if padding != 4: payload += "=" * padding return json.loads(base64.urlsafe_b64decode(payload)) except (json.JSONDecodeError, ValueError): return None def get_user_context_from_token(self, token: str) -> Optional[UserContext]: """Extract user context from a token.""" if self.is_token_revoked(token): return None claims = self.decode_token_claims(token) if not claims: return None username = claims.get("username") or claims.get("sub") actor_id = claims.get("actor_id") or claims.get("actor") if not username: return None return UserContext( username=username, actor_id=actor_id or f"@{username}", token=token, l2_server=settings.l2_server, ) async def verify_token_with_l2(self, token: str) -> Optional[UserContext]: """Verify token with L2 server.""" ctx = self.get_user_context_from_token(token) if not ctx: return None # If L2 server configured, verify token if settings.l2_server: try: async with httpx.AsyncClient() as client: resp = await client.get( f"{settings.l2_server}/auth/verify", headers={"Authorization": f"Bearer {token}"}, timeout=5.0, ) if resp.status_code != 200: return None except httpx.RequestError: # L2 unavailable, trust the token pass return ctx def get_user_from_cookie(self, request: "Request") -> Optional[UserContext]: """Extract user context from auth cookie.""" token = request.cookies.get("auth_token") if not token: return None return self.get_user_context_from_token(token)