Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s

- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 15:12:54 +00:00
parent dbc4ece2cc
commit fc9597456f
30 changed files with 7749 additions and 165 deletions

View File

@@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation:
"""
import sys
import os
import time
import json
import hashlib
@@ -62,6 +63,38 @@ class Context:
fps: float = 30.0
class DeferredEffectChain:
"""
Represents a chain of JAX effects that haven't been executed yet.
Allows effects to be accumulated through let bindings and fused
into a single JIT-compiled function when the result is needed.
"""
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
self.effects = effects # List of effect names, innermost first
self.params_list = params_list # List of param dicts, matching effects
self.base_frame = base_frame # The actual frame array at the start
self.t = t
self.frame_num = frame_num
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
"""Add another effect to the chain (outermost)."""
return DeferredEffectChain(
self.effects + [effect_name],
self.params_list + [params],
self.base_frame,
self.t,
self.frame_num
)
@property
def shape(self):
"""Allow shape check without forcing execution."""
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
class StreamInterpreter:
"""
Fully generic streaming sexp interpreter.
@@ -98,6 +131,9 @@ class StreamInterpreter:
self.use_jax = use_jax
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
if use_jax:
if _init_jax():
print("JAX acceleration enabled", file=sys.stderr)
@@ -238,6 +274,8 @@ class StreamInterpreter:
"""Load primitives from a Python library file.
Prefers GPU-accelerated versions (*_gpu.py) when available.
Uses cached modules from sys.modules to ensure consistent state
(e.g., same RNG instance for all interpreters).
"""
import importlib.util
@@ -264,9 +302,26 @@ class StreamInterpreter:
if not lib_path:
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Use cached module if already imported to preserve state (e.g., RNG)
# This is critical for deterministic random number sequences
# Check multiple possible module keys (standard import paths and our cache)
possible_keys = [
f"sexp_effects.primitive_libs.{actual_lib_name}",
f"sexp_primitives.{actual_lib_name}",
]
module = None
for key in possible_keys:
if key in sys.modules:
module = sys.modules[key]
break
if module is None:
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Cache for future use under our key
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
# Check if this is a GPU-accelerated module
is_gpu = actual_lib_name.endswith('_gpu')
@@ -452,30 +507,353 @@ class StreamInterpreter:
try:
jax_fn = self.jax_effects[name]
# Ensure frame is numpy array
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
# JAX handles numpy and JAX arrays natively, no conversion needed
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get'):
frame = frame.get()
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get() # CuPy array -> numpy
# Get seed from config for deterministic random
seed = self.config.get('seed', 42)
# Call JAX function with parameters
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
# Convert result back to numpy if needed
if hasattr(result, 'block_until_ready'):
result.block_until_ready() # Ensure computation is complete
if hasattr(result, '__array__'):
result = np.asarray(result)
return result
# Return JAX array directly - don't block or convert per-effect
# Conversion to numpy happens once at frame write time
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
except Exception as e:
# Fall back to interpreter on error
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
return None
def _is_jax_effect_expr(self, expr) -> bool:
"""Check if an expression is a JAX-compiled effect call."""
if not isinstance(expr, list) or not expr:
return False
head = expr[0]
if not isinstance(head, Symbol):
return False
return head.name in self.jax_effects
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
"""
Extract a chain of JAX effects from an expression.
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
effect_names and params_list are in execution order (innermost first).
For (effect1 (effect2 frame :p1 v1) :p2 v2):
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
"""
if not self._is_jax_effect_expr(expr):
return None
chain = []
params_list = []
current = expr
while self._is_jax_effect_expr(current):
head = current[0]
effect_name = head.name
args = current[1:]
# Extract params for this effect
effect = self.effects[effect_name]
effect_params = {}
for pname, pdef in effect['params'].items():
effect_params[pname] = pdef.get('default', 0)
# Find the frame argument (first positional) and other params
frame_arg = None
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
pname = args[i].name
if pname in effect['params'] and i + 1 < len(args):
effect_params[pname] = self._eval(args[i + 1], env)
i += 2
else:
if frame_arg is None:
frame_arg = args[i] # First positional is frame
i += 1
chain.append(effect_name)
params_list.append(effect_params)
if frame_arg is None:
return None # No frame argument found
# Check if frame_arg is another effect call
if self._is_jax_effect_expr(frame_arg):
current = frame_arg
else:
# End of chain - frame_arg is the base frame
# Reverse to get innermost-first execution order
chain.reverse()
params_list.reverse()
return (chain, params_list, frame_arg)
return None
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
"""Generate a cache key for an effect chain.
Includes static param values in the key since they affect compilation.
"""
parts = []
for name, params in zip(effect_names, params_list):
param_parts = []
for pname in sorted(params.keys()):
pval = params[pname]
# Include static values in key (strings, bools)
if isinstance(pval, (str, bool)):
param_parts.append(f"{pname}={pval}")
else:
param_parts.append(pname)
parts.append(f"{name}:{','.join(param_parts)}")
return '|'.join(parts)
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a chain of effects into a single fused JAX function.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts for each effect (used to detect static types)
Returns:
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param names and identify static params from actual values
effect_param_names = []
static_params = ['seed'] # seed is always static
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
# Check actual values to identify static types
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.append(f"_p{i}_{pname}")
def fused_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
# Extract params for this effect from kwargs
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# JIT with static params (seed + any string/bool params)
return jax.jit(fused_fn, static_argnames=static_params)
except Exception as e:
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
"""Apply a chain of effects, using fused compilation if available."""
chain_key = self._get_chain_key(effect_names, params_list)
# Try to get or compile fused chain
if chain_key not in self.jax_fused_chains:
fused_fn = self._compile_effect_chain(effect_names, params_list)
self.jax_fused_chains[chain_key] = fused_fn
if fused_fn:
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
fused_fn = self.jax_fused_chains.get(chain_key)
if fused_fn is not None:
# Build kwargs with prefixed param names
kwargs = {}
for i, params in enumerate(params_list):
for pname, pval in params.items():
kwargs[f"_p{i}_{pname}"] = pval
seed = self.config.get('seed', 42)
# Handle GPU frames
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get()
try:
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
except Exception as e:
print(f"Fused chain error: {e}", file=sys.stderr)
# Fall back to sequential application
result = frame
for name, params in zip(effect_names, params_list):
result = self._apply_jax_effect(name, result, params, t, frame_num)
if result is None:
return None
return result
def _force_deferred(self, deferred: DeferredEffectChain):
"""Execute a deferred effect chain and return the actual array."""
if len(deferred.effects) == 0:
return deferred.base_frame
return self._apply_effect_chain(
deferred.effects,
deferred.params_list,
deferred.base_frame,
deferred.t,
deferred.frame_num
)
def _maybe_force(self, value):
"""Force a deferred chain if needed, otherwise return as-is."""
if isinstance(value, DeferredEffectChain):
return self._force_deferred(value)
return value
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a vmapped version of an effect chain for batch processing.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts (used to detect static types)
Returns:
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
import jax.numpy as jnp
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param info
effect_param_names = []
static_params = set()
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.add(f"_p{i}_{pname}")
# Single-frame function (will be vmapped)
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# Return unbatched function - we'll vmap at call time with proper in_axes
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
except Exception as e:
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
frames: list, ts: list, frame_nums: list) -> Optional[list]:
"""
Apply an effect chain to a batch of frames using vmap.
Args:
effect_names: List of effect names
params_list_batch: List of params_list for each frame in batch
frames: List of input frames
ts: List of time values
frame_nums: List of frame numbers
Returns:
List of output frames, or None on failure
"""
if not frames:
return []
# Use first frame's params for chain key (assume same structure)
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
batch_key = f"batch:{chain_key}"
# Compile batched version if needed
if batch_key not in self.jax_batched_chains:
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
self.jax_batched_chains[batch_key] = batched_fn
if batched_fn:
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
batched_fn = self.jax_batched_chains.get(batch_key)
if batched_fn is not None:
try:
import jax
import jax.numpy as jnp
# Stack frames into batch array
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
ts_array = jnp.array(ts)
frame_nums_array = jnp.array(frame_nums)
# Build kwargs - all numeric params as arrays for vmap
kwargs = {}
static_kwargs = {} # Non-vmapped (strings, bools)
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
for j, pname in enumerate(params_list_batch[0][i].keys()):
key = f"_p{i}_{pname}"
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
first = values[0]
if isinstance(first, (str, bool)):
# Static params - not vmapped
static_kwargs[key] = first
elif isinstance(first, (int, float)):
# Always batch numeric params for simplicity
kwargs[key] = jnp.array(values)
elif hasattr(first, 'shape'):
kwargs[key] = jnp.stack(values)
else:
kwargs[key] = jnp.array(values)
seed = self.config.get('seed', 42)
# Create wrapper that unpacks the params dict
def single_call(frame, t, frame_num, params_dict):
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
# vmap over frame, t, frame_num, and the params dict (as pytree)
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
# Stack kwargs into a dict of arrays (pytree with matching structure)
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
# Unstack results
return [results[i] for i in range(len(frames))]
except Exception as e:
print(f"Batched chain error: {e}", file=sys.stderr)
# Fall back to sequential
return None
def _init(self):
"""Initialize from sexp - load primitives, effects, defs, scans."""
# Set random seed for deterministic output
@@ -869,6 +1247,22 @@ class StreamInterpreter:
# === Effects ===
if op in self.effects:
# Try to detect and fuse effect chains for JAX acceleration
if self.use_jax and op in self.jax_effects:
chain_info = self._extract_effect_chain(expr, env)
if chain_info is not None:
effect_names, params_list, base_frame_expr = chain_info
# Only use chain if we have 2+ effects (worth fusing)
if len(effect_names) >= 2:
base_frame = self._eval(base_frame_expr, env)
if base_frame is not None and hasattr(base_frame, 'shape'):
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
if result is not None:
return result
# Fall through if chain application fails
effect = self.effects[op]
effect_env = dict(env)
@@ -895,17 +1289,28 @@ class StreamInterpreter:
positional_idx += 1
i += 1
# Try JAX-accelerated execution first
# Try JAX-accelerated execution with deferred chaining
if self.use_jax and op in self.jax_effects and frame_val is not None:
# Build params dict for JAX (exclude 'frame')
jax_params = {k: v for k, v in effect_env.items()
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
if k != 'frame' and k in effect['params']}
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
if result is not None:
return result
# Fall through to interpreter if JAX fails
# Check if input is a deferred chain - if so, extend it
if isinstance(frame_val, DeferredEffectChain):
return frame_val.extend(op, jax_params)
# Check if input is a valid frame - create new deferred chain
if hasattr(frame_val, 'shape'):
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
# Fall through to interpreter if not a valid frame
# Force any deferred frame before interpreter evaluation
if isinstance(frame_val, DeferredEffectChain):
frame_val = self._force_deferred(frame_val)
effect_env['frame'] = frame_val
return self._eval(effect['body'], effect_env)
@@ -922,10 +1327,15 @@ class StreamInterpreter:
if isinstance(args[i], Keyword):
k = args[i].name
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
# Force deferred chains before passing to primitives
v = self._maybe_force(v)
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
i += 2
else:
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
val = self._eval(args[i], env)
# Force deferred chains before passing to primitives
val = self._maybe_force(val)
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
i += 1
try:
if kwargs:
@@ -1152,6 +1562,61 @@ class StreamInterpreter:
eval_times = []
write_times = []
# Batch accumulation for JAX
batch_deferred = [] # Accumulated DeferredEffectChains
batch_times = [] # Corresponding time values
batch_start_frame = 0
def flush_batch():
"""Execute accumulated batch and write results."""
nonlocal batch_deferred, batch_times
if not batch_deferred:
return
t_flush = time.time()
# Check if all chains have same structure (can batch)
first = batch_deferred[0]
can_batch = (
self.use_jax and
len(batch_deferred) >= 2 and
all(d.effects == first.effects for d in batch_deferred)
)
if can_batch:
# Try batched execution
frames = [d.base_frame for d in batch_deferred]
ts = [d.t for d in batch_deferred]
frame_nums = [d.frame_num for d in batch_deferred]
params_batch = [d.params_list for d in batch_deferred]
results = self._apply_batched_chain(
first.effects, params_batch, frames, ts, frame_nums
)
if results is not None:
# Write batched results
for result, t in zip(results, batch_times):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
return
# Fall back to sequential execution
for deferred, t in zip(batch_deferred, batch_times):
result = self._force_deferred(deferred)
if result is not None and hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
for frame_num in range(start_frame, n_frames):
if not out.is_open:
break
@@ -1182,8 +1647,23 @@ class StreamInterpreter:
eval_times.append(time.time() - t1)
t2 = time.time()
if result is not None and hasattr(result, 'shape'):
out.write(result, ctx.t)
if result is not None:
if isinstance(result, DeferredEffectChain):
# Accumulate for batching
batch_deferred.append(result)
batch_times.append(ctx.t)
# Flush when batch is full
if len(batch_deferred) >= self.jax_batch_size:
flush_batch()
else:
# Not deferred - flush any pending batch first, then write
flush_batch()
if hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, ctx.t)
write_times.append(time.time() - t2)
frame_elapsed = time.time() - frame_start
@@ -1219,6 +1699,9 @@ class StreamInterpreter:
except Exception as e:
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
# Flush any remaining batch
flush_batch()
finally:
out.close()
# Store output for access to properties like playlist_cid