Add JAX backend with frame-varying random keys
- Add sexp_to_jax.py: JAX compiler for S-expression effects - Use jax.random.fold_in for deterministic but varying random per frame - Pass seed from recipe config through to JAX effects - Fix NVENC detection to do actual encode test - Add set_random_seed for deterministic Python random The fold_in approach allows frame_num to be traced (not static) while still producing different random patterns per frame, fixing the interference pattern issue. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -52,15 +52,33 @@ def prim_max(*args):
|
||||
|
||||
|
||||
def prim_round(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.round(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.round(x)
|
||||
return round(x)
|
||||
|
||||
|
||||
def prim_floor(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.floor(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.floor(x)
|
||||
import math
|
||||
return math.floor(x)
|
||||
|
||||
|
||||
def prim_ceil(x):
|
||||
import numpy as np
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from .xector import Xector
|
||||
return Xector(np.ceil(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.ceil(x)
|
||||
import math
|
||||
return math.ceil(x)
|
||||
|
||||
@@ -193,6 +211,11 @@ def prim_range(*args):
|
||||
import random
|
||||
_rng = random.Random()
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set the random seed for deterministic output."""
|
||||
global _rng
|
||||
_rng = random.Random(seed)
|
||||
|
||||
def prim_rand():
|
||||
"""Return random float in [0, 1)."""
|
||||
return _rng.random()
|
||||
|
||||
Reference in New Issue
Block a user