Import L1 (celery) as l1/
This commit is contained in:
305
l1/tests/test_xector.py
Normal file
305
l1/tests/test_xector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Tests for xector primitives - parallel array operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sexp_effects.primitive_libs.xector import (
|
||||
Xector,
|
||||
xector_red, xector_green, xector_blue, xector_rgb,
|
||||
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
|
||||
xector_dist_from_center,
|
||||
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
|
||||
alpha_sin, alpha_cos, alpha_sq,
|
||||
alpha_lt, alpha_gt, alpha_eq,
|
||||
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
|
||||
xector_where, xector_fill, xector_zeros, xector_ones,
|
||||
is_xector,
|
||||
)
|
||||
|
||||
|
||||
class TestXectorBasics:
|
||||
"""Test Xector class basic operations."""
|
||||
|
||||
def test_create_from_list(self):
|
||||
x = Xector([1, 2, 3])
|
||||
assert len(x) == 3
|
||||
assert is_xector(x)
|
||||
|
||||
def test_create_from_numpy(self):
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
x = Xector(arr)
|
||||
assert len(x) == 3
|
||||
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
|
||||
|
||||
def test_implicit_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = a + b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_implicit_mul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([2, 2, 2])
|
||||
c = a * b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_scalar_broadcast(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = a + 10
|
||||
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
|
||||
|
||||
def test_scalar_broadcast_rmul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = 2 * a
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
|
||||
class TestAlphaOperations:
|
||||
"""Test α (element-wise) operations."""
|
||||
|
||||
def test_alpha_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = alpha_add(a, b)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_alpha_add_multi(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([1, 1, 1])
|
||||
c = Xector([10, 10, 10])
|
||||
d = alpha_add(a, b, c)
|
||||
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
|
||||
|
||||
def test_alpha_mul_scalar(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = alpha_mul(a, 2)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_alpha_sqrt(self):
|
||||
a = Xector([1, 4, 9, 16])
|
||||
c = alpha_sqrt(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_alpha_clamp(self):
|
||||
a = Xector([-5, 0, 5, 10, 15])
|
||||
c = alpha_clamp(a, 0, 10)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
|
||||
|
||||
def test_alpha_sin_cos(self):
|
||||
a = Xector([0, np.pi/2, np.pi])
|
||||
s = alpha_sin(a)
|
||||
c = alpha_cos(a)
|
||||
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
|
||||
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
|
||||
|
||||
def test_alpha_sq(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
c = alpha_sq(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
|
||||
|
||||
def test_alpha_comparison(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
b = Xector([2, 2, 2, 2])
|
||||
lt = alpha_lt(a, b)
|
||||
gt = alpha_gt(a, b)
|
||||
eq = alpha_eq(a, b)
|
||||
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
|
||||
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
|
||||
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
|
||||
|
||||
|
||||
class TestBetaOperations:
|
||||
"""Test β (reduction) operations."""
|
||||
|
||||
def test_beta_add(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_add(a) == 10
|
||||
|
||||
def test_beta_mul(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mul(a) == 24
|
||||
|
||||
def test_beta_min_max(self):
|
||||
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert beta_min(a) == 1
|
||||
assert beta_max(a) == 9
|
||||
|
||||
def test_beta_mean(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mean(a) == 2.5
|
||||
|
||||
def test_beta_count(self):
|
||||
a = Xector([1, 2, 3, 4, 5])
|
||||
assert beta_count(a) == 5
|
||||
|
||||
|
||||
class TestFrameConversion:
|
||||
"""Test frame/xector conversion."""
|
||||
|
||||
def test_extract_channels(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[255, 0, 0], [0, 255, 0]],
|
||||
[[0, 0, 255], [128, 128, 128]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
assert len(r) == 4
|
||||
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
|
||||
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
|
||||
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
|
||||
|
||||
def test_rgb_roundtrip(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[100, 150, 200], [50, 75, 100]],
|
||||
[[200, 100, 50], [25, 50, 75]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
reconstructed = xector_rgb(r, g, b)
|
||||
np.testing.assert_array_equal(reconstructed, frame)
|
||||
|
||||
def test_modify_and_reconstruct(self):
|
||||
frame = np.array([
|
||||
[[100, 100, 100], [100, 100, 100]],
|
||||
[[100, 100, 100], [100, 100, 100]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
# Double red channel
|
||||
r_doubled = r * 2
|
||||
|
||||
result = xector_rgb(r_doubled, g, b)
|
||||
|
||||
# Red should be 200, others unchanged
|
||||
assert result[0, 0, 0] == 200
|
||||
assert result[0, 0, 1] == 100
|
||||
assert result[0, 0, 2] == 100
|
||||
|
||||
|
||||
class TestCoordinates:
|
||||
"""Test coordinate generation."""
|
||||
|
||||
def test_x_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
x = xector_x_coords(frame)
|
||||
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
|
||||
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
|
||||
|
||||
def test_y_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
y = xector_y_coords(frame)
|
||||
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
|
||||
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
|
||||
|
||||
def test_normalized_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_x_norm(frame)
|
||||
y = xector_y_norm(frame)
|
||||
|
||||
# x should go 0 to 1 across width
|
||||
assert x.to_numpy()[0] == 0
|
||||
assert x.to_numpy()[2] == 1
|
||||
|
||||
# y should go 0 to 1 down height
|
||||
assert y.to_numpy()[0] == 0
|
||||
assert y.to_numpy()[3] == 1
|
||||
|
||||
|
||||
class TestConditional:
|
||||
"""Test conditional operations."""
|
||||
|
||||
def test_where(self):
|
||||
cond = Xector([True, False, True, False])
|
||||
true_val = Xector([1, 1, 1, 1])
|
||||
false_val = Xector([0, 0, 0, 0])
|
||||
|
||||
result = xector_where(cond, true_val, false_val)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
|
||||
|
||||
def test_where_with_comparison(self):
|
||||
a = Xector([1, 5, 3, 7])
|
||||
threshold = 4
|
||||
|
||||
# Elements > 4 become 255, others become 0
|
||||
result = xector_where(alpha_gt(a, threshold), 255, 0)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
|
||||
|
||||
def test_fill(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_fill(42, frame)
|
||||
assert len(x) == 6
|
||||
assert all(v == 42 for v in x.to_numpy())
|
||||
|
||||
def test_zeros_ones(self):
|
||||
frame = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||
z = xector_zeros(frame)
|
||||
o = xector_ones(frame)
|
||||
|
||||
assert all(v == 0 for v in z.to_numpy())
|
||||
assert all(v == 1 for v in o.to_numpy())
|
||||
|
||||
|
||||
class TestInterpreterIntegration:
|
||||
"""Test xector operations through the interpreter."""
|
||||
|
||||
def test_xector_vignette_effect(self):
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
|
||||
# Load the xector vignette effect
|
||||
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
|
||||
|
||||
# Create a test frame (white)
|
||||
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
|
||||
# Run effect
|
||||
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
|
||||
|
||||
# Center should be brighter than corners
|
||||
center = result[50, 50]
|
||||
corner = result[0, 0]
|
||||
|
||||
assert center.mean() > corner.mean(), "Center should be brighter than corners"
|
||||
# Corners should be darkened
|
||||
assert corner.mean() < 255, "Corners should be darkened"
|
||||
|
||||
def test_implicit_elementwise(self):
|
||||
"""Test that regular + works element-wise on xectors."""
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
# Load xector primitives
|
||||
from sexp_effects.primitive_libs.xector import PRIMITIVES
|
||||
for name, fn in PRIMITIVES.items():
|
||||
interp.global_env.set(name, fn)
|
||||
|
||||
# Parse and eval a simple xector expression
|
||||
from sexp_effects.parser import parse
|
||||
expr = parse('(+ (red frame) 10)')
|
||||
|
||||
# Create test frame
|
||||
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
|
||||
interp.global_env.set('frame', frame)
|
||||
|
||||
result = interp.eval(expr)
|
||||
|
||||
# Should be a xector with values 110
|
||||
assert is_xector(result)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user