Add shadow, gradient, rotation FX to JAX typography with pixel-exact precision
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s

- Add gradient functions: linear, radial, and multi-stop color maps
- Add RGBA strip rotation with bilinear interpolation
- Add shadow compositing with optional Gaussian blur
- Add combined place_text_strip_fx_jax pipeline (gradient + rotation + shadow)
- Add 7 new S-expression bindings for all FX primitives
- Extract shared _composite_strip_onto_frame helper
- Fix rotation precision: snap trig values near 0/±1 to exact values,
  use pixel-center convention (dim-1)/2, and parity-matched output buffers
- All 99 tests pass with zero pixel differences

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 19:25:26 +00:00
parent b322e003be
commit a29841f3c5
2 changed files with 1269 additions and 1 deletions

View File

@@ -54,9 +54,12 @@ Kerning support:
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
""" """
import math
import numpy as np import numpy as np
import jax
import jax.numpy as jnp import jax.numpy as jnp
from typing import Tuple, Dict, Any from jax import lax
from typing import Tuple, Dict, Any, List, Optional
from dataclasses import dataclass from dataclasses import dataclass
@@ -637,6 +640,661 @@ def place_glyph_simple(
) )
# =============================================================================
# Gradient Functions (compile-time: generate color maps from strip dimensions)
# =============================================================================
def make_linear_gradient(
width: int,
height: int,
color1: tuple,
color2: tuple,
angle: float = 0.0,
) -> np.ndarray:
"""Create a linear gradient color map.
Args:
width, height: Dimensions of the gradient (match strip dimensions)
color1: Start color (R, G, B) 0-255
color2: End color (R, G, B) 0-255
angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom)
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
# Create coordinate grid
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
# Normalize to [0, 1]
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
# Project onto gradient axis
theta = angle * np.pi / 180.0
cos_t = np.cos(theta)
sin_t = np.sin(theta)
# Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1]
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
# Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|)
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
if max_proj > 0:
t = (proj / max_proj + 1.0) / 2.0
else:
t = np.full_like(proj, 0.5)
t = np.clip(t, 0.0, 1.0)
# Interpolate
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
return gradient
def make_radial_gradient(
width: int,
height: int,
color1: tuple,
color2: tuple,
center_x: float = 0.5,
center_y: float = 0.5,
) -> np.ndarray:
"""Create a radial gradient color map.
Args:
width, height: Dimensions
color1: Inner color (R, G, B)
color2: Outer color (R, G, B)
center_x, center_y: Center position in [0, 1] (0.5 = center)
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
# Normalize to [0, 1]
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
# Distance from center, normalized so corners are ~1.0
dx = nx - center_x
dy = ny - center_y
# Max possible distance from center to a corner
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
if max_dist > 0:
t = np.sqrt(dx**2 + dy**2) / max_dist
else:
t = np.zeros_like(dx)
t = np.clip(t, 0.0, 1.0)
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
return gradient
def make_multi_stop_gradient(
width: int,
height: int,
stops: list,
angle: float = 0.0,
radial: bool = False,
center_x: float = 0.5,
center_y: float = 0.5,
) -> np.ndarray:
"""Create a multi-stop gradient color map.
Args:
width, height: Dimensions
stops: List of (position, (R, G, B)) tuples, position in [0, 1]
angle: Gradient angle in degrees (for linear mode)
radial: If True, use radial gradient
center_x, center_y: Center for radial gradient
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
if len(stops) < 2:
if len(stops) == 1:
c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0
return np.broadcast_to(c, (height, width, 3)).copy()
return np.zeros((height, width, 3), dtype=np.float32)
# Sort stops by position
stops = sorted(stops, key=lambda s: s[0])
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
if radial:
dx = nx - center_x
dy = ny - center_y
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6)
else:
theta = angle * np.pi / 180.0
cos_t = np.cos(theta)
sin_t = np.sin(theta)
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
if max_proj > 0:
t = (proj / max_proj + 1.0) / 2.0
else:
t = np.full_like(proj, 0.5)
t = np.clip(t, 0.0, 1.0)
# Build gradient from stops using piecewise linear interpolation
colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops])
positions = np.array([s[0] for s in stops], dtype=np.float32)
# Start with first color
gradient = np.broadcast_to(colors[0], (height, width, 3)).copy()
for i in range(len(stops) - 1):
p0, p1 = positions[i], positions[i + 1]
c0, c1 = colors[i], colors[i + 1]
if p1 <= p0:
continue
# Segment interpolation factor
seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0)
# Only apply where t >= p0
mask = (t >= p0)[:, :, None]
seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None]
gradient = np.where(mask, seg_color, gradient)
return gradient
def _composite_strip_onto_frame(
frame: jnp.ndarray,
strip_rgb: jnp.ndarray,
strip_alpha: jnp.ndarray,
dst_x: jnp.ndarray,
dst_y: jnp.ndarray,
sh: int,
sw: int,
) -> jnp.ndarray:
"""Core compositing: place tinted+alpha strip onto frame using padded buffer.
Args:
frame: (H, W, 3) RGB uint8
strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB
strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha
dst_x, dst_y: int32 destination position
sh, sw: strip dimensions (compile-time constants)
Returns:
Composited frame (H, W, 3) uint8
"""
h, w = frame.shape[:2]
buf_h = h + 2 * sh
buf_w = w + 2 * sw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0))
alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
# PIL-compatible integer alpha blending
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_text_strip_gradient_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int,
bearing_x: float,
gradient_map: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
) -> jnp.ndarray:
"""Place text strip with gradient coloring instead of solid color.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
gradient_map: (sh, sw, 3) float32 color map in [0, 1]
Other args same as place_text_strip_jax
Returns:
Composited frame
"""
sh, sw = strip_image.shape[:2]
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
# Extract alpha with opacity
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply gradient instead of solid color
tinted = strip_rgb * gradient_map
return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw)
# =============================================================================
# Strip Rotation (RGBA bilinear interpolation)
# =============================================================================
def _sample_rgba(strip, x, y):
"""Bilinear sample all 4 RGBA channels from a strip.
Args:
strip: (H, W, 4) RGBA float32
x, y: coordinate arrays (flattened)
Returns:
(r, g, b, a) each same shape as x
"""
h, w = strip.shape[:2]
x0 = jnp.floor(x).astype(jnp.int32)
y0 = jnp.floor(y).astype(jnp.int32)
x1 = x0 + 1
y1 = y0 + 1
fx = x - x0.astype(jnp.float32)
fy = y - y0.astype(jnp.float32)
valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h)
valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h)
valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h)
valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h)
x0_safe = jnp.clip(x0, 0, w - 1)
x1_safe = jnp.clip(x1, 0, w - 1)
y0_safe = jnp.clip(y0, 0, h - 1)
y1_safe = jnp.clip(y1, 0, h - 1)
channels = []
for c in range(4):
c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0)
c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0)
c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0)
c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0)
val = (c00 * (1 - fx) * (1 - fy) +
c10 * fx * (1 - fy) +
c01 * (1 - fx) * fy +
c11 * fx * fy)
channels.append(val)
return channels[0], channels[1], channels[2], channels[3]
def rotate_strip_jax(
strip_image: jnp.ndarray,
angle: float,
) -> jnp.ndarray:
"""Rotate an RGBA strip by angle (degrees), counter-clockwise.
Output buffer is sized to contain the full rotated strip.
The output size is ceil(sqrt(w^2 + h^2)), computed at trace time
from the strip's static shape.
Args:
strip_image: (H, W, 4) RGBA uint8
angle: Rotation angle in degrees
Returns:
(out_h, out_w, 4) RGBA uint8 - rotated strip
"""
sh, sw = strip_image.shape[:2]
# Output size: diagonal of original strip (compile-time constant).
# Ensure output dimensions have same parity as source so that the
# center offset (out - src) / 2 is always an integer. Otherwise
# identity rotations would place content at half-pixel offsets.
diag = int(math.ceil(math.sqrt(sw * sw + sh * sh)))
out_w = diag + ((diag % 2) != (sw % 2))
out_h = diag + ((diag % 2) != (sh % 2))
# Center of input strip and output buffer (pixel-center convention).
# Using (dim-1)/2 ensures integer coords map to integer coords for
# identity rotation regardless of even/odd dimension parity.
src_cx = (sw - 1) / 2.0
src_cy = (sh - 1) / 2.0
dst_cx = (out_w - 1) / 2.0
dst_cy = (out_h - 1) / 2.0
# Convert to radians and snap trig values near 0/±1 to exact values.
# Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing
# bilinear blending at pixel edges and 1-value differences.
theta = angle * jnp.pi / 180.0
cos_t = jnp.cos(theta)
sin_t = jnp.sin(theta)
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
# Create output coordinate grid
y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w)
x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w)
# Inverse rotation: map output coords to source coords
x_centered = x_coords.astype(jnp.float32) - dst_cx
y_centered = y_coords.astype(jnp.float32) - dst_cy
src_x = cos_t * x_centered - sin_t * y_centered + src_cx
src_y = sin_t * x_centered + cos_t * y_centered + src_cy
# Sample all 4 channels
strip_f = strip_image.astype(jnp.float32)
r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
], axis=2)
# =============================================================================
# Shadow Compositing
# =============================================================================
def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray:
"""Blur a single-channel alpha array using Gaussian convolution.
Args:
alpha: (H, W) float32 alpha channel
radius: Blur radius (compile-time constant)
Returns:
(H, W) float32 blurred alpha
"""
size = radius * 2 + 1
x = jnp.arange(size, dtype=jnp.float32) - radius
sigma = max(radius / 2.0, 0.5)
gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
kernel = jnp.outer(gaussian_1d, gaussian_1d)
# Use JAX conv with SAME padding
h, w = alpha.shape
data_4d = alpha.reshape(1, h, w, 1)
kernel_4d = kernel.reshape(size, size, 1, 1)
result = lax.conv_general_dilated(
data_4d, kernel_4d,
window_strides=(1, 1),
padding='SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
)
return result.reshape(h, w)
def place_text_strip_shadow_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int,
bearing_x: float,
color: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
shadow_offset_x: float = 3.0,
shadow_offset_y: float = 3.0,
shadow_color: jnp.ndarray = None,
shadow_opacity: float = 0.5,
shadow_blur_radius: int = 0,
) -> jnp.ndarray:
"""Place text strip with a drop shadow.
Composites the strip twice: first as shadow (offset, colored, optionally blurred),
then the text itself on top.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
shadow_offset_x/y: Shadow offset in pixels
shadow_color: (3,) RGB color for shadow (default black)
shadow_opacity: Shadow opacity 0-1
shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time)
Other args same as place_text_strip_jax
Returns:
Composited frame
"""
if shadow_color is None:
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
sh, sw = strip_image.shape[:2]
# --- Shadow pass ---
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
# Shadow alpha from strip alpha
shadow_opacity_int = jnp.round(shadow_opacity * 255)
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
if shadow_blur_radius > 0:
# Blur the alpha channel for soft shadow
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
else:
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1)
# Shadow RGB: solid shadow color
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
# --- Text pass ---
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0
color_norm = color.astype(jnp.float32) / 255.0
tinted = strip_rgb * color_norm
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
return frame
# =============================================================================
# Combined FX Pipeline
# =============================================================================
def place_text_strip_fx_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int = 0,
bearing_x: float = 0.0,
color: jnp.ndarray = None,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
gradient_map: jnp.ndarray = None,
angle: float = 0.0,
shadow_offset_x: float = 0.0,
shadow_offset_y: float = 0.0,
shadow_color: jnp.ndarray = None,
shadow_opacity: float = 0.0,
shadow_blur_radius: int = 0,
) -> jnp.ndarray:
"""Combined text placement with gradient, rotation, and shadow.
Pipeline order:
1. Build color layer (solid color or gradient map)
2. Rotate strip + color layer if angle != 0
3. Composite shadow if shadow_opacity > 0
4. Composite text
Note: angle and shadow_blur_radius should be compile-time constants
for optimal JIT performance (they affect buffer shapes/kernel sizes).
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
x, y: Anchor point position
color: (3,) RGB color (ignored if gradient_map provided)
opacity: Text opacity
gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color
angle: Rotation angle in degrees (0 = no rotation)
shadow_offset_x/y: Shadow offset
shadow_color: (3,) RGB shadow color
shadow_opacity: Shadow opacity (0 = no shadow)
shadow_blur_radius: Shadow blur radius
Returns:
Composited frame
"""
if color is None:
color = jnp.array([255, 255, 255], dtype=jnp.float32)
if shadow_color is None:
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
sh, sw = strip_image.shape[:2]
# --- Step 1: Build color layer ---
if gradient_map is not None:
color_layer = gradient_map # (sh, sw, 3) float32 [0, 1]
else:
color_norm = color.astype(jnp.float32) / 255.0
color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3))
# --- Step 2: Rotate if needed ---
# angle is expected to be a compile-time constant or static value
# We check at Python level to avoid tracing issues with dynamic shapes
use_rotation = not isinstance(angle, (int, float)) or angle != 0.0
if use_rotation:
# Rotate the strip
rotated_strip = rotate_strip_jax(strip_image, angle)
rh, rw = rotated_strip.shape[:2]
# Rotate the color layer by building a 4-channel color+dummy image
# Actually, just re-create color layer at rotated size
if gradient_map is not None:
# Rotate gradient map: pack into 3-channel "image", rotate via sampling
grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8)
# Create RGBA from gradient (alpha=255 everywhere)
grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2)
rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle)
color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0
else:
# Solid color: just broadcast to rotated size
color_norm = color.astype(jnp.float32) / 255.0
color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3))
# Update anchor point for rotation (pixel-center convention)
# Rotate the anchor offset around the strip center
theta = angle * jnp.pi / 180.0
cos_t = jnp.cos(theta)
sin_t = jnp.sin(theta)
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
# Original anchor relative to strip pixel center
src_cx = (sw - 1) / 2.0
src_cy = (sh - 1) / 2.0
dst_cx = (rw - 1) / 2.0
dst_cy = (rh - 1) / 2.0
ax_rel = anchor_x - src_cx
ay_rel = anchor_y - src_cy
# Rotate anchor point (forward rotation, not inverse)
new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx
new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy
strip_image = rotated_strip
anchor_x = new_ax
anchor_y = new_ay
sh, sw = rh, rw
# --- Step 3: Shadow ---
has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0
if has_shadow:
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
shadow_opacity_int = jnp.round(shadow_opacity * 255)
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
if shadow_blur_radius > 0:
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
else:
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
shadow_alpha = shadow_alpha[:, :, None]
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
# --- Step 4: Composite text ---
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
tinted = strip_rgb * color_layer
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
return frame
# ============================================================================= # =============================================================================
# S-Expression Primitive Bindings # S-Expression Primitive Bindings
# ============================================================================= # =============================================================================
@@ -788,6 +1446,119 @@ def bind_typography_primitives(env: dict) -> dict:
stroke_width=strip.stroke_width stroke_width=strip.stroke_width
) )
# --- Gradient primitives ---
def prim_linear_gradient(strip, color1, color2, angle=0.0):
"""Create linear gradient color map for a text strip. Compile-time."""
grad = make_linear_gradient(strip.width, strip.height,
tuple(int(c) for c in color1),
tuple(int(c) for c in color2),
float(angle))
return jnp.asarray(grad)
def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5):
"""Create radial gradient color map for a text strip. Compile-time."""
grad = make_radial_gradient(strip.width, strip.height,
tuple(int(c) for c in color1),
tuple(int(c) for c in color2),
float(center_x), float(center_y))
return jnp.asarray(grad)
def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False,
center_x=0.5, center_y=0.5):
"""Create multi-stop gradient for a text strip. Compile-time.
stops: list of (position, (R, G, B)) tuples
"""
parsed_stops = []
for s in stops:
pos = float(s[0])
color_tuple = tuple(int(c) for c in s[1])
parsed_stops.append((pos, color_tuple))
grad = make_multi_stop_gradient(strip.width, strip.height,
parsed_stops, float(angle),
bool(radial),
float(center_x), float(center_y))
return jnp.asarray(grad)
def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0):
"""Place text strip with gradient coloring. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
return place_text_strip_gradient_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
gradient_map, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
# --- Rotation primitive ---
def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255),
opacity=1.0, angle=0.0):
"""Place text strip with rotation. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color_arr, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
angle=float(angle),
)
# --- Shadow primitive ---
def prim_place_text_strip_shadow(frame, strip, x, y,
color=(255, 255, 255), opacity=1.0,
shadow_offset_x=3.0, shadow_offset_y=3.0,
shadow_color=(0, 0, 0), shadow_opacity=0.5,
shadow_blur_radius=0):
"""Place text strip with shadow. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
return place_text_strip_shadow_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color_arr, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
shadow_offset_x=float(shadow_offset_x),
shadow_offset_y=float(shadow_offset_y),
shadow_color=shadow_color_arr,
shadow_opacity=float(shadow_opacity),
shadow_blur_radius=int(shadow_blur_radius),
)
# --- Combined FX primitive ---
def prim_place_text_strip_fx(frame, strip, x, y,
color=(255, 255, 255), opacity=1.0,
gradient=None, angle=0.0,
shadow_offset_x=0.0, shadow_offset_y=0.0,
shadow_color=(0, 0, 0), shadow_opacity=0.0,
shadow_blur=0):
"""Place text strip with all effects. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color_arr, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
gradient_map=gradient,
angle=float(angle),
shadow_offset_x=float(shadow_offset_x),
shadow_offset_y=float(shadow_offset_y),
shadow_color=shadow_color_arr,
shadow_opacity=float(shadow_opacity),
shadow_blur_radius=int(shadow_blur),
)
# Add to environment # Add to environment
env.update({ env.update({
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
@@ -814,6 +1585,17 @@ def bind_typography_primitives(env: dict) -> dict:
'text-strip-anchor-x': prim_text_strip_anchor_x, 'text-strip-anchor-x': prim_text_strip_anchor_x,
'text-strip-anchor-y': prim_text_strip_anchor_y, 'text-strip-anchor-y': prim_text_strip_anchor_y,
'place-text-strip': prim_place_text_strip, 'place-text-strip': prim_place_text_strip,
# Gradient primitives
'linear-gradient': prim_linear_gradient,
'radial-gradient': prim_radial_gradient,
'multi-stop-gradient': prim_multi_stop_gradient,
'place-text-strip-gradient': prim_place_text_strip_gradient,
# Rotation
'place-text-strip-rotated': prim_place_text_strip_rotated,
# Shadow
'place-text-strip-shadow': prim_place_text_strip_shadow,
# Combined FX
'place-text-strip-fx': prim_place_text_strip_fx,
}) })
return env return env

486
test_typography_fx.py Normal file
View File

@@ -0,0 +1,486 @@
#!/usr/bin/env python3
"""
Tests for typography FX: gradients, rotation, shadow, and combined effects.
"""
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
from streaming.jax_typography import (
render_text_strip, place_text_strip_jax, _load_font,
make_linear_gradient, make_radial_gradient, make_multi_stop_gradient,
place_text_strip_gradient_jax, rotate_strip_jax,
place_text_strip_shadow_jax, place_text_strip_fx_jax,
bind_typography_primitives,
)
def make_frame(w=400, h=200):
"""Create a dark gray test frame."""
return jnp.full((h, w, 3), 40, dtype=jnp.uint8)
def get_strip(text="Hello", font_size=48):
"""Get a pre-rendered text strip."""
return render_text_strip(text, None, font_size)
def has_visible_pixels(frame, threshold=50):
"""Check if frame has pixels above threshold."""
return int(frame.max()) > threshold
def save_debug(name, frame):
"""Save frame for visual inspection."""
arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame
Image.fromarray(arr).save(f"/tmp/fx_{name}.png")
# =============================================================================
# Gradient Tests
# =============================================================================
def test_linear_gradient_shape():
grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255))
assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}"
assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}"
# Left edge should be red-ish, right edge blue-ish
assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}"
assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}"
print("PASS: test_linear_gradient_shape")
return True
def test_linear_gradient_angle():
# 90 degrees: top-to-bottom
grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0)
# Top row should be red, bottom row should be blue
assert grad[0, 50, 0] > 0.8, "Top should be red"
assert grad[-1, 50, 2] > 0.8, "Bottom should be blue"
print("PASS: test_linear_gradient_angle")
return True
def test_radial_gradient_shape():
grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128))
assert grad.shape == (100, 100, 3)
# Center should be yellow (color1)
assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)"
assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)"
# Corner should be closer to dark blue (color2)
assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue"
print("PASS: test_radial_gradient_shape")
return True
def test_multi_stop_gradient():
stops = [
(0.0, (255, 0, 0)),
(0.5, (0, 255, 0)),
(1.0, (0, 0, 255)),
]
grad = make_multi_stop_gradient(100, 10, stops)
assert grad.shape == (10, 100, 3)
# Left: red, Middle: green, Right: blue
assert grad[5, 0, 0] > 0.8, "Left should be red"
assert grad[5, 50, 1] > 0.8, "Middle should be green"
assert grad[5, -1, 2] > 0.8, "Right should be blue"
print("PASS: test_multi_stop_gradient")
return True
def test_place_gradient():
"""Test gradient text rendering produces visible output."""
frame = make_frame()
strip = get_strip()
grad = make_linear_gradient(strip.width, strip.height,
(255, 0, 0), (0, 0, 255))
grad_jax = jnp.asarray(grad)
strip_img = jnp.asarray(strip.image)
result = place_text_strip_gradient_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
grad_jax, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
assert result.shape == frame.shape
# Should have visible colored pixels
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Gradient text should be visible"
save_debug("gradient", result)
print("PASS: test_place_gradient")
return True
# =============================================================================
# Rotation Tests
# =============================================================================
def test_rotate_strip_identity():
"""Rotation by 0 degrees should preserve content."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
rotated = rotate_strip_jax(strip_img, 0.0)
# Output is larger (diagonal size)
assert rotated.shape[2] == 4, "Should be RGBA"
assert rotated.shape[0] >= strip.height
assert rotated.shape[1] >= strip.width
# Alpha should have non-zero pixels (text was preserved)
assert rotated[:, :, 3].max() > 200, "Should have visible alpha"
print("PASS: test_rotate_strip_identity")
return True
def test_rotate_strip_90():
"""Rotation by 90 degrees."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
rotated = rotate_strip_jax(strip_img, 90.0)
assert rotated.shape[2] == 4
# Should still have visible content
assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha"
save_debug("rotated_90", np.array(rotated))
print("PASS: test_rotate_strip_90")
return True
def test_rotate_360_exact():
"""360-degree rotation must be pixel-exact (regression test for trig snapping)."""
strip = get_strip()
strip_img = jnp.asarray(strip.image)
sh, sw = strip.height, strip.width
rotated = rotate_strip_jax(strip_img, 360.0)
rh, rw = rotated.shape[:2]
off_y = (rh - sh) // 2
off_x = (rw - sw) // 2
crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw])
orig = np.array(strip_img)
d = np.abs(crop.astype(np.int16) - orig.astype(np.int16))
max_diff = int(d.max())
assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}"
print("PASS: test_rotate_360_exact")
return True
def test_place_rotated():
"""Test rotated text placement produces visible output."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 0], dtype=jnp.float32)
result = place_text_strip_fx_jax(
frame, strip_img, 200.0, 100.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
angle=30.0,
)
assert result.shape == frame.shape
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Rotated text should be visible"
save_debug("rotated_30", result)
print("PASS: test_place_rotated")
return True
# =============================================================================
# Shadow Tests
# =============================================================================
def test_shadow_basic():
"""Test shadow produces visible offset copy."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_shadow_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=5.0, shadow_offset_y=5.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.8,
)
assert result.shape == frame.shape
# Should have both bright (text) and dark (shadow) pixels
assert result.max() > 200, "Should have bright text"
save_debug("shadow_basic", result)
print("PASS: test_shadow_basic")
return True
def test_shadow_blur():
"""Test blurred shadow."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_shadow_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=4.0, shadow_offset_y=4.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.7,
shadow_blur_radius=3,
)
assert result.shape == frame.shape
save_debug("shadow_blur", result)
print("PASS: test_shadow_blur")
return True
# =============================================================================
# Combined FX Tests
# =============================================================================
def test_fx_combined():
"""Test combined gradient + shadow + rotation."""
frame = make_frame(500, 300)
strip = get_strip("FX Test", 64)
strip_img = jnp.asarray(strip.image)
grad = make_linear_gradient(strip.width, strip.height,
(255, 100, 0), (0, 100, 255))
grad_jax = jnp.asarray(grad)
result = place_text_strip_fx_jax(
frame, strip_img, 250.0, 150.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
gradient_map=grad_jax,
angle=15.0,
shadow_offset_x=4.0, shadow_offset_y=4.0,
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
shadow_opacity=0.6,
shadow_blur_radius=2,
)
assert result.shape == frame.shape
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
assert diff.max() > 50, "Combined FX should produce visible output"
save_debug("fx_combined", result)
print("PASS: test_fx_combined")
return True
def test_fx_no_effects():
"""FX function with no effects should match basic place_text_strip_jax."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
# Using FX function with defaults
result_fx = place_text_strip_fx_jax(
frame, strip_img, 50.0, 100.0,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
# Using original function
result_orig = place_text_strip_jax(
frame, strip_img, 50.0, 100.0,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16))
max_diff = int(diff.max())
assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}"
print("PASS: test_fx_no_effects")
return True
# =============================================================================
# S-Expression Binding Tests
# =============================================================================
def test_sexp_bindings():
"""Test that all new primitives are registered."""
env = {}
bind_typography_primitives(env)
expected = [
'linear-gradient', 'radial-gradient', 'multi-stop-gradient',
'place-text-strip-gradient', 'place-text-strip-rotated',
'place-text-strip-shadow', 'place-text-strip-fx',
]
for name in expected:
assert name in env, f"Missing binding: {name}"
print("PASS: test_sexp_bindings")
return True
def test_sexp_gradient_primitive():
"""Test gradient primitive via binding."""
env = {}
bind_typography_primitives(env)
strip = env['render-text-strip']("Test", 36)
grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255))
assert grad.shape == (strip.height, strip.width, 3)
print("PASS: test_sexp_gradient_primitive")
return True
def test_sexp_fx_primitive():
"""Test combined FX primitive via binding."""
env = {}
bind_typography_primitives(env)
strip = env['render-text-strip']("FX", 36)
frame = make_frame()
result = env['place-text-strip-fx'](
frame, strip, 100.0, 80.0,
color=(255, 200, 0), opacity=0.9,
shadow_offset_x=3, shadow_offset_y=3,
shadow_opacity=0.5,
)
assert result.shape == frame.shape
print("PASS: test_sexp_fx_primitive")
return True
# =============================================================================
# JIT Compilation Test
# =============================================================================
def test_jit_fx():
"""Test that place_text_strip_fx_jax can be JIT compiled."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
# JIT compile with static args for angle and blur radius
@jax.jit
def render(frame, x, y, opacity):
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
shadow_offset_x=3.0, shadow_offset_y=3.0,
shadow_color=shadow_color,
shadow_opacity=0.5,
shadow_blur_radius=2,
)
# First call traces, second uses cache
result1 = render(frame, 50.0, 100.0, 1.0)
result2 = render(frame, 60.0, 90.0, 0.8)
assert result1.shape == frame.shape
assert result2.shape == frame.shape
print("PASS: test_jit_fx")
return True
def test_jit_gradient():
"""Test that gradient placement can be JIT compiled."""
frame = make_frame()
strip = get_strip()
strip_img = jnp.asarray(strip.image)
grad = jnp.asarray(make_linear_gradient(strip.width, strip.height,
(255, 0, 0), (0, 0, 255)))
@jax.jit
def render(frame, x, y):
return place_text_strip_gradient_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
grad, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
)
result = render(frame, 50.0, 100.0)
assert result.shape == frame.shape
print("PASS: test_jit_gradient")
return True
# =============================================================================
# Main
# =============================================================================
def main():
print("=" * 60)
print("Typography FX Tests")
print("=" * 60)
tests = [
# Gradients
test_linear_gradient_shape,
test_linear_gradient_angle,
test_radial_gradient_shape,
test_multi_stop_gradient,
test_place_gradient,
# Rotation
test_rotate_strip_identity,
test_rotate_strip_90,
test_rotate_360_exact,
test_place_rotated,
# Shadow
test_shadow_basic,
test_shadow_blur,
# Combined FX
test_fx_combined,
test_fx_no_effects,
# S-expression bindings
test_sexp_bindings,
test_sexp_gradient_primitive,
test_sexp_fx_primitive,
# JIT compilation
test_jit_fx,
test_jit_gradient,
]
results = []
for test in tests:
try:
results.append(test())
except Exception as e:
print(f"FAIL: {test.__name__}: {e}")
import traceback
traceback.print_exc()
results.append(False)
print("=" * 60)
passed = sum(r for r in results if r)
total = len(results)
print(f"Results: {passed}/{total} passed")
if passed == total:
print("ALL TESTS PASSED!")
else:
print(f"FAILED: {total - passed} tests")
print("=" * 60)
return passed == total
if __name__ == "__main__":
import sys
sys.exit(0 if main() else 1)