Add multi-step DAG execution support
tasks.py: - Import artdag DAG, Node, Engine, Executor - Register executors for effect:dog, effect:identity, SOURCE - Add execute_dag task for running full DAG workflows - Add build_effect_dag helper for simple effect-to-DAG conversion server.py: - Add use_dag and dag_json fields to RunRequest - Update create_run to support DAG mode - Handle both legacy render_effect and new execute_dag result formats - Import new tasks (execute_dag, build_effect_dag) The DAG engine executes nodes in topological order with automatic caching. This enables multi-step pipelines like: source -> effect1 -> effect2 -> output. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
72
server.py
72
server.py
@@ -26,7 +26,7 @@ import requests as http_requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from celery_app import app as celery_app
|
||||
from tasks import render_effect
|
||||
from tasks import render_effect, execute_dag, build_effect_dag
|
||||
from cache_manager import L1CacheManager, get_cache_manager
|
||||
|
||||
# L2 server for auth verification
|
||||
@@ -98,9 +98,11 @@ app = FastAPI(
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
"""Request to start a run."""
|
||||
recipe: str # Recipe name (e.g., "dog", "identity")
|
||||
recipe: str # Recipe name (e.g., "dog", "identity") or "dag" for custom DAG
|
||||
inputs: list[str] # List of content hashes
|
||||
output_name: Optional[str] = None
|
||||
use_dag: bool = False # Use DAG engine instead of legacy effect runner
|
||||
dag_json: Optional[str] = None # Custom DAG JSON (required if recipe="dag")
|
||||
|
||||
|
||||
class RunStatus(BaseModel):
|
||||
@@ -301,13 +303,25 @@ async def create_run(request: RunRequest, username: str = Depends(get_required_u
|
||||
)
|
||||
|
||||
# Submit to Celery
|
||||
# For now, we only support single-input recipes
|
||||
if len(request.inputs) != 1:
|
||||
raise HTTPException(400, "Currently only single-input recipes supported")
|
||||
if request.use_dag or request.recipe == "dag":
|
||||
# DAG mode - use artdag engine
|
||||
if request.dag_json:
|
||||
# Custom DAG provided
|
||||
dag_json = request.dag_json
|
||||
else:
|
||||
# Build simple effect DAG from recipe and inputs
|
||||
dag = build_effect_dag(request.inputs, request.recipe)
|
||||
dag_json = dag.to_json()
|
||||
|
||||
input_hash = request.inputs[0]
|
||||
task = execute_dag.delay(dag_json, run.run_id)
|
||||
else:
|
||||
# Legacy mode - single effect
|
||||
if len(request.inputs) != 1:
|
||||
raise HTTPException(400, "Legacy mode only supports single-input recipes. Use use_dag=true for multi-input.")
|
||||
|
||||
input_hash = request.inputs[0]
|
||||
task = render_effect.delay(input_hash, request.recipe, output_name)
|
||||
|
||||
task = render_effect.delay(input_hash, request.recipe, output_name)
|
||||
run.celery_task_id = task.id
|
||||
run.status = "running"
|
||||
|
||||
@@ -331,29 +345,37 @@ async def get_run(run_id: str):
|
||||
result = task.result
|
||||
run.status = "completed"
|
||||
run.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
run.output_hash = result.get("output", {}).get("content_hash")
|
||||
|
||||
# Extract effects info from provenance
|
||||
effects = result.get("effects", [])
|
||||
if effects:
|
||||
run.effects_commit = effects[0].get("repo_commit")
|
||||
run.effect_url = effects[0].get("repo_url")
|
||||
# Handle both legacy (render_effect) and new (execute_dag) result formats
|
||||
if "output_hash" in result:
|
||||
# New DAG result format
|
||||
run.output_hash = result.get("output_hash")
|
||||
output_path = Path(result.get("output_path", "")) if result.get("output_path") else None
|
||||
else:
|
||||
# Legacy render_effect format
|
||||
run.output_hash = result.get("output", {}).get("content_hash")
|
||||
output_path = Path(result.get("output", {}).get("local_path", ""))
|
||||
|
||||
# Extract infrastructure info
|
||||
run.infrastructure = result.get("infrastructure")
|
||||
# Extract effects info from provenance (legacy only)
|
||||
effects = result.get("effects", [])
|
||||
if effects:
|
||||
run.effects_commit = effects[0].get("repo_commit")
|
||||
run.effect_url = effects[0].get("repo_url")
|
||||
|
||||
# Cache the output
|
||||
output_path = Path(result.get("output", {}).get("local_path", ""))
|
||||
if output_path.exists():
|
||||
# Extract infrastructure info (legacy only)
|
||||
run.infrastructure = result.get("infrastructure")
|
||||
|
||||
# Cache the output (legacy mode - DAG already caches via cache_manager)
|
||||
if output_path and output_path.exists() and "output_hash" not in result:
|
||||
cache_file(output_path, node_type="effect_output")
|
||||
|
||||
# Record activity for deletion tracking
|
||||
if run.output_hash and run.inputs:
|
||||
cache_manager.record_simple_activity(
|
||||
input_hashes=run.inputs,
|
||||
output_hash=run.output_hash,
|
||||
run_id=run.run_id,
|
||||
)
|
||||
# Record activity for deletion tracking (legacy mode)
|
||||
if run.output_hash and run.inputs:
|
||||
cache_manager.record_simple_activity(
|
||||
input_hashes=run.inputs,
|
||||
output_hash=run.output_hash,
|
||||
run_id=run.run_id,
|
||||
)
|
||||
else:
|
||||
run.status = "failed"
|
||||
run.error = str(task.result)
|
||||
|
||||
191
tasks.py
191
tasks.py
@@ -2,23 +2,33 @@
|
||||
Art DAG Celery Tasks
|
||||
|
||||
Distributed rendering tasks for the Art DAG system.
|
||||
Supports both single-effect runs and multi-step DAG execution.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from celery import Task
|
||||
from celery_app import app
|
||||
|
||||
# Import artdag components
|
||||
from artdag import DAG, Node, NodeType
|
||||
from artdag.engine import Engine
|
||||
from artdag.executor import register_executor, Executor, get_executor
|
||||
|
||||
# Add effects to path (use env var in Docker, fallback to home dir locally)
|
||||
EFFECTS_PATH = Path(os.environ.get("EFFECTS_PATH", str(Path.home() / "artdag-effects")))
|
||||
ARTDAG_PATH = Path(os.environ.get("ARTDAG_PATH", str(Path.home() / "art" / "artdag")))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_effects_commit() -> str:
|
||||
"""Get current git commit hash of effects repo."""
|
||||
@@ -65,6 +75,60 @@ def file_hash(path: Path) -> str:
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
# Cache directory (shared between server and worker)
|
||||
CACHE_DIR = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache")))
|
||||
|
||||
|
||||
# ============ Executors for Effects ============
|
||||
|
||||
@register_executor("effect:dog")
|
||||
class DogExecutor(Executor):
|
||||
"""Executor for the dog effect."""
|
||||
|
||||
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||
from effect import effect_dog
|
||||
if len(inputs) != 1:
|
||||
raise ValueError(f"Dog effect expects 1 input, got {len(inputs)}")
|
||||
return effect_dog(inputs[0], output_path, config)
|
||||
|
||||
|
||||
@register_executor("effect:identity")
|
||||
class IdentityExecutor(Executor):
|
||||
"""Executor for the identity effect (passthrough)."""
|
||||
|
||||
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||
from artdag.nodes.effect import effect_identity
|
||||
if len(inputs) != 1:
|
||||
raise ValueError(f"Identity effect expects 1 input, got {len(inputs)}")
|
||||
return effect_identity(inputs[0], output_path, config)
|
||||
|
||||
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class SourceExecutor(Executor):
|
||||
"""Executor for SOURCE nodes - loads content from cache by hash."""
|
||||
|
||||
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||
# Source nodes load from cache by content_hash
|
||||
content_hash = config.get("content_hash")
|
||||
if not content_hash:
|
||||
raise ValueError("SOURCE node requires content_hash in config")
|
||||
|
||||
# Look up in cache
|
||||
source_path = CACHE_DIR / content_hash
|
||||
if not source_path.exists():
|
||||
# Try nodes directory
|
||||
from cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
source_path = cache_manager.get_by_content_hash(content_hash)
|
||||
|
||||
if not source_path or not source_path.exists():
|
||||
raise ValueError(f"Source content not in cache: {content_hash}")
|
||||
|
||||
# For source nodes, we just return the path (no transformation)
|
||||
# The engine will use this as input to subsequent nodes
|
||||
return source_path
|
||||
|
||||
|
||||
class RenderTask(Task):
|
||||
"""Base task with provenance tracking."""
|
||||
|
||||
@@ -197,3 +261,130 @@ def render_dog_from_cat() -> dict:
|
||||
"""Convenience task: render cat through dog effect."""
|
||||
CAT_HASH = "33268b6e167deaf018cc538de12dbe562612b33e89a749391cef855b320a269b"
|
||||
return render_effect.delay(CAT_HASH, "dog", "dog-from-cat-celery").get()
|
||||
|
||||
|
||||
@app.task(base=RenderTask, bind=True)
|
||||
def execute_dag(self, dag_json: str, run_id: str = None) -> dict:
|
||||
"""
|
||||
Execute a multi-step DAG.
|
||||
|
||||
Args:
|
||||
dag_json: Serialized DAG as JSON string
|
||||
run_id: Optional run ID for tracking
|
||||
|
||||
Returns:
|
||||
Execution result with output hash and node results
|
||||
"""
|
||||
from cache_manager import get_cache_manager
|
||||
|
||||
# Parse DAG
|
||||
try:
|
||||
dag = DAG.from_json(dag_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid DAG JSON: {e}")
|
||||
|
||||
# Validate DAG
|
||||
errors = dag.validate()
|
||||
if errors:
|
||||
raise ValueError(f"Invalid DAG: {errors}")
|
||||
|
||||
# Create engine with cache directory
|
||||
engine = Engine(CACHE_DIR / "nodes")
|
||||
|
||||
# Set up progress callback
|
||||
def progress_callback(progress):
|
||||
self.update_state(
|
||||
state='EXECUTING',
|
||||
meta={
|
||||
'node_id': progress.node_id,
|
||||
'node_type': progress.node_type,
|
||||
'status': progress.status,
|
||||
'progress': progress.progress,
|
||||
'message': progress.message,
|
||||
}
|
||||
)
|
||||
logger.info(f"DAG progress: {progress.node_id} - {progress.status} - {progress.message}")
|
||||
|
||||
engine.set_progress_callback(progress_callback)
|
||||
|
||||
# Execute DAG
|
||||
self.update_state(state='EXECUTING', meta={'status': 'starting', 'nodes': len(dag.nodes)})
|
||||
result = engine.execute(dag)
|
||||
|
||||
if not result.success:
|
||||
raise RuntimeError(f"DAG execution failed: {result.error}")
|
||||
|
||||
# Get output hash
|
||||
cache_manager = get_cache_manager()
|
||||
output_hash = None
|
||||
if result.output_path and result.output_path.exists():
|
||||
output_hash = file_hash(result.output_path)
|
||||
|
||||
# Store in cache_manager for proper tracking
|
||||
cached = cache_manager.put(result.output_path, node_type="dag_output")
|
||||
|
||||
# Record activity for deletion tracking
|
||||
input_hashes = []
|
||||
for node_id, node in dag.nodes.items():
|
||||
if node.node_type == NodeType.SOURCE or str(node.node_type) == "SOURCE":
|
||||
content_hash = node.config.get("content_hash")
|
||||
if content_hash:
|
||||
input_hashes.append(content_hash)
|
||||
|
||||
if input_hashes:
|
||||
cache_manager.record_simple_activity(
|
||||
input_hashes=input_hashes,
|
||||
output_hash=output_hash,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
# Build result
|
||||
return {
|
||||
"success": True,
|
||||
"run_id": run_id,
|
||||
"output_hash": output_hash,
|
||||
"output_path": str(result.output_path) if result.output_path else None,
|
||||
"execution_time": result.execution_time,
|
||||
"nodes_executed": result.nodes_executed,
|
||||
"nodes_cached": result.nodes_cached,
|
||||
"node_results": {
|
||||
node_id: str(path) for node_id, path in result.node_results.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_effect_dag(input_hashes: List[str], effect_name: str) -> DAG:
|
||||
"""
|
||||
Build a simple DAG for applying an effect to inputs.
|
||||
|
||||
Args:
|
||||
input_hashes: List of input content hashes
|
||||
effect_name: Name of effect to apply (e.g., "dog", "identity")
|
||||
|
||||
Returns:
|
||||
DAG ready for execution
|
||||
"""
|
||||
dag = DAG()
|
||||
|
||||
# Add source nodes for each input
|
||||
source_ids = []
|
||||
for i, content_hash in enumerate(input_hashes):
|
||||
source_node = Node(
|
||||
node_type=NodeType.SOURCE,
|
||||
config={"content_hash": content_hash},
|
||||
name=f"source_{i}",
|
||||
)
|
||||
dag.add_node(source_node)
|
||||
source_ids.append(source_node.node_id)
|
||||
|
||||
# Add effect node
|
||||
effect_node = Node(
|
||||
node_type=f"effect:{effect_name}",
|
||||
config={},
|
||||
inputs=source_ids,
|
||||
name=f"effect_{effect_name}",
|
||||
)
|
||||
dag.add_node(effect_node)
|
||||
dag.set_output(effect_node.node_id)
|
||||
|
||||
return dag
|
||||
|
||||
Reference in New Issue
Block a user