#!/usr/bin/env python3 """Integration tests comparing JAX and Python rendering pipelines. These tests ensure the JAX-compiled effect chains produce identical output to the Python/NumPy path. They test: 1. Full effect pipelines through both interpreters 2. Multi-frame sequences (to catch phase-dependent bugs) 3. Compiled effect chain fusion 4. Edge cases like shrinking/zooming that affect boundary handling """ import os import sys import pytest import numpy as np import shutil from pathlib import Path # Ensure the art-celery module is importable sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from sexp_effects.primitive_libs import core as core_mod # Path to test resources TEST_DIR = Path('/home/giles/art/test') EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects' TEMPLATES_DIR = TEST_DIR / 'templates' def create_test_image(h=96, w=128): """Create a test image with distinct patterns.""" import cv2 img = np.zeros((h, w, 3), dtype=np.uint8) # Create gradient background for y in range(h): for x in range(w): img[y, x] = [ int(255 * x / w), # R: horizontal gradient int(255 * y / h), # G: vertical gradient 128 # B: constant ] # Add features cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1) cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1) return img @pytest.fixture(scope='module') def test_env(tmp_path_factory): """Set up test environment with sexp files and test media.""" test_dir = tmp_path_factory.mktemp('sexp_test') original_dir = os.getcwd() os.chdir(test_dir) # Create directories (test_dir / 'effects').mkdir() (test_dir / 'sexp_effects' / 'effects').mkdir(parents=True) # Create test image import cv2 test_img = create_test_image() cv2.imwrite(str(test_dir / 'test_image.png'), test_img) # Copy required effect files for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']: src = EFFECTS_DIR / f'{effect}.sexp' dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp' if src.exists(): shutil.copy(src, dst) yield { 'dir': test_dir, 'image_path': test_dir / 'test_image.png', 'test_img': test_img, } os.chdir(original_dir) def create_sexp_file(test_dir, content, filename='test.sexp'): """Create a test sexp file.""" path = test_dir / 'effects' / filename with open(path, 'w') as f: f.write(content) return str(path) class TestJaxPythonPipelineEquivalence: """Test that JAX and Python pipelines produce equivalent output.""" def test_single_rotate_effect(self, test_env): """Test that a single rotate effect matches between Python and JAX.""" sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (effect rotate :path "../sexp_effects/effects/rotate.sexp") (frame (rotate frame :angle 15 :speed 0)) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) from streaming.stream_sexp_generic import StreamInterpreter, Context import cv2 test_img = cv2.imread(str(test_env['image_path'])) # Python path core_mod.set_random_seed(42) py_interp = StreamInterpreter(sexp_path, use_jax=False) py_interp._init() # JAX path core_mod.set_random_seed(42) jax_interp = StreamInterpreter(sexp_path, use_jax=True) jax_interp._init() ctx = Context(fps=10) ctx.t = 0.5 ctx.frame_num = 5 frame_env = { 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, 't': ctx.t, 'frame-num': ctx.frame_num, } # Inject test image into globals py_interp.globals['frame'] = test_img jax_interp.globals['frame'] = test_img py_result = py_interp._eval(py_interp.frame_pipeline, frame_env) jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env) # Force deferred if needed py_result = np.asarray(py_interp._maybe_force(py_result)) jax_result = np.asarray(jax_interp._maybe_force(jax_result)) diff = np.abs(py_result.astype(float) - jax_result.astype(float)) mean_diff = np.mean(diff) assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold" def test_rotate_then_zoom(self, test_env): """Test rotate followed by zoom - tests effect chain fusion.""" sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (effect rotate :path "../sexp_effects/effects/rotate.sexp") (effect zoom :path "../sexp_effects/effects/zoom.sexp") (frame (-> (rotate frame :angle 15 :speed 0) (zoom :amount 0.95 :speed 0))) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) from streaming.stream_sexp_generic import StreamInterpreter, Context import cv2 test_img = cv2.imread(str(test_env['image_path'])) core_mod.set_random_seed(42) py_interp = StreamInterpreter(sexp_path, use_jax=False) py_interp._init() core_mod.set_random_seed(42) jax_interp = StreamInterpreter(sexp_path, use_jax=True) jax_interp._init() ctx = Context(fps=10) ctx.t = 0.5 ctx.frame_num = 5 frame_env = { 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, 't': ctx.t, 'frame-num': ctx.frame_num, } py_interp.globals['frame'] = test_img jax_interp.globals['frame'] = test_img py_result = np.asarray(py_interp._maybe_force( py_interp._eval(py_interp.frame_pipeline, frame_env))) jax_result = np.asarray(jax_interp._maybe_force( jax_interp._eval(jax_interp.frame_pipeline, frame_env))) diff = np.abs(py_result.astype(float) - jax_result.astype(float)) mean_diff = np.mean(diff) assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold" def test_zoom_shrink_boundary_handling(self, test_env): """Test zoom with shrinking factor - critical for boundary handling.""" sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (effect zoom :path "../sexp_effects/effects/zoom.sexp") (frame (zoom frame :amount 0.8 :speed 0)) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) from streaming.stream_sexp_generic import StreamInterpreter, Context import cv2 test_img = cv2.imread(str(test_env['image_path'])) core_mod.set_random_seed(42) py_interp = StreamInterpreter(sexp_path, use_jax=False) py_interp._init() core_mod.set_random_seed(42) jax_interp = StreamInterpreter(sexp_path, use_jax=True) jax_interp._init() ctx = Context(fps=10) ctx.t = 0.5 ctx.frame_num = 5 frame_env = { 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, 't': ctx.t, 'frame-num': ctx.frame_num, } py_interp.globals['frame'] = test_img jax_interp.globals['frame'] = test_img py_result = np.asarray(py_interp._maybe_force( py_interp._eval(py_interp.frame_pipeline, frame_env))) jax_result = np.asarray(jax_interp._maybe_force( jax_interp._eval(jax_interp.frame_pipeline, frame_env))) # Check corners specifically - these are most affected by boundary handling h, w = test_img.shape[:2] corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)] for y, x in corners: py_val = py_result[y, x] jax_val = jax_result[y, x] corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max() assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}" def test_blend_two_clips(self, test_env): """Test blending two effect chains - the core bug scenario.""" sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (require-primitives "core") (require-primitives "image") (require-primitives "blending") (effect rotate :path "../sexp_effects/effects/rotate.sexp") (effect zoom :path "../sexp_effects/effects/zoom.sexp") (effect blend :path "../sexp_effects/effects/blend.sexp") (frame (let [clip_a (-> (rotate frame :angle 5 :speed 0) (zoom :amount 1.05 :speed 0)) clip_b (-> (rotate frame :angle -5 :speed 0) (zoom :amount 0.95 :speed 0))] (blend :base clip_a :overlay clip_b :opacity 0.5))) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) from streaming.stream_sexp_generic import StreamInterpreter, Context import cv2 test_img = cv2.imread(str(test_env['image_path'])) core_mod.set_random_seed(42) py_interp = StreamInterpreter(sexp_path, use_jax=False) py_interp._init() core_mod.set_random_seed(42) jax_interp = StreamInterpreter(sexp_path, use_jax=True) jax_interp._init() ctx = Context(fps=10) ctx.t = 0.5 ctx.frame_num = 5 frame_env = { 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, 't': ctx.t, 'frame-num': ctx.frame_num, } py_interp.globals['frame'] = test_img jax_interp.globals['frame'] = test_img py_result = np.asarray(py_interp._maybe_force( py_interp._eval(py_interp.frame_pipeline, frame_env))) jax_result = np.asarray(jax_interp._maybe_force( jax_interp._eval(jax_interp.frame_pipeline, frame_env))) diff = np.abs(py_result.astype(float) - jax_result.astype(float)) mean_diff = np.mean(diff) max_diff = np.max(diff) # Check edge region specifically edge_diff = diff[0, :].mean() assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold" assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold" def test_blend_with_invert(self, test_env): """Test blending with invert - matches the problematic recipe pattern.""" sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (require-primitives "core") (require-primitives "image") (require-primitives "blending") (require-primitives "color_ops") (effect rotate :path "../sexp_effects/effects/rotate.sexp") (effect zoom :path "../sexp_effects/effects/zoom.sexp") (effect blend :path "../sexp_effects/effects/blend.sexp") (effect invert :path "../sexp_effects/effects/invert.sexp") (frame (let [clip_a (-> (rotate frame :angle 5 :speed 0) (zoom :amount 1.05 :speed 0) (invert :amount 1)) clip_b (-> (rotate frame :angle -5 :speed 0) (zoom :amount 0.95 :speed 0))] (blend :base clip_a :overlay clip_b :opacity 0.5))) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) from streaming.stream_sexp_generic import StreamInterpreter, Context import cv2 test_img = cv2.imread(str(test_env['image_path'])) core_mod.set_random_seed(42) py_interp = StreamInterpreter(sexp_path, use_jax=False) py_interp._init() core_mod.set_random_seed(42) jax_interp = StreamInterpreter(sexp_path, use_jax=True) jax_interp._init() ctx = Context(fps=10) ctx.t = 0.5 ctx.frame_num = 5 frame_env = { 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, 't': ctx.t, 'frame-num': ctx.frame_num, } py_interp.globals['frame'] = test_img jax_interp.globals['frame'] = test_img py_result = np.asarray(py_interp._maybe_force( py_interp._eval(py_interp.frame_pipeline, frame_env))) jax_result = np.asarray(jax_interp._maybe_force( jax_interp._eval(jax_interp.frame_pipeline, frame_env))) diff = np.abs(py_result.astype(float) - jax_result.astype(float)) mean_diff = np.mean(diff) assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold" class TestDeferredEffectChainFusion: """Test the DeferredEffectChain fusion mechanism specifically.""" def test_manual_vs_fused_chain(self, test_env): """Compare manual application vs fused DeferredEffectChain.""" import jax.numpy as jnp from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain # Create minimal sexp to load effects sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (effect rotate :path "../sexp_effects/effects/rotate.sexp") (effect zoom :path "../sexp_effects/effects/zoom.sexp") (frame frame) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) core_mod.set_random_seed(42) interp = StreamInterpreter(sexp_path, use_jax=True) interp._init() test_img = test_env['test_img'] jax_frame = jnp.array(test_img) t = 0.5 frame_num = 5 # Manual step-by-step application rotate_fn = interp.jax_effects['rotate'] zoom_fn = interp.jax_effects['zoom'] rot_angle = -5.0 zoom_amount = 0.95 manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, angle=rot_angle, speed=0) manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, amount=zoom_amount, speed=0) manual_result = np.asarray(manual_result) # Fused chain application chain = DeferredEffectChain( ['rotate'], [{'angle': rot_angle, 'speed': 0}], jax_frame, t, frame_num ) chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) fused_result = np.asarray(interp._force_deferred(chain)) diff = np.abs(manual_result.astype(float) - fused_result.astype(float)) mean_diff = np.mean(diff) assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}" # Check specific pixels h, w = test_img.shape[:2] for y in [0, h//2, h-1]: for x in [0, w//2, w-1]: manual_val = manual_result[y, x] fused_val = fused_result[y, x] pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max() assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}" def test_chain_with_shrink_zoom_boundary(self, test_env): """Test that shrinking zoom handles boundaries correctly in chain.""" import jax.numpy as jnp from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain sexp_content = '''(stream "test" :width 128 :height 96 :seed 42 (effect rotate :path "../sexp_effects/effects/rotate.sexp") (effect zoom :path "../sexp_effects/effects/zoom.sexp") (frame frame) ) ''' sexp_path = create_sexp_file(test_env['dir'], sexp_content) core_mod.set_random_seed(42) interp = StreamInterpreter(sexp_path, use_jax=True) interp._init() test_img = test_env['test_img'] jax_frame = jnp.array(test_img) t = 0.5 frame_num = 5 # Parameters that shrink the image (zoom < 1.0) rot_angle = -4.555 zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries # Manual application rotate_fn = interp.jax_effects['rotate'] zoom_fn = interp.jax_effects['zoom'] manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, angle=rot_angle, speed=0) manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, amount=zoom_amount, speed=0) manual_result = np.asarray(manual_result) # Fused chain chain = DeferredEffectChain( ['rotate'], [{'angle': rot_angle, 'speed': 0}], jax_frame, t, frame_num ) chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) fused_result = np.asarray(interp._force_deferred(chain)) # Check top edge specifically - this is where boundary issues manifest top_edge_manual = manual_result[0, :] top_edge_fused = fused_result[0, :] edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float)) mean_edge_diff = np.mean(edge_diff) assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}" # Check for zeros at edges that shouldn't be there manual_edge_sum = np.sum(top_edge_manual) fused_edge_sum = np.sum(top_edge_fused) if manual_edge_sum > 100: # If manual has significant values assert fused_edge_sum > manual_edge_sum * 0.5, \ f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}" if __name__ == '__main__': pytest.main([__file__, '-v'])