Squashed 'core/' content from commit 4957443
git-subtree-dir: core git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07
This commit is contained in:
594
artdag/planning/schema.py
Normal file
594
artdag/planning/schema.py
Normal file
@@ -0,0 +1,594 @@
|
||||
# artdag/planning/schema.py
|
||||
"""
|
||||
Data structures for execution plans.
|
||||
|
||||
An ExecutionPlan contains all steps needed to execute a recipe,
|
||||
with pre-computed cache IDs for each step.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Cluster key for trust domains
|
||||
# Systems with the same key produce the same cache_ids and can share work
|
||||
# Systems with different keys have isolated cache namespaces
|
||||
CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY")
|
||||
|
||||
|
||||
def set_cluster_key(key: Optional[str]) -> None:
|
||||
"""Set the cluster key programmatically."""
|
||||
global CLUSTER_KEY
|
||||
CLUSTER_KEY = key
|
||||
|
||||
|
||||
def get_cluster_key() -> Optional[str]:
|
||||
"""Get the current cluster key."""
|
||||
return CLUSTER_KEY
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Create stable hash from arbitrary data.
|
||||
|
||||
If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create
|
||||
isolated trust domains. Systems with the same key can share work;
|
||||
systems with different keys have separate cache namespaces.
|
||||
"""
|
||||
# Mix in cluster key if set
|
||||
if CLUSTER_KEY:
|
||||
data = {"_cluster_key": CLUSTER_KEY, "_data": data}
|
||||
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
"""Status of an execution step."""
|
||||
PENDING = "pending"
|
||||
CLAIMED = "claimed"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CACHED = "cached"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepOutput:
|
||||
"""
|
||||
A single output from an execution step.
|
||||
|
||||
Nodes may produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has a human-readable name and a cache_id for storage.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Content-addressed hash for caching
|
||||
media_type: MIME type of the output (e.g., "video/mp4", "audio/wav")
|
||||
index: Output index for multi-output nodes
|
||||
metadata: Optional additional metadata (time_range, etc.)
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
media_type: str = "application/octet-stream"
|
||||
index: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"media_type": self.media_type,
|
||||
"index": self.index,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepOutput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
index=data.get("index", 0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepInput:
|
||||
"""
|
||||
Reference to an input for a step.
|
||||
|
||||
Inputs can reference outputs from other steps by name.
|
||||
|
||||
Attributes:
|
||||
name: Input slot name (e.g., "video", "audio", "segments")
|
||||
source: Source output name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Resolved cache_id of the source (populated during planning)
|
||||
"""
|
||||
name: str
|
||||
source: str
|
||||
cache_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"source": self.source,
|
||||
"cache_id": self.cache_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
source=data["source"],
|
||||
cache_id=data.get("cache_id"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStep:
|
||||
"""
|
||||
A single step in the execution plan.
|
||||
|
||||
Each step has a pre-computed cache_id that uniquely identifies
|
||||
its output based on its configuration and input cache_ids.
|
||||
|
||||
Steps can produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has its own cache_id derived from the step's cache_id + index.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name relating to recipe (e.g., "beats.split")
|
||||
step_id: Unique identifier (hash) for this step
|
||||
node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.)
|
||||
config: Configuration for the primitive
|
||||
input_steps: IDs of steps this depends on (legacy, use inputs for new code)
|
||||
inputs: Structured input references with names and sources
|
||||
cache_id: Pre-computed cache ID (hash of config + input cache_ids)
|
||||
outputs: List of outputs this step produces
|
||||
estimated_duration: Optional estimated execution time
|
||||
level: Dependency level (0 = no dependencies, higher = more deps)
|
||||
"""
|
||||
step_id: str
|
||||
node_type: str
|
||||
config: Dict[str, Any]
|
||||
input_steps: List[str] = field(default_factory=list)
|
||||
inputs: List[StepInput] = field(default_factory=list)
|
||||
cache_id: Optional[str] = None
|
||||
outputs: List[StepOutput] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
estimated_duration: Optional[float] = None
|
||||
level: int = 0
|
||||
|
||||
def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str:
|
||||
"""
|
||||
Compute cache ID from configuration and input cache IDs.
|
||||
|
||||
cache_id = SHA3-256(node_type + config + sorted(input_cache_ids))
|
||||
|
||||
Args:
|
||||
input_cache_ids: Mapping from input step_id/name to their cache_id
|
||||
|
||||
Returns:
|
||||
The computed cache_id
|
||||
"""
|
||||
# Use structured inputs if available, otherwise fall back to input_steps
|
||||
if self.inputs:
|
||||
resolved_inputs = [
|
||||
inp.cache_id or input_cache_ids.get(inp.source, inp.source)
|
||||
for inp in sorted(self.inputs, key=lambda x: x.name)
|
||||
]
|
||||
else:
|
||||
resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)]
|
||||
|
||||
content = {
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"inputs": resolved_inputs,
|
||||
}
|
||||
self.cache_id = _stable_hash(content)
|
||||
return self.cache_id
|
||||
|
||||
def compute_output_cache_id(self, index: int) -> str:
|
||||
"""
|
||||
Compute cache ID for a specific output index.
|
||||
|
||||
output_cache_id = SHA3-256(step_cache_id + index)
|
||||
|
||||
Args:
|
||||
index: The output index
|
||||
|
||||
Returns:
|
||||
Cache ID for that output
|
||||
"""
|
||||
if not self.cache_id:
|
||||
raise ValueError("Step cache_id must be computed first")
|
||||
content = {"step_cache_id": self.cache_id, "output_index": index}
|
||||
return _stable_hash(content)
|
||||
|
||||
def add_output(
|
||||
self,
|
||||
name: str,
|
||||
media_type: str = "application/octet-stream",
|
||||
index: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> StepOutput:
|
||||
"""
|
||||
Add an output to this step.
|
||||
|
||||
Args:
|
||||
name: Human-readable output name
|
||||
media_type: MIME type of the output
|
||||
index: Output index (defaults to next available)
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created StepOutput
|
||||
"""
|
||||
if index is None:
|
||||
index = len(self.outputs)
|
||||
|
||||
cache_id = self.compute_output_cache_id(index)
|
||||
output = StepOutput(
|
||||
name=name,
|
||||
cache_id=cache_id,
|
||||
media_type=media_type,
|
||||
index=index,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.outputs.append(output)
|
||||
return output
|
||||
|
||||
def get_output(self, index: int = 0) -> Optional[StepOutput]:
|
||||
"""Get output by index."""
|
||||
if index < len(self.outputs):
|
||||
return self.outputs[index]
|
||||
return None
|
||||
|
||||
def get_output_by_name(self, name: str) -> Optional[StepOutput]:
|
||||
"""Get output by name."""
|
||||
for output in self.outputs:
|
||||
if output.name == name:
|
||||
return output
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step_id": self.step_id,
|
||||
"name": self.name,
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"input_steps": self.input_steps,
|
||||
"inputs": [inp.to_dict() for inp in self.inputs],
|
||||
"cache_id": self.cache_id,
|
||||
"outputs": [out.to_dict() for out in self.outputs],
|
||||
"estimated_duration": self.estimated_duration,
|
||||
"level": self.level,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep":
|
||||
inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])]
|
||||
return cls(
|
||||
step_id=data["step_id"],
|
||||
node_type=data["node_type"],
|
||||
config=data.get("config", {}),
|
||||
input_steps=data.get("input_steps", []),
|
||||
inputs=inputs,
|
||||
cache_id=data.get("cache_id"),
|
||||
outputs=outputs,
|
||||
name=data.get("name"),
|
||||
estimated_duration=data.get("estimated_duration"),
|
||||
level=data.get("level", 0),
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionStep":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanInput:
|
||||
"""
|
||||
An input to the execution plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name from recipe (e.g., "source_video")
|
||||
cache_id: Content hash of the input file
|
||||
cid: Same as cache_id (for clarity)
|
||||
media_type: MIME type of the input
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
cid: str
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"cid": self.cid,
|
||||
"media_type": self.media_type,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlanInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
cid=data.get("cid", data["cache_id"]),
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
"""
|
||||
Complete execution plan for a recipe.
|
||||
|
||||
Contains all steps in topological order with pre-computed cache IDs.
|
||||
The plan is deterministic: same recipe + same inputs = same plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable plan name from recipe
|
||||
plan_id: Hash of the entire plan (for deduplication)
|
||||
recipe_id: Source recipe identifier
|
||||
recipe_name: Human-readable recipe name
|
||||
recipe_hash: Hash of the recipe content
|
||||
seed: Random seed used for planning
|
||||
steps: List of steps in execution order
|
||||
output_step: ID of the final output step
|
||||
output_name: Human-readable name of the final output
|
||||
inputs: Structured input definitions
|
||||
analysis_cache_ids: Cache IDs of analysis results used
|
||||
input_hashes: Content hashes of input files (legacy, use inputs)
|
||||
created_at: When the plan was generated
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
plan_id: Optional[str]
|
||||
recipe_id: str
|
||||
recipe_hash: str
|
||||
steps: List[ExecutionStep]
|
||||
output_step: str
|
||||
name: Optional[str] = None
|
||||
recipe_name: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
output_name: Optional[str] = None
|
||||
inputs: List[PlanInput] = field(default_factory=list)
|
||||
analysis_cache_ids: Dict[str, str] = field(default_factory=dict)
|
||||
input_hashes: Dict[str, str] = field(default_factory=dict)
|
||||
created_at: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.plan_id is None:
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def _compute_plan_id(self) -> str:
|
||||
"""Compute plan ID from contents."""
|
||||
content = {
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"input_hashes": self.input_hashes,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
}
|
||||
return _stable_hash(content)
|
||||
|
||||
def compute_all_cache_ids(self) -> None:
|
||||
"""
|
||||
Compute cache IDs for all steps in dependency order.
|
||||
|
||||
Must be called after all steps are added to ensure
|
||||
cache IDs propagate correctly through dependencies.
|
||||
"""
|
||||
# Build step lookup
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
|
||||
# Cache IDs start with input hashes
|
||||
cache_ids = dict(self.input_hashes)
|
||||
|
||||
# Process in order (assumes topological order)
|
||||
for step in self.steps:
|
||||
# For SOURCE steps referencing inputs, use input hash
|
||||
if step.node_type == "SOURCE" and step.config.get("input_ref"):
|
||||
ref = step.config["input_ref"]
|
||||
if ref in self.input_hashes:
|
||||
step.cache_id = self.input_hashes[ref]
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
continue
|
||||
|
||||
# For other steps, compute from inputs
|
||||
input_cache_ids = {}
|
||||
for input_step_id in step.input_steps:
|
||||
if input_step_id in cache_ids:
|
||||
input_cache_ids[input_step_id] = cache_ids[input_step_id]
|
||||
elif input_step_id in step_by_id:
|
||||
# Step should have been processed already
|
||||
input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id
|
||||
else:
|
||||
raise ValueError(f"Input step {input_step_id} not found for {step.step_id}")
|
||||
|
||||
step.compute_cache_id(input_cache_ids)
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Recompute plan_id with final cache IDs
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def compute_levels(self) -> int:
|
||||
"""
|
||||
Compute dependency levels for all steps.
|
||||
|
||||
Level 0 = no dependencies (can start immediately)
|
||||
Level N = depends on steps at level N-1
|
||||
|
||||
Returns:
|
||||
Maximum level (number of sequential dependency levels)
|
||||
"""
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
levels = {}
|
||||
|
||||
def compute_level(step_id: str) -> int:
|
||||
if step_id in levels:
|
||||
return levels[step_id]
|
||||
|
||||
step = step_by_id.get(step_id)
|
||||
if step is None:
|
||||
return 0 # Input from outside the plan
|
||||
|
||||
if not step.input_steps:
|
||||
levels[step_id] = 0
|
||||
step.level = 0
|
||||
return 0
|
||||
|
||||
max_input_level = max(compute_level(s) for s in step.input_steps)
|
||||
level = max_input_level + 1
|
||||
levels[step_id] = level
|
||||
step.level = level
|
||||
return level
|
||||
|
||||
for step in self.steps:
|
||||
compute_level(step.step_id)
|
||||
|
||||
return max(levels.values()) if levels else 0
|
||||
|
||||
def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]:
|
||||
"""
|
||||
Group steps by dependency level.
|
||||
|
||||
Steps at the same level can execute in parallel.
|
||||
|
||||
Returns:
|
||||
Dict mapping level -> list of steps at that level
|
||||
"""
|
||||
by_level: Dict[int, List[ExecutionStep]] = {}
|
||||
for step in self.steps:
|
||||
by_level.setdefault(step.level, []).append(step)
|
||||
return by_level
|
||||
|
||||
def get_step(self, step_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by ID."""
|
||||
for step in self.steps:
|
||||
if step.step_id == step_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by cache ID."""
|
||||
for step in self.steps:
|
||||
if step.cache_id == cache_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_name(self, name: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by human-readable name."""
|
||||
for step in self.steps:
|
||||
if step.name == name:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_all_outputs(self) -> Dict[str, StepOutput]:
|
||||
"""
|
||||
Get all outputs from all steps, keyed by output name.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> StepOutput
|
||||
"""
|
||||
outputs = {}
|
||||
for step in self.steps:
|
||||
for output in step.outputs:
|
||||
outputs[output.name] = output
|
||||
return outputs
|
||||
|
||||
def get_output_cache_ids(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get mapping of output names to cache IDs.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> cache_id
|
||||
"""
|
||||
return {
|
||||
output.name: output.cache_id
|
||||
for step in self.steps
|
||||
for output in step.outputs
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plan_id": self.plan_id,
|
||||
"name": self.name,
|
||||
"recipe_id": self.recipe_id,
|
||||
"recipe_name": self.recipe_name,
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"seed": self.seed,
|
||||
"inputs": [i.to_dict() for i in self.inputs],
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"output_step": self.output_step,
|
||||
"output_name": self.output_name,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
"input_hashes": self.input_hashes,
|
||||
"created_at": self.created_at,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan":
|
||||
inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
return cls(
|
||||
plan_id=data.get("plan_id"),
|
||||
name=data.get("name"),
|
||||
recipe_id=data["recipe_id"],
|
||||
recipe_name=data.get("recipe_name"),
|
||||
recipe_hash=data["recipe_hash"],
|
||||
seed=data.get("seed"),
|
||||
inputs=inputs,
|
||||
steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])],
|
||||
output_step=data["output_step"],
|
||||
output_name=data.get("output_name"),
|
||||
analysis_cache_ids=data.get("analysis_cache_ids", {}),
|
||||
input_hashes=data.get("input_hashes", {}),
|
||||
created_at=data.get("created_at"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionPlan":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Get a human-readable summary of the plan."""
|
||||
by_level = self.get_steps_by_level()
|
||||
max_level = max(by_level.keys()) if by_level else 0
|
||||
|
||||
lines = [
|
||||
f"Execution Plan: {self.plan_id[:16]}...",
|
||||
f"Recipe: {self.recipe_id}",
|
||||
f"Steps: {len(self.steps)}",
|
||||
f"Levels: {max_level + 1}",
|
||||
"",
|
||||
]
|
||||
|
||||
for level in sorted(by_level.keys()):
|
||||
steps = by_level[level]
|
||||
lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)")
|
||||
for step in steps:
|
||||
cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]"
|
||||
lines.append(f" - {step.step_id}: {step.node_type} {cache_status}")
|
||||
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user