""" Simplified step execution with IPFS-primary architecture. Steps receive CIDs, produce CIDs. No file paths cross machine boundaries. IPFS nodes form a distributed cache automatically. """ import logging import os import shutil import socket import tempfile from pathlib import Path from typing import Dict, Optional from celery import current_task import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from celery_app import app import ipfs_client # Redis for claiming and cache_id → CID mapping import redis REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5") _redis: Optional[redis.Redis] = None def get_redis() -> redis.Redis: global _redis if _redis is None: _redis = redis.from_url(REDIS_URL, decode_responses=True) return _redis # Import artdag try: from artdag import NodeType from artdag.executor import get_executor from artdag.planning import ExecutionStep from artdag import nodes # Register executors except ImportError: NodeType = None get_executor = None ExecutionStep = None logger = logging.getLogger(__name__) # Redis keys CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id def get_worker_id() -> str: """Get unique worker identifier.""" return f"{socket.gethostname()}:{os.getpid()}" def get_cached_cid(cache_id: str) -> Optional[str]: """Check if cache_id has a known CID.""" return get_redis().hget(CACHE_KEY, cache_id) def set_cached_cid(cache_id: str, cid: str) -> None: """Store cache_id → CID mapping.""" get_redis().hset(CACHE_KEY, cache_id, cid) def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool: """Try to claim a cache_id for execution. Returns True if claimed.""" key = f"{CLAIM_KEY_PREFIX}{cache_id}" return get_redis().set(key, worker_id, nx=True, ex=ttl) def release_claim(cache_id: str) -> None: """Release a claim.""" key = f"{CLAIM_KEY_PREFIX}{cache_id}" get_redis().delete(key) def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]: """Wait for another worker to produce a CID for cache_id.""" import time start = time.time() while time.time() - start < timeout: cid = get_cached_cid(cache_id) if cid: return cid time.sleep(poll_interval) return None def fetch_from_ipfs(cid: str, dest_dir: Path) -> Path: """Fetch a CID from IPFS to a local temp file.""" dest_path = dest_dir / f"{cid}.mkv" if not ipfs_client.get_file(cid, dest_path): raise RuntimeError(f"Failed to fetch CID from IPFS: {cid}") return dest_path @app.task(bind=True, name='tasks.execute_step_cid') def execute_step_cid( self, step_json: str, input_cids: Dict[str, str], ) -> Dict: """ Execute a step using IPFS-primary architecture. Args: step_json: JSON-serialized ExecutionStep input_cids: Mapping from input step_id to their IPFS CID Returns: Dict with 'cid' (output CID) and 'status' """ if ExecutionStep is None: raise ImportError("artdag not available") step = ExecutionStep.from_json(step_json) worker_id = get_worker_id() logger.info(f"[CID] Executing {step.step_id} ({step.node_type})") # 1. Check if already computed existing_cid = get_cached_cid(step.cache_id) if existing_cid: logger.info(f"[CID] Cache hit: {step.cache_id[:16]}... → {existing_cid}") return { "status": "cached", "step_id": step.step_id, "cache_id": step.cache_id, "cid": existing_cid, } # 2. Try to claim if not try_claim(step.cache_id, worker_id): logger.info(f"[CID] Claimed by another worker, waiting...") cid = wait_for_cid(step.cache_id) if cid: return { "status": "completed_by_other", "step_id": step.step_id, "cache_id": step.cache_id, "cid": cid, } return { "status": "timeout", "step_id": step.step_id, "cache_id": step.cache_id, "error": "Timeout waiting for other worker", } # 3. We have the claim - execute try: # Handle SOURCE nodes if step.node_type == "SOURCE": # SOURCE nodes should have their CID in input_cids source_name = step.config.get("name") or step.step_id cid = input_cids.get(source_name) or input_cids.get(step.step_id) if not cid: raise ValueError(f"SOURCE missing input CID: {source_name}") set_cached_cid(step.cache_id, cid) return { "status": "completed", "step_id": step.step_id, "cache_id": step.cache_id, "cid": cid, } # Get executor try: node_type = NodeType[step.node_type] except KeyError: node_type = step.node_type executor = get_executor(node_type) if executor is None: raise ValueError(f"No executor for: {step.node_type}") # Create temp workspace work_dir = Path(tempfile.mkdtemp(prefix="artdag_")) try: # Fetch inputs from IPFS input_paths = [] for i, input_step_id in enumerate(step.input_steps): input_cid = input_cids.get(input_step_id) if not input_cid: raise ValueError(f"Missing input CID for: {input_step_id}") input_path = work_dir / f"input_{i}_{input_cid[:16]}.mkv" logger.info(f"[CID] Fetching input {i}: {input_cid}") if not ipfs_client.get_file(input_cid, input_path): raise RuntimeError(f"Failed to fetch: {input_cid}") input_paths.append(input_path) # Execute output_path = work_dir / f"output_{step.cache_id[:16]}.mkv" logger.info(f"[CID] Running {step.node_type} with {len(input_paths)} inputs") result_path = executor.execute(step.config, input_paths, output_path) # Add output to IPFS output_cid = ipfs_client.add_file(result_path) if not output_cid: raise RuntimeError("Failed to add output to IPFS") logger.info(f"[CID] Completed: {step.step_id} → {output_cid}") # Store mapping set_cached_cid(step.cache_id, output_cid) return { "status": "completed", "step_id": step.step_id, "cache_id": step.cache_id, "cid": output_cid, } finally: # Cleanup temp workspace shutil.rmtree(work_dir, ignore_errors=True) except Exception as e: logger.error(f"[CID] Failed: {step.step_id}: {e}") release_claim(step.cache_id) return { "status": "failed", "step_id": step.step_id, "cache_id": step.cache_id, "error": str(e), } @app.task(bind=True, name='tasks.execute_plan_cid') def execute_plan_cid( self, plan_json: str, input_cids: Dict[str, str], ) -> Dict: """ Execute an entire plan using IPFS-primary architecture. Args: plan_json: JSON-serialized ExecutionPlan input_cids: Mapping from input name to IPFS CID Returns: Dict with 'output_cid' and per-step results """ from celery import group from artdag.planning import ExecutionPlan plan = ExecutionPlan.from_json(plan_json) logger.info(f"[CID] Executing plan: {plan.plan_id[:16]}... ({len(plan.steps)} steps)") # CID results accumulate as steps complete cid_results = dict(input_cids) # Also map step_id → CID for dependency resolution step_cids = {} steps_by_level = plan.get_steps_by_level() for level in sorted(steps_by_level.keys()): steps = steps_by_level[level] logger.info(f"[CID] Level {level}: {len(steps)} steps") # Build input CIDs for this level level_input_cids = dict(cid_results) level_input_cids.update(step_cids) # Dispatch steps in parallel tasks = [ execute_step_cid.s(step.to_json(), level_input_cids) for step in steps ] if len(tasks) == 1: # Single task - run directly results = [tasks[0].apply_async().get(timeout=3600)] else: # Multiple tasks - run in parallel job = group(tasks) results = job.apply_async().get(timeout=3600) # Collect output CIDs for step, result in zip(steps, results): if result.get("status") in ("completed", "cached", "completed_by_other"): step_cids[step.step_id] = result["cid"] else: return { "status": "failed", "failed_step": step.step_id, "error": result.get("error", "Unknown error"), } # Get final output CID output_step_id = plan.output_step or plan.steps[-1].step_id output_cid = step_cids.get(output_step_id) logger.info(f"[CID] Plan complete: {output_cid}") return { "status": "completed", "plan_id": plan.plan_id, "output_cid": output_cid, "step_cids": step_cids, }