Files
celery/streaming/stream_sexp_generic.py
gilesb fc9597456f
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:41:19 +00:00

1740 lines
71 KiB
Python

"""
Fully Generic Streaming S-expression Interpreter.
The interpreter knows NOTHING about video, audio, or any domain.
All domain logic comes from primitives loaded via (require-primitives ...).
Built-in forms:
- Control: if, cond, let, let*, lambda, ->
- Arithmetic: +, -, *, /, mod, map-range
- Comparison: <, >, =, <=, >=, and, or, not
- Data: dict, get, list, nth, len, quote
- Random: rand, rand-int, rand-range
- Scan: bind (access scan state)
Everything else comes from primitives or effects.
Context (ctx) is passed explicitly to frame evaluation:
- ctx.t: current time
- ctx.frame-num: current frame number
- ctx.fps: frames per second
"""
import sys
import os
import time
import json
import hashlib
import math
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Any, Optional, Tuple, Callable
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
sys.path.insert(0, str(Path(__file__).parent.parent))
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
# JAX backend (optional - loaded on demand)
_JAX_AVAILABLE = False
_jax_compiler = None
def _init_jax():
"""Lazily initialize JAX compiler."""
global _JAX_AVAILABLE, _jax_compiler
if _jax_compiler is not None:
return _JAX_AVAILABLE
try:
from streaming.sexp_to_jax import JaxCompiler, compile_effect_file
_jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file}
_JAX_AVAILABLE = True
print("JAX backend initialized", file=sys.stderr)
except ImportError as e:
print(f"JAX backend not available: {e}", file=sys.stderr)
_JAX_AVAILABLE = False
return _JAX_AVAILABLE
@dataclass
class Context:
"""Runtime context passed to frame evaluation."""
t: float = 0.0
frame_num: int = 0
fps: float = 30.0
class DeferredEffectChain:
"""
Represents a chain of JAX effects that haven't been executed yet.
Allows effects to be accumulated through let bindings and fused
into a single JIT-compiled function when the result is needed.
"""
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
self.effects = effects # List of effect names, innermost first
self.params_list = params_list # List of param dicts, matching effects
self.base_frame = base_frame # The actual frame array at the start
self.t = t
self.frame_num = frame_num
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
"""Add another effect to the chain (outermost)."""
return DeferredEffectChain(
self.effects + [effect_name],
self.params_list + [params],
self.base_frame,
self.t,
self.frame_num
)
@property
def shape(self):
"""Allow shape check without forcing execution."""
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
class StreamInterpreter:
"""
Fully generic streaming sexp interpreter.
No domain-specific knowledge - just evaluates expressions
and calls primitives.
"""
def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False):
self.sexp_path = Path(sexp_path)
self.sexp_dir = self.sexp_path.parent
self.actor_id = actor_id # For friendly name resolution
text = self.sexp_path.read_text()
self.ast = parse(text)
self.config = self._parse_config()
# Global environment for def bindings
self.globals: Dict[str, Any] = {}
# Scans
self.scans: Dict[str, dict] = {}
# Audio playback path (for syncing output)
self.audio_playback: Optional[str] = None
# Registries for external definitions
self.primitives: Dict[str, Any] = {}
self.effects: Dict[str, dict] = {}
self.macros: Dict[str, dict] = {}
# JAX backend for accelerated effect evaluation
self.use_jax = use_jax
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
if use_jax:
if _init_jax():
print("JAX acceleration enabled", file=sys.stderr)
else:
print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr)
self.use_jax = False
# Try multiple locations for primitive_libs
possible_paths = [
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
self.sexp_dir / "sexp_effects" / "primitive_libs", # app root
Path(__file__).parent.parent / "sexp_effects" / "primitive_libs", # relative to interpreter
]
self.primitive_lib_dir = next((p for p in possible_paths if p.exists()), possible_paths[0])
self.frame_pipeline = None
# External config files (set before run())
self.sources_config: Optional[Path] = None
self.audio_config: Optional[Path] = None
# Error tracking
self.errors: List[str] = []
# Callback for live streaming (called when IPFS playlist is updated)
self.on_playlist_update: callable = None
# Callback for progress updates (called periodically during rendering)
# Signature: on_progress(percent: float, frame_num: int, total_frames: int)
self.on_progress: callable = None
# Callback for checkpoint saves (called at segment boundaries for resumability)
# Signature: on_checkpoint(checkpoint: dict)
# checkpoint contains: frame_num, t, scans
self.on_checkpoint: callable = None
# Frames per segment for checkpoint timing (default 4 seconds at 30fps = 120 frames)
self._frames_per_segment: int = 120
def _resolve_name(self, name: str) -> Optional[Path]:
"""Resolve a friendly name to a file path using the naming service."""
try:
# Import here to avoid circular imports
from tasks.streaming import resolve_asset
path = resolve_asset(name, self.actor_id)
if path:
return path
except Exception as e:
print(f"Warning: failed to resolve name '{name}': {e}", file=sys.stderr)
return None
def _record_error(self, msg: str):
"""Record an error that occurred during evaluation."""
self.errors.append(msg)
print(f"ERROR: {msg}", file=sys.stderr)
def _maybe_to_numpy(self, val, for_gpu_primitive: bool = False):
"""Convert GPU frames/CuPy arrays to numpy for CPU primitives.
If for_gpu_primitive=True, preserve GPU data (CuPy arrays stay on GPU).
"""
if val is None:
return val
# For GPU primitives, keep data on GPU
if for_gpu_primitive:
# Handle GPUFrame - return the GPU array
if hasattr(val, 'gpu') and hasattr(val, 'is_on_gpu'):
if val.is_on_gpu:
return val.gpu
return val.cpu
# CuPy arrays pass through unchanged
if hasattr(val, '__cuda_array_interface__'):
return val
return val
# For CPU primitives, convert to numpy
# Handle GPUFrame objects (have .cpu property)
if hasattr(val, 'cpu'):
return val.cpu
# Handle CuPy arrays (have .get() method)
if hasattr(val, 'get') and callable(val.get):
return val.get()
return val
def _load_config_file(self, config_path):
"""Load a config file and process its definitions."""
config_path = Path(config_path) # Accept str or Path
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
text = config_path.read_text()
ast = parse_all(text)
for form in ast:
if not isinstance(form, list) or not form:
continue
if not isinstance(form[0], Symbol):
continue
cmd = form[0].name
if cmd == 'require-primitives':
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
self._load_primitives(lib_name)
elif cmd == 'def':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
value = self._eval(form[2], self.globals)
self.globals[name] = value
print(f"Config: {name}", file=sys.stderr)
elif cmd == 'audio-playback':
# Path relative to working directory (consistent with other paths)
path = str(form[1]).strip('"')
self.audio_playback = str(Path(path).resolve())
print(f"Audio playback: {self.audio_playback}", file=sys.stderr)
def _parse_config(self) -> dict:
"""Parse config from (stream name :key val ...)."""
config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720}
if not self.ast or not isinstance(self.ast[0], Symbol):
return config
if self.ast[0].name != 'stream':
return config
i = 2
while i < len(self.ast):
if isinstance(self.ast[i], Keyword):
config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None
i += 2
elif isinstance(self.ast[i], list):
break
else:
i += 1
return config
def _load_primitives(self, lib_name: str):
"""Load primitives from a Python library file.
Prefers GPU-accelerated versions (*_gpu.py) when available.
Uses cached modules from sys.modules to ensure consistent state
(e.g., same RNG instance for all interpreters).
"""
import importlib.util
# Try GPU version first, then fall back to CPU version
lib_names_to_try = [f"{lib_name}_gpu", lib_name]
lib_path = None
actual_lib_name = lib_name
for try_lib in lib_names_to_try:
lib_paths = [
self.primitive_lib_dir / f"{try_lib}.py",
self.sexp_dir / "primitive_libs" / f"{try_lib}.py",
self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{try_lib}.py",
]
for p in lib_paths:
if p.exists():
lib_path = p
actual_lib_name = try_lib
break
if lib_path:
break
if not lib_path:
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
# Use cached module if already imported to preserve state (e.g., RNG)
# This is critical for deterministic random number sequences
# Check multiple possible module keys (standard import paths and our cache)
possible_keys = [
f"sexp_effects.primitive_libs.{actual_lib_name}",
f"sexp_primitives.{actual_lib_name}",
]
module = None
for key in possible_keys:
if key in sys.modules:
module = sys.modules[key]
break
if module is None:
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Cache for future use under our key
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
# Check if this is a GPU-accelerated module
is_gpu = actual_lib_name.endswith('_gpu')
gpu_tag = " [GPU]" if is_gpu else ""
count = 0
for name in dir(module):
if name.startswith('prim_'):
func = getattr(module, name)
prim_name = name[5:]
dash_name = prim_name.replace('_', '-')
# Register with original lib_name namespace (geometry:rotate, not geometry_gpu:rotate)
# Don't overwrite if already registered (allows pre-registration of overrides)
key = f"{lib_name}:{dash_name}"
if key not in self.primitives:
self.primitives[key] = func
count += 1
if hasattr(module, 'PRIMITIVES'):
prims = getattr(module, 'PRIMITIVES')
if isinstance(prims, dict):
for name, func in prims.items():
# Register with original lib_name namespace
# Don't overwrite if already registered
dash_name = name.replace('_', '-')
key = f"{lib_name}:{dash_name}"
if key not in self.primitives:
self.primitives[key] = func
count += 1
print(f"Loaded primitives: {lib_name} ({count} functions){gpu_tag}", file=sys.stderr)
def _load_effect(self, effect_path: Path):
"""Load and register an effect from a .sexp file."""
if not effect_path.exists():
raise FileNotFoundError(f"Effect/include file not found: {effect_path}")
text = effect_path.read_text()
ast = parse_all(text)
for form in ast:
if not isinstance(form, list) or not form:
continue
if not isinstance(form[0], Symbol):
continue
cmd = form[0].name
if cmd == 'require-primitives':
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
self._load_primitives(lib_name)
elif cmd == 'define-effect':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
params = {}
body = None
i = 2
while i < len(form):
if isinstance(form[i], Keyword):
if form[i].name == 'params' and i + 1 < len(form):
for pdef in form[i + 1]:
if isinstance(pdef, list) and pdef:
pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0])
pinfo = {'default': 0}
j = 1
while j < len(pdef):
if isinstance(pdef[j], Keyword) and j + 1 < len(pdef):
pinfo[pdef[j].name] = pdef[j + 1]
j += 2
else:
j += 1
params[pname] = pinfo
i += 2
else:
body = form[i]
i += 1
self.effects[name] = {'params': params, 'body': body}
self.jax_effect_paths[name] = effect_path # Track source for JAX compilation
print(f"Effect: {name}", file=sys.stderr)
# Try to compile with JAX if enabled
if self.use_jax and _JAX_AVAILABLE:
self._compile_jax_effect(name, effect_path)
elif cmd == 'defmacro':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
body = form[3]
self.macros[name] = {'params': params, 'body': body}
elif cmd == 'effect':
# Handle (effect name :path "...") or (effect name :name "...") in included files
i = 2
while i < len(form):
if isinstance(form[i], Keyword):
kw = form[i].name
if kw == 'path':
path = str(form[i + 1]).strip('"')
full = (effect_path.parent / path).resolve()
self._load_effect(full)
i += 2
elif kw == 'name':
fname = str(form[i + 1]).strip('"')
resolved = self._resolve_name(fname)
if resolved:
self._load_effect(resolved)
else:
raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in")
i += 2
else:
i += 1
else:
i += 1
elif cmd == 'include':
# Handle (include :path "...") or (include :name "...") in included files
i = 1
while i < len(form):
if isinstance(form[i], Keyword):
kw = form[i].name
if kw == 'path':
path = str(form[i + 1]).strip('"')
full = (effect_path.parent / path).resolve()
self._load_effect(full)
i += 2
elif kw == 'name':
fname = str(form[i + 1]).strip('"')
resolved = self._resolve_name(fname)
if resolved:
self._load_effect(resolved)
else:
raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in")
i += 2
else:
i += 1
else:
i += 1
elif cmd == 'scan':
# Handle scans from included files
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
trigger_expr = form[2]
init_val, step_expr = {}, None
i = 3
while i < len(form):
if isinstance(form[i], Keyword):
if form[i].name == 'init' and i + 1 < len(form):
init_val = self._eval(form[i + 1], self.globals)
elif form[i].name == 'step' and i + 1 < len(form):
step_expr = form[i + 1]
i += 2
else:
i += 1
self.scans[name] = {
'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val},
'init': init_val,
'step': step_expr,
'trigger': trigger_expr,
}
print(f"Scan: {name}", file=sys.stderr)
def _compile_jax_effect(self, name: str, effect_path: Path):
"""Compile an effect with JAX for accelerated execution."""
if not _JAX_AVAILABLE or name in self.jax_effects:
return
try:
compile_effect_file = _jax_compiler['compile_effect_file']
jax_fn = compile_effect_file(str(effect_path))
self.jax_effects[name] = jax_fn
print(f" [JAX compiled: {name}]", file=sys.stderr)
except Exception as e:
# Silently fall back to interpreter for unsupported effects
if 'Unknown operation' not in str(e):
print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr)
def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]:
"""Apply a JAX-compiled effect to a frame."""
if name not in self.jax_effects:
return None
try:
jax_fn = self.jax_effects[name]
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
# JAX handles numpy and JAX arrays natively, no conversion needed
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get() # CuPy array -> numpy
# Get seed from config for deterministic random
seed = self.config.get('seed', 42)
# Call JAX function with parameters
# Return JAX array directly - don't block or convert per-effect
# Conversion to numpy happens once at frame write time
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
except Exception as e:
# Fall back to interpreter on error
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
return None
def _is_jax_effect_expr(self, expr) -> bool:
"""Check if an expression is a JAX-compiled effect call."""
if not isinstance(expr, list) or not expr:
return False
head = expr[0]
if not isinstance(head, Symbol):
return False
return head.name in self.jax_effects
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
"""
Extract a chain of JAX effects from an expression.
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
effect_names and params_list are in execution order (innermost first).
For (effect1 (effect2 frame :p1 v1) :p2 v2):
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
"""
if not self._is_jax_effect_expr(expr):
return None
chain = []
params_list = []
current = expr
while self._is_jax_effect_expr(current):
head = current[0]
effect_name = head.name
args = current[1:]
# Extract params for this effect
effect = self.effects[effect_name]
effect_params = {}
for pname, pdef in effect['params'].items():
effect_params[pname] = pdef.get('default', 0)
# Find the frame argument (first positional) and other params
frame_arg = None
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
pname = args[i].name
if pname in effect['params'] and i + 1 < len(args):
effect_params[pname] = self._eval(args[i + 1], env)
i += 2
else:
if frame_arg is None:
frame_arg = args[i] # First positional is frame
i += 1
chain.append(effect_name)
params_list.append(effect_params)
if frame_arg is None:
return None # No frame argument found
# Check if frame_arg is another effect call
if self._is_jax_effect_expr(frame_arg):
current = frame_arg
else:
# End of chain - frame_arg is the base frame
# Reverse to get innermost-first execution order
chain.reverse()
params_list.reverse()
return (chain, params_list, frame_arg)
return None
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
"""Generate a cache key for an effect chain.
Includes static param values in the key since they affect compilation.
"""
parts = []
for name, params in zip(effect_names, params_list):
param_parts = []
for pname in sorted(params.keys()):
pval = params[pname]
# Include static values in key (strings, bools)
if isinstance(pval, (str, bool)):
param_parts.append(f"{pname}={pval}")
else:
param_parts.append(pname)
parts.append(f"{name}:{','.join(param_parts)}")
return '|'.join(parts)
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a chain of effects into a single fused JAX function.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts for each effect (used to detect static types)
Returns:
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param names and identify static params from actual values
effect_param_names = []
static_params = ['seed'] # seed is always static
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
# Check actual values to identify static types
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.append(f"_p{i}_{pname}")
def fused_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
# Extract params for this effect from kwargs
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# JIT with static params (seed + any string/bool params)
return jax.jit(fused_fn, static_argnames=static_params)
except Exception as e:
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
"""Apply a chain of effects, using fused compilation if available."""
chain_key = self._get_chain_key(effect_names, params_list)
# Try to get or compile fused chain
if chain_key not in self.jax_fused_chains:
fused_fn = self._compile_effect_chain(effect_names, params_list)
self.jax_fused_chains[chain_key] = fused_fn
if fused_fn:
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
fused_fn = self.jax_fused_chains.get(chain_key)
if fused_fn is not None:
# Build kwargs with prefixed param names
kwargs = {}
for i, params in enumerate(params_list):
for pname, pval in params.items():
kwargs[f"_p{i}_{pname}"] = pval
seed = self.config.get('seed', 42)
# Handle GPU frames
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get()
try:
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
except Exception as e:
print(f"Fused chain error: {e}", file=sys.stderr)
# Fall back to sequential application
result = frame
for name, params in zip(effect_names, params_list):
result = self._apply_jax_effect(name, result, params, t, frame_num)
if result is None:
return None
return result
def _force_deferred(self, deferred: DeferredEffectChain):
"""Execute a deferred effect chain and return the actual array."""
if len(deferred.effects) == 0:
return deferred.base_frame
return self._apply_effect_chain(
deferred.effects,
deferred.params_list,
deferred.base_frame,
deferred.t,
deferred.frame_num
)
def _maybe_force(self, value):
"""Force a deferred chain if needed, otherwise return as-is."""
if isinstance(value, DeferredEffectChain):
return self._force_deferred(value)
return value
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a vmapped version of an effect chain for batch processing.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts (used to detect static types)
Returns:
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
import jax.numpy as jnp
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param info
effect_param_names = []
static_params = set()
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.add(f"_p{i}_{pname}")
# Single-frame function (will be vmapped)
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# Return unbatched function - we'll vmap at call time with proper in_axes
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
except Exception as e:
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
frames: list, ts: list, frame_nums: list) -> Optional[list]:
"""
Apply an effect chain to a batch of frames using vmap.
Args:
effect_names: List of effect names
params_list_batch: List of params_list for each frame in batch
frames: List of input frames
ts: List of time values
frame_nums: List of frame numbers
Returns:
List of output frames, or None on failure
"""
if not frames:
return []
# Use first frame's params for chain key (assume same structure)
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
batch_key = f"batch:{chain_key}"
# Compile batched version if needed
if batch_key not in self.jax_batched_chains:
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
self.jax_batched_chains[batch_key] = batched_fn
if batched_fn:
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
batched_fn = self.jax_batched_chains.get(batch_key)
if batched_fn is not None:
try:
import jax
import jax.numpy as jnp
# Stack frames into batch array
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
ts_array = jnp.array(ts)
frame_nums_array = jnp.array(frame_nums)
# Build kwargs - all numeric params as arrays for vmap
kwargs = {}
static_kwargs = {} # Non-vmapped (strings, bools)
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
for j, pname in enumerate(params_list_batch[0][i].keys()):
key = f"_p{i}_{pname}"
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
first = values[0]
if isinstance(first, (str, bool)):
# Static params - not vmapped
static_kwargs[key] = first
elif isinstance(first, (int, float)):
# Always batch numeric params for simplicity
kwargs[key] = jnp.array(values)
elif hasattr(first, 'shape'):
kwargs[key] = jnp.stack(values)
else:
kwargs[key] = jnp.array(values)
seed = self.config.get('seed', 42)
# Create wrapper that unpacks the params dict
def single_call(frame, t, frame_num, params_dict):
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
# vmap over frame, t, frame_num, and the params dict (as pytree)
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
# Stack kwargs into a dict of arrays (pytree with matching structure)
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
# Unstack results
return [results[i] for i in range(len(frames))]
except Exception as e:
print(f"Batched chain error: {e}", file=sys.stderr)
# Fall back to sequential
return None
def _init(self):
"""Initialize from sexp - load primitives, effects, defs, scans."""
# Set random seed for deterministic output
seed = self.config.get('seed', 42)
try:
from sexp_effects.primitive_libs.core import set_random_seed
set_random_seed(seed)
except ImportError:
pass
# Load external config files first (they can override recipe definitions)
if self.sources_config:
self._load_config_file(self.sources_config)
if self.audio_config:
self._load_config_file(self.audio_config)
for form in self.ast:
if not isinstance(form, list) or not form:
continue
if not isinstance(form[0], Symbol):
continue
cmd = form[0].name
if cmd == 'require-primitives':
lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"')
self._load_primitives(lib_name)
elif cmd == 'effect':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
i = 2
while i < len(form):
if isinstance(form[i], Keyword):
kw = form[i].name
if kw == 'path':
path = str(form[i + 1]).strip('"')
full = (self.sexp_dir / path).resolve()
self._load_effect(full)
i += 2
elif kw == 'name':
# Resolve friendly name to path
fname = str(form[i + 1]).strip('"')
resolved = self._resolve_name(fname)
if resolved:
self._load_effect(resolved)
else:
raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in")
i += 2
else:
i += 1
else:
i += 1
elif cmd == 'include':
i = 1
while i < len(form):
if isinstance(form[i], Keyword):
kw = form[i].name
if kw == 'path':
path = str(form[i + 1]).strip('"')
full = (self.sexp_dir / path).resolve()
self._load_effect(full)
i += 2
elif kw == 'name':
# Resolve friendly name to path
fname = str(form[i + 1]).strip('"')
resolved = self._resolve_name(fname)
if resolved:
self._load_effect(resolved)
else:
raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in")
i += 2
else:
i += 1
else:
i += 1
elif cmd == 'audio-playback':
# (audio-playback "path") - set audio file for playback sync
# Skip if already set by config file
if self.audio_playback is None:
path = str(form[1]).strip('"')
# Try to resolve as friendly name first
resolved = self._resolve_name(path)
if resolved:
self.audio_playback = str(resolved)
else:
# Fall back to relative path
self.audio_playback = str((self.sexp_dir / path).resolve())
print(f"Audio playback: {self.audio_playback}", file=sys.stderr)
elif cmd == 'def':
# (def name expr) - evaluate and store in globals
# Skip if already defined by config file
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
if name in self.globals:
print(f"Def: {name} (from config, skipped)", file=sys.stderr)
continue
value = self._eval(form[2], self.globals)
self.globals[name] = value
print(f"Def: {name}", file=sys.stderr)
elif cmd == 'defmacro':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
body = form[3]
self.macros[name] = {'params': params, 'body': body}
elif cmd == 'scan':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
trigger_expr = form[2]
init_val, step_expr = {}, None
i = 3
while i < len(form):
if isinstance(form[i], Keyword):
if form[i].name == 'init' and i + 1 < len(form):
init_val = self._eval(form[i + 1], self.globals)
elif form[i].name == 'step' and i + 1 < len(form):
step_expr = form[i + 1]
i += 2
else:
i += 1
self.scans[name] = {
'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val},
'init': init_val,
'step': step_expr,
'trigger': trigger_expr,
}
print(f"Scan: {name}", file=sys.stderr)
elif cmd == 'frame':
self.frame_pipeline = form[1] if len(form) > 1 else None
def _eval(self, expr, env: dict) -> Any:
"""Evaluate an expression."""
# Primitives
if isinstance(expr, (int, float)):
return expr
if isinstance(expr, str):
return expr
if isinstance(expr, bool):
return expr
if isinstance(expr, Symbol):
name = expr.name
# Built-in constants
if name == 'pi':
return math.pi
if name == 'true':
return True
if name == 'false':
return False
if name == 'nil':
return None
# Environment lookup
if name in env:
return env[name]
# Global lookup
if name in self.globals:
return self.globals[name]
# Scan state lookup
if name in self.scans:
return self.scans[name]['state']
raise NameError(f"Undefined variable: {name}")
if isinstance(expr, Keyword):
return expr.name
# Handle dicts from new parser - evaluate values
if isinstance(expr, dict):
return {k: self._eval(v, env) for k, v in expr.items()}
if not isinstance(expr, list) or not expr:
return expr
# Dict literal {:key val ...}
if isinstance(expr[0], Keyword):
result = {}
i = 0
while i < len(expr):
if isinstance(expr[i], Keyword):
result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None
i += 2
else:
i += 1
return result
head = expr[0]
if not isinstance(head, Symbol):
return [self._eval(e, env) for e in expr]
op = head.name
args = expr[1:]
# Check for closure call
if op in env:
val = env[op]
if isinstance(val, dict) and val.get('_type') == 'closure':
closure = val
closure_env = dict(closure['env'])
for i, pname in enumerate(closure['params']):
closure_env[pname] = self._eval(args[i], env) if i < len(args) else None
return self._eval(closure['body'], closure_env)
if op in self.globals:
val = self.globals[op]
if isinstance(val, dict) and val.get('_type') == 'closure':
closure = val
closure_env = dict(closure['env'])
for i, pname in enumerate(closure['params']):
closure_env[pname] = self._eval(args[i], env) if i < len(args) else None
return self._eval(closure['body'], closure_env)
# Threading macro
if op == '->':
result = self._eval(args[0], env)
for form in args[1:]:
if isinstance(form, list) and form:
new_form = [form[0], result] + form[1:]
result = self._eval(new_form, env)
else:
result = self._eval([form, result], env)
return result
# === Binding ===
if op == 'bind':
scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0])
if scan_name in self.scans:
state = self.scans[scan_name]['state']
if len(args) > 1:
key = args[1].name if isinstance(args[1], Keyword) else str(args[1])
return state.get(key, 0)
return state
return 0
# === Arithmetic ===
if op == '+':
return sum(self._eval(a, env) for a in args)
if op == '-':
vals = [self._eval(a, env) for a in args]
return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0]
if op == '*':
result = 1
for a in args:
result *= self._eval(a, env)
return result
if op == '/':
vals = [self._eval(a, env) for a in args]
return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0
if op == 'mod':
vals = [self._eval(a, env) for a in args]
return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0
# === Comparison ===
if op == '<':
return self._eval(args[0], env) < self._eval(args[1], env)
if op == '>':
return self._eval(args[0], env) > self._eval(args[1], env)
if op == '=':
return self._eval(args[0], env) == self._eval(args[1], env)
if op == '<=':
return self._eval(args[0], env) <= self._eval(args[1], env)
if op == '>=':
return self._eval(args[0], env) >= self._eval(args[1], env)
if op == 'and':
for arg in args:
if not self._eval(arg, env):
return False
return True
if op == 'or':
result = False
for arg in args:
result = self._eval(arg, env)
if result:
return result
return result
if op == 'not':
return not self._eval(args[0], env)
# === Logic ===
if op == 'if':
cond = self._eval(args[0], env)
if cond:
return self._eval(args[1], env)
return self._eval(args[2], env) if len(args) > 2 else None
if op == 'cond':
i = 0
while i < len(args) - 1:
pred = self._eval(args[i], env)
if pred:
return self._eval(args[i + 1], env)
i += 2
return None
if op == 'lambda':
params = args[0]
body = args[1]
param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params]
return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)}
if op == 'let' or op == 'let*':
bindings = args[0]
body = args[1]
new_env = dict(env)
if bindings and isinstance(bindings[0], list):
for binding in bindings:
if isinstance(binding, list) and len(binding) >= 2:
name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
val = self._eval(binding[1], new_env)
new_env[name] = val
else:
i = 0
while i < len(bindings):
name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
val = self._eval(bindings[i + 1], new_env)
new_env[name] = val
i += 2
return self._eval(body, new_env)
# === Dict ===
if op == 'dict':
result = {}
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
key = args[i].name
val = self._eval(args[i + 1], env) if i + 1 < len(args) else None
result[key] = val
i += 2
else:
i += 1
return result
if op == 'get':
obj = self._eval(args[0], env)
key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env)
if isinstance(obj, dict):
return obj.get(key, 0)
return 0
# === List ===
if op == 'list':
return [self._eval(a, env) for a in args]
if op == 'quote':
return args[0] if args else None
if op == 'nth':
lst = self._eval(args[0], env)
idx = int(self._eval(args[1], env))
if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst):
return lst[idx]
return None
if op == 'len':
val = self._eval(args[0], env)
return len(val) if hasattr(val, '__len__') else 0
if op == 'map':
seq = self._eval(args[0], env)
fn = self._eval(args[1], env)
if not isinstance(seq, (list, tuple)):
return []
# Handle closure (lambda from sexp)
if isinstance(fn, dict) and fn.get('_type') == 'closure':
results = []
for item in seq:
closure_env = dict(fn['env'])
if fn['params']:
closure_env[fn['params'][0]] = item
results.append(self._eval(fn['body'], closure_env))
return results
# Handle Python callable
if callable(fn):
return [fn(item) for item in seq]
return []
# === Effects ===
if op in self.effects:
# Try to detect and fuse effect chains for JAX acceleration
if self.use_jax and op in self.jax_effects:
chain_info = self._extract_effect_chain(expr, env)
if chain_info is not None:
effect_names, params_list, base_frame_expr = chain_info
# Only use chain if we have 2+ effects (worth fusing)
if len(effect_names) >= 2:
base_frame = self._eval(base_frame_expr, env)
if base_frame is not None and hasattr(base_frame, 'shape'):
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
if result is not None:
return result
# Fall through if chain application fails
effect = self.effects[op]
effect_env = dict(env)
param_names = list(effect['params'].keys())
for pname, pdef in effect['params'].items():
effect_env[pname] = pdef.get('default', 0)
positional_idx = 0
frame_val = None
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
pname = args[i].name
if pname in effect['params'] and i + 1 < len(args):
effect_env[pname] = self._eval(args[i + 1], env)
i += 2
else:
val = self._eval(args[i], env)
if positional_idx == 0:
effect_env['frame'] = val
frame_val = val
elif positional_idx - 1 < len(param_names):
effect_env[param_names[positional_idx - 1]] = val
positional_idx += 1
i += 1
# Try JAX-accelerated execution with deferred chaining
if self.use_jax and op in self.jax_effects and frame_val is not None:
# Build params dict for JAX (exclude 'frame')
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
if k != 'frame' and k in effect['params']}
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
# Check if input is a deferred chain - if so, extend it
if isinstance(frame_val, DeferredEffectChain):
return frame_val.extend(op, jax_params)
# Check if input is a valid frame - create new deferred chain
if hasattr(frame_val, 'shape'):
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
# Fall through to interpreter if not a valid frame
# Force any deferred frame before interpreter evaluation
if isinstance(frame_val, DeferredEffectChain):
frame_val = self._force_deferred(frame_val)
effect_env['frame'] = frame_val
return self._eval(effect['body'], effect_env)
# === Primitives ===
if op in self.primitives:
prim_func = self.primitives[op]
# Check if this is a GPU primitive (preserves GPU arrays)
is_gpu_prim = op.startswith('gpu:') or 'gpu' in op.lower()
evaluated_args = []
kwargs = {}
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
k = args[i].name
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
# Force deferred chains before passing to primitives
v = self._maybe_force(v)
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
i += 2
else:
val = self._eval(args[i], env)
# Force deferred chains before passing to primitives
val = self._maybe_force(val)
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
i += 1
try:
if kwargs:
return prim_func(*evaluated_args, **kwargs)
return prim_func(*evaluated_args)
except Exception as e:
self._record_error(f"Primitive {op} error: {e}")
raise RuntimeError(f"Primitive {op} failed: {e}")
# === Macros (function-like: args evaluated before binding) ===
if op in self.macros:
macro = self.macros[op]
macro_env = dict(env)
for i, pname in enumerate(macro['params']):
# Evaluate args in calling environment before binding
macro_env[pname] = self._eval(args[i], env) if i < len(args) else None
return self._eval(macro['body'], macro_env)
# Underscore variant lookup
prim_name = op.replace('-', '_')
if prim_name in self.primitives:
prim_func = self.primitives[prim_name]
# Check if this is a GPU primitive (preserves GPU arrays)
is_gpu_prim = 'gpu' in prim_name.lower()
evaluated_args = []
kwargs = {}
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
k = args[i].name.replace('-', '_')
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
i += 2
else:
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
i += 1
try:
if kwargs:
return prim_func(*evaluated_args, **kwargs)
return prim_func(*evaluated_args)
except Exception as e:
self._record_error(f"Primitive {op} error: {e}")
raise RuntimeError(f"Primitive {op} failed: {e}")
# Unknown function call - raise meaningful error
raise RuntimeError(f"Unknown function or primitive: '{op}'. "
f"Available primitives: {sorted(list(self.primitives.keys())[:10])}... "
f"Available effects: {sorted(list(self.effects.keys())[:10])}... "
f"Available macros: {sorted(list(self.macros.keys())[:10])}...")
def _step_scans(self, ctx: Context, env: dict):
"""Step scans based on trigger evaluation."""
for name, scan in self.scans.items():
trigger_expr = scan['trigger']
# Evaluate trigger in context
should_step = self._eval(trigger_expr, env)
if should_step:
state = scan['state']
step_env = dict(state)
step_env.update(env)
new_state = self._eval(scan['step'], step_env)
if isinstance(new_state, dict):
scan['state'] = new_state
else:
scan['state'] = {'acc': new_state}
def _restore_checkpoint(self, checkpoint: dict):
"""Restore scan states from a checkpoint.
Called when resuming a render from a previous checkpoint.
Args:
checkpoint: Dict with 'scans' key containing {scan_name: state_dict}
"""
scans_state = checkpoint.get('scans', {})
for name, state in scans_state.items():
if name in self.scans:
self.scans[name]['state'] = dict(state)
print(f"Restored scan '{name}' state from checkpoint", file=sys.stderr)
def _get_checkpoint_state(self) -> dict:
"""Get current scan states for checkpointing.
Returns:
Dict mapping scan names to their current state dicts
"""
return {name: dict(scan['state']) for name, scan in self.scans.items()}
def run(self, duration: float = None, output: str = "pipe", resume_from: dict = None):
"""Run the streaming pipeline.
Args:
duration: Duration in seconds (auto-detected from audio if None)
output: Output mode ("pipe", "preview", path/hls, path/ipfs-hls, or file path)
resume_from: Checkpoint dict to resume from, with keys:
- frame_num: Last completed frame
- t: Time value for checkpoint frame
- scans: Dict of scan states to restore
- segment_cids: Dict of quality -> {seg_num: cid} for output resume
"""
# Import output classes - handle both package and direct execution
try:
from .output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput
from .gpu_output import GPUHLSOutput, check_gpu_encode_available
from .multi_res_output import MultiResolutionHLSOutput
except ImportError:
from output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput
try:
from gpu_output import GPUHLSOutput, check_gpu_encode_available
except ImportError:
GPUHLSOutput = None
check_gpu_encode_available = lambda: False
try:
from multi_res_output import MultiResolutionHLSOutput
except ImportError:
MultiResolutionHLSOutput = None
self._init()
# Restore checkpoint state if resuming
if resume_from:
self._restore_checkpoint(resume_from)
print(f"Resuming from frame {resume_from.get('frame_num', 0)}", file=sys.stderr)
if not self.frame_pipeline:
print("Error: no (frame ...) pipeline defined", file=sys.stderr)
return
w = self.config.get('width', 720)
h = self.config.get('height', 720)
fps = self.config.get('fps', 30)
if duration is None:
# Try to get duration from audio if available
for name, val in self.globals.items():
if hasattr(val, 'duration'):
duration = val.duration
print(f"Using audio duration: {duration:.1f}s", file=sys.stderr)
break
else:
duration = 60.0
n_frames = int(duration * fps)
frame_time = 1.0 / fps
print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr)
# Create context
ctx = Context(fps=fps)
# Output (with optional audio sync)
# Resolve audio path lazily here if it wasn't resolved during parsing
audio = self.audio_playback
if audio and not Path(audio).exists():
# Try to resolve as friendly name (may have failed during parsing)
audio_name = Path(audio).name # Get just the name part
resolved = self._resolve_name(audio_name)
if resolved and resolved.exists():
audio = str(resolved)
print(f"Lazy resolved audio: {audio}", file=sys.stderr)
else:
raise FileNotFoundError(f"Audio file not found: {audio}")
if output == "pipe":
out = PipeOutput(size=(w, h), fps=fps, audio_source=audio)
elif output == "preview":
out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio)
elif output.endswith("/hls"):
# HLS output - output is a directory path ending in /hls
hls_dir = output[:-4] # Remove /hls suffix
out = HLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio)
elif output.endswith("/ipfs-hls"):
# IPFS HLS output - multi-resolution adaptive streaming
hls_dir = output[:-9] # Remove /ipfs-hls suffix
import os
ipfs_gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs")
# Build resume state for output if resuming
output_resume = None
if resume_from and resume_from.get('segment_cids'):
output_resume = {'segment_cids': resume_from['segment_cids']}
# Use multi-resolution output (renders original + 720p + 360p)
if MultiResolutionHLSOutput is not None:
print(f"[StreamInterpreter] Using multi-resolution HLS output ({w}x{h} + 720p + 360p)", file=sys.stderr)
out = MultiResolutionHLSOutput(
hls_dir,
source_size=(w, h),
fps=fps,
ipfs_gateway=ipfs_gateway,
on_playlist_update=self.on_playlist_update,
audio_source=audio,
resume_from=output_resume,
)
# Fallback to GPU single-resolution if multi-res not available
elif GPUHLSOutput is not None and check_gpu_encode_available():
print(f"[StreamInterpreter] Using GPU zero-copy encoding (single resolution)", file=sys.stderr)
out = GPUHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway,
on_playlist_update=self.on_playlist_update)
else:
out = IPFSHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway,
on_playlist_update=self.on_playlist_update)
else:
out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio)
# Calculate frames per segment based on fps and segment duration (4 seconds default)
segment_duration = 4.0
self._frames_per_segment = int(fps * segment_duration)
# Determine start frame (resume from checkpoint + 1, or 0)
start_frame = 0
if resume_from and resume_from.get('frame_num') is not None:
start_frame = resume_from['frame_num'] + 1
print(f"Starting from frame {start_frame} (checkpoint was at {resume_from['frame_num']})", file=sys.stderr)
try:
frame_times = []
profile_interval = 30 # Profile every N frames
scan_times = []
eval_times = []
write_times = []
# Batch accumulation for JAX
batch_deferred = [] # Accumulated DeferredEffectChains
batch_times = [] # Corresponding time values
batch_start_frame = 0
def flush_batch():
"""Execute accumulated batch and write results."""
nonlocal batch_deferred, batch_times
if not batch_deferred:
return
t_flush = time.time()
# Check if all chains have same structure (can batch)
first = batch_deferred[0]
can_batch = (
self.use_jax and
len(batch_deferred) >= 2 and
all(d.effects == first.effects for d in batch_deferred)
)
if can_batch:
# Try batched execution
frames = [d.base_frame for d in batch_deferred]
ts = [d.t for d in batch_deferred]
frame_nums = [d.frame_num for d in batch_deferred]
params_batch = [d.params_list for d in batch_deferred]
results = self._apply_batched_chain(
first.effects, params_batch, frames, ts, frame_nums
)
if results is not None:
# Write batched results
for result, t in zip(results, batch_times):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
return
# Fall back to sequential execution
for deferred, t in zip(batch_deferred, batch_times):
result = self._force_deferred(deferred)
if result is not None and hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
for frame_num in range(start_frame, n_frames):
if not out.is_open:
break
frame_start = time.time()
ctx.t = frame_num * frame_time
ctx.frame_num = frame_num
# Build frame environment with context
frame_env = {
'ctx': {
't': ctx.t,
'frame-num': ctx.frame_num,
'fps': ctx.fps,
},
't': ctx.t, # Also expose t directly for convenience
'frame-num': ctx.frame_num,
}
# Step scans
t0 = time.time()
self._step_scans(ctx, frame_env)
scan_times.append(time.time() - t0)
# Evaluate pipeline
t1 = time.time()
result = self._eval(self.frame_pipeline, frame_env)
eval_times.append(time.time() - t1)
t2 = time.time()
if result is not None:
if isinstance(result, DeferredEffectChain):
# Accumulate for batching
batch_deferred.append(result)
batch_times.append(ctx.t)
# Flush when batch is full
if len(batch_deferred) >= self.jax_batch_size:
flush_batch()
else:
# Not deferred - flush any pending batch first, then write
flush_batch()
if hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, ctx.t)
write_times.append(time.time() - t2)
frame_elapsed = time.time() - frame_start
frame_times.append(frame_elapsed)
# Checkpoint at segment boundaries (every _frames_per_segment frames)
if frame_num > 0 and frame_num % self._frames_per_segment == 0:
if self.on_checkpoint:
try:
checkpoint = {
'frame_num': frame_num,
't': ctx.t,
'scans': self._get_checkpoint_state(),
}
self.on_checkpoint(checkpoint)
except Exception as e:
print(f"Warning: checkpoint callback failed: {e}", file=sys.stderr)
# Progress with timing and profile breakdown
if frame_num % profile_interval == 0 and frame_num > 0:
pct = 100 * frame_num / n_frames
avg_ms = 1000 * sum(frame_times[-profile_interval:]) / max(1, len(frame_times[-profile_interval:]))
avg_scan = 1000 * sum(scan_times[-profile_interval:]) / max(1, len(scan_times[-profile_interval:]))
avg_eval = 1000 * sum(eval_times[-profile_interval:]) / max(1, len(eval_times[-profile_interval:]))
avg_write = 1000 * sum(write_times[-profile_interval:]) / max(1, len(write_times[-profile_interval:]))
target_ms = 1000 * frame_time
print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms] scan={avg_scan:.0f}ms eval={avg_eval:.0f}ms write={avg_write:.0f}ms", end="", file=sys.stderr, flush=True)
# Call progress callback if set (for Celery task state updates)
if self.on_progress:
try:
self.on_progress(pct, frame_num, n_frames)
except Exception as e:
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
# Flush any remaining batch
flush_batch()
finally:
out.close()
# Store output for access to properties like playlist_cid
self.output = out
print("\nDone", file=sys.stderr)
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
sources_config: str = None, audio_config: str = None, use_jax: bool = False):
"""Run a streaming sexp."""
interp = StreamInterpreter(sexp_path, use_jax=use_jax)
if fps:
interp.config['fps'] = fps
if sources_config:
interp.sources_config = Path(sources_config)
if audio_config:
interp.audio_config = Path(audio_config)
interp.run(duration=duration, output=output)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)")
parser.add_argument("sexp", help="Path to .sexp file")
parser.add_argument("-d", "--duration", type=float, default=None)
parser.add_argument("-o", "--output", default="pipe")
parser.add_argument("--fps", type=float, default=None)
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects")
args = parser.parse_args()
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
sources_config=args.sources_config, audio_config=args.audio_config,
use_jax=args.jax)