Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
860
streaming/jax_typography.py
Normal file
860
streaming/jax_typography.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
JAX Typography Primitives
|
||||
|
||||
Two approaches for text rendering, both compile to JAX/GPU:
|
||||
|
||||
## 1. TextStrip - Pixel-perfect static text
|
||||
Pre-render entire strings at compile time using PIL.
|
||||
Perfect sub-pixel anti-aliasing, exact match with PIL.
|
||||
Use for: static titles, labels, any text without per-character effects.
|
||||
|
||||
S-expression:
|
||||
(let ((strip (render-text-strip "Hello World" 48)))
|
||||
(place-text-strip frame strip x y :color white))
|
||||
|
||||
## 2. Glyph-by-glyph - Dynamic text effects
|
||||
Individual glyph placement for wave, arc, audio-reactive effects.
|
||||
Each character can have independent position, color, opacity.
|
||||
Note: slight anti-aliasing differences vs PIL due to integer positioning.
|
||||
|
||||
S-expression:
|
||||
; Wave text - y oscillates with character index
|
||||
(let ((glyphs (text-glyphs "Wavy" 48)))
|
||||
(first
|
||||
(fold glyphs (list frame 0)
|
||||
(lambda (acc g)
|
||||
(let ((frm (first acc))
|
||||
(cursor (second acc))
|
||||
(i (length acc))) ; approximate index
|
||||
(list
|
||||
(place-glyph frm (glyph-image g)
|
||||
(+ x cursor)
|
||||
(+ y (* amplitude (sin (* i frequency))))
|
||||
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||
white 1.0)
|
||||
(+ cursor (glyph-advance g))))))))
|
||||
|
||||
; Audio-reactive spacing
|
||||
(let ((glyphs (text-glyphs "Bass" 48))
|
||||
(bass (audio-band 0 200)))
|
||||
(first
|
||||
(fold glyphs (list frame 0)
|
||||
(lambda (acc g)
|
||||
(let ((frm (first acc))
|
||||
(cursor (second acc)))
|
||||
(list
|
||||
(place-glyph frm (glyph-image g)
|
||||
(+ x cursor) y
|
||||
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||
white 1.0)
|
||||
(+ cursor (glyph-advance g) (* bass 20))))))))
|
||||
|
||||
Kerning support:
|
||||
; With kerning adjustment
|
||||
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from typing import Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Glyph Data (computed at compile time)
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class GlyphData:
|
||||
"""Glyph data computed at compile time.
|
||||
|
||||
Attributes:
|
||||
char: The character
|
||||
image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime
|
||||
advance: Horizontal advance (distance to next glyph origin)
|
||||
bearing_x: Left side bearing (x offset from origin to first pixel)
|
||||
bearing_y: Top bearing (y offset from baseline to top of glyph)
|
||||
width: Image width
|
||||
height: Image height
|
||||
"""
|
||||
char: str
|
||||
image: np.ndarray # (H, W, 4) RGBA uint8
|
||||
advance: float
|
||||
bearing_x: float
|
||||
bearing_y: float
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# Font cache: (font_name, font_size) -> {char: GlyphData}
|
||||
_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {}
|
||||
|
||||
# Font metrics cache: (font_name, font_size) -> (ascent, descent)
|
||||
_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {}
|
||||
|
||||
# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment}
|
||||
# Kerning adjustment is added to advance: new_advance = advance + kerning
|
||||
# Typically negative (characters move closer together)
|
||||
_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {}
|
||||
|
||||
|
||||
def _load_font(font_name: str = None, font_size: int = 32):
|
||||
"""Load a font. Called at compile time."""
|
||||
from PIL import ImageFont
|
||||
|
||||
candidates = [
|
||||
font_name,
|
||||
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
|
||||
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
|
||||
]
|
||||
|
||||
for path in candidates:
|
||||
if path is None:
|
||||
continue
|
||||
try:
|
||||
return ImageFont.truetype(path, font_size)
|
||||
except (IOError, OSError):
|
||||
continue
|
||||
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]:
|
||||
"""Get or create glyph cache for a font. Called at compile time."""
|
||||
cache_key = (font_name, font_size)
|
||||
|
||||
if cache_key in _GLYPH_CACHE:
|
||||
return _GLYPH_CACHE[cache_key]
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
ascent, descent = font.getmetrics()
|
||||
_METRICS_CACHE[cache_key] = (ascent, descent)
|
||||
|
||||
glyphs = {}
|
||||
charset = ''.join(chr(i) for i in range(32, 127))
|
||||
|
||||
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
|
||||
temp_draw = ImageDraw.Draw(temp_img)
|
||||
|
||||
for char in charset:
|
||||
# Get metrics
|
||||
bbox = temp_draw.textbbox((0, 0), char, font=font)
|
||||
advance = font.getlength(char)
|
||||
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
|
||||
# Create glyph image with padding
|
||||
padding = 2
|
||||
img_w = max(int(x_max - x_min) + padding * 2, 1)
|
||||
img_h = max(int(y_max - y_min) + padding * 2, 1)
|
||||
|
||||
glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0))
|
||||
glyph_draw = ImageDraw.Draw(glyph_img)
|
||||
|
||||
# Draw at position accounting for bbox offset
|
||||
draw_x = padding - x_min
|
||||
draw_y = padding - y_min
|
||||
glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
glyphs[char] = GlyphData(
|
||||
char=char,
|
||||
image=np.array(glyph_img, dtype=np.uint8),
|
||||
advance=float(advance),
|
||||
bearing_x=float(x_min),
|
||||
bearing_y=float(-y_min), # Distance from baseline to top
|
||||
width=img_w,
|
||||
height=img_h,
|
||||
)
|
||||
|
||||
_GLYPH_CACHE[cache_key] = glyphs
|
||||
return glyphs
|
||||
|
||||
|
||||
def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]:
|
||||
"""Get or create kerning cache for a font. Called at compile time.
|
||||
|
||||
Kerning is computed as:
|
||||
kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b)
|
||||
|
||||
This gives the adjustment needed when placing 'b' after 'a'.
|
||||
Typically negative (characters move closer together).
|
||||
"""
|
||||
cache_key = (font_name, font_size)
|
||||
|
||||
if cache_key in _KERNING_CACHE:
|
||||
return _KERNING_CACHE[cache_key]
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
kerning = {}
|
||||
|
||||
# Compute kerning for all printable ASCII pairs
|
||||
charset = ''.join(chr(i) for i in range(32, 127))
|
||||
|
||||
# Pre-compute individual character lengths
|
||||
char_lengths = {c: font.getlength(c) for c in charset}
|
||||
|
||||
# Compute kerning for each pair
|
||||
for c1 in charset:
|
||||
for c2 in charset:
|
||||
pair_length = font.getlength(c1 + c2)
|
||||
individual_sum = char_lengths[c1] + char_lengths[c2]
|
||||
kern = pair_length - individual_sum
|
||||
|
||||
# Only store non-zero kerning to save memory
|
||||
if abs(kern) > 0.01:
|
||||
kerning[(c1, c2)] = kern
|
||||
|
||||
_KERNING_CACHE[cache_key] = kerning
|
||||
return kerning
|
||||
|
||||
|
||||
def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get kerning adjustment between two characters. Compile-time.
|
||||
|
||||
Returns the adjustment to add to char1's advance when char2 follows.
|
||||
Typically negative (characters move closer).
|
||||
|
||||
Usage in S-expression:
|
||||
(+ (glyph-advance g1) (kerning g1 g2))
|
||||
"""
|
||||
kerning_cache = _get_kerning_cache(font_name, font_size)
|
||||
return kerning_cache.get((char1, char2), 0.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextStrip:
|
||||
"""Pre-rendered text strip with proper sub-pixel anti-aliasing.
|
||||
|
||||
Rendered at compile time using PIL for exact matching.
|
||||
At runtime, just composite onto frame at integer positions.
|
||||
|
||||
Attributes:
|
||||
text: The original text
|
||||
image: RGBA image as numpy array (H, W, 4)
|
||||
width: Strip width
|
||||
height: Strip height
|
||||
baseline_y: Y position of baseline within the strip
|
||||
bearing_x: Left side bearing of first character
|
||||
anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right)
|
||||
anchor_y: Y offset for anchor point (depends on anchor type)
|
||||
stroke_width: Stroke width used when rendering
|
||||
"""
|
||||
text: str
|
||||
image: np.ndarray
|
||||
width: int
|
||||
height: int
|
||||
baseline_y: int
|
||||
bearing_x: float
|
||||
anchor_x: float = 0.0
|
||||
anchor_y: float = 0.0
|
||||
stroke_width: int = 0
|
||||
|
||||
|
||||
# Text strip cache: cache_key -> TextStrip
|
||||
_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {}
|
||||
|
||||
|
||||
def render_text_strip(
|
||||
text: str,
|
||||
font_name: str = None,
|
||||
font_size: int = 32,
|
||||
stroke_width: int = 0,
|
||||
stroke_fill: tuple = None,
|
||||
anchor: str = "la", # left-ascender (PIL default is "la")
|
||||
multiline: bool = False,
|
||||
line_spacing: int = 4,
|
||||
align: str = "left",
|
||||
) -> TextStrip:
|
||||
"""Render text to a strip at compile time. Perfect sub-pixel anti-aliasing.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
font_name: Path to font file (None for default)
|
||||
font_size: Font size in pixels
|
||||
stroke_width: Outline width in pixels (0 for no outline)
|
||||
stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black
|
||||
anchor: PIL anchor code - first char: h=left, m=middle, r=right
|
||||
second char: a=ascender, t=top, m=middle, s=baseline, d=descender
|
||||
multiline: If True, handle newlines in text
|
||||
line_spacing: Extra pixels between lines (for multiline)
|
||||
align: 'left', 'center', 'right' (for multiline)
|
||||
|
||||
Returns:
|
||||
TextStrip with pre-rendered text
|
||||
"""
|
||||
# Build cache key from all parameters
|
||||
cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align)
|
||||
if cache_key in _TEXT_STRIP_CACHE:
|
||||
return _TEXT_STRIP_CACHE[cache_key]
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
ascent, descent = font.getmetrics()
|
||||
|
||||
# Default stroke fill to black
|
||||
if stroke_fill is None:
|
||||
stroke_fill = (0, 0, 0, 255)
|
||||
elif len(stroke_fill) == 3:
|
||||
stroke_fill = (*stroke_fill, 255)
|
||||
|
||||
# Get text bbox (accounting for stroke)
|
||||
temp = Image.new('RGBA', (1, 1))
|
||||
temp_draw = ImageDraw.Draw(temp)
|
||||
|
||||
if multiline:
|
||||
bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing,
|
||||
stroke_width=stroke_width)
|
||||
else:
|
||||
bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width)
|
||||
|
||||
# bbox is (left, top, right, bottom) relative to origin
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
|
||||
# Create image with padding (extra for stroke)
|
||||
padding = 2 + stroke_width
|
||||
img_width = max(int(x_max - x_min) + padding * 2, 1)
|
||||
img_height = max(int(y_max - y_min) + padding * 2, 1)
|
||||
|
||||
# Create RGBA image
|
||||
img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw text at position that puts it in the image
|
||||
draw_x = padding - x_min
|
||||
draw_y = padding - y_min
|
||||
|
||||
if multiline:
|
||||
draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||
spacing=line_spacing, align=align,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||
else:
|
||||
draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||
|
||||
# Baseline is at y=0 in text coordinates, which is at draw_y in image
|
||||
baseline_y = draw_y
|
||||
|
||||
# Convert to numpy for pixel analysis
|
||||
img_array = np.array(img, dtype=np.uint8)
|
||||
|
||||
# Calculate anchor offsets
|
||||
# For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching
|
||||
h_anchor = anchor[0] if len(anchor) > 0 else 'l'
|
||||
v_anchor = anchor[1] if len(anchor) > 1 else 'a'
|
||||
|
||||
# Find actual pixel bounds (for middle anchor calculations)
|
||||
alpha = img_array[:, :, 3]
|
||||
nonzero_cols = np.where(alpha.max(axis=0) > 0)[0]
|
||||
nonzero_rows = np.where(alpha.max(axis=1) > 0)[0]
|
||||
|
||||
if len(nonzero_cols) > 0:
|
||||
pixel_x_min = nonzero_cols.min()
|
||||
pixel_x_max = nonzero_cols.max()
|
||||
pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0
|
||||
else:
|
||||
pixel_x_center = img_width / 2.0
|
||||
|
||||
if len(nonzero_rows) > 0:
|
||||
pixel_y_min = nonzero_rows.min()
|
||||
pixel_y_max = nonzero_rows.max()
|
||||
pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0
|
||||
else:
|
||||
pixel_y_center = img_height / 2.0
|
||||
|
||||
# Horizontal offset
|
||||
text_width = x_max - x_min
|
||||
if h_anchor == 'l': # left edge of text
|
||||
anchor_x = float(draw_x)
|
||||
elif h_anchor == 'm': # middle - use actual pixel center for perfect matching
|
||||
anchor_x = pixel_x_center
|
||||
elif h_anchor == 'r': # right edge of text
|
||||
anchor_x = float(draw_x + text_width)
|
||||
else:
|
||||
anchor_x = float(draw_x)
|
||||
|
||||
# Vertical offset
|
||||
# PIL anchor positions are based on font metrics (ascent/descent):
|
||||
# - 'a' (ascender): at the ascender line → draw_y in strip
|
||||
# - 't' (top): at top of text bounding box → padding in strip
|
||||
# - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender
|
||||
# - 's' (baseline): at baseline = ascent below ascender
|
||||
# - 'd' (descender): at descender line = ascent + descent below ascender
|
||||
|
||||
if v_anchor == 'a': # ascender
|
||||
anchor_y = float(draw_y)
|
||||
elif v_anchor == 't': # top of bbox
|
||||
anchor_y = float(padding)
|
||||
elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation)
|
||||
anchor_y = float(draw_y + (ascent + descent) / 2.0)
|
||||
elif v_anchor == 's': # baseline
|
||||
anchor_y = float(draw_y + ascent)
|
||||
elif v_anchor == 'd': # descender
|
||||
anchor_y = float(draw_y + ascent + descent)
|
||||
else:
|
||||
anchor_y = float(draw_y) # default to ascender
|
||||
|
||||
strip = TextStrip(
|
||||
text=text,
|
||||
image=img_array,
|
||||
width=img_width,
|
||||
height=img_height,
|
||||
baseline_y=baseline_y,
|
||||
bearing_x=float(x_min),
|
||||
anchor_x=anchor_x,
|
||||
anchor_y=anchor_y,
|
||||
stroke_width=stroke_width,
|
||||
)
|
||||
|
||||
_TEXT_STRIP_CACHE[cache_key] = strip
|
||||
return strip
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Compile-time functions (called during S-expression compilation)
|
||||
# =============================================================================
|
||||
|
||||
def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData:
|
||||
"""Get glyph data for a single character. Compile-time."""
|
||||
cache = _get_glyph_cache(font_name, font_size)
|
||||
return cache.get(char, cache.get(' '))
|
||||
|
||||
|
||||
def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list:
|
||||
"""Get glyph data for a string. Compile-time."""
|
||||
cache = _get_glyph_cache(font_name, font_size)
|
||||
space = cache.get(' ')
|
||||
return [cache.get(c, space) for c in text]
|
||||
|
||||
|
||||
def get_font_ascent(font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get font ascent. Compile-time."""
|
||||
_get_glyph_cache(font_name, font_size) # Ensure cache exists
|
||||
return _METRICS_CACHE[(font_name, font_size)][0]
|
||||
|
||||
|
||||
def get_font_descent(font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get font descent. Compile-time."""
|
||||
_get_glyph_cache(font_name, font_size)
|
||||
return _METRICS_CACHE[(font_name, font_size)][1]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JAX Runtime Primitives
|
||||
# =============================================================================
|
||||
|
||||
def place_glyph_jax(
|
||||
frame: jnp.ndarray,
|
||||
glyph_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||
x: float,
|
||||
y: float,
|
||||
bearing_x: float,
|
||||
bearing_y: float,
|
||||
color: jnp.ndarray, # (3,) RGB 0-255
|
||||
opacity: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Place a glyph onto a frame. This is the core JAX primitive.
|
||||
|
||||
All positioning math can use traced values (x, y from audio, time, etc.)
|
||||
The glyph_image is static (determined at compile time).
|
||||
|
||||
Args:
|
||||
frame: (H, W, 3) RGB frame
|
||||
glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array)
|
||||
x: X position of glyph origin (baseline point)
|
||||
y: Y position of baseline
|
||||
bearing_x: Left side bearing
|
||||
bearing_y: Top bearing (from baseline to top)
|
||||
color: RGB color array
|
||||
opacity: Opacity 0-1
|
||||
|
||||
Returns:
|
||||
Frame with glyph composited
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
gh, gw = glyph_image.shape[:2]
|
||||
|
||||
# Calculate destination position
|
||||
# bearing_x: how far right of origin the glyph starts (can be negative)
|
||||
# bearing_y: how far up from baseline the glyph extends
|
||||
padding = 2 # Must match padding used in glyph creation
|
||||
dst_x = x + bearing_x - padding
|
||||
dst_y = y - bearing_y - padding
|
||||
|
||||
# Extract glyph RGB and alpha
|
||||
glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||
opacity_int = jnp.round(opacity * 255)
|
||||
glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32)
|
||||
glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||
|
||||
# Apply color tint (glyph is white, multiply by color)
|
||||
color_normalized = color.astype(jnp.float32) / 255.0
|
||||
tinted = glyph_rgb * color_normalized
|
||||
|
||||
from jax.lax import dynamic_update_slice
|
||||
|
||||
# Use padded buffer to avoid XLA's dynamic_update_slice clamping
|
||||
buf_h = h + 2 * gh
|
||||
buf_w = w + 2 * gw
|
||||
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||
|
||||
dst_x_int = dst_x.astype(jnp.int32)
|
||||
dst_y_int = dst_y.astype(jnp.int32)
|
||||
place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32)
|
||||
place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32)
|
||||
|
||||
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||
alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0))
|
||||
|
||||
rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :]
|
||||
alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :]
|
||||
|
||||
# Alpha composite using PIL-compatible integer arithmetic
|
||||
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||
dst_int = frame.astype(jnp.int32)
|
||||
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||
|
||||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||
|
||||
|
||||
def place_text_strip_jax(
|
||||
frame: jnp.ndarray,
|
||||
strip_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||
x: float,
|
||||
y: float,
|
||||
baseline_y: int,
|
||||
bearing_x: float,
|
||||
color: jnp.ndarray,
|
||||
opacity: float = 1.0,
|
||||
anchor_x: float = 0.0,
|
||||
anchor_y: float = 0.0,
|
||||
stroke_width: int = 0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Place a pre-rendered text strip onto a frame.
|
||||
|
||||
The strip was rendered at compile time with proper sub-pixel anti-aliasing.
|
||||
This just composites it at the specified position.
|
||||
|
||||
Args:
|
||||
frame: (H, W, 3) RGB frame
|
||||
strip_image: (sh, sw, 4) RGBA text strip
|
||||
x: X position for anchor point
|
||||
y: Y position for anchor point
|
||||
baseline_y: Y position of baseline within the strip
|
||||
bearing_x: Left side bearing
|
||||
color: RGB color
|
||||
opacity: Opacity 0-1
|
||||
anchor_x: X offset of anchor point within strip
|
||||
anchor_y: Y offset of anchor point within strip
|
||||
stroke_width: Stroke width used when rendering (affects padding)
|
||||
|
||||
Returns:
|
||||
Frame with text composited
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
sh, sw = strip_image.shape[:2]
|
||||
|
||||
# Calculate destination position
|
||||
# Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame
|
||||
# anchor_x/anchor_y already account for the anchor position within the strip
|
||||
# Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding)
|
||||
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||
|
||||
# Extract strip RGB and alpha
|
||||
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||
# Use jnp.round (banker's rounding) to match Python's round() used by PIL
|
||||
opacity_int = jnp.round(opacity * 255)
|
||||
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
|
||||
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||
|
||||
# Apply color tint
|
||||
color_normalized = color.astype(jnp.float32) / 255.0
|
||||
tinted = strip_rgb * color_normalized
|
||||
|
||||
from jax.lax import dynamic_update_slice
|
||||
|
||||
# Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior.
|
||||
# XLA clamps indices so the update fits, which silently shifts the strip.
|
||||
# By placing into a buffer padded by strip dimensions, then extracting the
|
||||
# frame-sized region, we get correct clipping for both overflow and underflow.
|
||||
buf_h = h + 2 * sh
|
||||
buf_w = w + 2 * sw
|
||||
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||
|
||||
# Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer
|
||||
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
|
||||
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
|
||||
|
||||
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||
alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
|
||||
|
||||
# Extract frame-sized region (sh, sw are compile-time constants from strip shape)
|
||||
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
|
||||
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
|
||||
|
||||
# Alpha composite using PIL-compatible integer arithmetic:
|
||||
# result = (src * alpha + dst * (255 - alpha) + 127) // 255
|
||||
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||
dst_int = frame.astype(jnp.int32)
|
||||
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||
|
||||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||
|
||||
|
||||
def place_glyph_simple(
|
||||
frame: jnp.ndarray,
|
||||
glyph: GlyphData,
|
||||
x: float,
|
||||
y: float,
|
||||
color: tuple = (255, 255, 255),
|
||||
opacity: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Convenience wrapper that takes GlyphData directly.
|
||||
Converts glyph image to JAX array.
|
||||
|
||||
For S-expression use, prefer place_glyph_jax with pre-converted arrays.
|
||||
"""
|
||||
glyph_jax = jnp.asarray(glyph.image)
|
||||
color_jax = jnp.array(color, dtype=jnp.float32)
|
||||
|
||||
return place_glyph_jax(
|
||||
frame, glyph_jax, x, y,
|
||||
glyph.bearing_x, glyph.bearing_y,
|
||||
color_jax, opacity
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# S-Expression Primitive Bindings
|
||||
# =============================================================================
|
||||
|
||||
def bind_typography_primitives(env: dict) -> dict:
|
||||
"""
|
||||
Add typography primitives to an S-expression environment.
|
||||
|
||||
Primitives added:
|
||||
(text-glyphs text font-size) -> list of glyph data
|
||||
(glyph-image g) -> JAX array (H, W, 4)
|
||||
(glyph-advance g) -> float
|
||||
(glyph-bearing-x g) -> float
|
||||
(glyph-bearing-y g) -> float
|
||||
(glyph-width g) -> int
|
||||
(glyph-height g) -> int
|
||||
(font-ascent font-size) -> float
|
||||
(font-descent font-size) -> float
|
||||
(place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame
|
||||
"""
|
||||
|
||||
def prim_text_glyphs(text, font_size=32, font_name=None):
|
||||
"""Get list of glyph data for text. Compile-time."""
|
||||
return get_glyphs(str(text), font_name, int(font_size))
|
||||
|
||||
def prim_glyph_image(glyph):
|
||||
"""Get glyph image as JAX array."""
|
||||
return jnp.asarray(glyph.image)
|
||||
|
||||
def prim_glyph_advance(glyph):
|
||||
"""Get glyph advance width."""
|
||||
return glyph.advance
|
||||
|
||||
def prim_glyph_bearing_x(glyph):
|
||||
"""Get glyph left side bearing."""
|
||||
return glyph.bearing_x
|
||||
|
||||
def prim_glyph_bearing_y(glyph):
|
||||
"""Get glyph top bearing."""
|
||||
return glyph.bearing_y
|
||||
|
||||
def prim_glyph_width(glyph):
|
||||
"""Get glyph image width."""
|
||||
return glyph.width
|
||||
|
||||
def prim_glyph_height(glyph):
|
||||
"""Get glyph image height."""
|
||||
return glyph.height
|
||||
|
||||
def prim_font_ascent(font_size=32, font_name=None):
|
||||
"""Get font ascent."""
|
||||
return get_font_ascent(font_name, int(font_size))
|
||||
|
||||
def prim_font_descent(font_size=32, font_name=None):
|
||||
"""Get font descent."""
|
||||
return get_font_descent(font_name, int(font_size))
|
||||
|
||||
def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||
color=(255, 255, 255), opacity=1.0):
|
||||
"""Place glyph on frame. Runtime JAX operation."""
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||
color_arr, opacity)
|
||||
|
||||
def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None):
|
||||
"""Get kerning adjustment between two glyphs. Compile-time.
|
||||
|
||||
Returns adjustment to add to glyph1's advance when glyph2 follows.
|
||||
Typically negative (characters move closer).
|
||||
|
||||
Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||
"""
|
||||
return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size))
|
||||
|
||||
def prim_char_kerning(char1, char2, font_size=32, font_name=None):
|
||||
"""Get kerning adjustment between two characters. Compile-time."""
|
||||
return get_kerning(str(char1), str(char2), font_name, int(font_size))
|
||||
|
||||
# TextStrip primitives for pre-rendered text with proper anti-aliasing
|
||||
def prim_render_text_strip(text, font_size=32, font_name=None):
|
||||
"""Render text to a strip at compile time. Perfect anti-aliasing."""
|
||||
return render_text_strip(str(text), font_name, int(font_size))
|
||||
|
||||
def prim_render_text_strip_styled(
|
||||
text, font_size=32, font_name=None,
|
||||
stroke_width=0, stroke_fill=None,
|
||||
anchor="la", multiline=False, line_spacing=4, align="left"
|
||||
):
|
||||
"""Render styled text to a strip. Supports stroke, anchors, multiline.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
font_size: Size in pixels
|
||||
font_name: Path to font file
|
||||
stroke_width: Outline width (0 = no outline)
|
||||
stroke_fill: Outline color as (R,G,B) or (R,G,B,A)
|
||||
anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender)
|
||||
multiline: If True, handle newlines
|
||||
line_spacing: Extra pixels between lines
|
||||
align: "left", "center", "right" for multiline
|
||||
"""
|
||||
return render_text_strip(
|
||||
str(text), font_name, int(font_size),
|
||||
stroke_width=int(stroke_width),
|
||||
stroke_fill=stroke_fill,
|
||||
anchor=str(anchor),
|
||||
multiline=bool(multiline),
|
||||
line_spacing=int(line_spacing),
|
||||
align=str(align),
|
||||
)
|
||||
|
||||
def prim_text_strip_image(strip):
|
||||
"""Get text strip image as JAX array."""
|
||||
return jnp.asarray(strip.image)
|
||||
|
||||
def prim_text_strip_width(strip):
|
||||
"""Get text strip width."""
|
||||
return strip.width
|
||||
|
||||
def prim_text_strip_height(strip):
|
||||
"""Get text strip height."""
|
||||
return strip.height
|
||||
|
||||
def prim_text_strip_baseline_y(strip):
|
||||
"""Get text strip baseline Y position."""
|
||||
return strip.baseline_y
|
||||
|
||||
def prim_text_strip_bearing_x(strip):
|
||||
"""Get text strip left bearing."""
|
||||
return strip.bearing_x
|
||||
|
||||
def prim_text_strip_anchor_x(strip):
|
||||
"""Get text strip anchor X offset."""
|
||||
return strip.anchor_x
|
||||
|
||||
def prim_text_strip_anchor_y(strip):
|
||||
"""Get text strip anchor Y offset."""
|
||||
return strip.anchor_y
|
||||
|
||||
def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0):
|
||||
"""Place pre-rendered text strip on frame. Runtime JAX operation."""
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
return place_text_strip_jax(
|
||||
frame, strip_img, x, y,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color_arr, opacity,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
stroke_width=strip.stroke_width
|
||||
)
|
||||
|
||||
# Add to environment
|
||||
env.update({
|
||||
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
|
||||
'text-glyphs': prim_text_glyphs,
|
||||
'glyph-image': prim_glyph_image,
|
||||
'glyph-advance': prim_glyph_advance,
|
||||
'glyph-bearing-x': prim_glyph_bearing_x,
|
||||
'glyph-bearing-y': prim_glyph_bearing_y,
|
||||
'glyph-width': prim_glyph_width,
|
||||
'glyph-height': prim_glyph_height,
|
||||
'glyph-kerning': prim_glyph_kerning,
|
||||
'char-kerning': prim_char_kerning,
|
||||
'font-ascent': prim_font_ascent,
|
||||
'font-descent': prim_font_descent,
|
||||
'place-glyph': prim_place_glyph,
|
||||
# TextStrip primitives (for pixel-perfect static text)
|
||||
'render-text-strip': prim_render_text_strip,
|
||||
'render-text-strip-styled': prim_render_text_strip_styled,
|
||||
'text-strip-image': prim_text_strip_image,
|
||||
'text-strip-width': prim_text_strip_width,
|
||||
'text-strip-height': prim_text_strip_height,
|
||||
'text-strip-baseline-y': prim_text_strip_baseline_y,
|
||||
'text-strip-bearing-x': prim_text_strip_bearing_x,
|
||||
'text-strip-anchor-x': prim_text_strip_anchor_x,
|
||||
'text-strip-anchor-y': prim_text_strip_anchor_y,
|
||||
'place-text-strip': prim_place_text_strip,
|
||||
})
|
||||
|
||||
return env
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Example: Render text using primitives (for testing)
|
||||
# =============================================================================
|
||||
|
||||
def render_text_primitives(
|
||||
frame: jnp.ndarray,
|
||||
text: str,
|
||||
x: float,
|
||||
y: float,
|
||||
font_size: int = 32,
|
||||
color: tuple = (255, 255, 255),
|
||||
opacity: float = 1.0,
|
||||
use_kerning: bool = True,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Render text using the primitives.
|
||||
This is what an S-expression would compile to.
|
||||
|
||||
Args:
|
||||
use_kerning: If True, apply kerning adjustments between characters
|
||||
"""
|
||||
glyphs = get_glyphs(text, None, font_size)
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
|
||||
cursor = x
|
||||
for i, g in enumerate(glyphs):
|
||||
glyph_jax = jnp.asarray(g.image)
|
||||
frame = place_glyph_jax(
|
||||
frame, glyph_jax, cursor, y,
|
||||
g.bearing_x, g.bearing_y,
|
||||
color_arr, opacity
|
||||
)
|
||||
# Advance cursor with optional kerning
|
||||
advance = g.advance
|
||||
if use_kerning and i + 1 < len(glyphs):
|
||||
advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size)
|
||||
cursor = cursor + advance
|
||||
|
||||
return frame
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation:
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
@@ -62,6 +63,38 @@ class Context:
|
||||
fps: float = 30.0
|
||||
|
||||
|
||||
class DeferredEffectChain:
|
||||
"""
|
||||
Represents a chain of JAX effects that haven't been executed yet.
|
||||
|
||||
Allows effects to be accumulated through let bindings and fused
|
||||
into a single JIT-compiled function when the result is needed.
|
||||
"""
|
||||
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
|
||||
|
||||
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
|
||||
self.effects = effects # List of effect names, innermost first
|
||||
self.params_list = params_list # List of param dicts, matching effects
|
||||
self.base_frame = base_frame # The actual frame array at the start
|
||||
self.t = t
|
||||
self.frame_num = frame_num
|
||||
|
||||
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
|
||||
"""Add another effect to the chain (outermost)."""
|
||||
return DeferredEffectChain(
|
||||
self.effects + [effect_name],
|
||||
self.params_list + [params],
|
||||
self.base_frame,
|
||||
self.t,
|
||||
self.frame_num
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Allow shape check without forcing execution."""
|
||||
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
|
||||
|
||||
|
||||
class StreamInterpreter:
|
||||
"""
|
||||
Fully generic streaming sexp interpreter.
|
||||
@@ -98,6 +131,9 @@ class StreamInterpreter:
|
||||
self.use_jax = use_jax
|
||||
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
||||
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
||||
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
|
||||
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
|
||||
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
|
||||
if use_jax:
|
||||
if _init_jax():
|
||||
print("JAX acceleration enabled", file=sys.stderr)
|
||||
@@ -238,6 +274,8 @@ class StreamInterpreter:
|
||||
"""Load primitives from a Python library file.
|
||||
|
||||
Prefers GPU-accelerated versions (*_gpu.py) when available.
|
||||
Uses cached modules from sys.modules to ensure consistent state
|
||||
(e.g., same RNG instance for all interpreters).
|
||||
"""
|
||||
import importlib.util
|
||||
|
||||
@@ -264,9 +302,26 @@ class StreamInterpreter:
|
||||
if not lib_path:
|
||||
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
# Use cached module if already imported to preserve state (e.g., RNG)
|
||||
# This is critical for deterministic random number sequences
|
||||
# Check multiple possible module keys (standard import paths and our cache)
|
||||
possible_keys = [
|
||||
f"sexp_effects.primitive_libs.{actual_lib_name}",
|
||||
f"sexp_primitives.{actual_lib_name}",
|
||||
]
|
||||
|
||||
module = None
|
||||
for key in possible_keys:
|
||||
if key in sys.modules:
|
||||
module = sys.modules[key]
|
||||
break
|
||||
|
||||
if module is None:
|
||||
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
# Cache for future use under our key
|
||||
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
|
||||
|
||||
# Check if this is a GPU-accelerated module
|
||||
is_gpu = actual_lib_name.endswith('_gpu')
|
||||
@@ -452,30 +507,353 @@ class StreamInterpreter:
|
||||
|
||||
try:
|
||||
jax_fn = self.jax_effects[name]
|
||||
# Ensure frame is numpy array
|
||||
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
|
||||
# JAX handles numpy and JAX arrays natively, no conversion needed
|
||||
if hasattr(frame, 'cpu'):
|
||||
frame = frame.cpu
|
||||
elif hasattr(frame, 'get'):
|
||||
frame = frame.get()
|
||||
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||
frame = frame.get() # CuPy array -> numpy
|
||||
|
||||
# Get seed from config for deterministic random
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Call JAX function with parameters
|
||||
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||
|
||||
# Convert result back to numpy if needed
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready() # Ensure computation is complete
|
||||
if hasattr(result, '__array__'):
|
||||
result = np.asarray(result)
|
||||
|
||||
return result
|
||||
# Return JAX array directly - don't block or convert per-effect
|
||||
# Conversion to numpy happens once at frame write time
|
||||
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||
except Exception as e:
|
||||
# Fall back to interpreter on error
|
||||
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _is_jax_effect_expr(self, expr) -> bool:
|
||||
"""Check if an expression is a JAX-compiled effect call."""
|
||||
if not isinstance(expr, list) or not expr:
|
||||
return False
|
||||
head = expr[0]
|
||||
if not isinstance(head, Symbol):
|
||||
return False
|
||||
return head.name in self.jax_effects
|
||||
|
||||
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
|
||||
"""
|
||||
Extract a chain of JAX effects from an expression.
|
||||
|
||||
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
|
||||
effect_names and params_list are in execution order (innermost first).
|
||||
|
||||
For (effect1 (effect2 frame :p1 v1) :p2 v2):
|
||||
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
|
||||
"""
|
||||
if not self._is_jax_effect_expr(expr):
|
||||
return None
|
||||
|
||||
chain = []
|
||||
params_list = []
|
||||
current = expr
|
||||
|
||||
while self._is_jax_effect_expr(current):
|
||||
head = current[0]
|
||||
effect_name = head.name
|
||||
args = current[1:]
|
||||
|
||||
# Extract params for this effect
|
||||
effect = self.effects[effect_name]
|
||||
effect_params = {}
|
||||
for pname, pdef in effect['params'].items():
|
||||
effect_params[pname] = pdef.get('default', 0)
|
||||
|
||||
# Find the frame argument (first positional) and other params
|
||||
frame_arg = None
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if isinstance(args[i], Keyword):
|
||||
pname = args[i].name
|
||||
if pname in effect['params'] and i + 1 < len(args):
|
||||
effect_params[pname] = self._eval(args[i + 1], env)
|
||||
i += 2
|
||||
else:
|
||||
if frame_arg is None:
|
||||
frame_arg = args[i] # First positional is frame
|
||||
i += 1
|
||||
|
||||
chain.append(effect_name)
|
||||
params_list.append(effect_params)
|
||||
|
||||
if frame_arg is None:
|
||||
return None # No frame argument found
|
||||
|
||||
# Check if frame_arg is another effect call
|
||||
if self._is_jax_effect_expr(frame_arg):
|
||||
current = frame_arg
|
||||
else:
|
||||
# End of chain - frame_arg is the base frame
|
||||
# Reverse to get innermost-first execution order
|
||||
chain.reverse()
|
||||
params_list.reverse()
|
||||
return (chain, params_list, frame_arg)
|
||||
|
||||
return None
|
||||
|
||||
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
|
||||
"""Generate a cache key for an effect chain.
|
||||
|
||||
Includes static param values in the key since they affect compilation.
|
||||
"""
|
||||
parts = []
|
||||
for name, params in zip(effect_names, params_list):
|
||||
param_parts = []
|
||||
for pname in sorted(params.keys()):
|
||||
pval = params[pname]
|
||||
# Include static values in key (strings, bools)
|
||||
if isinstance(pval, (str, bool)):
|
||||
param_parts.append(f"{pname}={pval}")
|
||||
else:
|
||||
param_parts.append(pname)
|
||||
parts.append(f"{name}:{','.join(param_parts)}")
|
||||
return '|'.join(parts)
|
||||
|
||||
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||
"""
|
||||
Compile a chain of effects into a single fused JAX function.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||
params_list: List of param dicts for each effect (used to detect static types)
|
||||
|
||||
Returns:
|
||||
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
|
||||
"""
|
||||
if not _JAX_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
import jax
|
||||
|
||||
# Get the individual JAX functions
|
||||
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||
|
||||
# Pre-extract param names and identify static params from actual values
|
||||
effect_param_names = []
|
||||
static_params = ['seed'] # seed is always static
|
||||
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||
param_names = list(params.keys())
|
||||
effect_param_names.append(param_names)
|
||||
# Check actual values to identify static types
|
||||
for pname, pval in params.items():
|
||||
if isinstance(pval, (str, bool)):
|
||||
static_params.append(f"_p{i}_{pname}")
|
||||
|
||||
def fused_fn(frame, t, frame_num, seed, **kwargs):
|
||||
result = frame
|
||||
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||
# Extract params for this effect from kwargs
|
||||
effect_kwargs = {}
|
||||
for pname in param_names:
|
||||
key = f"_p{i}_{pname}"
|
||||
if key in kwargs:
|
||||
effect_kwargs[pname] = kwargs[key]
|
||||
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||
return result
|
||||
|
||||
# JIT with static params (seed + any string/bool params)
|
||||
return jax.jit(fused_fn, static_argnames=static_params)
|
||||
except Exception as e:
|
||||
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
|
||||
"""Apply a chain of effects, using fused compilation if available."""
|
||||
chain_key = self._get_chain_key(effect_names, params_list)
|
||||
|
||||
# Try to get or compile fused chain
|
||||
if chain_key not in self.jax_fused_chains:
|
||||
fused_fn = self._compile_effect_chain(effect_names, params_list)
|
||||
self.jax_fused_chains[chain_key] = fused_fn
|
||||
if fused_fn:
|
||||
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
|
||||
|
||||
fused_fn = self.jax_fused_chains.get(chain_key)
|
||||
|
||||
if fused_fn is not None:
|
||||
# Build kwargs with prefixed param names
|
||||
kwargs = {}
|
||||
for i, params in enumerate(params_list):
|
||||
for pname, pval in params.items():
|
||||
kwargs[f"_p{i}_{pname}"] = pval
|
||||
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Handle GPU frames
|
||||
if hasattr(frame, 'cpu'):
|
||||
frame = frame.cpu
|
||||
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||
frame = frame.get()
|
||||
|
||||
try:
|
||||
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"Fused chain error: {e}", file=sys.stderr)
|
||||
|
||||
# Fall back to sequential application
|
||||
result = frame
|
||||
for name, params in zip(effect_names, params_list):
|
||||
result = self._apply_jax_effect(name, result, params, t, frame_num)
|
||||
if result is None:
|
||||
return None
|
||||
return result
|
||||
|
||||
def _force_deferred(self, deferred: DeferredEffectChain):
|
||||
"""Execute a deferred effect chain and return the actual array."""
|
||||
if len(deferred.effects) == 0:
|
||||
return deferred.base_frame
|
||||
|
||||
return self._apply_effect_chain(
|
||||
deferred.effects,
|
||||
deferred.params_list,
|
||||
deferred.base_frame,
|
||||
deferred.t,
|
||||
deferred.frame_num
|
||||
)
|
||||
|
||||
def _maybe_force(self, value):
|
||||
"""Force a deferred chain if needed, otherwise return as-is."""
|
||||
if isinstance(value, DeferredEffectChain):
|
||||
return self._force_deferred(value)
|
||||
return value
|
||||
|
||||
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||
"""
|
||||
Compile a vmapped version of an effect chain for batch processing.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||
params_list: List of param dicts (used to detect static types)
|
||||
|
||||
Returns:
|
||||
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
|
||||
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
|
||||
"""
|
||||
if not _JAX_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Get the individual JAX functions
|
||||
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||
|
||||
# Pre-extract param info
|
||||
effect_param_names = []
|
||||
static_params = set()
|
||||
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||
param_names = list(params.keys())
|
||||
effect_param_names.append(param_names)
|
||||
for pname, pval in params.items():
|
||||
if isinstance(pval, (str, bool)):
|
||||
static_params.add(f"_p{i}_{pname}")
|
||||
|
||||
# Single-frame function (will be vmapped)
|
||||
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
|
||||
result = frame
|
||||
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||
effect_kwargs = {}
|
||||
for pname in param_names:
|
||||
key = f"_p{i}_{pname}"
|
||||
if key in kwargs:
|
||||
effect_kwargs[pname] = kwargs[key]
|
||||
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||
return result
|
||||
|
||||
# Return unbatched function - we'll vmap at call time with proper in_axes
|
||||
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
|
||||
except Exception as e:
|
||||
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
|
||||
frames: list, ts: list, frame_nums: list) -> Optional[list]:
|
||||
"""
|
||||
Apply an effect chain to a batch of frames using vmap.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names
|
||||
params_list_batch: List of params_list for each frame in batch
|
||||
frames: List of input frames
|
||||
ts: List of time values
|
||||
frame_nums: List of frame numbers
|
||||
|
||||
Returns:
|
||||
List of output frames, or None on failure
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
# Use first frame's params for chain key (assume same structure)
|
||||
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
|
||||
batch_key = f"batch:{chain_key}"
|
||||
|
||||
# Compile batched version if needed
|
||||
if batch_key not in self.jax_batched_chains:
|
||||
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
|
||||
self.jax_batched_chains[batch_key] = batched_fn
|
||||
if batched_fn:
|
||||
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
|
||||
|
||||
batched_fn = self.jax_batched_chains.get(batch_key)
|
||||
|
||||
if batched_fn is not None:
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Stack frames into batch array
|
||||
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
|
||||
ts_array = jnp.array(ts)
|
||||
frame_nums_array = jnp.array(frame_nums)
|
||||
|
||||
# Build kwargs - all numeric params as arrays for vmap
|
||||
kwargs = {}
|
||||
static_kwargs = {} # Non-vmapped (strings, bools)
|
||||
|
||||
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
|
||||
for j, pname in enumerate(params_list_batch[0][i].keys()):
|
||||
key = f"_p{i}_{pname}"
|
||||
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
|
||||
|
||||
first = values[0]
|
||||
if isinstance(first, (str, bool)):
|
||||
# Static params - not vmapped
|
||||
static_kwargs[key] = first
|
||||
elif isinstance(first, (int, float)):
|
||||
# Always batch numeric params for simplicity
|
||||
kwargs[key] = jnp.array(values)
|
||||
elif hasattr(first, 'shape'):
|
||||
kwargs[key] = jnp.stack(values)
|
||||
else:
|
||||
kwargs[key] = jnp.array(values)
|
||||
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Create wrapper that unpacks the params dict
|
||||
def single_call(frame, t, frame_num, params_dict):
|
||||
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
|
||||
|
||||
# vmap over frame, t, frame_num, and the params dict (as pytree)
|
||||
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# Stack kwargs into a dict of arrays (pytree with matching structure)
|
||||
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
|
||||
|
||||
# Unstack results
|
||||
return [results[i] for i in range(len(frames))]
|
||||
except Exception as e:
|
||||
print(f"Batched chain error: {e}", file=sys.stderr)
|
||||
|
||||
# Fall back to sequential
|
||||
return None
|
||||
|
||||
def _init(self):
|
||||
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
||||
# Set random seed for deterministic output
|
||||
@@ -869,6 +1247,22 @@ class StreamInterpreter:
|
||||
# === Effects ===
|
||||
|
||||
if op in self.effects:
|
||||
# Try to detect and fuse effect chains for JAX acceleration
|
||||
if self.use_jax and op in self.jax_effects:
|
||||
chain_info = self._extract_effect_chain(expr, env)
|
||||
if chain_info is not None:
|
||||
effect_names, params_list, base_frame_expr = chain_info
|
||||
# Only use chain if we have 2+ effects (worth fusing)
|
||||
if len(effect_names) >= 2:
|
||||
base_frame = self._eval(base_frame_expr, env)
|
||||
if base_frame is not None and hasattr(base_frame, 'shape'):
|
||||
t = env.get('t', 0.0)
|
||||
frame_num = env.get('frame-num', 0)
|
||||
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
|
||||
if result is not None:
|
||||
return result
|
||||
# Fall through if chain application fails
|
||||
|
||||
effect = self.effects[op]
|
||||
effect_env = dict(env)
|
||||
|
||||
@@ -895,17 +1289,28 @@ class StreamInterpreter:
|
||||
positional_idx += 1
|
||||
i += 1
|
||||
|
||||
# Try JAX-accelerated execution first
|
||||
# Try JAX-accelerated execution with deferred chaining
|
||||
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
||||
# Build params dict for JAX (exclude 'frame')
|
||||
jax_params = {k: v for k, v in effect_env.items()
|
||||
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
|
||||
if k != 'frame' and k in effect['params']}
|
||||
t = env.get('t', 0.0)
|
||||
frame_num = env.get('frame-num', 0)
|
||||
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
|
||||
if result is not None:
|
||||
return result
|
||||
# Fall through to interpreter if JAX fails
|
||||
|
||||
# Check if input is a deferred chain - if so, extend it
|
||||
if isinstance(frame_val, DeferredEffectChain):
|
||||
return frame_val.extend(op, jax_params)
|
||||
|
||||
# Check if input is a valid frame - create new deferred chain
|
||||
if hasattr(frame_val, 'shape'):
|
||||
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
|
||||
|
||||
# Fall through to interpreter if not a valid frame
|
||||
|
||||
# Force any deferred frame before interpreter evaluation
|
||||
if isinstance(frame_val, DeferredEffectChain):
|
||||
frame_val = self._force_deferred(frame_val)
|
||||
effect_env['frame'] = frame_val
|
||||
|
||||
return self._eval(effect['body'], effect_env)
|
||||
|
||||
@@ -922,10 +1327,15 @@ class StreamInterpreter:
|
||||
if isinstance(args[i], Keyword):
|
||||
k = args[i].name
|
||||
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
||||
# Force deferred chains before passing to primitives
|
||||
v = self._maybe_force(v)
|
||||
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
||||
i += 2
|
||||
else:
|
||||
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
|
||||
val = self._eval(args[i], env)
|
||||
# Force deferred chains before passing to primitives
|
||||
val = self._maybe_force(val)
|
||||
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
|
||||
i += 1
|
||||
try:
|
||||
if kwargs:
|
||||
@@ -1152,6 +1562,61 @@ class StreamInterpreter:
|
||||
eval_times = []
|
||||
write_times = []
|
||||
|
||||
# Batch accumulation for JAX
|
||||
batch_deferred = [] # Accumulated DeferredEffectChains
|
||||
batch_times = [] # Corresponding time values
|
||||
batch_start_frame = 0
|
||||
|
||||
def flush_batch():
|
||||
"""Execute accumulated batch and write results."""
|
||||
nonlocal batch_deferred, batch_times
|
||||
if not batch_deferred:
|
||||
return
|
||||
|
||||
t_flush = time.time()
|
||||
|
||||
# Check if all chains have same structure (can batch)
|
||||
first = batch_deferred[0]
|
||||
can_batch = (
|
||||
self.use_jax and
|
||||
len(batch_deferred) >= 2 and
|
||||
all(d.effects == first.effects for d in batch_deferred)
|
||||
)
|
||||
|
||||
if can_batch:
|
||||
# Try batched execution
|
||||
frames = [d.base_frame for d in batch_deferred]
|
||||
ts = [d.t for d in batch_deferred]
|
||||
frame_nums = [d.frame_num for d in batch_deferred]
|
||||
params_batch = [d.params_list for d in batch_deferred]
|
||||
|
||||
results = self._apply_batched_chain(
|
||||
first.effects, params_batch, frames, ts, frame_nums
|
||||
)
|
||||
|
||||
if results is not None:
|
||||
# Write batched results
|
||||
for result, t in zip(results, batch_times):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, t)
|
||||
batch_deferred = []
|
||||
batch_times = []
|
||||
return
|
||||
|
||||
# Fall back to sequential execution
|
||||
for deferred, t in zip(batch_deferred, batch_times):
|
||||
result = self._force_deferred(deferred)
|
||||
if result is not None and hasattr(result, 'shape'):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, t)
|
||||
|
||||
batch_deferred = []
|
||||
batch_times = []
|
||||
|
||||
for frame_num in range(start_frame, n_frames):
|
||||
if not out.is_open:
|
||||
break
|
||||
@@ -1182,8 +1647,23 @@ class StreamInterpreter:
|
||||
eval_times.append(time.time() - t1)
|
||||
|
||||
t2 = time.time()
|
||||
if result is not None and hasattr(result, 'shape'):
|
||||
out.write(result, ctx.t)
|
||||
if result is not None:
|
||||
if isinstance(result, DeferredEffectChain):
|
||||
# Accumulate for batching
|
||||
batch_deferred.append(result)
|
||||
batch_times.append(ctx.t)
|
||||
|
||||
# Flush when batch is full
|
||||
if len(batch_deferred) >= self.jax_batch_size:
|
||||
flush_batch()
|
||||
else:
|
||||
# Not deferred - flush any pending batch first, then write
|
||||
flush_batch()
|
||||
if hasattr(result, 'shape'):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, ctx.t)
|
||||
write_times.append(time.time() - t2)
|
||||
|
||||
frame_elapsed = time.time() - frame_start
|
||||
@@ -1219,6 +1699,9 @@ class StreamInterpreter:
|
||||
except Exception as e:
|
||||
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
|
||||
|
||||
# Flush any remaining batch
|
||||
flush_batch()
|
||||
|
||||
finally:
|
||||
out.close()
|
||||
# Store output for access to properties like playlist_cid
|
||||
|
||||
Reference in New Issue
Block a user