# artdag/planning/schema.py """ Data structures for execution plans. An ExecutionPlan contains all steps needed to execute a recipe, with pre-computed cache IDs for each step. """ import hashlib import json import os from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, Dict, List, Optional # Cluster key for trust domains # Systems with the same key produce the same cache_ids and can share work # Systems with different keys have isolated cache namespaces CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY") def set_cluster_key(key: Optional[str]) -> None: """Set the cluster key programmatically.""" global CLUSTER_KEY CLUSTER_KEY = key def get_cluster_key() -> Optional[str]: """Get the current cluster key.""" return CLUSTER_KEY def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: """ Create stable hash from arbitrary data. If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create isolated trust domains. Systems with the same key can share work; systems with different keys have separate cache namespaces. """ # Mix in cluster key if set if CLUSTER_KEY: data = {"_cluster_key": CLUSTER_KEY, "_data": data} json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) hasher = hashlib.new(algorithm) hasher.update(json_str.encode()) return hasher.hexdigest() class StepStatus(Enum): """Status of an execution step.""" PENDING = "pending" CLAIMED = "claimed" RUNNING = "running" COMPLETED = "completed" CACHED = "cached" FAILED = "failed" SKIPPED = "skipped" @dataclass class StepOutput: """ A single output from an execution step. Nodes may produce multiple outputs (e.g., split_on_beats produces N segments). Each output has a human-readable name and a cache_id for storage. Attributes: name: Human-readable name (e.g., "beats.split.segment[0]") cache_id: Content-addressed hash for caching media_type: MIME type of the output (e.g., "video/mp4", "audio/wav") index: Output index for multi-output nodes metadata: Optional additional metadata (time_range, etc.) """ name: str cache_id: str media_type: str = "application/octet-stream" index: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "cache_id": self.cache_id, "media_type": self.media_type, "index": self.index, "metadata": self.metadata, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "StepOutput": return cls( name=data["name"], cache_id=data["cache_id"], media_type=data.get("media_type", "application/octet-stream"), index=data.get("index", 0), metadata=data.get("metadata", {}), ) @dataclass class StepInput: """ Reference to an input for a step. Inputs can reference outputs from other steps by name. Attributes: name: Input slot name (e.g., "video", "audio", "segments") source: Source output name (e.g., "beats.split.segment[0]") cache_id: Resolved cache_id of the source (populated during planning) """ name: str source: str cache_id: Optional[str] = None def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "source": self.source, "cache_id": self.cache_id, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "StepInput": return cls( name=data["name"], source=data["source"], cache_id=data.get("cache_id"), ) @dataclass class ExecutionStep: """ A single step in the execution plan. Each step has a pre-computed cache_id that uniquely identifies its output based on its configuration and input cache_ids. Steps can produce multiple outputs (e.g., split_on_beats produces N segments). Each output has its own cache_id derived from the step's cache_id + index. Attributes: name: Human-readable name relating to recipe (e.g., "beats.split") step_id: Unique identifier (hash) for this step node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.) config: Configuration for the primitive input_steps: IDs of steps this depends on (legacy, use inputs for new code) inputs: Structured input references with names and sources cache_id: Pre-computed cache ID (hash of config + input cache_ids) outputs: List of outputs this step produces estimated_duration: Optional estimated execution time level: Dependency level (0 = no dependencies, higher = more deps) """ step_id: str node_type: str config: Dict[str, Any] input_steps: List[str] = field(default_factory=list) inputs: List[StepInput] = field(default_factory=list) cache_id: Optional[str] = None outputs: List[StepOutput] = field(default_factory=list) name: Optional[str] = None estimated_duration: Optional[float] = None level: int = 0 def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str: """ Compute cache ID from configuration and input cache IDs. cache_id = SHA3-256(node_type + config + sorted(input_cache_ids)) Args: input_cache_ids: Mapping from input step_id/name to their cache_id Returns: The computed cache_id """ # Use structured inputs if available, otherwise fall back to input_steps if self.inputs: resolved_inputs = [ inp.cache_id or input_cache_ids.get(inp.source, inp.source) for inp in sorted(self.inputs, key=lambda x: x.name) ] else: resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)] content = { "node_type": self.node_type, "config": self.config, "inputs": resolved_inputs, } self.cache_id = _stable_hash(content) return self.cache_id def compute_output_cache_id(self, index: int) -> str: """ Compute cache ID for a specific output index. output_cache_id = SHA3-256(step_cache_id + index) Args: index: The output index Returns: Cache ID for that output """ if not self.cache_id: raise ValueError("Step cache_id must be computed first") content = {"step_cache_id": self.cache_id, "output_index": index} return _stable_hash(content) def add_output( self, name: str, media_type: str = "application/octet-stream", index: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None, ) -> StepOutput: """ Add an output to this step. Args: name: Human-readable output name media_type: MIME type of the output index: Output index (defaults to next available) metadata: Optional metadata Returns: The created StepOutput """ if index is None: index = len(self.outputs) cache_id = self.compute_output_cache_id(index) output = StepOutput( name=name, cache_id=cache_id, media_type=media_type, index=index, metadata=metadata or {}, ) self.outputs.append(output) return output def get_output(self, index: int = 0) -> Optional[StepOutput]: """Get output by index.""" if index < len(self.outputs): return self.outputs[index] return None def get_output_by_name(self, name: str) -> Optional[StepOutput]: """Get output by name.""" for output in self.outputs: if output.name == name: return output return None def to_dict(self) -> Dict[str, Any]: return { "step_id": self.step_id, "name": self.name, "node_type": self.node_type, "config": self.config, "input_steps": self.input_steps, "inputs": [inp.to_dict() for inp in self.inputs], "cache_id": self.cache_id, "outputs": [out.to_dict() for out in self.outputs], "estimated_duration": self.estimated_duration, "level": self.level, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep": inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])] outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])] return cls( step_id=data["step_id"], node_type=data["node_type"], config=data.get("config", {}), input_steps=data.get("input_steps", []), inputs=inputs, cache_id=data.get("cache_id"), outputs=outputs, name=data.get("name"), estimated_duration=data.get("estimated_duration"), level=data.get("level", 0), ) def to_json(self) -> str: return json.dumps(self.to_dict()) @classmethod def from_json(cls, json_str: str) -> "ExecutionStep": return cls.from_dict(json.loads(json_str)) @dataclass class PlanInput: """ An input to the execution plan. Attributes: name: Human-readable name from recipe (e.g., "source_video") cache_id: Content hash of the input file cid: Same as cache_id (for clarity) media_type: MIME type of the input """ name: str cache_id: str cid: str media_type: str = "application/octet-stream" def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "cache_id": self.cache_id, "cid": self.cid, "media_type": self.media_type, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PlanInput": return cls( name=data["name"], cache_id=data["cache_id"], cid=data.get("cid", data["cache_id"]), media_type=data.get("media_type", "application/octet-stream"), ) @dataclass class ExecutionPlan: """ Complete execution plan for a recipe. Contains all steps in topological order with pre-computed cache IDs. The plan is deterministic: same recipe + same inputs = same plan. Attributes: name: Human-readable plan name from recipe plan_id: Hash of the entire plan (for deduplication) recipe_id: Source recipe identifier recipe_name: Human-readable recipe name recipe_hash: Hash of the recipe content seed: Random seed used for planning steps: List of steps in execution order output_step: ID of the final output step output_name: Human-readable name of the final output inputs: Structured input definitions analysis_cache_ids: Cache IDs of analysis results used input_hashes: Content hashes of input files (legacy, use inputs) created_at: When the plan was generated metadata: Optional additional metadata """ plan_id: Optional[str] recipe_id: str recipe_hash: str steps: List[ExecutionStep] output_step: str name: Optional[str] = None recipe_name: Optional[str] = None seed: Optional[int] = None output_name: Optional[str] = None inputs: List[PlanInput] = field(default_factory=list) analysis_cache_ids: Dict[str, str] = field(default_factory=dict) input_hashes: Dict[str, str] = field(default_factory=dict) created_at: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if self.created_at is None: self.created_at = datetime.now(timezone.utc).isoformat() if self.plan_id is None: self.plan_id = self._compute_plan_id() def _compute_plan_id(self) -> str: """Compute plan ID from contents.""" content = { "recipe_hash": self.recipe_hash, "steps": [s.to_dict() for s in self.steps], "input_hashes": self.input_hashes, "analysis_cache_ids": self.analysis_cache_ids, } return _stable_hash(content) def compute_all_cache_ids(self) -> None: """ Compute cache IDs for all steps in dependency order. Must be called after all steps are added to ensure cache IDs propagate correctly through dependencies. """ # Build step lookup step_by_id = {s.step_id: s for s in self.steps} # Cache IDs start with input hashes cache_ids = dict(self.input_hashes) # Process in order (assumes topological order) for step in self.steps: # For SOURCE steps referencing inputs, use input hash if step.node_type == "SOURCE" and step.config.get("input_ref"): ref = step.config["input_ref"] if ref in self.input_hashes: step.cache_id = self.input_hashes[ref] cache_ids[step.step_id] = step.cache_id continue # For other steps, compute from inputs input_cache_ids = {} for input_step_id in step.input_steps: if input_step_id in cache_ids: input_cache_ids[input_step_id] = cache_ids[input_step_id] elif input_step_id in step_by_id: # Step should have been processed already input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id else: raise ValueError(f"Input step {input_step_id} not found for {step.step_id}") step.compute_cache_id(input_cache_ids) cache_ids[step.step_id] = step.cache_id # Recompute plan_id with final cache IDs self.plan_id = self._compute_plan_id() def compute_levels(self) -> int: """ Compute dependency levels for all steps. Level 0 = no dependencies (can start immediately) Level N = depends on steps at level N-1 Returns: Maximum level (number of sequential dependency levels) """ step_by_id = {s.step_id: s for s in self.steps} levels = {} def compute_level(step_id: str) -> int: if step_id in levels: return levels[step_id] step = step_by_id.get(step_id) if step is None: return 0 # Input from outside the plan if not step.input_steps: levels[step_id] = 0 step.level = 0 return 0 max_input_level = max(compute_level(s) for s in step.input_steps) level = max_input_level + 1 levels[step_id] = level step.level = level return level for step in self.steps: compute_level(step.step_id) return max(levels.values()) if levels else 0 def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]: """ Group steps by dependency level. Steps at the same level can execute in parallel. Returns: Dict mapping level -> list of steps at that level """ by_level: Dict[int, List[ExecutionStep]] = {} for step in self.steps: by_level.setdefault(step.level, []).append(step) return by_level def get_step(self, step_id: str) -> Optional[ExecutionStep]: """Get step by ID.""" for step in self.steps: if step.step_id == step_id: return step return None def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]: """Get step by cache ID.""" for step in self.steps: if step.cache_id == cache_id: return step return None def get_step_by_name(self, name: str) -> Optional[ExecutionStep]: """Get step by human-readable name.""" for step in self.steps: if step.name == name: return step return None def get_all_outputs(self) -> Dict[str, StepOutput]: """ Get all outputs from all steps, keyed by output name. Returns: Dict mapping output name -> StepOutput """ outputs = {} for step in self.steps: for output in step.outputs: outputs[output.name] = output return outputs def get_output_cache_ids(self) -> Dict[str, str]: """ Get mapping of output names to cache IDs. Returns: Dict mapping output name -> cache_id """ return { output.name: output.cache_id for step in self.steps for output in step.outputs } def to_dict(self) -> Dict[str, Any]: return { "plan_id": self.plan_id, "name": self.name, "recipe_id": self.recipe_id, "recipe_name": self.recipe_name, "recipe_hash": self.recipe_hash, "seed": self.seed, "inputs": [i.to_dict() for i in self.inputs], "steps": [s.to_dict() for s in self.steps], "output_step": self.output_step, "output_name": self.output_name, "analysis_cache_ids": self.analysis_cache_ids, "input_hashes": self.input_hashes, "created_at": self.created_at, "metadata": self.metadata, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan": inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])] return cls( plan_id=data.get("plan_id"), name=data.get("name"), recipe_id=data["recipe_id"], recipe_name=data.get("recipe_name"), recipe_hash=data["recipe_hash"], seed=data.get("seed"), inputs=inputs, steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])], output_step=data["output_step"], output_name=data.get("output_name"), analysis_cache_ids=data.get("analysis_cache_ids", {}), input_hashes=data.get("input_hashes", {}), created_at=data.get("created_at"), metadata=data.get("metadata", {}), ) def to_json(self, indent: int = 2) -> str: return json.dumps(self.to_dict(), indent=indent) @classmethod def from_json(cls, json_str: str) -> "ExecutionPlan": return cls.from_dict(json.loads(json_str)) def summary(self) -> str: """Get a human-readable summary of the plan.""" by_level = self.get_steps_by_level() max_level = max(by_level.keys()) if by_level else 0 lines = [ f"Execution Plan: {self.plan_id[:16]}...", f"Recipe: {self.recipe_id}", f"Steps: {len(self.steps)}", f"Levels: {max_level + 1}", "", ] for level in sorted(by_level.keys()): steps = by_level[level] lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)") for step in steps: cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]" lines.append(f" - {step.step_id}: {step.node_type} {cache_status}") return "\n".join(lines)