From 383dbf6e03cc5cac3424d87eae7733742e07c041 Mon Sep 17 00:00:00 2001 From: gilesb Date: Sun, 11 Jan 2026 07:42:23 +0000 Subject: [PATCH] Add IPFS-primary execute_step_cid implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplified step execution where: - Steps receive CIDs, produce CIDs - No local cache management (IPFS handles it) - Minimal Redis: just claims + cache_id→CID mapping - Temp workspace for execution, cleaned up after Co-Authored-By: Claude Opus 4.5 --- tasks/execute_cid.py | 311 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tasks/execute_cid.py diff --git a/tasks/execute_cid.py b/tasks/execute_cid.py new file mode 100644 index 0000000..094bd0e --- /dev/null +++ b/tasks/execute_cid.py @@ -0,0 +1,311 @@ +""" +Simplified step execution with IPFS-primary architecture. + +Steps receive CIDs, produce CIDs. No file paths cross machine boundaries. +IPFS nodes form a distributed cache automatically. +""" + +import logging +import os +import shutil +import socket +import tempfile +from pathlib import Path +from typing import Dict, Optional + +from celery import current_task + +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from celery_app import app +import ipfs_client + +# Redis for claiming and cache_id → CID mapping +import redis +REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5") +_redis: Optional[redis.Redis] = None + +def get_redis() -> redis.Redis: + global _redis + if _redis is None: + _redis = redis.from_url(REDIS_URL, decode_responses=True) + return _redis + +# Import artdag +try: + from artdag import NodeType + from artdag.executor import get_executor + from artdag.planning import ExecutionStep + from artdag import nodes # Register executors +except ImportError: + NodeType = None + get_executor = None + ExecutionStep = None + +logger = logging.getLogger(__name__) + +# Redis keys +CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID +CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id + + +def get_worker_id() -> str: + """Get unique worker identifier.""" + return f"{socket.gethostname()}:{os.getpid()}" + + +def get_cached_cid(cache_id: str) -> Optional[str]: + """Check if cache_id has a known CID.""" + return get_redis().hget(CACHE_KEY, cache_id) + + +def set_cached_cid(cache_id: str, cid: str) -> None: + """Store cache_id → CID mapping.""" + get_redis().hset(CACHE_KEY, cache_id, cid) + + +def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool: + """Try to claim a cache_id for execution. Returns True if claimed.""" + key = f"{CLAIM_KEY_PREFIX}{cache_id}" + return get_redis().set(key, worker_id, nx=True, ex=ttl) + + +def release_claim(cache_id: str) -> None: + """Release a claim.""" + key = f"{CLAIM_KEY_PREFIX}{cache_id}" + get_redis().delete(key) + + +def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]: + """Wait for another worker to produce a CID for cache_id.""" + import time + start = time.time() + while time.time() - start < timeout: + cid = get_cached_cid(cache_id) + if cid: + return cid + time.sleep(poll_interval) + return None + + +def fetch_from_ipfs(cid: str, dest_dir: Path) -> Path: + """Fetch a CID from IPFS to a local temp file.""" + dest_path = dest_dir / f"{cid}.mkv" + if not ipfs_client.get_file(cid, dest_path): + raise RuntimeError(f"Failed to fetch CID from IPFS: {cid}") + return dest_path + + +@app.task(bind=True, name='tasks.execute_step_cid') +def execute_step_cid( + self, + step_json: str, + input_cids: Dict[str, str], +) -> Dict: + """ + Execute a step using IPFS-primary architecture. + + Args: + step_json: JSON-serialized ExecutionStep + input_cids: Mapping from input step_id to their IPFS CID + + Returns: + Dict with 'cid' (output CID) and 'status' + """ + if ExecutionStep is None: + raise ImportError("artdag not available") + + step = ExecutionStep.from_json(step_json) + worker_id = get_worker_id() + + logger.info(f"[CID] Executing {step.step_id} ({step.node_type})") + + # 1. Check if already computed + existing_cid = get_cached_cid(step.cache_id) + if existing_cid: + logger.info(f"[CID] Cache hit: {step.cache_id[:16]}... → {existing_cid}") + return { + "status": "cached", + "step_id": step.step_id, + "cache_id": step.cache_id, + "cid": existing_cid, + } + + # 2. Try to claim + if not try_claim(step.cache_id, worker_id): + logger.info(f"[CID] Claimed by another worker, waiting...") + cid = wait_for_cid(step.cache_id) + if cid: + return { + "status": "completed_by_other", + "step_id": step.step_id, + "cache_id": step.cache_id, + "cid": cid, + } + return { + "status": "timeout", + "step_id": step.step_id, + "cache_id": step.cache_id, + "error": "Timeout waiting for other worker", + } + + # 3. We have the claim - execute + try: + # Handle SOURCE nodes + if step.node_type == "SOURCE": + # SOURCE nodes should have their CID in input_cids + source_name = step.config.get("name") or step.step_id + cid = input_cids.get(source_name) or input_cids.get(step.step_id) + if not cid: + raise ValueError(f"SOURCE missing input CID: {source_name}") + + set_cached_cid(step.cache_id, cid) + return { + "status": "completed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "cid": cid, + } + + # Get executor + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + raise ValueError(f"No executor for: {step.node_type}") + + # Create temp workspace + work_dir = Path(tempfile.mkdtemp(prefix="artdag_")) + + try: + # Fetch inputs from IPFS + input_paths = [] + for i, input_step_id in enumerate(step.input_steps): + input_cid = input_cids.get(input_step_id) + if not input_cid: + raise ValueError(f"Missing input CID for: {input_step_id}") + + input_path = work_dir / f"input_{i}_{input_cid[:16]}.mkv" + logger.info(f"[CID] Fetching input {i}: {input_cid}") + + if not ipfs_client.get_file(input_cid, input_path): + raise RuntimeError(f"Failed to fetch: {input_cid}") + + input_paths.append(input_path) + + # Execute + output_path = work_dir / f"output_{step.cache_id[:16]}.mkv" + logger.info(f"[CID] Running {step.node_type} with {len(input_paths)} inputs") + + result_path = executor.execute(step.config, input_paths, output_path) + + # Add output to IPFS + output_cid = ipfs_client.add_file(result_path) + if not output_cid: + raise RuntimeError("Failed to add output to IPFS") + + logger.info(f"[CID] Completed: {step.step_id} → {output_cid}") + + # Store mapping + set_cached_cid(step.cache_id, output_cid) + + return { + "status": "completed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "cid": output_cid, + } + + finally: + # Cleanup temp workspace + shutil.rmtree(work_dir, ignore_errors=True) + + except Exception as e: + logger.error(f"[CID] Failed: {step.step_id}: {e}") + release_claim(step.cache_id) + return { + "status": "failed", + "step_id": step.step_id, + "cache_id": step.cache_id, + "error": str(e), + } + + +@app.task(bind=True, name='tasks.execute_plan_cid') +def execute_plan_cid( + self, + plan_json: str, + input_cids: Dict[str, str], +) -> Dict: + """ + Execute an entire plan using IPFS-primary architecture. + + Args: + plan_json: JSON-serialized ExecutionPlan + input_cids: Mapping from input name to IPFS CID + + Returns: + Dict with 'output_cid' and per-step results + """ + from celery import group + from artdag.planning import ExecutionPlan + + plan = ExecutionPlan.from_json(plan_json) + logger.info(f"[CID] Executing plan: {plan.plan_id[:16]}... ({len(plan.steps)} steps)") + + # CID results accumulate as steps complete + cid_results = dict(input_cids) + + # Also map step_id → CID for dependency resolution + step_cids = {} + + steps_by_level = plan.get_steps_by_level() + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + logger.info(f"[CID] Level {level}: {len(steps)} steps") + + # Build input CIDs for this level + level_input_cids = dict(cid_results) + level_input_cids.update(step_cids) + + # Dispatch steps in parallel + tasks = [ + execute_step_cid.s(step.to_json(), level_input_cids) + for step in steps + ] + + if len(tasks) == 1: + # Single task - run directly + results = [tasks[0].apply_async().get(timeout=3600)] + else: + # Multiple tasks - run in parallel + job = group(tasks) + results = job.apply_async().get(timeout=3600) + + # Collect output CIDs + for step, result in zip(steps, results): + if result.get("status") in ("completed", "cached", "completed_by_other"): + step_cids[step.step_id] = result["cid"] + else: + return { + "status": "failed", + "failed_step": step.step_id, + "error": result.get("error", "Unknown error"), + } + + # Get final output CID + output_step_id = plan.output_step or plan.steps[-1].step_id + output_cid = step_cids.get(output_step_id) + + logger.info(f"[CID] Plan complete: {output_cid}") + + return { + "status": "completed", + "plan_id": plan.plan_id, + "output_cid": output_cid, + "step_cids": step_cids, + }