diff --git a/celery_app.py b/celery_app.py index 3e5a17e..b8f7a00 100644 --- a/celery_app.py +++ b/celery_app.py @@ -14,7 +14,7 @@ app = Celery( 'art_celery', broker=REDIS_URL, backend=REDIS_URL, - include=['tasks'] + include=['tasks', 'tasks.analyze', 'tasks.execute', 'tasks.orchestrate'] ) app.conf.update( diff --git a/claiming.py b/claiming.py new file mode 100644 index 0000000..77fa1a0 --- /dev/null +++ b/claiming.py @@ -0,0 +1,421 @@ +""" +Hash-based task claiming for distributed execution. + +Prevents duplicate work when multiple workers process the same plan. +Uses Redis Lua scripts for atomic claim operations. +""" + +import json +import logging +import os +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +import redis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + +# Key prefix for task claims +CLAIM_PREFIX = "artdag:claim:" + +# Default TTL for claims (5 minutes) +DEFAULT_CLAIM_TTL = 300 + +# TTL for completed results (1 hour) +COMPLETED_TTL = 3600 + + +class ClaimStatus(Enum): + """Status of a task claim.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + CACHED = "cached" + FAILED = "failed" + + +@dataclass +class ClaimInfo: + """Information about a task claim.""" + cache_id: str + status: ClaimStatus + worker_id: Optional[str] = None + task_id: Optional[str] = None + claimed_at: Optional[str] = None + completed_at: Optional[str] = None + output_path: Optional[str] = None + error: Optional[str] = None + + def to_dict(self) -> dict: + return { + "cache_id": self.cache_id, + "status": self.status.value, + "worker_id": self.worker_id, + "task_id": self.task_id, + "claimed_at": self.claimed_at, + "completed_at": self.completed_at, + "output_path": self.output_path, + "error": self.error, + } + + @classmethod + def from_dict(cls, data: dict) -> "ClaimInfo": + return cls( + cache_id=data["cache_id"], + status=ClaimStatus(data["status"]), + worker_id=data.get("worker_id"), + task_id=data.get("task_id"), + claimed_at=data.get("claimed_at"), + completed_at=data.get("completed_at"), + output_path=data.get("output_path"), + error=data.get("error"), + ) + + +# Lua script for atomic task claiming +# Returns 1 if claim successful, 0 if already claimed/completed +CLAIM_TASK_SCRIPT = """ +local key = KEYS[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + local s = status['status'] + -- Already claimed, running, completed, or cached - don't claim + if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then + return 0 + end +end + +-- Claim the task +local claim_data = ARGV[1] +local ttl = tonumber(ARGV[2]) +redis.call('SETEX', key, ttl, claim_data) +return 1 +""" + +# Lua script for releasing a claim (e.g., on failure) +RELEASE_CLAIM_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + -- Only release if we own the claim + if status['worker_id'] == worker_id then + redis.call('DEL', key) + return 1 + end +end +return 0 +""" + +# Lua script for updating claim status (claimed -> running -> completed) +UPDATE_STATUS_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local new_status = ARGV[2] +local new_data = ARGV[3] +local ttl = tonumber(ARGV[4]) + +local data = redis.call('GET', key) +if not data then + return 0 +end + +local status = cjson.decode(data) + +-- Only update if we own the claim +if status['worker_id'] ~= worker_id then + return 0 +end + +redis.call('SETEX', key, ttl, new_data) +return 1 +""" + + +class TaskClaimer: + """ + Manages hash-based task claiming for distributed execution. + + Uses Redis for coordination between workers. + Each task is identified by its cache_id (content-addressed). + """ + + def __init__(self, redis_url: str = None): + """ + Initialize the claimer. + + Args: + redis_url: Redis connection URL + """ + self.redis_url = redis_url or REDIS_URL + self._redis: Optional[redis.Redis] = None + self._claim_script = None + self._release_script = None + self._update_script = None + + @property + def redis(self) -> redis.Redis: + """Get Redis connection (lazy initialization).""" + if self._redis is None: + self._redis = redis.from_url(self.redis_url, decode_responses=True) + # Register Lua scripts + self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT) + self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT) + self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT) + return self._redis + + def _key(self, cache_id: str) -> str: + """Get Redis key for a cache_id.""" + return f"{CLAIM_PREFIX}{cache_id}" + + def claim( + self, + cache_id: str, + worker_id: str, + task_id: Optional[str] = None, + ttl: int = DEFAULT_CLAIM_TTL, + ) -> bool: + """ + Attempt to claim a task. + + Args: + cache_id: The cache ID of the task to claim + worker_id: Identifier for the claiming worker + task_id: Optional Celery task ID + ttl: Time-to-live for the claim in seconds + + Returns: + True if claim successful, False if already claimed + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CLAIMED, + worker_id=worker_id, + task_id=task_id, + claimed_at=datetime.now(timezone.utc).isoformat(), + ) + + result = self._claim_script( + keys=[self._key(cache_id)], + args=[json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}") + return True + else: + logger.debug(f"Task {cache_id[:16]}... already claimed") + return False + + def update_status( + self, + cache_id: str, + worker_id: str, + status: ClaimStatus, + output_path: Optional[str] = None, + error: Optional[str] = None, + ttl: Optional[int] = None, + ) -> bool: + """ + Update the status of a claimed task. + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + status: New status + output_path: Path to output (for completed) + error: Error message (for failed) + ttl: New TTL (defaults based on status) + + Returns: + True if update successful + """ + if ttl is None: + if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED): + ttl = COMPLETED_TTL + else: + ttl = DEFAULT_CLAIM_TTL + + # Get existing claim info + existing = self.get_status(cache_id) + if not existing: + logger.warning(f"No claim found for {cache_id[:16]}...") + return False + + claim_info = ClaimInfo( + cache_id=cache_id, + status=status, + worker_id=worker_id, + task_id=existing.task_id, + claimed_at=existing.claimed_at, + completed_at=datetime.now(timezone.utc).isoformat() if status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ) else None, + output_path=output_path, + error=error, + ) + + result = self._update_script( + keys=[self._key(cache_id)], + args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Updated task {cache_id[:16]}... to {status.value}") + return True + else: + logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)") + return False + + def release(self, cache_id: str, worker_id: str) -> bool: + """ + Release a claim (e.g., on task failure before completion). + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + + Returns: + True if release successful + """ + result = self._release_script( + keys=[self._key(cache_id)], + args=[worker_id], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Released claim on {cache_id[:16]}...") + return True + return False + + def get_status(self, cache_id: str) -> Optional[ClaimInfo]: + """ + Get the current status of a task. + + Args: + cache_id: The cache ID of the task + + Returns: + ClaimInfo if task has been claimed, None otherwise + """ + data = self.redis.get(self._key(cache_id)) + if data: + return ClaimInfo.from_dict(json.loads(data)) + return None + + def is_completed(self, cache_id: str) -> bool: + """Check if a task is completed or cached.""" + info = self.get_status(cache_id) + return info is not None and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED + ) + + def wait_for_completion( + self, + cache_id: str, + timeout: float = 300, + poll_interval: float = 0.5, + ) -> Optional[ClaimInfo]: + """ + Wait for a task to complete. + + Args: + cache_id: The cache ID of the task + timeout: Maximum time to wait in seconds + poll_interval: How often to check status + + Returns: + ClaimInfo if completed, None if timeout + """ + start_time = time.time() + while time.time() - start_time < timeout: + info = self.get_status(cache_id) + if info and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ): + return info + time.sleep(poll_interval) + + logger.warning(f"Timeout waiting for {cache_id[:16]}...") + return None + + def mark_cached(self, cache_id: str, output_path: str) -> None: + """ + Mark a task as already cached (no processing needed). + + This is used when we discover the result already exists + before attempting to claim. + + Args: + cache_id: The cache ID of the task + output_path: Path to the cached output + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CACHED, + output_path=output_path, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + self.redis.setex( + self._key(cache_id), + COMPLETED_TTL, + json.dumps(claim_info.to_dict()), + ) + + def clear_all(self) -> int: + """ + Clear all claims (for testing/reset). + + Returns: + Number of claims cleared + """ + pattern = f"{CLAIM_PREFIX}*" + keys = list(self.redis.scan_iter(match=pattern)) + if keys: + return self.redis.delete(*keys) + return 0 + + +# Global claimer instance +_claimer: Optional[TaskClaimer] = None + + +def get_claimer() -> TaskClaimer: + """Get the global TaskClaimer instance.""" + global _claimer + if _claimer is None: + _claimer = TaskClaimer() + return _claimer + + +def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool: + """Convenience function to claim a task.""" + return get_claimer().claim(cache_id, worker_id, task_id) + + +def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool: + """Convenience function to mark a task as completed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path + ) + + +def fail_task(cache_id: str, worker_id: str, error: str) -> bool: + """Convenience function to mark a task as failed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.FAILED, error=error + ) diff --git a/server.py b/server.py index f32ea15..8564da7 100644 --- a/server.py +++ b/server.py @@ -4964,6 +4964,231 @@ async def download_client(): ) +# ============================================================================ +# 3-Phase Execution API (Analyze → Plan → Execute) +# ============================================================================ + +class RecipeRunRequest(BaseModel): + """Request to run a recipe with the 3-phase execution model.""" + recipe_yaml: str # Recipe YAML content + input_hashes: dict # Mapping from input name to content hash + features: Optional[list[str]] = None # Features to extract (default: beats, energy) + + +class PlanRequest(BaseModel): + """Request to generate an execution plan.""" + recipe_yaml: str + input_hashes: dict + features: Optional[list[str]] = None + + +class ExecutePlanRequest(BaseModel): + """Request to execute a pre-generated plan.""" + plan_json: str # JSON-serialized ExecutionPlan + + +@app.post("/api/v2/plan") +async def generate_plan_endpoint( + request: PlanRequest, + ctx: UserContext = Depends(get_required_user_context) +): + """ + Generate an execution plan without executing it. + + Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model. + + Returns the plan with cache status for each step. + """ + from tasks.orchestrate import generate_plan + + try: + # Submit to Celery + task = generate_plan.delay( + recipe_yaml=request.recipe_yaml, + input_hashes=request.input_hashes, + features=request.features, + ) + + # Wait for result (plan generation is usually fast) + result = task.get(timeout=60) + + return { + "status": result.get("status"), + "recipe": result.get("recipe"), + "plan_id": result.get("plan_id"), + "total_steps": result.get("total_steps"), + "cached_steps": result.get("cached_steps"), + "pending_steps": result.get("pending_steps"), + "steps": result.get("steps"), + } + except Exception as e: + logger.error(f"Plan generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/v2/execute") +async def execute_plan_endpoint( + request: ExecutePlanRequest, + ctx: UserContext = Depends(get_required_user_context) +): + """ + Execute a pre-generated execution plan. + + Phase 3 (Execute) of the 3-phase model. + + Submits the plan to Celery for parallel execution. + """ + from tasks.orchestrate import run_plan + + run_id = str(uuid.uuid4()) + + try: + # Submit to Celery (async) + task = run_plan.delay( + plan_json=request.plan_json, + run_id=run_id, + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + } + except Exception as e: + logger.error(f"Plan execution failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/v2/run-recipe") +async def run_recipe_endpoint( + request: RecipeRunRequest, + ctx: UserContext = Depends(get_required_user_context) +): + """ + Run a complete recipe through all 3 phases. + + 1. Analyze: Extract features from inputs + 2. Plan: Generate execution plan with cache IDs + 3. Execute: Run steps with parallel execution + + Returns immediately with run_id. Poll /api/v2/run/{run_id} for status. + """ + from tasks.orchestrate import run_recipe + + # Compute run_id from inputs and recipe + try: + recipe_data = yaml.safe_load(request.recipe_yaml) + recipe_name = recipe_data.get("name", "unknown") + except Exception: + recipe_name = "unknown" + + run_id = compute_run_id( + list(request.input_hashes.values()), + recipe_name, + hashlib.sha3_256(request.recipe_yaml.encode()).hexdigest() + ) + + # Check if already completed + cached = await database.get_run_cache(run_id) + if cached: + output_hash = cached.get("output_hash") + if cache_manager.has_content(output_hash): + return { + "status": "completed", + "run_id": run_id, + "output_hash": output_hash, + "output_ipfs_cid": cache_manager.get_ipfs_cid(output_hash), + "cached": True, + } + + # Submit to Celery + try: + task = run_recipe.delay( + recipe_yaml=request.recipe_yaml, + input_hashes=request.input_hashes, + features=request.features, + run_id=run_id, + ) + + # Store run status in Redis + run_data = { + "run_id": run_id, + "status": "pending", + "recipe": recipe_name, + "inputs": list(request.input_hashes.values()), + "celery_task_id": task.id, + "created_at": datetime.now(timezone.utc).isoformat(), + "username": ctx.actor_id, + } + redis_client.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, # 24 hour expiry + json.dumps(run_data) + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + "recipe": recipe_name, + } + except Exception as e: + logger.error(f"Recipe run failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/v2/run/{run_id}") +async def get_run_v2(run_id: str, ctx: UserContext = Depends(get_required_user_context)): + """ + Get status of a 3-phase execution run. + """ + # Check Redis for run status + run_data = redis_client.get(f"{RUNS_KEY_PREFIX}{run_id}") + if run_data: + data = json.loads(run_data) + + # If pending, check Celery task status + if data.get("status") == "pending" and data.get("celery_task_id"): + from celery.result import AsyncResult + result = AsyncResult(data["celery_task_id"]) + + if result.ready(): + if result.successful(): + task_result = result.get() + data["status"] = task_result.get("status", "completed") + data["output_hash"] = task_result.get("output_cache_id") + data["output_ipfs_cid"] = task_result.get("output_ipfs_cid") + data["total_steps"] = task_result.get("total_steps") + data["cached"] = task_result.get("cached") + data["executed"] = task_result.get("executed") + + # Update Redis + redis_client.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, + json.dumps(data) + ) + else: + data["status"] = "failed" + data["error"] = str(result.result) + else: + data["celery_status"] = result.status + + return data + + # Check database cache + cached = await database.get_run_cache(run_id) + if cached: + return { + "run_id": run_id, + "status": "completed", + "output_hash": cached.get("output_hash"), + "cached": True, + } + + raise HTTPException(status_code=404, detail="Run not found") + + if __name__ == "__main__": import uvicorn # Workers enabled - cache indexes shared via Redis diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000..02ca273 --- /dev/null +++ b/tasks/__init__.py @@ -0,0 +1,18 @@ +# art-celery/tasks - Celery tasks for 3-phase execution +# +# Tasks for the Art DAG distributed execution system: +# 1. analyze_input - Extract features from input media +# 2. execute_step - Execute a single step from the plan +# 3. run_plan - Orchestrate execution of a full plan + +from .analyze import analyze_input, analyze_inputs +from .execute import execute_step +from .orchestrate import run_plan, run_recipe + +__all__ = [ + "analyze_input", + "analyze_inputs", + "execute_step", + "run_plan", + "run_recipe", +] diff --git a/tasks/analyze.py b/tasks/analyze.py new file mode 100644 index 0000000..d68f9bf --- /dev/null +++ b/tasks/analyze.py @@ -0,0 +1,132 @@ +""" +Analysis tasks for extracting features from input media. + +Phase 1 of the 3-phase execution model. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Dict, List, Optional + +from celery import current_task + +# Import from the Celery app +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from celery_app import app + +# Import artdag analysis module +try: + from artdag.analysis import Analyzer, AnalysisResult +except ImportError: + # artdag not installed, will fail at runtime + Analyzer = None + AnalysisResult = None + +logger = logging.getLogger(__name__) + +# Cache directory for analysis results +CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache')) +ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis' + + +@app.task(bind=True, name='tasks.analyze_input') +def analyze_input( + self, + input_hash: str, + input_path: str, + features: List[str], +) -> dict: + """ + Analyze a single input file. + + Args: + input_hash: Content hash of the input + input_path: Path to the input file + features: List of features to extract + + Returns: + Dict with analysis results + """ + if Analyzer is None: + raise ImportError("artdag.analysis not available") + + logger.info(f"Analyzing {input_hash[:16]}... for features: {features}") + + # Create analyzer with caching + ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True) + analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR) + + try: + result = analyzer.analyze( + input_hash=input_hash, + features=features, + input_path=Path(input_path), + ) + + return { + "status": "completed", + "input_hash": input_hash, + "cache_id": result.cache_id, + "features": features, + "result": result.to_dict(), + } + + except Exception as e: + logger.error(f"Analysis failed for {input_hash}: {e}") + return { + "status": "failed", + "input_hash": input_hash, + "error": str(e), + } + + +@app.task(bind=True, name='tasks.analyze_inputs') +def analyze_inputs( + self, + inputs: Dict[str, str], + features: List[str], +) -> dict: + """ + Analyze multiple inputs in parallel. + + Args: + inputs: Dict mapping input_hash to file path + features: List of features to extract from all inputs + + Returns: + Dict with all analysis results + """ + if Analyzer is None: + raise ImportError("artdag.analysis not available") + + logger.info(f"Analyzing {len(inputs)} inputs for features: {features}") + + ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True) + analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR) + + results = {} + errors = [] + + for input_hash, input_path in inputs.items(): + try: + result = analyzer.analyze( + input_hash=input_hash, + features=features, + input_path=Path(input_path), + ) + results[input_hash] = result.to_dict() + + except Exception as e: + logger.error(f"Analysis failed for {input_hash}: {e}") + errors.append({"input_hash": input_hash, "error": str(e)}) + + return { + "status": "completed" if not errors else "partial", + "results": results, + "errors": errors, + "total": len(inputs), + "successful": len(results), + } diff --git a/tasks/execute.py b/tasks/execute.py new file mode 100644 index 0000000..326bfb6 --- /dev/null +++ b/tasks/execute.py @@ -0,0 +1,298 @@ +""" +Step execution task. + +Phase 3 of the 3-phase execution model. +Executes individual steps from an execution plan with IPFS-backed caching. +""" + +import json +import logging +import os +import socket +from pathlib import Path +from typing import Dict, List, Optional + +from celery import current_task + +# Import from the Celery app +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from celery_app import app +from claiming import ( + get_claimer, + claim_task, + complete_task, + fail_task, + ClaimStatus, +) +from cache_manager import get_cache_manager, L1CacheManager + +# Import artdag +try: + from artdag import Cache, NodeType + from artdag.executor import get_executor + from artdag.planning import ExecutionStep +except ImportError: + Cache = None + NodeType = None + get_executor = None + ExecutionStep = None + +logger = logging.getLogger(__name__) + + +def get_worker_id() -> str: + """Get a unique identifier for this worker.""" + hostname = socket.gethostname() + pid = os.getpid() + return f"{hostname}:{pid}" + + +@app.task(bind=True, name='tasks.execute_step') +def execute_step( + self, + step_json: str, + plan_id: str, + input_cache_ids: Dict[str, str], +) -> dict: + """ + Execute a single step from an execution plan. + + Uses hash-based claiming to prevent duplicate work. + Results are stored in IPFS-backed cache. + + Args: + step_json: JSON-serialized ExecutionStep + plan_id: ID of the parent execution plan + input_cache_ids: Mapping from input step_id to their cache_id + + Returns: + Dict with execution result + """ + if ExecutionStep is None: + raise ImportError("artdag.planning not available") + + step = ExecutionStep.from_json(step_json) + worker_id = get_worker_id() + task_id = self.request.id + + logger.info(f"Executing step {step.step_id} ({step.node_type}) cache_id={step.cache_id[:16]}...") + + # Get L1 cache manager (IPFS-backed) + cache_mgr = get_cache_manager() + + # Check if already cached (by cache_id as content_hash) + cached_path = cache_mgr.get_by_content_hash(step.cache_id) + if cached_path: + logger.info(f"Step {step.step_id} already cached at {cached_path}") + + # Mark as cached in claiming system + claimer = get_claimer() + claimer.mark_cached(step.cache_id, str(cached_path)) + + return { + "status": "cached", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": str(cached_path), + } + + # Try to claim the task + if not claim_task(step.cache_id, worker_id, task_id): + # Another worker is handling it + logger.info(f"Step {step.step_id} claimed by another worker, waiting...") + + claimer = get_claimer() + result = claimer.wait_for_completion(step.cache_id, timeout=600) + + if result and result.status == ClaimStatus.COMPLETED: + return { + "status": "completed_by_other", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": result.output_path, + } + elif result and result.status == ClaimStatus.CACHED: + return { + "status": "cached", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": result.output_path, + } + elif result and result.status == ClaimStatus.FAILED: + return { + "status": "failed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "error": result.error, + } + else: + return { + "status": "timeout", + "step_id": step.step_id, + "cache_id": step.cache_id, + "error": "Timeout waiting for other worker", + } + + # We have the claim, update to running + claimer = get_claimer() + claimer.update_status(step.cache_id, worker_id, ClaimStatus.RUNNING) + + try: + # Handle SOURCE nodes + if step.node_type == "SOURCE": + content_hash = step.config.get("content_hash") + if not content_hash: + raise ValueError(f"SOURCE step missing content_hash") + + # Look up in cache + path = cache_mgr.get_by_content_hash(content_hash) + if not path: + raise ValueError(f"SOURCE input not found in cache: {content_hash[:16]}...") + + output_path = str(path) + complete_task(step.cache_id, worker_id, output_path) + return { + "status": "completed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": output_path, + } + + # Handle _LIST virtual nodes + if step.node_type == "_LIST": + item_paths = [] + for item_id in step.config.get("items", []): + item_cache_id = input_cache_ids.get(item_id) + if item_cache_id: + path = cache_mgr.get_by_content_hash(item_cache_id) + if path: + item_paths.append(str(path)) + + complete_task(step.cache_id, worker_id, json.dumps(item_paths)) + return { + "status": "completed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": None, + "item_paths": item_paths, + } + + # Get executor for this node type + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + raise ValueError(f"No executor for node type: {step.node_type}") + + # Resolve input paths from cache + input_paths = [] + for input_step_id in step.input_steps: + input_cache_id = input_cache_ids.get(input_step_id) + if not input_cache_id: + raise ValueError(f"No cache_id for input step: {input_step_id}") + + path = cache_mgr.get_by_content_hash(input_cache_id) + if not path: + raise ValueError(f"Input not in cache: {input_cache_id[:16]}...") + + input_paths.append(Path(path)) + + # Create temp output path + import tempfile + output_dir = Path(tempfile.mkdtemp()) + output_path = output_dir / f"output_{step.cache_id[:16]}.mp4" + + # Execute + logger.info(f"Running executor for {step.node_type} with {len(input_paths)} inputs") + result_path = executor.execute(step.config, input_paths, output_path) + + # Store in IPFS-backed cache + cached_file, ipfs_cid = cache_mgr.put( + source_path=result_path, + node_type=step.node_type, + node_id=step.cache_id, + ) + + logger.info(f"Step {step.step_id} completed, IPFS CID: {ipfs_cid}") + + # Mark completed + complete_task(step.cache_id, worker_id, str(cached_file.path)) + + # Cleanup temp + if output_dir.exists(): + import shutil + shutil.rmtree(output_dir, ignore_errors=True) + + return { + "status": "completed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "output_path": str(cached_file.path), + "content_hash": cached_file.content_hash, + "ipfs_cid": ipfs_cid, + } + + except Exception as e: + logger.error(f"Step {step.step_id} failed: {e}") + fail_task(step.cache_id, worker_id, str(e)) + + return { + "status": "failed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "error": str(e), + } + + +@app.task(bind=True, name='tasks.execute_level') +def execute_level( + self, + steps_json: List[str], + plan_id: str, + cache_ids: Dict[str, str], +) -> dict: + """ + Execute all steps at a given dependency level. + + Steps at the same level can run in parallel. + + Args: + steps_json: List of JSON-serialized ExecutionSteps + plan_id: ID of the parent execution plan + cache_ids: Mapping from step_id to cache_id + + Returns: + Dict with results for all steps + """ + from celery import group + + # Dispatch all steps in parallel + tasks = [ + execute_step.s(step_json, plan_id, cache_ids) + for step_json in steps_json + ] + + # Execute in parallel and collect results + job = group(tasks) + results = job.apply_async() + + # Wait for completion + step_results = results.get(timeout=3600) # 1 hour timeout + + # Build cache_ids from results + new_cache_ids = dict(cache_ids) + for result in step_results: + step_id = result.get("step_id") + cache_id = result.get("cache_id") + if step_id and cache_id: + new_cache_ids[step_id] = cache_id + + return { + "status": "completed", + "results": step_results, + "cache_ids": new_cache_ids, + } diff --git a/tasks/orchestrate.py b/tasks/orchestrate.py new file mode 100644 index 0000000..4de6942 --- /dev/null +++ b/tasks/orchestrate.py @@ -0,0 +1,373 @@ +""" +Plan orchestration tasks. + +Coordinates the full 3-phase execution: +1. Analyze inputs +2. Generate plan +3. Execute steps level by level + +Uses IPFS-backed cache for durability. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Dict, List, Optional + +from celery import current_task, group, chain + +# Import from the Celery app +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from celery_app import app +from claiming import get_claimer +from cache_manager import get_cache_manager + +# Import artdag modules +try: + from artdag import Cache + from artdag.analysis import Analyzer, AnalysisResult + from artdag.planning import RecipePlanner, ExecutionPlan, Recipe +except ImportError: + Cache = None + Analyzer = None + AnalysisResult = None + RecipePlanner = None + ExecutionPlan = None + Recipe = None + +from .execute import execute_step + +logger = logging.getLogger(__name__) + +# Cache directories +CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache')) +ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis' +PLAN_CACHE_DIR = CACHE_DIR / 'plans' + + +@app.task(bind=True, name='tasks.run_plan') +def run_plan( + self, + plan_json: str, + run_id: Optional[str] = None, +) -> dict: + """ + Execute a complete execution plan. + + Runs steps level by level, with parallel execution within each level. + Results are stored in IPFS-backed cache. + + Args: + plan_json: JSON-serialized ExecutionPlan + run_id: Optional run ID for tracking + + Returns: + Dict with execution results + """ + if ExecutionPlan is None: + raise ImportError("artdag.planning not available") + + plan = ExecutionPlan.from_json(plan_json) + cache_mgr = get_cache_manager() + + logger.info(f"Executing plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)") + + # Build initial cache_ids mapping (step_id -> cache_id) + cache_ids = {} + for step in plan.steps: + cache_ids[step.step_id] = step.cache_id + + # Also map input hashes + for name, content_hash in plan.input_hashes.items(): + cache_ids[name] = content_hash + + # Group steps by level + steps_by_level = plan.get_steps_by_level() + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + + results_by_step = {} + total_cached = 0 + total_executed = 0 + + for level in range(max_level + 1): + level_steps = steps_by_level.get(level, []) + if not level_steps: + continue + + logger.info(f"Executing level {level}: {len(level_steps)} steps") + + # Check which steps need execution + steps_to_run = [] + + for step in level_steps: + # Check if cached + cached_path = cache_mgr.get_by_content_hash(step.cache_id) + if cached_path: + results_by_step[step.step_id] = { + "status": "cached", + "cache_id": step.cache_id, + "output_path": str(cached_path), + } + total_cached += 1 + else: + steps_to_run.append(step) + + if not steps_to_run: + logger.info(f"Level {level}: all steps cached") + continue + + # Build input cache_ids for this level + level_cache_ids = dict(cache_ids) + + # Execute steps in parallel + tasks = [ + execute_step.s(step.to_json(), plan.plan_id, level_cache_ids) + for step in steps_to_run + ] + + job = group(tasks) + async_results = job.apply_async() + + # Wait for completion + try: + step_results = async_results.get(timeout=3600) + except Exception as e: + logger.error(f"Level {level} execution failed: {e}") + return { + "status": "failed", + "error": str(e), + "level": level, + "results": results_by_step, + "run_id": run_id, + } + + # Process results + for result in step_results: + step_id = result.get("step_id") + cache_id = result.get("cache_id") + + results_by_step[step_id] = result + cache_ids[step_id] = cache_id + + if result.get("status") in ("completed", "cached", "completed_by_other"): + total_executed += 1 + elif result.get("status") == "failed": + logger.error(f"Step {step_id} failed: {result.get('error')}") + return { + "status": "failed", + "error": f"Step {step_id} failed: {result.get('error')}", + "level": level, + "results": results_by_step, + "run_id": run_id, + } + + # Get final output + output_step = plan.get_step(plan.output_step) + output_cache_id = output_step.cache_id if output_step else None + output_path = None + output_ipfs_cid = None + + if output_cache_id: + output_path = cache_mgr.get_by_content_hash(output_cache_id) + output_ipfs_cid = cache_mgr.get_ipfs_cid(output_cache_id) + + return { + "status": "completed", + "run_id": run_id, + "plan_id": plan.plan_id, + "output_cache_id": output_cache_id, + "output_path": str(output_path) if output_path else None, + "output_ipfs_cid": output_ipfs_cid, + "total_steps": len(plan.steps), + "cached": total_cached, + "executed": total_executed, + "results": results_by_step, + } + + +@app.task(bind=True, name='tasks.run_recipe') +def run_recipe( + self, + recipe_yaml: str, + input_hashes: Dict[str, str], + features: List[str] = None, + run_id: Optional[str] = None, +) -> dict: + """ + Run a complete recipe through all 3 phases. + + 1. Analyze: Extract features from inputs + 2. Plan: Generate execution plan + 3. Execute: Run the plan + + Args: + recipe_yaml: Recipe YAML content + input_hashes: Mapping from input name to content hash + features: Features to extract (default: ["beats", "energy"]) + run_id: Optional run ID for tracking + + Returns: + Dict with final results + """ + if RecipePlanner is None or Analyzer is None: + raise ImportError("artdag modules not available") + + if features is None: + features = ["beats", "energy"] + + cache_mgr = get_cache_manager() + + logger.info(f"Running recipe with {len(input_hashes)} inputs") + + # Phase 1: Analyze + logger.info("Phase 1: Analyzing inputs...") + + ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True) + analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR) + + analysis_results = {} + for name, content_hash in input_hashes.items(): + # Get path from cache + path = cache_mgr.get_by_content_hash(content_hash) + if path: + try: + result = analyzer.analyze( + input_hash=content_hash, + features=features, + input_path=Path(path), + ) + analysis_results[content_hash] = result + logger.info(f"Analyzed {name}: tempo={result.tempo}, beats={len(result.beat_times or [])}") + except Exception as e: + logger.warning(f"Analysis failed for {name}: {e}") + else: + logger.warning(f"Input {name} ({content_hash[:16]}...) not in cache") + + logger.info(f"Analyzed {len(analysis_results)} inputs") + + # Phase 2: Plan + logger.info("Phase 2: Generating execution plan...") + + recipe = Recipe.from_yaml(recipe_yaml) + planner = RecipePlanner(use_tree_reduction=True) + + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis_results, + ) + + logger.info(f"Generated plan with {len(plan.steps)} steps") + + # Save plan for debugging + PLAN_CACHE_DIR.mkdir(parents=True, exist_ok=True) + plan_path = PLAN_CACHE_DIR / f"{plan.plan_id}.json" + with open(plan_path, "w") as f: + f.write(plan.to_json()) + + # Phase 3: Execute + logger.info("Phase 3: Executing plan...") + + result = run_plan(plan.to_json(), run_id=run_id) + + return { + "status": result.get("status"), + "run_id": run_id, + "recipe": recipe.name, + "plan_id": plan.plan_id, + "output_path": result.get("output_path"), + "output_cache_id": result.get("output_cache_id"), + "output_ipfs_cid": result.get("output_ipfs_cid"), + "analysis_count": len(analysis_results), + "total_steps": len(plan.steps), + "cached": result.get("cached", 0), + "executed": result.get("executed", 0), + "error": result.get("error"), + } + + +@app.task(bind=True, name='tasks.generate_plan') +def generate_plan( + self, + recipe_yaml: str, + input_hashes: Dict[str, str], + features: List[str] = None, +) -> dict: + """ + Generate an execution plan without executing it. + + Useful for: + - Previewing what will be executed + - Checking cache status + - Debugging recipe issues + + Args: + recipe_yaml: Recipe YAML content + input_hashes: Mapping from input name to content hash + features: Features to extract for analysis + + Returns: + Dict with plan details + """ + if RecipePlanner is None or Analyzer is None: + raise ImportError("artdag modules not available") + + if features is None: + features = ["beats", "energy"] + + cache_mgr = get_cache_manager() + + # Analyze inputs + ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True) + analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR) + + analysis_results = {} + for name, content_hash in input_hashes.items(): + path = cache_mgr.get_by_content_hash(content_hash) + if path: + try: + result = analyzer.analyze( + input_hash=content_hash, + features=features, + input_path=Path(path), + ) + analysis_results[content_hash] = result + except Exception as e: + logger.warning(f"Analysis failed for {name}: {e}") + + # Generate plan + recipe = Recipe.from_yaml(recipe_yaml) + planner = RecipePlanner(use_tree_reduction=True) + + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis_results, + ) + + # Check cache status for each step + steps_status = [] + for step in plan.steps: + cached = cache_mgr.has_content(step.cache_id) + steps_status.append({ + "step_id": step.step_id, + "node_type": step.node_type, + "cache_id": step.cache_id, + "level": step.level, + "cached": cached, + }) + + cached_count = sum(1 for s in steps_status if s["cached"]) + + return { + "status": "planned", + "recipe": recipe.name, + "plan_id": plan.plan_id, + "total_steps": len(plan.steps), + "cached_steps": cached_count, + "pending_steps": len(plan.steps) - cached_count, + "steps": steps_status, + "plan_json": plan.to_json(), + }