Files
celery/app/routers/api.py
gilesb 2e9ba46f19 Use UserContext from artdag_common, remove duplicate
- Updated requirements.txt to use art-common@11aa056 with l2_server field
- All routers now import UserContext from artdag_common
- Removed duplicate UserContext from auth_service.py
- dependencies.py sets l2_server from settings on user context

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 17:38:02 +00:00

262 lines
7.5 KiB
Python

"""
3-phase API routes for L1 server.
Provides the plan/execute/run-recipe endpoints for programmatic access.
"""
import hashlib
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional
import yaml
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from artdag_common.middleware.auth import UserContext
from ..dependencies import require_auth, get_redis_client, get_cache_manager
router = APIRouter()
logger = logging.getLogger(__name__)
# Redis key prefix
RUNS_KEY_PREFIX = "artdag:run:"
class PlanRequest(BaseModel):
recipe_yaml: str
input_hashes: Dict[str, str]
features: List[str] = ["beats", "energy"]
class ExecutePlanRequest(BaseModel):
plan_json: str
run_id: Optional[str] = None
class RecipeRunRequest(BaseModel):
recipe_yaml: str
input_hashes: Dict[str, str]
features: List[str] = ["beats", "energy"]
def compute_run_id(input_hashes: List[str], recipe: str, recipe_hash: str = None) -> str:
"""Compute deterministic run_id from inputs and recipe."""
data = {
"inputs": sorted(input_hashes),
"recipe": recipe_hash or f"effect:{recipe}",
"version": "1",
}
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
return hashlib.sha3_256(json_str.encode()).hexdigest()
@router.post("/plan")
async def generate_plan_endpoint(
request: PlanRequest,
ctx: UserContext = Depends(require_auth),
):
"""
Generate an execution plan without executing it.
Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model.
Returns the plan with cache status for each step.
"""
from tasks.orchestrate import generate_plan
try:
task = generate_plan.delay(
recipe_yaml=request.recipe_yaml,
input_hashes=request.input_hashes,
features=request.features,
)
# Wait for result (plan generation is usually fast)
result = task.get(timeout=60)
return {
"status": result.get("status"),
"recipe": result.get("recipe"),
"plan_id": result.get("plan_id"),
"total_steps": result.get("total_steps"),
"cached_steps": result.get("cached_steps"),
"pending_steps": result.get("pending_steps"),
"steps": result.get("steps"),
}
except Exception as e:
logger.error(f"Plan generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execute")
async def execute_plan_endpoint(
request: ExecutePlanRequest,
ctx: UserContext = Depends(require_auth),
):
"""
Execute a pre-generated execution plan.
Phase 3 (Execute) of the 3-phase model.
Submits the plan to Celery for parallel execution.
"""
from tasks.orchestrate import run_plan
run_id = request.run_id or str(uuid.uuid4())
try:
task = run_plan.delay(
plan_json=request.plan_json,
run_id=run_id,
)
return {
"status": "submitted",
"run_id": run_id,
"celery_task_id": task.id,
}
except Exception as e:
logger.error(f"Plan execution failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/run-recipe")
async def run_recipe_endpoint(
request: RecipeRunRequest,
ctx: UserContext = Depends(require_auth),
):
"""
Run a complete recipe through all 3 phases.
1. Analyze: Extract features from inputs
2. Plan: Generate execution plan with cache IDs
3. Execute: Run steps with parallel execution
Returns immediately with run_id. Poll /api/run/{run_id} for status.
"""
from tasks.orchestrate import run_recipe
import database
redis = get_redis_client()
cache = get_cache_manager()
# Parse recipe name
try:
recipe_data = yaml.safe_load(request.recipe_yaml)
recipe_name = recipe_data.get("name", "unknown")
except Exception:
recipe_name = "unknown"
# Compute deterministic run_id
run_id = compute_run_id(
list(request.input_hashes.values()),
recipe_name,
hashlib.sha3_256(request.recipe_yaml.encode()).hexdigest()
)
# Check if already completed
cached = await database.get_run_cache(run_id)
if cached:
output_hash = cached.get("output_hash")
if cache.has_content(output_hash):
return {
"status": "completed",
"run_id": run_id,
"output_hash": output_hash,
"output_ipfs_cid": cache.get_ipfs_cid(output_hash),
"cached": True,
}
# Submit to Celery
try:
task = run_recipe.delay(
recipe_yaml=request.recipe_yaml,
input_hashes=request.input_hashes,
features=request.features,
run_id=run_id,
)
# Store run status in Redis
run_data = {
"run_id": run_id,
"status": "pending",
"recipe": recipe_name,
"inputs": list(request.input_hashes.values()),
"celery_task_id": task.id,
"created_at": datetime.now(timezone.utc).isoformat(),
"username": ctx.actor_id,
}
redis.setex(
f"{RUNS_KEY_PREFIX}{run_id}",
86400,
json.dumps(run_data)
)
return {
"status": "submitted",
"run_id": run_id,
"celery_task_id": task.id,
"recipe": recipe_name,
}
except Exception as e:
logger.error(f"Recipe run failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/run/{run_id}")
async def get_run_status(
run_id: str,
ctx: UserContext = Depends(require_auth),
):
"""Get status of a recipe execution run."""
import database
from celery.result import AsyncResult
redis = get_redis_client()
# Check Redis for run status
run_data = redis.get(f"{RUNS_KEY_PREFIX}{run_id}")
if run_data:
data = json.loads(run_data)
# If pending, check Celery task status
if data.get("status") == "pending" and data.get("celery_task_id"):
result = AsyncResult(data["celery_task_id"])
if result.ready():
if result.successful():
task_result = result.get()
data["status"] = task_result.get("status", "completed")
data["output_hash"] = task_result.get("output_cache_id")
data["output_ipfs_cid"] = task_result.get("output_ipfs_cid")
data["total_steps"] = task_result.get("total_steps")
data["cached"] = task_result.get("cached")
data["executed"] = task_result.get("executed")
# Update Redis
redis.setex(
f"{RUNS_KEY_PREFIX}{run_id}",
86400,
json.dumps(data)
)
else:
data["status"] = "failed"
data["error"] = str(result.result)
else:
data["celery_status"] = result.status
return data
# Check database cache
cached = await database.get_run_cache(run_id)
if cached:
return {
"run_id": run_id,
"status": "completed",
"output_hash": cached.get("output_hash"),
"cached": True,
}
raise HTTPException(status_code=404, detail="Run not found")