All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1740 lines
71 KiB
Python
1740 lines
71 KiB
Python
"""
|
|
Fully Generic Streaming S-expression Interpreter.
|
|
|
|
The interpreter knows NOTHING about video, audio, or any domain.
|
|
All domain logic comes from primitives loaded via (require-primitives ...).
|
|
|
|
Built-in forms:
|
|
- Control: if, cond, let, let*, lambda, ->
|
|
- Arithmetic: +, -, *, /, mod, map-range
|
|
- Comparison: <, >, =, <=, >=, and, or, not
|
|
- Data: dict, get, list, nth, len, quote
|
|
- Random: rand, rand-int, rand-range
|
|
- Scan: bind (access scan state)
|
|
|
|
Everything else comes from primitives or effects.
|
|
|
|
Context (ctx) is passed explicitly to frame evaluation:
|
|
- ctx.t: current time
|
|
- ctx.frame-num: current frame number
|
|
- ctx.fps: frames per second
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import time
|
|
import json
|
|
import hashlib
|
|
import math
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Any, Optional, Tuple, Callable
|
|
|
|
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
|
|
|
|
# JAX backend (optional - loaded on demand)
|
|
_JAX_AVAILABLE = False
|
|
_jax_compiler = None
|
|
|
|
def _init_jax():
|
|
"""Lazily initialize JAX compiler."""
|
|
global _JAX_AVAILABLE, _jax_compiler
|
|
if _jax_compiler is not None:
|
|
return _JAX_AVAILABLE
|
|
try:
|
|
from streaming.sexp_to_jax import JaxCompiler, compile_effect_file
|
|
_jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file}
|
|
_JAX_AVAILABLE = True
|
|
print("JAX backend initialized", file=sys.stderr)
|
|
except ImportError as e:
|
|
print(f"JAX backend not available: {e}", file=sys.stderr)
|
|
_JAX_AVAILABLE = False
|
|
return _JAX_AVAILABLE
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
"""Runtime context passed to frame evaluation."""
|
|
t: float = 0.0
|
|
frame_num: int = 0
|
|
fps: float = 30.0
|
|
|
|
|
|
class DeferredEffectChain:
|
|
"""
|
|
Represents a chain of JAX effects that haven't been executed yet.
|
|
|
|
Allows effects to be accumulated through let bindings and fused
|
|
into a single JIT-compiled function when the result is needed.
|
|
"""
|
|
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
|
|
|
|
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
|
|
self.effects = effects # List of effect names, innermost first
|
|
self.params_list = params_list # List of param dicts, matching effects
|
|
self.base_frame = base_frame # The actual frame array at the start
|
|
self.t = t
|
|
self.frame_num = frame_num
|
|
|
|
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
|
|
"""Add another effect to the chain (outermost)."""
|
|
return DeferredEffectChain(
|
|
self.effects + [effect_name],
|
|
self.params_list + [params],
|
|
self.base_frame,
|
|
self.t,
|
|
self.frame_num
|
|
)
|
|
|
|
@property
|
|
def shape(self):
|
|
"""Allow shape check without forcing execution."""
|
|
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
|
|
|
|
|
|
class StreamInterpreter:
|
|
"""
|
|
Fully generic streaming sexp interpreter.
|
|
|
|
No domain-specific knowledge - just evaluates expressions
|
|
and calls primitives.
|
|
"""
|
|
|
|
def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False):
|
|
self.sexp_path = Path(sexp_path)
|
|
self.sexp_dir = self.sexp_path.parent
|
|
self.actor_id = actor_id # For friendly name resolution
|
|
|
|
text = self.sexp_path.read_text()
|
|
self.ast = parse(text)
|
|
|
|
self.config = self._parse_config()
|
|
|
|
# Global environment for def bindings
|
|
self.globals: Dict[str, Any] = {}
|
|
|
|
# Scans
|
|
self.scans: Dict[str, dict] = {}
|
|
|
|
# Audio playback path (for syncing output)
|
|
self.audio_playback: Optional[str] = None
|
|
|
|
# Registries for external definitions
|
|
self.primitives: Dict[str, Any] = {}
|
|
self.effects: Dict[str, dict] = {}
|
|
self.macros: Dict[str, dict] = {}
|
|
|
|
# JAX backend for accelerated effect evaluation
|
|
self.use_jax = use_jax
|
|
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
|
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
|
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
|
|
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
|
|
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
|
|
if use_jax:
|
|
if _init_jax():
|
|
print("JAX acceleration enabled", file=sys.stderr)
|
|
else:
|
|
print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr)
|
|
self.use_jax = False
|
|
# Try multiple locations for primitive_libs
|
|
possible_paths = [
|
|
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
|
|
self.sexp_dir / "sexp_effects" / "primitive_libs", # app root
|
|
Path(__file__).parent.parent / "sexp_effects" / "primitive_libs", # relative to interpreter
|
|
]
|
|
self.primitive_lib_dir = next((p for p in possible_paths if p.exists()), possible_paths[0])
|
|
|
|
self.frame_pipeline = None
|
|
|
|
# External config files (set before run())
|
|
self.sources_config: Optional[Path] = None
|
|
self.audio_config: Optional[Path] = None
|
|
|
|
# Error tracking
|
|
self.errors: List[str] = []
|
|
|
|
# Callback for live streaming (called when IPFS playlist is updated)
|
|
self.on_playlist_update: callable = None
|
|
|
|
# Callback for progress updates (called periodically during rendering)
|
|
# Signature: on_progress(percent: float, frame_num: int, total_frames: int)
|
|
self.on_progress: callable = None
|
|
|
|
# Callback for checkpoint saves (called at segment boundaries for resumability)
|
|
# Signature: on_checkpoint(checkpoint: dict)
|
|
# checkpoint contains: frame_num, t, scans
|
|
self.on_checkpoint: callable = None
|
|
|
|
# Frames per segment for checkpoint timing (default 4 seconds at 30fps = 120 frames)
|
|
self._frames_per_segment: int = 120
|
|
|
|
def _resolve_name(self, name: str) -> Optional[Path]:
|
|
"""Resolve a friendly name to a file path using the naming service."""
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from tasks.streaming import resolve_asset
|
|
path = resolve_asset(name, self.actor_id)
|
|
if path:
|
|
return path
|
|
except Exception as e:
|
|
print(f"Warning: failed to resolve name '{name}': {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def _record_error(self, msg: str):
|
|
"""Record an error that occurred during evaluation."""
|
|
self.errors.append(msg)
|
|
print(f"ERROR: {msg}", file=sys.stderr)
|
|
|
|
def _maybe_to_numpy(self, val, for_gpu_primitive: bool = False):
|
|
"""Convert GPU frames/CuPy arrays to numpy for CPU primitives.
|
|
|
|
If for_gpu_primitive=True, preserve GPU data (CuPy arrays stay on GPU).
|
|
"""
|
|
if val is None:
|
|
return val
|
|
|
|
# For GPU primitives, keep data on GPU
|
|
if for_gpu_primitive:
|
|
# Handle GPUFrame - return the GPU array
|
|
if hasattr(val, 'gpu') and hasattr(val, 'is_on_gpu'):
|
|
if val.is_on_gpu:
|
|
return val.gpu
|
|
return val.cpu
|
|
# CuPy arrays pass through unchanged
|
|
if hasattr(val, '__cuda_array_interface__'):
|
|
return val
|
|
return val
|
|
|
|
# For CPU primitives, convert to numpy
|
|
# Handle GPUFrame objects (have .cpu property)
|
|
if hasattr(val, 'cpu'):
|
|
return val.cpu
|
|
# Handle CuPy arrays (have .get() method)
|
|
if hasattr(val, 'get') and callable(val.get):
|
|
return val.get()
|
|
return val
|
|
|
|
def _load_config_file(self, config_path):
|
|
"""Load a config file and process its definitions."""
|
|
config_path = Path(config_path) # Accept str or Path
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
|
|
text = config_path.read_text()
|
|
ast = parse_all(text)
|
|
|
|
for form in ast:
|
|
if not isinstance(form, list) or not form:
|
|
continue
|
|
if not isinstance(form[0], Symbol):
|
|
continue
|
|
|
|
cmd = form[0].name
|
|
|
|
if cmd == 'require-primitives':
|
|
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
|
|
self._load_primitives(lib_name)
|
|
|
|
elif cmd == 'def':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
value = self._eval(form[2], self.globals)
|
|
self.globals[name] = value
|
|
print(f"Config: {name}", file=sys.stderr)
|
|
|
|
elif cmd == 'audio-playback':
|
|
# Path relative to working directory (consistent with other paths)
|
|
path = str(form[1]).strip('"')
|
|
self.audio_playback = str(Path(path).resolve())
|
|
print(f"Audio playback: {self.audio_playback}", file=sys.stderr)
|
|
|
|
def _parse_config(self) -> dict:
|
|
"""Parse config from (stream name :key val ...)."""
|
|
config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720}
|
|
if not self.ast or not isinstance(self.ast[0], Symbol):
|
|
return config
|
|
if self.ast[0].name != 'stream':
|
|
return config
|
|
|
|
i = 2
|
|
while i < len(self.ast):
|
|
if isinstance(self.ast[i], Keyword):
|
|
config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None
|
|
i += 2
|
|
elif isinstance(self.ast[i], list):
|
|
break
|
|
else:
|
|
i += 1
|
|
return config
|
|
|
|
def _load_primitives(self, lib_name: str):
|
|
"""Load primitives from a Python library file.
|
|
|
|
Prefers GPU-accelerated versions (*_gpu.py) when available.
|
|
Uses cached modules from sys.modules to ensure consistent state
|
|
(e.g., same RNG instance for all interpreters).
|
|
"""
|
|
import importlib.util
|
|
|
|
# Try GPU version first, then fall back to CPU version
|
|
lib_names_to_try = [f"{lib_name}_gpu", lib_name]
|
|
|
|
lib_path = None
|
|
actual_lib_name = lib_name
|
|
|
|
for try_lib in lib_names_to_try:
|
|
lib_paths = [
|
|
self.primitive_lib_dir / f"{try_lib}.py",
|
|
self.sexp_dir / "primitive_libs" / f"{try_lib}.py",
|
|
self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{try_lib}.py",
|
|
]
|
|
for p in lib_paths:
|
|
if p.exists():
|
|
lib_path = p
|
|
actual_lib_name = try_lib
|
|
break
|
|
if lib_path:
|
|
break
|
|
|
|
if not lib_path:
|
|
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
|
|
|
|
# Use cached module if already imported to preserve state (e.g., RNG)
|
|
# This is critical for deterministic random number sequences
|
|
# Check multiple possible module keys (standard import paths and our cache)
|
|
possible_keys = [
|
|
f"sexp_effects.primitive_libs.{actual_lib_name}",
|
|
f"sexp_primitives.{actual_lib_name}",
|
|
]
|
|
|
|
module = None
|
|
for key in possible_keys:
|
|
if key in sys.modules:
|
|
module = sys.modules[key]
|
|
break
|
|
|
|
if module is None:
|
|
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
# Cache for future use under our key
|
|
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
|
|
|
|
# Check if this is a GPU-accelerated module
|
|
is_gpu = actual_lib_name.endswith('_gpu')
|
|
gpu_tag = " [GPU]" if is_gpu else ""
|
|
|
|
count = 0
|
|
for name in dir(module):
|
|
if name.startswith('prim_'):
|
|
func = getattr(module, name)
|
|
prim_name = name[5:]
|
|
dash_name = prim_name.replace('_', '-')
|
|
# Register with original lib_name namespace (geometry:rotate, not geometry_gpu:rotate)
|
|
# Don't overwrite if already registered (allows pre-registration of overrides)
|
|
key = f"{lib_name}:{dash_name}"
|
|
if key not in self.primitives:
|
|
self.primitives[key] = func
|
|
count += 1
|
|
|
|
if hasattr(module, 'PRIMITIVES'):
|
|
prims = getattr(module, 'PRIMITIVES')
|
|
if isinstance(prims, dict):
|
|
for name, func in prims.items():
|
|
# Register with original lib_name namespace
|
|
# Don't overwrite if already registered
|
|
dash_name = name.replace('_', '-')
|
|
key = f"{lib_name}:{dash_name}"
|
|
if key not in self.primitives:
|
|
self.primitives[key] = func
|
|
count += 1
|
|
|
|
print(f"Loaded primitives: {lib_name} ({count} functions){gpu_tag}", file=sys.stderr)
|
|
|
|
def _load_effect(self, effect_path: Path):
|
|
"""Load and register an effect from a .sexp file."""
|
|
if not effect_path.exists():
|
|
raise FileNotFoundError(f"Effect/include file not found: {effect_path}")
|
|
|
|
text = effect_path.read_text()
|
|
ast = parse_all(text)
|
|
|
|
for form in ast:
|
|
if not isinstance(form, list) or not form:
|
|
continue
|
|
if not isinstance(form[0], Symbol):
|
|
continue
|
|
|
|
cmd = form[0].name
|
|
|
|
if cmd == 'require-primitives':
|
|
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
|
|
self._load_primitives(lib_name)
|
|
|
|
elif cmd == 'define-effect':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
params = {}
|
|
body = None
|
|
i = 2
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
if form[i].name == 'params' and i + 1 < len(form):
|
|
for pdef in form[i + 1]:
|
|
if isinstance(pdef, list) and pdef:
|
|
pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0])
|
|
pinfo = {'default': 0}
|
|
j = 1
|
|
while j < len(pdef):
|
|
if isinstance(pdef[j], Keyword) and j + 1 < len(pdef):
|
|
pinfo[pdef[j].name] = pdef[j + 1]
|
|
j += 2
|
|
else:
|
|
j += 1
|
|
params[pname] = pinfo
|
|
i += 2
|
|
else:
|
|
body = form[i]
|
|
i += 1
|
|
|
|
self.effects[name] = {'params': params, 'body': body}
|
|
self.jax_effect_paths[name] = effect_path # Track source for JAX compilation
|
|
print(f"Effect: {name}", file=sys.stderr)
|
|
|
|
# Try to compile with JAX if enabled
|
|
if self.use_jax and _JAX_AVAILABLE:
|
|
self._compile_jax_effect(name, effect_path)
|
|
|
|
elif cmd == 'defmacro':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
|
|
body = form[3]
|
|
self.macros[name] = {'params': params, 'body': body}
|
|
|
|
elif cmd == 'effect':
|
|
# Handle (effect name :path "...") or (effect name :name "...") in included files
|
|
i = 2
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
kw = form[i].name
|
|
if kw == 'path':
|
|
path = str(form[i + 1]).strip('"')
|
|
full = (effect_path.parent / path).resolve()
|
|
self._load_effect(full)
|
|
i += 2
|
|
elif kw == 'name':
|
|
fname = str(form[i + 1]).strip('"')
|
|
resolved = self._resolve_name(fname)
|
|
if resolved:
|
|
self._load_effect(resolved)
|
|
else:
|
|
raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in")
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
|
|
elif cmd == 'include':
|
|
# Handle (include :path "...") or (include :name "...") in included files
|
|
i = 1
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
kw = form[i].name
|
|
if kw == 'path':
|
|
path = str(form[i + 1]).strip('"')
|
|
full = (effect_path.parent / path).resolve()
|
|
self._load_effect(full)
|
|
i += 2
|
|
elif kw == 'name':
|
|
fname = str(form[i + 1]).strip('"')
|
|
resolved = self._resolve_name(fname)
|
|
if resolved:
|
|
self._load_effect(resolved)
|
|
else:
|
|
raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in")
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
|
|
elif cmd == 'scan':
|
|
# Handle scans from included files
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
trigger_expr = form[2]
|
|
init_val, step_expr = {}, None
|
|
i = 3
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
if form[i].name == 'init' and i + 1 < len(form):
|
|
init_val = self._eval(form[i + 1], self.globals)
|
|
elif form[i].name == 'step' and i + 1 < len(form):
|
|
step_expr = form[i + 1]
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
|
|
self.scans[name] = {
|
|
'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val},
|
|
'init': init_val,
|
|
'step': step_expr,
|
|
'trigger': trigger_expr,
|
|
}
|
|
print(f"Scan: {name}", file=sys.stderr)
|
|
|
|
def _compile_jax_effect(self, name: str, effect_path: Path):
|
|
"""Compile an effect with JAX for accelerated execution."""
|
|
if not _JAX_AVAILABLE or name in self.jax_effects:
|
|
return
|
|
|
|
try:
|
|
compile_effect_file = _jax_compiler['compile_effect_file']
|
|
jax_fn = compile_effect_file(str(effect_path))
|
|
self.jax_effects[name] = jax_fn
|
|
print(f" [JAX compiled: {name}]", file=sys.stderr)
|
|
except Exception as e:
|
|
# Silently fall back to interpreter for unsupported effects
|
|
if 'Unknown operation' not in str(e):
|
|
print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr)
|
|
|
|
def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]:
|
|
"""Apply a JAX-compiled effect to a frame."""
|
|
if name not in self.jax_effects:
|
|
return None
|
|
|
|
try:
|
|
jax_fn = self.jax_effects[name]
|
|
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
|
|
# JAX handles numpy and JAX arrays natively, no conversion needed
|
|
if hasattr(frame, 'cpu'):
|
|
frame = frame.cpu
|
|
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
|
frame = frame.get() # CuPy array -> numpy
|
|
|
|
# Get seed from config for deterministic random
|
|
seed = self.config.get('seed', 42)
|
|
|
|
# Call JAX function with parameters
|
|
# Return JAX array directly - don't block or convert per-effect
|
|
# Conversion to numpy happens once at frame write time
|
|
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
|
except Exception as e:
|
|
# Fall back to interpreter on error
|
|
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def _is_jax_effect_expr(self, expr) -> bool:
|
|
"""Check if an expression is a JAX-compiled effect call."""
|
|
if not isinstance(expr, list) or not expr:
|
|
return False
|
|
head = expr[0]
|
|
if not isinstance(head, Symbol):
|
|
return False
|
|
return head.name in self.jax_effects
|
|
|
|
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
|
|
"""
|
|
Extract a chain of JAX effects from an expression.
|
|
|
|
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
|
|
effect_names and params_list are in execution order (innermost first).
|
|
|
|
For (effect1 (effect2 frame :p1 v1) :p2 v2):
|
|
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
|
|
"""
|
|
if not self._is_jax_effect_expr(expr):
|
|
return None
|
|
|
|
chain = []
|
|
params_list = []
|
|
current = expr
|
|
|
|
while self._is_jax_effect_expr(current):
|
|
head = current[0]
|
|
effect_name = head.name
|
|
args = current[1:]
|
|
|
|
# Extract params for this effect
|
|
effect = self.effects[effect_name]
|
|
effect_params = {}
|
|
for pname, pdef in effect['params'].items():
|
|
effect_params[pname] = pdef.get('default', 0)
|
|
|
|
# Find the frame argument (first positional) and other params
|
|
frame_arg = None
|
|
i = 0
|
|
while i < len(args):
|
|
if isinstance(args[i], Keyword):
|
|
pname = args[i].name
|
|
if pname in effect['params'] and i + 1 < len(args):
|
|
effect_params[pname] = self._eval(args[i + 1], env)
|
|
i += 2
|
|
else:
|
|
if frame_arg is None:
|
|
frame_arg = args[i] # First positional is frame
|
|
i += 1
|
|
|
|
chain.append(effect_name)
|
|
params_list.append(effect_params)
|
|
|
|
if frame_arg is None:
|
|
return None # No frame argument found
|
|
|
|
# Check if frame_arg is another effect call
|
|
if self._is_jax_effect_expr(frame_arg):
|
|
current = frame_arg
|
|
else:
|
|
# End of chain - frame_arg is the base frame
|
|
# Reverse to get innermost-first execution order
|
|
chain.reverse()
|
|
params_list.reverse()
|
|
return (chain, params_list, frame_arg)
|
|
|
|
return None
|
|
|
|
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
|
|
"""Generate a cache key for an effect chain.
|
|
|
|
Includes static param values in the key since they affect compilation.
|
|
"""
|
|
parts = []
|
|
for name, params in zip(effect_names, params_list):
|
|
param_parts = []
|
|
for pname in sorted(params.keys()):
|
|
pval = params[pname]
|
|
# Include static values in key (strings, bools)
|
|
if isinstance(pval, (str, bool)):
|
|
param_parts.append(f"{pname}={pval}")
|
|
else:
|
|
param_parts.append(pname)
|
|
parts.append(f"{name}:{','.join(param_parts)}")
|
|
return '|'.join(parts)
|
|
|
|
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
|
"""
|
|
Compile a chain of effects into a single fused JAX function.
|
|
|
|
Args:
|
|
effect_names: List of effect names in order [innermost, ..., outermost]
|
|
params_list: List of param dicts for each effect (used to detect static types)
|
|
|
|
Returns:
|
|
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
|
|
"""
|
|
if not _JAX_AVAILABLE:
|
|
return None
|
|
|
|
try:
|
|
import jax
|
|
|
|
# Get the individual JAX functions
|
|
jax_fns = [self.jax_effects[name] for name in effect_names]
|
|
|
|
# Pre-extract param names and identify static params from actual values
|
|
effect_param_names = []
|
|
static_params = ['seed'] # seed is always static
|
|
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
|
param_names = list(params.keys())
|
|
effect_param_names.append(param_names)
|
|
# Check actual values to identify static types
|
|
for pname, pval in params.items():
|
|
if isinstance(pval, (str, bool)):
|
|
static_params.append(f"_p{i}_{pname}")
|
|
|
|
def fused_fn(frame, t, frame_num, seed, **kwargs):
|
|
result = frame
|
|
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
|
# Extract params for this effect from kwargs
|
|
effect_kwargs = {}
|
|
for pname in param_names:
|
|
key = f"_p{i}_{pname}"
|
|
if key in kwargs:
|
|
effect_kwargs[pname] = kwargs[key]
|
|
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
|
return result
|
|
|
|
# JIT with static params (seed + any string/bool params)
|
|
return jax.jit(fused_fn, static_argnames=static_params)
|
|
except Exception as e:
|
|
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
|
|
"""Apply a chain of effects, using fused compilation if available."""
|
|
chain_key = self._get_chain_key(effect_names, params_list)
|
|
|
|
# Try to get or compile fused chain
|
|
if chain_key not in self.jax_fused_chains:
|
|
fused_fn = self._compile_effect_chain(effect_names, params_list)
|
|
self.jax_fused_chains[chain_key] = fused_fn
|
|
if fused_fn:
|
|
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
|
|
|
|
fused_fn = self.jax_fused_chains.get(chain_key)
|
|
|
|
if fused_fn is not None:
|
|
# Build kwargs with prefixed param names
|
|
kwargs = {}
|
|
for i, params in enumerate(params_list):
|
|
for pname, pval in params.items():
|
|
kwargs[f"_p{i}_{pname}"] = pval
|
|
|
|
seed = self.config.get('seed', 42)
|
|
|
|
# Handle GPU frames
|
|
if hasattr(frame, 'cpu'):
|
|
frame = frame.cpu
|
|
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
|
frame = frame.get()
|
|
|
|
try:
|
|
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
|
|
except Exception as e:
|
|
print(f"Fused chain error: {e}", file=sys.stderr)
|
|
|
|
# Fall back to sequential application
|
|
result = frame
|
|
for name, params in zip(effect_names, params_list):
|
|
result = self._apply_jax_effect(name, result, params, t, frame_num)
|
|
if result is None:
|
|
return None
|
|
return result
|
|
|
|
def _force_deferred(self, deferred: DeferredEffectChain):
|
|
"""Execute a deferred effect chain and return the actual array."""
|
|
if len(deferred.effects) == 0:
|
|
return deferred.base_frame
|
|
|
|
return self._apply_effect_chain(
|
|
deferred.effects,
|
|
deferred.params_list,
|
|
deferred.base_frame,
|
|
deferred.t,
|
|
deferred.frame_num
|
|
)
|
|
|
|
def _maybe_force(self, value):
|
|
"""Force a deferred chain if needed, otherwise return as-is."""
|
|
if isinstance(value, DeferredEffectChain):
|
|
return self._force_deferred(value)
|
|
return value
|
|
|
|
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
|
"""
|
|
Compile a vmapped version of an effect chain for batch processing.
|
|
|
|
Args:
|
|
effect_names: List of effect names in order [innermost, ..., outermost]
|
|
params_list: List of param dicts (used to detect static types)
|
|
|
|
Returns:
|
|
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
|
|
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
|
|
"""
|
|
if not _JAX_AVAILABLE:
|
|
return None
|
|
|
|
try:
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
# Get the individual JAX functions
|
|
jax_fns = [self.jax_effects[name] for name in effect_names]
|
|
|
|
# Pre-extract param info
|
|
effect_param_names = []
|
|
static_params = set()
|
|
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
|
param_names = list(params.keys())
|
|
effect_param_names.append(param_names)
|
|
for pname, pval in params.items():
|
|
if isinstance(pval, (str, bool)):
|
|
static_params.add(f"_p{i}_{pname}")
|
|
|
|
# Single-frame function (will be vmapped)
|
|
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
|
|
result = frame
|
|
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
|
effect_kwargs = {}
|
|
for pname in param_names:
|
|
key = f"_p{i}_{pname}"
|
|
if key in kwargs:
|
|
effect_kwargs[pname] = kwargs[key]
|
|
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
|
return result
|
|
|
|
# Return unbatched function - we'll vmap at call time with proper in_axes
|
|
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
|
|
except Exception as e:
|
|
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
|
|
frames: list, ts: list, frame_nums: list) -> Optional[list]:
|
|
"""
|
|
Apply an effect chain to a batch of frames using vmap.
|
|
|
|
Args:
|
|
effect_names: List of effect names
|
|
params_list_batch: List of params_list for each frame in batch
|
|
frames: List of input frames
|
|
ts: List of time values
|
|
frame_nums: List of frame numbers
|
|
|
|
Returns:
|
|
List of output frames, or None on failure
|
|
"""
|
|
if not frames:
|
|
return []
|
|
|
|
# Use first frame's params for chain key (assume same structure)
|
|
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
|
|
batch_key = f"batch:{chain_key}"
|
|
|
|
# Compile batched version if needed
|
|
if batch_key not in self.jax_batched_chains:
|
|
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
|
|
self.jax_batched_chains[batch_key] = batched_fn
|
|
if batched_fn:
|
|
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
|
|
|
|
batched_fn = self.jax_batched_chains.get(batch_key)
|
|
|
|
if batched_fn is not None:
|
|
try:
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
# Stack frames into batch array
|
|
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
|
|
ts_array = jnp.array(ts)
|
|
frame_nums_array = jnp.array(frame_nums)
|
|
|
|
# Build kwargs - all numeric params as arrays for vmap
|
|
kwargs = {}
|
|
static_kwargs = {} # Non-vmapped (strings, bools)
|
|
|
|
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
|
|
for j, pname in enumerate(params_list_batch[0][i].keys()):
|
|
key = f"_p{i}_{pname}"
|
|
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
|
|
|
|
first = values[0]
|
|
if isinstance(first, (str, bool)):
|
|
# Static params - not vmapped
|
|
static_kwargs[key] = first
|
|
elif isinstance(first, (int, float)):
|
|
# Always batch numeric params for simplicity
|
|
kwargs[key] = jnp.array(values)
|
|
elif hasattr(first, 'shape'):
|
|
kwargs[key] = jnp.stack(values)
|
|
else:
|
|
kwargs[key] = jnp.array(values)
|
|
|
|
seed = self.config.get('seed', 42)
|
|
|
|
# Create wrapper that unpacks the params dict
|
|
def single_call(frame, t, frame_num, params_dict):
|
|
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
|
|
|
|
# vmap over frame, t, frame_num, and the params dict (as pytree)
|
|
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
|
|
|
|
# Stack kwargs into a dict of arrays (pytree with matching structure)
|
|
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
|
|
|
|
# Unstack results
|
|
return [results[i] for i in range(len(frames))]
|
|
except Exception as e:
|
|
print(f"Batched chain error: {e}", file=sys.stderr)
|
|
|
|
# Fall back to sequential
|
|
return None
|
|
|
|
def _init(self):
|
|
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
|
# Set random seed for deterministic output
|
|
seed = self.config.get('seed', 42)
|
|
try:
|
|
from sexp_effects.primitive_libs.core import set_random_seed
|
|
set_random_seed(seed)
|
|
except ImportError:
|
|
pass
|
|
|
|
# Load external config files first (they can override recipe definitions)
|
|
if self.sources_config:
|
|
self._load_config_file(self.sources_config)
|
|
if self.audio_config:
|
|
self._load_config_file(self.audio_config)
|
|
|
|
for form in self.ast:
|
|
if not isinstance(form, list) or not form:
|
|
continue
|
|
if not isinstance(form[0], Symbol):
|
|
continue
|
|
|
|
cmd = form[0].name
|
|
|
|
if cmd == 'require-primitives':
|
|
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
|
|
self._load_primitives(lib_name)
|
|
|
|
elif cmd == 'effect':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
i = 2
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
kw = form[i].name
|
|
if kw == 'path':
|
|
path = str(form[i + 1]).strip('"')
|
|
full = (self.sexp_dir / path).resolve()
|
|
self._load_effect(full)
|
|
i += 2
|
|
elif kw == 'name':
|
|
# Resolve friendly name to path
|
|
fname = str(form[i + 1]).strip('"')
|
|
resolved = self._resolve_name(fname)
|
|
if resolved:
|
|
self._load_effect(resolved)
|
|
else:
|
|
raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in")
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
|
|
elif cmd == 'include':
|
|
i = 1
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
kw = form[i].name
|
|
if kw == 'path':
|
|
path = str(form[i + 1]).strip('"')
|
|
full = (self.sexp_dir / path).resolve()
|
|
self._load_effect(full)
|
|
i += 2
|
|
elif kw == 'name':
|
|
# Resolve friendly name to path
|
|
fname = str(form[i + 1]).strip('"')
|
|
resolved = self._resolve_name(fname)
|
|
if resolved:
|
|
self._load_effect(resolved)
|
|
else:
|
|
raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in")
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
|
|
elif cmd == 'audio-playback':
|
|
# (audio-playback "path") - set audio file for playback sync
|
|
# Skip if already set by config file
|
|
if self.audio_playback is None:
|
|
path = str(form[1]).strip('"')
|
|
# Try to resolve as friendly name first
|
|
resolved = self._resolve_name(path)
|
|
if resolved:
|
|
self.audio_playback = str(resolved)
|
|
else:
|
|
# Fall back to relative path
|
|
self.audio_playback = str((self.sexp_dir / path).resolve())
|
|
print(f"Audio playback: {self.audio_playback}", file=sys.stderr)
|
|
|
|
elif cmd == 'def':
|
|
# (def name expr) - evaluate and store in globals
|
|
# Skip if already defined by config file
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
if name in self.globals:
|
|
print(f"Def: {name} (from config, skipped)", file=sys.stderr)
|
|
continue
|
|
value = self._eval(form[2], self.globals)
|
|
self.globals[name] = value
|
|
print(f"Def: {name}", file=sys.stderr)
|
|
|
|
elif cmd == 'defmacro':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
|
|
body = form[3]
|
|
self.macros[name] = {'params': params, 'body': body}
|
|
|
|
elif cmd == 'scan':
|
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
|
trigger_expr = form[2]
|
|
init_val, step_expr = {}, None
|
|
i = 3
|
|
while i < len(form):
|
|
if isinstance(form[i], Keyword):
|
|
if form[i].name == 'init' and i + 1 < len(form):
|
|
init_val = self._eval(form[i + 1], self.globals)
|
|
elif form[i].name == 'step' and i + 1 < len(form):
|
|
step_expr = form[i + 1]
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
|
|
self.scans[name] = {
|
|
'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val},
|
|
'init': init_val,
|
|
'step': step_expr,
|
|
'trigger': trigger_expr,
|
|
}
|
|
print(f"Scan: {name}", file=sys.stderr)
|
|
|
|
elif cmd == 'frame':
|
|
self.frame_pipeline = form[1] if len(form) > 1 else None
|
|
|
|
def _eval(self, expr, env: dict) -> Any:
|
|
"""Evaluate an expression."""
|
|
|
|
# Primitives
|
|
if isinstance(expr, (int, float)):
|
|
return expr
|
|
if isinstance(expr, str):
|
|
return expr
|
|
if isinstance(expr, bool):
|
|
return expr
|
|
|
|
if isinstance(expr, Symbol):
|
|
name = expr.name
|
|
# Built-in constants
|
|
if name == 'pi':
|
|
return math.pi
|
|
if name == 'true':
|
|
return True
|
|
if name == 'false':
|
|
return False
|
|
if name == 'nil':
|
|
return None
|
|
# Environment lookup
|
|
if name in env:
|
|
return env[name]
|
|
# Global lookup
|
|
if name in self.globals:
|
|
return self.globals[name]
|
|
# Scan state lookup
|
|
if name in self.scans:
|
|
return self.scans[name]['state']
|
|
raise NameError(f"Undefined variable: {name}")
|
|
|
|
if isinstance(expr, Keyword):
|
|
return expr.name
|
|
|
|
# Handle dicts from new parser - evaluate values
|
|
if isinstance(expr, dict):
|
|
return {k: self._eval(v, env) for k, v in expr.items()}
|
|
|
|
if not isinstance(expr, list) or not expr:
|
|
return expr
|
|
|
|
# Dict literal {:key val ...}
|
|
if isinstance(expr[0], Keyword):
|
|
result = {}
|
|
i = 0
|
|
while i < len(expr):
|
|
if isinstance(expr[i], Keyword):
|
|
result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
return result
|
|
|
|
head = expr[0]
|
|
if not isinstance(head, Symbol):
|
|
return [self._eval(e, env) for e in expr]
|
|
|
|
op = head.name
|
|
args = expr[1:]
|
|
|
|
# Check for closure call
|
|
if op in env:
|
|
val = env[op]
|
|
if isinstance(val, dict) and val.get('_type') == 'closure':
|
|
closure = val
|
|
closure_env = dict(closure['env'])
|
|
for i, pname in enumerate(closure['params']):
|
|
closure_env[pname] = self._eval(args[i], env) if i < len(args) else None
|
|
return self._eval(closure['body'], closure_env)
|
|
|
|
if op in self.globals:
|
|
val = self.globals[op]
|
|
if isinstance(val, dict) and val.get('_type') == 'closure':
|
|
closure = val
|
|
closure_env = dict(closure['env'])
|
|
for i, pname in enumerate(closure['params']):
|
|
closure_env[pname] = self._eval(args[i], env) if i < len(args) else None
|
|
return self._eval(closure['body'], closure_env)
|
|
|
|
# Threading macro
|
|
if op == '->':
|
|
result = self._eval(args[0], env)
|
|
for form in args[1:]:
|
|
if isinstance(form, list) and form:
|
|
new_form = [form[0], result] + form[1:]
|
|
result = self._eval(new_form, env)
|
|
else:
|
|
result = self._eval([form, result], env)
|
|
return result
|
|
|
|
# === Binding ===
|
|
|
|
if op == 'bind':
|
|
scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0])
|
|
if scan_name in self.scans:
|
|
state = self.scans[scan_name]['state']
|
|
if len(args) > 1:
|
|
key = args[1].name if isinstance(args[1], Keyword) else str(args[1])
|
|
return state.get(key, 0)
|
|
return state
|
|
return 0
|
|
|
|
# === Arithmetic ===
|
|
|
|
if op == '+':
|
|
return sum(self._eval(a, env) for a in args)
|
|
if op == '-':
|
|
vals = [self._eval(a, env) for a in args]
|
|
return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0]
|
|
if op == '*':
|
|
result = 1
|
|
for a in args:
|
|
result *= self._eval(a, env)
|
|
return result
|
|
if op == '/':
|
|
vals = [self._eval(a, env) for a in args]
|
|
return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0
|
|
if op == 'mod':
|
|
vals = [self._eval(a, env) for a in args]
|
|
return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0
|
|
|
|
# === Comparison ===
|
|
|
|
if op == '<':
|
|
return self._eval(args[0], env) < self._eval(args[1], env)
|
|
if op == '>':
|
|
return self._eval(args[0], env) > self._eval(args[1], env)
|
|
if op == '=':
|
|
return self._eval(args[0], env) == self._eval(args[1], env)
|
|
if op == '<=':
|
|
return self._eval(args[0], env) <= self._eval(args[1], env)
|
|
if op == '>=':
|
|
return self._eval(args[0], env) >= self._eval(args[1], env)
|
|
|
|
if op == 'and':
|
|
for arg in args:
|
|
if not self._eval(arg, env):
|
|
return False
|
|
return True
|
|
|
|
if op == 'or':
|
|
result = False
|
|
for arg in args:
|
|
result = self._eval(arg, env)
|
|
if result:
|
|
return result
|
|
return result
|
|
|
|
if op == 'not':
|
|
return not self._eval(args[0], env)
|
|
|
|
# === Logic ===
|
|
|
|
if op == 'if':
|
|
cond = self._eval(args[0], env)
|
|
if cond:
|
|
return self._eval(args[1], env)
|
|
return self._eval(args[2], env) if len(args) > 2 else None
|
|
|
|
if op == 'cond':
|
|
i = 0
|
|
while i < len(args) - 1:
|
|
pred = self._eval(args[i], env)
|
|
if pred:
|
|
return self._eval(args[i + 1], env)
|
|
i += 2
|
|
return None
|
|
|
|
if op == 'lambda':
|
|
params = args[0]
|
|
body = args[1]
|
|
param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params]
|
|
return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)}
|
|
|
|
if op == 'let' or op == 'let*':
|
|
bindings = args[0]
|
|
body = args[1]
|
|
new_env = dict(env)
|
|
|
|
if bindings and isinstance(bindings[0], list):
|
|
for binding in bindings:
|
|
if isinstance(binding, list) and len(binding) >= 2:
|
|
name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
|
|
val = self._eval(binding[1], new_env)
|
|
new_env[name] = val
|
|
else:
|
|
i = 0
|
|
while i < len(bindings):
|
|
name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
|
|
val = self._eval(bindings[i + 1], new_env)
|
|
new_env[name] = val
|
|
i += 2
|
|
return self._eval(body, new_env)
|
|
|
|
# === Dict ===
|
|
|
|
if op == 'dict':
|
|
result = {}
|
|
i = 0
|
|
while i < len(args):
|
|
if isinstance(args[i], Keyword):
|
|
key = args[i].name
|
|
val = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
|
result[key] = val
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
return result
|
|
|
|
if op == 'get':
|
|
obj = self._eval(args[0], env)
|
|
key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env)
|
|
if isinstance(obj, dict):
|
|
return obj.get(key, 0)
|
|
return 0
|
|
|
|
# === List ===
|
|
|
|
if op == 'list':
|
|
return [self._eval(a, env) for a in args]
|
|
|
|
if op == 'quote':
|
|
return args[0] if args else None
|
|
|
|
if op == 'nth':
|
|
lst = self._eval(args[0], env)
|
|
idx = int(self._eval(args[1], env))
|
|
if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst):
|
|
return lst[idx]
|
|
return None
|
|
|
|
if op == 'len':
|
|
val = self._eval(args[0], env)
|
|
return len(val) if hasattr(val, '__len__') else 0
|
|
|
|
if op == 'map':
|
|
seq = self._eval(args[0], env)
|
|
fn = self._eval(args[1], env)
|
|
if not isinstance(seq, (list, tuple)):
|
|
return []
|
|
# Handle closure (lambda from sexp)
|
|
if isinstance(fn, dict) and fn.get('_type') == 'closure':
|
|
results = []
|
|
for item in seq:
|
|
closure_env = dict(fn['env'])
|
|
if fn['params']:
|
|
closure_env[fn['params'][0]] = item
|
|
results.append(self._eval(fn['body'], closure_env))
|
|
return results
|
|
# Handle Python callable
|
|
if callable(fn):
|
|
return [fn(item) for item in seq]
|
|
return []
|
|
|
|
# === Effects ===
|
|
|
|
if op in self.effects:
|
|
# Try to detect and fuse effect chains for JAX acceleration
|
|
if self.use_jax and op in self.jax_effects:
|
|
chain_info = self._extract_effect_chain(expr, env)
|
|
if chain_info is not None:
|
|
effect_names, params_list, base_frame_expr = chain_info
|
|
# Only use chain if we have 2+ effects (worth fusing)
|
|
if len(effect_names) >= 2:
|
|
base_frame = self._eval(base_frame_expr, env)
|
|
if base_frame is not None and hasattr(base_frame, 'shape'):
|
|
t = env.get('t', 0.0)
|
|
frame_num = env.get('frame-num', 0)
|
|
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
|
|
if result is not None:
|
|
return result
|
|
# Fall through if chain application fails
|
|
|
|
effect = self.effects[op]
|
|
effect_env = dict(env)
|
|
|
|
param_names = list(effect['params'].keys())
|
|
for pname, pdef in effect['params'].items():
|
|
effect_env[pname] = pdef.get('default', 0)
|
|
|
|
positional_idx = 0
|
|
frame_val = None
|
|
i = 0
|
|
while i < len(args):
|
|
if isinstance(args[i], Keyword):
|
|
pname = args[i].name
|
|
if pname in effect['params'] and i + 1 < len(args):
|
|
effect_env[pname] = self._eval(args[i + 1], env)
|
|
i += 2
|
|
else:
|
|
val = self._eval(args[i], env)
|
|
if positional_idx == 0:
|
|
effect_env['frame'] = val
|
|
frame_val = val
|
|
elif positional_idx - 1 < len(param_names):
|
|
effect_env[param_names[positional_idx - 1]] = val
|
|
positional_idx += 1
|
|
i += 1
|
|
|
|
# Try JAX-accelerated execution with deferred chaining
|
|
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
|
# Build params dict for JAX (exclude 'frame')
|
|
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
|
|
if k != 'frame' and k in effect['params']}
|
|
t = env.get('t', 0.0)
|
|
frame_num = env.get('frame-num', 0)
|
|
|
|
# Check if input is a deferred chain - if so, extend it
|
|
if isinstance(frame_val, DeferredEffectChain):
|
|
return frame_val.extend(op, jax_params)
|
|
|
|
# Check if input is a valid frame - create new deferred chain
|
|
if hasattr(frame_val, 'shape'):
|
|
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
|
|
|
|
# Fall through to interpreter if not a valid frame
|
|
|
|
# Force any deferred frame before interpreter evaluation
|
|
if isinstance(frame_val, DeferredEffectChain):
|
|
frame_val = self._force_deferred(frame_val)
|
|
effect_env['frame'] = frame_val
|
|
|
|
return self._eval(effect['body'], effect_env)
|
|
|
|
# === Primitives ===
|
|
|
|
if op in self.primitives:
|
|
prim_func = self.primitives[op]
|
|
# Check if this is a GPU primitive (preserves GPU arrays)
|
|
is_gpu_prim = op.startswith('gpu:') or 'gpu' in op.lower()
|
|
evaluated_args = []
|
|
kwargs = {}
|
|
i = 0
|
|
while i < len(args):
|
|
if isinstance(args[i], Keyword):
|
|
k = args[i].name
|
|
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
|
# Force deferred chains before passing to primitives
|
|
v = self._maybe_force(v)
|
|
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
|
i += 2
|
|
else:
|
|
val = self._eval(args[i], env)
|
|
# Force deferred chains before passing to primitives
|
|
val = self._maybe_force(val)
|
|
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
|
|
i += 1
|
|
try:
|
|
if kwargs:
|
|
return prim_func(*evaluated_args, **kwargs)
|
|
return prim_func(*evaluated_args)
|
|
except Exception as e:
|
|
self._record_error(f"Primitive {op} error: {e}")
|
|
raise RuntimeError(f"Primitive {op} failed: {e}")
|
|
|
|
# === Macros (function-like: args evaluated before binding) ===
|
|
|
|
if op in self.macros:
|
|
macro = self.macros[op]
|
|
macro_env = dict(env)
|
|
for i, pname in enumerate(macro['params']):
|
|
# Evaluate args in calling environment before binding
|
|
macro_env[pname] = self._eval(args[i], env) if i < len(args) else None
|
|
return self._eval(macro['body'], macro_env)
|
|
|
|
# Underscore variant lookup
|
|
prim_name = op.replace('-', '_')
|
|
if prim_name in self.primitives:
|
|
prim_func = self.primitives[prim_name]
|
|
# Check if this is a GPU primitive (preserves GPU arrays)
|
|
is_gpu_prim = 'gpu' in prim_name.lower()
|
|
evaluated_args = []
|
|
kwargs = {}
|
|
i = 0
|
|
while i < len(args):
|
|
if isinstance(args[i], Keyword):
|
|
k = args[i].name.replace('-', '_')
|
|
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
|
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
|
i += 2
|
|
else:
|
|
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
|
|
i += 1
|
|
|
|
try:
|
|
if kwargs:
|
|
return prim_func(*evaluated_args, **kwargs)
|
|
return prim_func(*evaluated_args)
|
|
except Exception as e:
|
|
self._record_error(f"Primitive {op} error: {e}")
|
|
raise RuntimeError(f"Primitive {op} failed: {e}")
|
|
|
|
# Unknown function call - raise meaningful error
|
|
raise RuntimeError(f"Unknown function or primitive: '{op}'. "
|
|
f"Available primitives: {sorted(list(self.primitives.keys())[:10])}... "
|
|
f"Available effects: {sorted(list(self.effects.keys())[:10])}... "
|
|
f"Available macros: {sorted(list(self.macros.keys())[:10])}...")
|
|
|
|
def _step_scans(self, ctx: Context, env: dict):
|
|
"""Step scans based on trigger evaluation."""
|
|
for name, scan in self.scans.items():
|
|
trigger_expr = scan['trigger']
|
|
|
|
# Evaluate trigger in context
|
|
should_step = self._eval(trigger_expr, env)
|
|
|
|
if should_step:
|
|
state = scan['state']
|
|
step_env = dict(state)
|
|
step_env.update(env)
|
|
|
|
new_state = self._eval(scan['step'], step_env)
|
|
if isinstance(new_state, dict):
|
|
scan['state'] = new_state
|
|
else:
|
|
scan['state'] = {'acc': new_state}
|
|
|
|
def _restore_checkpoint(self, checkpoint: dict):
|
|
"""Restore scan states from a checkpoint.
|
|
|
|
Called when resuming a render from a previous checkpoint.
|
|
|
|
Args:
|
|
checkpoint: Dict with 'scans' key containing {scan_name: state_dict}
|
|
"""
|
|
scans_state = checkpoint.get('scans', {})
|
|
for name, state in scans_state.items():
|
|
if name in self.scans:
|
|
self.scans[name]['state'] = dict(state)
|
|
print(f"Restored scan '{name}' state from checkpoint", file=sys.stderr)
|
|
|
|
def _get_checkpoint_state(self) -> dict:
|
|
"""Get current scan states for checkpointing.
|
|
|
|
Returns:
|
|
Dict mapping scan names to their current state dicts
|
|
"""
|
|
return {name: dict(scan['state']) for name, scan in self.scans.items()}
|
|
|
|
def run(self, duration: float = None, output: str = "pipe", resume_from: dict = None):
|
|
"""Run the streaming pipeline.
|
|
|
|
Args:
|
|
duration: Duration in seconds (auto-detected from audio if None)
|
|
output: Output mode ("pipe", "preview", path/hls, path/ipfs-hls, or file path)
|
|
resume_from: Checkpoint dict to resume from, with keys:
|
|
- frame_num: Last completed frame
|
|
- t: Time value for checkpoint frame
|
|
- scans: Dict of scan states to restore
|
|
- segment_cids: Dict of quality -> {seg_num: cid} for output resume
|
|
"""
|
|
# Import output classes - handle both package and direct execution
|
|
try:
|
|
from .output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput
|
|
from .gpu_output import GPUHLSOutput, check_gpu_encode_available
|
|
from .multi_res_output import MultiResolutionHLSOutput
|
|
except ImportError:
|
|
from output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput
|
|
try:
|
|
from gpu_output import GPUHLSOutput, check_gpu_encode_available
|
|
except ImportError:
|
|
GPUHLSOutput = None
|
|
check_gpu_encode_available = lambda: False
|
|
try:
|
|
from multi_res_output import MultiResolutionHLSOutput
|
|
except ImportError:
|
|
MultiResolutionHLSOutput = None
|
|
|
|
self._init()
|
|
|
|
# Restore checkpoint state if resuming
|
|
if resume_from:
|
|
self._restore_checkpoint(resume_from)
|
|
print(f"Resuming from frame {resume_from.get('frame_num', 0)}", file=sys.stderr)
|
|
|
|
if not self.frame_pipeline:
|
|
print("Error: no (frame ...) pipeline defined", file=sys.stderr)
|
|
return
|
|
|
|
w = self.config.get('width', 720)
|
|
h = self.config.get('height', 720)
|
|
fps = self.config.get('fps', 30)
|
|
|
|
if duration is None:
|
|
# Try to get duration from audio if available
|
|
for name, val in self.globals.items():
|
|
if hasattr(val, 'duration'):
|
|
duration = val.duration
|
|
print(f"Using audio duration: {duration:.1f}s", file=sys.stderr)
|
|
break
|
|
else:
|
|
duration = 60.0
|
|
|
|
n_frames = int(duration * fps)
|
|
frame_time = 1.0 / fps
|
|
|
|
print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr)
|
|
|
|
# Create context
|
|
ctx = Context(fps=fps)
|
|
|
|
# Output (with optional audio sync)
|
|
# Resolve audio path lazily here if it wasn't resolved during parsing
|
|
audio = self.audio_playback
|
|
if audio and not Path(audio).exists():
|
|
# Try to resolve as friendly name (may have failed during parsing)
|
|
audio_name = Path(audio).name # Get just the name part
|
|
resolved = self._resolve_name(audio_name)
|
|
if resolved and resolved.exists():
|
|
audio = str(resolved)
|
|
print(f"Lazy resolved audio: {audio}", file=sys.stderr)
|
|
else:
|
|
raise FileNotFoundError(f"Audio file not found: {audio}")
|
|
if output == "pipe":
|
|
out = PipeOutput(size=(w, h), fps=fps, audio_source=audio)
|
|
elif output == "preview":
|
|
out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio)
|
|
elif output.endswith("/hls"):
|
|
# HLS output - output is a directory path ending in /hls
|
|
hls_dir = output[:-4] # Remove /hls suffix
|
|
out = HLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio)
|
|
elif output.endswith("/ipfs-hls"):
|
|
# IPFS HLS output - multi-resolution adaptive streaming
|
|
hls_dir = output[:-9] # Remove /ipfs-hls suffix
|
|
import os
|
|
ipfs_gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs")
|
|
|
|
# Build resume state for output if resuming
|
|
output_resume = None
|
|
if resume_from and resume_from.get('segment_cids'):
|
|
output_resume = {'segment_cids': resume_from['segment_cids']}
|
|
|
|
# Use multi-resolution output (renders original + 720p + 360p)
|
|
if MultiResolutionHLSOutput is not None:
|
|
print(f"[StreamInterpreter] Using multi-resolution HLS output ({w}x{h} + 720p + 360p)", file=sys.stderr)
|
|
out = MultiResolutionHLSOutput(
|
|
hls_dir,
|
|
source_size=(w, h),
|
|
fps=fps,
|
|
ipfs_gateway=ipfs_gateway,
|
|
on_playlist_update=self.on_playlist_update,
|
|
audio_source=audio,
|
|
resume_from=output_resume,
|
|
)
|
|
# Fallback to GPU single-resolution if multi-res not available
|
|
elif GPUHLSOutput is not None and check_gpu_encode_available():
|
|
print(f"[StreamInterpreter] Using GPU zero-copy encoding (single resolution)", file=sys.stderr)
|
|
out = GPUHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway,
|
|
on_playlist_update=self.on_playlist_update)
|
|
else:
|
|
out = IPFSHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway,
|
|
on_playlist_update=self.on_playlist_update)
|
|
else:
|
|
out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio)
|
|
|
|
# Calculate frames per segment based on fps and segment duration (4 seconds default)
|
|
segment_duration = 4.0
|
|
self._frames_per_segment = int(fps * segment_duration)
|
|
|
|
# Determine start frame (resume from checkpoint + 1, or 0)
|
|
start_frame = 0
|
|
if resume_from and resume_from.get('frame_num') is not None:
|
|
start_frame = resume_from['frame_num'] + 1
|
|
print(f"Starting from frame {start_frame} (checkpoint was at {resume_from['frame_num']})", file=sys.stderr)
|
|
|
|
try:
|
|
frame_times = []
|
|
profile_interval = 30 # Profile every N frames
|
|
scan_times = []
|
|
eval_times = []
|
|
write_times = []
|
|
|
|
# Batch accumulation for JAX
|
|
batch_deferred = [] # Accumulated DeferredEffectChains
|
|
batch_times = [] # Corresponding time values
|
|
batch_start_frame = 0
|
|
|
|
def flush_batch():
|
|
"""Execute accumulated batch and write results."""
|
|
nonlocal batch_deferred, batch_times
|
|
if not batch_deferred:
|
|
return
|
|
|
|
t_flush = time.time()
|
|
|
|
# Check if all chains have same structure (can batch)
|
|
first = batch_deferred[0]
|
|
can_batch = (
|
|
self.use_jax and
|
|
len(batch_deferred) >= 2 and
|
|
all(d.effects == first.effects for d in batch_deferred)
|
|
)
|
|
|
|
if can_batch:
|
|
# Try batched execution
|
|
frames = [d.base_frame for d in batch_deferred]
|
|
ts = [d.t for d in batch_deferred]
|
|
frame_nums = [d.frame_num for d in batch_deferred]
|
|
params_batch = [d.params_list for d in batch_deferred]
|
|
|
|
results = self._apply_batched_chain(
|
|
first.effects, params_batch, frames, ts, frame_nums
|
|
)
|
|
|
|
if results is not None:
|
|
# Write batched results
|
|
for result, t in zip(results, batch_times):
|
|
if hasattr(result, 'block_until_ready'):
|
|
result.block_until_ready()
|
|
result = np.asarray(result)
|
|
out.write(result, t)
|
|
batch_deferred = []
|
|
batch_times = []
|
|
return
|
|
|
|
# Fall back to sequential execution
|
|
for deferred, t in zip(batch_deferred, batch_times):
|
|
result = self._force_deferred(deferred)
|
|
if result is not None and hasattr(result, 'shape'):
|
|
if hasattr(result, 'block_until_ready'):
|
|
result.block_until_ready()
|
|
result = np.asarray(result)
|
|
out.write(result, t)
|
|
|
|
batch_deferred = []
|
|
batch_times = []
|
|
|
|
for frame_num in range(start_frame, n_frames):
|
|
if not out.is_open:
|
|
break
|
|
|
|
frame_start = time.time()
|
|
ctx.t = frame_num * frame_time
|
|
ctx.frame_num = frame_num
|
|
|
|
# Build frame environment with context
|
|
frame_env = {
|
|
'ctx': {
|
|
't': ctx.t,
|
|
'frame-num': ctx.frame_num,
|
|
'fps': ctx.fps,
|
|
},
|
|
't': ctx.t, # Also expose t directly for convenience
|
|
'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
# Step scans
|
|
t0 = time.time()
|
|
self._step_scans(ctx, frame_env)
|
|
scan_times.append(time.time() - t0)
|
|
|
|
# Evaluate pipeline
|
|
t1 = time.time()
|
|
result = self._eval(self.frame_pipeline, frame_env)
|
|
eval_times.append(time.time() - t1)
|
|
|
|
t2 = time.time()
|
|
if result is not None:
|
|
if isinstance(result, DeferredEffectChain):
|
|
# Accumulate for batching
|
|
batch_deferred.append(result)
|
|
batch_times.append(ctx.t)
|
|
|
|
# Flush when batch is full
|
|
if len(batch_deferred) >= self.jax_batch_size:
|
|
flush_batch()
|
|
else:
|
|
# Not deferred - flush any pending batch first, then write
|
|
flush_batch()
|
|
if hasattr(result, 'shape'):
|
|
if hasattr(result, 'block_until_ready'):
|
|
result.block_until_ready()
|
|
result = np.asarray(result)
|
|
out.write(result, ctx.t)
|
|
write_times.append(time.time() - t2)
|
|
|
|
frame_elapsed = time.time() - frame_start
|
|
frame_times.append(frame_elapsed)
|
|
|
|
# Checkpoint at segment boundaries (every _frames_per_segment frames)
|
|
if frame_num > 0 and frame_num % self._frames_per_segment == 0:
|
|
if self.on_checkpoint:
|
|
try:
|
|
checkpoint = {
|
|
'frame_num': frame_num,
|
|
't': ctx.t,
|
|
'scans': self._get_checkpoint_state(),
|
|
}
|
|
self.on_checkpoint(checkpoint)
|
|
except Exception as e:
|
|
print(f"Warning: checkpoint callback failed: {e}", file=sys.stderr)
|
|
|
|
# Progress with timing and profile breakdown
|
|
if frame_num % profile_interval == 0 and frame_num > 0:
|
|
pct = 100 * frame_num / n_frames
|
|
avg_ms = 1000 * sum(frame_times[-profile_interval:]) / max(1, len(frame_times[-profile_interval:]))
|
|
avg_scan = 1000 * sum(scan_times[-profile_interval:]) / max(1, len(scan_times[-profile_interval:]))
|
|
avg_eval = 1000 * sum(eval_times[-profile_interval:]) / max(1, len(eval_times[-profile_interval:]))
|
|
avg_write = 1000 * sum(write_times[-profile_interval:]) / max(1, len(write_times[-profile_interval:]))
|
|
target_ms = 1000 * frame_time
|
|
print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms] scan={avg_scan:.0f}ms eval={avg_eval:.0f}ms write={avg_write:.0f}ms", end="", file=sys.stderr, flush=True)
|
|
|
|
# Call progress callback if set (for Celery task state updates)
|
|
if self.on_progress:
|
|
try:
|
|
self.on_progress(pct, frame_num, n_frames)
|
|
except Exception as e:
|
|
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
|
|
|
|
# Flush any remaining batch
|
|
flush_batch()
|
|
|
|
finally:
|
|
out.close()
|
|
# Store output for access to properties like playlist_cid
|
|
self.output = out
|
|
print("\nDone", file=sys.stderr)
|
|
|
|
|
|
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
|
|
sources_config: str = None, audio_config: str = None, use_jax: bool = False):
|
|
"""Run a streaming sexp."""
|
|
interp = StreamInterpreter(sexp_path, use_jax=use_jax)
|
|
if fps:
|
|
interp.config['fps'] = fps
|
|
if sources_config:
|
|
interp.sources_config = Path(sources_config)
|
|
if audio_config:
|
|
interp.audio_config = Path(audio_config)
|
|
interp.run(duration=duration, output=output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)")
|
|
parser.add_argument("sexp", help="Path to .sexp file")
|
|
parser.add_argument("-d", "--duration", type=float, default=None)
|
|
parser.add_argument("-o", "--output", default="pipe")
|
|
parser.add_argument("--fps", type=float, default=None)
|
|
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
|
|
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
|
|
parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects")
|
|
args = parser.parse_args()
|
|
|
|
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
|
|
sources_config=args.sources_config, audio_config=args.audio_config,
|
|
use_jax=args.jax)
|