Files
celery/tasks/execute.py
gilesb 3db606bf15 Plan-based caching with artifact playback in UI
RunStatus now stores:
- plan_id, plan_name for linking to execution plan
- step_results for per-step execution status
- all_outputs for all artifacts from all steps

Plan visualization:
- Shows human-readable step names from recipe structure
- Video/audio artifact preview on node click
- Outputs list with links to cached artifacts
- Stats reflect actual execution status (completed/cached/pending)

Execution:
- Step results include outputs list with cache_ids
- run_plan returns all outputs from all steps
- Support for completed_by_other status

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 00:20:19 +00:00

327 lines
9.9 KiB
Python

"""
Step execution task.
Phase 3 of the 3-phase execution model.
Executes individual steps from an execution plan with IPFS-backed caching.
"""
import json
import logging
import os
import socket
from pathlib import Path
from typing import Dict, List, Optional
from celery import current_task
# Import from the Celery app
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
from claiming import (
get_claimer,
claim_task,
complete_task,
fail_task,
ClaimStatus,
)
from cache_manager import get_cache_manager, L1CacheManager
# Import artdag
try:
from artdag import Cache, NodeType
from artdag.executor import get_executor
from artdag.planning import ExecutionStep
except ImportError:
Cache = None
NodeType = None
get_executor = None
ExecutionStep = None
logger = logging.getLogger(__name__)
def get_worker_id() -> str:
"""Get a unique identifier for this worker."""
hostname = socket.gethostname()
pid = os.getpid()
return f"{hostname}:{pid}"
@app.task(bind=True, name='tasks.execute_step')
def execute_step(
self,
step_json: str,
plan_id: str,
input_cache_ids: Dict[str, str],
) -> dict:
"""
Execute a single step from an execution plan.
Uses hash-based claiming to prevent duplicate work.
Results are stored in IPFS-backed cache.
Args:
step_json: JSON-serialized ExecutionStep
plan_id: ID of the parent execution plan
input_cache_ids: Mapping from input step_id to their cache_id
Returns:
Dict with execution result
"""
if ExecutionStep is None:
raise ImportError("artdag.planning not available")
step = ExecutionStep.from_json(step_json)
worker_id = get_worker_id()
task_id = self.request.id
logger.info(f"Executing step {step.step_id} ({step.node_type}) cache_id={step.cache_id[:16]}...")
# Get L1 cache manager (IPFS-backed)
cache_mgr = get_cache_manager()
# Check if already cached (by cache_id as content_hash)
cached_path = cache_mgr.get_by_content_hash(step.cache_id)
if cached_path:
logger.info(f"Step {step.step_id} already cached at {cached_path}")
# Mark as cached in claiming system
claimer = get_claimer()
claimer.mark_cached(step.cache_id, str(cached_path))
return {
"status": "cached",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": str(cached_path),
}
# Try to claim the task
if not claim_task(step.cache_id, worker_id, task_id):
# Another worker is handling it
logger.info(f"Step {step.step_id} claimed by another worker, waiting...")
claimer = get_claimer()
result = claimer.wait_for_completion(step.cache_id, timeout=600)
if result and result.status == ClaimStatus.COMPLETED:
return {
"status": "completed_by_other",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": result.output_path,
}
elif result and result.status == ClaimStatus.CACHED:
return {
"status": "cached",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": result.output_path,
}
elif result and result.status == ClaimStatus.FAILED:
return {
"status": "failed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": result.error,
}
else:
return {
"status": "timeout",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": "Timeout waiting for other worker",
}
# We have the claim, update to running
claimer = get_claimer()
claimer.update_status(step.cache_id, worker_id, ClaimStatus.RUNNING)
try:
# Handle SOURCE nodes
if step.node_type == "SOURCE":
content_hash = step.config.get("content_hash")
if not content_hash:
raise ValueError(f"SOURCE step missing content_hash")
# Look up in cache
path = cache_mgr.get_by_content_hash(content_hash)
if not path:
raise ValueError(f"SOURCE input not found in cache: {content_hash[:16]}...")
output_path = str(path)
complete_task(step.cache_id, worker_id, output_path)
return {
"status": "completed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": output_path,
}
# Handle _LIST virtual nodes
if step.node_type == "_LIST":
item_paths = []
for item_id in step.config.get("items", []):
item_cache_id = input_cache_ids.get(item_id)
if item_cache_id:
path = cache_mgr.get_by_content_hash(item_cache_id)
if path:
item_paths.append(str(path))
complete_task(step.cache_id, worker_id, json.dumps(item_paths))
return {
"status": "completed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": None,
"item_paths": item_paths,
}
# Get executor for this node type
try:
node_type = NodeType[step.node_type]
except KeyError:
node_type = step.node_type
executor = get_executor(node_type)
if executor is None:
raise ValueError(f"No executor for node type: {step.node_type}")
# Resolve input paths from cache
input_paths = []
for input_step_id in step.input_steps:
input_cache_id = input_cache_ids.get(input_step_id)
if not input_cache_id:
raise ValueError(f"No cache_id for input step: {input_step_id}")
path = cache_mgr.get_by_content_hash(input_cache_id)
if not path:
raise ValueError(f"Input not in cache: {input_cache_id[:16]}...")
input_paths.append(Path(path))
# Create temp output path
import tempfile
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / f"output_{step.cache_id[:16]}.mp4"
# Execute
logger.info(f"Running executor for {step.node_type} with {len(input_paths)} inputs")
result_path = executor.execute(step.config, input_paths, output_path)
# Store in IPFS-backed cache
cached_file, ipfs_cid = cache_mgr.put(
source_path=result_path,
node_type=step.node_type,
node_id=step.cache_id,
)
logger.info(f"Step {step.step_id} completed, IPFS CID: {ipfs_cid}")
# Mark completed
complete_task(step.cache_id, worker_id, str(cached_file.path))
# Build outputs list (for multi-output support)
outputs = []
if step.outputs:
# Use pre-defined outputs from step
for output_def in step.outputs:
outputs.append({
"name": output_def.name,
"cache_id": output_def.cache_id,
"media_type": output_def.media_type,
"index": output_def.index,
"path": str(cached_file.path),
"content_hash": cached_file.content_hash,
"ipfs_cid": ipfs_cid,
})
else:
# Single output (backwards compat)
outputs.append({
"name": step.name or step.step_id,
"cache_id": step.cache_id,
"media_type": "video/mp4",
"index": 0,
"path": str(cached_file.path),
"content_hash": cached_file.content_hash,
"ipfs_cid": ipfs_cid,
})
# Cleanup temp
if output_dir.exists():
import shutil
shutil.rmtree(output_dir, ignore_errors=True)
return {
"status": "completed",
"step_id": step.step_id,
"name": step.name,
"cache_id": step.cache_id,
"output_path": str(cached_file.path),
"content_hash": cached_file.content_hash,
"ipfs_cid": ipfs_cid,
"outputs": outputs,
}
except Exception as e:
logger.error(f"Step {step.step_id} failed: {e}")
fail_task(step.cache_id, worker_id, str(e))
return {
"status": "failed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": str(e),
}
@app.task(bind=True, name='tasks.execute_level')
def execute_level(
self,
steps_json: List[str],
plan_id: str,
cache_ids: Dict[str, str],
) -> dict:
"""
Execute all steps at a given dependency level.
Steps at the same level can run in parallel.
Args:
steps_json: List of JSON-serialized ExecutionSteps
plan_id: ID of the parent execution plan
cache_ids: Mapping from step_id to cache_id
Returns:
Dict with results for all steps
"""
from celery import group
# Dispatch all steps in parallel
tasks = [
execute_step.s(step_json, plan_id, cache_ids)
for step_json in steps_json
]
# Execute in parallel and collect results
job = group(tasks)
results = job.apply_async()
# Wait for completion
step_results = results.get(timeout=3600) # 1 hour timeout
# Build cache_ids from results
new_cache_ids = dict(cache_ids)
for result in step_results:
step_id = result.get("step_id")
cache_id = result.get("cache_id")
if step_id and cache_id:
new_cache_ids[step_id] = cache_id
return {
"status": "completed",
"results": step_results,
"cache_ids": new_cache_ids,
}