""" Step execution task. Phase 3 of the 3-phase execution model. Executes individual steps from an execution plan with IPFS-backed caching. """ import json import logging import os import socket from pathlib import Path from typing import Dict, List, Optional from celery import current_task # Import from the Celery app import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from celery_app import app from claiming import ( get_claimer, claim_task, complete_task, fail_task, ClaimStatus, ) from cache_manager import get_cache_manager, L1CacheManager # Import artdag try: from artdag import Cache, NodeType from artdag.executor import get_executor from artdag.planning import ExecutionStep except ImportError: Cache = None NodeType = None get_executor = None ExecutionStep = None logger = logging.getLogger(__name__) def get_worker_id() -> str: """Get a unique identifier for this worker.""" hostname = socket.gethostname() pid = os.getpid() return f"{hostname}:{pid}" @app.task(bind=True, name='tasks.execute_step') def execute_step( self, step_json: str, plan_id: str, input_cache_ids: Dict[str, str], ) -> dict: """ Execute a single step from an execution plan. Uses hash-based claiming to prevent duplicate work. Results are stored in IPFS-backed cache. Args: step_json: JSON-serialized ExecutionStep plan_id: ID of the parent execution plan input_cache_ids: Mapping from input step_id to their cache_id Returns: Dict with execution result """ if ExecutionStep is None: raise ImportError("artdag.planning not available") step = ExecutionStep.from_json(step_json) worker_id = get_worker_id() task_id = self.request.id logger.info(f"Executing step {step.step_id} ({step.node_type}) cache_id={step.cache_id[:16]}...") # Get L1 cache manager (IPFS-backed) cache_mgr = get_cache_manager() # Check if already cached (by cache_id as content_hash) cached_path = cache_mgr.get_by_content_hash(step.cache_id) if cached_path: logger.info(f"Step {step.step_id} already cached at {cached_path}") # Mark as cached in claiming system claimer = get_claimer() claimer.mark_cached(step.cache_id, str(cached_path)) return { "status": "cached", "step_id": step.step_id, "cache_id": step.cache_id, "output_path": str(cached_path), } # Try to claim the task if not claim_task(step.cache_id, worker_id, task_id): # Another worker is handling it logger.info(f"Step {step.step_id} claimed by another worker, waiting...") claimer = get_claimer() result = claimer.wait_for_completion(step.cache_id, timeout=600) if result and result.status == ClaimStatus.COMPLETED: return { "status": "completed_by_other", "step_id": step.step_id, "cache_id": step.cache_id, "output_path": result.output_path, } elif result and result.status == ClaimStatus.CACHED: return { "status": "cached", "step_id": step.step_id, "cache_id": step.cache_id, "output_path": result.output_path, } elif result and result.status == ClaimStatus.FAILED: return { "status": "failed", "step_id": step.step_id, "cache_id": step.cache_id, "error": result.error, } else: return { "status": "timeout", "step_id": step.step_id, "cache_id": step.cache_id, "error": "Timeout waiting for other worker", } # We have the claim, update to running claimer = get_claimer() claimer.update_status(step.cache_id, worker_id, ClaimStatus.RUNNING) try: # Handle SOURCE nodes if step.node_type == "SOURCE": content_hash = step.config.get("content_hash") if not content_hash: raise ValueError(f"SOURCE step missing content_hash") # Look up in cache path = cache_mgr.get_by_content_hash(content_hash) if not path: raise ValueError(f"SOURCE input not found in cache: {content_hash[:16]}...") output_path = str(path) complete_task(step.cache_id, worker_id, output_path) return { "status": "completed", "step_id": step.step_id, "cache_id": step.cache_id, "output_path": output_path, } # Handle _LIST virtual nodes if step.node_type == "_LIST": item_paths = [] for item_id in step.config.get("items", []): item_cache_id = input_cache_ids.get(item_id) if item_cache_id: path = cache_mgr.get_by_content_hash(item_cache_id) if path: item_paths.append(str(path)) complete_task(step.cache_id, worker_id, json.dumps(item_paths)) return { "status": "completed", "step_id": step.step_id, "cache_id": step.cache_id, "output_path": None, "item_paths": item_paths, } # Get executor for this node type try: node_type = NodeType[step.node_type] except KeyError: node_type = step.node_type executor = get_executor(node_type) if executor is None: raise ValueError(f"No executor for node type: {step.node_type}") # Resolve input paths from cache input_paths = [] for input_step_id in step.input_steps: input_cache_id = input_cache_ids.get(input_step_id) if not input_cache_id: raise ValueError(f"No cache_id for input step: {input_step_id}") path = cache_mgr.get_by_content_hash(input_cache_id) if not path: raise ValueError(f"Input not in cache: {input_cache_id[:16]}...") input_paths.append(Path(path)) # Create temp output path import tempfile output_dir = Path(tempfile.mkdtemp()) output_path = output_dir / f"output_{step.cache_id[:16]}.mp4" # Execute logger.info(f"Running executor for {step.node_type} with {len(input_paths)} inputs") result_path = executor.execute(step.config, input_paths, output_path) # Store in IPFS-backed cache cached_file, ipfs_cid = cache_mgr.put( source_path=result_path, node_type=step.node_type, node_id=step.cache_id, ) logger.info(f"Step {step.step_id} completed, IPFS CID: {ipfs_cid}") # Mark completed complete_task(step.cache_id, worker_id, str(cached_file.path)) # Build outputs list (for multi-output support) outputs = [] if step.outputs: # Use pre-defined outputs from step for output_def in step.outputs: outputs.append({ "name": output_def.name, "cache_id": output_def.cache_id, "media_type": output_def.media_type, "index": output_def.index, "path": str(cached_file.path), "content_hash": cached_file.content_hash, "ipfs_cid": ipfs_cid, }) else: # Single output (backwards compat) outputs.append({ "name": step.name or step.step_id, "cache_id": step.cache_id, "media_type": "video/mp4", "index": 0, "path": str(cached_file.path), "content_hash": cached_file.content_hash, "ipfs_cid": ipfs_cid, }) # Cleanup temp if output_dir.exists(): import shutil shutil.rmtree(output_dir, ignore_errors=True) return { "status": "completed", "step_id": step.step_id, "name": step.name, "cache_id": step.cache_id, "output_path": str(cached_file.path), "content_hash": cached_file.content_hash, "ipfs_cid": ipfs_cid, "outputs": outputs, } except Exception as e: logger.error(f"Step {step.step_id} failed: {e}") fail_task(step.cache_id, worker_id, str(e)) return { "status": "failed", "step_id": step.step_id, "cache_id": step.cache_id, "error": str(e), } @app.task(bind=True, name='tasks.execute_level') def execute_level( self, steps_json: List[str], plan_id: str, cache_ids: Dict[str, str], ) -> dict: """ Execute all steps at a given dependency level. Steps at the same level can run in parallel. Args: steps_json: List of JSON-serialized ExecutionSteps plan_id: ID of the parent execution plan cache_ids: Mapping from step_id to cache_id Returns: Dict with results for all steps """ from celery import group # Dispatch all steps in parallel tasks = [ execute_step.s(step_json, plan_id, cache_ids) for step_json in steps_json ] # Execute in parallel and collect results job = group(tasks) results = job.apply_async() # Wait for completion step_results = results.get(timeout=3600) # 1 hour timeout # Build cache_ids from results new_cache_ids = dict(cache_ids) for result in step_results: step_id = result.get("step_id") cache_id = result.get("cache_id") if step_id and cache_id: new_cache_ids[step_id] = cache_id return { "status": "completed", "results": step_results, "cache_ids": new_cache_ids, }