Files
mono/l1/tests/test_xector.py
2026-02-24 23:07:19 +00:00

306 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Tests for xector primitives - parallel array operations.
"""
import pytest
import numpy as np
from sexp_effects.primitive_libs.xector import (
Xector,
xector_red, xector_green, xector_blue, xector_rgb,
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
xector_dist_from_center,
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
alpha_sin, alpha_cos, alpha_sq,
alpha_lt, alpha_gt, alpha_eq,
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
xector_where, xector_fill, xector_zeros, xector_ones,
is_xector,
)
class TestXectorBasics:
"""Test Xector class basic operations."""
def test_create_from_list(self):
x = Xector([1, 2, 3])
assert len(x) == 3
assert is_xector(x)
def test_create_from_numpy(self):
arr = np.array([1.0, 2.0, 3.0])
x = Xector(arr)
assert len(x) == 3
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
def test_implicit_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = a + b
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_implicit_mul(self):
a = Xector([1, 2, 3])
b = Xector([2, 2, 2])
c = a * b
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_scalar_broadcast(self):
a = Xector([1, 2, 3])
c = a + 10
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
def test_scalar_broadcast_rmul(self):
a = Xector([1, 2, 3])
c = 2 * a
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
class TestAlphaOperations:
"""Test α (element-wise) operations."""
def test_alpha_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = alpha_add(a, b)
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_alpha_add_multi(self):
a = Xector([1, 2, 3])
b = Xector([1, 1, 1])
c = Xector([10, 10, 10])
d = alpha_add(a, b, c)
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
def test_alpha_mul_scalar(self):
a = Xector([1, 2, 3])
c = alpha_mul(a, 2)
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_alpha_sqrt(self):
a = Xector([1, 4, 9, 16])
c = alpha_sqrt(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
def test_alpha_clamp(self):
a = Xector([-5, 0, 5, 10, 15])
c = alpha_clamp(a, 0, 10)
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
def test_alpha_sin_cos(self):
a = Xector([0, np.pi/2, np.pi])
s = alpha_sin(a)
c = alpha_cos(a)
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
def test_alpha_sq(self):
a = Xector([1, 2, 3, 4])
c = alpha_sq(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
def test_alpha_comparison(self):
a = Xector([1, 2, 3, 4])
b = Xector([2, 2, 2, 2])
lt = alpha_lt(a, b)
gt = alpha_gt(a, b)
eq = alpha_eq(a, b)
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
class TestBetaOperations:
"""Test β (reduction) operations."""
def test_beta_add(self):
a = Xector([1, 2, 3, 4])
assert beta_add(a) == 10
def test_beta_mul(self):
a = Xector([1, 2, 3, 4])
assert beta_mul(a) == 24
def test_beta_min_max(self):
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
assert beta_min(a) == 1
assert beta_max(a) == 9
def test_beta_mean(self):
a = Xector([1, 2, 3, 4])
assert beta_mean(a) == 2.5
def test_beta_count(self):
a = Xector([1, 2, 3, 4, 5])
assert beta_count(a) == 5
class TestFrameConversion:
"""Test frame/xector conversion."""
def test_extract_channels(self):
# Create a 2x2 RGB frame
frame = np.array([
[[255, 0, 0], [0, 255, 0]],
[[0, 0, 255], [128, 128, 128]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
assert len(r) == 4
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
def test_rgb_roundtrip(self):
# Create a 2x2 RGB frame
frame = np.array([
[[100, 150, 200], [50, 75, 100]],
[[200, 100, 50], [25, 50, 75]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
reconstructed = xector_rgb(r, g, b)
np.testing.assert_array_equal(reconstructed, frame)
def test_modify_and_reconstruct(self):
frame = np.array([
[[100, 100, 100], [100, 100, 100]],
[[100, 100, 100], [100, 100, 100]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
# Double red channel
r_doubled = r * 2
result = xector_rgb(r_doubled, g, b)
# Red should be 200, others unchanged
assert result[0, 0, 0] == 200
assert result[0, 0, 1] == 100
assert result[0, 0, 2] == 100
class TestCoordinates:
"""Test coordinate generation."""
def test_x_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
x = xector_x_coords(frame)
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
def test_y_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
y = xector_y_coords(frame)
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
def test_normalized_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_x_norm(frame)
y = xector_y_norm(frame)
# x should go 0 to 1 across width
assert x.to_numpy()[0] == 0
assert x.to_numpy()[2] == 1
# y should go 0 to 1 down height
assert y.to_numpy()[0] == 0
assert y.to_numpy()[3] == 1
class TestConditional:
"""Test conditional operations."""
def test_where(self):
cond = Xector([True, False, True, False])
true_val = Xector([1, 1, 1, 1])
false_val = Xector([0, 0, 0, 0])
result = xector_where(cond, true_val, false_val)
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
def test_where_with_comparison(self):
a = Xector([1, 5, 3, 7])
threshold = 4
# Elements > 4 become 255, others become 0
result = xector_where(alpha_gt(a, threshold), 255, 0)
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
def test_fill(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_fill(42, frame)
assert len(x) == 6
assert all(v == 42 for v in x.to_numpy())
def test_zeros_ones(self):
frame = np.zeros((2, 2, 3), dtype=np.uint8)
z = xector_zeros(frame)
o = xector_ones(frame)
assert all(v == 0 for v in z.to_numpy())
assert all(v == 1 for v in o.to_numpy())
class TestInterpreterIntegration:
"""Test xector operations through the interpreter."""
def test_xector_vignette_effect(self):
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load the xector vignette effect
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
# Create a test frame (white)
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
# Run effect
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
# Center should be brighter than corners
center = result[50, 50]
corner = result[0, 0]
assert center.mean() > corner.mean(), "Center should be brighter than corners"
# Corners should be darkened
assert corner.mean() < 255, "Corners should be darkened"
def test_implicit_elementwise(self):
"""Test that regular + works element-wise on xectors."""
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load xector primitives
from sexp_effects.primitive_libs.xector import PRIMITIVES
for name, fn in PRIMITIVES.items():
interp.global_env.set(name, fn)
# Parse and eval a simple xector expression
from sexp_effects.parser import parse
expr = parse('(+ (red frame) 10)')
# Create test frame
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
interp.global_env.set('frame', frame)
result = interp.eval(expr)
# Should be a xector with values 110
assert is_xector(result)
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
if __name__ == '__main__':
pytest.main([__file__, '-v'])