Add 3-phase execution with IPFS cache and hash-based task claiming

New files:
- claiming.py - Redis Lua scripts for atomic task claiming
- tasks/analyze.py - Analysis Celery task
- tasks/execute.py - Step execution with IPFS-backed cache
- tasks/orchestrate.py - Plan orchestration (run_plan, run_recipe)

New API endpoints (/api/v2/):
- POST /api/v2/plan - Generate execution plan
- POST /api/v2/execute - Execute a plan
- POST /api/v2/run-recipe - Full 3-phase pipeline
- GET /api/v2/run/{run_id} - Get run status

Features:
- Hash-based task claiming prevents duplicate work
- Parallel execution within dependency levels
- IPFS-backed cache for durability
- Integration with artdag planning module

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-01-10 11:44:00 +00:00
parent 7d05011daa
commit f7890dd1ad
7 changed files with 1468 additions and 1 deletions

View File

@@ -14,7 +14,7 @@ app = Celery(
'art_celery',
broker=REDIS_URL,
backend=REDIS_URL,
include=['tasks']
include=['tasks', 'tasks.analyze', 'tasks.execute', 'tasks.orchestrate']
)
app.conf.update(

421
claiming.py Normal file
View File

@@ -0,0 +1,421 @@
"""
Hash-based task claiming for distributed execution.
Prevents duplicate work when multiple workers process the same plan.
Uses Redis Lua scripts for atomic claim operations.
"""
import json
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
import redis
logger = logging.getLogger(__name__)
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
# Key prefix for task claims
CLAIM_PREFIX = "artdag:claim:"
# Default TTL for claims (5 minutes)
DEFAULT_CLAIM_TTL = 300
# TTL for completed results (1 hour)
COMPLETED_TTL = 3600
class ClaimStatus(Enum):
"""Status of a task claim."""
PENDING = "pending"
CLAIMED = "claimed"
RUNNING = "running"
COMPLETED = "completed"
CACHED = "cached"
FAILED = "failed"
@dataclass
class ClaimInfo:
"""Information about a task claim."""
cache_id: str
status: ClaimStatus
worker_id: Optional[str] = None
task_id: Optional[str] = None
claimed_at: Optional[str] = None
completed_at: Optional[str] = None
output_path: Optional[str] = None
error: Optional[str] = None
def to_dict(self) -> dict:
return {
"cache_id": self.cache_id,
"status": self.status.value,
"worker_id": self.worker_id,
"task_id": self.task_id,
"claimed_at": self.claimed_at,
"completed_at": self.completed_at,
"output_path": self.output_path,
"error": self.error,
}
@classmethod
def from_dict(cls, data: dict) -> "ClaimInfo":
return cls(
cache_id=data["cache_id"],
status=ClaimStatus(data["status"]),
worker_id=data.get("worker_id"),
task_id=data.get("task_id"),
claimed_at=data.get("claimed_at"),
completed_at=data.get("completed_at"),
output_path=data.get("output_path"),
error=data.get("error"),
)
# Lua script for atomic task claiming
# Returns 1 if claim successful, 0 if already claimed/completed
CLAIM_TASK_SCRIPT = """
local key = KEYS[1]
local data = redis.call('GET', key)
if data then
local status = cjson.decode(data)
local s = status['status']
-- Already claimed, running, completed, or cached - don't claim
if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then
return 0
end
end
-- Claim the task
local claim_data = ARGV[1]
local ttl = tonumber(ARGV[2])
redis.call('SETEX', key, ttl, claim_data)
return 1
"""
# Lua script for releasing a claim (e.g., on failure)
RELEASE_CLAIM_SCRIPT = """
local key = KEYS[1]
local worker_id = ARGV[1]
local data = redis.call('GET', key)
if data then
local status = cjson.decode(data)
-- Only release if we own the claim
if status['worker_id'] == worker_id then
redis.call('DEL', key)
return 1
end
end
return 0
"""
# Lua script for updating claim status (claimed -> running -> completed)
UPDATE_STATUS_SCRIPT = """
local key = KEYS[1]
local worker_id = ARGV[1]
local new_status = ARGV[2]
local new_data = ARGV[3]
local ttl = tonumber(ARGV[4])
local data = redis.call('GET', key)
if not data then
return 0
end
local status = cjson.decode(data)
-- Only update if we own the claim
if status['worker_id'] ~= worker_id then
return 0
end
redis.call('SETEX', key, ttl, new_data)
return 1
"""
class TaskClaimer:
"""
Manages hash-based task claiming for distributed execution.
Uses Redis for coordination between workers.
Each task is identified by its cache_id (content-addressed).
"""
def __init__(self, redis_url: str = None):
"""
Initialize the claimer.
Args:
redis_url: Redis connection URL
"""
self.redis_url = redis_url or REDIS_URL
self._redis: Optional[redis.Redis] = None
self._claim_script = None
self._release_script = None
self._update_script = None
@property
def redis(self) -> redis.Redis:
"""Get Redis connection (lazy initialization)."""
if self._redis is None:
self._redis = redis.from_url(self.redis_url, decode_responses=True)
# Register Lua scripts
self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT)
self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT)
self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT)
return self._redis
def _key(self, cache_id: str) -> str:
"""Get Redis key for a cache_id."""
return f"{CLAIM_PREFIX}{cache_id}"
def claim(
self,
cache_id: str,
worker_id: str,
task_id: Optional[str] = None,
ttl: int = DEFAULT_CLAIM_TTL,
) -> bool:
"""
Attempt to claim a task.
Args:
cache_id: The cache ID of the task to claim
worker_id: Identifier for the claiming worker
task_id: Optional Celery task ID
ttl: Time-to-live for the claim in seconds
Returns:
True if claim successful, False if already claimed
"""
claim_info = ClaimInfo(
cache_id=cache_id,
status=ClaimStatus.CLAIMED,
worker_id=worker_id,
task_id=task_id,
claimed_at=datetime.now(timezone.utc).isoformat(),
)
result = self._claim_script(
keys=[self._key(cache_id)],
args=[json.dumps(claim_info.to_dict()), ttl],
client=self.redis,
)
if result == 1:
logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}")
return True
else:
logger.debug(f"Task {cache_id[:16]}... already claimed")
return False
def update_status(
self,
cache_id: str,
worker_id: str,
status: ClaimStatus,
output_path: Optional[str] = None,
error: Optional[str] = None,
ttl: Optional[int] = None,
) -> bool:
"""
Update the status of a claimed task.
Args:
cache_id: The cache ID of the task
worker_id: Worker ID that owns the claim
status: New status
output_path: Path to output (for completed)
error: Error message (for failed)
ttl: New TTL (defaults based on status)
Returns:
True if update successful
"""
if ttl is None:
if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED):
ttl = COMPLETED_TTL
else:
ttl = DEFAULT_CLAIM_TTL
# Get existing claim info
existing = self.get_status(cache_id)
if not existing:
logger.warning(f"No claim found for {cache_id[:16]}...")
return False
claim_info = ClaimInfo(
cache_id=cache_id,
status=status,
worker_id=worker_id,
task_id=existing.task_id,
claimed_at=existing.claimed_at,
completed_at=datetime.now(timezone.utc).isoformat() if status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
) else None,
output_path=output_path,
error=error,
)
result = self._update_script(
keys=[self._key(cache_id)],
args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl],
client=self.redis,
)
if result == 1:
logger.debug(f"Updated task {cache_id[:16]}... to {status.value}")
return True
else:
logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)")
return False
def release(self, cache_id: str, worker_id: str) -> bool:
"""
Release a claim (e.g., on task failure before completion).
Args:
cache_id: The cache ID of the task
worker_id: Worker ID that owns the claim
Returns:
True if release successful
"""
result = self._release_script(
keys=[self._key(cache_id)],
args=[worker_id],
client=self.redis,
)
if result == 1:
logger.debug(f"Released claim on {cache_id[:16]}...")
return True
return False
def get_status(self, cache_id: str) -> Optional[ClaimInfo]:
"""
Get the current status of a task.
Args:
cache_id: The cache ID of the task
Returns:
ClaimInfo if task has been claimed, None otherwise
"""
data = self.redis.get(self._key(cache_id))
if data:
return ClaimInfo.from_dict(json.loads(data))
return None
def is_completed(self, cache_id: str) -> bool:
"""Check if a task is completed or cached."""
info = self.get_status(cache_id)
return info is not None and info.status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED
)
def wait_for_completion(
self,
cache_id: str,
timeout: float = 300,
poll_interval: float = 0.5,
) -> Optional[ClaimInfo]:
"""
Wait for a task to complete.
Args:
cache_id: The cache ID of the task
timeout: Maximum time to wait in seconds
poll_interval: How often to check status
Returns:
ClaimInfo if completed, None if timeout
"""
start_time = time.time()
while time.time() - start_time < timeout:
info = self.get_status(cache_id)
if info and info.status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
):
return info
time.sleep(poll_interval)
logger.warning(f"Timeout waiting for {cache_id[:16]}...")
return None
def mark_cached(self, cache_id: str, output_path: str) -> None:
"""
Mark a task as already cached (no processing needed).
This is used when we discover the result already exists
before attempting to claim.
Args:
cache_id: The cache ID of the task
output_path: Path to the cached output
"""
claim_info = ClaimInfo(
cache_id=cache_id,
status=ClaimStatus.CACHED,
output_path=output_path,
completed_at=datetime.now(timezone.utc).isoformat(),
)
self.redis.setex(
self._key(cache_id),
COMPLETED_TTL,
json.dumps(claim_info.to_dict()),
)
def clear_all(self) -> int:
"""
Clear all claims (for testing/reset).
Returns:
Number of claims cleared
"""
pattern = f"{CLAIM_PREFIX}*"
keys = list(self.redis.scan_iter(match=pattern))
if keys:
return self.redis.delete(*keys)
return 0
# Global claimer instance
_claimer: Optional[TaskClaimer] = None
def get_claimer() -> TaskClaimer:
"""Get the global TaskClaimer instance."""
global _claimer
if _claimer is None:
_claimer = TaskClaimer()
return _claimer
def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool:
"""Convenience function to claim a task."""
return get_claimer().claim(cache_id, worker_id, task_id)
def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool:
"""Convenience function to mark a task as completed."""
return get_claimer().update_status(
cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path
)
def fail_task(cache_id: str, worker_id: str, error: str) -> bool:
"""Convenience function to mark a task as failed."""
return get_claimer().update_status(
cache_id, worker_id, ClaimStatus.FAILED, error=error
)

