Import L1 (celery) as l1/
This commit is contained in:
334
l1/tests/test_jax_primitives.py
Normal file
334
l1/tests/test_jax_primitives.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test framework to verify JAX primitives match Python primitives.
|
||||
|
||||
Compares output of each primitive through:
|
||||
1. Python/NumPy path
|
||||
2. JAX path (CPU)
|
||||
3. JAX path (GPU) - if available
|
||||
|
||||
Reports any mismatches with detailed diffs.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/home/giles/art/art-celery')
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Test configuration
|
||||
TEST_WIDTH = 64
|
||||
TEST_HEIGHT = 48
|
||||
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
||||
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
||||
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
primitive: str
|
||||
passed: bool
|
||||
python_mean: float = 0.0
|
||||
jax_mean: float = 0.0
|
||||
diff_mean: float = 0.0
|
||||
diff_max: float = 0.0
|
||||
pct_within_1: float = 0.0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
||||
"""Create a test frame with known pattern."""
|
||||
if pattern == 'gradient':
|
||||
# Diagonal gradient
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
r = (x * 255 / w).astype(np.uint8)
|
||||
g = (y * 255 / h).astype(np.uint8)
|
||||
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
||||
return np.stack([r, g, b], axis=2)
|
||||
elif pattern == 'checkerboard':
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
check = ((x // 8) + (y // 8)) % 2
|
||||
v = (check * 255).astype(np.uint8)
|
||||
return np.stack([v, v, v], axis=2)
|
||||
elif pattern == 'solid':
|
||||
return np.full((h, w, 3), 128, dtype=np.uint8)
|
||||
else:
|
||||
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
||||
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
||||
if py_out is None or jax_out is None:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
||||
# Compare coordinate maps
|
||||
diffs = []
|
||||
for k in py_out:
|
||||
if k in jax_out:
|
||||
py_arr = np.asarray(py_out[k])
|
||||
jax_arr = np.asarray(jax_out[k])
|
||||
if py_arr.shape == jax_arr.shape:
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
diffs.append(diff)
|
||||
if diffs:
|
||||
all_diff = np.concatenate([d.flatten() for d in diffs])
|
||||
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
py_arr = np.asarray(py_out)
|
||||
jax_arr = np.asarray(jax_out)
|
||||
|
||||
if py_arr.shape != jax_arr.shape:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Test Definitions
|
||||
# ============================================================================
|
||||
|
||||
PRIMITIVE_TESTS = {
|
||||
# Geometry primitives
|
||||
'geometry:ripple-displace': {
|
||||
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
||||
'returns': 'coords',
|
||||
},
|
||||
'geometry:rotate-img': {
|
||||
'args': ['frame', 45],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:scale-img': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-h': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-v': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Color operations
|
||||
'color_ops:invert': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:grayscale': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:brightness': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:contrast': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:hue-shift': {
|
||||
'args': ['frame', 90],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Image operations
|
||||
'image:width': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:height': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:channel': {
|
||||
'args': ['frame', 0],
|
||||
'returns': 'array',
|
||||
},
|
||||
|
||||
# Blending
|
||||
'blending:blend': {
|
||||
'args': ['frame', 'frame2', 0.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-add': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-multiply': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the Python interpreter."""
|
||||
if prim_name not in interp.primitives:
|
||||
return None
|
||||
|
||||
func = interp.primitives[prim_name]
|
||||
args = []
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(frame.copy())
|
||||
elif a == 'frame2':
|
||||
args.append(frame2.copy())
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
try:
|
||||
return func(*args)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the JAX compiler."""
|
||||
try:
|
||||
from streaming.sexp_to_jax import JaxCompiler
|
||||
import jax.numpy as jnp
|
||||
|
||||
compiler = JaxCompiler()
|
||||
|
||||
# Build a simple expression to test the primitive
|
||||
from sexp_effects.parser import Symbol, Keyword
|
||||
|
||||
args = []
|
||||
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
||||
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(Symbol('frame'))
|
||||
elif a == 'frame2':
|
||||
args.append(Symbol('frame2'))
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
# Create expression: (prim_name arg1 arg2 ...)
|
||||
expr = [Symbol(prim_name)] + args
|
||||
|
||||
result = compiler._eval(expr, env)
|
||||
|
||||
if hasattr(result, '__array__'):
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
||||
"""Test a single primitive."""
|
||||
frame = create_test_frame(pattern='gradient')
|
||||
frame2 = create_test_frame(pattern='checkerboard')
|
||||
|
||||
result = TestResult(primitive=prim_name, passed=False)
|
||||
|
||||
# Run Python version
|
||||
try:
|
||||
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
||||
if py_out is not None and hasattr(py_out, 'shape'):
|
||||
result.python_mean = float(np.mean(py_out))
|
||||
except Exception as e:
|
||||
result.error = f"Python error: {e}"
|
||||
return result
|
||||
|
||||
# Run JAX version
|
||||
try:
|
||||
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
||||
if jax_out is not None and hasattr(jax_out, 'shape'):
|
||||
result.jax_mean = float(np.mean(jax_out))
|
||||
except Exception as e:
|
||||
result.error = f"JAX error: {e}"
|
||||
return result
|
||||
|
||||
if py_out is None:
|
||||
result.error = "Python returned None"
|
||||
return result
|
||||
if jax_out is None:
|
||||
result.error = "JAX returned None"
|
||||
return result
|
||||
|
||||
# Compare
|
||||
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
||||
result.diff_mean = diff_mean
|
||||
result.diff_max = diff_max
|
||||
result.pct_within_1 = pct
|
||||
|
||||
# Check pass/fail
|
||||
result.passed = (
|
||||
diff_mean <= TOLERANCE_MEAN and
|
||||
diff_max <= TOLERANCE_MAX and
|
||||
pct >= TOLERANCE_PCT
|
||||
)
|
||||
|
||||
if not result.passed:
|
||||
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_primitives(interp) -> List[str]:
|
||||
"""Discover all primitives available in the interpreter."""
|
||||
return sorted(interp.primitives.keys())
|
||||
|
||||
|
||||
def run_all_tests(verbose=True):
|
||||
"""Run all primitive tests."""
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import os
|
||||
os.chdir('/home/giles/art/test')
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
|
||||
# Create interpreter to get Python primitives
|
||||
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
||||
interp._init()
|
||||
|
||||
results = []
|
||||
|
||||
print("=" * 60)
|
||||
print("JAX PRIMITIVE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Test defined primitives
|
||||
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
||||
result = test_primitive(interp, prim_name, test_def)
|
||||
results.append(result)
|
||||
|
||||
status = "✓ PASS" if result.passed else "✗ FAIL"
|
||||
if verbose:
|
||||
print(f"{status} {prim_name}")
|
||||
if not result.passed:
|
||||
print(f" {result.error}")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
failed = sum(1 for r in results if not r.passed)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"SUMMARY: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed > 0:
|
||||
print("\nFailed primitives:")
|
||||
for r in results:
|
||||
if not r.passed:
|
||||
print(f" - {r.primitive}: {r.error}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_all_tests()
|
||||
Reference in New Issue
Block a user