""" Fully Generic Streaming S-expression Interpreter. The interpreter knows NOTHING about video, audio, or any domain. All domain logic comes from primitives loaded via (require-primitives ...). Built-in forms: - Control: if, cond, let, let*, lambda, -> - Arithmetic: +, -, *, /, mod, map-range - Comparison: <, >, =, <=, >=, and, or, not - Data: dict, get, list, nth, len, quote - Random: rand, rand-int, rand-range - Scan: bind (access scan state) Everything else comes from primitives or effects. Context (ctx) is passed explicitly to frame evaluation: - ctx.t: current time - ctx.frame-num: current frame number - ctx.fps: frames per second """ import sys import time import json import hashlib import math import numpy as np from pathlib import Path from dataclasses import dataclass from typing import Dict, List, Any, Optional, Tuple, Callable # Use local sexp_effects parser (supports namespaced symbols like math:sin) sys.path.insert(0, str(Path(__file__).parent.parent)) from sexp_effects.parser import parse, parse_all, Symbol, Keyword # JAX backend (optional - loaded on demand) _JAX_AVAILABLE = False _jax_compiler = None def _init_jax(): """Lazily initialize JAX compiler.""" global _JAX_AVAILABLE, _jax_compiler if _jax_compiler is not None: return _JAX_AVAILABLE try: from streaming.sexp_to_jax import JaxCompiler, compile_effect_file _jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file} _JAX_AVAILABLE = True print("JAX backend initialized", file=sys.stderr) except ImportError as e: print(f"JAX backend not available: {e}", file=sys.stderr) _JAX_AVAILABLE = False return _JAX_AVAILABLE @dataclass class Context: """Runtime context passed to frame evaluation.""" t: float = 0.0 frame_num: int = 0 fps: float = 30.0 class StreamInterpreter: """ Fully generic streaming sexp interpreter. No domain-specific knowledge - just evaluates expressions and calls primitives. """ def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False): self.sexp_path = Path(sexp_path) self.sexp_dir = self.sexp_path.parent self.actor_id = actor_id # For friendly name resolution text = self.sexp_path.read_text() self.ast = parse(text) self.config = self._parse_config() # Global environment for def bindings self.globals: Dict[str, Any] = {} # Scans self.scans: Dict[str, dict] = {} # Audio playback path (for syncing output) self.audio_playback: Optional[str] = None # Registries for external definitions self.primitives: Dict[str, Any] = {} self.effects: Dict[str, dict] = {} self.macros: Dict[str, dict] = {} # JAX backend for accelerated effect evaluation self.use_jax = use_jax self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects if use_jax: if _init_jax(): print("JAX acceleration enabled", file=sys.stderr) else: print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr) self.use_jax = False # Try multiple locations for primitive_libs possible_paths = [ self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir self.sexp_dir / "sexp_effects" / "primitive_libs", # app root Path(__file__).parent.parent / "sexp_effects" / "primitive_libs", # relative to interpreter ] self.primitive_lib_dir = next((p for p in possible_paths if p.exists()), possible_paths[0]) self.frame_pipeline = None # External config files (set before run()) self.sources_config: Optional[Path] = None self.audio_config: Optional[Path] = None # Error tracking self.errors: List[str] = [] # Callback for live streaming (called when IPFS playlist is updated) self.on_playlist_update: callable = None # Callback for progress updates (called periodically during rendering) # Signature: on_progress(percent: float, frame_num: int, total_frames: int) self.on_progress: callable = None # Callback for checkpoint saves (called at segment boundaries for resumability) # Signature: on_checkpoint(checkpoint: dict) # checkpoint contains: frame_num, t, scans self.on_checkpoint: callable = None # Frames per segment for checkpoint timing (default 4 seconds at 30fps = 120 frames) self._frames_per_segment: int = 120 def _resolve_name(self, name: str) -> Optional[Path]: """Resolve a friendly name to a file path using the naming service.""" try: # Import here to avoid circular imports from tasks.streaming import resolve_asset path = resolve_asset(name, self.actor_id) if path: return path except Exception as e: print(f"Warning: failed to resolve name '{name}': {e}", file=sys.stderr) return None def _record_error(self, msg: str): """Record an error that occurred during evaluation.""" self.errors.append(msg) print(f"ERROR: {msg}", file=sys.stderr) def _maybe_to_numpy(self, val, for_gpu_primitive: bool = False): """Convert GPU frames/CuPy arrays to numpy for CPU primitives. If for_gpu_primitive=True, preserve GPU data (CuPy arrays stay on GPU). """ if val is None: return val # For GPU primitives, keep data on GPU if for_gpu_primitive: # Handle GPUFrame - return the GPU array if hasattr(val, 'gpu') and hasattr(val, 'is_on_gpu'): if val.is_on_gpu: return val.gpu return val.cpu # CuPy arrays pass through unchanged if hasattr(val, '__cuda_array_interface__'): return val return val # For CPU primitives, convert to numpy # Handle GPUFrame objects (have .cpu property) if hasattr(val, 'cpu'): return val.cpu # Handle CuPy arrays (have .get() method) if hasattr(val, 'get') and callable(val.get): return val.get() return val def _load_config_file(self, config_path): """Load a config file and process its definitions.""" config_path = Path(config_path) # Accept str or Path if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") text = config_path.read_text() ast = parse_all(text) for form in ast: if not isinstance(form, list) or not form: continue if not isinstance(form[0], Symbol): continue cmd = form[0].name if cmd == 'require-primitives': lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') self._load_primitives(lib_name) elif cmd == 'def': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) value = self._eval(form[2], self.globals) self.globals[name] = value print(f"Config: {name}", file=sys.stderr) elif cmd == 'audio-playback': # Path relative to working directory (consistent with other paths) path = str(form[1]).strip('"') self.audio_playback = str(Path(path).resolve()) print(f"Audio playback: {self.audio_playback}", file=sys.stderr) def _parse_config(self) -> dict: """Parse config from (stream name :key val ...).""" config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720} if not self.ast or not isinstance(self.ast[0], Symbol): return config if self.ast[0].name != 'stream': return config i = 2 while i < len(self.ast): if isinstance(self.ast[i], Keyword): config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None i += 2 elif isinstance(self.ast[i], list): break else: i += 1 return config def _load_primitives(self, lib_name: str): """Load primitives from a Python library file. Prefers GPU-accelerated versions (*_gpu.py) when available. """ import importlib.util # Try GPU version first, then fall back to CPU version lib_names_to_try = [f"{lib_name}_gpu", lib_name] lib_path = None actual_lib_name = lib_name for try_lib in lib_names_to_try: lib_paths = [ self.primitive_lib_dir / f"{try_lib}.py", self.sexp_dir / "primitive_libs" / f"{try_lib}.py", self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{try_lib}.py", ] for p in lib_paths: if p.exists(): lib_path = p actual_lib_name = try_lib break if lib_path: break if not lib_path: raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}") spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Check if this is a GPU-accelerated module is_gpu = actual_lib_name.endswith('_gpu') gpu_tag = " [GPU]" if is_gpu else "" count = 0 for name in dir(module): if name.startswith('prim_'): func = getattr(module, name) prim_name = name[5:] dash_name = prim_name.replace('_', '-') # Register with original lib_name namespace (geometry:rotate, not geometry_gpu:rotate) # Don't overwrite if already registered (allows pre-registration of overrides) key = f"{lib_name}:{dash_name}" if key not in self.primitives: self.primitives[key] = func count += 1 if hasattr(module, 'PRIMITIVES'): prims = getattr(module, 'PRIMITIVES') if isinstance(prims, dict): for name, func in prims.items(): # Register with original lib_name namespace # Don't overwrite if already registered dash_name = name.replace('_', '-') key = f"{lib_name}:{dash_name}" if key not in self.primitives: self.primitives[key] = func count += 1 print(f"Loaded primitives: {lib_name} ({count} functions){gpu_tag}", file=sys.stderr) def _load_effect(self, effect_path: Path): """Load and register an effect from a .sexp file.""" if not effect_path.exists(): raise FileNotFoundError(f"Effect/include file not found: {effect_path}") text = effect_path.read_text() ast = parse_all(text) for form in ast: if not isinstance(form, list) or not form: continue if not isinstance(form[0], Symbol): continue cmd = form[0].name if cmd == 'require-primitives': lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') self._load_primitives(lib_name) elif cmd == 'define-effect': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) params = {} body = None i = 2 while i < len(form): if isinstance(form[i], Keyword): if form[i].name == 'params' and i + 1 < len(form): for pdef in form[i + 1]: if isinstance(pdef, list) and pdef: pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0]) pinfo = {'default': 0} j = 1 while j < len(pdef): if isinstance(pdef[j], Keyword) and j + 1 < len(pdef): pinfo[pdef[j].name] = pdef[j + 1] j += 2 else: j += 1 params[pname] = pinfo i += 2 else: body = form[i] i += 1 self.effects[name] = {'params': params, 'body': body} self.jax_effect_paths[name] = effect_path # Track source for JAX compilation print(f"Effect: {name}", file=sys.stderr) # Try to compile with JAX if enabled if self.use_jax and _JAX_AVAILABLE: self._compile_jax_effect(name, effect_path) elif cmd == 'defmacro': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] body = form[3] self.macros[name] = {'params': params, 'body': body} elif cmd == 'effect': # Handle (effect name :path "...") or (effect name :name "...") in included files i = 2 while i < len(form): if isinstance(form[i], Keyword): kw = form[i].name if kw == 'path': path = str(form[i + 1]).strip('"') full = (effect_path.parent / path).resolve() self._load_effect(full) i += 2 elif kw == 'name': fname = str(form[i + 1]).strip('"') resolved = self._resolve_name(fname) if resolved: self._load_effect(resolved) else: raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") i += 2 else: i += 1 else: i += 1 elif cmd == 'include': # Handle (include :path "...") or (include :name "...") in included files i = 1 while i < len(form): if isinstance(form[i], Keyword): kw = form[i].name if kw == 'path': path = str(form[i + 1]).strip('"') full = (effect_path.parent / path).resolve() self._load_effect(full) i += 2 elif kw == 'name': fname = str(form[i + 1]).strip('"') resolved = self._resolve_name(fname) if resolved: self._load_effect(resolved) else: raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") i += 2 else: i += 1 else: i += 1 elif cmd == 'scan': # Handle scans from included files name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) trigger_expr = form[2] init_val, step_expr = {}, None i = 3 while i < len(form): if isinstance(form[i], Keyword): if form[i].name == 'init' and i + 1 < len(form): init_val = self._eval(form[i + 1], self.globals) elif form[i].name == 'step' and i + 1 < len(form): step_expr = form[i + 1] i += 2 else: i += 1 self.scans[name] = { 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, 'init': init_val, 'step': step_expr, 'trigger': trigger_expr, } print(f"Scan: {name}", file=sys.stderr) def _compile_jax_effect(self, name: str, effect_path: Path): """Compile an effect with JAX for accelerated execution.""" if not _JAX_AVAILABLE or name in self.jax_effects: return try: compile_effect_file = _jax_compiler['compile_effect_file'] jax_fn = compile_effect_file(str(effect_path)) self.jax_effects[name] = jax_fn print(f" [JAX compiled: {name}]", file=sys.stderr) except Exception as e: # Silently fall back to interpreter for unsupported effects if 'Unknown operation' not in str(e): print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr) def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]: """Apply a JAX-compiled effect to a frame.""" if name not in self.jax_effects: return None try: jax_fn = self.jax_effects[name] # Ensure frame is numpy array if hasattr(frame, 'cpu'): frame = frame.cpu elif hasattr(frame, 'get'): frame = frame.get() # Get seed from config for deterministic random seed = self.config.get('seed', 42) # Call JAX function with parameters result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) # Convert result back to numpy if needed if hasattr(result, 'block_until_ready'): result.block_until_ready() # Ensure computation is complete if hasattr(result, '__array__'): result = np.asarray(result) return result except Exception as e: # Fall back to interpreter on error print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) return None def _init(self): """Initialize from sexp - load primitives, effects, defs, scans.""" # Set random seed for deterministic output seed = self.config.get('seed', 42) try: from sexp_effects.primitive_libs.core import set_random_seed set_random_seed(seed) except ImportError: pass # Load external config files first (they can override recipe definitions) if self.sources_config: self._load_config_file(self.sources_config) if self.audio_config: self._load_config_file(self.audio_config) for form in self.ast: if not isinstance(form, list) or not form: continue if not isinstance(form[0], Symbol): continue cmd = form[0].name if cmd == 'require-primitives': lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') self._load_primitives(lib_name) elif cmd == 'effect': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) i = 2 while i < len(form): if isinstance(form[i], Keyword): kw = form[i].name if kw == 'path': path = str(form[i + 1]).strip('"') full = (self.sexp_dir / path).resolve() self._load_effect(full) i += 2 elif kw == 'name': # Resolve friendly name to path fname = str(form[i + 1]).strip('"') resolved = self._resolve_name(fname) if resolved: self._load_effect(resolved) else: raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") i += 2 else: i += 1 else: i += 1 elif cmd == 'include': i = 1 while i < len(form): if isinstance(form[i], Keyword): kw = form[i].name if kw == 'path': path = str(form[i + 1]).strip('"') full = (self.sexp_dir / path).resolve() self._load_effect(full) i += 2 elif kw == 'name': # Resolve friendly name to path fname = str(form[i + 1]).strip('"') resolved = self._resolve_name(fname) if resolved: self._load_effect(resolved) else: raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") i += 2 else: i += 1 else: i += 1 elif cmd == 'audio-playback': # (audio-playback "path") - set audio file for playback sync # Skip if already set by config file if self.audio_playback is None: path = str(form[1]).strip('"') # Try to resolve as friendly name first resolved = self._resolve_name(path) if resolved: self.audio_playback = str(resolved) else: # Fall back to relative path self.audio_playback = str((self.sexp_dir / path).resolve()) print(f"Audio playback: {self.audio_playback}", file=sys.stderr) elif cmd == 'def': # (def name expr) - evaluate and store in globals # Skip if already defined by config file name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) if name in self.globals: print(f"Def: {name} (from config, skipped)", file=sys.stderr) continue value = self._eval(form[2], self.globals) self.globals[name] = value print(f"Def: {name}", file=sys.stderr) elif cmd == 'defmacro': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] body = form[3] self.macros[name] = {'params': params, 'body': body} elif cmd == 'scan': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) trigger_expr = form[2] init_val, step_expr = {}, None i = 3 while i < len(form): if isinstance(form[i], Keyword): if form[i].name == 'init' and i + 1 < len(form): init_val = self._eval(form[i + 1], self.globals) elif form[i].name == 'step' and i + 1 < len(form): step_expr = form[i + 1] i += 2 else: i += 1 self.scans[name] = { 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, 'init': init_val, 'step': step_expr, 'trigger': trigger_expr, } print(f"Scan: {name}", file=sys.stderr) elif cmd == 'frame': self.frame_pipeline = form[1] if len(form) > 1 else None def _eval(self, expr, env: dict) -> Any: """Evaluate an expression.""" # Primitives if isinstance(expr, (int, float)): return expr if isinstance(expr, str): return expr if isinstance(expr, bool): return expr if isinstance(expr, Symbol): name = expr.name # Built-in constants if name == 'pi': return math.pi if name == 'true': return True if name == 'false': return False if name == 'nil': return None # Environment lookup if name in env: return env[name] # Global lookup if name in self.globals: return self.globals[name] # Scan state lookup if name in self.scans: return self.scans[name]['state'] raise NameError(f"Undefined variable: {name}") if isinstance(expr, Keyword): return expr.name # Handle dicts from new parser - evaluate values if isinstance(expr, dict): return {k: self._eval(v, env) for k, v in expr.items()} if not isinstance(expr, list) or not expr: return expr # Dict literal {:key val ...} if isinstance(expr[0], Keyword): result = {} i = 0 while i < len(expr): if isinstance(expr[i], Keyword): result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None i += 2 else: i += 1 return result head = expr[0] if not isinstance(head, Symbol): return [self._eval(e, env) for e in expr] op = head.name args = expr[1:] # Check for closure call if op in env: val = env[op] if isinstance(val, dict) and val.get('_type') == 'closure': closure = val closure_env = dict(closure['env']) for i, pname in enumerate(closure['params']): closure_env[pname] = self._eval(args[i], env) if i < len(args) else None return self._eval(closure['body'], closure_env) if op in self.globals: val = self.globals[op] if isinstance(val, dict) and val.get('_type') == 'closure': closure = val closure_env = dict(closure['env']) for i, pname in enumerate(closure['params']): closure_env[pname] = self._eval(args[i], env) if i < len(args) else None return self._eval(closure['body'], closure_env) # Threading macro if op == '->': result = self._eval(args[0], env) for form in args[1:]: if isinstance(form, list) and form: new_form = [form[0], result] + form[1:] result = self._eval(new_form, env) else: result = self._eval([form, result], env) return result # === Binding === if op == 'bind': scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) if scan_name in self.scans: state = self.scans[scan_name]['state'] if len(args) > 1: key = args[1].name if isinstance(args[1], Keyword) else str(args[1]) return state.get(key, 0) return state return 0 # === Arithmetic === if op == '+': return sum(self._eval(a, env) for a in args) if op == '-': vals = [self._eval(a, env) for a in args] return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] if op == '*': result = 1 for a in args: result *= self._eval(a, env) return result if op == '/': vals = [self._eval(a, env) for a in args] return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 if op == 'mod': vals = [self._eval(a, env) for a in args] return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 # === Comparison === if op == '<': return self._eval(args[0], env) < self._eval(args[1], env) if op == '>': return self._eval(args[0], env) > self._eval(args[1], env) if op == '=': return self._eval(args[0], env) == self._eval(args[1], env) if op == '<=': return self._eval(args[0], env) <= self._eval(args[1], env) if op == '>=': return self._eval(args[0], env) >= self._eval(args[1], env) if op == 'and': for arg in args: if not self._eval(arg, env): return False return True if op == 'or': result = False for arg in args: result = self._eval(arg, env) if result: return result return result if op == 'not': return not self._eval(args[0], env) # === Logic === if op == 'if': cond = self._eval(args[0], env) if cond: return self._eval(args[1], env) return self._eval(args[2], env) if len(args) > 2 else None if op == 'cond': i = 0 while i < len(args) - 1: pred = self._eval(args[i], env) if pred: return self._eval(args[i + 1], env) i += 2 return None if op == 'lambda': params = args[0] body = args[1] param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} if op == 'let' or op == 'let*': bindings = args[0] body = args[1] new_env = dict(env) if bindings and isinstance(bindings[0], list): for binding in bindings: if isinstance(binding, list) and len(binding) >= 2: name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) val = self._eval(binding[1], new_env) new_env[name] = val else: i = 0 while i < len(bindings): name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) val = self._eval(bindings[i + 1], new_env) new_env[name] = val i += 2 return self._eval(body, new_env) # === Dict === if op == 'dict': result = {} i = 0 while i < len(args): if isinstance(args[i], Keyword): key = args[i].name val = self._eval(args[i + 1], env) if i + 1 < len(args) else None result[key] = val i += 2 else: i += 1 return result if op == 'get': obj = self._eval(args[0], env) key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) if isinstance(obj, dict): return obj.get(key, 0) return 0 # === List === if op == 'list': return [self._eval(a, env) for a in args] if op == 'quote': return args[0] if args else None if op == 'nth': lst = self._eval(args[0], env) idx = int(self._eval(args[1], env)) if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst): return lst[idx] return None if op == 'len': val = self._eval(args[0], env) return len(val) if hasattr(val, '__len__') else 0 if op == 'map': seq = self._eval(args[0], env) fn = self._eval(args[1], env) if not isinstance(seq, (list, tuple)): return [] # Handle closure (lambda from sexp) if isinstance(fn, dict) and fn.get('_type') == 'closure': results = [] for item in seq: closure_env = dict(fn['env']) if fn['params']: closure_env[fn['params'][0]] = item results.append(self._eval(fn['body'], closure_env)) return results # Handle Python callable if callable(fn): return [fn(item) for item in seq] return [] # === Effects === if op in self.effects: effect = self.effects[op] effect_env = dict(env) param_names = list(effect['params'].keys()) for pname, pdef in effect['params'].items(): effect_env[pname] = pdef.get('default', 0) positional_idx = 0 frame_val = None i = 0 while i < len(args): if isinstance(args[i], Keyword): pname = args[i].name if pname in effect['params'] and i + 1 < len(args): effect_env[pname] = self._eval(args[i + 1], env) i += 2 else: val = self._eval(args[i], env) if positional_idx == 0: effect_env['frame'] = val frame_val = val elif positional_idx - 1 < len(param_names): effect_env[param_names[positional_idx - 1]] = val positional_idx += 1 i += 1 # Try JAX-accelerated execution first if self.use_jax and op in self.jax_effects and frame_val is not None: # Build params dict for JAX (exclude 'frame') jax_params = {k: v for k, v in effect_env.items() if k != 'frame' and k in effect['params']} t = env.get('t', 0.0) frame_num = env.get('frame-num', 0) result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num) if result is not None: return result # Fall through to interpreter if JAX fails return self._eval(effect['body'], effect_env) # === Primitives === if op in self.primitives: prim_func = self.primitives[op] # Check if this is a GPU primitive (preserves GPU arrays) is_gpu_prim = op.startswith('gpu:') or 'gpu' in op.lower() evaluated_args = [] kwargs = {} i = 0 while i < len(args): if isinstance(args[i], Keyword): k = args[i].name v = self._eval(args[i + 1], env) if i + 1 < len(args) else None kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) i += 2 else: evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) i += 1 try: if kwargs: return prim_func(*evaluated_args, **kwargs) return prim_func(*evaluated_args) except Exception as e: self._record_error(f"Primitive {op} error: {e}") raise RuntimeError(f"Primitive {op} failed: {e}") # === Macros (function-like: args evaluated before binding) === if op in self.macros: macro = self.macros[op] macro_env = dict(env) for i, pname in enumerate(macro['params']): # Evaluate args in calling environment before binding macro_env[pname] = self._eval(args[i], env) if i < len(args) else None return self._eval(macro['body'], macro_env) # Underscore variant lookup prim_name = op.replace('-', '_') if prim_name in self.primitives: prim_func = self.primitives[prim_name] # Check if this is a GPU primitive (preserves GPU arrays) is_gpu_prim = 'gpu' in prim_name.lower() evaluated_args = [] kwargs = {} i = 0 while i < len(args): if isinstance(args[i], Keyword): k = args[i].name.replace('-', '_') v = self._eval(args[i + 1], env) if i + 1 < len(args) else None kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) i += 2 else: evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) i += 1 try: if kwargs: return prim_func(*evaluated_args, **kwargs) return prim_func(*evaluated_args) except Exception as e: self._record_error(f"Primitive {op} error: {e}") raise RuntimeError(f"Primitive {op} failed: {e}") # Unknown function call - raise meaningful error raise RuntimeError(f"Unknown function or primitive: '{op}'. " f"Available primitives: {sorted(list(self.primitives.keys())[:10])}... " f"Available effects: {sorted(list(self.effects.keys())[:10])}... " f"Available macros: {sorted(list(self.macros.keys())[:10])}...") def _step_scans(self, ctx: Context, env: dict): """Step scans based on trigger evaluation.""" for name, scan in self.scans.items(): trigger_expr = scan['trigger'] # Evaluate trigger in context should_step = self._eval(trigger_expr, env) if should_step: state = scan['state'] step_env = dict(state) step_env.update(env) new_state = self._eval(scan['step'], step_env) if isinstance(new_state, dict): scan['state'] = new_state else: scan['state'] = {'acc': new_state} def _restore_checkpoint(self, checkpoint: dict): """Restore scan states from a checkpoint. Called when resuming a render from a previous checkpoint. Args: checkpoint: Dict with 'scans' key containing {scan_name: state_dict} """ scans_state = checkpoint.get('scans', {}) for name, state in scans_state.items(): if name in self.scans: self.scans[name]['state'] = dict(state) print(f"Restored scan '{name}' state from checkpoint", file=sys.stderr) def _get_checkpoint_state(self) -> dict: """Get current scan states for checkpointing. Returns: Dict mapping scan names to their current state dicts """ return {name: dict(scan['state']) for name, scan in self.scans.items()} def run(self, duration: float = None, output: str = "pipe", resume_from: dict = None): """Run the streaming pipeline. Args: duration: Duration in seconds (auto-detected from audio if None) output: Output mode ("pipe", "preview", path/hls, path/ipfs-hls, or file path) resume_from: Checkpoint dict to resume from, with keys: - frame_num: Last completed frame - t: Time value for checkpoint frame - scans: Dict of scan states to restore - segment_cids: Dict of quality -> {seg_num: cid} for output resume """ # Import output classes - handle both package and direct execution try: from .output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput from .gpu_output import GPUHLSOutput, check_gpu_encode_available from .multi_res_output import MultiResolutionHLSOutput except ImportError: from output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput try: from gpu_output import GPUHLSOutput, check_gpu_encode_available except ImportError: GPUHLSOutput = None check_gpu_encode_available = lambda: False try: from multi_res_output import MultiResolutionHLSOutput except ImportError: MultiResolutionHLSOutput = None self._init() # Restore checkpoint state if resuming if resume_from: self._restore_checkpoint(resume_from) print(f"Resuming from frame {resume_from.get('frame_num', 0)}", file=sys.stderr) if not self.frame_pipeline: print("Error: no (frame ...) pipeline defined", file=sys.stderr) return w = self.config.get('width', 720) h = self.config.get('height', 720) fps = self.config.get('fps', 30) if duration is None: # Try to get duration from audio if available for name, val in self.globals.items(): if hasattr(val, 'duration'): duration = val.duration print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) break else: duration = 60.0 n_frames = int(duration * fps) frame_time = 1.0 / fps print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr) # Create context ctx = Context(fps=fps) # Output (with optional audio sync) # Resolve audio path lazily here if it wasn't resolved during parsing audio = self.audio_playback if audio and not Path(audio).exists(): # Try to resolve as friendly name (may have failed during parsing) audio_name = Path(audio).name # Get just the name part resolved = self._resolve_name(audio_name) if resolved and resolved.exists(): audio = str(resolved) print(f"Lazy resolved audio: {audio}", file=sys.stderr) else: raise FileNotFoundError(f"Audio file not found: {audio}") if output == "pipe": out = PipeOutput(size=(w, h), fps=fps, audio_source=audio) elif output == "preview": out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio) elif output.endswith("/hls"): # HLS output - output is a directory path ending in /hls hls_dir = output[:-4] # Remove /hls suffix out = HLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio) elif output.endswith("/ipfs-hls"): # IPFS HLS output - multi-resolution adaptive streaming hls_dir = output[:-9] # Remove /ipfs-hls suffix import os ipfs_gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") # Build resume state for output if resuming output_resume = None if resume_from and resume_from.get('segment_cids'): output_resume = {'segment_cids': resume_from['segment_cids']} # Use multi-resolution output (renders original + 720p + 360p) if MultiResolutionHLSOutput is not None: print(f"[StreamInterpreter] Using multi-resolution HLS output ({w}x{h} + 720p + 360p)", file=sys.stderr) out = MultiResolutionHLSOutput( hls_dir, source_size=(w, h), fps=fps, ipfs_gateway=ipfs_gateway, on_playlist_update=self.on_playlist_update, audio_source=audio, resume_from=output_resume, ) # Fallback to GPU single-resolution if multi-res not available elif GPUHLSOutput is not None and check_gpu_encode_available(): print(f"[StreamInterpreter] Using GPU zero-copy encoding (single resolution)", file=sys.stderr) out = GPUHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, on_playlist_update=self.on_playlist_update) else: out = IPFSHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, on_playlist_update=self.on_playlist_update) else: out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio) # Calculate frames per segment based on fps and segment duration (4 seconds default) segment_duration = 4.0 self._frames_per_segment = int(fps * segment_duration) # Determine start frame (resume from checkpoint + 1, or 0) start_frame = 0 if resume_from and resume_from.get('frame_num') is not None: start_frame = resume_from['frame_num'] + 1 print(f"Starting from frame {start_frame} (checkpoint was at {resume_from['frame_num']})", file=sys.stderr) try: frame_times = [] profile_interval = 30 # Profile every N frames scan_times = [] eval_times = [] write_times = [] for frame_num in range(start_frame, n_frames): if not out.is_open: break frame_start = time.time() ctx.t = frame_num * frame_time ctx.frame_num = frame_num # Build frame environment with context frame_env = { 'ctx': { 't': ctx.t, 'frame-num': ctx.frame_num, 'fps': ctx.fps, }, 't': ctx.t, # Also expose t directly for convenience 'frame-num': ctx.frame_num, } # Step scans t0 = time.time() self._step_scans(ctx, frame_env) scan_times.append(time.time() - t0) # Evaluate pipeline t1 = time.time() result = self._eval(self.frame_pipeline, frame_env) eval_times.append(time.time() - t1) t2 = time.time() if result is not None and hasattr(result, 'shape'): out.write(result, ctx.t) write_times.append(time.time() - t2) frame_elapsed = time.time() - frame_start frame_times.append(frame_elapsed) # Checkpoint at segment boundaries (every _frames_per_segment frames) if frame_num > 0 and frame_num % self._frames_per_segment == 0: if self.on_checkpoint: try: checkpoint = { 'frame_num': frame_num, 't': ctx.t, 'scans': self._get_checkpoint_state(), } self.on_checkpoint(checkpoint) except Exception as e: print(f"Warning: checkpoint callback failed: {e}", file=sys.stderr) # Progress with timing and profile breakdown if frame_num % profile_interval == 0 and frame_num > 0: pct = 100 * frame_num / n_frames avg_ms = 1000 * sum(frame_times[-profile_interval:]) / max(1, len(frame_times[-profile_interval:])) avg_scan = 1000 * sum(scan_times[-profile_interval:]) / max(1, len(scan_times[-profile_interval:])) avg_eval = 1000 * sum(eval_times[-profile_interval:]) / max(1, len(eval_times[-profile_interval:])) avg_write = 1000 * sum(write_times[-profile_interval:]) / max(1, len(write_times[-profile_interval:])) target_ms = 1000 * frame_time print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms] scan={avg_scan:.0f}ms eval={avg_eval:.0f}ms write={avg_write:.0f}ms", end="", file=sys.stderr, flush=True) # Call progress callback if set (for Celery task state updates) if self.on_progress: try: self.on_progress(pct, frame_num, n_frames) except Exception as e: print(f"Warning: progress callback failed: {e}", file=sys.stderr) finally: out.close() # Store output for access to properties like playlist_cid self.output = out print("\nDone", file=sys.stderr) def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None, sources_config: str = None, audio_config: str = None, use_jax: bool = False): """Run a streaming sexp.""" interp = StreamInterpreter(sexp_path, use_jax=use_jax) if fps: interp.config['fps'] = fps if sources_config: interp.sources_config = Path(sources_config) if audio_config: interp.audio_config = Path(audio_config) interp.run(duration=duration, output=output) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)") parser.add_argument("sexp", help="Path to .sexp file") parser.add_argument("-d", "--duration", type=float, default=None) parser.add_argument("-o", "--output", default="pipe") parser.add_argument("--fps", type=float, default=None) parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file") parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file") parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects") args = parser.parse_args() run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps, sources_config=args.sources_config, audio_config=args.audio_config, use_jax=args.jax)