Files
celery/streaming/sexp_to_jax.py
gilesb fc9597456f
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:41:19 +00:00

4629 lines
178 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Sexp to JAX Compiler.
Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU.
Uses XLA compilation via @jax.jit for automatic kernel fusion.
Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles
S-expressions directly to JAX operations which XLA then optimizes.
Usage:
from streaming.sexp_to_jax import compile_effect
effect_code = '''
(effect "threshold"
:params ((threshold :default 128))
:body (let ((g (gray frame)))
(rgb (where (> g threshold) 255 0)
(where (> g threshold) 255 0)
(where (> g threshold) 255 0))))
'''
run_effect = compile_effect(effect_code)
output = run_effect(frame, threshold=128)
"""
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Any, Dict, List, Callable, Optional, Tuple
import hashlib
import numpy as np
# Import parser
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
# Import typography primitives
from streaming.jax_typography import bind_typography_primitives
# =============================================================================
# Compilation Cache
# =============================================================================
_COMPILED_EFFECTS: Dict[str, Callable] = {}
# =============================================================================
# Font Atlas for ASCII Effects
# =============================================================================
# Character sets for ASCII rendering
ASCII_ALPHABETS = {
'standard': ' .:-=+*#%@',
'blocks': ' ░▒▓█',
'simple': ' .:oO@',
'digits': ' 0123456789',
'binary': ' 01',
'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$',
}
# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array
_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {}
def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray:
"""
Create a font atlas with all characters pre-rendered.
Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time.
Args:
alphabet: String of characters to render (ordered by brightness, dark to light)
char_size: Size of each character cell in pixels
font_name: Optional font name/path (uses default monospace if None)
Returns:
NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters
Each character is white on black background.
"""
cache_key = (alphabet, char_size, font_name)
if cache_key in _FONT_ATLAS_CACHE:
return _FONT_ATLAS_CACHE[cache_key]
try:
from PIL import Image, ImageDraw, ImageFont
except ImportError:
# Fallback: create simple block-based atlas without PIL
return _create_block_atlas(alphabet, char_size)
num_chars = len(alphabet)
atlas = []
# Try to load a monospace font
font = None
font_size = int(char_size * 0.9) # Slightly smaller than cell
# Try various monospace fonts
font_candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
'/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf',
'/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf',
'/System/Library/Fonts/Menlo.ttc', # macOS
'/System/Library/Fonts/Monaco.dfont', # macOS
'C:\\Windows\\Fonts\\consola.ttf', # Windows
]
for font_path in font_candidates:
if font_path is None:
continue
try:
font = ImageFont.truetype(font_path, font_size)
break
except (IOError, OSError):
continue
if font is None:
# Use default font
try:
font = ImageFont.load_default()
except:
# Ultimate fallback to blocks
return _create_block_atlas(alphabet, char_size)
for char in alphabet:
# Create image for this character
img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0))
draw = ImageDraw.Draw(img)
# Get text bounding box for centering
try:
bbox = draw.textbbox((0, 0), char, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
except AttributeError:
# Older PIL versions
text_width, text_height = draw.textsize(char, font=font)
# Center the character
x = (char_size - text_width) // 2
y = (char_size - text_height) // 2
# Draw white character on black background
draw.text((x, y), char, fill=(255, 255, 255), font=font)
# Convert to numpy array (NOT jax array - avoids tracer issues)
char_array = np.array(img, dtype=np.uint8)
atlas.append(char_array)
atlas = np.stack(atlas, axis=0)
_FONT_ATLAS_CACHE[cache_key] = atlas
return atlas
def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray:
"""
Create a simple block-based atlas without fonts.
Uses numpy to avoid tracer issues.
"""
num_chars = len(alphabet)
atlas = []
for i, char in enumerate(alphabet):
# Brightness proportional to position in alphabet
brightness = int(255 * i / max(num_chars - 1, 1))
# Create a simple pattern based on character
img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8)
# Add some texture/pattern for visual interest
# Checkerboard pattern for mid-range characters
if 0.2 < i / num_chars < 0.8:
y_coords, x_coords = np.mgrid[:char_size, :char_size]
checker = ((x_coords + y_coords) % 2 == 0)
variation = int(brightness * 0.2)
img = np.where(checker[:, :, None],
np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8),
np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8))
atlas.append(img)
return np.stack(atlas, axis=0)
def _get_alphabet_string(alphabet_name: str) -> str:
"""Get the character string for a named alphabet or return as-is if custom."""
if alphabet_name in ASCII_ALPHABETS:
return ASCII_ALPHABETS[alphabet_name]
return alphabet_name # Assume it's a custom character string
# =============================================================================
# Text Rendering with Font Atlas (JAX-compatible)
# =============================================================================
# Default character set for text rendering (printable ASCII)
TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~
# Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height)
_TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {}
def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple:
"""
Create a font atlas for general text rendering with proper baseline alignment.
Font Metrics (from typography):
- Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l)
- Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y)
- Baseline: the line text "sits" on - all characters align to this
- Em-square: the design space, typically = ascender + descender
Returns:
(atlas, char_to_idx, char_widths, char_height, baseline_offset)
- atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA
- char_to_idx: dict mapping character to atlas index
- char_widths: numpy array of actual width for each character
- char_height: height of character cells (ascent + descent)
- baseline_offset: pixels from top of cell to baseline (= ascent)
"""
cache_key = (font_name, font_size)
if cache_key in _TEXT_ATLAS_CACHE:
return _TEXT_ATLAS_CACHE[cache_key]
try:
from PIL import Image, ImageDraw, ImageFont
except ImportError:
raise ImportError("PIL/Pillow required for text rendering")
# Load font - match drawing.py's font order for consistency
font = None
font_candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py
'/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
'/System/Library/Fonts/Helvetica.ttc',
'/System/Library/Fonts/Arial.ttf',
'C:\\Windows\\Fonts\\arial.ttf',
]
for font_path in font_candidates:
if font_path is None:
continue
try:
font = ImageFont.truetype(font_path, font_size)
break
except (IOError, OSError):
continue
if font is None:
font = ImageFont.load_default()
# Get font metrics - this is the key to proper text layout
# getmetrics() returns (ascent, descent) where:
# ascent = pixels from baseline to top of tallest character
# descent = pixels from baseline to bottom of lowest character
ascent, descent = font.getmetrics()
# Cell dimensions based on font metrics (not per-character bounding boxes)
cell_height = ascent + descent + 2 # +2 for padding
baseline_y = ascent + 1 # Baseline position within cell (1px padding from top)
# Find max character width
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
temp_draw = ImageDraw.Draw(temp_img)
max_width = 0
char_widths_dict = {}
for char in TEXT_CHARSET:
try:
# Use getlength for horizontal advance (proper character spacing)
advance = font.getlength(char)
char_widths_dict[char] = int(advance)
max_width = max(max_width, int(advance))
except:
char_widths_dict[char] = font_size // 2
max_width = max(max_width, font_size // 2)
cell_width = max_width + 2 # +2 for padding
# Create atlas with all characters - draw same way as prim_text for pixel-perfect match
char_to_idx = {}
char_widths = [] # Advance widths
char_left_bearings = [] # Left bearing (x offset from origin to first pixel)
atlas = []
# Position to draw at within each tile (with margin for negative bearings)
draw_x = 5 # Margin for chars with negative left bearing
draw_y = 0 # Top of cell (PIL default without anchor)
for i, char in enumerate(TEXT_CHARSET):
char_to_idx[char] = i
char_widths.append(char_widths_dict.get(char, cell_width // 2))
# Create RGBA image for this character
img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Draw same way as prim_text - at (draw_x, draw_y), no anchor
# This positions the text origin, and glyphs may extend left/right from there
draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
# Get bbox to find left bearing
bbox = draw.textbbox((draw_x, draw_y), char, font=font)
left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends
char_left_bearings.append(left_bearing)
# Convert to numpy
char_array = np.array(img, dtype=np.uint8)
atlas.append(char_array)
atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4)
char_widths = np.array(char_widths, dtype=np.int32)
char_left_bearings = np.array(char_left_bearings, dtype=np.int32)
# Return draw_x (origin offset within tile) so rendering knows where origin is
result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings)
_TEXT_ATLAS_CACHE[cache_key] = result
return result
def jax_text_render(frame, text: str, x: int, y: int,
font_name: str = None, font_size: int = 32,
color=(255, 255, 255), opacity: float = 1.0,
align: str = "left", valign: str = "baseline",
shadow: bool = False, shadow_color=(0, 0, 0),
shadow_offset: int = 2):
"""
Render text onto frame using font atlas (JAX-compatible).
This is designed to be called from within a JIT-compiled function.
The font atlas is created at compile time (using numpy/PIL),
then converted to JAX array for the actual rendering.
Typography notes:
- Baseline: The line text "sits" on. Most characters rest on this line.
- Ascender: Top of tall letters (b, d, h, k, l) - above baseline
- Descender: Bottom of letters like g, j, p, q, y - below baseline
- For normal text, use valign="baseline" and y = the baseline position
Args:
frame: Input frame (H, W, 3)
text: Text string to render
x, y: Position reference point (affected by align/valign)
font_name: Font to use (None = default)
font_size: Font size in pixels
color: RGB tuple (0-255)
opacity: 0.0 to 1.0
align: Horizontal alignment relative to x:
"left" - text starts at x
"center" - text centered on x
"right" - text ends at x
valign: Vertical alignment relative to y:
"baseline" - text baseline at y (default, like normal text)
"top" - top of ascenders at y
"middle" - text vertically centered on y
"bottom" - bottom of descenders at y
shadow: Whether to draw drop shadow
shadow_color: Shadow RGB color
shadow_offset: Shadow offset in pixels
Returns:
Frame with text rendered
"""
if not text:
return frame
h, w = frame.shape[:2]
# Get or create font atlas (this happens at trace time, uses numpy)
atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size)
# Convert atlas to JAX array
atlas = jnp.asarray(atlas_np)
# Atlas dimensions
cell_width = atlas.shape[2]
# Convert text to character indices and compute character widths
# (at trace time, text is static so we can pre-compute)
indices_list = []
char_x_offsets = [0] # Starting x position for each character
total_width = 0
for char in text:
if char in char_to_idx:
idx = char_to_idx[char]
indices_list.append(idx)
char_w = int(char_widths_np[idx])
else:
indices_list.append(char_to_idx.get(' ', 0))
char_w = int(char_widths_np[char_to_idx.get(' ', 0)])
total_width += char_w
char_x_offsets.append(total_width)
indices = jnp.array(indices_list, dtype=jnp.int32)
num_chars = len(indices_list)
# Actual text dimensions using proportional widths
text_width = total_width
text_height = char_height
# Adjust position for horizontal alignment
if align == "center":
x = x - text_width // 2
elif align == "right":
x = x - text_width
# Adjust position for vertical alignment
# baseline_offset = pixels from top of cell to baseline
if valign == "baseline":
# y specifies baseline position, so top of text cell is above it
y = y - baseline_offset
elif valign == "middle":
y = y - text_height // 2
elif valign == "bottom":
y = y - text_height
# valign == "top" needs no adjustment (default)
# Ensure position is integer
x, y = int(x), int(y)
# Create text strip with proper character spacing at trace time (using numpy)
# This ensures proportional fonts render correctly
#
# The atlas stores each character drawn at (origin_x, 0) in its tile.
# To place a character at cursor position 'cx':
# - The tile's origin_x should align with cx in the strip
# - So we blit tile to strip starting at (cx - origin_x)
#
# Add padding for characters with negative left bearings
strip_padding = origin_x # Extra space at start for negative bearings
text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8)
for i, char in enumerate(text):
if char in char_to_idx:
idx = char_to_idx[char]
char_tile = atlas_np[idx] # (char_height, cell_width, 4)
cx = char_x_offsets[i]
# Position tile so its origin aligns with cursor position
strip_x = strip_padding + cx - origin_x
if strip_x >= 0:
end_x = min(strip_x + cell_width, text_strip_np.shape[1])
tile_end = end_x - strip_x
text_strip_np[:, strip_x:end_x] = np.maximum(
text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end])
# Trim the strip:
# - Left side: trim to first visible pixel (handles negative left bearing)
# - Right side: use computed text_width (preserve advance width spacing)
alpha = text_strip_np[:, :, 3]
cols_with_content = np.any(alpha > 0, axis=0)
if cols_with_content.any():
first_col = np.argmax(cols_with_content)
# Right edge: use the computed text width from the strip's logical end
right_col = strip_padding + text_width
# Adjust x to account for the left trim offset
x = x + first_col - strip_padding
text_strip_np = text_strip_np[:, first_col:right_col]
else:
# No visible content, return original frame
return frame
# Convert to JAX
text_strip = jnp.asarray(text_strip_np)
# Convert color to array
color = jnp.array(color, dtype=jnp.float32)
shadow_color = jnp.array(shadow_color, dtype=jnp.float32)
# Apply color tint to text strip (white text * color)
text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color
text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity
# Start with frame as float
result = frame.astype(jnp.float32)
# Draw shadow first if enabled
if shadow:
sx, sy = x + shadow_offset, y + shadow_offset
shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color
shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5
result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy)
# Draw main text
result = _composite_text_strip(result, text_rgb, text_alpha, x, y)
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def _composite_text_strip(frame, text_rgb, text_alpha, x, y):
"""
Composite text strip onto frame at position (x, y).
Uses alpha blending: result = text * alpha + frame * (1 - alpha)
This is designed to work within JAX tracing.
"""
h, w = frame.shape[:2]
th, tw = text_rgb.shape[:2]
# Clamp to frame bounds
# Source region (in text strip)
src_x1 = jnp.maximum(0, -x)
src_y1 = jnp.maximum(0, -y)
src_x2 = jnp.minimum(tw, w - x)
src_y2 = jnp.minimum(th, h - y)
# Destination region (in frame)
dst_x1 = jnp.maximum(0, x)
dst_y1 = jnp.maximum(0, y)
dst_x2 = jnp.minimum(w, x + tw)
dst_y2 = jnp.minimum(h, y + th)
# Check if there's anything to draw
# (We need to handle this carefully for JAX - can't use Python if with traced values)
# Instead, we'll do the full operation but the slicing will handle bounds
# Create coordinate grids for the destination region
# We'll use dynamic_slice for JAX-compatible slicing
# For simplicity and JAX compatibility, we'll create a full-frame text layer
# and composite it - this is less efficient but works with JIT
# Create full-frame RGBA layer
text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32)
text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32)
# Place text strip in the layer using dynamic_update_slice
# First pad the text strip to handle out-of-bounds
padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32)
padded_alpha = jnp.zeros((h, w), dtype=jnp.float32)
# Calculate valid region
y_start = int(max(0, y))
y_end = int(min(h, y + th))
x_start = int(max(0, x))
x_end = int(min(w, x + tw))
src_y_start = int(max(0, -y))
src_y_end = src_y_start + (y_end - y_start)
src_x_start = int(max(0, -x))
src_x_end = src_x_start + (x_end - x_start)
# Only proceed if there's a valid region
if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start:
# Extract the valid portion of text
valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end]
valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end]
# Use lax.dynamic_update_slice for JAX compatibility
padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0))
padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start))
# Alpha composite: result = text * alpha + frame * (1 - alpha)
alpha_3d = padded_alpha[:, :, jnp.newaxis]
result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d)
return result
def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple:
"""
Measure text dimensions (width, height).
This can be called at compile time to get text dimensions for layout.
Returns:
(width, height) tuple in pixels
"""
_, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size)
# Sum actual character widths
total_width = 0
for c in text:
if c in char_to_idx:
total_width += int(char_widths[char_to_idx[c]])
else:
total_width += int(char_widths[char_to_idx.get(' ', 0)])
return (total_width, char_height)
def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict:
"""
Get font metrics for layout calculations.
Typography terms:
- ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.)
- descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.)
- height: total height = ascent + descent (plus padding)
- baseline: position of baseline from top of text cell
Returns:
dict with keys: ascent, descent, height, baseline
"""
_, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size)
# baseline_offset is pixels from top to baseline (= ascent + padding)
# descent = height - baseline (approximately)
ascent = baseline_offset - 1 # remove padding
descent = char_height - baseline_offset - 1 # remove padding
return {
'ascent': ascent,
'descent': descent,
'height': char_height,
'baseline': baseline_offset,
}
def jax_fit_text_size(text: str, max_width: int, max_height: int,
font_name: str = None, min_size: int = 8, max_size: int = 200) -> int:
"""
Calculate font size to fit text within bounds.
Binary search for largest size that fits.
"""
best_size = min_size
low, high = min_size, max_size
while low <= high:
mid = (low + high) // 2
w, h = jax_text_size(text, font_name, mid)
if w <= max_width and h <= max_height:
best_size = mid
low = mid + 1
else:
high = mid - 1
return best_size
# =============================================================================
# JAX Primitives - True primitives that can't be derived
# =============================================================================
def jax_width(frame):
"""Frame width."""
return frame.shape[1]
def jax_height(frame):
"""Frame height."""
return frame.shape[0]
def jax_channel(frame, idx):
"""Extract channel by index as flat array."""
# idx must be a static int for indexing
return frame[:, :, int(idx)].flatten().astype(jnp.float32)
def jax_merge_channels(r, g, b, shape):
"""Merge RGB channels back to frame."""
h, w = shape
r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8)
g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8)
b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
return jnp.stack([r_img, g_img, b_img], axis=2)
def jax_iota(n):
"""Generate [0, 1, 2, ..., n-1]."""
return jnp.arange(n, dtype=jnp.float32)
def jax_repeat(x, n):
"""Repeat each element n times: [a,b] -> [a,a,b,b]."""
return jnp.repeat(x, n)
def jax_tile(x, n):
"""Tile array n times: [a,b] -> [a,b,a,b]."""
return jnp.tile(x, n)
def jax_gather(data, indices):
"""Parallel index lookup."""
flat_data = data.flatten()
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1)
return flat_data[idx_clipped]
def jax_scatter(indices, values, size):
"""Parallel index write (last write wins)."""
result = jnp.zeros(size, dtype=jnp.float32)
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
return result.at[idx_clipped].set(values)
def jax_scatter_add(indices, values, size):
"""Parallel index accumulate."""
result = jnp.zeros(size, dtype=jnp.float32)
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
return result.at[idx_clipped].add(values)
def jax_group_reduce(values, group_indices, num_groups, op='mean'):
"""Reduce values by group."""
grp = group_indices.astype(jnp.int32)
if op == 'sum':
result = jnp.zeros(num_groups, dtype=jnp.float32)
return result.at[grp].add(values)
elif op == 'mean':
sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values)
counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0)
return jnp.where(counts > 0, sums / counts, 0.0)
elif op == 'max':
result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32)
result = result.at[grp].max(values)
return jnp.where(result == -jnp.inf, 0.0, result)
elif op == 'min':
result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32)
result = result.at[grp].min(values)
return jnp.where(result == jnp.inf, 0.0, result)
else:
raise ValueError(f"Unknown reduce op: {op}")
def jax_where(cond, true_val, false_val):
"""Conditional select."""
return jnp.where(cond, true_val, false_val)
def jax_cell_indices(frame, cell_size):
"""Compute cell index for each pixel."""
h, w = frame.shape[:2]
cell_size = int(cell_size)
rows = h // cell_size
cols = w // cell_size
# For each pixel, compute its cell index
y_coords = jnp.repeat(jnp.arange(h), w)
x_coords = jnp.tile(jnp.arange(w), h)
cell_row = y_coords // cell_size
cell_col = x_coords // cell_size
cell_idx = cell_row * cols + cell_col
# Clip to valid range
return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32)
def jax_pool_frame(frame, cell_size):
"""
Pool frame to cell values.
Returns tuple: (cell_r, cell_g, cell_b, cell_lum)
"""
h, w = frame.shape[:2]
cs = int(cell_size)
rows = h // cs
cols = w // cs
num_cells = rows * cols
# Compute cell indices for each pixel
y_coords = jnp.repeat(jnp.arange(h), w)
x_coords = jnp.tile(jnp.arange(w), h)
cell_row = jnp.clip(y_coords // cs, 0, rows - 1)
cell_col = jnp.clip(x_coords // cs, 0, cols - 1)
cell_idx = (cell_row * cols + cell_col).astype(jnp.int32)
# Extract channels
r_flat = frame[:, :, 0].flatten().astype(jnp.float32)
g_flat = frame[:, :, 1].flatten().astype(jnp.float32)
b_flat = frame[:, :, 2].flatten().astype(jnp.float32)
# Pool each channel (mean)
def pool_channel(data):
sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data)
counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0)
return jnp.where(counts > 0, sums / counts, 0.0)
r_pooled = pool_channel(r_flat)
g_pooled = pool_channel(g_flat)
b_pooled = pool_channel(b_flat)
lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled
return (r_pooled, g_pooled, b_pooled, lum)
# =============================================================================
# Scan (Prefix Operations) - JAX implementations
# =============================================================================
def jax_scan_add(x, axis=None):
"""Cumulative sum (prefix sum)."""
if axis is not None:
return jnp.cumsum(x, axis=int(axis))
return jnp.cumsum(x.flatten())
def jax_scan_mul(x, axis=None):
"""Cumulative product."""
if axis is not None:
return jnp.cumprod(x, axis=int(axis))
return jnp.cumprod(x.flatten())
def jax_scan_max(x, axis=None):
"""Cumulative maximum."""
if axis is not None:
return lax.cummax(x, axis=int(axis))
return lax.cummax(x.flatten(), axis=0)
def jax_scan_min(x, axis=None):
"""Cumulative minimum."""
if axis is not None:
return lax.cummin(x, axis=int(axis))
return lax.cummin(x.flatten(), axis=0)
# =============================================================================
# Outer Product - JAX implementations
# =============================================================================
def jax_outer(x, y, op='*'):
"""Outer product with configurable operation."""
x_flat = x.flatten()
y_flat = y.flatten()
ops = {
'*': lambda a, b: jnp.outer(a, b),
'+': lambda a, b: a[:, None] + b[None, :],
'-': lambda a, b: a[:, None] - b[None, :],
'/': lambda a, b: a[:, None] / b[None, :],
'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]),
'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]),
}
op_fn = ops.get(op, ops['*'])
return op_fn(x_flat, y_flat)
def jax_outer_add(x, y):
"""Outer sum."""
return jax_outer(x, y, '+')
def jax_outer_mul(x, y):
"""Outer product."""
return jax_outer(x, y, '*')
def jax_outer_max(x, y):
"""Outer max."""
return jax_outer(x, y, 'max')
def jax_outer_min(x, y):
"""Outer min."""
return jax_outer(x, y, 'min')
# =============================================================================
# Reduce with Axis - JAX implementations
# =============================================================================
def jax_reduce_axis(x, op='sum', axis=0):
"""Reduce along an axis."""
axis = int(axis)
ops = {
'sum': lambda d: jnp.sum(d, axis=axis),
'+': lambda d: jnp.sum(d, axis=axis),
'mean': lambda d: jnp.mean(d, axis=axis),
'max': lambda d: jnp.max(d, axis=axis),
'min': lambda d: jnp.min(d, axis=axis),
'prod': lambda d: jnp.prod(d, axis=axis),
'*': lambda d: jnp.prod(d, axis=axis),
'std': lambda d: jnp.std(d, axis=axis),
}
op_fn = ops.get(op, ops['sum'])
return op_fn(x)
def jax_sum_axis(x, axis=0):
"""Sum along axis."""
return jnp.sum(x, axis=int(axis))
def jax_mean_axis(x, axis=0):
"""Mean along axis."""
return jnp.mean(x, axis=int(axis))
def jax_max_axis(x, axis=0):
"""Max along axis."""
return jnp.max(x, axis=int(axis))
def jax_min_axis(x, axis=0):
"""Min along axis."""
return jnp.min(x, axis=int(axis))
# =============================================================================
# Windowed Operations - JAX implementations
# =============================================================================
def jax_window(x, size, op='mean', stride=1):
"""
Sliding window operation.
For 1D arrays: standard sliding window
For 2D arrays: 2D sliding window (size x size)
"""
size = int(size)
stride = int(stride)
if x.ndim == 1:
# 1D sliding window using convolution trick
n = len(x)
if op == 'sum':
kernel = jnp.ones(size)
return jnp.convolve(x, kernel, mode='valid')[::stride]
elif op == 'mean':
kernel = jnp.ones(size) / size
return jnp.convolve(x, kernel, mode='valid')[::stride]
else:
# For max/min, use manual approach
out_n = (n - size) // stride + 1
indices = jnp.arange(out_n) * stride
windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices)
if op == 'max':
return jnp.max(windows, axis=1)
elif op == 'min':
return jnp.min(windows, axis=1)
else:
return jnp.mean(windows, axis=1)
else:
# 2D sliding window
h, w = x.shape[:2]
out_h = (h - size) // stride + 1
out_w = (w - size) // stride + 1
# Extract all windows using vmap
def extract_window(ij):
i, j = ij // out_w, ij % out_w
return lax.dynamic_slice(x, (i * stride, j * stride), (size, size))
indices = jnp.arange(out_h * out_w)
windows = jax.vmap(extract_window)(indices)
if op == 'sum':
result = jnp.sum(windows, axis=(1, 2))
elif op == 'mean':
result = jnp.mean(windows, axis=(1, 2))
elif op == 'max':
result = jnp.max(windows, axis=(1, 2))
elif op == 'min':
result = jnp.min(windows, axis=(1, 2))
else:
result = jnp.mean(windows, axis=(1, 2))
return result.reshape(out_h, out_w)
def jax_window_sum(x, size, stride=1):
"""Sliding window sum."""
return jax_window(x, size, 'sum', stride)
def jax_window_mean(x, size, stride=1):
"""Sliding window mean."""
return jax_window(x, size, 'mean', stride)
def jax_window_max(x, size, stride=1):
"""Sliding window max."""
return jax_window(x, size, 'max', stride)
def jax_window_min(x, size, stride=1):
"""Sliding window min."""
return jax_window(x, size, 'min', stride)
def jax_integral_image(frame):
"""
Compute integral image (summed area table).
Enables O(1) box blur at any radius.
"""
if frame.ndim == 3:
# Convert to grayscale
gray = jnp.mean(frame.astype(jnp.float32), axis=2)
else:
gray = frame.astype(jnp.float32)
# Cumsum along both axes
return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1)
def jax_sample(frame, x, y):
"""Bilinear sample at (x, y) coordinates.
Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default):
out-of-bounds samples return 0, then bilinear blend includes those zeros.
"""
h, w = frame.shape[:2]
# Get integer coords for the 4 sample points
x0 = jnp.floor(x).astype(jnp.int32)
y0 = jnp.floor(y).astype(jnp.int32)
x1 = x0 + 1
y1 = y0 + 1
fx = x - x0.astype(jnp.float32)
fy = y - y0.astype(jnp.float32)
# Check which sample points are in bounds
valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h)
valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h)
valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h)
valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h)
# Clamp indices for safe array access (values will be masked anyway)
x0_safe = jnp.clip(x0, 0, w - 1)
x1_safe = jnp.clip(x1, 0, w - 1)
y0_safe = jnp.clip(y0, 0, h - 1)
y1_safe = jnp.clip(y1, 0, h - 1)
# Bilinear interpolation for each channel
def interp_channel(c):
# Sample with 0 for out-of-bounds (BORDER_CONSTANT)
c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0)
c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0)
c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0)
c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0)
return (c00 * (1 - fx) * (1 - fy) +
c10 * fx * (1 - fy) +
c01 * (1 - fx) * fy +
c11 * fx * fy)
r = interp_channel(0)
g = interp_channel(1)
b = interp_channel(2)
return r, g, b
# =============================================================================
# Convolution Operations
# =============================================================================
def jax_convolve2d(data, kernel):
"""2D convolution on a single channel."""
# data shape: (H, W), kernel shape: (kH, kW)
# Use JAX's conv with appropriate padding
h, w = data.shape
kh, kw = kernel.shape
# Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c)
data_4d = data.reshape(1, h, w, 1)
kernel_4d = kernel.reshape(kh, kw, 1, 1)
# Convolve with 'SAME' padding
result = lax.conv_general_dilated(
data_4d, kernel_4d,
window_strides=(1, 1),
padding='SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
return result.reshape(h, w)
def jax_blur(frame, radius=1):
"""Gaussian blur."""
# Create gaussian kernel
size = int(radius) * 2 + 1
x = jnp.arange(size) - radius
gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
kernel = jnp.outer(gaussian_1d, gaussian_1d)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_sharpen(frame, amount=1.0):
"""Sharpen using unsharp mask."""
kernel = jnp.array([
[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]
], dtype=jnp.float32)
# Adjust kernel based on amount
center = 4 * amount + 1
kernel = kernel.at[1, 1].set(center)
kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_edge_detect(frame):
"""Sobel edge detection."""
# Sobel kernels
sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)
sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)
# Convert to grayscale first
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
gx = jax_convolve2d(gray, sobel_x)
gy = jax_convolve2d(gray, sobel_y)
edges = jnp.sqrt(gx**2 + gy**2)
edges = jnp.clip(edges, 0, 255).astype(jnp.uint8)
return jnp.stack([edges, edges, edges], axis=2)
def jax_emboss(frame):
"""Emboss effect."""
kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32)
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
embossed = jax_convolve2d(gray, kernel) + 128
embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8)
return jnp.stack([embossed, embossed, embossed], axis=2)
# =============================================================================
# Color Space Conversion
# =============================================================================
def jax_rgb_to_hsv(r, g, b):
"""Convert RGB to HSV. All inputs/outputs are 0-255 range."""
r, g, b = r / 255.0, g / 255.0, b / 255.0
max_c = jnp.maximum(jnp.maximum(r, g), b)
min_c = jnp.minimum(jnp.minimum(r, g), b)
diff = max_c - min_c
# Value
v = max_c
# Saturation
s = jnp.where(max_c > 0, diff / max_c, 0.0)
# Hue
h = jnp.where(diff == 0, 0.0,
jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360,
jnp.where(max_c == g, 60 * ((b - r) / diff) + 120,
60 * ((r - g) / diff) + 240)))
return h, s * 255, v * 255
def jax_hsv_to_rgb(h, s, v):
"""Convert HSV to RGB. H is 0-360, S and V are 0-255."""
h = h % 360
s, v = s / 255.0, v / 255.0
c = v * s
x = c * (1 - jnp.abs((h / 60) % 2 - 1))
m = v - c
h_sector = (h / 60).astype(jnp.int32) % 6
r = jnp.where(h_sector == 0, c,
jnp.where(h_sector == 1, x,
jnp.where(h_sector == 2, 0,
jnp.where(h_sector == 3, 0,
jnp.where(h_sector == 4, x, c)))))
g = jnp.where(h_sector == 0, x,
jnp.where(h_sector == 1, c,
jnp.where(h_sector == 2, c,
jnp.where(h_sector == 3, x,
jnp.where(h_sector == 4, 0, 0)))))
b = jnp.where(h_sector == 0, 0,
jnp.where(h_sector == 1, 0,
jnp.where(h_sector == 2, x,
jnp.where(h_sector == 3, c,
jnp.where(h_sector == 4, c, x)))))
return (r + m) * 255, (g + m) * 255, (b + m) * 255
def jax_adjust_saturation(frame, factor):
"""Adjust saturation by factor (1.0 = unchanged)."""
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
h, s, v = jax_rgb_to_hsv(r, g, b)
s = jnp.clip(s * factor, 0, 255)
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
h_dim, w_dim = frame.shape[:2]
return jnp.stack([
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
], axis=2)
def jax_shift_hue(frame, degrees):
"""Shift hue by degrees."""
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
h, s, v = jax_rgb_to_hsv(r, g, b)
h = (h + degrees) % 360
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
h_dim, w_dim = frame.shape[:2]
return jnp.stack([
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
], axis=2)
# =============================================================================
# Color Adjustment Operations
# =============================================================================
def jax_adjust_brightness(frame, amount):
"""Adjust brightness by amount (-255 to 255)."""
result = frame.astype(jnp.float32) + amount
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_adjust_contrast(frame, factor):
"""Adjust contrast by factor (1.0 = unchanged)."""
result = (frame.astype(jnp.float32) - 128) * factor + 128
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_invert(frame):
"""Invert colors."""
return 255 - frame
def jax_posterize(frame, levels):
"""Reduce to N color levels per channel."""
levels = int(levels)
if levels < 2:
levels = 2
step = 255.0 / (levels - 1)
result = jnp.round(frame.astype(jnp.float32) / step) * step
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_threshold(frame, level, invert=False):
"""Binary threshold."""
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
if invert:
binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8)
else:
binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8)
return jnp.stack([binary, binary, binary], axis=2)
def jax_sepia(frame):
"""Apply sepia tone."""
r = frame[:, :, 0].astype(jnp.float32)
g = frame[:, :, 1].astype(jnp.float32)
b = frame[:, :, 2].astype(jnp.float32)
new_r = r * 0.393 + g * 0.769 + b * 0.189
new_g = r * 0.349 + g * 0.686 + b * 0.168
new_b = r * 0.272 + g * 0.534 + b * 0.131
return jnp.stack([
jnp.clip(new_r, 0, 255).astype(jnp.uint8),
jnp.clip(new_g, 0, 255).astype(jnp.uint8),
jnp.clip(new_b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_grayscale(frame):
"""Convert to grayscale."""
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
gray = gray.astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# =============================================================================
# Geometry Operations
# =============================================================================
def jax_flip_horizontal(frame):
"""Flip horizontally."""
return frame[:, ::-1, :]
def jax_flip_vertical(frame):
"""Flip vertically."""
return frame[::-1, :, :]
def jax_rotate(frame, angle, center_x=None, center_y=None):
"""Rotate frame by angle (degrees), matching OpenCV convention.
Positive angle = counter-clockwise rotation.
"""
h, w = frame.shape[:2]
if center_x is None:
center_x = w / 2
if center_y is None:
center_y = h / 2
# Convert to radians
theta = angle * jnp.pi / 180
cos_t, sin_t = jnp.cos(theta), jnp.sin(theta)
# Create coordinate grids
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
# OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]]
# For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]]
# So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx
# src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy
x_centered = x_coords - center_x
y_centered = y_coords - center_y
src_x = cos_t * x_centered - sin_t * y_centered + center_x
src_y = sin_t * x_centered + cos_t * y_centered + center_y
# Sample using bilinear interpolation
# jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples)
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
def jax_scale(frame, scale_x, scale_y=None):
"""Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds."""
if scale_y is None:
scale_y = scale_x
h, w = frame.shape[:2]
center_x, center_y = w / 2, h / 2
# Create coordinate grids
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
# Scale from center (inverse mapping: dst -> src)
src_x = (x_coords - center_x) / scale_x + center_x
src_y = (y_coords - center_y) / scale_y + center_y
# Sample using bilinear interpolation
# jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples)
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
def jax_resize(frame, new_width, new_height):
"""Resize frame to new dimensions."""
h, w = frame.shape[:2]
new_h, new_w = int(new_height), int(new_width)
# Create coordinate grids for new size
y_coords = jnp.repeat(jnp.arange(new_h), new_w)
x_coords = jnp.tile(jnp.arange(new_w), new_h)
# Map to source coordinates
src_x = x_coords * (w - 1) / (new_w - 1)
src_y = y_coords * (h - 1) / (new_h - 1)
r, g, b = jax_sample(frame, src_x, src_y)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8)
], axis=2)
# =============================================================================
# Blending Operations
# =============================================================================
def _resize_to_match(frame1, frame2):
"""Resize frame2 to match frame1's dimensions if they differ.
Uses jax.image.resize for bilinear interpolation.
Returns frame2 resized to frame1's shape.
"""
h1, w1 = frame1.shape[:2]
h2, w2 = frame2.shape[:2]
# If same size, return as-is
if h1 == h2 and w1 == w2:
return frame2
# Resize frame2 to match frame1
# jax.image.resize expects (height, width, channels) and target shape
return jax.image.resize(
frame2.astype(jnp.float32),
(h1, w1, frame2.shape[2]),
method='bilinear'
).astype(jnp.uint8)
def jax_blend(frame1, frame2, alpha):
"""Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2.
Auto-resizes frame2 to match frame1 if dimensions differ.
"""
frame2 = _resize_to_match(frame1, frame2)
return (frame1.astype(jnp.float32) * (1 - alpha) +
frame2.astype(jnp.float32) * alpha).astype(jnp.uint8)
def jax_blend_add(frame1, frame2):
"""Additive blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32)
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_blend_multiply(frame1, frame2):
"""Multiply blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_blend_screen(frame1, frame2):
"""Screen blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
f1 = frame1.astype(jnp.float32) / 255
f2 = frame2.astype(jnp.float32) / 255
result = 1 - (1 - f1) * (1 - f2)
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
def jax_blend_overlay(frame1, frame2):
"""Overlay blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
f1 = frame1.astype(jnp.float32) / 255
f2 = frame2.astype(jnp.float32) / 255
result = jnp.where(f1 < 0.5,
2 * f1 * f2,
1 - 2 * (1 - f1) * (1 - f2))
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
# =============================================================================
# Utility
# =============================================================================
def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0):
"""Create a JAX random key that varies with frame and operation.
Uses jax.random.fold_in to mix frame_num (which may be traced) into the key.
This allows JIT compilation without recompiling for each frame.
Args:
seed: Base seed for determinism (must be concrete)
frame_num: Frame number for variation (can be traced)
op_id: Operation ID for variation (must be concrete)
Returns:
JAX PRNGKey
"""
# Create base key from seed and op_id (both concrete)
base_key = jax.random.PRNGKey(seed + op_id * 1000003)
# Fold in frame_num (can be traced value)
return jax.random.fold_in(base_key, frame_num)
def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42):
"""Random float in [lo, hi), varies with frame."""
key = make_jax_key(seed, frame_num, op_id)
return lo + jax.random.uniform(key) * (hi - lo)
def jax_is_nil(x):
"""Check if value is None/nil."""
return x is None
# =============================================================================
# S-expression to JAX Compiler
# =============================================================================
class JaxCompiler:
"""Compiles S-expressions to JAX functions."""
def __init__(self):
self.env = {} # Variable bindings during compilation
self.params = {} # Effect parameters
self.primitives = {} # Loaded primitive libraries
self.derived = {} # Loaded derived functions
def load_derived(self, path: str):
"""Load derived operations from a .sexp file."""
with open(path, 'r') as f:
code = f.read()
exprs = parse_all(code)
# Evaluate all define expressions to populate derived functions
for expr in exprs:
if isinstance(expr, list) and len(expr) >= 3:
head = expr[0]
if isinstance(head, Symbol) and head.name == 'define':
self._eval_define(expr[1:], self.derived)
def compile_effect(self, sexp) -> Callable:
"""
Compile an effect S-expression to a JAX function.
Supports both formats:
(effect "name" :params (...) :body ...)
(define-effect name :params (...) body)
Args:
sexp: Parsed S-expression
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
if not isinstance(sexp, list) or len(sexp) < 2:
raise ValueError("Effect must be a list")
head = sexp[0]
if not isinstance(head, Symbol):
raise ValueError("Effect must start with a symbol")
form = head.name
# Handle both 'effect' and 'define-effect' formats
if form == 'effect':
# (effect "name" :params (...) :body ...)
name = sexp[1] if len(sexp) > 1 else "unnamed"
if isinstance(name, Symbol):
name = name.name
start_idx = 2
elif form == 'define-effect':
# (define-effect name :params (...) body)
name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1])
start_idx = 2
else:
raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'")
params_spec = []
body = None
i = start_idx
while i < len(sexp):
item = sexp[i]
if isinstance(item, Keyword):
if item.name == 'params' and i + 1 < len(sexp):
params_spec = sexp[i + 1]
i += 2
elif item.name == 'body' and i + 1 < len(sexp):
body = sexp[i + 1]
i += 2
elif item.name in ('desc', 'type', 'range'):
# Skip metadata keywords
i += 2
else:
i += 2 # Skip unknown keywords with their values
else:
# Assume it's the body if we haven't seen one
if body is None:
body = item
i += 1
if body is None:
raise ValueError(f"Effect '{name}' must have a body")
# Extract parameter names, defaults, and static params (strings, bools)
param_info, static_params = self._parse_params(params_spec)
# Capture derived functions for the closure
derived_fns = self.derived.copy()
# Create the JAX function
def effect_fn(frame, **kwargs):
# Set up environment
h, w = frame.shape[:2]
# Get frame_num for deterministic random variation
frame_num = kwargs.get('frame_num', 0)
# Get seed from recipe config (passed via kwargs)
seed = kwargs.get('seed', 42)
env = {
'frame': frame,
'width': w,
'height': h,
'_shape': (h, w),
# Time variables (default to 0, can be overridden via kwargs)
't': kwargs.get('t', kwargs.get('_time', 0.0)),
'_time': kwargs.get('_time', kwargs.get('t', 0.0)),
'time': kwargs.get('time', kwargs.get('t', 0.0)),
# Frame number for random key generation
'frame_num': frame_num,
'frame-num': frame_num,
'_frame_num': frame_num,
# Seed from recipe for deterministic random
'_seed': seed,
# Counter for unique random keys within same frame
'_rand_op_counter': 0,
# Common constants
'pi': jnp.pi,
'PI': jnp.pi,
}
# Add derived functions
env.update(derived_fns)
# Add typography primitives
bind_typography_primitives(env)
# Add parameters with defaults
for pname, pdefault in param_info.items():
if pname in kwargs:
env[pname] = kwargs[pname]
elif isinstance(pdefault, list):
# Unevaluated S-expression default - evaluate it
env[pname] = self._eval(pdefault, env)
else:
env[pname] = pdefault
# Evaluate body
result = self._eval(body, env)
# Ensure result is a frame
if isinstance(result, tuple) and len(result) == 3:
# RGB tuple - merge to frame
r, g, b = result
return jax_merge_channels(r, g, b, (h, w))
elif result.ndim == 3:
return result
else:
# Single channel - replicate to RGB
h, w = env['_shape']
gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# JIT compile with static args for string/bool parameters and seed
# seed must be static for PRNGKey, but frame_num can be traced via fold_in
all_static = set(static_params) | {'seed'}
return jax.jit(effect_fn, static_argnames=list(all_static))
def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]:
"""Parse parameter specifications.
Returns:
Tuple of (param_defaults, static_params)
- param_defaults: Dict mapping param names to default values
- static_params: Set of param names that should be static (strings, bools)
"""
result = {}
static_params = set()
if not isinstance(params_spec, list):
return result, static_params
for param in params_spec:
if isinstance(param, Symbol):
result[param.name] = 0.0
elif isinstance(param, list) and len(param) >= 1:
pname = param[0].name if isinstance(param[0], Symbol) else str(param[0])
pdefault = 0.0
ptype = None
# Look for :default and :type keywords
i = 1
while i < len(param):
if isinstance(param[i], Keyword):
kw = param[i].name
if kw == 'default' and i + 1 < len(param):
pdefault = param[i + 1]
if isinstance(pdefault, Symbol):
if pdefault.name == 'nil':
pdefault = None
elif pdefault.name == 'true':
pdefault = True
elif pdefault.name == 'false':
pdefault = False
i += 2
elif kw == 'type' and i + 1 < len(param):
ptype = param[i + 1]
if isinstance(ptype, Symbol):
ptype = ptype.name
i += 2
else:
i += 1
else:
i += 1
result[pname] = pdefault
# Mark string and bool parameters as static (can't be traced by JAX)
if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)):
static_params.add(pname)
return result, static_params
def _eval(self, expr, env: Dict[str, Any]) -> Any:
"""Evaluate an S-expression in the given environment."""
# Already-evaluated values (e.g., from threading macros)
# JAX arrays, NumPy arrays, tuples, etc.
if hasattr(expr, 'shape'): # JAX/NumPy array
return expr
if isinstance(expr, tuple): # e.g., (r, g, b) from rgb
return expr
# Literals - keep as Python numbers for static operations
if isinstance(expr, (int, float)):
return expr
if isinstance(expr, str):
return expr
# Symbols - variable lookup
if isinstance(expr, Symbol):
name = expr.name
if name in env:
return env[name]
if name == 'nil':
return None
if name == 'true':
return True
if name == 'false':
return False
raise NameError(f"Unknown symbol: {name}")
# Lists - function calls
if isinstance(expr, list) and len(expr) > 0:
head = expr[0]
if isinstance(head, Symbol):
op = head.name
args = expr[1:]
# Special forms
if op == 'let' or op == 'let*':
return self._eval_let(args, env)
if op == 'if':
return self._eval_if(args, env)
if op == 'lambda' or op == 'λ':
return self._eval_lambda(args, env)
if op == 'define':
return self._eval_define(args, env)
# Built-in operations
return self._eval_call(op, args, env)
# Empty list
if isinstance(expr, list) and len(expr) == 0:
return []
raise ValueError(f"Cannot evaluate: {expr}")
def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]):
"""Extract a keyword argument from args list.
Looks for :key value pattern in args and evaluates the value.
Returns default if not found.
"""
i = 0
while i < len(args):
if isinstance(args[i], Keyword) and args[i].name == key:
if i + 1 < len(args):
val = self._eval(args[i + 1], env)
# Handle Symbol values (e.g., :op 'sum -> 'sum')
if isinstance(val, Symbol):
return val.name
return val
return default
i += 1
return default
def _eval_let(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body)."""
if len(args) < 2:
raise ValueError("let requires bindings and body")
bindings = args[0]
body = args[1]
new_env = env.copy()
# Handle both ((var val) ...) and [var val var2 val2 ...] syntax
if isinstance(bindings, list):
# Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...)
if bindings and isinstance(bindings[0], Symbol):
# Flat list: [var val var2 val2 ...]
i = 0
while i < len(bindings) - 1:
var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
val = self._eval(bindings[i + 1], new_env)
new_env[var] = val
i += 2
else:
# Nested list: ((var val) (var2 val2) ...)
for binding in bindings:
if isinstance(binding, list) and len(binding) >= 2:
var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
val = self._eval(binding[1], new_env)
new_env[var] = val
return self._eval(body, new_env)
def _eval_if(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (if cond then else)."""
if len(args) < 2:
raise ValueError("if requires condition and then-branch")
cond = self._eval(args[0], env)
# Handle None as falsy (important for optional params like overlay)
if cond is None:
return self._eval(args[2], env) if len(args) > 2 else None
# For Python scalar bools, use normal Python if
# This allows side effects and None values
if isinstance(cond, bool):
if cond:
return self._eval(args[1], env)
else:
return self._eval(args[2], env) if len(args) > 2 else None
# For NumPy/JAX scalar bools with concrete values
if hasattr(cond, 'item') and cond.shape == ():
try:
if bool(cond.item()):
return self._eval(args[1], env)
else:
return self._eval(args[2], env) if len(args) > 2 else None
except:
pass # Fall through to jnp.where for traced values
# For traced values, evaluate both branches and use jnp.where
then_val = self._eval(args[1], env)
else_val = self._eval(args[2], env) if len(args) > 2 else 0.0
# Handle None by converting to zeros
if then_val is None:
then_val = 0.0
if else_val is None:
else_val = 0.0
# Convert lists to tuples
if isinstance(then_val, list):
then_val = tuple(then_val)
if isinstance(else_val, list):
else_val = tuple(else_val)
# Handle tuple results (e.g., from rgb in map-pixels)
if isinstance(then_val, tuple) and isinstance(else_val, tuple):
return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val))
return jnp.where(cond, then_val, else_val)
def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable:
"""Evaluate (lambda (params) body)."""
if len(args) < 2:
raise ValueError("lambda requires parameters and body")
params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]]
body = args[1]
captured_env = env.copy()
def fn(*fn_args):
local_env = captured_env.copy()
for pname, pval in zip(params, fn_args):
local_env[pname] = pval
return self._eval(body, local_env)
return fn
def _eval_define(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (define name value) or (define (name params) body)."""
if len(args) < 2:
raise ValueError("define requires name and value")
name_part = args[0]
if isinstance(name_part, list):
# Function definition: (define (name params) body)
fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0])
params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]]
body = args[1]
captured_env = env.copy()
def fn(*fn_args):
local_env = captured_env.copy()
for pname, pval in zip(params, fn_args):
local_env[pname] = pval
return self._eval(body, local_env)
env[fn_name] = fn
return fn
else:
# Variable definition
var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part)
val = self._eval(args[1], env)
env[var_name] = val
return val
def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any:
"""Evaluate a function call."""
# Check if it's a user-defined function
if op in env and callable(env[op]):
fn = env[op]
eval_args = [self._eval(a, env) for a in args]
return fn(*eval_args)
# Arithmetic
if op == '+':
vals = [self._eval(a, env) for a in args]
result = vals[0] if vals else 0.0
for v in vals[1:]:
result = result + v
return result
if op == '-':
if len(args) == 1:
return -self._eval(args[0], env)
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result - v
return result
if op == '*':
vals = [self._eval(a, env) for a in args]
result = vals[0] if vals else 1.0
for v in vals[1:]:
result = result * v
return result
if op == '/':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result / v
return result
if op == 'mod':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a % b
if op == 'pow' or op == '**':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return jnp.power(a, b)
# Comparison
if op == '<':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a < b
if op == '>':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a > b
if op == '<=':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a <= b
if op == '>=':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a >= b
if op == '=' or op == '==':
a, b = self._eval(args[0], env), self._eval(args[1], env)
# For scalar Python types, return Python bool to enable trace-time if
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return bool(a == b)
return a == b
if op == '!=' or op == '<>':
a, b = self._eval(args[0], env), self._eval(args[1], env)
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return bool(a != b)
return a != b
# Logic
if op == 'and':
vals = [self._eval(a, env) for a in args]
# Use Python and for concrete Python bools (e.g., shape comparisons)
if all(isinstance(v, (bool, np.bool_)) for v in vals):
result = True
for v in vals:
result = result and bool(v)
return result
# Otherwise use JAX logical_and
result = vals[0]
for v in vals[1:]:
result = jnp.logical_and(result, v)
return result
if op == 'or':
# Lisp-style or: returns first truthy value, not boolean
# (or a b c) returns a if a is truthy, else b if b is truthy, else c
for arg in args:
val = self._eval(arg, env)
# Check if value is truthy
if val is None:
continue
if isinstance(val, (bool, np.bool_)):
if val:
return val
continue
if isinstance(val, (int, float)):
if val:
return val
continue
if hasattr(val, 'shape'):
# JAX/numpy array - return it (considered truthy)
return val
# For other types, check truthiness
if val:
return val
# All values were falsy, return the last one
return self._eval(args[-1], env) if args else None
if op == 'not':
val = self._eval(args[0], env)
if isinstance(val, (bool, np.bool_)):
return not bool(val)
return jnp.logical_not(val)
# Math functions
if op == 'sqrt':
return jnp.sqrt(self._eval(args[0], env))
if op == 'sin':
return jnp.sin(self._eval(args[0], env))
if op == 'cos':
return jnp.cos(self._eval(args[0], env))
if op == 'tan':
return jnp.tan(self._eval(args[0], env))
if op == 'exp':
return jnp.exp(self._eval(args[0], env))
if op == 'log':
return jnp.log(self._eval(args[0], env))
if op == 'abs':
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return abs(x)
return jnp.abs(x)
if op == 'floor':
import math
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return math.floor(x)
return jnp.floor(x)
if op == 'ceil':
import math
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return math.ceil(x)
return jnp.ceil(x)
if op == 'round':
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return round(x)
return jnp.round(x)
# Frame primitives
if op == 'width':
return env['width']
if op == 'height':
return env['height']
if op == 'channel':
frame = self._eval(args[0], env)
idx = self._eval(args[1], env)
# idx should be a Python int (literal from S-expression)
return jax_channel(frame, idx)
if op == 'merge-channels' or op == 'rgb':
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
# For scalars (e.g., in map-pixels), return tuple
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
if r_is_scalar and g_is_scalar and b_is_scalar:
return (r, g, b)
return jax_merge_channels(r, g, b, env['_shape'])
if op == 'sample':
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
return jax_sample(frame, x, y)
if op == 'cell-indices':
frame = self._eval(args[0], env)
cell_size = self._eval(args[1], env)
return jax_cell_indices(frame, cell_size)
if op == 'pool-frame':
frame = self._eval(args[0], env)
cell_size = self._eval(args[1], env)
return jax_pool_frame(frame, cell_size)
# Xector primitives
if op == 'iota':
n = self._eval(args[0], env)
return jax_iota(int(n))
if op == 'repeat':
x = self._eval(args[0], env)
n = self._eval(args[1], env)
return jax_repeat(x, int(n))
if op == 'tile':
x = self._eval(args[0], env)
n = self._eval(args[1], env)
return jax_tile(x, int(n))
if op == 'gather':
data = self._eval(args[0], env)
indices = self._eval(args[1], env)
return jax_gather(data, indices)
if op == 'scatter':
indices = self._eval(args[0], env)
values = self._eval(args[1], env)
size = int(self._eval(args[2], env))
return jax_scatter(indices, values, size)
if op == 'scatter-add':
indices = self._eval(args[0], env)
values = self._eval(args[1], env)
size = int(self._eval(args[2], env))
return jax_scatter_add(indices, values, size)
if op == 'group-reduce':
values = self._eval(args[0], env)
groups = self._eval(args[1], env)
num_groups = int(self._eval(args[2], env))
reduce_op = args[3] if len(args) > 3 else 'mean'
if isinstance(reduce_op, Symbol):
reduce_op = reduce_op.name
return jax_group_reduce(values, groups, num_groups, reduce_op)
if op == 'where':
cond = self._eval(args[0], env)
true_val = self._eval(args[1], env)
false_val = self._eval(args[2], env)
# Handle None values
if true_val is None:
true_val = 0.0
if false_val is None:
false_val = 0.0
return jax_where(cond, true_val, false_val)
if op == 'len' or op == 'length':
x = self._eval(args[0], env)
if isinstance(x, (list, tuple)):
return len(x)
return x.size
# Beta reductions
if op in ('β+', 'beta+', 'sum'):
return jnp.sum(self._eval(args[0], env))
if op in ('β*', 'beta*', 'product'):
return jnp.prod(self._eval(args[0], env))
if op in ('βmin', 'beta-min'):
return jnp.min(self._eval(args[0], env))
if op in ('βmax', 'beta-max'):
return jnp.max(self._eval(args[0], env))
if op in ('βmean', 'beta-mean', 'mean'):
return jnp.mean(self._eval(args[0], env))
if op in ('βstd', 'beta-std'):
return jnp.std(self._eval(args[0], env))
if op in ('βany', 'beta-any'):
return jnp.any(self._eval(args[0], env))
if op in ('βall', 'beta-all'):
return jnp.all(self._eval(args[0], env))
# Scan (prefix) operations
if op in ('scan+', 'scan-add'):
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', None, env)
return jax_scan_add(x, axis)
if op in ('scan*', 'scan-mul'):
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', None, env)
return jax_scan_mul(x, axis)
if op == 'scan-max':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', None, env)
return jax_scan_max(x, axis)
if op == 'scan-min':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', None, env)
return jax_scan_min(x, axis)
# Outer product operations
if op == 'outer':
x = self._eval(args[0], env)
y = self._eval(args[1], env)
op_type = self._eval_kwarg(args, 'op', '*', env)
return jax_outer(x, y, op_type)
if op in ('outer+', 'outer-add'):
x = self._eval(args[0], env)
y = self._eval(args[1], env)
return jax_outer_add(x, y)
if op in ('outer*', 'outer-mul'):
x = self._eval(args[0], env)
y = self._eval(args[1], env)
return jax_outer_mul(x, y)
if op == 'outer-max':
x = self._eval(args[0], env)
y = self._eval(args[1], env)
return jax_outer_max(x, y)
if op == 'outer-min':
x = self._eval(args[0], env)
y = self._eval(args[1], env)
return jax_outer_min(x, y)
# Reduce with axis operations
if op == 'reduce-axis':
x = self._eval(args[0], env)
reduce_op = self._eval_kwarg(args, 'op', 'sum', env)
axis = self._eval_kwarg(args, 'axis', 0, env)
return jax_reduce_axis(x, reduce_op, axis)
if op == 'sum-axis':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', 0, env)
return jax_sum_axis(x, axis)
if op == 'mean-axis':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', 0, env)
return jax_mean_axis(x, axis)
if op == 'max-axis':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', 0, env)
return jax_max_axis(x, axis)
if op == 'min-axis':
x = self._eval(args[0], env)
axis = self._eval_kwarg(args, 'axis', 0, env)
return jax_min_axis(x, axis)
# Windowed operations
if op == 'window':
x = self._eval(args[0], env)
size = int(self._eval(args[1], env))
win_op = self._eval_kwarg(args, 'op', 'mean', env)
stride = int(self._eval_kwarg(args, 'stride', 1, env))
return jax_window(x, size, win_op, stride)
if op == 'window-sum':
x = self._eval(args[0], env)
size = int(self._eval(args[1], env))
stride = int(self._eval_kwarg(args, 'stride', 1, env))
return jax_window_sum(x, size, stride)
if op == 'window-mean':
x = self._eval(args[0], env)
size = int(self._eval(args[1], env))
stride = int(self._eval_kwarg(args, 'stride', 1, env))
return jax_window_mean(x, size, stride)
if op == 'window-max':
x = self._eval(args[0], env)
size = int(self._eval(args[1], env))
stride = int(self._eval_kwarg(args, 'stride', 1, env))
return jax_window_max(x, size, stride)
if op == 'window-min':
x = self._eval(args[0], env)
size = int(self._eval(args[1], env))
stride = int(self._eval_kwarg(args, 'stride', 1, env))
return jax_window_min(x, size, stride)
# Integral image
if op == 'integral-image':
frame = self._eval(args[0], env)
return jax_integral_image(frame)
# Convenience - min/max of two values (handle both scalars and arrays)
if op == 'min' or op == 'min2':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
# Use Python min/max for scalar Python values to preserve type
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return min(a, b)
return jnp.minimum(jnp.asarray(a), jnp.asarray(b))
if op == 'max' or op == 'max2':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
# Use Python min/max for scalar Python values to preserve type
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return max(a, b)
return jnp.maximum(jnp.asarray(a), jnp.asarray(b))
if op == 'clamp':
x = self._eval(args[0], env)
lo = self._eval(args[1], env)
hi = self._eval(args[2], env)
return jnp.clip(x, lo, hi)
# List operations
if op == 'list':
return tuple(self._eval(a, env) for a in args)
if op == 'nth':
seq = self._eval(args[0], env)
idx = int(self._eval(args[1], env))
if isinstance(seq, (list, tuple)):
return seq[idx] if 0 <= idx < len(seq) else None
return seq[idx] # For arrays
if op == 'first':
seq = self._eval(args[0], env)
return seq[0] if len(seq) > 0 else None
if op == 'second':
seq = self._eval(args[0], env)
return seq[1] if len(seq) > 1 else None
# Random (JAX-compatible)
# Get frame_num for deterministic variation - can be traced, fold_in handles it
frame_num = env.get('_frame_num', env.get('frame_num', 0))
# Convert to int32 for fold_in if needed (but keep as JAX array if traced)
if frame_num is None:
frame_num = 0
elif isinstance(frame_num, (int, float)):
frame_num = int(frame_num)
# If it's a JAX array, leave it as-is for tracing
# Increment operation counter for unique keys within same frame
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
if op == 'rand' or op == 'rand-x':
# For size-based random
if args:
size = self._eval(args[0], env)
if hasattr(size, 'shape'):
# For frames (3D), use h*w (channel size), not h*w*c
if size.ndim == 3:
n = size.shape[0] * size.shape[1] # h * w
shape = (n,)
else:
n = size.size
shape = size.shape
elif hasattr(size, 'size'):
n = size.size
shape = (n,)
else:
n = int(size)
shape = (n,)
# Use deterministic key that varies with frame
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.uniform(key, shape).flatten()
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.uniform(key, ())
if op == 'randn' or op == 'randn-x':
# Normal random
if args:
size = self._eval(args[0], env)
if hasattr(size, 'shape'):
# For frames (3D), use h*w (channel size), not h*w*c
if size.ndim == 3:
n = size.shape[0] * size.shape[1] # h * w
else:
n = size.size
elif hasattr(size, 'size'):
n = size.size
else:
n = int(size)
mean = self._eval(args[1], env) if len(args) > 1 else 0.0
std = self._eval(args[2], env) if len(args) > 2 else 1.0
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, (n,)) * std + mean
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, ())
if op == 'rand-range' or op == 'core:rand-range':
lo = self._eval(args[0], env)
hi = self._eval(args[1], env)
seed = env.get('_seed', 42)
return jax_rand_range(lo, hi, frame_num, op_counter, seed)
# =====================================================================
# Convolution operations
# =====================================================================
if op == 'blur' or op == 'image:blur':
frame = self._eval(args[0], env)
radius = self._eval(args[1], env) if len(args) > 1 else 1
# Convert traced value to concrete for kernel size
if hasattr(radius, 'item'):
radius = int(radius.item())
elif hasattr(radius, '__float__'):
radius = int(float(radius))
else:
radius = int(radius)
return jax_blur(frame, max(1, radius))
if op == 'gaussian':
first_arg = self._eval(args[0], env)
# Check if first arg is a frame (blur) or scalar (random)
if hasattr(first_arg, 'shape') and first_arg.ndim == 3:
# Frame - apply gaussian blur
sigma = self._eval(args[1], env) if len(args) > 1 else 1.0
radius = max(1, int(sigma * 3))
return jax_blur(first_arg, radius)
else:
# Scalar args - generate gaussian random value
mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg
std = self._eval(args[1], env) if len(args) > 1 else 1.0
# Return a single random value
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, ()) * std + mean
if op == 'sharpen' or op == 'image:sharpen':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
return jax_sharpen(frame, amount)
if op == 'edge-detect' or op == 'image:edge-detect':
frame = self._eval(args[0], env)
return jax_edge_detect(frame)
if op == 'emboss':
frame = self._eval(args[0], env)
return jax_emboss(frame)
if op == 'convolve':
frame = self._eval(args[0], env)
kernel = self._eval(args[1], env)
# Convert kernel to array if it's a list
if isinstance(kernel, (list, tuple)):
kernel = jnp.array(kernel, dtype=jnp.float32)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
if op == 'add-noise':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env) if len(args) > 1 else 0.1
h, w = frame.shape[:2]
# Use frame-varying key for noise
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1]
result = frame.astype(jnp.float32) + noise * amount * 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
if op == 'translate':
frame = self._eval(args[0], env)
dx = self._eval(args[1], env)
dy = self._eval(args[2], env) if len(args) > 2 else 0
h, w = frame.shape[:2]
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
src_x = (x_coords - dx).flatten()
src_y = (y_coords - dy).flatten()
r, g, b = jax_sample(frame, src_x, src_y)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'image:crop' or op == 'crop':
frame = self._eval(args[0], env)
x = int(self._eval(args[1], env))
y = int(self._eval(args[2], env))
w = int(self._eval(args[3], env))
h = int(self._eval(args[4], env))
return frame[y:y+h, x:x+w, :]
if op == 'dilate':
frame = self._eval(args[0], env)
size = int(self._eval(args[1], env)) if len(args) > 1 else 3
# Simple dilation using max pooling approximation
kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
if op == 'map-rows':
frame = self._eval(args[0], env)
fn = args[1] # S-expression function
h, w = frame.shape[:2]
# For each row, apply the function
results = []
for row_idx in range(h):
row_env = env.copy()
row_env['row'] = frame[row_idx, :, :]
row_env['row-idx'] = row_idx
# Check if fn is a lambda
if isinstance(fn, list) and len(fn) >= 2:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
# Bind lambda params to y and row
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
row_env[param_name] = row_idx # y
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
row_env[param_name] = frame[row_idx, :, :] # row
result_row = self._eval(body, row_env)
results.append(result_row)
continue
result_row = self._eval(fn, row_env)
# If result is a function, call it
if callable(result_row):
result_row = result_row(row_idx, frame[row_idx, :, :])
results.append(result_row)
return jnp.stack(results, axis=0)
# =====================================================================
# Text rendering operations
# =====================================================================
if op == 'text':
frame = self._eval(args[0], env)
text_str = self._eval(args[1], env)
if isinstance(text_str, Symbol):
text_str = text_str.name
text_str = str(text_str)
# Extract keyword arguments
x = self._eval_kwarg(args, 'x', None, env)
y = self._eval_kwarg(args, 'y', None, env)
font_size = self._eval_kwarg(args, 'font-size', 32, env)
font_name = self._eval_kwarg(args, 'font-name', None, env)
color = self._eval_kwarg(args, 'color', (255, 255, 255), env)
opacity = self._eval_kwarg(args, 'opacity', 1.0, env)
align = self._eval_kwarg(args, 'align', 'left', env)
valign = self._eval_kwarg(args, 'valign', 'top', env)
shadow = self._eval_kwarg(args, 'shadow', False, env)
shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env)
shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env)
fit = self._eval_kwarg(args, 'fit', False, env)
width = self._eval_kwarg(args, 'width', None, env)
height = self._eval_kwarg(args, 'height', None, env)
# Handle color as list/tuple
if isinstance(color, (list, tuple)):
color = tuple(int(c) for c in color[:3])
if isinstance(shadow_color, (list, tuple)):
shadow_color = tuple(int(c) for c in shadow_color[:3])
h, w_frame = frame.shape[:2]
# Default position to 0,0 or center based on alignment
if x is None:
if align == 'center':
x = w_frame // 2
elif align == 'right':
x = w_frame
else:
x = 0
if y is None:
if valign == 'middle':
y = h // 2
elif valign == 'bottom':
y = h
else:
y = 0
# Auto-fit text to bounds
if fit and width is not None and height is not None:
font_size = jax_fit_text_size(text_str, int(width), int(height),
font_name, min_size=8, max_size=200)
return jax_text_render(frame, text_str, int(x), int(y),
font_name=font_name, font_size=int(font_size),
color=color, opacity=float(opacity),
align=str(align), valign=str(valign),
shadow=bool(shadow), shadow_color=shadow_color,
shadow_offset=int(shadow_offset))
if op == 'text-size':
text_str = self._eval(args[0], env)
if isinstance(text_str, Symbol):
text_str = text_str.name
text_str = str(text_str)
font_size = self._eval_kwarg(args, 'font-size', 32, env)
font_name = self._eval_kwarg(args, 'font-name', None, env)
return jax_text_size(text_str, font_name, int(font_size))
if op == 'fit-text-size':
text_str = self._eval(args[0], env)
if isinstance(text_str, Symbol):
text_str = text_str.name
text_str = str(text_str)
max_width = int(self._eval(args[1], env))
max_height = int(self._eval(args[2], env))
font_name = self._eval_kwarg(args, 'font-name', None, env)
return jax_fit_text_size(text_str, max_width, max_height, font_name)
# =====================================================================
# Color operations
# =====================================================================
if op == 'rgb->hsv' or op == 'rgb-to-hsv':
# Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple
if len(args) == 1:
c = self._eval(args[0], env)
if isinstance(c, tuple) and len(c) == 3:
r, g, b = c
else:
# Assume it's a list-like
r, g, b = c[0], c[1], c[2]
else:
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
return jax_rgb_to_hsv(r, g, b)
if op == 'hsv->rgb' or op == 'hsv-to-rgb':
# Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list)
if len(args) == 1:
hsv = self._eval(args[0], env)
if isinstance(hsv, (tuple, list)) and len(hsv) >= 3:
h, s, v = hsv[0], hsv[1], hsv[2]
else:
h, s, v = hsv[0], hsv[1], hsv[2]
else:
h = self._eval(args[0], env)
s = self._eval(args[1], env)
v = self._eval(args[2], env)
return jax_hsv_to_rgb(h, s, v)
if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env)
return jax_adjust_brightness(frame, amount)
if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast':
frame = self._eval(args[0], env)
factor = self._eval(args[1], env)
return jax_adjust_contrast(frame, factor)
if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation':
frame = self._eval(args[0], env)
factor = self._eval(args[1], env)
return jax_adjust_saturation(frame, factor)
if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift':
frame = self._eval(args[0], env)
degrees = self._eval(args[1], env)
return jax_shift_hue(frame, degrees)
if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img':
frame = self._eval(args[0], env)
return jax_invert(frame)
if op == 'posterize' or op == 'color_ops:posterize':
frame = self._eval(args[0], env)
levels = self._eval(args[1], env)
return jax_posterize(frame, levels)
if op == 'threshold' or op == 'color_ops:threshold':
frame = self._eval(args[0], env)
level = self._eval(args[1], env)
invert = self._eval(args[2], env) if len(args) > 2 else False
return jax_threshold(frame, level, invert)
if op == 'sepia' or op == 'color_ops:sepia':
frame = self._eval(args[0], env)
return jax_sepia(frame)
if op == 'grayscale' or op == 'image:grayscale':
frame = self._eval(args[0], env)
return jax_grayscale(frame)
# =====================================================================
# Geometry operations
# =====================================================================
if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img':
frame = self._eval(args[0], env)
direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal'
if direction == 'vertical' or direction == 'v':
return jax_flip_vertical(frame)
return jax_flip_horizontal(frame)
if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v':
frame = self._eval(args[0], env)
return jax_flip_vertical(frame)
if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img':
frame = self._eval(args[0], env)
angle = self._eval(args[1], env)
return jax_rotate(frame, angle)
if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img':
frame = self._eval(args[0], env)
scale_x = self._eval(args[1], env)
scale_y = self._eval(args[2], env) if len(args) > 2 else None
return jax_scale(frame, scale_x, scale_y)
if op == 'resize' or op == 'image:resize':
frame = self._eval(args[0], env)
new_w = self._eval(args[1], env)
new_h = self._eval(args[2], env)
return jax_resize(frame, new_w, new_h)
# =====================================================================
# Geometry distortion effects
# =====================================================================
if op == 'geometry:fisheye-coords' or op == 'fisheye':
# Signature: (w h strength cx cy zoom_correct) or (frame strength)
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
# (w h strength cx cy zoom_correct) signature
w = int(first_arg)
h = int(self._eval(args[1], env))
strength = self._eval(args[2], env) if len(args) > 2 else 0.5
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
frame = None
else:
frame = first_arg
strength = self._eval(args[1], env) if len(args) > 1 else 0.5
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
max_r = jnp.sqrt(float(cx*cx + cy*cy))
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
# Fisheye distortion
r_new = r + strength * r * (1 - r / max_r)
src_x = r_new * jnp.cos(theta) + cx
src_y = r_new * jnp.sin(theta) + cy
if frame is None:
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:swirl-coords' or op == 'swirl':
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
w = int(first_arg)
h = int(self._eval(args[1], env))
amount = self._eval(args[2], env) if len(args) > 2 else 1.0
frame = None
else:
frame = first_arg
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
max_r = jnp.sqrt(float(cx*cx + cy*cy))
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
swirl_angle = amount * (1 - r / max_r)
new_theta = theta + swirl_angle
src_x = r * jnp.cos(new_theta) + cx
src_y = r * jnp.sin(new_theta) + cy
if frame is None:
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# Wave effect (frame-first signature for simple usage)
if op == 'wave-distort':
first_arg = self._eval(args[0], env)
frame = first_arg
amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0
amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0
freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1
freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1
h, w = frame.shape[:2]
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y)
src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x)
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:ripple-displace' or op == 'ripple':
# Match Python prim_ripple_displace signature:
# (w h freq amp cx cy decay phase) or (frame ...)
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
# Coordinate-only mode: (w h freq amp cx cy decay phase)
w = int(first_arg)
h = int(self._eval(args[1], env))
freq = self._eval(args[2], env) if len(args) > 2 else 5.0
amp = self._eval(args[3], env) if len(args) > 3 else 10.0
cx = self._eval(args[4], env) if len(args) > 4 else w / 2
cy = self._eval(args[5], env) if len(args) > 5 else h / 2
decay = self._eval(args[6], env) if len(args) > 6 else 0.0
phase = self._eval(args[7], env) if len(args) > 7 else 0.0
frame = None
else:
# Frame mode: (frame :amplitude A :frequency F :center_x CX ...)
frame = first_arg
h, w = frame.shape[:2]
# Parse keyword args
amp = 10.0
freq = 5.0
cx = w / 2
cy = h / 2
decay = 0.0
phase = 0.0
i = 1
while i < len(args):
if isinstance(args[i], Keyword):
kw = args[i].name
val = self._eval(args[i + 1], env) if i + 1 < len(args) else None
if kw == 'amplitude':
amp = val
elif kw == 'frequency':
freq = val
elif kw == 'center_x':
cx = val * w if val <= 1 else val # normalized or absolute
elif kw == 'center_y':
cy = val * h if val <= 1 else val
elif kw == 'decay':
decay = val
elif kw == 'speed':
# speed affects phase via time
t = env.get('t', 0)
phase = t * val * 2 * jnp.pi
elif kw == 'phase':
phase = val
i += 2
else:
i += 1
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
dist = jnp.sqrt(dx*dx + dy*dy)
# Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp
max_dim = jnp.maximum(w, h)
ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp
# Apply decay (when decay=0, exp(0)=1 so no effect)
decay_factor = jnp.exp(-decay * dist / max_dim)
ripple = ripple * decay_factor
# Radial displacement - use ADDITION to match Python prim_ripple_displace
# Python (primitives.py line 2890-2891):
# map_x = x_coords + ripple * norm_dx
# map_y = y_coords + ripple * norm_dy
# where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle)
angle = jnp.arctan2(dy, dx)
src_x = x_coords + ripple * jnp.cos(angle)
src_y = y_coords + ripple * jnp.sin(angle)
if frame is None:
return {'x': src_x, 'y': src_y}
# Sample using bilinear interpolation (jax_sample clamps coords,
# matching OpenCV's default remap behavior)
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:coords-x' or op == 'coords-x':
# Extract x coordinates from coord dict
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords.get('x', coords.get('map_x'))
return coords[0] if isinstance(coords, (list, tuple)) else coords
if op == 'geometry:coords-y' or op == 'coords-y':
# Extract y coordinates from coord dict
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords.get('y', coords.get('map_y'))
return coords[1] if isinstance(coords, (list, tuple)) else coords
if op == 'geometry:remap' or op == 'remap':
# Remap image using coordinate maps: (frame map_x map_y)
# OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords
frame = self._eval(args[0], env)
map_x = self._eval(args[1], env)
map_y = self._eval(args[2], env)
h, w = frame.shape[:2]
# Flatten coordinate maps
src_x = map_x.flatten()
src_y = map_y.flatten()
# Sample using bilinear interpolation (jax_sample clamps coords internally,
# matching OpenCV's default behavior)
r_out, g_out, b_out = jax_sample(frame, src_x, src_y)
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope':
# Two signatures: (frame segments) or (w h segments cx cy)
if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'):
# (w h segments cx cy) signature
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
segments = int(self._eval(args[2], env)) if len(args) > 2 else 6
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
frame = None
else:
frame = self._eval(args[0], env)
segments = int(self._eval(args[1], env)) if len(args) > 1 else 6
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
# Mirror into segments
segment_angle = 2 * jnp.pi / segments
theta_mod = theta % segment_angle
theta_mirror = jnp.where(
(jnp.floor(theta / segment_angle) % 2) == 0,
theta_mod,
segment_angle - theta_mod
)
src_x = r * jnp.cos(theta_mirror) + cx
src_y = r * jnp.sin(theta_mirror) + cy
if frame is None:
# Return coordinate arrays
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# Geometry coordinate extraction
if op == 'geometry:coords-x':
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords['x']
return coords[0] if isinstance(coords, tuple) else coords
if op == 'geometry:coords-y':
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords['y']
return coords[1] if isinstance(coords, tuple) else coords
if op == 'geometry:remap' or op == 'remap':
frame = self._eval(args[0], env)
x_coords = self._eval(args[1], env)
y_coords = self._eval(args[2], env)
h, w = frame.shape[:2]
r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# =====================================================================
# Blending operations
# =====================================================================
if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
return jax_blend(frame1, frame2, alpha)
if op == 'blend-add':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_add(frame1, frame2)
if op == 'blend-multiply':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_multiply(frame1, frame2)
if op == 'blend-screen':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_screen(frame1, frame2)
if op == 'blend-overlay':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_overlay(frame1, frame2)
# =====================================================================
# Image dimension queries (namespaced aliases)
# =====================================================================
if op == 'image:width':
if args:
frame = self._eval(args[0], env)
return frame.shape[1] # width is second dimension (h, w, c)
return env['width']
if op == 'image:height':
if args:
frame = self._eval(args[0], env)
return frame.shape[0] # height is first dimension (h, w, c)
return env['height']
# =====================================================================
# Utility
# =====================================================================
if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?':
x = self._eval(args[0], env)
return jax_is_nil(x)
# =====================================================================
# Xector channel operations (shortcuts)
# =====================================================================
if op == 'red':
val = self._eval(args[0], env)
# Works on frames or pixel tuples
if isinstance(val, tuple):
return val[0]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 0)
else:
return val # Assume it's already a channel
if op == 'green':
val = self._eval(args[0], env)
if isinstance(val, tuple):
return val[1]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 1)
else:
return val
if op == 'blue':
val = self._eval(args[0], env)
if isinstance(val, tuple):
return val[2]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 2)
else:
return val
if op == 'gray' or op == 'luminance':
val = self._eval(args[0], env)
# Handle tuple (r, g, b) from map-pixels
if isinstance(val, tuple) and len(val) == 3:
r, g, b = val
return r * 0.299 + g * 0.587 + b * 0.114
# Handle frame
frame = val
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
return r * 0.299 + g * 0.587 + b * 0.114
if op == 'rgb':
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
# For scalars (e.g., in map-pixels), return tuple
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
if r_is_scalar and g_is_scalar and b_is_scalar:
return (r, g, b)
return jax_merge_channels(r, g, b, env['_shape'])
# =====================================================================
# Coordinate operations
# =====================================================================
if op == 'x-coords':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
return jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
if op == 'y-coords':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
if op == 'dist-from-center':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx
y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy
return jnp.sqrt(x*x + y*y)
# =====================================================================
# Alpha operations (element-wise on xectors)
# =====================================================================
if op == 'α/' or op == 'alpha/':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a / b
if op == 'α+' or op == 'alpha+':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result + v
return result
if op == 'α*' or op == 'alpha*':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result * v
return result
if op == 'α-' or op == 'alpha-':
if len(args) == 1:
return -self._eval(args[0], env)
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a - b
if op == 'αclamp' or op == 'alpha-clamp':
x = self._eval(args[0], env)
lo = self._eval(args[1], env)
hi = self._eval(args[2], env)
return jnp.clip(x, lo, hi)
if op == 'αmin' or op == 'alpha-min':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.minimum(a, b)
if op == 'αmax' or op == 'alpha-max':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.maximum(a, b)
if op == 'αsqrt' or op == 'alpha-sqrt':
return jnp.sqrt(self._eval(args[0], env))
if op == 'αsin' or op == 'alpha-sin':
return jnp.sin(self._eval(args[0], env))
if op == 'αcos' or op == 'alpha-cos':
return jnp.cos(self._eval(args[0], env))
if op == 'αmod' or op == 'alpha-mod':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a % b
if op == 'α²' or op == 'αsq' or op == 'alpha-sq':
x = self._eval(args[0], env)
return x * x
if op == 'α<' or op == 'alpha<':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a < b
if op == 'α>' or op == 'alpha>':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a > b
if op == 'α<=' or op == 'alpha<=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a <= b
if op == 'α>=' or op == 'alpha>=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a >= b
if op == 'α=' or op == 'alpha=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a == b
if op == 'αfloor' or op == 'alpha-floor':
return jnp.floor(self._eval(args[0], env))
if op == 'αceil' or op == 'alpha-ceil':
return jnp.ceil(self._eval(args[0], env))
if op == 'αround' or op == 'alpha-round':
return jnp.round(self._eval(args[0], env))
if op == 'αabs' or op == 'alpha-abs':
return jnp.abs(self._eval(args[0], env))
if op == 'αexp' or op == 'alpha-exp':
return jnp.exp(self._eval(args[0], env))
if op == 'αlog' or op == 'alpha-log':
return jnp.log(self._eval(args[0], env))
if op == 'αor' or op == 'alpha-or':
# Element-wise logical OR
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_or(a, b)
if op == 'αand' or op == 'alpha-and':
# Element-wise logical AND
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_and(a, b)
if op == 'αnot' or op == 'alpha-not':
# Element-wise logical NOT
return jnp.logical_not(self._eval(args[0], env))
if op == 'αxor' or op == 'alpha-xor':
# Element-wise logical XOR
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_xor(a, b)
# =====================================================================
# Threading/arrow operations
# =====================================================================
if op == '->':
# Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b)
val = self._eval(args[0], env)
for form in args[1:]:
if isinstance(form, list):
# Insert val as first argument
fn_name = form[0].name if isinstance(form[0], Symbol) else form[0]
new_args = [val] + [self._eval(a, env) for a in form[1:]]
val = self._eval_call(fn_name, [val] + form[1:], env)
else:
# Simple function call
fn_name = form.name if isinstance(form, Symbol) else form
val = self._eval_call(fn_name, [args[0]], env)
return val
# =====================================================================
# Range and iteration
# =====================================================================
if op == 'range':
if len(args) == 1:
end = int(self._eval(args[0], env))
return list(range(end))
elif len(args) == 2:
start = int(self._eval(args[0], env))
end = int(self._eval(args[1], env))
return list(range(start, end))
else:
start = int(self._eval(args[0], env))
end = int(self._eval(args[1], env))
step = int(self._eval(args[2], env))
return list(range(start, end, step))
if op == 'reduce' or op == 'fold':
# (reduce seq init fn) - left fold
seq = self._eval(args[0], env)
acc = self._eval(args[1], env)
fn = args[2] # Lambda S-expression
# Handle lambda
if isinstance(fn, list) and len(fn) >= 3:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
for item in seq:
fn_env = env.copy()
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
fn_env[param_name] = acc
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
fn_env[param_name] = item
acc = self._eval(body, fn_env)
return acc
# Fallback - try evaluating fn and calling it
fn_eval = self._eval(fn, env)
if callable(fn_eval):
for item in seq:
acc = fn_eval(acc, item)
return acc
if op == 'fold-indexed':
# (fold-indexed seq init fn) - fold with index
# fn takes (acc item index) or (acc item index cursor) for typography
seq = self._eval(args[0], env)
acc = self._eval(args[1], env)
fn = args[2] # Lambda S-expression
# Handle lambda
if isinstance(fn, list) and len(fn) >= 3:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
for idx, item in enumerate(seq):
fn_env = env.copy()
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
fn_env[param_name] = acc
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
fn_env[param_name] = item
if len(params) >= 3:
param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2])
fn_env[param_name] = idx
acc = self._eval(body, fn_env)
return acc
# Fallback
fn_eval = self._eval(fn, env)
if callable(fn_eval):
for idx, item in enumerate(seq):
acc = fn_eval(acc, item, idx)
return acc
# =====================================================================
# Map-pixels (apply function to each pixel)
# =====================================================================
if op == 'map-pixels':
frame = self._eval(args[0], env)
fn = args[1] # Lambda or S-expression
h, w = frame.shape[:2]
# Extract channels
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
# Set up pixel environment
pixel_env = env.copy()
pixel_env['r'] = r
pixel_env['g'] = g
pixel_env['b'] = b
pixel_env['x'] = x_coords
pixel_env['y'] = y_coords
# Also provide c (color) as a tuple for lambda (x y c) style
pixel_env['c'] = (r, g, b)
# If fn is a lambda, we need to handle it specially
if isinstance(fn, list) and len(fn) >= 2:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
# Lambda: (lambda (x y c) body)
params = fn[1]
body = fn[2]
# Bind parameters
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
pixel_env[param_name] = x_coords
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
pixel_env[param_name] = y_coords
if len(params) >= 3:
param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2])
pixel_env[param_name] = (r, g, b)
result = self._eval(body, pixel_env)
else:
result = self._eval(fn, pixel_env)
else:
result = self._eval(fn, pixel_env)
if isinstance(result, tuple) and len(result) == 3:
nr, ng, nb = result
return jax_merge_channels(nr, ng, nb, (h, w))
elif hasattr(result, 'shape') and result.ndim == 3:
return result
else:
# Single channel result
if hasattr(result, 'flatten'):
result = result.flatten()
gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# =====================================================================
# State operations (return unchanged for stateless JIT)
# =====================================================================
if op == 'state-get':
key = self._eval(args[0], env)
default = self._eval(args[1], env) if len(args) > 1 else None
return default # State not supported in JIT, return default
if op == 'state-set':
return None # No-op in JIT
# =====================================================================
# Cell/grid operations
# =====================================================================
if op == 'local-x-norm':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return (x % cell_size) / max(1, cell_size - 1)
if op == 'local-y-norm':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return (y % cell_size) / max(1, cell_size - 1)
if op == 'local-x':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return (x % cell_size).astype(jnp.float32)
if op == 'local-y':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return (y % cell_size).astype(jnp.float32)
if op == 'cell-row':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return jnp.floor(y / cell_size)
if op == 'cell-col':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return jnp.floor(x / cell_size)
if op == 'num-rows':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
return frame.shape[0] // cell_size
if op == 'num-cols':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
return frame.shape[1] // cell_size
# =====================================================================
# Control flow
# =====================================================================
if op == 'cond':
# (cond (test1 expr1) (test2 expr2) ... (else exprN))
# For JAX compatibility, build a nested jnp.where structure
# Start from the else clause and work backwards
# Collect clauses
clauses = []
else_expr = None
for clause in args:
if isinstance(clause, list) and len(clause) >= 2:
test = clause[0]
if isinstance(test, Symbol) and test.name == 'else':
else_expr = clause[1]
else:
clauses.append((test, clause[1]))
# If no else, default to None/0
if else_expr is not None:
result = self._eval(else_expr, env)
else:
result = 0
# Build nested where from last to first
for test_expr, val_expr in reversed(clauses):
cond_val = self._eval(test_expr, env)
then_val = self._eval(val_expr, env)
# Check if condition is array or scalar
if hasattr(cond_val, 'shape') and cond_val.shape != ():
# Array condition - use jnp.where
result = jnp.where(cond_val, then_val, result)
else:
# Scalar - can use Python if
if cond_val:
result = then_val
return result
if op == 'set!' or op == 'set':
# Mutation - not really supported in JAX, but we can update env
var = args[0].name if isinstance(args[0], Symbol) else str(args[0])
val = self._eval(args[1], env)
env[var] = val
return val
if op == 'begin' or op == 'do':
# Evaluate all expressions, return last
result = None
for expr in args:
result = self._eval(expr, env)
return result
# =====================================================================
# Additional math
# =====================================================================
if op == 'sq' or op == 'square':
x = self._eval(args[0], env)
return x * x
if op == 'lerp':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
t = self._eval(args[2], env)
return a * (1 - t) + b * t
if op == 'smoothstep':
edge0 = self._eval(args[0], env)
edge1 = self._eval(args[1], env)
x = self._eval(args[2], env)
t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1)
return t * t * (3 - 2 * t)
if op == 'atan2':
y = self._eval(args[0], env)
x = self._eval(args[1], env)
return jnp.arctan2(y, x)
if op == 'fract' or op == 'frac':
x = self._eval(args[0], env)
return x - jnp.floor(x)
# =====================================================================
# Frame copy and construction operations
# =====================================================================
if op == 'pixel':
# Get pixel at (x, y) from frame
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
h, w = frame.shape[:2]
# Convert to int and clip to bounds
if isinstance(x, (int, float)):
x = max(0, min(int(x), w - 1))
else:
x = jnp.clip(x, 0, w - 1).astype(jnp.int32)
if isinstance(y, (int, float)):
y = max(0, min(int(y), h - 1))
else:
y = jnp.clip(y, 0, h - 1).astype(jnp.int32)
r = frame[y, x, 0]
g = frame[y, x, 1]
b = frame[y, x, 2]
return (r, g, b)
if op == 'copy':
frame = self._eval(args[0], env)
return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame)
if op == 'make-image':
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
if len(args) > 2:
color = self._eval(args[2], env)
if isinstance(color, (list, tuple)):
r, g, b = int(color[0]), int(color[1]), int(color[2])
else:
r = g = b = int(color)
else:
r = g = b = 0
img = jnp.zeros((h, w, 3), dtype=jnp.uint8)
img = img.at[:, :, 0].set(r)
img = img.at[:, :, 1].set(g)
img = img.at[:, :, 2].set(b)
return img
if op == 'paste':
dest = self._eval(args[0], env)
src = self._eval(args[1], env)
x = int(self._eval(args[2], env))
y = int(self._eval(args[3], env))
sh, sw = src.shape[:2]
dh, dw = dest.shape[:2]
# Clip to dest bounds
x1, y1 = max(0, x), max(0, y)
x2, y2 = min(dw, x + sw), min(dh, y + sh)
sx1, sy1 = x1 - x, y1 - y
sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1)
result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest)
result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :])
return result
# =====================================================================
# Blending operations
# =====================================================================
if op == 'blending:blend-images' or op == 'blend-images':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
return jax_blend(a, b, alpha)
if op == 'blending:blend-mode' or op == 'blend-mode':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
mode = self._eval(args[2], env) if len(args) > 2 else 'add'
if mode == 'add':
return jax_blend_add(a, b)
elif mode == 'multiply':
return jax_blend_multiply(a, b)
elif mode == 'screen':
return jax_blend_screen(a, b)
elif mode == 'overlay':
return jax_blend_overlay(a, b)
elif mode == 'lighten':
return jnp.maximum(a, b)
elif mode == 'darken':
return jnp.minimum(a, b)
elif mode == 'difference':
return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8)
else:
return jax_blend(a, b, 0.5)
# =====================================================================
# Geometry coordinate operations
# =====================================================================
if op == 'geometry:wave-coords' or op == 'wave-coords':
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
axis = self._eval(args[2], env) if len(args) > 2 else 'x'
freq = self._eval(args[3], env) if len(args) > 3 else 1.0
amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0
phase = self._eval(args[5], env) if len(args) > 5 else 0.0
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
if axis == 'x' or axis == 'horizontal':
# Wave displaces X based on Y
offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
src_x = x_coords + offset
src_y = y_coords
elif axis == 'y' or axis == 'vertical':
# Wave displaces Y based on X
offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
src_x = x_coords
src_y = y_coords + offset
else: # both
offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
src_x = x_coords + offset_x
src_y = y_coords + offset_y
return {'x': src_x, 'y': src_y}
if op == 'geometry:coords-x' or op == 'coords-x':
coords = self._eval(args[0], env)
return coords['x']
if op == 'geometry:coords-y' or op == 'coords-y':
coords = self._eval(args[0], env)
return coords['y']
if op == 'geometry:remap' or op == 'remap':
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
h, w = frame.shape[:2]
r, g, b = jax_sample(frame, x.flatten(), y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# =====================================================================
# Glitch effects
# =====================================================================
if op == 'pixelsort':
frame = self._eval(args[0], env)
sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness'
thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50
thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200
angle = int(self._eval(args[4], env)) if len(args) > 4 else 0
reverse = self._eval(args[5], env) if len(args) > 5 else False
h, w = frame.shape[:2]
result = frame.copy()
# Get luminance for thresholding
lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
# Sort each row
for y in range(h):
row_lum = lum[y, :]
row = frame[y, :, :]
# Find mask of pixels to sort
mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi)
# Get indices where we should sort
sort_indices = jnp.where(mask, jnp.arange(w), -1)
# Simple sort by luminance for the row
if sort_by == 'lightness':
sort_key = row_lum
elif sort_by == 'hue':
# Approximate hue from RGB
sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32),
row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32)))
else:
sort_key = row_lum
# Sort pixels in masked region
sorted_indices = jnp.argsort(sort_key)
if reverse:
sorted_indices = sorted_indices[::-1]
# Apply partial sort (only where mask is true)
# This is a simplified version - full pixelsort is more complex
result = result.at[y, :, :].set(row[sorted_indices])
return result
if op == 'datamosh':
frame = self._eval(args[0], env)
prev = self._eval(args[1], env)
block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32
corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3
max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50
color_corrupt = self._eval(args[5], env) if len(args) > 5 else True
h, w = frame.shape[:2]
# Use deterministic random for JIT with frame variation
seed = env.get('_seed', 42)
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
key = make_jax_key(seed, frame_num, op_counter)
num_blocks_y = h // block_size
num_blocks_x = w // block_size
total_blocks = num_blocks_y * num_blocks_x
# Pre-generate all random values at once (vectorized)
key, k1, k2, k3, k4, k5 = jax.random.split(key, 6)
corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption
offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1)
offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1)
channels = jax.random.randint(k4, (total_blocks,), 0, 3)
color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51)
# Create coordinate grids for blocks
by_grid = jnp.arange(num_blocks_y)
bx_grid = jnp.arange(num_blocks_x)
# Create block coordinate arrays
by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...]
bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...]
# Create pixel coordinate grids
y_coords, x_coords = jnp.mgrid[:h, :w]
# Determine which block each pixel belongs to
pixel_block_y = y_coords // block_size
pixel_block_x = x_coords // block_size
pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x
# Clamp to valid block indices (for pixels outside the block grid)
pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1)
# Get the corrupt mask, offsets for each pixel's block
pixel_corrupt = corrupt_mask[pixel_block_idx]
pixel_offset_y = offsets_y[pixel_block_idx]
pixel_offset_x = offsets_x[pixel_block_idx]
# Calculate source coordinates with offset (clamped)
src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1)
src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1)
# Sample from previous frame at offset positions
prev_sampled = prev[src_y, src_x, :]
# Where corrupt mask is true, use prev_sampled; else use frame
result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame)
# Apply color corruption to corrupted blocks
if color_corrupt:
pixel_channel = channels[pixel_block_idx]
pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16)
# Create per-channel shift arrays (only shift the selected channel)
shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0)
shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0)
shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0)
result_int = result.astype(jnp.int16)
result_int = result_int.at[:, :, 0].add(shift_r)
result_int = result_int.at[:, :, 1].add(shift_g)
result_int = result_int.at[:, :, 2].add(shift_b)
result = jnp.clip(result_int, 0, 255).astype(jnp.uint8)
return result
# =====================================================================
# ASCII Art Operations (using pre-rendered font atlas)
# =====================================================================
if op == 'cell-sample':
# (cell-sample frame char_size) -> (colors, luminances)
# Downsample frame into cells, return average colors and luminances
frame = self._eval(args[0], env)
char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8
h, w = frame.shape[:2]
num_rows = h // char_size
num_cols = w // char_size
# Crop to exact multiple of char_size
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
# Reshape to (num_rows, char_size, num_cols, char_size, 3)
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
# Average over char_size dimensions -> (num_rows, num_cols, 3)
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
# Compute luminance per cell
colors_float = colors.astype(jnp.float32)
luminances = (0.299 * colors_float[:, :, 0] +
0.587 * colors_float[:, :, 1] +
0.114 * colors_float[:, :, 2]) / 255.0
return (colors, luminances)
if op == 'luminance-to-chars':
# (luminance-to-chars luminances alphabet contrast) -> char_indices
# Map luminance values to character indices
luminances = self._eval(args[0], env)
alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard'
contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5
# Get alphabet string
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
# Apply contrast
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
# Map to character indices (0 = darkest, num_chars-1 = brightest)
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
return char_indices
if op == 'render-char-grid':
# (render-char-grid frame chars colors char_size color_mode background_color invert_colors)
# Render character grid using font atlas
frame = self._eval(args[0], env)
char_indices = self._eval(args[1], env)
colors = self._eval(args[2], env)
char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8
color_mode = self._eval(args[4], env) if len(args) > 4 else 'color'
background_color = self._eval(args[5], env) if len(args) > 5 else 'black'
invert_colors = self._eval(args[6], env) if len(args) > 6 else 0
h, w = frame.shape[:2]
num_rows, num_cols = char_indices.shape
# Get the alphabet used (stored in env or default)
alphabet = env.get('_ascii_alphabet', 'standard')
alpha_str = _get_alphabet_string(alphabet)
# Get or create font atlas
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background color
if background_color == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background_color == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
# Try to parse hex color
try:
if background_color.startswith('#'):
bg_hex = background_color[1:]
bg = jnp.array([int(bg_hex[0:2], 16),
int(bg_hex[2:4], 16),
int(bg_hex[4:6], 16)], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
except:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Create output image starting with background
output_h = num_rows * char_size
output_w = num_cols * char_size
result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy()
# Gather characters from atlas based on indices
# char_indices shape: (num_rows, num_cols)
# font_atlas shape: (num_chars, char_size, char_size, 3)
# Convert numpy atlas to JAX for indexing with traced indices
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3)
# Reshape to grid
char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3)
# Create coordinate grids for output pixels
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = y_out // char_size
cell_col = x_out // char_size
local_y = y_out % char_size
local_x = x_out % char_size
# Clamp to valid ranges
cell_row = jnp.clip(cell_row, 0, num_rows - 1)
cell_col = jnp.clip(cell_col, 0, num_cols - 1)
# Get character pixel values
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
# Get character brightness (for masking)
char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0
# Handle color modes
if color_mode == 'mono':
# White characters on background
fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
fg = jnp.broadcast_to(fg_color, char_pixels.shape)
elif color_mode == 'invert':
# Inverted cell colors
cell_colors = colors[cell_row, cell_col]
fg = 255 - cell_colors
else:
# 'color' mode - use cell colors
fg = colors[cell_row, cell_col]
# Blend foreground onto background based on character brightness
if invert_colors:
# Swap fg and bg
fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg
result = (fg.astype(jnp.float32) * (1 - char_brightness) +
bg_broadcast.astype(jnp.float32) * char_brightness)
else:
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) +
fg.astype(jnp.float32) * char_brightness)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize back to original frame size if needed
if result.shape[0] != h or result.shape[1] != w:
# Simple nearest-neighbor resize
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = (jnp.arange(h) * y_scale).astype(jnp.int32)
x_src = (jnp.arange(w) * x_scale).astype(jnp.int32)
y_src = jnp.clip(y_src, 0, result.shape[0] - 1)
x_src = jnp.clip(x_src, 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'ascii-fx-zone':
# Complex ASCII effect with per-zone expressions
# (ascii-fx-zone frame :cols cols :alphabet alphabet ...)
frame = self._eval(args[0], env)
# Parse keyword arguments
kwargs = {}
i = 1
while i < len(args):
if isinstance(args[i], Keyword):
key = args[i].name
if i + 1 < len(args):
kwargs[key] = args[i + 1]
i += 2
else:
i += 1
# Get parameters
cols = int(self._eval(kwargs.get('cols', 80), env))
char_size_param = kwargs.get('char_size')
alphabet = self._eval(kwargs.get('alphabet', 'standard'), env)
color_mode = self._eval(kwargs.get('color_mode', 'color'), env)
background = self._eval(kwargs.get('background', 'black'), env)
contrast = float(self._eval(kwargs.get('contrast', 1.5), env))
h, w = frame.shape[:2]
# Calculate char_size from cols if not specified
if char_size_param is not None:
char_size_val = self._eval(char_size_param, env)
if char_size_val is not None:
char_size = int(char_size_val)
else:
char_size = w // cols
else:
char_size = w // cols
char_size = max(4, min(char_size, 64))
# Store alphabet for render-char-grid to use
env['_ascii_alphabet'] = alphabet
# Cell sampling
num_rows = h // char_size
num_cols = w // char_size
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
# Compute luminances
colors_float = colors.astype(jnp.float32)
luminances = (0.299 * colors_float[:, :, 0] +
0.587 * colors_float[:, :, 1] +
0.114 * colors_float[:, :, 2]) / 255.0
# Get alphabet and map luminances to chars
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
# Handle optional per-zone effects (char_hue, char_saturation, etc.)
# These would modify colors based on zone position
char_hue = kwargs.get('char_hue')
char_saturation = kwargs.get('char_saturation')
char_brightness = kwargs.get('char_brightness')
if char_hue is not None or char_saturation is not None or char_brightness is not None:
# Create zone coordinate arrays for expression evaluation
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
row_norm = row_coords / max(num_rows - 1, 1)
col_norm = col_coords / max(num_cols - 1, 1)
# Bind zone variables
zone_env = env.copy()
zone_env['zone-row'] = row_coords
zone_env['zone-col'] = col_coords
zone_env['zone-row-norm'] = row_norm
zone_env['zone-col-norm'] = col_norm
zone_env['zone-lum'] = luminances
# Apply color modifications (simplified - full version would use HSV)
if char_brightness is not None:
brightness_mult = self._eval(char_brightness, zone_env)
if brightness_mult is not None:
colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None],
0, 255).astype(jnp.uint8)
# Render using font atlas
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background color
if background == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Gather characters - convert numpy atlas to JAX for traced indexing
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
# Create output
output_h = num_rows * char_size
output_w = num_cols * char_size
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
local_y = y_out % char_size
local_x = x_out % char_size
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
if color_mode == 'mono':
fg = jnp.full_like(char_pixels, 255)
else:
fg = colors[cell_row, cell_col]
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
fg.astype(jnp.float32) * char_bright)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize to original dimensions
if result.shape[0] != h or result.shape[1] != w:
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'render-char-grid-fx':
# Enhanced render with per-character effects
# (render-char-grid-fx frame chars colors luminances char_size
# color_mode bg_color invert_colors
# char_jitter char_scale char_rotation char_hue_shift
# jitter_source scale_source rotation_source hue_source)
frame = self._eval(args[0], env)
char_indices = self._eval(args[1], env)
colors = self._eval(args[2], env)
luminances = self._eval(args[3], env)
char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8
color_mode = self._eval(args[5], env) if len(args) > 5 else 'color'
background_color = self._eval(args[6], env) if len(args) > 6 else 'black'
invert_colors = self._eval(args[7], env) if len(args) > 7 else 0
# Per-char effect amounts
char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0
char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0
char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0
char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0
# Modulation sources
jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none'
scale_source = self._eval(args[13], env) if len(args) > 13 else 'none'
rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none'
hue_source = self._eval(args[15], env) if len(args) > 15 else 'none'
h, w = frame.shape[:2]
num_rows, num_cols = char_indices.shape
# Get alphabet
alphabet = env.get('_ascii_alphabet', 'standard')
alpha_str = _get_alphabet_string(alphabet)
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background
if background_color == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background_color == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Create modulation values based on source
def get_modulation(source, lums, num_rows, num_cols):
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
row_norm = row_coords / max(num_rows - 1, 1)
col_norm = col_coords / max(num_cols - 1, 1)
if source == 'luminance':
return lums
elif source == 'inv_luminance':
return 1.0 - lums
elif source == 'position_x':
return col_norm
elif source == 'position_y':
return row_norm
elif source == 'position_diag':
return (row_norm + col_norm) / 2
elif source == 'center_dist':
cy, cx = 0.5, 0.5
dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2)
return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal
elif source == 'random':
# Use frame-varying key for random source
seed = env.get('_seed', 42)
op_ctr = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_ctr + 1
key = make_jax_key(seed, frame_num, op_ctr)
return jax.random.uniform(key, (num_rows, num_cols))
else:
return jnp.zeros((num_rows, num_cols))
# Get modulation values
jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols)
scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols)
rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols)
hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols)
# Gather characters - convert numpy atlas to JAX for traced indexing
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
# Create output
output_h = num_rows * char_size
output_w = num_cols * char_size
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
local_y = y_out % char_size
local_x = x_out % char_size
# Apply jitter if enabled
if char_jitter > 0:
jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter
# Use frame-varying key for jitter
seed = env.get('_seed', 42)
op_ctr = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_ctr + 1
key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2)
# Generate deterministic jitter per cell
jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1)
jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1)
offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
local_y = jnp.clip(local_y + offset_y, 0, char_size - 1)
local_x = jnp.clip(local_x + offset_x, 0, char_size - 1)
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
# Get foreground colors
if color_mode == 'mono':
fg = jnp.full_like(char_pixels, 255)
else:
fg = colors[cell_row, cell_col]
# Apply hue shift if enabled
if char_hue_shift > 0 and color_mode == 'color':
hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift
# Simple hue rotation via channel cycling
fg_float = fg.astype(jnp.float32)
shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB
# Simplified: blend channels based on shift
r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2]
shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac
# Just do a simple tint for now
fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8)
# Blend
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
if invert_colors:
result = (fg.astype(jnp.float32) * (1 - char_bright) +
bg_broadcast.astype(jnp.float32) * char_bright)
else:
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
fg.astype(jnp.float32) * char_bright)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize to original
if result.shape[0] != h or result.shape[1] != w:
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'alphabet-char':
# (alphabet-char alphabet-name index) -> char_index in that alphabet
alphabet = self._eval(args[0], env)
index = self._eval(args[1], env)
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
# Handle both scalar and array indices
if hasattr(index, 'shape'):
index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1)
else:
index = max(0, min(int(index), num_chars - 1))
return index
if op == 'map-char-grid':
# (map-char-grid base-chars luminances (lambda (r c ch lum) ...))
# Map over character grid, allowing per-cell character selection
base_chars = self._eval(args[0], env)
luminances = self._eval(args[1], env)
fn = args[2] # Lambda expression
num_rows, num_cols = base_chars.shape
# For JAX compatibility, we can't use Python loops with traced values
# Instead, we'll evaluate the lambda for the whole grid at once
if isinstance(fn, list) and len(fn) >= 3:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
# Create grid coordinates
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
# Bind parameters for whole-grid evaluation
fn_env = env.copy()
# Params: (r c ch lum)
if len(params) >= 1:
fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords
if len(params) >= 2:
fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords
if len(params) >= 3:
fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars
if len(params) >= 4:
# Luminances scaled to 0-255 range
fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32)
# Evaluate body - should return new character indices
result = self._eval(body, fn_env)
if hasattr(result, 'shape'):
return result.astype(jnp.int32)
return base_chars
return base_chars
# =====================================================================
# List operations
# =====================================================================
if op == 'take':
seq = self._eval(args[0], env)
n = int(self._eval(args[1], env))
if isinstance(seq, (list, tuple)):
return seq[:n]
return seq[:n] # Works for arrays too
if op == 'cons':
item = self._eval(args[0], env)
seq = self._eval(args[1], env)
if isinstance(seq, list):
return [item] + seq
elif isinstance(seq, tuple):
return (item,) + seq
return jnp.concatenate([jnp.array([item]), seq])
if op == 'roll':
arr = self._eval(args[0], env)
shift = self._eval(args[1], env)
axis = self._eval(args[2], env) if len(args) > 2 else 0
# Convert to int for concrete values, keep as-is for JAX traced values
if isinstance(shift, (int, float)):
shift = int(shift)
elif hasattr(shift, 'astype'):
shift = shift.astype(jnp.int32)
if isinstance(axis, (int, float)):
axis = int(axis)
return jnp.roll(arr, shift, axis=axis)
# =====================================================================
# Pi constant
# =====================================================================
if op == 'pi':
return jnp.pi
raise ValueError(f"Unknown operation: {op}")
# =============================================================================
# Public API
# =============================================================================
def compile_effect(code: str) -> Callable:
"""
Compile an S-expression effect to a JAX function.
Args:
code: S-expression effect code
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
# Check cache
cache_key = hashlib.md5(code.encode()).hexdigest()
if cache_key in _COMPILED_EFFECTS:
return _COMPILED_EFFECTS[cache_key]
# Parse and compile
sexp = parse(code)
compiler = JaxCompiler()
fn = compiler.compile_effect(sexp)
_COMPILED_EFFECTS[cache_key] = fn
return fn
def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable:
"""
Compile an effect from a .sexp file.
Args:
path: Path to the .sexp effect file
derived_paths: Optional list of paths to derived.sexp files to load
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
with open(path, 'r') as f:
code = f.read()
# Parse all expressions in file
exprs = parse_all(code)
# Create compiler
compiler = JaxCompiler()
# Load derived files if specified
if derived_paths:
for dp in derived_paths:
compiler.load_derived(dp)
# Process expressions - find require statements and the effect
effect_sexp = None
effect_dir = Path(path).parent
for expr in exprs:
if not isinstance(expr, list) or len(expr) < 2:
continue
head = expr[0]
if not isinstance(head, Symbol):
continue
if head.name == 'require':
# (require "derived") or (require "path/to/file")
req_path = expr[1]
if isinstance(req_path, str):
# Resolve relative to effect file
if not req_path.endswith('.sexp'):
req_path = req_path + '.sexp'
full_path = effect_dir / req_path
if not full_path.exists():
# Try sexp_effects directory
full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path
if full_path.exists():
compiler.load_derived(str(full_path))
elif head.name == 'require-primitives':
# (require-primitives "lib") - currently ignored for JAX
# JAX has all primitives built-in
pass
elif head.name in ('effect', 'define-effect'):
effect_sexp = expr
if effect_sexp is None:
raise ValueError(f"No effect definition found in {path}")
return compiler.compile_effect(effect_sexp)
def load_derived(derived_path: str = None) -> Dict[str, Callable]:
"""
Load derived operations from derived.sexp.
Returns dict of compiled functions that can be used in effects.
"""
if derived_path is None:
derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp'
with open(derived_path, 'r') as f:
code = f.read()
exprs = parse_all(code)
compiler = JaxCompiler()
env = {}
for expr in exprs:
if isinstance(expr, list) and len(expr) >= 3:
head = expr[0]
if isinstance(head, Symbol) and head.name == 'define':
compiler._eval_define(expr[1:], env)
return env
# =============================================================================
# Test / Demo
# =============================================================================
if __name__ == '__main__':
import numpy as np
# Test effect
test_effect = '''
(effect "threshold-test"
:params ((threshold :default 128))
:body (let* ((r (channel frame 0))
(g (channel frame 1))
(b (channel frame 2))
(gray (+ (* r 0.299) (* g 0.587) (* b 0.114)))
(mask (where (> gray threshold) 255 0)))
(merge-channels mask mask mask)))
'''
print("Compiling effect...")
run_effect = compile_effect(test_effect)
# Create test frame
print("Creating test frame...")
frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8)
# Run effect
print("Running effect (first run includes JIT compilation)...")
import time
t0 = time.time()
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms")
# Second run should be faster
t0 = time.time()
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"Second run (cached): {(t1-t0)*1000:.2f}ms")
# Multiple runs
t0 = time.time()
for _ in range(100):
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg")
print(f"\nResult shape: {result.shape}, dtype: {result.dtype}")
print("Done!")