#!/usr/bin/env python3 """ Test framework to verify JAX primitives match Python primitives. Compares output of each primitive through: 1. Python/NumPy path 2. JAX path (CPU) 3. JAX path (GPU) - if available Reports any mismatches with detailed diffs. """ import sys sys.path.insert(0, '/home/giles/art/art-celery') import numpy as np import json from pathlib import Path from typing import Dict, List, Tuple, Any, Optional from dataclasses import dataclass, field # Test configuration TEST_WIDTH = 64 TEST_HEIGHT = 48 TOLERANCE_MEAN = 1.0 # Max allowed mean difference TOLERANCE_MAX = 10.0 # Max allowed single pixel difference TOLERANCE_PCT = 0.95 # Min % of pixels within ±1 @dataclass class TestResult: primitive: str passed: bool python_mean: float = 0.0 jax_mean: float = 0.0 diff_mean: float = 0.0 diff_max: float = 0.0 pct_within_1: float = 0.0 error: str = "" def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'): """Create a test frame with known pattern.""" if pattern == 'gradient': # Diagonal gradient y, x = np.mgrid[0:h, 0:w] r = (x * 255 / w).astype(np.uint8) g = (y * 255 / h).astype(np.uint8) b = ((x + y) * 127 / (w + h)).astype(np.uint8) return np.stack([r, g, b], axis=2) elif pattern == 'checkerboard': y, x = np.mgrid[0:h, 0:w] check = ((x // 8) + (y // 8)) % 2 v = (check * 255).astype(np.uint8) return np.stack([v, v, v], axis=2) elif pattern == 'solid': return np.full((h, w, 3), 128, dtype=np.uint8) else: return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]: """Compare two outputs, return (mean_diff, max_diff, pct_within_1).""" if py_out is None or jax_out is None: return float('inf'), float('inf'), 0.0 if isinstance(py_out, dict) and isinstance(jax_out, dict): # Compare coordinate maps diffs = [] for k in py_out: if k in jax_out: py_arr = np.asarray(py_out[k]) jax_arr = np.asarray(jax_out[k]) if py_arr.shape == jax_arr.shape: diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) diffs.append(diff) if diffs: all_diff = np.concatenate([d.flatten() for d in diffs]) return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1)) return float('inf'), float('inf'), 0.0 py_arr = np.asarray(py_out) jax_arr = np.asarray(jax_out) if py_arr.shape != jax_arr.shape: return float('inf'), float('inf'), 0.0 diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1)) # ============================================================================ # Primitive Test Definitions # ============================================================================ PRIMITIVE_TESTS = { # Geometry primitives 'geometry:ripple-displace': { 'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5], 'returns': 'coords', }, 'geometry:rotate-img': { 'args': ['frame', 45], 'returns': 'frame', }, 'geometry:scale-img': { 'args': ['frame', 1.5], 'returns': 'frame', }, 'geometry:flip-h': { 'args': ['frame'], 'returns': 'frame', }, 'geometry:flip-v': { 'args': ['frame'], 'returns': 'frame', }, # Color operations 'color_ops:invert': { 'args': ['frame'], 'returns': 'frame', }, 'color_ops:grayscale': { 'args': ['frame'], 'returns': 'frame', }, 'color_ops:brightness': { 'args': ['frame', 1.5], 'returns': 'frame', }, 'color_ops:contrast': { 'args': ['frame', 1.5], 'returns': 'frame', }, 'color_ops:hue-shift': { 'args': ['frame', 90], 'returns': 'frame', }, # Image operations 'image:width': { 'args': ['frame'], 'returns': 'scalar', }, 'image:height': { 'args': ['frame'], 'returns': 'scalar', }, 'image:channel': { 'args': ['frame', 0], 'returns': 'array', }, # Blending 'blending:blend': { 'args': ['frame', 'frame2', 0.5], 'returns': 'frame', }, 'blending:blend-add': { 'args': ['frame', 'frame2'], 'returns': 'frame', }, 'blending:blend-multiply': { 'args': ['frame', 'frame2'], 'returns': 'frame', }, } def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: """Run a primitive through the Python interpreter.""" if prim_name not in interp.primitives: return None func = interp.primitives[prim_name] args = [] for a in test_def['args']: if a == 'frame': args.append(frame.copy()) elif a == 'frame2': args.append(frame2.copy()) else: args.append(a) try: return func(*args) except Exception as e: return None def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: """Run a primitive through the JAX compiler.""" try: from streaming.sexp_to_jax import JaxCompiler import jax.numpy as jnp compiler = JaxCompiler() # Build a simple expression to test the primitive from sexp_effects.parser import Symbol, Keyword args = [] env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)} for a in test_def['args']: if a == 'frame': args.append(Symbol('frame')) elif a == 'frame2': args.append(Symbol('frame2')) else: args.append(a) # Create expression: (prim_name arg1 arg2 ...) expr = [Symbol(prim_name)] + args result = compiler._eval(expr, env) if hasattr(result, '__array__'): return np.asarray(result) return result except Exception as e: return None def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult: """Test a single primitive.""" frame = create_test_frame(pattern='gradient') frame2 = create_test_frame(pattern='checkerboard') result = TestResult(primitive=prim_name, passed=False) # Run Python version try: py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2) if py_out is not None and hasattr(py_out, 'shape'): result.python_mean = float(np.mean(py_out)) except Exception as e: result.error = f"Python error: {e}" return result # Run JAX version try: jax_out = run_jax_primitive(prim_name, test_def, frame, frame2) if jax_out is not None and hasattr(jax_out, 'shape'): result.jax_mean = float(np.mean(jax_out)) except Exception as e: result.error = f"JAX error: {e}" return result if py_out is None: result.error = "Python returned None" return result if jax_out is None: result.error = "JAX returned None" return result # Compare diff_mean, diff_max, pct = compare_outputs(py_out, jax_out) result.diff_mean = diff_mean result.diff_max = diff_max result.pct_within_1 = pct # Check pass/fail result.passed = ( diff_mean <= TOLERANCE_MEAN and diff_max <= TOLERANCE_MAX and pct >= TOLERANCE_PCT ) if not result.passed: result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}" return result def discover_primitives(interp) -> List[str]: """Discover all primitives available in the interpreter.""" return sorted(interp.primitives.keys()) def run_all_tests(verbose=True): """Run all primitive tests.""" import warnings warnings.filterwarnings('ignore') import os os.chdir('/home/giles/art/test') from streaming.stream_sexp_generic import StreamInterpreter from sexp_effects.primitive_libs import core as core_mod core_mod.set_random_seed(42) # Create interpreter to get Python primitives interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False) interp._init() results = [] print("=" * 60) print("JAX PRIMITIVE TEST SUITE") print("=" * 60) # Test defined primitives for prim_name, test_def in PRIMITIVE_TESTS.items(): result = test_primitive(interp, prim_name, test_def) results.append(result) status = "✓ PASS" if result.passed else "✗ FAIL" if verbose: print(f"{status} {prim_name}") if not result.passed: print(f" {result.error}") # Summary passed = sum(1 for r in results if r.passed) failed = sum(1 for r in results if not r.passed) print("\n" + "=" * 60) print(f"SUMMARY: {passed} passed, {failed} failed") print("=" * 60) if failed > 0: print("\nFailed primitives:") for r in results: if not r.passed: print(f" - {r.primitive}: {r.error}") return results if __name__ == '__main__': run_all_tests()