Add JAX backend with frame-varying random keys
Some checks failed
GPU Worker CI/CD / test (push) Has been cancelled
GPU Worker CI/CD / deploy (push) Has been cancelled

- Add sexp_to_jax.py: JAX compiler for S-expression effects
- Use jax.random.fold_in for deterministic but varying random per frame
- Pass seed from recipe config through to JAX effects
- Fix NVENC detection to do actual encode test
- Add set_random_seed for deterministic Python random

The fold_in approach allows frame_num to be traced (not static)
while still producing different random patterns per frame,
fixing the interference pattern issue.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-05 11:07:02 +00:00
parent 0534081e44
commit 7411aa74c4
4 changed files with 3793 additions and 7 deletions

View File

@@ -52,15 +52,33 @@ def prim_max(*args):
def prim_round(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.round(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.round(x)
return round(x)
def prim_floor(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.floor(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.floor(x)
import math
return math.floor(x)
def prim_ceil(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.ceil(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.ceil(x)
import math
return math.ceil(x)
@@ -193,6 +211,11 @@ def prim_range(*args):
import random
_rng = random.Random()
def set_random_seed(seed):
"""Set the random seed for deterministic output."""
global _rng
_rng = random.Random(seed)
def prim_rand():
"""Return random float in [0, 1)."""
return _rng.random()

View File

@@ -37,19 +37,39 @@ _nvenc_available: Optional[bool] = None
def check_nvenc_available() -> bool:
"""Check if NVENC hardware encoding is available."""
"""Check if NVENC hardware encoding is available and working.
Does a real encode test to catch cases where nvenc is listed
but CUDA libraries aren't loaded.
"""
global _nvenc_available
if _nvenc_available is not None:
return _nvenc_available
try:
# First check if encoder is listed
result = subprocess.run(
["ffmpeg", "-encoders"],
capture_output=True,
text=True,
timeout=5
)
_nvenc_available = "h264_nvenc" in result.stdout
if "h264_nvenc" not in result.stdout:
_nvenc_available = False
return _nvenc_available
# Actually try to encode a small test frame
result = subprocess.run(
["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1",
"-c:v", "h264_nvenc", "-f", "null", "-"],
capture_output=True,
text=True,
timeout=10
)
_nvenc_available = result.returncode == 0
if not _nvenc_available:
import sys
print("NVENC listed but not working, falling back to libx264", file=sys.stderr)
except Exception:
_nvenc_available = False

3638
streaming/sexp_to_jax.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -28,12 +28,31 @@ import math
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Any, Optional, Tuple
from typing import Dict, List, Any, Optional, Tuple, Callable
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
sys.path.insert(0, str(Path(__file__).parent.parent))
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
# JAX backend (optional - loaded on demand)
_JAX_AVAILABLE = False
_jax_compiler = None
def _init_jax():
"""Lazily initialize JAX compiler."""
global _JAX_AVAILABLE, _jax_compiler
if _jax_compiler is not None:
return _JAX_AVAILABLE
try:
from streaming.sexp_to_jax import JaxCompiler, compile_effect_file
_jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file}
_JAX_AVAILABLE = True
print("JAX backend initialized", file=sys.stderr)
except ImportError as e:
print(f"JAX backend not available: {e}", file=sys.stderr)
_JAX_AVAILABLE = False
return _JAX_AVAILABLE
@dataclass
class Context:
@@ -51,7 +70,7 @@ class StreamInterpreter:
and calls primitives.
"""
def __init__(self, sexp_path: str, actor_id: Optional[str] = None):
def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False):
self.sexp_path = Path(sexp_path)
self.sexp_dir = self.sexp_path.parent
self.actor_id = actor_id # For friendly name resolution
@@ -74,6 +93,17 @@ class StreamInterpreter:
self.primitives: Dict[str, Any] = {}
self.effects: Dict[str, dict] = {}
self.macros: Dict[str, dict] = {}
# JAX backend for accelerated effect evaluation
self.use_jax = use_jax
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
if use_jax:
if _init_jax():
print("JAX acceleration enabled", file=sys.stderr)
else:
print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr)
self.use_jax = False
# Try multiple locations for primitive_libs
possible_paths = [
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
@@ -307,8 +337,13 @@ class StreamInterpreter:
i += 1
self.effects[name] = {'params': params, 'body': body}
self.jax_effect_paths[name] = effect_path # Track source for JAX compilation
print(f"Effect: {name}", file=sys.stderr)
# Try to compile with JAX if enabled
if self.use_jax and _JAX_AVAILABLE:
self._compile_jax_effect(name, effect_path)
elif cmd == 'defmacro':
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
@@ -387,8 +422,62 @@ class StreamInterpreter:
}
print(f"Scan: {name}", file=sys.stderr)
def _compile_jax_effect(self, name: str, effect_path: Path):
"""Compile an effect with JAX for accelerated execution."""
if not _JAX_AVAILABLE or name in self.jax_effects:
return
try:
compile_effect_file = _jax_compiler['compile_effect_file']
jax_fn = compile_effect_file(str(effect_path))
self.jax_effects[name] = jax_fn
print(f" [JAX compiled: {name}]", file=sys.stderr)
except Exception as e:
# Silently fall back to interpreter for unsupported effects
if 'Unknown operation' not in str(e):
print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr)
def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]:
"""Apply a JAX-compiled effect to a frame."""
if name not in self.jax_effects:
return None
try:
jax_fn = self.jax_effects[name]
# Ensure frame is numpy array
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get'):
frame = frame.get()
# Get seed from config for deterministic random
seed = self.config.get('seed', 42)
# Call JAX function with parameters
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
# Convert result back to numpy if needed
if hasattr(result, 'block_until_ready'):
result.block_until_ready() # Ensure computation is complete
if hasattr(result, '__array__'):
result = np.asarray(result)
return result
except Exception as e:
# Fall back to interpreter on error
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
return None
def _init(self):
"""Initialize from sexp - load primitives, effects, defs, scans."""
# Set random seed for deterministic output
seed = self.config.get('seed', 42)
try:
from sexp_effects.primitive_libs.core import set_random_seed
set_random_seed(seed)
except ImportError:
pass
# Load external config files first (they can override recipe definitions)
if self.sources_config:
self._load_config_file(self.sources_config)
@@ -780,6 +869,7 @@ class StreamInterpreter:
effect_env[pname] = pdef.get('default', 0)
positional_idx = 0
frame_val = None
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
@@ -791,11 +881,24 @@ class StreamInterpreter:
val = self._eval(args[i], env)
if positional_idx == 0:
effect_env['frame'] = val
frame_val = val
elif positional_idx - 1 < len(param_names):
effect_env[param_names[positional_idx - 1]] = val
positional_idx += 1
i += 1
# Try JAX-accelerated execution first
if self.use_jax and op in self.jax_effects and frame_val is not None:
# Build params dict for JAX (exclude 'frame')
jax_params = {k: v for k, v in effect_env.items()
if k != 'frame' and k in effect['params']}
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
if result is not None:
return result
# Fall through to interpreter if JAX fails
return self._eval(effect['body'], effect_env)
# === Primitives ===
@@ -1049,9 +1152,9 @@ class StreamInterpreter:
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
sources_config: str = None, audio_config: str = None):
sources_config: str = None, audio_config: str = None, use_jax: bool = False):
"""Run a streaming sexp."""
interp = StreamInterpreter(sexp_path)
interp = StreamInterpreter(sexp_path, use_jax=use_jax)
if fps:
interp.config['fps'] = fps
if sources_config:
@@ -1070,7 +1173,9 @@ if __name__ == "__main__":
parser.add_argument("--fps", type=float, default=None)
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects")
args = parser.parse_args()
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
sources_config=args.sources_config, audio_config=args.audio_config)
sources_config=args.sources_config, audio_config=args.audio_config,
use_jax=args.jax)