diff --git a/server.py b/server.py index dab6bd4..ada0d10 100644 --- a/server.py +++ b/server.py @@ -26,7 +26,7 @@ import requests as http_requests from urllib.parse import urlparse from celery_app import app as celery_app -from tasks import render_effect +from tasks import render_effect, execute_dag, build_effect_dag from cache_manager import L1CacheManager, get_cache_manager # L2 server for auth verification @@ -98,9 +98,11 @@ app = FastAPI( class RunRequest(BaseModel): """Request to start a run.""" - recipe: str # Recipe name (e.g., "dog", "identity") + recipe: str # Recipe name (e.g., "dog", "identity") or "dag" for custom DAG inputs: list[str] # List of content hashes output_name: Optional[str] = None + use_dag: bool = False # Use DAG engine instead of legacy effect runner + dag_json: Optional[str] = None # Custom DAG JSON (required if recipe="dag") class RunStatus(BaseModel): @@ -301,13 +303,25 @@ async def create_run(request: RunRequest, username: str = Depends(get_required_u ) # Submit to Celery - # For now, we only support single-input recipes - if len(request.inputs) != 1: - raise HTTPException(400, "Currently only single-input recipes supported") + if request.use_dag or request.recipe == "dag": + # DAG mode - use artdag engine + if request.dag_json: + # Custom DAG provided + dag_json = request.dag_json + else: + # Build simple effect DAG from recipe and inputs + dag = build_effect_dag(request.inputs, request.recipe) + dag_json = dag.to_json() - input_hash = request.inputs[0] + task = execute_dag.delay(dag_json, run.run_id) + else: + # Legacy mode - single effect + if len(request.inputs) != 1: + raise HTTPException(400, "Legacy mode only supports single-input recipes. Use use_dag=true for multi-input.") + + input_hash = request.inputs[0] + task = render_effect.delay(input_hash, request.recipe, output_name) - task = render_effect.delay(input_hash, request.recipe, output_name) run.celery_task_id = task.id run.status = "running" @@ -331,29 +345,37 @@ async def get_run(run_id: str): result = task.result run.status = "completed" run.completed_at = datetime.now(timezone.utc).isoformat() - run.output_hash = result.get("output", {}).get("content_hash") - # Extract effects info from provenance - effects = result.get("effects", []) - if effects: - run.effects_commit = effects[0].get("repo_commit") - run.effect_url = effects[0].get("repo_url") + # Handle both legacy (render_effect) and new (execute_dag) result formats + if "output_hash" in result: + # New DAG result format + run.output_hash = result.get("output_hash") + output_path = Path(result.get("output_path", "")) if result.get("output_path") else None + else: + # Legacy render_effect format + run.output_hash = result.get("output", {}).get("content_hash") + output_path = Path(result.get("output", {}).get("local_path", "")) - # Extract infrastructure info - run.infrastructure = result.get("infrastructure") + # Extract effects info from provenance (legacy only) + effects = result.get("effects", []) + if effects: + run.effects_commit = effects[0].get("repo_commit") + run.effect_url = effects[0].get("repo_url") - # Cache the output - output_path = Path(result.get("output", {}).get("local_path", "")) - if output_path.exists(): + # Extract infrastructure info (legacy only) + run.infrastructure = result.get("infrastructure") + + # Cache the output (legacy mode - DAG already caches via cache_manager) + if output_path and output_path.exists() and "output_hash" not in result: cache_file(output_path, node_type="effect_output") - # Record activity for deletion tracking - if run.output_hash and run.inputs: - cache_manager.record_simple_activity( - input_hashes=run.inputs, - output_hash=run.output_hash, - run_id=run.run_id, - ) + # Record activity for deletion tracking (legacy mode) + if run.output_hash and run.inputs: + cache_manager.record_simple_activity( + input_hashes=run.inputs, + output_hash=run.output_hash, + run_id=run.run_id, + ) else: run.status = "failed" run.error = str(task.result) diff --git a/tasks.py b/tasks.py index 5f0df0c..62ebb76 100644 --- a/tasks.py +++ b/tasks.py @@ -2,23 +2,33 @@ Art DAG Celery Tasks Distributed rendering tasks for the Art DAG system. +Supports both single-effect runs and multi-step DAG execution. """ import hashlib import json +import logging import os import subprocess import sys from datetime import datetime, timezone from pathlib import Path +from typing import Dict, List, Optional from celery import Task from celery_app import app +# Import artdag components +from artdag import DAG, Node, NodeType +from artdag.engine import Engine +from artdag.executor import register_executor, Executor, get_executor + # Add effects to path (use env var in Docker, fallback to home dir locally) EFFECTS_PATH = Path(os.environ.get("EFFECTS_PATH", str(Path.home() / "artdag-effects"))) ARTDAG_PATH = Path(os.environ.get("ARTDAG_PATH", str(Path.home() / "art" / "artdag"))) +logger = logging.getLogger(__name__) + def get_effects_commit() -> str: """Get current git commit hash of effects repo.""" @@ -65,6 +75,60 @@ def file_hash(path: Path) -> str: return hasher.hexdigest() +# Cache directory (shared between server and worker) +CACHE_DIR = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache"))) + + +# ============ Executors for Effects ============ + +@register_executor("effect:dog") +class DogExecutor(Executor): + """Executor for the dog effect.""" + + def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path: + from effect import effect_dog + if len(inputs) != 1: + raise ValueError(f"Dog effect expects 1 input, got {len(inputs)}") + return effect_dog(inputs[0], output_path, config) + + +@register_executor("effect:identity") +class IdentityExecutor(Executor): + """Executor for the identity effect (passthrough).""" + + def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path: + from artdag.nodes.effect import effect_identity + if len(inputs) != 1: + raise ValueError(f"Identity effect expects 1 input, got {len(inputs)}") + return effect_identity(inputs[0], output_path, config) + + +@register_executor(NodeType.SOURCE) +class SourceExecutor(Executor): + """Executor for SOURCE nodes - loads content from cache by hash.""" + + def execute(self, config: Dict, inputs: List[Path], output_path: Path) -> Path: + # Source nodes load from cache by content_hash + content_hash = config.get("content_hash") + if not content_hash: + raise ValueError("SOURCE node requires content_hash in config") + + # Look up in cache + source_path = CACHE_DIR / content_hash + if not source_path.exists(): + # Try nodes directory + from cache_manager import get_cache_manager + cache_manager = get_cache_manager() + source_path = cache_manager.get_by_content_hash(content_hash) + + if not source_path or not source_path.exists(): + raise ValueError(f"Source content not in cache: {content_hash}") + + # For source nodes, we just return the path (no transformation) + # The engine will use this as input to subsequent nodes + return source_path + + class RenderTask(Task): """Base task with provenance tracking.""" @@ -197,3 +261,130 @@ def render_dog_from_cat() -> dict: """Convenience task: render cat through dog effect.""" CAT_HASH = "33268b6e167deaf018cc538de12dbe562612b33e89a749391cef855b320a269b" return render_effect.delay(CAT_HASH, "dog", "dog-from-cat-celery").get() + + +@app.task(base=RenderTask, bind=True) +def execute_dag(self, dag_json: str, run_id: str = None) -> dict: + """ + Execute a multi-step DAG. + + Args: + dag_json: Serialized DAG as JSON string + run_id: Optional run ID for tracking + + Returns: + Execution result with output hash and node results + """ + from cache_manager import get_cache_manager + + # Parse DAG + try: + dag = DAG.from_json(dag_json) + except Exception as e: + raise ValueError(f"Invalid DAG JSON: {e}") + + # Validate DAG + errors = dag.validate() + if errors: + raise ValueError(f"Invalid DAG: {errors}") + + # Create engine with cache directory + engine = Engine(CACHE_DIR / "nodes") + + # Set up progress callback + def progress_callback(progress): + self.update_state( + state='EXECUTING', + meta={ + 'node_id': progress.node_id, + 'node_type': progress.node_type, + 'status': progress.status, + 'progress': progress.progress, + 'message': progress.message, + } + ) + logger.info(f"DAG progress: {progress.node_id} - {progress.status} - {progress.message}") + + engine.set_progress_callback(progress_callback) + + # Execute DAG + self.update_state(state='EXECUTING', meta={'status': 'starting', 'nodes': len(dag.nodes)}) + result = engine.execute(dag) + + if not result.success: + raise RuntimeError(f"DAG execution failed: {result.error}") + + # Get output hash + cache_manager = get_cache_manager() + output_hash = None + if result.output_path and result.output_path.exists(): + output_hash = file_hash(result.output_path) + + # Store in cache_manager for proper tracking + cached = cache_manager.put(result.output_path, node_type="dag_output") + + # Record activity for deletion tracking + input_hashes = [] + for node_id, node in dag.nodes.items(): + if node.node_type == NodeType.SOURCE or str(node.node_type) == "SOURCE": + content_hash = node.config.get("content_hash") + if content_hash: + input_hashes.append(content_hash) + + if input_hashes: + cache_manager.record_simple_activity( + input_hashes=input_hashes, + output_hash=output_hash, + run_id=run_id, + ) + + # Build result + return { + "success": True, + "run_id": run_id, + "output_hash": output_hash, + "output_path": str(result.output_path) if result.output_path else None, + "execution_time": result.execution_time, + "nodes_executed": result.nodes_executed, + "nodes_cached": result.nodes_cached, + "node_results": { + node_id: str(path) for node_id, path in result.node_results.items() + }, + } + + +def build_effect_dag(input_hashes: List[str], effect_name: str) -> DAG: + """ + Build a simple DAG for applying an effect to inputs. + + Args: + input_hashes: List of input content hashes + effect_name: Name of effect to apply (e.g., "dog", "identity") + + Returns: + DAG ready for execution + """ + dag = DAG() + + # Add source nodes for each input + source_ids = [] + for i, content_hash in enumerate(input_hashes): + source_node = Node( + node_type=NodeType.SOURCE, + config={"content_hash": content_hash}, + name=f"source_{i}", + ) + dag.add_node(source_node) + source_ids.append(source_node.node_id) + + # Add effect node + effect_node = Node( + node_type=f"effect:{effect_name}", + config={}, + inputs=source_ids, + name=f"effect_{effect_name}", + ) + dag.add_node(effect_node) + dag.set_output(effect_node.node_id) + + return dag