""" JAX Typography Primitives Two approaches for text rendering, both compile to JAX/GPU: ## 1. TextStrip - Pixel-perfect static text Pre-render entire strings at compile time using PIL. Perfect sub-pixel anti-aliasing, exact match with PIL. Use for: static titles, labels, any text without per-character effects. S-expression: (let ((strip (render-text-strip "Hello World" 48))) (place-text-strip frame strip x y :color white)) ## 2. Glyph-by-glyph - Dynamic text effects Individual glyph placement for wave, arc, audio-reactive effects. Each character can have independent position, color, opacity. Note: slight anti-aliasing differences vs PIL due to integer positioning. S-expression: ; Wave text - y oscillates with character index (let ((glyphs (text-glyphs "Wavy" 48))) (first (fold glyphs (list frame 0) (lambda (acc g) (let ((frm (first acc)) (cursor (second acc)) (i (length acc))) ; approximate index (list (place-glyph frm (glyph-image g) (+ x cursor) (+ y (* amplitude (sin (* i frequency)))) (glyph-bearing-x g) (glyph-bearing-y g) white 1.0) (+ cursor (glyph-advance g)))))))) ; Audio-reactive spacing (let ((glyphs (text-glyphs "Bass" 48)) (bass (audio-band 0 200))) (first (fold glyphs (list frame 0) (lambda (acc g) (let ((frm (first acc)) (cursor (second acc))) (list (place-glyph frm (glyph-image g) (+ x cursor) y (glyph-bearing-x g) (glyph-bearing-y g) white 1.0) (+ cursor (glyph-advance g) (* bass 20)))))))) Kerning support: ; With kerning adjustment (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) """ import math import numpy as np import jax import jax.numpy as jnp from jax import lax from typing import Tuple, Dict, Any, List, Optional from dataclasses import dataclass # ============================================================================= # Glyph Data (computed at compile time) # ============================================================================= @dataclass class GlyphData: """Glyph data computed at compile time. Attributes: char: The character image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime advance: Horizontal advance (distance to next glyph origin) bearing_x: Left side bearing (x offset from origin to first pixel) bearing_y: Top bearing (y offset from baseline to top of glyph) width: Image width height: Image height """ char: str image: np.ndarray # (H, W, 4) RGBA uint8 advance: float bearing_x: float bearing_y: float width: int height: int # Font cache: (font_name, font_size) -> {char: GlyphData} _GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {} # Font metrics cache: (font_name, font_size) -> (ascent, descent) _METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {} # Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment} # Kerning adjustment is added to advance: new_advance = advance + kerning # Typically negative (characters move closer together) _KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {} def _load_font(font_name: str = None, font_size: int = 32): """Load a font. Called at compile time.""" from PIL import ImageFont candidates = [ font_name, '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', '/usr/share/fonts/truetype/freefont/FreeSans.ttf', ] for path in candidates: if path is None: continue try: return ImageFont.truetype(path, font_size) except (IOError, OSError): continue return ImageFont.load_default() def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]: """Get or create glyph cache for a font. Called at compile time.""" cache_key = (font_name, font_size) if cache_key in _GLYPH_CACHE: return _GLYPH_CACHE[cache_key] from PIL import Image, ImageDraw font = _load_font(font_name, font_size) ascent, descent = font.getmetrics() _METRICS_CACHE[cache_key] = (ascent, descent) glyphs = {} charset = ''.join(chr(i) for i in range(32, 127)) temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) temp_draw = ImageDraw.Draw(temp_img) for char in charset: # Get metrics bbox = temp_draw.textbbox((0, 0), char, font=font) advance = font.getlength(char) x_min, y_min, x_max, y_max = bbox # Create glyph image with padding padding = 2 img_w = max(int(x_max - x_min) + padding * 2, 1) img_h = max(int(y_max - y_min) + padding * 2, 1) glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) glyph_draw = ImageDraw.Draw(glyph_img) # Draw at position accounting for bbox offset draw_x = padding - x_min draw_y = padding - y_min glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) glyphs[char] = GlyphData( char=char, image=np.array(glyph_img, dtype=np.uint8), advance=float(advance), bearing_x=float(x_min), bearing_y=float(-y_min), # Distance from baseline to top width=img_w, height=img_h, ) _GLYPH_CACHE[cache_key] = glyphs return glyphs def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]: """Get or create kerning cache for a font. Called at compile time. Kerning is computed as: kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b) This gives the adjustment needed when placing 'b' after 'a'. Typically negative (characters move closer together). """ cache_key = (font_name, font_size) if cache_key in _KERNING_CACHE: return _KERNING_CACHE[cache_key] font = _load_font(font_name, font_size) kerning = {} # Compute kerning for all printable ASCII pairs charset = ''.join(chr(i) for i in range(32, 127)) # Pre-compute individual character lengths char_lengths = {c: font.getlength(c) for c in charset} # Compute kerning for each pair for c1 in charset: for c2 in charset: pair_length = font.getlength(c1 + c2) individual_sum = char_lengths[c1] + char_lengths[c2] kern = pair_length - individual_sum # Only store non-zero kerning to save memory if abs(kern) > 0.01: kerning[(c1, c2)] = kern _KERNING_CACHE[cache_key] = kerning return kerning def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float: """Get kerning adjustment between two characters. Compile-time. Returns the adjustment to add to char1's advance when char2 follows. Typically negative (characters move closer). Usage in S-expression: (+ (glyph-advance g1) (kerning g1 g2)) """ kerning_cache = _get_kerning_cache(font_name, font_size) return kerning_cache.get((char1, char2), 0.0) @dataclass class TextStrip: """Pre-rendered text strip with proper sub-pixel anti-aliasing. Rendered at compile time using PIL for exact matching. At runtime, just composite onto frame at integer positions. Attributes: text: The original text image: RGBA image as numpy array (H, W, 4) width: Strip width height: Strip height baseline_y: Y position of baseline within the strip bearing_x: Left side bearing of first character anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right) anchor_y: Y offset for anchor point (depends on anchor type) stroke_width: Stroke width used when rendering """ text: str image: np.ndarray width: int height: int baseline_y: int bearing_x: float anchor_x: float = 0.0 anchor_y: float = 0.0 stroke_width: int = 0 # Text strip cache: cache_key -> TextStrip _TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {} def render_text_strip( text: str, font_name: str = None, font_size: int = 32, stroke_width: int = 0, stroke_fill: tuple = None, anchor: str = "la", # left-ascender (PIL default is "la") multiline: bool = False, line_spacing: int = 4, align: str = "left", ) -> TextStrip: """Render text to a strip at compile time. Perfect sub-pixel anti-aliasing. Args: text: Text to render font_name: Path to font file (None for default) font_size: Font size in pixels stroke_width: Outline width in pixels (0 for no outline) stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black anchor: PIL anchor code - first char: h=left, m=middle, r=right second char: a=ascender, t=top, m=middle, s=baseline, d=descender multiline: If True, handle newlines in text line_spacing: Extra pixels between lines (for multiline) align: 'left', 'center', 'right' (for multiline) Returns: TextStrip with pre-rendered text """ # Build cache key from all parameters cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align) if cache_key in _TEXT_STRIP_CACHE: return _TEXT_STRIP_CACHE[cache_key] from PIL import Image, ImageDraw font = _load_font(font_name, font_size) ascent, descent = font.getmetrics() # Default stroke fill to black if stroke_fill is None: stroke_fill = (0, 0, 0, 255) elif len(stroke_fill) == 3: stroke_fill = (*stroke_fill, 255) # Get text bbox (accounting for stroke) temp = Image.new('RGBA', (1, 1)) temp_draw = ImageDraw.Draw(temp) if multiline: bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing, stroke_width=stroke_width) else: bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width) # bbox is (left, top, right, bottom) relative to origin x_min, y_min, x_max, y_max = bbox # Create image with padding (extra for stroke) padding = 2 + stroke_width img_width = max(int(x_max - x_min) + padding * 2, 1) img_height = max(int(y_max - y_min) + padding * 2, 1) # Create RGBA image img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0)) draw = ImageDraw.Draw(img) # Draw text at position that puts it in the image draw_x = padding - x_min draw_y = padding - y_min if multiline: draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, spacing=line_spacing, align=align, stroke_width=stroke_width, stroke_fill=stroke_fill) else: draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, stroke_width=stroke_width, stroke_fill=stroke_fill) # Baseline is at y=0 in text coordinates, which is at draw_y in image baseline_y = draw_y # Convert to numpy for pixel analysis img_array = np.array(img, dtype=np.uint8) # Calculate anchor offsets # For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching h_anchor = anchor[0] if len(anchor) > 0 else 'l' v_anchor = anchor[1] if len(anchor) > 1 else 'a' # Find actual pixel bounds (for middle anchor calculations) alpha = img_array[:, :, 3] nonzero_cols = np.where(alpha.max(axis=0) > 0)[0] nonzero_rows = np.where(alpha.max(axis=1) > 0)[0] if len(nonzero_cols) > 0: pixel_x_min = nonzero_cols.min() pixel_x_max = nonzero_cols.max() pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0 else: pixel_x_center = img_width / 2.0 if len(nonzero_rows) > 0: pixel_y_min = nonzero_rows.min() pixel_y_max = nonzero_rows.max() pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0 else: pixel_y_center = img_height / 2.0 # Horizontal offset text_width = x_max - x_min if h_anchor == 'l': # left edge of text anchor_x = float(draw_x) elif h_anchor == 'm': # middle - use actual pixel center for perfect matching anchor_x = pixel_x_center elif h_anchor == 'r': # right edge of text anchor_x = float(draw_x + text_width) else: anchor_x = float(draw_x) # Vertical offset # PIL anchor positions are based on font metrics (ascent/descent): # - 'a' (ascender): at the ascender line → draw_y in strip # - 't' (top): at top of text bounding box → padding in strip # - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender # - 's' (baseline): at baseline = ascent below ascender # - 'd' (descender): at descender line = ascent + descent below ascender if v_anchor == 'a': # ascender anchor_y = float(draw_y) elif v_anchor == 't': # top of bbox anchor_y = float(padding) elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation) anchor_y = float(draw_y + (ascent + descent) / 2.0) elif v_anchor == 's': # baseline anchor_y = float(draw_y + ascent) elif v_anchor == 'd': # descender anchor_y = float(draw_y + ascent + descent) else: anchor_y = float(draw_y) # default to ascender strip = TextStrip( text=text, image=img_array, width=img_width, height=img_height, baseline_y=baseline_y, bearing_x=float(x_min), anchor_x=anchor_x, anchor_y=anchor_y, stroke_width=stroke_width, ) _TEXT_STRIP_CACHE[cache_key] = strip return strip # ============================================================================= # Compile-time functions (called during S-expression compilation) # ============================================================================= def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData: """Get glyph data for a single character. Compile-time.""" cache = _get_glyph_cache(font_name, font_size) return cache.get(char, cache.get(' ')) def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list: """Get glyph data for a string. Compile-time.""" cache = _get_glyph_cache(font_name, font_size) space = cache.get(' ') return [cache.get(c, space) for c in text] def get_font_ascent(font_name: str = None, font_size: int = 32) -> float: """Get font ascent. Compile-time.""" _get_glyph_cache(font_name, font_size) # Ensure cache exists return _METRICS_CACHE[(font_name, font_size)][0] def get_font_descent(font_name: str = None, font_size: int = 32) -> float: """Get font descent. Compile-time.""" _get_glyph_cache(font_name, font_size) return _METRICS_CACHE[(font_name, font_size)][1] # ============================================================================= # JAX Runtime Primitives # ============================================================================= def place_glyph_jax( frame: jnp.ndarray, glyph_image: jnp.ndarray, # (H, W, 4) RGBA x: float, y: float, bearing_x: float, bearing_y: float, color: jnp.ndarray, # (3,) RGB 0-255 opacity: float = 1.0, ) -> jnp.ndarray: """ Place a glyph onto a frame. This is the core JAX primitive. All positioning math can use traced values (x, y from audio, time, etc.) The glyph_image is static (determined at compile time). Args: frame: (H, W, 3) RGB frame glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array) x: X position of glyph origin (baseline point) y: Y position of baseline bearing_x: Left side bearing bearing_y: Top bearing (from baseline to top) color: RGB color array opacity: Opacity 0-1 Returns: Frame with glyph composited """ h, w = frame.shape[:2] gh, gw = glyph_image.shape[:2] # Calculate destination position # bearing_x: how far right of origin the glyph starts (can be negative) # bearing_y: how far up from baseline the glyph extends padding = 2 # Must match padding used in glyph creation dst_x = x + bearing_x - padding dst_y = y - bearing_y - padding # Extract glyph RGB and alpha glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0 # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 opacity_int = jnp.round(opacity * 255) glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32) glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0 # Apply color tint (glyph is white, multiply by color) color_normalized = color.astype(jnp.float32) / 255.0 tinted = glyph_rgb * color_normalized from jax.lax import dynamic_update_slice # Use padded buffer to avoid XLA's dynamic_update_slice clamping buf_h = h + 2 * gh buf_w = w + 2 * gw rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) dst_x_int = dst_x.astype(jnp.int32) dst_y_int = dst_y.astype(jnp.int32) place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32) place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32) rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0)) rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :] alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :] # Alpha composite using PIL-compatible integer arithmetic src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) dst_int = frame.astype(jnp.int32) result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 return jnp.clip(result, 0, 255).astype(jnp.uint8) def place_text_strip_jax( frame: jnp.ndarray, strip_image: jnp.ndarray, # (H, W, 4) RGBA x: float, y: float, baseline_y: int, bearing_x: float, color: jnp.ndarray, opacity: float = 1.0, anchor_x: float = 0.0, anchor_y: float = 0.0, stroke_width: int = 0, ) -> jnp.ndarray: """ Place a pre-rendered text strip onto a frame. The strip was rendered at compile time with proper sub-pixel anti-aliasing. This just composites it at the specified position. Args: frame: (H, W, 3) RGB frame strip_image: (sh, sw, 4) RGBA text strip x: X position for anchor point y: Y position for anchor point baseline_y: Y position of baseline within the strip bearing_x: Left side bearing color: RGB color opacity: Opacity 0-1 anchor_x: X offset of anchor point within strip anchor_y: Y offset of anchor point within strip stroke_width: Stroke width used when rendering (affects padding) Returns: Frame with text composited """ h, w = frame.shape[:2] sh, sw = strip_image.shape[:2] # Calculate destination position # Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame # anchor_x/anchor_y already account for the anchor position within the strip # Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding) dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) # Extract strip RGB and alpha strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 # Use jnp.round (banker's rounding) to match Python's round() used by PIL opacity_int = jnp.round(opacity * 255) strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 # Apply color tint color_normalized = color.astype(jnp.float32) / 255.0 tinted = strip_rgb * color_normalized from jax.lax import dynamic_update_slice # Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior. # XLA clamps indices so the update fits, which silently shifts the strip. # By placing into a buffer padded by strip dimensions, then extracting the # frame-sized region, we get correct clipping for both overflow and underflow. buf_h = h + 2 * sh buf_w = w + 2 * sw rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) # Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) # Extract frame-sized region (sh, sw are compile-time constants from strip shape) rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] # Alpha composite using PIL-compatible integer arithmetic: # result = (src * alpha + dst * (255 - alpha) + 127) // 255 src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) dst_int = frame.astype(jnp.int32) result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 return jnp.clip(result, 0, 255).astype(jnp.uint8) def place_glyph_simple( frame: jnp.ndarray, glyph: GlyphData, x: float, y: float, color: tuple = (255, 255, 255), opacity: float = 1.0, ) -> jnp.ndarray: """ Convenience wrapper that takes GlyphData directly. Converts glyph image to JAX array. For S-expression use, prefer place_glyph_jax with pre-converted arrays. """ glyph_jax = jnp.asarray(glyph.image) color_jax = jnp.array(color, dtype=jnp.float32) return place_glyph_jax( frame, glyph_jax, x, y, glyph.bearing_x, glyph.bearing_y, color_jax, opacity ) # ============================================================================= # Gradient Functions (compile-time: generate color maps from strip dimensions) # ============================================================================= def make_linear_gradient( width: int, height: int, color1: tuple, color2: tuple, angle: float = 0.0, ) -> np.ndarray: """Create a linear gradient color map. Args: width, height: Dimensions of the gradient (match strip dimensions) color1: Start color (R, G, B) 0-255 color2: End color (R, G, B) 0-255 angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom) Returns: (height, width, 3) float32 array with values in [0, 1] """ c1 = np.array(color1[:3], dtype=np.float32) / 255.0 c2 = np.array(color2[:3], dtype=np.float32) / 255.0 # Create coordinate grid ys = np.arange(height, dtype=np.float32) xs = np.arange(width, dtype=np.float32) yy, xx = np.meshgrid(ys, xs, indexing='ij') # Normalize to [0, 1] nx = xx / max(width - 1, 1) ny = yy / max(height - 1, 1) # Project onto gradient axis theta = angle * np.pi / 180.0 cos_t = np.cos(theta) sin_t = np.sin(theta) # Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1] proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t # Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|) max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) if max_proj > 0: t = (proj / max_proj + 1.0) / 2.0 else: t = np.full_like(proj, 0.5) t = np.clip(t, 0.0, 1.0) # Interpolate gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] return gradient def make_radial_gradient( width: int, height: int, color1: tuple, color2: tuple, center_x: float = 0.5, center_y: float = 0.5, ) -> np.ndarray: """Create a radial gradient color map. Args: width, height: Dimensions color1: Inner color (R, G, B) color2: Outer color (R, G, B) center_x, center_y: Center position in [0, 1] (0.5 = center) Returns: (height, width, 3) float32 array with values in [0, 1] """ c1 = np.array(color1[:3], dtype=np.float32) / 255.0 c2 = np.array(color2[:3], dtype=np.float32) / 255.0 ys = np.arange(height, dtype=np.float32) xs = np.arange(width, dtype=np.float32) yy, xx = np.meshgrid(ys, xs, indexing='ij') # Normalize to [0, 1] nx = xx / max(width - 1, 1) ny = yy / max(height - 1, 1) # Distance from center, normalized so corners are ~1.0 dx = nx - center_x dy = ny - center_y # Max possible distance from center to a corner max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) if max_dist > 0: t = np.sqrt(dx**2 + dy**2) / max_dist else: t = np.zeros_like(dx) t = np.clip(t, 0.0, 1.0) gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] return gradient def make_multi_stop_gradient( width: int, height: int, stops: list, angle: float = 0.0, radial: bool = False, center_x: float = 0.5, center_y: float = 0.5, ) -> np.ndarray: """Create a multi-stop gradient color map. Args: width, height: Dimensions stops: List of (position, (R, G, B)) tuples, position in [0, 1] angle: Gradient angle in degrees (for linear mode) radial: If True, use radial gradient center_x, center_y: Center for radial gradient Returns: (height, width, 3) float32 array with values in [0, 1] """ if len(stops) < 2: if len(stops) == 1: c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0 return np.broadcast_to(c, (height, width, 3)).copy() return np.zeros((height, width, 3), dtype=np.float32) # Sort stops by position stops = sorted(stops, key=lambda s: s[0]) ys = np.arange(height, dtype=np.float32) xs = np.arange(width, dtype=np.float32) yy, xx = np.meshgrid(ys, xs, indexing='ij') nx = xx / max(width - 1, 1) ny = yy / max(height - 1, 1) if radial: dx = nx - center_x dy = ny - center_y max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6) else: theta = angle * np.pi / 180.0 cos_t = np.cos(theta) sin_t = np.sin(theta) proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) if max_proj > 0: t = (proj / max_proj + 1.0) / 2.0 else: t = np.full_like(proj, 0.5) t = np.clip(t, 0.0, 1.0) # Build gradient from stops using piecewise linear interpolation colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops]) positions = np.array([s[0] for s in stops], dtype=np.float32) # Start with first color gradient = np.broadcast_to(colors[0], (height, width, 3)).copy() for i in range(len(stops) - 1): p0, p1 = positions[i], positions[i + 1] c0, c1 = colors[i], colors[i + 1] if p1 <= p0: continue # Segment interpolation factor seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0) # Only apply where t >= p0 mask = (t >= p0)[:, :, None] seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None] gradient = np.where(mask, seg_color, gradient) return gradient def _composite_strip_onto_frame( frame: jnp.ndarray, strip_rgb: jnp.ndarray, strip_alpha: jnp.ndarray, dst_x: jnp.ndarray, dst_y: jnp.ndarray, sh: int, sw: int, ) -> jnp.ndarray: """Core compositing: place tinted+alpha strip onto frame using padded buffer. Args: frame: (H, W, 3) RGB uint8 strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha dst_x, dst_y: int32 destination position sh, sw: strip dimensions (compile-time constants) Returns: Composited frame (H, W, 3) uint8 """ h, w = frame.shape[:2] buf_h = h + 2 * sh buf_w = w + 2 * sw rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0)) alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] # PIL-compatible integer alpha blending src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) dst_int = frame.astype(jnp.int32) result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 return jnp.clip(result, 0, 255).astype(jnp.uint8) def place_text_strip_gradient_jax( frame: jnp.ndarray, strip_image: jnp.ndarray, x: float, y: float, baseline_y: int, bearing_x: float, gradient_map: jnp.ndarray, opacity: float = 1.0, anchor_x: float = 0.0, anchor_y: float = 0.0, stroke_width: int = 0, ) -> jnp.ndarray: """Place text strip with gradient coloring instead of solid color. Args: frame: (H, W, 3) RGB frame strip_image: (sh, sw, 4) RGBA text strip gradient_map: (sh, sw, 3) float32 color map in [0, 1] Other args same as place_text_strip_jax Returns: Composited frame """ sh, sw = strip_image.shape[:2] dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) # Extract alpha with opacity strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 opacity_int = jnp.round(opacity * 255) strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 # Apply gradient instead of solid color tinted = strip_rgb * gradient_map return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw) # ============================================================================= # Strip Rotation (RGBA bilinear interpolation) # ============================================================================= def _sample_rgba(strip, x, y): """Bilinear sample all 4 RGBA channels from a strip. Args: strip: (H, W, 4) RGBA float32 x, y: coordinate arrays (flattened) Returns: (r, g, b, a) each same shape as x """ h, w = strip.shape[:2] x0 = jnp.floor(x).astype(jnp.int32) y0 = jnp.floor(y).astype(jnp.int32) x1 = x0 + 1 y1 = y0 + 1 fx = x - x0.astype(jnp.float32) fy = y - y0.astype(jnp.float32) valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) x0_safe = jnp.clip(x0, 0, w - 1) x1_safe = jnp.clip(x1, 0, w - 1) y0_safe = jnp.clip(y0, 0, h - 1) y1_safe = jnp.clip(y1, 0, h - 1) channels = [] for c in range(4): c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0) c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0) c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0) c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0) val = (c00 * (1 - fx) * (1 - fy) + c10 * fx * (1 - fy) + c01 * (1 - fx) * fy + c11 * fx * fy) channels.append(val) return channels[0], channels[1], channels[2], channels[3] def rotate_strip_jax( strip_image: jnp.ndarray, angle: float, ) -> jnp.ndarray: """Rotate an RGBA strip by angle (degrees), counter-clockwise. Output buffer is sized to contain the full rotated strip. The output size is ceil(sqrt(w^2 + h^2)), computed at trace time from the strip's static shape. Args: strip_image: (H, W, 4) RGBA uint8 angle: Rotation angle in degrees Returns: (out_h, out_w, 4) RGBA uint8 - rotated strip """ sh, sw = strip_image.shape[:2] # Output size: diagonal of original strip (compile-time constant). # Ensure output dimensions have same parity as source so that the # center offset (out - src) / 2 is always an integer. Otherwise # identity rotations would place content at half-pixel offsets. diag = int(math.ceil(math.sqrt(sw * sw + sh * sh))) out_w = diag + ((diag % 2) != (sw % 2)) out_h = diag + ((diag % 2) != (sh % 2)) # Center of input strip and output buffer (pixel-center convention). # Using (dim-1)/2 ensures integer coords map to integer coords for # identity rotation regardless of even/odd dimension parity. src_cx = (sw - 1) / 2.0 src_cy = (sh - 1) / 2.0 dst_cx = (out_w - 1) / 2.0 dst_cy = (out_h - 1) / 2.0 # Convert to radians and snap trig values near 0/±1 to exact values. # Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing # bilinear blending at pixel edges and 1-value differences. theta = angle * jnp.pi / 180.0 cos_t = jnp.cos(theta) sin_t = jnp.sin(theta) cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) # Create output coordinate grid y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w) x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w) # Inverse rotation: map output coords to source coords x_centered = x_coords.astype(jnp.float32) - dst_cx y_centered = y_coords.astype(jnp.float32) - dst_cy src_x = cos_t * x_centered - sin_t * y_centered + src_cx src_y = sin_t * x_centered + cos_t * y_centered + src_cy # Sample all 4 channels strip_f = strip_image.astype(jnp.float32) r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), ], axis=2) # ============================================================================= # Shadow Compositing # ============================================================================= def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray: """Blur a single-channel alpha array using Gaussian convolution. Args: alpha: (H, W) float32 alpha channel radius: Blur radius (compile-time constant) Returns: (H, W) float32 blurred alpha """ size = radius * 2 + 1 x = jnp.arange(size, dtype=jnp.float32) - radius sigma = max(radius / 2.0, 0.5) gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2)) gaussian_1d = gaussian_1d / gaussian_1d.sum() kernel = jnp.outer(gaussian_1d, gaussian_1d) # Use JAX conv with SAME padding h, w = alpha.shape data_4d = alpha.reshape(1, h, w, 1) kernel_4d = kernel.reshape(size, size, 1, 1) result = lax.conv_general_dilated( data_4d, kernel_4d, window_strides=(1, 1), padding='SAME', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), ) return result.reshape(h, w) def place_text_strip_shadow_jax( frame: jnp.ndarray, strip_image: jnp.ndarray, x: float, y: float, baseline_y: int, bearing_x: float, color: jnp.ndarray, opacity: float = 1.0, anchor_x: float = 0.0, anchor_y: float = 0.0, stroke_width: int = 0, shadow_offset_x: float = 3.0, shadow_offset_y: float = 3.0, shadow_color: jnp.ndarray = None, shadow_opacity: float = 0.5, shadow_blur_radius: int = 0, ) -> jnp.ndarray: """Place text strip with a drop shadow. Composites the strip twice: first as shadow (offset, colored, optionally blurred), then the text itself on top. Args: frame: (H, W, 3) RGB frame strip_image: (sh, sw, 4) RGBA text strip shadow_offset_x/y: Shadow offset in pixels shadow_color: (3,) RGB color for shadow (default black) shadow_opacity: Shadow opacity 0-1 shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time) Other args same as place_text_strip_jax Returns: Composited frame """ if shadow_color is None: shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) sh, sw = strip_image.shape[:2] # --- Shadow pass --- shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) # Shadow alpha from strip alpha shadow_opacity_int = jnp.round(shadow_opacity * 255) strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) if shadow_blur_radius > 0: # Blur the alpha channel for soft shadow blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 else: shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1) # Shadow RGB: solid shadow color shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) # --- Text pass --- dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 opacity_int = jnp.round(opacity * 255) text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0 color_norm = color.astype(jnp.float32) / 255.0 tinted = strip_rgb * color_norm frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) return frame # ============================================================================= # Combined FX Pipeline # ============================================================================= def place_text_strip_fx_jax( frame: jnp.ndarray, strip_image: jnp.ndarray, x: float, y: float, baseline_y: int = 0, bearing_x: float = 0.0, color: jnp.ndarray = None, opacity: float = 1.0, anchor_x: float = 0.0, anchor_y: float = 0.0, stroke_width: int = 0, gradient_map: jnp.ndarray = None, angle: float = 0.0, shadow_offset_x: float = 0.0, shadow_offset_y: float = 0.0, shadow_color: jnp.ndarray = None, shadow_opacity: float = 0.0, shadow_blur_radius: int = 0, ) -> jnp.ndarray: """Combined text placement with gradient, rotation, and shadow. Pipeline order: 1. Build color layer (solid color or gradient map) 2. Rotate strip + color layer if angle != 0 3. Composite shadow if shadow_opacity > 0 4. Composite text Note: angle and shadow_blur_radius should be compile-time constants for optimal JIT performance (they affect buffer shapes/kernel sizes). Args: frame: (H, W, 3) RGB frame strip_image: (sh, sw, 4) RGBA text strip x, y: Anchor point position color: (3,) RGB color (ignored if gradient_map provided) opacity: Text opacity gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color angle: Rotation angle in degrees (0 = no rotation) shadow_offset_x/y: Shadow offset shadow_color: (3,) RGB shadow color shadow_opacity: Shadow opacity (0 = no shadow) shadow_blur_radius: Shadow blur radius Returns: Composited frame """ if color is None: color = jnp.array([255, 255, 255], dtype=jnp.float32) if shadow_color is None: shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) sh, sw = strip_image.shape[:2] # --- Step 1: Build color layer --- if gradient_map is not None: color_layer = gradient_map # (sh, sw, 3) float32 [0, 1] else: color_norm = color.astype(jnp.float32) / 255.0 color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3)) # --- Step 2: Rotate if needed --- # angle is expected to be a compile-time constant or static value # We check at Python level to avoid tracing issues with dynamic shapes use_rotation = not isinstance(angle, (int, float)) or angle != 0.0 if use_rotation: # Rotate the strip rotated_strip = rotate_strip_jax(strip_image, angle) rh, rw = rotated_strip.shape[:2] # Rotate the color layer by building a 4-channel color+dummy image # Actually, just re-create color layer at rotated size if gradient_map is not None: # Rotate gradient map: pack into 3-channel "image", rotate via sampling grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8) # Create RGBA from gradient (alpha=255 everywhere) grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2) rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle) color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0 else: # Solid color: just broadcast to rotated size color_norm = color.astype(jnp.float32) / 255.0 color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3)) # Update anchor point for rotation (pixel-center convention) # Rotate the anchor offset around the strip center theta = angle * jnp.pi / 180.0 cos_t = jnp.cos(theta) sin_t = jnp.sin(theta) cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) # Original anchor relative to strip pixel center src_cx = (sw - 1) / 2.0 src_cy = (sh - 1) / 2.0 dst_cx = (rw - 1) / 2.0 dst_cy = (rh - 1) / 2.0 ax_rel = anchor_x - src_cx ay_rel = anchor_y - src_cy # Rotate anchor point (forward rotation, not inverse) new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy strip_image = rotated_strip anchor_x = new_ax anchor_y = new_ay sh, sw = rh, rw # --- Step 3: Shadow --- has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0 if has_shadow: shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) shadow_opacity_int = jnp.round(shadow_opacity * 255) strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) if shadow_blur_radius > 0: blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 else: shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 shadow_alpha = shadow_alpha[:, :, None] shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) # --- Step 4: Composite text --- dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 opacity_int = jnp.round(opacity * 255) strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 tinted = strip_rgb * color_layer frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) return frame # ============================================================================= # S-Expression Primitive Bindings # ============================================================================= def bind_typography_primitives(env: dict) -> dict: """ Add typography primitives to an S-expression environment. Primitives added: (text-glyphs text font-size) -> list of glyph data (glyph-image g) -> JAX array (H, W, 4) (glyph-advance g) -> float (glyph-bearing-x g) -> float (glyph-bearing-y g) -> float (glyph-width g) -> int (glyph-height g) -> int (font-ascent font-size) -> float (font-descent font-size) -> float (place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame """ def prim_text_glyphs(text, font_size=32, font_name=None): """Get list of glyph data for text. Compile-time.""" return get_glyphs(str(text), font_name, int(font_size)) def prim_glyph_image(glyph): """Get glyph image as JAX array.""" return jnp.asarray(glyph.image) def prim_glyph_advance(glyph): """Get glyph advance width.""" return glyph.advance def prim_glyph_bearing_x(glyph): """Get glyph left side bearing.""" return glyph.bearing_x def prim_glyph_bearing_y(glyph): """Get glyph top bearing.""" return glyph.bearing_y def prim_glyph_width(glyph): """Get glyph image width.""" return glyph.width def prim_glyph_height(glyph): """Get glyph image height.""" return glyph.height def prim_font_ascent(font_size=32, font_name=None): """Get font ascent.""" return get_font_ascent(font_name, int(font_size)) def prim_font_descent(font_size=32, font_name=None): """Get font descent.""" return get_font_descent(font_name, int(font_size)) def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y, color=(255, 255, 255), opacity=1.0): """Place glyph on frame. Runtime JAX operation.""" color_arr = jnp.array(color, dtype=jnp.float32) return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y, color_arr, opacity) def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None): """Get kerning adjustment between two glyphs. Compile-time. Returns adjustment to add to glyph1's advance when glyph2 follows. Typically negative (characters move closer). Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size)) """ return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size)) def prim_char_kerning(char1, char2, font_size=32, font_name=None): """Get kerning adjustment between two characters. Compile-time.""" return get_kerning(str(char1), str(char2), font_name, int(font_size)) # TextStrip primitives for pre-rendered text with proper anti-aliasing def prim_render_text_strip(text, font_size=32, font_name=None): """Render text to a strip at compile time. Perfect anti-aliasing.""" return render_text_strip(str(text), font_name, int(font_size)) def prim_render_text_strip_styled( text, font_size=32, font_name=None, stroke_width=0, stroke_fill=None, anchor="la", multiline=False, line_spacing=4, align="left" ): """Render styled text to a strip. Supports stroke, anchors, multiline. Args: text: Text to render font_size: Size in pixels font_name: Path to font file stroke_width: Outline width (0 = no outline) stroke_fill: Outline color as (R,G,B) or (R,G,B,A) anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender) multiline: If True, handle newlines line_spacing: Extra pixels between lines align: "left", "center", "right" for multiline """ return render_text_strip( str(text), font_name, int(font_size), stroke_width=int(stroke_width), stroke_fill=stroke_fill, anchor=str(anchor), multiline=bool(multiline), line_spacing=int(line_spacing), align=str(align), ) def prim_text_strip_image(strip): """Get text strip image as JAX array.""" return jnp.asarray(strip.image) def prim_text_strip_width(strip): """Get text strip width.""" return strip.width def prim_text_strip_height(strip): """Get text strip height.""" return strip.height def prim_text_strip_baseline_y(strip): """Get text strip baseline Y position.""" return strip.baseline_y def prim_text_strip_bearing_x(strip): """Get text strip left bearing.""" return strip.bearing_x def prim_text_strip_anchor_x(strip): """Get text strip anchor X offset.""" return strip.anchor_x def prim_text_strip_anchor_y(strip): """Get text strip anchor Y offset.""" return strip.anchor_y def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0): """Place pre-rendered text strip on frame. Runtime JAX operation.""" strip_img = jnp.asarray(strip.image) color_arr = jnp.array(color, dtype=jnp.float32) return place_text_strip_jax( frame, strip_img, x, y, strip.baseline_y, strip.bearing_x, color_arr, opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width ) # --- Gradient primitives --- def prim_linear_gradient(strip, color1, color2, angle=0.0): """Create linear gradient color map for a text strip. Compile-time.""" grad = make_linear_gradient(strip.width, strip.height, tuple(int(c) for c in color1), tuple(int(c) for c in color2), float(angle)) return jnp.asarray(grad) def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5): """Create radial gradient color map for a text strip. Compile-time.""" grad = make_radial_gradient(strip.width, strip.height, tuple(int(c) for c in color1), tuple(int(c) for c in color2), float(center_x), float(center_y)) return jnp.asarray(grad) def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False, center_x=0.5, center_y=0.5): """Create multi-stop gradient for a text strip. Compile-time. stops: list of (position, (R, G, B)) tuples """ parsed_stops = [] for s in stops: pos = float(s[0]) color_tuple = tuple(int(c) for c in s[1]) parsed_stops.append((pos, color_tuple)) grad = make_multi_stop_gradient(strip.width, strip.height, parsed_stops, float(angle), bool(radial), float(center_x), float(center_y)) return jnp.asarray(grad) def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0): """Place text strip with gradient coloring. Runtime JAX operation.""" strip_img = jnp.asarray(strip.image) return place_text_strip_gradient_jax( frame, strip_img, x, y, strip.baseline_y, strip.bearing_x, gradient_map, opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width ) # --- Rotation primitive --- def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255), opacity=1.0, angle=0.0): """Place text strip with rotation. Runtime JAX operation.""" strip_img = jnp.asarray(strip.image) color_arr = jnp.array(color, dtype=jnp.float32) return place_text_strip_fx_jax( frame, strip_img, x, y, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, color=color_arr, opacity=opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width, angle=float(angle), ) # --- Shadow primitive --- def prim_place_text_strip_shadow(frame, strip, x, y, color=(255, 255, 255), opacity=1.0, shadow_offset_x=3.0, shadow_offset_y=3.0, shadow_color=(0, 0, 0), shadow_opacity=0.5, shadow_blur_radius=0): """Place text strip with shadow. Runtime JAX operation.""" strip_img = jnp.asarray(strip.image) color_arr = jnp.array(color, dtype=jnp.float32) shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) return place_text_strip_shadow_jax( frame, strip_img, x, y, strip.baseline_y, strip.bearing_x, color_arr, opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width, shadow_offset_x=float(shadow_offset_x), shadow_offset_y=float(shadow_offset_y), shadow_color=shadow_color_arr, shadow_opacity=float(shadow_opacity), shadow_blur_radius=int(shadow_blur_radius), ) # --- Combined FX primitive --- def prim_place_text_strip_fx(frame, strip, x, y, color=(255, 255, 255), opacity=1.0, gradient=None, angle=0.0, shadow_offset_x=0.0, shadow_offset_y=0.0, shadow_color=(0, 0, 0), shadow_opacity=0.0, shadow_blur=0): """Place text strip with all effects. Runtime JAX operation.""" strip_img = jnp.asarray(strip.image) color_arr = jnp.array(color, dtype=jnp.float32) shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) return place_text_strip_fx_jax( frame, strip_img, x, y, baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, color=color_arr, opacity=opacity, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width, gradient_map=gradient, angle=float(angle), shadow_offset_x=float(shadow_offset_x), shadow_offset_y=float(shadow_offset_y), shadow_color=shadow_color_arr, shadow_opacity=float(shadow_opacity), shadow_blur_radius=int(shadow_blur), ) # Add to environment env.update({ # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) 'text-glyphs': prim_text_glyphs, 'glyph-image': prim_glyph_image, 'glyph-advance': prim_glyph_advance, 'glyph-bearing-x': prim_glyph_bearing_x, 'glyph-bearing-y': prim_glyph_bearing_y, 'glyph-width': prim_glyph_width, 'glyph-height': prim_glyph_height, 'glyph-kerning': prim_glyph_kerning, 'char-kerning': prim_char_kerning, 'font-ascent': prim_font_ascent, 'font-descent': prim_font_descent, 'place-glyph': prim_place_glyph, # TextStrip primitives (for pixel-perfect static text) 'render-text-strip': prim_render_text_strip, 'render-text-strip-styled': prim_render_text_strip_styled, 'text-strip-image': prim_text_strip_image, 'text-strip-width': prim_text_strip_width, 'text-strip-height': prim_text_strip_height, 'text-strip-baseline-y': prim_text_strip_baseline_y, 'text-strip-bearing-x': prim_text_strip_bearing_x, 'text-strip-anchor-x': prim_text_strip_anchor_x, 'text-strip-anchor-y': prim_text_strip_anchor_y, 'place-text-strip': prim_place_text_strip, # Gradient primitives 'linear-gradient': prim_linear_gradient, 'radial-gradient': prim_radial_gradient, 'multi-stop-gradient': prim_multi_stop_gradient, 'place-text-strip-gradient': prim_place_text_strip_gradient, # Rotation 'place-text-strip-rotated': prim_place_text_strip_rotated, # Shadow 'place-text-strip-shadow': prim_place_text_strip_shadow, # Combined FX 'place-text-strip-fx': prim_place_text_strip_fx, }) return env # ============================================================================= # Example: Render text using primitives (for testing) # ============================================================================= def render_text_primitives( frame: jnp.ndarray, text: str, x: float, y: float, font_size: int = 32, color: tuple = (255, 255, 255), opacity: float = 1.0, use_kerning: bool = True, ) -> jnp.ndarray: """ Render text using the primitives. This is what an S-expression would compile to. Args: use_kerning: If True, apply kerning adjustments between characters """ glyphs = get_glyphs(text, None, font_size) color_arr = jnp.array(color, dtype=jnp.float32) cursor = x for i, g in enumerate(glyphs): glyph_jax = jnp.asarray(g.image) frame = place_glyph_jax( frame, glyph_jax, cursor, y, g.bearing_x, g.bearing_y, color_arr, opacity ) # Advance cursor with optional kerning advance = g.advance if use_kerning and i + 1 < len(glyphs): advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size) cursor = cursor + advance return frame