Add shadow, gradient, rotation FX to JAX typography with pixel-exact precision
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s
- Add gradient functions: linear, radial, and multi-stop color maps - Add RGBA strip rotation with bilinear interpolation - Add shadow compositing with optional Gaussian blur - Add combined place_text_strip_fx_jax pipeline (gradient + rotation + shadow) - Add 7 new S-expression bindings for all FX primitives - Extract shared _composite_strip_onto_frame helper - Fix rotation precision: snap trig values near 0/±1 to exact values, use pixel-center convention (dim-1)/2, and parity-matched output buffers - All 99 tests pass with zero pixel differences Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
486
test_typography_fx.py
Normal file
486
test_typography_fx.py
Normal file
@@ -0,0 +1,486 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for typography FX: gradients, rotation, shadow, and combined effects.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from PIL import Image
|
||||
|
||||
from streaming.jax_typography import (
|
||||
render_text_strip, place_text_strip_jax, _load_font,
|
||||
make_linear_gradient, make_radial_gradient, make_multi_stop_gradient,
|
||||
place_text_strip_gradient_jax, rotate_strip_jax,
|
||||
place_text_strip_shadow_jax, place_text_strip_fx_jax,
|
||||
bind_typography_primitives,
|
||||
)
|
||||
|
||||
|
||||
def make_frame(w=400, h=200):
|
||||
"""Create a dark gray test frame."""
|
||||
return jnp.full((h, w, 3), 40, dtype=jnp.uint8)
|
||||
|
||||
|
||||
def get_strip(text="Hello", font_size=48):
|
||||
"""Get a pre-rendered text strip."""
|
||||
return render_text_strip(text, None, font_size)
|
||||
|
||||
|
||||
def has_visible_pixels(frame, threshold=50):
|
||||
"""Check if frame has pixels above threshold."""
|
||||
return int(frame.max()) > threshold
|
||||
|
||||
|
||||
def save_debug(name, frame):
|
||||
"""Save frame for visual inspection."""
|
||||
arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame
|
||||
Image.fromarray(arr).save(f"/tmp/fx_{name}.png")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gradient Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_linear_gradient_shape():
|
||||
grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255))
|
||||
assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}"
|
||||
assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}"
|
||||
# Left edge should be red-ish, right edge blue-ish
|
||||
assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}"
|
||||
assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}"
|
||||
print("PASS: test_linear_gradient_shape")
|
||||
return True
|
||||
|
||||
|
||||
def test_linear_gradient_angle():
|
||||
# 90 degrees: top-to-bottom
|
||||
grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0)
|
||||
# Top row should be red, bottom row should be blue
|
||||
assert grad[0, 50, 0] > 0.8, "Top should be red"
|
||||
assert grad[-1, 50, 2] > 0.8, "Bottom should be blue"
|
||||
print("PASS: test_linear_gradient_angle")
|
||||
return True
|
||||
|
||||
|
||||
def test_radial_gradient_shape():
|
||||
grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128))
|
||||
assert grad.shape == (100, 100, 3)
|
||||
# Center should be yellow (color1)
|
||||
assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)"
|
||||
assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)"
|
||||
# Corner should be closer to dark blue (color2)
|
||||
assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue"
|
||||
print("PASS: test_radial_gradient_shape")
|
||||
return True
|
||||
|
||||
|
||||
def test_multi_stop_gradient():
|
||||
stops = [
|
||||
(0.0, (255, 0, 0)),
|
||||
(0.5, (0, 255, 0)),
|
||||
(1.0, (0, 0, 255)),
|
||||
]
|
||||
grad = make_multi_stop_gradient(100, 10, stops)
|
||||
assert grad.shape == (10, 100, 3)
|
||||
# Left: red, Middle: green, Right: blue
|
||||
assert grad[5, 0, 0] > 0.8, "Left should be red"
|
||||
assert grad[5, 50, 1] > 0.8, "Middle should be green"
|
||||
assert grad[5, -1, 2] > 0.8, "Right should be blue"
|
||||
print("PASS: test_multi_stop_gradient")
|
||||
return True
|
||||
|
||||
|
||||
def test_place_gradient():
|
||||
"""Test gradient text rendering produces visible output."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
grad = make_linear_gradient(strip.width, strip.height,
|
||||
(255, 0, 0), (0, 0, 255))
|
||||
grad_jax = jnp.asarray(grad)
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
|
||||
result = place_text_strip_gradient_jax(
|
||||
frame, strip_img, 50.0, 100.0,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
grad_jax, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
)
|
||||
|
||||
assert result.shape == frame.shape
|
||||
# Should have visible colored pixels
|
||||
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||
assert diff.max() > 50, "Gradient text should be visible"
|
||||
save_debug("gradient", result)
|
||||
print("PASS: test_place_gradient")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rotation Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_rotate_strip_identity():
|
||||
"""Rotation by 0 degrees should preserve content."""
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
|
||||
rotated = rotate_strip_jax(strip_img, 0.0)
|
||||
# Output is larger (diagonal size)
|
||||
assert rotated.shape[2] == 4, "Should be RGBA"
|
||||
assert rotated.shape[0] >= strip.height
|
||||
assert rotated.shape[1] >= strip.width
|
||||
|
||||
# Alpha should have non-zero pixels (text was preserved)
|
||||
assert rotated[:, :, 3].max() > 200, "Should have visible alpha"
|
||||
print("PASS: test_rotate_strip_identity")
|
||||
return True
|
||||
|
||||
|
||||
def test_rotate_strip_90():
|
||||
"""Rotation by 90 degrees."""
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
|
||||
rotated = rotate_strip_jax(strip_img, 90.0)
|
||||
assert rotated.shape[2] == 4
|
||||
# Should still have visible content
|
||||
assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha"
|
||||
save_debug("rotated_90", np.array(rotated))
|
||||
print("PASS: test_rotate_strip_90")
|
||||
return True
|
||||
|
||||
|
||||
def test_rotate_360_exact():
|
||||
"""360-degree rotation must be pixel-exact (regression test for trig snapping)."""
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
sh, sw = strip.height, strip.width
|
||||
|
||||
rotated = rotate_strip_jax(strip_img, 360.0)
|
||||
rh, rw = rotated.shape[:2]
|
||||
off_y = (rh - sh) // 2
|
||||
off_x = (rw - sw) // 2
|
||||
|
||||
crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw])
|
||||
orig = np.array(strip_img)
|
||||
d = np.abs(crop.astype(np.int16) - orig.astype(np.int16))
|
||||
max_diff = int(d.max())
|
||||
assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}"
|
||||
print("PASS: test_rotate_360_exact")
|
||||
return True
|
||||
|
||||
|
||||
def test_place_rotated():
|
||||
"""Test rotated text placement produces visible output."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 0], dtype=jnp.float32)
|
||||
|
||||
result = place_text_strip_fx_jax(
|
||||
frame, strip_img, 200.0, 100.0,
|
||||
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||
color=color, opacity=1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
angle=30.0,
|
||||
)
|
||||
|
||||
assert result.shape == frame.shape
|
||||
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||
assert diff.max() > 50, "Rotated text should be visible"
|
||||
save_debug("rotated_30", result)
|
||||
print("PASS: test_place_rotated")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shadow Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_shadow_basic():
|
||||
"""Test shadow produces visible offset copy."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||
|
||||
result = place_text_strip_shadow_jax(
|
||||
frame, strip_img, 50.0, 100.0,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
shadow_offset_x=5.0, shadow_offset_y=5.0,
|
||||
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||
shadow_opacity=0.8,
|
||||
)
|
||||
|
||||
assert result.shape == frame.shape
|
||||
# Should have both bright (text) and dark (shadow) pixels
|
||||
assert result.max() > 200, "Should have bright text"
|
||||
save_debug("shadow_basic", result)
|
||||
print("PASS: test_shadow_basic")
|
||||
return True
|
||||
|
||||
|
||||
def test_shadow_blur():
|
||||
"""Test blurred shadow."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||
|
||||
result = place_text_strip_shadow_jax(
|
||||
frame, strip_img, 50.0, 100.0,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
shadow_offset_x=4.0, shadow_offset_y=4.0,
|
||||
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||
shadow_opacity=0.7,
|
||||
shadow_blur_radius=3,
|
||||
)
|
||||
|
||||
assert result.shape == frame.shape
|
||||
save_debug("shadow_blur", result)
|
||||
print("PASS: test_shadow_blur")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined FX Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_fx_combined():
|
||||
"""Test combined gradient + shadow + rotation."""
|
||||
frame = make_frame(500, 300)
|
||||
strip = get_strip("FX Test", 64)
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
|
||||
grad = make_linear_gradient(strip.width, strip.height,
|
||||
(255, 100, 0), (0, 100, 255))
|
||||
grad_jax = jnp.asarray(grad)
|
||||
|
||||
result = place_text_strip_fx_jax(
|
||||
frame, strip_img, 250.0, 150.0,
|
||||
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
gradient_map=grad_jax,
|
||||
angle=15.0,
|
||||
shadow_offset_x=4.0, shadow_offset_y=4.0,
|
||||
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||
shadow_opacity=0.6,
|
||||
shadow_blur_radius=2,
|
||||
)
|
||||
|
||||
assert result.shape == frame.shape
|
||||
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||
assert diff.max() > 50, "Combined FX should produce visible output"
|
||||
save_debug("fx_combined", result)
|
||||
print("PASS: test_fx_combined")
|
||||
return True
|
||||
|
||||
|
||||
def test_fx_no_effects():
|
||||
"""FX function with no effects should match basic place_text_strip_jax."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||
|
||||
# Using FX function with defaults
|
||||
result_fx = place_text_strip_fx_jax(
|
||||
frame, strip_img, 50.0, 100.0,
|
||||
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||
color=color, opacity=1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
)
|
||||
|
||||
# Using original function
|
||||
result_orig = place_text_strip_jax(
|
||||
frame, strip_img, 50.0, 100.0,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
)
|
||||
|
||||
diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16))
|
||||
max_diff = int(diff.max())
|
||||
assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}"
|
||||
print("PASS: test_fx_no_effects")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# S-Expression Binding Tests
|
||||
# =============================================================================
|
||||
|
||||
def test_sexp_bindings():
|
||||
"""Test that all new primitives are registered."""
|
||||
env = {}
|
||||
bind_typography_primitives(env)
|
||||
|
||||
expected = [
|
||||
'linear-gradient', 'radial-gradient', 'multi-stop-gradient',
|
||||
'place-text-strip-gradient', 'place-text-strip-rotated',
|
||||
'place-text-strip-shadow', 'place-text-strip-fx',
|
||||
]
|
||||
for name in expected:
|
||||
assert name in env, f"Missing binding: {name}"
|
||||
|
||||
print("PASS: test_sexp_bindings")
|
||||
return True
|
||||
|
||||
|
||||
def test_sexp_gradient_primitive():
|
||||
"""Test gradient primitive via binding."""
|
||||
env = {}
|
||||
bind_typography_primitives(env)
|
||||
|
||||
strip = env['render-text-strip']("Test", 36)
|
||||
grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255))
|
||||
|
||||
assert grad.shape == (strip.height, strip.width, 3)
|
||||
print("PASS: test_sexp_gradient_primitive")
|
||||
return True
|
||||
|
||||
|
||||
def test_sexp_fx_primitive():
|
||||
"""Test combined FX primitive via binding."""
|
||||
env = {}
|
||||
bind_typography_primitives(env)
|
||||
|
||||
strip = env['render-text-strip']("FX", 36)
|
||||
frame = make_frame()
|
||||
|
||||
result = env['place-text-strip-fx'](
|
||||
frame, strip, 100.0, 80.0,
|
||||
color=(255, 200, 0), opacity=0.9,
|
||||
shadow_offset_x=3, shadow_offset_y=3,
|
||||
shadow_opacity=0.5,
|
||||
)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS: test_sexp_fx_primitive")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JIT Compilation Test
|
||||
# =============================================================================
|
||||
|
||||
def test_jit_fx():
|
||||
"""Test that place_text_strip_fx_jax can be JIT compiled."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
|
||||
|
||||
# JIT compile with static args for angle and blur radius
|
||||
@jax.jit
|
||||
def render(frame, x, y, opacity):
|
||||
return place_text_strip_fx_jax(
|
||||
frame, strip_img, x, y,
|
||||
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||
color=color, opacity=opacity,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
shadow_offset_x=3.0, shadow_offset_y=3.0,
|
||||
shadow_color=shadow_color,
|
||||
shadow_opacity=0.5,
|
||||
shadow_blur_radius=2,
|
||||
)
|
||||
|
||||
# First call traces, second uses cache
|
||||
result1 = render(frame, 50.0, 100.0, 1.0)
|
||||
result2 = render(frame, 60.0, 90.0, 0.8)
|
||||
|
||||
assert result1.shape == frame.shape
|
||||
assert result2.shape == frame.shape
|
||||
print("PASS: test_jit_fx")
|
||||
return True
|
||||
|
||||
|
||||
def test_jit_gradient():
|
||||
"""Test that gradient placement can be JIT compiled."""
|
||||
frame = make_frame()
|
||||
strip = get_strip()
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
grad = jnp.asarray(make_linear_gradient(strip.width, strip.height,
|
||||
(255, 0, 0), (0, 0, 255)))
|
||||
|
||||
@jax.jit
|
||||
def render(frame, x, y):
|
||||
return place_text_strip_gradient_jax(
|
||||
frame, strip_img, x, y,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
grad, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
)
|
||||
|
||||
result = render(frame, 50.0, 100.0)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS: test_jit_gradient")
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Typography FX Tests")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
# Gradients
|
||||
test_linear_gradient_shape,
|
||||
test_linear_gradient_angle,
|
||||
test_radial_gradient_shape,
|
||||
test_multi_stop_gradient,
|
||||
test_place_gradient,
|
||||
# Rotation
|
||||
test_rotate_strip_identity,
|
||||
test_rotate_strip_90,
|
||||
test_rotate_360_exact,
|
||||
test_place_rotated,
|
||||
# Shadow
|
||||
test_shadow_basic,
|
||||
test_shadow_blur,
|
||||
# Combined FX
|
||||
test_fx_combined,
|
||||
test_fx_no_effects,
|
||||
# S-expression bindings
|
||||
test_sexp_bindings,
|
||||
test_sexp_gradient_primitive,
|
||||
test_sexp_fx_primitive,
|
||||
# JIT compilation
|
||||
test_jit_fx,
|
||||
test_jit_gradient,
|
||||
]
|
||||
|
||||
results = []
|
||||
for test in tests:
|
||||
try:
|
||||
results.append(test())
|
||||
except Exception as e:
|
||||
print(f"FAIL: {test.__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results.append(False)
|
||||
|
||||
print("=" * 60)
|
||||
passed = sum(r for r in results if r)
|
||||
total = len(results)
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
if passed == total:
|
||||
print("ALL TESTS PASSED!")
|
||||
else:
|
||||
print(f"FAILED: {total - passed} tests")
|
||||
print("=" * 60)
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(0 if main() else 1)
|
||||
Reference in New Issue
Block a user