Add JAX backend with frame-varying random keys
- Add sexp_to_jax.py: JAX compiler for S-expression effects - Use jax.random.fold_in for deterministic but varying random per frame - Pass seed from recipe config through to JAX effects - Fix NVENC detection to do actual encode test - Add set_random_seed for deterministic Python random The fold_in approach allows frame_num to be traced (not static) while still producing different random patterns per frame, fixing the interference pattern issue. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -52,15 +52,33 @@ def prim_max(*args):
|
||||
|
||||
|
||||
def prim_round(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.round(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.round(x)
|
||||
return round(x)
|
||||
|
||||
|
||||
def prim_floor(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.floor(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.floor(x)
|
||||
import math
|
||||
return math.floor(x)
|
||||
|
||||
|
||||
def prim_ceil(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.ceil(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.ceil(x)
|
||||
import math
|
||||
return math.ceil(x)
|
||||
|
||||
@@ -193,6 +211,11 @@ def prim_range(*args):
|
||||
import random
|
||||
_rng = random.Random()
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set the random seed for deterministic output."""
|
||||
global _rng
|
||||
_rng = random.Random(seed)
|
||||
|
||||
def prim_rand():
|
||||
"""Return random float in [0, 1)."""
|
||||
return _rng.random()
|
||||
|
||||
@@ -37,19 +37,39 @@ _nvenc_available: Optional[bool] = None
|
||||
|
||||
|
||||
def check_nvenc_available() -> bool:
|
||||
"""Check if NVENC hardware encoding is available."""
|
||||
"""Check if NVENC hardware encoding is available and working.
|
||||
|
||||
Does a real encode test to catch cases where nvenc is listed
|
||||
but CUDA libraries aren't loaded.
|
||||
"""
|
||||
global _nvenc_available
|
||||
if _nvenc_available is not None:
|
||||
return _nvenc_available
|
||||
|
||||
try:
|
||||
# First check if encoder is listed
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-encoders"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
_nvenc_available = "h264_nvenc" in result.stdout
|
||||
if "h264_nvenc" not in result.stdout:
|
||||
_nvenc_available = False
|
||||
return _nvenc_available
|
||||
|
||||
# Actually try to encode a small test frame
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1",
|
||||
"-c:v", "h264_nvenc", "-f", "null", "-"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
_nvenc_available = result.returncode == 0
|
||||
if not _nvenc_available:
|
||||
import sys
|
||||
print("NVENC listed but not working, falling back to libx264", file=sys.stderr)
|
||||
except Exception:
|
||||
_nvenc_available = False
|
||||
|
||||
|
||||
3638
streaming/sexp_to_jax.py
Normal file
3638
streaming/sexp_to_jax.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -28,12 +28,31 @@ import math
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from typing import Dict, List, Any, Optional, Tuple, Callable
|
||||
|
||||
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
|
||||
|
||||
# JAX backend (optional - loaded on demand)
|
||||
_JAX_AVAILABLE = False
|
||||
_jax_compiler = None
|
||||
|
||||
def _init_jax():
|
||||
"""Lazily initialize JAX compiler."""
|
||||
global _JAX_AVAILABLE, _jax_compiler
|
||||
if _jax_compiler is not None:
|
||||
return _JAX_AVAILABLE
|
||||
try:
|
||||
from streaming.sexp_to_jax import JaxCompiler, compile_effect_file
|
||||
_jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file}
|
||||
_JAX_AVAILABLE = True
|
||||
print("JAX backend initialized", file=sys.stderr)
|
||||
except ImportError as e:
|
||||
print(f"JAX backend not available: {e}", file=sys.stderr)
|
||||
_JAX_AVAILABLE = False
|
||||
return _JAX_AVAILABLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
@@ -51,7 +70,7 @@ class StreamInterpreter:
|
||||
and calls primitives.
|
||||
"""
|
||||
|
||||
def __init__(self, sexp_path: str, actor_id: Optional[str] = None):
|
||||
def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False):
|
||||
self.sexp_path = Path(sexp_path)
|
||||
self.sexp_dir = self.sexp_path.parent
|
||||
self.actor_id = actor_id # For friendly name resolution
|
||||
@@ -74,6 +93,17 @@ class StreamInterpreter:
|
||||
self.primitives: Dict[str, Any] = {}
|
||||
self.effects: Dict[str, dict] = {}
|
||||
self.macros: Dict[str, dict] = {}
|
||||
|
||||
# JAX backend for accelerated effect evaluation
|
||||
self.use_jax = use_jax
|
||||
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
||||
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
||||
if use_jax:
|
||||
if _init_jax():
|
||||
print("JAX acceleration enabled", file=sys.stderr)
|
||||
else:
|
||||
print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr)
|
||||
self.use_jax = False
|
||||
# Try multiple locations for primitive_libs
|
||||
possible_paths = [
|
||||
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
|
||||
@@ -307,8 +337,13 @@ class StreamInterpreter:
|
||||
i += 1
|
||||
|
||||
self.effects[name] = {'params': params, 'body': body}
|
||||
self.jax_effect_paths[name] = effect_path # Track source for JAX compilation
|
||||
print(f"Effect: {name}", file=sys.stderr)
|
||||
|
||||
# Try to compile with JAX if enabled
|
||||
if self.use_jax and _JAX_AVAILABLE:
|
||||
self._compile_jax_effect(name, effect_path)
|
||||
|
||||
elif cmd == 'defmacro':
|
||||
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
||||
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
|
||||
@@ -387,8 +422,62 @@ class StreamInterpreter:
|
||||
}
|
||||
print(f"Scan: {name}", file=sys.stderr)
|
||||
|
||||
def _compile_jax_effect(self, name: str, effect_path: Path):
|
||||
"""Compile an effect with JAX for accelerated execution."""
|
||||
if not _JAX_AVAILABLE or name in self.jax_effects:
|
||||
return
|
||||
|
||||
try:
|
||||
compile_effect_file = _jax_compiler['compile_effect_file']
|
||||
jax_fn = compile_effect_file(str(effect_path))
|
||||
self.jax_effects[name] = jax_fn
|
||||
print(f" [JAX compiled: {name}]", file=sys.stderr)
|
||||
except Exception as e:
|
||||
# Silently fall back to interpreter for unsupported effects
|
||||
if 'Unknown operation' not in str(e):
|
||||
print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr)
|
||||
|
||||
def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]:
|
||||
"""Apply a JAX-compiled effect to a frame."""
|
||||
if name not in self.jax_effects:
|
||||
return None
|
||||
|
||||
try:
|
||||
jax_fn = self.jax_effects[name]
|
||||
# Ensure frame is numpy array
|
||||
if hasattr(frame, 'cpu'):
|
||||
frame = frame.cpu
|
||||
elif hasattr(frame, 'get'):
|
||||
frame = frame.get()
|
||||
|
||||
# Get seed from config for deterministic random
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Call JAX function with parameters
|
||||
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||
|
||||
# Convert result back to numpy if needed
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready() # Ensure computation is complete
|
||||
if hasattr(result, '__array__'):
|
||||
result = np.asarray(result)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Fall back to interpreter on error
|
||||
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _init(self):
|
||||
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
||||
# Set random seed for deterministic output
|
||||
seed = self.config.get('seed', 42)
|
||||
try:
|
||||
from sexp_effects.primitive_libs.core import set_random_seed
|
||||
set_random_seed(seed)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Load external config files first (they can override recipe definitions)
|
||||
if self.sources_config:
|
||||
self._load_config_file(self.sources_config)
|
||||
@@ -780,6 +869,7 @@ class StreamInterpreter:
|
||||
effect_env[pname] = pdef.get('default', 0)
|
||||
|
||||
positional_idx = 0
|
||||
frame_val = None
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if isinstance(args[i], Keyword):
|
||||
@@ -791,11 +881,24 @@ class StreamInterpreter:
|
||||
val = self._eval(args[i], env)
|
||||
if positional_idx == 0:
|
||||
effect_env['frame'] = val
|
||||
frame_val = val
|
||||
elif positional_idx - 1 < len(param_names):
|
||||
effect_env[param_names[positional_idx - 1]] = val
|
||||
positional_idx += 1
|
||||
i += 1
|
||||
|
||||
# Try JAX-accelerated execution first
|
||||
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
||||
# Build params dict for JAX (exclude 'frame')
|
||||
jax_params = {k: v for k, v in effect_env.items()
|
||||
if k != 'frame' and k in effect['params']}
|
||||
t = env.get('t', 0.0)
|
||||
frame_num = env.get('frame-num', 0)
|
||||
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
|
||||
if result is not None:
|
||||
return result
|
||||
# Fall through to interpreter if JAX fails
|
||||
|
||||
return self._eval(effect['body'], effect_env)
|
||||
|
||||
# === Primitives ===
|
||||
@@ -1049,9 +1152,9 @@ class StreamInterpreter:
|
||||
|
||||
|
||||
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
|
||||
sources_config: str = None, audio_config: str = None):
|
||||
sources_config: str = None, audio_config: str = None, use_jax: bool = False):
|
||||
"""Run a streaming sexp."""
|
||||
interp = StreamInterpreter(sexp_path)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=use_jax)
|
||||
if fps:
|
||||
interp.config['fps'] = fps
|
||||
if sources_config:
|
||||
@@ -1070,7 +1173,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--fps", type=float, default=None)
|
||||
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
|
||||
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
|
||||
parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects")
|
||||
args = parser.parse_args()
|
||||
|
||||
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
|
||||
sources_config=args.sources_config, audio_config=args.audio_config)
|
||||
sources_config=args.sources_config, audio_config=args.audio_config,
|
||||
use_jax=args.jax)
|
||||
|
||||
Reference in New Issue
Block a user