Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s

- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 15:12:54 +00:00
parent dbc4ece2cc
commit fc9597456f
30 changed files with 7749 additions and 165 deletions

View File

@@ -0,0 +1,517 @@
#!/usr/bin/env python3
"""Integration tests comparing JAX and Python rendering pipelines.
These tests ensure the JAX-compiled effect chains produce identical output
to the Python/NumPy path. They test:
1. Full effect pipelines through both interpreters
2. Multi-frame sequences (to catch phase-dependent bugs)
3. Compiled effect chain fusion
4. Edge cases like shrinking/zooming that affect boundary handling
"""
import os
import sys
import pytest
import numpy as np
import shutil
from pathlib import Path
# Ensure the art-celery module is importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sexp_effects.primitive_libs import core as core_mod
# Path to test resources
TEST_DIR = Path('/home/giles/art/test')
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
TEMPLATES_DIR = TEST_DIR / 'templates'
def create_test_image(h=96, w=128):
"""Create a test image with distinct patterns."""
import cv2
img = np.zeros((h, w, 3), dtype=np.uint8)
# Create gradient background
for y in range(h):
for x in range(w):
img[y, x] = [
int(255 * x / w), # R: horizontal gradient
int(255 * y / h), # G: vertical gradient
128 # B: constant
]
# Add features
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
return img
@pytest.fixture(scope='module')
def test_env(tmp_path_factory):
"""Set up test environment with sexp files and test media."""
test_dir = tmp_path_factory.mktemp('sexp_test')
original_dir = os.getcwd()
os.chdir(test_dir)
# Create directories
(test_dir / 'effects').mkdir()
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
# Create test image
import cv2
test_img = create_test_image()
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
# Copy required effect files
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
src = EFFECTS_DIR / f'{effect}.sexp'
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
if src.exists():
shutil.copy(src, dst)
yield {
'dir': test_dir,
'image_path': test_dir / 'test_image.png',
'test_img': test_img,
}
os.chdir(original_dir)
def create_sexp_file(test_dir, content, filename='test.sexp'):
"""Create a test sexp file."""
path = test_dir / 'effects' / filename
with open(path, 'w') as f:
f.write(content)
return str(path)
class TestJaxPythonPipelineEquivalence:
"""Test that JAX and Python pipelines produce equivalent output."""
def test_single_rotate_effect(self, test_env):
"""Test that a single rotate effect matches between Python and JAX."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(frame (rotate frame :angle 15 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
# Python path
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
# JAX path
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
# Inject test image into globals
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
# Force deferred if needed
py_result = np.asarray(py_interp._maybe_force(py_result))
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
def test_rotate_then_zoom(self, test_env):
"""Test rotate followed by zoom - tests effect chain fusion."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (-> (rotate frame :angle 15 :speed 0)
(zoom :amount 0.95 :speed 0)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
def test_zoom_shrink_boundary_handling(self, test_env):
"""Test zoom with shrinking factor - critical for boundary handling."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (zoom frame :amount 0.8 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
# Check corners specifically - these are most affected by boundary handling
h, w = test_img.shape[:2]
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
for y, x in corners:
py_val = py_result[y, x]
jax_val = jax_result[y, x]
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
def test_blend_two_clips(self, test_env):
"""Test blending two effect chains - the core bug scenario."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
max_diff = np.max(diff)
# Check edge region specifically
edge_diff = diff[0, :].mean()
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
def test_blend_with_invert(self, test_env):
"""Test blending with invert - matches the problematic recipe pattern."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(require-primitives "color_ops")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(effect invert :path "../sexp_effects/effects/invert.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0)
(invert :amount 1))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
class TestDeferredEffectChainFusion:
"""Test the DeferredEffectChain fusion mechanism specifically."""
def test_manual_vs_fused_chain(self, test_env):
"""Compare manual application vs fused DeferredEffectChain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
# Create minimal sexp to load effects
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Manual step-by-step application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
rot_angle = -5.0
zoom_amount = 0.95
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain application
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
# Check specific pixels
h, w = test_img.shape[:2]
for y in [0, h//2, h-1]:
for x in [0, w//2, w-1]:
manual_val = manual_result[y, x]
fused_val = fused_result[y, x]
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
def test_chain_with_shrink_zoom_boundary(self, test_env):
"""Test that shrinking zoom handles boundaries correctly in chain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Parameters that shrink the image (zoom < 1.0)
rot_angle = -4.555
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
# Manual application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
# Check top edge specifically - this is where boundary issues manifest
top_edge_manual = manual_result[0, :]
top_edge_fused = fused_result[0, :]
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
mean_edge_diff = np.mean(edge_diff)
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
# Check for zeros at edges that shouldn't be there
manual_edge_sum = np.sum(top_edge_manual)
fused_edge_sum = np.sum(top_edge_fused)
if manual_edge_sum > 100: # If manual has significant values
assert fused_edge_sum > manual_edge_sum * 0.5, \
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
Test framework to verify JAX primitives match Python primitives.
Compares output of each primitive through:
1. Python/NumPy path
2. JAX path (CPU)
3. JAX path (GPU) - if available
Reports any mismatches with detailed diffs.
"""
import sys
sys.path.insert(0, '/home/giles/art/art-celery')
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass, field
# Test configuration
TEST_WIDTH = 64
TEST_HEIGHT = 48
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
@dataclass
class TestResult:
primitive: str
passed: bool
python_mean: float = 0.0
jax_mean: float = 0.0
diff_mean: float = 0.0
diff_max: float = 0.0
pct_within_1: float = 0.0
error: str = ""
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
"""Create a test frame with known pattern."""
if pattern == 'gradient':
# Diagonal gradient
y, x = np.mgrid[0:h, 0:w]
r = (x * 255 / w).astype(np.uint8)
g = (y * 255 / h).astype(np.uint8)
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
return np.stack([r, g, b], axis=2)
elif pattern == 'checkerboard':
y, x = np.mgrid[0:h, 0:w]
check = ((x // 8) + (y // 8)) % 2
v = (check * 255).astype(np.uint8)
return np.stack([v, v, v], axis=2)
elif pattern == 'solid':
return np.full((h, w, 3), 128, dtype=np.uint8)
else:
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
if py_out is None or jax_out is None:
return float('inf'), float('inf'), 0.0
if isinstance(py_out, dict) and isinstance(jax_out, dict):
# Compare coordinate maps
diffs = []
for k in py_out:
if k in jax_out:
py_arr = np.asarray(py_out[k])
jax_arr = np.asarray(jax_out[k])
if py_arr.shape == jax_arr.shape:
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
diffs.append(diff)
if diffs:
all_diff = np.concatenate([d.flatten() for d in diffs])
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
return float('inf'), float('inf'), 0.0
py_arr = np.asarray(py_out)
jax_arr = np.asarray(jax_out)
if py_arr.shape != jax_arr.shape:
return float('inf'), float('inf'), 0.0
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
# ============================================================================
# Primitive Test Definitions
# ============================================================================
PRIMITIVE_TESTS = {
# Geometry primitives
'geometry:ripple-displace': {
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
'returns': 'coords',
},
'geometry:rotate-img': {
'args': ['frame', 45],
'returns': 'frame',
},
'geometry:scale-img': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'geometry:flip-h': {
'args': ['frame'],
'returns': 'frame',
},
'geometry:flip-v': {
'args': ['frame'],
'returns': 'frame',
},
# Color operations
'color_ops:invert': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:grayscale': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:brightness': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:contrast': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:hue-shift': {
'args': ['frame', 90],
'returns': 'frame',
},
# Image operations
'image:width': {
'args': ['frame'],
'returns': 'scalar',
},
'image:height': {
'args': ['frame'],
'returns': 'scalar',
},
'image:channel': {
'args': ['frame', 0],
'returns': 'array',
},
# Blending
'blending:blend': {
'args': ['frame', 'frame2', 0.5],
'returns': 'frame',
},
'blending:blend-add': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
'blending:blend-multiply': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
}
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the Python interpreter."""
if prim_name not in interp.primitives:
return None
func = interp.primitives[prim_name]
args = []
for a in test_def['args']:
if a == 'frame':
args.append(frame.copy())
elif a == 'frame2':
args.append(frame2.copy())
else:
args.append(a)
try:
return func(*args)
except Exception as e:
return None
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the JAX compiler."""
try:
from streaming.sexp_to_jax import JaxCompiler
import jax.numpy as jnp
compiler = JaxCompiler()
# Build a simple expression to test the primitive
from sexp_effects.parser import Symbol, Keyword
args = []
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
for a in test_def['args']:
if a == 'frame':
args.append(Symbol('frame'))
elif a == 'frame2':
args.append(Symbol('frame2'))
else:
args.append(a)
# Create expression: (prim_name arg1 arg2 ...)
expr = [Symbol(prim_name)] + args
result = compiler._eval(expr, env)
if hasattr(result, '__array__'):
return np.asarray(result)
return result
except Exception as e:
return None
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
"""Test a single primitive."""
frame = create_test_frame(pattern='gradient')
frame2 = create_test_frame(pattern='checkerboard')
result = TestResult(primitive=prim_name, passed=False)
# Run Python version
try:
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
if py_out is not None and hasattr(py_out, 'shape'):
result.python_mean = float(np.mean(py_out))
except Exception as e:
result.error = f"Python error: {e}"
return result
# Run JAX version
try:
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
if jax_out is not None and hasattr(jax_out, 'shape'):
result.jax_mean = float(np.mean(jax_out))
except Exception as e:
result.error = f"JAX error: {e}"
return result
if py_out is None:
result.error = "Python returned None"
return result
if jax_out is None:
result.error = "JAX returned None"
return result
# Compare
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
result.diff_mean = diff_mean
result.diff_max = diff_max
result.pct_within_1 = pct
# Check pass/fail
result.passed = (
diff_mean <= TOLERANCE_MEAN and
diff_max <= TOLERANCE_MAX and
pct >= TOLERANCE_PCT
)
if not result.passed:
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
return result
def discover_primitives(interp) -> List[str]:
"""Discover all primitives available in the interpreter."""
return sorted(interp.primitives.keys())
def run_all_tests(verbose=True):
"""Run all primitive tests."""
import warnings
warnings.filterwarnings('ignore')
import os
os.chdir('/home/giles/art/test')
from streaming.stream_sexp_generic import StreamInterpreter
from sexp_effects.primitive_libs import core as core_mod
core_mod.set_random_seed(42)
# Create interpreter to get Python primitives
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
interp._init()
results = []
print("=" * 60)
print("JAX PRIMITIVE TEST SUITE")
print("=" * 60)
# Test defined primitives
for prim_name, test_def in PRIMITIVE_TESTS.items():
result = test_primitive(interp, prim_name, test_def)
results.append(result)
status = "✓ PASS" if result.passed else "✗ FAIL"
if verbose:
print(f"{status} {prim_name}")
if not result.passed:
print(f" {result.error}")
# Summary
passed = sum(1 for r in results if r.passed)
failed = sum(1 for r in results if not r.passed)
print("\n" + "=" * 60)
print(f"SUMMARY: {passed} passed, {failed} failed")
print("=" * 60)
if failed > 0:
print("\nFailed primitives:")
for r in results:
if not r.passed:
print(f" - {r.primitive}: {r.error}")
return results
if __name__ == '__main__':
run_all_tests()

305
tests/test_xector.py Normal file
View File

@@ -0,0 +1,305 @@
"""
Tests for xector primitives - parallel array operations.
"""
import pytest
import numpy as np
from sexp_effects.primitive_libs.xector import (
Xector,
xector_red, xector_green, xector_blue, xector_rgb,
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
xector_dist_from_center,
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
alpha_sin, alpha_cos, alpha_sq,
alpha_lt, alpha_gt, alpha_eq,
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
xector_where, xector_fill, xector_zeros, xector_ones,
is_xector,
)
class TestXectorBasics:
"""Test Xector class basic operations."""
def test_create_from_list(self):
x = Xector([1, 2, 3])
assert len(x) == 3
assert is_xector(x)
def test_create_from_numpy(self):
arr = np.array([1.0, 2.0, 3.0])
x = Xector(arr)
assert len(x) == 3
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
def test_implicit_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = a + b
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_implicit_mul(self):
a = Xector([1, 2, 3])
b = Xector([2, 2, 2])
c = a * b
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_scalar_broadcast(self):
a = Xector([1, 2, 3])
c = a + 10
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
def test_scalar_broadcast_rmul(self):
a = Xector([1, 2, 3])
c = 2 * a
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
class TestAlphaOperations:
"""Test α (element-wise) operations."""
def test_alpha_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = alpha_add(a, b)
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_alpha_add_multi(self):
a = Xector([1, 2, 3])
b = Xector([1, 1, 1])
c = Xector([10, 10, 10])
d = alpha_add(a, b, c)
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
def test_alpha_mul_scalar(self):
a = Xector([1, 2, 3])
c = alpha_mul(a, 2)
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_alpha_sqrt(self):
a = Xector([1, 4, 9, 16])
c = alpha_sqrt(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
def test_alpha_clamp(self):
a = Xector([-5, 0, 5, 10, 15])
c = alpha_clamp(a, 0, 10)
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
def test_alpha_sin_cos(self):
a = Xector([0, np.pi/2, np.pi])
s = alpha_sin(a)
c = alpha_cos(a)
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
def test_alpha_sq(self):
a = Xector([1, 2, 3, 4])
c = alpha_sq(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
def test_alpha_comparison(self):
a = Xector([1, 2, 3, 4])
b = Xector([2, 2, 2, 2])
lt = alpha_lt(a, b)
gt = alpha_gt(a, b)
eq = alpha_eq(a, b)
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
class TestBetaOperations:
"""Test β (reduction) operations."""
def test_beta_add(self):
a = Xector([1, 2, 3, 4])
assert beta_add(a) == 10
def test_beta_mul(self):
a = Xector([1, 2, 3, 4])
assert beta_mul(a) == 24
def test_beta_min_max(self):
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
assert beta_min(a) == 1
assert beta_max(a) == 9
def test_beta_mean(self):
a = Xector([1, 2, 3, 4])
assert beta_mean(a) == 2.5
def test_beta_count(self):
a = Xector([1, 2, 3, 4, 5])
assert beta_count(a) == 5
class TestFrameConversion:
"""Test frame/xector conversion."""
def test_extract_channels(self):
# Create a 2x2 RGB frame
frame = np.array([
[[255, 0, 0], [0, 255, 0]],
[[0, 0, 255], [128, 128, 128]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
assert len(r) == 4
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
def test_rgb_roundtrip(self):
# Create a 2x2 RGB frame
frame = np.array([
[[100, 150, 200], [50, 75, 100]],
[[200, 100, 50], [25, 50, 75]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
reconstructed = xector_rgb(r, g, b)
np.testing.assert_array_equal(reconstructed, frame)
def test_modify_and_reconstruct(self):
frame = np.array([
[[100, 100, 100], [100, 100, 100]],
[[100, 100, 100], [100, 100, 100]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
# Double red channel
r_doubled = r * 2
result = xector_rgb(r_doubled, g, b)
# Red should be 200, others unchanged
assert result[0, 0, 0] == 200
assert result[0, 0, 1] == 100
assert result[0, 0, 2] == 100
class TestCoordinates:
"""Test coordinate generation."""
def test_x_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
x = xector_x_coords(frame)
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
def test_y_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
y = xector_y_coords(frame)
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
def test_normalized_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_x_norm(frame)
y = xector_y_norm(frame)
# x should go 0 to 1 across width
assert x.to_numpy()[0] == 0
assert x.to_numpy()[2] == 1
# y should go 0 to 1 down height
assert y.to_numpy()[0] == 0
assert y.to_numpy()[3] == 1
class TestConditional:
"""Test conditional operations."""
def test_where(self):
cond = Xector([True, False, True, False])
true_val = Xector([1, 1, 1, 1])
false_val = Xector([0, 0, 0, 0])
result = xector_where(cond, true_val, false_val)
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
def test_where_with_comparison(self):
a = Xector([1, 5, 3, 7])
threshold = 4
# Elements > 4 become 255, others become 0
result = xector_where(alpha_gt(a, threshold), 255, 0)
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
def test_fill(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_fill(42, frame)
assert len(x) == 6
assert all(v == 42 for v in x.to_numpy())
def test_zeros_ones(self):
frame = np.zeros((2, 2, 3), dtype=np.uint8)
z = xector_zeros(frame)
o = xector_ones(frame)
assert all(v == 0 for v in z.to_numpy())
assert all(v == 1 for v in o.to_numpy())
class TestInterpreterIntegration:
"""Test xector operations through the interpreter."""
def test_xector_vignette_effect(self):
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load the xector vignette effect
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
# Create a test frame (white)
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
# Run effect
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
# Center should be brighter than corners
center = result[50, 50]
corner = result[0, 0]
assert center.mean() > corner.mean(), "Center should be brighter than corners"
# Corners should be darkened
assert corner.mean() < 255, "Corners should be darkened"
def test_implicit_elementwise(self):
"""Test that regular + works element-wise on xectors."""
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load xector primitives
from sexp_effects.primitive_libs.xector import PRIMITIVES
for name, fn in PRIMITIVES.items():
interp.global_env.set(name, fn)
# Parse and eval a simple xector expression
from sexp_effects.parser import parse
expr = parse('(+ (red frame) 10)')
# Create test frame
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
interp.global_env.set('frame', frame)
result = interp.eval(expr)
# Should be a xector with values 110
assert is_xector(result)
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
if __name__ == '__main__':
pytest.main([__file__, '-v'])