Squashed 'core/' content from commit 4957443
git-subtree-dir: core git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07
This commit is contained in:
75
artdag/sexp/__init__.py
Normal file
75
artdag/sexp/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
S-expression parsing, compilation, and planning for ArtDAG.
|
||||
|
||||
This module provides:
|
||||
- parser: Parse S-expression text into Python data structures
|
||||
- compiler: Compile recipe S-expressions into DAG format
|
||||
- planner: Generate execution plans from recipes
|
||||
"""
|
||||
|
||||
from .parser import (
|
||||
parse,
|
||||
parse_all,
|
||||
serialize,
|
||||
Symbol,
|
||||
Keyword,
|
||||
ParseError,
|
||||
)
|
||||
|
||||
from .compiler import (
|
||||
compile_recipe,
|
||||
compile_string,
|
||||
CompiledRecipe,
|
||||
CompileError,
|
||||
ParamDef,
|
||||
_parse_params,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
create_plan,
|
||||
ExecutionPlanSexp,
|
||||
PlanStep,
|
||||
step_to_task_sexp,
|
||||
task_cache_id,
|
||||
)
|
||||
|
||||
from .scheduler import (
|
||||
PlanScheduler,
|
||||
PlanResult,
|
||||
StepResult,
|
||||
schedule_plan,
|
||||
step_to_sexp,
|
||||
step_sexp_to_string,
|
||||
verify_step_cache_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Parser
|
||||
'parse',
|
||||
'parse_all',
|
||||
'serialize',
|
||||
'Symbol',
|
||||
'Keyword',
|
||||
'ParseError',
|
||||
# Compiler
|
||||
'compile_recipe',
|
||||
'compile_string',
|
||||
'CompiledRecipe',
|
||||
'CompileError',
|
||||
'ParamDef',
|
||||
'_parse_params',
|
||||
# Planner
|
||||
'create_plan',
|
||||
'ExecutionPlanSexp',
|
||||
'PlanStep',
|
||||
'step_to_task_sexp',
|
||||
'task_cache_id',
|
||||
# Scheduler
|
||||
'PlanScheduler',
|
||||
'PlanResult',
|
||||
'StepResult',
|
||||
'schedule_plan',
|
||||
'step_to_sexp',
|
||||
'step_sexp_to_string',
|
||||
'verify_step_cache_id',
|
||||
]
|
||||
2463
artdag/sexp/compiler.py
Normal file
2463
artdag/sexp/compiler.py
Normal file
File diff suppressed because it is too large
Load Diff
337
artdag/sexp/effect_loader.py
Normal file
337
artdag/sexp/effect_loader.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Sexp effect loader.
|
||||
|
||||
Loads sexp effect definitions (define-effect forms) and creates
|
||||
frame processors that evaluate the sexp body with primitives.
|
||||
|
||||
Effects must use :params syntax:
|
||||
|
||||
(define-effect name
|
||||
:params (
|
||||
(param1 :type int :default 8 :range [4 32] :desc "description")
|
||||
(param2 :type string :default "value" :desc "description")
|
||||
)
|
||||
body)
|
||||
|
||||
For effects with no parameters, use empty :params ():
|
||||
|
||||
(define-effect name
|
||||
:params ()
|
||||
body)
|
||||
|
||||
Unknown parameters passed to effects will raise an error.
|
||||
|
||||
Usage:
|
||||
loader = SexpEffectLoader()
|
||||
effect_fn = loader.load_effect_file(Path("effects/ascii_art.sexp"))
|
||||
output = effect_fn(input_path, output_path, config)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .parser import parse_all, Symbol, Keyword
|
||||
from .evaluator import evaluate
|
||||
from .primitives import PRIMITIVES
|
||||
from .compiler import ParamDef, _parse_params, CompileError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_define_effect(sexp) -> tuple:
|
||||
"""
|
||||
Parse a define-effect form.
|
||||
|
||||
Required syntax:
|
||||
(define-effect name
|
||||
:params (
|
||||
(param1 :type int :default 8 :range [4 32] :desc "description")
|
||||
)
|
||||
body)
|
||||
|
||||
Effects MUST use :params syntax. Legacy ((param default) ...) syntax is not supported.
|
||||
|
||||
Returns (name, params_with_defaults, param_defs, body)
|
||||
where param_defs is a list of ParamDef objects
|
||||
"""
|
||||
if not isinstance(sexp, list) or len(sexp) < 3:
|
||||
raise ValueError(f"Invalid define-effect form: {sexp}")
|
||||
|
||||
head = sexp[0]
|
||||
if not (isinstance(head, Symbol) and head.name == "define-effect"):
|
||||
raise ValueError(f"Expected define-effect, got {head}")
|
||||
|
||||
name = sexp[1]
|
||||
if isinstance(name, Symbol):
|
||||
name = name.name
|
||||
|
||||
params_with_defaults = {}
|
||||
param_defs: List[ParamDef] = []
|
||||
body = None
|
||||
found_params = False
|
||||
|
||||
# Parse :params and body
|
||||
i = 2
|
||||
while i < len(sexp):
|
||||
item = sexp[i]
|
||||
if isinstance(item, Keyword) and item.name == "params":
|
||||
# :params syntax
|
||||
if i + 1 >= len(sexp):
|
||||
raise ValueError(f"Effect '{name}': Missing params list after :params keyword")
|
||||
try:
|
||||
param_defs = _parse_params(sexp[i + 1])
|
||||
# Build params_with_defaults from ParamDef objects
|
||||
for pd in param_defs:
|
||||
params_with_defaults[pd.name] = pd.default
|
||||
except CompileError as e:
|
||||
raise ValueError(f"Effect '{name}': Error parsing :params: {e}")
|
||||
found_params = True
|
||||
i += 2
|
||||
elif isinstance(item, Keyword):
|
||||
# Skip other keywords we don't recognize
|
||||
i += 2
|
||||
elif body is None:
|
||||
# First non-keyword item is the body
|
||||
if isinstance(item, list) and item:
|
||||
first_elem = item[0]
|
||||
# Check for legacy syntax and reject it
|
||||
if isinstance(first_elem, list) and len(first_elem) >= 2:
|
||||
raise ValueError(
|
||||
f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. "
|
||||
f"Use :params block instead:\n"
|
||||
f" :params (\n"
|
||||
f" (param_name :type int :default 0 :desc \"description\")\n"
|
||||
f" )"
|
||||
)
|
||||
body = item
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if body is None:
|
||||
raise ValueError(f"Effect '{name}': No body found")
|
||||
|
||||
if not found_params:
|
||||
raise ValueError(
|
||||
f"Effect '{name}': Missing :params block. Effects must declare parameters.\n"
|
||||
f"For effects with no parameters, use empty :params ():\n"
|
||||
f" (define-effect {name}\n"
|
||||
f" :params ()\n"
|
||||
f" body)"
|
||||
)
|
||||
|
||||
return name, params_with_defaults, param_defs, body
|
||||
|
||||
|
||||
def _create_process_frame(
|
||||
effect_name: str,
|
||||
params_with_defaults: Dict[str, Any],
|
||||
param_defs: List[ParamDef],
|
||||
body: Any,
|
||||
) -> Callable:
|
||||
"""
|
||||
Create a process_frame function that evaluates the sexp body.
|
||||
|
||||
The function signature is: (frame, params, state) -> (frame, state)
|
||||
"""
|
||||
import math
|
||||
|
||||
def process_frame(frame: np.ndarray, params: Dict[str, Any], state: Any):
|
||||
"""Evaluate sexp effect body on a frame."""
|
||||
# Build environment with primitives
|
||||
env = dict(PRIMITIVES)
|
||||
|
||||
# Add math functions
|
||||
env["floor"] = lambda x: int(math.floor(x))
|
||||
env["ceil"] = lambda x: int(math.ceil(x))
|
||||
env["round"] = lambda x: int(round(x))
|
||||
env["abs"] = abs
|
||||
env["min"] = min
|
||||
env["max"] = max
|
||||
env["sqrt"] = math.sqrt
|
||||
env["sin"] = math.sin
|
||||
env["cos"] = math.cos
|
||||
|
||||
# Add list operations
|
||||
env["list"] = lambda *args: tuple(args)
|
||||
env["nth"] = lambda coll, i: coll[int(i)] if coll else None
|
||||
|
||||
# Bind frame
|
||||
env["frame"] = frame
|
||||
|
||||
# Validate that all provided params are known
|
||||
known_params = set(params_with_defaults.keys())
|
||||
for k in params.keys():
|
||||
if k not in known_params:
|
||||
raise ValueError(
|
||||
f"Effect '{effect_name}': Unknown parameter '{k}'. "
|
||||
f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}"
|
||||
)
|
||||
|
||||
# Bind parameters (defaults + overrides from config)
|
||||
for param_name, default in params_with_defaults.items():
|
||||
# Use config value if provided, otherwise default
|
||||
if param_name in params:
|
||||
env[param_name] = params[param_name]
|
||||
elif default is not None:
|
||||
env[param_name] = default
|
||||
|
||||
# Evaluate the body
|
||||
try:
|
||||
result = evaluate(body, env)
|
||||
if isinstance(result, np.ndarray):
|
||||
return result, state
|
||||
else:
|
||||
logger.warning(f"Effect {effect_name} returned {type(result)}, expected ndarray")
|
||||
return frame, state
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating effect {effect_name}: {e}")
|
||||
raise
|
||||
|
||||
return process_frame
|
||||
|
||||
|
||||
def load_sexp_effect(source: str, base_path: Optional[Path] = None) -> tuple:
|
||||
"""
|
||||
Load a sexp effect from source code.
|
||||
|
||||
Args:
|
||||
source: Sexp source code
|
||||
base_path: Base path for resolving relative imports
|
||||
|
||||
Returns:
|
||||
(effect_name, process_frame_fn, params_with_defaults, param_defs)
|
||||
where param_defs is a list of ParamDef objects for introspection
|
||||
"""
|
||||
exprs = parse_all(source)
|
||||
|
||||
# Find define-effect form
|
||||
define_effect = None
|
||||
if isinstance(exprs, list):
|
||||
for expr in exprs:
|
||||
if isinstance(expr, list) and expr and isinstance(expr[0], Symbol):
|
||||
if expr[0].name == "define-effect":
|
||||
define_effect = expr
|
||||
break
|
||||
elif isinstance(exprs, list) and exprs and isinstance(exprs[0], Symbol):
|
||||
if exprs[0].name == "define-effect":
|
||||
define_effect = exprs
|
||||
|
||||
if not define_effect:
|
||||
raise ValueError("No define-effect form found in sexp effect")
|
||||
|
||||
name, params_with_defaults, param_defs, body = _parse_define_effect(define_effect)
|
||||
process_frame = _create_process_frame(name, params_with_defaults, param_defs, body)
|
||||
|
||||
return name, process_frame, params_with_defaults, param_defs
|
||||
|
||||
|
||||
def load_sexp_effect_file(path: Path) -> tuple:
|
||||
"""
|
||||
Load a sexp effect from file.
|
||||
|
||||
Returns:
|
||||
(effect_name, process_frame_fn, params_with_defaults, param_defs)
|
||||
where param_defs is a list of ParamDef objects for introspection
|
||||
"""
|
||||
source = path.read_text()
|
||||
return load_sexp_effect(source, base_path=path.parent)
|
||||
|
||||
|
||||
class SexpEffectLoader:
|
||||
"""
|
||||
Loader for sexp effect definitions.
|
||||
|
||||
Creates effect functions compatible with the EffectExecutor.
|
||||
"""
|
||||
|
||||
def __init__(self, recipe_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize loader.
|
||||
|
||||
Args:
|
||||
recipe_dir: Base directory for resolving relative effect paths
|
||||
"""
|
||||
self.recipe_dir = recipe_dir or Path.cwd()
|
||||
# Cache loaded effects with their param_defs for introspection
|
||||
self._loaded_effects: Dict[str, tuple] = {}
|
||||
|
||||
def load_effect_path(self, effect_path: str) -> Callable:
|
||||
"""
|
||||
Load a sexp effect from a relative path.
|
||||
|
||||
Args:
|
||||
effect_path: Relative path to effect .sexp file
|
||||
|
||||
Returns:
|
||||
Effect function (input_path, output_path, config) -> output_path
|
||||
"""
|
||||
from ..effects.frame_processor import process_video
|
||||
|
||||
full_path = self.recipe_dir / effect_path
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
|
||||
|
||||
name, process_frame_fn, params_defaults, param_defs = load_sexp_effect_file(full_path)
|
||||
logger.info(f"Loaded sexp effect: {name} from {effect_path}")
|
||||
|
||||
# Cache for introspection
|
||||
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
|
||||
|
||||
def effect_fn(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""Run sexp effect via frame processor."""
|
||||
# Extract params (excluding internal keys)
|
||||
params = dict(params_defaults) # Start with defaults
|
||||
for k, v in config.items():
|
||||
if k not in ("effect", "cid", "hash", "effect_path", "_binding"):
|
||||
params[k] = v
|
||||
|
||||
# Get bindings if present
|
||||
bindings = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and value.get("_resolved_values"):
|
||||
bindings[key] = value["_resolved_values"]
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
actual_output = output_path.with_suffix(".mp4")
|
||||
|
||||
process_video(
|
||||
input_path=input_path,
|
||||
output_path=actual_output,
|
||||
process_frame=process_frame_fn,
|
||||
params=params,
|
||||
bindings=bindings,
|
||||
)
|
||||
|
||||
logger.info(f"Processed sexp effect '{name}' from {effect_path}")
|
||||
return actual_output
|
||||
|
||||
return effect_fn
|
||||
|
||||
def get_effect_params(self, effect_path: str) -> List[ParamDef]:
|
||||
"""
|
||||
Get parameter definitions for an effect.
|
||||
|
||||
Args:
|
||||
effect_path: Relative path to effect .sexp file
|
||||
|
||||
Returns:
|
||||
List of ParamDef objects describing the effect's parameters
|
||||
"""
|
||||
if effect_path not in self._loaded_effects:
|
||||
# Load the effect to get its params
|
||||
full_path = self.recipe_dir / effect_path
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
|
||||
name, _, params_defaults, param_defs = load_sexp_effect_file(full_path)
|
||||
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
|
||||
|
||||
return self._loaded_effects[effect_path][2]
|
||||
|
||||
|
||||
def get_sexp_effect_loader(recipe_dir: Optional[Path] = None) -> SexpEffectLoader:
|
||||
"""Get a sexp effect loader instance."""
|
||||
return SexpEffectLoader(recipe_dir)
|
||||
869
artdag/sexp/evaluator.py
Normal file
869
artdag/sexp/evaluator.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""
|
||||
Expression evaluator for S-expression DSL.
|
||||
|
||||
Supports:
|
||||
- Arithmetic: +, -, *, /, mod, sqrt, pow, abs, floor, ceil, round, min, max, clamp
|
||||
- Comparison: =, <, >, <=, >=
|
||||
- Logic: and, or, not
|
||||
- Predicates: odd?, even?, zero?, nil?
|
||||
- Conditionals: if, cond, case
|
||||
- Data: list, dict/map construction, get
|
||||
- Lambda calls
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Callable
|
||||
from .parser import Symbol, Keyword, Lambda, Binding
|
||||
|
||||
|
||||
class EvalError(Exception):
|
||||
"""Error during expression evaluation."""
|
||||
pass
|
||||
|
||||
|
||||
# Built-in functions
|
||||
BUILTINS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def builtin(name: str):
|
||||
"""Decorator to register a builtin function."""
|
||||
def decorator(fn):
|
||||
BUILTINS[name] = fn
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
@builtin("+")
|
||||
def add(*args):
|
||||
return sum(args)
|
||||
|
||||
|
||||
@builtin("-")
|
||||
def sub(a, b=None):
|
||||
if b is None:
|
||||
return -a
|
||||
return a - b
|
||||
|
||||
|
||||
@builtin("*")
|
||||
def mul(*args):
|
||||
result = 1
|
||||
for a in args:
|
||||
result *= a
|
||||
return result
|
||||
|
||||
|
||||
@builtin("/")
|
||||
def div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
@builtin("mod")
|
||||
def mod(a, b):
|
||||
return a % b
|
||||
|
||||
|
||||
@builtin("sqrt")
|
||||
def sqrt(x):
|
||||
return x ** 0.5
|
||||
|
||||
|
||||
@builtin("pow")
|
||||
def power(x, n):
|
||||
return x ** n
|
||||
|
||||
|
||||
@builtin("abs")
|
||||
def absolute(x):
|
||||
return abs(x)
|
||||
|
||||
|
||||
@builtin("floor")
|
||||
def floor_fn(x):
|
||||
import math
|
||||
return math.floor(x)
|
||||
|
||||
|
||||
@builtin("ceil")
|
||||
def ceil_fn(x):
|
||||
import math
|
||||
return math.ceil(x)
|
||||
|
||||
|
||||
@builtin("round")
|
||||
def round_fn(x, ndigits=0):
|
||||
return round(x, int(ndigits))
|
||||
|
||||
|
||||
@builtin("min")
|
||||
def min_fn(*args):
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
return min(args[0])
|
||||
return min(args)
|
||||
|
||||
|
||||
@builtin("max")
|
||||
def max_fn(*args):
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
return max(args[0])
|
||||
return max(args)
|
||||
|
||||
|
||||
@builtin("clamp")
|
||||
def clamp(x, lo, hi):
|
||||
return max(lo, min(hi, x))
|
||||
|
||||
|
||||
@builtin("=")
|
||||
def eq(a, b):
|
||||
return a == b
|
||||
|
||||
|
||||
@builtin("<")
|
||||
def lt(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
@builtin(">")
|
||||
def gt(a, b):
|
||||
return a > b
|
||||
|
||||
|
||||
@builtin("<=")
|
||||
def lte(a, b):
|
||||
return a <= b
|
||||
|
||||
|
||||
@builtin(">=")
|
||||
def gte(a, b):
|
||||
return a >= b
|
||||
|
||||
|
||||
@builtin("odd?")
|
||||
def is_odd(n):
|
||||
return n % 2 == 1
|
||||
|
||||
|
||||
@builtin("even?")
|
||||
def is_even(n):
|
||||
return n % 2 == 0
|
||||
|
||||
|
||||
@builtin("zero?")
|
||||
def is_zero(n):
|
||||
return n == 0
|
||||
|
||||
|
||||
@builtin("nil?")
|
||||
def is_nil(x):
|
||||
return x is None
|
||||
|
||||
|
||||
@builtin("not")
|
||||
def not_fn(x):
|
||||
return not x
|
||||
|
||||
|
||||
@builtin("inc")
|
||||
def inc(n):
|
||||
return n + 1
|
||||
|
||||
|
||||
@builtin("dec")
|
||||
def dec(n):
|
||||
return n - 1
|
||||
|
||||
|
||||
@builtin("list")
|
||||
def make_list(*args):
|
||||
return list(args)
|
||||
|
||||
|
||||
@builtin("assert")
|
||||
def assert_true(condition, message="Assertion failed"):
|
||||
if not condition:
|
||||
raise RuntimeError(f"Assertion error: {message}")
|
||||
return True
|
||||
|
||||
|
||||
@builtin("get")
|
||||
def get(coll, key, default=None):
|
||||
if isinstance(coll, dict):
|
||||
# Try the key directly first
|
||||
result = coll.get(key, None)
|
||||
if result is not None:
|
||||
return result
|
||||
# If key is a Keyword, also try its string name (for Python dicts with string keys)
|
||||
if isinstance(key, Keyword):
|
||||
result = coll.get(key.name, None)
|
||||
if result is not None:
|
||||
return result
|
||||
# Return the default
|
||||
return default
|
||||
elif isinstance(coll, list):
|
||||
return coll[key] if 0 <= key < len(coll) else default
|
||||
else:
|
||||
raise EvalError(f"get: expected dict or list, got {type(coll).__name__}: {str(coll)[:100]}")
|
||||
|
||||
|
||||
@builtin("dict?")
|
||||
def is_dict(x):
|
||||
return isinstance(x, dict)
|
||||
|
||||
|
||||
@builtin("list?")
|
||||
def is_list(x):
|
||||
return isinstance(x, list)
|
||||
|
||||
|
||||
@builtin("nil?")
|
||||
def is_nil(x):
|
||||
return x is None
|
||||
|
||||
|
||||
@builtin("number?")
|
||||
def is_number(x):
|
||||
return isinstance(x, (int, float))
|
||||
|
||||
|
||||
@builtin("string?")
|
||||
def is_string(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
@builtin("len")
|
||||
def length(coll):
|
||||
return len(coll)
|
||||
|
||||
|
||||
@builtin("first")
|
||||
def first(coll):
|
||||
return coll[0] if coll else None
|
||||
|
||||
|
||||
@builtin("last")
|
||||
def last(coll):
|
||||
return coll[-1] if coll else None
|
||||
|
||||
|
||||
@builtin("chunk-every")
|
||||
def chunk_every(coll, n):
|
||||
"""Split collection into chunks of n elements."""
|
||||
n = int(n)
|
||||
return [coll[i:i+n] for i in range(0, len(coll), n)]
|
||||
|
||||
|
||||
@builtin("rest")
|
||||
def rest(coll):
|
||||
return coll[1:] if coll else []
|
||||
|
||||
|
||||
@builtin("nth")
|
||||
def nth(coll, n):
|
||||
return coll[n] if 0 <= n < len(coll) else None
|
||||
|
||||
|
||||
@builtin("concat")
|
||||
def concat(*colls):
|
||||
"""Concatenate multiple lists/sequences."""
|
||||
result = []
|
||||
for c in colls:
|
||||
if c is not None:
|
||||
result.extend(c)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("cons")
|
||||
def cons(x, coll):
|
||||
"""Prepend x to collection."""
|
||||
return [x] + list(coll) if coll else [x]
|
||||
|
||||
|
||||
@builtin("append")
|
||||
def append(coll, x):
|
||||
"""Append x to collection."""
|
||||
return list(coll) + [x] if coll else [x]
|
||||
|
||||
|
||||
@builtin("range")
|
||||
def make_range(start, end, step=1):
|
||||
"""Create a range of numbers."""
|
||||
return list(range(int(start), int(end), int(step)))
|
||||
|
||||
|
||||
@builtin("zip-pairs")
|
||||
def zip_pairs(coll):
|
||||
"""Zip consecutive pairs: [a,b,c,d] -> [[a,b],[b,c],[c,d]]."""
|
||||
if not coll or len(coll) < 2:
|
||||
return []
|
||||
return [[coll[i], coll[i+1]] for i in range(len(coll)-1)]
|
||||
|
||||
|
||||
@builtin("dict")
|
||||
def make_dict(*pairs):
|
||||
"""Create dict from key-value pairs: (dict :a 1 :b 2)."""
|
||||
result = {}
|
||||
i = 0
|
||||
while i < len(pairs) - 1:
|
||||
key = pairs[i]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = pairs[i + 1]
|
||||
i += 2
|
||||
return result
|
||||
|
||||
|
||||
@builtin("keys")
|
||||
def keys(d):
|
||||
"""Get the keys of a dict as a list."""
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"keys: expected dict, got {type(d).__name__}")
|
||||
return list(d.keys())
|
||||
|
||||
|
||||
@builtin("vals")
|
||||
def vals(d):
|
||||
"""Get the values of a dict as a list."""
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"vals: expected dict, got {type(d).__name__}")
|
||||
return list(d.values())
|
||||
|
||||
|
||||
@builtin("merge")
|
||||
def merge(*dicts):
|
||||
"""Merge multiple dicts, later dicts override earlier."""
|
||||
result = {}
|
||||
for d in dicts:
|
||||
if d is not None:
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"merge: expected dict, got {type(d).__name__}")
|
||||
result.update(d)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("assoc")
|
||||
def assoc(d, *pairs):
|
||||
"""Associate keys with values in a dict: (assoc d :a 1 :b 2)."""
|
||||
if d is None:
|
||||
result = {}
|
||||
elif isinstance(d, dict):
|
||||
result = dict(d)
|
||||
else:
|
||||
raise EvalError(f"assoc: expected dict or nil, got {type(d).__name__}")
|
||||
|
||||
i = 0
|
||||
while i < len(pairs) - 1:
|
||||
key = pairs[i]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = pairs[i + 1]
|
||||
i += 2
|
||||
return result
|
||||
|
||||
|
||||
@builtin("dissoc")
|
||||
def dissoc(d, *keys_to_remove):
|
||||
"""Remove keys from a dict: (dissoc d :a :b)."""
|
||||
if d is None:
|
||||
return {}
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"dissoc: expected dict or nil, got {type(d).__name__}")
|
||||
|
||||
result = dict(d)
|
||||
for key in keys_to_remove:
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result.pop(key, None)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("into")
|
||||
def into(target, coll):
|
||||
"""Convert a collection into another collection type.
|
||||
|
||||
(into [] {:a 1 :b 2}) -> [["a" 1] ["b" 2]]
|
||||
(into {} [[:a 1] [:b 2]]) -> {"a": 1, "b": 2}
|
||||
(into [] [1 2 3]) -> [1 2 3]
|
||||
"""
|
||||
if isinstance(target, list):
|
||||
if isinstance(coll, dict):
|
||||
return [[k, v] for k, v in coll.items()]
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
return list(coll)
|
||||
else:
|
||||
raise EvalError(f"into: cannot convert {type(coll).__name__} into list")
|
||||
elif isinstance(target, dict):
|
||||
if isinstance(coll, dict):
|
||||
return dict(coll)
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
result = {}
|
||||
for item in coll:
|
||||
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
key = item[0]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = item[1]
|
||||
else:
|
||||
raise EvalError(f"into: expected [key value] pairs, got {item}")
|
||||
return result
|
||||
else:
|
||||
raise EvalError(f"into: cannot convert {type(coll).__name__} into dict")
|
||||
else:
|
||||
raise EvalError(f"into: unsupported target type {type(target).__name__}")
|
||||
|
||||
|
||||
@builtin("filter")
|
||||
def filter_fn(pred, coll):
|
||||
"""Filter collection by predicate. Pred must be a lambda."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"filter: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
result = []
|
||||
for item in coll:
|
||||
# Evaluate predicate with item
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
# Inline evaluation of pred.body
|
||||
from . import evaluator
|
||||
if evaluator.evaluate(pred.body, local_env):
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("some")
|
||||
def some(pred, coll):
|
||||
"""Return first truthy value of (pred item) for items in coll, or nil."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"some: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
from . import evaluator
|
||||
result = evaluator.evaluate(pred.body, local_env)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
@builtin("every?")
|
||||
def every(pred, coll):
|
||||
"""Return true if (pred item) is truthy for all items in coll."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"every?: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
from . import evaluator
|
||||
if not evaluator.evaluate(pred.body, local_env):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@builtin("empty?")
|
||||
def is_empty(coll):
|
||||
"""Return true if collection is empty."""
|
||||
if coll is None:
|
||||
return True
|
||||
return len(coll) == 0
|
||||
|
||||
|
||||
@builtin("contains?")
|
||||
def contains(coll, key):
|
||||
"""Check if collection contains key (for dicts) or element (for lists)."""
|
||||
if isinstance(coll, dict):
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
return key in coll
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
return key in coll
|
||||
return False
|
||||
|
||||
|
||||
def evaluate(expr: Any, env: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
Evaluate an S-expression in the given environment.
|
||||
|
||||
Args:
|
||||
expr: The expression to evaluate
|
||||
env: Variable bindings (name -> value)
|
||||
|
||||
Returns:
|
||||
The result of evaluation
|
||||
"""
|
||||
if env is None:
|
||||
env = {}
|
||||
|
||||
# Literals
|
||||
if isinstance(expr, (int, float, str, bool)) or expr is None:
|
||||
return expr
|
||||
|
||||
# Symbol - variable lookup
|
||||
if isinstance(expr, Symbol):
|
||||
name = expr.name
|
||||
if name in env:
|
||||
return env[name]
|
||||
if name in BUILTINS:
|
||||
return BUILTINS[name]
|
||||
if name == "true":
|
||||
return True
|
||||
if name == "false":
|
||||
return False
|
||||
if name == "nil":
|
||||
return None
|
||||
raise EvalError(f"Undefined symbol: {name}")
|
||||
|
||||
# Keyword - return as-is (used as map keys)
|
||||
if isinstance(expr, Keyword):
|
||||
return expr.name
|
||||
|
||||
# Lambda - return as-is (it's a value)
|
||||
if isinstance(expr, Lambda):
|
||||
return expr
|
||||
|
||||
# Binding - return as-is (resolved at execution time)
|
||||
if isinstance(expr, Binding):
|
||||
return expr
|
||||
|
||||
# Dict literal
|
||||
if isinstance(expr, dict):
|
||||
return {k: evaluate(v, env) for k, v in expr.items()}
|
||||
|
||||
# List - function call or special form
|
||||
if isinstance(expr, list):
|
||||
if not expr:
|
||||
return []
|
||||
|
||||
head = expr[0]
|
||||
|
||||
# If head is a string/number/etc (not Symbol), treat as data list
|
||||
if not isinstance(head, (Symbol, Lambda, list)):
|
||||
return [evaluate(x, env) for x in expr]
|
||||
|
||||
# Special forms
|
||||
if isinstance(head, Symbol):
|
||||
name = head.name
|
||||
|
||||
# if - conditional
|
||||
if name == "if":
|
||||
if len(expr) < 3:
|
||||
raise EvalError("if requires condition and then-branch")
|
||||
cond_result = evaluate(expr[1], env)
|
||||
if cond_result:
|
||||
return evaluate(expr[2], env)
|
||||
elif len(expr) > 3:
|
||||
return evaluate(expr[3], env)
|
||||
return None
|
||||
|
||||
# cond - multi-way conditional
|
||||
# Supports both Clojure style: (cond test1 result1 test2 result2 :else default)
|
||||
# and Scheme style: (cond (test1 result1) (test2 result2) (else default))
|
||||
if name == "cond":
|
||||
clauses = expr[1:]
|
||||
# Check if Clojure style (flat list) or Scheme style (nested pairs)
|
||||
# Scheme style: first clause is (test result) - exactly 2 elements
|
||||
# Clojure style: first clause is a test expression like (= x 0) - 3+ elements
|
||||
first_is_scheme_clause = (
|
||||
clauses and
|
||||
isinstance(clauses[0], list) and
|
||||
len(clauses[0]) == 2 and
|
||||
not (isinstance(clauses[0][0], Symbol) and clauses[0][0].name in ('=', '<', '>', '<=', '>=', '!=', 'not=', 'and', 'or'))
|
||||
)
|
||||
if first_is_scheme_clause:
|
||||
# Scheme style: ((test result) ...)
|
||||
for clause in clauses:
|
||||
if not isinstance(clause, list) or len(clause) < 2:
|
||||
raise EvalError("cond clause must be (test result)")
|
||||
test = clause[0]
|
||||
# Check for else/default
|
||||
if isinstance(test, Symbol) and test.name in ("else", ":else"):
|
||||
return evaluate(clause[1], env)
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(clause[1], env)
|
||||
if evaluate(test, env):
|
||||
return evaluate(clause[1], env)
|
||||
else:
|
||||
# Clojure style: test1 result1 test2 result2 ...
|
||||
i = 0
|
||||
while i < len(clauses) - 1:
|
||||
test = clauses[i]
|
||||
result = clauses[i + 1]
|
||||
# Check for :else
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(result, env)
|
||||
if isinstance(test, Symbol) and test.name == ":else":
|
||||
return evaluate(result, env)
|
||||
if evaluate(test, env):
|
||||
return evaluate(result, env)
|
||||
i += 2
|
||||
return None
|
||||
|
||||
# case - switch on value
|
||||
# (case expr val1 result1 val2 result2 :else default)
|
||||
if name == "case":
|
||||
if len(expr) < 2:
|
||||
raise EvalError("case requires expression to match")
|
||||
match_val = evaluate(expr[1], env)
|
||||
clauses = expr[2:]
|
||||
i = 0
|
||||
while i < len(clauses) - 1:
|
||||
test = clauses[i]
|
||||
result = clauses[i + 1]
|
||||
# Check for :else / else
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(result, env)
|
||||
if isinstance(test, Symbol) and test.name in (":else", "else"):
|
||||
return evaluate(result, env)
|
||||
# Evaluate test value and compare
|
||||
test_val = evaluate(test, env)
|
||||
if match_val == test_val:
|
||||
return evaluate(result, env)
|
||||
i += 2
|
||||
return None
|
||||
|
||||
# and - short-circuit
|
||||
if name == "and":
|
||||
result = True
|
||||
for arg in expr[1:]:
|
||||
result = evaluate(arg, env)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# or - short-circuit
|
||||
if name == "or":
|
||||
result = False
|
||||
for arg in expr[1:]:
|
||||
result = evaluate(arg, env)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# let and let* - local bindings (both bind sequentially in this impl)
|
||||
if name in ("let", "let*"):
|
||||
if len(expr) < 3:
|
||||
raise EvalError(f"{name} requires bindings and body")
|
||||
bindings = expr[1]
|
||||
|
||||
local_env = dict(env)
|
||||
|
||||
if isinstance(bindings, list):
|
||||
# Check if it's ((name value) ...) style (Lisp let* style)
|
||||
if bindings and isinstance(bindings[0], list):
|
||||
for binding in bindings:
|
||||
if len(binding) != 2:
|
||||
raise EvalError(f"{name} binding must be (name value)")
|
||||
var_name = binding[0]
|
||||
if isinstance(var_name, Symbol):
|
||||
var_name = var_name.name
|
||||
value = evaluate(binding[1], local_env)
|
||||
local_env[var_name] = value
|
||||
# Vector-style [name value ...]
|
||||
elif len(bindings) % 2 == 0:
|
||||
for i in range(0, len(bindings), 2):
|
||||
var_name = bindings[i]
|
||||
if isinstance(var_name, Symbol):
|
||||
var_name = var_name.name
|
||||
value = evaluate(bindings[i + 1], local_env)
|
||||
local_env[var_name] = value
|
||||
else:
|
||||
raise EvalError(f"{name} bindings must be [name value ...] or ((name value) ...)")
|
||||
else:
|
||||
raise EvalError(f"{name} bindings must be a list")
|
||||
|
||||
return evaluate(expr[2], local_env)
|
||||
|
||||
# lambda / fn - create function with closure
|
||||
if name in ("lambda", "fn"):
|
||||
if len(expr) < 3:
|
||||
raise EvalError("lambda requires params and body")
|
||||
params = expr[1]
|
||||
if not isinstance(params, list):
|
||||
raise EvalError("lambda params must be a list")
|
||||
param_names = []
|
||||
for p in params:
|
||||
if isinstance(p, Symbol):
|
||||
param_names.append(p.name)
|
||||
elif isinstance(p, str):
|
||||
param_names.append(p)
|
||||
else:
|
||||
raise EvalError(f"Invalid param: {p}")
|
||||
# Capture current environment as closure
|
||||
return Lambda(param_names, expr[2], dict(env))
|
||||
|
||||
# quote - return unevaluated
|
||||
if name == "quote":
|
||||
return expr[1] if len(expr) > 1 else None
|
||||
|
||||
# bind - create binding to analysis data
|
||||
# (bind analysis-var)
|
||||
# (bind analysis-var :range [0.3 1.0])
|
||||
# (bind analysis-var :range [0 100] :transform sqrt)
|
||||
if name == "bind":
|
||||
if len(expr) < 2:
|
||||
raise EvalError("bind requires analysis reference")
|
||||
analysis_ref = expr[1]
|
||||
if isinstance(analysis_ref, Symbol):
|
||||
symbol_name = analysis_ref.name
|
||||
# Look up the symbol in environment
|
||||
if symbol_name in env:
|
||||
resolved = env[symbol_name]
|
||||
# If resolved is actual analysis data (dict with times/values or
|
||||
# S-expression list with Keywords), keep the symbol name as reference
|
||||
# for later lookup at execution time
|
||||
if isinstance(resolved, dict) and ("times" in resolved or "values" in resolved):
|
||||
analysis_ref = symbol_name # Use name as reference, not the data
|
||||
elif isinstance(resolved, list) and any(isinstance(x, Keyword) for x in resolved):
|
||||
# Parsed S-expression analysis data ([:times [...] :duration ...])
|
||||
analysis_ref = symbol_name
|
||||
else:
|
||||
analysis_ref = resolved
|
||||
else:
|
||||
raise EvalError(f"bind: undefined symbol '{symbol_name}' - must reference analysis data")
|
||||
|
||||
# Parse optional :range [min max] and :transform
|
||||
range_min, range_max = 0.0, 1.0
|
||||
transform = None
|
||||
i = 2
|
||||
while i < len(expr):
|
||||
if isinstance(expr[i], Keyword):
|
||||
kw = expr[i].name
|
||||
if kw == "range" and i + 1 < len(expr):
|
||||
range_val = evaluate(expr[i + 1], env) # Evaluate to get actual value
|
||||
if isinstance(range_val, list) and len(range_val) >= 2:
|
||||
range_min = float(range_val[0])
|
||||
range_max = float(range_val[1])
|
||||
i += 2
|
||||
elif kw == "transform" and i + 1 < len(expr):
|
||||
t = expr[i + 1]
|
||||
if isinstance(t, Symbol):
|
||||
transform = t.name
|
||||
elif isinstance(t, str):
|
||||
transform = t
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return Binding(analysis_ref, range_min=range_min, range_max=range_max, transform=transform)
|
||||
|
||||
# Vector literal [a b c]
|
||||
if name == "vec" or name == "vector":
|
||||
return [evaluate(e, env) for e in expr[1:]]
|
||||
|
||||
# map - (map fn coll)
|
||||
if name == "map":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("map requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"map requires lambda, got {type(fn)}")
|
||||
result = []
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = item
|
||||
result.append(evaluate(fn.body, local_env))
|
||||
return result
|
||||
|
||||
# map-indexed - (map-indexed fn coll)
|
||||
if name == "map-indexed":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("map-indexed requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"map-indexed requires lambda, got {type(fn)}")
|
||||
if len(fn.params) < 2:
|
||||
raise EvalError("map-indexed lambda needs (i item) params")
|
||||
result = []
|
||||
for i, item in enumerate(coll):
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = i
|
||||
local_env[fn.params[1]] = item
|
||||
result.append(evaluate(fn.body, local_env))
|
||||
return result
|
||||
|
||||
# reduce - (reduce fn init coll)
|
||||
if name == "reduce":
|
||||
if len(expr) != 4:
|
||||
raise EvalError("reduce requires fn, init, and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
acc = evaluate(expr[2], env)
|
||||
coll = evaluate(expr[3], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"reduce requires lambda, got {type(fn)}")
|
||||
if len(fn.params) < 2:
|
||||
raise EvalError("reduce lambda needs (acc item) params")
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = acc
|
||||
local_env[fn.params[1]] = item
|
||||
acc = evaluate(fn.body, local_env)
|
||||
return acc
|
||||
|
||||
# for-each - (for-each fn coll) - iterate with side effects
|
||||
if name == "for-each":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("for-each requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"for-each requires lambda, got {type(fn)}")
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = item
|
||||
evaluate(fn.body, local_env)
|
||||
return None
|
||||
|
||||
# Function call
|
||||
fn = evaluate(head, env)
|
||||
args = [evaluate(arg, env) for arg in expr[1:]]
|
||||
|
||||
# Call builtin
|
||||
if callable(fn):
|
||||
return fn(*args)
|
||||
|
||||
# Call lambda
|
||||
if isinstance(fn, Lambda):
|
||||
if len(args) != len(fn.params):
|
||||
raise EvalError(f"Lambda expects {len(fn.params)} args, got {len(args)}")
|
||||
# Start with closure (captured env), then overlay calling env, then params
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
for name, value in zip(fn.params, args):
|
||||
local_env[name] = value
|
||||
return evaluate(fn.body, local_env)
|
||||
|
||||
raise EvalError(f"Not callable: {fn}")
|
||||
|
||||
raise EvalError(f"Cannot evaluate: {expr!r}")
|
||||
|
||||
|
||||
def make_env(**kwargs) -> Dict[str, Any]:
|
||||
"""Create an environment with initial bindings."""
|
||||
return dict(kwargs)
|
||||
292
artdag/sexp/external_tools.py
Normal file
292
artdag/sexp/external_tools.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
External tool runners for effects that can't be done in FFmpeg.
|
||||
|
||||
Supports:
|
||||
- datamosh: via ffglitch or datamoshing Python CLI
|
||||
- pixelsort: via Rust pixelsort or Python pixelsort-cli
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def find_tool(tool_names: List[str]) -> Optional[str]:
|
||||
"""Find first available tool from a list of candidates."""
|
||||
for name in tool_names:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def check_python_package(package: str) -> bool:
|
||||
"""Check if a Python package is installed."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python3", "-c", f"import {package}"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Tool detection
|
||||
DATAMOSH_TOOLS = ["ffgac", "ffedit"] # ffglitch tools
|
||||
PIXELSORT_TOOLS = ["pixelsort"] # Rust CLI
|
||||
|
||||
|
||||
def get_available_tools() -> Dict[str, Optional[str]]:
|
||||
"""Get dictionary of available external tools."""
|
||||
return {
|
||||
"datamosh": find_tool(DATAMOSH_TOOLS),
|
||||
"pixelsort": find_tool(PIXELSORT_TOOLS),
|
||||
"datamosh_python": "datamoshing" if check_python_package("datamoshing") else None,
|
||||
"pixelsort_python": "pixelsort" if check_python_package("pixelsort") else None,
|
||||
}
|
||||
|
||||
|
||||
def run_datamosh(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run datamosh effect using available tool.
|
||||
|
||||
Args:
|
||||
input_path: Input video file
|
||||
output_path: Output video file
|
||||
params: Effect parameters (corruption, block_size, etc.)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
tools = get_available_tools()
|
||||
|
||||
corruption = params.get("corruption", 0.3)
|
||||
|
||||
# Try ffglitch first
|
||||
if tools.get("datamosh"):
|
||||
ffgac = tools["datamosh"]
|
||||
try:
|
||||
# ffglitch approach: remove I-frames to create datamosh effect
|
||||
# This is a simplified version - full datamosh needs custom scripts
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
# Write a simple ffglitch script that corrupts motion vectors
|
||||
f.write(f"""
|
||||
// Datamosh script - corrupt motion vectors
|
||||
let corruption = {corruption};
|
||||
|
||||
export function glitch_frame(frame, stream) {{
|
||||
if (frame.pict_type === 'P' && Math.random() < corruption) {{
|
||||
// Corrupt motion vectors
|
||||
let dominated = frame.mv?.forward?.overflow;
|
||||
if (dominated) {{
|
||||
for (let i = 0; i < dominated.length; i++) {{
|
||||
if (Math.random() < corruption) {{
|
||||
dominated[i] = [
|
||||
Math.floor(Math.random() * 64 - 32),
|
||||
Math.floor(Math.random() * 64 - 32)
|
||||
];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
return frame;
|
||||
}}
|
||||
""")
|
||||
script_path = f.name
|
||||
|
||||
cmd = [
|
||||
ffgac,
|
||||
"-i", str(input_path),
|
||||
"-s", script_path,
|
||||
"-o", str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
Path(script_path).unlink(missing_ok=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Datamosh timeout"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
# Fall back to Python datamoshing package
|
||||
if tools.get("datamosh_python"):
|
||||
try:
|
||||
cmd = [
|
||||
"python3", "-m", "datamoshing",
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
"--mode", "iframe_removal",
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
return False, "No datamosh tool available. Install ffglitch or: pip install datamoshing"
|
||||
|
||||
|
||||
def run_pixelsort(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run pixelsort effect using available tool.
|
||||
|
||||
Args:
|
||||
input_path: Input image/frame file
|
||||
output_path: Output image file
|
||||
params: Effect parameters (sort_by, threshold_low, threshold_high, angle)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
tools = get_available_tools()
|
||||
|
||||
sort_by = params.get("sort_by", "lightness")
|
||||
threshold_low = params.get("threshold_low", 50)
|
||||
threshold_high = params.get("threshold_high", 200)
|
||||
angle = params.get("angle", 0)
|
||||
|
||||
# Try Rust pixelsort first (faster)
|
||||
if tools.get("pixelsort"):
|
||||
try:
|
||||
cmd = [
|
||||
tools["pixelsort"],
|
||||
str(input_path),
|
||||
"-o", str(output_path),
|
||||
"--sort", sort_by,
|
||||
"-r", str(angle),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
# Fall back to Python pixelsort-cli
|
||||
if tools.get("pixelsort_python"):
|
||||
try:
|
||||
cmd = [
|
||||
"python3", "-m", "pixelsort",
|
||||
"--image_path", str(input_path),
|
||||
"--output", str(output_path),
|
||||
"--angle", str(angle),
|
||||
"--threshold", str(threshold_low / 255), # Normalize to 0-1
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
return False, "No pixelsort tool available. Install: cargo install pixelsort or pip install pixelsort-cli"
|
||||
|
||||
|
||||
def run_pixelsort_video(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
fps: float = 30,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run pixelsort on a video by processing each frame.
|
||||
|
||||
This extracts frames, processes them, then reassembles.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
frames_in = tmpdir / "frame_%04d.png"
|
||||
frames_out = tmpdir / "sorted_%04d.png"
|
||||
|
||||
# Extract frames
|
||||
extract_cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(input_path),
|
||||
"-vf", f"fps={fps}",
|
||||
str(frames_in),
|
||||
]
|
||||
result = subprocess.run(extract_cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
return False, "Failed to extract frames"
|
||||
|
||||
# Process each frame
|
||||
frame_files = sorted(tmpdir.glob("frame_*.png"))
|
||||
for i, frame in enumerate(frame_files):
|
||||
out_frame = tmpdir / f"sorted_{i:04d}.png"
|
||||
success, error = run_pixelsort(frame, out_frame, params)
|
||||
if not success:
|
||||
return False, f"Frame {i}: {error}"
|
||||
|
||||
# Reassemble
|
||||
# Get audio from original
|
||||
reassemble_cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-framerate", str(fps),
|
||||
"-i", str(tmpdir / "sorted_%04d.png"),
|
||||
"-i", str(input_path),
|
||||
"-map", "0:v", "-map", "1:a?",
|
||||
"-c:v", "libx264", "-preset", "fast",
|
||||
"-c:a", "copy",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(reassemble_cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
return False, "Failed to reassemble video"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def run_external_effect(
|
||||
effect_name: str,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
is_video: bool = True,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run an external effect tool.
|
||||
|
||||
Args:
|
||||
effect_name: Name of effect (datamosh, pixelsort)
|
||||
input_path: Input file
|
||||
output_path: Output file
|
||||
params: Effect parameters
|
||||
is_video: Whether input is video (vs single image)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
if effect_name == "datamosh":
|
||||
return run_datamosh(input_path, output_path, params)
|
||||
elif effect_name == "pixelsort":
|
||||
if is_video:
|
||||
return run_pixelsort_video(input_path, output_path, params)
|
||||
else:
|
||||
return run_pixelsort(input_path, output_path, params)
|
||||
else:
|
||||
return False, f"Unknown external effect: {effect_name}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Print available tools
|
||||
print("Available external tools:")
|
||||
for name, path in get_available_tools().items():
|
||||
status = path if path else "NOT INSTALLED"
|
||||
print(f" {name}: {status}")
|
||||
616
artdag/sexp/ffmpeg_compiler.py
Normal file
616
artdag/sexp/ffmpeg_compiler.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
FFmpeg filter compiler for sexp effects.
|
||||
|
||||
Compiles sexp effect definitions to FFmpeg filter expressions,
|
||||
with support for dynamic parameters via sendcmd scripts.
|
||||
|
||||
Usage:
|
||||
compiler = FFmpegCompiler()
|
||||
|
||||
# Compile an effect with static params
|
||||
filter_str = compiler.compile_effect("brightness", {"amount": 50})
|
||||
# -> "eq=brightness=0.196"
|
||||
|
||||
# Compile with dynamic binding to analysis data
|
||||
filter_str, sendcmd = compiler.compile_effect_with_binding(
|
||||
"brightness",
|
||||
{"amount": {"_bind": "bass-data", "range_min": 0, "range_max": 100}},
|
||||
analysis_data={"bass-data": {"times": [...], "values": [...]}},
|
||||
segment_start=0.0,
|
||||
segment_duration=5.0,
|
||||
)
|
||||
# -> ("eq=brightness=0.5", "0.0 [eq] brightness 0.5;\n0.05 [eq] brightness 0.6;...")
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# FFmpeg filter mappings for common effects
|
||||
# Maps effect name -> {filter: str, params: {param_name: {ffmpeg_param, scale, offset}}}
|
||||
EFFECT_MAPPINGS = {
|
||||
"invert": {
|
||||
"filter": "negate",
|
||||
"params": {},
|
||||
},
|
||||
"grayscale": {
|
||||
"filter": "colorchannelmixer",
|
||||
"static": "0.3:0.4:0.3:0:0.3:0.4:0.3:0:0.3:0.4:0.3",
|
||||
"params": {},
|
||||
},
|
||||
"sepia": {
|
||||
"filter": "colorchannelmixer",
|
||||
"static": "0.393:0.769:0.189:0:0.349:0.686:0.168:0:0.272:0.534:0.131",
|
||||
"params": {},
|
||||
},
|
||||
"brightness": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
|
||||
},
|
||||
},
|
||||
"contrast": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"saturation": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "saturation", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"hue_shift": {
|
||||
"filter": "hue",
|
||||
"params": {
|
||||
"degrees": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"blur": {
|
||||
"filter": "gblur",
|
||||
"params": {
|
||||
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"sharpen": {
|
||||
"filter": "unsharp",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "la", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"pixelate": {
|
||||
# Scale down then up to create pixelation effect
|
||||
"filter": "scale",
|
||||
"static": "iw/8:ih/8:flags=neighbor,scale=iw*8:ih*8:flags=neighbor",
|
||||
"params": {},
|
||||
},
|
||||
"vignette": {
|
||||
"filter": "vignette",
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "a", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"noise": {
|
||||
"filter": "noise",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "alls", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"flip": {
|
||||
"filter": "hflip", # Default horizontal
|
||||
"params": {},
|
||||
},
|
||||
"mirror": {
|
||||
"filter": "hflip",
|
||||
"params": {},
|
||||
},
|
||||
"rotate": {
|
||||
"filter": "rotate",
|
||||
"params": {
|
||||
"angle": {"ffmpeg_param": "a", "scale": math.pi/180, "offset": 0}, # degrees to radians
|
||||
},
|
||||
},
|
||||
"zoom": {
|
||||
"filter": "zoompan",
|
||||
"params": {
|
||||
"factor": {"ffmpeg_param": "z", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"posterize": {
|
||||
# Use lutyuv to quantize levels (approximate posterization)
|
||||
"filter": "lutyuv",
|
||||
"static": "y=floor(val/32)*32:u=floor(val/32)*32:v=floor(val/32)*32",
|
||||
"params": {},
|
||||
},
|
||||
"threshold": {
|
||||
# Use geq for thresholding
|
||||
"filter": "geq",
|
||||
"static": "lum='if(gt(lum(X,Y),128),255,0)':cb=128:cr=128",
|
||||
"params": {},
|
||||
},
|
||||
"edge_detect": {
|
||||
"filter": "edgedetect",
|
||||
"params": {
|
||||
"low": {"ffmpeg_param": "low", "scale": 1/255, "offset": 0},
|
||||
"high": {"ffmpeg_param": "high", "scale": 1/255, "offset": 0},
|
||||
},
|
||||
},
|
||||
"swirl": {
|
||||
"filter": "lenscorrection", # Approximate with lens distortion
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "k1", "scale": 0.1, "offset": 0},
|
||||
},
|
||||
},
|
||||
"fisheye": {
|
||||
"filter": "lenscorrection",
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "k1", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"wave": {
|
||||
# Wave displacement using geq - need r/g/b for RGB mode
|
||||
"filter": "geq",
|
||||
"static": "r='r(X+10*sin(Y/20),Y)':g='g(X+10*sin(Y/20),Y)':b='b(X+10*sin(Y/20),Y)'",
|
||||
"params": {},
|
||||
},
|
||||
"rgb_split": {
|
||||
# Chromatic aberration using geq
|
||||
"filter": "geq",
|
||||
"static": "r='p(X+5,Y)':g='p(X,Y)':b='p(X-5,Y)'",
|
||||
"params": {},
|
||||
},
|
||||
"scanlines": {
|
||||
"filter": "drawgrid",
|
||||
"params": {
|
||||
"spacing": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"film_grain": {
|
||||
"filter": "noise",
|
||||
"params": {
|
||||
"intensity": {"ffmpeg_param": "alls", "scale": 100, "offset": 0},
|
||||
},
|
||||
},
|
||||
"crt": {
|
||||
"filter": "vignette", # Simplified - just vignette for CRT look
|
||||
"params": {},
|
||||
},
|
||||
"bloom": {
|
||||
"filter": "gblur", # Simplified bloom = blur overlay
|
||||
"params": {
|
||||
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"color_cycle": {
|
||||
"filter": "hue",
|
||||
"params": {
|
||||
"speed": {"ffmpeg_param": "h", "scale": 360.0, "offset": 0, "time_expr": True},
|
||||
},
|
||||
"time_based": True, # Uses time expression
|
||||
},
|
||||
"strobe": {
|
||||
# Strobe using select to drop frames
|
||||
"filter": "select",
|
||||
"static": "'mod(n,4)'",
|
||||
"params": {},
|
||||
},
|
||||
"echo": {
|
||||
# Echo using tmix
|
||||
"filter": "tmix",
|
||||
"static": "frames=4:weights='1 0.5 0.25 0.125'",
|
||||
"params": {},
|
||||
},
|
||||
"trails": {
|
||||
# Trails using tblend
|
||||
"filter": "tblend",
|
||||
"static": "all_mode=average",
|
||||
"params": {},
|
||||
},
|
||||
"kaleidoscope": {
|
||||
# 4-way mirror kaleidoscope using FFmpeg filter chain
|
||||
# Crops top-left quadrant, mirrors horizontally, then vertically
|
||||
"filter": "crop",
|
||||
"complex": True,
|
||||
"static": "iw/2:ih/2:0:0[q];[q]split[q1][q2];[q1]hflip[qr];[q2][qr]hstack[top];[top]split[t1][t2];[t2]vflip[bot];[t1][bot]vstack",
|
||||
"params": {},
|
||||
},
|
||||
"emboss": {
|
||||
"filter": "convolution",
|
||||
"static": "-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2",
|
||||
"params": {},
|
||||
},
|
||||
"neon_glow": {
|
||||
# Edge detect + negate for neon-like effect
|
||||
"filter": "edgedetect",
|
||||
"static": "mode=colormix:high=0.1",
|
||||
"params": {},
|
||||
},
|
||||
"ascii_art": {
|
||||
# Requires Python frame processing - no FFmpeg equivalent
|
||||
"filter": None,
|
||||
"python_primitive": "ascii_art_frame",
|
||||
"params": {
|
||||
"char_size": {"default": 8},
|
||||
"alphabet": {"default": "standard"},
|
||||
"color_mode": {"default": "color"},
|
||||
},
|
||||
},
|
||||
"ascii_zones": {
|
||||
# Requires Python frame processing - zone-based ASCII
|
||||
"filter": None,
|
||||
"python_primitive": "ascii_zones_frame",
|
||||
"params": {
|
||||
"char_size": {"default": 8},
|
||||
"zone_threshold": {"default": 128},
|
||||
},
|
||||
},
|
||||
"datamosh": {
|
||||
# External tool: ffglitch or datamoshing CLI, falls back to Python
|
||||
"filter": None,
|
||||
"external_tool": "datamosh",
|
||||
"python_primitive": "datamosh_frame",
|
||||
"params": {
|
||||
"block_size": {"default": 32},
|
||||
"corruption": {"default": 0.3},
|
||||
},
|
||||
},
|
||||
"pixelsort": {
|
||||
# External tool: pixelsort CLI (Rust or Python), falls back to Python
|
||||
"filter": None,
|
||||
"external_tool": "pixelsort",
|
||||
"python_primitive": "pixelsort_frame",
|
||||
"params": {
|
||||
"sort_by": {"default": "lightness"},
|
||||
"threshold_low": {"default": 50},
|
||||
"threshold_high": {"default": 200},
|
||||
"angle": {"default": 0},
|
||||
},
|
||||
},
|
||||
"ripple": {
|
||||
# Use geq for ripple displacement
|
||||
"filter": "geq",
|
||||
"static": "lum='lum(X+5*sin(hypot(X-W/2,Y-H/2)/10),Y+5*cos(hypot(X-W/2,Y-H/2)/10))'",
|
||||
"params": {},
|
||||
},
|
||||
"tile_grid": {
|
||||
# Use tile filter for grid
|
||||
"filter": "tile",
|
||||
"static": "2x2",
|
||||
"params": {},
|
||||
},
|
||||
"outline": {
|
||||
"filter": "edgedetect",
|
||||
"params": {},
|
||||
},
|
||||
"color-adjust": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"brightness": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
|
||||
"contrast": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class FFmpegCompiler:
|
||||
"""Compiles sexp effects to FFmpeg filters with sendcmd support."""
|
||||
|
||||
def __init__(self, effect_mappings: Dict = None):
|
||||
self.mappings = effect_mappings or EFFECT_MAPPINGS
|
||||
|
||||
def get_full_mapping(self, effect_name: str) -> Optional[Dict]:
|
||||
"""Get full mapping for an effect (including external tools and python primitives)."""
|
||||
mapping = self.mappings.get(effect_name)
|
||||
if not mapping:
|
||||
# Try with underscores/hyphens converted
|
||||
normalized = effect_name.replace("-", "_").replace(" ", "_").lower()
|
||||
mapping = self.mappings.get(normalized)
|
||||
return mapping
|
||||
|
||||
def get_mapping(self, effect_name: str) -> Optional[Dict]:
|
||||
"""Get FFmpeg filter mapping for an effect (returns None for non-FFmpeg effects)."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
# Return None if no mapping or no FFmpeg filter
|
||||
if mapping and mapping.get("filter") is None:
|
||||
return None
|
||||
return mapping
|
||||
|
||||
def has_external_tool(self, effect_name: str) -> Optional[str]:
|
||||
"""Check if effect uses an external tool. Returns tool name or None."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
if mapping:
|
||||
return mapping.get("external_tool")
|
||||
return None
|
||||
|
||||
def has_python_primitive(self, effect_name: str) -> Optional[str]:
|
||||
"""Check if effect uses a Python primitive. Returns primitive name or None."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
if mapping:
|
||||
return mapping.get("python_primitive")
|
||||
return None
|
||||
|
||||
def is_complex_filter(self, effect_name: str) -> bool:
|
||||
"""Check if effect uses a complex filter chain."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
return bool(mapping and mapping.get("complex"))
|
||||
|
||||
def compile_effect(
|
||||
self,
|
||||
effect_name: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compile an effect to an FFmpeg filter string with static params.
|
||||
|
||||
Returns None if effect has no FFmpeg mapping.
|
||||
"""
|
||||
mapping = self.get_mapping(effect_name)
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
filter_name = mapping["filter"]
|
||||
|
||||
# Handle static filters (no params)
|
||||
if "static" in mapping:
|
||||
return f"{filter_name}={mapping['static']}"
|
||||
|
||||
if not mapping.get("params"):
|
||||
return filter_name
|
||||
|
||||
# Build param string
|
||||
filter_params = []
|
||||
for param_name, param_config in mapping["params"].items():
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
# Skip if it's a binding (handled separately)
|
||||
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
|
||||
continue
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
offset = param_config.get("offset", 0)
|
||||
# Handle various value types
|
||||
if isinstance(value, (int, float)):
|
||||
ffmpeg_value = value * scale + offset
|
||||
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
elif isinstance(value, str):
|
||||
filter_params.append(f"{ffmpeg_param}={value}")
|
||||
elif isinstance(value, list) and value and isinstance(value[0], (int, float)):
|
||||
ffmpeg_value = value[0] * scale + offset
|
||||
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
|
||||
if filter_params:
|
||||
return f"{filter_name}={':'.join(filter_params)}"
|
||||
return filter_name
|
||||
|
||||
def compile_effect_with_bindings(
|
||||
self,
|
||||
effect_name: str,
|
||||
params: Dict[str, Any],
|
||||
analysis_data: Dict[str, Dict],
|
||||
segment_start: float,
|
||||
segment_duration: float,
|
||||
sample_interval: float = 0.04, # ~25 fps
|
||||
) -> Tuple[Optional[str], Optional[str], List[str]]:
|
||||
"""
|
||||
Compile an effect with dynamic bindings to a filter + sendcmd script.
|
||||
|
||||
Returns:
|
||||
(filter_string, sendcmd_script, bound_param_names)
|
||||
- filter_string: Initial FFmpeg filter (may have placeholder values)
|
||||
- sendcmd_script: Script content for sendcmd filter
|
||||
- bound_param_names: List of params that have bindings
|
||||
"""
|
||||
mapping = self.get_mapping(effect_name)
|
||||
if not mapping:
|
||||
return None, None, []
|
||||
|
||||
filter_name = mapping["filter"]
|
||||
static_params = []
|
||||
bound_params = []
|
||||
sendcmd_lines = []
|
||||
|
||||
# Handle time-based effects (use FFmpeg expressions with 't')
|
||||
if mapping.get("time_based"):
|
||||
for param_name, param_config in mapping.get("params", {}).items():
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
if isinstance(value, (int, float)):
|
||||
# Create time expression: h='t*speed*scale'
|
||||
static_params.append(f"{ffmpeg_param}='t*{value}*{scale}'")
|
||||
else:
|
||||
static_params.append(f"{ffmpeg_param}='t*{scale}'")
|
||||
if static_params:
|
||||
filter_str = f"{filter_name}={':'.join(static_params)}"
|
||||
else:
|
||||
filter_str = f"{filter_name}=h='t*360'" # Default rotation
|
||||
return filter_str, None, []
|
||||
|
||||
# Process each param
|
||||
for param_name, param_config in mapping.get("params", {}).items():
|
||||
if param_name not in params:
|
||||
continue
|
||||
|
||||
value = params[param_name]
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
offset = param_config.get("offset", 0)
|
||||
|
||||
# Check if it's a binding
|
||||
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
|
||||
bind_ref = value.get("_bind") or value.get("_binding")
|
||||
range_min = value.get("range_min", 0.0)
|
||||
range_max = value.get("range_max", 1.0)
|
||||
transform = value.get("transform")
|
||||
|
||||
# Get analysis data
|
||||
analysis = analysis_data.get(bind_ref)
|
||||
if not analysis:
|
||||
# Try without -data suffix
|
||||
analysis = analysis_data.get(bind_ref.replace("-data", ""))
|
||||
|
||||
if analysis and "times" in analysis and "values" in analysis:
|
||||
times = analysis["times"]
|
||||
values = analysis["values"]
|
||||
|
||||
# Generate sendcmd entries for this segment
|
||||
segment_end = segment_start + segment_duration
|
||||
t = 0.0 # Time relative to segment start
|
||||
|
||||
while t < segment_duration:
|
||||
abs_time = segment_start + t
|
||||
|
||||
# Find analysis value at this time
|
||||
raw_value = self._interpolate_value(times, values, abs_time)
|
||||
|
||||
# Apply transform
|
||||
if transform == "sqrt":
|
||||
raw_value = math.sqrt(max(0, raw_value))
|
||||
elif transform == "pow2":
|
||||
raw_value = raw_value ** 2
|
||||
elif transform == "log":
|
||||
raw_value = math.log(max(0.001, raw_value))
|
||||
|
||||
# Map to range
|
||||
mapped_value = range_min + raw_value * (range_max - range_min)
|
||||
|
||||
# Apply FFmpeg scaling
|
||||
ffmpeg_value = mapped_value * scale + offset
|
||||
|
||||
# Add sendcmd line (time relative to segment)
|
||||
sendcmd_lines.append(f"{t:.3f} [{filter_name}] {ffmpeg_param} {ffmpeg_value:.4f};")
|
||||
|
||||
t += sample_interval
|
||||
|
||||
bound_params.append(param_name)
|
||||
# Use initial value for the filter string
|
||||
initial_value = self._interpolate_value(times, values, segment_start)
|
||||
initial_mapped = range_min + initial_value * (range_max - range_min)
|
||||
initial_ffmpeg = initial_mapped * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={initial_ffmpeg:.4f}")
|
||||
else:
|
||||
# No analysis data, use range midpoint
|
||||
mid_value = (range_min + range_max) / 2
|
||||
ffmpeg_value = mid_value * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
else:
|
||||
# Static value - handle various types
|
||||
if isinstance(value, (int, float)):
|
||||
ffmpeg_value = value * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
elif isinstance(value, str):
|
||||
# String value - use as-is (e.g., for direction parameters)
|
||||
static_params.append(f"{ffmpeg_param}={value}")
|
||||
elif isinstance(value, list) and value:
|
||||
# List - try to use first numeric element
|
||||
first = value[0]
|
||||
if isinstance(first, (int, float)):
|
||||
ffmpeg_value = first * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
# Skip other types
|
||||
|
||||
# Handle static filters
|
||||
if "static" in mapping:
|
||||
filter_str = f"{filter_name}={mapping['static']}"
|
||||
elif static_params:
|
||||
filter_str = f"{filter_name}={':'.join(static_params)}"
|
||||
else:
|
||||
filter_str = filter_name
|
||||
|
||||
# Combine sendcmd lines
|
||||
sendcmd_script = "\n".join(sendcmd_lines) if sendcmd_lines else None
|
||||
|
||||
return filter_str, sendcmd_script, bound_params
|
||||
|
||||
def _interpolate_value(
|
||||
self,
|
||||
times: List[float],
|
||||
values: List[float],
|
||||
target_time: float,
|
||||
) -> float:
|
||||
"""Interpolate a value from analysis data at a given time."""
|
||||
if not times or not values:
|
||||
return 0.5
|
||||
|
||||
# Find surrounding indices
|
||||
if target_time <= times[0]:
|
||||
return values[0]
|
||||
if target_time >= times[-1]:
|
||||
return values[-1]
|
||||
|
||||
# Binary search for efficiency
|
||||
lo, hi = 0, len(times) - 1
|
||||
while lo < hi - 1:
|
||||
mid = (lo + hi) // 2
|
||||
if times[mid] <= target_time:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
# Linear interpolation
|
||||
t0, t1 = times[lo], times[hi]
|
||||
v0, v1 = values[lo], values[hi]
|
||||
|
||||
if t1 == t0:
|
||||
return v0
|
||||
|
||||
alpha = (target_time - t0) / (t1 - t0)
|
||||
return v0 + alpha * (v1 - v0)
|
||||
|
||||
|
||||
def generate_sendcmd_filter(
|
||||
effects: List[Dict],
|
||||
analysis_data: Dict[str, Dict],
|
||||
segment_start: float,
|
||||
segment_duration: float,
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
"""
|
||||
Generate FFmpeg filter chain with sendcmd for dynamic effects.
|
||||
|
||||
Args:
|
||||
effects: List of effect configs with name and params
|
||||
analysis_data: Analysis data keyed by name
|
||||
segment_start: Segment start time in source
|
||||
segment_duration: Segment duration
|
||||
|
||||
Returns:
|
||||
(filter_chain_string, sendcmd_file_path or None)
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
compiler = FFmpegCompiler()
|
||||
filters = []
|
||||
all_sendcmd_lines = []
|
||||
|
||||
for effect in effects:
|
||||
effect_name = effect.get("effect")
|
||||
params = {k: v for k, v in effect.items() if k != "effect"}
|
||||
|
||||
filter_str, sendcmd, _ = compiler.compile_effect_with_bindings(
|
||||
effect_name,
|
||||
params,
|
||||
analysis_data,
|
||||
segment_start,
|
||||
segment_duration,
|
||||
)
|
||||
|
||||
if filter_str:
|
||||
filters.append(filter_str)
|
||||
if sendcmd:
|
||||
all_sendcmd_lines.append(sendcmd)
|
||||
|
||||
if not filters:
|
||||
return "", None
|
||||
|
||||
filter_chain = ",".join(filters)
|
||||
|
||||
# NOTE: sendcmd disabled - FFmpeg's sendcmd filter has compatibility issues.
|
||||
# Bindings use their initial value (sampled at segment start time).
|
||||
# This is acceptable since each segment is only ~8 seconds.
|
||||
# The binding value is still music-reactive (varies per segment).
|
||||
sendcmd_path = None
|
||||
|
||||
return filter_chain, sendcmd_path
|
||||
425
artdag/sexp/parser.py
Normal file
425
artdag/sexp/parser.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
S-expression parser for ArtDAG recipes and plans.
|
||||
|
||||
Supports:
|
||||
- Lists: (a b c)
|
||||
- Symbols: foo, bar-baz, ->
|
||||
- Keywords: :key
|
||||
- Strings: "hello world"
|
||||
- Numbers: 42, 3.14, -1.5
|
||||
- Comments: ; to end of line
|
||||
- Vectors: [a b c] (syntactic sugar for lists)
|
||||
- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Union
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class Symbol:
|
||||
"""An unquoted symbol/identifier."""
|
||||
name: str
|
||||
|
||||
def __repr__(self):
|
||||
return f"Symbol({self.name!r})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Symbol):
|
||||
return self.name == other.name
|
||||
if isinstance(other, str):
|
||||
return self.name == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Keyword:
|
||||
"""A keyword starting with colon."""
|
||||
name: str
|
||||
|
||||
def __repr__(self):
|
||||
return f"Keyword({self.name!r})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Keyword):
|
||||
return self.name == other.name
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((':' , self.name))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lambda:
|
||||
"""A lambda/anonymous function with closure."""
|
||||
params: List[str] # Parameter names
|
||||
body: Any # Expression body
|
||||
closure: Dict = None # Captured environment (optional for backwards compat)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Lambda({self.params}, {self.body!r})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Binding:
|
||||
"""A binding to analysis data for dynamic effect parameters."""
|
||||
analysis_ref: str # Name of analysis variable
|
||||
track: str = None # Optional track name (e.g., "bass", "energy")
|
||||
range_min: float = 0.0 # Output range minimum
|
||||
range_max: float = 1.0 # Output range maximum
|
||||
transform: str = None # Optional transform: "sqrt", "pow2", "log", etc.
|
||||
|
||||
def __repr__(self):
|
||||
t = f", transform={self.transform!r}" if self.transform else ""
|
||||
return f"Binding({self.analysis_ref!r}, track={self.track!r}, range=[{self.range_min}, {self.range_max}]{t})"
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Error during S-expression parsing."""
|
||||
def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1):
|
||||
self.position = position
|
||||
self.line = line
|
||||
self.col = col
|
||||
super().__init__(f"{message} at line {line}, column {col}")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Tokenize S-expression text into tokens."""
|
||||
|
||||
# Token patterns
|
||||
WHITESPACE = re.compile(r'\s+')
|
||||
COMMENT = re.compile(r';[^\n]*')
|
||||
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
|
||||
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
|
||||
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
|
||||
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*')
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.line = 1
|
||||
self.col = 1
|
||||
|
||||
def _advance(self, count: int = 1):
|
||||
"""Advance position, tracking line/column."""
|
||||
for _ in range(count):
|
||||
if self.pos < len(self.text):
|
||||
if self.text[self.pos] == '\n':
|
||||
self.line += 1
|
||||
self.col = 1
|
||||
else:
|
||||
self.col += 1
|
||||
self.pos += 1
|
||||
|
||||
def _skip_whitespace_and_comments(self):
|
||||
"""Skip whitespace and comments."""
|
||||
while self.pos < len(self.text):
|
||||
# Whitespace
|
||||
match = self.WHITESPACE.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
continue
|
||||
|
||||
# Comments
|
||||
match = self.COMMENT.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
def peek(self) -> str | None:
|
||||
"""Peek at current character."""
|
||||
self._skip_whitespace_and_comments()
|
||||
if self.pos >= len(self.text):
|
||||
return None
|
||||
return self.text[self.pos]
|
||||
|
||||
def next_token(self) -> Any:
|
||||
"""Get the next token."""
|
||||
self._skip_whitespace_and_comments()
|
||||
|
||||
if self.pos >= len(self.text):
|
||||
return None
|
||||
|
||||
char = self.text[self.pos]
|
||||
start_line, start_col = self.line, self.col
|
||||
|
||||
# Single-character tokens (parens, brackets, braces)
|
||||
if char in '()[]{}':
|
||||
self._advance()
|
||||
return char
|
||||
|
||||
# String
|
||||
if char == '"':
|
||||
match = self.STRING.match(self.text, self.pos)
|
||||
if not match:
|
||||
raise ParseError("Unterminated string", self.pos, self.line, self.col)
|
||||
self._advance(match.end() - self.pos)
|
||||
# Parse escape sequences
|
||||
content = match.group()[1:-1]
|
||||
content = content.replace('\\n', '\n')
|
||||
content = content.replace('\\t', '\t')
|
||||
content = content.replace('\\"', '"')
|
||||
content = content.replace('\\\\', '\\')
|
||||
return content
|
||||
|
||||
# Keyword
|
||||
if char == ':':
|
||||
match = self.KEYWORD.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
return Keyword(match.group()[1:]) # Strip leading colon
|
||||
raise ParseError(f"Invalid keyword", self.pos, self.line, self.col)
|
||||
|
||||
# Number (must check before symbol due to - prefix)
|
||||
if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and
|
||||
(self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')):
|
||||
match = self.NUMBER.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
num_str = match.group()
|
||||
if '.' in num_str or 'e' in num_str or 'E' in num_str:
|
||||
return float(num_str)
|
||||
return int(num_str)
|
||||
|
||||
# Symbol
|
||||
match = self.SYMBOL.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
return Symbol(match.group())
|
||||
|
||||
raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col)
|
||||
|
||||
|
||||
def parse(text: str) -> Any:
|
||||
"""
|
||||
Parse an S-expression string into Python data structures.
|
||||
|
||||
Returns:
|
||||
Parsed S-expression as nested Python structures:
|
||||
- Lists become Python lists
|
||||
- Symbols become Symbol objects
|
||||
- Keywords become Keyword objects
|
||||
- Strings become Python strings
|
||||
- Numbers become int/float
|
||||
|
||||
Example:
|
||||
>>> parse('(recipe "test" :version "1.0")')
|
||||
[Symbol('recipe'), 'test', Keyword('version'), '1.0']
|
||||
"""
|
||||
tokenizer = Tokenizer(text)
|
||||
result = _parse_expr(tokenizer)
|
||||
|
||||
# Check for trailing content
|
||||
if tokenizer.peek() is not None:
|
||||
raise ParseError("Unexpected content after expression",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_all(text: str) -> List[Any]:
|
||||
"""
|
||||
Parse multiple S-expressions from a string.
|
||||
|
||||
Returns list of parsed expressions.
|
||||
"""
|
||||
tokenizer = Tokenizer(text)
|
||||
results = []
|
||||
|
||||
while tokenizer.peek() is not None:
|
||||
results.append(_parse_expr(tokenizer))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_expr(tokenizer: Tokenizer) -> Any:
|
||||
"""Parse a single expression."""
|
||||
token = tokenizer.next_token()
|
||||
|
||||
if token is None:
|
||||
raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# List
|
||||
if token == '(':
|
||||
return _parse_list(tokenizer, ')')
|
||||
|
||||
# Vector (sugar for list)
|
||||
if token == '[':
|
||||
return _parse_list(tokenizer, ']')
|
||||
|
||||
# Map/dict: {:key1 val1 :key2 val2}
|
||||
if token == '{':
|
||||
return _parse_map(tokenizer)
|
||||
|
||||
# Unexpected closers
|
||||
if token in (')', ']', '}'):
|
||||
raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# Atom
|
||||
return token
|
||||
|
||||
|
||||
def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]:
|
||||
"""Parse a list until the closing delimiter."""
|
||||
items = []
|
||||
|
||||
while True:
|
||||
char = tokenizer.peek()
|
||||
|
||||
if char is None:
|
||||
raise ParseError(f"Unterminated list, expected {closer!r}",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
if char == closer:
|
||||
tokenizer.next_token() # Consume closer
|
||||
return items
|
||||
|
||||
items.append(_parse_expr(tokenizer))
|
||||
|
||||
|
||||
def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]:
|
||||
"""Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}."""
|
||||
result = {}
|
||||
|
||||
while True:
|
||||
char = tokenizer.peek()
|
||||
|
||||
if char is None:
|
||||
raise ParseError("Unterminated map, expected '}'",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
if char == '}':
|
||||
tokenizer.next_token() # Consume closer
|
||||
return result
|
||||
|
||||
# Parse key (should be a keyword like :key)
|
||||
key_token = _parse_expr(tokenizer)
|
||||
if isinstance(key_token, Keyword):
|
||||
key = key_token.name
|
||||
elif isinstance(key_token, str):
|
||||
key = key_token
|
||||
else:
|
||||
raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# Parse value
|
||||
value = _parse_expr(tokenizer)
|
||||
result[key] = value
|
||||
|
||||
|
||||
def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str:
|
||||
"""
|
||||
Serialize a Python data structure back to S-expression format.
|
||||
|
||||
Args:
|
||||
expr: The expression to serialize
|
||||
indent: Current indentation level (for pretty printing)
|
||||
pretty: Whether to use pretty printing with newlines
|
||||
|
||||
Returns:
|
||||
S-expression string
|
||||
"""
|
||||
if isinstance(expr, list):
|
||||
if not expr:
|
||||
return "()"
|
||||
|
||||
if pretty:
|
||||
return _serialize_pretty(expr, indent)
|
||||
else:
|
||||
items = [serialize(item, indent, False) for item in expr]
|
||||
return "(" + " ".join(items) + ")"
|
||||
|
||||
if isinstance(expr, Symbol):
|
||||
return expr.name
|
||||
|
||||
if isinstance(expr, Keyword):
|
||||
return f":{expr.name}"
|
||||
|
||||
if isinstance(expr, Lambda):
|
||||
params = " ".join(expr.params)
|
||||
body = serialize(expr.body, indent, pretty)
|
||||
return f"(fn [{params}] {body})"
|
||||
|
||||
if isinstance(expr, Binding):
|
||||
# analysis_ref can be a string, node ID, or dict - serialize it properly
|
||||
if isinstance(expr.analysis_ref, str):
|
||||
ref_str = f'"{expr.analysis_ref}"'
|
||||
else:
|
||||
ref_str = serialize(expr.analysis_ref, indent, pretty)
|
||||
s = f"(bind {ref_str} :range [{expr.range_min} {expr.range_max}]"
|
||||
if expr.transform:
|
||||
s += f" :transform {expr.transform}"
|
||||
return s + ")"
|
||||
|
||||
if isinstance(expr, str):
|
||||
# Escape special characters
|
||||
escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t')
|
||||
return f'"{escaped}"'
|
||||
|
||||
if isinstance(expr, bool):
|
||||
return "true" if expr else "false"
|
||||
|
||||
if isinstance(expr, (int, float)):
|
||||
return str(expr)
|
||||
|
||||
if expr is None:
|
||||
return "nil"
|
||||
|
||||
if isinstance(expr, dict):
|
||||
# Serialize dict as property list: {:key1 val1 :key2 val2}
|
||||
items = []
|
||||
for k, v in expr.items():
|
||||
items.append(f":{k}")
|
||||
items.append(serialize(v, indent, pretty))
|
||||
return "{" + " ".join(items) + "}"
|
||||
|
||||
raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}")
|
||||
|
||||
|
||||
def _serialize_pretty(expr: List, indent: int) -> str:
|
||||
"""Pretty-print a list expression with smart formatting."""
|
||||
if not expr:
|
||||
return "()"
|
||||
|
||||
prefix = " " * indent
|
||||
inner_prefix = " " * (indent + 1)
|
||||
|
||||
# Check if this is a simple list that fits on one line
|
||||
simple = serialize(expr, indent, False)
|
||||
if len(simple) < 60 and '\n' not in simple:
|
||||
return simple
|
||||
|
||||
# Start building multiline output
|
||||
head = serialize(expr[0], indent + 1, False)
|
||||
parts = [f"({head}"]
|
||||
|
||||
i = 1
|
||||
while i < len(expr):
|
||||
item = expr[i]
|
||||
|
||||
# Group keyword-value pairs on same line
|
||||
if isinstance(item, Keyword) and i + 1 < len(expr):
|
||||
key = serialize(item, 0, False)
|
||||
val = serialize(expr[i + 1], indent + 1, False)
|
||||
|
||||
# If value is short, put on same line
|
||||
if len(val) < 50 and '\n' not in val:
|
||||
parts.append(f"{inner_prefix}{key} {val}")
|
||||
else:
|
||||
# Value is complex, serialize it pretty
|
||||
val_pretty = serialize(expr[i + 1], indent + 1, True)
|
||||
parts.append(f"{inner_prefix}{key} {val_pretty}")
|
||||
i += 2
|
||||
else:
|
||||
# Regular item
|
||||
item_str = serialize(item, indent + 1, True)
|
||||
parts.append(f"{inner_prefix}{item_str}")
|
||||
i += 1
|
||||
|
||||
return "\n".join(parts) + ")"
|
||||
2187
artdag/sexp/planner.py
Normal file
2187
artdag/sexp/planner.py
Normal file
File diff suppressed because it is too large
Load Diff
620
artdag/sexp/primitives.py
Normal file
620
artdag/sexp/primitives.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
Frame processing primitives for sexp effects.
|
||||
|
||||
These primitives are called by sexp effect definitions and operate on
|
||||
numpy arrays (frames). They're used when falling back to Python rendering
|
||||
instead of FFmpeg.
|
||||
|
||||
Required: numpy, PIL
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
HAS_NUMPY = False
|
||||
np = None
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
HAS_PIL = True
|
||||
except ImportError:
|
||||
HAS_PIL = False
|
||||
|
||||
|
||||
# ASCII character sets for different styles
|
||||
ASCII_ALPHABETS = {
|
||||
"standard": " .:-=+*#%@",
|
||||
"blocks": " ░▒▓█",
|
||||
"simple": " .-:+=xX#",
|
||||
"detailed": " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$",
|
||||
"binary": " █",
|
||||
}
|
||||
|
||||
|
||||
def check_deps():
|
||||
"""Check that required dependencies are available."""
|
||||
if not HAS_NUMPY:
|
||||
raise ImportError("numpy required for frame primitives: pip install numpy")
|
||||
if not HAS_PIL:
|
||||
raise ImportError("PIL required for frame primitives: pip install Pillow")
|
||||
|
||||
|
||||
def frame_to_image(frame: np.ndarray) -> Image.Image:
|
||||
"""Convert numpy frame to PIL Image."""
|
||||
check_deps()
|
||||
if frame.dtype != np.uint8:
|
||||
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
||||
return Image.fromarray(frame)
|
||||
|
||||
|
||||
def image_to_frame(img: Image.Image) -> np.ndarray:
|
||||
"""Convert PIL Image to numpy frame."""
|
||||
check_deps()
|
||||
return np.array(img)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Art Primitives
|
||||
# ============================================================================
|
||||
|
||||
def cell_sample(frame: np.ndarray, cell_size: int = 8) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample frame into cells, returning average colors and luminances.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C)
|
||||
cell_size: Size of each cell
|
||||
|
||||
Returns:
|
||||
(colors, luminances) - colors is (rows, cols, 3), luminances is (rows, cols)
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
rows = h // cell_size
|
||||
cols = w // cell_size
|
||||
|
||||
colors = np.zeros((rows, cols, 3), dtype=np.float32)
|
||||
luminances = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
y0, y1 = r * cell_size, (r + 1) * cell_size
|
||||
x0, x1 = c * cell_size, (c + 1) * cell_size
|
||||
cell = frame[y0:y1, x0:x1]
|
||||
|
||||
# Average color
|
||||
avg_color = cell.mean(axis=(0, 1))
|
||||
colors[r, c] = avg_color[:3] # RGB only
|
||||
|
||||
# Luminance (ITU-R BT.601)
|
||||
lum = 0.299 * avg_color[0] + 0.587 * avg_color[1] + 0.114 * avg_color[2]
|
||||
luminances[r, c] = lum
|
||||
|
||||
return colors, luminances
|
||||
|
||||
|
||||
def luminance_to_chars(
|
||||
luminances: np.ndarray,
|
||||
alphabet: str = "standard",
|
||||
contrast: float = 1.0,
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Convert luminance values to ASCII characters.
|
||||
|
||||
Args:
|
||||
luminances: 2D array of luminance values (0-255)
|
||||
alphabet: Name of character set or custom string
|
||||
contrast: Contrast multiplier
|
||||
|
||||
Returns:
|
||||
2D list of characters
|
||||
"""
|
||||
check_deps()
|
||||
chars = ASCII_ALPHABETS.get(alphabet, alphabet)
|
||||
n_chars = len(chars)
|
||||
|
||||
rows, cols = luminances.shape
|
||||
result = []
|
||||
|
||||
for r in range(rows):
|
||||
row_chars = []
|
||||
for c in range(cols):
|
||||
lum = luminances[r, c]
|
||||
# Apply contrast around midpoint
|
||||
lum = 128 + (lum - 128) * contrast
|
||||
lum = np.clip(lum, 0, 255)
|
||||
# Map to character index
|
||||
idx = int(lum / 256 * n_chars)
|
||||
idx = min(idx, n_chars - 1)
|
||||
row_chars.append(chars[idx])
|
||||
result.append(row_chars)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def render_char_grid(
|
||||
frame: np.ndarray,
|
||||
chars: List[List[str]],
|
||||
colors: np.ndarray,
|
||||
char_size: int = 8,
|
||||
color_mode: str = "color",
|
||||
background: Tuple[int, int, int] = (0, 0, 0),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Render character grid to an image.
|
||||
|
||||
Args:
|
||||
frame: Original frame (for dimensions)
|
||||
chars: 2D list of characters
|
||||
colors: Color for each cell (rows, cols, 3)
|
||||
char_size: Size of each character cell
|
||||
color_mode: "color", "white", or "green"
|
||||
background: Background RGB color
|
||||
|
||||
Returns:
|
||||
Rendered frame
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
rows = len(chars)
|
||||
cols = len(chars[0]) if chars else 0
|
||||
|
||||
# Create output image
|
||||
img = Image.new("RGB", (w, h), background)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Try to get a monospace font
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size)
|
||||
except (IOError, OSError):
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", char_size)
|
||||
except (IOError, OSError):
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
char = chars[r][c]
|
||||
if char == ' ':
|
||||
continue
|
||||
|
||||
x = c * char_size
|
||||
y = r * char_size
|
||||
|
||||
if color_mode == "color":
|
||||
color = tuple(int(v) for v in colors[r, c])
|
||||
elif color_mode == "green":
|
||||
color = (0, 255, 0)
|
||||
else: # white
|
||||
color = (255, 255, 255)
|
||||
|
||||
draw.text((x, y), char, fill=color, font=font)
|
||||
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def ascii_art_frame(
|
||||
frame: np.ndarray,
|
||||
char_size: int = 8,
|
||||
alphabet: str = "standard",
|
||||
color_mode: str = "color",
|
||||
contrast: float = 1.5,
|
||||
background: Tuple[int, int, int] = (0, 0, 0),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply ASCII art effect to a frame.
|
||||
|
||||
This is the main entry point for the ascii_art effect.
|
||||
"""
|
||||
check_deps()
|
||||
colors, luminances = cell_sample(frame, char_size)
|
||||
chars = luminance_to_chars(luminances, alphabet, contrast)
|
||||
return render_char_grid(frame, chars, colors, char_size, color_mode, background)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Zones Primitives
|
||||
# ============================================================================
|
||||
|
||||
def ascii_zones_frame(
|
||||
frame: np.ndarray,
|
||||
char_size: int = 8,
|
||||
zone_threshold: int = 128,
|
||||
dark_chars: str = " .-:",
|
||||
light_chars: str = "=+*#",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply zone-based ASCII art effect.
|
||||
|
||||
Different character sets for dark vs light regions.
|
||||
"""
|
||||
check_deps()
|
||||
colors, luminances = cell_sample(frame, char_size)
|
||||
|
||||
rows, cols = luminances.shape
|
||||
chars = []
|
||||
|
||||
for r in range(rows):
|
||||
row_chars = []
|
||||
for c in range(cols):
|
||||
lum = luminances[r, c]
|
||||
if lum < zone_threshold:
|
||||
# Dark zone
|
||||
charset = dark_chars
|
||||
local_lum = lum / zone_threshold # 0-1 within zone
|
||||
else:
|
||||
# Light zone
|
||||
charset = light_chars
|
||||
local_lum = (lum - zone_threshold) / (255 - zone_threshold)
|
||||
|
||||
idx = int(local_lum * len(charset))
|
||||
idx = min(idx, len(charset) - 1)
|
||||
row_chars.append(charset[idx])
|
||||
chars.append(row_chars)
|
||||
|
||||
return render_char_grid(frame, chars, colors, char_size, "color", (0, 0, 0))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Kaleidoscope Primitives (Python fallback)
|
||||
# ============================================================================
|
||||
|
||||
def kaleidoscope_displace(
|
||||
w: int,
|
||||
h: int,
|
||||
segments: int = 6,
|
||||
rotation: float = 0,
|
||||
cx: float = None,
|
||||
cy: float = None,
|
||||
zoom: float = 1.0,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute kaleidoscope displacement coordinates.
|
||||
|
||||
Returns (x_coords, y_coords) arrays for remapping.
|
||||
"""
|
||||
check_deps()
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
|
||||
# Create coordinate grids
|
||||
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
||||
|
||||
# Center coordinates
|
||||
x_centered = x_grid - cx
|
||||
y_centered = y_grid - cy
|
||||
|
||||
# Convert to polar
|
||||
r = np.sqrt(x_centered**2 + y_centered**2) / zoom
|
||||
theta = np.arctan2(y_centered, x_centered)
|
||||
|
||||
# Apply rotation
|
||||
theta = theta - np.radians(rotation)
|
||||
|
||||
# Kaleidoscope: fold angle into segment
|
||||
segment_angle = 2 * np.pi / segments
|
||||
theta = np.abs(np.mod(theta, segment_angle) - segment_angle / 2)
|
||||
|
||||
# Convert back to cartesian
|
||||
x_new = r * np.cos(theta) + cx
|
||||
y_new = r * np.sin(theta) + cy
|
||||
|
||||
return x_new, y_new
|
||||
|
||||
|
||||
def remap(
|
||||
frame: np.ndarray,
|
||||
x_coords: np.ndarray,
|
||||
y_coords: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Remap frame using coordinate arrays.
|
||||
|
||||
Uses bilinear interpolation.
|
||||
"""
|
||||
check_deps()
|
||||
from scipy import ndimage
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Clip coordinates
|
||||
x_coords = np.clip(x_coords, 0, w - 1)
|
||||
y_coords = np.clip(y_coords, 0, h - 1)
|
||||
|
||||
# Remap each channel
|
||||
if len(frame.shape) == 3:
|
||||
result = np.zeros_like(frame)
|
||||
for c in range(frame.shape[2]):
|
||||
result[:, :, c] = ndimage.map_coordinates(
|
||||
frame[:, :, c],
|
||||
[y_coords, x_coords],
|
||||
order=1,
|
||||
mode='reflect',
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return ndimage.map_coordinates(frame, [y_coords, x_coords], order=1, mode='reflect')
|
||||
|
||||
|
||||
def kaleidoscope_frame(
|
||||
frame: np.ndarray,
|
||||
segments: int = 6,
|
||||
rotation: float = 0,
|
||||
center_x: float = 0.5,
|
||||
center_y: float = 0.5,
|
||||
zoom: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply kaleidoscope effect to a frame.
|
||||
|
||||
This is a Python fallback - FFmpeg version is faster.
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
cx = w * center_x
|
||||
cy = h * center_y
|
||||
|
||||
x_coords, y_coords = kaleidoscope_displace(w, h, segments, rotation, cx, cy, zoom)
|
||||
return remap(frame, x_coords, y_coords)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Datamosh Primitives (simplified Python version)
|
||||
# ============================================================================
|
||||
|
||||
def datamosh_frame(
|
||||
frame: np.ndarray,
|
||||
prev_frame: Optional[np.ndarray],
|
||||
block_size: int = 32,
|
||||
corruption: float = 0.3,
|
||||
max_offset: int = 50,
|
||||
color_corrupt: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Simplified datamosh effect using block displacement.
|
||||
|
||||
This is a basic approximation - real datamosh works on compressed video.
|
||||
"""
|
||||
check_deps()
|
||||
if prev_frame is None:
|
||||
return frame.copy()
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
result = frame.copy()
|
||||
|
||||
# Process in blocks
|
||||
for y in range(0, h - block_size, block_size):
|
||||
for x in range(0, w - block_size, block_size):
|
||||
if np.random.random() < corruption:
|
||||
# Random offset
|
||||
ox = np.random.randint(-max_offset, max_offset + 1)
|
||||
oy = np.random.randint(-max_offset, max_offset + 1)
|
||||
|
||||
# Source from previous frame with offset
|
||||
src_y = np.clip(y + oy, 0, h - block_size)
|
||||
src_x = np.clip(x + ox, 0, w - block_size)
|
||||
|
||||
block = prev_frame[src_y:src_y+block_size, src_x:src_x+block_size]
|
||||
|
||||
# Color corruption
|
||||
if color_corrupt and np.random.random() < 0.3:
|
||||
# Swap or shift channels
|
||||
block = np.roll(block, np.random.randint(1, 3), axis=2)
|
||||
|
||||
result[y:y+block_size, x:x+block_size] = block
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pixelsort Primitives (Python version)
|
||||
# ============================================================================
|
||||
|
||||
def pixelsort_frame(
|
||||
frame: np.ndarray,
|
||||
sort_by: str = "lightness",
|
||||
threshold_low: float = 50,
|
||||
threshold_high: float = 200,
|
||||
angle: float = 0,
|
||||
reverse: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply pixel sorting effect to a frame.
|
||||
"""
|
||||
check_deps()
|
||||
from scipy import ndimage
|
||||
|
||||
# Rotate if needed
|
||||
if angle != 0:
|
||||
frame = ndimage.rotate(frame, -angle, reshape=False, mode='reflect')
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
result = frame.copy()
|
||||
|
||||
# Compute sort key
|
||||
if sort_by == "lightness":
|
||||
key = 0.299 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2]
|
||||
elif sort_by == "hue":
|
||||
# Simple hue approximation
|
||||
key = np.arctan2(
|
||||
np.sqrt(3) * (frame[:,:,1].astype(float) - frame[:,:,2]),
|
||||
2 * frame[:,:,0].astype(float) - frame[:,:,1] - frame[:,:,2]
|
||||
)
|
||||
elif sort_by == "saturation":
|
||||
mx = frame.max(axis=2).astype(float)
|
||||
mn = frame.min(axis=2).astype(float)
|
||||
key = np.where(mx > 0, (mx - mn) / mx, 0)
|
||||
else:
|
||||
key = frame[:,:,0] # Red channel
|
||||
|
||||
# Sort each row
|
||||
for y in range(h):
|
||||
row = result[y]
|
||||
row_key = key[y]
|
||||
|
||||
# Find sortable intervals (pixels within threshold)
|
||||
mask = (row_key >= threshold_low) & (row_key <= threshold_high)
|
||||
|
||||
# Find runs of True in mask
|
||||
runs = []
|
||||
start = None
|
||||
for x in range(w):
|
||||
if mask[x] and start is None:
|
||||
start = x
|
||||
elif not mask[x] and start is not None:
|
||||
if x - start > 1:
|
||||
runs.append((start, x))
|
||||
start = None
|
||||
if start is not None and w - start > 1:
|
||||
runs.append((start, w))
|
||||
|
||||
# Sort each run
|
||||
for start, end in runs:
|
||||
indices = np.argsort(row_key[start:end])
|
||||
if reverse:
|
||||
indices = indices[::-1]
|
||||
result[y, start:end] = row[start:end][indices]
|
||||
|
||||
# Rotate back
|
||||
if angle != 0:
|
||||
result = ndimage.rotate(result, angle, reshape=False, mode='reflect')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Registry
|
||||
# ============================================================================
|
||||
|
||||
def map_char_grid(
|
||||
chars,
|
||||
luminances,
|
||||
fn,
|
||||
):
|
||||
"""
|
||||
Map a function over each cell of a character grid.
|
||||
|
||||
Args:
|
||||
chars: 2D array/list of characters (rows, cols)
|
||||
luminances: 2D array of luminance values
|
||||
fn: Function or Lambda (row, col, char, luminance) -> new_char
|
||||
|
||||
Returns:
|
||||
New character grid with mapped values (list of lists)
|
||||
"""
|
||||
from .parser import Lambda
|
||||
from .evaluator import evaluate
|
||||
|
||||
# Handle both list and numpy array inputs
|
||||
if isinstance(chars, np.ndarray):
|
||||
rows, cols = chars.shape[:2]
|
||||
else:
|
||||
rows = len(chars)
|
||||
cols = len(chars[0]) if rows > 0 and isinstance(chars[0], (list, tuple, str)) else 1
|
||||
|
||||
# Get luminances as 2D
|
||||
if isinstance(luminances, np.ndarray):
|
||||
lum_arr = luminances
|
||||
else:
|
||||
lum_arr = np.array(luminances)
|
||||
|
||||
# Check if fn is a Lambda (from sexp) or a Python callable
|
||||
is_lambda = isinstance(fn, Lambda)
|
||||
|
||||
result = []
|
||||
for r in range(rows):
|
||||
row_result = []
|
||||
for c in range(cols):
|
||||
# Get character
|
||||
if isinstance(chars, np.ndarray):
|
||||
ch = chars[r, c] if len(chars.shape) > 1 else chars[r]
|
||||
elif isinstance(chars[r], str):
|
||||
ch = chars[r][c] if c < len(chars[r]) else ' '
|
||||
else:
|
||||
ch = chars[r][c] if c < len(chars[r]) else ' '
|
||||
|
||||
# Get luminance
|
||||
if len(lum_arr.shape) > 1:
|
||||
lum = lum_arr[r, c]
|
||||
else:
|
||||
lum = lum_arr[r]
|
||||
|
||||
# Call the function
|
||||
if is_lambda:
|
||||
# Evaluate the Lambda with arguments bound
|
||||
call_env = dict(fn.closure) if fn.closure else {}
|
||||
for param, val in zip(fn.params, [r, c, ch, float(lum)]):
|
||||
call_env[param] = val
|
||||
new_ch = evaluate(fn.body, call_env)
|
||||
else:
|
||||
new_ch = fn(r, c, ch, float(lum))
|
||||
|
||||
row_result.append(new_ch)
|
||||
result.append(row_result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def alphabet_char(alphabet: str, index: int) -> str:
|
||||
"""
|
||||
Get a character from an alphabet at a given index.
|
||||
|
||||
Args:
|
||||
alphabet: Alphabet name (from ASCII_ALPHABETS) or literal string
|
||||
index: Index into the alphabet (clamped to valid range)
|
||||
|
||||
Returns:
|
||||
Character at the index
|
||||
"""
|
||||
# Get alphabet string
|
||||
if alphabet in ASCII_ALPHABETS:
|
||||
chars = ASCII_ALPHABETS[alphabet]
|
||||
else:
|
||||
chars = alphabet
|
||||
|
||||
# Clamp index to valid range
|
||||
index = int(index)
|
||||
index = max(0, min(index, len(chars) - 1))
|
||||
|
||||
return chars[index]
|
||||
|
||||
|
||||
PRIMITIVES = {
|
||||
# ASCII
|
||||
"cell-sample": cell_sample,
|
||||
"luminance-to-chars": luminance_to_chars,
|
||||
"render-char-grid": render_char_grid,
|
||||
"map-char-grid": map_char_grid,
|
||||
"alphabet-char": alphabet_char,
|
||||
"ascii_art_frame": ascii_art_frame,
|
||||
"ascii_zones_frame": ascii_zones_frame,
|
||||
|
||||
# Kaleidoscope
|
||||
"kaleidoscope-displace": kaleidoscope_displace,
|
||||
"remap": remap,
|
||||
"kaleidoscope_frame": kaleidoscope_frame,
|
||||
|
||||
# Datamosh
|
||||
"datamosh": datamosh_frame,
|
||||
"datamosh_frame": datamosh_frame,
|
||||
|
||||
# Pixelsort
|
||||
"pixelsort": pixelsort_frame,
|
||||
"pixelsort_frame": pixelsort_frame,
|
||||
}
|
||||
|
||||
|
||||
def get_primitive(name: str):
|
||||
"""Get a primitive function by name."""
|
||||
return PRIMITIVES.get(name)
|
||||
|
||||
|
||||
def list_primitives() -> List[str]:
|
||||
"""List all available primitives."""
|
||||
return list(PRIMITIVES.keys())
|
||||
779
artdag/sexp/scheduler.py
Normal file
779
artdag/sexp/scheduler.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Celery scheduler for S-expression execution plans.
|
||||
|
||||
Distributes plan steps to workers as S-expressions.
|
||||
The S-expression is the canonical format - workers receive
|
||||
serialized S-expressions and can verify cache_ids by hashing them.
|
||||
|
||||
Usage:
|
||||
from artdag.sexp import compile_string, create_plan
|
||||
from artdag.sexp.scheduler import schedule_plan
|
||||
|
||||
recipe = compile_string(sexp_content)
|
||||
plan = create_plan(recipe, inputs={'video': 'abc123...'})
|
||||
result = schedule_plan(plan)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from .parser import Symbol, Keyword, serialize, parse
|
||||
from .planner import ExecutionPlanSexp, PlanStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
"""Result from executing a step."""
|
||||
step_id: str
|
||||
cache_id: str
|
||||
status: str # 'completed', 'cached', 'failed', 'pending'
|
||||
output_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
ipfs_cid: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanResult:
|
||||
"""Result from executing a complete plan."""
|
||||
plan_id: str
|
||||
status: str # 'completed', 'failed', 'partial'
|
||||
steps_completed: int = 0
|
||||
steps_cached: int = 0
|
||||
steps_failed: int = 0
|
||||
output_cache_id: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
output_ipfs_cid: Optional[str] = None
|
||||
step_results: Dict[str, StepResult] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def step_to_sexp(step: PlanStep) -> List:
|
||||
"""
|
||||
Convert a PlanStep to minimal S-expression for worker.
|
||||
|
||||
This is the canonical form that workers receive.
|
||||
Workers can verify cache_id by hashing this S-expression.
|
||||
"""
|
||||
sexp = [Symbol(step.node_type.lower())]
|
||||
|
||||
# Add config as keywords
|
||||
for key, value in step.config.items():
|
||||
sexp.extend([Keyword(key.replace('_', '-')), value])
|
||||
|
||||
# Add inputs as cache IDs (not step IDs)
|
||||
if step.inputs:
|
||||
sexp.extend([Keyword("inputs"), step.inputs])
|
||||
|
||||
return sexp
|
||||
|
||||
|
||||
def step_sexp_to_string(step: PlanStep) -> str:
|
||||
"""Serialize step to S-expression string for Celery task."""
|
||||
return serialize(step_to_sexp(step))
|
||||
|
||||
|
||||
def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool:
|
||||
"""
|
||||
Verify that a step's cache_id matches its S-expression.
|
||||
|
||||
Workers should call this to verify they're executing the correct task.
|
||||
"""
|
||||
data = {"sexp": step_sexp}
|
||||
if cluster_key:
|
||||
data = {"_cluster_key": cluster_key, "_data": data}
|
||||
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
computed = hashlib.sha3_256(json_str.encode()).hexdigest()
|
||||
return computed == expected_cache_id
|
||||
|
||||
|
||||
class PlanScheduler:
|
||||
"""
|
||||
Schedules execution of S-expression plans on Celery workers.
|
||||
|
||||
The scheduler:
|
||||
1. Groups steps by dependency level
|
||||
2. Checks cache for already-computed results
|
||||
3. Dispatches uncached steps to workers
|
||||
4. Waits for completion before proceeding to next level
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager=None,
|
||||
celery_app=None,
|
||||
execute_task_name: str = 'tasks.execute_step_sexp',
|
||||
):
|
||||
"""
|
||||
Initialize the scheduler.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for checking cached results
|
||||
celery_app: Celery application instance
|
||||
execute_task_name: Name of the Celery task for step execution
|
||||
"""
|
||||
self.cache_manager = cache_manager
|
||||
self.celery_app = celery_app
|
||||
self.execute_task_name = execute_task_name
|
||||
|
||||
def schedule(
|
||||
self,
|
||||
plan: ExecutionPlanSexp,
|
||||
timeout: int = 3600,
|
||||
) -> PlanResult:
|
||||
"""
|
||||
Schedule and execute a plan.
|
||||
|
||||
Args:
|
||||
plan: The execution plan (S-expression format)
|
||||
timeout: Timeout in seconds for the entire plan
|
||||
|
||||
Returns:
|
||||
PlanResult with execution results
|
||||
"""
|
||||
from celery import group
|
||||
|
||||
logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
|
||||
|
||||
# Build step lookup and group by level
|
||||
steps_by_id = {s.step_id: s for s in plan.steps}
|
||||
steps_by_level = self._group_by_level(plan.steps)
|
||||
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||
|
||||
# Track results
|
||||
result = PlanResult(
|
||||
plan_id=plan.plan_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Map step_id -> cache_id for resolving inputs
|
||||
cache_ids = dict(plan.inputs) # Start with input hashes
|
||||
for step in plan.steps:
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Execute level by level
|
||||
for level in range(max_level + 1):
|
||||
level_steps = steps_by_level.get(level, [])
|
||||
if not level_steps:
|
||||
continue
|
||||
|
||||
logger.info(f"Level {level}: {len(level_steps)} steps")
|
||||
|
||||
# Check cache for each step
|
||||
steps_to_run = []
|
||||
for step in level_steps:
|
||||
if self._is_cached(step.cache_id):
|
||||
result.steps_cached += 1
|
||||
result.step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="cached",
|
||||
output_path=self._get_cached_path(step.cache_id),
|
||||
)
|
||||
else:
|
||||
steps_to_run.append(step)
|
||||
|
||||
if not steps_to_run:
|
||||
logger.info(f"Level {level}: all {len(level_steps)} steps cached")
|
||||
continue
|
||||
|
||||
# Dispatch uncached steps to workers
|
||||
logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps")
|
||||
|
||||
tasks = []
|
||||
for step in steps_to_run:
|
||||
# Build task arguments
|
||||
step_sexp = step_sexp_to_string(step)
|
||||
input_cache_ids = {
|
||||
inp: cache_ids.get(inp, inp)
|
||||
for inp in step.inputs
|
||||
}
|
||||
|
||||
task = self._get_execute_task().s(
|
||||
step_sexp=step_sexp,
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
plan_id=plan.plan_id,
|
||||
input_cache_ids=input_cache_ids,
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute in parallel
|
||||
job = group(tasks)
|
||||
async_result = job.apply_async()
|
||||
|
||||
try:
|
||||
step_results = async_result.get(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Level {level} failed: {e}")
|
||||
result.status = "failed"
|
||||
result.error = f"Level {level} failed: {e}"
|
||||
return result
|
||||
|
||||
# Process results
|
||||
for step_result in step_results:
|
||||
step_id = step_result.get("step_id")
|
||||
status = step_result.get("status")
|
||||
|
||||
result.step_results[step_id] = StepResult(
|
||||
step_id=step_id,
|
||||
cache_id=step_result.get("cache_id"),
|
||||
status=status,
|
||||
output_path=step_result.get("output_path"),
|
||||
error=step_result.get("error"),
|
||||
ipfs_cid=step_result.get("ipfs_cid"),
|
||||
)
|
||||
|
||||
if status in ("completed", "cached", "completed_by_other"):
|
||||
result.steps_completed += 1
|
||||
elif status == "failed":
|
||||
result.steps_failed += 1
|
||||
result.status = "failed"
|
||||
result.error = step_result.get("error")
|
||||
return result
|
||||
|
||||
# Get final output
|
||||
output_step = steps_by_id.get(plan.output_step_id)
|
||||
if output_step:
|
||||
output_result = result.step_results.get(output_step.step_id)
|
||||
if output_result:
|
||||
result.output_cache_id = output_step.cache_id
|
||||
result.output_path = output_result.output_path
|
||||
result.output_ipfs_cid = output_result.ipfs_cid
|
||||
|
||||
result.status = "completed"
|
||||
logger.info(
|
||||
f"Plan {plan.plan_id[:16]}... completed: "
|
||||
f"{result.steps_completed} executed, {result.steps_cached} cached"
|
||||
)
|
||||
return result
|
||||
|
||||
def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]:
|
||||
"""Group steps by dependency level."""
|
||||
by_level = {}
|
||||
for step in steps:
|
||||
by_level.setdefault(step.level, []).append(step)
|
||||
return by_level
|
||||
|
||||
def _is_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a cache_id exists in cache."""
|
||||
if self.cache_manager is None:
|
||||
return False
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return path is not None
|
||||
|
||||
def _get_cached_path(self, cache_id: str) -> Optional[str]:
|
||||
"""Get the path for a cached item."""
|
||||
if self.cache_manager is None:
|
||||
return None
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return str(path) if path else None
|
||||
|
||||
def _get_execute_task(self):
|
||||
"""Get the Celery task for step execution."""
|
||||
if self.celery_app is None:
|
||||
raise RuntimeError("Celery app not configured")
|
||||
return self.celery_app.tasks[self.execute_task_name]
|
||||
|
||||
|
||||
def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler:
|
||||
"""
|
||||
Create a scheduler with the given dependencies.
|
||||
|
||||
If not provided, attempts to import from art-celery.
|
||||
"""
|
||||
if celery_app is None:
|
||||
try:
|
||||
from celery_app import app as celery_app
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if cache_manager is None:
|
||||
try:
|
||||
from cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return PlanScheduler(
|
||||
cache_manager=cache_manager,
|
||||
celery_app=celery_app,
|
||||
)
|
||||
|
||||
|
||||
def schedule_plan(
|
||||
plan: ExecutionPlanSexp,
|
||||
cache_manager=None,
|
||||
celery_app=None,
|
||||
timeout: int = 3600,
|
||||
) -> PlanResult:
|
||||
"""
|
||||
Convenience function to schedule a plan.
|
||||
|
||||
Args:
|
||||
plan: The execution plan
|
||||
cache_manager: Optional cache manager
|
||||
celery_app: Optional Celery app
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
PlanResult
|
||||
"""
|
||||
scheduler = create_scheduler(cache_manager, celery_app)
|
||||
return scheduler.schedule(plan, timeout=timeout)
|
||||
|
||||
|
||||
# Stage-aware scheduling
|
||||
|
||||
@dataclass
|
||||
class StageResult:
|
||||
"""Result from executing a stage."""
|
||||
stage_name: str
|
||||
cache_id: str
|
||||
status: str # 'completed', 'cached', 'failed', 'pending'
|
||||
step_results: Dict[str, StepResult] = field(default_factory=dict)
|
||||
outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagePlanResult:
|
||||
"""Result from executing a plan with stages."""
|
||||
plan_id: str
|
||||
status: str # 'completed', 'failed', 'partial'
|
||||
stages_completed: int = 0
|
||||
stages_cached: int = 0
|
||||
stages_failed: int = 0
|
||||
steps_completed: int = 0
|
||||
steps_cached: int = 0
|
||||
steps_failed: int = 0
|
||||
stage_results: Dict[str, StageResult] = field(default_factory=dict)
|
||||
output_cache_id: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class StagePlanScheduler:
|
||||
"""
|
||||
Stage-aware scheduler for S-expression plans.
|
||||
|
||||
The scheduler:
|
||||
1. Groups stages by level (parallel groups)
|
||||
2. For each stage level:
|
||||
- Check stage cache, skip entire stage if hit
|
||||
- Execute stage steps (grouped by level within stage)
|
||||
- Cache stage outputs
|
||||
3. Stages at same level can run in parallel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
execute_task_name: str = 'tasks.execute_step_sexp',
|
||||
):
|
||||
"""
|
||||
Initialize the stage-aware scheduler.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for step-level caching
|
||||
stage_cache: StageCache instance for stage-level caching
|
||||
celery_app: Celery application instance
|
||||
execute_task_name: Name of the Celery task for step execution
|
||||
"""
|
||||
self.cache_manager = cache_manager
|
||||
self.stage_cache = stage_cache
|
||||
self.celery_app = celery_app
|
||||
self.execute_task_name = execute_task_name
|
||||
|
||||
def schedule(
|
||||
self,
|
||||
plan: ExecutionPlanSexp,
|
||||
timeout: int = 3600,
|
||||
) -> StagePlanResult:
|
||||
"""
|
||||
Schedule and execute a plan with stage awareness.
|
||||
|
||||
If the plan has stages, uses stage-level scheduling.
|
||||
Otherwise, falls back to step-level scheduling.
|
||||
|
||||
Args:
|
||||
plan: The execution plan (S-expression format)
|
||||
timeout: Timeout in seconds for the entire plan
|
||||
|
||||
Returns:
|
||||
StagePlanResult with execution results
|
||||
"""
|
||||
# If no stages, use regular scheduling
|
||||
if not plan.stage_plans:
|
||||
logger.info("Plan has no stages, using step-level scheduling")
|
||||
regular_scheduler = PlanScheduler(
|
||||
cache_manager=self.cache_manager,
|
||||
celery_app=self.celery_app,
|
||||
execute_task_name=self.execute_task_name,
|
||||
)
|
||||
step_result = regular_scheduler.schedule(plan, timeout)
|
||||
return StagePlanResult(
|
||||
plan_id=step_result.plan_id,
|
||||
status=step_result.status,
|
||||
steps_completed=step_result.steps_completed,
|
||||
steps_cached=step_result.steps_cached,
|
||||
steps_failed=step_result.steps_failed,
|
||||
output_cache_id=step_result.output_cache_id,
|
||||
output_path=step_result.output_path,
|
||||
error=step_result.error,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling plan {plan.plan_id[:16]}... "
|
||||
f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)"
|
||||
)
|
||||
|
||||
result = StagePlanResult(
|
||||
plan_id=plan.plan_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Group stages by level
|
||||
stages_by_level = self._group_stages_by_level(plan.stage_plans)
|
||||
max_level = max(stages_by_level.keys()) if stages_by_level else 0
|
||||
|
||||
# Track stage outputs for data flow
|
||||
stage_outputs = {} # stage_name -> {binding_name -> cache_id}
|
||||
|
||||
# Execute stage by stage level
|
||||
for level in range(max_level + 1):
|
||||
level_stages = stages_by_level.get(level, [])
|
||||
if not level_stages:
|
||||
continue
|
||||
|
||||
logger.info(f"Stage level {level}: {len(level_stages)} stages")
|
||||
|
||||
# Check stage cache for each stage
|
||||
stages_to_run = []
|
||||
for stage_plan in level_stages:
|
||||
if self._is_stage_cached(stage_plan.cache_id):
|
||||
result.stages_cached += 1
|
||||
cached_entry = self._load_cached_stage(stage_plan.cache_id)
|
||||
if cached_entry:
|
||||
stage_outputs[stage_plan.stage_name] = {
|
||||
name: out.cache_id
|
||||
for name, out in cached_entry.outputs.items()
|
||||
}
|
||||
result.stage_results[stage_plan.stage_name] = StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="cached",
|
||||
outputs=stage_outputs[stage_plan.stage_name],
|
||||
)
|
||||
logger.info(f"Stage {stage_plan.stage_name}: cached")
|
||||
else:
|
||||
stages_to_run.append(stage_plan)
|
||||
|
||||
if not stages_to_run:
|
||||
logger.info(f"Stage level {level}: all {len(level_stages)} stages cached")
|
||||
continue
|
||||
|
||||
# Execute uncached stages
|
||||
# For now, execute sequentially; L1 Celery will add parallel execution
|
||||
for stage_plan in stages_to_run:
|
||||
logger.info(f"Executing stage: {stage_plan.stage_name}")
|
||||
|
||||
stage_result = self._execute_stage(
|
||||
stage_plan,
|
||||
plan,
|
||||
stage_outputs,
|
||||
timeout,
|
||||
)
|
||||
|
||||
result.stage_results[stage_plan.stage_name] = stage_result
|
||||
|
||||
if stage_result.status == "completed":
|
||||
result.stages_completed += 1
|
||||
stage_outputs[stage_plan.stage_name] = stage_result.outputs
|
||||
|
||||
# Cache the stage result
|
||||
self._cache_stage(stage_plan, stage_result)
|
||||
elif stage_result.status == "failed":
|
||||
result.stages_failed += 1
|
||||
result.status = "failed"
|
||||
result.error = stage_result.error
|
||||
return result
|
||||
|
||||
# Accumulate step counts
|
||||
for sr in stage_result.step_results.values():
|
||||
if sr.status == "completed":
|
||||
result.steps_completed += 1
|
||||
elif sr.status == "cached":
|
||||
result.steps_cached += 1
|
||||
elif sr.status == "failed":
|
||||
result.steps_failed += 1
|
||||
|
||||
# Get final output
|
||||
if plan.stage_plans:
|
||||
last_stage = plan.stage_plans[-1]
|
||||
if last_stage.stage_name in result.stage_results:
|
||||
stage_res = result.stage_results[last_stage.stage_name]
|
||||
result.output_cache_id = last_stage.cache_id
|
||||
# Find the output step's path from step results
|
||||
for step_res in stage_res.step_results.values():
|
||||
if step_res.output_path:
|
||||
result.output_path = step_res.output_path
|
||||
|
||||
result.status = "completed"
|
||||
logger.info(
|
||||
f"Plan {plan.plan_id[:16]}... completed: "
|
||||
f"{result.stages_completed} stages executed, "
|
||||
f"{result.stages_cached} stages cached"
|
||||
)
|
||||
return result
|
||||
|
||||
def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]:
|
||||
"""Group stage plans by their level."""
|
||||
by_level = {}
|
||||
for stage_plan in stage_plans:
|
||||
by_level.setdefault(stage_plan.level, []).append(stage_plan)
|
||||
return by_level
|
||||
|
||||
def _is_stage_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a stage is cached."""
|
||||
if self.stage_cache is None:
|
||||
return False
|
||||
return self.stage_cache.has_stage(cache_id)
|
||||
|
||||
def _load_cached_stage(self, cache_id: str):
|
||||
"""Load a cached stage entry."""
|
||||
if self.stage_cache is None:
|
||||
return None
|
||||
return self.stage_cache.load_stage(cache_id)
|
||||
|
||||
def _cache_stage(self, stage_plan, stage_result: StageResult) -> None:
|
||||
"""Cache a stage result."""
|
||||
if self.stage_cache is None:
|
||||
return
|
||||
|
||||
from .stage_cache import StageCacheEntry, StageOutput
|
||||
|
||||
outputs = {}
|
||||
for name, cache_id in stage_result.outputs.items():
|
||||
outputs[name] = StageOutput(
|
||||
cache_id=cache_id,
|
||||
output_type="artifact",
|
||||
)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
outputs=outputs,
|
||||
)
|
||||
self.stage_cache.save_stage(entry)
|
||||
|
||||
def _execute_stage(
|
||||
self,
|
||||
stage_plan,
|
||||
plan: ExecutionPlanSexp,
|
||||
stage_outputs: Dict,
|
||||
timeout: int,
|
||||
) -> StageResult:
|
||||
"""
|
||||
Execute a single stage.
|
||||
|
||||
Uses the PlanScheduler to execute the stage's steps.
|
||||
"""
|
||||
# Create a mini-plan with just this stage's steps
|
||||
stage_steps = stage_plan.steps
|
||||
|
||||
# Build step lookup
|
||||
steps_by_id = {s.step_id: s for s in stage_steps}
|
||||
steps_by_level = {}
|
||||
for step in stage_steps:
|
||||
steps_by_level.setdefault(step.level, []).append(step)
|
||||
|
||||
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||
|
||||
# Track step results
|
||||
step_results = {}
|
||||
cache_ids = dict(plan.inputs) # Start with input hashes
|
||||
for step in plan.steps:
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Include outputs from previous stages
|
||||
for stage_name, outputs in stage_outputs.items():
|
||||
for binding_name, binding_cache_id in outputs.items():
|
||||
cache_ids[binding_name] = binding_cache_id
|
||||
|
||||
# Execute steps level by level
|
||||
for level in range(max_level + 1):
|
||||
level_steps = steps_by_level.get(level, [])
|
||||
if not level_steps:
|
||||
continue
|
||||
|
||||
# Check cache for each step
|
||||
steps_to_run = []
|
||||
for step in level_steps:
|
||||
if self._is_step_cached(step.cache_id):
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="cached",
|
||||
output_path=self._get_cached_path(step.cache_id),
|
||||
)
|
||||
else:
|
||||
steps_to_run.append(step)
|
||||
|
||||
if not steps_to_run:
|
||||
continue
|
||||
|
||||
# Execute steps (for now, sequentially - L1 will add Celery dispatch)
|
||||
for step in steps_to_run:
|
||||
# In a full implementation, this would dispatch to Celery
|
||||
# For now, mark as pending
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# If Celery is configured, dispatch the task
|
||||
if self.celery_app:
|
||||
try:
|
||||
task_result = self._dispatch_step(step, cache_ids, timeout)
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status=task_result.get("status", "completed"),
|
||||
output_path=task_result.get("output_path"),
|
||||
error=task_result.get("error"),
|
||||
ipfs_cid=task_result.get("ipfs_cid"),
|
||||
)
|
||||
except Exception as e:
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
)
|
||||
return StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="failed",
|
||||
step_results=step_results,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Build output bindings
|
||||
outputs = {}
|
||||
for out_name, node_id in stage_plan.output_bindings.items():
|
||||
outputs[out_name] = cache_ids.get(node_id, node_id)
|
||||
|
||||
return StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="completed",
|
||||
step_results=step_results,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _is_step_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a step is cached."""
|
||||
if self.cache_manager is None:
|
||||
return False
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return path is not None
|
||||
|
||||
def _get_cached_path(self, cache_id: str) -> Optional[str]:
|
||||
"""Get the path for a cached step."""
|
||||
if self.cache_manager is None:
|
||||
return None
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return str(path) if path else None
|
||||
|
||||
def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict:
|
||||
"""Dispatch a step to Celery for execution."""
|
||||
if self.celery_app is None:
|
||||
raise RuntimeError("Celery app not configured")
|
||||
|
||||
task = self.celery_app.tasks[self.execute_task_name]
|
||||
|
||||
step_sexp = step_sexp_to_string(step)
|
||||
input_cache_ids = {
|
||||
inp: cache_ids.get(inp, inp)
|
||||
for inp in step.inputs
|
||||
}
|
||||
|
||||
async_result = task.apply_async(
|
||||
kwargs={
|
||||
"step_sexp": step_sexp,
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"input_cache_ids": input_cache_ids,
|
||||
}
|
||||
)
|
||||
|
||||
return async_result.get(timeout=timeout)
|
||||
|
||||
|
||||
def create_stage_scheduler(
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
) -> StagePlanScheduler:
|
||||
"""
|
||||
Create a stage-aware scheduler with the given dependencies.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for step-level caching
|
||||
stage_cache: StageCache instance for stage-level caching
|
||||
celery_app: Celery application instance
|
||||
|
||||
Returns:
|
||||
StagePlanScheduler
|
||||
"""
|
||||
if celery_app is None:
|
||||
try:
|
||||
from celery_app import app as celery_app
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if cache_manager is None:
|
||||
try:
|
||||
from cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return StagePlanScheduler(
|
||||
cache_manager=cache_manager,
|
||||
stage_cache=stage_cache,
|
||||
celery_app=celery_app,
|
||||
)
|
||||
|
||||
|
||||
def schedule_staged_plan(
|
||||
plan: ExecutionPlanSexp,
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
timeout: int = 3600,
|
||||
) -> StagePlanResult:
|
||||
"""
|
||||
Convenience function to schedule a plan with stage awareness.
|
||||
|
||||
Args:
|
||||
plan: The execution plan
|
||||
cache_manager: Optional step-level cache manager
|
||||
stage_cache: Optional stage-level cache
|
||||
celery_app: Optional Celery app
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
StagePlanResult
|
||||
"""
|
||||
scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app)
|
||||
return scheduler.schedule(plan, timeout=timeout)
|
||||
412
artdag/sexp/stage_cache.py
Normal file
412
artdag/sexp/stage_cache.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Stage-level cache layer using S-expression storage.
|
||||
|
||||
Provides caching for stage outputs, enabling:
|
||||
- Stage-level cache hits (skip entire stage execution)
|
||||
- Analysis result persistence as sexp
|
||||
- Cross-worker stage cache sharing (for L1 Celery integration)
|
||||
|
||||
All cache files use .sexp extension - no JSON in the pipeline.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .parser import Symbol, Keyword, parse, serialize
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""A single output from a stage."""
|
||||
cache_id: Optional[str] = None # For artifacts (files, analysis data)
|
||||
value: Any = None # For scalar values
|
||||
output_type: str = "artifact" # "artifact", "analysis", "scalar"
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""Convert to S-expression."""
|
||||
sexp = []
|
||||
if self.cache_id:
|
||||
sexp.extend([Keyword("cache-id"), self.cache_id])
|
||||
if self.value is not None:
|
||||
sexp.extend([Keyword("value"), self.value])
|
||||
sexp.extend([Keyword("type"), Keyword(self.output_type)])
|
||||
return sexp
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'StageOutput':
|
||||
"""Parse from S-expression list."""
|
||||
cache_id = None
|
||||
value = None
|
||||
output_type = "artifact"
|
||||
|
||||
i = 0
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "cache-id":
|
||||
cache_id = val
|
||||
elif key == "value":
|
||||
value = val
|
||||
elif key == "type":
|
||||
if isinstance(val, Keyword):
|
||||
output_type = val.name
|
||||
else:
|
||||
output_type = str(val)
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return cls(cache_id=cache_id, value=value, output_type=output_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageCacheEntry:
|
||||
"""Cached result of a stage execution."""
|
||||
stage_name: str
|
||||
cache_id: str
|
||||
outputs: Dict[str, StageOutput] # binding_name -> output info
|
||||
completed_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""
|
||||
Convert to S-expression for storage.
|
||||
|
||||
Format:
|
||||
(stage-result
|
||||
:name "analyze-a"
|
||||
:cache-id "abc123..."
|
||||
:completed-at 1705678900.123
|
||||
:outputs
|
||||
((beats-a :cache-id "def456..." :type :analysis)
|
||||
(tempo :value 120.5 :type :scalar)))
|
||||
"""
|
||||
sexp = [Symbol("stage-result")]
|
||||
sexp.extend([Keyword("name"), self.stage_name])
|
||||
sexp.extend([Keyword("cache-id"), self.cache_id])
|
||||
sexp.extend([Keyword("completed-at"), self.completed_at])
|
||||
|
||||
if self.outputs:
|
||||
outputs_sexp = []
|
||||
for name, output in self.outputs.items():
|
||||
output_sexp = [Symbol(name)] + output.to_sexp()
|
||||
outputs_sexp.append(output_sexp)
|
||||
sexp.extend([Keyword("outputs"), outputs_sexp])
|
||||
|
||||
if self.metadata:
|
||||
sexp.extend([Keyword("metadata"), self.metadata])
|
||||
|
||||
return sexp
|
||||
|
||||
def to_string(self, pretty: bool = True) -> str:
|
||||
"""Serialize to S-expression string."""
|
||||
return serialize(self.to_sexp(), pretty=pretty)
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'StageCacheEntry':
|
||||
"""Parse from S-expression."""
|
||||
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "stage-result":
|
||||
raise ValueError("Invalid stage-result sexp")
|
||||
|
||||
stage_name = None
|
||||
cache_id = None
|
||||
completed_at = time.time()
|
||||
outputs = {}
|
||||
metadata = {}
|
||||
|
||||
i = 1
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "name":
|
||||
stage_name = val
|
||||
elif key == "cache-id":
|
||||
cache_id = val
|
||||
elif key == "completed-at":
|
||||
completed_at = float(val)
|
||||
elif key == "outputs":
|
||||
if isinstance(val, list):
|
||||
for output_sexp in val:
|
||||
if isinstance(output_sexp, list) and output_sexp:
|
||||
out_name = output_sexp[0]
|
||||
if isinstance(out_name, Symbol):
|
||||
out_name = out_name.name
|
||||
outputs[out_name] = StageOutput.from_sexp(output_sexp[1:])
|
||||
elif key == "metadata":
|
||||
metadata = val if isinstance(val, dict) else {}
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if not stage_name or not cache_id:
|
||||
raise ValueError("stage-result missing required fields (name, cache-id)")
|
||||
|
||||
return cls(
|
||||
stage_name=stage_name,
|
||||
cache_id=cache_id,
|
||||
outputs=outputs,
|
||||
completed_at=completed_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, text: str) -> 'StageCacheEntry':
|
||||
"""Parse from S-expression string."""
|
||||
sexp = parse(text)
|
||||
return cls.from_sexp(sexp)
|
||||
|
||||
|
||||
class StageCache:
|
||||
"""
|
||||
Stage-level cache manager using S-expression files.
|
||||
|
||||
Cache structure:
|
||||
cache_dir/
|
||||
_stages/
|
||||
{cache_id}.sexp <- Stage result files
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Union[str, Path]):
|
||||
"""
|
||||
Initialize stage cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.stages_dir = self.cache_dir / "_stages"
|
||||
self.stages_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_cache_path(self, cache_id: str) -> Path:
|
||||
"""Get the path for a stage cache file."""
|
||||
return self.stages_dir / f"{cache_id}.sexp"
|
||||
|
||||
def has_stage(self, cache_id: str) -> bool:
|
||||
"""Check if a stage result is cached."""
|
||||
return self.get_cache_path(cache_id).exists()
|
||||
|
||||
def load_stage(self, cache_id: str) -> Optional[StageCacheEntry]:
|
||||
"""
|
||||
Load a cached stage result.
|
||||
|
||||
Args:
|
||||
cache_id: Stage cache ID
|
||||
|
||||
Returns:
|
||||
StageCacheEntry if found, None otherwise
|
||||
"""
|
||||
path = self.get_cache_path(cache_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text()
|
||||
return StageCacheEntry.from_string(content)
|
||||
except Exception as e:
|
||||
# Corrupted cache file - log and return None
|
||||
import sys
|
||||
print(f"Warning: corrupted stage cache {cache_id}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def save_stage(self, entry: StageCacheEntry) -> Path:
|
||||
"""
|
||||
Save a stage result to cache.
|
||||
|
||||
Args:
|
||||
entry: Stage cache entry to save
|
||||
|
||||
Returns:
|
||||
Path to the saved cache file
|
||||
"""
|
||||
path = self.get_cache_path(entry.cache_id)
|
||||
content = entry.to_string(pretty=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
def delete_stage(self, cache_id: str) -> bool:
|
||||
"""
|
||||
Delete a cached stage result.
|
||||
|
||||
Args:
|
||||
cache_id: Stage cache ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
path = self.get_cache_path(cache_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_stages(self) -> List[str]:
|
||||
"""List all cached stage IDs."""
|
||||
return [
|
||||
p.stem for p in self.stages_dir.glob("*.sexp")
|
||||
]
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cached stages.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
for path in self.stages_dir.glob("*.sexp"):
|
||||
path.unlink()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""
|
||||
Analysis result stored as S-expression.
|
||||
|
||||
Format:
|
||||
(analysis-result
|
||||
:analyzer "beats"
|
||||
:input-hash "abc123..."
|
||||
:duration 120.5
|
||||
:tempo 128.0
|
||||
:times (0.0 0.468 0.937 1.406 ...)
|
||||
:values (0.8 0.9 0.7 0.85 ...))
|
||||
"""
|
||||
analyzer: str
|
||||
input_hash: str
|
||||
data: Dict[str, Any] # Analysis data (times, values, duration, etc.)
|
||||
computed_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""Convert to S-expression."""
|
||||
sexp = [Symbol("analysis-result")]
|
||||
sexp.extend([Keyword("analyzer"), self.analyzer])
|
||||
sexp.extend([Keyword("input-hash"), self.input_hash])
|
||||
sexp.extend([Keyword("computed-at"), self.computed_at])
|
||||
|
||||
# Add all data fields
|
||||
for key, value in self.data.items():
|
||||
# Convert key to keyword
|
||||
sexp.extend([Keyword(key.replace("_", "-")), value])
|
||||
|
||||
return sexp
|
||||
|
||||
def to_string(self, pretty: bool = True) -> str:
|
||||
"""Serialize to S-expression string."""
|
||||
return serialize(self.to_sexp(), pretty=pretty)
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'AnalysisResult':
|
||||
"""Parse from S-expression."""
|
||||
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis-result":
|
||||
raise ValueError("Invalid analysis-result sexp")
|
||||
|
||||
analyzer = None
|
||||
input_hash = None
|
||||
computed_at = time.time()
|
||||
data = {}
|
||||
|
||||
i = 1
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "analyzer":
|
||||
analyzer = val
|
||||
elif key == "input-hash":
|
||||
input_hash = val
|
||||
elif key == "computed-at":
|
||||
computed_at = float(val)
|
||||
else:
|
||||
# Convert kebab-case back to snake_case
|
||||
data_key = key.replace("-", "_")
|
||||
data[data_key] = val
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if not analyzer:
|
||||
raise ValueError("analysis-result missing analyzer field")
|
||||
|
||||
return cls(
|
||||
analyzer=analyzer,
|
||||
input_hash=input_hash or "",
|
||||
data=data,
|
||||
computed_at=computed_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, text: str) -> 'AnalysisResult':
|
||||
"""Parse from S-expression string."""
|
||||
sexp = parse(text)
|
||||
return cls.from_sexp(sexp)
|
||||
|
||||
|
||||
def save_analysis_result(
|
||||
cache_dir: Union[str, Path],
|
||||
node_id: str,
|
||||
result: AnalysisResult,
|
||||
) -> Path:
|
||||
"""
|
||||
Save an analysis result as S-expression.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
node_id: Node ID for the analysis
|
||||
result: Analysis result to save
|
||||
|
||||
Returns:
|
||||
Path to the saved file
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
node_dir = cache_dir / node_id
|
||||
node_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = node_dir / "analysis.sexp"
|
||||
content = result.to_string(pretty=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
|
||||
def load_analysis_result(
|
||||
cache_dir: Union[str, Path],
|
||||
node_id: str,
|
||||
) -> Optional[AnalysisResult]:
|
||||
"""
|
||||
Load an analysis result from cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
node_id: Node ID for the analysis
|
||||
|
||||
Returns:
|
||||
AnalysisResult if found, None otherwise
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
path = cache_dir / node_id / "analysis.sexp"
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text()
|
||||
return AnalysisResult.from_string(content)
|
||||
except Exception as e:
|
||||
import sys
|
||||
print(f"Warning: corrupted analysis cache {node_id}: {e}", file=sys.stderr)
|
||||
return None
|
||||
146
artdag/sexp/test_ffmpeg_compiler.py
Normal file
146
artdag/sexp/test_ffmpeg_compiler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Tests for FFmpeg filter compilation.
|
||||
|
||||
Validates that each filter mapping produces valid FFmpeg commands.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .ffmpeg_compiler import FFmpegCompiler, EFFECT_MAPPINGS
|
||||
|
||||
|
||||
def test_filter_syntax(filter_str: str, duration: float = 0.1, is_complex: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Test if an FFmpeg filter string is valid by running it on a test pattern.
|
||||
|
||||
Args:
|
||||
filter_str: The filter string to test
|
||||
duration: Duration of test video
|
||||
is_complex: If True, use -filter_complex instead of -vf
|
||||
|
||||
Returns (success, error_message)
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
if is_complex:
|
||||
# Complex filter graph needs -filter_complex and explicit output mapping
|
||||
cmd = [
|
||||
'ffmpeg', '-y',
|
||||
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
|
||||
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
|
||||
'-filter_complex', f'[0:v]{filter_str}[out]',
|
||||
'-map', '[out]', '-map', '1:a',
|
||||
'-c:v', 'libx264', '-preset', 'ultrafast',
|
||||
'-c:a', 'aac',
|
||||
'-t', str(duration),
|
||||
output_path
|
||||
]
|
||||
else:
|
||||
# Simple filter uses -vf
|
||||
cmd = [
|
||||
'ffmpeg', '-y',
|
||||
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
|
||||
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
|
||||
'-vf', filter_str,
|
||||
'-c:v', 'libx264', '-preset', 'ultrafast',
|
||||
'-c:a', 'aac',
|
||||
'-t', str(duration),
|
||||
output_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
else:
|
||||
# Extract relevant error
|
||||
stderr = result.stderr
|
||||
for line in stderr.split('\n'):
|
||||
if 'Error' in line or 'error' in line or 'Invalid' in line:
|
||||
return False, line.strip()
|
||||
return False, stderr[-500:] if len(stderr) > 500 else stderr
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
finally:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Test all effect mappings."""
|
||||
compiler = FFmpegCompiler()
|
||||
results = []
|
||||
|
||||
for effect_name, mapping in EFFECT_MAPPINGS.items():
|
||||
filter_name = mapping.get("filter")
|
||||
|
||||
# Skip effects with no FFmpeg equivalent (external tools or python primitives)
|
||||
if filter_name is None:
|
||||
reason = "No FFmpeg equivalent"
|
||||
if mapping.get("external_tool"):
|
||||
reason = f"External tool: {mapping['external_tool']}"
|
||||
elif mapping.get("python_primitive"):
|
||||
reason = f"Python primitive: {mapping['python_primitive']}"
|
||||
results.append((effect_name, "SKIP", reason))
|
||||
continue
|
||||
|
||||
# Check if complex filter
|
||||
is_complex = mapping.get("complex", False)
|
||||
|
||||
# Build filter string
|
||||
if "static" in mapping:
|
||||
filter_str = f"{filter_name}={mapping['static']}"
|
||||
else:
|
||||
filter_str = filter_name
|
||||
|
||||
# Test it
|
||||
success, error = test_filter_syntax(filter_str, is_complex=is_complex)
|
||||
|
||||
if success:
|
||||
results.append((effect_name, "PASS", filter_str))
|
||||
else:
|
||||
results.append((effect_name, "FAIL", f"{filter_str} -> {error}"))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(results):
|
||||
"""Print test results."""
|
||||
passed = sum(1 for _, status, _ in results if status == "PASS")
|
||||
failed = sum(1 for _, status, _ in results if status == "FAIL")
|
||||
skipped = sum(1 for _, status, _ in results if status == "SKIP")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"FFmpeg Filter Tests: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Print failures first
|
||||
if failed > 0:
|
||||
print("FAILURES:")
|
||||
for name, status, msg in results:
|
||||
if status == "FAIL":
|
||||
print(f" {name}: {msg}")
|
||||
print()
|
||||
|
||||
# Print passes
|
||||
print("PASSED:")
|
||||
for name, status, msg in results:
|
||||
if status == "PASS":
|
||||
print(f" {name}: {msg}")
|
||||
|
||||
# Print skips
|
||||
if skipped > 0:
|
||||
print("\nSKIPPED (Python fallback):")
|
||||
for name, status, msg in results:
|
||||
if status == "SKIP":
|
||||
print(f" {name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = run_all_tests()
|
||||
print_results(results)
|
||||
201
artdag/sexp/test_primitives.py
Normal file
201
artdag/sexp/test_primitives.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Tests for Python primitive effects.
|
||||
|
||||
Tests that ascii_art, ascii_zones, and other Python primitives
|
||||
can be executed via the EffectExecutor.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
HAS_DEPS = True
|
||||
except ImportError:
|
||||
HAS_DEPS = False
|
||||
|
||||
from .primitives import (
|
||||
ascii_art_frame,
|
||||
ascii_zones_frame,
|
||||
get_primitive,
|
||||
list_primitives,
|
||||
)
|
||||
from .ffmpeg_compiler import FFmpegCompiler
|
||||
|
||||
|
||||
def create_test_video(path: Path, duration: float = 0.5, size: str = "64x64") -> bool:
|
||||
"""Create a short test video using ffmpeg."""
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10",
|
||||
"-c:v", "libx264", "-preset", "ultrafast",
|
||||
str(path)
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
|
||||
class TestPrimitives:
|
||||
"""Test primitive functions directly."""
|
||||
|
||||
def test_ascii_art_frame_basic(self):
|
||||
"""Test ascii_art_frame produces output of same shape."""
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
result = ascii_art_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
assert result.dtype == np.uint8
|
||||
|
||||
def test_ascii_zones_frame_basic(self):
|
||||
"""Test ascii_zones_frame produces output of same shape."""
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
result = ascii_zones_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
assert result.dtype == np.uint8
|
||||
|
||||
def test_get_primitive(self):
|
||||
"""Test primitive lookup."""
|
||||
assert get_primitive("ascii_art_frame") is ascii_art_frame
|
||||
assert get_primitive("ascii_zones_frame") is ascii_zones_frame
|
||||
assert get_primitive("nonexistent") is None
|
||||
|
||||
def test_list_primitives(self):
|
||||
"""Test listing primitives."""
|
||||
primitives = list_primitives()
|
||||
assert "ascii_art_frame" in primitives
|
||||
assert "ascii_zones_frame" in primitives
|
||||
assert len(primitives) > 5
|
||||
|
||||
|
||||
class TestFFmpegCompilerPrimitives:
|
||||
"""Test FFmpegCompiler python_primitive mappings."""
|
||||
|
||||
def test_has_python_primitive_ascii_art(self):
|
||||
"""Test ascii_art has python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
|
||||
|
||||
def test_has_python_primitive_ascii_zones(self):
|
||||
"""Test ascii_zones has python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
|
||||
|
||||
def test_has_python_primitive_ffmpeg_effect(self):
|
||||
"""Test FFmpeg effects don't have python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("brightness") is None
|
||||
assert compiler.has_python_primitive("blur") is None
|
||||
|
||||
def test_compile_effect_returns_none_for_primitives(self):
|
||||
"""Test compile_effect returns None for primitive effects."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.compile_effect("ascii_art", {}) is None
|
||||
assert compiler.compile_effect("ascii_zones", {}) is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
|
||||
class TestEffectExecutorPrimitives:
|
||||
"""Test EffectExecutor with Python primitives."""
|
||||
|
||||
def test_executor_loads_primitive(self):
|
||||
"""Test that executor finds primitive effects."""
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
effect_fn = _get_python_primitive_effect("ascii_art")
|
||||
assert effect_fn is not None
|
||||
|
||||
effect_fn = _get_python_primitive_effect("ascii_zones")
|
||||
assert effect_fn is not None
|
||||
|
||||
def test_executor_rejects_unknown_effect(self):
|
||||
"""Test that executor returns None for unknown effects."""
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
effect_fn = _get_python_primitive_effect("nonexistent_effect")
|
||||
assert effect_fn is None
|
||||
|
||||
def test_execute_ascii_art_effect(self, tmp_path):
|
||||
"""Test executing ascii_art effect on a video."""
|
||||
from ..nodes.effect import EffectExecutor
|
||||
|
||||
# Create test video
|
||||
input_path = tmp_path / "input.mp4"
|
||||
if not create_test_video(input_path):
|
||||
pytest.skip("Could not create test video")
|
||||
|
||||
output_path = tmp_path / "output.mkv"
|
||||
executor = EffectExecutor()
|
||||
|
||||
result = executor.execute(
|
||||
config={"effect": "ascii_art", "char_size": 8},
|
||||
inputs=[input_path],
|
||||
output_path=output_path,
|
||||
)
|
||||
|
||||
assert result.exists()
|
||||
assert result.stat().st_size > 0
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run tests manually."""
|
||||
import sys
|
||||
|
||||
# Check dependencies
|
||||
if not HAS_DEPS:
|
||||
print("SKIP: numpy/PIL not available")
|
||||
return
|
||||
|
||||
print("Testing primitives...")
|
||||
|
||||
# Test primitive functions
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
|
||||
print(" ascii_art_frame...", end=" ")
|
||||
result = ascii_art_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS")
|
||||
|
||||
print(" ascii_zones_frame...", end=" ")
|
||||
result = ascii_zones_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS")
|
||||
|
||||
# Test FFmpegCompiler mappings
|
||||
print("\nTesting FFmpegCompiler mappings...")
|
||||
compiler = FFmpegCompiler()
|
||||
|
||||
print(" ascii_art python_primitive...", end=" ")
|
||||
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
|
||||
print("PASS")
|
||||
|
||||
print(" ascii_zones python_primitive...", end=" ")
|
||||
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
|
||||
print("PASS")
|
||||
|
||||
# Test executor lookup
|
||||
print("\nTesting EffectExecutor...")
|
||||
try:
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
print(" _get_python_primitive_effect(ascii_art)...", end=" ")
|
||||
effect_fn = _get_python_primitive_effect("ascii_art")
|
||||
assert effect_fn is not None
|
||||
print("PASS")
|
||||
|
||||
print(" _get_python_primitive_effect(ascii_zones)...", end=" ")
|
||||
effect_fn = _get_python_primitive_effect("ascii_zones")
|
||||
assert effect_fn is not None
|
||||
print("PASS")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"SKIP: {e}")
|
||||
|
||||
print("\n=== All tests passed ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
||||
324
artdag/sexp/test_stage_cache.py
Normal file
324
artdag/sexp/test_stage_cache.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Tests for stage cache layer.
|
||||
|
||||
Tests S-expression storage for stage results and analysis data.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .stage_cache import (
|
||||
StageCache,
|
||||
StageCacheEntry,
|
||||
StageOutput,
|
||||
AnalysisResult,
|
||||
save_analysis_result,
|
||||
load_analysis_result,
|
||||
)
|
||||
from .parser import parse, serialize
|
||||
|
||||
|
||||
class TestStageOutput:
|
||||
"""Test StageOutput dataclass and serialization."""
|
||||
|
||||
def test_stage_output_artifact(self):
|
||||
"""StageOutput can represent an artifact."""
|
||||
output = StageOutput(
|
||||
cache_id="abc123",
|
||||
output_type="artifact",
|
||||
)
|
||||
assert output.cache_id == "abc123"
|
||||
assert output.output_type == "artifact"
|
||||
|
||||
def test_stage_output_scalar(self):
|
||||
"""StageOutput can represent a scalar value."""
|
||||
output = StageOutput(
|
||||
value=120.5,
|
||||
output_type="scalar",
|
||||
)
|
||||
assert output.value == 120.5
|
||||
assert output.output_type == "scalar"
|
||||
|
||||
def test_stage_output_to_sexp(self):
|
||||
"""StageOutput serializes to sexp."""
|
||||
output = StageOutput(
|
||||
cache_id="abc123",
|
||||
output_type="artifact",
|
||||
)
|
||||
sexp = output.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "cache-id" in sexp_str
|
||||
assert "abc123" in sexp_str
|
||||
assert "type" in sexp_str
|
||||
assert "artifact" in sexp_str
|
||||
|
||||
def test_stage_output_from_sexp(self):
|
||||
"""StageOutput parses from sexp."""
|
||||
sexp = parse('(:cache-id "def456" :type :analysis)')
|
||||
output = StageOutput.from_sexp(sexp)
|
||||
|
||||
assert output.cache_id == "def456"
|
||||
assert output.output_type == "analysis"
|
||||
|
||||
|
||||
class TestStageCacheEntry:
|
||||
"""Test StageCacheEntry serialization."""
|
||||
|
||||
def test_stage_cache_entry_to_sexp(self):
|
||||
"""StageCacheEntry serializes to sexp."""
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze-a",
|
||||
cache_id="stage_abc123",
|
||||
outputs={
|
||||
"beats": StageOutput(cache_id="beats_def456", output_type="analysis"),
|
||||
"tempo": StageOutput(value=120.5, output_type="scalar"),
|
||||
},
|
||||
completed_at=1705678900.123,
|
||||
)
|
||||
|
||||
sexp = entry.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "stage-result" in sexp_str
|
||||
assert "analyze-a" in sexp_str
|
||||
assert "stage_abc123" in sexp_str
|
||||
assert "outputs" in sexp_str
|
||||
assert "beats" in sexp_str
|
||||
|
||||
def test_stage_cache_entry_roundtrip(self):
|
||||
"""save -> load produces identical data."""
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze-b",
|
||||
cache_id="stage_xyz789",
|
||||
outputs={
|
||||
"segments": StageOutput(cache_id="seg_123", output_type="artifact"),
|
||||
},
|
||||
completed_at=1705678900.0,
|
||||
)
|
||||
|
||||
sexp_str = entry.to_string()
|
||||
loaded = StageCacheEntry.from_string(sexp_str)
|
||||
|
||||
assert loaded.stage_name == entry.stage_name
|
||||
assert loaded.cache_id == entry.cache_id
|
||||
assert "segments" in loaded.outputs
|
||||
assert loaded.outputs["segments"].cache_id == "seg_123"
|
||||
|
||||
def test_stage_cache_entry_from_sexp(self):
|
||||
"""StageCacheEntry parses from sexp."""
|
||||
sexp_str = '''
|
||||
(stage-result
|
||||
:name "test-stage"
|
||||
:cache-id "cache123"
|
||||
:completed-at 1705678900.0
|
||||
:outputs ((beats :cache-id "beats123" :type :analysis)))
|
||||
'''
|
||||
entry = StageCacheEntry.from_string(sexp_str)
|
||||
|
||||
assert entry.stage_name == "test-stage"
|
||||
assert entry.cache_id == "cache123"
|
||||
assert "beats" in entry.outputs
|
||||
assert entry.outputs["beats"].cache_id == "beats123"
|
||||
|
||||
|
||||
class TestStageCache:
|
||||
"""Test StageCache file operations."""
|
||||
|
||||
def test_save_and_load_stage(self):
|
||||
"""Save and load a stage result."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="test_cache_id",
|
||||
outputs={
|
||||
"beats": StageOutput(cache_id="beats_out", output_type="analysis"),
|
||||
},
|
||||
)
|
||||
|
||||
path = cache.save_stage(entry)
|
||||
assert path.exists()
|
||||
assert path.suffix == ".sexp"
|
||||
|
||||
loaded = cache.load_stage("test_cache_id")
|
||||
assert loaded is not None
|
||||
assert loaded.stage_name == "analyze"
|
||||
assert "beats" in loaded.outputs
|
||||
|
||||
def test_has_stage(self):
|
||||
"""Check if stage is cached."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
assert not cache.has_stage("nonexistent")
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="test",
|
||||
cache_id="exists_cache_id",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
assert cache.has_stage("exists_cache_id")
|
||||
|
||||
def test_delete_stage(self):
|
||||
"""Delete a cached stage."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="test",
|
||||
cache_id="to_delete",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
assert cache.has_stage("to_delete")
|
||||
result = cache.delete_stage("to_delete")
|
||||
assert result is True
|
||||
assert not cache.has_stage("to_delete")
|
||||
|
||||
def test_list_stages(self):
|
||||
"""List all cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
for i in range(3):
|
||||
entry = StageCacheEntry(
|
||||
stage_name=f"stage{i}",
|
||||
cache_id=f"cache_{i}",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
stages = cache.list_stages()
|
||||
assert len(stages) == 3
|
||||
assert "cache_0" in stages
|
||||
assert "cache_1" in stages
|
||||
assert "cache_2" in stages
|
||||
|
||||
def test_clear(self):
|
||||
"""Clear all cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
for i in range(3):
|
||||
entry = StageCacheEntry(
|
||||
stage_name=f"stage{i}",
|
||||
cache_id=f"cache_{i}",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
count = cache.clear()
|
||||
assert count == 3
|
||||
assert len(cache.list_stages()) == 0
|
||||
|
||||
def test_cache_file_extension(self):
|
||||
"""Cache files use .sexp extension."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
path = cache.get_cache_path("test_id")
|
||||
assert path.suffix == ".sexp"
|
||||
|
||||
def test_invalid_sexp_error_handling(self):
|
||||
"""Graceful error on corrupted cache file."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
# Write corrupted content
|
||||
corrupt_path = cache.get_cache_path("corrupted")
|
||||
corrupt_path.write_text("this is not valid sexp )()(")
|
||||
|
||||
# Should return None, not raise
|
||||
result = cache.load_stage("corrupted")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAnalysisResult:
|
||||
"""Test AnalysisResult serialization."""
|
||||
|
||||
def test_analysis_result_to_sexp(self):
|
||||
"""AnalysisResult serializes to sexp."""
|
||||
result = AnalysisResult(
|
||||
analyzer="beats",
|
||||
input_hash="input_abc123",
|
||||
data={
|
||||
"duration": 120.5,
|
||||
"tempo": 128.0,
|
||||
"times": [0.0, 0.468, 0.937, 1.406],
|
||||
"values": [0.8, 0.9, 0.7, 0.85],
|
||||
},
|
||||
)
|
||||
|
||||
sexp = result.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "analysis-result" in sexp_str
|
||||
assert "beats" in sexp_str
|
||||
assert "duration" in sexp_str
|
||||
assert "tempo" in sexp_str
|
||||
assert "times" in sexp_str
|
||||
|
||||
def test_analysis_result_roundtrip(self):
|
||||
"""Analysis result round-trips through sexp."""
|
||||
original = AnalysisResult(
|
||||
analyzer="scenes",
|
||||
input_hash="video_xyz",
|
||||
data={
|
||||
"scene_count": 5,
|
||||
"scene_times": [0.0, 10.5, 25.0, 45.2, 60.0],
|
||||
},
|
||||
)
|
||||
|
||||
sexp_str = original.to_string()
|
||||
loaded = AnalysisResult.from_string(sexp_str)
|
||||
|
||||
assert loaded.analyzer == original.analyzer
|
||||
assert loaded.input_hash == original.input_hash
|
||||
assert loaded.data["scene_count"] == 5
|
||||
|
||||
def test_save_and_load_analysis_result(self):
|
||||
"""Save and load analysis result from cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = AnalysisResult(
|
||||
analyzer="beats",
|
||||
input_hash="audio_123",
|
||||
data={
|
||||
"tempo": 120.0,
|
||||
"times": [0.0, 0.5, 1.0],
|
||||
},
|
||||
)
|
||||
|
||||
path = save_analysis_result(tmpdir, "node_abc", result)
|
||||
assert path.exists()
|
||||
assert path.name == "analysis.sexp"
|
||||
|
||||
loaded = load_analysis_result(tmpdir, "node_abc")
|
||||
assert loaded is not None
|
||||
assert loaded.analyzer == "beats"
|
||||
assert loaded.data["tempo"] == 120.0
|
||||
|
||||
def test_analysis_result_kebab_case(self):
|
||||
"""Keys convert between snake_case and kebab-case."""
|
||||
result = AnalysisResult(
|
||||
analyzer="test",
|
||||
input_hash="hash",
|
||||
data={
|
||||
"scene_count": 5,
|
||||
"beat_times": [1, 2, 3],
|
||||
},
|
||||
)
|
||||
|
||||
sexp_str = result.to_string()
|
||||
# Kebab case in sexp
|
||||
assert "scene-count" in sexp_str
|
||||
assert "beat-times" in sexp_str
|
||||
|
||||
# Back to snake_case after parsing
|
||||
loaded = AnalysisResult.from_string(sexp_str)
|
||||
assert "scene_count" in loaded.data
|
||||
assert "beat_times" in loaded.data
|
||||
286
artdag/sexp/test_stage_compiler.py
Normal file
286
artdag/sexp/test_stage_compiler.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Tests for stage compilation and scoping.
|
||||
|
||||
Tests the CompiledStage dataclass, stage form parsing,
|
||||
variable scoping, and dependency validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .parser import parse, Symbol, Keyword
|
||||
from .compiler import (
|
||||
compile_recipe,
|
||||
CompileError,
|
||||
CompiledStage,
|
||||
CompilerContext,
|
||||
_topological_sort_stages,
|
||||
)
|
||||
|
||||
|
||||
class TestStageCompilation:
|
||||
"""Test stage form compilation."""
|
||||
|
||||
def test_parse_stage_form_basic(self):
|
||||
"""Stage parses correctly with name and outputs."""
|
||||
recipe = '''
|
||||
(recipe "test-stage"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 1
|
||||
assert compiled.stages[0].name == "analyze"
|
||||
assert "beats" in compiled.stages[0].outputs
|
||||
assert len(compiled.stages[0].node_ids) > 0
|
||||
|
||||
def test_parse_stage_with_requires(self):
|
||||
"""Stage parses correctly with requires and inputs."""
|
||||
recipe = '''
|
||||
(recipe "test-requires"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [segments]
|
||||
(def segments (-> audio (segment :times beats)))
|
||||
(-> segments (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
process_stage = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process_stage.requires == ["analyze"]
|
||||
assert "beats" in process_stage.inputs
|
||||
assert "segments" in process_stage.outputs
|
||||
|
||||
def test_stage_outputs_recorded(self):
|
||||
"""Stage outputs are tracked in CompiledStage."""
|
||||
recipe = '''
|
||||
(recipe "test-outputs"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo)))
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
stage = compiled.stages[0]
|
||||
assert "beats" in stage.outputs
|
||||
assert "tempo" in stage.outputs
|
||||
assert "beats" in stage.output_bindings
|
||||
assert "tempo" in stage.output_bindings
|
||||
|
||||
def test_stage_order_topological(self):
|
||||
"""Stages are topologically sorted."""
|
||||
recipe = '''
|
||||
(recipe "test-order"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
# analyze should come before output
|
||||
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
|
||||
|
||||
|
||||
class TestStageValidation:
|
||||
"""Test stage dependency and input validation."""
|
||||
|
||||
def test_stage_requires_validation(self):
|
||||
"""Error if requiring non-existent stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-require"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :process
|
||||
:requires [:nonexistent]
|
||||
:inputs [beats]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_stage_inputs_validation(self):
|
||||
"""Error if input not produced by required stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-input"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [nonexistent]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not an output of any required stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_undeclared_output_error(self):
|
||||
"""Error if stage declares output not defined in body."""
|
||||
recipe = '''
|
||||
(recipe "test-missing-output"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats nonexistent]
|
||||
(def beats (-> audio (analyze beats)))))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not defined in the stage body"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_forward_reference_detection(self):
|
||||
"""Error when requiring a stage not yet defined."""
|
||||
# Forward references are not allowed - stages must be defined
|
||||
# before they can be required
|
||||
recipe = '''
|
||||
(recipe "test-forward"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :a
|
||||
:requires [:b]
|
||||
:outputs [out-a]
|
||||
(def out-a audio))
|
||||
|
||||
(stage :b
|
||||
:outputs [out-b]
|
||||
(def out-b audio)
|
||||
audio))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
|
||||
class TestStageScoping:
|
||||
"""Test variable scoping between stages."""
|
||||
|
||||
def test_pre_stage_bindings_accessible(self):
|
||||
"""Sources defined before stages accessible to all stages."""
|
||||
recipe = '''
|
||||
(recipe "test-pre-stage"
|
||||
(def audio (source :path "test.mp3"))
|
||||
(def video (source :path "test.mp4"))
|
||||
|
||||
(stage :analyze-audio
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :analyze-video
|
||||
:outputs [scenes]
|
||||
(def scenes (-> video (analyze scenes)))
|
||||
(-> video (segment :times scenes) (sequence))))
|
||||
'''
|
||||
# Should compile without error - audio and video accessible to both stages
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
def test_stage_bindings_flow_through_requires(self):
|
||||
"""Stage bindings accessible to dependent stages via :inputs."""
|
||||
recipe = '''
|
||||
(recipe "test-binding-flow"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [result]
|
||||
(def result (-> audio (segment :times beats)))
|
||||
(-> result (sequence))))
|
||||
'''
|
||||
# Should compile without error - beats flows from analyze to process
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
|
||||
class TestTopologicalSort:
|
||||
"""Test stage topological sorting."""
|
||||
|
||||
def test_empty_stages(self):
|
||||
"""Empty stages returns empty list."""
|
||||
assert _topological_sort_stages({}) == []
|
||||
|
||||
def test_single_stage(self):
|
||||
"""Single stage returns single element."""
|
||||
stages = {
|
||||
"a": CompiledStage(
|
||||
name="a",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["out"],
|
||||
node_ids=["n1"],
|
||||
output_bindings={"out": "n1"},
|
||||
)
|
||||
}
|
||||
assert _topological_sort_stages(stages) == ["a"]
|
||||
|
||||
def test_linear_chain(self):
|
||||
"""Linear chain sorted correctly."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
"c": CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
assert result.index("a") < result.index("b") < result.index("c")
|
||||
|
||||
def test_parallel_stages_same_level(self):
|
||||
"""Parallel stages are both valid orderings."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
# Both a and b should be in result (order doesn't matter)
|
||||
assert set(result) == {"a", "b"}
|
||||
|
||||
def test_diamond_dependency(self):
|
||||
"""Diamond pattern: A -> B, A -> C, B+C -> D."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
"c": CompiledStage(name="c", requires=["a"], inputs=["x"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
"d": CompiledStage(name="d", requires=["b", "c"], inputs=["y", "z"], outputs=["out"],
|
||||
node_ids=["n4"], output_bindings={"out": "n4"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
# a must be first, d must be last
|
||||
assert result[0] == "a"
|
||||
assert result[-1] == "d"
|
||||
# b and c must be before d
|
||||
assert result.index("b") < result.index("d")
|
||||
assert result.index("c") < result.index("d")
|
||||
739
artdag/sexp/test_stage_integration.py
Normal file
739
artdag/sexp/test_stage_integration.py
Normal file
@@ -0,0 +1,739 @@
|
||||
"""
|
||||
End-to-end integration tests for staged recipes.
|
||||
|
||||
Tests the complete flow: compile -> plan -> execute
|
||||
for recipes with stages.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .parser import parse, serialize
|
||||
from .compiler import compile_recipe, CompileError
|
||||
from .planner import ExecutionPlanSexp, StagePlan
|
||||
from .stage_cache import StageCache, StageCacheEntry, StageOutput
|
||||
from .scheduler import StagePlanScheduler, StagePlanResult
|
||||
|
||||
|
||||
class TestSimpleTwoStageRecipe:
|
||||
"""Test basic two-stage recipe flow."""
|
||||
|
||||
def test_compile_two_stage_recipe(self):
|
||||
"""Compile a simple two-stage recipe."""
|
||||
recipe = '''
|
||||
(recipe "test-two-stages"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
assert compiled.stage_order == ["analyze", "output"]
|
||||
|
||||
analyze_stage = compiled.stages[0]
|
||||
assert analyze_stage.name == "analyze"
|
||||
assert "beats" in analyze_stage.outputs
|
||||
|
||||
output_stage = compiled.stages[1]
|
||||
assert output_stage.name == "output"
|
||||
assert output_stage.requires == ["analyze"]
|
||||
assert "beats" in output_stage.inputs
|
||||
|
||||
|
||||
class TestParallelAnalysisStages:
|
||||
"""Test parallel analysis stages."""
|
||||
|
||||
def test_compile_parallel_stages(self):
|
||||
"""Two analysis stages can run in parallel."""
|
||||
recipe = '''
|
||||
(recipe "test-parallel"
|
||||
(def audio-a (source :path "a.mp3"))
|
||||
(def audio-b (source :path "b.mp3"))
|
||||
|
||||
(stage :analyze-a
|
||||
:outputs [beats-a]
|
||||
(def beats-a (-> audio-a (analyze beats))))
|
||||
|
||||
(stage :analyze-b
|
||||
:outputs [beats-b]
|
||||
(def beats-b (-> audio-b (analyze beats))))
|
||||
|
||||
(stage :combine
|
||||
:requires [:analyze-a :analyze-b]
|
||||
:inputs [beats-a beats-b]
|
||||
(-> audio-a (segment :times beats-a) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 3
|
||||
|
||||
# analyze-a and analyze-b should both be at level 0 (parallel)
|
||||
analyze_a = next(s for s in compiled.stages if s.name == "analyze-a")
|
||||
analyze_b = next(s for s in compiled.stages if s.name == "analyze-b")
|
||||
combine = next(s for s in compiled.stages if s.name == "combine")
|
||||
|
||||
assert analyze_a.requires == []
|
||||
assert analyze_b.requires == []
|
||||
assert set(combine.requires) == {"analyze-a", "analyze-b"}
|
||||
|
||||
|
||||
class TestDiamondDependency:
|
||||
"""Test diamond dependency pattern: A -> B, A -> C, B+C -> D."""
|
||||
|
||||
def test_compile_diamond_pattern(self):
|
||||
"""Diamond pattern compiles correctly."""
|
||||
recipe = '''
|
||||
(recipe "test-diamond"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :source-stage
|
||||
:outputs [audio-ref]
|
||||
(def audio-ref audio))
|
||||
|
||||
(stage :branch-b
|
||||
:requires [:source-stage]
|
||||
:inputs [audio-ref]
|
||||
:outputs [result-b]
|
||||
(def result-b (-> audio-ref (effect gain :amount 0.5))))
|
||||
|
||||
(stage :branch-c
|
||||
:requires [:source-stage]
|
||||
:inputs [audio-ref]
|
||||
:outputs [result-c]
|
||||
(def result-c (-> audio-ref (effect gain :amount 0.8))))
|
||||
|
||||
(stage :merge
|
||||
:requires [:branch-b :branch-c]
|
||||
:inputs [result-b result-c]
|
||||
(-> result-b (blend result-c :mode "mix"))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 4
|
||||
|
||||
# Check dependencies
|
||||
source = next(s for s in compiled.stages if s.name == "source-stage")
|
||||
branch_b = next(s for s in compiled.stages if s.name == "branch-b")
|
||||
branch_c = next(s for s in compiled.stages if s.name == "branch-c")
|
||||
merge = next(s for s in compiled.stages if s.name == "merge")
|
||||
|
||||
assert source.requires == []
|
||||
assert branch_b.requires == ["source-stage"]
|
||||
assert branch_c.requires == ["source-stage"]
|
||||
assert set(merge.requires) == {"branch-b", "branch-c"}
|
||||
|
||||
# source-stage should come first in order
|
||||
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-b")
|
||||
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-c")
|
||||
# merge should come last
|
||||
assert compiled.stage_order.index("branch-b") < compiled.stage_order.index("merge")
|
||||
assert compiled.stage_order.index("branch-c") < compiled.stage_order.index("merge")
|
||||
|
||||
|
||||
class TestStageReuseOnRerun:
|
||||
"""Test that re-running recipe uses cached stages."""
|
||||
|
||||
def test_stage_reuse(self):
|
||||
"""Re-running recipe uses cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Simulate first run by caching a stage
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="fixed_cache_id",
|
||||
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(entry)
|
||||
|
||||
# Verify cache exists
|
||||
assert stage_cache.has_stage("fixed_cache_id")
|
||||
|
||||
# Second run should find cache
|
||||
loaded = stage_cache.load_stage("fixed_cache_id")
|
||||
assert loaded is not None
|
||||
assert loaded.stage_name == "analyze"
|
||||
|
||||
|
||||
class TestExplicitDataFlowEndToEnd:
|
||||
"""Test that analysis results flow through :inputs/:outputs."""
|
||||
|
||||
def test_data_flow_declaration(self):
|
||||
"""Explicit data flow is declared correctly."""
|
||||
recipe = '''
|
||||
(recipe "test-data-flow"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats tempo]
|
||||
:outputs [result]
|
||||
(def result (-> audio (segment :times beats) (effect speed :factor tempo)))
|
||||
(-> result (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
|
||||
# Analyze outputs
|
||||
assert set(analyze.outputs) == {"beats", "tempo"}
|
||||
assert "beats" in analyze.output_bindings
|
||||
assert "tempo" in analyze.output_bindings
|
||||
|
||||
# Process inputs
|
||||
assert set(process.inputs) == {"beats", "tempo"}
|
||||
assert process.requires == ["analyze"]
|
||||
|
||||
|
||||
class TestRecipeFixtures:
|
||||
"""Test using recipe fixtures."""
|
||||
|
||||
@pytest.fixture
|
||||
def test_recipe_two_stages(self):
|
||||
return '''
|
||||
(recipe "test-two-stages"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
|
||||
@pytest.fixture
|
||||
def test_recipe_parallel_stages(self):
|
||||
return '''
|
||||
(recipe "test-parallel"
|
||||
(def audio-a (source :path "a.mp3"))
|
||||
(def audio-b (source :path "b.mp3"))
|
||||
|
||||
(stage :analyze-a
|
||||
:outputs [beats-a]
|
||||
(def beats-a (-> audio-a (analyze beats))))
|
||||
|
||||
(stage :analyze-b
|
||||
:outputs [beats-b]
|
||||
(def beats-b (-> audio-b (analyze beats))))
|
||||
|
||||
(stage :combine
|
||||
:requires [:analyze-a :analyze-b]
|
||||
:inputs [beats-a beats-b]
|
||||
(-> audio-a (blend audio-b :mode "mix"))))
|
||||
'''
|
||||
|
||||
def test_two_stages_fixture(self, test_recipe_two_stages):
|
||||
"""Two-stage recipe fixture compiles."""
|
||||
compiled = compile_recipe(parse(test_recipe_two_stages))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
def test_parallel_stages_fixture(self, test_recipe_parallel_stages):
|
||||
"""Parallel stages recipe fixture compiles."""
|
||||
compiled = compile_recipe(parse(test_recipe_parallel_stages))
|
||||
assert len(compiled.stages) == 3
|
||||
|
||||
|
||||
class TestStageValidationErrors:
|
||||
"""Test error handling for invalid stage recipes."""
|
||||
|
||||
def test_missing_output_declaration(self):
|
||||
"""Error when stage output not declared."""
|
||||
recipe = '''
|
||||
(recipe "test-missing-output"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats nonexistent]
|
||||
(def beats (-> audio (analyze beats)))))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not defined in the stage body"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_input_without_requires(self):
|
||||
"""Error when using input not from required stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-input"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires []
|
||||
:inputs [beats]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not an output of any required stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_forward_reference(self):
|
||||
"""Error when requiring stage not yet defined (forward reference)."""
|
||||
recipe = '''
|
||||
(recipe "test-forward-ref"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :a
|
||||
:requires [:b]
|
||||
:outputs [out-a]
|
||||
(def out-a audio)
|
||||
audio)
|
||||
|
||||
(stage :b
|
||||
:outputs [out-b]
|
||||
(def out-b audio)
|
||||
audio))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
|
||||
class TestBeatSyncDemoRecipe:
|
||||
"""Test the beat-sync demo recipe from examples."""
|
||||
|
||||
BEAT_SYNC_RECIPE = '''
|
||||
;; Simple staged recipe demo
|
||||
(recipe "beat-sync-demo"
|
||||
:version "1.0"
|
||||
:description "Demo of staged beat-sync workflow"
|
||||
|
||||
;; Pre-stage definitions (available to all stages)
|
||||
(def audio (source :path "input.mp3"))
|
||||
|
||||
;; Stage 1: Analysis (expensive, cached)
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo))))
|
||||
|
||||
;; Stage 2: Processing (uses analysis results)
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [segments]
|
||||
(def segments (-> audio (segment :times beats)))
|
||||
(-> segments (sequence))))
|
||||
'''
|
||||
|
||||
def test_compile_beat_sync_recipe(self):
|
||||
"""Beat-sync demo recipe compiles correctly."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
assert compiled.name == "beat-sync-demo"
|
||||
assert compiled.version == "1.0"
|
||||
assert compiled.description == "Demo of staged beat-sync workflow"
|
||||
|
||||
def test_beat_sync_stage_count(self):
|
||||
"""Beat-sync has 2 stages in correct order."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
assert compiled.stage_order == ["analyze", "process"]
|
||||
|
||||
def test_beat_sync_analyze_stage(self):
|
||||
"""Analyze stage has correct outputs."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
assert analyze.requires == []
|
||||
assert analyze.inputs == []
|
||||
assert set(analyze.outputs) == {"beats", "tempo"}
|
||||
assert "beats" in analyze.output_bindings
|
||||
assert "tempo" in analyze.output_bindings
|
||||
|
||||
def test_beat_sync_process_stage(self):
|
||||
"""Process stage has correct dependencies and inputs."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process.requires == ["analyze"]
|
||||
assert "beats" in process.inputs
|
||||
assert "segments" in process.outputs
|
||||
|
||||
def test_beat_sync_node_count(self):
|
||||
"""Beat-sync generates expected number of nodes."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
# 1 SOURCE + 2 ANALYZE + 1 SEGMENT + 1 SEQUENCE = 5 nodes
|
||||
assert len(compiled.nodes) == 5
|
||||
|
||||
def test_beat_sync_node_types(self):
|
||||
"""Beat-sync generates correct node types."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
node_types = [n["type"] for n in compiled.nodes]
|
||||
assert node_types.count("SOURCE") == 1
|
||||
assert node_types.count("ANALYZE") == 2
|
||||
assert node_types.count("SEGMENT") == 1
|
||||
assert node_types.count("SEQUENCE") == 1
|
||||
|
||||
def test_beat_sync_output_is_sequence(self):
|
||||
"""Beat-sync output node is the sequence node."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
|
||||
assert output_node["type"] == "SEQUENCE"
|
||||
|
||||
|
||||
class TestAsciiArtStagedRecipe:
|
||||
"""Test the ASCII art staged recipe."""
|
||||
|
||||
ASCII_ART_STAGED_RECIPE = '''
|
||||
;; ASCII art effect with staged execution
|
||||
(recipe "ascii_art_staged"
|
||||
:version "1.0"
|
||||
:description "ASCII art effect with staged execution"
|
||||
:encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30)
|
||||
|
||||
;; Registry
|
||||
(effect ascii_art :path "sexp_effects/effects/ascii_art.sexp")
|
||||
(analyzer energy :path "../artdag-analyzers/energy/analyzer.py")
|
||||
|
||||
;; Pre-stage definitions
|
||||
(def color_mode "color")
|
||||
(def video (source :path "monday.webm"))
|
||||
(def audio (source :path "dizzy.mp3"))
|
||||
|
||||
;; Stage 1: Analysis
|
||||
(stage :analyze
|
||||
:outputs [energy-data]
|
||||
(def audio-clip (-> audio (segment :start 60 :duration 10)))
|
||||
(def energy-data (-> audio-clip (analyze energy))))
|
||||
|
||||
;; Stage 2: Process
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [energy-data]
|
||||
:outputs [result audio-clip]
|
||||
(def clip (-> video (segment :start 0 :duration 10)))
|
||||
(def audio-clip (-> audio (segment :start 60 :duration 10)))
|
||||
(def result (-> clip
|
||||
(effect ascii_art
|
||||
:char_size (bind energy-data values :range [2 32])
|
||||
:color_mode color_mode))))
|
||||
|
||||
;; Stage 3: Output
|
||||
(stage :output
|
||||
:requires [:process]
|
||||
:inputs [result audio-clip]
|
||||
(mux result audio-clip)))
|
||||
'''
|
||||
|
||||
def test_compile_ascii_art_staged(self):
|
||||
"""ASCII art staged recipe compiles correctly."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
assert compiled.name == "ascii_art_staged"
|
||||
assert compiled.version == "1.0"
|
||||
|
||||
def test_ascii_art_stage_count(self):
|
||||
"""ASCII art has 3 stages in correct order."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
assert len(compiled.stages) == 3
|
||||
assert compiled.stage_order == ["analyze", "process", "output"]
|
||||
|
||||
def test_ascii_art_analyze_stage(self):
|
||||
"""Analyze stage outputs energy-data."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
assert analyze.requires == []
|
||||
assert analyze.inputs == []
|
||||
assert "energy-data" in analyze.outputs
|
||||
|
||||
def test_ascii_art_process_stage(self):
|
||||
"""Process stage requires analyze and outputs result."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process.requires == ["analyze"]
|
||||
assert "energy-data" in process.inputs
|
||||
assert "result" in process.outputs
|
||||
assert "audio-clip" in process.outputs
|
||||
|
||||
def test_ascii_art_output_stage(self):
|
||||
"""Output stage requires process and has mux."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
output = next(s for s in compiled.stages if s.name == "output")
|
||||
assert output.requires == ["process"]
|
||||
assert "result" in output.inputs
|
||||
assert "audio-clip" in output.inputs
|
||||
|
||||
def test_ascii_art_node_count(self):
|
||||
"""ASCII art generates expected nodes."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
# 2 SOURCE + 2 SEGMENT + 1 ANALYZE + 1 EFFECT + 1 MUX = 7+ nodes
|
||||
assert len(compiled.nodes) >= 7
|
||||
|
||||
def test_ascii_art_has_mux_output(self):
|
||||
"""ASCII art output is MUX node."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
|
||||
assert output_node["type"] == "MUX"
|
||||
|
||||
|
||||
class TestMixedStagedAndNonStagedRecipes:
|
||||
"""Test that non-staged recipes still work."""
|
||||
|
||||
def test_recipe_without_stages(self):
|
||||
"""Non-staged recipe compiles normally."""
|
||||
recipe = '''
|
||||
(recipe "no-stages"
|
||||
(-> (source :path "test.mp3")
|
||||
(effect gain :amount 0.5)))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert compiled.stages == []
|
||||
assert compiled.stage_order == []
|
||||
# Should still have nodes
|
||||
assert len(compiled.nodes) > 0
|
||||
|
||||
def test_mixed_pre_stage_and_stages(self):
|
||||
"""Pre-stage definitions work with stages."""
|
||||
recipe = '''
|
||||
(recipe "mixed"
|
||||
;; Pre-stage definitions
|
||||
(def audio (source :path "test.mp3"))
|
||||
(def volume 0.8)
|
||||
|
||||
;; Stage using pre-stage definitions, ending with output expression
|
||||
(stage :process
|
||||
:outputs [result]
|
||||
(def result (-> audio (effect gain :amount volume)))
|
||||
result))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 1
|
||||
# audio and volume should be accessible in stage
|
||||
process = compiled.stages[0]
|
||||
assert process.name == "process"
|
||||
assert "result" in process.outputs
|
||||
|
||||
|
||||
class TestEffectParamsBlock:
|
||||
"""Test :params block parsing in effect definitions."""
|
||||
|
||||
def test_parse_effect_with_params_block(self):
|
||||
"""Parse effect with new :params syntax."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
|
||||
effect_code = '''
|
||||
(define-effect test_effect
|
||||
:params (
|
||||
(size :type int :default 10 :range [1 100] :desc "Size parameter")
|
||||
(color :type string :default "red" :desc "Color parameter")
|
||||
(enabled :type int :default 1 :range [0 1] :desc "Enable flag")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, process_fn, defaults, param_defs = load_sexp_effect(effect_code)
|
||||
|
||||
assert name == "test_effect"
|
||||
assert len(param_defs) == 3
|
||||
assert defaults["size"] == 10
|
||||
assert defaults["color"] == "red"
|
||||
assert defaults["enabled"] == 1
|
||||
|
||||
# Check ParamDef objects
|
||||
size_param = param_defs[0]
|
||||
assert size_param.name == "size"
|
||||
assert size_param.param_type == "int"
|
||||
assert size_param.default == 10
|
||||
assert size_param.range_min == 1.0
|
||||
assert size_param.range_max == 100.0
|
||||
assert size_param.description == "Size parameter"
|
||||
|
||||
color_param = param_defs[1]
|
||||
assert color_param.name == "color"
|
||||
assert color_param.param_type == "string"
|
||||
assert color_param.default == "red"
|
||||
|
||||
def test_parse_effect_with_choices(self):
|
||||
"""Parse effect with choices in :params."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
|
||||
effect_code = '''
|
||||
(define-effect mode_effect
|
||||
:params (
|
||||
(mode :type string :default "fast"
|
||||
:choices [fast slow medium]
|
||||
:desc "Processing mode")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, _, defaults, param_defs = load_sexp_effect(effect_code)
|
||||
|
||||
assert name == "mode_effect"
|
||||
assert defaults["mode"] == "fast"
|
||||
|
||||
mode_param = param_defs[0]
|
||||
assert mode_param.choices == ["fast", "slow", "medium"]
|
||||
|
||||
def test_legacy_effect_syntax_rejected(self):
|
||||
"""Legacy effect syntax should be rejected."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect legacy_effect
|
||||
((width 100)
|
||||
(height 200)
|
||||
(name "default"))
|
||||
frame)
|
||||
'''
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
load_sexp_effect(effect_code)
|
||||
|
||||
assert "Legacy parameter syntax" in str(exc_info.value)
|
||||
assert ":params" in str(exc_info.value)
|
||||
|
||||
def test_effect_params_introspection(self):
|
||||
"""Test that effect params are available for introspection."""
|
||||
from .effect_loader import load_sexp_effect_file
|
||||
from pathlib import Path
|
||||
|
||||
# Create a temp effect file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('''
|
||||
(define-effect introspect_test
|
||||
:params (
|
||||
(alpha :type float :default 0.5 :range [0 1] :desc "Alpha value")
|
||||
)
|
||||
frame)
|
||||
''')
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
name, _, defaults, param_defs = load_sexp_effect_file(temp_path)
|
||||
assert name == "introspect_test"
|
||||
assert len(param_defs) == 1
|
||||
assert param_defs[0].name == "alpha"
|
||||
assert param_defs[0].param_type == "float"
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
class TestConstructParamsBlock:
|
||||
"""Test :params block parsing in construct definitions."""
|
||||
|
||||
def test_parse_construct_params_helper(self):
|
||||
"""Test the _parse_construct_params helper function."""
|
||||
from .planner import _parse_construct_params
|
||||
from .parser import Symbol, Keyword
|
||||
|
||||
params_list = [
|
||||
[Symbol("duration"), Keyword("type"), Symbol("float"),
|
||||
Keyword("default"), 5.0, Keyword("desc"), "Duration in seconds"],
|
||||
[Symbol("count"), Keyword("type"), Symbol("int"),
|
||||
Keyword("default"), 10],
|
||||
]
|
||||
|
||||
param_names, param_defaults = _parse_construct_params(params_list)
|
||||
|
||||
assert param_names == ["duration", "count"]
|
||||
assert param_defaults["duration"] == 5.0
|
||||
assert param_defaults["count"] == 10
|
||||
|
||||
def test_construct_params_with_no_defaults(self):
|
||||
"""Test construct params where some have no default."""
|
||||
from .planner import _parse_construct_params
|
||||
from .parser import Symbol, Keyword
|
||||
|
||||
params_list = [
|
||||
[Symbol("required_param"), Keyword("type"), Symbol("string")],
|
||||
[Symbol("optional_param"), Keyword("type"), Symbol("int"),
|
||||
Keyword("default"), 42],
|
||||
]
|
||||
|
||||
param_names, param_defaults = _parse_construct_params(params_list)
|
||||
|
||||
assert param_names == ["required_param", "optional_param"]
|
||||
assert param_defaults["required_param"] is None
|
||||
assert param_defaults["optional_param"] == 42
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test that unknown parameters are rejected."""
|
||||
|
||||
def test_effect_rejects_unknown_params(self):
|
||||
"""Effects should reject unknown parameters."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect test_effect
|
||||
:params (
|
||||
(brightness :type int :default 0 :desc "Brightness")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
|
||||
|
||||
# Create a test frame
|
||||
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
state = {}
|
||||
|
||||
# Valid param should work
|
||||
result, _ = process_frame(frame, {"brightness": 10}, state)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Unknown param should raise
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_frame(frame, {"unknown_param": 42}, state)
|
||||
|
||||
assert "Unknown parameter 'unknown_param'" in str(exc_info.value)
|
||||
assert "brightness" in str(exc_info.value)
|
||||
|
||||
def test_effect_no_params_rejects_all(self):
|
||||
"""Effects with no params should reject any parameter."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect no_params_effect
|
||||
:params ()
|
||||
frame)
|
||||
'''
|
||||
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
|
||||
|
||||
# Create a test frame
|
||||
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
state = {}
|
||||
|
||||
# Empty params should work
|
||||
result, _ = process_frame(frame, {}, state)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Any param should raise
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_frame(frame, {"any_param": 42}, state)
|
||||
|
||||
assert "Unknown parameter 'any_param'" in str(exc_info.value)
|
||||
assert "(none)" in str(exc_info.value)
|
||||
228
artdag/sexp/test_stage_planner.py
Normal file
228
artdag/sexp/test_stage_planner.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Tests for stage-aware planning.
|
||||
|
||||
Tests stage topological sorting, level computation, cache ID computation,
|
||||
and plan metadata generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from .parser import parse
|
||||
from .compiler import compile_recipe, CompiledStage
|
||||
from .planner import (
|
||||
create_plan,
|
||||
StagePlan,
|
||||
_compute_stage_levels,
|
||||
_compute_stage_cache_id,
|
||||
)
|
||||
|
||||
|
||||
class TestStagePlanning:
|
||||
"""Test stage-aware plan creation."""
|
||||
|
||||
def test_stage_topological_sort_in_plan(self):
|
||||
"""Stages sorted by dependencies in plan."""
|
||||
recipe = '''
|
||||
(recipe "test-sort"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
# Note: create_plan needs recipe_dir for analysis, we'll test the ordering differently
|
||||
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
|
||||
|
||||
def test_stage_level_computation(self):
|
||||
"""Independent stages get same level."""
|
||||
stages = [
|
||||
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
CompiledStage(name="c", requires=["a", "b"], inputs=["x", "y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
]
|
||||
levels = _compute_stage_levels(stages)
|
||||
|
||||
assert levels["a"] == 0
|
||||
assert levels["b"] == 0
|
||||
assert levels["c"] == 1 # Depends on a and b
|
||||
|
||||
def test_stage_level_chain(self):
|
||||
"""Chain stages get increasing levels."""
|
||||
stages = [
|
||||
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
]
|
||||
levels = _compute_stage_levels(stages)
|
||||
|
||||
assert levels["a"] == 0
|
||||
assert levels["b"] == 1
|
||||
assert levels["c"] == 2
|
||||
|
||||
def test_stage_cache_id_deterministic(self):
|
||||
"""Same stage = same cache ID."""
|
||||
stage = CompiledStage(
|
||||
name="analyze",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["beats"],
|
||||
node_ids=["abc123"],
|
||||
output_bindings={"beats": "abc123"},
|
||||
)
|
||||
|
||||
cache_id_1 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_2 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
|
||||
assert cache_id_1 == cache_id_2
|
||||
|
||||
def test_stage_cache_id_includes_requires(self):
|
||||
"""Cache ID changes when required stage cache ID changes."""
|
||||
stage = CompiledStage(
|
||||
name="process",
|
||||
requires=["analyze"],
|
||||
inputs=["beats"],
|
||||
outputs=["result"],
|
||||
node_ids=["def456"],
|
||||
output_bindings={"result": "def456"},
|
||||
)
|
||||
|
||||
cache_id_1 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={"analyze": "req_cache_a"},
|
||||
node_cache_ids={"def456": "node_def"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_2 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={"analyze": "req_cache_b"},
|
||||
node_cache_ids={"def456": "node_def"},
|
||||
cluster_key=None,
|
||||
)
|
||||
|
||||
# Different required stage cache IDs should produce different cache IDs
|
||||
assert cache_id_1 != cache_id_2
|
||||
|
||||
def test_stage_cache_id_cluster_key(self):
|
||||
"""Cache ID changes with cluster key."""
|
||||
stage = CompiledStage(
|
||||
name="analyze",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["beats"],
|
||||
node_ids=["abc123"],
|
||||
output_bindings={"beats": "abc123"},
|
||||
)
|
||||
|
||||
cache_id_no_key = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_with_key = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key="cluster123",
|
||||
)
|
||||
|
||||
# Cluster key should change the cache ID
|
||||
assert cache_id_no_key != cache_id_with_key
|
||||
|
||||
|
||||
class TestStagePlanMetadata:
|
||||
"""Test stage metadata in execution plans."""
|
||||
|
||||
def test_plan_without_stages(self):
|
||||
"""Plan without stages has empty stage fields."""
|
||||
recipe = '''
|
||||
(recipe "no-stages"
|
||||
(-> (source :path "test.mp3") (effect gain :amount 0.5)))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert compiled.stages == []
|
||||
assert compiled.stage_order == []
|
||||
|
||||
|
||||
class TestStagePlanDataclass:
|
||||
"""Test StagePlan dataclass."""
|
||||
|
||||
def test_stage_plan_creation(self):
|
||||
"""StagePlan can be created with all fields."""
|
||||
from .planner import PlanStep
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={"analyzer": "beats"},
|
||||
inputs=["input1"],
|
||||
cache_id="cache123",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_123",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "cache123"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
assert stage_plan.stage_name == "analyze"
|
||||
assert stage_plan.cache_id == "stage_cache_123"
|
||||
assert len(stage_plan.steps) == 1
|
||||
assert stage_plan.level == 0
|
||||
|
||||
|
||||
class TestExplicitDataRouting:
|
||||
"""Test that plan includes explicit data routing."""
|
||||
|
||||
def test_plan_step_includes_stage_info(self):
|
||||
"""PlanStep includes stage and stage_cache_id."""
|
||||
from .planner import PlanStep
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache123",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_abc",
|
||||
)
|
||||
|
||||
sexp = step.to_sexp()
|
||||
# Convert to string to check for stage info
|
||||
from .parser import serialize
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "stage" in sexp_str
|
||||
assert "analyze" in sexp_str
|
||||
assert "stage-cache-id" in sexp_str
|
||||
323
artdag/sexp/test_stage_scheduler.py
Normal file
323
artdag/sexp/test_stage_scheduler.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Tests for stage-aware scheduler.
|
||||
|
||||
Tests stage cache hit/miss, stage execution ordering,
|
||||
and parallel stage support.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from .scheduler import (
|
||||
StagePlanScheduler,
|
||||
StageResult,
|
||||
StagePlanResult,
|
||||
create_stage_scheduler,
|
||||
schedule_staged_plan,
|
||||
)
|
||||
from .planner import ExecutionPlanSexp, PlanStep, StagePlan
|
||||
from .stage_cache import StageCache, StageCacheEntry, StageOutput
|
||||
|
||||
|
||||
class TestStagePlanScheduler:
|
||||
"""Test stage-aware scheduling."""
|
||||
|
||||
def test_plan_without_stages_uses_regular_scheduling(self):
|
||||
"""Plans without stages fall back to regular scheduling."""
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[],
|
||||
output_step_id="output",
|
||||
stage_plans=[], # No stages
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler()
|
||||
# This will use PlanScheduler internally
|
||||
# Without Celery, it just returns completed status
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
assert isinstance(result, StagePlanResult)
|
||||
|
||||
def test_stage_cache_hit_skips_execution(self):
|
||||
"""Cached stage not re-executed."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Pre-populate cache
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(entry)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="step_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_123",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "beats_out"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step],
|
||||
output_step_id="step1",
|
||||
stage_plans=[stage_plan],
|
||||
stage_order=["analyze"],
|
||||
stage_levels={"analyze": 0},
|
||||
stage_cache_ids={"analyze": "stage_cache_123"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
assert result.stages_cached == 1
|
||||
assert result.stages_completed == 0
|
||||
|
||||
def test_stage_inputs_loaded_from_cache(self):
|
||||
"""Stage receives inputs from required stage cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Pre-populate upstream stage cache
|
||||
upstream_entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="upstream_cache",
|
||||
outputs={"beats": StageOutput(cache_id="beats_data", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(upstream_entry)
|
||||
|
||||
# Steps for stages
|
||||
upstream_step = PlanStep(
|
||||
step_id="analyze_step",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="analyze_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="upstream_cache",
|
||||
)
|
||||
|
||||
downstream_step = PlanStep(
|
||||
step_id="process_step",
|
||||
node_type="SEGMENT",
|
||||
config={},
|
||||
inputs=["analyze_step"],
|
||||
cache_id="process_cache",
|
||||
level=1,
|
||||
stage="process",
|
||||
stage_cache_id="downstream_cache",
|
||||
)
|
||||
|
||||
upstream_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="upstream_cache",
|
||||
steps=[upstream_step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "beats_data"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
downstream_plan = StagePlan(
|
||||
stage_name="process",
|
||||
cache_id="downstream_cache",
|
||||
steps=[downstream_step],
|
||||
requires=["analyze"],
|
||||
output_bindings={"result": "process_cache"},
|
||||
level=1,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[upstream_step, downstream_step],
|
||||
output_step_id="process_step",
|
||||
stage_plans=[upstream_plan, downstream_plan],
|
||||
stage_order=["analyze", "process"],
|
||||
stage_levels={"analyze": 0, "process": 1},
|
||||
stage_cache_ids={"analyze": "upstream_cache", "process": "downstream_cache"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
# Upstream should be cached, downstream executed
|
||||
assert result.stages_cached == 1
|
||||
assert "analyze" in result.stage_results
|
||||
assert result.stage_results["analyze"].status == "cached"
|
||||
|
||||
def test_parallel_stages_same_level(self):
|
||||
"""Stages at same level can run in parallel."""
|
||||
step_a = PlanStep(
|
||||
step_id="step_a",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache_a",
|
||||
level=0,
|
||||
stage="analyze-a",
|
||||
stage_cache_id="stage_a",
|
||||
)
|
||||
|
||||
step_b = PlanStep(
|
||||
step_id="step_b",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache_b",
|
||||
level=0,
|
||||
stage="analyze-b",
|
||||
stage_cache_id="stage_b",
|
||||
)
|
||||
|
||||
stage_a = StagePlan(
|
||||
stage_name="analyze-a",
|
||||
cache_id="stage_a",
|
||||
steps=[step_a],
|
||||
requires=[],
|
||||
output_bindings={"beats-a": "cache_a"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
stage_b = StagePlan(
|
||||
stage_name="analyze-b",
|
||||
cache_id="stage_b",
|
||||
steps=[step_b],
|
||||
requires=[],
|
||||
output_bindings={"beats-b": "cache_b"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step_a, step_b],
|
||||
output_step_id="step_b",
|
||||
stage_plans=[stage_a, stage_b],
|
||||
stage_order=["analyze-a", "analyze-b"],
|
||||
stage_levels={"analyze-a": 0, "analyze-b": 0},
|
||||
stage_cache_ids={"analyze-a": "stage_a", "analyze-b": "stage_b"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler()
|
||||
# Group stages by level
|
||||
stages_by_level = scheduler._group_stages_by_level(plan.stage_plans)
|
||||
|
||||
# Both stages should be at level 0
|
||||
assert len(stages_by_level[0]) == 2
|
||||
|
||||
def test_stage_outputs_cached_after_execution(self):
|
||||
"""Stage outputs written to cache after completion."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="step_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="new_stage_cache",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="new_stage_cache",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "step_cache"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step],
|
||||
output_step_id="step1",
|
||||
stage_plans=[stage_plan],
|
||||
stage_order=["analyze"],
|
||||
stage_levels={"analyze": 0},
|
||||
stage_cache_ids={"analyze": "new_stage_cache"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
# Stage should now be cached
|
||||
assert stage_cache.has_stage("new_stage_cache")
|
||||
|
||||
|
||||
class TestStageResult:
|
||||
"""Test StageResult dataclass."""
|
||||
|
||||
def test_stage_result_creation(self):
|
||||
"""StageResult can be created with all fields."""
|
||||
result = StageResult(
|
||||
stage_name="test",
|
||||
cache_id="cache123",
|
||||
status="completed",
|
||||
step_results={},
|
||||
outputs={"out": "out_cache"},
|
||||
)
|
||||
|
||||
assert result.stage_name == "test"
|
||||
assert result.status == "completed"
|
||||
assert result.outputs["out"] == "out_cache"
|
||||
|
||||
|
||||
class TestStagePlanResult:
|
||||
"""Test StagePlanResult dataclass."""
|
||||
|
||||
def test_stage_plan_result_creation(self):
|
||||
"""StagePlanResult can be created with all fields."""
|
||||
result = StagePlanResult(
|
||||
plan_id="plan123",
|
||||
status="completed",
|
||||
stages_completed=2,
|
||||
stages_cached=1,
|
||||
stages_failed=0,
|
||||
)
|
||||
|
||||
assert result.plan_id == "plan123"
|
||||
assert result.stages_completed == 2
|
||||
assert result.stages_cached == 1
|
||||
|
||||
|
||||
class TestSchedulerFactory:
|
||||
"""Test scheduler factory functions."""
|
||||
|
||||
def test_create_stage_scheduler(self):
|
||||
"""create_stage_scheduler returns StagePlanScheduler."""
|
||||
scheduler = create_stage_scheduler()
|
||||
assert isinstance(scheduler, StagePlanScheduler)
|
||||
|
||||
def test_create_stage_scheduler_with_cache(self):
|
||||
"""create_stage_scheduler accepts stage_cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
scheduler = create_stage_scheduler(stage_cache=stage_cache)
|
||||
assert scheduler.stage_cache is stage_cache
|
||||
Reference in New Issue
Block a user