Squashed 'core/' content from commit 4957443

git-subtree-dir: core
git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07
This commit is contained in:
giles
2026-02-24 23:09:39 +00:00
commit cc2dcbddd4
80 changed files with 25711 additions and 0 deletions

75
artdag/sexp/__init__.py Normal file
View File

@@ -0,0 +1,75 @@
"""
S-expression parsing, compilation, and planning for ArtDAG.
This module provides:
- parser: Parse S-expression text into Python data structures
- compiler: Compile recipe S-expressions into DAG format
- planner: Generate execution plans from recipes
"""
from .parser import (
parse,
parse_all,
serialize,
Symbol,
Keyword,
ParseError,
)
from .compiler import (
compile_recipe,
compile_string,
CompiledRecipe,
CompileError,
ParamDef,
_parse_params,
)
from .planner import (
create_plan,
ExecutionPlanSexp,
PlanStep,
step_to_task_sexp,
task_cache_id,
)
from .scheduler import (
PlanScheduler,
PlanResult,
StepResult,
schedule_plan,
step_to_sexp,
step_sexp_to_string,
verify_step_cache_id,
)
__all__ = [
# Parser
'parse',
'parse_all',
'serialize',
'Symbol',
'Keyword',
'ParseError',
# Compiler
'compile_recipe',
'compile_string',
'CompiledRecipe',
'CompileError',
'ParamDef',
'_parse_params',
# Planner
'create_plan',
'ExecutionPlanSexp',
'PlanStep',
'step_to_task_sexp',
'task_cache_id',
# Scheduler
'PlanScheduler',
'PlanResult',
'StepResult',
'schedule_plan',
'step_to_sexp',
'step_sexp_to_string',
'verify_step_cache_id',
]

