Files
rose-ash/l1/test_typography_fx.py
2026-02-24 23:07:19 +00:00

487 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Tests for typography FX: gradients, rotation, shadow, and combined effects.
"""
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
from streaming.jax_typography import (
render_text_strip, place_text_strip_jax, _load_font,
make_linear_gradient, make_radial_gradient, make_multi_stop_gradient,
place_text_strip_gradient_jax, rotate_strip_jax,
place_text_strip_shadow_jax, place_text_strip_fx_jax,
bind_typography_primitives,
)
def make_frame(w=400, h=200):
"""Create a dark gray test frame."""
return jnp.full((h, w, 3), 40, dtype=jnp.uint8)
def get_strip(text="Hello", font_size=48):
"""Get a pre-rendered text strip."""
return render_text_strip(text, None, font_size)
def has_visible_pixels(frame, threshold=50):
"""Check if frame has pixels above threshold."""
return int(frame.max()) > threshold
def save_debug(name, frame):
"""Save frame for visual inspection."""
arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame
Image.fromarray(arr).save(f"/tmp/fx_{name}.png")
# =============================================================================
# Gradient Tests
# =============================================================================
def test_linear_gradient_shape():
grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255))
assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}"
assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}"
# Left edge should be red-ish, right edge blue-ish
assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}"
assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}"
print("PASS: test_linear_gradient_shape")
return True
def test_linear_gradient_angle():
# 90 degrees: top-to-bottom
grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0)
# Top row should be red, bottom row should be blue
assert grad[0, 50, 0] > 0.8, "Top should be red"
assert grad[-1, 50, 2] > 0.8, "Bottom should be blue"
print("PASS: test_linear_gradient_angle")
return True
def test_radial_gradient_shape():
grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128))
assert grad.shape == (100, 100, 3)
# Center should be yellow (color1)
assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)"
assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)"
# Corner should be closer to dark blue (color2)
assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue"
print("PASS: test_radial_gradient_shape")
return True
def test_multi_stop_gradient():
stops = [
(0.0, (255, 0, 0)),
(0.5, (0, 255, 0)),
(1.0, (0, 0, 255)),
]
grad = make_multi_stop_gradient(100, 10, stops)
assert grad.shape == (10, 100, 3)
# Left: red, Middle: green, Right: blue
assert grad[5, 0, 0] > 0.8, "Left should be red"
assert grad[5, 50, 1] > 0.8, "Middle should be green"
assert grad[5, -1, 2] > 0.8, "Right should be blue"
print("PASS: test_multi_stop_gradient")
return True
def test_place_gradient():
"""Test gradient text rendering produces visible output."""
frame = make_frame()
strip = get_strip()
grad = make_linear_gradient(strip.width, strip.height,
(255, 0, 0), (0, 0, 255))
grad_jax = jnp.asarray(grad)
strip_img = jnp.asarray(strip.image)
result = place_text_strip_gradient_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
grad_jax, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
assert result.shape == frame.shape
# Should have visible colored pixels
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Gradient text should be visible"
save_debug("gradient", result)
print("PASS: test_place_gradient")
return True
# =============================================================================
# Rotation Tests
# =============================================================================
def test_rotate_strip_identity():
"""Rotation by 0 degrees should preserve content."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
rotated = rotate_strip_jax(strip_img, 0.0)
# Output is larger (diagonal size)
assert rotated.shape[2] == 4, "Should be RGBA"
assert rotated.shape[0] >= strip.height
assert rotated.shape[1] >= strip.width
# Alpha should have non-zero pixels (text was preserved)
assert rotated[:, :, 3].max() > 200, "Should have visible alpha"
print("PASS: test_rotate_strip_identity")
return True
def test_rotate_strip_90():
"""Rotation by 90 degrees."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
rotated = rotate_strip_jax(strip_img, 90.0)
assert rotated.shape[2] == 4
# Should still have visible content
assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha"
save_debug("rotated_90", np.array(rotated))
print("PASS: test_rotate_strip_90")
return True
def test_rotate_360_exact():
"""360-degree rotation must be pixel-exact (regression test for trig snapping)."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
sh, sw = strip.height, strip.width
rotated = rotate_strip_jax(strip_img, 360.0)
rh, rw = rotated.shape[:2]
off_y = (rh - sh) // 2
off_x = (rw - sw) // 2
crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw])
orig = np.array(strip_img)
d = np.abs(crop.astype(np.int16) - orig.astype(np.int16))
max_diff = int(d.max())
assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}"
print("PASS: test_rotate_360_exact")
return True
def test_place_rotated():
"""Test rotated text placement produces visible output."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 0], dtype=jnp.float32)
result = place_text_strip_fx_jax(
frame, strip_img, 200.0, 100.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
angle=30.0,
)
assert result.shape == frame.shape
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Rotated text should be visible"
save_debug("rotated_30", result)
print("PASS: test_place_rotated")
return True
# =============================================================================
# Shadow Tests
# =============================================================================
def test_shadow_basic():
"""Test shadow produces visible offset copy."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_shadow_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=5.0, shadow_offset_y=5.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.8,
)
assert result.shape == frame.shape
# Should have both bright (text) and dark (shadow) pixels
assert result.max() > 200, "Should have bright text"
save_debug("shadow_basic", result)
print("PASS: test_shadow_basic")
return True
def test_shadow_blur():
"""Test blurred shadow."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_shadow_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=4.0, shadow_offset_y=4.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.7,
shadow_blur_radius=3,
)
assert result.shape == frame.shape
save_debug("shadow_blur", result)
print("PASS: test_shadow_blur")
return True
# =============================================================================
# Combined FX Tests
# =============================================================================
def test_fx_combined():
"""Test combined gradient + shadow + rotation."""
frame = make_frame(500, 300)
strip = get_strip("FX Test", 64)
strip_img = jnp.asarray(strip.image)
grad = make_linear_gradient(strip.width, strip.height,
(255, 100, 0), (0, 100, 255))
grad_jax = jnp.asarray(grad)
result = place_text_strip_fx_jax(
frame, strip_img, 250.0, 150.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
gradient_map=grad_jax,
angle=15.0,
shadow_offset_x=4.0, shadow_offset_y=4.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.6,
shadow_blur_radius=2,
)
assert result.shape == frame.shape
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Combined FX should produce visible output"
save_debug("fx_combined", result)
print("PASS: test_fx_combined")
return True
def test_fx_no_effects():
"""FX function with no effects should match basic place_text_strip_jax."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
# Using FX function with defaults
result_fx = place_text_strip_fx_jax(
frame, strip_img, 50.0, 100.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
# Using original function
result_orig = place_text_strip_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16))
max_diff = int(diff.max())
assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}"
print("PASS: test_fx_no_effects")
return True
# =============================================================================
# S-Expression Binding Tests
# =============================================================================
def test_sexp_bindings():
"""Test that all new primitives are registered."""
env = {}
bind_typography_primitives(env)
expected = [
'linear-gradient', 'radial-gradient', 'multi-stop-gradient',
'place-text-strip-gradient', 'place-text-strip-rotated',
'place-text-strip-shadow', 'place-text-strip-fx',
]
for name in expected:
assert name in env, f"Missing binding: {name}"
print("PASS: test_sexp_bindings")
return True
def test_sexp_gradient_primitive():
"""Test gradient primitive via binding."""
env = {}
bind_typography_primitives(env)
strip = env['render-text-strip']("Test", 36)
grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255))
assert grad.shape == (strip.height, strip.width, 3)
print("PASS: test_sexp_gradient_primitive")
return True
def test_sexp_fx_primitive():
"""Test combined FX primitive via binding."""
env = {}
bind_typography_primitives(env)
strip = env['render-text-strip']("FX", 36)
frame = make_frame()
result = env['place-text-strip-fx'](
frame, strip, 100.0, 80.0,
color=(255, 200, 0), opacity=0.9,
shadow_offset_x=3, shadow_offset_y=3,
shadow_opacity=0.5,
)
assert result.shape == frame.shape
print("PASS: test_sexp_fx_primitive")
return True
# =============================================================================
# JIT Compilation Test
# =============================================================================
def test_jit_fx():
"""Test that place_text_strip_fx_jax can be JIT compiled."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
# JIT compile with static args for angle and blur radius
@jax.jit
def render(frame, x, y, opacity):
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=3.0, shadow_offset_y=3.0,
shadow_color=shadow_color,
shadow_opacity=0.5,
shadow_blur_radius=2,
)
# First call traces, second uses cache
result1 = render(frame, 50.0, 100.0, 1.0)
result2 = render(frame, 60.0, 90.0, 0.8)
assert result1.shape == frame.shape
assert result2.shape == frame.shape
print("PASS: test_jit_fx")
return True
def test_jit_gradient():
"""Test that gradient placement can be JIT compiled."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
grad = jnp.asarray(make_linear_gradient(strip.width, strip.height,
(255, 0, 0), (0, 0, 255)))
@jax.jit
def render(frame, x, y):
return place_text_strip_gradient_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
grad, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
result = render(frame, 50.0, 100.0)
assert result.shape == frame.shape
print("PASS: test_jit_gradient")
return True
# =============================================================================
# Main
# =============================================================================
def main():
print("=" * 60)
print("Typography FX Tests")
print("=" * 60)
tests = [
# Gradients
test_linear_gradient_shape,
test_linear_gradient_angle,
test_radial_gradient_shape,
test_multi_stop_gradient,
test_place_gradient,
# Rotation
test_rotate_strip_identity,
test_rotate_strip_90,
test_rotate_360_exact,
test_place_rotated,
# Shadow
test_shadow_basic,
test_shadow_blur,
# Combined FX
test_fx_combined,
test_fx_no_effects,
# S-expression bindings
test_sexp_bindings,
test_sexp_gradient_primitive,
test_sexp_fx_primitive,
# JIT compilation
test_jit_fx,
test_jit_gradient,
]
results = []
for test in tests:
try:
results.append(test())
except Exception as e:
print(f"FAIL: {test.__name__}: {e}")
import traceback
traceback.print_exc()
results.append(False)
print("=" * 60)
passed = sum(r for r in results if r)
total = len(results)
print(f"Results: {passed}/{total} passed")
if passed == total:
print("ALL TESTS PASSED!")
else:
print(f"FAILED: {total - passed} tests")
print("=" * 60)
return passed == total
if __name__ == "__main__":
import sys
sys.exit(0 if main() else 1)