All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
518 lines
17 KiB
Python
518 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""Integration tests comparing JAX and Python rendering pipelines.
|
|
|
|
These tests ensure the JAX-compiled effect chains produce identical output
|
|
to the Python/NumPy path. They test:
|
|
1. Full effect pipelines through both interpreters
|
|
2. Multi-frame sequences (to catch phase-dependent bugs)
|
|
3. Compiled effect chain fusion
|
|
4. Edge cases like shrinking/zooming that affect boundary handling
|
|
"""
|
|
import os
|
|
import sys
|
|
import pytest
|
|
import numpy as np
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
# Ensure the art-celery module is importable
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from sexp_effects.primitive_libs import core as core_mod
|
|
|
|
|
|
# Path to test resources
|
|
TEST_DIR = Path('/home/giles/art/test')
|
|
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
|
|
TEMPLATES_DIR = TEST_DIR / 'templates'
|
|
|
|
|
|
def create_test_image(h=96, w=128):
|
|
"""Create a test image with distinct patterns."""
|
|
import cv2
|
|
img = np.zeros((h, w, 3), dtype=np.uint8)
|
|
|
|
# Create gradient background
|
|
for y in range(h):
|
|
for x in range(w):
|
|
img[y, x] = [
|
|
int(255 * x / w), # R: horizontal gradient
|
|
int(255 * y / h), # G: vertical gradient
|
|
128 # B: constant
|
|
]
|
|
|
|
# Add features
|
|
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
|
|
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
|
|
|
|
return img
|
|
|
|
|
|
@pytest.fixture(scope='module')
|
|
def test_env(tmp_path_factory):
|
|
"""Set up test environment with sexp files and test media."""
|
|
test_dir = tmp_path_factory.mktemp('sexp_test')
|
|
original_dir = os.getcwd()
|
|
os.chdir(test_dir)
|
|
|
|
# Create directories
|
|
(test_dir / 'effects').mkdir()
|
|
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
|
|
|
|
# Create test image
|
|
import cv2
|
|
test_img = create_test_image()
|
|
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
|
|
|
|
# Copy required effect files
|
|
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
|
|
src = EFFECTS_DIR / f'{effect}.sexp'
|
|
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
|
|
if src.exists():
|
|
shutil.copy(src, dst)
|
|
|
|
yield {
|
|
'dir': test_dir,
|
|
'image_path': test_dir / 'test_image.png',
|
|
'test_img': test_img,
|
|
}
|
|
|
|
os.chdir(original_dir)
|
|
|
|
|
|
def create_sexp_file(test_dir, content, filename='test.sexp'):
|
|
"""Create a test sexp file."""
|
|
path = test_dir / 'effects' / filename
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
return str(path)
|
|
|
|
|
|
class TestJaxPythonPipelineEquivalence:
|
|
"""Test that JAX and Python pipelines produce equivalent output."""
|
|
|
|
def test_single_rotate_effect(self, test_env):
|
|
"""Test that a single rotate effect matches between Python and JAX."""
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
|
|
(frame (rotate frame :angle 15 :speed 0))
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
|
import cv2
|
|
|
|
test_img = cv2.imread(str(test_env['image_path']))
|
|
|
|
# Python path
|
|
core_mod.set_random_seed(42)
|
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
|
py_interp._init()
|
|
|
|
# JAX path
|
|
core_mod.set_random_seed(42)
|
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
jax_interp._init()
|
|
|
|
ctx = Context(fps=10)
|
|
ctx.t = 0.5
|
|
ctx.frame_num = 5
|
|
|
|
frame_env = {
|
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
# Inject test image into globals
|
|
py_interp.globals['frame'] = test_img
|
|
jax_interp.globals['frame'] = test_img
|
|
|
|
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
|
|
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
|
|
|
|
# Force deferred if needed
|
|
py_result = np.asarray(py_interp._maybe_force(py_result))
|
|
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
|
|
|
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
|
mean_diff = np.mean(diff)
|
|
|
|
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
|
|
|
|
def test_rotate_then_zoom(self, test_env):
|
|
"""Test rotate followed by zoom - tests effect chain fusion."""
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
|
|
(frame (-> (rotate frame :angle 15 :speed 0)
|
|
(zoom :amount 0.95 :speed 0)))
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
|
import cv2
|
|
|
|
test_img = cv2.imread(str(test_env['image_path']))
|
|
|
|
core_mod.set_random_seed(42)
|
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
|
py_interp._init()
|
|
|
|
core_mod.set_random_seed(42)
|
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
jax_interp._init()
|
|
|
|
ctx = Context(fps=10)
|
|
ctx.t = 0.5
|
|
ctx.frame_num = 5
|
|
|
|
frame_env = {
|
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
py_interp.globals['frame'] = test_img
|
|
jax_interp.globals['frame'] = test_img
|
|
|
|
py_result = np.asarray(py_interp._maybe_force(
|
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
|
jax_result = np.asarray(jax_interp._maybe_force(
|
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
|
|
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
|
mean_diff = np.mean(diff)
|
|
|
|
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
|
|
|
|
def test_zoom_shrink_boundary_handling(self, test_env):
|
|
"""Test zoom with shrinking factor - critical for boundary handling."""
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
|
|
(frame (zoom frame :amount 0.8 :speed 0))
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
|
import cv2
|
|
|
|
test_img = cv2.imread(str(test_env['image_path']))
|
|
|
|
core_mod.set_random_seed(42)
|
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
|
py_interp._init()
|
|
|
|
core_mod.set_random_seed(42)
|
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
jax_interp._init()
|
|
|
|
ctx = Context(fps=10)
|
|
ctx.t = 0.5
|
|
ctx.frame_num = 5
|
|
|
|
frame_env = {
|
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
py_interp.globals['frame'] = test_img
|
|
jax_interp.globals['frame'] = test_img
|
|
|
|
py_result = np.asarray(py_interp._maybe_force(
|
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
|
jax_result = np.asarray(jax_interp._maybe_force(
|
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
|
|
|
# Check corners specifically - these are most affected by boundary handling
|
|
h, w = test_img.shape[:2]
|
|
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
|
|
for y, x in corners:
|
|
py_val = py_result[y, x]
|
|
jax_val = jax_result[y, x]
|
|
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
|
|
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
|
|
|
|
def test_blend_two_clips(self, test_env):
|
|
"""Test blending two effect chains - the core bug scenario."""
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(require-primitives "core")
|
|
(require-primitives "image")
|
|
(require-primitives "blending")
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
|
|
|
(frame
|
|
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
|
(zoom :amount 1.05 :speed 0))
|
|
clip_b (-> (rotate frame :angle -5 :speed 0)
|
|
(zoom :amount 0.95 :speed 0))]
|
|
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
|
import cv2
|
|
|
|
test_img = cv2.imread(str(test_env['image_path']))
|
|
|
|
core_mod.set_random_seed(42)
|
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
|
py_interp._init()
|
|
|
|
core_mod.set_random_seed(42)
|
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
jax_interp._init()
|
|
|
|
ctx = Context(fps=10)
|
|
ctx.t = 0.5
|
|
ctx.frame_num = 5
|
|
|
|
frame_env = {
|
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
py_interp.globals['frame'] = test_img
|
|
jax_interp.globals['frame'] = test_img
|
|
|
|
py_result = np.asarray(py_interp._maybe_force(
|
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
|
jax_result = np.asarray(jax_interp._maybe_force(
|
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
|
|
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
|
mean_diff = np.mean(diff)
|
|
max_diff = np.max(diff)
|
|
|
|
# Check edge region specifically
|
|
edge_diff = diff[0, :].mean()
|
|
|
|
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
|
|
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
|
|
|
|
def test_blend_with_invert(self, test_env):
|
|
"""Test blending with invert - matches the problematic recipe pattern."""
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(require-primitives "core")
|
|
(require-primitives "image")
|
|
(require-primitives "blending")
|
|
(require-primitives "color_ops")
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
|
(effect invert :path "../sexp_effects/effects/invert.sexp")
|
|
|
|
(frame
|
|
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
|
(zoom :amount 1.05 :speed 0)
|
|
(invert :amount 1))
|
|
clip_b (-> (rotate frame :angle -5 :speed 0)
|
|
(zoom :amount 0.95 :speed 0))]
|
|
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
|
import cv2
|
|
|
|
test_img = cv2.imread(str(test_env['image_path']))
|
|
|
|
core_mod.set_random_seed(42)
|
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
|
py_interp._init()
|
|
|
|
core_mod.set_random_seed(42)
|
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
jax_interp._init()
|
|
|
|
ctx = Context(fps=10)
|
|
ctx.t = 0.5
|
|
ctx.frame_num = 5
|
|
|
|
frame_env = {
|
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
|
}
|
|
|
|
py_interp.globals['frame'] = test_img
|
|
jax_interp.globals['frame'] = test_img
|
|
|
|
py_result = np.asarray(py_interp._maybe_force(
|
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
|
jax_result = np.asarray(jax_interp._maybe_force(
|
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
|
|
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
|
mean_diff = np.mean(diff)
|
|
|
|
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
|
|
|
|
|
|
class TestDeferredEffectChainFusion:
|
|
"""Test the DeferredEffectChain fusion mechanism specifically."""
|
|
|
|
def test_manual_vs_fused_chain(self, test_env):
|
|
"""Compare manual application vs fused DeferredEffectChain."""
|
|
import jax.numpy as jnp
|
|
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
|
|
|
# Create minimal sexp to load effects
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
|
|
(frame frame)
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
core_mod.set_random_seed(42)
|
|
interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
interp._init()
|
|
|
|
test_img = test_env['test_img']
|
|
jax_frame = jnp.array(test_img)
|
|
t = 0.5
|
|
frame_num = 5
|
|
|
|
# Manual step-by-step application
|
|
rotate_fn = interp.jax_effects['rotate']
|
|
zoom_fn = interp.jax_effects['zoom']
|
|
|
|
rot_angle = -5.0
|
|
zoom_amount = 0.95
|
|
|
|
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
|
angle=rot_angle, speed=0)
|
|
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
|
amount=zoom_amount, speed=0)
|
|
manual_result = np.asarray(manual_result)
|
|
|
|
# Fused chain application
|
|
chain = DeferredEffectChain(
|
|
['rotate'],
|
|
[{'angle': rot_angle, 'speed': 0}],
|
|
jax_frame, t, frame_num
|
|
)
|
|
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
|
|
|
fused_result = np.asarray(interp._force_deferred(chain))
|
|
|
|
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
|
|
mean_diff = np.mean(diff)
|
|
|
|
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
|
|
|
|
# Check specific pixels
|
|
h, w = test_img.shape[:2]
|
|
for y in [0, h//2, h-1]:
|
|
for x in [0, w//2, w-1]:
|
|
manual_val = manual_result[y, x]
|
|
fused_val = fused_result[y, x]
|
|
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
|
|
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
|
|
|
|
def test_chain_with_shrink_zoom_boundary(self, test_env):
|
|
"""Test that shrinking zoom handles boundaries correctly in chain."""
|
|
import jax.numpy as jnp
|
|
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
|
|
|
sexp_content = '''(stream "test"
|
|
:width 128
|
|
:height 96
|
|
:seed 42
|
|
|
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
|
|
|
(frame frame)
|
|
)
|
|
'''
|
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
|
|
|
core_mod.set_random_seed(42)
|
|
interp = StreamInterpreter(sexp_path, use_jax=True)
|
|
interp._init()
|
|
|
|
test_img = test_env['test_img']
|
|
jax_frame = jnp.array(test_img)
|
|
t = 0.5
|
|
frame_num = 5
|
|
|
|
# Parameters that shrink the image (zoom < 1.0)
|
|
rot_angle = -4.555
|
|
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
|
|
|
|
# Manual application
|
|
rotate_fn = interp.jax_effects['rotate']
|
|
zoom_fn = interp.jax_effects['zoom']
|
|
|
|
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
|
angle=rot_angle, speed=0)
|
|
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
|
amount=zoom_amount, speed=0)
|
|
manual_result = np.asarray(manual_result)
|
|
|
|
# Fused chain
|
|
chain = DeferredEffectChain(
|
|
['rotate'],
|
|
[{'angle': rot_angle, 'speed': 0}],
|
|
jax_frame, t, frame_num
|
|
)
|
|
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
|
|
|
fused_result = np.asarray(interp._force_deferred(chain))
|
|
|
|
# Check top edge specifically - this is where boundary issues manifest
|
|
top_edge_manual = manual_result[0, :]
|
|
top_edge_fused = fused_result[0, :]
|
|
|
|
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
|
|
mean_edge_diff = np.mean(edge_diff)
|
|
|
|
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
|
|
|
|
# Check for zeros at edges that shouldn't be there
|
|
manual_edge_sum = np.sum(top_edge_manual)
|
|
fused_edge_sum = np.sum(top_edge_fused)
|
|
|
|
if manual_edge_sum > 100: # If manual has significant values
|
|
assert fused_edge_sum > manual_edge_sum * 0.5, \
|
|
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__, '-v'])
|