All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
335 lines
9.3 KiB
Python
335 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test framework to verify JAX primitives match Python primitives.
|
|
|
|
Compares output of each primitive through:
|
|
1. Python/NumPy path
|
|
2. JAX path (CPU)
|
|
3. JAX path (GPU) - if available
|
|
|
|
Reports any mismatches with detailed diffs.
|
|
"""
|
|
import sys
|
|
sys.path.insert(0, '/home/giles/art/art-celery')
|
|
|
|
import numpy as np
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Any, Optional
|
|
from dataclasses import dataclass, field
|
|
|
|
# Test configuration
|
|
TEST_WIDTH = 64
|
|
TEST_HEIGHT = 48
|
|
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
|
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
|
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
|
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
primitive: str
|
|
passed: bool
|
|
python_mean: float = 0.0
|
|
jax_mean: float = 0.0
|
|
diff_mean: float = 0.0
|
|
diff_max: float = 0.0
|
|
pct_within_1: float = 0.0
|
|
error: str = ""
|
|
|
|
|
|
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
|
"""Create a test frame with known pattern."""
|
|
if pattern == 'gradient':
|
|
# Diagonal gradient
|
|
y, x = np.mgrid[0:h, 0:w]
|
|
r = (x * 255 / w).astype(np.uint8)
|
|
g = (y * 255 / h).astype(np.uint8)
|
|
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
|
return np.stack([r, g, b], axis=2)
|
|
elif pattern == 'checkerboard':
|
|
y, x = np.mgrid[0:h, 0:w]
|
|
check = ((x // 8) + (y // 8)) % 2
|
|
v = (check * 255).astype(np.uint8)
|
|
return np.stack([v, v, v], axis=2)
|
|
elif pattern == 'solid':
|
|
return np.full((h, w, 3), 128, dtype=np.uint8)
|
|
else:
|
|
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
|
|
|
|
|
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
|
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
|
if py_out is None or jax_out is None:
|
|
return float('inf'), float('inf'), 0.0
|
|
|
|
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
|
# Compare coordinate maps
|
|
diffs = []
|
|
for k in py_out:
|
|
if k in jax_out:
|
|
py_arr = np.asarray(py_out[k])
|
|
jax_arr = np.asarray(jax_out[k])
|
|
if py_arr.shape == jax_arr.shape:
|
|
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
|
diffs.append(diff)
|
|
if diffs:
|
|
all_diff = np.concatenate([d.flatten() for d in diffs])
|
|
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
|
return float('inf'), float('inf'), 0.0
|
|
|
|
py_arr = np.asarray(py_out)
|
|
jax_arr = np.asarray(jax_out)
|
|
|
|
if py_arr.shape != jax_arr.shape:
|
|
return float('inf'), float('inf'), 0.0
|
|
|
|
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
|
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
|
|
|
|
|
# ============================================================================
|
|
# Primitive Test Definitions
|
|
# ============================================================================
|
|
|
|
PRIMITIVE_TESTS = {
|
|
# Geometry primitives
|
|
'geometry:ripple-displace': {
|
|
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
|
'returns': 'coords',
|
|
},
|
|
'geometry:rotate-img': {
|
|
'args': ['frame', 45],
|
|
'returns': 'frame',
|
|
},
|
|
'geometry:scale-img': {
|
|
'args': ['frame', 1.5],
|
|
'returns': 'frame',
|
|
},
|
|
'geometry:flip-h': {
|
|
'args': ['frame'],
|
|
'returns': 'frame',
|
|
},
|
|
'geometry:flip-v': {
|
|
'args': ['frame'],
|
|
'returns': 'frame',
|
|
},
|
|
|
|
# Color operations
|
|
'color_ops:invert': {
|
|
'args': ['frame'],
|
|
'returns': 'frame',
|
|
},
|
|
'color_ops:grayscale': {
|
|
'args': ['frame'],
|
|
'returns': 'frame',
|
|
},
|
|
'color_ops:brightness': {
|
|
'args': ['frame', 1.5],
|
|
'returns': 'frame',
|
|
},
|
|
'color_ops:contrast': {
|
|
'args': ['frame', 1.5],
|
|
'returns': 'frame',
|
|
},
|
|
'color_ops:hue-shift': {
|
|
'args': ['frame', 90],
|
|
'returns': 'frame',
|
|
},
|
|
|
|
# Image operations
|
|
'image:width': {
|
|
'args': ['frame'],
|
|
'returns': 'scalar',
|
|
},
|
|
'image:height': {
|
|
'args': ['frame'],
|
|
'returns': 'scalar',
|
|
},
|
|
'image:channel': {
|
|
'args': ['frame', 0],
|
|
'returns': 'array',
|
|
},
|
|
|
|
# Blending
|
|
'blending:blend': {
|
|
'args': ['frame', 'frame2', 0.5],
|
|
'returns': 'frame',
|
|
},
|
|
'blending:blend-add': {
|
|
'args': ['frame', 'frame2'],
|
|
'returns': 'frame',
|
|
},
|
|
'blending:blend-multiply': {
|
|
'args': ['frame', 'frame2'],
|
|
'returns': 'frame',
|
|
},
|
|
}
|
|
|
|
|
|
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
|
"""Run a primitive through the Python interpreter."""
|
|
if prim_name not in interp.primitives:
|
|
return None
|
|
|
|
func = interp.primitives[prim_name]
|
|
args = []
|
|
for a in test_def['args']:
|
|
if a == 'frame':
|
|
args.append(frame.copy())
|
|
elif a == 'frame2':
|
|
args.append(frame2.copy())
|
|
else:
|
|
args.append(a)
|
|
|
|
try:
|
|
return func(*args)
|
|
except Exception as e:
|
|
return None
|
|
|
|
|
|
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
|
"""Run a primitive through the JAX compiler."""
|
|
try:
|
|
from streaming.sexp_to_jax import JaxCompiler
|
|
import jax.numpy as jnp
|
|
|
|
compiler = JaxCompiler()
|
|
|
|
# Build a simple expression to test the primitive
|
|
from sexp_effects.parser import Symbol, Keyword
|
|
|
|
args = []
|
|
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
|
|
|
for a in test_def['args']:
|
|
if a == 'frame':
|
|
args.append(Symbol('frame'))
|
|
elif a == 'frame2':
|
|
args.append(Symbol('frame2'))
|
|
else:
|
|
args.append(a)
|
|
|
|
# Create expression: (prim_name arg1 arg2 ...)
|
|
expr = [Symbol(prim_name)] + args
|
|
|
|
result = compiler._eval(expr, env)
|
|
|
|
if hasattr(result, '__array__'):
|
|
return np.asarray(result)
|
|
return result
|
|
|
|
except Exception as e:
|
|
return None
|
|
|
|
|
|
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
|
"""Test a single primitive."""
|
|
frame = create_test_frame(pattern='gradient')
|
|
frame2 = create_test_frame(pattern='checkerboard')
|
|
|
|
result = TestResult(primitive=prim_name, passed=False)
|
|
|
|
# Run Python version
|
|
try:
|
|
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
|
if py_out is not None and hasattr(py_out, 'shape'):
|
|
result.python_mean = float(np.mean(py_out))
|
|
except Exception as e:
|
|
result.error = f"Python error: {e}"
|
|
return result
|
|
|
|
# Run JAX version
|
|
try:
|
|
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
|
if jax_out is not None and hasattr(jax_out, 'shape'):
|
|
result.jax_mean = float(np.mean(jax_out))
|
|
except Exception as e:
|
|
result.error = f"JAX error: {e}"
|
|
return result
|
|
|
|
if py_out is None:
|
|
result.error = "Python returned None"
|
|
return result
|
|
if jax_out is None:
|
|
result.error = "JAX returned None"
|
|
return result
|
|
|
|
# Compare
|
|
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
|
result.diff_mean = diff_mean
|
|
result.diff_max = diff_max
|
|
result.pct_within_1 = pct
|
|
|
|
# Check pass/fail
|
|
result.passed = (
|
|
diff_mean <= TOLERANCE_MEAN and
|
|
diff_max <= TOLERANCE_MAX and
|
|
pct >= TOLERANCE_PCT
|
|
)
|
|
|
|
if not result.passed:
|
|
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
|
|
|
return result
|
|
|
|
|
|
def discover_primitives(interp) -> List[str]:
|
|
"""Discover all primitives available in the interpreter."""
|
|
return sorted(interp.primitives.keys())
|
|
|
|
|
|
def run_all_tests(verbose=True):
|
|
"""Run all primitive tests."""
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
import os
|
|
os.chdir('/home/giles/art/test')
|
|
|
|
from streaming.stream_sexp_generic import StreamInterpreter
|
|
from sexp_effects.primitive_libs import core as core_mod
|
|
|
|
core_mod.set_random_seed(42)
|
|
|
|
# Create interpreter to get Python primitives
|
|
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
|
interp._init()
|
|
|
|
results = []
|
|
|
|
print("=" * 60)
|
|
print("JAX PRIMITIVE TEST SUITE")
|
|
print("=" * 60)
|
|
|
|
# Test defined primitives
|
|
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
|
result = test_primitive(interp, prim_name, test_def)
|
|
results.append(result)
|
|
|
|
status = "✓ PASS" if result.passed else "✗ FAIL"
|
|
if verbose:
|
|
print(f"{status} {prim_name}")
|
|
if not result.passed:
|
|
print(f" {result.error}")
|
|
|
|
# Summary
|
|
passed = sum(1 for r in results if r.passed)
|
|
failed = sum(1 for r in results if not r.passed)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"SUMMARY: {passed} passed, {failed} failed")
|
|
print("=" * 60)
|
|
|
|
if failed > 0:
|
|
print("\nFailed primitives:")
|
|
for r in results:
|
|
if not r.passed:
|
|
print(f" - {r.primitive}: {r.error}")
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_all_tests()
|