Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
517
tests/test_jax_pipeline_integration.py
Normal file
517
tests/test_jax_pipeline_integration.py
Normal file
@@ -0,0 +1,517 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Integration tests comparing JAX and Python rendering pipelines.
|
||||
|
||||
These tests ensure the JAX-compiled effect chains produce identical output
|
||||
to the Python/NumPy path. They test:
|
||||
1. Full effect pipelines through both interpreters
|
||||
2. Multi-frame sequences (to catch phase-dependent bugs)
|
||||
3. Compiled effect chain fusion
|
||||
4. Edge cases like shrinking/zooming that affect boundary handling
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import numpy as np
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the art-celery module is importable
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
|
||||
# Path to test resources
|
||||
TEST_DIR = Path('/home/giles/art/test')
|
||||
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
|
||||
TEMPLATES_DIR = TEST_DIR / 'templates'
|
||||
|
||||
|
||||
def create_test_image(h=96, w=128):
|
||||
"""Create a test image with distinct patterns."""
|
||||
import cv2
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# Create gradient background
|
||||
for y in range(h):
|
||||
for x in range(w):
|
||||
img[y, x] = [
|
||||
int(255 * x / w), # R: horizontal gradient
|
||||
int(255 * y / h), # G: vertical gradient
|
||||
128 # B: constant
|
||||
]
|
||||
|
||||
# Add features
|
||||
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
|
||||
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def test_env(tmp_path_factory):
|
||||
"""Set up test environment with sexp files and test media."""
|
||||
test_dir = tmp_path_factory.mktemp('sexp_test')
|
||||
original_dir = os.getcwd()
|
||||
os.chdir(test_dir)
|
||||
|
||||
# Create directories
|
||||
(test_dir / 'effects').mkdir()
|
||||
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
|
||||
|
||||
# Create test image
|
||||
import cv2
|
||||
test_img = create_test_image()
|
||||
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
|
||||
|
||||
# Copy required effect files
|
||||
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
|
||||
src = EFFECTS_DIR / f'{effect}.sexp'
|
||||
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
|
||||
if src.exists():
|
||||
shutil.copy(src, dst)
|
||||
|
||||
yield {
|
||||
'dir': test_dir,
|
||||
'image_path': test_dir / 'test_image.png',
|
||||
'test_img': test_img,
|
||||
}
|
||||
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
def create_sexp_file(test_dir, content, filename='test.sexp'):
|
||||
"""Create a test sexp file."""
|
||||
path = test_dir / 'effects' / filename
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
class TestJaxPythonPipelineEquivalence:
|
||||
"""Test that JAX and Python pipelines produce equivalent output."""
|
||||
|
||||
def test_single_rotate_effect(self, test_env):
|
||||
"""Test that a single rotate effect matches between Python and JAX."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
|
||||
(frame (rotate frame :angle 15 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
# Python path
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
# JAX path
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
# Inject test image into globals
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
|
||||
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
|
||||
|
||||
# Force deferred if needed
|
||||
py_result = np.asarray(py_interp._maybe_force(py_result))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_rotate_then_zoom(self, test_env):
|
||||
"""Test rotate followed by zoom - tests effect chain fusion."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (-> (rotate frame :angle 15 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_zoom_shrink_boundary_handling(self, test_env):
|
||||
"""Test zoom with shrinking factor - critical for boundary handling."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (zoom frame :amount 0.8 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
# Check corners specifically - these are most affected by boundary handling
|
||||
h, w = test_img.shape[:2]
|
||||
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
|
||||
for y, x in corners:
|
||||
py_val = py_result[y, x]
|
||||
jax_val = jax_result[y, x]
|
||||
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
|
||||
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
|
||||
|
||||
def test_blend_two_clips(self, test_env):
|
||||
"""Test blending two effect chains - the core bug scenario."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
# Check edge region specifically
|
||||
edge_diff = diff[0, :].mean()
|
||||
|
||||
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_blend_with_invert(self, test_env):
|
||||
"""Test blending with invert - matches the problematic recipe pattern."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(require-primitives "color_ops")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
(effect invert :path "../sexp_effects/effects/invert.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0)
|
||||
(invert :amount 1))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
|
||||
class TestDeferredEffectChainFusion:
|
||||
"""Test the DeferredEffectChain fusion mechanism specifically."""
|
||||
|
||||
def test_manual_vs_fused_chain(self, test_env):
|
||||
"""Compare manual application vs fused DeferredEffectChain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
# Create minimal sexp to load effects
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Manual step-by-step application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
rot_angle = -5.0
|
||||
zoom_amount = 0.95
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain application
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
|
||||
|
||||
# Check specific pixels
|
||||
h, w = test_img.shape[:2]
|
||||
for y in [0, h//2, h-1]:
|
||||
for x in [0, w//2, w-1]:
|
||||
manual_val = manual_result[y, x]
|
||||
fused_val = fused_result[y, x]
|
||||
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
|
||||
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
|
||||
|
||||
def test_chain_with_shrink_zoom_boundary(self, test_env):
|
||||
"""Test that shrinking zoom handles boundaries correctly in chain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Parameters that shrink the image (zoom < 1.0)
|
||||
rot_angle = -4.555
|
||||
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
|
||||
|
||||
# Manual application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
# Check top edge specifically - this is where boundary issues manifest
|
||||
top_edge_manual = manual_result[0, :]
|
||||
top_edge_fused = fused_result[0, :]
|
||||
|
||||
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
|
||||
mean_edge_diff = np.mean(edge_diff)
|
||||
|
||||
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
|
||||
|
||||
# Check for zeros at edges that shouldn't be there
|
||||
manual_edge_sum = np.sum(top_edge_manual)
|
||||
fused_edge_sum = np.sum(top_edge_fused)
|
||||
|
||||
if manual_edge_sum > 100: # If manual has significant values
|
||||
assert fused_edge_sum > manual_edge_sum * 0.5, \
|
||||
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
334
tests/test_jax_primitives.py
Normal file
334
tests/test_jax_primitives.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test framework to verify JAX primitives match Python primitives.
|
||||
|
||||
Compares output of each primitive through:
|
||||
1. Python/NumPy path
|
||||
2. JAX path (CPU)
|
||||
3. JAX path (GPU) - if available
|
||||
|
||||
Reports any mismatches with detailed diffs.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/home/giles/art/art-celery')
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Test configuration
|
||||
TEST_WIDTH = 64
|
||||
TEST_HEIGHT = 48
|
||||
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
||||
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
||||
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
primitive: str
|
||||
passed: bool
|
||||
python_mean: float = 0.0
|
||||
jax_mean: float = 0.0
|
||||
diff_mean: float = 0.0
|
||||
diff_max: float = 0.0
|
||||
pct_within_1: float = 0.0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
||||
"""Create a test frame with known pattern."""
|
||||
if pattern == 'gradient':
|
||||
# Diagonal gradient
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
r = (x * 255 / w).astype(np.uint8)
|
||||
g = (y * 255 / h).astype(np.uint8)
|
||||
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
||||
return np.stack([r, g, b], axis=2)
|
||||
elif pattern == 'checkerboard':
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
check = ((x // 8) + (y // 8)) % 2
|
||||
v = (check * 255).astype(np.uint8)
|
||||
return np.stack([v, v, v], axis=2)
|
||||
elif pattern == 'solid':
|
||||
return np.full((h, w, 3), 128, dtype=np.uint8)
|
||||
else:
|
||||
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
||||
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
||||
if py_out is None or jax_out is None:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
||||
# Compare coordinate maps
|
||||
diffs = []
|
||||
for k in py_out:
|
||||
if k in jax_out:
|
||||
py_arr = np.asarray(py_out[k])
|
||||
jax_arr = np.asarray(jax_out[k])
|
||||
if py_arr.shape == jax_arr.shape:
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
diffs.append(diff)
|
||||
if diffs:
|
||||
all_diff = np.concatenate([d.flatten() for d in diffs])
|
||||
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
py_arr = np.asarray(py_out)
|
||||
jax_arr = np.asarray(jax_out)
|
||||
|
||||
if py_arr.shape != jax_arr.shape:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Test Definitions
|
||||
# ============================================================================
|
||||
|
||||
PRIMITIVE_TESTS = {
|
||||
# Geometry primitives
|
||||
'geometry:ripple-displace': {
|
||||
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
||||
'returns': 'coords',
|
||||
},
|
||||
'geometry:rotate-img': {
|
||||
'args': ['frame', 45],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:scale-img': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-h': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-v': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Color operations
|
||||
'color_ops:invert': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:grayscale': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:brightness': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:contrast': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:hue-shift': {
|
||||
'args': ['frame', 90],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Image operations
|
||||
'image:width': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:height': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:channel': {
|
||||
'args': ['frame', 0],
|
||||
'returns': 'array',
|
||||
},
|
||||
|
||||
# Blending
|
||||
'blending:blend': {
|
||||
'args': ['frame', 'frame2', 0.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-add': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-multiply': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the Python interpreter."""
|
||||
if prim_name not in interp.primitives:
|
||||
return None
|
||||
|
||||
func = interp.primitives[prim_name]
|
||||
args = []
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(frame.copy())
|
||||
elif a == 'frame2':
|
||||
args.append(frame2.copy())
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
try:
|
||||
return func(*args)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the JAX compiler."""
|
||||
try:
|
||||
from streaming.sexp_to_jax import JaxCompiler
|
||||
import jax.numpy as jnp
|
||||
|
||||
compiler = JaxCompiler()
|
||||
|
||||
# Build a simple expression to test the primitive
|
||||
from sexp_effects.parser import Symbol, Keyword
|
||||
|
||||
args = []
|
||||
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
||||
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(Symbol('frame'))
|
||||
elif a == 'frame2':
|
||||
args.append(Symbol('frame2'))
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
# Create expression: (prim_name arg1 arg2 ...)
|
||||
expr = [Symbol(prim_name)] + args
|
||||
|
||||
result = compiler._eval(expr, env)
|
||||
|
||||
if hasattr(result, '__array__'):
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
||||
"""Test a single primitive."""
|
||||
frame = create_test_frame(pattern='gradient')
|
||||
frame2 = create_test_frame(pattern='checkerboard')
|
||||
|
||||
result = TestResult(primitive=prim_name, passed=False)
|
||||
|
||||
# Run Python version
|
||||
try:
|
||||
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
||||
if py_out is not None and hasattr(py_out, 'shape'):
|
||||
result.python_mean = float(np.mean(py_out))
|
||||
except Exception as e:
|
||||
result.error = f"Python error: {e}"
|
||||
return result
|
||||
|
||||
# Run JAX version
|
||||
try:
|
||||
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
||||
if jax_out is not None and hasattr(jax_out, 'shape'):
|
||||
result.jax_mean = float(np.mean(jax_out))
|
||||
except Exception as e:
|
||||
result.error = f"JAX error: {e}"
|
||||
return result
|
||||
|
||||
if py_out is None:
|
||||
result.error = "Python returned None"
|
||||
return result
|
||||
if jax_out is None:
|
||||
result.error = "JAX returned None"
|
||||
return result
|
||||
|
||||
# Compare
|
||||
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
||||
result.diff_mean = diff_mean
|
||||
result.diff_max = diff_max
|
||||
result.pct_within_1 = pct
|
||||
|
||||
# Check pass/fail
|
||||
result.passed = (
|
||||
diff_mean <= TOLERANCE_MEAN and
|
||||
diff_max <= TOLERANCE_MAX and
|
||||
pct >= TOLERANCE_PCT
|
||||
)
|
||||
|
||||
if not result.passed:
|
||||
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_primitives(interp) -> List[str]:
|
||||
"""Discover all primitives available in the interpreter."""
|
||||
return sorted(interp.primitives.keys())
|
||||
|
||||
|
||||
def run_all_tests(verbose=True):
|
||||
"""Run all primitive tests."""
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import os
|
||||
os.chdir('/home/giles/art/test')
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
|
||||
# Create interpreter to get Python primitives
|
||||
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
||||
interp._init()
|
||||
|
||||
results = []
|
||||
|
||||
print("=" * 60)
|
||||
print("JAX PRIMITIVE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Test defined primitives
|
||||
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
||||
result = test_primitive(interp, prim_name, test_def)
|
||||
results.append(result)
|
||||
|
||||
status = "✓ PASS" if result.passed else "✗ FAIL"
|
||||
if verbose:
|
||||
print(f"{status} {prim_name}")
|
||||
if not result.passed:
|
||||
print(f" {result.error}")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
failed = sum(1 for r in results if not r.passed)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"SUMMARY: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed > 0:
|
||||
print("\nFailed primitives:")
|
||||
for r in results:
|
||||
if not r.passed:
|
||||
print(f" - {r.primitive}: {r.error}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_all_tests()
|
||||
305
tests/test_xector.py
Normal file
305
tests/test_xector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Tests for xector primitives - parallel array operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sexp_effects.primitive_libs.xector import (
|
||||
Xector,
|
||||
xector_red, xector_green, xector_blue, xector_rgb,
|
||||
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
|
||||
xector_dist_from_center,
|
||||
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
|
||||
alpha_sin, alpha_cos, alpha_sq,
|
||||
alpha_lt, alpha_gt, alpha_eq,
|
||||
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
|
||||
xector_where, xector_fill, xector_zeros, xector_ones,
|
||||
is_xector,
|
||||
)
|
||||
|
||||
|
||||
class TestXectorBasics:
|
||||
"""Test Xector class basic operations."""
|
||||
|
||||
def test_create_from_list(self):
|
||||
x = Xector([1, 2, 3])
|
||||
assert len(x) == 3
|
||||
assert is_xector(x)
|
||||
|
||||
def test_create_from_numpy(self):
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
x = Xector(arr)
|
||||
assert len(x) == 3
|
||||
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
|
||||
|
||||
def test_implicit_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = a + b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_implicit_mul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([2, 2, 2])
|
||||
c = a * b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_scalar_broadcast(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = a + 10
|
||||
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
|
||||
|
||||
def test_scalar_broadcast_rmul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = 2 * a
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
|
||||
class TestAlphaOperations:
|
||||
"""Test α (element-wise) operations."""
|
||||
|
||||
def test_alpha_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = alpha_add(a, b)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_alpha_add_multi(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([1, 1, 1])
|
||||
c = Xector([10, 10, 10])
|
||||
d = alpha_add(a, b, c)
|
||||
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
|
||||
|
||||
def test_alpha_mul_scalar(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = alpha_mul(a, 2)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_alpha_sqrt(self):
|
||||
a = Xector([1, 4, 9, 16])
|
||||
c = alpha_sqrt(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_alpha_clamp(self):
|
||||
a = Xector([-5, 0, 5, 10, 15])
|
||||
c = alpha_clamp(a, 0, 10)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
|
||||
|
||||
def test_alpha_sin_cos(self):
|
||||
a = Xector([0, np.pi/2, np.pi])
|
||||
s = alpha_sin(a)
|
||||
c = alpha_cos(a)
|
||||
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
|
||||
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
|
||||
|
||||
def test_alpha_sq(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
c = alpha_sq(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
|
||||
|
||||
def test_alpha_comparison(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
b = Xector([2, 2, 2, 2])
|
||||
lt = alpha_lt(a, b)
|
||||
gt = alpha_gt(a, b)
|
||||
eq = alpha_eq(a, b)
|
||||
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
|
||||
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
|
||||
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
|
||||
|
||||
|
||||
class TestBetaOperations:
|
||||
"""Test β (reduction) operations."""
|
||||
|
||||
def test_beta_add(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_add(a) == 10
|
||||
|
||||
def test_beta_mul(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mul(a) == 24
|
||||
|
||||
def test_beta_min_max(self):
|
||||
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert beta_min(a) == 1
|
||||
assert beta_max(a) == 9
|
||||
|
||||
def test_beta_mean(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mean(a) == 2.5
|
||||
|
||||
def test_beta_count(self):
|
||||
a = Xector([1, 2, 3, 4, 5])
|
||||
assert beta_count(a) == 5
|
||||
|
||||
|
||||
class TestFrameConversion:
|
||||
"""Test frame/xector conversion."""
|
||||
|
||||
def test_extract_channels(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[255, 0, 0], [0, 255, 0]],
|
||||
[[0, 0, 255], [128, 128, 128]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
assert len(r) == 4
|
||||
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
|
||||
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
|
||||
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
|
||||
|
||||
def test_rgb_roundtrip(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[100, 150, 200], [50, 75, 100]],
|
||||
[[200, 100, 50], [25, 50, 75]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
reconstructed = xector_rgb(r, g, b)
|
||||
np.testing.assert_array_equal(reconstructed, frame)
|
||||
|
||||
def test_modify_and_reconstruct(self):
|
||||
frame = np.array([
|
||||
[[100, 100, 100], [100, 100, 100]],
|
||||
[[100, 100, 100], [100, 100, 100]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
# Double red channel
|
||||
r_doubled = r * 2
|
||||
|
||||
result = xector_rgb(r_doubled, g, b)
|
||||
|
||||
# Red should be 200, others unchanged
|
||||
assert result[0, 0, 0] == 200
|
||||
assert result[0, 0, 1] == 100
|
||||
assert result[0, 0, 2] == 100
|
||||
|
||||
|
||||
class TestCoordinates:
|
||||
"""Test coordinate generation."""
|
||||
|
||||
def test_x_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
x = xector_x_coords(frame)
|
||||
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
|
||||
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
|
||||
|
||||
def test_y_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
y = xector_y_coords(frame)
|
||||
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
|
||||
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
|
||||
|
||||
def test_normalized_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_x_norm(frame)
|
||||
y = xector_y_norm(frame)
|
||||
|
||||
# x should go 0 to 1 across width
|
||||
assert x.to_numpy()[0] == 0
|
||||
assert x.to_numpy()[2] == 1
|
||||
|
||||
# y should go 0 to 1 down height
|
||||
assert y.to_numpy()[0] == 0
|
||||
assert y.to_numpy()[3] == 1
|
||||
|
||||
|
||||
class TestConditional:
|
||||
"""Test conditional operations."""
|
||||
|
||||
def test_where(self):
|
||||
cond = Xector([True, False, True, False])
|
||||
true_val = Xector([1, 1, 1, 1])
|
||||
false_val = Xector([0, 0, 0, 0])
|
||||
|
||||
result = xector_where(cond, true_val, false_val)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
|
||||
|
||||
def test_where_with_comparison(self):
|
||||
a = Xector([1, 5, 3, 7])
|
||||
threshold = 4
|
||||
|
||||
# Elements > 4 become 255, others become 0
|
||||
result = xector_where(alpha_gt(a, threshold), 255, 0)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
|
||||
|
||||
def test_fill(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_fill(42, frame)
|
||||
assert len(x) == 6
|
||||
assert all(v == 42 for v in x.to_numpy())
|
||||
|
||||
def test_zeros_ones(self):
|
||||
frame = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||
z = xector_zeros(frame)
|
||||
o = xector_ones(frame)
|
||||
|
||||
assert all(v == 0 for v in z.to_numpy())
|
||||
assert all(v == 1 for v in o.to_numpy())
|
||||
|
||||
|
||||
class TestInterpreterIntegration:
|
||||
"""Test xector operations through the interpreter."""
|
||||
|
||||
def test_xector_vignette_effect(self):
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
|
||||
# Load the xector vignette effect
|
||||
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
|
||||
|
||||
# Create a test frame (white)
|
||||
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
|
||||
# Run effect
|
||||
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
|
||||
|
||||
# Center should be brighter than corners
|
||||
center = result[50, 50]
|
||||
corner = result[0, 0]
|
||||
|
||||
assert center.mean() > corner.mean(), "Center should be brighter than corners"
|
||||
# Corners should be darkened
|
||||
assert corner.mean() < 255, "Corners should be darkened"
|
||||
|
||||
def test_implicit_elementwise(self):
|
||||
"""Test that regular + works element-wise on xectors."""
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
# Load xector primitives
|
||||
from sexp_effects.primitive_libs.xector import PRIMITIVES
|
||||
for name, fn in PRIMITIVES.items():
|
||||
interp.global_env.set(name, fn)
|
||||
|
||||
# Parse and eval a simple xector expression
|
||||
from sexp_effects.parser import parse
|
||||
expr = parse('(+ (red frame) 10)')
|
||||
|
||||
# Create test frame
|
||||
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
|
||||
interp.global_env.set('frame', frame)
|
||||
|
||||
result = interp.eval(expr)
|
||||
|
||||
# Should be a xector with values 110
|
||||
assert is_xector(result)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user