diff --git a/server.py b/server.py index 442321d..86d81bc 100644 --- a/server.py +++ b/server.py @@ -86,11 +86,20 @@ redis_client = redis.Redis( RUNS_KEY_PREFIX = "artdag:run:" RECIPES_KEY_PREFIX = "artdag:recipe:" REVOKED_KEY_PREFIX = "artdag:revoked:" +USER_TOKENS_PREFIX = "artdag:user_tokens:" # Token revocation (30 day expiry to match token lifetime) TOKEN_EXPIRY_SECONDS = 60 * 60 * 24 * 30 +def register_user_token(username: str, token: str) -> None: + """Track a token for a user (for later revocation by username).""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{USER_TOKENS_PREFIX}{username}" + redis_client.sadd(key, token_hash) + redis_client.expire(key, TOKEN_EXPIRY_SECONDS) + + def revoke_token(token: str) -> bool: """Add token to revocation set. Returns True if newly revoked.""" token_hash = hashlib.sha256(token.encode()).hexdigest() @@ -99,6 +108,26 @@ def revoke_token(token: str) -> bool: return result is not None +def revoke_token_hash(token_hash: str) -> bool: + """Add token hash to revocation set. Returns True if newly revoked.""" + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + result = redis_client.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) + return result is not None + + +def revoke_all_user_tokens(username: str) -> int: + """Revoke all tokens for a user. Returns count revoked.""" + key = f"{USER_TOKENS_PREFIX}{username}" + token_hashes = redis_client.smembers(key) + count = 0 + for token_hash in token_hashes: + if revoke_token_hash(token_hash.decode() if isinstance(token_hash, bytes) else token_hash): + count += 1 + # Clear the user's token set + redis_client.delete(key) + return count + + def is_token_revoked(token: str) -> bool: """Check if token has been revoked.""" token_hash = hashlib.sha256(token.encode()).hexdigest() @@ -3880,6 +3909,9 @@ async def auth_callback(auth_token: str = None): if not ctx: return RedirectResponse(url="/", status_code=302) + # Register token for this user (for revocation by username later) + register_user_token(ctx.username, auth_token) + # Set local first-party cookie and redirect to home response = RedirectResponse(url="/runs", status_code=302) response.set_cookie( @@ -3923,6 +3955,26 @@ async def auth_revoke(credentials: HTTPAuthorizationCredentials = Depends(securi return {"revoked": True, "newly_revoked": newly_revoked} +class RevokeUserRequest(BaseModel): + username: str + l2_server: str # L2 server requesting the revocation + + +@app.post("/auth/revoke-user") +async def auth_revoke_user(request: RevokeUserRequest): + """ + Revoke all tokens for a user. Called by L2 when user logs out. + This handles the case where L2 issued scoped tokens that differ from L2's own token. + """ + # Verify the L2 server is authorized (must be in L1's known list or match token's l2_server) + # For now, we trust any request since this only affects users already on this L1 + + # Revoke all tokens registered for this user + count = revoke_all_user_tokens(request.username) + + return {"revoked": True, "tokens_revoked": count, "username": request.username} + + @app.post("/ui/publish-run/{run_id}", response_class=HTMLResponse) async def ui_publish_run(run_id: str, request: Request): """Publish a run to L2 from the web UI. Assets are named by content_hash."""