Add IPFS-primary execute_step_cid implementation
Simplified step execution where: - Steps receive CIDs, produce CIDs - No local cache management (IPFS handles it) - Minimal Redis: just claims + cache_id→CID mapping - Temp workspace for execution, cleaned up after Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
311
tasks/execute_cid.py
Normal file
311
tasks/execute_cid.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Simplified step execution with IPFS-primary architecture.
|
||||
|
||||
Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.
|
||||
IPFS nodes form a distributed cache automatically.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from celery import current_task
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from celery_app import app
|
||||
import ipfs_client
|
||||
|
||||
# Redis for claiming and cache_id → CID mapping
|
||||
import redis
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
|
||||
# Import artdag
|
||||
try:
|
||||
from artdag import NodeType
|
||||
from artdag.executor import get_executor
|
||||
from artdag.planning import ExecutionStep
|
||||
from artdag import nodes # Register executors
|
||||
except ImportError:
|
||||
NodeType = None
|
||||
get_executor = None
|
||||
ExecutionStep = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis keys
|
||||
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID
|
||||
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id
|
||||
|
||||
|
||||
def get_worker_id() -> str:
|
||||
"""Get unique worker identifier."""
|
||||
return f"{socket.gethostname()}:{os.getpid()}"
|
||||
|
||||
|
||||
def get_cached_cid(cache_id: str) -> Optional[str]:
|
||||
"""Check if cache_id has a known CID."""
|
||||
return get_redis().hget(CACHE_KEY, cache_id)
|
||||
|
||||
|
||||
def set_cached_cid(cache_id: str, cid: str) -> None:
|
||||
"""Store cache_id → CID mapping."""
|
||||
get_redis().hset(CACHE_KEY, cache_id, cid)
|
||||
|
||||
|
||||
def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool:
|
||||
"""Try to claim a cache_id for execution. Returns True if claimed."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
return get_redis().set(key, worker_id, nx=True, ex=ttl)
|
||||
|
||||
|
||||
def release_claim(cache_id: str) -> None:
|
||||
"""Release a claim."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
get_redis().delete(key)
|
||||
|
||||
|
||||
def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]:
|
||||
"""Wait for another worker to produce a CID for cache_id."""
|
||||
import time
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
cid = get_cached_cid(cache_id)
|
||||
if cid:
|
||||
return cid
|
||||
time.sleep(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def fetch_from_ipfs(cid: str, dest_dir: Path) -> Path:
|
||||
"""Fetch a CID from IPFS to a local temp file."""
|
||||
dest_path = dest_dir / f"{cid}.mkv"
|
||||
if not ipfs_client.get_file(cid, dest_path):
|
||||
raise RuntimeError(f"Failed to fetch CID from IPFS: {cid}")
|
||||
return dest_path
|
||||
|
||||
|
||||
@app.task(bind=True, name='tasks.execute_step_cid')
|
||||
def execute_step_cid(
|
||||
self,
|
||||
step_json: str,
|
||||
input_cids: Dict[str, str],
|
||||
) -> Dict:
|
||||
"""
|
||||
Execute a step using IPFS-primary architecture.
|
||||
|
||||
Args:
|
||||
step_json: JSON-serialized ExecutionStep
|
||||
input_cids: Mapping from input step_id to their IPFS CID
|
||||
|
||||
Returns:
|
||||
Dict with 'cid' (output CID) and 'status'
|
||||
"""
|
||||
if ExecutionStep is None:
|
||||
raise ImportError("artdag not available")
|
||||
|
||||
step = ExecutionStep.from_json(step_json)
|
||||
worker_id = get_worker_id()
|
||||
|
||||
logger.info(f"[CID] Executing {step.step_id} ({step.node_type})")
|
||||
|
||||
# 1. Check if already computed
|
||||
existing_cid = get_cached_cid(step.cache_id)
|
||||
if existing_cid:
|
||||
logger.info(f"[CID] Cache hit: {step.cache_id[:16]}... → {existing_cid}")
|
||||
return {
|
||||
"status": "cached",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"cid": existing_cid,
|
||||
}
|
||||
|
||||
# 2. Try to claim
|
||||
if not try_claim(step.cache_id, worker_id):
|
||||
logger.info(f"[CID] Claimed by another worker, waiting...")
|
||||
cid = wait_for_cid(step.cache_id)
|
||||
if cid:
|
||||
return {
|
||||
"status": "completed_by_other",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"cid": cid,
|
||||
}
|
||||
return {
|
||||
"status": "timeout",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"error": "Timeout waiting for other worker",
|
||||
}
|
||||
|
||||
# 3. We have the claim - execute
|
||||
try:
|
||||
# Handle SOURCE nodes
|
||||
if step.node_type == "SOURCE":
|
||||
# SOURCE nodes should have their CID in input_cids
|
||||
source_name = step.config.get("name") or step.step_id
|
||||
cid = input_cids.get(source_name) or input_cids.get(step.step_id)
|
||||
if not cid:
|
||||
raise ValueError(f"SOURCE missing input CID: {source_name}")
|
||||
|
||||
set_cached_cid(step.cache_id, cid)
|
||||
return {
|
||||
"status": "completed",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"cid": cid,
|
||||
}
|
||||
|
||||
# Get executor
|
||||
try:
|
||||
node_type = NodeType[step.node_type]
|
||||
except KeyError:
|
||||
node_type = step.node_type
|
||||
|
||||
executor = get_executor(node_type)
|
||||
if executor is None:
|
||||
raise ValueError(f"No executor for: {step.node_type}")
|
||||
|
||||
# Create temp workspace
|
||||
work_dir = Path(tempfile.mkdtemp(prefix="artdag_"))
|
||||
|
||||
try:
|
||||
# Fetch inputs from IPFS
|
||||
input_paths = []
|
||||
for i, input_step_id in enumerate(step.input_steps):
|
||||
input_cid = input_cids.get(input_step_id)
|
||||
if not input_cid:
|
||||
raise ValueError(f"Missing input CID for: {input_step_id}")
|
||||
|
||||
input_path = work_dir / f"input_{i}_{input_cid[:16]}.mkv"
|
||||
logger.info(f"[CID] Fetching input {i}: {input_cid}")
|
||||
|
||||
if not ipfs_client.get_file(input_cid, input_path):
|
||||
raise RuntimeError(f"Failed to fetch: {input_cid}")
|
||||
|
||||
input_paths.append(input_path)
|
||||
|
||||
# Execute
|
||||
output_path = work_dir / f"output_{step.cache_id[:16]}.mkv"
|
||||
logger.info(f"[CID] Running {step.node_type} with {len(input_paths)} inputs")
|
||||
|
||||
result_path = executor.execute(step.config, input_paths, output_path)
|
||||
|
||||
# Add output to IPFS
|
||||
output_cid = ipfs_client.add_file(result_path)
|
||||
if not output_cid:
|
||||
raise RuntimeError("Failed to add output to IPFS")
|
||||
|
||||
logger.info(f"[CID] Completed: {step.step_id} → {output_cid}")
|
||||
|
||||
# Store mapping
|
||||
set_cached_cid(step.cache_id, output_cid)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"cid": output_cid,
|
||||
}
|
||||
|
||||
finally:
|
||||
# Cleanup temp workspace
|
||||
shutil.rmtree(work_dir, ignore_errors=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[CID] Failed: {step.step_id}: {e}")
|
||||
release_claim(step.cache_id)
|
||||
return {
|
||||
"status": "failed",
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
@app.task(bind=True, name='tasks.execute_plan_cid')
|
||||
def execute_plan_cid(
|
||||
self,
|
||||
plan_json: str,
|
||||
input_cids: Dict[str, str],
|
||||
) -> Dict:
|
||||
"""
|
||||
Execute an entire plan using IPFS-primary architecture.
|
||||
|
||||
Args:
|
||||
plan_json: JSON-serialized ExecutionPlan
|
||||
input_cids: Mapping from input name to IPFS CID
|
||||
|
||||
Returns:
|
||||
Dict with 'output_cid' and per-step results
|
||||
"""
|
||||
from celery import group
|
||||
from artdag.planning import ExecutionPlan
|
||||
|
||||
plan = ExecutionPlan.from_json(plan_json)
|
||||
logger.info(f"[CID] Executing plan: {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
|
||||
|
||||
# CID results accumulate as steps complete
|
||||
cid_results = dict(input_cids)
|
||||
|
||||
# Also map step_id → CID for dependency resolution
|
||||
step_cids = {}
|
||||
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
logger.info(f"[CID] Level {level}: {len(steps)} steps")
|
||||
|
||||
# Build input CIDs for this level
|
||||
level_input_cids = dict(cid_results)
|
||||
level_input_cids.update(step_cids)
|
||||
|
||||
# Dispatch steps in parallel
|
||||
tasks = [
|
||||
execute_step_cid.s(step.to_json(), level_input_cids)
|
||||
for step in steps
|
||||
]
|
||||
|
||||
if len(tasks) == 1:
|
||||
# Single task - run directly
|
||||
results = [tasks[0].apply_async().get(timeout=3600)]
|
||||
else:
|
||||
# Multiple tasks - run in parallel
|
||||
job = group(tasks)
|
||||
results = job.apply_async().get(timeout=3600)
|
||||
|
||||
# Collect output CIDs
|
||||
for step, result in zip(steps, results):
|
||||
if result.get("status") in ("completed", "cached", "completed_by_other"):
|
||||
step_cids[step.step_id] = result["cid"]
|
||||
else:
|
||||
return {
|
||||
"status": "failed",
|
||||
"failed_step": step.step_id,
|
||||
"error": result.get("error", "Unknown error"),
|
||||
}
|
||||
|
||||
# Get final output CID
|
||||
output_step_id = plan.output_step or plan.steps[-1].step_id
|
||||
output_cid = step_cids.get(output_step_id)
|
||||
|
||||
logger.info(f"[CID] Plan complete: {output_cid}")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"plan_id": plan.plan_id,
|
||||
"output_cid": output_cid,
|
||||
"step_cids": step_cids,
|
||||
}
|
||||
Reference in New Issue
Block a user