Files
rose-ash/artdag/l1/tests/test_jax_primitives.py
giles 1a74d811f7
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 2m33s
Incorporate art-dag-mono repo into artdag/ subfolder
Merges full history from art-dag/mono.git into the monorepo
under the artdag/ directory. Contains: core (DAG engine),
l1 (Celery rendering server), l2 (ActivityPub registry),
common (shared templates/middleware), client (CLI), test (e2e).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

git-subtree-dir: artdag
git-subtree-mainline: 1a179de547
git-subtree-split: 4c2e716558
2026-02-27 09:07:23 +00:00

335 lines
9.3 KiB
Python

#!/usr/bin/env python3
"""
Test framework to verify JAX primitives match Python primitives.
Compares output of each primitive through:
1. Python/NumPy path
2. JAX path (CPU)
3. JAX path (GPU) - if available
Reports any mismatches with detailed diffs.
"""
import sys
sys.path.insert(0, '/home/giles/art/art-celery')
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass, field
# Test configuration
TEST_WIDTH = 64
TEST_HEIGHT = 48
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
@dataclass
class TestResult:
primitive: str
passed: bool
python_mean: float = 0.0
jax_mean: float = 0.0
diff_mean: float = 0.0
diff_max: float = 0.0
pct_within_1: float = 0.0
error: str = ""
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
"""Create a test frame with known pattern."""
if pattern == 'gradient':
# Diagonal gradient
y, x = np.mgrid[0:h, 0:w]
r = (x * 255 / w).astype(np.uint8)
g = (y * 255 / h).astype(np.uint8)
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
return np.stack([r, g, b], axis=2)
elif pattern == 'checkerboard':
y, x = np.mgrid[0:h, 0:w]
check = ((x // 8) + (y // 8)) % 2
v = (check * 255).astype(np.uint8)
return np.stack([v, v, v], axis=2)
elif pattern == 'solid':
return np.full((h, w, 3), 128, dtype=np.uint8)
else:
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
if py_out is None or jax_out is None:
return float('inf'), float('inf'), 0.0
if isinstance(py_out, dict) and isinstance(jax_out, dict):
# Compare coordinate maps
diffs = []
for k in py_out:
if k in jax_out:
py_arr = np.asarray(py_out[k])
jax_arr = np.asarray(jax_out[k])
if py_arr.shape == jax_arr.shape:
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
diffs.append(diff)
if diffs:
all_diff = np.concatenate([d.flatten() for d in diffs])
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
return float('inf'), float('inf'), 0.0
py_arr = np.asarray(py_out)
jax_arr = np.asarray(jax_out)
if py_arr.shape != jax_arr.shape:
return float('inf'), float('inf'), 0.0
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
# ============================================================================
# Primitive Test Definitions
# ============================================================================
PRIMITIVE_TESTS = {
# Geometry primitives
'geometry:ripple-displace': {
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
'returns': 'coords',
},
'geometry:rotate-img': {
'args': ['frame', 45],
'returns': 'frame',
},
'geometry:scale-img': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'geometry:flip-h': {
'args': ['frame'],
'returns': 'frame',
},
'geometry:flip-v': {
'args': ['frame'],
'returns': 'frame',
},
# Color operations
'color_ops:invert': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:grayscale': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:brightness': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:contrast': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:hue-shift': {
'args': ['frame', 90],
'returns': 'frame',
},
# Image operations
'image:width': {
'args': ['frame'],
'returns': 'scalar',
},
'image:height': {
'args': ['frame'],
'returns': 'scalar',
},
'image:channel': {
'args': ['frame', 0],
'returns': 'array',
},
# Blending
'blending:blend': {
'args': ['frame', 'frame2', 0.5],
'returns': 'frame',
},
'blending:blend-add': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
'blending:blend-multiply': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
}
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the Python interpreter."""
if prim_name not in interp.primitives:
return None
func = interp.primitives[prim_name]
args = []
for a in test_def['args']:
if a == 'frame':
args.append(frame.copy())
elif a == 'frame2':
args.append(frame2.copy())
else:
args.append(a)
try:
return func(*args)
except Exception as e:
return None
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the JAX compiler."""
try:
from streaming.sexp_to_jax import JaxCompiler
import jax.numpy as jnp
compiler = JaxCompiler()
# Build a simple expression to test the primitive
from sexp_effects.parser import Symbol, Keyword
args = []
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
for a in test_def['args']:
if a == 'frame':
args.append(Symbol('frame'))
elif a == 'frame2':
args.append(Symbol('frame2'))
else:
args.append(a)
# Create expression: (prim_name arg1 arg2 ...)
expr = [Symbol(prim_name)] + args
result = compiler._eval(expr, env)
if hasattr(result, '__array__'):
return np.asarray(result)
return result
except Exception as e:
return None
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
"""Test a single primitive."""
frame = create_test_frame(pattern='gradient')
frame2 = create_test_frame(pattern='checkerboard')
result = TestResult(primitive=prim_name, passed=False)
# Run Python version
try:
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
if py_out is not None and hasattr(py_out, 'shape'):
result.python_mean = float(np.mean(py_out))
except Exception as e:
result.error = f"Python error: {e}"
return result
# Run JAX version
try:
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
if jax_out is not None and hasattr(jax_out, 'shape'):
result.jax_mean = float(np.mean(jax_out))
except Exception as e:
result.error = f"JAX error: {e}"
return result
if py_out is None:
result.error = "Python returned None"
return result
if jax_out is None:
result.error = "JAX returned None"
return result
# Compare
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
result.diff_mean = diff_mean
result.diff_max = diff_max
result.pct_within_1 = pct
# Check pass/fail
result.passed = (
diff_mean <= TOLERANCE_MEAN and
diff_max <= TOLERANCE_MAX and
pct >= TOLERANCE_PCT
)
if not result.passed:
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
return result
def discover_primitives(interp) -> List[str]:
"""Discover all primitives available in the interpreter."""
return sorted(interp.primitives.keys())
def run_all_tests(verbose=True):
"""Run all primitive tests."""
import warnings
warnings.filterwarnings('ignore')
import os
os.chdir('/home/giles/art/test')
from streaming.stream_sexp_generic import StreamInterpreter
from sexp_effects.primitive_libs import core as core_mod
core_mod.set_random_seed(42)
# Create interpreter to get Python primitives
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
interp._init()
results = []
print("=" * 60)
print("JAX PRIMITIVE TEST SUITE")
print("=" * 60)
# Test defined primitives
for prim_name, test_def in PRIMITIVE_TESTS.items():
result = test_primitive(interp, prim_name, test_def)
results.append(result)
status = "✓ PASS" if result.passed else "✗ FAIL"
if verbose:
print(f"{status} {prim_name}")
if not result.passed:
print(f" {result.error}")
# Summary
passed = sum(1 for r in results if r.passed)
failed = sum(1 for r in results if not r.passed)
print("\n" + "=" * 60)
print(f"SUMMARY: {passed} passed, {failed} failed")
print("=" * 60)
if failed > 0:
print("\nFailed primitives:")
for r in results:
if not r.passed:
print(f" - {r.primitive}: {r.error}")
return results
if __name__ == '__main__':
run_all_tests()