225
server.py
View File

@@ -4964,6 +4964,231 @@ async def download_client():
)
# ============================================================================
# 3-Phase Execution API (Analyze → Plan → Execute)
# ============================================================================
class RecipeRunRequest(BaseModel):
"""Request to run a recipe with the 3-phase execution model."""
recipe_yaml: str # Recipe YAML content
input_hashes: dict # Mapping from input name to content hash
features: Optional[list[str]] = None # Features to extract (default: beats, energy)
class PlanRequest(BaseModel):
"""Request to generate an execution plan."""
recipe_yaml: str
input_hashes: dict
features: Optional[list[str]] = None
class ExecutePlanRequest(BaseModel):
"""Request to execute a pre-generated plan."""
plan_json: str # JSON-serialized ExecutionPlan
@app.post("/api/v2/plan")
async def generate_plan_endpoint(
request: PlanRequest,
ctx: UserContext = Depends(get_required_user_context)
):
"""
Generate an execution plan without executing it.
Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model.
Returns the plan with cache status for each step.
"""
from tasks.orchestrate import generate_plan
try:
# Submit to Celery
task = generate_plan.delay(
recipe_yaml=request.recipe_yaml,
input_hashes=request.input_hashes,
features=request.features,
)
# Wait for result (plan generation is usually fast)
result = task.get(timeout=60)
return {
"status": result.get("status"),
"recipe": result.get("recipe"),
"plan_id": result.get("plan_id"),
"total_steps": result.get("total_steps"),
"cached_steps": result.get("cached_steps"),
"pending_steps": result.get("pending_steps"),
"steps": result.get("steps"),
}
except Exception as e:
logger.error(f"Plan generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/v2/execute")
async def execute_plan_endpoint(
request: ExecutePlanRequest,
ctx: UserContext = Depends(get_required_user_context)
):
"""
Execute a pre-generated execution plan.
Phase 3 (Execute) of the 3-phase model.
Submits the plan to Celery for parallel execution.
"""
from tasks.orchestrate import run_plan
run_id = str(uuid.uuid4())
try:
# Submit to Celery (async)
task = run_plan.delay(
plan_json=request.plan_json,
run_id=run_id,
)
return {
"status": "submitted",
"run_id": run_id,
"celery_task_id": task.id,
}
except Exception as e:
logger.error(f"Plan execution failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/v2/run-recipe")
async def run_recipe_endpoint(
request: RecipeRunRequest,
ctx: UserContext = Depends(get_required_user_context)
):
"""
Run a complete recipe through all 3 phases.
1. Analyze: Extract features from inputs
2. Plan: Generate execution plan with cache IDs
3. Execute: Run steps with parallel execution
Returns immediately with run_id. Poll /api/v2/run/{run_id} for status.
"""
from tasks.orchestrate import run_recipe
# Compute run_id from inputs and recipe
try:
recipe_data = yaml.safe_load(request.recipe_yaml)
recipe_name = recipe_data.get("name", "unknown")
except Exception:
recipe_name = "unknown"
run_id = compute_run_id(
list(request.input_hashes.values()),
recipe_name,
hashlib.sha3_256(request.recipe_yaml.encode()).hexdigest()
)
# Check if already completed
cached = await database.get_run_cache(run_id)
if cached:
output_hash = cached.get("output_hash")
if cache_manager.has_content(output_hash):
return {
"status": "completed",
"run_id": run_id,
"output_hash": output_hash,
"output_ipfs_cid": cache_manager.get_ipfs_cid(output_hash),
"cached": True,
}
# Submit to Celery
try:
task = run_recipe.delay(
recipe_yaml=request.recipe_yaml,
input_hashes=request.input_hashes,
features=request.features,
run_id=run_id,
)
# Store run status in Redis
run_data = {
"run_id": run_id,
"status": "pending",
"recipe": recipe_name,
"inputs": list(request.input_hashes.values()),
"celery_task_id": task.id,
"created_at": datetime.now(timezone.utc).isoformat(),
"username": ctx.actor_id,
}
redis_client.setex(
f"{RUNS_KEY_PREFIX}{run_id}",
86400, # 24 hour expiry
json.dumps(run_data)
)
return {
"status": "submitted",
"run_id": run_id,
"celery_task_id": task.id,
"recipe": recipe_name,
}
except Exception as e:
logger.error(f"Recipe run failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/v2/run/{run_id}")
async def get_run_v2(run_id: str, ctx: UserContext = Depends(get_required_user_context)):
"""
Get status of a 3-phase execution run.
"""
# Check Redis for run status
run_data = redis_client.get(f"{RUNS_KEY_PREFIX}{run_id}")
if run_data:
data = json.loads(run_data)
# If pending, check Celery task status
if data.get("status") == "pending" and data.get("celery_task_id"):
from celery.result import AsyncResult
result = AsyncResult(data["celery_task_id"])
if result.ready():
if result.successful():
task_result = result.get()
data["status"] = task_result.get("status", "completed")
data["output_hash"] = task_result.get("output_cache_id")
data["output_ipfs_cid"] = task_result.get("output_ipfs_cid")
data["total_steps"] = task_result.get("total_steps")
data["cached"] = task_result.get("cached")
data["executed"] = task_result.get("executed")
# Update Redis
redis_client.setex(
f"{RUNS_KEY_PREFIX}{run_id}",
86400,
json.dumps(data)
)
else:
data["status"] = "failed"
data["error"] = str(result.result)
else:
data["celery_status"] = result.status
return data
# Check database cache
cached = await database.get_run_cache(run_id)
if cached:
return {
"run_id": run_id,
"status": "completed",
"output_hash": cached.get("output_hash"),
"cached": True,
}
raise HTTPException(status_code=404, detail="Run not found")
if __name__ == "__main__":
import uvicorn
# Workers enabled - cache indexes shared via Redis

