Files
celery/tests/test_jax_pipeline_integration.py
gilesb fc9597456f
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:41:19 +00:00

518 lines
17 KiB
Python

#!/usr/bin/env python3
"""Integration tests comparing JAX and Python rendering pipelines.
These tests ensure the JAX-compiled effect chains produce identical output
to the Python/NumPy path. They test:
1. Full effect pipelines through both interpreters
2. Multi-frame sequences (to catch phase-dependent bugs)
3. Compiled effect chain fusion
4. Edge cases like shrinking/zooming that affect boundary handling
"""
import os
import sys
import pytest
import numpy as np
import shutil
from pathlib import Path
# Ensure the art-celery module is importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sexp_effects.primitive_libs import core as core_mod
# Path to test resources
TEST_DIR = Path('/home/giles/art/test')
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
TEMPLATES_DIR = TEST_DIR / 'templates'
def create_test_image(h=96, w=128):
"""Create a test image with distinct patterns."""
import cv2
img = np.zeros((h, w, 3), dtype=np.uint8)
# Create gradient background
for y in range(h):
for x in range(w):
img[y, x] = [
int(255 * x / w), # R: horizontal gradient
int(255 * y / h), # G: vertical gradient
128 # B: constant
]
# Add features
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
return img
@pytest.fixture(scope='module')
def test_env(tmp_path_factory):
"""Set up test environment with sexp files and test media."""
test_dir = tmp_path_factory.mktemp('sexp_test')
original_dir = os.getcwd()
os.chdir(test_dir)
# Create directories
(test_dir / 'effects').mkdir()
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
# Create test image
import cv2
test_img = create_test_image()
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
# Copy required effect files
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
src = EFFECTS_DIR / f'{effect}.sexp'
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
if src.exists():
shutil.copy(src, dst)
yield {
'dir': test_dir,
'image_path': test_dir / 'test_image.png',
'test_img': test_img,
}
os.chdir(original_dir)
def create_sexp_file(test_dir, content, filename='test.sexp'):
"""Create a test sexp file."""
path = test_dir / 'effects' / filename
with open(path, 'w') as f:
f.write(content)
return str(path)
class TestJaxPythonPipelineEquivalence:
"""Test that JAX and Python pipelines produce equivalent output."""
def test_single_rotate_effect(self, test_env):
"""Test that a single rotate effect matches between Python and JAX."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(frame (rotate frame :angle 15 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
# Python path
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
# JAX path
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
# Inject test image into globals
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
# Force deferred if needed
py_result = np.asarray(py_interp._maybe_force(py_result))
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
def test_rotate_then_zoom(self, test_env):
"""Test rotate followed by zoom - tests effect chain fusion."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (-> (rotate frame :angle 15 :speed 0)
(zoom :amount 0.95 :speed 0)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
def test_zoom_shrink_boundary_handling(self, test_env):
"""Test zoom with shrinking factor - critical for boundary handling."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (zoom frame :amount 0.8 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
# Check corners specifically - these are most affected by boundary handling
h, w = test_img.shape[:2]
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
for y, x in corners:
py_val = py_result[y, x]
jax_val = jax_result[y, x]
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
def test_blend_two_clips(self, test_env):
"""Test blending two effect chains - the core bug scenario."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
max_diff = np.max(diff)
# Check edge region specifically
edge_diff = diff[0, :].mean()
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
def test_blend_with_invert(self, test_env):
"""Test blending with invert - matches the problematic recipe pattern."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(require-primitives "color_ops")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(effect invert :path "../sexp_effects/effects/invert.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0)
(invert :amount 1))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
class TestDeferredEffectChainFusion:
"""Test the DeferredEffectChain fusion mechanism specifically."""
def test_manual_vs_fused_chain(self, test_env):
"""Compare manual application vs fused DeferredEffectChain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
# Create minimal sexp to load effects
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Manual step-by-step application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
rot_angle = -5.0
zoom_amount = 0.95
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain application
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
# Check specific pixels
h, w = test_img.shape[:2]
for y in [0, h//2, h-1]:
for x in [0, w//2, w-1]:
manual_val = manual_result[y, x]
fused_val = fused_result[y, x]
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
def test_chain_with_shrink_zoom_boundary(self, test_env):
"""Test that shrinking zoom handles boundaries correctly in chain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Parameters that shrink the image (zoom < 1.0)
rot_angle = -4.555
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
# Manual application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
# Check top edge specifically - this is where boundary issues manifest
top_edge_manual = manual_result[0, :]
top_edge_fused = fused_result[0, :]
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
mean_edge_diff = np.mean(edge_diff)
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
# Check for zeros at edges that shouldn't be there
manual_edge_sum = np.sum(top_edge_manual)
fused_edge_sum = np.sum(top_edge_fused)
if manual_edge_sum > 100: # If manual has significant values
assert fused_edge_sum > manual_edge_sum * 0.5, \
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])