Add shadow, gradient, rotation FX to JAX typography with pixel-exact precision
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s
- Add gradient functions: linear, radial, and multi-stop color maps - Add RGBA strip rotation with bilinear interpolation - Add shadow compositing with optional Gaussian blur - Add combined place_text_strip_fx_jax pipeline (gradient + rotation + shadow) - Add 7 new S-expression bindings for all FX primitives - Extract shared _composite_strip_onto_frame helper - Fix rotation precision: snap trig values near 0/±1 to exact values, use pixel-center convention (dim-1)/2, and parity-matched output buffers - All 99 tests pass with zero pixel differences Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -54,9 +54,12 @@ Kerning support:
|
|||||||
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
|
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from typing import Tuple, Dict, Any
|
from jax import lax
|
||||||
|
from typing import Tuple, Dict, Any, List, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@@ -637,6 +640,661 @@ def place_glyph_simple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Gradient Functions (compile-time: generate color maps from strip dimensions)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def make_linear_gradient(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
color1: tuple,
|
||||||
|
color2: tuple,
|
||||||
|
angle: float = 0.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Create a linear gradient color map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width, height: Dimensions of the gradient (match strip dimensions)
|
||||||
|
color1: Start color (R, G, B) 0-255
|
||||||
|
color2: End color (R, G, B) 0-255
|
||||||
|
angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(height, width, 3) float32 array with values in [0, 1]
|
||||||
|
"""
|
||||||
|
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
|
||||||
|
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
|
||||||
|
|
||||||
|
# Create coordinate grid
|
||||||
|
ys = np.arange(height, dtype=np.float32)
|
||||||
|
xs = np.arange(width, dtype=np.float32)
|
||||||
|
yy, xx = np.meshgrid(ys, xs, indexing='ij')
|
||||||
|
|
||||||
|
# Normalize to [0, 1]
|
||||||
|
nx = xx / max(width - 1, 1)
|
||||||
|
ny = yy / max(height - 1, 1)
|
||||||
|
|
||||||
|
# Project onto gradient axis
|
||||||
|
theta = angle * np.pi / 180.0
|
||||||
|
cos_t = np.cos(theta)
|
||||||
|
sin_t = np.sin(theta)
|
||||||
|
|
||||||
|
# Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1]
|
||||||
|
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
|
||||||
|
# Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|)
|
||||||
|
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
|
||||||
|
if max_proj > 0:
|
||||||
|
t = (proj / max_proj + 1.0) / 2.0
|
||||||
|
else:
|
||||||
|
t = np.full_like(proj, 0.5)
|
||||||
|
t = np.clip(t, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Interpolate
|
||||||
|
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
def make_radial_gradient(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
color1: tuple,
|
||||||
|
color2: tuple,
|
||||||
|
center_x: float = 0.5,
|
||||||
|
center_y: float = 0.5,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Create a radial gradient color map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width, height: Dimensions
|
||||||
|
color1: Inner color (R, G, B)
|
||||||
|
color2: Outer color (R, G, B)
|
||||||
|
center_x, center_y: Center position in [0, 1] (0.5 = center)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(height, width, 3) float32 array with values in [0, 1]
|
||||||
|
"""
|
||||||
|
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
|
||||||
|
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
|
||||||
|
|
||||||
|
ys = np.arange(height, dtype=np.float32)
|
||||||
|
xs = np.arange(width, dtype=np.float32)
|
||||||
|
yy, xx = np.meshgrid(ys, xs, indexing='ij')
|
||||||
|
|
||||||
|
# Normalize to [0, 1]
|
||||||
|
nx = xx / max(width - 1, 1)
|
||||||
|
ny = yy / max(height - 1, 1)
|
||||||
|
|
||||||
|
# Distance from center, normalized so corners are ~1.0
|
||||||
|
dx = nx - center_x
|
||||||
|
dy = ny - center_y
|
||||||
|
# Max possible distance from center to a corner
|
||||||
|
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
|
||||||
|
if max_dist > 0:
|
||||||
|
t = np.sqrt(dx**2 + dy**2) / max_dist
|
||||||
|
else:
|
||||||
|
t = np.zeros_like(dx)
|
||||||
|
t = np.clip(t, 0.0, 1.0)
|
||||||
|
|
||||||
|
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
def make_multi_stop_gradient(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
stops: list,
|
||||||
|
angle: float = 0.0,
|
||||||
|
radial: bool = False,
|
||||||
|
center_x: float = 0.5,
|
||||||
|
center_y: float = 0.5,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Create a multi-stop gradient color map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width, height: Dimensions
|
||||||
|
stops: List of (position, (R, G, B)) tuples, position in [0, 1]
|
||||||
|
angle: Gradient angle in degrees (for linear mode)
|
||||||
|
radial: If True, use radial gradient
|
||||||
|
center_x, center_y: Center for radial gradient
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(height, width, 3) float32 array with values in [0, 1]
|
||||||
|
"""
|
||||||
|
if len(stops) < 2:
|
||||||
|
if len(stops) == 1:
|
||||||
|
c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0
|
||||||
|
return np.broadcast_to(c, (height, width, 3)).copy()
|
||||||
|
return np.zeros((height, width, 3), dtype=np.float32)
|
||||||
|
|
||||||
|
# Sort stops by position
|
||||||
|
stops = sorted(stops, key=lambda s: s[0])
|
||||||
|
|
||||||
|
ys = np.arange(height, dtype=np.float32)
|
||||||
|
xs = np.arange(width, dtype=np.float32)
|
||||||
|
yy, xx = np.meshgrid(ys, xs, indexing='ij')
|
||||||
|
|
||||||
|
nx = xx / max(width - 1, 1)
|
||||||
|
ny = yy / max(height - 1, 1)
|
||||||
|
|
||||||
|
if radial:
|
||||||
|
dx = nx - center_x
|
||||||
|
dy = ny - center_y
|
||||||
|
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
|
||||||
|
t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6)
|
||||||
|
else:
|
||||||
|
theta = angle * np.pi / 180.0
|
||||||
|
cos_t = np.cos(theta)
|
||||||
|
sin_t = np.sin(theta)
|
||||||
|
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
|
||||||
|
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
|
||||||
|
if max_proj > 0:
|
||||||
|
t = (proj / max_proj + 1.0) / 2.0
|
||||||
|
else:
|
||||||
|
t = np.full_like(proj, 0.5)
|
||||||
|
|
||||||
|
t = np.clip(t, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Build gradient from stops using piecewise linear interpolation
|
||||||
|
colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops])
|
||||||
|
positions = np.array([s[0] for s in stops], dtype=np.float32)
|
||||||
|
|
||||||
|
# Start with first color
|
||||||
|
gradient = np.broadcast_to(colors[0], (height, width, 3)).copy()
|
||||||
|
|
||||||
|
for i in range(len(stops) - 1):
|
||||||
|
p0, p1 = positions[i], positions[i + 1]
|
||||||
|
c0, c1 = colors[i], colors[i + 1]
|
||||||
|
|
||||||
|
if p1 <= p0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Segment interpolation factor
|
||||||
|
seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0)
|
||||||
|
# Only apply where t >= p0
|
||||||
|
mask = (t >= p0)[:, :, None]
|
||||||
|
seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None]
|
||||||
|
gradient = np.where(mask, seg_color, gradient)
|
||||||
|
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
def _composite_strip_onto_frame(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
strip_rgb: jnp.ndarray,
|
||||||
|
strip_alpha: jnp.ndarray,
|
||||||
|
dst_x: jnp.ndarray,
|
||||||
|
dst_y: jnp.ndarray,
|
||||||
|
sh: int,
|
||||||
|
sw: int,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Core compositing: place tinted+alpha strip onto frame using padded buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB uint8
|
||||||
|
strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB
|
||||||
|
strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha
|
||||||
|
dst_x, dst_y: int32 destination position
|
||||||
|
sh, sw: strip dimensions (compile-time constants)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composited frame (H, W, 3) uint8
|
||||||
|
"""
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
|
||||||
|
buf_h = h + 2 * sh
|
||||||
|
buf_w = w + 2 * sw
|
||||||
|
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||||
|
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||||
|
|
||||||
|
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
|
||||||
|
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
|
||||||
|
|
||||||
|
rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0))
|
||||||
|
alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
|
||||||
|
|
||||||
|
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
|
||||||
|
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
|
||||||
|
|
||||||
|
# PIL-compatible integer alpha blending
|
||||||
|
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
dst_int = frame.astype(jnp.int32)
|
||||||
|
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||||
|
|
||||||
|
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def place_text_strip_gradient_jax(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
strip_image: jnp.ndarray,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
baseline_y: int,
|
||||||
|
bearing_x: float,
|
||||||
|
gradient_map: jnp.ndarray,
|
||||||
|
opacity: float = 1.0,
|
||||||
|
anchor_x: float = 0.0,
|
||||||
|
anchor_y: float = 0.0,
|
||||||
|
stroke_width: int = 0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Place text strip with gradient coloring instead of solid color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB frame
|
||||||
|
strip_image: (sh, sw, 4) RGBA text strip
|
||||||
|
gradient_map: (sh, sw, 3) float32 color map in [0, 1]
|
||||||
|
Other args same as place_text_strip_jax
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composited frame
|
||||||
|
"""
|
||||||
|
sh, sw = strip_image.shape[:2]
|
||||||
|
|
||||||
|
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||||
|
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
# Extract alpha with opacity
|
||||||
|
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
opacity_int = jnp.round(opacity * 255)
|
||||||
|
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
|
||||||
|
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
|
||||||
|
# Apply gradient instead of solid color
|
||||||
|
tinted = strip_rgb * gradient_map
|
||||||
|
|
||||||
|
return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Strip Rotation (RGBA bilinear interpolation)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _sample_rgba(strip, x, y):
|
||||||
|
"""Bilinear sample all 4 RGBA channels from a strip.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strip: (H, W, 4) RGBA float32
|
||||||
|
x, y: coordinate arrays (flattened)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(r, g, b, a) each same shape as x
|
||||||
|
"""
|
||||||
|
h, w = strip.shape[:2]
|
||||||
|
|
||||||
|
x0 = jnp.floor(x).astype(jnp.int32)
|
||||||
|
y0 = jnp.floor(y).astype(jnp.int32)
|
||||||
|
x1 = x0 + 1
|
||||||
|
y1 = y0 + 1
|
||||||
|
|
||||||
|
fx = x - x0.astype(jnp.float32)
|
||||||
|
fy = y - y0.astype(jnp.float32)
|
||||||
|
|
||||||
|
valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h)
|
||||||
|
valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h)
|
||||||
|
valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h)
|
||||||
|
valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h)
|
||||||
|
|
||||||
|
x0_safe = jnp.clip(x0, 0, w - 1)
|
||||||
|
x1_safe = jnp.clip(x1, 0, w - 1)
|
||||||
|
y0_safe = jnp.clip(y0, 0, h - 1)
|
||||||
|
y1_safe = jnp.clip(y1, 0, h - 1)
|
||||||
|
|
||||||
|
channels = []
|
||||||
|
for c in range(4):
|
||||||
|
c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0)
|
||||||
|
c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0)
|
||||||
|
c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0)
|
||||||
|
c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0)
|
||||||
|
|
||||||
|
val = (c00 * (1 - fx) * (1 - fy) +
|
||||||
|
c10 * fx * (1 - fy) +
|
||||||
|
c01 * (1 - fx) * fy +
|
||||||
|
c11 * fx * fy)
|
||||||
|
channels.append(val)
|
||||||
|
|
||||||
|
return channels[0], channels[1], channels[2], channels[3]
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_strip_jax(
|
||||||
|
strip_image: jnp.ndarray,
|
||||||
|
angle: float,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Rotate an RGBA strip by angle (degrees), counter-clockwise.
|
||||||
|
|
||||||
|
Output buffer is sized to contain the full rotated strip.
|
||||||
|
The output size is ceil(sqrt(w^2 + h^2)), computed at trace time
|
||||||
|
from the strip's static shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strip_image: (H, W, 4) RGBA uint8
|
||||||
|
angle: Rotation angle in degrees
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(out_h, out_w, 4) RGBA uint8 - rotated strip
|
||||||
|
"""
|
||||||
|
sh, sw = strip_image.shape[:2]
|
||||||
|
|
||||||
|
# Output size: diagonal of original strip (compile-time constant).
|
||||||
|
# Ensure output dimensions have same parity as source so that the
|
||||||
|
# center offset (out - src) / 2 is always an integer. Otherwise
|
||||||
|
# identity rotations would place content at half-pixel offsets.
|
||||||
|
diag = int(math.ceil(math.sqrt(sw * sw + sh * sh)))
|
||||||
|
out_w = diag + ((diag % 2) != (sw % 2))
|
||||||
|
out_h = diag + ((diag % 2) != (sh % 2))
|
||||||
|
|
||||||
|
# Center of input strip and output buffer (pixel-center convention).
|
||||||
|
# Using (dim-1)/2 ensures integer coords map to integer coords for
|
||||||
|
# identity rotation regardless of even/odd dimension parity.
|
||||||
|
src_cx = (sw - 1) / 2.0
|
||||||
|
src_cy = (sh - 1) / 2.0
|
||||||
|
dst_cx = (out_w - 1) / 2.0
|
||||||
|
dst_cy = (out_h - 1) / 2.0
|
||||||
|
|
||||||
|
# Convert to radians and snap trig values near 0/±1 to exact values.
|
||||||
|
# Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing
|
||||||
|
# bilinear blending at pixel edges and 1-value differences.
|
||||||
|
theta = angle * jnp.pi / 180.0
|
||||||
|
cos_t = jnp.cos(theta)
|
||||||
|
sin_t = jnp.sin(theta)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
|
||||||
|
|
||||||
|
# Create output coordinate grid
|
||||||
|
y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w)
|
||||||
|
x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w)
|
||||||
|
|
||||||
|
# Inverse rotation: map output coords to source coords
|
||||||
|
x_centered = x_coords.astype(jnp.float32) - dst_cx
|
||||||
|
y_centered = y_coords.astype(jnp.float32) - dst_cy
|
||||||
|
|
||||||
|
src_x = cos_t * x_centered - sin_t * y_centered + src_cx
|
||||||
|
src_y = sin_t * x_centered + cos_t * y_centered + src_cy
|
||||||
|
|
||||||
|
# Sample all 4 channels
|
||||||
|
strip_f = strip_image.astype(jnp.float32)
|
||||||
|
r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten())
|
||||||
|
|
||||||
|
return jnp.stack([
|
||||||
|
jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
|
||||||
|
jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
|
||||||
|
jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
|
||||||
|
jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
|
||||||
|
], axis=2)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Shadow Compositing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray:
|
||||||
|
"""Blur a single-channel alpha array using Gaussian convolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha: (H, W) float32 alpha channel
|
||||||
|
radius: Blur radius (compile-time constant)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(H, W) float32 blurred alpha
|
||||||
|
"""
|
||||||
|
size = radius * 2 + 1
|
||||||
|
x = jnp.arange(size, dtype=jnp.float32) - radius
|
||||||
|
sigma = max(radius / 2.0, 0.5)
|
||||||
|
gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2))
|
||||||
|
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
||||||
|
kernel = jnp.outer(gaussian_1d, gaussian_1d)
|
||||||
|
|
||||||
|
# Use JAX conv with SAME padding
|
||||||
|
h, w = alpha.shape
|
||||||
|
data_4d = alpha.reshape(1, h, w, 1)
|
||||||
|
kernel_4d = kernel.reshape(size, size, 1, 1)
|
||||||
|
|
||||||
|
result = lax.conv_general_dilated(
|
||||||
|
data_4d, kernel_4d,
|
||||||
|
window_strides=(1, 1),
|
||||||
|
padding='SAME',
|
||||||
|
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
|
||||||
|
)
|
||||||
|
return result.reshape(h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def place_text_strip_shadow_jax(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
strip_image: jnp.ndarray,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
baseline_y: int,
|
||||||
|
bearing_x: float,
|
||||||
|
color: jnp.ndarray,
|
||||||
|
opacity: float = 1.0,
|
||||||
|
anchor_x: float = 0.0,
|
||||||
|
anchor_y: float = 0.0,
|
||||||
|
stroke_width: int = 0,
|
||||||
|
shadow_offset_x: float = 3.0,
|
||||||
|
shadow_offset_y: float = 3.0,
|
||||||
|
shadow_color: jnp.ndarray = None,
|
||||||
|
shadow_opacity: float = 0.5,
|
||||||
|
shadow_blur_radius: int = 0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Place text strip with a drop shadow.
|
||||||
|
|
||||||
|
Composites the strip twice: first as shadow (offset, colored, optionally blurred),
|
||||||
|
then the text itself on top.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB frame
|
||||||
|
strip_image: (sh, sw, 4) RGBA text strip
|
||||||
|
shadow_offset_x/y: Shadow offset in pixels
|
||||||
|
shadow_color: (3,) RGB color for shadow (default black)
|
||||||
|
shadow_opacity: Shadow opacity 0-1
|
||||||
|
shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time)
|
||||||
|
Other args same as place_text_strip_jax
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composited frame
|
||||||
|
"""
|
||||||
|
if shadow_color is None:
|
||||||
|
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
|
||||||
|
|
||||||
|
sh, sw = strip_image.shape[:2]
|
||||||
|
|
||||||
|
# --- Shadow pass ---
|
||||||
|
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
|
||||||
|
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
# Shadow alpha from strip alpha
|
||||||
|
shadow_opacity_int = jnp.round(shadow_opacity * 255)
|
||||||
|
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
|
||||||
|
|
||||||
|
if shadow_blur_radius > 0:
|
||||||
|
# Blur the alpha channel for soft shadow
|
||||||
|
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
|
||||||
|
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
|
||||||
|
else:
|
||||||
|
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1)
|
||||||
|
|
||||||
|
# Shadow RGB: solid shadow color
|
||||||
|
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
|
||||||
|
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
|
||||||
|
|
||||||
|
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
|
||||||
|
|
||||||
|
# --- Text pass ---
|
||||||
|
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||||
|
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
opacity_int = jnp.round(opacity * 255)
|
||||||
|
text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
|
||||||
|
color_norm = color.astype(jnp.float32) / 255.0
|
||||||
|
tinted = strip_rgb * color_norm
|
||||||
|
|
||||||
|
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Combined FX Pipeline
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def place_text_strip_fx_jax(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
strip_image: jnp.ndarray,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
baseline_y: int = 0,
|
||||||
|
bearing_x: float = 0.0,
|
||||||
|
color: jnp.ndarray = None,
|
||||||
|
opacity: float = 1.0,
|
||||||
|
anchor_x: float = 0.0,
|
||||||
|
anchor_y: float = 0.0,
|
||||||
|
stroke_width: int = 0,
|
||||||
|
gradient_map: jnp.ndarray = None,
|
||||||
|
angle: float = 0.0,
|
||||||
|
shadow_offset_x: float = 0.0,
|
||||||
|
shadow_offset_y: float = 0.0,
|
||||||
|
shadow_color: jnp.ndarray = None,
|
||||||
|
shadow_opacity: float = 0.0,
|
||||||
|
shadow_blur_radius: int = 0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Combined text placement with gradient, rotation, and shadow.
|
||||||
|
|
||||||
|
Pipeline order:
|
||||||
|
1. Build color layer (solid color or gradient map)
|
||||||
|
2. Rotate strip + color layer if angle != 0
|
||||||
|
3. Composite shadow if shadow_opacity > 0
|
||||||
|
4. Composite text
|
||||||
|
|
||||||
|
Note: angle and shadow_blur_radius should be compile-time constants
|
||||||
|
for optimal JIT performance (they affect buffer shapes/kernel sizes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB frame
|
||||||
|
strip_image: (sh, sw, 4) RGBA text strip
|
||||||
|
x, y: Anchor point position
|
||||||
|
color: (3,) RGB color (ignored if gradient_map provided)
|
||||||
|
opacity: Text opacity
|
||||||
|
gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color
|
||||||
|
angle: Rotation angle in degrees (0 = no rotation)
|
||||||
|
shadow_offset_x/y: Shadow offset
|
||||||
|
shadow_color: (3,) RGB shadow color
|
||||||
|
shadow_opacity: Shadow opacity (0 = no shadow)
|
||||||
|
shadow_blur_radius: Shadow blur radius
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composited frame
|
||||||
|
"""
|
||||||
|
if color is None:
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
if shadow_color is None:
|
||||||
|
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
|
||||||
|
|
||||||
|
sh, sw = strip_image.shape[:2]
|
||||||
|
|
||||||
|
# --- Step 1: Build color layer ---
|
||||||
|
if gradient_map is not None:
|
||||||
|
color_layer = gradient_map # (sh, sw, 3) float32 [0, 1]
|
||||||
|
else:
|
||||||
|
color_norm = color.astype(jnp.float32) / 255.0
|
||||||
|
color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3))
|
||||||
|
|
||||||
|
# --- Step 2: Rotate if needed ---
|
||||||
|
# angle is expected to be a compile-time constant or static value
|
||||||
|
# We check at Python level to avoid tracing issues with dynamic shapes
|
||||||
|
use_rotation = not isinstance(angle, (int, float)) or angle != 0.0
|
||||||
|
|
||||||
|
if use_rotation:
|
||||||
|
# Rotate the strip
|
||||||
|
rotated_strip = rotate_strip_jax(strip_image, angle)
|
||||||
|
rh, rw = rotated_strip.shape[:2]
|
||||||
|
|
||||||
|
# Rotate the color layer by building a 4-channel color+dummy image
|
||||||
|
# Actually, just re-create color layer at rotated size
|
||||||
|
if gradient_map is not None:
|
||||||
|
# Rotate gradient map: pack into 3-channel "image", rotate via sampling
|
||||||
|
grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8)
|
||||||
|
# Create RGBA from gradient (alpha=255 everywhere)
|
||||||
|
grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2)
|
||||||
|
rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle)
|
||||||
|
color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
else:
|
||||||
|
# Solid color: just broadcast to rotated size
|
||||||
|
color_norm = color.astype(jnp.float32) / 255.0
|
||||||
|
color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3))
|
||||||
|
|
||||||
|
# Update anchor point for rotation (pixel-center convention)
|
||||||
|
# Rotate the anchor offset around the strip center
|
||||||
|
theta = angle * jnp.pi / 180.0
|
||||||
|
cos_t = jnp.cos(theta)
|
||||||
|
sin_t = jnp.sin(theta)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
|
||||||
|
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
|
||||||
|
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
|
||||||
|
|
||||||
|
# Original anchor relative to strip pixel center
|
||||||
|
src_cx = (sw - 1) / 2.0
|
||||||
|
src_cy = (sh - 1) / 2.0
|
||||||
|
dst_cx = (rw - 1) / 2.0
|
||||||
|
dst_cy = (rh - 1) / 2.0
|
||||||
|
|
||||||
|
ax_rel = anchor_x - src_cx
|
||||||
|
ay_rel = anchor_y - src_cy
|
||||||
|
|
||||||
|
# Rotate anchor point (forward rotation, not inverse)
|
||||||
|
new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx
|
||||||
|
new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy
|
||||||
|
|
||||||
|
strip_image = rotated_strip
|
||||||
|
anchor_x = new_ax
|
||||||
|
anchor_y = new_ay
|
||||||
|
sh, sw = rh, rw
|
||||||
|
|
||||||
|
# --- Step 3: Shadow ---
|
||||||
|
has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0
|
||||||
|
if has_shadow:
|
||||||
|
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
|
||||||
|
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
shadow_opacity_int = jnp.round(shadow_opacity * 255)
|
||||||
|
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
|
||||||
|
|
||||||
|
if shadow_blur_radius > 0:
|
||||||
|
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
|
||||||
|
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
|
||||||
|
else:
|
||||||
|
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
shadow_alpha = shadow_alpha[:, :, None]
|
||||||
|
|
||||||
|
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
|
||||||
|
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
|
||||||
|
|
||||||
|
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
|
||||||
|
|
||||||
|
# --- Step 4: Composite text ---
|
||||||
|
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||||
|
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
opacity_int = jnp.round(opacity * 255)
|
||||||
|
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
|
||||||
|
text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
|
||||||
|
tinted = strip_rgb * color_layer
|
||||||
|
|
||||||
|
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# S-Expression Primitive Bindings
|
# S-Expression Primitive Bindings
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -788,6 +1446,119 @@ def bind_typography_primitives(env: dict) -> dict:
|
|||||||
stroke_width=strip.stroke_width
|
stroke_width=strip.stroke_width
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --- Gradient primitives ---
|
||||||
|
|
||||||
|
def prim_linear_gradient(strip, color1, color2, angle=0.0):
|
||||||
|
"""Create linear gradient color map for a text strip. Compile-time."""
|
||||||
|
grad = make_linear_gradient(strip.width, strip.height,
|
||||||
|
tuple(int(c) for c in color1),
|
||||||
|
tuple(int(c) for c in color2),
|
||||||
|
float(angle))
|
||||||
|
return jnp.asarray(grad)
|
||||||
|
|
||||||
|
def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5):
|
||||||
|
"""Create radial gradient color map for a text strip. Compile-time."""
|
||||||
|
grad = make_radial_gradient(strip.width, strip.height,
|
||||||
|
tuple(int(c) for c in color1),
|
||||||
|
tuple(int(c) for c in color2),
|
||||||
|
float(center_x), float(center_y))
|
||||||
|
return jnp.asarray(grad)
|
||||||
|
|
||||||
|
def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False,
|
||||||
|
center_x=0.5, center_y=0.5):
|
||||||
|
"""Create multi-stop gradient for a text strip. Compile-time.
|
||||||
|
|
||||||
|
stops: list of (position, (R, G, B)) tuples
|
||||||
|
"""
|
||||||
|
parsed_stops = []
|
||||||
|
for s in stops:
|
||||||
|
pos = float(s[0])
|
||||||
|
color_tuple = tuple(int(c) for c in s[1])
|
||||||
|
parsed_stops.append((pos, color_tuple))
|
||||||
|
grad = make_multi_stop_gradient(strip.width, strip.height,
|
||||||
|
parsed_stops, float(angle),
|
||||||
|
bool(radial),
|
||||||
|
float(center_x), float(center_y))
|
||||||
|
return jnp.asarray(grad)
|
||||||
|
|
||||||
|
def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0):
|
||||||
|
"""Place text strip with gradient coloring. Runtime JAX operation."""
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
return place_text_strip_gradient_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
gradient_map, opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Rotation primitive ---
|
||||||
|
|
||||||
|
def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255),
|
||||||
|
opacity=1.0, angle=0.0):
|
||||||
|
"""Place text strip with rotation. Runtime JAX operation."""
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
return place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
color=color_arr, opacity=opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width,
|
||||||
|
angle=float(angle),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Shadow primitive ---
|
||||||
|
|
||||||
|
def prim_place_text_strip_shadow(frame, strip, x, y,
|
||||||
|
color=(255, 255, 255), opacity=1.0,
|
||||||
|
shadow_offset_x=3.0, shadow_offset_y=3.0,
|
||||||
|
shadow_color=(0, 0, 0), shadow_opacity=0.5,
|
||||||
|
shadow_blur_radius=0):
|
||||||
|
"""Place text strip with shadow. Runtime JAX operation."""
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
|
||||||
|
return place_text_strip_shadow_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color_arr, opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width,
|
||||||
|
shadow_offset_x=float(shadow_offset_x),
|
||||||
|
shadow_offset_y=float(shadow_offset_y),
|
||||||
|
shadow_color=shadow_color_arr,
|
||||||
|
shadow_opacity=float(shadow_opacity),
|
||||||
|
shadow_blur_radius=int(shadow_blur_radius),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Combined FX primitive ---
|
||||||
|
|
||||||
|
def prim_place_text_strip_fx(frame, strip, x, y,
|
||||||
|
color=(255, 255, 255), opacity=1.0,
|
||||||
|
gradient=None, angle=0.0,
|
||||||
|
shadow_offset_x=0.0, shadow_offset_y=0.0,
|
||||||
|
shadow_color=(0, 0, 0), shadow_opacity=0.0,
|
||||||
|
shadow_blur=0):
|
||||||
|
"""Place text strip with all effects. Runtime JAX operation."""
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
|
||||||
|
return place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
color=color_arr, opacity=opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width,
|
||||||
|
gradient_map=gradient,
|
||||||
|
angle=float(angle),
|
||||||
|
shadow_offset_x=float(shadow_offset_x),
|
||||||
|
shadow_offset_y=float(shadow_offset_y),
|
||||||
|
shadow_color=shadow_color_arr,
|
||||||
|
shadow_opacity=float(shadow_opacity),
|
||||||
|
shadow_blur_radius=int(shadow_blur),
|
||||||
|
)
|
||||||
|
|
||||||
# Add to environment
|
# Add to environment
|
||||||
env.update({
|
env.update({
|
||||||
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
|
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
|
||||||
@@ -814,6 +1585,17 @@ def bind_typography_primitives(env: dict) -> dict:
|
|||||||
'text-strip-anchor-x': prim_text_strip_anchor_x,
|
'text-strip-anchor-x': prim_text_strip_anchor_x,
|
||||||
'text-strip-anchor-y': prim_text_strip_anchor_y,
|
'text-strip-anchor-y': prim_text_strip_anchor_y,
|
||||||
'place-text-strip': prim_place_text_strip,
|
'place-text-strip': prim_place_text_strip,
|
||||||
|
# Gradient primitives
|
||||||
|
'linear-gradient': prim_linear_gradient,
|
||||||
|
'radial-gradient': prim_radial_gradient,
|
||||||
|
'multi-stop-gradient': prim_multi_stop_gradient,
|
||||||
|
'place-text-strip-gradient': prim_place_text_strip_gradient,
|
||||||
|
# Rotation
|
||||||
|
'place-text-strip-rotated': prim_place_text_strip_rotated,
|
||||||
|
# Shadow
|
||||||
|
'place-text-strip-shadow': prim_place_text_strip_shadow,
|
||||||
|
# Combined FX
|
||||||
|
'place-text-strip-fx': prim_place_text_strip_fx,
|
||||||
})
|
})
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|||||||
486
test_typography_fx.py
Normal file
486
test_typography_fx.py
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for typography FX: gradients, rotation, shadow, and combined effects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from streaming.jax_typography import (
|
||||||
|
render_text_strip, place_text_strip_jax, _load_font,
|
||||||
|
make_linear_gradient, make_radial_gradient, make_multi_stop_gradient,
|
||||||
|
place_text_strip_gradient_jax, rotate_strip_jax,
|
||||||
|
place_text_strip_shadow_jax, place_text_strip_fx_jax,
|
||||||
|
bind_typography_primitives,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame(w=400, h=200):
|
||||||
|
"""Create a dark gray test frame."""
|
||||||
|
return jnp.full((h, w, 3), 40, dtype=jnp.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def get_strip(text="Hello", font_size=48):
|
||||||
|
"""Get a pre-rendered text strip."""
|
||||||
|
return render_text_strip(text, None, font_size)
|
||||||
|
|
||||||
|
|
||||||
|
def has_visible_pixels(frame, threshold=50):
|
||||||
|
"""Check if frame has pixels above threshold."""
|
||||||
|
return int(frame.max()) > threshold
|
||||||
|
|
||||||
|
|
||||||
|
def save_debug(name, frame):
|
||||||
|
"""Save frame for visual inspection."""
|
||||||
|
arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame
|
||||||
|
Image.fromarray(arr).save(f"/tmp/fx_{name}.png")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Gradient Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_linear_gradient_shape():
|
||||||
|
grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255))
|
||||||
|
assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}"
|
||||||
|
assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}"
|
||||||
|
# Left edge should be red-ish, right edge blue-ish
|
||||||
|
assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}"
|
||||||
|
assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}"
|
||||||
|
print("PASS: test_linear_gradient_shape")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_gradient_angle():
|
||||||
|
# 90 degrees: top-to-bottom
|
||||||
|
grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0)
|
||||||
|
# Top row should be red, bottom row should be blue
|
||||||
|
assert grad[0, 50, 0] > 0.8, "Top should be red"
|
||||||
|
assert grad[-1, 50, 2] > 0.8, "Bottom should be blue"
|
||||||
|
print("PASS: test_linear_gradient_angle")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_radial_gradient_shape():
|
||||||
|
grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128))
|
||||||
|
assert grad.shape == (100, 100, 3)
|
||||||
|
# Center should be yellow (color1)
|
||||||
|
assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)"
|
||||||
|
assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)"
|
||||||
|
# Corner should be closer to dark blue (color2)
|
||||||
|
assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue"
|
||||||
|
print("PASS: test_radial_gradient_shape")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_stop_gradient():
|
||||||
|
stops = [
|
||||||
|
(0.0, (255, 0, 0)),
|
||||||
|
(0.5, (0, 255, 0)),
|
||||||
|
(1.0, (0, 0, 255)),
|
||||||
|
]
|
||||||
|
grad = make_multi_stop_gradient(100, 10, stops)
|
||||||
|
assert grad.shape == (10, 100, 3)
|
||||||
|
# Left: red, Middle: green, Right: blue
|
||||||
|
assert grad[5, 0, 0] > 0.8, "Left should be red"
|
||||||
|
assert grad[5, 50, 1] > 0.8, "Middle should be green"
|
||||||
|
assert grad[5, -1, 2] > 0.8, "Right should be blue"
|
||||||
|
print("PASS: test_multi_stop_gradient")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_place_gradient():
|
||||||
|
"""Test gradient text rendering produces visible output."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
grad = make_linear_gradient(strip.width, strip.height,
|
||||||
|
(255, 0, 0), (0, 0, 255))
|
||||||
|
grad_jax = jnp.asarray(grad)
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
|
||||||
|
result = place_text_strip_gradient_jax(
|
||||||
|
frame, strip_img, 50.0, 100.0,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
grad_jax, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
# Should have visible colored pixels
|
||||||
|
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||||
|
assert diff.max() > 50, "Gradient text should be visible"
|
||||||
|
save_debug("gradient", result)
|
||||||
|
print("PASS: test_place_gradient")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Rotation Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_rotate_strip_identity():
|
||||||
|
"""Rotation by 0 degrees should preserve content."""
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
|
||||||
|
rotated = rotate_strip_jax(strip_img, 0.0)
|
||||||
|
# Output is larger (diagonal size)
|
||||||
|
assert rotated.shape[2] == 4, "Should be RGBA"
|
||||||
|
assert rotated.shape[0] >= strip.height
|
||||||
|
assert rotated.shape[1] >= strip.width
|
||||||
|
|
||||||
|
# Alpha should have non-zero pixels (text was preserved)
|
||||||
|
assert rotated[:, :, 3].max() > 200, "Should have visible alpha"
|
||||||
|
print("PASS: test_rotate_strip_identity")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_strip_90():
|
||||||
|
"""Rotation by 90 degrees."""
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
|
||||||
|
rotated = rotate_strip_jax(strip_img, 90.0)
|
||||||
|
assert rotated.shape[2] == 4
|
||||||
|
# Should still have visible content
|
||||||
|
assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha"
|
||||||
|
save_debug("rotated_90", np.array(rotated))
|
||||||
|
print("PASS: test_rotate_strip_90")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_360_exact():
|
||||||
|
"""360-degree rotation must be pixel-exact (regression test for trig snapping)."""
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
sh, sw = strip.height, strip.width
|
||||||
|
|
||||||
|
rotated = rotate_strip_jax(strip_img, 360.0)
|
||||||
|
rh, rw = rotated.shape[:2]
|
||||||
|
off_y = (rh - sh) // 2
|
||||||
|
off_x = (rw - sw) // 2
|
||||||
|
|
||||||
|
crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw])
|
||||||
|
orig = np.array(strip_img)
|
||||||
|
d = np.abs(crop.astype(np.int16) - orig.astype(np.int16))
|
||||||
|
max_diff = int(d.max())
|
||||||
|
assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}"
|
||||||
|
print("PASS: test_rotate_360_exact")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_place_rotated():
|
||||||
|
"""Test rotated text placement produces visible output."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 0], dtype=jnp.float32)
|
||||||
|
|
||||||
|
result = place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, 200.0, 100.0,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
color=color, opacity=1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
angle=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||||
|
assert diff.max() > 50, "Rotated text should be visible"
|
||||||
|
save_debug("rotated_30", result)
|
||||||
|
print("PASS: test_place_rotated")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Shadow Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_shadow_basic():
|
||||||
|
"""Test shadow produces visible offset copy."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
|
||||||
|
result = place_text_strip_shadow_jax(
|
||||||
|
frame, strip_img, 50.0, 100.0,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
shadow_offset_x=5.0, shadow_offset_y=5.0,
|
||||||
|
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||||
|
shadow_opacity=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
# Should have both bright (text) and dark (shadow) pixels
|
||||||
|
assert result.max() > 200, "Should have bright text"
|
||||||
|
save_debug("shadow_basic", result)
|
||||||
|
print("PASS: test_shadow_basic")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_shadow_blur():
|
||||||
|
"""Test blurred shadow."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
|
||||||
|
result = place_text_strip_shadow_jax(
|
||||||
|
frame, strip_img, 50.0, 100.0,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
shadow_offset_x=4.0, shadow_offset_y=4.0,
|
||||||
|
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||||
|
shadow_opacity=0.7,
|
||||||
|
shadow_blur_radius=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
save_debug("shadow_blur", result)
|
||||||
|
print("PASS: test_shadow_blur")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Combined FX Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_fx_combined():
|
||||||
|
"""Test combined gradient + shadow + rotation."""
|
||||||
|
frame = make_frame(500, 300)
|
||||||
|
strip = get_strip("FX Test", 64)
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
|
||||||
|
grad = make_linear_gradient(strip.width, strip.height,
|
||||||
|
(255, 100, 0), (0, 100, 255))
|
||||||
|
grad_jax = jnp.asarray(grad)
|
||||||
|
|
||||||
|
result = place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, 250.0, 150.0,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
gradient_map=grad_jax,
|
||||||
|
angle=15.0,
|
||||||
|
shadow_offset_x=4.0, shadow_offset_y=4.0,
|
||||||
|
shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32),
|
||||||
|
shadow_opacity=0.6,
|
||||||
|
shadow_blur_radius=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16))
|
||||||
|
assert diff.max() > 50, "Combined FX should produce visible output"
|
||||||
|
save_debug("fx_combined", result)
|
||||||
|
print("PASS: test_fx_combined")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_fx_no_effects():
|
||||||
|
"""FX function with no effects should match basic place_text_strip_jax."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
|
||||||
|
# Using FX function with defaults
|
||||||
|
result_fx = place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, 50.0, 100.0,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
color=color, opacity=1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using original function
|
||||||
|
result_orig = place_text_strip_jax(
|
||||||
|
frame, strip_img, 50.0, 100.0,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
)
|
||||||
|
|
||||||
|
diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16))
|
||||||
|
max_diff = int(diff.max())
|
||||||
|
assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}"
|
||||||
|
print("PASS: test_fx_no_effects")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# S-Expression Binding Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_sexp_bindings():
|
||||||
|
"""Test that all new primitives are registered."""
|
||||||
|
env = {}
|
||||||
|
bind_typography_primitives(env)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
'linear-gradient', 'radial-gradient', 'multi-stop-gradient',
|
||||||
|
'place-text-strip-gradient', 'place-text-strip-rotated',
|
||||||
|
'place-text-strip-shadow', 'place-text-strip-fx',
|
||||||
|
]
|
||||||
|
for name in expected:
|
||||||
|
assert name in env, f"Missing binding: {name}"
|
||||||
|
|
||||||
|
print("PASS: test_sexp_bindings")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sexp_gradient_primitive():
|
||||||
|
"""Test gradient primitive via binding."""
|
||||||
|
env = {}
|
||||||
|
bind_typography_primitives(env)
|
||||||
|
|
||||||
|
strip = env['render-text-strip']("Test", 36)
|
||||||
|
grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255))
|
||||||
|
|
||||||
|
assert grad.shape == (strip.height, strip.width, 3)
|
||||||
|
print("PASS: test_sexp_gradient_primitive")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sexp_fx_primitive():
|
||||||
|
"""Test combined FX primitive via binding."""
|
||||||
|
env = {}
|
||||||
|
bind_typography_primitives(env)
|
||||||
|
|
||||||
|
strip = env['render-text-strip']("FX", 36)
|
||||||
|
frame = make_frame()
|
||||||
|
|
||||||
|
result = env['place-text-strip-fx'](
|
||||||
|
frame, strip, 100.0, 80.0,
|
||||||
|
color=(255, 200, 0), opacity=0.9,
|
||||||
|
shadow_offset_x=3, shadow_offset_y=3,
|
||||||
|
shadow_opacity=0.5,
|
||||||
|
)
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
print("PASS: test_sexp_fx_primitive")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# JIT Compilation Test
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_jit_fx():
|
||||||
|
"""Test that place_text_strip_fx_jax can be JIT compiled."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
|
||||||
|
|
||||||
|
# JIT compile with static args for angle and blur radius
|
||||||
|
@jax.jit
|
||||||
|
def render(frame, x, y, opacity):
|
||||||
|
return place_text_strip_fx_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
|
||||||
|
color=color, opacity=opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
shadow_offset_x=3.0, shadow_offset_y=3.0,
|
||||||
|
shadow_color=shadow_color,
|
||||||
|
shadow_opacity=0.5,
|
||||||
|
shadow_blur_radius=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call traces, second uses cache
|
||||||
|
result1 = render(frame, 50.0, 100.0, 1.0)
|
||||||
|
result2 = render(frame, 60.0, 90.0, 0.8)
|
||||||
|
|
||||||
|
assert result1.shape == frame.shape
|
||||||
|
assert result2.shape == frame.shape
|
||||||
|
print("PASS: test_jit_fx")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_jit_gradient():
|
||||||
|
"""Test that gradient placement can be JIT compiled."""
|
||||||
|
frame = make_frame()
|
||||||
|
strip = get_strip()
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
grad = jnp.asarray(make_linear_gradient(strip.width, strip.height,
|
||||||
|
(255, 0, 0), (0, 0, 255)))
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def render(frame, x, y):
|
||||||
|
return place_text_strip_gradient_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
grad, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = render(frame, 50.0, 100.0)
|
||||||
|
assert result.shape == frame.shape
|
||||||
|
print("PASS: test_jit_gradient")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Main
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Typography FX Tests")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
# Gradients
|
||||||
|
test_linear_gradient_shape,
|
||||||
|
test_linear_gradient_angle,
|
||||||
|
test_radial_gradient_shape,
|
||||||
|
test_multi_stop_gradient,
|
||||||
|
test_place_gradient,
|
||||||
|
# Rotation
|
||||||
|
test_rotate_strip_identity,
|
||||||
|
test_rotate_strip_90,
|
||||||
|
test_rotate_360_exact,
|
||||||
|
test_place_rotated,
|
||||||
|
# Shadow
|
||||||
|
test_shadow_basic,
|
||||||
|
test_shadow_blur,
|
||||||
|
# Combined FX
|
||||||
|
test_fx_combined,
|
||||||
|
test_fx_no_effects,
|
||||||
|
# S-expression bindings
|
||||||
|
test_sexp_bindings,
|
||||||
|
test_sexp_gradient_primitive,
|
||||||
|
test_sexp_fx_primitive,
|
||||||
|
# JIT compilation
|
||||||
|
test_jit_fx,
|
||||||
|
test_jit_gradient,
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
results.append(test())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"FAIL: {test.__name__}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
results.append(False)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
passed = sum(r for r in results if r)
|
||||||
|
total = len(results)
|
||||||
|
print(f"Results: {passed}/{total} passed")
|
||||||
|
if passed == total:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
else:
|
||||||
|
print(f"FAILED: {total - passed} tests")
|
||||||
|
print("=" * 60)
|
||||||
|
return passed == total
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(0 if main() else 1)
|
||||||
Reference in New Issue
Block a user