18
tasks/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
# art-celery/tasks - Celery tasks for 3-phase execution
#
# Tasks for the Art DAG distributed execution system:
# 1. analyze_input - Extract features from input media
# 2. execute_step - Execute a single step from the plan
# 3. run_plan - Orchestrate execution of a full plan
from .analyze import analyze_input, analyze_inputs
from .execute import execute_step
from .orchestrate import run_plan, run_recipe
__all__ = [
"analyze_input",
"analyze_inputs",
"execute_step",
"run_plan",
"run_recipe",
]

132
tasks/analyze.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Analysis tasks for extracting features from input media.
Phase 1 of the 3-phase execution model.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
from celery import current_task
# Import from the Celery app
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
# Import artdag analysis module
try:
from artdag.analysis import Analyzer, AnalysisResult
except ImportError:
# artdag not installed, will fail at runtime
Analyzer = None
AnalysisResult = None
logger = logging.getLogger(__name__)
# Cache directory for analysis results
CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache'))
ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis'
@app.task(bind=True, name='tasks.analyze_input')
def analyze_input(
self,
input_hash: str,
input_path: str,
features: List[str],
) -> dict:
"""
Analyze a single input file.
Args:
input_hash: Content hash of the input
input_path: Path to the input file
features: List of features to extract
Returns:
Dict with analysis results
"""
if Analyzer is None:
raise ImportError("artdag.analysis not available")
logger.info(f"Analyzing {input_hash[:16]}... for features: {features}")
# Create analyzer with caching
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
try:
result = analyzer.analyze(
input_hash=input_hash,
features=features,
input_path=Path(input_path),
)
return {
"status": "completed",
"input_hash": input_hash,
"cache_id": result.cache_id,
"features": features,
"result": result.to_dict(),
}
except Exception as e:
logger.error(f"Analysis failed for {input_hash}: {e}")
return {
"status": "failed",
"input_hash": input_hash,
"error": str(e),
}
@app.task(bind=True, name='tasks.analyze_inputs')
def analyze_inputs(
self,
inputs: Dict[str, str],
features: List[str],
) -> dict:
"""
Analyze multiple inputs in parallel.
Args:
inputs: Dict mapping input_hash to file path
features: List of features to extract from all inputs
Returns:
Dict with all analysis results
"""
if Analyzer is None:
raise ImportError("artdag.analysis not available")
logger.info(f"Analyzing {len(inputs)} inputs for features: {features}")
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
results = {}
errors = []
for input_hash, input_path in inputs.items():
try:
result = analyzer.analyze(
input_hash=input_hash,
features=features,
input_path=Path(input_path),
)
results[input_hash] = result.to_dict()
except Exception as e:
logger.error(f"Analysis failed for {input_hash}: {e}")
errors.append({"input_hash": input_hash, "error": str(e)})
return {
"status": "completed" if not errors else "partial",
"results": results,
"errors": errors,
"total": len(inputs),
"successful": len(results),
}

