""" Tests for xector primitives - parallel array operations. """ import pytest import numpy as np from sexp_effects.primitive_libs.xector import ( Xector, xector_red, xector_green, xector_blue, xector_rgb, xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm, xector_dist_from_center, alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp, alpha_sin, alpha_cos, alpha_sq, alpha_lt, alpha_gt, alpha_eq, beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count, xector_where, xector_fill, xector_zeros, xector_ones, is_xector, ) class TestXectorBasics: """Test Xector class basic operations.""" def test_create_from_list(self): x = Xector([1, 2, 3]) assert len(x) == 3 assert is_xector(x) def test_create_from_numpy(self): arr = np.array([1.0, 2.0, 3.0]) x = Xector(arr) assert len(x) == 3 np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32)) def test_implicit_add(self): a = Xector([1, 2, 3]) b = Xector([4, 5, 6]) c = a + b np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) def test_implicit_mul(self): a = Xector([1, 2, 3]) b = Xector([2, 2, 2]) c = a * b np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) def test_scalar_broadcast(self): a = Xector([1, 2, 3]) c = a + 10 np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13]) def test_scalar_broadcast_rmul(self): a = Xector([1, 2, 3]) c = 2 * a np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) class TestAlphaOperations: """Test α (element-wise) operations.""" def test_alpha_add(self): a = Xector([1, 2, 3]) b = Xector([4, 5, 6]) c = alpha_add(a, b) np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) def test_alpha_add_multi(self): a = Xector([1, 2, 3]) b = Xector([1, 1, 1]) c = Xector([10, 10, 10]) d = alpha_add(a, b, c) np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14]) def test_alpha_mul_scalar(self): a = Xector([1, 2, 3]) c = alpha_mul(a, 2) np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) def test_alpha_sqrt(self): a = Xector([1, 4, 9, 16]) c = alpha_sqrt(a) np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4]) def test_alpha_clamp(self): a = Xector([-5, 0, 5, 10, 15]) c = alpha_clamp(a, 0, 10) np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10]) def test_alpha_sin_cos(self): a = Xector([0, np.pi/2, np.pi]) s = alpha_sin(a) c = alpha_cos(a) np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5) np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5) def test_alpha_sq(self): a = Xector([1, 2, 3, 4]) c = alpha_sq(a) np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16]) def test_alpha_comparison(self): a = Xector([1, 2, 3, 4]) b = Xector([2, 2, 2, 2]) lt = alpha_lt(a, b) gt = alpha_gt(a, b) eq = alpha_eq(a, b) np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False]) np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True]) np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False]) class TestBetaOperations: """Test β (reduction) operations.""" def test_beta_add(self): a = Xector([1, 2, 3, 4]) assert beta_add(a) == 10 def test_beta_mul(self): a = Xector([1, 2, 3, 4]) assert beta_mul(a) == 24 def test_beta_min_max(self): a = Xector([3, 1, 4, 1, 5, 9, 2, 6]) assert beta_min(a) == 1 assert beta_max(a) == 9 def test_beta_mean(self): a = Xector([1, 2, 3, 4]) assert beta_mean(a) == 2.5 def test_beta_count(self): a = Xector([1, 2, 3, 4, 5]) assert beta_count(a) == 5 class TestFrameConversion: """Test frame/xector conversion.""" def test_extract_channels(self): # Create a 2x2 RGB frame frame = np.array([ [[255, 0, 0], [0, 255, 0]], [[0, 0, 255], [128, 128, 128]] ], dtype=np.uint8) r = xector_red(frame) g = xector_green(frame) b = xector_blue(frame) assert len(r) == 4 np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128]) np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128]) np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128]) def test_rgb_roundtrip(self): # Create a 2x2 RGB frame frame = np.array([ [[100, 150, 200], [50, 75, 100]], [[200, 100, 50], [25, 50, 75]] ], dtype=np.uint8) r = xector_red(frame) g = xector_green(frame) b = xector_blue(frame) reconstructed = xector_rgb(r, g, b) np.testing.assert_array_equal(reconstructed, frame) def test_modify_and_reconstruct(self): frame = np.array([ [[100, 100, 100], [100, 100, 100]], [[100, 100, 100], [100, 100, 100]] ], dtype=np.uint8) r = xector_red(frame) g = xector_green(frame) b = xector_blue(frame) # Double red channel r_doubled = r * 2 result = xector_rgb(r_doubled, g, b) # Red should be 200, others unchanged assert result[0, 0, 0] == 200 assert result[0, 0, 1] == 100 assert result[0, 0, 2] == 100 class TestCoordinates: """Test coordinate generation.""" def test_x_coords(self): frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols x = xector_x_coords(frame) # Should be [0,1,2, 0,1,2] (x coords repeated for each row) np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2]) def test_y_coords(self): frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols y = xector_y_coords(frame) # Should be [0,0,0, 1,1,1] (y coords for each pixel) np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1]) def test_normalized_coords(self): frame = np.zeros((2, 3, 3), dtype=np.uint8) x = xector_x_norm(frame) y = xector_y_norm(frame) # x should go 0 to 1 across width assert x.to_numpy()[0] == 0 assert x.to_numpy()[2] == 1 # y should go 0 to 1 down height assert y.to_numpy()[0] == 0 assert y.to_numpy()[3] == 1 class TestConditional: """Test conditional operations.""" def test_where(self): cond = Xector([True, False, True, False]) true_val = Xector([1, 1, 1, 1]) false_val = Xector([0, 0, 0, 0]) result = xector_where(cond, true_val, false_val) np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0]) def test_where_with_comparison(self): a = Xector([1, 5, 3, 7]) threshold = 4 # Elements > 4 become 255, others become 0 result = xector_where(alpha_gt(a, threshold), 255, 0) np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255]) def test_fill(self): frame = np.zeros((2, 3, 3), dtype=np.uint8) x = xector_fill(42, frame) assert len(x) == 6 assert all(v == 42 for v in x.to_numpy()) def test_zeros_ones(self): frame = np.zeros((2, 2, 3), dtype=np.uint8) z = xector_zeros(frame) o = xector_ones(frame) assert all(v == 0 for v in z.to_numpy()) assert all(v == 1 for v in o.to_numpy()) class TestInterpreterIntegration: """Test xector operations through the interpreter.""" def test_xector_vignette_effect(self): from sexp_effects.interpreter import Interpreter interp = Interpreter(minimal_primitives=True) # Load the xector vignette effect interp.load_effect('sexp_effects/effects/xector_vignette.sexp') # Create a test frame (white) frame = np.full((100, 100, 3), 255, dtype=np.uint8) # Run effect result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {}) # Center should be brighter than corners center = result[50, 50] corner = result[0, 0] assert center.mean() > corner.mean(), "Center should be brighter than corners" # Corners should be darkened assert corner.mean() < 255, "Corners should be darkened" def test_implicit_elementwise(self): """Test that regular + works element-wise on xectors.""" from sexp_effects.interpreter import Interpreter interp = Interpreter(minimal_primitives=True) # Load xector primitives from sexp_effects.primitive_libs.xector import PRIMITIVES for name, fn in PRIMITIVES.items(): interp.global_env.set(name, fn) # Parse and eval a simple xector expression from sexp_effects.parser import parse expr = parse('(+ (red frame) 10)') # Create test frame frame = np.full((2, 2, 3), 100, dtype=np.uint8) interp.global_env.set('frame', frame) result = interp.eval(expr) # Should be a xector with values 110 assert is_xector(result) np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110]) if __name__ == '__main__': pytest.main([__file__, '-v'])