2463
artdag/sexp/compiler.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,337 @@
"""
Sexp effect loader.
Loads sexp effect definitions (define-effect forms) and creates
frame processors that evaluate the sexp body with primitives.
Effects must use :params syntax:
(define-effect name
:params (
(param1 :type int :default 8 :range [4 32] :desc "description")
(param2 :type string :default "value" :desc "description")
)
body)
For effects with no parameters, use empty :params ():
(define-effect name
:params ()
body)
Unknown parameters passed to effects will raise an error.
Usage:
loader = SexpEffectLoader()
effect_fn = loader.load_effect_file(Path("effects/ascii_art.sexp"))
output = effect_fn(input_path, output_path, config)
"""
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from .parser import parse_all, Symbol, Keyword
from .evaluator import evaluate
from .primitives import PRIMITIVES
from .compiler import ParamDef, _parse_params, CompileError
logger = logging.getLogger(__name__)
def _parse_define_effect(sexp) -> tuple:
"""
Parse a define-effect form.
Required syntax:
(define-effect name
:params (
(param1 :type int :default 8 :range [4 32] :desc "description")
)
body)
Effects MUST use :params syntax. Legacy ((param default) ...) syntax is not supported.
Returns (name, params_with_defaults, param_defs, body)
where param_defs is a list of ParamDef objects
"""
if not isinstance(sexp, list) or len(sexp) < 3:
raise ValueError(f"Invalid define-effect form: {sexp}")
head = sexp[0]
if not (isinstance(head, Symbol) and head.name == "define-effect"):
raise ValueError(f"Expected define-effect, got {head}")
name = sexp[1]
if isinstance(name, Symbol):
name = name.name
params_with_defaults = {}
param_defs: List[ParamDef] = []
body = None
found_params = False
# Parse :params and body
i = 2
while i < len(sexp):
item = sexp[i]
if isinstance(item, Keyword) and item.name == "params":
# :params syntax
if i + 1 >= len(sexp):
raise ValueError(f"Effect '{name}': Missing params list after :params keyword")
try:
param_defs = _parse_params(sexp[i + 1])
# Build params_with_defaults from ParamDef objects
for pd in param_defs:
params_with_defaults[pd.name] = pd.default
except CompileError as e:
raise ValueError(f"Effect '{name}': Error parsing :params: {e}")
found_params = True
i += 2
elif isinstance(item, Keyword):
# Skip other keywords we don't recognize
i += 2
elif body is None:
# First non-keyword item is the body
if isinstance(item, list) and item:
first_elem = item[0]
# Check for legacy syntax and reject it
if isinstance(first_elem, list) and len(first_elem) >= 2:
raise ValueError(
f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. "
f"Use :params block instead:\n"
f" :params (\n"
f" (param_name :type int :default 0 :desc \"description\")\n"
f" )"
)
body = item
i += 1
else:
i += 1
if body is None:
raise ValueError(f"Effect '{name}': No body found")
if not found_params:
raise ValueError(
f"Effect '{name}': Missing :params block. Effects must declare parameters.\n"
f"For effects with no parameters, use empty :params ():\n"
f" (define-effect {name}\n"
f" :params ()\n"
f" body)"
)
return name, params_with_defaults, param_defs, body
def _create_process_frame(
effect_name: str,
params_with_defaults: Dict[str, Any],
param_defs: List[ParamDef],
body: Any,
) -> Callable:
"""
Create a process_frame function that evaluates the sexp body.
The function signature is: (frame, params, state) -> (frame, state)
"""
import math
def process_frame(frame: np.ndarray, params: Dict[str, Any], state: Any):
"""Evaluate sexp effect body on a frame."""
# Build environment with primitives
env = dict(PRIMITIVES)
# Add math functions
env["floor"] = lambda x: int(math.floor(x))
env["ceil"] = lambda x: int(math.ceil(x))
env["round"] = lambda x: int(round(x))
env["abs"] = abs
env["min"] = min
env["max"] = max
env["sqrt"] = math.sqrt
env["sin"] = math.sin
env["cos"] = math.cos
# Add list operations
env["list"] = lambda *args: tuple(args)
env["nth"] = lambda coll, i: coll[int(i)] if coll else None
# Bind frame
env["frame"] = frame
# Validate that all provided params are known
known_params = set(params_with_defaults.keys())
for k in params.keys():
if k not in known_params:
raise ValueError(
f"Effect '{effect_name}': Unknown parameter '{k}'. "
f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}"
)
# Bind parameters (defaults + overrides from config)
for param_name, default in params_with_defaults.items():
# Use config value if provided, otherwise default
if param_name in params:
env[param_name] = params[param_name]
elif default is not None:
env[param_name] = default
# Evaluate the body
try:
result = evaluate(body, env)
if isinstance(result, np.ndarray):
return result, state
else:
logger.warning(f"Effect {effect_name} returned {type(result)}, expected ndarray")
return frame, state
except Exception as e:
logger.error(f"Error evaluating effect {effect_name}: {e}")
raise
return process_frame
def load_sexp_effect(source: str, base_path: Optional[Path] = None) -> tuple:
"""
Load a sexp effect from source code.
Args:
source: Sexp source code
base_path: Base path for resolving relative imports
Returns:
(effect_name, process_frame_fn, params_with_defaults, param_defs)
where param_defs is a list of ParamDef objects for introspection
"""
exprs = parse_all(source)
# Find define-effect form
define_effect = None
if isinstance(exprs, list):
for expr in exprs:
if isinstance(expr, list) and expr and isinstance(expr[0], Symbol):
if expr[0].name == "define-effect":
define_effect = expr
break
elif isinstance(exprs, list) and exprs and isinstance(exprs[0], Symbol):
if exprs[0].name == "define-effect":
define_effect = exprs
if not define_effect:
raise ValueError("No define-effect form found in sexp effect")
name, params_with_defaults, param_defs, body = _parse_define_effect(define_effect)
process_frame = _create_process_frame(name, params_with_defaults, param_defs, body)
return name, process_frame, params_with_defaults, param_defs
def load_sexp_effect_file(path: Path) -> tuple:
"""
Load a sexp effect from file.
Returns:
(effect_name, process_frame_fn, params_with_defaults, param_defs)
where param_defs is a list of ParamDef objects for introspection
"""
source = path.read_text()
return load_sexp_effect(source, base_path=path.parent)
class SexpEffectLoader:
"""
Loader for sexp effect definitions.
Creates effect functions compatible with the EffectExecutor.
"""
def __init__(self, recipe_dir: Optional[Path] = None):
"""
Initialize loader.
Args:
recipe_dir: Base directory for resolving relative effect paths
"""
self.recipe_dir = recipe_dir or Path.cwd()
# Cache loaded effects with their param_defs for introspection
self._loaded_effects: Dict[str, tuple] = {}
def load_effect_path(self, effect_path: str) -> Callable:
"""
Load a sexp effect from a relative path.
Args:
effect_path: Relative path to effect .sexp file
Returns:
Effect function (input_path, output_path, config) -> output_path
"""
from ..effects.frame_processor import process_video
full_path = self.recipe_dir / effect_path
if not full_path.exists():
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
name, process_frame_fn, params_defaults, param_defs = load_sexp_effect_file(full_path)
logger.info(f"Loaded sexp effect: {name} from {effect_path}")
# Cache for introspection
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
def effect_fn(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
"""Run sexp effect via frame processor."""
# Extract params (excluding internal keys)
params = dict(params_defaults) # Start with defaults
for k, v in config.items():
if k not in ("effect", "cid", "hash", "effect_path", "_binding"):
params[k] = v
# Get bindings if present
bindings = {}
for key, value in config.items():
if isinstance(value, dict) and value.get("_resolved_values"):
bindings[key] = value["_resolved_values"]
output_path.parent.mkdir(parents=True, exist_ok=True)
actual_output = output_path.with_suffix(".mp4")
process_video(
input_path=input_path,
output_path=actual_output,
process_frame=process_frame_fn,
params=params,
bindings=bindings,
)
logger.info(f"Processed sexp effect '{name}' from {effect_path}")
return actual_output
return effect_fn
def get_effect_params(self, effect_path: str) -> List[ParamDef]:
"""
Get parameter definitions for an effect.
Args:
effect_path: Relative path to effect .sexp file
Returns:
List of ParamDef objects describing the effect's parameters
"""
if effect_path not in self._loaded_effects:
# Load the effect to get its params
full_path = self.recipe_dir / effect_path
if not full_path.exists():
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
name, _, params_defaults, param_defs = load_sexp_effect_file(full_path)
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
return self._loaded_effects[effect_path][2]
def get_sexp_effect_loader(recipe_dir: Optional[Path] = None) -> SexpEffectLoader:
"""Get a sexp effect loader instance."""
return SexpEffectLoader(recipe_dir)

869
artdag/sexp/evaluator.py Normal file
View File

@@ -0,0 +1,869 @@
"""
Expression evaluator for S-expression DSL.
Supports:
- Arithmetic: +, -, *, /, mod, sqrt, pow, abs, floor, ceil, round, min, max, clamp
- Comparison: =, <, >, <=, >=
- Logic: and, or, not
- Predicates: odd?, even?, zero?, nil?
- Conditionals: if, cond, case
- Data: list, dict/map construction, get
- Lambda calls
"""
from typing import Any, Dict, List, Callable
from .parser import Symbol, Keyword, Lambda, Binding
class EvalError(Exception):
"""Error during expression evaluation."""
pass
# Built-in functions
BUILTINS: Dict[str, Callable] = {}
def builtin(name: str):
"""Decorator to register a builtin function."""
def decorator(fn):
BUILTINS[name] = fn
return fn
return decorator
@builtin("+")
def add(*args):
return sum(args)
@builtin("-")
def sub(a, b=None):
if b is None:
return -a
return a - b
@builtin("*")
def mul(*args):
result = 1
for a in args:
result *= a
return result
@builtin("/")
def div(a, b):
return a / b
@builtin("mod")
def mod(a, b):
return a % b
@builtin("sqrt")
def sqrt(x):
return x ** 0.5
@builtin("pow")
def power(x, n):
return x ** n
@builtin("abs")
def absolute(x):
return abs(x)
@builtin("floor")
def floor_fn(x):
import math
return math.floor(x)
@builtin("ceil")
def ceil_fn(x):
import math
return math.ceil(x)
@builtin("round")
def round_fn(x, ndigits=0):
return round(x, int(ndigits))
@builtin("min")
def min_fn(*args):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
return min(args[0])
return min(args)
@builtin("max")
def max_fn(*args):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
return max(args[0])
return max(args)
@builtin("clamp")
def clamp(x, lo, hi):
return max(lo, min(hi, x))
@builtin("=")
def eq(a, b):
return a == b
@builtin("<")
def lt(a, b):
return a < b
@builtin(">")
def gt(a, b):
return a > b
@builtin("<=")
def lte(a, b):
return a <= b
@builtin(">=")
def gte(a, b):
return a >= b
@builtin("odd?")
def is_odd(n):
return n % 2 == 1
@builtin("even?")
def is_even(n):
return n % 2 == 0
@builtin("zero?")
def is_zero(n):
return n == 0
@builtin("nil?")
def is_nil(x):
return x is None
@builtin("not")
def not_fn(x):
return not x
@builtin("inc")
def inc(n):
return n + 1
@builtin("dec")
def dec(n):
return n - 1
@builtin("list")
def make_list(*args):
return list(args)
@builtin("assert")
def assert_true(condition, message="Assertion failed"):
if not condition:
raise RuntimeError(f"Assertion error: {message}")
return True
@builtin("get")
def get(coll, key, default=None):
if isinstance(coll, dict):
# Try the key directly first
result = coll.get(key, None)
if result is not None:
return result
# If key is a Keyword, also try its string name (for Python dicts with string keys)
if isinstance(key, Keyword):
result = coll.get(key.name, None)
if result is not None:
return result
# Return the default
return default
elif isinstance(coll, list):
return coll[key] if 0 <= key < len(coll) else default
else:
raise EvalError(f"get: expected dict or list, got {type(coll).__name__}: {str(coll)[:100]}")
@builtin("dict?")
def is_dict(x):
return isinstance(x, dict)
@builtin("list?")
def is_list(x):
return isinstance(x, list)
@builtin("nil?")
def is_nil(x):
return x is None
@builtin("number?")
def is_number(x):
return isinstance(x, (int, float))
@builtin("string?")
def is_string(x):
return isinstance(x, str)
@builtin("len")
def length(coll):
return len(coll)
@builtin("first")
def first(coll):
return coll[0] if coll else None
@builtin("last")
def last(coll):
return coll[-1] if coll else None
@builtin("chunk-every")
def chunk_every(coll, n):
"""Split collection into chunks of n elements."""
n = int(n)
return [coll[i:i+n] for i in range(0, len(coll), n)]
@builtin("rest")
def rest(coll):
return coll[1:] if coll else []
@builtin("nth")
def nth(coll, n):
return coll[n] if 0 <= n < len(coll) else None
@builtin("concat")
def concat(*colls):
"""Concatenate multiple lists/sequences."""
result = []
for c in colls:
if c is not None:
result.extend(c)
return result
@builtin("cons")
def cons(x, coll):
"""Prepend x to collection."""
return [x] + list(coll) if coll else [x]
@builtin("append")
def append(coll, x):
"""Append x to collection."""
return list(coll) + [x] if coll else [x]
@builtin("range")
def make_range(start, end, step=1):
"""Create a range of numbers."""
return list(range(int(start), int(end), int(step)))
@builtin("zip-pairs")
def zip_pairs(coll):
"""Zip consecutive pairs: [a,b,c,d] -> [[a,b],[b,c],[c,d]]."""
if not coll or len(coll) < 2:
return []
return [[coll[i], coll[i+1]] for i in range(len(coll)-1)]
@builtin("dict")
def make_dict(*pairs):
"""Create dict from key-value pairs: (dict :a 1 :b 2)."""
result = {}
i = 0
while i < len(pairs) - 1:
key = pairs[i]
if isinstance(key, Keyword):
key = key.name
result[key] = pairs[i + 1]
i += 2
return result
@builtin("keys")
def keys(d):
"""Get the keys of a dict as a list."""
if not isinstance(d, dict):
raise EvalError(f"keys: expected dict, got {type(d).__name__}")
return list(d.keys())
@builtin("vals")
def vals(d):
"""Get the values of a dict as a list."""
if not isinstance(d, dict):
raise EvalError(f"vals: expected dict, got {type(d).__name__}")
return list(d.values())
@builtin("merge")
def merge(*dicts):
"""Merge multiple dicts, later dicts override earlier."""
result = {}
for d in dicts:
if d is not None:
if not isinstance(d, dict):
raise EvalError(f"merge: expected dict, got {type(d).__name__}")
result.update(d)
return result
@builtin("assoc")
def assoc(d, *pairs):
"""Associate keys with values in a dict: (assoc d :a 1 :b 2)."""
if d is None:
result = {}
elif isinstance(d, dict):
result = dict(d)
else:
raise EvalError(f"assoc: expected dict or nil, got {type(d).__name__}")
i = 0
while i < len(pairs) - 1:
key = pairs[i]
if isinstance(key, Keyword):
key = key.name
result[key] = pairs[i + 1]
i += 2
return result
@builtin("dissoc")
def dissoc(d, *keys_to_remove):
"""Remove keys from a dict: (dissoc d :a :b)."""
if d is None:
return {}
if not isinstance(d, dict):
raise EvalError(f"dissoc: expected dict or nil, got {type(d).__name__}")
result = dict(d)
for key in keys_to_remove:
if isinstance(key, Keyword):
key = key.name
result.pop(key, None)
return result
@builtin("into")
def into(target, coll):
"""Convert a collection into another collection type.
(into [] {:a 1 :b 2}) -> [["a" 1] ["b" 2]]
(into {} [[:a 1] [:b 2]]) -> {"a": 1, "b": 2}
(into [] [1 2 3]) -> [1 2 3]
"""
if isinstance(target, list):
if isinstance(coll, dict):
return [[k, v] for k, v in coll.items()]
elif isinstance(coll, (list, tuple)):
return list(coll)
else:
raise EvalError(f"into: cannot convert {type(coll).__name__} into list")
elif isinstance(target, dict):
if isinstance(coll, dict):
return dict(coll)
elif isinstance(coll, (list, tuple)):
result = {}
for item in coll:
if isinstance(item, (list, tuple)) and len(item) >= 2:
key = item[0]
if isinstance(key, Keyword):
key = key.name
result[key] = item[1]
else:
raise EvalError(f"into: expected [key value] pairs, got {item}")
return result
else:
raise EvalError(f"into: cannot convert {type(coll).__name__} into dict")
else:
raise EvalError(f"into: unsupported target type {type(target).__name__}")
@builtin("filter")
def filter_fn(pred, coll):
"""Filter collection by predicate. Pred must be a lambda."""
if not isinstance(pred, Lambda):
raise EvalError(f"filter: expected lambda as predicate, got {type(pred).__name__}")
result = []
for item in coll:
# Evaluate predicate with item
local_env = {}
if pred.closure:
local_env.update(pred.closure)
local_env[pred.params[0]] = item
# Inline evaluation of pred.body
from . import evaluator
if evaluator.evaluate(pred.body, local_env):
result.append(item)
return result
@builtin("some")
def some(pred, coll):
"""Return first truthy value of (pred item) for items in coll, or nil."""
if not isinstance(pred, Lambda):
raise EvalError(f"some: expected lambda as predicate, got {type(pred).__name__}")
for item in coll:
local_env = {}
if pred.closure:
local_env.update(pred.closure)
local_env[pred.params[0]] = item
from . import evaluator
result = evaluator.evaluate(pred.body, local_env)
if result:
return result
return None
@builtin("every?")
def every(pred, coll):
"""Return true if (pred item) is truthy for all items in coll."""
if not isinstance(pred, Lambda):
raise EvalError(f"every?: expected lambda as predicate, got {type(pred).__name__}")
for item in coll:
local_env = {}
if pred.closure:
local_env.update(pred.closure)
local_env[pred.params[0]] = item
from . import evaluator
if not evaluator.evaluate(pred.body, local_env):
return False
return True
@builtin("empty?")
def is_empty(coll):
"""Return true if collection is empty."""
if coll is None:
return True
return len(coll) == 0
@builtin("contains?")
def contains(coll, key):
"""Check if collection contains key (for dicts) or element (for lists)."""
if isinstance(coll, dict):
if isinstance(key, Keyword):
key = key.name
return key in coll
elif isinstance(coll, (list, tuple)):
return key in coll
return False
def evaluate(expr: Any, env: Dict[str, Any] = None) -> Any:
"""
Evaluate an S-expression in the given environment.
Args:
expr: The expression to evaluate
env: Variable bindings (name -> value)
Returns:
The result of evaluation
"""
if env is None:
env = {}
# Literals
if isinstance(expr, (int, float, str, bool)) or expr is None:
return expr
# Symbol - variable lookup
if isinstance(expr, Symbol):
name = expr.name
if name in env:
return env[name]
if name in BUILTINS:
return BUILTINS[name]
if name == "true":
return True
if name == "false":
return False
if name == "nil":
return None
raise EvalError(f"Undefined symbol: {name}")
# Keyword - return as-is (used as map keys)
if isinstance(expr, Keyword):
return expr.name
# Lambda - return as-is (it's a value)
if isinstance(expr, Lambda):
return expr
# Binding - return as-is (resolved at execution time)
if isinstance(expr, Binding):
return expr
# Dict literal
if isinstance(expr, dict):
return {k: evaluate(v, env) for k, v in expr.items()}
# List - function call or special form
if isinstance(expr, list):
if not expr:
return []
head = expr[0]
# If head is a string/number/etc (not Symbol), treat as data list
if not isinstance(head, (Symbol, Lambda, list)):
return [evaluate(x, env) for x in expr]
# Special forms
if isinstance(head, Symbol):
name = head.name
# if - conditional
if name == "if":
if len(expr) < 3:
raise EvalError("if requires condition and then-branch")
cond_result = evaluate(expr[1], env)
if cond_result:
return evaluate(expr[2], env)
elif len(expr) > 3:
return evaluate(expr[3], env)
return None
# cond - multi-way conditional
# Supports both Clojure style: (cond test1 result1 test2 result2 :else default)
# and Scheme style: (cond (test1 result1) (test2 result2) (else default))
if name == "cond":
clauses = expr[1:]
# Check if Clojure style (flat list) or Scheme style (nested pairs)
# Scheme style: first clause is (test result) - exactly 2 elements
# Clojure style: first clause is a test expression like (= x 0) - 3+ elements
first_is_scheme_clause = (
clauses and
isinstance(clauses[0], list) and
len(clauses[0]) == 2 and
not (isinstance(clauses[0][0], Symbol) and clauses[0][0].name in ('=', '<', '>', '<=', '>=', '!=', 'not=', 'and', 'or'))
)
if first_is_scheme_clause:
# Scheme style: ((test result) ...)
for clause in clauses:
if not isinstance(clause, list) or len(clause) < 2:
raise EvalError("cond clause must be (test result)")
test = clause[0]
# Check for else/default
if isinstance(test, Symbol) and test.name in ("else", ":else"):
return evaluate(clause[1], env)
if isinstance(test, Keyword) and test.name == "else":
return evaluate(clause[1], env)
if evaluate(test, env):
return evaluate(clause[1], env)
else:
# Clojure style: test1 result1 test2 result2 ...
i = 0
while i < len(clauses) - 1:
test = clauses[i]
result = clauses[i + 1]
# Check for :else
if isinstance(test, Keyword) and test.name == "else":
return evaluate(result, env)
if isinstance(test, Symbol) and test.name == ":else":
return evaluate(result, env)
if evaluate(test, env):
return evaluate(result, env)
i += 2
return None
# case - switch on value
# (case expr val1 result1 val2 result2 :else default)
if name == "case":
if len(expr) < 2:
raise EvalError("case requires expression to match")
match_val = evaluate(expr[1], env)
clauses = expr[2:]
i = 0
while i < len(clauses) - 1:
test = clauses[i]
result = clauses[i + 1]
# Check for :else / else
if isinstance(test, Keyword) and test.name == "else":
return evaluate(result, env)
if isinstance(test, Symbol) and test.name in (":else", "else"):
return evaluate(result, env)
# Evaluate test value and compare
test_val = evaluate(test, env)
if match_val == test_val:
return evaluate(result, env)
i += 2
return None
# and - short-circuit
if name == "and":
result = True
for arg in expr[1:]:
result = evaluate(arg, env)
if not result:
return result
return result
# or - short-circuit
if name == "or":
result = False
for arg in expr[1:]:
result = evaluate(arg, env)
if result:
return result
return result
# let and let* - local bindings (both bind sequentially in this impl)
if name in ("let", "let*"):
if len(expr) < 3:
raise EvalError(f"{name} requires bindings and body")
bindings = expr[1]
local_env = dict(env)
if isinstance(bindings, list):
# Check if it's ((name value) ...) style (Lisp let* style)
if bindings and isinstance(bindings[0], list):
for binding in bindings:
if len(binding) != 2:
raise EvalError(f"{name} binding must be (name value)")
var_name = binding[0]
if isinstance(var_name, Symbol):
var_name = var_name.name
value = evaluate(binding[1], local_env)
local_env[var_name] = value
# Vector-style [name value ...]
elif len(bindings) % 2 == 0:
for i in range(0, len(bindings), 2):
var_name = bindings[i]
if isinstance(var_name, Symbol):
var_name = var_name.name
value = evaluate(bindings[i + 1], local_env)
local_env[var_name] = value
else:
raise EvalError(f"{name} bindings must be [name value ...] or ((name value) ...)")
else:
raise EvalError(f"{name} bindings must be a list")
return evaluate(expr[2], local_env)
# lambda / fn - create function with closure
if name in ("lambda", "fn"):
if len(expr) < 3:
raise EvalError("lambda requires params and body")
params = expr[1]
if not isinstance(params, list):
raise EvalError("lambda params must be a list")
param_names = []
for p in params:
if isinstance(p, Symbol):
param_names.append(p.name)
elif isinstance(p, str):
param_names.append(p)
else:
raise EvalError(f"Invalid param: {p}")
# Capture current environment as closure
return Lambda(param_names, expr[2], dict(env))
# quote - return unevaluated
if name == "quote":
return expr[1] if len(expr) > 1 else None
# bind - create binding to analysis data
# (bind analysis-var)
# (bind analysis-var :range [0.3 1.0])
# (bind analysis-var :range [0 100] :transform sqrt)
if name == "bind":
if len(expr) < 2:
raise EvalError("bind requires analysis reference")
analysis_ref = expr[1]
if isinstance(analysis_ref, Symbol):
symbol_name = analysis_ref.name
# Look up the symbol in environment
if symbol_name in env:
resolved = env[symbol_name]
# If resolved is actual analysis data (dict with times/values or
# S-expression list with Keywords), keep the symbol name as reference
# for later lookup at execution time
if isinstance(resolved, dict) and ("times" in resolved or "values" in resolved):
analysis_ref = symbol_name # Use name as reference, not the data
elif isinstance(resolved, list) and any(isinstance(x, Keyword) for x in resolved):
# Parsed S-expression analysis data ([:times [...] :duration ...])
analysis_ref = symbol_name
else:
analysis_ref = resolved
else:
raise EvalError(f"bind: undefined symbol '{symbol_name}' - must reference analysis data")
# Parse optional :range [min max] and :transform
range_min, range_max = 0.0, 1.0
transform = None
i = 2
while i < len(expr):
if isinstance(expr[i], Keyword):
kw = expr[i].name
if kw == "range" and i + 1 < len(expr):
range_val = evaluate(expr[i + 1], env) # Evaluate to get actual value
if isinstance(range_val, list) and len(range_val) >= 2:
range_min = float(range_val[0])
range_max = float(range_val[1])
i += 2
elif kw == "transform" and i + 1 < len(expr):
t = expr[i + 1]
if isinstance(t, Symbol):
transform = t.name
elif isinstance(t, str):
transform = t
i += 2
else:
i += 1
else:
i += 1
return Binding(analysis_ref, range_min=range_min, range_max=range_max, transform=transform)
# Vector literal [a b c]
if name == "vec" or name == "vector":
return [evaluate(e, env) for e in expr[1:]]
# map - (map fn coll)
if name == "map":
if len(expr) != 3:
raise EvalError("map requires fn and collection")
fn = evaluate(expr[1], env)
coll = evaluate(expr[2], env)
if not isinstance(fn, Lambda):
raise EvalError(f"map requires lambda, got {type(fn)}")
result = []
for item in coll:
local_env = {}
if fn.closure:
local_env.update(fn.closure)
local_env.update(env)
local_env[fn.params[0]] = item
result.append(evaluate(fn.body, local_env))
return result
# map-indexed - (map-indexed fn coll)
if name == "map-indexed":
if len(expr) != 3:
raise EvalError("map-indexed requires fn and collection")
fn = evaluate(expr[1], env)
coll = evaluate(expr[2], env)
if not isinstance(fn, Lambda):
raise EvalError(f"map-indexed requires lambda, got {type(fn)}")
if len(fn.params) < 2:
raise EvalError("map-indexed lambda needs (i item) params")
result = []
for i, item in enumerate(coll):
local_env = {}
if fn.closure:
local_env.update(fn.closure)
local_env.update(env)
local_env[fn.params[0]] = i
local_env[fn.params[1]] = item
result.append(evaluate(fn.body, local_env))
return result
# reduce - (reduce fn init coll)
if name == "reduce":
if len(expr) != 4:
raise EvalError("reduce requires fn, init, and collection")
fn = evaluate(expr[1], env)
acc = evaluate(expr[2], env)
coll = evaluate(expr[3], env)
if not isinstance(fn, Lambda):
raise EvalError(f"reduce requires lambda, got {type(fn)}")
if len(fn.params) < 2:
raise EvalError("reduce lambda needs (acc item) params")
for item in coll:
local_env = {}
if fn.closure:
local_env.update(fn.closure)
local_env.update(env)
local_env[fn.params[0]] = acc
local_env[fn.params[1]] = item
acc = evaluate(fn.body, local_env)
return acc
# for-each - (for-each fn coll) - iterate with side effects
if name == "for-each":
if len(expr) != 3:
raise EvalError("for-each requires fn and collection")
fn = evaluate(expr[1], env)
coll = evaluate(expr[2], env)
if not isinstance(fn, Lambda):
raise EvalError(f"for-each requires lambda, got {type(fn)}")
for item in coll:
local_env = {}
if fn.closure:
local_env.update(fn.closure)
local_env.update(env)
local_env[fn.params[0]] = item
evaluate(fn.body, local_env)
return None
# Function call
fn = evaluate(head, env)
args = [evaluate(arg, env) for arg in expr[1:]]
# Call builtin
if callable(fn):
return fn(*args)
# Call lambda
if isinstance(fn, Lambda):
if len(args) != len(fn.params):
raise EvalError(f"Lambda expects {len(fn.params)} args, got {len(args)}")
# Start with closure (captured env), then overlay calling env, then params
local_env = {}
if fn.closure:
local_env.update(fn.closure)
local_env.update(env)
for name, value in zip(fn.params, args):
local_env[name] = value
return evaluate(fn.body, local_env)
raise EvalError(f"Not callable: {fn}")
raise EvalError(f"Cannot evaluate: {expr!r}")
def make_env(**kwargs) -> Dict[str, Any]:
"""Create an environment with initial bindings."""
return dict(kwargs)

View File

@@ -0,0 +1,292 @@
"""
External tool runners for effects that can't be done in FFmpeg.
Supports:
- datamosh: via ffglitch or datamoshing Python CLI
- pixelsort: via Rust pixelsort or Python pixelsort-cli
"""
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
def find_tool(tool_names: List[str]) -> Optional[str]:
"""Find first available tool from a list of candidates."""
for name in tool_names:
path = shutil.which(name)
if path:
return path
return None
def check_python_package(package: str) -> bool:
"""Check if a Python package is installed."""
try:
result = subprocess.run(
["python3", "-c", f"import {package}"],
capture_output=True,
timeout=5,
)
return result.returncode == 0
except Exception:
return False
# Tool detection
DATAMOSH_TOOLS = ["ffgac", "ffedit"] # ffglitch tools
PIXELSORT_TOOLS = ["pixelsort"] # Rust CLI
def get_available_tools() -> Dict[str, Optional[str]]:
"""Get dictionary of available external tools."""
return {
"datamosh": find_tool(DATAMOSH_TOOLS),
"pixelsort": find_tool(PIXELSORT_TOOLS),
"datamosh_python": "datamoshing" if check_python_package("datamoshing") else None,
"pixelsort_python": "pixelsort" if check_python_package("pixelsort") else None,
}
def run_datamosh(
input_path: Path,
output_path: Path,
params: Dict[str, Any],
) -> Tuple[bool, str]:
"""
Run datamosh effect using available tool.
Args:
input_path: Input video file
output_path: Output video file
params: Effect parameters (corruption, block_size, etc.)
Returns:
(success, error_message)
"""
tools = get_available_tools()
corruption = params.get("corruption", 0.3)
# Try ffglitch first
if tools.get("datamosh"):
ffgac = tools["datamosh"]
try:
# ffglitch approach: remove I-frames to create datamosh effect
# This is a simplified version - full datamosh needs custom scripts
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
# Write a simple ffglitch script that corrupts motion vectors
f.write(f"""
// Datamosh script - corrupt motion vectors
let corruption = {corruption};
export function glitch_frame(frame, stream) {{
if (frame.pict_type === 'P' && Math.random() < corruption) {{
// Corrupt motion vectors
let dominated = frame.mv?.forward?.overflow;
if (dominated) {{
for (let i = 0; i < dominated.length; i++) {{
if (Math.random() < corruption) {{
dominated[i] = [
Math.floor(Math.random() * 64 - 32),
Math.floor(Math.random() * 64 - 32)
];
}}
}}
}}
}}
return frame;
}}
""")
script_path = f.name
cmd = [
ffgac,
"-i", str(input_path),
"-s", script_path,
"-o", str(output_path),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
Path(script_path).unlink(missing_ok=True)
if result.returncode == 0:
return True, ""
return False, result.stderr[:500]
except subprocess.TimeoutExpired:
return False, "Datamosh timeout"
except Exception as e:
return False, str(e)
# Fall back to Python datamoshing package
if tools.get("datamosh_python"):
try:
cmd = [
"python3", "-m", "datamoshing",
str(input_path),
str(output_path),
"--mode", "iframe_removal",
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
return True, ""
return False, result.stderr[:500]
except Exception as e:
return False, str(e)
return False, "No datamosh tool available. Install ffglitch or: pip install datamoshing"
def run_pixelsort(
input_path: Path,
output_path: Path,
params: Dict[str, Any],
) -> Tuple[bool, str]:
"""
Run pixelsort effect using available tool.
Args:
input_path: Input image/frame file
output_path: Output image file
params: Effect parameters (sort_by, threshold_low, threshold_high, angle)
Returns:
(success, error_message)
"""
tools = get_available_tools()
sort_by = params.get("sort_by", "lightness")
threshold_low = params.get("threshold_low", 50)
threshold_high = params.get("threshold_high", 200)
angle = params.get("angle", 0)
# Try Rust pixelsort first (faster)
if tools.get("pixelsort"):
try:
cmd = [
tools["pixelsort"],
str(input_path),
"-o", str(output_path),
"--sort", sort_by,
"-r", str(angle),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
return True, ""
return False, result.stderr[:500]
except Exception as e:
return False, str(e)
# Fall back to Python pixelsort-cli
if tools.get("pixelsort_python"):
try:
cmd = [
"python3", "-m", "pixelsort",
"--image_path", str(input_path),
"--output", str(output_path),
"--angle", str(angle),
"--threshold", str(threshold_low / 255), # Normalize to 0-1
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
return True, ""
return False, result.stderr[:500]
except Exception as e:
return False, str(e)
return False, "No pixelsort tool available. Install: cargo install pixelsort or pip install pixelsort-cli"
def run_pixelsort_video(
input_path: Path,
output_path: Path,
params: Dict[str, Any],
fps: float = 30,
) -> Tuple[bool, str]:
"""
Run pixelsort on a video by processing each frame.
This extracts frames, processes them, then reassembles.
"""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
frames_in = tmpdir / "frame_%04d.png"
frames_out = tmpdir / "sorted_%04d.png"
# Extract frames
extract_cmd = [
"ffmpeg", "-y",
"-i", str(input_path),
"-vf", f"fps={fps}",
str(frames_in),
]
result = subprocess.run(extract_cmd, capture_output=True, timeout=300)
if result.returncode != 0:
return False, "Failed to extract frames"
# Process each frame
frame_files = sorted(tmpdir.glob("frame_*.png"))
for i, frame in enumerate(frame_files):
out_frame = tmpdir / f"sorted_{i:04d}.png"
success, error = run_pixelsort(frame, out_frame, params)
if not success:
return False, f"Frame {i}: {error}"
# Reassemble
# Get audio from original
reassemble_cmd = [
"ffmpeg", "-y",
"-framerate", str(fps),
"-i", str(tmpdir / "sorted_%04d.png"),
"-i", str(input_path),
"-map", "0:v", "-map", "1:a?",
"-c:v", "libx264", "-preset", "fast",
"-c:a", "copy",
str(output_path),
]
result = subprocess.run(reassemble_cmd, capture_output=True, timeout=300)
if result.returncode != 0:
return False, "Failed to reassemble video"
return True, ""
def run_external_effect(
effect_name: str,
input_path: Path,
output_path: Path,
params: Dict[str, Any],
is_video: bool = True,
) -> Tuple[bool, str]:
"""
Run an external effect tool.
Args:
effect_name: Name of effect (datamosh, pixelsort)
input_path: Input file
output_path: Output file
params: Effect parameters
is_video: Whether input is video (vs single image)
Returns:
(success, error_message)
"""
if effect_name == "datamosh":
return run_datamosh(input_path, output_path, params)
elif effect_name == "pixelsort":
if is_video:
return run_pixelsort_video(input_path, output_path, params)
else:
return run_pixelsort(input_path, output_path, params)
else:
return False, f"Unknown external effect: {effect_name}"
if __name__ == "__main__":
# Print available tools
print("Available external tools:")
for name, path in get_available_tools().items():
status = path if path else "NOT INSTALLED"
print(f" {name}: {status}")

View File

@@ -0,0 +1,616 @@
"""
FFmpeg filter compiler for sexp effects.
Compiles sexp effect definitions to FFmpeg filter expressions,
with support for dynamic parameters via sendcmd scripts.
Usage:
compiler = FFmpegCompiler()
# Compile an effect with static params
filter_str = compiler.compile_effect("brightness", {"amount": 50})
# -> "eq=brightness=0.196"
# Compile with dynamic binding to analysis data
filter_str, sendcmd = compiler.compile_effect_with_binding(
"brightness",
{"amount": {"_bind": "bass-data", "range_min": 0, "range_max": 100}},
analysis_data={"bass-data": {"times": [...], "values": [...]}},
segment_start=0.0,
segment_duration=5.0,
)
# -> ("eq=brightness=0.5", "0.0 [eq] brightness 0.5;\n0.05 [eq] brightness 0.6;...")
"""
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# FFmpeg filter mappings for common effects
# Maps effect name -> {filter: str, params: {param_name: {ffmpeg_param, scale, offset}}}
EFFECT_MAPPINGS = {
"invert": {
"filter": "negate",
"params": {},
},
"grayscale": {
"filter": "colorchannelmixer",
"static": "0.3:0.4:0.3:0:0.3:0.4:0.3:0:0.3:0.4:0.3",
"params": {},
},
"sepia": {
"filter": "colorchannelmixer",
"static": "0.393:0.769:0.189:0:0.349:0.686:0.168:0:0.272:0.534:0.131",
"params": {},
},
"brightness": {
"filter": "eq",
"params": {
"amount": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
},
},
"contrast": {
"filter": "eq",
"params": {
"amount": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
},
},
"saturation": {
"filter": "eq",
"params": {
"amount": {"ffmpeg_param": "saturation", "scale": 1.0, "offset": 0},
},
},
"hue_shift": {
"filter": "hue",
"params": {
"degrees": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
},
},
"blur": {
"filter": "gblur",
"params": {
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
},
},
"sharpen": {
"filter": "unsharp",
"params": {
"amount": {"ffmpeg_param": "la", "scale": 1.0, "offset": 0},
},
},
"pixelate": {
# Scale down then up to create pixelation effect
"filter": "scale",
"static": "iw/8:ih/8:flags=neighbor,scale=iw*8:ih*8:flags=neighbor",
"params": {},
},
"vignette": {
"filter": "vignette",
"params": {
"strength": {"ffmpeg_param": "a", "scale": 1.0, "offset": 0},
},
},
"noise": {
"filter": "noise",
"params": {
"amount": {"ffmpeg_param": "alls", "scale": 1.0, "offset": 0},
},
},
"flip": {
"filter": "hflip", # Default horizontal
"params": {},
},
"mirror": {
"filter": "hflip",
"params": {},
},
"rotate": {
"filter": "rotate",
"params": {
"angle": {"ffmpeg_param": "a", "scale": math.pi/180, "offset": 0}, # degrees to radians
},
},
"zoom": {
"filter": "zoompan",
"params": {
"factor": {"ffmpeg_param": "z", "scale": 1.0, "offset": 0},
},
},
"posterize": {
# Use lutyuv to quantize levels (approximate posterization)
"filter": "lutyuv",
"static": "y=floor(val/32)*32:u=floor(val/32)*32:v=floor(val/32)*32",
"params": {},
},
"threshold": {
# Use geq for thresholding
"filter": "geq",
"static": "lum='if(gt(lum(X,Y),128),255,0)':cb=128:cr=128",
"params": {},
},
"edge_detect": {
"filter": "edgedetect",
"params": {
"low": {"ffmpeg_param": "low", "scale": 1/255, "offset": 0},
"high": {"ffmpeg_param": "high", "scale": 1/255, "offset": 0},
},
},
"swirl": {
"filter": "lenscorrection", # Approximate with lens distortion
"params": {
"strength": {"ffmpeg_param": "k1", "scale": 0.1, "offset": 0},
},
},
"fisheye": {
"filter": "lenscorrection",
"params": {
"strength": {"ffmpeg_param": "k1", "scale": 1.0, "offset": 0},
},
},
"wave": {
# Wave displacement using geq - need r/g/b for RGB mode
"filter": "geq",
"static": "r='r(X+10*sin(Y/20),Y)':g='g(X+10*sin(Y/20),Y)':b='b(X+10*sin(Y/20),Y)'",
"params": {},
},
"rgb_split": {
# Chromatic aberration using geq
"filter": "geq",
"static": "r='p(X+5,Y)':g='p(X,Y)':b='p(X-5,Y)'",
"params": {},
},
"scanlines": {
"filter": "drawgrid",
"params": {
"spacing": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
},
},
"film_grain": {
"filter": "noise",
"params": {
"intensity": {"ffmpeg_param": "alls", "scale": 100, "offset": 0},
},
},
"crt": {
"filter": "vignette", # Simplified - just vignette for CRT look
"params": {},
},
"bloom": {
"filter": "gblur", # Simplified bloom = blur overlay
"params": {
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
},
},
"color_cycle": {
"filter": "hue",
"params": {
"speed": {"ffmpeg_param": "h", "scale": 360.0, "offset": 0, "time_expr": True},
},
"time_based": True, # Uses time expression
},
"strobe": {
# Strobe using select to drop frames
"filter": "select",
"static": "'mod(n,4)'",
"params": {},
},
"echo": {
# Echo using tmix
"filter": "tmix",
"static": "frames=4:weights='1 0.5 0.25 0.125'",
"params": {},
},
"trails": {
# Trails using tblend
"filter": "tblend",
"static": "all_mode=average",
"params": {},
},
"kaleidoscope": {
# 4-way mirror kaleidoscope using FFmpeg filter chain
# Crops top-left quadrant, mirrors horizontally, then vertically
"filter": "crop",
"complex": True,
"static": "iw/2:ih/2:0:0[q];[q]split[q1][q2];[q1]hflip[qr];[q2][qr]hstack[top];[top]split[t1][t2];[t2]vflip[bot];[t1][bot]vstack",
"params": {},
},
"emboss": {
"filter": "convolution",
"static": "-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2",
"params": {},
},
"neon_glow": {
# Edge detect + negate for neon-like effect
"filter": "edgedetect",
"static": "mode=colormix:high=0.1",
"params": {},
},
"ascii_art": {
# Requires Python frame processing - no FFmpeg equivalent
"filter": None,
"python_primitive": "ascii_art_frame",
"params": {
"char_size": {"default": 8},
"alphabet": {"default": "standard"},
"color_mode": {"default": "color"},
},
},
"ascii_zones": {
# Requires Python frame processing - zone-based ASCII
"filter": None,
"python_primitive": "ascii_zones_frame",
"params": {
"char_size": {"default": 8},
"zone_threshold": {"default": 128},
},
},
"datamosh": {
# External tool: ffglitch or datamoshing CLI, falls back to Python
"filter": None,
"external_tool": "datamosh",
"python_primitive": "datamosh_frame",
"params": {
"block_size": {"default": 32},
"corruption": {"default": 0.3},
},
},
"pixelsort": {
# External tool: pixelsort CLI (Rust or Python), falls back to Python
"filter": None,
"external_tool": "pixelsort",
"python_primitive": "pixelsort_frame",
"params": {
"sort_by": {"default": "lightness"},
"threshold_low": {"default": 50},
"threshold_high": {"default": 200},
"angle": {"default": 0},
},
},
"ripple": {
# Use geq for ripple displacement
"filter": "geq",
"static": "lum='lum(X+5*sin(hypot(X-W/2,Y-H/2)/10),Y+5*cos(hypot(X-W/2,Y-H/2)/10))'",
"params": {},
},
"tile_grid": {
# Use tile filter for grid
"filter": "tile",
"static": "2x2",
"params": {},
},
"outline": {
"filter": "edgedetect",
"params": {},
},
"color-adjust": {
"filter": "eq",
"params": {
"brightness": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
"contrast": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
},
},
}
class FFmpegCompiler:
"""Compiles sexp effects to FFmpeg filters with sendcmd support."""
def __init__(self, effect_mappings: Dict = None):
self.mappings = effect_mappings or EFFECT_MAPPINGS
def get_full_mapping(self, effect_name: str) -> Optional[Dict]:
"""Get full mapping for an effect (including external tools and python primitives)."""
mapping = self.mappings.get(effect_name)
if not mapping:
# Try with underscores/hyphens converted
normalized = effect_name.replace("-", "_").replace(" ", "_").lower()
mapping = self.mappings.get(normalized)
return mapping
def get_mapping(self, effect_name: str) -> Optional[Dict]:
"""Get FFmpeg filter mapping for an effect (returns None for non-FFmpeg effects)."""
mapping = self.get_full_mapping(effect_name)
# Return None if no mapping or no FFmpeg filter
if mapping and mapping.get("filter") is None:
return None
return mapping
def has_external_tool(self, effect_name: str) -> Optional[str]:
"""Check if effect uses an external tool. Returns tool name or None."""
mapping = self.get_full_mapping(effect_name)
if mapping:
return mapping.get("external_tool")
return None
def has_python_primitive(self, effect_name: str) -> Optional[str]:
"""Check if effect uses a Python primitive. Returns primitive name or None."""
mapping = self.get_full_mapping(effect_name)
if mapping:
return mapping.get("python_primitive")
return None
def is_complex_filter(self, effect_name: str) -> bool:
"""Check if effect uses a complex filter chain."""
mapping = self.get_full_mapping(effect_name)
return bool(mapping and mapping.get("complex"))
def compile_effect(
self,
effect_name: str,
params: Dict[str, Any],
) -> Optional[str]:
"""
Compile an effect to an FFmpeg filter string with static params.
Returns None if effect has no FFmpeg mapping.
"""
mapping = self.get_mapping(effect_name)
if not mapping:
return None
filter_name = mapping["filter"]
# Handle static filters (no params)
if "static" in mapping:
return f"{filter_name}={mapping['static']}"
if not mapping.get("params"):
return filter_name
# Build param string
filter_params = []
for param_name, param_config in mapping["params"].items():
if param_name in params:
value = params[param_name]
# Skip if it's a binding (handled separately)
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
continue
ffmpeg_param = param_config["ffmpeg_param"]
scale = param_config.get("scale", 1.0)
offset = param_config.get("offset", 0)
# Handle various value types
if isinstance(value, (int, float)):
ffmpeg_value = value * scale + offset
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
elif isinstance(value, str):
filter_params.append(f"{ffmpeg_param}={value}")
elif isinstance(value, list) and value and isinstance(value[0], (int, float)):
ffmpeg_value = value[0] * scale + offset
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
if filter_params:
return f"{filter_name}={':'.join(filter_params)}"
return filter_name
def compile_effect_with_bindings(
self,
effect_name: str,
params: Dict[str, Any],
analysis_data: Dict[str, Dict],
segment_start: float,
segment_duration: float,
sample_interval: float = 0.04, # ~25 fps
) -> Tuple[Optional[str], Optional[str], List[str]]:
"""
Compile an effect with dynamic bindings to a filter + sendcmd script.
Returns:
(filter_string, sendcmd_script, bound_param_names)
- filter_string: Initial FFmpeg filter (may have placeholder values)
- sendcmd_script: Script content for sendcmd filter
- bound_param_names: List of params that have bindings
"""
mapping = self.get_mapping(effect_name)
if not mapping:
return None, None, []
filter_name = mapping["filter"]
static_params = []
bound_params = []
sendcmd_lines = []
# Handle time-based effects (use FFmpeg expressions with 't')
if mapping.get("time_based"):
for param_name, param_config in mapping.get("params", {}).items():
if param_name in params:
value = params[param_name]
ffmpeg_param = param_config["ffmpeg_param"]
scale = param_config.get("scale", 1.0)
if isinstance(value, (int, float)):
# Create time expression: h='t*speed*scale'
static_params.append(f"{ffmpeg_param}='t*{value}*{scale}'")
else:
static_params.append(f"{ffmpeg_param}='t*{scale}'")
if static_params:
filter_str = f"{filter_name}={':'.join(static_params)}"
else:
filter_str = f"{filter_name}=h='t*360'" # Default rotation
return filter_str, None, []
# Process each param
for param_name, param_config in mapping.get("params", {}).items():
if param_name not in params:
continue
value = params[param_name]
ffmpeg_param = param_config["ffmpeg_param"]
scale = param_config.get("scale", 1.0)
offset = param_config.get("offset", 0)
# Check if it's a binding
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
bind_ref = value.get("_bind") or value.get("_binding")
range_min = value.get("range_min", 0.0)
range_max = value.get("range_max", 1.0)
transform = value.get("transform")
# Get analysis data
analysis = analysis_data.get(bind_ref)
if not analysis:
# Try without -data suffix
analysis = analysis_data.get(bind_ref.replace("-data", ""))
if analysis and "times" in analysis and "values" in analysis:
times = analysis["times"]
values = analysis["values"]
# Generate sendcmd entries for this segment
segment_end = segment_start + segment_duration
t = 0.0 # Time relative to segment start
while t < segment_duration:
abs_time = segment_start + t
# Find analysis value at this time
raw_value = self._interpolate_value(times, values, abs_time)
# Apply transform
if transform == "sqrt":
raw_value = math.sqrt(max(0, raw_value))
elif transform == "pow2":
raw_value = raw_value ** 2
elif transform == "log":
raw_value = math.log(max(0.001, raw_value))
# Map to range
mapped_value = range_min + raw_value * (range_max - range_min)
# Apply FFmpeg scaling
ffmpeg_value = mapped_value * scale + offset
# Add sendcmd line (time relative to segment)
sendcmd_lines.append(f"{t:.3f} [{filter_name}] {ffmpeg_param} {ffmpeg_value:.4f};")
t += sample_interval
bound_params.append(param_name)
# Use initial value for the filter string
initial_value = self._interpolate_value(times, values, segment_start)
initial_mapped = range_min + initial_value * (range_max - range_min)
initial_ffmpeg = initial_mapped * scale + offset
static_params.append(f"{ffmpeg_param}={initial_ffmpeg:.4f}")
else:
# No analysis data, use range midpoint
mid_value = (range_min + range_max) / 2
ffmpeg_value = mid_value * scale + offset
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
else:
# Static value - handle various types
if isinstance(value, (int, float)):
ffmpeg_value = value * scale + offset
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
elif isinstance(value, str):
# String value - use as-is (e.g., for direction parameters)
static_params.append(f"{ffmpeg_param}={value}")
elif isinstance(value, list) and value:
# List - try to use first numeric element
first = value[0]
if isinstance(first, (int, float)):
ffmpeg_value = first * scale + offset
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
# Skip other types
# Handle static filters
if "static" in mapping:
filter_str = f"{filter_name}={mapping['static']}"
elif static_params:
filter_str = f"{filter_name}={':'.join(static_params)}"
else:
filter_str = filter_name
# Combine sendcmd lines
sendcmd_script = "\n".join(sendcmd_lines) if sendcmd_lines else None
return filter_str, sendcmd_script, bound_params
def _interpolate_value(
self,
times: List[float],
values: List[float],
target_time: float,
) -> float:
"""Interpolate a value from analysis data at a given time."""
if not times or not values:
return 0.5
# Find surrounding indices
if target_time <= times[0]:
return values[0]
if target_time >= times[-1]:
return values[-1]
# Binary search for efficiency
lo, hi = 0, len(times) - 1
while lo < hi - 1:
mid = (lo + hi) // 2
if times[mid] <= target_time:
lo = mid
else:
hi = mid
# Linear interpolation
t0, t1 = times[lo], times[hi]
v0, v1 = values[lo], values[hi]
if t1 == t0:
return v0
alpha = (target_time - t0) / (t1 - t0)
return v0 + alpha * (v1 - v0)
def generate_sendcmd_filter(
effects: List[Dict],
analysis_data: Dict[str, Dict],
segment_start: float,
segment_duration: float,
) -> Tuple[str, Optional[Path]]:
"""
Generate FFmpeg filter chain with sendcmd for dynamic effects.
Args:
effects: List of effect configs with name and params
analysis_data: Analysis data keyed by name
segment_start: Segment start time in source
segment_duration: Segment duration
Returns:
(filter_chain_string, sendcmd_file_path or None)
"""
import tempfile
compiler = FFmpegCompiler()
filters = []
all_sendcmd_lines = []
for effect in effects:
effect_name = effect.get("effect")
params = {k: v for k, v in effect.items() if k != "effect"}
filter_str, sendcmd, _ = compiler.compile_effect_with_bindings(
effect_name,
params,
analysis_data,
segment_start,
segment_duration,
)
if filter_str:
filters.append(filter_str)
if sendcmd:
all_sendcmd_lines.append(sendcmd)
if not filters:
return "", None
filter_chain = ",".join(filters)
# NOTE: sendcmd disabled - FFmpeg's sendcmd filter has compatibility issues.
# Bindings use their initial value (sampled at segment start time).
# This is acceptable since each segment is only ~8 seconds.
# The binding value is still music-reactive (varies per segment).
sendcmd_path = None
return filter_chain, sendcmd_path

425
artdag/sexp/parser.py Normal file
View File

@@ -0,0 +1,425 @@
"""
S-expression parser for ArtDAG recipes and plans.
Supports:
- Lists: (a b c)
- Symbols: foo, bar-baz, ->
- Keywords: :key
- Strings: "hello world"
- Numbers: 42, 3.14, -1.5
- Comments: ; to end of line
- Vectors: [a b c] (syntactic sugar for lists)
- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts)
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import re
@dataclass
class Symbol:
"""An unquoted symbol/identifier."""
name: str
def __repr__(self):
return f"Symbol({self.name!r})"
def __eq__(self, other):
if isinstance(other, Symbol):
return self.name == other.name
if isinstance(other, str):
return self.name == other
return False
def __hash__(self):
return hash(self.name)
@dataclass
class Keyword:
"""A keyword starting with colon."""
name: str
def __repr__(self):
return f"Keyword({self.name!r})"
def __eq__(self, other):
if isinstance(other, Keyword):
return self.name == other.name
return False
def __hash__(self):
return hash((':' , self.name))
@dataclass
class Lambda:
"""A lambda/anonymous function with closure."""
params: List[str] # Parameter names
body: Any # Expression body
closure: Dict = None # Captured environment (optional for backwards compat)
def __repr__(self):
return f"Lambda({self.params}, {self.body!r})"
@dataclass
class Binding:
"""A binding to analysis data for dynamic effect parameters."""
analysis_ref: str # Name of analysis variable
track: str = None # Optional track name (e.g., "bass", "energy")
range_min: float = 0.0 # Output range minimum
range_max: float = 1.0 # Output range maximum
transform: str = None # Optional transform: "sqrt", "pow2", "log", etc.
def __repr__(self):
t = f", transform={self.transform!r}" if self.transform else ""
return f"Binding({self.analysis_ref!r}, track={self.track!r}, range=[{self.range_min}, {self.range_max}]{t})"
class ParseError(Exception):
"""Error during S-expression parsing."""
def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1):
self.position = position
self.line = line
self.col = col
super().__init__(f"{message} at line {line}, column {col}")
class Tokenizer:
"""Tokenize S-expression text into tokens."""
# Token patterns
WHITESPACE = re.compile(r'\s+')
COMMENT = re.compile(r';[^\n]*')
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*')
def __init__(self, text: str):
self.text = text
self.pos = 0
self.line = 1
self.col = 1
def _advance(self, count: int = 1):
"""Advance position, tracking line/column."""
for _ in range(count):
if self.pos < len(self.text):
if self.text[self.pos] == '\n':
self.line += 1
self.col = 1
else:
self.col += 1
self.pos += 1
def _skip_whitespace_and_comments(self):
"""Skip whitespace and comments."""
while self.pos < len(self.text):
# Whitespace
match = self.WHITESPACE.match(self.text, self.pos)
if match:
self._advance(match.end() - self.pos)
continue
# Comments
match = self.COMMENT.match(self.text, self.pos)
if match:
self._advance(match.end() - self.pos)
continue
break
def peek(self) -> str | None:
"""Peek at current character."""
self._skip_whitespace_and_comments()
if self.pos >= len(self.text):
return None
return self.text[self.pos]
def next_token(self) -> Any:
"""Get the next token."""
self._skip_whitespace_and_comments()
if self.pos >= len(self.text):
return None
char = self.text[self.pos]
start_line, start_col = self.line, self.col
# Single-character tokens (parens, brackets, braces)
if char in '()[]{}':
self._advance()
return char
# String
if char == '"':
match = self.STRING.match(self.text, self.pos)
if not match:
raise ParseError("Unterminated string", self.pos, self.line, self.col)
self._advance(match.end() - self.pos)
# Parse escape sequences
content = match.group()[1:-1]
content = content.replace('\\n', '\n')
content = content.replace('\\t', '\t')
content = content.replace('\\"', '"')
content = content.replace('\\\\', '\\')
return content
# Keyword
if char == ':':
match = self.KEYWORD.match(self.text, self.pos)
if match:
self._advance(match.end() - self.pos)
return Keyword(match.group()[1:]) # Strip leading colon
raise ParseError(f"Invalid keyword", self.pos, self.line, self.col)
# Number (must check before symbol due to - prefix)
if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and
(self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')):
match = self.NUMBER.match(self.text, self.pos)
if match:
self._advance(match.end() - self.pos)
num_str = match.group()
if '.' in num_str or 'e' in num_str or 'E' in num_str:
return float(num_str)
return int(num_str)
# Symbol
match = self.SYMBOL.match(self.text, self.pos)
if match:
self._advance(match.end() - self.pos)
return Symbol(match.group())
raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col)
def parse(text: str) -> Any:
"""
Parse an S-expression string into Python data structures.
Returns:
Parsed S-expression as nested Python structures:
- Lists become Python lists
- Symbols become Symbol objects
- Keywords become Keyword objects
- Strings become Python strings
- Numbers become int/float
Example:
>>> parse('(recipe "test" :version "1.0")')
[Symbol('recipe'), 'test', Keyword('version'), '1.0']
"""
tokenizer = Tokenizer(text)
result = _parse_expr(tokenizer)
# Check for trailing content
if tokenizer.peek() is not None:
raise ParseError("Unexpected content after expression",
tokenizer.pos, tokenizer.line, tokenizer.col)
return result
def parse_all(text: str) -> List[Any]:
"""
Parse multiple S-expressions from a string.
Returns list of parsed expressions.
"""
tokenizer = Tokenizer(text)
results = []
while tokenizer.peek() is not None:
results.append(_parse_expr(tokenizer))
return results
def _parse_expr(tokenizer: Tokenizer) -> Any:
"""Parse a single expression."""
token = tokenizer.next_token()
if token is None:
raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col)
# List
if token == '(':
return _parse_list(tokenizer, ')')
# Vector (sugar for list)
if token == '[':
return _parse_list(tokenizer, ']')
# Map/dict: {:key1 val1 :key2 val2}
if token == '{':
return _parse_map(tokenizer)
# Unexpected closers
if token in (')', ']', '}'):
raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col)
# Atom
return token
def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]:
"""Parse a list until the closing delimiter."""
items = []
while True:
char = tokenizer.peek()
if char is None:
raise ParseError(f"Unterminated list, expected {closer!r}",
tokenizer.pos, tokenizer.line, tokenizer.col)
if char == closer:
tokenizer.next_token() # Consume closer
return items
items.append(_parse_expr(tokenizer))
def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]:
"""Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}."""
result = {}
while True:
char = tokenizer.peek()
if char is None:
raise ParseError("Unterminated map, expected '}'",
tokenizer.pos, tokenizer.line, tokenizer.col)
if char == '}':
tokenizer.next_token() # Consume closer
return result
# Parse key (should be a keyword like :key)
key_token = _parse_expr(tokenizer)
if isinstance(key_token, Keyword):
key = key_token.name
elif isinstance(key_token, str):
key = key_token
else:
raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}",
tokenizer.pos, tokenizer.line, tokenizer.col)
# Parse value
value = _parse_expr(tokenizer)
result[key] = value
def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str:
"""
Serialize a Python data structure back to S-expression format.
Args:
expr: The expression to serialize
indent: Current indentation level (for pretty printing)
pretty: Whether to use pretty printing with newlines
Returns:
S-expression string
"""
if isinstance(expr, list):
if not expr:
return "()"
if pretty:
return _serialize_pretty(expr, indent)
else:
items = [serialize(item, indent, False) for item in expr]
return "(" + " ".join(items) + ")"
if isinstance(expr, Symbol):
return expr.name
if isinstance(expr, Keyword):
return f":{expr.name}"
if isinstance(expr, Lambda):
params = " ".join(expr.params)
body = serialize(expr.body, indent, pretty)
return f"(fn [{params}] {body})"
if isinstance(expr, Binding):
# analysis_ref can be a string, node ID, or dict - serialize it properly
if isinstance(expr.analysis_ref, str):
ref_str = f'"{expr.analysis_ref}"'
else:
ref_str = serialize(expr.analysis_ref, indent, pretty)
s = f"(bind {ref_str} :range [{expr.range_min} {expr.range_max}]"
if expr.transform:
s += f" :transform {expr.transform}"
return s + ")"
if isinstance(expr, str):
# Escape special characters
escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t')
return f'"{escaped}"'
if isinstance(expr, bool):
return "true" if expr else "false"
if isinstance(expr, (int, float)):
return str(expr)
if expr is None:
return "nil"
if isinstance(expr, dict):
# Serialize dict as property list: {:key1 val1 :key2 val2}
items = []
for k, v in expr.items():
items.append(f":{k}")
items.append(serialize(v, indent, pretty))
return "{" + " ".join(items) + "}"
raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}")
def _serialize_pretty(expr: List, indent: int) -> str:
"""Pretty-print a list expression with smart formatting."""
if not expr:
return "()"
prefix = " " * indent
inner_prefix = " " * (indent + 1)
# Check if this is a simple list that fits on one line
simple = serialize(expr, indent, False)
if len(simple) < 60 and '\n' not in simple:
return simple
# Start building multiline output
head = serialize(expr[0], indent + 1, False)
parts = [f"({head}"]
i = 1
while i < len(expr):
item = expr[i]
# Group keyword-value pairs on same line
if isinstance(item, Keyword) and i + 1 < len(expr):
key = serialize(item, 0, False)
val = serialize(expr[i + 1], indent + 1, False)
# If value is short, put on same line
if len(val) < 50 and '\n' not in val:
parts.append(f"{inner_prefix}{key} {val}")
else:
# Value is complex, serialize it pretty
val_pretty = serialize(expr[i + 1], indent + 1, True)
parts.append(f"{inner_prefix}{key} {val_pretty}")
i += 2
else:
# Regular item
item_str = serialize(item, indent + 1, True)
parts.append(f"{inner_prefix}{item_str}")
i += 1
return "\n".join(parts) + ")"

2187
artdag/sexp/planner.py Normal file

File diff suppressed because it is too large Load Diff

620
artdag/sexp/primitives.py Normal file
View File

@@ -0,0 +1,620 @@
"""
Frame processing primitives for sexp effects.
These primitives are called by sexp effect definitions and operate on
numpy arrays (frames). They're used when falling back to Python rendering
instead of FFmpeg.
Required: numpy, PIL
"""
import math
from typing import Any, Dict, List, Optional, Tuple
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
np = None
try:
from PIL import Image, ImageDraw, ImageFont
HAS_PIL = True
except ImportError:
HAS_PIL = False
# ASCII character sets for different styles
ASCII_ALPHABETS = {
"standard": " .:-=+*#%@",
"blocks": " ░▒▓█",
"simple": " .-:+=xX#",
"detailed": " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$",
"binary": "",
}
def check_deps():
"""Check that required dependencies are available."""
if not HAS_NUMPY:
raise ImportError("numpy required for frame primitives: pip install numpy")
if not HAS_PIL:
raise ImportError("PIL required for frame primitives: pip install Pillow")
def frame_to_image(frame: np.ndarray) -> Image.Image:
"""Convert numpy frame to PIL Image."""
check_deps()
if frame.dtype != np.uint8:
frame = np.clip(frame, 0, 255).astype(np.uint8)
return Image.fromarray(frame)
def image_to_frame(img: Image.Image) -> np.ndarray:
"""Convert PIL Image to numpy frame."""
check_deps()
return np.array(img)
# ============================================================================
# ASCII Art Primitives
# ============================================================================
def cell_sample(frame: np.ndarray, cell_size: int = 8) -> Tuple[np.ndarray, np.ndarray]:
"""
Sample frame into cells, returning average colors and luminances.
Args:
frame: Input frame (H, W, C)
cell_size: Size of each cell
Returns:
(colors, luminances) - colors is (rows, cols, 3), luminances is (rows, cols)
"""
check_deps()
h, w = frame.shape[:2]
rows = h // cell_size
cols = w // cell_size
colors = np.zeros((rows, cols, 3), dtype=np.float32)
luminances = np.zeros((rows, cols), dtype=np.float32)
for r in range(rows):
for c in range(cols):
y0, y1 = r * cell_size, (r + 1) * cell_size
x0, x1 = c * cell_size, (c + 1) * cell_size
cell = frame[y0:y1, x0:x1]
# Average color
avg_color = cell.mean(axis=(0, 1))
colors[r, c] = avg_color[:3] # RGB only
# Luminance (ITU-R BT.601)
lum = 0.299 * avg_color[0] + 0.587 * avg_color[1] + 0.114 * avg_color[2]
luminances[r, c] = lum
return colors, luminances
def luminance_to_chars(
luminances: np.ndarray,
alphabet: str = "standard",
contrast: float = 1.0,
) -> List[List[str]]:
"""
Convert luminance values to ASCII characters.
Args:
luminances: 2D array of luminance values (0-255)
alphabet: Name of character set or custom string
contrast: Contrast multiplier
Returns:
2D list of characters
"""
check_deps()
chars = ASCII_ALPHABETS.get(alphabet, alphabet)
n_chars = len(chars)
rows, cols = luminances.shape
result = []
for r in range(rows):
row_chars = []
for c in range(cols):
lum = luminances[r, c]
# Apply contrast around midpoint
lum = 128 + (lum - 128) * contrast
lum = np.clip(lum, 0, 255)
# Map to character index
idx = int(lum / 256 * n_chars)
idx = min(idx, n_chars - 1)
row_chars.append(chars[idx])
result.append(row_chars)
return result
def render_char_grid(
frame: np.ndarray,
chars: List[List[str]],
colors: np.ndarray,
char_size: int = 8,
color_mode: str = "color",
background: Tuple[int, int, int] = (0, 0, 0),
) -> np.ndarray:
"""
Render character grid to an image.
Args:
frame: Original frame (for dimensions)
chars: 2D list of characters
colors: Color for each cell (rows, cols, 3)
char_size: Size of each character cell
color_mode: "color", "white", or "green"
background: Background RGB color
Returns:
Rendered frame
"""
check_deps()
h, w = frame.shape[:2]
rows = len(chars)
cols = len(chars[0]) if chars else 0
# Create output image
img = Image.new("RGB", (w, h), background)
draw = ImageDraw.Draw(img)
# Try to get a monospace font
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size)
except (IOError, OSError):
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", char_size)
except (IOError, OSError):
font = ImageFont.load_default()
for r in range(rows):
for c in range(cols):
char = chars[r][c]
if char == ' ':
continue
x = c * char_size
y = r * char_size
if color_mode == "color":
color = tuple(int(v) for v in colors[r, c])
elif color_mode == "green":
color = (0, 255, 0)
else: # white
color = (255, 255, 255)
draw.text((x, y), char, fill=color, font=font)
return np.array(img)
def ascii_art_frame(
frame: np.ndarray,
char_size: int = 8,
alphabet: str = "standard",
color_mode: str = "color",
contrast: float = 1.5,
background: Tuple[int, int, int] = (0, 0, 0),
) -> np.ndarray:
"""
Apply ASCII art effect to a frame.
This is the main entry point for the ascii_art effect.
"""
check_deps()
colors, luminances = cell_sample(frame, char_size)
chars = luminance_to_chars(luminances, alphabet, contrast)
return render_char_grid(frame, chars, colors, char_size, color_mode, background)
# ============================================================================
# ASCII Zones Primitives
# ============================================================================
def ascii_zones_frame(
frame: np.ndarray,
char_size: int = 8,
zone_threshold: int = 128,
dark_chars: str = " .-:",
light_chars: str = "=+*#",
) -> np.ndarray:
"""
Apply zone-based ASCII art effect.
Different character sets for dark vs light regions.
"""
check_deps()
colors, luminances = cell_sample(frame, char_size)
rows, cols = luminances.shape
chars = []
for r in range(rows):
row_chars = []
for c in range(cols):
lum = luminances[r, c]
if lum < zone_threshold:
# Dark zone
charset = dark_chars
local_lum = lum / zone_threshold # 0-1 within zone
else:
# Light zone
charset = light_chars
local_lum = (lum - zone_threshold) / (255 - zone_threshold)
idx = int(local_lum * len(charset))
idx = min(idx, len(charset) - 1)
row_chars.append(charset[idx])
chars.append(row_chars)
return render_char_grid(frame, chars, colors, char_size, "color", (0, 0, 0))
# ============================================================================
# Kaleidoscope Primitives (Python fallback)
# ============================================================================
def kaleidoscope_displace(
w: int,
h: int,
segments: int = 6,
rotation: float = 0,
cx: float = None,
cy: float = None,
zoom: float = 1.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute kaleidoscope displacement coordinates.
Returns (x_coords, y_coords) arrays for remapping.
"""
check_deps()
if cx is None:
cx = w / 2
if cy is None:
cy = h / 2
# Create coordinate grids
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
# Center coordinates
x_centered = x_grid - cx
y_centered = y_grid - cy
# Convert to polar
r = np.sqrt(x_centered**2 + y_centered**2) / zoom
theta = np.arctan2(y_centered, x_centered)
# Apply rotation
theta = theta - np.radians(rotation)
# Kaleidoscope: fold angle into segment
segment_angle = 2 * np.pi / segments
theta = np.abs(np.mod(theta, segment_angle) - segment_angle / 2)
# Convert back to cartesian
x_new = r * np.cos(theta) + cx
y_new = r * np.sin(theta) + cy
return x_new, y_new
def remap(
frame: np.ndarray,
x_coords: np.ndarray,
y_coords: np.ndarray,
) -> np.ndarray:
"""
Remap frame using coordinate arrays.
Uses bilinear interpolation.
"""
check_deps()
from scipy import ndimage
h, w = frame.shape[:2]
# Clip coordinates
x_coords = np.clip(x_coords, 0, w - 1)
y_coords = np.clip(y_coords, 0, h - 1)
# Remap each channel
if len(frame.shape) == 3:
result = np.zeros_like(frame)
for c in range(frame.shape[2]):
result[:, :, c] = ndimage.map_coordinates(
frame[:, :, c],
[y_coords, x_coords],
order=1,
mode='reflect',
)
return result
else:
return ndimage.map_coordinates(frame, [y_coords, x_coords], order=1, mode='reflect')
def kaleidoscope_frame(
frame: np.ndarray,
segments: int = 6,
rotation: float = 0,
center_x: float = 0.5,
center_y: float = 0.5,
zoom: float = 1.0,
) -> np.ndarray:
"""
Apply kaleidoscope effect to a frame.
This is a Python fallback - FFmpeg version is faster.
"""
check_deps()
h, w = frame.shape[:2]
cx = w * center_x
cy = h * center_y
x_coords, y_coords = kaleidoscope_displace(w, h, segments, rotation, cx, cy, zoom)
return remap(frame, x_coords, y_coords)
# ============================================================================
# Datamosh Primitives (simplified Python version)
# ============================================================================
def datamosh_frame(
frame: np.ndarray,
prev_frame: Optional[np.ndarray],
block_size: int = 32,
corruption: float = 0.3,
max_offset: int = 50,
color_corrupt: bool = True,
) -> np.ndarray:
"""
Simplified datamosh effect using block displacement.
This is a basic approximation - real datamosh works on compressed video.
"""
check_deps()
if prev_frame is None:
return frame.copy()
h, w = frame.shape[:2]
result = frame.copy()
# Process in blocks
for y in range(0, h - block_size, block_size):
for x in range(0, w - block_size, block_size):
if np.random.random() < corruption:
# Random offset
ox = np.random.randint(-max_offset, max_offset + 1)
oy = np.random.randint(-max_offset, max_offset + 1)
# Source from previous frame with offset
src_y = np.clip(y + oy, 0, h - block_size)
src_x = np.clip(x + ox, 0, w - block_size)
block = prev_frame[src_y:src_y+block_size, src_x:src_x+block_size]
# Color corruption
if color_corrupt and np.random.random() < 0.3:
# Swap or shift channels
block = np.roll(block, np.random.randint(1, 3), axis=2)
result[y:y+block_size, x:x+block_size] = block
return result
# ============================================================================
# Pixelsort Primitives (Python version)
# ============================================================================
def pixelsort_frame(
frame: np.ndarray,
sort_by: str = "lightness",
threshold_low: float = 50,
threshold_high: float = 200,
angle: float = 0,
reverse: bool = False,
) -> np.ndarray:
"""
Apply pixel sorting effect to a frame.
"""
check_deps()
from scipy import ndimage
# Rotate if needed
if angle != 0:
frame = ndimage.rotate(frame, -angle, reshape=False, mode='reflect')
h, w = frame.shape[:2]
result = frame.copy()
# Compute sort key
if sort_by == "lightness":
key = 0.299 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2]
elif sort_by == "hue":
# Simple hue approximation
key = np.arctan2(
np.sqrt(3) * (frame[:,:,1].astype(float) - frame[:,:,2]),
2 * frame[:,:,0].astype(float) - frame[:,:,1] - frame[:,:,2]
)
elif sort_by == "saturation":
mx = frame.max(axis=2).astype(float)
mn = frame.min(axis=2).astype(float)
key = np.where(mx > 0, (mx - mn) / mx, 0)
else:
key = frame[:,:,0] # Red channel
# Sort each row
for y in range(h):
row = result[y]
row_key = key[y]
# Find sortable intervals (pixels within threshold)
mask = (row_key >= threshold_low) & (row_key <= threshold_high)
# Find runs of True in mask
runs = []
start = None
for x in range(w):
if mask[x] and start is None:
start = x
elif not mask[x] and start is not None:
if x - start > 1:
runs.append((start, x))
start = None
if start is not None and w - start > 1:
runs.append((start, w))
# Sort each run
for start, end in runs:
indices = np.argsort(row_key[start:end])
if reverse:
indices = indices[::-1]
result[y, start:end] = row[start:end][indices]
# Rotate back
if angle != 0:
result = ndimage.rotate(result, angle, reshape=False, mode='reflect')
return result
# ============================================================================
# Primitive Registry
# ============================================================================
def map_char_grid(
chars,
luminances,
fn,
):
"""
Map a function over each cell of a character grid.
Args:
chars: 2D array/list of characters (rows, cols)
luminances: 2D array of luminance values
fn: Function or Lambda (row, col, char, luminance) -> new_char
Returns:
New character grid with mapped values (list of lists)
"""
from .parser import Lambda
from .evaluator import evaluate
# Handle both list and numpy array inputs
if isinstance(chars, np.ndarray):
rows, cols = chars.shape[:2]
else:
rows = len(chars)
cols = len(chars[0]) if rows > 0 and isinstance(chars[0], (list, tuple, str)) else 1
# Get luminances as 2D
if isinstance(luminances, np.ndarray):
lum_arr = luminances
else:
lum_arr = np.array(luminances)
# Check if fn is a Lambda (from sexp) or a Python callable
is_lambda = isinstance(fn, Lambda)
result = []
for r in range(rows):
row_result = []
for c in range(cols):
# Get character
if isinstance(chars, np.ndarray):
ch = chars[r, c] if len(chars.shape) > 1 else chars[r]
elif isinstance(chars[r], str):
ch = chars[r][c] if c < len(chars[r]) else ' '
else:
ch = chars[r][c] if c < len(chars[r]) else ' '
# Get luminance
if len(lum_arr.shape) > 1:
lum = lum_arr[r, c]
else:
lum = lum_arr[r]
# Call the function
if is_lambda:
# Evaluate the Lambda with arguments bound
call_env = dict(fn.closure) if fn.closure else {}
for param, val in zip(fn.params, [r, c, ch, float(lum)]):
call_env[param] = val
new_ch = evaluate(fn.body, call_env)
else:
new_ch = fn(r, c, ch, float(lum))
row_result.append(new_ch)
result.append(row_result)
return result
def alphabet_char(alphabet: str, index: int) -> str:
"""
Get a character from an alphabet at a given index.
Args:
alphabet: Alphabet name (from ASCII_ALPHABETS) or literal string
index: Index into the alphabet (clamped to valid range)
Returns:
Character at the index
"""
# Get alphabet string
if alphabet in ASCII_ALPHABETS:
chars = ASCII_ALPHABETS[alphabet]
else:
chars = alphabet
# Clamp index to valid range
index = int(index)
index = max(0, min(index, len(chars) - 1))
return chars[index]
PRIMITIVES = {
# ASCII
"cell-sample": cell_sample,
"luminance-to-chars": luminance_to_chars,
"render-char-grid": render_char_grid,
"map-char-grid": map_char_grid,
"alphabet-char": alphabet_char,
"ascii_art_frame": ascii_art_frame,
"ascii_zones_frame": ascii_zones_frame,
# Kaleidoscope
"kaleidoscope-displace": kaleidoscope_displace,
"remap": remap,
"kaleidoscope_frame": kaleidoscope_frame,
# Datamosh
"datamosh": datamosh_frame,
"datamosh_frame": datamosh_frame,
# Pixelsort
"pixelsort": pixelsort_frame,
"pixelsort_frame": pixelsort_frame,
}
def get_primitive(name: str):
"""Get a primitive function by name."""
return PRIMITIVES.get(name)
def list_primitives() -> List[str]:
"""List all available primitives."""
return list(PRIMITIVES.keys())

779
artdag/sexp/scheduler.py Normal file
View File

@@ -0,0 +1,779 @@
"""
Celery scheduler for S-expression execution plans.
Distributes plan steps to workers as S-expressions.
The S-expression is the canonical format - workers receive
serialized S-expressions and can verify cache_ids by hashing them.
Usage:
from artdag.sexp import compile_string, create_plan
from artdag.sexp.scheduler import schedule_plan
recipe = compile_string(sexp_content)
plan = create_plan(recipe, inputs={'video': 'abc123...'})
result = schedule_plan(plan)
"""
import hashlib
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Callable
from .parser import Symbol, Keyword, serialize, parse
from .planner import ExecutionPlanSexp, PlanStep
logger = logging.getLogger(__name__)
@dataclass
class StepResult:
"""Result from executing a step."""
step_id: str
cache_id: str
status: str # 'completed', 'cached', 'failed', 'pending'
output_path: Optional[str] = None
error: Optional[str] = None
ipfs_cid: Optional[str] = None
@dataclass
class PlanResult:
"""Result from executing a complete plan."""
plan_id: str
status: str # 'completed', 'failed', 'partial'
steps_completed: int = 0
steps_cached: int = 0
steps_failed: int = 0
output_cache_id: Optional[str] = None
output_path: Optional[str] = None
output_ipfs_cid: Optional[str] = None
step_results: Dict[str, StepResult] = field(default_factory=dict)
error: Optional[str] = None
def step_to_sexp(step: PlanStep) -> List:
"""
Convert a PlanStep to minimal S-expression for worker.
This is the canonical form that workers receive.
Workers can verify cache_id by hashing this S-expression.
"""
sexp = [Symbol(step.node_type.lower())]
# Add config as keywords
for key, value in step.config.items():
sexp.extend([Keyword(key.replace('_', '-')), value])
# Add inputs as cache IDs (not step IDs)
if step.inputs:
sexp.extend([Keyword("inputs"), step.inputs])
return sexp
def step_sexp_to_string(step: PlanStep) -> str:
"""Serialize step to S-expression string for Celery task."""
return serialize(step_to_sexp(step))
def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool:
"""
Verify that a step's cache_id matches its S-expression.
Workers should call this to verify they're executing the correct task.
"""
data = {"sexp": step_sexp}
if cluster_key:
data = {"_cluster_key": cluster_key, "_data": data}
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
computed = hashlib.sha3_256(json_str.encode()).hexdigest()
return computed == expected_cache_id
class PlanScheduler:
"""
Schedules execution of S-expression plans on Celery workers.
The scheduler:
1. Groups steps by dependency level
2. Checks cache for already-computed results
3. Dispatches uncached steps to workers
4. Waits for completion before proceeding to next level
"""
def __init__(
self,
cache_manager=None,
celery_app=None,
execute_task_name: str = 'tasks.execute_step_sexp',
):
"""
Initialize the scheduler.
Args:
cache_manager: L1 cache manager for checking cached results
celery_app: Celery application instance
execute_task_name: Name of the Celery task for step execution
"""
self.cache_manager = cache_manager
self.celery_app = celery_app
self.execute_task_name = execute_task_name
def schedule(
self,
plan: ExecutionPlanSexp,
timeout: int = 3600,
) -> PlanResult:
"""
Schedule and execute a plan.
Args:
plan: The execution plan (S-expression format)
timeout: Timeout in seconds for the entire plan
Returns:
PlanResult with execution results
"""
from celery import group
logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
# Build step lookup and group by level
steps_by_id = {s.step_id: s for s in plan.steps}
steps_by_level = self._group_by_level(plan.steps)
max_level = max(steps_by_level.keys()) if steps_by_level else 0
# Track results
result = PlanResult(
plan_id=plan.plan_id,
status="pending",
)
# Map step_id -> cache_id for resolving inputs
cache_ids = dict(plan.inputs) # Start with input hashes
for step in plan.steps:
cache_ids[step.step_id] = step.cache_id
# Execute level by level
for level in range(max_level + 1):
level_steps = steps_by_level.get(level, [])
if not level_steps:
continue
logger.info(f"Level {level}: {len(level_steps)} steps")
# Check cache for each step
steps_to_run = []
for step in level_steps:
if self._is_cached(step.cache_id):
result.steps_cached += 1
result.step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="cached",
output_path=self._get_cached_path(step.cache_id),
)
else:
steps_to_run.append(step)
if not steps_to_run:
logger.info(f"Level {level}: all {len(level_steps)} steps cached")
continue
# Dispatch uncached steps to workers
logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps")
tasks = []
for step in steps_to_run:
# Build task arguments
step_sexp = step_sexp_to_string(step)
input_cache_ids = {
inp: cache_ids.get(inp, inp)
for inp in step.inputs
}
task = self._get_execute_task().s(
step_sexp=step_sexp,
step_id=step.step_id,
cache_id=step.cache_id,
plan_id=plan.plan_id,
input_cache_ids=input_cache_ids,
)
tasks.append(task)
# Execute in parallel
job = group(tasks)
async_result = job.apply_async()
try:
step_results = async_result.get(timeout=timeout)
except Exception as e:
logger.error(f"Level {level} failed: {e}")
result.status = "failed"
result.error = f"Level {level} failed: {e}"
return result
# Process results
for step_result in step_results:
step_id = step_result.get("step_id")
status = step_result.get("status")
result.step_results[step_id] = StepResult(
step_id=step_id,
cache_id=step_result.get("cache_id"),
status=status,
output_path=step_result.get("output_path"),
error=step_result.get("error"),
ipfs_cid=step_result.get("ipfs_cid"),
)
if status in ("completed", "cached", "completed_by_other"):
result.steps_completed += 1
elif status == "failed":
result.steps_failed += 1
result.status = "failed"
result.error = step_result.get("error")
return result
# Get final output
output_step = steps_by_id.get(plan.output_step_id)
if output_step:
output_result = result.step_results.get(output_step.step_id)
if output_result:
result.output_cache_id = output_step.cache_id
result.output_path = output_result.output_path
result.output_ipfs_cid = output_result.ipfs_cid
result.status = "completed"
logger.info(
f"Plan {plan.plan_id[:16]}... completed: "
f"{result.steps_completed} executed, {result.steps_cached} cached"
)
return result
def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]:
"""Group steps by dependency level."""
by_level = {}
for step in steps:
by_level.setdefault(step.level, []).append(step)
return by_level
def _is_cached(self, cache_id: str) -> bool:
"""Check if a cache_id exists in cache."""
if self.cache_manager is None:
return False
path = self.cache_manager.get_by_cid(cache_id)
return path is not None
def _get_cached_path(self, cache_id: str) -> Optional[str]:
"""Get the path for a cached item."""
if self.cache_manager is None:
return None
path = self.cache_manager.get_by_cid(cache_id)
return str(path) if path else None
def _get_execute_task(self):
"""Get the Celery task for step execution."""
if self.celery_app is None:
raise RuntimeError("Celery app not configured")
return self.celery_app.tasks[self.execute_task_name]
def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler:
"""
Create a scheduler with the given dependencies.
If not provided, attempts to import from art-celery.
"""
if celery_app is None:
try:
from celery_app import app as celery_app
except ImportError:
pass
if cache_manager is None:
try:
from cache_manager import get_cache_manager
cache_manager = get_cache_manager()
except ImportError:
pass
return PlanScheduler(
cache_manager=cache_manager,
celery_app=celery_app,
)
def schedule_plan(
plan: ExecutionPlanSexp,
cache_manager=None,
celery_app=None,
timeout: int = 3600,
) -> PlanResult:
"""
Convenience function to schedule a plan.
Args:
plan: The execution plan
cache_manager: Optional cache manager
celery_app: Optional Celery app
timeout: Execution timeout
Returns:
PlanResult
"""
scheduler = create_scheduler(cache_manager, celery_app)
return scheduler.schedule(plan, timeout=timeout)
# Stage-aware scheduling
@dataclass
class StageResult:
"""Result from executing a stage."""
stage_name: str
cache_id: str
status: str # 'completed', 'cached', 'failed', 'pending'
step_results: Dict[str, StepResult] = field(default_factory=dict)
outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id
error: Optional[str] = None
@dataclass
class StagePlanResult:
"""Result from executing a plan with stages."""
plan_id: str
status: str # 'completed', 'failed', 'partial'
stages_completed: int = 0
stages_cached: int = 0
stages_failed: int = 0
steps_completed: int = 0
steps_cached: int = 0
steps_failed: int = 0
stage_results: Dict[str, StageResult] = field(default_factory=dict)
output_cache_id: Optional[str] = None
output_path: Optional[str] = None
error: Optional[str] = None
class StagePlanScheduler:
"""
Stage-aware scheduler for S-expression plans.
The scheduler:
1. Groups stages by level (parallel groups)
2. For each stage level:
- Check stage cache, skip entire stage if hit
- Execute stage steps (grouped by level within stage)
- Cache stage outputs
3. Stages at same level can run in parallel
"""
def __init__(
self,
cache_manager=None,
stage_cache=None,
celery_app=None,
execute_task_name: str = 'tasks.execute_step_sexp',
):
"""
Initialize the stage-aware scheduler.
Args:
cache_manager: L1 cache manager for step-level caching
stage_cache: StageCache instance for stage-level caching
celery_app: Celery application instance
execute_task_name: Name of the Celery task for step execution
"""
self.cache_manager = cache_manager
self.stage_cache = stage_cache
self.celery_app = celery_app
self.execute_task_name = execute_task_name
def schedule(
self,
plan: ExecutionPlanSexp,
timeout: int = 3600,
) -> StagePlanResult:
"""
Schedule and execute a plan with stage awareness.
If the plan has stages, uses stage-level scheduling.
Otherwise, falls back to step-level scheduling.
Args:
plan: The execution plan (S-expression format)
timeout: Timeout in seconds for the entire plan
Returns:
StagePlanResult with execution results
"""
# If no stages, use regular scheduling
if not plan.stage_plans:
logger.info("Plan has no stages, using step-level scheduling")
regular_scheduler = PlanScheduler(
cache_manager=self.cache_manager,
celery_app=self.celery_app,
execute_task_name=self.execute_task_name,
)
step_result = regular_scheduler.schedule(plan, timeout)
return StagePlanResult(
plan_id=step_result.plan_id,
status=step_result.status,
steps_completed=step_result.steps_completed,
steps_cached=step_result.steps_cached,
steps_failed=step_result.steps_failed,
output_cache_id=step_result.output_cache_id,
output_path=step_result.output_path,
error=step_result.error,
)
logger.info(
f"Scheduling plan {plan.plan_id[:16]}... "
f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)"
)
result = StagePlanResult(
plan_id=plan.plan_id,
status="pending",
)
# Group stages by level
stages_by_level = self._group_stages_by_level(plan.stage_plans)
max_level = max(stages_by_level.keys()) if stages_by_level else 0
# Track stage outputs for data flow
stage_outputs = {} # stage_name -> {binding_name -> cache_id}
# Execute stage by stage level
for level in range(max_level + 1):
level_stages = stages_by_level.get(level, [])
if not level_stages:
continue
logger.info(f"Stage level {level}: {len(level_stages)} stages")
# Check stage cache for each stage
stages_to_run = []
for stage_plan in level_stages:
if self._is_stage_cached(stage_plan.cache_id):
result.stages_cached += 1
cached_entry = self._load_cached_stage(stage_plan.cache_id)
if cached_entry:
stage_outputs[stage_plan.stage_name] = {
name: out.cache_id
for name, out in cached_entry.outputs.items()
}
result.stage_results[stage_plan.stage_name] = StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="cached",
outputs=stage_outputs[stage_plan.stage_name],
)
logger.info(f"Stage {stage_plan.stage_name}: cached")
else:
stages_to_run.append(stage_plan)
if not stages_to_run:
logger.info(f"Stage level {level}: all {len(level_stages)} stages cached")
continue
# Execute uncached stages
# For now, execute sequentially; L1 Celery will add parallel execution
for stage_plan in stages_to_run:
logger.info(f"Executing stage: {stage_plan.stage_name}")
stage_result = self._execute_stage(
stage_plan,
plan,
stage_outputs,
timeout,
)
result.stage_results[stage_plan.stage_name] = stage_result
if stage_result.status == "completed":
result.stages_completed += 1
stage_outputs[stage_plan.stage_name] = stage_result.outputs
# Cache the stage result
self._cache_stage(stage_plan, stage_result)
elif stage_result.status == "failed":
result.stages_failed += 1
result.status = "failed"
result.error = stage_result.error
return result
# Accumulate step counts
for sr in stage_result.step_results.values():
if sr.status == "completed":
result.steps_completed += 1
elif sr.status == "cached":
result.steps_cached += 1
elif sr.status == "failed":
result.steps_failed += 1
# Get final output
if plan.stage_plans:
last_stage = plan.stage_plans[-1]
if last_stage.stage_name in result.stage_results:
stage_res = result.stage_results[last_stage.stage_name]
result.output_cache_id = last_stage.cache_id
# Find the output step's path from step results
for step_res in stage_res.step_results.values():
if step_res.output_path:
result.output_path = step_res.output_path
result.status = "completed"
logger.info(
f"Plan {plan.plan_id[:16]}... completed: "
f"{result.stages_completed} stages executed, "
f"{result.stages_cached} stages cached"
)
return result
def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]:
"""Group stage plans by their level."""
by_level = {}
for stage_plan in stage_plans:
by_level.setdefault(stage_plan.level, []).append(stage_plan)
return by_level
def _is_stage_cached(self, cache_id: str) -> bool:
"""Check if a stage is cached."""
if self.stage_cache is None:
return False
return self.stage_cache.has_stage(cache_id)
def _load_cached_stage(self, cache_id: str):
"""Load a cached stage entry."""
if self.stage_cache is None:
return None
return self.stage_cache.load_stage(cache_id)
def _cache_stage(self, stage_plan, stage_result: StageResult) -> None:
"""Cache a stage result."""
if self.stage_cache is None:
return
from .stage_cache import StageCacheEntry, StageOutput
outputs = {}
for name, cache_id in stage_result.outputs.items():
outputs[name] = StageOutput(
cache_id=cache_id,
output_type="artifact",
)
entry = StageCacheEntry(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
outputs=outputs,
)
self.stage_cache.save_stage(entry)
def _execute_stage(
self,
stage_plan,
plan: ExecutionPlanSexp,
stage_outputs: Dict,
timeout: int,
) -> StageResult:
"""
Execute a single stage.
Uses the PlanScheduler to execute the stage's steps.
"""
# Create a mini-plan with just this stage's steps
stage_steps = stage_plan.steps
# Build step lookup
steps_by_id = {s.step_id: s for s in stage_steps}
steps_by_level = {}
for step in stage_steps:
steps_by_level.setdefault(step.level, []).append(step)
max_level = max(steps_by_level.keys()) if steps_by_level else 0
# Track step results
step_results = {}
cache_ids = dict(plan.inputs) # Start with input hashes
for step in plan.steps:
cache_ids[step.step_id] = step.cache_id
# Include outputs from previous stages
for stage_name, outputs in stage_outputs.items():
for binding_name, binding_cache_id in outputs.items():
cache_ids[binding_name] = binding_cache_id
# Execute steps level by level
for level in range(max_level + 1):
level_steps = steps_by_level.get(level, [])
if not level_steps:
continue
# Check cache for each step
steps_to_run = []
for step in level_steps:
if self._is_step_cached(step.cache_id):
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="cached",
output_path=self._get_cached_path(step.cache_id),
)
else:
steps_to_run.append(step)
if not steps_to_run:
continue
# Execute steps (for now, sequentially - L1 will add Celery dispatch)
for step in steps_to_run:
# In a full implementation, this would dispatch to Celery
# For now, mark as pending
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="pending",
)
# If Celery is configured, dispatch the task
if self.celery_app:
try:
task_result = self._dispatch_step(step, cache_ids, timeout)
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status=task_result.get("status", "completed"),
output_path=task_result.get("output_path"),
error=task_result.get("error"),
ipfs_cid=task_result.get("ipfs_cid"),
)
except Exception as e:
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="failed",
error=str(e),
)
return StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="failed",
step_results=step_results,
error=str(e),
)
# Build output bindings
outputs = {}
for out_name, node_id in stage_plan.output_bindings.items():
outputs[out_name] = cache_ids.get(node_id, node_id)
return StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="completed",
step_results=step_results,
outputs=outputs,
)
def _is_step_cached(self, cache_id: str) -> bool:
"""Check if a step is cached."""
if self.cache_manager is None:
return False
path = self.cache_manager.get_by_cid(cache_id)
return path is not None
def _get_cached_path(self, cache_id: str) -> Optional[str]:
"""Get the path for a cached step."""
if self.cache_manager is None:
return None
path = self.cache_manager.get_by_cid(cache_id)
return str(path) if path else None
def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict:
"""Dispatch a step to Celery for execution."""
if self.celery_app is None:
raise RuntimeError("Celery app not configured")
task = self.celery_app.tasks[self.execute_task_name]
step_sexp = step_sexp_to_string(step)
input_cache_ids = {
inp: cache_ids.get(inp, inp)
for inp in step.inputs
}
async_result = task.apply_async(
kwargs={
"step_sexp": step_sexp,
"step_id": step.step_id,
"cache_id": step.cache_id,
"input_cache_ids": input_cache_ids,
}
)
return async_result.get(timeout=timeout)
def create_stage_scheduler(
cache_manager=None,
stage_cache=None,
celery_app=None,
) -> StagePlanScheduler:
"""
Create a stage-aware scheduler with the given dependencies.
Args:
cache_manager: L1 cache manager for step-level caching
stage_cache: StageCache instance for stage-level caching
celery_app: Celery application instance
Returns:
StagePlanScheduler
"""
if celery_app is None:
try:
from celery_app import app as celery_app
except ImportError:
pass
if cache_manager is None:
try:
from cache_manager import get_cache_manager
cache_manager = get_cache_manager()
except ImportError:
pass
return StagePlanScheduler(
cache_manager=cache_manager,
stage_cache=stage_cache,
celery_app=celery_app,
)
def schedule_staged_plan(
plan: ExecutionPlanSexp,
cache_manager=None,
stage_cache=None,
celery_app=None,
timeout: int = 3600,
) -> StagePlanResult:
"""
Convenience function to schedule a plan with stage awareness.
Args:
plan: The execution plan
cache_manager: Optional step-level cache manager
stage_cache: Optional stage-level cache
celery_app: Optional Celery app
timeout: Execution timeout
Returns:
StagePlanResult
"""
scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app)
return scheduler.schedule(plan, timeout=timeout)

