Squashed 'core/' content from commit 4957443
git-subtree-dir: core git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07
This commit is contained in:
29
artdag/planning/__init__.py
Normal file
29
artdag/planning/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# artdag/planning - Execution plan generation
|
||||
#
|
||||
# Provides the Planning phase of the 3-phase execution model:
|
||||
# 1. ANALYZE - Extract features from inputs
|
||||
# 2. PLAN - Generate execution plan with cache IDs
|
||||
# 3. EXECUTE - Run steps with caching
|
||||
|
||||
from .schema import (
|
||||
ExecutionStep,
|
||||
ExecutionPlan,
|
||||
StepStatus,
|
||||
StepOutput,
|
||||
StepInput,
|
||||
PlanInput,
|
||||
)
|
||||
from .planner import RecipePlanner, Recipe
|
||||
from .tree_reduction import TreeReducer
|
||||
|
||||
__all__ = [
|
||||
"ExecutionStep",
|
||||
"ExecutionPlan",
|
||||
"StepStatus",
|
||||
"StepOutput",
|
||||
"StepInput",
|
||||
"PlanInput",
|
||||
"RecipePlanner",
|
||||
"Recipe",
|
||||
"TreeReducer",
|
||||
]
|
||||
756
artdag/planning/planner.py
Normal file
756
artdag/planning/planner.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# artdag/planning/planner.py
|
||||
"""
|
||||
Recipe planner - converts recipes into execution plans.
|
||||
|
||||
The planner is the second phase of the 3-phase execution model.
|
||||
It takes a recipe and analysis results and generates a complete
|
||||
execution plan with pre-computed cache IDs.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from .schema import ExecutionPlan, ExecutionStep, StepOutput, StepInput, PlanInput
|
||||
from .tree_reduction import TreeReducer, reduce_sequence
|
||||
from ..analysis import AnalysisResult
|
||||
|
||||
|
||||
def _infer_media_type(node_type: str, config: Dict[str, Any] = None) -> str:
|
||||
"""Infer media type from node type and config."""
|
||||
config = config or {}
|
||||
|
||||
# Audio operations
|
||||
if node_type in ("AUDIO", "MIX_AUDIO", "EXTRACT_AUDIO"):
|
||||
return "audio/wav"
|
||||
if "audio" in node_type.lower():
|
||||
return "audio/wav"
|
||||
|
||||
# Image operations
|
||||
if node_type in ("FRAME", "THUMBNAIL", "IMAGE"):
|
||||
return "image/png"
|
||||
|
||||
# Default to video
|
||||
return "video/mp4"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""Create stable hash from arbitrary data."""
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecipeNode:
|
||||
"""A node in the recipe DAG."""
|
||||
id: str
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
inputs: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Recipe:
|
||||
"""Parsed recipe structure."""
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
nodes: List[RecipeNode]
|
||||
output: str
|
||||
registry: Dict[str, Any]
|
||||
owner: str
|
||||
raw_yaml: str
|
||||
|
||||
@property
|
||||
def recipe_hash(self) -> str:
|
||||
"""Compute hash of recipe content."""
|
||||
return _stable_hash({"yaml": self.raw_yaml})
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_content: str) -> "Recipe":
|
||||
"""Parse recipe from YAML string."""
|
||||
data = yaml.safe_load(yaml_content)
|
||||
|
||||
nodes = []
|
||||
for node_data in data.get("dag", {}).get("nodes", []):
|
||||
# Handle both 'inputs' as list and 'inputs' as dict
|
||||
inputs = node_data.get("inputs", [])
|
||||
if isinstance(inputs, dict):
|
||||
# Extract input references from dict structure
|
||||
input_list = []
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, str):
|
||||
input_list.append(value)
|
||||
elif isinstance(value, list):
|
||||
input_list.extend(value)
|
||||
inputs = input_list
|
||||
elif isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
|
||||
nodes.append(RecipeNode(
|
||||
id=node_data["id"],
|
||||
type=node_data["type"],
|
||||
config=node_data.get("config", {}),
|
||||
inputs=inputs,
|
||||
))
|
||||
|
||||
return cls(
|
||||
name=data.get("name", "unnamed"),
|
||||
version=data.get("version", "1.0"),
|
||||
description=data.get("description", ""),
|
||||
nodes=nodes,
|
||||
output=data.get("dag", {}).get("output", ""),
|
||||
registry=data.get("registry", {}),
|
||||
owner=data.get("owner", ""),
|
||||
raw_yaml=yaml_content,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Path) -> "Recipe":
|
||||
"""Load recipe from YAML file."""
|
||||
with open(path, "r") as f:
|
||||
return cls.from_yaml(f.read())
|
||||
|
||||
|
||||
class RecipePlanner:
|
||||
"""
|
||||
Generates execution plans from recipes.
|
||||
|
||||
The planner:
|
||||
1. Parses the recipe
|
||||
2. Resolves fixed inputs from registry
|
||||
3. Maps variable inputs to provided hashes
|
||||
4. Expands MAP/iteration nodes
|
||||
5. Applies tree reduction for SEQUENCE nodes
|
||||
6. Computes cache IDs for all steps
|
||||
"""
|
||||
|
||||
def __init__(self, use_tree_reduction: bool = True):
|
||||
"""
|
||||
Initialize the planner.
|
||||
|
||||
Args:
|
||||
use_tree_reduction: Whether to use tree reduction for SEQUENCE
|
||||
"""
|
||||
self.use_tree_reduction = use_tree_reduction
|
||||
|
||||
def plan(
|
||||
self,
|
||||
recipe: Recipe,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate an execution plan from a recipe.
|
||||
|
||||
Args:
|
||||
recipe: The parsed recipe
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results for inputs (keyed by hash)
|
||||
seed: Random seed for deterministic planning
|
||||
|
||||
Returns:
|
||||
ExecutionPlan with pre-computed cache IDs
|
||||
"""
|
||||
logger.info(f"Planning recipe: {recipe.name}")
|
||||
|
||||
# Build node lookup
|
||||
nodes_by_id = {n.id: n for n in recipe.nodes}
|
||||
|
||||
# Topologically sort nodes
|
||||
sorted_ids = self._topological_sort(recipe.nodes)
|
||||
|
||||
# Resolve registry references
|
||||
registry_hashes = self._resolve_registry(recipe.registry)
|
||||
|
||||
# Build PlanInput objects from input_hashes
|
||||
plan_inputs = []
|
||||
for name, cid in input_hashes.items():
|
||||
# Try to find matching SOURCE node for media type
|
||||
media_type = "application/octet-stream"
|
||||
for node in recipe.nodes:
|
||||
if node.id == name and node.type == "SOURCE":
|
||||
media_type = _infer_media_type("SOURCE", node.config)
|
||||
break
|
||||
|
||||
plan_inputs.append(PlanInput(
|
||||
name=name,
|
||||
cache_id=cid,
|
||||
cid=cid,
|
||||
media_type=media_type,
|
||||
))
|
||||
|
||||
# Generate steps
|
||||
steps = []
|
||||
step_id_map = {} # Maps recipe node ID to step ID(s)
|
||||
step_name_map = {} # Maps recipe node ID to human-readable name
|
||||
analysis_cache_ids = {}
|
||||
|
||||
for node_id in sorted_ids:
|
||||
node = nodes_by_id[node_id]
|
||||
logger.debug(f"Processing node: {node.id} ({node.type})")
|
||||
|
||||
new_steps, output_step_id = self._process_node(
|
||||
node=node,
|
||||
step_id_map=step_id_map,
|
||||
step_name_map=step_name_map,
|
||||
input_hashes=input_hashes,
|
||||
registry_hashes=registry_hashes,
|
||||
analysis=analysis or {},
|
||||
recipe_name=recipe.name,
|
||||
)
|
||||
|
||||
steps.extend(new_steps)
|
||||
step_id_map[node_id] = output_step_id
|
||||
# Track human-readable name for this node
|
||||
if new_steps:
|
||||
step_name_map[node_id] = new_steps[-1].name
|
||||
|
||||
# Find output step
|
||||
output_step = step_id_map.get(recipe.output)
|
||||
if not output_step:
|
||||
raise ValueError(f"Output node '{recipe.output}' not found")
|
||||
|
||||
# Determine output name
|
||||
output_name = f"{recipe.name}.output"
|
||||
output_step_obj = next((s for s in steps if s.step_id == output_step), None)
|
||||
if output_step_obj and output_step_obj.outputs:
|
||||
output_name = output_step_obj.outputs[0].name
|
||||
|
||||
# Build analysis cache IDs
|
||||
if analysis:
|
||||
analysis_cache_ids = {
|
||||
h: a.cache_id for h, a in analysis.items()
|
||||
if a.cache_id
|
||||
}
|
||||
|
||||
# Create plan
|
||||
plan = ExecutionPlan(
|
||||
plan_id=None, # Computed in __post_init__
|
||||
name=f"{recipe.name}_plan",
|
||||
recipe_id=recipe.name,
|
||||
recipe_name=recipe.name,
|
||||
recipe_hash=recipe.recipe_hash,
|
||||
seed=seed,
|
||||
inputs=plan_inputs,
|
||||
steps=steps,
|
||||
output_step=output_step,
|
||||
output_name=output_name,
|
||||
analysis_cache_ids=analysis_cache_ids,
|
||||
input_hashes=input_hashes,
|
||||
metadata={
|
||||
"recipe_version": recipe.version,
|
||||
"recipe_description": recipe.description,
|
||||
"owner": recipe.owner,
|
||||
},
|
||||
)
|
||||
|
||||
# Compute all cache IDs and then generate outputs
|
||||
plan.compute_all_cache_ids()
|
||||
plan.compute_levels()
|
||||
|
||||
# Now add outputs to each step (needs cache_id to be computed first)
|
||||
self._add_step_outputs(plan, recipe.name)
|
||||
|
||||
# Recompute plan_id after outputs are added
|
||||
plan.plan_id = plan._compute_plan_id()
|
||||
|
||||
logger.info(f"Generated plan with {len(steps)} steps")
|
||||
return plan
|
||||
|
||||
def _add_step_outputs(self, plan: ExecutionPlan, recipe_name: str) -> None:
|
||||
"""Add output definitions to each step after cache_ids are computed."""
|
||||
for step in plan.steps:
|
||||
if step.outputs:
|
||||
continue # Already has outputs
|
||||
|
||||
# Generate output name from step name
|
||||
base_name = step.name or step.step_id
|
||||
output_name = f"{recipe_name}.{base_name}.out"
|
||||
|
||||
media_type = _infer_media_type(step.node_type, step.config)
|
||||
|
||||
step.add_output(
|
||||
name=output_name,
|
||||
media_type=media_type,
|
||||
index=0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
def plan_from_yaml(
|
||||
self,
|
||||
yaml_content: str,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate plan from YAML string.
|
||||
|
||||
Args:
|
||||
yaml_content: Recipe YAML content
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results
|
||||
|
||||
Returns:
|
||||
ExecutionPlan
|
||||
"""
|
||||
recipe = Recipe.from_yaml(yaml_content)
|
||||
return self.plan(recipe, input_hashes, analysis)
|
||||
|
||||
def plan_from_file(
|
||||
self,
|
||||
recipe_path: Path,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate plan from recipe file.
|
||||
|
||||
Args:
|
||||
recipe_path: Path to recipe YAML file
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results
|
||||
|
||||
Returns:
|
||||
ExecutionPlan
|
||||
"""
|
||||
recipe = Recipe.from_file(recipe_path)
|
||||
return self.plan(recipe, input_hashes, analysis)
|
||||
|
||||
def _topological_sort(self, nodes: List[RecipeNode]) -> List[str]:
|
||||
"""Topologically sort recipe nodes."""
|
||||
nodes_by_id = {n.id: n for n in nodes}
|
||||
visited = set()
|
||||
order = []
|
||||
|
||||
def visit(node_id: str):
|
||||
if node_id in visited:
|
||||
return
|
||||
if node_id not in nodes_by_id:
|
||||
return # External input
|
||||
visited.add(node_id)
|
||||
node = nodes_by_id[node_id]
|
||||
for input_id in node.inputs:
|
||||
visit(input_id)
|
||||
order.append(node_id)
|
||||
|
||||
for node in nodes:
|
||||
visit(node.id)
|
||||
|
||||
return order
|
||||
|
||||
def _resolve_registry(self, registry: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Resolve registry references to content hashes.
|
||||
|
||||
Args:
|
||||
registry: Registry section from recipe
|
||||
|
||||
Returns:
|
||||
Mapping from name to content hash
|
||||
"""
|
||||
hashes = {}
|
||||
|
||||
# Assets
|
||||
for name, asset_data in registry.get("assets", {}).items():
|
||||
if isinstance(asset_data, dict) and "hash" in asset_data:
|
||||
hashes[name] = asset_data["hash"]
|
||||
elif isinstance(asset_data, str):
|
||||
hashes[name] = asset_data
|
||||
|
||||
# Effects
|
||||
for name, effect_data in registry.get("effects", {}).items():
|
||||
if isinstance(effect_data, dict) and "hash" in effect_data:
|
||||
hashes[f"effect:{name}"] = effect_data["hash"]
|
||||
elif isinstance(effect_data, str):
|
||||
hashes[f"effect:{name}"] = effect_data
|
||||
|
||||
return hashes
|
||||
|
||||
def _process_node(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
step_name_map: Dict[str, str],
|
||||
input_hashes: Dict[str, str],
|
||||
registry_hashes: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process a recipe node into execution steps.
|
||||
|
||||
Args:
|
||||
node: Recipe node to process
|
||||
step_id_map: Mapping from processed node IDs to step IDs
|
||||
step_name_map: Mapping from node IDs to human-readable names
|
||||
input_hashes: User-provided input hashes
|
||||
registry_hashes: Registry-resolved hashes
|
||||
analysis: Analysis results
|
||||
recipe_name: Name of the recipe (for generating readable names)
|
||||
|
||||
Returns:
|
||||
Tuple of (new steps, output step ID)
|
||||
"""
|
||||
# SOURCE nodes
|
||||
if node.type == "SOURCE":
|
||||
return self._process_source(node, input_hashes, registry_hashes, recipe_name)
|
||||
|
||||
# SOURCE_LIST nodes
|
||||
if node.type == "SOURCE_LIST":
|
||||
return self._process_source_list(node, input_hashes, recipe_name)
|
||||
|
||||
# ANALYZE nodes
|
||||
if node.type == "ANALYZE":
|
||||
return self._process_analyze(node, step_id_map, analysis, recipe_name)
|
||||
|
||||
# MAP nodes
|
||||
if node.type == "MAP":
|
||||
return self._process_map(node, step_id_map, input_hashes, analysis, recipe_name)
|
||||
|
||||
# SEQUENCE nodes (may use tree reduction)
|
||||
if node.type == "SEQUENCE":
|
||||
return self._process_sequence(node, step_id_map, recipe_name)
|
||||
|
||||
# SEGMENT_AT nodes
|
||||
if node.type == "SEGMENT_AT":
|
||||
return self._process_segment_at(node, step_id_map, analysis, recipe_name)
|
||||
|
||||
# Standard nodes (SEGMENT, RESIZE, TRANSFORM, LAYER, MUX, BLEND, etc.)
|
||||
return self._process_standard(node, step_id_map, recipe_name)
|
||||
|
||||
def _process_source(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
input_hashes: Dict[str, str],
|
||||
registry_hashes: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""Process SOURCE node."""
|
||||
config = dict(node.config)
|
||||
|
||||
# Variable input?
|
||||
if config.get("input"):
|
||||
# Look up in user-provided inputs
|
||||
if node.id not in input_hashes:
|
||||
raise ValueError(f"Missing input for SOURCE node '{node.id}'")
|
||||
cid = input_hashes[node.id]
|
||||
# Fixed asset from registry?
|
||||
elif config.get("asset"):
|
||||
asset_name = config["asset"]
|
||||
if asset_name not in registry_hashes:
|
||||
raise ValueError(f"Asset '{asset_name}' not found in registry")
|
||||
cid = registry_hashes[asset_name]
|
||||
else:
|
||||
raise ValueError(f"SOURCE node '{node.id}' has no input or asset")
|
||||
|
||||
# Human-readable name
|
||||
display_name = config.get("name", node.id)
|
||||
step_name = f"{recipe_name}.inputs.{display_name}" if recipe_name else display_name
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SOURCE",
|
||||
config={"input_ref": node.id, "cid": cid},
|
||||
input_steps=[],
|
||||
cache_id=cid, # SOURCE cache_id is just the content hash
|
||||
name=step_name,
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_source_list(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
input_hashes: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SOURCE_LIST node.
|
||||
|
||||
Creates individual SOURCE steps for each item in the list.
|
||||
"""
|
||||
# Look for list input
|
||||
if node.id not in input_hashes:
|
||||
raise ValueError(f"Missing input for SOURCE_LIST node '{node.id}'")
|
||||
|
||||
input_value = input_hashes[node.id]
|
||||
|
||||
# Parse as comma-separated list if string
|
||||
if isinstance(input_value, str):
|
||||
items = [h.strip() for h in input_value.split(",")]
|
||||
else:
|
||||
items = list(input_value)
|
||||
|
||||
display_name = node.config.get("name", node.id)
|
||||
base_name = f"{recipe_name}.{display_name}" if recipe_name else display_name
|
||||
|
||||
steps = []
|
||||
for i, cid in enumerate(items):
|
||||
step = ExecutionStep(
|
||||
step_id=f"{node.id}_{i}",
|
||||
node_type="SOURCE",
|
||||
config={"input_ref": f"{node.id}[{i}]", "cid": cid},
|
||||
input_steps=[],
|
||||
cache_id=cid,
|
||||
name=f"{base_name}[{i}]",
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# Return list marker as output
|
||||
list_step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="_LIST",
|
||||
config={"items": [s.step_id for s in steps]},
|
||||
input_steps=[s.step_id for s in steps],
|
||||
name=f"{base_name}.list",
|
||||
)
|
||||
steps.append(list_step)
|
||||
|
||||
return steps, list_step.step_id
|
||||
|
||||
def _process_analyze(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process ANALYZE node.
|
||||
|
||||
ANALYZE nodes reference pre-computed analysis results.
|
||||
"""
|
||||
input_step = step_id_map.get(node.inputs[0]) if node.inputs else None
|
||||
if not input_step:
|
||||
raise ValueError(f"ANALYZE node '{node.id}' has no input")
|
||||
|
||||
feature = node.config.get("feature", "all")
|
||||
step_name = f"{recipe_name}.analysis.{feature}" if recipe_name else f"analysis.{feature}"
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="ANALYZE",
|
||||
config={
|
||||
"feature": feature,
|
||||
**node.config,
|
||||
},
|
||||
input_steps=[input_step],
|
||||
name=step_name,
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_map(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process MAP node - expand iteration over list.
|
||||
|
||||
MAP applies an operation to each item in a list.
|
||||
"""
|
||||
operation = node.config.get("operation", "TRANSFORM")
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
|
||||
# Get items input
|
||||
items_ref = node.config.get("items") or (
|
||||
node.inputs[0] if isinstance(node.inputs, list) else
|
||||
node.inputs.get("items") if isinstance(node.inputs, dict) else None
|
||||
)
|
||||
|
||||
if not items_ref:
|
||||
raise ValueError(f"MAP node '{node.id}' has no items input")
|
||||
|
||||
# Resolve items to list of step IDs
|
||||
if items_ref in step_id_map:
|
||||
# Reference to SOURCE_LIST output
|
||||
items_step = step_id_map[items_ref]
|
||||
# TODO: expand list items
|
||||
logger.warning(f"MAP node '{node.id}' references list step, expansion TBD")
|
||||
item_steps = [items_step]
|
||||
else:
|
||||
item_steps = [items_ref]
|
||||
|
||||
# Generate step for each item
|
||||
steps = []
|
||||
output_steps = []
|
||||
|
||||
for i, item_step in enumerate(item_steps):
|
||||
step_id = f"{node.id}_{i}"
|
||||
|
||||
if operation == "RANDOM_SLICE":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="SEGMENT",
|
||||
config={
|
||||
"random": True,
|
||||
"seed_from": node.config.get("seed_from"),
|
||||
"index": i,
|
||||
},
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.slice[{i}]",
|
||||
)
|
||||
elif operation == "TRANSFORM":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="TRANSFORM",
|
||||
config=node.config.get("effects", {}),
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.transform[{i}]",
|
||||
)
|
||||
elif operation == "ANALYZE":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="ANALYZE",
|
||||
config={"feature": node.config.get("feature", "all")},
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.analyze[{i}]",
|
||||
)
|
||||
else:
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type=operation,
|
||||
config=node.config,
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.{operation.lower()}[{i}]",
|
||||
)
|
||||
|
||||
steps.append(step)
|
||||
output_steps.append(step_id)
|
||||
|
||||
# Create list output
|
||||
list_step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="_LIST",
|
||||
config={"items": output_steps},
|
||||
input_steps=output_steps,
|
||||
name=f"{base_name}.results",
|
||||
)
|
||||
steps.append(list_step)
|
||||
|
||||
return steps, list_step.step_id
|
||||
|
||||
def _process_sequence(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SEQUENCE node.
|
||||
|
||||
Uses tree reduction for parallel composition if enabled.
|
||||
"""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
|
||||
# Resolve input steps
|
||||
input_steps = []
|
||||
for input_id in node.inputs:
|
||||
if input_id in step_id_map:
|
||||
input_steps.append(step_id_map[input_id])
|
||||
else:
|
||||
input_steps.append(input_id)
|
||||
|
||||
if len(input_steps) == 0:
|
||||
raise ValueError(f"SEQUENCE node '{node.id}' has no inputs")
|
||||
|
||||
if len(input_steps) == 1:
|
||||
# Single input, no sequence needed
|
||||
return [], input_steps[0]
|
||||
|
||||
transition_config = node.config.get("transition", {"type": "cut"})
|
||||
config = {"transition": transition_config}
|
||||
|
||||
if self.use_tree_reduction and len(input_steps) > 2:
|
||||
# Use tree reduction
|
||||
reduction_steps, output_id = reduce_sequence(
|
||||
input_steps,
|
||||
transition_config=config,
|
||||
id_prefix=node.id,
|
||||
)
|
||||
|
||||
steps = []
|
||||
for i, (step_id, inputs, step_config) in enumerate(reduction_steps):
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="SEQUENCE",
|
||||
config=step_config,
|
||||
input_steps=inputs,
|
||||
name=f"{base_name}.reduce[{i}]",
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
return steps, output_id
|
||||
else:
|
||||
# Direct sequence
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SEQUENCE",
|
||||
config=config,
|
||||
input_steps=input_steps,
|
||||
name=f"{base_name}.concat",
|
||||
)
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_segment_at(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SEGMENT_AT node - cut at specific times.
|
||||
|
||||
Creates SEGMENT steps for each time range.
|
||||
"""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
times_from = node.config.get("times_from")
|
||||
distribute = node.config.get("distribute", "round_robin")
|
||||
|
||||
# TODO: Resolve times from analysis
|
||||
# For now, create a placeholder
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SEGMENT_AT",
|
||||
config=node.config,
|
||||
input_steps=[step_id_map.get(i, i) for i in node.inputs],
|
||||
name=f"{base_name}.segment",
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_standard(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""Process standard transformation/composition node."""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
input_steps = [step_id_map.get(i, i) for i in node.inputs]
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type=node.type,
|
||||
config=node.config,
|
||||
input_steps=input_steps,
|
||||
name=f"{base_name}.{node.type.lower()}",
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
594
artdag/planning/schema.py
Normal file
594
artdag/planning/schema.py
Normal file
@@ -0,0 +1,594 @@
|
||||
# artdag/planning/schema.py
|
||||
"""
|
||||
Data structures for execution plans.
|
||||
|
||||
An ExecutionPlan contains all steps needed to execute a recipe,
|
||||
with pre-computed cache IDs for each step.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Cluster key for trust domains
|
||||
# Systems with the same key produce the same cache_ids and can share work
|
||||
# Systems with different keys have isolated cache namespaces
|
||||
CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY")
|
||||
|
||||
|
||||
def set_cluster_key(key: Optional[str]) -> None:
|
||||
"""Set the cluster key programmatically."""
|
||||
global CLUSTER_KEY
|
||||
CLUSTER_KEY = key
|
||||
|
||||
|
||||
def get_cluster_key() -> Optional[str]:
|
||||
"""Get the current cluster key."""
|
||||
return CLUSTER_KEY
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Create stable hash from arbitrary data.
|
||||
|
||||
If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create
|
||||
isolated trust domains. Systems with the same key can share work;
|
||||
systems with different keys have separate cache namespaces.
|
||||
"""
|
||||
# Mix in cluster key if set
|
||||
if CLUSTER_KEY:
|
||||
data = {"_cluster_key": CLUSTER_KEY, "_data": data}
|
||||
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
"""Status of an execution step."""
|
||||
PENDING = "pending"
|
||||
CLAIMED = "claimed"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CACHED = "cached"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepOutput:
|
||||
"""
|
||||
A single output from an execution step.
|
||||
|
||||
Nodes may produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has a human-readable name and a cache_id for storage.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Content-addressed hash for caching
|
||||
media_type: MIME type of the output (e.g., "video/mp4", "audio/wav")
|
||||
index: Output index for multi-output nodes
|
||||
metadata: Optional additional metadata (time_range, etc.)
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
media_type: str = "application/octet-stream"
|
||||
index: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"media_type": self.media_type,
|
||||
"index": self.index,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepOutput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
index=data.get("index", 0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepInput:
|
||||
"""
|
||||
Reference to an input for a step.
|
||||
|
||||
Inputs can reference outputs from other steps by name.
|
||||
|
||||
Attributes:
|
||||
name: Input slot name (e.g., "video", "audio", "segments")
|
||||
source: Source output name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Resolved cache_id of the source (populated during planning)
|
||||
"""
|
||||
name: str
|
||||
source: str
|
||||
cache_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"source": self.source,
|
||||
"cache_id": self.cache_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
source=data["source"],
|
||||
cache_id=data.get("cache_id"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStep:
|
||||
"""
|
||||
A single step in the execution plan.
|
||||
|
||||
Each step has a pre-computed cache_id that uniquely identifies
|
||||
its output based on its configuration and input cache_ids.
|
||||
|
||||
Steps can produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has its own cache_id derived from the step's cache_id + index.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name relating to recipe (e.g., "beats.split")
|
||||
step_id: Unique identifier (hash) for this step
|
||||
node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.)
|
||||
config: Configuration for the primitive
|
||||
input_steps: IDs of steps this depends on (legacy, use inputs for new code)
|
||||
inputs: Structured input references with names and sources
|
||||
cache_id: Pre-computed cache ID (hash of config + input cache_ids)
|
||||
outputs: List of outputs this step produces
|
||||
estimated_duration: Optional estimated execution time
|
||||
level: Dependency level (0 = no dependencies, higher = more deps)
|
||||
"""
|
||||
step_id: str
|
||||
node_type: str
|
||||
config: Dict[str, Any]
|
||||
input_steps: List[str] = field(default_factory=list)
|
||||
inputs: List[StepInput] = field(default_factory=list)
|
||||
cache_id: Optional[str] = None
|
||||
outputs: List[StepOutput] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
estimated_duration: Optional[float] = None
|
||||
level: int = 0
|
||||
|
||||
def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str:
|
||||
"""
|
||||
Compute cache ID from configuration and input cache IDs.
|
||||
|
||||
cache_id = SHA3-256(node_type + config + sorted(input_cache_ids))
|
||||
|
||||
Args:
|
||||
input_cache_ids: Mapping from input step_id/name to their cache_id
|
||||
|
||||
Returns:
|
||||
The computed cache_id
|
||||
"""
|
||||
# Use structured inputs if available, otherwise fall back to input_steps
|
||||
if self.inputs:
|
||||
resolved_inputs = [
|
||||
inp.cache_id or input_cache_ids.get(inp.source, inp.source)
|
||||
for inp in sorted(self.inputs, key=lambda x: x.name)
|
||||
]
|
||||
else:
|
||||
resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)]
|
||||
|
||||
content = {
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"inputs": resolved_inputs,
|
||||
}
|
||||
self.cache_id = _stable_hash(content)
|
||||
return self.cache_id
|
||||
|
||||
def compute_output_cache_id(self, index: int) -> str:
|
||||
"""
|
||||
Compute cache ID for a specific output index.
|
||||
|
||||
output_cache_id = SHA3-256(step_cache_id + index)
|
||||
|
||||
Args:
|
||||
index: The output index
|
||||
|
||||
Returns:
|
||||
Cache ID for that output
|
||||
"""
|
||||
if not self.cache_id:
|
||||
raise ValueError("Step cache_id must be computed first")
|
||||
content = {"step_cache_id": self.cache_id, "output_index": index}
|
||||
return _stable_hash(content)
|
||||
|
||||
def add_output(
|
||||
self,
|
||||
name: str,
|
||||
media_type: str = "application/octet-stream",
|
||||
index: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> StepOutput:
|
||||
"""
|
||||
Add an output to this step.
|
||||
|
||||
Args:
|
||||
name: Human-readable output name
|
||||
media_type: MIME type of the output
|
||||
index: Output index (defaults to next available)
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created StepOutput
|
||||
"""
|
||||
if index is None:
|
||||
index = len(self.outputs)
|
||||
|
||||
cache_id = self.compute_output_cache_id(index)
|
||||
output = StepOutput(
|
||||
name=name,
|
||||
cache_id=cache_id,
|
||||
media_type=media_type,
|
||||
index=index,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.outputs.append(output)
|
||||
return output
|
||||
|
||||
def get_output(self, index: int = 0) -> Optional[StepOutput]:
|
||||
"""Get output by index."""
|
||||
if index < len(self.outputs):
|
||||
return self.outputs[index]
|
||||
return None
|
||||
|
||||
def get_output_by_name(self, name: str) -> Optional[StepOutput]:
|
||||
"""Get output by name."""
|
||||
for output in self.outputs:
|
||||
if output.name == name:
|
||||
return output
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step_id": self.step_id,
|
||||
"name": self.name,
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"input_steps": self.input_steps,
|
||||
"inputs": [inp.to_dict() for inp in self.inputs],
|
||||
"cache_id": self.cache_id,
|
||||
"outputs": [out.to_dict() for out in self.outputs],
|
||||
"estimated_duration": self.estimated_duration,
|
||||
"level": self.level,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep":
|
||||
inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])]
|
||||
return cls(
|
||||
step_id=data["step_id"],
|
||||
node_type=data["node_type"],
|
||||
config=data.get("config", {}),
|
||||
input_steps=data.get("input_steps", []),
|
||||
inputs=inputs,
|
||||
cache_id=data.get("cache_id"),
|
||||
outputs=outputs,
|
||||
name=data.get("name"),
|
||||
estimated_duration=data.get("estimated_duration"),
|
||||
level=data.get("level", 0),
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionStep":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanInput:
|
||||
"""
|
||||
An input to the execution plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name from recipe (e.g., "source_video")
|
||||
cache_id: Content hash of the input file
|
||||
cid: Same as cache_id (for clarity)
|
||||
media_type: MIME type of the input
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
cid: str
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"cid": self.cid,
|
||||
"media_type": self.media_type,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlanInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
cid=data.get("cid", data["cache_id"]),
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
"""
|
||||
Complete execution plan for a recipe.
|
||||
|
||||
Contains all steps in topological order with pre-computed cache IDs.
|
||||
The plan is deterministic: same recipe + same inputs = same plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable plan name from recipe
|
||||
plan_id: Hash of the entire plan (for deduplication)
|
||||
recipe_id: Source recipe identifier
|
||||
recipe_name: Human-readable recipe name
|
||||
recipe_hash: Hash of the recipe content
|
||||
seed: Random seed used for planning
|
||||
steps: List of steps in execution order
|
||||
output_step: ID of the final output step
|
||||
output_name: Human-readable name of the final output
|
||||
inputs: Structured input definitions
|
||||
analysis_cache_ids: Cache IDs of analysis results used
|
||||
input_hashes: Content hashes of input files (legacy, use inputs)
|
||||
created_at: When the plan was generated
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
plan_id: Optional[str]
|
||||
recipe_id: str
|
||||
recipe_hash: str
|
||||
steps: List[ExecutionStep]
|
||||
output_step: str
|
||||
name: Optional[str] = None
|
||||
recipe_name: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
output_name: Optional[str] = None
|
||||
inputs: List[PlanInput] = field(default_factory=list)
|
||||
analysis_cache_ids: Dict[str, str] = field(default_factory=dict)
|
||||
input_hashes: Dict[str, str] = field(default_factory=dict)
|
||||
created_at: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.plan_id is None:
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def _compute_plan_id(self) -> str:
|
||||
"""Compute plan ID from contents."""
|
||||
content = {
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"input_hashes": self.input_hashes,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
}
|
||||
return _stable_hash(content)
|
||||
|
||||
def compute_all_cache_ids(self) -> None:
|
||||
"""
|
||||
Compute cache IDs for all steps in dependency order.
|
||||
|
||||
Must be called after all steps are added to ensure
|
||||
cache IDs propagate correctly through dependencies.
|
||||
"""
|
||||
# Build step lookup
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
|
||||
# Cache IDs start with input hashes
|
||||
cache_ids = dict(self.input_hashes)
|
||||
|
||||
# Process in order (assumes topological order)
|
||||
for step in self.steps:
|
||||
# For SOURCE steps referencing inputs, use input hash
|
||||
if step.node_type == "SOURCE" and step.config.get("input_ref"):
|
||||
ref = step.config["input_ref"]
|
||||
if ref in self.input_hashes:
|
||||
step.cache_id = self.input_hashes[ref]
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
continue
|
||||
|
||||
# For other steps, compute from inputs
|
||||
input_cache_ids = {}
|
||||
for input_step_id in step.input_steps:
|
||||
if input_step_id in cache_ids:
|
||||
input_cache_ids[input_step_id] = cache_ids[input_step_id]
|
||||
elif input_step_id in step_by_id:
|
||||
# Step should have been processed already
|
||||
input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id
|
||||
else:
|
||||
raise ValueError(f"Input step {input_step_id} not found for {step.step_id}")
|
||||
|
||||
step.compute_cache_id(input_cache_ids)
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Recompute plan_id with final cache IDs
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def compute_levels(self) -> int:
|
||||
"""
|
||||
Compute dependency levels for all steps.
|
||||
|
||||
Level 0 = no dependencies (can start immediately)
|
||||
Level N = depends on steps at level N-1
|
||||
|
||||
Returns:
|
||||
Maximum level (number of sequential dependency levels)
|
||||
"""
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
levels = {}
|
||||
|
||||
def compute_level(step_id: str) -> int:
|
||||
if step_id in levels:
|
||||
return levels[step_id]
|
||||
|
||||
step = step_by_id.get(step_id)
|
||||
if step is None:
|
||||
return 0 # Input from outside the plan
|
||||
|
||||
if not step.input_steps:
|
||||
levels[step_id] = 0
|
||||
step.level = 0
|
||||
return 0
|
||||
|
||||
max_input_level = max(compute_level(s) for s in step.input_steps)
|
||||
level = max_input_level + 1
|
||||
levels[step_id] = level
|
||||
step.level = level
|
||||
return level
|
||||
|
||||
for step in self.steps:
|
||||
compute_level(step.step_id)
|
||||
|
||||
return max(levels.values()) if levels else 0
|
||||
|
||||
def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]:
|
||||
"""
|
||||
Group steps by dependency level.
|
||||
|
||||
Steps at the same level can execute in parallel.
|
||||
|
||||
Returns:
|
||||
Dict mapping level -> list of steps at that level
|
||||
"""
|
||||
by_level: Dict[int, List[ExecutionStep]] = {}
|
||||
for step in self.steps:
|
||||
by_level.setdefault(step.level, []).append(step)
|
||||
return by_level
|
||||
|
||||
def get_step(self, step_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by ID."""
|
||||
for step in self.steps:
|
||||
if step.step_id == step_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by cache ID."""
|
||||
for step in self.steps:
|
||||
if step.cache_id == cache_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_name(self, name: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by human-readable name."""
|
||||
for step in self.steps:
|
||||
if step.name == name:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_all_outputs(self) -> Dict[str, StepOutput]:
|
||||
"""
|
||||
Get all outputs from all steps, keyed by output name.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> StepOutput
|
||||
"""
|
||||
outputs = {}
|
||||
for step in self.steps:
|
||||
for output in step.outputs:
|
||||
outputs[output.name] = output
|
||||
return outputs
|
||||
|
||||
def get_output_cache_ids(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get mapping of output names to cache IDs.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> cache_id
|
||||
"""
|
||||
return {
|
||||
output.name: output.cache_id
|
||||
for step in self.steps
|
||||
for output in step.outputs
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plan_id": self.plan_id,
|
||||
"name": self.name,
|
||||
"recipe_id": self.recipe_id,
|
||||
"recipe_name": self.recipe_name,
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"seed": self.seed,
|
||||
"inputs": [i.to_dict() for i in self.inputs],
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"output_step": self.output_step,
|
||||
"output_name": self.output_name,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
"input_hashes": self.input_hashes,
|
||||
"created_at": self.created_at,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan":
|
||||
inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
return cls(
|
||||
plan_id=data.get("plan_id"),
|
||||
name=data.get("name"),
|
||||
recipe_id=data["recipe_id"],
|
||||
recipe_name=data.get("recipe_name"),
|
||||
recipe_hash=data["recipe_hash"],
|
||||
seed=data.get("seed"),
|
||||
inputs=inputs,
|
||||
steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])],
|
||||
output_step=data["output_step"],
|
||||
output_name=data.get("output_name"),
|
||||
analysis_cache_ids=data.get("analysis_cache_ids", {}),
|
||||
input_hashes=data.get("input_hashes", {}),
|
||||
created_at=data.get("created_at"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionPlan":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Get a human-readable summary of the plan."""
|
||||
by_level = self.get_steps_by_level()
|
||||
max_level = max(by_level.keys()) if by_level else 0
|
||||
|
||||
lines = [
|
||||
f"Execution Plan: {self.plan_id[:16]}...",
|
||||
f"Recipe: {self.recipe_id}",
|
||||
f"Steps: {len(self.steps)}",
|
||||
f"Levels: {max_level + 1}",
|
||||
"",
|
||||
]
|
||||
|
||||
for level in sorted(by_level.keys()):
|
||||
steps = by_level[level]
|
||||
lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)")
|
||||
for step in steps:
|
||||
cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]"
|
||||
lines.append(f" - {step.step_id}: {step.node_type} {cache_status}")
|
||||
|
||||
return "\n".join(lines)
|
||||
231
artdag/planning/tree_reduction.py
Normal file
231
artdag/planning/tree_reduction.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# artdag/planning/tree_reduction.py
|
||||
"""
|
||||
Tree reduction for parallel composition.
|
||||
|
||||
Instead of sequential pairwise composition:
|
||||
A → AB → ABC → ABCD (3 sequential steps)
|
||||
|
||||
Use parallel tree reduction:
|
||||
A ─┬─ AB ─┬─ ABCD
|
||||
B ─┘ │
|
||||
C ─┬─ CD ─┘
|
||||
D ─┘
|
||||
|
||||
This reduces O(N) to O(log N) levels of sequential dependency.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReductionNode:
|
||||
"""A node in the reduction tree."""
|
||||
node_id: str
|
||||
input_ids: List[str]
|
||||
level: int
|
||||
position: int # Position within level
|
||||
|
||||
|
||||
class TreeReducer:
|
||||
"""
|
||||
Generates tree reduction plans for parallel composition.
|
||||
|
||||
Used to convert N inputs into optimal parallel SEQUENCE operations.
|
||||
"""
|
||||
|
||||
def __init__(self, node_type: str = "SEQUENCE"):
|
||||
"""
|
||||
Initialize the reducer.
|
||||
|
||||
Args:
|
||||
node_type: The composition node type (SEQUENCE, AUDIO_MIX, etc.)
|
||||
"""
|
||||
self.node_type = node_type
|
||||
|
||||
def reduce(
|
||||
self,
|
||||
input_ids: List[str],
|
||||
id_prefix: str = "reduce",
|
||||
) -> Tuple[List[ReductionNode], str]:
|
||||
"""
|
||||
Generate a tree reduction plan for the given inputs.
|
||||
|
||||
Args:
|
||||
input_ids: List of input step IDs to reduce
|
||||
id_prefix: Prefix for generated node IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of reduction nodes, final output node ID)
|
||||
"""
|
||||
if len(input_ids) == 0:
|
||||
raise ValueError("Cannot reduce empty input list")
|
||||
|
||||
if len(input_ids) == 1:
|
||||
# Single input, no reduction needed
|
||||
return [], input_ids[0]
|
||||
|
||||
if len(input_ids) == 2:
|
||||
# Two inputs, single reduction
|
||||
node_id = f"{id_prefix}_final"
|
||||
node = ReductionNode(
|
||||
node_id=node_id,
|
||||
input_ids=input_ids,
|
||||
level=0,
|
||||
position=0,
|
||||
)
|
||||
return [node], node_id
|
||||
|
||||
# Build tree levels
|
||||
nodes = []
|
||||
current_level = list(input_ids)
|
||||
level_num = 0
|
||||
|
||||
while len(current_level) > 1:
|
||||
next_level = []
|
||||
position = 0
|
||||
|
||||
# Pair up nodes at current level
|
||||
i = 0
|
||||
while i < len(current_level):
|
||||
if i + 1 < len(current_level):
|
||||
# Pair available
|
||||
left = current_level[i]
|
||||
right = current_level[i + 1]
|
||||
node_id = f"{id_prefix}_L{level_num}_P{position}"
|
||||
node = ReductionNode(
|
||||
node_id=node_id,
|
||||
input_ids=[left, right],
|
||||
level=level_num,
|
||||
position=position,
|
||||
)
|
||||
nodes.append(node)
|
||||
next_level.append(node_id)
|
||||
i += 2
|
||||
else:
|
||||
# Odd one out, promote to next level
|
||||
next_level.append(current_level[i])
|
||||
i += 1
|
||||
|
||||
position += 1
|
||||
|
||||
current_level = next_level
|
||||
level_num += 1
|
||||
|
||||
# The last remaining node is the output
|
||||
output_id = current_level[0]
|
||||
|
||||
# Rename final node for clarity
|
||||
if nodes and nodes[-1].node_id == output_id:
|
||||
nodes[-1].node_id = f"{id_prefix}_final"
|
||||
output_id = f"{id_prefix}_final"
|
||||
|
||||
return nodes, output_id
|
||||
|
||||
def get_reduction_depth(self, n: int) -> int:
|
||||
"""
|
||||
Calculate the number of reduction levels needed.
|
||||
|
||||
Args:
|
||||
n: Number of inputs
|
||||
|
||||
Returns:
|
||||
Number of sequential reduction levels (log2(n) ceiling)
|
||||
"""
|
||||
if n <= 1:
|
||||
return 0
|
||||
return math.ceil(math.log2(n))
|
||||
|
||||
def get_total_operations(self, n: int) -> int:
|
||||
"""
|
||||
Calculate total number of reduction operations.
|
||||
|
||||
Args:
|
||||
n: Number of inputs
|
||||
|
||||
Returns:
|
||||
Total composition operations (always n-1)
|
||||
"""
|
||||
return max(0, n - 1)
|
||||
|
||||
def reduce_with_config(
|
||||
self,
|
||||
input_ids: List[str],
|
||||
base_config: Dict[str, Any],
|
||||
id_prefix: str = "reduce",
|
||||
) -> Tuple[List[Tuple[ReductionNode, Dict[str, Any]]], str]:
|
||||
"""
|
||||
Generate reduction plan with configuration for each node.
|
||||
|
||||
Args:
|
||||
input_ids: List of input step IDs
|
||||
base_config: Base configuration to use for each reduction
|
||||
id_prefix: Prefix for generated node IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (node, config) pairs, final output ID)
|
||||
"""
|
||||
nodes, output_id = self.reduce(input_ids, id_prefix)
|
||||
result = [(node, dict(base_config)) for node in nodes]
|
||||
return result, output_id
|
||||
|
||||
|
||||
def reduce_sequence(
|
||||
input_ids: List[str],
|
||||
transition_config: Dict[str, Any] = None,
|
||||
id_prefix: str = "seq",
|
||||
) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]:
|
||||
"""
|
||||
Convenience function for SEQUENCE reduction.
|
||||
|
||||
Args:
|
||||
input_ids: Input step IDs to sequence
|
||||
transition_config: Transition configuration (default: cut)
|
||||
id_prefix: Prefix for generated step IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (step_id, inputs, config), final step ID)
|
||||
"""
|
||||
if transition_config is None:
|
||||
transition_config = {"transition": {"type": "cut"}}
|
||||
|
||||
reducer = TreeReducer("SEQUENCE")
|
||||
nodes, output_id = reducer.reduce(input_ids, id_prefix)
|
||||
|
||||
result = [
|
||||
(node.node_id, node.input_ids, dict(transition_config))
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
return result, output_id
|
||||
|
||||
|
||||
def reduce_audio_mix(
|
||||
input_ids: List[str],
|
||||
mix_config: Dict[str, Any] = None,
|
||||
id_prefix: str = "mix",
|
||||
) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]:
|
||||
"""
|
||||
Convenience function for AUDIO_MIX reduction.
|
||||
|
||||
Args:
|
||||
input_ids: Input step IDs to mix
|
||||
mix_config: Mix configuration
|
||||
id_prefix: Prefix for generated step IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (step_id, inputs, config), final step ID)
|
||||
"""
|
||||
if mix_config is None:
|
||||
mix_config = {"normalize": True}
|
||||
|
||||
reducer = TreeReducer("AUDIO_MIX")
|
||||
nodes, output_id = reducer.reduce(input_ids, id_prefix)
|
||||
|
||||
result = [
|
||||
(node.node_id, node.input_ids, dict(mix_config))
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
return result, output_id
|
||||
Reference in New Issue
Block a user