Add JAX backend with frame-varying random keys
Some checks failed
GPU Worker CI/CD / test (push) Has been cancelled
GPU Worker CI/CD / deploy (push) Has been cancelled

- Add sexp_to_jax.py: JAX compiler for S-expression effects
- Use jax.random.fold_in for deterministic but varying random per frame
- Pass seed from recipe config through to JAX effects
- Fix NVENC detection to do actual encode test
- Add set_random_seed for deterministic Python random

The fold_in approach allows frame_num to be traced (not static)
while still producing different random patterns per frame,
fixing the interference pattern issue.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-05 11:07:02 +00:00
parent 0534081e44
commit 7411aa74c4
4 changed files with 3793 additions and 7 deletions

View File

@@ -52,15 +52,33 @@ def prim_max(*args):
def prim_round(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.round(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.round(x)
return round(x)
def prim_floor(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.floor(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.floor(x)
import math
return math.floor(x)
def prim_ceil(x):
import numpy as np
if hasattr(x, '_data'): # Xector
from .xector import Xector
return Xector(np.ceil(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.ceil(x)
import math
return math.ceil(x)
@@ -193,6 +211,11 @@ def prim_range(*args):
import random
_rng = random.Random()
def set_random_seed(seed):
"""Set the random seed for deterministic output."""
global _rng
_rng = random.Random(seed)
def prim_rand():
"""Return random float in [0, 1)."""
return _rng.random()