412
artdag/sexp/stage_cache.py Normal file
View File

@@ -0,0 +1,412 @@
"""
Stage-level cache layer using S-expression storage.
Provides caching for stage outputs, enabling:
- Stage-level cache hits (skip entire stage execution)
- Analysis result persistence as sexp
- Cross-worker stage cache sharing (for L1 Celery integration)
All cache files use .sexp extension - no JSON in the pipeline.
"""
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from .parser import Symbol, Keyword, parse, serialize
@dataclass
class StageOutput:
"""A single output from a stage."""
cache_id: Optional[str] = None # For artifacts (files, analysis data)
value: Any = None # For scalar values
output_type: str = "artifact" # "artifact", "analysis", "scalar"
def to_sexp(self) -> List:
"""Convert to S-expression."""
sexp = []
if self.cache_id:
sexp.extend([Keyword("cache-id"), self.cache_id])
if self.value is not None:
sexp.extend([Keyword("value"), self.value])
sexp.extend([Keyword("type"), Keyword(self.output_type)])
return sexp
@classmethod
def from_sexp(cls, sexp: List) -> 'StageOutput':
"""Parse from S-expression list."""
cache_id = None
value = None
output_type = "artifact"
i = 0
while i < len(sexp):
if isinstance(sexp[i], Keyword):
key = sexp[i].name
if i + 1 < len(sexp):
val = sexp[i + 1]
if key == "cache-id":
cache_id = val
elif key == "value":
value = val
elif key == "type":
if isinstance(val, Keyword):
output_type = val.name
else:
output_type = str(val)
i += 2
else:
i += 1
else:
i += 1
return cls(cache_id=cache_id, value=value, output_type=output_type)
@dataclass
class StageCacheEntry:
"""Cached result of a stage execution."""
stage_name: str
cache_id: str
outputs: Dict[str, StageOutput] # binding_name -> output info
completed_at: float = field(default_factory=time.time)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_sexp(self) -> List:
"""
Convert to S-expression for storage.
Format:
(stage-result
:name "analyze-a"
:cache-id "abc123..."
:completed-at 1705678900.123
:outputs
((beats-a :cache-id "def456..." :type :analysis)
(tempo :value 120.5 :type :scalar)))
"""
sexp = [Symbol("stage-result")]
sexp.extend([Keyword("name"), self.stage_name])
sexp.extend([Keyword("cache-id"), self.cache_id])
sexp.extend([Keyword("completed-at"), self.completed_at])
if self.outputs:
outputs_sexp = []
for name, output in self.outputs.items():
output_sexp = [Symbol(name)] + output.to_sexp()
outputs_sexp.append(output_sexp)
sexp.extend([Keyword("outputs"), outputs_sexp])
if self.metadata:
sexp.extend([Keyword("metadata"), self.metadata])
return sexp
def to_string(self, pretty: bool = True) -> str:
"""Serialize to S-expression string."""
return serialize(self.to_sexp(), pretty=pretty)
@classmethod
def from_sexp(cls, sexp: List) -> 'StageCacheEntry':
"""Parse from S-expression."""
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "stage-result":
raise ValueError("Invalid stage-result sexp")
stage_name = None
cache_id = None
completed_at = time.time()
outputs = {}
metadata = {}
i = 1
while i < len(sexp):
if isinstance(sexp[i], Keyword):
key = sexp[i].name
if i + 1 < len(sexp):
val = sexp[i + 1]
if key == "name":
stage_name = val
elif key == "cache-id":
cache_id = val
elif key == "completed-at":
completed_at = float(val)
elif key == "outputs":
if isinstance(val, list):
for output_sexp in val:
if isinstance(output_sexp, list) and output_sexp:
out_name = output_sexp[0]
if isinstance(out_name, Symbol):
out_name = out_name.name
outputs[out_name] = StageOutput.from_sexp(output_sexp[1:])
elif key == "metadata":
metadata = val if isinstance(val, dict) else {}
i += 2
else:
i += 1
else:
i += 1
if not stage_name or not cache_id:
raise ValueError("stage-result missing required fields (name, cache-id)")
return cls(
stage_name=stage_name,
cache_id=cache_id,
outputs=outputs,
completed_at=completed_at,
metadata=metadata,
)
@classmethod
def from_string(cls, text: str) -> 'StageCacheEntry':
"""Parse from S-expression string."""
sexp = parse(text)
return cls.from_sexp(sexp)
class StageCache:
"""
Stage-level cache manager using S-expression files.
Cache structure:
cache_dir/
_stages/
{cache_id}.sexp <- Stage result files
"""
def __init__(self, cache_dir: Union[str, Path]):
"""
Initialize stage cache.
Args:
cache_dir: Base cache directory
"""
self.cache_dir = Path(cache_dir)
self.stages_dir = self.cache_dir / "_stages"
self.stages_dir.mkdir(parents=True, exist_ok=True)
def get_cache_path(self, cache_id: str) -> Path:
"""Get the path for a stage cache file."""
return self.stages_dir / f"{cache_id}.sexp"
def has_stage(self, cache_id: str) -> bool:
"""Check if a stage result is cached."""
return self.get_cache_path(cache_id).exists()
def load_stage(self, cache_id: str) -> Optional[StageCacheEntry]:
"""
Load a cached stage result.
Args:
cache_id: Stage cache ID
Returns:
StageCacheEntry if found, None otherwise
"""
path = self.get_cache_path(cache_id)
if not path.exists():
return None
try:
content = path.read_text()
return StageCacheEntry.from_string(content)
except Exception as e:
# Corrupted cache file - log and return None
import sys
print(f"Warning: corrupted stage cache {cache_id}: {e}", file=sys.stderr)
return None
def save_stage(self, entry: StageCacheEntry) -> Path:
"""
Save a stage result to cache.
Args:
entry: Stage cache entry to save
Returns:
Path to the saved cache file
"""
path = self.get_cache_path(entry.cache_id)
content = entry.to_string(pretty=True)
path.write_text(content)
return path
def delete_stage(self, cache_id: str) -> bool:
"""
Delete a cached stage result.
Args:
cache_id: Stage cache ID
Returns:
True if deleted, False if not found
"""
path = self.get_cache_path(cache_id)
if path.exists():
path.unlink()
return True
return False
def list_stages(self) -> List[str]:
"""List all cached stage IDs."""
return [
p.stem for p in self.stages_dir.glob("*.sexp")
]
def clear(self) -> int:
"""
Clear all cached stages.
Returns:
Number of entries cleared
"""
count = 0
for path in self.stages_dir.glob("*.sexp"):
path.unlink()
count += 1
return count
@dataclass
class AnalysisResult:
"""
Analysis result stored as S-expression.
Format:
(analysis-result
:analyzer "beats"
:input-hash "abc123..."
:duration 120.5
:tempo 128.0
:times (0.0 0.468 0.937 1.406 ...)
:values (0.8 0.9 0.7 0.85 ...))
"""
analyzer: str
input_hash: str
data: Dict[str, Any] # Analysis data (times, values, duration, etc.)
computed_at: float = field(default_factory=time.time)
def to_sexp(self) -> List:
"""Convert to S-expression."""
sexp = [Symbol("analysis-result")]
sexp.extend([Keyword("analyzer"), self.analyzer])
sexp.extend([Keyword("input-hash"), self.input_hash])
sexp.extend([Keyword("computed-at"), self.computed_at])
# Add all data fields
for key, value in self.data.items():
# Convert key to keyword
sexp.extend([Keyword(key.replace("_", "-")), value])
return sexp
def to_string(self, pretty: bool = True) -> str:
"""Serialize to S-expression string."""
return serialize(self.to_sexp(), pretty=pretty)
@classmethod
def from_sexp(cls, sexp: List) -> 'AnalysisResult':
"""Parse from S-expression."""
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis-result":
raise ValueError("Invalid analysis-result sexp")
analyzer = None
input_hash = None
computed_at = time.time()
data = {}
i = 1
while i < len(sexp):
if isinstance(sexp[i], Keyword):
key = sexp[i].name
if i + 1 < len(sexp):
val = sexp[i + 1]
if key == "analyzer":
analyzer = val
elif key == "input-hash":
input_hash = val
elif key == "computed-at":
computed_at = float(val)
else:
# Convert kebab-case back to snake_case
data_key = key.replace("-", "_")
data[data_key] = val
i += 2
else:
i += 1
else:
i += 1
if not analyzer:
raise ValueError("analysis-result missing analyzer field")
return cls(
analyzer=analyzer,
input_hash=input_hash or "",
data=data,
computed_at=computed_at,
)
@classmethod
def from_string(cls, text: str) -> 'AnalysisResult':
"""Parse from S-expression string."""
sexp = parse(text)
return cls.from_sexp(sexp)
def save_analysis_result(
cache_dir: Union[str, Path],
node_id: str,
result: AnalysisResult,
) -> Path:
"""
Save an analysis result as S-expression.
Args:
cache_dir: Base cache directory
node_id: Node ID for the analysis
result: Analysis result to save
Returns:
Path to the saved file
"""
cache_dir = Path(cache_dir)
node_dir = cache_dir / node_id
node_dir.mkdir(parents=True, exist_ok=True)
path = node_dir / "analysis.sexp"
content = result.to_string(pretty=True)
path.write_text(content)
return path
def load_analysis_result(
cache_dir: Union[str, Path],
node_id: str,
) -> Optional[AnalysisResult]:
"""
Load an analysis result from cache.
Args:
cache_dir: Base cache directory
node_id: Node ID for the analysis
Returns:
AnalysisResult if found, None otherwise
"""
cache_dir = Path(cache_dir)
path = cache_dir / node_id / "analysis.sexp"
if not path.exists():
return None
try:
content = path.read_text()
return AnalysisResult.from_string(content)
except Exception as e:
import sys
print(f"Warning: corrupted analysis cache {node_id}: {e}", file=sys.stderr)
return None

