Add JAX backend with frame-varying random keys
- Add sexp_to_jax.py: JAX compiler for S-expression effects - Use jax.random.fold_in for deterministic but varying random per frame - Pass seed from recipe config through to JAX effects - Fix NVENC detection to do actual encode test - Add set_random_seed for deterministic Python random The fold_in approach allows frame_num to be traced (not static) while still producing different random patterns per frame, fixing the interference pattern issue. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -52,15 +52,33 @@ def prim_max(*args):
|
|||||||
|
|
||||||
|
|
||||||
def prim_round(x):
|
def prim_round(x):
|
||||||
|
import numpy as np
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from .xector import Xector
|
||||||
|
return Xector(np.round(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.round(x)
|
||||||
return round(x)
|
return round(x)
|
||||||
|
|
||||||
|
|
||||||
def prim_floor(x):
|
def prim_floor(x):
|
||||||
|
import numpy as np
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from .xector import Xector
|
||||||
|
return Xector(np.floor(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.floor(x)
|
||||||
import math
|
import math
|
||||||
return math.floor(x)
|
return math.floor(x)
|
||||||
|
|
||||||
|
|
||||||
def prim_ceil(x):
|
def prim_ceil(x):
|
||||||
|
import numpy as np
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from .xector import Xector
|
||||||
|
return Xector(np.ceil(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.ceil(x)
|
||||||
import math
|
import math
|
||||||
return math.ceil(x)
|
return math.ceil(x)
|
||||||
|
|
||||||
@@ -193,6 +211,11 @@ def prim_range(*args):
|
|||||||
import random
|
import random
|
||||||
_rng = random.Random()
|
_rng = random.Random()
|
||||||
|
|
||||||
|
def set_random_seed(seed):
|
||||||
|
"""Set the random seed for deterministic output."""
|
||||||
|
global _rng
|
||||||
|
_rng = random.Random(seed)
|
||||||
|
|
||||||
def prim_rand():
|
def prim_rand():
|
||||||
"""Return random float in [0, 1)."""
|
"""Return random float in [0, 1)."""
|
||||||
return _rng.random()
|
return _rng.random()
|
||||||
|
|||||||
@@ -37,19 +37,39 @@ _nvenc_available: Optional[bool] = None
|
|||||||
|
|
||||||
|
|
||||||
def check_nvenc_available() -> bool:
|
def check_nvenc_available() -> bool:
|
||||||
"""Check if NVENC hardware encoding is available."""
|
"""Check if NVENC hardware encoding is available and working.
|
||||||
|
|
||||||
|
Does a real encode test to catch cases where nvenc is listed
|
||||||
|
but CUDA libraries aren't loaded.
|
||||||
|
"""
|
||||||
global _nvenc_available
|
global _nvenc_available
|
||||||
if _nvenc_available is not None:
|
if _nvenc_available is not None:
|
||||||
return _nvenc_available
|
return _nvenc_available
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# First check if encoder is listed
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["ffmpeg", "-encoders"],
|
["ffmpeg", "-encoders"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=5
|
timeout=5
|
||||||
)
|
)
|
||||||
_nvenc_available = "h264_nvenc" in result.stdout
|
if "h264_nvenc" not in result.stdout:
|
||||||
|
_nvenc_available = False
|
||||||
|
return _nvenc_available
|
||||||
|
|
||||||
|
# Actually try to encode a small test frame
|
||||||
|
result = subprocess.run(
|
||||||
|
["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1",
|
||||||
|
"-c:v", "h264_nvenc", "-f", "null", "-"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
_nvenc_available = result.returncode == 0
|
||||||
|
if not _nvenc_available:
|
||||||
|
import sys
|
||||||
|
print("NVENC listed but not working, falling back to libx264", file=sys.stderr)
|
||||||
except Exception:
|
except Exception:
|
||||||
_nvenc_available = False
|
_nvenc_available = False
|
||||||
|
|
||||||
|
|||||||
3638
streaming/sexp_to_jax.py
Normal file
3638
streaming/sexp_to_jax.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -28,12 +28,31 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Any, Optional, Tuple
|
from typing import Dict, List, Any, Optional, Tuple, Callable
|
||||||
|
|
||||||
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
|
# Use local sexp_effects parser (supports namespaced symbols like math:sin)
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
|
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
|
||||||
|
|
||||||
|
# JAX backend (optional - loaded on demand)
|
||||||
|
_JAX_AVAILABLE = False
|
||||||
|
_jax_compiler = None
|
||||||
|
|
||||||
|
def _init_jax():
|
||||||
|
"""Lazily initialize JAX compiler."""
|
||||||
|
global _JAX_AVAILABLE, _jax_compiler
|
||||||
|
if _jax_compiler is not None:
|
||||||
|
return _JAX_AVAILABLE
|
||||||
|
try:
|
||||||
|
from streaming.sexp_to_jax import JaxCompiler, compile_effect_file
|
||||||
|
_jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file}
|
||||||
|
_JAX_AVAILABLE = True
|
||||||
|
print("JAX backend initialized", file=sys.stderr)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"JAX backend not available: {e}", file=sys.stderr)
|
||||||
|
_JAX_AVAILABLE = False
|
||||||
|
return _JAX_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Context:
|
class Context:
|
||||||
@@ -51,7 +70,7 @@ class StreamInterpreter:
|
|||||||
and calls primitives.
|
and calls primitives.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sexp_path: str, actor_id: Optional[str] = None):
|
def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False):
|
||||||
self.sexp_path = Path(sexp_path)
|
self.sexp_path = Path(sexp_path)
|
||||||
self.sexp_dir = self.sexp_path.parent
|
self.sexp_dir = self.sexp_path.parent
|
||||||
self.actor_id = actor_id # For friendly name resolution
|
self.actor_id = actor_id # For friendly name resolution
|
||||||
@@ -74,6 +93,17 @@ class StreamInterpreter:
|
|||||||
self.primitives: Dict[str, Any] = {}
|
self.primitives: Dict[str, Any] = {}
|
||||||
self.effects: Dict[str, dict] = {}
|
self.effects: Dict[str, dict] = {}
|
||||||
self.macros: Dict[str, dict] = {}
|
self.macros: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
# JAX backend for accelerated effect evaluation
|
||||||
|
self.use_jax = use_jax
|
||||||
|
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
||||||
|
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
||||||
|
if use_jax:
|
||||||
|
if _init_jax():
|
||||||
|
print("JAX acceleration enabled", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr)
|
||||||
|
self.use_jax = False
|
||||||
# Try multiple locations for primitive_libs
|
# Try multiple locations for primitive_libs
|
||||||
possible_paths = [
|
possible_paths = [
|
||||||
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
|
self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir
|
||||||
@@ -307,8 +337,13 @@ class StreamInterpreter:
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
self.effects[name] = {'params': params, 'body': body}
|
self.effects[name] = {'params': params, 'body': body}
|
||||||
|
self.jax_effect_paths[name] = effect_path # Track source for JAX compilation
|
||||||
print(f"Effect: {name}", file=sys.stderr)
|
print(f"Effect: {name}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Try to compile with JAX if enabled
|
||||||
|
if self.use_jax and _JAX_AVAILABLE:
|
||||||
|
self._compile_jax_effect(name, effect_path)
|
||||||
|
|
||||||
elif cmd == 'defmacro':
|
elif cmd == 'defmacro':
|
||||||
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
name = form[1].name if isinstance(form[1], Symbol) else str(form[1])
|
||||||
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
|
params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]]
|
||||||
@@ -387,8 +422,62 @@ class StreamInterpreter:
|
|||||||
}
|
}
|
||||||
print(f"Scan: {name}", file=sys.stderr)
|
print(f"Scan: {name}", file=sys.stderr)
|
||||||
|
|
||||||
|
def _compile_jax_effect(self, name: str, effect_path: Path):
|
||||||
|
"""Compile an effect with JAX for accelerated execution."""
|
||||||
|
if not _JAX_AVAILABLE or name in self.jax_effects:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
compile_effect_file = _jax_compiler['compile_effect_file']
|
||||||
|
jax_fn = compile_effect_file(str(effect_path))
|
||||||
|
self.jax_effects[name] = jax_fn
|
||||||
|
print(f" [JAX compiled: {name}]", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
# Silently fall back to interpreter for unsupported effects
|
||||||
|
if 'Unknown operation' not in str(e):
|
||||||
|
print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr)
|
||||||
|
|
||||||
|
def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]:
|
||||||
|
"""Apply a JAX-compiled effect to a frame."""
|
||||||
|
if name not in self.jax_effects:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
jax_fn = self.jax_effects[name]
|
||||||
|
# Ensure frame is numpy array
|
||||||
|
if hasattr(frame, 'cpu'):
|
||||||
|
frame = frame.cpu
|
||||||
|
elif hasattr(frame, 'get'):
|
||||||
|
frame = frame.get()
|
||||||
|
|
||||||
|
# Get seed from config for deterministic random
|
||||||
|
seed = self.config.get('seed', 42)
|
||||||
|
|
||||||
|
# Call JAX function with parameters
|
||||||
|
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||||
|
|
||||||
|
# Convert result back to numpy if needed
|
||||||
|
if hasattr(result, 'block_until_ready'):
|
||||||
|
result.block_until_ready() # Ensure computation is complete
|
||||||
|
if hasattr(result, '__array__'):
|
||||||
|
result = np.asarray(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
# Fall back to interpreter on error
|
||||||
|
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
||||||
|
# Set random seed for deterministic output
|
||||||
|
seed = self.config.get('seed', 42)
|
||||||
|
try:
|
||||||
|
from sexp_effects.primitive_libs.core import set_random_seed
|
||||||
|
set_random_seed(seed)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Load external config files first (they can override recipe definitions)
|
# Load external config files first (they can override recipe definitions)
|
||||||
if self.sources_config:
|
if self.sources_config:
|
||||||
self._load_config_file(self.sources_config)
|
self._load_config_file(self.sources_config)
|
||||||
@@ -780,6 +869,7 @@ class StreamInterpreter:
|
|||||||
effect_env[pname] = pdef.get('default', 0)
|
effect_env[pname] = pdef.get('default', 0)
|
||||||
|
|
||||||
positional_idx = 0
|
positional_idx = 0
|
||||||
|
frame_val = None
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(args):
|
while i < len(args):
|
||||||
if isinstance(args[i], Keyword):
|
if isinstance(args[i], Keyword):
|
||||||
@@ -791,11 +881,24 @@ class StreamInterpreter:
|
|||||||
val = self._eval(args[i], env)
|
val = self._eval(args[i], env)
|
||||||
if positional_idx == 0:
|
if positional_idx == 0:
|
||||||
effect_env['frame'] = val
|
effect_env['frame'] = val
|
||||||
|
frame_val = val
|
||||||
elif positional_idx - 1 < len(param_names):
|
elif positional_idx - 1 < len(param_names):
|
||||||
effect_env[param_names[positional_idx - 1]] = val
|
effect_env[param_names[positional_idx - 1]] = val
|
||||||
positional_idx += 1
|
positional_idx += 1
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
# Try JAX-accelerated execution first
|
||||||
|
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
||||||
|
# Build params dict for JAX (exclude 'frame')
|
||||||
|
jax_params = {k: v for k, v in effect_env.items()
|
||||||
|
if k != 'frame' and k in effect['params']}
|
||||||
|
t = env.get('t', 0.0)
|
||||||
|
frame_num = env.get('frame-num', 0)
|
||||||
|
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
# Fall through to interpreter if JAX fails
|
||||||
|
|
||||||
return self._eval(effect['body'], effect_env)
|
return self._eval(effect['body'], effect_env)
|
||||||
|
|
||||||
# === Primitives ===
|
# === Primitives ===
|
||||||
@@ -1049,9 +1152,9 @@ class StreamInterpreter:
|
|||||||
|
|
||||||
|
|
||||||
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
|
def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None,
|
||||||
sources_config: str = None, audio_config: str = None):
|
sources_config: str = None, audio_config: str = None, use_jax: bool = False):
|
||||||
"""Run a streaming sexp."""
|
"""Run a streaming sexp."""
|
||||||
interp = StreamInterpreter(sexp_path)
|
interp = StreamInterpreter(sexp_path, use_jax=use_jax)
|
||||||
if fps:
|
if fps:
|
||||||
interp.config['fps'] = fps
|
interp.config['fps'] = fps
|
||||||
if sources_config:
|
if sources_config:
|
||||||
@@ -1070,7 +1173,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--fps", type=float, default=None)
|
parser.add_argument("--fps", type=float, default=None)
|
||||||
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
|
parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file")
|
||||||
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
|
parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file")
|
||||||
|
parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
|
run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps,
|
||||||
sources_config=args.sources_config, audio_config=args.audio_config)
|
sources_config=args.sources_config, audio_config=args.audio_config,
|
||||||
|
use_jax=args.jax)
|
||||||
|
|||||||
Reference in New Issue
Block a user