diff --git a/streaming/jax_typography.py b/streaming/jax_typography.py index f976b6d..74c0b31 100644 --- a/streaming/jax_typography.py +++ b/streaming/jax_typography.py @@ -54,9 +54,12 @@ Kerning support: (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) """ +import math import numpy as np +import jax import jax.numpy as jnp -from typing import Tuple, Dict, Any +from jax import lax +from typing import Tuple, Dict, Any, List, Optional from dataclasses import dataclass @@ -637,6 +640,661 @@ def place_glyph_simple( ) +# ============================================================================= +# Gradient Functions (compile-time: generate color maps from strip dimensions) +# ============================================================================= + +def make_linear_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + angle: float = 0.0, +) -> np.ndarray: + """Create a linear gradient color map. + + Args: + width, height: Dimensions of the gradient (match strip dimensions) + color1: Start color (R, G, B) 0-255 + color2: End color (R, G, B) 0-255 + angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + # Create coordinate grid + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Project onto gradient axis + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + + # Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1] + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + # Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|) + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + t = np.clip(t, 0.0, 1.0) + + # Interpolate + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_radial_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a radial gradient color map. + + Args: + width, height: Dimensions + color1: Inner color (R, G, B) + color2: Outer color (R, G, B) + center_x, center_y: Center position in [0, 1] (0.5 = center) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Distance from center, normalized so corners are ~1.0 + dx = nx - center_x + dy = ny - center_y + # Max possible distance from center to a corner + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + if max_dist > 0: + t = np.sqrt(dx**2 + dy**2) / max_dist + else: + t = np.zeros_like(dx) + t = np.clip(t, 0.0, 1.0) + + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_multi_stop_gradient( + width: int, + height: int, + stops: list, + angle: float = 0.0, + radial: bool = False, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a multi-stop gradient color map. + + Args: + width, height: Dimensions + stops: List of (position, (R, G, B)) tuples, position in [0, 1] + angle: Gradient angle in degrees (for linear mode) + radial: If True, use radial gradient + center_x, center_y: Center for radial gradient + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + if len(stops) < 2: + if len(stops) == 1: + c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0 + return np.broadcast_to(c, (height, width, 3)).copy() + return np.zeros((height, width, 3), dtype=np.float32) + + # Sort stops by position + stops = sorted(stops, key=lambda s: s[0]) + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + if radial: + dx = nx - center_x + dy = ny - center_y + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6) + else: + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + + t = np.clip(t, 0.0, 1.0) + + # Build gradient from stops using piecewise linear interpolation + colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops]) + positions = np.array([s[0] for s in stops], dtype=np.float32) + + # Start with first color + gradient = np.broadcast_to(colors[0], (height, width, 3)).copy() + + for i in range(len(stops) - 1): + p0, p1 = positions[i], positions[i + 1] + c0, c1 = colors[i], colors[i + 1] + + if p1 <= p0: + continue + + # Segment interpolation factor + seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0) + # Only apply where t >= p0 + mask = (t >= p0)[:, :, None] + seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None] + gradient = np.where(mask, seg_color, gradient) + + return gradient + + +def _composite_strip_onto_frame( + frame: jnp.ndarray, + strip_rgb: jnp.ndarray, + strip_alpha: jnp.ndarray, + dst_x: jnp.ndarray, + dst_y: jnp.ndarray, + sh: int, + sw: int, +) -> jnp.ndarray: + """Core compositing: place tinted+alpha strip onto frame using padded buffer. + + Args: + frame: (H, W, 3) RGB uint8 + strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB + strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha + dst_x, dst_y: int32 destination position + sh, sw: strip dimensions (compile-time constants) + + Returns: + Composited frame (H, W, 3) uint8 + """ + h, w = frame.shape[:2] + + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0)) + alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # PIL-compatible integer alpha blending + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_gradient_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + gradient_map: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """Place text strip with gradient coloring instead of solid color. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + gradient_map: (sh, sw, 3) float32 color map in [0, 1] + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + sh, sw = strip_image.shape[:2] + + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract alpha with opacity + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply gradient instead of solid color + tinted = strip_rgb * gradient_map + + return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw) + + +# ============================================================================= +# Strip Rotation (RGBA bilinear interpolation) +# ============================================================================= + +def _sample_rgba(strip, x, y): + """Bilinear sample all 4 RGBA channels from a strip. + + Args: + strip: (H, W, 4) RGBA float32 + x, y: coordinate arrays (flattened) + + Returns: + (r, g, b, a) each same shape as x + """ + h, w = strip.shape[:2] + + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = x0 + 1 + y1 = y0 + 1 + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + + channels = [] + for c in range(4): + c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0) + c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0) + c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0) + c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0) + + val = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + channels.append(val) + + return channels[0], channels[1], channels[2], channels[3] + + +def rotate_strip_jax( + strip_image: jnp.ndarray, + angle: float, +) -> jnp.ndarray: + """Rotate an RGBA strip by angle (degrees), counter-clockwise. + + Output buffer is sized to contain the full rotated strip. + The output size is ceil(sqrt(w^2 + h^2)), computed at trace time + from the strip's static shape. + + Args: + strip_image: (H, W, 4) RGBA uint8 + angle: Rotation angle in degrees + + Returns: + (out_h, out_w, 4) RGBA uint8 - rotated strip + """ + sh, sw = strip_image.shape[:2] + + # Output size: diagonal of original strip (compile-time constant). + # Ensure output dimensions have same parity as source so that the + # center offset (out - src) / 2 is always an integer. Otherwise + # identity rotations would place content at half-pixel offsets. + diag = int(math.ceil(math.sqrt(sw * sw + sh * sh))) + out_w = diag + ((diag % 2) != (sw % 2)) + out_h = diag + ((diag % 2) != (sh % 2)) + + # Center of input strip and output buffer (pixel-center convention). + # Using (dim-1)/2 ensures integer coords map to integer coords for + # identity rotation regardless of even/odd dimension parity. + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (out_w - 1) / 2.0 + dst_cy = (out_h - 1) / 2.0 + + # Convert to radians and snap trig values near 0/±1 to exact values. + # Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing + # bilinear blending at pixel edges and 1-value differences. + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Create output coordinate grid + y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w) + x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w) + + # Inverse rotation: map output coords to source coords + x_centered = x_coords.astype(jnp.float32) - dst_cx + y_centered = y_coords.astype(jnp.float32) - dst_cy + + src_x = cos_t * x_centered - sin_t * y_centered + src_cx + src_y = sin_t * x_centered + cos_t * y_centered + src_cy + + # Sample all 4 channels + strip_f = strip_image.astype(jnp.float32) + r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + ], axis=2) + + +# ============================================================================= +# Shadow Compositing +# ============================================================================= + +def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray: + """Blur a single-channel alpha array using Gaussian convolution. + + Args: + alpha: (H, W) float32 alpha channel + radius: Blur radius (compile-time constant) + + Returns: + (H, W) float32 blurred alpha + """ + size = radius * 2 + 1 + x = jnp.arange(size, dtype=jnp.float32) - radius + sigma = max(radius / 2.0, 0.5) + gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + # Use JAX conv with SAME padding + h, w = alpha.shape + data_4d = alpha.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(size, size, 1, 1) + + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + ) + return result.reshape(h, w) + + +def place_text_strip_shadow_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + shadow_offset_x: float = 3.0, + shadow_offset_y: float = 3.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.5, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Place text strip with a drop shadow. + + Composites the strip twice: first as shadow (offset, colored, optionally blurred), + then the text itself on top. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + shadow_offset_x/y: Shadow offset in pixels + shadow_color: (3,) RGB color for shadow (default black) + shadow_opacity: Shadow opacity 0-1 + shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time) + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Shadow pass --- + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + # Shadow alpha from strip alpha + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + # Blur the alpha channel for soft shadow + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1) + + # Shadow RGB: solid shadow color + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Text pass --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0 + + color_norm = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_norm + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + +# ============================================================================= +# Combined FX Pipeline +# ============================================================================= + +def place_text_strip_fx_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int = 0, + bearing_x: float = 0.0, + color: jnp.ndarray = None, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + gradient_map: jnp.ndarray = None, + angle: float = 0.0, + shadow_offset_x: float = 0.0, + shadow_offset_y: float = 0.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.0, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Combined text placement with gradient, rotation, and shadow. + + Pipeline order: + 1. Build color layer (solid color or gradient map) + 2. Rotate strip + color layer if angle != 0 + 3. Composite shadow if shadow_opacity > 0 + 4. Composite text + + Note: angle and shadow_blur_radius should be compile-time constants + for optimal JIT performance (they affect buffer shapes/kernel sizes). + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x, y: Anchor point position + color: (3,) RGB color (ignored if gradient_map provided) + opacity: Text opacity + gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color + angle: Rotation angle in degrees (0 = no rotation) + shadow_offset_x/y: Shadow offset + shadow_color: (3,) RGB shadow color + shadow_opacity: Shadow opacity (0 = no shadow) + shadow_blur_radius: Shadow blur radius + + Returns: + Composited frame + """ + if color is None: + color = jnp.array([255, 255, 255], dtype=jnp.float32) + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Step 1: Build color layer --- + if gradient_map is not None: + color_layer = gradient_map # (sh, sw, 3) float32 [0, 1] + else: + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3)) + + # --- Step 2: Rotate if needed --- + # angle is expected to be a compile-time constant or static value + # We check at Python level to avoid tracing issues with dynamic shapes + use_rotation = not isinstance(angle, (int, float)) or angle != 0.0 + + if use_rotation: + # Rotate the strip + rotated_strip = rotate_strip_jax(strip_image, angle) + rh, rw = rotated_strip.shape[:2] + + # Rotate the color layer by building a 4-channel color+dummy image + # Actually, just re-create color layer at rotated size + if gradient_map is not None: + # Rotate gradient map: pack into 3-channel "image", rotate via sampling + grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8) + # Create RGBA from gradient (alpha=255 everywhere) + grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2) + rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle) + color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0 + else: + # Solid color: just broadcast to rotated size + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3)) + + # Update anchor point for rotation (pixel-center convention) + # Rotate the anchor offset around the strip center + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Original anchor relative to strip pixel center + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (rw - 1) / 2.0 + dst_cy = (rh - 1) / 2.0 + + ax_rel = anchor_x - src_cx + ay_rel = anchor_y - src_cy + + # Rotate anchor point (forward rotation, not inverse) + new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx + new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy + + strip_image = rotated_strip + anchor_x = new_ax + anchor_y = new_ay + sh, sw = rh, rw + + # --- Step 3: Shadow --- + has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0 + if has_shadow: + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] + + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Step 4: Composite text --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + tinted = strip_rgb * color_layer + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + # ============================================================================= # S-Expression Primitive Bindings # ============================================================================= @@ -788,6 +1446,119 @@ def bind_typography_primitives(env: dict) -> dict: stroke_width=strip.stroke_width ) + # --- Gradient primitives --- + + def prim_linear_gradient(strip, color1, color2, angle=0.0): + """Create linear gradient color map for a text strip. Compile-time.""" + grad = make_linear_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(angle)) + return jnp.asarray(grad) + + def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5): + """Create radial gradient color map for a text strip. Compile-time.""" + grad = make_radial_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False, + center_x=0.5, center_y=0.5): + """Create multi-stop gradient for a text strip. Compile-time. + + stops: list of (position, (R, G, B)) tuples + """ + parsed_stops = [] + for s in stops: + pos = float(s[0]) + color_tuple = tuple(int(c) for c in s[1]) + parsed_stops.append((pos, color_tuple)) + grad = make_multi_stop_gradient(strip.width, strip.height, + parsed_stops, float(angle), + bool(radial), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0): + """Place text strip with gradient coloring. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + gradient_map, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # --- Rotation primitive --- + + def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255), + opacity=1.0, angle=0.0): + """Place text strip with rotation. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + angle=float(angle), + ) + + # --- Shadow primitive --- + + def prim_place_text_strip_shadow(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=(0, 0, 0), shadow_opacity=0.5, + shadow_blur_radius=0): + """Place text strip with shadow. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_shadow_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur_radius), + ) + + # --- Combined FX primitive --- + + def prim_place_text_strip_fx(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + gradient=None, angle=0.0, + shadow_offset_x=0.0, shadow_offset_y=0.0, + shadow_color=(0, 0, 0), shadow_opacity=0.0, + shadow_blur=0): + """Place text strip with all effects. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + gradient_map=gradient, + angle=float(angle), + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur), + ) + # Add to environment env.update({ # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) @@ -814,6 +1585,17 @@ def bind_typography_primitives(env: dict) -> dict: 'text-strip-anchor-x': prim_text_strip_anchor_x, 'text-strip-anchor-y': prim_text_strip_anchor_y, 'place-text-strip': prim_place_text_strip, + # Gradient primitives + 'linear-gradient': prim_linear_gradient, + 'radial-gradient': prim_radial_gradient, + 'multi-stop-gradient': prim_multi_stop_gradient, + 'place-text-strip-gradient': prim_place_text_strip_gradient, + # Rotation + 'place-text-strip-rotated': prim_place_text_strip_rotated, + # Shadow + 'place-text-strip-shadow': prim_place_text_strip_shadow, + # Combined FX + 'place-text-strip-fx': prim_place_text_strip_fx, }) return env diff --git a/test_typography_fx.py b/test_typography_fx.py new file mode 100644 index 0000000..e57186e --- /dev/null +++ b/test_typography_fx.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +""" +Tests for typography FX: gradients, rotation, shadow, and combined effects. +""" + +import numpy as np +import jax +import jax.numpy as jnp +from PIL import Image + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font, + make_linear_gradient, make_radial_gradient, make_multi_stop_gradient, + place_text_strip_gradient_jax, rotate_strip_jax, + place_text_strip_shadow_jax, place_text_strip_fx_jax, + bind_typography_primitives, +) + + +def make_frame(w=400, h=200): + """Create a dark gray test frame.""" + return jnp.full((h, w, 3), 40, dtype=jnp.uint8) + + +def get_strip(text="Hello", font_size=48): + """Get a pre-rendered text strip.""" + return render_text_strip(text, None, font_size) + + +def has_visible_pixels(frame, threshold=50): + """Check if frame has pixels above threshold.""" + return int(frame.max()) > threshold + + +def save_debug(name, frame): + """Save frame for visual inspection.""" + arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame + Image.fromarray(arr).save(f"/tmp/fx_{name}.png") + + +# ============================================================================= +# Gradient Tests +# ============================================================================= + +def test_linear_gradient_shape(): + grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255)) + assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}" + assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}" + # Left edge should be red-ish, right edge blue-ish + assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}" + assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}" + print("PASS: test_linear_gradient_shape") + return True + + +def test_linear_gradient_angle(): + # 90 degrees: top-to-bottom + grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0) + # Top row should be red, bottom row should be blue + assert grad[0, 50, 0] > 0.8, "Top should be red" + assert grad[-1, 50, 2] > 0.8, "Bottom should be blue" + print("PASS: test_linear_gradient_angle") + return True + + +def test_radial_gradient_shape(): + grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128)) + assert grad.shape == (100, 100, 3) + # Center should be yellow (color1) + assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)" + assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)" + # Corner should be closer to dark blue (color2) + assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue" + print("PASS: test_radial_gradient_shape") + return True + + +def test_multi_stop_gradient(): + stops = [ + (0.0, (255, 0, 0)), + (0.5, (0, 255, 0)), + (1.0, (0, 0, 255)), + ] + grad = make_multi_stop_gradient(100, 10, stops) + assert grad.shape == (10, 100, 3) + # Left: red, Middle: green, Right: blue + assert grad[5, 0, 0] > 0.8, "Left should be red" + assert grad[5, 50, 1] > 0.8, "Middle should be green" + assert grad[5, -1, 2] > 0.8, "Right should be blue" + print("PASS: test_multi_stop_gradient") + return True + + +def test_place_gradient(): + """Test gradient text rendering produces visible output.""" + frame = make_frame() + strip = get_strip() + grad = make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255)) + grad_jax = jnp.asarray(grad) + strip_img = jnp.asarray(strip.image) + + result = place_text_strip_gradient_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + grad_jax, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + assert result.shape == frame.shape + # Should have visible colored pixels + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Gradient text should be visible" + save_debug("gradient", result) + print("PASS: test_place_gradient") + return True + + +# ============================================================================= +# Rotation Tests +# ============================================================================= + +def test_rotate_strip_identity(): + """Rotation by 0 degrees should preserve content.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 0.0) + # Output is larger (diagonal size) + assert rotated.shape[2] == 4, "Should be RGBA" + assert rotated.shape[0] >= strip.height + assert rotated.shape[1] >= strip.width + + # Alpha should have non-zero pixels (text was preserved) + assert rotated[:, :, 3].max() > 200, "Should have visible alpha" + print("PASS: test_rotate_strip_identity") + return True + + +def test_rotate_strip_90(): + """Rotation by 90 degrees.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 90.0) + assert rotated.shape[2] == 4 + # Should still have visible content + assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha" + save_debug("rotated_90", np.array(rotated)) + print("PASS: test_rotate_strip_90") + return True + + +def test_rotate_360_exact(): + """360-degree rotation must be pixel-exact (regression test for trig snapping).""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + sh, sw = strip.height, strip.width + + rotated = rotate_strip_jax(strip_img, 360.0) + rh, rw = rotated.shape[:2] + off_y = (rh - sh) // 2 + off_x = (rw - sw) // 2 + + crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw]) + orig = np.array(strip_img) + d = np.abs(crop.astype(np.int16) - orig.astype(np.int16)) + max_diff = int(d.max()) + assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}" + print("PASS: test_rotate_360_exact") + return True + + +def test_place_rotated(): + """Test rotated text placement produces visible output.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 0], dtype=jnp.float32) + + result = place_text_strip_fx_jax( + frame, strip_img, 200.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + angle=30.0, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Rotated text should be visible" + save_debug("rotated_30", result) + print("PASS: test_place_rotated") + return True + + +# ============================================================================= +# Shadow Tests +# ============================================================================= + +def test_shadow_basic(): + """Test shadow produces visible offset copy.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=5.0, shadow_offset_y=5.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.8, + ) + + assert result.shape == frame.shape + # Should have both bright (text) and dark (shadow) pixels + assert result.max() > 200, "Should have bright text" + save_debug("shadow_basic", result) + print("PASS: test_shadow_basic") + return True + + +def test_shadow_blur(): + """Test blurred shadow.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.7, + shadow_blur_radius=3, + ) + + assert result.shape == frame.shape + save_debug("shadow_blur", result) + print("PASS: test_shadow_blur") + return True + + +# ============================================================================= +# Combined FX Tests +# ============================================================================= + +def test_fx_combined(): + """Test combined gradient + shadow + rotation.""" + frame = make_frame(500, 300) + strip = get_strip("FX Test", 64) + strip_img = jnp.asarray(strip.image) + + grad = make_linear_gradient(strip.width, strip.height, + (255, 100, 0), (0, 100, 255)) + grad_jax = jnp.asarray(grad) + + result = place_text_strip_fx_jax( + frame, strip_img, 250.0, 150.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + gradient_map=grad_jax, + angle=15.0, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.6, + shadow_blur_radius=2, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Combined FX should produce visible output" + save_debug("fx_combined", result) + print("PASS: test_fx_combined") + return True + + +def test_fx_no_effects(): + """FX function with no effects should match basic place_text_strip_jax.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + # Using FX function with defaults + result_fx = place_text_strip_fx_jax( + frame, strip_img, 50.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + # Using original function + result_orig = place_text_strip_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16)) + max_diff = int(diff.max()) + assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}" + print("PASS: test_fx_no_effects") + return True + + +# ============================================================================= +# S-Expression Binding Tests +# ============================================================================= + +def test_sexp_bindings(): + """Test that all new primitives are registered.""" + env = {} + bind_typography_primitives(env) + + expected = [ + 'linear-gradient', 'radial-gradient', 'multi-stop-gradient', + 'place-text-strip-gradient', 'place-text-strip-rotated', + 'place-text-strip-shadow', 'place-text-strip-fx', + ] + for name in expected: + assert name in env, f"Missing binding: {name}" + + print("PASS: test_sexp_bindings") + return True + + +def test_sexp_gradient_primitive(): + """Test gradient primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("Test", 36) + grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255)) + + assert grad.shape == (strip.height, strip.width, 3) + print("PASS: test_sexp_gradient_primitive") + return True + + +def test_sexp_fx_primitive(): + """Test combined FX primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("FX", 36) + frame = make_frame() + + result = env['place-text-strip-fx']( + frame, strip, 100.0, 80.0, + color=(255, 200, 0), opacity=0.9, + shadow_offset_x=3, shadow_offset_y=3, + shadow_opacity=0.5, + ) + assert result.shape == frame.shape + print("PASS: test_sexp_fx_primitive") + return True + + +# ============================================================================= +# JIT Compilation Test +# ============================================================================= + +def test_jit_fx(): + """Test that place_text_strip_fx_jax can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + # JIT compile with static args for angle and blur radius + @jax.jit + def render(frame, x, y, opacity): + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=shadow_color, + shadow_opacity=0.5, + shadow_blur_radius=2, + ) + + # First call traces, second uses cache + result1 = render(frame, 50.0, 100.0, 1.0) + result2 = render(frame, 60.0, 90.0, 0.8) + + assert result1.shape == frame.shape + assert result2.shape == frame.shape + print("PASS: test_jit_fx") + return True + + +def test_jit_gradient(): + """Test that gradient placement can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + grad = jnp.asarray(make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255))) + + @jax.jit + def render(frame, x, y): + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + grad, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + result = render(frame, 50.0, 100.0) + assert result.shape == frame.shape + print("PASS: test_jit_gradient") + return True + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("Typography FX Tests") + print("=" * 60) + + tests = [ + # Gradients + test_linear_gradient_shape, + test_linear_gradient_angle, + test_radial_gradient_shape, + test_multi_stop_gradient, + test_place_gradient, + # Rotation + test_rotate_strip_identity, + test_rotate_strip_90, + test_rotate_360_exact, + test_place_rotated, + # Shadow + test_shadow_basic, + test_shadow_blur, + # Combined FX + test_fx_combined, + test_fx_no_effects, + # S-expression bindings + test_sexp_bindings, + test_sexp_gradient_primitive, + test_sexp_fx_primitive, + # JIT compilation + test_jit_fx, + test_jit_gradient, + ] + + results = [] + for test in tests: + try: + results.append(test()) + except Exception as e: + print(f"FAIL: {test.__name__}: {e}") + import traceback + traceback.print_exc() + results.append(False) + + print("=" * 60) + passed = sum(r for r in results if r) + total = len(results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + return passed == total + + +if __name__ == "__main__": + import sys + sys.exit(0 if main() else 1)