diff --git a/server.py b/server.py index a8c2264..122c772 100644 --- a/server.py +++ b/server.py @@ -8,6 +8,7 @@ Manages rendering runs and provides access to the cache. - GET /cache/{content_hash} - get cached content """ +import asyncio import base64 import hashlib import json @@ -334,8 +335,8 @@ def get_user_context_from_token(token: str) -> Optional[UserContext]: return UserContext(username=username, l2_server=l2_server, l2_domain=l2_domain) -def verify_token_with_l2(token: str, l2_server: str) -> Optional[str]: - """Verify token with the L2 server that issued it, return username if valid.""" +def _verify_token_with_l2_sync(token: str, l2_server: str) -> Optional[str]: + """Verify token with the L2 server that issued it, return username if valid. (Sync version)""" try: resp = http_requests.post( f"{l2_server}/auth/verify", @@ -349,14 +350,19 @@ def verify_token_with_l2(token: str, l2_server: str) -> Optional[str]: return None -def get_verified_user_context(token: str) -> Optional[UserContext]: +async def verify_token_with_l2(token: str, l2_server: str) -> Optional[str]: + """Verify token with the L2 server that issued it, return username if valid.""" + return await asyncio.to_thread(_verify_token_with_l2_sync, token, l2_server) + + +async def get_verified_user_context(token: str) -> Optional[UserContext]: """Get verified user context from token. Verifies with the L2 that issued it.""" ctx = get_user_context_from_token(token) if not ctx: return None - # Verify token with the L2 server from the token - verified_username = verify_token_with_l2(token, ctx.l2_server) + # Verify token with the L2 server from the token (non-blocking) + verified_username = await verify_token_with_l2(token, ctx.l2_server) if not verified_username: return None @@ -369,7 +375,7 @@ async def get_optional_user( """Get username if authenticated, None otherwise.""" if not credentials: return None - ctx = get_verified_user_context(credentials.credentials) + ctx = await get_verified_user_context(credentials.credentials) return ctx.username if ctx else None @@ -379,7 +385,7 @@ async def get_required_user( """Get username, raise 401 if not authenticated.""" if not credentials: raise HTTPException(401, "Not authenticated") - ctx = get_verified_user_context(credentials.credentials) + ctx = await get_verified_user_context(credentials.credentials) if not ctx: raise HTTPException(401, "Invalid token") return ctx.username @@ -391,7 +397,7 @@ async def get_required_user_context( """Get full user context, raise 401 if not authenticated.""" if not credentials: raise HTTPException(401, "Not authenticated") - ctx = get_verified_user_context(credentials.credentials) + ctx = await get_verified_user_context(credentials.credentials) if not ctx: raise HTTPException(401, "Invalid token") return ctx @@ -427,11 +433,12 @@ def get_cache_path(content_hash: str) -> Optional[Path]: @app.get("/api") async def api_info(): """Server info (JSON).""" + runs = await asyncio.to_thread(list_all_runs) return { "name": "Art DAG L1 Server", "version": "0.1.0", "cache_dir": str(CACHE_DIR), - "runs_count": len(list_all_runs()) + "runs_count": len(runs) } @@ -580,10 +587,21 @@ async def create_run(request: RunRequest, ctx: UserContext = Depends(get_require run.celery_task_id = task.id run.status = "running" - save_run(run) + await asyncio.to_thread(save_run, run) return run +def _check_celery_task_sync(task_id: str) -> tuple[bool, bool, Optional[dict], Optional[str]]: + """Check Celery task status synchronously. Returns (is_ready, is_successful, result, error).""" + task = celery_app.AsyncResult(task_id) + if not task.ready(): + return (False, False, None, None) + if task.successful(): + return (True, True, task.result, None) + else: + return (True, False, None, str(task.result)) + + @app.get("/runs/{run_id}", response_model=RunStatus) async def get_run(run_id: str): """Get status of a run.""" @@ -591,7 +609,7 @@ async def get_run(run_id: str): logger.info(f"get_run: Starting for {run_id}") t0 = time.time() - run = load_run(run_id) + run = await asyncio.to_thread(load_run, run_id) logger.info(f"get_run: load_run took {time.time()-t0:.3f}s, status={run.status if run else 'None'}") if not run: @@ -600,16 +618,13 @@ async def get_run(run_id: str): # Check Celery task status if running if run.status == "running" and run.celery_task_id: t0 = time.time() - task = celery_app.AsyncResult(run.celery_task_id) - logger.info(f"get_run: AsyncResult took {time.time()-t0:.3f}s") - - t0 = time.time() - is_ready = task.ready() - logger.info(f"get_run: task.ready() took {time.time()-t0:.3f}s, ready={is_ready}") + is_ready, is_successful, result, error = await asyncio.to_thread( + _check_celery_task_sync, run.celery_task_id + ) + logger.info(f"get_run: Celery check took {time.time()-t0:.3f}s, ready={is_ready}") if is_ready: - if task.successful(): - result = task.result + if is_successful: run.status = "completed" run.completed_at = datetime.now(timezone.utc).isoformat() @@ -642,18 +657,19 @@ async def get_run(run_id: str): # Record activity for deletion tracking (legacy mode) if run.output_hash and run.inputs: - cache_manager.record_simple_activity( + await asyncio.to_thread( + cache_manager.record_simple_activity, input_hashes=run.inputs, output_hash=run.output_hash, run_id=run.run_id, ) else: run.status = "failed" - run.error = str(task.result) + run.error = error # Save updated status t0 = time.time() - save_run(run) + await asyncio.to_thread(save_run, run) logger.info(f"get_run: save_run took {time.time()-t0:.3f}s") logger.info(f"get_run: Total time {time.time()-start:.3f}s") @@ -670,7 +686,7 @@ async def discard_run(run_id: str, ctx: UserContext = Depends(get_required_user_ - Deletes outputs and intermediate cache entries - Preserves inputs (cache items and recipes are NOT deleted) """ - run = load_run(run_id) + run = await asyncio.to_thread(load_run, run_id) if not run: raise HTTPException(404, f"Run {run_id} not found") @@ -688,16 +704,16 @@ async def discard_run(run_id: str, ctx: UserContext = Depends(get_required_user_ raise HTTPException(400, f"Cannot discard run: output {run.output_hash[:16]}... is pinned ({pin_reason})") # Check if activity exists for this run - activity = cache_manager.get_activity(run_id) + activity = await asyncio.to_thread(cache_manager.get_activity, run_id) if activity: # Discard the activity - only delete outputs, preserve inputs - success, msg = cache_manager.discard_activity_outputs_only(run_id) + success, msg = await asyncio.to_thread(cache_manager.discard_activity_outputs_only, run_id) if not success: raise HTTPException(400, f"Cannot discard run: {msg}") # Remove from Redis - redis_client.delete(f"{RUNS_KEY_PREFIX}{run_id}") + await asyncio.to_thread(redis_client.delete, f"{RUNS_KEY_PREFIX}{run_id}") return {"discarded": True, "run_id": run_id} @@ -709,7 +725,7 @@ async def ui_discard_run(run_id: str, request: Request): if not ctx: return '
Run not found: {run_id}
' @@ -757,10 +773,11 @@ async def run_detail(run_id: str, request: Request): # Check Celery task status if running if run.status == "running" and run.celery_task_id: - task = celery_app.AsyncResult(run.celery_task_id) - if task.ready(): - if task.successful(): - result = task.result + is_ready, is_successful, result, error = await asyncio.to_thread( + _check_celery_task_sync, run.celery_task_id + ) + if is_ready: + if is_successful: run.status = "completed" run.completed_at = datetime.now(timezone.utc).isoformat() run.output_hash = result.get("output", {}).get("content_hash") @@ -774,8 +791,8 @@ async def run_detail(run_id: str, request: Request): await cache_file(output_path) else: run.status = "failed" - run.error = str(task.result) - save_run(run) + run.error = error + await asyncio.to_thread(save_run, run) if wants_html(request): ctx = get_user_context_from_cookie(request) @@ -999,7 +1016,7 @@ async def list_runs(request: Request, page: int = 1, limit: int = 20): """List runs. HTML for browsers (with infinite scroll), JSON for APIs (with pagination).""" ctx = get_user_context_from_cookie(request) - all_runs = list_all_runs() + all_runs = await asyncio.to_thread(list_all_runs) total = len(all_runs) # Filter by user if logged in for HTML @@ -1166,7 +1183,7 @@ async def upload_recipe(file: UploadFile = File(...), ctx: UserContext = Depends except Exception as e: raise HTTPException(400, f"Failed to parse recipe: {e}") - save_recipe(recipe_status) + await asyncio.to_thread(save_recipe, recipe_status) # Save cache metadata to database await database.save_item_metadata( @@ -1191,7 +1208,7 @@ async def list_recipes_api(request: Request, page: int = 1, limit: int = 20): """List recipes. HTML for browsers, JSON for APIs.""" ctx = get_user_context_from_cookie(request) - all_recipes = list_all_recipes() + all_recipes = await asyncio.to_thread(list_all_recipes) if wants_html(request): # HTML response @@ -1307,7 +1324,7 @@ async def remove_recipe(recipe_id: str, ctx: UserContext = Depends(get_required_ @app.post("/recipes/{recipe_id}/run") async def run_recipe(recipe_id: str, request: RecipeRunRequest, ctx: UserContext = Depends(get_required_user_context)): """Run a recipe with provided variable inputs. Requires authentication.""" - recipe = load_recipe(recipe_id) + recipe = await asyncio.to_thread(load_recipe, recipe_id) if not recipe: raise HTTPException(404, f"Recipe {recipe_id} not found") @@ -1317,7 +1334,7 @@ async def run_recipe(recipe_id: str, request: RecipeRunRequest, ctx: UserContext raise HTTPException(400, f"Missing required input: {var_input.name}") # Load recipe YAML - recipe_path = cache_manager.get_by_content_hash(recipe_id) + recipe_path = await asyncio.to_thread(cache_manager.get_by_content_hash, recipe_id) if not recipe_path: raise HTTPException(500, "Recipe YAML not found in cache") @@ -1353,7 +1370,7 @@ async def run_recipe(recipe_id: str, request: RecipeRunRequest, ctx: UserContext run.celery_task_id = task.id run.status = "running" - save_run(run) + await asyncio.to_thread(save_run, run) return run @@ -2534,18 +2551,18 @@ async def discard_cache(content_hash: str, ctx: UserContext = Depends(get_requir raise HTTPException(400, f"Cannot discard pinned item (reason: {pin_reason})") # Check if used by any run (Redis runs, not just activity store) - runs_using = find_runs_using_content(content_hash) + runs_using = await asyncio.to_thread(find_runs_using_content, content_hash) if runs_using: run, role = runs_using[0] raise HTTPException(400, f"Cannot discard: item is {role} of run {run.run_id}") # Check deletion rules via cache_manager (L2 shared status, activity store) - can_delete, reason = cache_manager.can_delete(content_hash) + can_delete, reason = await asyncio.to_thread(cache_manager.can_delete, content_hash) if not can_delete: raise HTTPException(400, f"Cannot discard: {reason}") # Delete via cache_manager - success, msg = cache_manager.delete_by_content_hash(content_hash) + success, msg = await asyncio.to_thread(cache_manager.delete_by_content_hash, content_hash) if not success: # Fallback to legacy deletion cache_path = get_cache_path(content_hash) @@ -2576,7 +2593,8 @@ async def ui_discard_cache(content_hash: str, request: Request): return '