298
tasks/execute.py Normal file
View File

@@ -0,0 +1,298 @@
"""
Step execution task.
Phase 3 of the 3-phase execution model.
Executes individual steps from an execution plan with IPFS-backed caching.
"""
import json
import logging
import os
import socket
from pathlib import Path
from typing import Dict, List, Optional
from celery import current_task
# Import from the Celery app
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
from claiming import (
get_claimer,
claim_task,
complete_task,
fail_task,
ClaimStatus,
)
from cache_manager import get_cache_manager, L1CacheManager
# Import artdag
try:
from artdag import Cache, NodeType
from artdag.executor import get_executor
from artdag.planning import ExecutionStep
except ImportError:
Cache = None
NodeType = None
get_executor = None
ExecutionStep = None
logger = logging.getLogger(__name__)
def get_worker_id() -> str:
"""Get a unique identifier for this worker."""
hostname = socket.gethostname()
pid = os.getpid()
return f"{hostname}:{pid}"
@app.task(bind=True, name='tasks.execute_step')
def execute_step(
self,
step_json: str,
plan_id: str,
input_cache_ids: Dict[str, str],
) -> dict:
"""
Execute a single step from an execution plan.
Uses hash-based claiming to prevent duplicate work.
Results are stored in IPFS-backed cache.
Args:
step_json: JSON-serialized ExecutionStep
plan_id: ID of the parent execution plan
input_cache_ids: Mapping from input step_id to their cache_id
Returns:
Dict with execution result
"""
if ExecutionStep is None:
raise ImportError("artdag.planning not available")
step = ExecutionStep.from_json(step_json)
worker_id = get_worker_id()
task_id = self.request.id
logger.info(f"Executing step {step.step_id} ({step.node_type}) cache_id={step.cache_id[:16]}...")
# Get L1 cache manager (IPFS-backed)
cache_mgr = get_cache_manager()
# Check if already cached (by cache_id as content_hash)
cached_path = cache_mgr.get_by_content_hash(step.cache_id)
if cached_path:
logger.info(f"Step {step.step_id} already cached at {cached_path}")
# Mark as cached in claiming system
claimer = get_claimer()
claimer.mark_cached(step.cache_id, str(cached_path))
return {
"status": "cached",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": str(cached_path),
}
# Try to claim the task
if not claim_task(step.cache_id, worker_id, task_id):
# Another worker is handling it
logger.info(f"Step {step.step_id} claimed by another worker, waiting...")
claimer = get_claimer()
result = claimer.wait_for_completion(step.cache_id, timeout=600)
if result and result.status == ClaimStatus.COMPLETED:
return {
"status": "completed_by_other",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": result.output_path,
}
elif result and result.status == ClaimStatus.CACHED:
return {
"status": "cached",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": result.output_path,
}
elif result and result.status == ClaimStatus.FAILED:
return {
"status": "failed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": result.error,
}
else:
return {
"status": "timeout",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": "Timeout waiting for other worker",
}
# We have the claim, update to running
claimer = get_claimer()
claimer.update_status(step.cache_id, worker_id, ClaimStatus.RUNNING)
try:
# Handle SOURCE nodes
if step.node_type == "SOURCE":
content_hash = step.config.get("content_hash")
if not content_hash:
raise ValueError(f"SOURCE step missing content_hash")
# Look up in cache
path = cache_mgr.get_by_content_hash(content_hash)
if not path:
raise ValueError(f"SOURCE input not found in cache: {content_hash[:16]}...")
output_path = str(path)
complete_task(step.cache_id, worker_id, output_path)
return {
"status": "completed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": output_path,
}
# Handle _LIST virtual nodes
if step.node_type == "_LIST":
item_paths = []
for item_id in step.config.get("items", []):
item_cache_id = input_cache_ids.get(item_id)
if item_cache_id:
path = cache_mgr.get_by_content_hash(item_cache_id)
if path:
item_paths.append(str(path))
complete_task(step.cache_id, worker_id, json.dumps(item_paths))
return {
"status": "completed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": None,
"item_paths": item_paths,
}
# Get executor for this node type
try:
node_type = NodeType[step.node_type]
except KeyError:
node_type = step.node_type
executor = get_executor(node_type)
if executor is None:
raise ValueError(f"No executor for node type: {step.node_type}")
# Resolve input paths from cache
input_paths = []
for input_step_id in step.input_steps:
input_cache_id = input_cache_ids.get(input_step_id)
if not input_cache_id:
raise ValueError(f"No cache_id for input step: {input_step_id}")
path = cache_mgr.get_by_content_hash(input_cache_id)
if not path:
raise ValueError(f"Input not in cache: {input_cache_id[:16]}...")
input_paths.append(Path(path))
# Create temp output path
import tempfile
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / f"output_{step.cache_id[:16]}.mp4"
# Execute
logger.info(f"Running executor for {step.node_type} with {len(input_paths)} inputs")
result_path = executor.execute(step.config, input_paths, output_path)
# Store in IPFS-backed cache
cached_file, ipfs_cid = cache_mgr.put(
source_path=result_path,
node_type=step.node_type,
node_id=step.cache_id,
)
logger.info(f"Step {step.step_id} completed, IPFS CID: {ipfs_cid}")
# Mark completed
complete_task(step.cache_id, worker_id, str(cached_file.path))
# Cleanup temp
if output_dir.exists():
import shutil
shutil.rmtree(output_dir, ignore_errors=True)
return {
"status": "completed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"output_path": str(cached_file.path),
"content_hash": cached_file.content_hash,
"ipfs_cid": ipfs_cid,
}
except Exception as e:
logger.error(f"Step {step.step_id} failed: {e}")
fail_task(step.cache_id, worker_id, str(e))
return {
"status": "failed",
"step_id": step.step_id,
"cache_id": step.cache_id,
"error": str(e),
}
@app.task(bind=True, name='tasks.execute_level')
def execute_level(
self,
steps_json: List[str],
plan_id: str,
cache_ids: Dict[str, str],
) -> dict:
"""
Execute all steps at a given dependency level.
Steps at the same level can run in parallel.
Args:
steps_json: List of JSON-serialized ExecutionSteps
plan_id: ID of the parent execution plan
cache_ids: Mapping from step_id to cache_id
Returns:
Dict with results for all steps
"""
from celery import group
# Dispatch all steps in parallel
tasks = [
execute_step.s(step_json, plan_id, cache_ids)
for step_json in steps_json
]
# Execute in parallel and collect results
job = group(tasks)
results = job.apply_async()
# Wait for completion
step_results = results.get(timeout=3600) # 1 hour timeout
# Build cache_ids from results
new_cache_ids = dict(cache_ids)
for result in step_results:
step_id = result.get("step_id")
cache_id = result.get("cache_id")
if step_id and cache_id:
new_cache_ids[step_id] = cache_id
return {
"status": "completed",
"results": step_results,
"cache_ids": new_cache_ids,
}