View File

@@ -0,0 +1,146 @@
"""
Tests for FFmpeg filter compilation.
Validates that each filter mapping produces valid FFmpeg commands.
"""
import subprocess
import tempfile
from pathlib import Path
from .ffmpeg_compiler import FFmpegCompiler, EFFECT_MAPPINGS
def test_filter_syntax(filter_str: str, duration: float = 0.1, is_complex: bool = False) -> tuple[bool, str]:
"""
Test if an FFmpeg filter string is valid by running it on a test pattern.
Args:
filter_str: The filter string to test
duration: Duration of test video
is_complex: If True, use -filter_complex instead of -vf
Returns (success, error_message)
"""
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
output_path = f.name
try:
if is_complex:
# Complex filter graph needs -filter_complex and explicit output mapping
cmd = [
'ffmpeg', '-y',
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
'-filter_complex', f'[0:v]{filter_str}[out]',
'-map', '[out]', '-map', '1:a',
'-c:v', 'libx264', '-preset', 'ultrafast',
'-c:a', 'aac',
'-t', str(duration),
output_path
]
else:
# Simple filter uses -vf
cmd = [
'ffmpeg', '-y',
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
'-vf', filter_str,
'-c:v', 'libx264', '-preset', 'ultrafast',
'-c:a', 'aac',
'-t', str(duration),
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0:
return True, ""
else:
# Extract relevant error
stderr = result.stderr
for line in stderr.split('\n'):
if 'Error' in line or 'error' in line or 'Invalid' in line:
return False, line.strip()
return False, stderr[-500:] if len(stderr) > 500 else stderr
except subprocess.TimeoutExpired:
return False, "Timeout"
except Exception as e:
return False, str(e)
finally:
Path(output_path).unlink(missing_ok=True)
def run_all_tests():
"""Test all effect mappings."""
compiler = FFmpegCompiler()
results = []
for effect_name, mapping in EFFECT_MAPPINGS.items():
filter_name = mapping.get("filter")
# Skip effects with no FFmpeg equivalent (external tools or python primitives)
if filter_name is None:
reason = "No FFmpeg equivalent"
if mapping.get("external_tool"):
reason = f"External tool: {mapping['external_tool']}"
elif mapping.get("python_primitive"):
reason = f"Python primitive: {mapping['python_primitive']}"
results.append((effect_name, "SKIP", reason))
continue
# Check if complex filter
is_complex = mapping.get("complex", False)
# Build filter string
if "static" in mapping:
filter_str = f"{filter_name}={mapping['static']}"
else:
filter_str = filter_name
# Test it
success, error = test_filter_syntax(filter_str, is_complex=is_complex)
if success:
results.append((effect_name, "PASS", filter_str))
else:
results.append((effect_name, "FAIL", f"{filter_str} -> {error}"))
return results
def print_results(results):
"""Print test results."""
passed = sum(1 for _, status, _ in results if status == "PASS")
failed = sum(1 for _, status, _ in results if status == "FAIL")
skipped = sum(1 for _, status, _ in results if status == "SKIP")
print(f"\n{'='*60}")
print(f"FFmpeg Filter Tests: {passed} passed, {failed} failed, {skipped} skipped")
print(f"{'='*60}\n")
# Print failures first
if failed > 0:
print("FAILURES:")
for name, status, msg in results:
if status == "FAIL":
print(f" {name}: {msg}")
print()
# Print passes
print("PASSED:")
for name, status, msg in results:
if status == "PASS":
print(f" {name}: {msg}")
# Print skips
if skipped > 0:
print("\nSKIPPED (Python fallback):")
for name, status, msg in results:
if status == "SKIP":
print(f" {name}")
if __name__ == "__main__":
results = run_all_tests()
print_results(results)

View File

@@ -0,0 +1,201 @@
"""
Tests for Python primitive effects.
Tests that ascii_art, ascii_zones, and other Python primitives
can be executed via the EffectExecutor.
"""
import subprocess
import tempfile
from pathlib import Path
import pytest
try:
import numpy as np
from PIL import Image
HAS_DEPS = True
except ImportError:
HAS_DEPS = False
from .primitives import (
ascii_art_frame,
ascii_zones_frame,
get_primitive,
list_primitives,
)
from .ffmpeg_compiler import FFmpegCompiler
def create_test_video(path: Path, duration: float = 0.5, size: str = "64x64") -> bool:
"""Create a short test video using ffmpeg."""
cmd = [
"ffmpeg", "-y",
"-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10",
"-c:v", "libx264", "-preset", "ultrafast",
str(path)
]
result = subprocess.run(cmd, capture_output=True)
return result.returncode == 0
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
class TestPrimitives:
"""Test primitive functions directly."""
def test_ascii_art_frame_basic(self):
"""Test ascii_art_frame produces output of same shape."""
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
result = ascii_art_frame(frame, char_size=8)
assert result.shape == frame.shape
assert result.dtype == np.uint8
def test_ascii_zones_frame_basic(self):
"""Test ascii_zones_frame produces output of same shape."""
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
result = ascii_zones_frame(frame, char_size=8)
assert result.shape == frame.shape
assert result.dtype == np.uint8
def test_get_primitive(self):
"""Test primitive lookup."""
assert get_primitive("ascii_art_frame") is ascii_art_frame
assert get_primitive("ascii_zones_frame") is ascii_zones_frame
assert get_primitive("nonexistent") is None
def test_list_primitives(self):
"""Test listing primitives."""
primitives = list_primitives()
assert "ascii_art_frame" in primitives
assert "ascii_zones_frame" in primitives
assert len(primitives) > 5
class TestFFmpegCompilerPrimitives:
"""Test FFmpegCompiler python_primitive mappings."""
def test_has_python_primitive_ascii_art(self):
"""Test ascii_art has python_primitive."""
compiler = FFmpegCompiler()
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
def test_has_python_primitive_ascii_zones(self):
"""Test ascii_zones has python_primitive."""
compiler = FFmpegCompiler()
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
def test_has_python_primitive_ffmpeg_effect(self):
"""Test FFmpeg effects don't have python_primitive."""
compiler = FFmpegCompiler()
assert compiler.has_python_primitive("brightness") is None
assert compiler.has_python_primitive("blur") is None
def test_compile_effect_returns_none_for_primitives(self):
"""Test compile_effect returns None for primitive effects."""
compiler = FFmpegCompiler()
assert compiler.compile_effect("ascii_art", {}) is None
assert compiler.compile_effect("ascii_zones", {}) is None
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
class TestEffectExecutorPrimitives:
"""Test EffectExecutor with Python primitives."""
def test_executor_loads_primitive(self):
"""Test that executor finds primitive effects."""
from ..nodes.effect import _get_python_primitive_effect
effect_fn = _get_python_primitive_effect("ascii_art")
assert effect_fn is not None
effect_fn = _get_python_primitive_effect("ascii_zones")
assert effect_fn is not None
def test_executor_rejects_unknown_effect(self):
"""Test that executor returns None for unknown effects."""
from ..nodes.effect import _get_python_primitive_effect
effect_fn = _get_python_primitive_effect("nonexistent_effect")
assert effect_fn is None
def test_execute_ascii_art_effect(self, tmp_path):
"""Test executing ascii_art effect on a video."""
from ..nodes.effect import EffectExecutor
# Create test video
input_path = tmp_path / "input.mp4"
if not create_test_video(input_path):
pytest.skip("Could not create test video")
output_path = tmp_path / "output.mkv"
executor = EffectExecutor()
result = executor.execute(
config={"effect": "ascii_art", "char_size": 8},
inputs=[input_path],
output_path=output_path,
)
assert result.exists()
assert result.stat().st_size > 0
def run_all_tests():
"""Run tests manually."""
import sys
# Check dependencies
if not HAS_DEPS:
print("SKIP: numpy/PIL not available")
return
print("Testing primitives...")
# Test primitive functions
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
print(" ascii_art_frame...", end=" ")
result = ascii_art_frame(frame, char_size=8)
assert result.shape == frame.shape
print("PASS")
print(" ascii_zones_frame...", end=" ")
result = ascii_zones_frame(frame, char_size=8)
assert result.shape == frame.shape
print("PASS")
# Test FFmpegCompiler mappings
print("\nTesting FFmpegCompiler mappings...")
compiler = FFmpegCompiler()
print(" ascii_art python_primitive...", end=" ")
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
print("PASS")
print(" ascii_zones python_primitive...", end=" ")
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
print("PASS")
# Test executor lookup
print("\nTesting EffectExecutor...")
try:
from ..nodes.effect import _get_python_primitive_effect
print(" _get_python_primitive_effect(ascii_art)...", end=" ")
effect_fn = _get_python_primitive_effect("ascii_art")
assert effect_fn is not None
print("PASS")
print(" _get_python_primitive_effect(ascii_zones)...", end=" ")
effect_fn = _get_python_primitive_effect("ascii_zones")
assert effect_fn is not None
print("PASS")
except ImportError as e:
print(f"SKIP: {e}")
print("\n=== All tests passed ===")
if __name__ == "__main__":
run_all_tests()

View File

@@ -0,0 +1,324 @@
"""
Tests for stage cache layer.
Tests S-expression storage for stage results and analysis data.
"""
import pytest
import tempfile
from pathlib import Path
from .stage_cache import (
StageCache,
StageCacheEntry,
StageOutput,
AnalysisResult,
save_analysis_result,
load_analysis_result,
)
from .parser import parse, serialize
class TestStageOutput:
"""Test StageOutput dataclass and serialization."""
def test_stage_output_artifact(self):
"""StageOutput can represent an artifact."""
output = StageOutput(
cache_id="abc123",
output_type="artifact",
)
assert output.cache_id == "abc123"
assert output.output_type == "artifact"
def test_stage_output_scalar(self):
"""StageOutput can represent a scalar value."""
output = StageOutput(
value=120.5,
output_type="scalar",
)
assert output.value == 120.5
assert output.output_type == "scalar"
def test_stage_output_to_sexp(self):
"""StageOutput serializes to sexp."""
output = StageOutput(
cache_id="abc123",
output_type="artifact",
)
sexp = output.to_sexp()
sexp_str = serialize(sexp)
assert "cache-id" in sexp_str
assert "abc123" in sexp_str
assert "type" in sexp_str
assert "artifact" in sexp_str
def test_stage_output_from_sexp(self):
"""StageOutput parses from sexp."""
sexp = parse('(:cache-id "def456" :type :analysis)')
output = StageOutput.from_sexp(sexp)
assert output.cache_id == "def456"
assert output.output_type == "analysis"
class TestStageCacheEntry:
"""Test StageCacheEntry serialization."""
def test_stage_cache_entry_to_sexp(self):
"""StageCacheEntry serializes to sexp."""
entry = StageCacheEntry(
stage_name="analyze-a",
cache_id="stage_abc123",
outputs={
"beats": StageOutput(cache_id="beats_def456", output_type="analysis"),
"tempo": StageOutput(value=120.5, output_type="scalar"),
},
completed_at=1705678900.123,
)
sexp = entry.to_sexp()
sexp_str = serialize(sexp)
assert "stage-result" in sexp_str
assert "analyze-a" in sexp_str
assert "stage_abc123" in sexp_str
assert "outputs" in sexp_str
assert "beats" in sexp_str
def test_stage_cache_entry_roundtrip(self):
"""save -> load produces identical data."""
entry = StageCacheEntry(
stage_name="analyze-b",
cache_id="stage_xyz789",
outputs={
"segments": StageOutput(cache_id="seg_123", output_type="artifact"),
},
completed_at=1705678900.0,
)
sexp_str = entry.to_string()
loaded = StageCacheEntry.from_string(sexp_str)
assert loaded.stage_name == entry.stage_name
assert loaded.cache_id == entry.cache_id
assert "segments" in loaded.outputs
assert loaded.outputs["segments"].cache_id == "seg_123"
def test_stage_cache_entry_from_sexp(self):
"""StageCacheEntry parses from sexp."""
sexp_str = '''
(stage-result
:name "test-stage"
:cache-id "cache123"
:completed-at 1705678900.0
:outputs ((beats :cache-id "beats123" :type :analysis)))
'''
entry = StageCacheEntry.from_string(sexp_str)
assert entry.stage_name == "test-stage"
assert entry.cache_id == "cache123"
assert "beats" in entry.outputs
assert entry.outputs["beats"].cache_id == "beats123"
class TestStageCache:
"""Test StageCache file operations."""
def test_save_and_load_stage(self):
"""Save and load a stage result."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
entry = StageCacheEntry(
stage_name="analyze",
cache_id="test_cache_id",
outputs={
"beats": StageOutput(cache_id="beats_out", output_type="analysis"),
},
)
path = cache.save_stage(entry)
assert path.exists()
assert path.suffix == ".sexp"
loaded = cache.load_stage("test_cache_id")
assert loaded is not None
assert loaded.stage_name == "analyze"
assert "beats" in loaded.outputs
def test_has_stage(self):
"""Check if stage is cached."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
assert not cache.has_stage("nonexistent")
entry = StageCacheEntry(
stage_name="test",
cache_id="exists_cache_id",
outputs={},
)
cache.save_stage(entry)
assert cache.has_stage("exists_cache_id")
def test_delete_stage(self):
"""Delete a cached stage."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
entry = StageCacheEntry(
stage_name="test",
cache_id="to_delete",
outputs={},
)
cache.save_stage(entry)
assert cache.has_stage("to_delete")
result = cache.delete_stage("to_delete")
assert result is True
assert not cache.has_stage("to_delete")
def test_list_stages(self):
"""List all cached stages."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
for i in range(3):
entry = StageCacheEntry(
stage_name=f"stage{i}",
cache_id=f"cache_{i}",
outputs={},
)
cache.save_stage(entry)
stages = cache.list_stages()
assert len(stages) == 3
assert "cache_0" in stages
assert "cache_1" in stages
assert "cache_2" in stages
def test_clear(self):
"""Clear all cached stages."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
for i in range(3):
entry = StageCacheEntry(
stage_name=f"stage{i}",
cache_id=f"cache_{i}",
outputs={},
)
cache.save_stage(entry)
count = cache.clear()
assert count == 3
assert len(cache.list_stages()) == 0
def test_cache_file_extension(self):
"""Cache files use .sexp extension."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
path = cache.get_cache_path("test_id")
assert path.suffix == ".sexp"
def test_invalid_sexp_error_handling(self):
"""Graceful error on corrupted cache file."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = StageCache(tmpdir)
# Write corrupted content
corrupt_path = cache.get_cache_path("corrupted")
corrupt_path.write_text("this is not valid sexp )()(")
# Should return None, not raise
result = cache.load_stage("corrupted")
assert result is None
class TestAnalysisResult:
"""Test AnalysisResult serialization."""
def test_analysis_result_to_sexp(self):
"""AnalysisResult serializes to sexp."""
result = AnalysisResult(
analyzer="beats",
input_hash="input_abc123",
data={
"duration": 120.5,
"tempo": 128.0,
"times": [0.0, 0.468, 0.937, 1.406],
"values": [0.8, 0.9, 0.7, 0.85],
},
)
sexp = result.to_sexp()
sexp_str = serialize(sexp)
assert "analysis-result" in sexp_str
assert "beats" in sexp_str
assert "duration" in sexp_str
assert "tempo" in sexp_str
assert "times" in sexp_str
def test_analysis_result_roundtrip(self):
"""Analysis result round-trips through sexp."""
original = AnalysisResult(
analyzer="scenes",
input_hash="video_xyz",
data={
"scene_count": 5,
"scene_times": [0.0, 10.5, 25.0, 45.2, 60.0],
},
)
sexp_str = original.to_string()
loaded = AnalysisResult.from_string(sexp_str)
assert loaded.analyzer == original.analyzer
assert loaded.input_hash == original.input_hash
assert loaded.data["scene_count"] == 5
def test_save_and_load_analysis_result(self):
"""Save and load analysis result from cache."""
with tempfile.TemporaryDirectory() as tmpdir:
result = AnalysisResult(
analyzer="beats",
input_hash="audio_123",
data={
"tempo": 120.0,
"times": [0.0, 0.5, 1.0],
},
)
path = save_analysis_result(tmpdir, "node_abc", result)
assert path.exists()
assert path.name == "analysis.sexp"
loaded = load_analysis_result(tmpdir, "node_abc")
assert loaded is not None
assert loaded.analyzer == "beats"
assert loaded.data["tempo"] == 120.0
def test_analysis_result_kebab_case(self):
"""Keys convert between snake_case and kebab-case."""
result = AnalysisResult(
analyzer="test",
input_hash="hash",
data={
"scene_count": 5,
"beat_times": [1, 2, 3],
},
)
sexp_str = result.to_string()
# Kebab case in sexp
assert "scene-count" in sexp_str
assert "beat-times" in sexp_str
# Back to snake_case after parsing
loaded = AnalysisResult.from_string(sexp_str)
assert "scene_count" in loaded.data
assert "beat_times" in loaded.data

View File

@@ -0,0 +1,286 @@
"""
Tests for stage compilation and scoping.
Tests the CompiledStage dataclass, stage form parsing,
variable scoping, and dependency validation.
"""
import pytest
from .parser import parse, Symbol, Keyword
from .compiler import (
compile_recipe,
CompileError,
CompiledStage,
CompilerContext,
_topological_sort_stages,
)
class TestStageCompilation:
"""Test stage form compilation."""
def test_parse_stage_form_basic(self):
"""Stage parses correctly with name and outputs."""
recipe = '''
(recipe "test-stage"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats)))
(-> audio (segment :times beats) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 1
assert compiled.stages[0].name == "analyze"
assert "beats" in compiled.stages[0].outputs
assert len(compiled.stages[0].node_ids) > 0
def test_parse_stage_with_requires(self):
"""Stage parses correctly with requires and inputs."""
recipe = '''
(recipe "test-requires"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :process
:requires [:analyze]
:inputs [beats]
:outputs [segments]
(def segments (-> audio (segment :times beats)))
(-> segments (sequence))))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 2
process_stage = next(s for s in compiled.stages if s.name == "process")
assert process_stage.requires == ["analyze"]
assert "beats" in process_stage.inputs
assert "segments" in process_stage.outputs
def test_stage_outputs_recorded(self):
"""Stage outputs are tracked in CompiledStage."""
recipe = '''
(recipe "test-outputs"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats tempo]
(def beats (-> audio (analyze beats)))
(def tempo (-> audio (analyze tempo)))
(-> audio (segment :times beats) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
stage = compiled.stages[0]
assert "beats" in stage.outputs
assert "tempo" in stage.outputs
assert "beats" in stage.output_bindings
assert "tempo" in stage.output_bindings
def test_stage_order_topological(self):
"""Stages are topologically sorted."""
recipe = '''
(recipe "test-order"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :output
:requires [:analyze]
:inputs [beats]
(-> audio (segment :times beats) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
# analyze should come before output
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
class TestStageValidation:
"""Test stage dependency and input validation."""
def test_stage_requires_validation(self):
"""Error if requiring non-existent stage."""
recipe = '''
(recipe "test-bad-require"
(def audio (source :path "test.mp3"))
(stage :process
:requires [:nonexistent]
:inputs [beats]
(def result audio)))
'''
with pytest.raises(CompileError, match="requires undefined stage"):
compile_recipe(parse(recipe))
def test_stage_inputs_validation(self):
"""Error if input not produced by required stage."""
recipe = '''
(recipe "test-bad-input"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :process
:requires [:analyze]
:inputs [nonexistent]
(def result audio)))
'''
with pytest.raises(CompileError, match="not an output of any required stage"):
compile_recipe(parse(recipe))
def test_undeclared_output_error(self):
"""Error if stage declares output not defined in body."""
recipe = '''
(recipe "test-missing-output"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats nonexistent]
(def beats (-> audio (analyze beats)))))
'''
with pytest.raises(CompileError, match="not defined in the stage body"):
compile_recipe(parse(recipe))
def test_forward_reference_detection(self):
"""Error when requiring a stage not yet defined."""
# Forward references are not allowed - stages must be defined
# before they can be required
recipe = '''
(recipe "test-forward"
(def audio (source :path "test.mp3"))
(stage :a
:requires [:b]
:outputs [out-a]
(def out-a audio))
(stage :b
:outputs [out-b]
(def out-b audio)
audio))
'''
with pytest.raises(CompileError, match="requires undefined stage"):
compile_recipe(parse(recipe))
class TestStageScoping:
"""Test variable scoping between stages."""
def test_pre_stage_bindings_accessible(self):
"""Sources defined before stages accessible to all stages."""
recipe = '''
(recipe "test-pre-stage"
(def audio (source :path "test.mp3"))
(def video (source :path "test.mp4"))
(stage :analyze-audio
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :analyze-video
:outputs [scenes]
(def scenes (-> video (analyze scenes)))
(-> video (segment :times scenes) (sequence))))
'''
# Should compile without error - audio and video accessible to both stages
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 2
def test_stage_bindings_flow_through_requires(self):
"""Stage bindings accessible to dependent stages via :inputs."""
recipe = '''
(recipe "test-binding-flow"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :process
:requires [:analyze]
:inputs [beats]
:outputs [result]
(def result (-> audio (segment :times beats)))
(-> result (sequence))))
'''
# Should compile without error - beats flows from analyze to process
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 2
class TestTopologicalSort:
"""Test stage topological sorting."""
def test_empty_stages(self):
"""Empty stages returns empty list."""
assert _topological_sort_stages({}) == []
def test_single_stage(self):
"""Single stage returns single element."""
stages = {
"a": CompiledStage(
name="a",
requires=[],
inputs=[],
outputs=["out"],
node_ids=["n1"],
output_bindings={"out": "n1"},
)
}
assert _topological_sort_stages(stages) == ["a"]
def test_linear_chain(self):
"""Linear chain sorted correctly."""
stages = {
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
node_ids=["n1"], output_bindings={"x": "n1"}),
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
node_ids=["n2"], output_bindings={"y": "n2"}),
"c": CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
node_ids=["n3"], output_bindings={"z": "n3"}),
}
result = _topological_sort_stages(stages)
assert result.index("a") < result.index("b") < result.index("c")
def test_parallel_stages_same_level(self):
"""Parallel stages are both valid orderings."""
stages = {
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
node_ids=["n1"], output_bindings={"x": "n1"}),
"b": CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
node_ids=["n2"], output_bindings={"y": "n2"}),
}
result = _topological_sort_stages(stages)
# Both a and b should be in result (order doesn't matter)
assert set(result) == {"a", "b"}
def test_diamond_dependency(self):
"""Diamond pattern: A -> B, A -> C, B+C -> D."""
stages = {
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
node_ids=["n1"], output_bindings={"x": "n1"}),
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
node_ids=["n2"], output_bindings={"y": "n2"}),
"c": CompiledStage(name="c", requires=["a"], inputs=["x"], outputs=["z"],
node_ids=["n3"], output_bindings={"z": "n3"}),
"d": CompiledStage(name="d", requires=["b", "c"], inputs=["y", "z"], outputs=["out"],
node_ids=["n4"], output_bindings={"out": "n4"}),
}
result = _topological_sort_stages(stages)
# a must be first, d must be last
assert result[0] == "a"
assert result[-1] == "d"
# b and c must be before d
assert result.index("b") < result.index("d")
assert result.index("c") < result.index("d")

View File

@@ -0,0 +1,739 @@
"""
End-to-end integration tests for staged recipes.
Tests the complete flow: compile -> plan -> execute
for recipes with stages.
"""
import pytest
import tempfile
from pathlib import Path
from .parser import parse, serialize
from .compiler import compile_recipe, CompileError
from .planner import ExecutionPlanSexp, StagePlan
from .stage_cache import StageCache, StageCacheEntry, StageOutput
from .scheduler import StagePlanScheduler, StagePlanResult
class TestSimpleTwoStageRecipe:
"""Test basic two-stage recipe flow."""
def test_compile_two_stage_recipe(self):
"""Compile a simple two-stage recipe."""
recipe = '''
(recipe "test-two-stages"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :output
:requires [:analyze]
:inputs [beats]
(-> audio (segment :times beats) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 2
assert compiled.stage_order == ["analyze", "output"]
analyze_stage = compiled.stages[0]
assert analyze_stage.name == "analyze"
assert "beats" in analyze_stage.outputs
output_stage = compiled.stages[1]
assert output_stage.name == "output"
assert output_stage.requires == ["analyze"]
assert "beats" in output_stage.inputs
class TestParallelAnalysisStages:
"""Test parallel analysis stages."""
def test_compile_parallel_stages(self):
"""Two analysis stages can run in parallel."""
recipe = '''
(recipe "test-parallel"
(def audio-a (source :path "a.mp3"))
(def audio-b (source :path "b.mp3"))
(stage :analyze-a
:outputs [beats-a]
(def beats-a (-> audio-a (analyze beats))))
(stage :analyze-b
:outputs [beats-b]
(def beats-b (-> audio-b (analyze beats))))
(stage :combine
:requires [:analyze-a :analyze-b]
:inputs [beats-a beats-b]
(-> audio-a (segment :times beats-a) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 3
# analyze-a and analyze-b should both be at level 0 (parallel)
analyze_a = next(s for s in compiled.stages if s.name == "analyze-a")
analyze_b = next(s for s in compiled.stages if s.name == "analyze-b")
combine = next(s for s in compiled.stages if s.name == "combine")
assert analyze_a.requires == []
assert analyze_b.requires == []
assert set(combine.requires) == {"analyze-a", "analyze-b"}
class TestDiamondDependency:
"""Test diamond dependency pattern: A -> B, A -> C, B+C -> D."""
def test_compile_diamond_pattern(self):
"""Diamond pattern compiles correctly."""
recipe = '''
(recipe "test-diamond"
(def audio (source :path "test.mp3"))
(stage :source-stage
:outputs [audio-ref]
(def audio-ref audio))
(stage :branch-b
:requires [:source-stage]
:inputs [audio-ref]
:outputs [result-b]
(def result-b (-> audio-ref (effect gain :amount 0.5))))
(stage :branch-c
:requires [:source-stage]
:inputs [audio-ref]
:outputs [result-c]
(def result-c (-> audio-ref (effect gain :amount 0.8))))
(stage :merge
:requires [:branch-b :branch-c]
:inputs [result-b result-c]
(-> result-b (blend result-c :mode "mix"))))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 4
# Check dependencies
source = next(s for s in compiled.stages if s.name == "source-stage")
branch_b = next(s for s in compiled.stages if s.name == "branch-b")
branch_c = next(s for s in compiled.stages if s.name == "branch-c")
merge = next(s for s in compiled.stages if s.name == "merge")
assert source.requires == []
assert branch_b.requires == ["source-stage"]
assert branch_c.requires == ["source-stage"]
assert set(merge.requires) == {"branch-b", "branch-c"}
# source-stage should come first in order
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-b")
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-c")
# merge should come last
assert compiled.stage_order.index("branch-b") < compiled.stage_order.index("merge")
assert compiled.stage_order.index("branch-c") < compiled.stage_order.index("merge")
class TestStageReuseOnRerun:
"""Test that re-running recipe uses cached stages."""
def test_stage_reuse(self):
"""Re-running recipe uses cached stages."""
with tempfile.TemporaryDirectory() as tmpdir:
stage_cache = StageCache(tmpdir)
# Simulate first run by caching a stage
entry = StageCacheEntry(
stage_name="analyze",
cache_id="fixed_cache_id",
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
)
stage_cache.save_stage(entry)
# Verify cache exists
assert stage_cache.has_stage("fixed_cache_id")
# Second run should find cache
loaded = stage_cache.load_stage("fixed_cache_id")
assert loaded is not None
assert loaded.stage_name == "analyze"
class TestExplicitDataFlowEndToEnd:
"""Test that analysis results flow through :inputs/:outputs."""
def test_data_flow_declaration(self):
"""Explicit data flow is declared correctly."""
recipe = '''
(recipe "test-data-flow"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats tempo]
(def beats (-> audio (analyze beats)))
(def tempo (-> audio (analyze tempo))))
(stage :process
:requires [:analyze]
:inputs [beats tempo]
:outputs [result]
(def result (-> audio (segment :times beats) (effect speed :factor tempo)))
(-> result (sequence))))
'''
compiled = compile_recipe(parse(recipe))
analyze = next(s for s in compiled.stages if s.name == "analyze")
process = next(s for s in compiled.stages if s.name == "process")
# Analyze outputs
assert set(analyze.outputs) == {"beats", "tempo"}
assert "beats" in analyze.output_bindings
assert "tempo" in analyze.output_bindings
# Process inputs
assert set(process.inputs) == {"beats", "tempo"}
assert process.requires == ["analyze"]
class TestRecipeFixtures:
"""Test using recipe fixtures."""
@pytest.fixture
def test_recipe_two_stages(self):
return '''
(recipe "test-two-stages"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :output
:requires [:analyze]
:inputs [beats]
(-> audio (segment :times beats) (sequence))))
'''
@pytest.fixture
def test_recipe_parallel_stages(self):
return '''
(recipe "test-parallel"
(def audio-a (source :path "a.mp3"))
(def audio-b (source :path "b.mp3"))
(stage :analyze-a
:outputs [beats-a]
(def beats-a (-> audio-a (analyze beats))))
(stage :analyze-b
:outputs [beats-b]
(def beats-b (-> audio-b (analyze beats))))
(stage :combine
:requires [:analyze-a :analyze-b]
:inputs [beats-a beats-b]
(-> audio-a (blend audio-b :mode "mix"))))
'''
def test_two_stages_fixture(self, test_recipe_two_stages):
"""Two-stage recipe fixture compiles."""
compiled = compile_recipe(parse(test_recipe_two_stages))
assert len(compiled.stages) == 2
def test_parallel_stages_fixture(self, test_recipe_parallel_stages):
"""Parallel stages recipe fixture compiles."""
compiled = compile_recipe(parse(test_recipe_parallel_stages))
assert len(compiled.stages) == 3
class TestStageValidationErrors:
"""Test error handling for invalid stage recipes."""
def test_missing_output_declaration(self):
"""Error when stage output not declared."""
recipe = '''
(recipe "test-missing-output"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats nonexistent]
(def beats (-> audio (analyze beats)))))
'''
with pytest.raises(CompileError, match="not defined in the stage body"):
compile_recipe(parse(recipe))
def test_input_without_requires(self):
"""Error when using input not from required stage."""
recipe = '''
(recipe "test-bad-input"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :process
:requires []
:inputs [beats]
(def result audio)))
'''
with pytest.raises(CompileError, match="not an output of any required stage"):
compile_recipe(parse(recipe))
def test_forward_reference(self):
"""Error when requiring stage not yet defined (forward reference)."""
recipe = '''
(recipe "test-forward-ref"
(def audio (source :path "test.mp3"))
(stage :a
:requires [:b]
:outputs [out-a]
(def out-a audio)
audio)
(stage :b
:outputs [out-b]
(def out-b audio)
audio))
'''
with pytest.raises(CompileError, match="requires undefined stage"):
compile_recipe(parse(recipe))
class TestBeatSyncDemoRecipe:
"""Test the beat-sync demo recipe from examples."""
BEAT_SYNC_RECIPE = '''
;; Simple staged recipe demo
(recipe "beat-sync-demo"
:version "1.0"
:description "Demo of staged beat-sync workflow"
;; Pre-stage definitions (available to all stages)
(def audio (source :path "input.mp3"))
;; Stage 1: Analysis (expensive, cached)
(stage :analyze
:outputs [beats tempo]
(def beats (-> audio (analyze beats)))
(def tempo (-> audio (analyze tempo))))
;; Stage 2: Processing (uses analysis results)
(stage :process
:requires [:analyze]
:inputs [beats]
:outputs [segments]
(def segments (-> audio (segment :times beats)))
(-> segments (sequence))))
'''
def test_compile_beat_sync_recipe(self):
"""Beat-sync demo recipe compiles correctly."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
assert compiled.name == "beat-sync-demo"
assert compiled.version == "1.0"
assert compiled.description == "Demo of staged beat-sync workflow"
def test_beat_sync_stage_count(self):
"""Beat-sync has 2 stages in correct order."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
assert len(compiled.stages) == 2
assert compiled.stage_order == ["analyze", "process"]
def test_beat_sync_analyze_stage(self):
"""Analyze stage has correct outputs."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
analyze = next(s for s in compiled.stages if s.name == "analyze")
assert analyze.requires == []
assert analyze.inputs == []
assert set(analyze.outputs) == {"beats", "tempo"}
assert "beats" in analyze.output_bindings
assert "tempo" in analyze.output_bindings
def test_beat_sync_process_stage(self):
"""Process stage has correct dependencies and inputs."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
process = next(s for s in compiled.stages if s.name == "process")
assert process.requires == ["analyze"]
assert "beats" in process.inputs
assert "segments" in process.outputs
def test_beat_sync_node_count(self):
"""Beat-sync generates expected number of nodes."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
# 1 SOURCE + 2 ANALYZE + 1 SEGMENT + 1 SEQUENCE = 5 nodes
assert len(compiled.nodes) == 5
def test_beat_sync_node_types(self):
"""Beat-sync generates correct node types."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
node_types = [n["type"] for n in compiled.nodes]
assert node_types.count("SOURCE") == 1
assert node_types.count("ANALYZE") == 2
assert node_types.count("SEGMENT") == 1
assert node_types.count("SEQUENCE") == 1
def test_beat_sync_output_is_sequence(self):
"""Beat-sync output node is the sequence node."""
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
assert output_node["type"] == "SEQUENCE"
class TestAsciiArtStagedRecipe:
"""Test the ASCII art staged recipe."""
ASCII_ART_STAGED_RECIPE = '''
;; ASCII art effect with staged execution
(recipe "ascii_art_staged"
:version "1.0"
:description "ASCII art effect with staged execution"
:encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30)
;; Registry
(effect ascii_art :path "sexp_effects/effects/ascii_art.sexp")
(analyzer energy :path "../artdag-analyzers/energy/analyzer.py")
;; Pre-stage definitions
(def color_mode "color")
(def video (source :path "monday.webm"))
(def audio (source :path "dizzy.mp3"))
;; Stage 1: Analysis
(stage :analyze
:outputs [energy-data]
(def audio-clip (-> audio (segment :start 60 :duration 10)))
(def energy-data (-> audio-clip (analyze energy))))
;; Stage 2: Process
(stage :process
:requires [:analyze]
:inputs [energy-data]
:outputs [result audio-clip]
(def clip (-> video (segment :start 0 :duration 10)))
(def audio-clip (-> audio (segment :start 60 :duration 10)))
(def result (-> clip
(effect ascii_art
:char_size (bind energy-data values :range [2 32])
:color_mode color_mode))))
;; Stage 3: Output
(stage :output
:requires [:process]
:inputs [result audio-clip]
(mux result audio-clip)))
'''
def test_compile_ascii_art_staged(self):
"""ASCII art staged recipe compiles correctly."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
assert compiled.name == "ascii_art_staged"
assert compiled.version == "1.0"
def test_ascii_art_stage_count(self):
"""ASCII art has 3 stages in correct order."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
assert len(compiled.stages) == 3
assert compiled.stage_order == ["analyze", "process", "output"]
def test_ascii_art_analyze_stage(self):
"""Analyze stage outputs energy-data."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
analyze = next(s for s in compiled.stages if s.name == "analyze")
assert analyze.requires == []
assert analyze.inputs == []
assert "energy-data" in analyze.outputs
def test_ascii_art_process_stage(self):
"""Process stage requires analyze and outputs result."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
process = next(s for s in compiled.stages if s.name == "process")
assert process.requires == ["analyze"]
assert "energy-data" in process.inputs
assert "result" in process.outputs
assert "audio-clip" in process.outputs
def test_ascii_art_output_stage(self):
"""Output stage requires process and has mux."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
output = next(s for s in compiled.stages if s.name == "output")
assert output.requires == ["process"]
assert "result" in output.inputs
assert "audio-clip" in output.inputs
def test_ascii_art_node_count(self):
"""ASCII art generates expected nodes."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
# 2 SOURCE + 2 SEGMENT + 1 ANALYZE + 1 EFFECT + 1 MUX = 7+ nodes
assert len(compiled.nodes) >= 7
def test_ascii_art_has_mux_output(self):
"""ASCII art output is MUX node."""
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
assert output_node["type"] == "MUX"
class TestMixedStagedAndNonStagedRecipes:
"""Test that non-staged recipes still work."""
def test_recipe_without_stages(self):
"""Non-staged recipe compiles normally."""
recipe = '''
(recipe "no-stages"
(-> (source :path "test.mp3")
(effect gain :amount 0.5)))
'''
compiled = compile_recipe(parse(recipe))
assert compiled.stages == []
assert compiled.stage_order == []
# Should still have nodes
assert len(compiled.nodes) > 0
def test_mixed_pre_stage_and_stages(self):
"""Pre-stage definitions work with stages."""
recipe = '''
(recipe "mixed"
;; Pre-stage definitions
(def audio (source :path "test.mp3"))
(def volume 0.8)
;; Stage using pre-stage definitions, ending with output expression
(stage :process
:outputs [result]
(def result (-> audio (effect gain :amount volume)))
result))
'''
compiled = compile_recipe(parse(recipe))
assert len(compiled.stages) == 1
# audio and volume should be accessible in stage
process = compiled.stages[0]
assert process.name == "process"
assert "result" in process.outputs
class TestEffectParamsBlock:
"""Test :params block parsing in effect definitions."""
def test_parse_effect_with_params_block(self):
"""Parse effect with new :params syntax."""
from .effect_loader import load_sexp_effect
effect_code = '''
(define-effect test_effect
:params (
(size :type int :default 10 :range [1 100] :desc "Size parameter")
(color :type string :default "red" :desc "Color parameter")
(enabled :type int :default 1 :range [0 1] :desc "Enable flag")
)
frame)
'''
name, process_fn, defaults, param_defs = load_sexp_effect(effect_code)
assert name == "test_effect"
assert len(param_defs) == 3
assert defaults["size"] == 10
assert defaults["color"] == "red"
assert defaults["enabled"] == 1
# Check ParamDef objects
size_param = param_defs[0]
assert size_param.name == "size"
assert size_param.param_type == "int"
assert size_param.default == 10
assert size_param.range_min == 1.0
assert size_param.range_max == 100.0
assert size_param.description == "Size parameter"
color_param = param_defs[1]
assert color_param.name == "color"
assert color_param.param_type == "string"
assert color_param.default == "red"
def test_parse_effect_with_choices(self):
"""Parse effect with choices in :params."""
from .effect_loader import load_sexp_effect
effect_code = '''
(define-effect mode_effect
:params (
(mode :type string :default "fast"
:choices [fast slow medium]
:desc "Processing mode")
)
frame)
'''
name, _, defaults, param_defs = load_sexp_effect(effect_code)
assert name == "mode_effect"
assert defaults["mode"] == "fast"
mode_param = param_defs[0]
assert mode_param.choices == ["fast", "slow", "medium"]
def test_legacy_effect_syntax_rejected(self):
"""Legacy effect syntax should be rejected."""
from .effect_loader import load_sexp_effect
import pytest
effect_code = '''
(define-effect legacy_effect
((width 100)
(height 200)
(name "default"))
frame)
'''
with pytest.raises(ValueError) as exc_info:
load_sexp_effect(effect_code)
assert "Legacy parameter syntax" in str(exc_info.value)
assert ":params" in str(exc_info.value)
def test_effect_params_introspection(self):
"""Test that effect params are available for introspection."""
from .effect_loader import load_sexp_effect_file
from pathlib import Path
# Create a temp effect file
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
f.write('''
(define-effect introspect_test
:params (
(alpha :type float :default 0.5 :range [0 1] :desc "Alpha value")
)
frame)
''')
temp_path = Path(f.name)
try:
name, _, defaults, param_defs = load_sexp_effect_file(temp_path)
assert name == "introspect_test"
assert len(param_defs) == 1
assert param_defs[0].name == "alpha"
assert param_defs[0].param_type == "float"
finally:
temp_path.unlink()
class TestConstructParamsBlock:
"""Test :params block parsing in construct definitions."""
def test_parse_construct_params_helper(self):
"""Test the _parse_construct_params helper function."""
from .planner import _parse_construct_params
from .parser import Symbol, Keyword
params_list = [
[Symbol("duration"), Keyword("type"), Symbol("float"),
Keyword("default"), 5.0, Keyword("desc"), "Duration in seconds"],
[Symbol("count"), Keyword("type"), Symbol("int"),
Keyword("default"), 10],
]
param_names, param_defaults = _parse_construct_params(params_list)
assert param_names == ["duration", "count"]
assert param_defaults["duration"] == 5.0
assert param_defaults["count"] == 10
def test_construct_params_with_no_defaults(self):
"""Test construct params where some have no default."""
from .planner import _parse_construct_params
from .parser import Symbol, Keyword
params_list = [
[Symbol("required_param"), Keyword("type"), Symbol("string")],
[Symbol("optional_param"), Keyword("type"), Symbol("int"),
Keyword("default"), 42],
]
param_names, param_defaults = _parse_construct_params(params_list)
assert param_names == ["required_param", "optional_param"]
assert param_defaults["required_param"] is None
assert param_defaults["optional_param"] == 42
class TestParameterValidation:
"""Test that unknown parameters are rejected."""
def test_effect_rejects_unknown_params(self):
"""Effects should reject unknown parameters."""
from .effect_loader import load_sexp_effect
import numpy as np
import pytest
effect_code = '''
(define-effect test_effect
:params (
(brightness :type int :default 0 :desc "Brightness")
)
frame)
'''
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
# Create a test frame
frame = np.zeros((100, 100, 3), dtype=np.uint8)
state = {}
# Valid param should work
result, _ = process_frame(frame, {"brightness": 10}, state)
assert isinstance(result, np.ndarray)
# Unknown param should raise
with pytest.raises(ValueError) as exc_info:
process_frame(frame, {"unknown_param": 42}, state)
assert "Unknown parameter 'unknown_param'" in str(exc_info.value)
assert "brightness" in str(exc_info.value)
def test_effect_no_params_rejects_all(self):
"""Effects with no params should reject any parameter."""
from .effect_loader import load_sexp_effect
import numpy as np
import pytest
effect_code = '''
(define-effect no_params_effect
:params ()
frame)
'''
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
# Create a test frame
frame = np.zeros((100, 100, 3), dtype=np.uint8)
state = {}
# Empty params should work
result, _ = process_frame(frame, {}, state)
assert isinstance(result, np.ndarray)
# Any param should raise
with pytest.raises(ValueError) as exc_info:
process_frame(frame, {"any_param": 42}, state)
assert "Unknown parameter 'any_param'" in str(exc_info.value)
assert "(none)" in str(exc_info.value)

View File

@@ -0,0 +1,228 @@
"""
Tests for stage-aware planning.
Tests stage topological sorting, level computation, cache ID computation,
and plan metadata generation.
"""
import pytest
from pathlib import Path
from .parser import parse
from .compiler import compile_recipe, CompiledStage
from .planner import (
create_plan,
StagePlan,
_compute_stage_levels,
_compute_stage_cache_id,
)
class TestStagePlanning:
"""Test stage-aware plan creation."""
def test_stage_topological_sort_in_plan(self):
"""Stages sorted by dependencies in plan."""
recipe = '''
(recipe "test-sort"
(def audio (source :path "test.mp3"))
(stage :analyze
:outputs [beats]
(def beats (-> audio (analyze beats))))
(stage :output
:requires [:analyze]
:inputs [beats]
(-> audio (segment :times beats) (sequence))))
'''
compiled = compile_recipe(parse(recipe))
# Note: create_plan needs recipe_dir for analysis, we'll test the ordering differently
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
def test_stage_level_computation(self):
"""Independent stages get same level."""
stages = [
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
node_ids=["n1"], output_bindings={"x": "n1"}),
CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
node_ids=["n2"], output_bindings={"y": "n2"}),
CompiledStage(name="c", requires=["a", "b"], inputs=["x", "y"], outputs=["z"],
node_ids=["n3"], output_bindings={"z": "n3"}),
]
levels = _compute_stage_levels(stages)
assert levels["a"] == 0
assert levels["b"] == 0
assert levels["c"] == 1 # Depends on a and b
def test_stage_level_chain(self):
"""Chain stages get increasing levels."""
stages = [
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
node_ids=["n1"], output_bindings={"x": "n1"}),
CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
node_ids=["n2"], output_bindings={"y": "n2"}),
CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
node_ids=["n3"], output_bindings={"z": "n3"}),
]
levels = _compute_stage_levels(stages)
assert levels["a"] == 0
assert levels["b"] == 1
assert levels["c"] == 2
def test_stage_cache_id_deterministic(self):
"""Same stage = same cache ID."""
stage = CompiledStage(
name="analyze",
requires=[],
inputs=[],
outputs=["beats"],
node_ids=["abc123"],
output_bindings={"beats": "abc123"},
)
cache_id_1 = _compute_stage_cache_id(
stage,
stage_cache_ids={},
node_cache_ids={"abc123": "nodeabc"},
cluster_key=None,
)
cache_id_2 = _compute_stage_cache_id(
stage,
stage_cache_ids={},
node_cache_ids={"abc123": "nodeabc"},
cluster_key=None,
)
assert cache_id_1 == cache_id_2
def test_stage_cache_id_includes_requires(self):
"""Cache ID changes when required stage cache ID changes."""
stage = CompiledStage(
name="process",
requires=["analyze"],
inputs=["beats"],
outputs=["result"],
node_ids=["def456"],
output_bindings={"result": "def456"},
)
cache_id_1 = _compute_stage_cache_id(
stage,
stage_cache_ids={"analyze": "req_cache_a"},
node_cache_ids={"def456": "node_def"},
cluster_key=None,
)
cache_id_2 = _compute_stage_cache_id(
stage,
stage_cache_ids={"analyze": "req_cache_b"},
node_cache_ids={"def456": "node_def"},
cluster_key=None,
)
# Different required stage cache IDs should produce different cache IDs
assert cache_id_1 != cache_id_2
def test_stage_cache_id_cluster_key(self):
"""Cache ID changes with cluster key."""
stage = CompiledStage(
name="analyze",
requires=[],
inputs=[],
outputs=["beats"],
node_ids=["abc123"],
output_bindings={"beats": "abc123"},
)
cache_id_no_key = _compute_stage_cache_id(
stage,
stage_cache_ids={},
node_cache_ids={"abc123": "nodeabc"},
cluster_key=None,
)
cache_id_with_key = _compute_stage_cache_id(
stage,
stage_cache_ids={},
node_cache_ids={"abc123": "nodeabc"},
cluster_key="cluster123",
)
# Cluster key should change the cache ID
assert cache_id_no_key != cache_id_with_key
class TestStagePlanMetadata:
"""Test stage metadata in execution plans."""
def test_plan_without_stages(self):
"""Plan without stages has empty stage fields."""
recipe = '''
(recipe "no-stages"
(-> (source :path "test.mp3") (effect gain :amount 0.5)))
'''
compiled = compile_recipe(parse(recipe))
assert compiled.stages == []
assert compiled.stage_order == []
class TestStagePlanDataclass:
"""Test StagePlan dataclass."""
def test_stage_plan_creation(self):
"""StagePlan can be created with all fields."""
from .planner import PlanStep
step = PlanStep(
step_id="step1",
node_type="ANALYZE",
config={"analyzer": "beats"},
inputs=["input1"],
cache_id="cache123",
level=0,
stage="analyze",
stage_cache_id="stage_cache_123",
)
stage_plan = StagePlan(
stage_name="analyze",
cache_id="stage_cache_123",
steps=[step],
requires=[],
output_bindings={"beats": "cache123"},
level=0,
)
assert stage_plan.stage_name == "analyze"
assert stage_plan.cache_id == "stage_cache_123"
assert len(stage_plan.steps) == 1
assert stage_plan.level == 0
class TestExplicitDataRouting:
"""Test that plan includes explicit data routing."""
def test_plan_step_includes_stage_info(self):
"""PlanStep includes stage and stage_cache_id."""
from .planner import PlanStep
step = PlanStep(
step_id="step1",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="cache123",
level=0,
stage="analyze",
stage_cache_id="stage_cache_abc",
)
sexp = step.to_sexp()
# Convert to string to check for stage info
from .parser import serialize
sexp_str = serialize(sexp)
assert "stage" in sexp_str
assert "analyze" in sexp_str
assert "stage-cache-id" in sexp_str

View File

@@ -0,0 +1,323 @@
"""
Tests for stage-aware scheduler.
Tests stage cache hit/miss, stage execution ordering,
and parallel stage support.
"""
import pytest
import tempfile
from unittest.mock import Mock, MagicMock, patch
from .scheduler import (
StagePlanScheduler,
StageResult,
StagePlanResult,
create_stage_scheduler,
schedule_staged_plan,
)
from .planner import ExecutionPlanSexp, PlanStep, StagePlan
from .stage_cache import StageCache, StageCacheEntry, StageOutput
class TestStagePlanScheduler:
"""Test stage-aware scheduling."""
def test_plan_without_stages_uses_regular_scheduling(self):
"""Plans without stages fall back to regular scheduling."""
plan = ExecutionPlanSexp(
plan_id="test_plan",
recipe_id="test_recipe",
recipe_hash="abc123",
steps=[],
output_step_id="output",
stage_plans=[], # No stages
)
scheduler = StagePlanScheduler()
# This will use PlanScheduler internally
# Without Celery, it just returns completed status
result = scheduler.schedule(plan)
assert isinstance(result, StagePlanResult)
def test_stage_cache_hit_skips_execution(self):
"""Cached stage not re-executed."""
with tempfile.TemporaryDirectory() as tmpdir:
stage_cache = StageCache(tmpdir)
# Pre-populate cache
entry = StageCacheEntry(
stage_name="analyze",
cache_id="stage_cache_123",
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
)
stage_cache.save_stage(entry)
step = PlanStep(
step_id="step1",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="step_cache",
level=0,
stage="analyze",
stage_cache_id="stage_cache_123",
)
stage_plan = StagePlan(
stage_name="analyze",
cache_id="stage_cache_123",
steps=[step],
requires=[],
output_bindings={"beats": "beats_out"},
level=0,
)
plan = ExecutionPlanSexp(
plan_id="test_plan",
recipe_id="test_recipe",
recipe_hash="abc123",
steps=[step],
output_step_id="step1",
stage_plans=[stage_plan],
stage_order=["analyze"],
stage_levels={"analyze": 0},
stage_cache_ids={"analyze": "stage_cache_123"},
)
scheduler = StagePlanScheduler(stage_cache=stage_cache)
result = scheduler.schedule(plan)
assert result.stages_cached == 1
assert result.stages_completed == 0
def test_stage_inputs_loaded_from_cache(self):
"""Stage receives inputs from required stage cache."""
with tempfile.TemporaryDirectory() as tmpdir:
stage_cache = StageCache(tmpdir)
# Pre-populate upstream stage cache
upstream_entry = StageCacheEntry(
stage_name="analyze",
cache_id="upstream_cache",
outputs={"beats": StageOutput(cache_id="beats_data", output_type="analysis")},
)
stage_cache.save_stage(upstream_entry)
# Steps for stages
upstream_step = PlanStep(
step_id="analyze_step",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="analyze_cache",
level=0,
stage="analyze",
stage_cache_id="upstream_cache",
)
downstream_step = PlanStep(
step_id="process_step",
node_type="SEGMENT",
config={},
inputs=["analyze_step"],
cache_id="process_cache",
level=1,
stage="process",
stage_cache_id="downstream_cache",
)
upstream_plan = StagePlan(
stage_name="analyze",
cache_id="upstream_cache",
steps=[upstream_step],
requires=[],
output_bindings={"beats": "beats_data"},
level=0,
)
downstream_plan = StagePlan(
stage_name="process",
cache_id="downstream_cache",
steps=[downstream_step],
requires=["analyze"],
output_bindings={"result": "process_cache"},
level=1,
)
plan = ExecutionPlanSexp(
plan_id="test_plan",
recipe_id="test_recipe",
recipe_hash="abc123",
steps=[upstream_step, downstream_step],
output_step_id="process_step",
stage_plans=[upstream_plan, downstream_plan],
stage_order=["analyze", "process"],
stage_levels={"analyze": 0, "process": 1},
stage_cache_ids={"analyze": "upstream_cache", "process": "downstream_cache"},
)
scheduler = StagePlanScheduler(stage_cache=stage_cache)
result = scheduler.schedule(plan)
# Upstream should be cached, downstream executed
assert result.stages_cached == 1
assert "analyze" in result.stage_results
assert result.stage_results["analyze"].status == "cached"
def test_parallel_stages_same_level(self):
"""Stages at same level can run in parallel."""
step_a = PlanStep(
step_id="step_a",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="cache_a",
level=0,
stage="analyze-a",
stage_cache_id="stage_a",
)
step_b = PlanStep(
step_id="step_b",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="cache_b",
level=0,
stage="analyze-b",
stage_cache_id="stage_b",
)
stage_a = StagePlan(
stage_name="analyze-a",
cache_id="stage_a",
steps=[step_a],
requires=[],
output_bindings={"beats-a": "cache_a"},
level=0,
)
stage_b = StagePlan(
stage_name="analyze-b",
cache_id="stage_b",
steps=[step_b],
requires=[],
output_bindings={"beats-b": "cache_b"},
level=0,
)
plan = ExecutionPlanSexp(
plan_id="test_plan",
recipe_id="test_recipe",
recipe_hash="abc123",
steps=[step_a, step_b],
output_step_id="step_b",
stage_plans=[stage_a, stage_b],
stage_order=["analyze-a", "analyze-b"],
stage_levels={"analyze-a": 0, "analyze-b": 0},
stage_cache_ids={"analyze-a": "stage_a", "analyze-b": "stage_b"},
)
scheduler = StagePlanScheduler()
# Group stages by level
stages_by_level = scheduler._group_stages_by_level(plan.stage_plans)
# Both stages should be at level 0
assert len(stages_by_level[0]) == 2
def test_stage_outputs_cached_after_execution(self):
"""Stage outputs written to cache after completion."""
with tempfile.TemporaryDirectory() as tmpdir:
stage_cache = StageCache(tmpdir)
step = PlanStep(
step_id="step1",
node_type="ANALYZE",
config={},
inputs=[],
cache_id="step_cache",
level=0,
stage="analyze",
stage_cache_id="new_stage_cache",
)
stage_plan = StagePlan(
stage_name="analyze",
cache_id="new_stage_cache",
steps=[step],
requires=[],
output_bindings={"beats": "step_cache"},
level=0,
)
plan = ExecutionPlanSexp(
plan_id="test_plan",
recipe_id="test_recipe",
recipe_hash="abc123",
steps=[step],
output_step_id="step1",
stage_plans=[stage_plan],
stage_order=["analyze"],
stage_levels={"analyze": 0},
stage_cache_ids={"analyze": "new_stage_cache"},
)
scheduler = StagePlanScheduler(stage_cache=stage_cache)
result = scheduler.schedule(plan)
# Stage should now be cached
assert stage_cache.has_stage("new_stage_cache")
class TestStageResult:
"""Test StageResult dataclass."""
def test_stage_result_creation(self):
"""StageResult can be created with all fields."""
result = StageResult(
stage_name="test",
cache_id="cache123",
status="completed",
step_results={},
outputs={"out": "out_cache"},
)
assert result.stage_name == "test"
assert result.status == "completed"
assert result.outputs["out"] == "out_cache"
class TestStagePlanResult:
"""Test StagePlanResult dataclass."""
def test_stage_plan_result_creation(self):
"""StagePlanResult can be created with all fields."""
result = StagePlanResult(
plan_id="plan123",
status="completed",
stages_completed=2,
stages_cached=1,
stages_failed=0,
)
assert result.plan_id == "plan123"
assert result.stages_completed == 2
assert result.stages_cached == 1
class TestSchedulerFactory:
"""Test scheduler factory functions."""
def test_create_stage_scheduler(self):
"""create_stage_scheduler returns StagePlanScheduler."""
scheduler = create_stage_scheduler()
assert isinstance(scheduler, StagePlanScheduler)
def test_create_stage_scheduler_with_cache(self):
"""create_stage_scheduler accepts stage_cache."""
with tempfile.TemporaryDirectory() as tmpdir:
stage_cache = StageCache(tmpdir)
scheduler = create_stage_scheduler(stage_cache=stage_cache)
assert scheduler.stage_cache is stage_cache