Add 3-phase execution with IPFS cache and hash-based task claiming
New files:
- claiming.py - Redis Lua scripts for atomic task claiming
- tasks/analyze.py - Analysis Celery task
- tasks/execute.py - Step execution with IPFS-backed cache
- tasks/orchestrate.py - Plan orchestration (run_plan, run_recipe)
New API endpoints (/api/v2/):
- POST /api/v2/plan - Generate execution plan
- POST /api/v2/execute - Execute a plan
- POST /api/v2/run-recipe - Full 3-phase pipeline
- GET /api/v2/run/{run_id} - Get run status
Features:
- Hash-based task claiming prevents duplicate work
- Parallel execution within dependency levels
- IPFS-backed cache for durability
- Integration with artdag planning module
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,7 @@ app = Celery(
|
|||||||
'art_celery',
|
'art_celery',
|
||||||
broker=REDIS_URL,
|
broker=REDIS_URL,
|
||||||
backend=REDIS_URL,
|
backend=REDIS_URL,
|
||||||
include=['tasks']
|
include=['tasks', 'tasks.analyze', 'tasks.execute', 'tasks.orchestrate']
|
||||||
)
|
)
|
||||||
|
|
||||||
app.conf.update(
|
app.conf.update(
|
||||||
|
|||||||
421
claiming.py
Normal file
421
claiming.py
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
"""
|
||||||
|
Hash-based task claiming for distributed execution.
|
||||||
|
|
||||||
|
Prevents duplicate work when multiple workers process the same plan.
|
||||||
|
Uses Redis Lua scripts for atomic claim operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
|
||||||
|
|
||||||
|
# Key prefix for task claims
|
||||||
|
CLAIM_PREFIX = "artdag:claim:"
|
||||||
|
|
||||||
|
# Default TTL for claims (5 minutes)
|
||||||
|
DEFAULT_CLAIM_TTL = 300
|
||||||
|
|
||||||
|
# TTL for completed results (1 hour)
|
||||||
|
COMPLETED_TTL = 3600
|
||||||
|
|
||||||
|
|
||||||
|
class ClaimStatus(Enum):
|
||||||
|
"""Status of a task claim."""
|
||||||
|
PENDING = "pending"
|
||||||
|
CLAIMED = "claimed"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
CACHED = "cached"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClaimInfo:
|
||||||
|
"""Information about a task claim."""
|
||||||
|
cache_id: str
|
||||||
|
status: ClaimStatus
|
||||||
|
worker_id: Optional[str] = None
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
claimed_at: Optional[str] = None
|
||||||
|
completed_at: Optional[str] = None
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"cache_id": self.cache_id,
|
||||||
|
"status": self.status.value,
|
||||||
|
"worker_id": self.worker_id,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"claimed_at": self.claimed_at,
|
||||||
|
"completed_at": self.completed_at,
|
||||||
|
"output_path": self.output_path,
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "ClaimInfo":
|
||||||
|
return cls(
|
||||||
|
cache_id=data["cache_id"],
|
||||||
|
status=ClaimStatus(data["status"]),
|
||||||
|
worker_id=data.get("worker_id"),
|
||||||
|
task_id=data.get("task_id"),
|
||||||
|
claimed_at=data.get("claimed_at"),
|
||||||
|
completed_at=data.get("completed_at"),
|
||||||
|
output_path=data.get("output_path"),
|
||||||
|
error=data.get("error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Lua script for atomic task claiming
|
||||||
|
# Returns 1 if claim successful, 0 if already claimed/completed
|
||||||
|
CLAIM_TASK_SCRIPT = """
|
||||||
|
local key = KEYS[1]
|
||||||
|
local data = redis.call('GET', key)
|
||||||
|
|
||||||
|
if data then
|
||||||
|
local status = cjson.decode(data)
|
||||||
|
local s = status['status']
|
||||||
|
-- Already claimed, running, completed, or cached - don't claim
|
||||||
|
if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Claim the task
|
||||||
|
local claim_data = ARGV[1]
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
redis.call('SETEX', key, ttl, claim_data)
|
||||||
|
return 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Lua script for releasing a claim (e.g., on failure)
|
||||||
|
RELEASE_CLAIM_SCRIPT = """
|
||||||
|
local key = KEYS[1]
|
||||||
|
local worker_id = ARGV[1]
|
||||||
|
local data = redis.call('GET', key)
|
||||||
|
|
||||||
|
if data then
|
||||||
|
local status = cjson.decode(data)
|
||||||
|
-- Only release if we own the claim
|
||||||
|
if status['worker_id'] == worker_id then
|
||||||
|
redis.call('DEL', key)
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Lua script for updating claim status (claimed -> running -> completed)
|
||||||
|
UPDATE_STATUS_SCRIPT = """
|
||||||
|
local key = KEYS[1]
|
||||||
|
local worker_id = ARGV[1]
|
||||||
|
local new_status = ARGV[2]
|
||||||
|
local new_data = ARGV[3]
|
||||||
|
local ttl = tonumber(ARGV[4])
|
||||||
|
|
||||||
|
local data = redis.call('GET', key)
|
||||||
|
if not data then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
local status = cjson.decode(data)
|
||||||
|
|
||||||
|
-- Only update if we own the claim
|
||||||
|
if status['worker_id'] ~= worker_id then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('SETEX', key, ttl, new_data)
|
||||||
|
return 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TaskClaimer:
|
||||||
|
"""
|
||||||
|
Manages hash-based task claiming for distributed execution.
|
||||||
|
|
||||||
|
Uses Redis for coordination between workers.
|
||||||
|
Each task is identified by its cache_id (content-addressed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = None):
|
||||||
|
"""
|
||||||
|
Initialize the claimer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url: Redis connection URL
|
||||||
|
"""
|
||||||
|
self.redis_url = redis_url or REDIS_URL
|
||||||
|
self._redis: Optional[redis.Redis] = None
|
||||||
|
self._claim_script = None
|
||||||
|
self._release_script = None
|
||||||
|
self._update_script = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redis(self) -> redis.Redis:
|
||||||
|
"""Get Redis connection (lazy initialization)."""
|
||||||
|
if self._redis is None:
|
||||||
|
self._redis = redis.from_url(self.redis_url, decode_responses=True)
|
||||||
|
# Register Lua scripts
|
||||||
|
self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT)
|
||||||
|
self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT)
|
||||||
|
self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT)
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
def _key(self, cache_id: str) -> str:
|
||||||
|
"""Get Redis key for a cache_id."""
|
||||||
|
return f"{CLAIM_PREFIX}{cache_id}"
|
||||||
|
|
||||||
|
def claim(
|
||||||
|
self,
|
||||||
|
cache_id: str,
|
||||||
|
worker_id: str,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
ttl: int = DEFAULT_CLAIM_TTL,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to claim a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task to claim
|
||||||
|
worker_id: Identifier for the claiming worker
|
||||||
|
task_id: Optional Celery task ID
|
||||||
|
ttl: Time-to-live for the claim in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if claim successful, False if already claimed
|
||||||
|
"""
|
||||||
|
claim_info = ClaimInfo(
|
||||||
|
cache_id=cache_id,
|
||||||
|
status=ClaimStatus.CLAIMED,
|
||||||
|
worker_id=worker_id,
|
||||||
|
task_id=task_id,
|
||||||
|
claimed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._claim_script(
|
||||||
|
keys=[self._key(cache_id)],
|
||||||
|
args=[json.dumps(claim_info.to_dict()), ttl],
|
||||||
|
client=self.redis,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 1:
|
||||||
|
logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.debug(f"Task {cache_id[:16]}... already claimed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
cache_id: str,
|
||||||
|
worker_id: str,
|
||||||
|
status: ClaimStatus,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
ttl: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Update the status of a claimed task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task
|
||||||
|
worker_id: Worker ID that owns the claim
|
||||||
|
status: New status
|
||||||
|
output_path: Path to output (for completed)
|
||||||
|
error: Error message (for failed)
|
||||||
|
ttl: New TTL (defaults based on status)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if update successful
|
||||||
|
"""
|
||||||
|
if ttl is None:
|
||||||
|
if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED):
|
||||||
|
ttl = COMPLETED_TTL
|
||||||
|
else:
|
||||||
|
ttl = DEFAULT_CLAIM_TTL
|
||||||
|
|
||||||
|
# Get existing claim info
|
||||||
|
existing = self.get_status(cache_id)
|
||||||
|
if not existing:
|
||||||
|
logger.warning(f"No claim found for {cache_id[:16]}...")
|
||||||
|
return False
|
||||||
|
|
||||||
|
claim_info = ClaimInfo(
|
||||||
|
cache_id=cache_id,
|
||||||
|
status=status,
|
||||||
|
worker_id=worker_id,
|
||||||
|
task_id=existing.task_id,
|
||||||
|
claimed_at=existing.claimed_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat() if status in (
|
||||||
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
||||||
|
) else None,
|
||||||
|
output_path=output_path,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._update_script(
|
||||||
|
keys=[self._key(cache_id)],
|
||||||
|
args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl],
|
||||||
|
client=self.redis,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 1:
|
||||||
|
logger.debug(f"Updated task {cache_id[:16]}... to {status.value}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def release(self, cache_id: str, worker_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Release a claim (e.g., on task failure before completion).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task
|
||||||
|
worker_id: Worker ID that owns the claim
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if release successful
|
||||||
|
"""
|
||||||
|
result = self._release_script(
|
||||||
|
keys=[self._key(cache_id)],
|
||||||
|
args=[worker_id],
|
||||||
|
client=self.redis,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 1:
|
||||||
|
logger.debug(f"Released claim on {cache_id[:16]}...")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_status(self, cache_id: str) -> Optional[ClaimInfo]:
|
||||||
|
"""
|
||||||
|
Get the current status of a task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClaimInfo if task has been claimed, None otherwise
|
||||||
|
"""
|
||||||
|
data = self.redis.get(self._key(cache_id))
|
||||||
|
if data:
|
||||||
|
return ClaimInfo.from_dict(json.loads(data))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_completed(self, cache_id: str) -> bool:
|
||||||
|
"""Check if a task is completed or cached."""
|
||||||
|
info = self.get_status(cache_id)
|
||||||
|
return info is not None and info.status in (
|
||||||
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for_completion(
|
||||||
|
self,
|
||||||
|
cache_id: str,
|
||||||
|
timeout: float = 300,
|
||||||
|
poll_interval: float = 0.5,
|
||||||
|
) -> Optional[ClaimInfo]:
|
||||||
|
"""
|
||||||
|
Wait for a task to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task
|
||||||
|
timeout: Maximum time to wait in seconds
|
||||||
|
poll_interval: How often to check status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClaimInfo if completed, None if timeout
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
info = self.get_status(cache_id)
|
||||||
|
if info and info.status in (
|
||||||
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
||||||
|
):
|
||||||
|
return info
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
logger.warning(f"Timeout waiting for {cache_id[:16]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def mark_cached(self, cache_id: str, output_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Mark a task as already cached (no processing needed).
|
||||||
|
|
||||||
|
This is used when we discover the result already exists
|
||||||
|
before attempting to claim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: The cache ID of the task
|
||||||
|
output_path: Path to the cached output
|
||||||
|
"""
|
||||||
|
claim_info = ClaimInfo(
|
||||||
|
cache_id=cache_id,
|
||||||
|
status=ClaimStatus.CACHED,
|
||||||
|
output_path=output_path,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.redis.setex(
|
||||||
|
self._key(cache_id),
|
||||||
|
COMPLETED_TTL,
|
||||||
|
json.dumps(claim_info.to_dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all claims (for testing/reset).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of claims cleared
|
||||||
|
"""
|
||||||
|
pattern = f"{CLAIM_PREFIX}*"
|
||||||
|
keys = list(self.redis.scan_iter(match=pattern))
|
||||||
|
if keys:
|
||||||
|
return self.redis.delete(*keys)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# Global claimer instance
|
||||||
|
_claimer: Optional[TaskClaimer] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_claimer() -> TaskClaimer:
|
||||||
|
"""Get the global TaskClaimer instance."""
|
||||||
|
global _claimer
|
||||||
|
if _claimer is None:
|
||||||
|
_claimer = TaskClaimer()
|
||||||
|
return _claimer
|
||||||
|
|
||||||
|
|
||||||
|
def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool:
|
||||||
|
"""Convenience function to claim a task."""
|
||||||
|
return get_claimer().claim(cache_id, worker_id, task_id)
|
||||||
|
|
||||||
|
|
||||||
|
def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool:
|
||||||
|
"""Convenience function to mark a task as completed."""
|
||||||
|
return get_claimer().update_status(
|
||||||
|
cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fail_task(cache_id: str, worker_id: str, error: str) -> bool:
|
||||||
|
"""Convenience function to mark a task as failed."""
|
||||||
|
return get_claimer().update_status(
|
||||||
|
cache_id, worker_id, ClaimStatus.FAILED, error=error
|
||||||
|
)
|
||||||
225
server.py
225
server.py
@@ -4964,6 +4964,231 @@ async def download_client():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 3-Phase Execution API (Analyze → Plan → Execute)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class RecipeRunRequest(BaseModel):
|
||||||
|
"""Request to run a recipe with the 3-phase execution model."""
|
||||||
|
recipe_yaml: str # Recipe YAML content
|
||||||
|
input_hashes: dict # Mapping from input name to content hash
|
||||||
|
features: Optional[list[str]] = None # Features to extract (default: beats, energy)
|
||||||
|
|
||||||
|
|
||||||
|
class PlanRequest(BaseModel):
|
||||||
|
"""Request to generate an execution plan."""
|
||||||
|
recipe_yaml: str
|
||||||
|
input_hashes: dict
|
||||||
|
features: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutePlanRequest(BaseModel):
|
||||||
|
"""Request to execute a pre-generated plan."""
|
||||||
|
plan_json: str # JSON-serialized ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v2/plan")
|
||||||
|
async def generate_plan_endpoint(
|
||||||
|
request: PlanRequest,
|
||||||
|
ctx: UserContext = Depends(get_required_user_context)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate an execution plan without executing it.
|
||||||
|
|
||||||
|
Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model.
|
||||||
|
|
||||||
|
Returns the plan with cache status for each step.
|
||||||
|
"""
|
||||||
|
from tasks.orchestrate import generate_plan
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Submit to Celery
|
||||||
|
task = generate_plan.delay(
|
||||||
|
recipe_yaml=request.recipe_yaml,
|
||||||
|
input_hashes=request.input_hashes,
|
||||||
|
features=request.features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for result (plan generation is usually fast)
|
||||||
|
result = task.get(timeout=60)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": result.get("status"),
|
||||||
|
"recipe": result.get("recipe"),
|
||||||
|
"plan_id": result.get("plan_id"),
|
||||||
|
"total_steps": result.get("total_steps"),
|
||||||
|
"cached_steps": result.get("cached_steps"),
|
||||||
|
"pending_steps": result.get("pending_steps"),
|
||||||
|
"steps": result.get("steps"),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Plan generation failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v2/execute")
|
||||||
|
async def execute_plan_endpoint(
|
||||||
|
request: ExecutePlanRequest,
|
||||||
|
ctx: UserContext = Depends(get_required_user_context)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Execute a pre-generated execution plan.
|
||||||
|
|
||||||
|
Phase 3 (Execute) of the 3-phase model.
|
||||||
|
|
||||||
|
Submits the plan to Celery for parallel execution.
|
||||||
|
"""
|
||||||
|
from tasks.orchestrate import run_plan
|
||||||
|
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Submit to Celery (async)
|
||||||
|
task = run_plan.delay(
|
||||||
|
plan_json=request.plan_json,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "submitted",
|
||||||
|
"run_id": run_id,
|
||||||
|
"celery_task_id": task.id,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Plan execution failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v2/run-recipe")
|
||||||
|
async def run_recipe_endpoint(
|
||||||
|
request: RecipeRunRequest,
|
||||||
|
ctx: UserContext = Depends(get_required_user_context)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run a complete recipe through all 3 phases.
|
||||||
|
|
||||||
|
1. Analyze: Extract features from inputs
|
||||||
|
2. Plan: Generate execution plan with cache IDs
|
||||||
|
3. Execute: Run steps with parallel execution
|
||||||
|
|
||||||
|
Returns immediately with run_id. Poll /api/v2/run/{run_id} for status.
|
||||||
|
"""
|
||||||
|
from tasks.orchestrate import run_recipe
|
||||||
|
|
||||||
|
# Compute run_id from inputs and recipe
|
||||||
|
try:
|
||||||
|
recipe_data = yaml.safe_load(request.recipe_yaml)
|
||||||
|
recipe_name = recipe_data.get("name", "unknown")
|
||||||
|
except Exception:
|
||||||
|
recipe_name = "unknown"
|
||||||
|
|
||||||
|
run_id = compute_run_id(
|
||||||
|
list(request.input_hashes.values()),
|
||||||
|
recipe_name,
|
||||||
|
hashlib.sha3_256(request.recipe_yaml.encode()).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if already completed
|
||||||
|
cached = await database.get_run_cache(run_id)
|
||||||
|
if cached:
|
||||||
|
output_hash = cached.get("output_hash")
|
||||||
|
if cache_manager.has_content(output_hash):
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"output_hash": output_hash,
|
||||||
|
"output_ipfs_cid": cache_manager.get_ipfs_cid(output_hash),
|
||||||
|
"cached": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Submit to Celery
|
||||||
|
try:
|
||||||
|
task = run_recipe.delay(
|
||||||
|
recipe_yaml=request.recipe_yaml,
|
||||||
|
input_hashes=request.input_hashes,
|
||||||
|
features=request.features,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store run status in Redis
|
||||||
|
run_data = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"status": "pending",
|
||||||
|
"recipe": recipe_name,
|
||||||
|
"inputs": list(request.input_hashes.values()),
|
||||||
|
"celery_task_id": task.id,
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"username": ctx.actor_id,
|
||||||
|
}
|
||||||
|
redis_client.setex(
|
||||||
|
f"{RUNS_KEY_PREFIX}{run_id}",
|
||||||
|
86400, # 24 hour expiry
|
||||||
|
json.dumps(run_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "submitted",
|
||||||
|
"run_id": run_id,
|
||||||
|
"celery_task_id": task.id,
|
||||||
|
"recipe": recipe_name,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Recipe run failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v2/run/{run_id}")
|
||||||
|
async def get_run_v2(run_id: str, ctx: UserContext = Depends(get_required_user_context)):
|
||||||
|
"""
|
||||||
|
Get status of a 3-phase execution run.
|
||||||
|
"""
|
||||||
|
# Check Redis for run status
|
||||||
|
run_data = redis_client.get(f"{RUNS_KEY_PREFIX}{run_id}")
|
||||||
|
if run_data:
|
||||||
|
data = json.loads(run_data)
|
||||||
|
|
||||||
|
# If pending, check Celery task status
|
||||||
|
if data.get("status") == "pending" and data.get("celery_task_id"):
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
result = AsyncResult(data["celery_task_id"])
|
||||||
|
|
||||||
|
if result.ready():
|
||||||
|
if result.successful():
|
||||||
|
task_result = result.get()
|
||||||
|
data["status"] = task_result.get("status", "completed")
|
||||||
|
data["output_hash"] = task_result.get("output_cache_id")
|
||||||
|
data["output_ipfs_cid"] = task_result.get("output_ipfs_cid")
|
||||||
|
data["total_steps"] = task_result.get("total_steps")
|
||||||
|
data["cached"] = task_result.get("cached")
|
||||||
|
data["executed"] = task_result.get("executed")
|
||||||
|
|
||||||
|
# Update Redis
|
||||||
|
redis_client.setex(
|
||||||
|
f"{RUNS_KEY_PREFIX}{run_id}",
|
||||||
|
86400,
|
||||||
|
json.dumps(data)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data["status"] = "failed"
|
||||||
|
data["error"] = str(result.result)
|
||||||
|
else:
|
||||||
|
data["celery_status"] = result.status
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check database cache
|
||||||
|
cached = await database.get_run_cache(run_id)
|
||||||
|
if cached:
|
||||||
|
return {
|
||||||
|
"run_id": run_id,
|
||||||
|
"status": "completed",
|
||||||
|
"output_hash": cached.get("output_hash"),
|
||||||
|
"cached": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
# Workers enabled - cache indexes shared via Redis
|
# Workers enabled - cache indexes shared via Redis
|
||||||
|
|||||||
18
tasks/__init__.py
Normal file
18
tasks/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# art-celery/tasks - Celery tasks for 3-phase execution
|
||||||
|
#
|
||||||
|
# Tasks for the Art DAG distributed execution system:
|
||||||
|
# 1. analyze_input - Extract features from input media
|
||||||
|
# 2. execute_step - Execute a single step from the plan
|
||||||
|
# 3. run_plan - Orchestrate execution of a full plan
|
||||||
|
|
||||||
|
from .analyze import analyze_input, analyze_inputs
|
||||||
|
from .execute import execute_step
|
||||||
|
from .orchestrate import run_plan, run_recipe
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"analyze_input",
|
||||||
|
"analyze_inputs",
|
||||||
|
"execute_step",
|
||||||
|
"run_plan",
|
||||||
|
"run_recipe",
|
||||||
|
]
|
||||||
132
tasks/analyze.py
Normal file
132
tasks/analyze.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Analysis tasks for extracting features from input media.
|
||||||
|
|
||||||
|
Phase 1 of the 3-phase execution model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from celery import current_task
|
||||||
|
|
||||||
|
# Import from the Celery app
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from celery_app import app
|
||||||
|
|
||||||
|
# Import artdag analysis module
|
||||||
|
try:
|
||||||
|
from artdag.analysis import Analyzer, AnalysisResult
|
||||||
|
except ImportError:
|
||||||
|
# artdag not installed, will fail at runtime
|
||||||
|
Analyzer = None
|
||||||
|
AnalysisResult = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache directory for analysis results
|
||||||
|
CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache'))
|
||||||
|
ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis'
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.analyze_input')
|
||||||
|
def analyze_input(
|
||||||
|
self,
|
||||||
|
input_hash: str,
|
||||||
|
input_path: str,
|
||||||
|
features: List[str],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Analyze a single input file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_hash: Content hash of the input
|
||||||
|
input_path: Path to the input file
|
||||||
|
features: List of features to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with analysis results
|
||||||
|
"""
|
||||||
|
if Analyzer is None:
|
||||||
|
raise ImportError("artdag.analysis not available")
|
||||||
|
|
||||||
|
logger.info(f"Analyzing {input_hash[:16]}... for features: {features}")
|
||||||
|
|
||||||
|
# Create analyzer with caching
|
||||||
|
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = analyzer.analyze(
|
||||||
|
input_hash=input_hash,
|
||||||
|
features=features,
|
||||||
|
input_path=Path(input_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"input_hash": input_hash,
|
||||||
|
"cache_id": result.cache_id,
|
||||||
|
"features": features,
|
||||||
|
"result": result.to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Analysis failed for {input_hash}: {e}")
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"input_hash": input_hash,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.analyze_inputs')
|
||||||
|
def analyze_inputs(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
features: List[str],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Analyze multiple inputs in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dict mapping input_hash to file path
|
||||||
|
features: List of features to extract from all inputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with all analysis results
|
||||||
|
"""
|
||||||
|
if Analyzer is None:
|
||||||
|
raise ImportError("artdag.analysis not available")
|
||||||
|
|
||||||
|
logger.info(f"Analyzing {len(inputs)} inputs for features: {features}")
|
||||||
|
|
||||||
|
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for input_hash, input_path in inputs.items():
|
||||||
|
try:
|
||||||
|
result = analyzer.analyze(
|
||||||
|
input_hash=input_hash,
|
||||||
|
features=features,
|
||||||
|
input_path=Path(input_path),
|
||||||
|
)
|
||||||
|
results[input_hash] = result.to_dict()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Analysis failed for {input_hash}: {e}")
|
||||||
|
errors.append({"input_hash": input_hash, "error": str(e)})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed" if not errors else "partial",
|
||||||
|
"results": results,
|
||||||
|
"errors": errors,
|
||||||
|
"total": len(inputs),
|
||||||
|
"successful": len(results),
|
||||||
|
}
|
||||||
298
tasks/execute.py
Normal file
298
tasks/execute.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""
|
||||||
|
Step execution task.
|
||||||
|
|
||||||
|
Phase 3 of the 3-phase execution model.
|
||||||
|
Executes individual steps from an execution plan with IPFS-backed caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from celery import current_task
|
||||||
|
|
||||||
|
# Import from the Celery app
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from celery_app import app
|
||||||
|
from claiming import (
|
||||||
|
get_claimer,
|
||||||
|
claim_task,
|
||||||
|
complete_task,
|
||||||
|
fail_task,
|
||||||
|
ClaimStatus,
|
||||||
|
)
|
||||||
|
from cache_manager import get_cache_manager, L1CacheManager
|
||||||
|
|
||||||
|
# Import artdag
|
||||||
|
try:
|
||||||
|
from artdag import Cache, NodeType
|
||||||
|
from artdag.executor import get_executor
|
||||||
|
from artdag.planning import ExecutionStep
|
||||||
|
except ImportError:
|
||||||
|
Cache = None
|
||||||
|
NodeType = None
|
||||||
|
get_executor = None
|
||||||
|
ExecutionStep = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_worker_id() -> str:
|
||||||
|
"""Get a unique identifier for this worker."""
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
pid = os.getpid()
|
||||||
|
return f"{hostname}:{pid}"
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.execute_step')
|
||||||
|
def execute_step(
|
||||||
|
self,
|
||||||
|
step_json: str,
|
||||||
|
plan_id: str,
|
||||||
|
input_cache_ids: Dict[str, str],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Execute a single step from an execution plan.
|
||||||
|
|
||||||
|
Uses hash-based claiming to prevent duplicate work.
|
||||||
|
Results are stored in IPFS-backed cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_json: JSON-serialized ExecutionStep
|
||||||
|
plan_id: ID of the parent execution plan
|
||||||
|
input_cache_ids: Mapping from input step_id to their cache_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with execution result
|
||||||
|
"""
|
||||||
|
if ExecutionStep is None:
|
||||||
|
raise ImportError("artdag.planning not available")
|
||||||
|
|
||||||
|
step = ExecutionStep.from_json(step_json)
|
||||||
|
worker_id = get_worker_id()
|
||||||
|
task_id = self.request.id
|
||||||
|
|
||||||
|
logger.info(f"Executing step {step.step_id} ({step.node_type}) cache_id={step.cache_id[:16]}...")
|
||||||
|
|
||||||
|
# Get L1 cache manager (IPFS-backed)
|
||||||
|
cache_mgr = get_cache_manager()
|
||||||
|
|
||||||
|
# Check if already cached (by cache_id as content_hash)
|
||||||
|
cached_path = cache_mgr.get_by_content_hash(step.cache_id)
|
||||||
|
if cached_path:
|
||||||
|
logger.info(f"Step {step.step_id} already cached at {cached_path}")
|
||||||
|
|
||||||
|
# Mark as cached in claiming system
|
||||||
|
claimer = get_claimer()
|
||||||
|
claimer.mark_cached(step.cache_id, str(cached_path))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "cached",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": str(cached_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to claim the task
|
||||||
|
if not claim_task(step.cache_id, worker_id, task_id):
|
||||||
|
# Another worker is handling it
|
||||||
|
logger.info(f"Step {step.step_id} claimed by another worker, waiting...")
|
||||||
|
|
||||||
|
claimer = get_claimer()
|
||||||
|
result = claimer.wait_for_completion(step.cache_id, timeout=600)
|
||||||
|
|
||||||
|
if result and result.status == ClaimStatus.COMPLETED:
|
||||||
|
return {
|
||||||
|
"status": "completed_by_other",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": result.output_path,
|
||||||
|
}
|
||||||
|
elif result and result.status == ClaimStatus.CACHED:
|
||||||
|
return {
|
||||||
|
"status": "cached",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": result.output_path,
|
||||||
|
}
|
||||||
|
elif result and result.status == ClaimStatus.FAILED:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"error": result.error,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": "timeout",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"error": "Timeout waiting for other worker",
|
||||||
|
}
|
||||||
|
|
||||||
|
# We have the claim, update to running
|
||||||
|
claimer = get_claimer()
|
||||||
|
claimer.update_status(step.cache_id, worker_id, ClaimStatus.RUNNING)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle SOURCE nodes
|
||||||
|
if step.node_type == "SOURCE":
|
||||||
|
content_hash = step.config.get("content_hash")
|
||||||
|
if not content_hash:
|
||||||
|
raise ValueError(f"SOURCE step missing content_hash")
|
||||||
|
|
||||||
|
# Look up in cache
|
||||||
|
path = cache_mgr.get_by_content_hash(content_hash)
|
||||||
|
if not path:
|
||||||
|
raise ValueError(f"SOURCE input not found in cache: {content_hash[:16]}...")
|
||||||
|
|
||||||
|
output_path = str(path)
|
||||||
|
complete_task(step.cache_id, worker_id, output_path)
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": output_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle _LIST virtual nodes
|
||||||
|
if step.node_type == "_LIST":
|
||||||
|
item_paths = []
|
||||||
|
for item_id in step.config.get("items", []):
|
||||||
|
item_cache_id = input_cache_ids.get(item_id)
|
||||||
|
if item_cache_id:
|
||||||
|
path = cache_mgr.get_by_content_hash(item_cache_id)
|
||||||
|
if path:
|
||||||
|
item_paths.append(str(path))
|
||||||
|
|
||||||
|
complete_task(step.cache_id, worker_id, json.dumps(item_paths))
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": None,
|
||||||
|
"item_paths": item_paths,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get executor for this node type
|
||||||
|
try:
|
||||||
|
node_type = NodeType[step.node_type]
|
||||||
|
except KeyError:
|
||||||
|
node_type = step.node_type
|
||||||
|
|
||||||
|
executor = get_executor(node_type)
|
||||||
|
if executor is None:
|
||||||
|
raise ValueError(f"No executor for node type: {step.node_type}")
|
||||||
|
|
||||||
|
# Resolve input paths from cache
|
||||||
|
input_paths = []
|
||||||
|
for input_step_id in step.input_steps:
|
||||||
|
input_cache_id = input_cache_ids.get(input_step_id)
|
||||||
|
if not input_cache_id:
|
||||||
|
raise ValueError(f"No cache_id for input step: {input_step_id}")
|
||||||
|
|
||||||
|
path = cache_mgr.get_by_content_hash(input_cache_id)
|
||||||
|
if not path:
|
||||||
|
raise ValueError(f"Input not in cache: {input_cache_id[:16]}...")
|
||||||
|
|
||||||
|
input_paths.append(Path(path))
|
||||||
|
|
||||||
|
# Create temp output path
|
||||||
|
import tempfile
|
||||||
|
output_dir = Path(tempfile.mkdtemp())
|
||||||
|
output_path = output_dir / f"output_{step.cache_id[:16]}.mp4"
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
logger.info(f"Running executor for {step.node_type} with {len(input_paths)} inputs")
|
||||||
|
result_path = executor.execute(step.config, input_paths, output_path)
|
||||||
|
|
||||||
|
# Store in IPFS-backed cache
|
||||||
|
cached_file, ipfs_cid = cache_mgr.put(
|
||||||
|
source_path=result_path,
|
||||||
|
node_type=step.node_type,
|
||||||
|
node_id=step.cache_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Step {step.step_id} completed, IPFS CID: {ipfs_cid}")
|
||||||
|
|
||||||
|
# Mark completed
|
||||||
|
complete_task(step.cache_id, worker_id, str(cached_file.path))
|
||||||
|
|
||||||
|
# Cleanup temp
|
||||||
|
if output_dir.exists():
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(output_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": str(cached_file.path),
|
||||||
|
"content_hash": cached_file.content_hash,
|
||||||
|
"ipfs_cid": ipfs_cid,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Step {step.step_id} failed: {e}")
|
||||||
|
fail_task(step.cache_id, worker_id, str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.execute_level')
|
||||||
|
def execute_level(
|
||||||
|
self,
|
||||||
|
steps_json: List[str],
|
||||||
|
plan_id: str,
|
||||||
|
cache_ids: Dict[str, str],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Execute all steps at a given dependency level.
|
||||||
|
|
||||||
|
Steps at the same level can run in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps_json: List of JSON-serialized ExecutionSteps
|
||||||
|
plan_id: ID of the parent execution plan
|
||||||
|
cache_ids: Mapping from step_id to cache_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with results for all steps
|
||||||
|
"""
|
||||||
|
from celery import group
|
||||||
|
|
||||||
|
# Dispatch all steps in parallel
|
||||||
|
tasks = [
|
||||||
|
execute_step.s(step_json, plan_id, cache_ids)
|
||||||
|
for step_json in steps_json
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute in parallel and collect results
|
||||||
|
job = group(tasks)
|
||||||
|
results = job.apply_async()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
step_results = results.get(timeout=3600) # 1 hour timeout
|
||||||
|
|
||||||
|
# Build cache_ids from results
|
||||||
|
new_cache_ids = dict(cache_ids)
|
||||||
|
for result in step_results:
|
||||||
|
step_id = result.get("step_id")
|
||||||
|
cache_id = result.get("cache_id")
|
||||||
|
if step_id and cache_id:
|
||||||
|
new_cache_ids[step_id] = cache_id
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"results": step_results,
|
||||||
|
"cache_ids": new_cache_ids,
|
||||||
|
}
|
||||||
373
tasks/orchestrate.py
Normal file
373
tasks/orchestrate.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
"""
|
||||||
|
Plan orchestration tasks.
|
||||||
|
|
||||||
|
Coordinates the full 3-phase execution:
|
||||||
|
1. Analyze inputs
|
||||||
|
2. Generate plan
|
||||||
|
3. Execute steps level by level
|
||||||
|
|
||||||
|
Uses IPFS-backed cache for durability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from celery import current_task, group, chain
|
||||||
|
|
||||||
|
# Import from the Celery app
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from celery_app import app
|
||||||
|
from claiming import get_claimer
|
||||||
|
from cache_manager import get_cache_manager
|
||||||
|
|
||||||
|
# Import artdag modules
|
||||||
|
try:
|
||||||
|
from artdag import Cache
|
||||||
|
from artdag.analysis import Analyzer, AnalysisResult
|
||||||
|
from artdag.planning import RecipePlanner, ExecutionPlan, Recipe
|
||||||
|
except ImportError:
|
||||||
|
Cache = None
|
||||||
|
Analyzer = None
|
||||||
|
AnalysisResult = None
|
||||||
|
RecipePlanner = None
|
||||||
|
ExecutionPlan = None
|
||||||
|
Recipe = None
|
||||||
|
|
||||||
|
from .execute import execute_step
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache directories
|
||||||
|
CACHE_DIR = Path(os.environ.get('CACHE_DIR', '/data/cache'))
|
||||||
|
ANALYSIS_CACHE_DIR = CACHE_DIR / 'analysis'
|
||||||
|
PLAN_CACHE_DIR = CACHE_DIR / 'plans'
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.run_plan')
|
||||||
|
def run_plan(
|
||||||
|
self,
|
||||||
|
plan_json: str,
|
||||||
|
run_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Execute a complete execution plan.
|
||||||
|
|
||||||
|
Runs steps level by level, with parallel execution within each level.
|
||||||
|
Results are stored in IPFS-backed cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan_json: JSON-serialized ExecutionPlan
|
||||||
|
run_id: Optional run ID for tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with execution results
|
||||||
|
"""
|
||||||
|
if ExecutionPlan is None:
|
||||||
|
raise ImportError("artdag.planning not available")
|
||||||
|
|
||||||
|
plan = ExecutionPlan.from_json(plan_json)
|
||||||
|
cache_mgr = get_cache_manager()
|
||||||
|
|
||||||
|
logger.info(f"Executing plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
|
||||||
|
|
||||||
|
# Build initial cache_ids mapping (step_id -> cache_id)
|
||||||
|
cache_ids = {}
|
||||||
|
for step in plan.steps:
|
||||||
|
cache_ids[step.step_id] = step.cache_id
|
||||||
|
|
||||||
|
# Also map input hashes
|
||||||
|
for name, content_hash in plan.input_hashes.items():
|
||||||
|
cache_ids[name] = content_hash
|
||||||
|
|
||||||
|
# Group steps by level
|
||||||
|
steps_by_level = plan.get_steps_by_level()
|
||||||
|
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||||
|
|
||||||
|
results_by_step = {}
|
||||||
|
total_cached = 0
|
||||||
|
total_executed = 0
|
||||||
|
|
||||||
|
for level in range(max_level + 1):
|
||||||
|
level_steps = steps_by_level.get(level, [])
|
||||||
|
if not level_steps:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Executing level {level}: {len(level_steps)} steps")
|
||||||
|
|
||||||
|
# Check which steps need execution
|
||||||
|
steps_to_run = []
|
||||||
|
|
||||||
|
for step in level_steps:
|
||||||
|
# Check if cached
|
||||||
|
cached_path = cache_mgr.get_by_content_hash(step.cache_id)
|
||||||
|
if cached_path:
|
||||||
|
results_by_step[step.step_id] = {
|
||||||
|
"status": "cached",
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"output_path": str(cached_path),
|
||||||
|
}
|
||||||
|
total_cached += 1
|
||||||
|
else:
|
||||||
|
steps_to_run.append(step)
|
||||||
|
|
||||||
|
if not steps_to_run:
|
||||||
|
logger.info(f"Level {level}: all steps cached")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build input cache_ids for this level
|
||||||
|
level_cache_ids = dict(cache_ids)
|
||||||
|
|
||||||
|
# Execute steps in parallel
|
||||||
|
tasks = [
|
||||||
|
execute_step.s(step.to_json(), plan.plan_id, level_cache_ids)
|
||||||
|
for step in steps_to_run
|
||||||
|
]
|
||||||
|
|
||||||
|
job = group(tasks)
|
||||||
|
async_results = job.apply_async()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
try:
|
||||||
|
step_results = async_results.get(timeout=3600)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Level {level} execution failed: {e}")
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
"level": level,
|
||||||
|
"results": results_by_step,
|
||||||
|
"run_id": run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
for result in step_results:
|
||||||
|
step_id = result.get("step_id")
|
||||||
|
cache_id = result.get("cache_id")
|
||||||
|
|
||||||
|
results_by_step[step_id] = result
|
||||||
|
cache_ids[step_id] = cache_id
|
||||||
|
|
||||||
|
if result.get("status") in ("completed", "cached", "completed_by_other"):
|
||||||
|
total_executed += 1
|
||||||
|
elif result.get("status") == "failed":
|
||||||
|
logger.error(f"Step {step_id} failed: {result.get('error')}")
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": f"Step {step_id} failed: {result.get('error')}",
|
||||||
|
"level": level,
|
||||||
|
"results": results_by_step,
|
||||||
|
"run_id": run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get final output
|
||||||
|
output_step = plan.get_step(plan.output_step)
|
||||||
|
output_cache_id = output_step.cache_id if output_step else None
|
||||||
|
output_path = None
|
||||||
|
output_ipfs_cid = None
|
||||||
|
|
||||||
|
if output_cache_id:
|
||||||
|
output_path = cache_mgr.get_by_content_hash(output_cache_id)
|
||||||
|
output_ipfs_cid = cache_mgr.get_ipfs_cid(output_cache_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"plan_id": plan.plan_id,
|
||||||
|
"output_cache_id": output_cache_id,
|
||||||
|
"output_path": str(output_path) if output_path else None,
|
||||||
|
"output_ipfs_cid": output_ipfs_cid,
|
||||||
|
"total_steps": len(plan.steps),
|
||||||
|
"cached": total_cached,
|
||||||
|
"executed": total_executed,
|
||||||
|
"results": results_by_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.run_recipe')
|
||||||
|
def run_recipe(
|
||||||
|
self,
|
||||||
|
recipe_yaml: str,
|
||||||
|
input_hashes: Dict[str, str],
|
||||||
|
features: List[str] = None,
|
||||||
|
run_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Run a complete recipe through all 3 phases.
|
||||||
|
|
||||||
|
1. Analyze: Extract features from inputs
|
||||||
|
2. Plan: Generate execution plan
|
||||||
|
3. Execute: Run the plan
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recipe_yaml: Recipe YAML content
|
||||||
|
input_hashes: Mapping from input name to content hash
|
||||||
|
features: Features to extract (default: ["beats", "energy"])
|
||||||
|
run_id: Optional run ID for tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with final results
|
||||||
|
"""
|
||||||
|
if RecipePlanner is None or Analyzer is None:
|
||||||
|
raise ImportError("artdag modules not available")
|
||||||
|
|
||||||
|
if features is None:
|
||||||
|
features = ["beats", "energy"]
|
||||||
|
|
||||||
|
cache_mgr = get_cache_manager()
|
||||||
|
|
||||||
|
logger.info(f"Running recipe with {len(input_hashes)} inputs")
|
||||||
|
|
||||||
|
# Phase 1: Analyze
|
||||||
|
logger.info("Phase 1: Analyzing inputs...")
|
||||||
|
|
||||||
|
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
|
||||||
|
|
||||||
|
analysis_results = {}
|
||||||
|
for name, content_hash in input_hashes.items():
|
||||||
|
# Get path from cache
|
||||||
|
path = cache_mgr.get_by_content_hash(content_hash)
|
||||||
|
if path:
|
||||||
|
try:
|
||||||
|
result = analyzer.analyze(
|
||||||
|
input_hash=content_hash,
|
||||||
|
features=features,
|
||||||
|
input_path=Path(path),
|
||||||
|
)
|
||||||
|
analysis_results[content_hash] = result
|
||||||
|
logger.info(f"Analyzed {name}: tempo={result.tempo}, beats={len(result.beat_times or [])}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Analysis failed for {name}: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Input {name} ({content_hash[:16]}...) not in cache")
|
||||||
|
|
||||||
|
logger.info(f"Analyzed {len(analysis_results)} inputs")
|
||||||
|
|
||||||
|
# Phase 2: Plan
|
||||||
|
logger.info("Phase 2: Generating execution plan...")
|
||||||
|
|
||||||
|
recipe = Recipe.from_yaml(recipe_yaml)
|
||||||
|
planner = RecipePlanner(use_tree_reduction=True)
|
||||||
|
|
||||||
|
plan = planner.plan(
|
||||||
|
recipe=recipe,
|
||||||
|
input_hashes=input_hashes,
|
||||||
|
analysis=analysis_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated plan with {len(plan.steps)} steps")
|
||||||
|
|
||||||
|
# Save plan for debugging
|
||||||
|
PLAN_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
plan_path = PLAN_CACHE_DIR / f"{plan.plan_id}.json"
|
||||||
|
with open(plan_path, "w") as f:
|
||||||
|
f.write(plan.to_json())
|
||||||
|
|
||||||
|
# Phase 3: Execute
|
||||||
|
logger.info("Phase 3: Executing plan...")
|
||||||
|
|
||||||
|
result = run_plan(plan.to_json(), run_id=run_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": result.get("status"),
|
||||||
|
"run_id": run_id,
|
||||||
|
"recipe": recipe.name,
|
||||||
|
"plan_id": plan.plan_id,
|
||||||
|
"output_path": result.get("output_path"),
|
||||||
|
"output_cache_id": result.get("output_cache_id"),
|
||||||
|
"output_ipfs_cid": result.get("output_ipfs_cid"),
|
||||||
|
"analysis_count": len(analysis_results),
|
||||||
|
"total_steps": len(plan.steps),
|
||||||
|
"cached": result.get("cached", 0),
|
||||||
|
"executed": result.get("executed", 0),
|
||||||
|
"error": result.get("error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name='tasks.generate_plan')
|
||||||
|
def generate_plan(
|
||||||
|
self,
|
||||||
|
recipe_yaml: str,
|
||||||
|
input_hashes: Dict[str, str],
|
||||||
|
features: List[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate an execution plan without executing it.
|
||||||
|
|
||||||
|
Useful for:
|
||||||
|
- Previewing what will be executed
|
||||||
|
- Checking cache status
|
||||||
|
- Debugging recipe issues
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recipe_yaml: Recipe YAML content
|
||||||
|
input_hashes: Mapping from input name to content hash
|
||||||
|
features: Features to extract for analysis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with plan details
|
||||||
|
"""
|
||||||
|
if RecipePlanner is None or Analyzer is None:
|
||||||
|
raise ImportError("artdag modules not available")
|
||||||
|
|
||||||
|
if features is None:
|
||||||
|
features = ["beats", "energy"]
|
||||||
|
|
||||||
|
cache_mgr = get_cache_manager()
|
||||||
|
|
||||||
|
# Analyze inputs
|
||||||
|
ANALYSIS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
analyzer = Analyzer(cache_dir=ANALYSIS_CACHE_DIR)
|
||||||
|
|
||||||
|
analysis_results = {}
|
||||||
|
for name, content_hash in input_hashes.items():
|
||||||
|
path = cache_mgr.get_by_content_hash(content_hash)
|
||||||
|
if path:
|
||||||
|
try:
|
||||||
|
result = analyzer.analyze(
|
||||||
|
input_hash=content_hash,
|
||||||
|
features=features,
|
||||||
|
input_path=Path(path),
|
||||||
|
)
|
||||||
|
analysis_results[content_hash] = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Analysis failed for {name}: {e}")
|
||||||
|
|
||||||
|
# Generate plan
|
||||||
|
recipe = Recipe.from_yaml(recipe_yaml)
|
||||||
|
planner = RecipePlanner(use_tree_reduction=True)
|
||||||
|
|
||||||
|
plan = planner.plan(
|
||||||
|
recipe=recipe,
|
||||||
|
input_hashes=input_hashes,
|
||||||
|
analysis=analysis_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check cache status for each step
|
||||||
|
steps_status = []
|
||||||
|
for step in plan.steps:
|
||||||
|
cached = cache_mgr.has_content(step.cache_id)
|
||||||
|
steps_status.append({
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"node_type": step.node_type,
|
||||||
|
"cache_id": step.cache_id,
|
||||||
|
"level": step.level,
|
||||||
|
"cached": cached,
|
||||||
|
})
|
||||||
|
|
||||||
|
cached_count = sum(1 for s in steps_status if s["cached"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "planned",
|
||||||
|
"recipe": recipe.name,
|
||||||
|
"plan_id": plan.plan_id,
|
||||||
|
"total_steps": len(plan.steps),
|
||||||
|
"cached_steps": cached_count,
|
||||||
|
"pending_steps": len(plan.steps) - cached_count,
|
||||||
|
"steps": steps_status,
|
||||||
|
"plan_json": plan.to_json(),
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user