Add shadow, gradient, rotation FX to JAX typography with pixel-exact precision
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m39s

- Add gradient functions: linear, radial, and multi-stop color maps
- Add RGBA strip rotation with bilinear interpolation
- Add shadow compositing with optional Gaussian blur
- Add combined place_text_strip_fx_jax pipeline (gradient + rotation + shadow)
- Add 7 new S-expression bindings for all FX primitives
- Extract shared _composite_strip_onto_frame helper
- Fix rotation precision: snap trig values near 0/±1 to exact values,
  use pixel-center convention (dim-1)/2, and parity-matched output buffers
- All 99 tests pass with zero pixel differences

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 19:25:26 +00:00
parent b322e003be
commit a29841f3c5
2 changed files with 1269 additions and 1 deletions

View File

@@ -54,9 +54,12 @@ Kerning support:
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
"""
import math
import numpy as np
import jax
import jax.numpy as jnp
from typing import Tuple, Dict, Any
from jax import lax
from typing import Tuple, Dict, Any, List, Optional
from dataclasses import dataclass
@@ -637,6 +640,661 @@ def place_glyph_simple(
)
# =============================================================================
# Gradient Functions (compile-time: generate color maps from strip dimensions)
# =============================================================================
def make_linear_gradient(
width: int,
height: int,
color1: tuple,
color2: tuple,
angle: float = 0.0,
) -> np.ndarray:
"""Create a linear gradient color map.
Args:
width, height: Dimensions of the gradient (match strip dimensions)
color1: Start color (R, G, B) 0-255
color2: End color (R, G, B) 0-255
angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom)
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
# Create coordinate grid
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
# Normalize to [0, 1]
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
# Project onto gradient axis
theta = angle * np.pi / 180.0
cos_t = np.cos(theta)
sin_t = np.sin(theta)
# Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1]
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
# Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|)
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
if max_proj > 0:
t = (proj / max_proj + 1.0) / 2.0
else:
t = np.full_like(proj, 0.5)
t = np.clip(t, 0.0, 1.0)
# Interpolate
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
return gradient
def make_radial_gradient(
width: int,
height: int,
color1: tuple,
color2: tuple,
center_x: float = 0.5,
center_y: float = 0.5,
) -> np.ndarray:
"""Create a radial gradient color map.
Args:
width, height: Dimensions
color1: Inner color (R, G, B)
color2: Outer color (R, G, B)
center_x, center_y: Center position in [0, 1] (0.5 = center)
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
c1 = np.array(color1[:3], dtype=np.float32) / 255.0
c2 = np.array(color2[:3], dtype=np.float32) / 255.0
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
# Normalize to [0, 1]
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
# Distance from center, normalized so corners are ~1.0
dx = nx - center_x
dy = ny - center_y
# Max possible distance from center to a corner
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
if max_dist > 0:
t = np.sqrt(dx**2 + dy**2) / max_dist
else:
t = np.zeros_like(dx)
t = np.clip(t, 0.0, 1.0)
gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None]
return gradient
def make_multi_stop_gradient(
width: int,
height: int,
stops: list,
angle: float = 0.0,
radial: bool = False,
center_x: float = 0.5,
center_y: float = 0.5,
) -> np.ndarray:
"""Create a multi-stop gradient color map.
Args:
width, height: Dimensions
stops: List of (position, (R, G, B)) tuples, position in [0, 1]
angle: Gradient angle in degrees (for linear mode)
radial: If True, use radial gradient
center_x, center_y: Center for radial gradient
Returns:
(height, width, 3) float32 array with values in [0, 1]
"""
if len(stops) < 2:
if len(stops) == 1:
c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0
return np.broadcast_to(c, (height, width, 3)).copy()
return np.zeros((height, width, 3), dtype=np.float32)
# Sort stops by position
stops = sorted(stops, key=lambda s: s[0])
ys = np.arange(height, dtype=np.float32)
xs = np.arange(width, dtype=np.float32)
yy, xx = np.meshgrid(ys, xs, indexing='ij')
nx = xx / max(width - 1, 1)
ny = yy / max(height - 1, 1)
if radial:
dx = nx - center_x
dy = ny - center_y
max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2)
t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6)
else:
theta = angle * np.pi / 180.0
cos_t = np.cos(theta)
sin_t = np.sin(theta)
proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t
max_proj = 0.5 * (abs(cos_t) + abs(sin_t))
if max_proj > 0:
t = (proj / max_proj + 1.0) / 2.0
else:
t = np.full_like(proj, 0.5)
t = np.clip(t, 0.0, 1.0)
# Build gradient from stops using piecewise linear interpolation
colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops])
positions = np.array([s[0] for s in stops], dtype=np.float32)
# Start with first color
gradient = np.broadcast_to(colors[0], (height, width, 3)).copy()
for i in range(len(stops) - 1):
p0, p1 = positions[i], positions[i + 1]
c0, c1 = colors[i], colors[i + 1]
if p1 <= p0:
continue
# Segment interpolation factor
seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0)
# Only apply where t >= p0
mask = (t >= p0)[:, :, None]
seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None]
gradient = np.where(mask, seg_color, gradient)
return gradient
def _composite_strip_onto_frame(
frame: jnp.ndarray,
strip_rgb: jnp.ndarray,
strip_alpha: jnp.ndarray,
dst_x: jnp.ndarray,
dst_y: jnp.ndarray,
sh: int,
sw: int,
) -> jnp.ndarray:
"""Core compositing: place tinted+alpha strip onto frame using padded buffer.
Args:
frame: (H, W, 3) RGB uint8
strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB
strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha
dst_x, dst_y: int32 destination position
sh, sw: strip dimensions (compile-time constants)
Returns:
Composited frame (H, W, 3) uint8
"""
h, w = frame.shape[:2]
buf_h = h + 2 * sh
buf_w = w + 2 * sw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0))
alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
# PIL-compatible integer alpha blending
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_text_strip_gradient_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int,
bearing_x: float,
gradient_map: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
) -> jnp.ndarray:
"""Place text strip with gradient coloring instead of solid color.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
gradient_map: (sh, sw, 3) float32 color map in [0, 1]
Other args same as place_text_strip_jax
Returns:
Composited frame
"""
sh, sw = strip_image.shape[:2]
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
# Extract alpha with opacity
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply gradient instead of solid color
tinted = strip_rgb * gradient_map
return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw)
# =============================================================================
# Strip Rotation (RGBA bilinear interpolation)
# =============================================================================
def _sample_rgba(strip, x, y):
"""Bilinear sample all 4 RGBA channels from a strip.
Args:
strip: (H, W, 4) RGBA float32
x, y: coordinate arrays (flattened)
Returns:
(r, g, b, a) each same shape as x
"""
h, w = strip.shape[:2]
x0 = jnp.floor(x).astype(jnp.int32)
y0 = jnp.floor(y).astype(jnp.int32)
x1 = x0 + 1
y1 = y0 + 1
fx = x - x0.astype(jnp.float32)
fy = y - y0.astype(jnp.float32)
valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h)
valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h)
valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h)
valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h)
x0_safe = jnp.clip(x0, 0, w - 1)
x1_safe = jnp.clip(x1, 0, w - 1)
y0_safe = jnp.clip(y0, 0, h - 1)
y1_safe = jnp.clip(y1, 0, h - 1)
channels = []
for c in range(4):
c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0)
c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0)
c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0)
c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0)
val = (c00 * (1 - fx) * (1 - fy) +
c10 * fx * (1 - fy) +
c01 * (1 - fx) * fy +
c11 * fx * fy)
channels.append(val)
return channels[0], channels[1], channels[2], channels[3]
def rotate_strip_jax(
strip_image: jnp.ndarray,
angle: float,
) -> jnp.ndarray:
"""Rotate an RGBA strip by angle (degrees), counter-clockwise.
Output buffer is sized to contain the full rotated strip.
The output size is ceil(sqrt(w^2 + h^2)), computed at trace time
from the strip's static shape.
Args:
strip_image: (H, W, 4) RGBA uint8
angle: Rotation angle in degrees
Returns:
(out_h, out_w, 4) RGBA uint8 - rotated strip
"""
sh, sw = strip_image.shape[:2]
# Output size: diagonal of original strip (compile-time constant).
# Ensure output dimensions have same parity as source so that the
# center offset (out - src) / 2 is always an integer. Otherwise
# identity rotations would place content at half-pixel offsets.
diag = int(math.ceil(math.sqrt(sw * sw + sh * sh)))
out_w = diag + ((diag % 2) != (sw % 2))
out_h = diag + ((diag % 2) != (sh % 2))
# Center of input strip and output buffer (pixel-center convention).
# Using (dim-1)/2 ensures integer coords map to integer coords for
# identity rotation regardless of even/odd dimension parity.
src_cx = (sw - 1) / 2.0
src_cy = (sh - 1) / 2.0
dst_cx = (out_w - 1) / 2.0
dst_cy = (out_h - 1) / 2.0
# Convert to radians and snap trig values near 0/±1 to exact values.
# Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing
# bilinear blending at pixel edges and 1-value differences.
theta = angle * jnp.pi / 180.0
cos_t = jnp.cos(theta)
sin_t = jnp.sin(theta)
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
# Create output coordinate grid
y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w)
x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w)
# Inverse rotation: map output coords to source coords
x_centered = x_coords.astype(jnp.float32) - dst_cx
y_centered = y_coords.astype(jnp.float32) - dst_cy
src_x = cos_t * x_centered - sin_t * y_centered + src_cx
src_y = sin_t * x_centered + cos_t * y_centered + src_cy
# Sample all 4 channels
strip_f = strip_image.astype(jnp.float32)
r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8),
], axis=2)
# =============================================================================
# Shadow Compositing
# =============================================================================
def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray:
"""Blur a single-channel alpha array using Gaussian convolution.
Args:
alpha: (H, W) float32 alpha channel
radius: Blur radius (compile-time constant)
Returns:
(H, W) float32 blurred alpha
"""
size = radius * 2 + 1
x = jnp.arange(size, dtype=jnp.float32) - radius
sigma = max(radius / 2.0, 0.5)
gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
kernel = jnp.outer(gaussian_1d, gaussian_1d)
# Use JAX conv with SAME padding
h, w = alpha.shape
data_4d = alpha.reshape(1, h, w, 1)
kernel_4d = kernel.reshape(size, size, 1, 1)
result = lax.conv_general_dilated(
data_4d, kernel_4d,
window_strides=(1, 1),
padding='SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
)
return result.reshape(h, w)
def place_text_strip_shadow_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int,
bearing_x: float,
color: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
shadow_offset_x: float = 3.0,
shadow_offset_y: float = 3.0,
shadow_color: jnp.ndarray = None,
shadow_opacity: float = 0.5,
shadow_blur_radius: int = 0,
) -> jnp.ndarray:
"""Place text strip with a drop shadow.
Composites the strip twice: first as shadow (offset, colored, optionally blurred),
then the text itself on top.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
shadow_offset_x/y: Shadow offset in pixels
shadow_color: (3,) RGB color for shadow (default black)
shadow_opacity: Shadow opacity 0-1
shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time)
Other args same as place_text_strip_jax
Returns:
Composited frame
"""
if shadow_color is None:
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
sh, sw = strip_image.shape[:2]
# --- Shadow pass ---
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
# Shadow alpha from strip alpha
shadow_opacity_int = jnp.round(shadow_opacity * 255)
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
if shadow_blur_radius > 0:
# Blur the alpha channel for soft shadow
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
else:
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1)
# Shadow RGB: solid shadow color
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
# --- Text pass ---
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0
color_norm = color.astype(jnp.float32) / 255.0
tinted = strip_rgb * color_norm
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
return frame
# =============================================================================
# Combined FX Pipeline
# =============================================================================
def place_text_strip_fx_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray,
x: float,
y: float,
baseline_y: int = 0,
bearing_x: float = 0.0,
color: jnp.ndarray = None,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
gradient_map: jnp.ndarray = None,
angle: float = 0.0,
shadow_offset_x: float = 0.0,
shadow_offset_y: float = 0.0,
shadow_color: jnp.ndarray = None,
shadow_opacity: float = 0.0,
shadow_blur_radius: int = 0,
) -> jnp.ndarray:
"""Combined text placement with gradient, rotation, and shadow.
Pipeline order:
1. Build color layer (solid color or gradient map)
2. Rotate strip + color layer if angle != 0
3. Composite shadow if shadow_opacity > 0
4. Composite text
Note: angle and shadow_blur_radius should be compile-time constants
for optimal JIT performance (they affect buffer shapes/kernel sizes).
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
x, y: Anchor point position
color: (3,) RGB color (ignored if gradient_map provided)
opacity: Text opacity
gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color
angle: Rotation angle in degrees (0 = no rotation)
shadow_offset_x/y: Shadow offset
shadow_color: (3,) RGB shadow color
shadow_opacity: Shadow opacity (0 = no shadow)
shadow_blur_radius: Shadow blur radius
Returns:
Composited frame
"""
if color is None:
color = jnp.array([255, 255, 255], dtype=jnp.float32)
if shadow_color is None:
shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32)
sh, sw = strip_image.shape[:2]
# --- Step 1: Build color layer ---
if gradient_map is not None:
color_layer = gradient_map # (sh, sw, 3) float32 [0, 1]
else:
color_norm = color.astype(jnp.float32) / 255.0
color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3))
# --- Step 2: Rotate if needed ---
# angle is expected to be a compile-time constant or static value
# We check at Python level to avoid tracing issues with dynamic shapes
use_rotation = not isinstance(angle, (int, float)) or angle != 0.0
if use_rotation:
# Rotate the strip
rotated_strip = rotate_strip_jax(strip_image, angle)
rh, rw = rotated_strip.shape[:2]
# Rotate the color layer by building a 4-channel color+dummy image
# Actually, just re-create color layer at rotated size
if gradient_map is not None:
# Rotate gradient map: pack into 3-channel "image", rotate via sampling
grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8)
# Create RGBA from gradient (alpha=255 everywhere)
grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2)
rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle)
color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0
else:
# Solid color: just broadcast to rotated size
color_norm = color.astype(jnp.float32) / 255.0
color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3))
# Update anchor point for rotation (pixel-center convention)
# Rotate the anchor offset around the strip center
theta = angle * jnp.pi / 180.0
cos_t = jnp.cos(theta)
sin_t = jnp.sin(theta)
cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t)
cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t)
cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t)
sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t)
sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t)
# Original anchor relative to strip pixel center
src_cx = (sw - 1) / 2.0
src_cy = (sh - 1) / 2.0
dst_cx = (rw - 1) / 2.0
dst_cy = (rh - 1) / 2.0
ax_rel = anchor_x - src_cx
ay_rel = anchor_y - src_cy
# Rotate anchor point (forward rotation, not inverse)
new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx
new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy
strip_image = rotated_strip
anchor_x = new_ax
anchor_y = new_ay
sh, sw = rh, rw
# --- Step 3: Shadow ---
has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0
if has_shadow:
shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32)
shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32)
shadow_opacity_int = jnp.round(shadow_opacity * 255)
strip_a_raw = strip_image[:, :, 3].astype(jnp.float32)
if shadow_blur_radius > 0:
blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius)
shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0
else:
shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0
shadow_alpha = shadow_alpha[:, :, None]
shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0
shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3))
frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw)
# --- Step 4: Composite text ---
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
tinted = strip_rgb * color_layer
frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw)
return frame
# =============================================================================
# S-Expression Primitive Bindings
# =============================================================================
@@ -788,6 +1446,119 @@ def bind_typography_primitives(env: dict) -> dict:
stroke_width=strip.stroke_width
)
# --- Gradient primitives ---
def prim_linear_gradient(strip, color1, color2, angle=0.0):
"""Create linear gradient color map for a text strip. Compile-time."""
grad = make_linear_gradient(strip.width, strip.height,
tuple(int(c) for c in color1),
tuple(int(c) for c in color2),
float(angle))
return jnp.asarray(grad)
def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5):
"""Create radial gradient color map for a text strip. Compile-time."""
grad = make_radial_gradient(strip.width, strip.height,
tuple(int(c) for c in color1),
tuple(int(c) for c in color2),
float(center_x), float(center_y))
return jnp.asarray(grad)
def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False,
center_x=0.5, center_y=0.5):
"""Create multi-stop gradient for a text strip. Compile-time.
stops: list of (position, (R, G, B)) tuples
"""
parsed_stops = []
for s in stops:
pos = float(s[0])
color_tuple = tuple(int(c) for c in s[1])
parsed_stops.append((pos, color_tuple))
grad = make_multi_stop_gradient(strip.width, strip.height,
parsed_stops, float(angle),
bool(radial),
float(center_x), float(center_y))
return jnp.asarray(grad)
def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0):
"""Place text strip with gradient coloring. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
return place_text_strip_gradient_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
gradient_map, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
# --- Rotation primitive ---
def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255),
opacity=1.0, angle=0.0):
"""Place text strip with rotation. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color_arr, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
angle=float(angle),
)
# --- Shadow primitive ---
def prim_place_text_strip_shadow(frame, strip, x, y,
color=(255, 255, 255), opacity=1.0,
shadow_offset_x=3.0, shadow_offset_y=3.0,
shadow_color=(0, 0, 0), shadow_opacity=0.5,
shadow_blur_radius=0):
"""Place text strip with shadow. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
return place_text_strip_shadow_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color_arr, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
shadow_offset_x=float(shadow_offset_x),
shadow_offset_y=float(shadow_offset_y),
shadow_color=shadow_color_arr,
shadow_opacity=float(shadow_opacity),
shadow_blur_radius=int(shadow_blur_radius),
)
# --- Combined FX primitive ---
def prim_place_text_strip_fx(frame, strip, x, y,
color=(255, 255, 255), opacity=1.0,
gradient=None, angle=0.0,
shadow_offset_x=0.0, shadow_offset_y=0.0,
shadow_color=(0, 0, 0), shadow_opacity=0.0,
shadow_blur=0):
"""Place text strip with all effects. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32)
return place_text_strip_fx_jax(
frame, strip_img, x, y,
baseline_y=strip.baseline_y, bearing_x=strip.bearing_x,
color=color_arr, opacity=opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width,
gradient_map=gradient,
angle=float(angle),
shadow_offset_x=float(shadow_offset_x),
shadow_offset_y=float(shadow_offset_y),
shadow_color=shadow_color_arr,
shadow_opacity=float(shadow_opacity),
shadow_blur_radius=int(shadow_blur),
)
# Add to environment
env.update({
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
@@ -814,6 +1585,17 @@ def bind_typography_primitives(env: dict) -> dict:
'text-strip-anchor-x': prim_text_strip_anchor_x,
'text-strip-anchor-y': prim_text_strip_anchor_y,
'place-text-strip': prim_place_text_strip,
# Gradient primitives
'linear-gradient': prim_linear_gradient,
'radial-gradient': prim_radial_gradient,
'multi-stop-gradient': prim_multi_stop_gradient,
'place-text-strip-gradient': prim_place_text_strip_gradient,
# Rotation
'place-text-strip-rotated': prim_place_text_strip_rotated,
# Shadow
'place-text-strip-shadow': prim_place_text_strip_shadow,
# Combined FX
'place-text-strip-fx': prim_place_text_strip_fx,
})
return env