373
tasks/orchestrate.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Plan orchestration tasks.
Coordinates the full 3-phase execution:
1. Analyze inputs
2. Generate plan
3. Execute steps level by level
Uses IPFS-backed cache for durability.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
from celery import current_task, group, chain
# Import from the Celery app
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
from claiming import get_claimer
from cache_manager import get_cache_manager
# Import artdag modules
try:
from artdag import Cache
from artdag.analysis import Analyzer, AnalysisResult
from artdag.planning import RecipePlanner, ExecutionPlan, Recipe
except ImportError:
Cache = None
Analyzer = None
AnalysisResult = None
RecipePlanner = None
ExecutionPlan = None
Recipe = None
from .execute import execute_step
logger = logging.getLogger(__name__)
# Cache directories
CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache'))
ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis'
PLAN_CACHE_DIR = CACHE_DIR / 'plans'
@app.task(bind=True, name='tasks.run_plan')
def run_plan(
self,
plan_json: str,
run_id: Optional[str] = None,
) -> dict:
"""
Execute a complete execution plan.
Runs steps level by level, with parallel execution within each level.
Results are stored in IPFS-backed cache.
Args:
plan_json: JSON-serialized ExecutionPlan
run_id: Optional run ID for tracking
Returns:
Dict with execution results
"""
if ExecutionPlan is None:
raise ImportError("artdag.planning not available")
plan = ExecutionPlan.from_json(plan_json)
cache_mgr = get_cache_manager()
logger.info(f"Executing plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
# Build initial cache_ids mapping (step_id -> cache_id)
cache_ids = {}
for step in plan.steps:
cache_ids[step.step_id] = step.cache_id
# Also map input hashes
for name, content_hash in plan.input_hashes.items():
cache_ids[name] = content_hash
# Group steps by level
steps_by_level = plan.get_steps_by_level()
max_level = max(steps_by_level.keys()) if steps_by_level else 0
results_by_step = {}
total_cached = 0
total_executed = 0
for level in range(max_level + 1):
level_steps = steps_by_level.get(level, [])
if not level_steps:
continue
logger.info(f"Executing level {level}: {len(level_steps)} steps")
# Check which steps need execution
steps_to_run = []
for step in level_steps:
# Check if cached
cached_path = cache_mgr.get_by_content_hash(step.cache_id)
if cached_path:
results_by_step[step.step_id] = {
"status": "cached",
"cache_id": step.cache_id,
"output_path": str(cached_path),
}
total_cached += 1
else:
steps_to_run.append(step)
if not steps_to_run:
logger.info(f"Level {level}: all steps cached")
continue
# Build input cache_ids for this level
level_cache_ids = dict(cache_ids)
# Execute steps in parallel
tasks = [
execute_step.s(step.to_json(), plan.plan_id, level_cache_ids)
for step in steps_to_run
]
job = group(tasks)
async_results = job.apply_async()
# Wait for completion
try:
step_results = async_results.get(timeout=3600)
except Exception as e:
logger.error(f"Level {level} execution failed: {e}")
return {
"status": "failed",
"error": str(e),
"level": level,
"results": results_by_step,
"run_id": run_id,
}
# Process results
for result in step_results:
step_id = result.get("step_id")
cache_id = result.get("cache_id")
results_by_step[step_id] = result
cache_ids[step_id] = cache_id
if result.get("status") in ("completed", "cached", "completed_by_other"):
total_executed += 1
elif result.get("status") == "failed":
logger.error(f"Step {step_id} failed: {result.get('error')}")
return {
"status": "failed",
"error": f"Step {step_id} failed: {result.get('error')}",
"level": level,
"results": results_by_step,
"run_id": run_id,
}
# Get final output
output_step = plan.get_step(plan.output_step)
output_cache_id = output_step.cache_id if output_step else None
output_path = None
output_ipfs_cid = None
if output_cache_id:
output_path = cache_mgr.get_by_content_hash(output_cache_id)
output_ipfs_cid = cache_mgr.get_ipfs_cid(output_cache_id)
return {
"status": "completed",
"run_id": run_id,
"plan_id": plan.plan_id,
"output_cache_id": output_cache_id,
"output_path": str(output_path) if output_path else None,
"output_ipfs_cid": output_ipfs_cid,
"total_steps": len(plan.steps),
"cached": total_cached,
"executed": total_executed,
"results": results_by_step,
}
@app.task(bind=True, name='tasks.run_recipe')
def run_recipe(
self,
recipe_yaml: str,
input_hashes: Dict[str, str],
features: List[str] = None,
run_id: Optional[str] = None,
) -> dict:
"""
Run a complete recipe through all 3 phases.
1. Analyze: Extract features from inputs
2. Plan: Generate execution plan
3. Execute: Run the plan
Args:
recipe_yaml: Recipe YAML content
input_hashes: Mapping from input name to content hash
features: Features to extract (default: ["beats", "energy"])
run_id: Optional run ID for tracking
Returns:
Dict with final results
"""
if RecipePlanner is None or Analyzer is None:
raise ImportError("artdag modules not available")
if features is None:
features = ["beats", "energy"]
cache_mgr = get_cache_manager()
logger.info(f"Running recipe with {len(input_hashes)} inputs")
# Phase 1: Analyze
logger.info("Phase 1: Analyzing inputs...")
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
analysis_results = {}
for name, content_hash in input_hashes.items():
# Get path from cache
path = cache_mgr.get_by_content_hash(content_hash)
if path:
try:
result = analyzer.analyze(
input_hash=content_hash,
features=features,
input_path=Path(path),
)
analysis_results[content_hash] = result
logger.info(f"Analyzed {name}: tempo={result.tempo}, beats={len(result.beat_times or [])}")
except Exception as e:
logger.warning(f"Analysis failed for {name}: {e}")
else:
logger.warning(f"Input {name} ({content_hash[:16]}...) not in cache")
logger.info(f"Analyzed {len(analysis_results)} inputs")
# Phase 2: Plan
logger.info("Phase 2: Generating execution plan...")
recipe = Recipe.from_yaml(recipe_yaml)
planner = RecipePlanner(use_tree_reduction=True)
plan = planner.plan(
recipe=recipe,
input_hashes=input_hashes,
analysis=analysis_results,
)
logger.info(f"Generated plan with {len(plan.steps)} steps")
# Save plan for debugging
PLAN_CACHE_DIR.mkdir(parents=True, exist_ok=True)
plan_path = PLAN_CACHE_DIR / f"{plan.plan_id}.json"
with open(plan_path, "w") as f:
f.write(plan.to_json())
# Phase 3: Execute
logger.info("Phase 3: Executing plan...")
result = run_plan(plan.to_json(), run_id=run_id)
return {
"status": result.get("status"),
"run_id": run_id,
"recipe": recipe.name,
"plan_id": plan.plan_id,
"output_path": result.get("output_path"),
"output_cache_id": result.get("output_cache_id"),
"output_ipfs_cid": result.get("output_ipfs_cid"),
"analysis_count": len(analysis_results),
"total_steps": len(plan.steps),
"cached": result.get("cached", 0),
"executed": result.get("executed", 0),
"error": result.get("error"),
}
@app.task(bind=True, name='tasks.generate_plan')
def generate_plan(
self,
recipe_yaml: str,
input_hashes: Dict[str, str],
features: List[str] = None,
) -> dict:
"""
Generate an execution plan without executing it.
Useful for:
- Previewing what will be executed
- Checking cache status
- Debugging recipe issues
Args:
recipe_yaml: Recipe YAML content
input_hashes: Mapping from input name to content hash
features: Features to extract for analysis
Returns:
Dict with plan details
"""
if RecipePlanner is None or Analyzer is None:
raise ImportError("artdag modules not available")
if features is None:
features = ["beats", "energy"]
cache_mgr = get_cache_manager()
# Analyze inputs
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
analysis_results = {}
for name, content_hash in input_hashes.items():
path = cache_mgr.get_by_content_hash(content_hash)
if path:
try:
result = analyzer.analyze(
input_hash=content_hash,
features=features,
input_path=Path(path),
)
analysis_results[content_hash] = result
except Exception as e:
logger.warning(f"Analysis failed for {name}: {e}")
# Generate plan
recipe = Recipe.from_yaml(recipe_yaml)
planner = RecipePlanner(use_tree_reduction=True)
plan = planner.plan(
recipe=recipe,
input_hashes=input_hashes,
analysis=analysis_results,
)
# Check cache status for each step
steps_status = []
for step in plan.steps:
cached = cache_mgr.has_content(step.cache_id)
steps_status.append({
"step_id": step.step_id,
"node_type": step.node_type,
"cache_id": step.cache_id,
"level": step.level,
"cached": cached,
})
cached_count = sum(1 for s in steps_status if s["cached"])
return {
"status": "planned",
"recipe": recipe.name,
"plan_id": plan.plan_id,
"total_steps": len(plan.steps),
"cached_steps": cached_count,
"pending_steps": len(plan.steps) - cached_count,
"steps": steps_status,
"plan_json": plan.to_json(),
}