#!/usr/bin/env python3 """ Tests for typography FX: gradients, rotation, shadow, and combined effects. """ import numpy as np import jax import jax.numpy as jnp from PIL import Image from streaming.jax_typography import ( render_text_strip, place_text_strip_jax, _load_font, make_linear_gradient, make_radial_gradient, make_multi_stop_gradient, place_text_strip_gradient_jax, rotate_strip_jax, place_text_strip_shadow_jax, place_text_strip_fx_jax, bind_typography_primitives, ) def make_frame(w=400, h=200): """Create a dark gray test frame.""" return jnp.full((h, w, 3), 40, dtype=jnp.uint8) def get_strip(text="Hello", font_size=48): """Get a pre-rendered text strip.""" return render_text_strip(text, None, font_size) def has_visible_pixels(frame, threshold=50): """Check if frame has pixels above threshold.""" return int(frame.max()) > threshold def save_debug(name, frame): """Save frame for visual inspection.""" arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame Image.fromarray(arr).save(f"/tmp/fx_{name}.png") # ============================================================================= # Gradient Tests # ============================================================================= def test_linear_gradient_shape(): grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255)) assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}" assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}" # Left edge should be red-ish, right edge blue-ish assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}" assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}" print("PASS: test_linear_gradient_shape") return True def test_linear_gradient_angle(): # 90 degrees: top-to-bottom grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0) # Top row should be red, bottom row should be blue assert grad[0, 50, 0] > 0.8, "Top should be red" assert grad[-1, 50, 2] > 0.8, "Bottom should be blue" print("PASS: test_linear_gradient_angle") return True def test_radial_gradient_shape(): grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128)) assert grad.shape == (100, 100, 3) # Center should be yellow (color1) assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)" assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)" # Corner should be closer to dark blue (color2) assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue" print("PASS: test_radial_gradient_shape") return True def test_multi_stop_gradient(): stops = [ (0.0, (255, 0, 0)), (0.5, (0, 255, 0)), (1.0, (0, 0, 255)), ] grad = make_multi_stop_gradient(100, 10, stops) assert grad.shape == (10, 100, 3) # Left: red, Middle: green, Right: blue assert grad[5, 0, 0] > 0.8, "Left should be red" assert grad[5, 50, 1] > 0.8, "Middle should be green" assert grad[5, -1, 2] > 0.8, "Right should be blue" print("PASS: test_multi_stop_gradient") return True def test_place_gradient(): """Test gradient text rendering produces visible output.""" frame = make_frame() strip = get_strip() grad = make_linear_gradient(strip.width, strip.height, (255, 0, 0), (0, 0, 255)) grad_jax = jnp.asarray(grad) strip_img = jnp.asarray(strip.image) result = place_text_strip_gradient_jax( frame, strip_img, 50.0, 100.0, strip.baseline_y, strip.bearing_x, grad_jax, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, ) assert result.shape == frame.shape # Should have visible colored pixels diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) assert diff.max() > 50, "Gradient text should be visible" save_debug("gradient", result) print("PASS: test_place_gradient") return True # ============================================================================= # Rotation Tests # ============================================================================= def test_rotate_strip_identity(): """Rotation by 0 degrees should preserve content.""" strip = get_strip() strip_img = jnp.asarray(strip.image) rotated = rotate_strip_jax(strip_img, 0.0) # Output is larger (diagonal size) assert rotated.shape[2] == 4, "Should be RGBA" assert rotated.shape[0] >= strip.height assert rotated.shape[1] >= strip.width # Alpha should have non-zero pixels (text was preserved) assert rotated[:, :, 3].max() > 200, "Should have visible alpha" print("PASS: test_rotate_strip_identity") return True def test_rotate_strip_90(): """Rotation by 90 degrees.""" strip = get_strip() strip_img = jnp.asarray(strip.image) rotated = rotate_strip_jax(strip_img, 90.0) assert rotated.shape[2] == 4 # Should still have visible content assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha" save_debug("rotated_90", np.array(rotated)) print("PASS: test_rotate_strip_90") return True def test_rotate_360_exact(): """360-degree rotation must be pixel-exact (regression test for trig snapping).""" strip = get_strip() strip_img = jnp.asarray(strip.image) sh, sw = strip.height, strip.width rotated = rotate_strip_jax(strip_img, 360.0) rh, rw = rotated.shape[:2] off_y = (rh - sh) // 2 off_x = (rw - sw) // 2 crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw]) orig = np.array(strip_img) d = np.abs(crop.astype(np.int16) - orig.astype(np.int16)) max_diff = int(d.max()) assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}" print("PASS: test_rotate_360_exact") return True def test_place_rotated(): """Test rotated text placement produces visible output.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 0], dtype=jnp.float32) result = place_text_strip_fx_jax( frame, strip_img, 200.0, 100.0, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, color=color, opacity=1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, angle=30.0, ) assert result.shape == frame.shape diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) assert diff.max() > 50, "Rotated text should be visible" save_debug("rotated_30", result) print("PASS: test_place_rotated") return True # ============================================================================= # Shadow Tests # ============================================================================= def test_shadow_basic(): """Test shadow produces visible offset copy.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 255], dtype=jnp.float32) result = place_text_strip_shadow_jax( frame, strip_img, 50.0, 100.0, strip.baseline_y, strip.bearing_x, color, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, shadow_offset_x=5.0, shadow_offset_y=5.0, shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), shadow_opacity=0.8, ) assert result.shape == frame.shape # Should have both bright (text) and dark (shadow) pixels assert result.max() > 200, "Should have bright text" save_debug("shadow_basic", result) print("PASS: test_shadow_basic") return True def test_shadow_blur(): """Test blurred shadow.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 255], dtype=jnp.float32) result = place_text_strip_shadow_jax( frame, strip_img, 50.0, 100.0, strip.baseline_y, strip.bearing_x, color, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, shadow_offset_x=4.0, shadow_offset_y=4.0, shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), shadow_opacity=0.7, shadow_blur_radius=3, ) assert result.shape == frame.shape save_debug("shadow_blur", result) print("PASS: test_shadow_blur") return True # ============================================================================= # Combined FX Tests # ============================================================================= def test_fx_combined(): """Test combined gradient + shadow + rotation.""" frame = make_frame(500, 300) strip = get_strip("FX Test", 64) strip_img = jnp.asarray(strip.image) grad = make_linear_gradient(strip.width, strip.height, (255, 100, 0), (0, 100, 255)) grad_jax = jnp.asarray(grad) result = place_text_strip_fx_jax( frame, strip_img, 250.0, 150.0, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, gradient_map=grad_jax, angle=15.0, shadow_offset_x=4.0, shadow_offset_y=4.0, shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), shadow_opacity=0.6, shadow_blur_radius=2, ) assert result.shape == frame.shape diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) assert diff.max() > 50, "Combined FX should produce visible output" save_debug("fx_combined", result) print("PASS: test_fx_combined") return True def test_fx_no_effects(): """FX function with no effects should match basic place_text_strip_jax.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 255], dtype=jnp.float32) # Using FX function with defaults result_fx = place_text_strip_fx_jax( frame, strip_img, 50.0, 100.0, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, color=color, opacity=1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, ) # Using original function result_orig = place_text_strip_jax( frame, strip_img, 50.0, 100.0, strip.baseline_y, strip.bearing_x, color, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, ) diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16)) max_diff = int(diff.max()) assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}" print("PASS: test_fx_no_effects") return True # ============================================================================= # S-Expression Binding Tests # ============================================================================= def test_sexp_bindings(): """Test that all new primitives are registered.""" env = {} bind_typography_primitives(env) expected = [ 'linear-gradient', 'radial-gradient', 'multi-stop-gradient', 'place-text-strip-gradient', 'place-text-strip-rotated', 'place-text-strip-shadow', 'place-text-strip-fx', ] for name in expected: assert name in env, f"Missing binding: {name}" print("PASS: test_sexp_bindings") return True def test_sexp_gradient_primitive(): """Test gradient primitive via binding.""" env = {} bind_typography_primitives(env) strip = env['render-text-strip']("Test", 36) grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255)) assert grad.shape == (strip.height, strip.width, 3) print("PASS: test_sexp_gradient_primitive") return True def test_sexp_fx_primitive(): """Test combined FX primitive via binding.""" env = {} bind_typography_primitives(env) strip = env['render-text-strip']("FX", 36) frame = make_frame() result = env['place-text-strip-fx']( frame, strip, 100.0, 80.0, color=(255, 200, 0), opacity=0.9, shadow_offset_x=3, shadow_offset_y=3, shadow_opacity=0.5, ) assert result.shape == frame.shape print("PASS: test_sexp_fx_primitive") return True # ============================================================================= # JIT Compilation Test # ============================================================================= def test_jit_fx(): """Test that place_text_strip_fx_jax can be JIT compiled.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 255], dtype=jnp.float32) shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) # JIT compile with static args for angle and blur radius @jax.jit def render(frame, x, y, opacity): return place_text_strip_fx_jax( frame, strip_img, x, y, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, color=color, opacity=opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, shadow_offset_x=3.0, shadow_offset_y=3.0, shadow_color=shadow_color, shadow_opacity=0.5, shadow_blur_radius=2, ) # First call traces, second uses cache result1 = render(frame, 50.0, 100.0, 1.0) result2 = render(frame, 60.0, 90.0, 0.8) assert result1.shape == frame.shape assert result2.shape == frame.shape print("PASS: test_jit_fx") return True def test_jit_gradient(): """Test that gradient placement can be JIT compiled.""" frame = make_frame() strip = get_strip() strip_img = jnp.asarray(strip.image) grad = jnp.asarray(make_linear_gradient(strip.width, strip.height, (255, 0, 0), (0, 0, 255))) @jax.jit def render(frame, x, y): return place_text_strip_gradient_jax( frame, strip_img, x, y, strip.baseline_y, strip.bearing_x, grad, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, ) result = render(frame, 50.0, 100.0) assert result.shape == frame.shape print("PASS: test_jit_gradient") return True # ============================================================================= # Main # ============================================================================= def main(): print("=" * 60) print("Typography FX Tests") print("=" * 60) tests = [ # Gradients test_linear_gradient_shape, test_linear_gradient_angle, test_radial_gradient_shape, test_multi_stop_gradient, test_place_gradient, # Rotation test_rotate_strip_identity, test_rotate_strip_90, test_rotate_360_exact, test_place_rotated, # Shadow test_shadow_basic, test_shadow_blur, # Combined FX test_fx_combined, test_fx_no_effects, # S-expression bindings test_sexp_bindings, test_sexp_gradient_primitive, test_sexp_fx_primitive, # JIT compilation test_jit_fx, test_jit_gradient, ] results = [] for test in tests: try: results.append(test()) except Exception as e: print(f"FAIL: {test.__name__}: {e}") import traceback traceback.print_exc() results.append(False) print("=" * 60) passed = sum(r for r in results if r) total = len(results) print(f"Results: {passed}/{total} passed") if passed == total: print("ALL TESTS PASSED!") else: print(f"FAILED: {total - passed} tests") print("=" * 60) return passed == total if __name__ == "__main__": import sys sys.exit(0 if main() else 1)