Add multi-step DAG execution support
tasks.py: - Import artdag DAG, Node, Engine, Executor - Register executors for effect:dog, effect:identity, SOURCE - Add execute_dag task for running full DAG workflows - Add build_effect_dag helper for simple effect-to-DAG conversion server.py: - Add use_dag and dag_json fields to RunRequest - Update create_run to support DAG mode - Handle both legacy render_effect and new execute_dag result formats - Import new tasks (execute_dag, build_effect_dag) The DAG engine executes nodes in topological order with automatic caching. This enables multi-step pipelines like: source -> effect1 -> effect2 -> output. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
72
server.py
72
server.py
@@ -26,7 +26,7 @@ import requests as http_requests
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from celery_app import app as celery_app
|
from celery_app import app as celery_app
|
||||||
from tasks import render_effect
|
from tasks import render_effect, execute_dag, build_effect_dag
|
||||||
from cache_manager import L1CacheManager, get_cache_manager
|
from cache_manager import L1CacheManager, get_cache_manager
|
||||||
|
|
||||||
# L2 server for auth verification
|
# L2 server for auth verification
|
||||||
@@ -98,9 +98,11 @@ app = FastAPI(
|
|||||||
|
|
||||||
class RunRequest(BaseModel):
|
class RunRequest(BaseModel):
|
||||||
"""Request to start a run."""
|
"""Request to start a run."""
|
||||||
recipe: str # Recipe name (e.g., "dog", "identity")
|
recipe: str # Recipe name (e.g., "dog", "identity") or "dag" for custom DAG
|
||||||
inputs: list[str] # List of content hashes
|
inputs: list[str] # List of content hashes
|
||||||
output_name: Optional[str] = None
|
output_name: Optional[str] = None
|
||||||
|
use_dag: bool = False # Use DAG engine instead of legacy effect runner
|
||||||
|
dag_json: Optional[str] = None # Custom DAG JSON (required if recipe="dag")
|
||||||
|
|
||||||
|
|
||||||
class RunStatus(BaseModel):
|
class RunStatus(BaseModel):
|
||||||
@@ -301,13 +303,25 @@ async def create_run(request: RunRequest, username: str = Depends(get_required_u
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Submit to Celery
|
# Submit to Celery
|
||||||
# For now, we only support single-input recipes
|
if request.use_dag or request.recipe == "dag":
|
||||||
if len(request.inputs) != 1:
|
# DAG mode - use artdag engine
|
||||||
raise HTTPException(400, "Currently only single-input recipes supported")
|
if request.dag_json:
|
||||||
|
# Custom DAG provided
|
||||||
|
dag_json = request.dag_json
|
||||||
|
else:
|
||||||
|
# Build simple effect DAG from recipe and inputs
|
||||||
|
dag = build_effect_dag(request.inputs, request.recipe)
|
||||||
|
dag_json = dag.to_json()
|
||||||
|
|
||||||
input_hash = request.inputs[0]
|
task = execute_dag.delay(dag_json, run.run_id)
|
||||||
|
else:
|
||||||
|
# Legacy mode - single effect
|
||||||
|
if len(request.inputs) != 1:
|
||||||
|
raise HTTPException(400, "Legacy mode only supports single-input recipes. Use use_dag=true for multi-input.")
|
||||||
|
|
||||||
|
input_hash = request.inputs[0]
|
||||||
|
task = render_effect.delay(input_hash, request.recipe, output_name)
|
||||||
|
|
||||||
task = render_effect.delay(input_hash, request.recipe, output_name)
|
|
||||||
run.celery_task_id = task.id
|
run.celery_task_id = task.id
|
||||||
run.status = "running"
|
run.status = "running"
|
||||||
|
|
||||||
@@ -331,29 +345,37 @@ async def get_run(run_id: str):
|
|||||||
result = task.result
|
result = task.result
|
||||||
run.status = "completed"
|
run.status = "completed"
|
||||||
run.completed_at = datetime.now(timezone.utc).isoformat()
|
run.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
run.output_hash = result.get("output", {}).get("content_hash")
|
|
||||||
|
|
||||||
# Extract effects info from provenance
|
# Handle both legacy (render_effect) and new (execute_dag) result formats
|
||||||
effects = result.get("effects", [])
|
if "output_hash" in result:
|
||||||
if effects:
|
# New DAG result format
|
||||||
run.effects_commit = effects[0].get("repo_commit")
|
run.output_hash = result.get("output_hash")
|
||||||
run.effect_url = effects[0].get("repo_url")
|
output_path = Path(result.get("output_path", "")) if result.get("output_path") else None
|
||||||
|
else:
|
||||||
|
# Legacy render_effect format
|
||||||
|
run.output_hash = result.get("output", {}).get("content_hash")
|
||||||
|
output_path = Path(result.get("output", {}).get("local_path", ""))
|
||||||
|
|
||||||
# Extract infrastructure info
|
# Extract effects info from provenance (legacy only)
|
||||||
run.infrastructure = result.get("infrastructure")
|
effects = result.get("effects", [])
|
||||||
|
if effects:
|
||||||
|
run.effects_commit = effects[0].get("repo_commit")
|
||||||
|
run.effect_url = effects[0].get("repo_url")
|
||||||
|
|
||||||
# Cache the output
|
# Extract infrastructure info (legacy only)
|
||||||
output_path = Path(result.get("output", {}).get("local_path", ""))
|
run.infrastructure = result.get("infrastructure")
|
||||||
if output_path.exists():
|
|
||||||
|
# Cache the output (legacy mode - DAG already caches via cache_manager)
|
||||||
|
if output_path and output_path.exists() and "output_hash" not in result:
|
||||||
cache_file(output_path, node_type="effect_output")
|
cache_file(output_path, node_type="effect_output")
|
||||||
|
|
||||||
# Record activity for deletion tracking
|
# Record activity for deletion tracking (legacy mode)
|
||||||
if run.output_hash and run.inputs:
|
if run.output_hash and run.inputs:
|
||||||
cache_manager.record_simple_activity(
|
cache_manager.record_simple_activity(
|
||||||
input_hashes=run.inputs,
|
input_hashes=run.inputs,
|
||||||
output_hash=run.output_hash,
|
output_hash=run.output_hash,
|
||||||
run_id=run.run_id,
|
run_id=run.run_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run.status = "failed"
|
run.status = "failed"
|
||||||
run.error = str(task.result)
|
run.error = str(task.result)
|
||||||
|
|||||||
191
tasks.py
191
tasks.py
@@ -2,23 +2,33 @@
|
|||||||
Art DAG Celery Tasks
|
Art DAG Celery Tasks
|
||||||
|
|
||||||
Distributed rendering tasks for the Art DAG system.
|
Distributed rendering tasks for the Art DAG system.
|
||||||
|
Supports both single-effect runs and multi-step DAG execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery_app import app
|
from celery_app import app
|
||||||
|
|
||||||
|
# Import artdag components
|
||||||
|
from artdag import DAG, Node, NodeType
|
||||||
|
from artdag.engine import Engine
|
||||||
|
from artdag.executor import register_executor, Executor, get_executor
|
||||||
|
|
||||||
# Add effects to path (use env var in Docker, fallback to home dir locally)
|
# Add effects to path (use env var in Docker, fallback to home dir locally)
|
||||||
EFFECTS_PATH = Path(os.environ.get("EFFECTS_PATH", str(Path.home() / "artdag-effects")))
|
EFFECTS_PATH = Path(os.environ.get("EFFECTS_PATH", str(Path.home() / "artdag-effects")))
|
||||||
ARTDAG_PATH = Path(os.environ.get("ARTDAG_PATH", str(Path.home() / "art" / "artdag")))
|
ARTDAG_PATH = Path(os.environ.get("ARTDAG_PATH", str(Path.home() / "art" / "artdag")))
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_effects_commit() -> str:
|
def get_effects_commit() -> str:
|
||||||
"""Get current git commit hash of effects repo."""
|
"""Get current git commit hash of effects repo."""
|
||||||
@@ -65,6 +75,60 @@ def file_hash(path: Path) -> str:
|
|||||||
return hasher.hexdigest()
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# Cache directory (shared between server and worker)
|
||||||
|
CACHE_DIR = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache")))
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Executors for Effects ============
|
||||||
|
|
||||||
|
@register_executor("effect:dog")
|
||||||
|
class DogExecutor(Executor):
|
||||||
|
"""Executor for the dog effect."""
|
||||||
|
|
||||||
|
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||||
|
from effect import effect_dog
|
||||||
|
if len(inputs) != 1:
|
||||||
|
raise ValueError(f"Dog effect expects 1 input, got {len(inputs)}")
|
||||||
|
return effect_dog(inputs[0], output_path, config)
|
||||||
|
|
||||||
|
|
||||||
|
@register_executor("effect:identity")
|
||||||
|
class IdentityExecutor(Executor):
|
||||||
|
"""Executor for the identity effect (passthrough)."""
|
||||||
|
|
||||||
|
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||||
|
from artdag.nodes.effect import effect_identity
|
||||||
|
if len(inputs) != 1:
|
||||||
|
raise ValueError(f"Identity effect expects 1 input, got {len(inputs)}")
|
||||||
|
return effect_identity(inputs[0], output_path, config)
|
||||||
|
|
||||||
|
|
||||||
|
@register_executor(NodeType.SOURCE)
|
||||||
|
class SourceExecutor(Executor):
|
||||||
|
"""Executor for SOURCE nodes - loads content from cache by hash."""
|
||||||
|
|
||||||
|
def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path:
|
||||||
|
# Source nodes load from cache by content_hash
|
||||||
|
content_hash = config.get("content_hash")
|
||||||
|
if not content_hash:
|
||||||
|
raise ValueError("SOURCE node requires content_hash in config")
|
||||||
|
|
||||||
|
# Look up in cache
|
||||||
|
source_path = CACHE_DIR / content_hash
|
||||||
|
if not source_path.exists():
|
||||||
|
# Try nodes directory
|
||||||
|
from cache_manager import get_cache_manager
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
source_path = cache_manager.get_by_content_hash(content_hash)
|
||||||
|
|
||||||
|
if not source_path or not source_path.exists():
|
||||||
|
raise ValueError(f"Source content not in cache: {content_hash}")
|
||||||
|
|
||||||
|
# For source nodes, we just return the path (no transformation)
|
||||||
|
# The engine will use this as input to subsequent nodes
|
||||||
|
return source_path
|
||||||
|
|
||||||
|
|
||||||
class RenderTask(Task):
|
class RenderTask(Task):
|
||||||
"""Base task with provenance tracking."""
|
"""Base task with provenance tracking."""
|
||||||
|
|
||||||
@@ -197,3 +261,130 @@ def render_dog_from_cat() -> dict:
|
|||||||
"""Convenience task: render cat through dog effect."""
|
"""Convenience task: render cat through dog effect."""
|
||||||
CAT_HASH = "33268b6e167deaf018cc538de12dbe562612b33e89a749391cef855b320a269b"
|
CAT_HASH = "33268b6e167deaf018cc538de12dbe562612b33e89a749391cef855b320a269b"
|
||||||
return render_effect.delay(CAT_HASH, "dog", "dog-from-cat-celery").get()
|
return render_effect.delay(CAT_HASH, "dog", "dog-from-cat-celery").get()
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(base=RenderTask, bind=True)
|
||||||
|
def execute_dag(self, dag_json: str, run_id: str = None) -> dict:
|
||||||
|
"""
|
||||||
|
Execute a multi-step DAG.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag_json: Serialized DAG as JSON string
|
||||||
|
run_id: Optional run ID for tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Execution result with output hash and node results
|
||||||
|
"""
|
||||||
|
from cache_manager import get_cache_manager
|
||||||
|
|
||||||
|
# Parse DAG
|
||||||
|
try:
|
||||||
|
dag = DAG.from_json(dag_json)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid DAG JSON: {e}")
|
||||||
|
|
||||||
|
# Validate DAG
|
||||||
|
errors = dag.validate()
|
||||||
|
if errors:
|
||||||
|
raise ValueError(f"Invalid DAG: {errors}")
|
||||||
|
|
||||||
|
# Create engine with cache directory
|
||||||
|
engine = Engine(CACHE_DIR / "nodes")
|
||||||
|
|
||||||
|
# Set up progress callback
|
||||||
|
def progress_callback(progress):
|
||||||
|
self.update_state(
|
||||||
|
state='EXECUTING',
|
||||||
|
meta={
|
||||||
|
'node_id': progress.node_id,
|
||||||
|
'node_type': progress.node_type,
|
||||||
|
'status': progress.status,
|
||||||
|
'progress': progress.progress,
|
||||||
|
'message': progress.message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"DAG progress: {progress.node_id} - {progress.status} - {progress.message}")
|
||||||
|
|
||||||
|
engine.set_progress_callback(progress_callback)
|
||||||
|
|
||||||
|
# Execute DAG
|
||||||
|
self.update_state(state='EXECUTING', meta={'status': 'starting', 'nodes': len(dag.nodes)})
|
||||||
|
result = engine.execute(dag)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
raise RuntimeError(f"DAG execution failed: {result.error}")
|
||||||
|
|
||||||
|
# Get output hash
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
output_hash = None
|
||||||
|
if result.output_path and result.output_path.exists():
|
||||||
|
output_hash = file_hash(result.output_path)
|
||||||
|
|
||||||
|
# Store in cache_manager for proper tracking
|
||||||
|
cached = cache_manager.put(result.output_path, node_type="dag_output")
|
||||||
|
|
||||||
|
# Record activity for deletion tracking
|
||||||
|
input_hashes = []
|
||||||
|
for node_id, node in dag.nodes.items():
|
||||||
|
if node.node_type == NodeType.SOURCE or str(node.node_type) == "SOURCE":
|
||||||
|
content_hash = node.config.get("content_hash")
|
||||||
|
if content_hash:
|
||||||
|
input_hashes.append(content_hash)
|
||||||
|
|
||||||
|
if input_hashes:
|
||||||
|
cache_manager.record_simple_activity(
|
||||||
|
input_hashes=input_hashes,
|
||||||
|
output_hash=output_hash,
|
||||||
|
run_id=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build result
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"run_id": run_id,
|
||||||
|
"output_hash": output_hash,
|
||||||
|
"output_path": str(result.output_path) if result.output_path else None,
|
||||||
|
"execution_time": result.execution_time,
|
||||||
|
"nodes_executed": result.nodes_executed,
|
||||||
|
"nodes_cached": result.nodes_cached,
|
||||||
|
"node_results": {
|
||||||
|
node_id: str(path) for node_id, path in result.node_results.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_effect_dag(input_hashes: List[str], effect_name: str) -> DAG:
|
||||||
|
"""
|
||||||
|
Build a simple DAG for applying an effect to inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_hashes: List of input content hashes
|
||||||
|
effect_name: Name of effect to apply (e.g., "dog", "identity")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DAG ready for execution
|
||||||
|
"""
|
||||||
|
dag = DAG()
|
||||||
|
|
||||||
|
# Add source nodes for each input
|
||||||
|
source_ids = []
|
||||||
|
for i, content_hash in enumerate(input_hashes):
|
||||||
|
source_node = Node(
|
||||||
|
node_type=NodeType.SOURCE,
|
||||||
|
config={"content_hash": content_hash},
|
||||||
|
name=f"source_{i}",
|
||||||
|
)
|
||||||
|
dag.add_node(source_node)
|
||||||
|
source_ids.append(source_node.node_id)
|
||||||
|
|
||||||
|
# Add effect node
|
||||||
|
effect_node = Node(
|
||||||
|
node_type=f"effect:{effect_name}",
|
||||||
|
config={},
|
||||||
|
inputs=source_ids,
|
||||||
|
name=f"effect_{effect_name}",
|
||||||
|
)
|
||||||
|
dag.add_node(effect_node)
|
||||||
|
dag.set_output(effect_node.node_id)
|
||||||
|
|
||||||
|
return dag
|
||||||
|
|||||||
Reference in New Issue
Block a user