Files
celery/streaming/jax_typography.py
gilesb fc9597456f
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:41:19 +00:00

861 lines
30 KiB
Python

"""
JAX Typography Primitives
Two approaches for text rendering, both compile to JAX/GPU:
## 1. TextStrip - Pixel-perfect static text
Pre-render entire strings at compile time using PIL.
Perfect sub-pixel anti-aliasing, exact match with PIL.
Use for: static titles, labels, any text without per-character effects.
S-expression:
(let ((strip (render-text-strip "Hello World" 48)))
(place-text-strip frame strip x y :color white))
## 2. Glyph-by-glyph - Dynamic text effects
Individual glyph placement for wave, arc, audio-reactive effects.
Each character can have independent position, color, opacity.
Note: slight anti-aliasing differences vs PIL due to integer positioning.
S-expression:
; Wave text - y oscillates with character index
(let ((glyphs (text-glyphs "Wavy" 48)))
(first
(fold glyphs (list frame 0)
(lambda (acc g)
(let ((frm (first acc))
(cursor (second acc))
(i (length acc))) ; approximate index
(list
(place-glyph frm (glyph-image g)
(+ x cursor)
(+ y (* amplitude (sin (* i frequency))))
(glyph-bearing-x g) (glyph-bearing-y g)
white 1.0)
(+ cursor (glyph-advance g))))))))
; Audio-reactive spacing
(let ((glyphs (text-glyphs "Bass" 48))
(bass (audio-band 0 200)))
(first
(fold glyphs (list frame 0)
(lambda (acc g)
(let ((frm (first acc))
(cursor (second acc)))
(list
(place-glyph frm (glyph-image g)
(+ x cursor) y
(glyph-bearing-x g) (glyph-bearing-y g)
white 1.0)
(+ cursor (glyph-advance g) (* bass 20))))))))
Kerning support:
; With kerning adjustment
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
"""
import numpy as np
import jax.numpy as jnp
from typing import Tuple, Dict, Any
from dataclasses import dataclass
# =============================================================================
# Glyph Data (computed at compile time)
# =============================================================================
@dataclass
class GlyphData:
"""Glyph data computed at compile time.
Attributes:
char: The character
image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime
advance: Horizontal advance (distance to next glyph origin)
bearing_x: Left side bearing (x offset from origin to first pixel)
bearing_y: Top bearing (y offset from baseline to top of glyph)
width: Image width
height: Image height
"""
char: str
image: np.ndarray # (H, W, 4) RGBA uint8
advance: float
bearing_x: float
bearing_y: float
width: int
height: int
# Font cache: (font_name, font_size) -> {char: GlyphData}
_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {}
# Font metrics cache: (font_name, font_size) -> (ascent, descent)
_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {}
# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment}
# Kerning adjustment is added to advance: new_advance = advance + kerning
# Typically negative (characters move closer together)
_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {}
def _load_font(font_name: str = None, font_size: int = 32):
"""Load a font. Called at compile time."""
from PIL import ImageFont
candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
]
for path in candidates:
if path is None:
continue
try:
return ImageFont.truetype(path, font_size)
except (IOError, OSError):
continue
return ImageFont.load_default()
def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]:
"""Get or create glyph cache for a font. Called at compile time."""
cache_key = (font_name, font_size)
if cache_key in _GLYPH_CACHE:
return _GLYPH_CACHE[cache_key]
from PIL import Image, ImageDraw
font = _load_font(font_name, font_size)
ascent, descent = font.getmetrics()
_METRICS_CACHE[cache_key] = (ascent, descent)
glyphs = {}
charset = ''.join(chr(i) for i in range(32, 127))
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
temp_draw = ImageDraw.Draw(temp_img)
for char in charset:
# Get metrics
bbox = temp_draw.textbbox((0, 0), char, font=font)
advance = font.getlength(char)
x_min, y_min, x_max, y_max = bbox
# Create glyph image with padding
padding = 2
img_w = max(int(x_max - x_min) + padding * 2, 1)
img_h = max(int(y_max - y_min) + padding * 2, 1)
glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0))
glyph_draw = ImageDraw.Draw(glyph_img)
# Draw at position accounting for bbox offset
draw_x = padding - x_min
draw_y = padding - y_min
glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
glyphs[char] = GlyphData(
char=char,
image=np.array(glyph_img, dtype=np.uint8),
advance=float(advance),
bearing_x=float(x_min),
bearing_y=float(-y_min), # Distance from baseline to top
width=img_w,
height=img_h,
)
_GLYPH_CACHE[cache_key] = glyphs
return glyphs
def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]:
"""Get or create kerning cache for a font. Called at compile time.
Kerning is computed as:
kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b)
This gives the adjustment needed when placing 'b' after 'a'.
Typically negative (characters move closer together).
"""
cache_key = (font_name, font_size)
if cache_key in _KERNING_CACHE:
return _KERNING_CACHE[cache_key]
font = _load_font(font_name, font_size)
kerning = {}
# Compute kerning for all printable ASCII pairs
charset = ''.join(chr(i) for i in range(32, 127))
# Pre-compute individual character lengths
char_lengths = {c: font.getlength(c) for c in charset}
# Compute kerning for each pair
for c1 in charset:
for c2 in charset:
pair_length = font.getlength(c1 + c2)
individual_sum = char_lengths[c1] + char_lengths[c2]
kern = pair_length - individual_sum
# Only store non-zero kerning to save memory
if abs(kern) > 0.01:
kerning[(c1, c2)] = kern
_KERNING_CACHE[cache_key] = kerning
return kerning
def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float:
"""Get kerning adjustment between two characters. Compile-time.
Returns the adjustment to add to char1's advance when char2 follows.
Typically negative (characters move closer).
Usage in S-expression:
(+ (glyph-advance g1) (kerning g1 g2))
"""
kerning_cache = _get_kerning_cache(font_name, font_size)
return kerning_cache.get((char1, char2), 0.0)
@dataclass
class TextStrip:
"""Pre-rendered text strip with proper sub-pixel anti-aliasing.
Rendered at compile time using PIL for exact matching.
At runtime, just composite onto frame at integer positions.
Attributes:
text: The original text
image: RGBA image as numpy array (H, W, 4)
width: Strip width
height: Strip height
baseline_y: Y position of baseline within the strip
bearing_x: Left side bearing of first character
anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right)
anchor_y: Y offset for anchor point (depends on anchor type)
stroke_width: Stroke width used when rendering
"""
text: str
image: np.ndarray
width: int
height: int
baseline_y: int
bearing_x: float
anchor_x: float = 0.0
anchor_y: float = 0.0
stroke_width: int = 0
# Text strip cache: cache_key -> TextStrip
_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {}
def render_text_strip(
text: str,
font_name: str = None,
font_size: int = 32,
stroke_width: int = 0,
stroke_fill: tuple = None,
anchor: str = "la", # left-ascender (PIL default is "la")
multiline: bool = False,
line_spacing: int = 4,
align: str = "left",
) -> TextStrip:
"""Render text to a strip at compile time. Perfect sub-pixel anti-aliasing.
Args:
text: Text to render
font_name: Path to font file (None for default)
font_size: Font size in pixels
stroke_width: Outline width in pixels (0 for no outline)
stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black
anchor: PIL anchor code - first char: h=left, m=middle, r=right
second char: a=ascender, t=top, m=middle, s=baseline, d=descender
multiline: If True, handle newlines in text
line_spacing: Extra pixels between lines (for multiline)
align: 'left', 'center', 'right' (for multiline)
Returns:
TextStrip with pre-rendered text
"""
# Build cache key from all parameters
cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align)
if cache_key in _TEXT_STRIP_CACHE:
return _TEXT_STRIP_CACHE[cache_key]
from PIL import Image, ImageDraw
font = _load_font(font_name, font_size)
ascent, descent = font.getmetrics()
# Default stroke fill to black
if stroke_fill is None:
stroke_fill = (0, 0, 0, 255)
elif len(stroke_fill) == 3:
stroke_fill = (*stroke_fill, 255)
# Get text bbox (accounting for stroke)
temp = Image.new('RGBA', (1, 1))
temp_draw = ImageDraw.Draw(temp)
if multiline:
bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing,
stroke_width=stroke_width)
else:
bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width)
# bbox is (left, top, right, bottom) relative to origin
x_min, y_min, x_max, y_max = bbox
# Create image with padding (extra for stroke)
padding = 2 + stroke_width
img_width = max(int(x_max - x_min) + padding * 2, 1)
img_height = max(int(y_max - y_min) + padding * 2, 1)
# Create RGBA image
img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Draw text at position that puts it in the image
draw_x = padding - x_min
draw_y = padding - y_min
if multiline:
draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
spacing=line_spacing, align=align,
stroke_width=stroke_width, stroke_fill=stroke_fill)
else:
draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill)
# Baseline is at y=0 in text coordinates, which is at draw_y in image
baseline_y = draw_y
# Convert to numpy for pixel analysis
img_array = np.array(img, dtype=np.uint8)
# Calculate anchor offsets
# For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching
h_anchor = anchor[0] if len(anchor) > 0 else 'l'
v_anchor = anchor[1] if len(anchor) > 1 else 'a'
# Find actual pixel bounds (for middle anchor calculations)
alpha = img_array[:, :, 3]
nonzero_cols = np.where(alpha.max(axis=0) > 0)[0]
nonzero_rows = np.where(alpha.max(axis=1) > 0)[0]
if len(nonzero_cols) > 0:
pixel_x_min = nonzero_cols.min()
pixel_x_max = nonzero_cols.max()
pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0
else:
pixel_x_center = img_width / 2.0
if len(nonzero_rows) > 0:
pixel_y_min = nonzero_rows.min()
pixel_y_max = nonzero_rows.max()
pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0
else:
pixel_y_center = img_height / 2.0
# Horizontal offset
text_width = x_max - x_min
if h_anchor == 'l': # left edge of text
anchor_x = float(draw_x)
elif h_anchor == 'm': # middle - use actual pixel center for perfect matching
anchor_x = pixel_x_center
elif h_anchor == 'r': # right edge of text
anchor_x = float(draw_x + text_width)
else:
anchor_x = float(draw_x)
# Vertical offset
# PIL anchor positions are based on font metrics (ascent/descent):
# - 'a' (ascender): at the ascender line → draw_y in strip
# - 't' (top): at top of text bounding box → padding in strip
# - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender
# - 's' (baseline): at baseline = ascent below ascender
# - 'd' (descender): at descender line = ascent + descent below ascender
if v_anchor == 'a': # ascender
anchor_y = float(draw_y)
elif v_anchor == 't': # top of bbox
anchor_y = float(padding)
elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation)
anchor_y = float(draw_y + (ascent + descent) / 2.0)
elif v_anchor == 's': # baseline
anchor_y = float(draw_y + ascent)
elif v_anchor == 'd': # descender
anchor_y = float(draw_y + ascent + descent)
else:
anchor_y = float(draw_y) # default to ascender
strip = TextStrip(
text=text,
image=img_array,
width=img_width,
height=img_height,
baseline_y=baseline_y,
bearing_x=float(x_min),
anchor_x=anchor_x,
anchor_y=anchor_y,
stroke_width=stroke_width,
)
_TEXT_STRIP_CACHE[cache_key] = strip
return strip
# =============================================================================
# Compile-time functions (called during S-expression compilation)
# =============================================================================
def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData:
"""Get glyph data for a single character. Compile-time."""
cache = _get_glyph_cache(font_name, font_size)
return cache.get(char, cache.get(' '))
def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list:
"""Get glyph data for a string. Compile-time."""
cache = _get_glyph_cache(font_name, font_size)
space = cache.get(' ')
return [cache.get(c, space) for c in text]
def get_font_ascent(font_name: str = None, font_size: int = 32) -> float:
"""Get font ascent. Compile-time."""
_get_glyph_cache(font_name, font_size) # Ensure cache exists
return _METRICS_CACHE[(font_name, font_size)][0]
def get_font_descent(font_name: str = None, font_size: int = 32) -> float:
"""Get font descent. Compile-time."""
_get_glyph_cache(font_name, font_size)
return _METRICS_CACHE[(font_name, font_size)][1]
# =============================================================================
# JAX Runtime Primitives
# =============================================================================
def place_glyph_jax(
frame: jnp.ndarray,
glyph_image: jnp.ndarray, # (H, W, 4) RGBA
x: float,
y: float,
bearing_x: float,
bearing_y: float,
color: jnp.ndarray, # (3,) RGB 0-255
opacity: float = 1.0,
) -> jnp.ndarray:
"""
Place a glyph onto a frame. This is the core JAX primitive.
All positioning math can use traced values (x, y from audio, time, etc.)
The glyph_image is static (determined at compile time).
Args:
frame: (H, W, 3) RGB frame
glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array)
x: X position of glyph origin (baseline point)
y: Y position of baseline
bearing_x: Left side bearing
bearing_y: Top bearing (from baseline to top)
color: RGB color array
opacity: Opacity 0-1
Returns:
Frame with glyph composited
"""
h, w = frame.shape[:2]
gh, gw = glyph_image.shape[:2]
# Calculate destination position
# bearing_x: how far right of origin the glyph starts (can be negative)
# bearing_y: how far up from baseline the glyph extends
padding = 2 # Must match padding used in glyph creation
dst_x = x + bearing_x - padding
dst_y = y - bearing_y - padding
# Extract glyph RGB and alpha
glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
opacity_int = jnp.round(opacity * 255)
glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32)
glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply color tint (glyph is white, multiply by color)
color_normalized = color.astype(jnp.float32) / 255.0
tinted = glyph_rgb * color_normalized
from jax.lax import dynamic_update_slice
# Use padded buffer to avoid XLA's dynamic_update_slice clamping
buf_h = h + 2 * gh
buf_w = w + 2 * gw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
dst_x_int = dst_x.astype(jnp.int32)
dst_y_int = dst_y.astype(jnp.int32)
place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32)
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0))
rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :]
alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :]
# Alpha composite using PIL-compatible integer arithmetic
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_text_strip_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray, # (H, W, 4) RGBA
x: float,
y: float,
baseline_y: int,
bearing_x: float,
color: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
) -> jnp.ndarray:
"""
Place a pre-rendered text strip onto a frame.
The strip was rendered at compile time with proper sub-pixel anti-aliasing.
This just composites it at the specified position.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
x: X position for anchor point
y: Y position for anchor point
baseline_y: Y position of baseline within the strip
bearing_x: Left side bearing
color: RGB color
opacity: Opacity 0-1
anchor_x: X offset of anchor point within strip
anchor_y: Y offset of anchor point within strip
stroke_width: Stroke width used when rendering (affects padding)
Returns:
Frame with text composited
"""
h, w = frame.shape[:2]
sh, sw = strip_image.shape[:2]
# Calculate destination position
# Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame
# anchor_x/anchor_y already account for the anchor position within the strip
# Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding)
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
# Extract strip RGB and alpha
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
# Use jnp.round (banker's rounding) to match Python's round() used by PIL
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply color tint
color_normalized = color.astype(jnp.float32) / 255.0
tinted = strip_rgb * color_normalized
from jax.lax import dynamic_update_slice
# Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior.
# XLA clamps indices so the update fits, which silently shifts the strip.
# By placing into a buffer padded by strip dimensions, then extracting the
# frame-sized region, we get correct clipping for both overflow and underflow.
buf_h = h + 2 * sh
buf_w = w + 2 * sw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
# Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
# Extract frame-sized region (sh, sw are compile-time constants from strip shape)
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
# Alpha composite using PIL-compatible integer arithmetic:
# result = (src * alpha + dst * (255 - alpha) + 127) // 255
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_glyph_simple(
frame: jnp.ndarray,
glyph: GlyphData,
x: float,
y: float,
color: tuple = (255, 255, 255),
opacity: float = 1.0,
) -> jnp.ndarray:
"""
Convenience wrapper that takes GlyphData directly.
Converts glyph image to JAX array.
For S-expression use, prefer place_glyph_jax with pre-converted arrays.
"""
glyph_jax = jnp.asarray(glyph.image)
color_jax = jnp.array(color, dtype=jnp.float32)
return place_glyph_jax(
frame, glyph_jax, x, y,
glyph.bearing_x, glyph.bearing_y,
color_jax, opacity
)
# =============================================================================
# S-Expression Primitive Bindings
# =============================================================================
def bind_typography_primitives(env: dict) -> dict:
"""
Add typography primitives to an S-expression environment.
Primitives added:
(text-glyphs text font-size) -> list of glyph data
(glyph-image g) -> JAX array (H, W, 4)
(glyph-advance g) -> float
(glyph-bearing-x g) -> float
(glyph-bearing-y g) -> float
(glyph-width g) -> int
(glyph-height g) -> int
(font-ascent font-size) -> float
(font-descent font-size) -> float
(place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame
"""
def prim_text_glyphs(text, font_size=32, font_name=None):
"""Get list of glyph data for text. Compile-time."""
return get_glyphs(str(text), font_name, int(font_size))
def prim_glyph_image(glyph):
"""Get glyph image as JAX array."""
return jnp.asarray(glyph.image)
def prim_glyph_advance(glyph):
"""Get glyph advance width."""
return glyph.advance
def prim_glyph_bearing_x(glyph):
"""Get glyph left side bearing."""
return glyph.bearing_x
def prim_glyph_bearing_y(glyph):
"""Get glyph top bearing."""
return glyph.bearing_y
def prim_glyph_width(glyph):
"""Get glyph image width."""
return glyph.width
def prim_glyph_height(glyph):
"""Get glyph image height."""
return glyph.height
def prim_font_ascent(font_size=32, font_name=None):
"""Get font ascent."""
return get_font_ascent(font_name, int(font_size))
def prim_font_descent(font_size=32, font_name=None):
"""Get font descent."""
return get_font_descent(font_name, int(font_size))
def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y,
color=(255, 255, 255), opacity=1.0):
"""Place glyph on frame. Runtime JAX operation."""
color_arr = jnp.array(color, dtype=jnp.float32)
return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y,
color_arr, opacity)
def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None):
"""Get kerning adjustment between two glyphs. Compile-time.
Returns adjustment to add to glyph1's advance when glyph2 follows.
Typically negative (characters move closer).
Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size))
"""
return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size))
def prim_char_kerning(char1, char2, font_size=32, font_name=None):
"""Get kerning adjustment between two characters. Compile-time."""
return get_kerning(str(char1), str(char2), font_name, int(font_size))
# TextStrip primitives for pre-rendered text with proper anti-aliasing
def prim_render_text_strip(text, font_size=32, font_name=None):
"""Render text to a strip at compile time. Perfect anti-aliasing."""
return render_text_strip(str(text), font_name, int(font_size))
def prim_render_text_strip_styled(
text, font_size=32, font_name=None,
stroke_width=0, stroke_fill=None,
anchor="la", multiline=False, line_spacing=4, align="left"
):
"""Render styled text to a strip. Supports stroke, anchors, multiline.
Args:
text: Text to render
font_size: Size in pixels
font_name: Path to font file
stroke_width: Outline width (0 = no outline)
stroke_fill: Outline color as (R,G,B) or (R,G,B,A)
anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender)
multiline: If True, handle newlines
line_spacing: Extra pixels between lines
align: "left", "center", "right" for multiline
"""
return render_text_strip(
str(text), font_name, int(font_size),
stroke_width=int(stroke_width),
stroke_fill=stroke_fill,
anchor=str(anchor),
multiline=bool(multiline),
line_spacing=int(line_spacing),
align=str(align),
)
def prim_text_strip_image(strip):
"""Get text strip image as JAX array."""
return jnp.asarray(strip.image)
def prim_text_strip_width(strip):
"""Get text strip width."""
return strip.width
def prim_text_strip_height(strip):
"""Get text strip height."""
return strip.height
def prim_text_strip_baseline_y(strip):
"""Get text strip baseline Y position."""
return strip.baseline_y
def prim_text_strip_bearing_x(strip):
"""Get text strip left bearing."""
return strip.bearing_x
def prim_text_strip_anchor_x(strip):
"""Get text strip anchor X offset."""
return strip.anchor_x
def prim_text_strip_anchor_y(strip):
"""Get text strip anchor Y offset."""
return strip.anchor_y
def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0):
"""Place pre-rendered text strip on frame. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
return place_text_strip_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color_arr, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
# Add to environment
env.update({
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
'text-glyphs': prim_text_glyphs,
'glyph-image': prim_glyph_image,
'glyph-advance': prim_glyph_advance,
'glyph-bearing-x': prim_glyph_bearing_x,
'glyph-bearing-y': prim_glyph_bearing_y,
'glyph-width': prim_glyph_width,
'glyph-height': prim_glyph_height,
'glyph-kerning': prim_glyph_kerning,
'char-kerning': prim_char_kerning,
'font-ascent': prim_font_ascent,
'font-descent': prim_font_descent,
'place-glyph': prim_place_glyph,
# TextStrip primitives (for pixel-perfect static text)
'render-text-strip': prim_render_text_strip,
'render-text-strip-styled': prim_render_text_strip_styled,
'text-strip-image': prim_text_strip_image,
'text-strip-width': prim_text_strip_width,
'text-strip-height': prim_text_strip_height,
'text-strip-baseline-y': prim_text_strip_baseline_y,
'text-strip-bearing-x': prim_text_strip_bearing_x,
'text-strip-anchor-x': prim_text_strip_anchor_x,
'text-strip-anchor-y': prim_text_strip_anchor_y,
'place-text-strip': prim_place_text_strip,
})
return env
# =============================================================================
# Example: Render text using primitives (for testing)
# =============================================================================
def render_text_primitives(
frame: jnp.ndarray,
text: str,
x: float,
y: float,
font_size: int = 32,
color: tuple = (255, 255, 255),
opacity: float = 1.0,
use_kerning: bool = True,
) -> jnp.ndarray:
"""
Render text using the primitives.
This is what an S-expression would compile to.
Args:
use_kerning: If True, apply kerning adjustments between characters
"""
glyphs = get_glyphs(text, None, font_size)
color_arr = jnp.array(color, dtype=jnp.float32)
cursor = x
for i, g in enumerate(glyphs):
glyph_jax = jnp.asarray(g.image)
frame = place_glyph_jax(
frame, glyph_jax, cursor, y,
g.bearing_x, g.bearing_y,
color_arr, opacity
)
# Advance cursor with optional kerning
advance = g.advance
if use_kerning and i + 1 < len(glyphs):
advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size)
cursor = cursor + advance
return frame