All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
4629 lines
178 KiB
Python
4629 lines
178 KiB
Python
"""
|
||
Sexp to JAX Compiler.
|
||
|
||
Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU.
|
||
Uses XLA compilation via @jax.jit for automatic kernel fusion.
|
||
|
||
Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles
|
||
S-expressions directly to JAX operations which XLA then optimizes.
|
||
|
||
Usage:
|
||
from streaming.sexp_to_jax import compile_effect
|
||
|
||
effect_code = '''
|
||
(effect "threshold"
|
||
:params ((threshold :default 128))
|
||
:body (let ((g (gray frame)))
|
||
(rgb (where (> g threshold) 255 0)
|
||
(where (> g threshold) 255 0)
|
||
(where (> g threshold) 255 0))))
|
||
'''
|
||
|
||
run_effect = compile_effect(effect_code)
|
||
output = run_effect(frame, threshold=128)
|
||
"""
|
||
|
||
import jax
|
||
import jax.numpy as jnp
|
||
from jax import lax
|
||
from functools import partial
|
||
from typing import Any, Dict, List, Callable, Optional, Tuple
|
||
import hashlib
|
||
import numpy as np
|
||
|
||
# Import parser
|
||
import sys
|
||
from pathlib import Path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
|
||
|
||
# Import typography primitives
|
||
from streaming.jax_typography import bind_typography_primitives
|
||
|
||
|
||
# =============================================================================
|
||
# Compilation Cache
|
||
# =============================================================================
|
||
|
||
_COMPILED_EFFECTS: Dict[str, Callable] = {}
|
||
|
||
|
||
# =============================================================================
|
||
# Font Atlas for ASCII Effects
|
||
# =============================================================================
|
||
|
||
# Character sets for ASCII rendering
|
||
ASCII_ALPHABETS = {
|
||
'standard': ' .:-=+*#%@',
|
||
'blocks': ' ░▒▓█',
|
||
'simple': ' .:oO@',
|
||
'digits': ' 0123456789',
|
||
'binary': ' 01',
|
||
'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$',
|
||
}
|
||
|
||
# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array
|
||
_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {}
|
||
|
||
|
||
def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray:
|
||
"""
|
||
Create a font atlas with all characters pre-rendered.
|
||
|
||
Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time.
|
||
|
||
Args:
|
||
alphabet: String of characters to render (ordered by brightness, dark to light)
|
||
char_size: Size of each character cell in pixels
|
||
font_name: Optional font name/path (uses default monospace if None)
|
||
|
||
Returns:
|
||
NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters
|
||
Each character is white on black background.
|
||
"""
|
||
cache_key = (alphabet, char_size, font_name)
|
||
if cache_key in _FONT_ATLAS_CACHE:
|
||
return _FONT_ATLAS_CACHE[cache_key]
|
||
|
||
try:
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
except ImportError:
|
||
# Fallback: create simple block-based atlas without PIL
|
||
return _create_block_atlas(alphabet, char_size)
|
||
|
||
num_chars = len(alphabet)
|
||
atlas = []
|
||
|
||
# Try to load a monospace font
|
||
font = None
|
||
font_size = int(char_size * 0.9) # Slightly smaller than cell
|
||
|
||
# Try various monospace fonts
|
||
font_candidates = [
|
||
font_name,
|
||
'/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
|
||
'/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf',
|
||
'/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf',
|
||
'/System/Library/Fonts/Menlo.ttc', # macOS
|
||
'/System/Library/Fonts/Monaco.dfont', # macOS
|
||
'C:\\Windows\\Fonts\\consola.ttf', # Windows
|
||
]
|
||
|
||
for font_path in font_candidates:
|
||
if font_path is None:
|
||
continue
|
||
try:
|
||
font = ImageFont.truetype(font_path, font_size)
|
||
break
|
||
except (IOError, OSError):
|
||
continue
|
||
|
||
if font is None:
|
||
# Use default font
|
||
try:
|
||
font = ImageFont.load_default()
|
||
except:
|
||
# Ultimate fallback to blocks
|
||
return _create_block_atlas(alphabet, char_size)
|
||
|
||
for char in alphabet:
|
||
# Create image for this character
|
||
img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0))
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
# Get text bounding box for centering
|
||
try:
|
||
bbox = draw.textbbox((0, 0), char, font=font)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
except AttributeError:
|
||
# Older PIL versions
|
||
text_width, text_height = draw.textsize(char, font=font)
|
||
|
||
# Center the character
|
||
x = (char_size - text_width) // 2
|
||
y = (char_size - text_height) // 2
|
||
|
||
# Draw white character on black background
|
||
draw.text((x, y), char, fill=(255, 255, 255), font=font)
|
||
|
||
# Convert to numpy array (NOT jax array - avoids tracer issues)
|
||
char_array = np.array(img, dtype=np.uint8)
|
||
atlas.append(char_array)
|
||
|
||
atlas = np.stack(atlas, axis=0)
|
||
_FONT_ATLAS_CACHE[cache_key] = atlas
|
||
return atlas
|
||
|
||
|
||
def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray:
|
||
"""
|
||
Create a simple block-based atlas without fonts.
|
||
Uses numpy to avoid tracer issues.
|
||
"""
|
||
num_chars = len(alphabet)
|
||
atlas = []
|
||
|
||
for i, char in enumerate(alphabet):
|
||
# Brightness proportional to position in alphabet
|
||
brightness = int(255 * i / max(num_chars - 1, 1))
|
||
|
||
# Create a simple pattern based on character
|
||
img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8)
|
||
|
||
# Add some texture/pattern for visual interest
|
||
# Checkerboard pattern for mid-range characters
|
||
if 0.2 < i / num_chars < 0.8:
|
||
y_coords, x_coords = np.mgrid[:char_size, :char_size]
|
||
checker = ((x_coords + y_coords) % 2 == 0)
|
||
variation = int(brightness * 0.2)
|
||
img = np.where(checker[:, :, None],
|
||
np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8),
|
||
np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8))
|
||
|
||
atlas.append(img)
|
||
|
||
return np.stack(atlas, axis=0)
|
||
|
||
|
||
def _get_alphabet_string(alphabet_name: str) -> str:
|
||
"""Get the character string for a named alphabet or return as-is if custom."""
|
||
if alphabet_name in ASCII_ALPHABETS:
|
||
return ASCII_ALPHABETS[alphabet_name]
|
||
return alphabet_name # Assume it's a custom character string
|
||
|
||
|
||
# =============================================================================
|
||
# Text Rendering with Font Atlas (JAX-compatible)
|
||
# =============================================================================
|
||
|
||
# Default character set for text rendering (printable ASCII)
|
||
TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~
|
||
|
||
# Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height)
|
||
_TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {}
|
||
|
||
|
||
def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple:
|
||
"""
|
||
Create a font atlas for general text rendering with proper baseline alignment.
|
||
|
||
Font Metrics (from typography):
|
||
- Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l)
|
||
- Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y)
|
||
- Baseline: the line text "sits" on - all characters align to this
|
||
- Em-square: the design space, typically = ascender + descender
|
||
|
||
Returns:
|
||
(atlas, char_to_idx, char_widths, char_height, baseline_offset)
|
||
- atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA
|
||
- char_to_idx: dict mapping character to atlas index
|
||
- char_widths: numpy array of actual width for each character
|
||
- char_height: height of character cells (ascent + descent)
|
||
- baseline_offset: pixels from top of cell to baseline (= ascent)
|
||
"""
|
||
cache_key = (font_name, font_size)
|
||
if cache_key in _TEXT_ATLAS_CACHE:
|
||
return _TEXT_ATLAS_CACHE[cache_key]
|
||
|
||
try:
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
except ImportError:
|
||
raise ImportError("PIL/Pillow required for text rendering")
|
||
|
||
# Load font - match drawing.py's font order for consistency
|
||
font = None
|
||
font_candidates = [
|
||
font_name,
|
||
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py
|
||
'/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
|
||
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
|
||
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
|
||
'/System/Library/Fonts/Helvetica.ttc',
|
||
'/System/Library/Fonts/Arial.ttf',
|
||
'C:\\Windows\\Fonts\\arial.ttf',
|
||
]
|
||
|
||
for font_path in font_candidates:
|
||
if font_path is None:
|
||
continue
|
||
try:
|
||
font = ImageFont.truetype(font_path, font_size)
|
||
break
|
||
except (IOError, OSError):
|
||
continue
|
||
|
||
if font is None:
|
||
font = ImageFont.load_default()
|
||
|
||
# Get font metrics - this is the key to proper text layout
|
||
# getmetrics() returns (ascent, descent) where:
|
||
# ascent = pixels from baseline to top of tallest character
|
||
# descent = pixels from baseline to bottom of lowest character
|
||
ascent, descent = font.getmetrics()
|
||
|
||
# Cell dimensions based on font metrics (not per-character bounding boxes)
|
||
cell_height = ascent + descent + 2 # +2 for padding
|
||
baseline_y = ascent + 1 # Baseline position within cell (1px padding from top)
|
||
|
||
# Find max character width
|
||
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
|
||
temp_draw = ImageDraw.Draw(temp_img)
|
||
|
||
max_width = 0
|
||
char_widths_dict = {}
|
||
|
||
for char in TEXT_CHARSET:
|
||
try:
|
||
# Use getlength for horizontal advance (proper character spacing)
|
||
advance = font.getlength(char)
|
||
char_widths_dict[char] = int(advance)
|
||
max_width = max(max_width, int(advance))
|
||
except:
|
||
char_widths_dict[char] = font_size // 2
|
||
max_width = max(max_width, font_size // 2)
|
||
|
||
cell_width = max_width + 2 # +2 for padding
|
||
|
||
# Create atlas with all characters - draw same way as prim_text for pixel-perfect match
|
||
char_to_idx = {}
|
||
char_widths = [] # Advance widths
|
||
char_left_bearings = [] # Left bearing (x offset from origin to first pixel)
|
||
atlas = []
|
||
|
||
# Position to draw at within each tile (with margin for negative bearings)
|
||
draw_x = 5 # Margin for chars with negative left bearing
|
||
draw_y = 0 # Top of cell (PIL default without anchor)
|
||
|
||
for i, char in enumerate(TEXT_CHARSET):
|
||
char_to_idx[char] = i
|
||
char_widths.append(char_widths_dict.get(char, cell_width // 2))
|
||
|
||
# Create RGBA image for this character
|
||
img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0))
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
# Draw same way as prim_text - at (draw_x, draw_y), no anchor
|
||
# This positions the text origin, and glyphs may extend left/right from there
|
||
draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
|
||
|
||
# Get bbox to find left bearing
|
||
bbox = draw.textbbox((draw_x, draw_y), char, font=font)
|
||
left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends
|
||
char_left_bearings.append(left_bearing)
|
||
|
||
# Convert to numpy
|
||
char_array = np.array(img, dtype=np.uint8)
|
||
atlas.append(char_array)
|
||
|
||
atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4)
|
||
char_widths = np.array(char_widths, dtype=np.int32)
|
||
char_left_bearings = np.array(char_left_bearings, dtype=np.int32)
|
||
|
||
# Return draw_x (origin offset within tile) so rendering knows where origin is
|
||
result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings)
|
||
_TEXT_ATLAS_CACHE[cache_key] = result
|
||
return result
|
||
|
||
|
||
def jax_text_render(frame, text: str, x: int, y: int,
|
||
font_name: str = None, font_size: int = 32,
|
||
color=(255, 255, 255), opacity: float = 1.0,
|
||
align: str = "left", valign: str = "baseline",
|
||
shadow: bool = False, shadow_color=(0, 0, 0),
|
||
shadow_offset: int = 2):
|
||
"""
|
||
Render text onto frame using font atlas (JAX-compatible).
|
||
|
||
This is designed to be called from within a JIT-compiled function.
|
||
The font atlas is created at compile time (using numpy/PIL),
|
||
then converted to JAX array for the actual rendering.
|
||
|
||
Typography notes:
|
||
- Baseline: The line text "sits" on. Most characters rest on this line.
|
||
- Ascender: Top of tall letters (b, d, h, k, l) - above baseline
|
||
- Descender: Bottom of letters like g, j, p, q, y - below baseline
|
||
- For normal text, use valign="baseline" and y = the baseline position
|
||
|
||
Args:
|
||
frame: Input frame (H, W, 3)
|
||
text: Text string to render
|
||
x, y: Position reference point (affected by align/valign)
|
||
font_name: Font to use (None = default)
|
||
font_size: Font size in pixels
|
||
color: RGB tuple (0-255)
|
||
opacity: 0.0 to 1.0
|
||
align: Horizontal alignment relative to x:
|
||
"left" - text starts at x
|
||
"center" - text centered on x
|
||
"right" - text ends at x
|
||
valign: Vertical alignment relative to y:
|
||
"baseline" - text baseline at y (default, like normal text)
|
||
"top" - top of ascenders at y
|
||
"middle" - text vertically centered on y
|
||
"bottom" - bottom of descenders at y
|
||
shadow: Whether to draw drop shadow
|
||
shadow_color: Shadow RGB color
|
||
shadow_offset: Shadow offset in pixels
|
||
|
||
Returns:
|
||
Frame with text rendered
|
||
"""
|
||
if not text:
|
||
return frame
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# Get or create font atlas (this happens at trace time, uses numpy)
|
||
atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size)
|
||
|
||
# Convert atlas to JAX array
|
||
atlas = jnp.asarray(atlas_np)
|
||
|
||
# Atlas dimensions
|
||
cell_width = atlas.shape[2]
|
||
|
||
# Convert text to character indices and compute character widths
|
||
# (at trace time, text is static so we can pre-compute)
|
||
indices_list = []
|
||
char_x_offsets = [0] # Starting x position for each character
|
||
total_width = 0
|
||
|
||
for char in text:
|
||
if char in char_to_idx:
|
||
idx = char_to_idx[char]
|
||
indices_list.append(idx)
|
||
char_w = int(char_widths_np[idx])
|
||
else:
|
||
indices_list.append(char_to_idx.get(' ', 0))
|
||
char_w = int(char_widths_np[char_to_idx.get(' ', 0)])
|
||
total_width += char_w
|
||
char_x_offsets.append(total_width)
|
||
|
||
indices = jnp.array(indices_list, dtype=jnp.int32)
|
||
num_chars = len(indices_list)
|
||
|
||
# Actual text dimensions using proportional widths
|
||
text_width = total_width
|
||
text_height = char_height
|
||
|
||
# Adjust position for horizontal alignment
|
||
if align == "center":
|
||
x = x - text_width // 2
|
||
elif align == "right":
|
||
x = x - text_width
|
||
|
||
# Adjust position for vertical alignment
|
||
# baseline_offset = pixels from top of cell to baseline
|
||
if valign == "baseline":
|
||
# y specifies baseline position, so top of text cell is above it
|
||
y = y - baseline_offset
|
||
elif valign == "middle":
|
||
y = y - text_height // 2
|
||
elif valign == "bottom":
|
||
y = y - text_height
|
||
# valign == "top" needs no adjustment (default)
|
||
|
||
# Ensure position is integer
|
||
x, y = int(x), int(y)
|
||
|
||
# Create text strip with proper character spacing at trace time (using numpy)
|
||
# This ensures proportional fonts render correctly
|
||
#
|
||
# The atlas stores each character drawn at (origin_x, 0) in its tile.
|
||
# To place a character at cursor position 'cx':
|
||
# - The tile's origin_x should align with cx in the strip
|
||
# - So we blit tile to strip starting at (cx - origin_x)
|
||
#
|
||
# Add padding for characters with negative left bearings
|
||
strip_padding = origin_x # Extra space at start for negative bearings
|
||
text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8)
|
||
|
||
for i, char in enumerate(text):
|
||
if char in char_to_idx:
|
||
idx = char_to_idx[char]
|
||
char_tile = atlas_np[idx] # (char_height, cell_width, 4)
|
||
cx = char_x_offsets[i]
|
||
# Position tile so its origin aligns with cursor position
|
||
strip_x = strip_padding + cx - origin_x
|
||
if strip_x >= 0:
|
||
end_x = min(strip_x + cell_width, text_strip_np.shape[1])
|
||
tile_end = end_x - strip_x
|
||
text_strip_np[:, strip_x:end_x] = np.maximum(
|
||
text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end])
|
||
|
||
# Trim the strip:
|
||
# - Left side: trim to first visible pixel (handles negative left bearing)
|
||
# - Right side: use computed text_width (preserve advance width spacing)
|
||
alpha = text_strip_np[:, :, 3]
|
||
cols_with_content = np.any(alpha > 0, axis=0)
|
||
if cols_with_content.any():
|
||
first_col = np.argmax(cols_with_content)
|
||
# Right edge: use the computed text width from the strip's logical end
|
||
right_col = strip_padding + text_width
|
||
# Adjust x to account for the left trim offset
|
||
x = x + first_col - strip_padding
|
||
text_strip_np = text_strip_np[:, first_col:right_col]
|
||
else:
|
||
# No visible content, return original frame
|
||
return frame
|
||
|
||
# Convert to JAX
|
||
text_strip = jnp.asarray(text_strip_np)
|
||
|
||
# Convert color to array
|
||
color = jnp.array(color, dtype=jnp.float32)
|
||
shadow_color = jnp.array(shadow_color, dtype=jnp.float32)
|
||
|
||
# Apply color tint to text strip (white text * color)
|
||
text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color
|
||
text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity
|
||
|
||
# Start with frame as float
|
||
result = frame.astype(jnp.float32)
|
||
|
||
# Draw shadow first if enabled
|
||
if shadow:
|
||
sx, sy = x + shadow_offset, y + shadow_offset
|
||
shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color
|
||
shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5
|
||
result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy)
|
||
|
||
# Draw main text
|
||
result = _composite_text_strip(result, text_rgb, text_alpha, x, y)
|
||
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def _composite_text_strip(frame, text_rgb, text_alpha, x, y):
|
||
"""
|
||
Composite text strip onto frame at position (x, y).
|
||
|
||
Uses alpha blending: result = text * alpha + frame * (1 - alpha)
|
||
|
||
This is designed to work within JAX tracing.
|
||
"""
|
||
h, w = frame.shape[:2]
|
||
th, tw = text_rgb.shape[:2]
|
||
|
||
# Clamp to frame bounds
|
||
# Source region (in text strip)
|
||
src_x1 = jnp.maximum(0, -x)
|
||
src_y1 = jnp.maximum(0, -y)
|
||
src_x2 = jnp.minimum(tw, w - x)
|
||
src_y2 = jnp.minimum(th, h - y)
|
||
|
||
# Destination region (in frame)
|
||
dst_x1 = jnp.maximum(0, x)
|
||
dst_y1 = jnp.maximum(0, y)
|
||
dst_x2 = jnp.minimum(w, x + tw)
|
||
dst_y2 = jnp.minimum(h, y + th)
|
||
|
||
# Check if there's anything to draw
|
||
# (We need to handle this carefully for JAX - can't use Python if with traced values)
|
||
# Instead, we'll do the full operation but the slicing will handle bounds
|
||
|
||
# Create coordinate grids for the destination region
|
||
# We'll use dynamic_slice for JAX-compatible slicing
|
||
|
||
# For simplicity and JAX compatibility, we'll create a full-frame text layer
|
||
# and composite it - this is less efficient but works with JIT
|
||
|
||
# Create full-frame RGBA layer
|
||
text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32)
|
||
text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32)
|
||
|
||
# Place text strip in the layer using dynamic_update_slice
|
||
# First pad the text strip to handle out-of-bounds
|
||
padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32)
|
||
padded_alpha = jnp.zeros((h, w), dtype=jnp.float32)
|
||
|
||
# Calculate valid region
|
||
y_start = int(max(0, y))
|
||
y_end = int(min(h, y + th))
|
||
x_start = int(max(0, x))
|
||
x_end = int(min(w, x + tw))
|
||
|
||
src_y_start = int(max(0, -y))
|
||
src_y_end = src_y_start + (y_end - y_start)
|
||
src_x_start = int(max(0, -x))
|
||
src_x_end = src_x_start + (x_end - x_start)
|
||
|
||
# Only proceed if there's a valid region
|
||
if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start:
|
||
# Extract the valid portion of text
|
||
valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end]
|
||
valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end]
|
||
|
||
# Use lax.dynamic_update_slice for JAX compatibility
|
||
padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0))
|
||
padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start))
|
||
|
||
# Alpha composite: result = text * alpha + frame * (1 - alpha)
|
||
alpha_3d = padded_alpha[:, :, jnp.newaxis]
|
||
result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d)
|
||
|
||
return result
|
||
|
||
|
||
def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple:
|
||
"""
|
||
Measure text dimensions (width, height).
|
||
|
||
This can be called at compile time to get text dimensions for layout.
|
||
|
||
Returns:
|
||
(width, height) tuple in pixels
|
||
"""
|
||
_, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size)
|
||
|
||
# Sum actual character widths
|
||
total_width = 0
|
||
for c in text:
|
||
if c in char_to_idx:
|
||
total_width += int(char_widths[char_to_idx[c]])
|
||
else:
|
||
total_width += int(char_widths[char_to_idx.get(' ', 0)])
|
||
|
||
return (total_width, char_height)
|
||
|
||
|
||
def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict:
|
||
"""
|
||
Get font metrics for layout calculations.
|
||
|
||
Typography terms:
|
||
- ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.)
|
||
- descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.)
|
||
- height: total height = ascent + descent (plus padding)
|
||
- baseline: position of baseline from top of text cell
|
||
|
||
Returns:
|
||
dict with keys: ascent, descent, height, baseline
|
||
"""
|
||
_, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size)
|
||
|
||
# baseline_offset is pixels from top to baseline (= ascent + padding)
|
||
# descent = height - baseline (approximately)
|
||
ascent = baseline_offset - 1 # remove padding
|
||
descent = char_height - baseline_offset - 1 # remove padding
|
||
|
||
return {
|
||
'ascent': ascent,
|
||
'descent': descent,
|
||
'height': char_height,
|
||
'baseline': baseline_offset,
|
||
}
|
||
|
||
|
||
def jax_fit_text_size(text: str, max_width: int, max_height: int,
|
||
font_name: str = None, min_size: int = 8, max_size: int = 200) -> int:
|
||
"""
|
||
Calculate font size to fit text within bounds.
|
||
|
||
Binary search for largest size that fits.
|
||
"""
|
||
best_size = min_size
|
||
low, high = min_size, max_size
|
||
|
||
while low <= high:
|
||
mid = (low + high) // 2
|
||
w, h = jax_text_size(text, font_name, mid)
|
||
|
||
if w <= max_width and h <= max_height:
|
||
best_size = mid
|
||
low = mid + 1
|
||
else:
|
||
high = mid - 1
|
||
|
||
return best_size
|
||
|
||
|
||
# =============================================================================
|
||
# JAX Primitives - True primitives that can't be derived
|
||
# =============================================================================
|
||
|
||
def jax_width(frame):
|
||
"""Frame width."""
|
||
return frame.shape[1]
|
||
|
||
|
||
def jax_height(frame):
|
||
"""Frame height."""
|
||
return frame.shape[0]
|
||
|
||
|
||
def jax_channel(frame, idx):
|
||
"""Extract channel by index as flat array."""
|
||
# idx must be a static int for indexing
|
||
return frame[:, :, int(idx)].flatten().astype(jnp.float32)
|
||
|
||
|
||
def jax_merge_channels(r, g, b, shape):
|
||
"""Merge RGB channels back to frame."""
|
||
h, w = shape
|
||
r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
return jnp.stack([r_img, g_img, b_img], axis=2)
|
||
|
||
|
||
def jax_iota(n):
|
||
"""Generate [0, 1, 2, ..., n-1]."""
|
||
return jnp.arange(n, dtype=jnp.float32)
|
||
|
||
|
||
def jax_repeat(x, n):
|
||
"""Repeat each element n times: [a,b] -> [a,a,b,b]."""
|
||
return jnp.repeat(x, n)
|
||
|
||
|
||
def jax_tile(x, n):
|
||
"""Tile array n times: [a,b] -> [a,b,a,b]."""
|
||
return jnp.tile(x, n)
|
||
|
||
|
||
def jax_gather(data, indices):
|
||
"""Parallel index lookup."""
|
||
flat_data = data.flatten()
|
||
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1)
|
||
return flat_data[idx_clipped]
|
||
|
||
|
||
def jax_scatter(indices, values, size):
|
||
"""Parallel index write (last write wins)."""
|
||
result = jnp.zeros(size, dtype=jnp.float32)
|
||
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
|
||
return result.at[idx_clipped].set(values)
|
||
|
||
|
||
def jax_scatter_add(indices, values, size):
|
||
"""Parallel index accumulate."""
|
||
result = jnp.zeros(size, dtype=jnp.float32)
|
||
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
|
||
return result.at[idx_clipped].add(values)
|
||
|
||
|
||
def jax_group_reduce(values, group_indices, num_groups, op='mean'):
|
||
"""Reduce values by group."""
|
||
grp = group_indices.astype(jnp.int32)
|
||
|
||
if op == 'sum':
|
||
result = jnp.zeros(num_groups, dtype=jnp.float32)
|
||
return result.at[grp].add(values)
|
||
elif op == 'mean':
|
||
sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values)
|
||
counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0)
|
||
return jnp.where(counts > 0, sums / counts, 0.0)
|
||
elif op == 'max':
|
||
result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32)
|
||
result = result.at[grp].max(values)
|
||
return jnp.where(result == -jnp.inf, 0.0, result)
|
||
elif op == 'min':
|
||
result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32)
|
||
result = result.at[grp].min(values)
|
||
return jnp.where(result == jnp.inf, 0.0, result)
|
||
else:
|
||
raise ValueError(f"Unknown reduce op: {op}")
|
||
|
||
|
||
def jax_where(cond, true_val, false_val):
|
||
"""Conditional select."""
|
||
return jnp.where(cond, true_val, false_val)
|
||
|
||
|
||
def jax_cell_indices(frame, cell_size):
|
||
"""Compute cell index for each pixel."""
|
||
h, w = frame.shape[:2]
|
||
cell_size = int(cell_size)
|
||
|
||
rows = h // cell_size
|
||
cols = w // cell_size
|
||
|
||
# For each pixel, compute its cell index
|
||
y_coords = jnp.repeat(jnp.arange(h), w)
|
||
x_coords = jnp.tile(jnp.arange(w), h)
|
||
|
||
cell_row = y_coords // cell_size
|
||
cell_col = x_coords // cell_size
|
||
cell_idx = cell_row * cols + cell_col
|
||
|
||
# Clip to valid range
|
||
return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32)
|
||
|
||
|
||
def jax_pool_frame(frame, cell_size):
|
||
"""
|
||
Pool frame to cell values.
|
||
Returns tuple: (cell_r, cell_g, cell_b, cell_lum)
|
||
"""
|
||
h, w = frame.shape[:2]
|
||
cs = int(cell_size)
|
||
rows = h // cs
|
||
cols = w // cs
|
||
num_cells = rows * cols
|
||
|
||
# Compute cell indices for each pixel
|
||
y_coords = jnp.repeat(jnp.arange(h), w)
|
||
x_coords = jnp.tile(jnp.arange(w), h)
|
||
cell_row = jnp.clip(y_coords // cs, 0, rows - 1)
|
||
cell_col = jnp.clip(x_coords // cs, 0, cols - 1)
|
||
cell_idx = (cell_row * cols + cell_col).astype(jnp.int32)
|
||
|
||
# Extract channels
|
||
r_flat = frame[:, :, 0].flatten().astype(jnp.float32)
|
||
g_flat = frame[:, :, 1].flatten().astype(jnp.float32)
|
||
b_flat = frame[:, :, 2].flatten().astype(jnp.float32)
|
||
|
||
# Pool each channel (mean)
|
||
def pool_channel(data):
|
||
sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data)
|
||
counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0)
|
||
return jnp.where(counts > 0, sums / counts, 0.0)
|
||
|
||
r_pooled = pool_channel(r_flat)
|
||
g_pooled = pool_channel(g_flat)
|
||
b_pooled = pool_channel(b_flat)
|
||
lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled
|
||
|
||
return (r_pooled, g_pooled, b_pooled, lum)
|
||
|
||
|
||
# =============================================================================
|
||
# Scan (Prefix Operations) - JAX implementations
|
||
# =============================================================================
|
||
|
||
def jax_scan_add(x, axis=None):
|
||
"""Cumulative sum (prefix sum)."""
|
||
if axis is not None:
|
||
return jnp.cumsum(x, axis=int(axis))
|
||
return jnp.cumsum(x.flatten())
|
||
|
||
|
||
def jax_scan_mul(x, axis=None):
|
||
"""Cumulative product."""
|
||
if axis is not None:
|
||
return jnp.cumprod(x, axis=int(axis))
|
||
return jnp.cumprod(x.flatten())
|
||
|
||
|
||
def jax_scan_max(x, axis=None):
|
||
"""Cumulative maximum."""
|
||
if axis is not None:
|
||
return lax.cummax(x, axis=int(axis))
|
||
return lax.cummax(x.flatten(), axis=0)
|
||
|
||
|
||
def jax_scan_min(x, axis=None):
|
||
"""Cumulative minimum."""
|
||
if axis is not None:
|
||
return lax.cummin(x, axis=int(axis))
|
||
return lax.cummin(x.flatten(), axis=0)
|
||
|
||
|
||
# =============================================================================
|
||
# Outer Product - JAX implementations
|
||
# =============================================================================
|
||
|
||
def jax_outer(x, y, op='*'):
|
||
"""Outer product with configurable operation."""
|
||
x_flat = x.flatten()
|
||
y_flat = y.flatten()
|
||
|
||
ops = {
|
||
'*': lambda a, b: jnp.outer(a, b),
|
||
'+': lambda a, b: a[:, None] + b[None, :],
|
||
'-': lambda a, b: a[:, None] - b[None, :],
|
||
'/': lambda a, b: a[:, None] / b[None, :],
|
||
'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]),
|
||
'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]),
|
||
}
|
||
|
||
op_fn = ops.get(op, ops['*'])
|
||
return op_fn(x_flat, y_flat)
|
||
|
||
|
||
def jax_outer_add(x, y):
|
||
"""Outer sum."""
|
||
return jax_outer(x, y, '+')
|
||
|
||
|
||
def jax_outer_mul(x, y):
|
||
"""Outer product."""
|
||
return jax_outer(x, y, '*')
|
||
|
||
|
||
def jax_outer_max(x, y):
|
||
"""Outer max."""
|
||
return jax_outer(x, y, 'max')
|
||
|
||
|
||
def jax_outer_min(x, y):
|
||
"""Outer min."""
|
||
return jax_outer(x, y, 'min')
|
||
|
||
|
||
# =============================================================================
|
||
# Reduce with Axis - JAX implementations
|
||
# =============================================================================
|
||
|
||
def jax_reduce_axis(x, op='sum', axis=0):
|
||
"""Reduce along an axis."""
|
||
axis = int(axis)
|
||
ops = {
|
||
'sum': lambda d: jnp.sum(d, axis=axis),
|
||
'+': lambda d: jnp.sum(d, axis=axis),
|
||
'mean': lambda d: jnp.mean(d, axis=axis),
|
||
'max': lambda d: jnp.max(d, axis=axis),
|
||
'min': lambda d: jnp.min(d, axis=axis),
|
||
'prod': lambda d: jnp.prod(d, axis=axis),
|
||
'*': lambda d: jnp.prod(d, axis=axis),
|
||
'std': lambda d: jnp.std(d, axis=axis),
|
||
}
|
||
op_fn = ops.get(op, ops['sum'])
|
||
return op_fn(x)
|
||
|
||
|
||
def jax_sum_axis(x, axis=0):
|
||
"""Sum along axis."""
|
||
return jnp.sum(x, axis=int(axis))
|
||
|
||
|
||
def jax_mean_axis(x, axis=0):
|
||
"""Mean along axis."""
|
||
return jnp.mean(x, axis=int(axis))
|
||
|
||
|
||
def jax_max_axis(x, axis=0):
|
||
"""Max along axis."""
|
||
return jnp.max(x, axis=int(axis))
|
||
|
||
|
||
def jax_min_axis(x, axis=0):
|
||
"""Min along axis."""
|
||
return jnp.min(x, axis=int(axis))
|
||
|
||
|
||
# =============================================================================
|
||
# Windowed Operations - JAX implementations
|
||
# =============================================================================
|
||
|
||
def jax_window(x, size, op='mean', stride=1):
|
||
"""
|
||
Sliding window operation.
|
||
|
||
For 1D arrays: standard sliding window
|
||
For 2D arrays: 2D sliding window (size x size)
|
||
"""
|
||
size = int(size)
|
||
stride = int(stride)
|
||
|
||
if x.ndim == 1:
|
||
# 1D sliding window using convolution trick
|
||
n = len(x)
|
||
if op == 'sum':
|
||
kernel = jnp.ones(size)
|
||
return jnp.convolve(x, kernel, mode='valid')[::stride]
|
||
elif op == 'mean':
|
||
kernel = jnp.ones(size) / size
|
||
return jnp.convolve(x, kernel, mode='valid')[::stride]
|
||
else:
|
||
# For max/min, use manual approach
|
||
out_n = (n - size) // stride + 1
|
||
indices = jnp.arange(out_n) * stride
|
||
windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices)
|
||
if op == 'max':
|
||
return jnp.max(windows, axis=1)
|
||
elif op == 'min':
|
||
return jnp.min(windows, axis=1)
|
||
else:
|
||
return jnp.mean(windows, axis=1)
|
||
else:
|
||
# 2D sliding window
|
||
h, w = x.shape[:2]
|
||
out_h = (h - size) // stride + 1
|
||
out_w = (w - size) // stride + 1
|
||
|
||
# Extract all windows using vmap
|
||
def extract_window(ij):
|
||
i, j = ij // out_w, ij % out_w
|
||
return lax.dynamic_slice(x, (i * stride, j * stride), (size, size))
|
||
|
||
indices = jnp.arange(out_h * out_w)
|
||
windows = jax.vmap(extract_window)(indices)
|
||
|
||
if op == 'sum':
|
||
result = jnp.sum(windows, axis=(1, 2))
|
||
elif op == 'mean':
|
||
result = jnp.mean(windows, axis=(1, 2))
|
||
elif op == 'max':
|
||
result = jnp.max(windows, axis=(1, 2))
|
||
elif op == 'min':
|
||
result = jnp.min(windows, axis=(1, 2))
|
||
else:
|
||
result = jnp.mean(windows, axis=(1, 2))
|
||
|
||
return result.reshape(out_h, out_w)
|
||
|
||
|
||
def jax_window_sum(x, size, stride=1):
|
||
"""Sliding window sum."""
|
||
return jax_window(x, size, 'sum', stride)
|
||
|
||
|
||
def jax_window_mean(x, size, stride=1):
|
||
"""Sliding window mean."""
|
||
return jax_window(x, size, 'mean', stride)
|
||
|
||
|
||
def jax_window_max(x, size, stride=1):
|
||
"""Sliding window max."""
|
||
return jax_window(x, size, 'max', stride)
|
||
|
||
|
||
def jax_window_min(x, size, stride=1):
|
||
"""Sliding window min."""
|
||
return jax_window(x, size, 'min', stride)
|
||
|
||
|
||
def jax_integral_image(frame):
|
||
"""
|
||
Compute integral image (summed area table).
|
||
Enables O(1) box blur at any radius.
|
||
"""
|
||
if frame.ndim == 3:
|
||
# Convert to grayscale
|
||
gray = jnp.mean(frame.astype(jnp.float32), axis=2)
|
||
else:
|
||
gray = frame.astype(jnp.float32)
|
||
|
||
# Cumsum along both axes
|
||
return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1)
|
||
|
||
|
||
def jax_sample(frame, x, y):
|
||
"""Bilinear sample at (x, y) coordinates.
|
||
|
||
Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default):
|
||
out-of-bounds samples return 0, then bilinear blend includes those zeros.
|
||
"""
|
||
h, w = frame.shape[:2]
|
||
|
||
# Get integer coords for the 4 sample points
|
||
x0 = jnp.floor(x).astype(jnp.int32)
|
||
y0 = jnp.floor(y).astype(jnp.int32)
|
||
x1 = x0 + 1
|
||
y1 = y0 + 1
|
||
|
||
fx = x - x0.astype(jnp.float32)
|
||
fy = y - y0.astype(jnp.float32)
|
||
|
||
# Check which sample points are in bounds
|
||
valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h)
|
||
valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h)
|
||
valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h)
|
||
valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h)
|
||
|
||
# Clamp indices for safe array access (values will be masked anyway)
|
||
x0_safe = jnp.clip(x0, 0, w - 1)
|
||
x1_safe = jnp.clip(x1, 0, w - 1)
|
||
y0_safe = jnp.clip(y0, 0, h - 1)
|
||
y1_safe = jnp.clip(y1, 0, h - 1)
|
||
|
||
# Bilinear interpolation for each channel
|
||
def interp_channel(c):
|
||
# Sample with 0 for out-of-bounds (BORDER_CONSTANT)
|
||
c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0)
|
||
c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0)
|
||
c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0)
|
||
c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0)
|
||
|
||
return (c00 * (1 - fx) * (1 - fy) +
|
||
c10 * fx * (1 - fy) +
|
||
c01 * (1 - fx) * fy +
|
||
c11 * fx * fy)
|
||
|
||
r = interp_channel(0)
|
||
g = interp_channel(1)
|
||
b = interp_channel(2)
|
||
|
||
return r, g, b
|
||
|
||
|
||
# =============================================================================
|
||
# Convolution Operations
|
||
# =============================================================================
|
||
|
||
def jax_convolve2d(data, kernel):
|
||
"""2D convolution on a single channel."""
|
||
# data shape: (H, W), kernel shape: (kH, kW)
|
||
# Use JAX's conv with appropriate padding
|
||
h, w = data.shape
|
||
kh, kw = kernel.shape
|
||
|
||
# Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c)
|
||
data_4d = data.reshape(1, h, w, 1)
|
||
kernel_4d = kernel.reshape(kh, kw, 1, 1)
|
||
|
||
# Convolve with 'SAME' padding
|
||
result = lax.conv_general_dilated(
|
||
data_4d, kernel_4d,
|
||
window_strides=(1, 1),
|
||
padding='SAME',
|
||
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
|
||
)
|
||
|
||
return result.reshape(h, w)
|
||
|
||
|
||
def jax_blur(frame, radius=1):
|
||
"""Gaussian blur."""
|
||
# Create gaussian kernel
|
||
size = int(radius) * 2 + 1
|
||
x = jnp.arange(size) - radius
|
||
gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2))
|
||
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
||
kernel = jnp.outer(gaussian_1d, gaussian_1d)
|
||
|
||
h, w = frame.shape[:2]
|
||
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
|
||
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
|
||
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_sharpen(frame, amount=1.0):
|
||
"""Sharpen using unsharp mask."""
|
||
kernel = jnp.array([
|
||
[0, -1, 0],
|
||
[-1, 5, -1],
|
||
[0, -1, 0]
|
||
], dtype=jnp.float32)
|
||
|
||
# Adjust kernel based on amount
|
||
center = 4 * amount + 1
|
||
kernel = kernel.at[1, 1].set(center)
|
||
kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount)
|
||
|
||
h, w = frame.shape[:2]
|
||
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
|
||
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
|
||
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_edge_detect(frame):
|
||
"""Sobel edge detection."""
|
||
# Sobel kernels
|
||
sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)
|
||
sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)
|
||
|
||
# Convert to grayscale first
|
||
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
|
||
frame[:, :, 1].astype(jnp.float32) * 0.587 +
|
||
frame[:, :, 2].astype(jnp.float32) * 0.114)
|
||
|
||
gx = jax_convolve2d(gray, sobel_x)
|
||
gy = jax_convolve2d(gray, sobel_y)
|
||
|
||
edges = jnp.sqrt(gx**2 + gy**2)
|
||
edges = jnp.clip(edges, 0, 255).astype(jnp.uint8)
|
||
|
||
return jnp.stack([edges, edges, edges], axis=2)
|
||
|
||
|
||
def jax_emboss(frame):
|
||
"""Emboss effect."""
|
||
kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32)
|
||
|
||
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
|
||
frame[:, :, 1].astype(jnp.float32) * 0.587 +
|
||
frame[:, :, 2].astype(jnp.float32) * 0.114)
|
||
|
||
embossed = jax_convolve2d(gray, kernel) + 128
|
||
embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8)
|
||
|
||
return jnp.stack([embossed, embossed, embossed], axis=2)
|
||
|
||
|
||
# =============================================================================
|
||
# Color Space Conversion
|
||
# =============================================================================
|
||
|
||
def jax_rgb_to_hsv(r, g, b):
|
||
"""Convert RGB to HSV. All inputs/outputs are 0-255 range."""
|
||
r, g, b = r / 255.0, g / 255.0, b / 255.0
|
||
|
||
max_c = jnp.maximum(jnp.maximum(r, g), b)
|
||
min_c = jnp.minimum(jnp.minimum(r, g), b)
|
||
diff = max_c - min_c
|
||
|
||
# Value
|
||
v = max_c
|
||
|
||
# Saturation
|
||
s = jnp.where(max_c > 0, diff / max_c, 0.0)
|
||
|
||
# Hue
|
||
h = jnp.where(diff == 0, 0.0,
|
||
jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360,
|
||
jnp.where(max_c == g, 60 * ((b - r) / diff) + 120,
|
||
60 * ((r - g) / diff) + 240)))
|
||
|
||
return h, s * 255, v * 255
|
||
|
||
|
||
def jax_hsv_to_rgb(h, s, v):
|
||
"""Convert HSV to RGB. H is 0-360, S and V are 0-255."""
|
||
h = h % 360
|
||
s, v = s / 255.0, v / 255.0
|
||
|
||
c = v * s
|
||
x = c * (1 - jnp.abs((h / 60) % 2 - 1))
|
||
m = v - c
|
||
|
||
h_sector = (h / 60).astype(jnp.int32) % 6
|
||
|
||
r = jnp.where(h_sector == 0, c,
|
||
jnp.where(h_sector == 1, x,
|
||
jnp.where(h_sector == 2, 0,
|
||
jnp.where(h_sector == 3, 0,
|
||
jnp.where(h_sector == 4, x, c)))))
|
||
|
||
g = jnp.where(h_sector == 0, x,
|
||
jnp.where(h_sector == 1, c,
|
||
jnp.where(h_sector == 2, c,
|
||
jnp.where(h_sector == 3, x,
|
||
jnp.where(h_sector == 4, 0, 0)))))
|
||
|
||
b = jnp.where(h_sector == 0, 0,
|
||
jnp.where(h_sector == 1, 0,
|
||
jnp.where(h_sector == 2, x,
|
||
jnp.where(h_sector == 3, c,
|
||
jnp.where(h_sector == 4, c, x)))))
|
||
|
||
return (r + m) * 255, (g + m) * 255, (b + m) * 255
|
||
|
||
|
||
def jax_adjust_saturation(frame, factor):
|
||
"""Adjust saturation by factor (1.0 = unchanged)."""
|
||
r = frame[:, :, 0].flatten().astype(jnp.float32)
|
||
g = frame[:, :, 1].flatten().astype(jnp.float32)
|
||
b = frame[:, :, 2].flatten().astype(jnp.float32)
|
||
|
||
h, s, v = jax_rgb_to_hsv(r, g, b)
|
||
s = jnp.clip(s * factor, 0, 255)
|
||
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
|
||
|
||
h_dim, w_dim = frame.shape[:2]
|
||
return jnp.stack([
|
||
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
|
||
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
|
||
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_shift_hue(frame, degrees):
|
||
"""Shift hue by degrees."""
|
||
r = frame[:, :, 0].flatten().astype(jnp.float32)
|
||
g = frame[:, :, 1].flatten().astype(jnp.float32)
|
||
b = frame[:, :, 2].flatten().astype(jnp.float32)
|
||
|
||
h, s, v = jax_rgb_to_hsv(r, g, b)
|
||
h = (h + degrees) % 360
|
||
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
|
||
|
||
h_dim, w_dim = frame.shape[:2]
|
||
return jnp.stack([
|
||
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
|
||
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
|
||
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
# =============================================================================
|
||
# Color Adjustment Operations
|
||
# =============================================================================
|
||
|
||
def jax_adjust_brightness(frame, amount):
|
||
"""Adjust brightness by amount (-255 to 255)."""
|
||
result = frame.astype(jnp.float32) + amount
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_adjust_contrast(frame, factor):
|
||
"""Adjust contrast by factor (1.0 = unchanged)."""
|
||
result = (frame.astype(jnp.float32) - 128) * factor + 128
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_invert(frame):
|
||
"""Invert colors."""
|
||
return 255 - frame
|
||
|
||
|
||
def jax_posterize(frame, levels):
|
||
"""Reduce to N color levels per channel."""
|
||
levels = int(levels)
|
||
if levels < 2:
|
||
levels = 2
|
||
step = 255.0 / (levels - 1)
|
||
result = jnp.round(frame.astype(jnp.float32) / step) * step
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_threshold(frame, level, invert=False):
|
||
"""Binary threshold."""
|
||
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
|
||
frame[:, :, 1].astype(jnp.float32) * 0.587 +
|
||
frame[:, :, 2].astype(jnp.float32) * 0.114)
|
||
|
||
if invert:
|
||
binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8)
|
||
else:
|
||
binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8)
|
||
|
||
return jnp.stack([binary, binary, binary], axis=2)
|
||
|
||
|
||
def jax_sepia(frame):
|
||
"""Apply sepia tone."""
|
||
r = frame[:, :, 0].astype(jnp.float32)
|
||
g = frame[:, :, 1].astype(jnp.float32)
|
||
b = frame[:, :, 2].astype(jnp.float32)
|
||
|
||
new_r = r * 0.393 + g * 0.769 + b * 0.189
|
||
new_g = r * 0.349 + g * 0.686 + b * 0.168
|
||
new_b = r * 0.272 + g * 0.534 + b * 0.131
|
||
|
||
return jnp.stack([
|
||
jnp.clip(new_r, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(new_g, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(new_b, 0, 255).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_grayscale(frame):
|
||
"""Convert to grayscale."""
|
||
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
|
||
frame[:, :, 1].astype(jnp.float32) * 0.587 +
|
||
frame[:, :, 2].astype(jnp.float32) * 0.114)
|
||
gray = gray.astype(jnp.uint8)
|
||
return jnp.stack([gray, gray, gray], axis=2)
|
||
|
||
|
||
# =============================================================================
|
||
# Geometry Operations
|
||
# =============================================================================
|
||
|
||
def jax_flip_horizontal(frame):
|
||
"""Flip horizontally."""
|
||
return frame[:, ::-1, :]
|
||
|
||
|
||
def jax_flip_vertical(frame):
|
||
"""Flip vertically."""
|
||
return frame[::-1, :, :]
|
||
|
||
|
||
def jax_rotate(frame, angle, center_x=None, center_y=None):
|
||
"""Rotate frame by angle (degrees), matching OpenCV convention.
|
||
|
||
Positive angle = counter-clockwise rotation.
|
||
"""
|
||
h, w = frame.shape[:2]
|
||
if center_x is None:
|
||
center_x = w / 2
|
||
if center_y is None:
|
||
center_y = h / 2
|
||
|
||
# Convert to radians
|
||
theta = angle * jnp.pi / 180
|
||
cos_t, sin_t = jnp.cos(theta), jnp.sin(theta)
|
||
|
||
# Create coordinate grids
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
|
||
|
||
# OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]]
|
||
# For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]]
|
||
# So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx
|
||
# src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy
|
||
x_centered = x_coords - center_x
|
||
y_centered = y_coords - center_y
|
||
|
||
src_x = cos_t * x_centered - sin_t * y_centered + center_x
|
||
src_y = sin_t * x_centered + cos_t * y_centered + center_y
|
||
|
||
# Sample using bilinear interpolation
|
||
# jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples)
|
||
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_scale(frame, scale_x, scale_y=None):
|
||
"""Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds."""
|
||
if scale_y is None:
|
||
scale_y = scale_x
|
||
|
||
h, w = frame.shape[:2]
|
||
center_x, center_y = w / 2, h / 2
|
||
|
||
# Create coordinate grids
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
|
||
|
||
# Scale from center (inverse mapping: dst -> src)
|
||
src_x = (x_coords - center_x) / scale_x + center_x
|
||
src_y = (y_coords - center_y) / scale_y + center_y
|
||
|
||
# Sample using bilinear interpolation
|
||
# jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples)
|
||
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
def jax_resize(frame, new_width, new_height):
|
||
"""Resize frame to new dimensions."""
|
||
h, w = frame.shape[:2]
|
||
new_h, new_w = int(new_height), int(new_width)
|
||
|
||
# Create coordinate grids for new size
|
||
y_coords = jnp.repeat(jnp.arange(new_h), new_w)
|
||
x_coords = jnp.tile(jnp.arange(new_w), new_h)
|
||
|
||
# Map to source coordinates
|
||
src_x = x_coords * (w - 1) / (new_w - 1)
|
||
src_y = y_coords * (h - 1) / (new_h - 1)
|
||
|
||
r, g, b = jax_sample(frame, src_x, src_y)
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
|
||
# =============================================================================
|
||
# Blending Operations
|
||
# =============================================================================
|
||
|
||
def _resize_to_match(frame1, frame2):
|
||
"""Resize frame2 to match frame1's dimensions if they differ.
|
||
|
||
Uses jax.image.resize for bilinear interpolation.
|
||
Returns frame2 resized to frame1's shape.
|
||
"""
|
||
h1, w1 = frame1.shape[:2]
|
||
h2, w2 = frame2.shape[:2]
|
||
|
||
# If same size, return as-is
|
||
if h1 == h2 and w1 == w2:
|
||
return frame2
|
||
|
||
# Resize frame2 to match frame1
|
||
# jax.image.resize expects (height, width, channels) and target shape
|
||
return jax.image.resize(
|
||
frame2.astype(jnp.float32),
|
||
(h1, w1, frame2.shape[2]),
|
||
method='bilinear'
|
||
).astype(jnp.uint8)
|
||
|
||
|
||
def jax_blend(frame1, frame2, alpha):
|
||
"""Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2.
|
||
|
||
Auto-resizes frame2 to match frame1 if dimensions differ.
|
||
"""
|
||
frame2 = _resize_to_match(frame1, frame2)
|
||
return (frame1.astype(jnp.float32) * (1 - alpha) +
|
||
frame2.astype(jnp.float32) * alpha).astype(jnp.uint8)
|
||
|
||
|
||
def jax_blend_add(frame1, frame2):
|
||
"""Additive blend. Auto-resizes frame2 to match frame1."""
|
||
frame2 = _resize_to_match(frame1, frame2)
|
||
result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32)
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_blend_multiply(frame1, frame2):
|
||
"""Multiply blend. Auto-resizes frame2 to match frame1."""
|
||
frame2 = _resize_to_match(frame1, frame2)
|
||
result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_blend_screen(frame1, frame2):
|
||
"""Screen blend. Auto-resizes frame2 to match frame1."""
|
||
frame2 = _resize_to_match(frame1, frame2)
|
||
f1 = frame1.astype(jnp.float32) / 255
|
||
f2 = frame2.astype(jnp.float32) / 255
|
||
result = 1 - (1 - f1) * (1 - f2)
|
||
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
def jax_blend_overlay(frame1, frame2):
|
||
"""Overlay blend. Auto-resizes frame2 to match frame1."""
|
||
frame2 = _resize_to_match(frame1, frame2)
|
||
f1 = frame1.astype(jnp.float32) / 255
|
||
f2 = frame2.astype(jnp.float32) / 255
|
||
result = jnp.where(f1 < 0.5,
|
||
2 * f1 * f2,
|
||
1 - 2 * (1 - f1) * (1 - f2))
|
||
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
|
||
|
||
|
||
# =============================================================================
|
||
# Utility
|
||
# =============================================================================
|
||
|
||
def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0):
|
||
"""Create a JAX random key that varies with frame and operation.
|
||
|
||
Uses jax.random.fold_in to mix frame_num (which may be traced) into the key.
|
||
This allows JIT compilation without recompiling for each frame.
|
||
|
||
Args:
|
||
seed: Base seed for determinism (must be concrete)
|
||
frame_num: Frame number for variation (can be traced)
|
||
op_id: Operation ID for variation (must be concrete)
|
||
|
||
Returns:
|
||
JAX PRNGKey
|
||
"""
|
||
# Create base key from seed and op_id (both concrete)
|
||
base_key = jax.random.PRNGKey(seed + op_id * 1000003)
|
||
# Fold in frame_num (can be traced value)
|
||
return jax.random.fold_in(base_key, frame_num)
|
||
|
||
|
||
def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42):
|
||
"""Random float in [lo, hi), varies with frame."""
|
||
key = make_jax_key(seed, frame_num, op_id)
|
||
return lo + jax.random.uniform(key) * (hi - lo)
|
||
|
||
|
||
def jax_is_nil(x):
|
||
"""Check if value is None/nil."""
|
||
return x is None
|
||
|
||
|
||
# =============================================================================
|
||
# S-expression to JAX Compiler
|
||
# =============================================================================
|
||
|
||
class JaxCompiler:
|
||
"""Compiles S-expressions to JAX functions."""
|
||
|
||
def __init__(self):
|
||
self.env = {} # Variable bindings during compilation
|
||
self.params = {} # Effect parameters
|
||
self.primitives = {} # Loaded primitive libraries
|
||
self.derived = {} # Loaded derived functions
|
||
|
||
def load_derived(self, path: str):
|
||
"""Load derived operations from a .sexp file."""
|
||
with open(path, 'r') as f:
|
||
code = f.read()
|
||
exprs = parse_all(code)
|
||
|
||
# Evaluate all define expressions to populate derived functions
|
||
for expr in exprs:
|
||
if isinstance(expr, list) and len(expr) >= 3:
|
||
head = expr[0]
|
||
if isinstance(head, Symbol) and head.name == 'define':
|
||
self._eval_define(expr[1:], self.derived)
|
||
|
||
def compile_effect(self, sexp) -> Callable:
|
||
"""
|
||
Compile an effect S-expression to a JAX function.
|
||
|
||
Supports both formats:
|
||
(effect "name" :params (...) :body ...)
|
||
(define-effect name :params (...) body)
|
||
|
||
Args:
|
||
sexp: Parsed S-expression
|
||
|
||
Returns:
|
||
JIT-compiled function: (frame, **params) -> frame
|
||
"""
|
||
if not isinstance(sexp, list) or len(sexp) < 2:
|
||
raise ValueError("Effect must be a list")
|
||
|
||
head = sexp[0]
|
||
if not isinstance(head, Symbol):
|
||
raise ValueError("Effect must start with a symbol")
|
||
|
||
form = head.name
|
||
|
||
# Handle both 'effect' and 'define-effect' formats
|
||
if form == 'effect':
|
||
# (effect "name" :params (...) :body ...)
|
||
name = sexp[1] if len(sexp) > 1 else "unnamed"
|
||
if isinstance(name, Symbol):
|
||
name = name.name
|
||
start_idx = 2
|
||
elif form == 'define-effect':
|
||
# (define-effect name :params (...) body)
|
||
name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1])
|
||
start_idx = 2
|
||
else:
|
||
raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'")
|
||
|
||
params_spec = []
|
||
body = None
|
||
|
||
i = start_idx
|
||
while i < len(sexp):
|
||
item = sexp[i]
|
||
if isinstance(item, Keyword):
|
||
if item.name == 'params' and i + 1 < len(sexp):
|
||
params_spec = sexp[i + 1]
|
||
i += 2
|
||
elif item.name == 'body' and i + 1 < len(sexp):
|
||
body = sexp[i + 1]
|
||
i += 2
|
||
elif item.name in ('desc', 'type', 'range'):
|
||
# Skip metadata keywords
|
||
i += 2
|
||
else:
|
||
i += 2 # Skip unknown keywords with their values
|
||
else:
|
||
# Assume it's the body if we haven't seen one
|
||
if body is None:
|
||
body = item
|
||
i += 1
|
||
|
||
if body is None:
|
||
raise ValueError(f"Effect '{name}' must have a body")
|
||
|
||
# Extract parameter names, defaults, and static params (strings, bools)
|
||
param_info, static_params = self._parse_params(params_spec)
|
||
|
||
# Capture derived functions for the closure
|
||
derived_fns = self.derived.copy()
|
||
|
||
# Create the JAX function
|
||
def effect_fn(frame, **kwargs):
|
||
# Set up environment
|
||
h, w = frame.shape[:2]
|
||
# Get frame_num for deterministic random variation
|
||
frame_num = kwargs.get('frame_num', 0)
|
||
# Get seed from recipe config (passed via kwargs)
|
||
seed = kwargs.get('seed', 42)
|
||
env = {
|
||
'frame': frame,
|
||
'width': w,
|
||
'height': h,
|
||
'_shape': (h, w),
|
||
# Time variables (default to 0, can be overridden via kwargs)
|
||
't': kwargs.get('t', kwargs.get('_time', 0.0)),
|
||
'_time': kwargs.get('_time', kwargs.get('t', 0.0)),
|
||
'time': kwargs.get('time', kwargs.get('t', 0.0)),
|
||
# Frame number for random key generation
|
||
'frame_num': frame_num,
|
||
'frame-num': frame_num,
|
||
'_frame_num': frame_num,
|
||
# Seed from recipe for deterministic random
|
||
'_seed': seed,
|
||
# Counter for unique random keys within same frame
|
||
'_rand_op_counter': 0,
|
||
# Common constants
|
||
'pi': jnp.pi,
|
||
'PI': jnp.pi,
|
||
}
|
||
|
||
# Add derived functions
|
||
env.update(derived_fns)
|
||
|
||
# Add typography primitives
|
||
bind_typography_primitives(env)
|
||
|
||
# Add parameters with defaults
|
||
for pname, pdefault in param_info.items():
|
||
if pname in kwargs:
|
||
env[pname] = kwargs[pname]
|
||
elif isinstance(pdefault, list):
|
||
# Unevaluated S-expression default - evaluate it
|
||
env[pname] = self._eval(pdefault, env)
|
||
else:
|
||
env[pname] = pdefault
|
||
|
||
# Evaluate body
|
||
result = self._eval(body, env)
|
||
|
||
# Ensure result is a frame
|
||
if isinstance(result, tuple) and len(result) == 3:
|
||
# RGB tuple - merge to frame
|
||
r, g, b = result
|
||
return jax_merge_channels(r, g, b, (h, w))
|
||
elif result.ndim == 3:
|
||
return result
|
||
else:
|
||
# Single channel - replicate to RGB
|
||
h, w = env['_shape']
|
||
gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8)
|
||
return jnp.stack([gray, gray, gray], axis=2)
|
||
|
||
# JIT compile with static args for string/bool parameters and seed
|
||
# seed must be static for PRNGKey, but frame_num can be traced via fold_in
|
||
all_static = set(static_params) | {'seed'}
|
||
return jax.jit(effect_fn, static_argnames=list(all_static))
|
||
|
||
def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]:
|
||
"""Parse parameter specifications.
|
||
|
||
Returns:
|
||
Tuple of (param_defaults, static_params)
|
||
- param_defaults: Dict mapping param names to default values
|
||
- static_params: Set of param names that should be static (strings, bools)
|
||
"""
|
||
result = {}
|
||
static_params = set()
|
||
if not isinstance(params_spec, list):
|
||
return result, static_params
|
||
|
||
for param in params_spec:
|
||
if isinstance(param, Symbol):
|
||
result[param.name] = 0.0
|
||
elif isinstance(param, list) and len(param) >= 1:
|
||
pname = param[0].name if isinstance(param[0], Symbol) else str(param[0])
|
||
pdefault = 0.0
|
||
ptype = None
|
||
|
||
# Look for :default and :type keywords
|
||
i = 1
|
||
while i < len(param):
|
||
if isinstance(param[i], Keyword):
|
||
kw = param[i].name
|
||
if kw == 'default' and i + 1 < len(param):
|
||
pdefault = param[i + 1]
|
||
if isinstance(pdefault, Symbol):
|
||
if pdefault.name == 'nil':
|
||
pdefault = None
|
||
elif pdefault.name == 'true':
|
||
pdefault = True
|
||
elif pdefault.name == 'false':
|
||
pdefault = False
|
||
i += 2
|
||
elif kw == 'type' and i + 1 < len(param):
|
||
ptype = param[i + 1]
|
||
if isinstance(ptype, Symbol):
|
||
ptype = ptype.name
|
||
i += 2
|
||
else:
|
||
i += 1
|
||
else:
|
||
i += 1
|
||
|
||
result[pname] = pdefault
|
||
|
||
# Mark string and bool parameters as static (can't be traced by JAX)
|
||
if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)):
|
||
static_params.add(pname)
|
||
|
||
return result, static_params
|
||
|
||
def _eval(self, expr, env: Dict[str, Any]) -> Any:
|
||
"""Evaluate an S-expression in the given environment."""
|
||
|
||
# Already-evaluated values (e.g., from threading macros)
|
||
# JAX arrays, NumPy arrays, tuples, etc.
|
||
if hasattr(expr, 'shape'): # JAX/NumPy array
|
||
return expr
|
||
if isinstance(expr, tuple): # e.g., (r, g, b) from rgb
|
||
return expr
|
||
|
||
# Literals - keep as Python numbers for static operations
|
||
if isinstance(expr, (int, float)):
|
||
return expr
|
||
|
||
if isinstance(expr, str):
|
||
return expr
|
||
|
||
# Symbols - variable lookup
|
||
if isinstance(expr, Symbol):
|
||
name = expr.name
|
||
if name in env:
|
||
return env[name]
|
||
if name == 'nil':
|
||
return None
|
||
if name == 'true':
|
||
return True
|
||
if name == 'false':
|
||
return False
|
||
raise NameError(f"Unknown symbol: {name}")
|
||
|
||
# Lists - function calls
|
||
if isinstance(expr, list) and len(expr) > 0:
|
||
head = expr[0]
|
||
|
||
if isinstance(head, Symbol):
|
||
op = head.name
|
||
args = expr[1:]
|
||
|
||
# Special forms
|
||
if op == 'let' or op == 'let*':
|
||
return self._eval_let(args, env)
|
||
if op == 'if':
|
||
return self._eval_if(args, env)
|
||
if op == 'lambda' or op == 'λ':
|
||
return self._eval_lambda(args, env)
|
||
if op == 'define':
|
||
return self._eval_define(args, env)
|
||
|
||
# Built-in operations
|
||
return self._eval_call(op, args, env)
|
||
|
||
# Empty list
|
||
if isinstance(expr, list) and len(expr) == 0:
|
||
return []
|
||
|
||
raise ValueError(f"Cannot evaluate: {expr}")
|
||
|
||
def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]):
|
||
"""Extract a keyword argument from args list.
|
||
|
||
Looks for :key value pattern in args and evaluates the value.
|
||
Returns default if not found.
|
||
"""
|
||
i = 0
|
||
while i < len(args):
|
||
if isinstance(args[i], Keyword) and args[i].name == key:
|
||
if i + 1 < len(args):
|
||
val = self._eval(args[i + 1], env)
|
||
# Handle Symbol values (e.g., :op 'sum -> 'sum')
|
||
if isinstance(val, Symbol):
|
||
return val.name
|
||
return val
|
||
return default
|
||
i += 1
|
||
return default
|
||
|
||
def _eval_let(self, args, env: Dict[str, Any]) -> Any:
|
||
"""Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body)."""
|
||
if len(args) < 2:
|
||
raise ValueError("let requires bindings and body")
|
||
|
||
bindings = args[0]
|
||
body = args[1]
|
||
|
||
new_env = env.copy()
|
||
|
||
# Handle both ((var val) ...) and [var val var2 val2 ...] syntax
|
||
if isinstance(bindings, list):
|
||
# Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...)
|
||
if bindings and isinstance(bindings[0], Symbol):
|
||
# Flat list: [var val var2 val2 ...]
|
||
i = 0
|
||
while i < len(bindings) - 1:
|
||
var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
|
||
val = self._eval(bindings[i + 1], new_env)
|
||
new_env[var] = val
|
||
i += 2
|
||
else:
|
||
# Nested list: ((var val) (var2 val2) ...)
|
||
for binding in bindings:
|
||
if isinstance(binding, list) and len(binding) >= 2:
|
||
var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
|
||
val = self._eval(binding[1], new_env)
|
||
new_env[var] = val
|
||
|
||
return self._eval(body, new_env)
|
||
|
||
def _eval_if(self, args, env: Dict[str, Any]) -> Any:
|
||
"""Evaluate (if cond then else)."""
|
||
if len(args) < 2:
|
||
raise ValueError("if requires condition and then-branch")
|
||
|
||
cond = self._eval(args[0], env)
|
||
|
||
# Handle None as falsy (important for optional params like overlay)
|
||
if cond is None:
|
||
return self._eval(args[2], env) if len(args) > 2 else None
|
||
|
||
# For Python scalar bools, use normal Python if
|
||
# This allows side effects and None values
|
||
if isinstance(cond, bool):
|
||
if cond:
|
||
return self._eval(args[1], env)
|
||
else:
|
||
return self._eval(args[2], env) if len(args) > 2 else None
|
||
|
||
# For NumPy/JAX scalar bools with concrete values
|
||
if hasattr(cond, 'item') and cond.shape == ():
|
||
try:
|
||
if bool(cond.item()):
|
||
return self._eval(args[1], env)
|
||
else:
|
||
return self._eval(args[2], env) if len(args) > 2 else None
|
||
except:
|
||
pass # Fall through to jnp.where for traced values
|
||
|
||
# For traced values, evaluate both branches and use jnp.where
|
||
then_val = self._eval(args[1], env)
|
||
else_val = self._eval(args[2], env) if len(args) > 2 else 0.0
|
||
|
||
# Handle None by converting to zeros
|
||
if then_val is None:
|
||
then_val = 0.0
|
||
if else_val is None:
|
||
else_val = 0.0
|
||
|
||
# Convert lists to tuples
|
||
if isinstance(then_val, list):
|
||
then_val = tuple(then_val)
|
||
if isinstance(else_val, list):
|
||
else_val = tuple(else_val)
|
||
|
||
# Handle tuple results (e.g., from rgb in map-pixels)
|
||
if isinstance(then_val, tuple) and isinstance(else_val, tuple):
|
||
return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val))
|
||
|
||
return jnp.where(cond, then_val, else_val)
|
||
|
||
def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable:
|
||
"""Evaluate (lambda (params) body)."""
|
||
if len(args) < 2:
|
||
raise ValueError("lambda requires parameters and body")
|
||
|
||
params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]]
|
||
body = args[1]
|
||
captured_env = env.copy()
|
||
|
||
def fn(*fn_args):
|
||
local_env = captured_env.copy()
|
||
for pname, pval in zip(params, fn_args):
|
||
local_env[pname] = pval
|
||
return self._eval(body, local_env)
|
||
|
||
return fn
|
||
|
||
def _eval_define(self, args, env: Dict[str, Any]) -> Any:
|
||
"""Evaluate (define name value) or (define (name params) body)."""
|
||
if len(args) < 2:
|
||
raise ValueError("define requires name and value")
|
||
|
||
name_part = args[0]
|
||
|
||
if isinstance(name_part, list):
|
||
# Function definition: (define (name params) body)
|
||
fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0])
|
||
params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]]
|
||
body = args[1]
|
||
captured_env = env.copy()
|
||
|
||
def fn(*fn_args):
|
||
local_env = captured_env.copy()
|
||
for pname, pval in zip(params, fn_args):
|
||
local_env[pname] = pval
|
||
return self._eval(body, local_env)
|
||
|
||
env[fn_name] = fn
|
||
return fn
|
||
else:
|
||
# Variable definition
|
||
var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part)
|
||
val = self._eval(args[1], env)
|
||
env[var_name] = val
|
||
return val
|
||
|
||
def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any:
|
||
"""Evaluate a function call."""
|
||
|
||
# Check if it's a user-defined function
|
||
if op in env and callable(env[op]):
|
||
fn = env[op]
|
||
eval_args = [self._eval(a, env) for a in args]
|
||
return fn(*eval_args)
|
||
|
||
# Arithmetic
|
||
if op == '+':
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0] if vals else 0.0
|
||
for v in vals[1:]:
|
||
result = result + v
|
||
return result
|
||
|
||
if op == '-':
|
||
if len(args) == 1:
|
||
return -self._eval(args[0], env)
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0]
|
||
for v in vals[1:]:
|
||
result = result - v
|
||
return result
|
||
|
||
if op == '*':
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0] if vals else 1.0
|
||
for v in vals[1:]:
|
||
result = result * v
|
||
return result
|
||
|
||
if op == '/':
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0]
|
||
for v in vals[1:]:
|
||
result = result / v
|
||
return result
|
||
|
||
if op == 'mod':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return a % b
|
||
|
||
if op == 'pow' or op == '**':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return jnp.power(a, b)
|
||
|
||
# Comparison
|
||
if op == '<':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return a < b
|
||
if op == '>':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return a > b
|
||
if op == '<=':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return a <= b
|
||
if op == '>=':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
return a >= b
|
||
if op == '=' or op == '==':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
# For scalar Python types, return Python bool to enable trace-time if
|
||
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
||
return bool(a == b)
|
||
return a == b
|
||
if op == '!=' or op == '<>':
|
||
a, b = self._eval(args[0], env), self._eval(args[1], env)
|
||
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
||
return bool(a != b)
|
||
return a != b
|
||
|
||
# Logic
|
||
if op == 'and':
|
||
vals = [self._eval(a, env) for a in args]
|
||
# Use Python and for concrete Python bools (e.g., shape comparisons)
|
||
if all(isinstance(v, (bool, np.bool_)) for v in vals):
|
||
result = True
|
||
for v in vals:
|
||
result = result and bool(v)
|
||
return result
|
||
# Otherwise use JAX logical_and
|
||
result = vals[0]
|
||
for v in vals[1:]:
|
||
result = jnp.logical_and(result, v)
|
||
return result
|
||
|
||
if op == 'or':
|
||
# Lisp-style or: returns first truthy value, not boolean
|
||
# (or a b c) returns a if a is truthy, else b if b is truthy, else c
|
||
for arg in args:
|
||
val = self._eval(arg, env)
|
||
# Check if value is truthy
|
||
if val is None:
|
||
continue
|
||
if isinstance(val, (bool, np.bool_)):
|
||
if val:
|
||
return val
|
||
continue
|
||
if isinstance(val, (int, float)):
|
||
if val:
|
||
return val
|
||
continue
|
||
if hasattr(val, 'shape'):
|
||
# JAX/numpy array - return it (considered truthy)
|
||
return val
|
||
# For other types, check truthiness
|
||
if val:
|
||
return val
|
||
# All values were falsy, return the last one
|
||
return self._eval(args[-1], env) if args else None
|
||
|
||
if op == 'not':
|
||
val = self._eval(args[0], env)
|
||
if isinstance(val, (bool, np.bool_)):
|
||
return not bool(val)
|
||
return jnp.logical_not(val)
|
||
|
||
# Math functions
|
||
if op == 'sqrt':
|
||
return jnp.sqrt(self._eval(args[0], env))
|
||
if op == 'sin':
|
||
return jnp.sin(self._eval(args[0], env))
|
||
if op == 'cos':
|
||
return jnp.cos(self._eval(args[0], env))
|
||
if op == 'tan':
|
||
return jnp.tan(self._eval(args[0], env))
|
||
if op == 'exp':
|
||
return jnp.exp(self._eval(args[0], env))
|
||
if op == 'log':
|
||
return jnp.log(self._eval(args[0], env))
|
||
if op == 'abs':
|
||
x = self._eval(args[0], env)
|
||
if isinstance(x, (int, float)):
|
||
return abs(x)
|
||
return jnp.abs(x)
|
||
if op == 'floor':
|
||
import math
|
||
x = self._eval(args[0], env)
|
||
if isinstance(x, (int, float)):
|
||
return math.floor(x)
|
||
return jnp.floor(x)
|
||
if op == 'ceil':
|
||
import math
|
||
x = self._eval(args[0], env)
|
||
if isinstance(x, (int, float)):
|
||
return math.ceil(x)
|
||
return jnp.ceil(x)
|
||
if op == 'round':
|
||
x = self._eval(args[0], env)
|
||
if isinstance(x, (int, float)):
|
||
return round(x)
|
||
return jnp.round(x)
|
||
|
||
# Frame primitives
|
||
if op == 'width':
|
||
return env['width']
|
||
if op == 'height':
|
||
return env['height']
|
||
|
||
if op == 'channel':
|
||
frame = self._eval(args[0], env)
|
||
idx = self._eval(args[1], env)
|
||
# idx should be a Python int (literal from S-expression)
|
||
return jax_channel(frame, idx)
|
||
|
||
if op == 'merge-channels' or op == 'rgb':
|
||
r = self._eval(args[0], env)
|
||
g = self._eval(args[1], env)
|
||
b = self._eval(args[2], env)
|
||
# For scalars (e.g., in map-pixels), return tuple
|
||
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
|
||
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
|
||
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
|
||
if r_is_scalar and g_is_scalar and b_is_scalar:
|
||
return (r, g, b)
|
||
return jax_merge_channels(r, g, b, env['_shape'])
|
||
|
||
if op == 'sample':
|
||
frame = self._eval(args[0], env)
|
||
x = self._eval(args[1], env)
|
||
y = self._eval(args[2], env)
|
||
return jax_sample(frame, x, y)
|
||
|
||
if op == 'cell-indices':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = self._eval(args[1], env)
|
||
return jax_cell_indices(frame, cell_size)
|
||
|
||
if op == 'pool-frame':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = self._eval(args[1], env)
|
||
return jax_pool_frame(frame, cell_size)
|
||
|
||
# Xector primitives
|
||
if op == 'iota':
|
||
n = self._eval(args[0], env)
|
||
return jax_iota(int(n))
|
||
|
||
if op == 'repeat':
|
||
x = self._eval(args[0], env)
|
||
n = self._eval(args[1], env)
|
||
return jax_repeat(x, int(n))
|
||
|
||
if op == 'tile':
|
||
x = self._eval(args[0], env)
|
||
n = self._eval(args[1], env)
|
||
return jax_tile(x, int(n))
|
||
|
||
if op == 'gather':
|
||
data = self._eval(args[0], env)
|
||
indices = self._eval(args[1], env)
|
||
return jax_gather(data, indices)
|
||
|
||
if op == 'scatter':
|
||
indices = self._eval(args[0], env)
|
||
values = self._eval(args[1], env)
|
||
size = int(self._eval(args[2], env))
|
||
return jax_scatter(indices, values, size)
|
||
|
||
if op == 'scatter-add':
|
||
indices = self._eval(args[0], env)
|
||
values = self._eval(args[1], env)
|
||
size = int(self._eval(args[2], env))
|
||
return jax_scatter_add(indices, values, size)
|
||
|
||
if op == 'group-reduce':
|
||
values = self._eval(args[0], env)
|
||
groups = self._eval(args[1], env)
|
||
num_groups = int(self._eval(args[2], env))
|
||
reduce_op = args[3] if len(args) > 3 else 'mean'
|
||
if isinstance(reduce_op, Symbol):
|
||
reduce_op = reduce_op.name
|
||
return jax_group_reduce(values, groups, num_groups, reduce_op)
|
||
|
||
if op == 'where':
|
||
cond = self._eval(args[0], env)
|
||
true_val = self._eval(args[1], env)
|
||
false_val = self._eval(args[2], env)
|
||
# Handle None values
|
||
if true_val is None:
|
||
true_val = 0.0
|
||
if false_val is None:
|
||
false_val = 0.0
|
||
return jax_where(cond, true_val, false_val)
|
||
|
||
if op == 'len' or op == 'length':
|
||
x = self._eval(args[0], env)
|
||
if isinstance(x, (list, tuple)):
|
||
return len(x)
|
||
return x.size
|
||
|
||
# Beta reductions
|
||
if op in ('β+', 'beta+', 'sum'):
|
||
return jnp.sum(self._eval(args[0], env))
|
||
if op in ('β*', 'beta*', 'product'):
|
||
return jnp.prod(self._eval(args[0], env))
|
||
if op in ('βmin', 'beta-min'):
|
||
return jnp.min(self._eval(args[0], env))
|
||
if op in ('βmax', 'beta-max'):
|
||
return jnp.max(self._eval(args[0], env))
|
||
if op in ('βmean', 'beta-mean', 'mean'):
|
||
return jnp.mean(self._eval(args[0], env))
|
||
if op in ('βstd', 'beta-std'):
|
||
return jnp.std(self._eval(args[0], env))
|
||
if op in ('βany', 'beta-any'):
|
||
return jnp.any(self._eval(args[0], env))
|
||
if op in ('βall', 'beta-all'):
|
||
return jnp.all(self._eval(args[0], env))
|
||
|
||
# Scan (prefix) operations
|
||
if op in ('scan+', 'scan-add'):
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', None, env)
|
||
return jax_scan_add(x, axis)
|
||
if op in ('scan*', 'scan-mul'):
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', None, env)
|
||
return jax_scan_mul(x, axis)
|
||
if op == 'scan-max':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', None, env)
|
||
return jax_scan_max(x, axis)
|
||
if op == 'scan-min':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', None, env)
|
||
return jax_scan_min(x, axis)
|
||
|
||
# Outer product operations
|
||
if op == 'outer':
|
||
x = self._eval(args[0], env)
|
||
y = self._eval(args[1], env)
|
||
op_type = self._eval_kwarg(args, 'op', '*', env)
|
||
return jax_outer(x, y, op_type)
|
||
if op in ('outer+', 'outer-add'):
|
||
x = self._eval(args[0], env)
|
||
y = self._eval(args[1], env)
|
||
return jax_outer_add(x, y)
|
||
if op in ('outer*', 'outer-mul'):
|
||
x = self._eval(args[0], env)
|
||
y = self._eval(args[1], env)
|
||
return jax_outer_mul(x, y)
|
||
if op == 'outer-max':
|
||
x = self._eval(args[0], env)
|
||
y = self._eval(args[1], env)
|
||
return jax_outer_max(x, y)
|
||
if op == 'outer-min':
|
||
x = self._eval(args[0], env)
|
||
y = self._eval(args[1], env)
|
||
return jax_outer_min(x, y)
|
||
|
||
# Reduce with axis operations
|
||
if op == 'reduce-axis':
|
||
x = self._eval(args[0], env)
|
||
reduce_op = self._eval_kwarg(args, 'op', 'sum', env)
|
||
axis = self._eval_kwarg(args, 'axis', 0, env)
|
||
return jax_reduce_axis(x, reduce_op, axis)
|
||
if op == 'sum-axis':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', 0, env)
|
||
return jax_sum_axis(x, axis)
|
||
if op == 'mean-axis':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', 0, env)
|
||
return jax_mean_axis(x, axis)
|
||
if op == 'max-axis':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', 0, env)
|
||
return jax_max_axis(x, axis)
|
||
if op == 'min-axis':
|
||
x = self._eval(args[0], env)
|
||
axis = self._eval_kwarg(args, 'axis', 0, env)
|
||
return jax_min_axis(x, axis)
|
||
|
||
# Windowed operations
|
||
if op == 'window':
|
||
x = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env))
|
||
win_op = self._eval_kwarg(args, 'op', 'mean', env)
|
||
stride = int(self._eval_kwarg(args, 'stride', 1, env))
|
||
return jax_window(x, size, win_op, stride)
|
||
if op == 'window-sum':
|
||
x = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env))
|
||
stride = int(self._eval_kwarg(args, 'stride', 1, env))
|
||
return jax_window_sum(x, size, stride)
|
||
if op == 'window-mean':
|
||
x = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env))
|
||
stride = int(self._eval_kwarg(args, 'stride', 1, env))
|
||
return jax_window_mean(x, size, stride)
|
||
if op == 'window-max':
|
||
x = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env))
|
||
stride = int(self._eval_kwarg(args, 'stride', 1, env))
|
||
return jax_window_max(x, size, stride)
|
||
if op == 'window-min':
|
||
x = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env))
|
||
stride = int(self._eval_kwarg(args, 'stride', 1, env))
|
||
return jax_window_min(x, size, stride)
|
||
|
||
# Integral image
|
||
if op == 'integral-image':
|
||
frame = self._eval(args[0], env)
|
||
return jax_integral_image(frame)
|
||
|
||
# Convenience - min/max of two values (handle both scalars and arrays)
|
||
if op == 'min' or op == 'min2':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
# Use Python min/max for scalar Python values to preserve type
|
||
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
||
return min(a, b)
|
||
return jnp.minimum(jnp.asarray(a), jnp.asarray(b))
|
||
if op == 'max' or op == 'max2':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
# Use Python min/max for scalar Python values to preserve type
|
||
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
||
return max(a, b)
|
||
return jnp.maximum(jnp.asarray(a), jnp.asarray(b))
|
||
if op == 'clamp':
|
||
x = self._eval(args[0], env)
|
||
lo = self._eval(args[1], env)
|
||
hi = self._eval(args[2], env)
|
||
return jnp.clip(x, lo, hi)
|
||
|
||
# List operations
|
||
if op == 'list':
|
||
return tuple(self._eval(a, env) for a in args)
|
||
|
||
if op == 'nth':
|
||
seq = self._eval(args[0], env)
|
||
idx = int(self._eval(args[1], env))
|
||
if isinstance(seq, (list, tuple)):
|
||
return seq[idx] if 0 <= idx < len(seq) else None
|
||
return seq[idx] # For arrays
|
||
|
||
if op == 'first':
|
||
seq = self._eval(args[0], env)
|
||
return seq[0] if len(seq) > 0 else None
|
||
|
||
if op == 'second':
|
||
seq = self._eval(args[0], env)
|
||
return seq[1] if len(seq) > 1 else None
|
||
|
||
# Random (JAX-compatible)
|
||
# Get frame_num for deterministic variation - can be traced, fold_in handles it
|
||
frame_num = env.get('_frame_num', env.get('frame_num', 0))
|
||
# Convert to int32 for fold_in if needed (but keep as JAX array if traced)
|
||
if frame_num is None:
|
||
frame_num = 0
|
||
elif isinstance(frame_num, (int, float)):
|
||
frame_num = int(frame_num)
|
||
# If it's a JAX array, leave it as-is for tracing
|
||
|
||
# Increment operation counter for unique keys within same frame
|
||
op_counter = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_counter + 1
|
||
|
||
if op == 'rand' or op == 'rand-x':
|
||
# For size-based random
|
||
if args:
|
||
size = self._eval(args[0], env)
|
||
if hasattr(size, 'shape'):
|
||
# For frames (3D), use h*w (channel size), not h*w*c
|
||
if size.ndim == 3:
|
||
n = size.shape[0] * size.shape[1] # h * w
|
||
shape = (n,)
|
||
else:
|
||
n = size.size
|
||
shape = size.shape
|
||
elif hasattr(size, 'size'):
|
||
n = size.size
|
||
shape = (n,)
|
||
else:
|
||
n = int(size)
|
||
shape = (n,)
|
||
# Use deterministic key that varies with frame
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
return jax.random.uniform(key, shape).flatten()
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
return jax.random.uniform(key, ())
|
||
|
||
if op == 'randn' or op == 'randn-x':
|
||
# Normal random
|
||
if args:
|
||
size = self._eval(args[0], env)
|
||
if hasattr(size, 'shape'):
|
||
# For frames (3D), use h*w (channel size), not h*w*c
|
||
if size.ndim == 3:
|
||
n = size.shape[0] * size.shape[1] # h * w
|
||
else:
|
||
n = size.size
|
||
elif hasattr(size, 'size'):
|
||
n = size.size
|
||
else:
|
||
n = int(size)
|
||
mean = self._eval(args[1], env) if len(args) > 1 else 0.0
|
||
std = self._eval(args[2], env) if len(args) > 2 else 1.0
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
return jax.random.normal(key, (n,)) * std + mean
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
return jax.random.normal(key, ())
|
||
|
||
if op == 'rand-range' or op == 'core:rand-range':
|
||
lo = self._eval(args[0], env)
|
||
hi = self._eval(args[1], env)
|
||
seed = env.get('_seed', 42)
|
||
return jax_rand_range(lo, hi, frame_num, op_counter, seed)
|
||
|
||
# =====================================================================
|
||
# Convolution operations
|
||
# =====================================================================
|
||
if op == 'blur' or op == 'image:blur':
|
||
frame = self._eval(args[0], env)
|
||
radius = self._eval(args[1], env) if len(args) > 1 else 1
|
||
# Convert traced value to concrete for kernel size
|
||
if hasattr(radius, 'item'):
|
||
radius = int(radius.item())
|
||
elif hasattr(radius, '__float__'):
|
||
radius = int(float(radius))
|
||
else:
|
||
radius = int(radius)
|
||
return jax_blur(frame, max(1, radius))
|
||
|
||
if op == 'gaussian':
|
||
first_arg = self._eval(args[0], env)
|
||
# Check if first arg is a frame (blur) or scalar (random)
|
||
if hasattr(first_arg, 'shape') and first_arg.ndim == 3:
|
||
# Frame - apply gaussian blur
|
||
sigma = self._eval(args[1], env) if len(args) > 1 else 1.0
|
||
radius = max(1, int(sigma * 3))
|
||
return jax_blur(first_arg, radius)
|
||
else:
|
||
# Scalar args - generate gaussian random value
|
||
mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg
|
||
std = self._eval(args[1], env) if len(args) > 1 else 1.0
|
||
# Return a single random value
|
||
op_counter = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_counter + 1
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
return jax.random.normal(key, ()) * std + mean
|
||
|
||
if op == 'sharpen' or op == 'image:sharpen':
|
||
frame = self._eval(args[0], env)
|
||
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
|
||
return jax_sharpen(frame, amount)
|
||
|
||
if op == 'edge-detect' or op == 'image:edge-detect':
|
||
frame = self._eval(args[0], env)
|
||
return jax_edge_detect(frame)
|
||
|
||
if op == 'emboss':
|
||
frame = self._eval(args[0], env)
|
||
return jax_emboss(frame)
|
||
|
||
if op == 'convolve':
|
||
frame = self._eval(args[0], env)
|
||
kernel = self._eval(args[1], env)
|
||
# Convert kernel to array if it's a list
|
||
if isinstance(kernel, (list, tuple)):
|
||
kernel = jnp.array(kernel, dtype=jnp.float32)
|
||
h, w = frame.shape[:2]
|
||
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
|
||
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
|
||
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'add-noise':
|
||
frame = self._eval(args[0], env)
|
||
amount = self._eval(args[1], env) if len(args) > 1 else 0.1
|
||
h, w = frame.shape[:2]
|
||
# Use frame-varying key for noise
|
||
op_counter = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_counter + 1
|
||
seed = env.get('_seed', 42)
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1]
|
||
result = frame.astype(jnp.float32) + noise * amount * 255
|
||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
if op == 'translate':
|
||
frame = self._eval(args[0], env)
|
||
dx = self._eval(args[1], env)
|
||
dy = self._eval(args[2], env) if len(args) > 2 else 0
|
||
h, w = frame.shape[:2]
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
|
||
src_x = (x_coords - dx).flatten()
|
||
src_y = (y_coords - dy).flatten()
|
||
r, g, b = jax_sample(frame, src_x, src_y)
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'image:crop' or op == 'crop':
|
||
frame = self._eval(args[0], env)
|
||
x = int(self._eval(args[1], env))
|
||
y = int(self._eval(args[2], env))
|
||
w = int(self._eval(args[3], env))
|
||
h = int(self._eval(args[4], env))
|
||
return frame[y:y+h, x:x+w, :]
|
||
|
||
if op == 'dilate':
|
||
frame = self._eval(args[0], env)
|
||
size = int(self._eval(args[1], env)) if len(args) > 1 else 3
|
||
# Simple dilation using max pooling approximation
|
||
kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size)
|
||
h, w = frame.shape[:2]
|
||
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size)
|
||
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size)
|
||
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size)
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'map-rows':
|
||
frame = self._eval(args[0], env)
|
||
fn = args[1] # S-expression function
|
||
h, w = frame.shape[:2]
|
||
# For each row, apply the function
|
||
results = []
|
||
for row_idx in range(h):
|
||
row_env = env.copy()
|
||
row_env['row'] = frame[row_idx, :, :]
|
||
row_env['row-idx'] = row_idx
|
||
|
||
# Check if fn is a lambda
|
||
if isinstance(fn, list) and len(fn) >= 2:
|
||
head = fn[0]
|
||
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
|
||
params = fn[1]
|
||
body = fn[2]
|
||
# Bind lambda params to y and row
|
||
if len(params) >= 1:
|
||
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
|
||
row_env[param_name] = row_idx # y
|
||
if len(params) >= 2:
|
||
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
|
||
row_env[param_name] = frame[row_idx, :, :] # row
|
||
result_row = self._eval(body, row_env)
|
||
results.append(result_row)
|
||
continue
|
||
|
||
result_row = self._eval(fn, row_env)
|
||
# If result is a function, call it
|
||
if callable(result_row):
|
||
result_row = result_row(row_idx, frame[row_idx, :, :])
|
||
results.append(result_row)
|
||
return jnp.stack(results, axis=0)
|
||
|
||
# =====================================================================
|
||
# Text rendering operations
|
||
# =====================================================================
|
||
if op == 'text':
|
||
frame = self._eval(args[0], env)
|
||
text_str = self._eval(args[1], env)
|
||
if isinstance(text_str, Symbol):
|
||
text_str = text_str.name
|
||
text_str = str(text_str)
|
||
|
||
# Extract keyword arguments
|
||
x = self._eval_kwarg(args, 'x', None, env)
|
||
y = self._eval_kwarg(args, 'y', None, env)
|
||
font_size = self._eval_kwarg(args, 'font-size', 32, env)
|
||
font_name = self._eval_kwarg(args, 'font-name', None, env)
|
||
color = self._eval_kwarg(args, 'color', (255, 255, 255), env)
|
||
opacity = self._eval_kwarg(args, 'opacity', 1.0, env)
|
||
align = self._eval_kwarg(args, 'align', 'left', env)
|
||
valign = self._eval_kwarg(args, 'valign', 'top', env)
|
||
shadow = self._eval_kwarg(args, 'shadow', False, env)
|
||
shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env)
|
||
shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env)
|
||
fit = self._eval_kwarg(args, 'fit', False, env)
|
||
width = self._eval_kwarg(args, 'width', None, env)
|
||
height = self._eval_kwarg(args, 'height', None, env)
|
||
|
||
# Handle color as list/tuple
|
||
if isinstance(color, (list, tuple)):
|
||
color = tuple(int(c) for c in color[:3])
|
||
if isinstance(shadow_color, (list, tuple)):
|
||
shadow_color = tuple(int(c) for c in shadow_color[:3])
|
||
|
||
h, w_frame = frame.shape[:2]
|
||
|
||
# Default position to 0,0 or center based on alignment
|
||
if x is None:
|
||
if align == 'center':
|
||
x = w_frame // 2
|
||
elif align == 'right':
|
||
x = w_frame
|
||
else:
|
||
x = 0
|
||
if y is None:
|
||
if valign == 'middle':
|
||
y = h // 2
|
||
elif valign == 'bottom':
|
||
y = h
|
||
else:
|
||
y = 0
|
||
|
||
# Auto-fit text to bounds
|
||
if fit and width is not None and height is not None:
|
||
font_size = jax_fit_text_size(text_str, int(width), int(height),
|
||
font_name, min_size=8, max_size=200)
|
||
|
||
return jax_text_render(frame, text_str, int(x), int(y),
|
||
font_name=font_name, font_size=int(font_size),
|
||
color=color, opacity=float(opacity),
|
||
align=str(align), valign=str(valign),
|
||
shadow=bool(shadow), shadow_color=shadow_color,
|
||
shadow_offset=int(shadow_offset))
|
||
|
||
if op == 'text-size':
|
||
text_str = self._eval(args[0], env)
|
||
if isinstance(text_str, Symbol):
|
||
text_str = text_str.name
|
||
text_str = str(text_str)
|
||
font_size = self._eval_kwarg(args, 'font-size', 32, env)
|
||
font_name = self._eval_kwarg(args, 'font-name', None, env)
|
||
return jax_text_size(text_str, font_name, int(font_size))
|
||
|
||
if op == 'fit-text-size':
|
||
text_str = self._eval(args[0], env)
|
||
if isinstance(text_str, Symbol):
|
||
text_str = text_str.name
|
||
text_str = str(text_str)
|
||
max_width = int(self._eval(args[1], env))
|
||
max_height = int(self._eval(args[2], env))
|
||
font_name = self._eval_kwarg(args, 'font-name', None, env)
|
||
return jax_fit_text_size(text_str, max_width, max_height, font_name)
|
||
|
||
# =====================================================================
|
||
# Color operations
|
||
# =====================================================================
|
||
if op == 'rgb->hsv' or op == 'rgb-to-hsv':
|
||
# Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple
|
||
if len(args) == 1:
|
||
c = self._eval(args[0], env)
|
||
if isinstance(c, tuple) and len(c) == 3:
|
||
r, g, b = c
|
||
else:
|
||
# Assume it's a list-like
|
||
r, g, b = c[0], c[1], c[2]
|
||
else:
|
||
r = self._eval(args[0], env)
|
||
g = self._eval(args[1], env)
|
||
b = self._eval(args[2], env)
|
||
return jax_rgb_to_hsv(r, g, b)
|
||
|
||
if op == 'hsv->rgb' or op == 'hsv-to-rgb':
|
||
# Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list)
|
||
if len(args) == 1:
|
||
hsv = self._eval(args[0], env)
|
||
if isinstance(hsv, (tuple, list)) and len(hsv) >= 3:
|
||
h, s, v = hsv[0], hsv[1], hsv[2]
|
||
else:
|
||
h, s, v = hsv[0], hsv[1], hsv[2]
|
||
else:
|
||
h = self._eval(args[0], env)
|
||
s = self._eval(args[1], env)
|
||
v = self._eval(args[2], env)
|
||
return jax_hsv_to_rgb(h, s, v)
|
||
|
||
if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness':
|
||
frame = self._eval(args[0], env)
|
||
amount = self._eval(args[1], env)
|
||
return jax_adjust_brightness(frame, amount)
|
||
|
||
if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast':
|
||
frame = self._eval(args[0], env)
|
||
factor = self._eval(args[1], env)
|
||
return jax_adjust_contrast(frame, factor)
|
||
|
||
if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation':
|
||
frame = self._eval(args[0], env)
|
||
factor = self._eval(args[1], env)
|
||
return jax_adjust_saturation(frame, factor)
|
||
|
||
if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift':
|
||
frame = self._eval(args[0], env)
|
||
degrees = self._eval(args[1], env)
|
||
return jax_shift_hue(frame, degrees)
|
||
|
||
if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img':
|
||
frame = self._eval(args[0], env)
|
||
return jax_invert(frame)
|
||
|
||
if op == 'posterize' or op == 'color_ops:posterize':
|
||
frame = self._eval(args[0], env)
|
||
levels = self._eval(args[1], env)
|
||
return jax_posterize(frame, levels)
|
||
|
||
if op == 'threshold' or op == 'color_ops:threshold':
|
||
frame = self._eval(args[0], env)
|
||
level = self._eval(args[1], env)
|
||
invert = self._eval(args[2], env) if len(args) > 2 else False
|
||
return jax_threshold(frame, level, invert)
|
||
|
||
if op == 'sepia' or op == 'color_ops:sepia':
|
||
frame = self._eval(args[0], env)
|
||
return jax_sepia(frame)
|
||
|
||
if op == 'grayscale' or op == 'image:grayscale':
|
||
frame = self._eval(args[0], env)
|
||
return jax_grayscale(frame)
|
||
|
||
# =====================================================================
|
||
# Geometry operations
|
||
# =====================================================================
|
||
if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img':
|
||
frame = self._eval(args[0], env)
|
||
direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal'
|
||
if direction == 'vertical' or direction == 'v':
|
||
return jax_flip_vertical(frame)
|
||
return jax_flip_horizontal(frame)
|
||
|
||
if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v':
|
||
frame = self._eval(args[0], env)
|
||
return jax_flip_vertical(frame)
|
||
|
||
if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img':
|
||
frame = self._eval(args[0], env)
|
||
angle = self._eval(args[1], env)
|
||
return jax_rotate(frame, angle)
|
||
|
||
if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img':
|
||
frame = self._eval(args[0], env)
|
||
scale_x = self._eval(args[1], env)
|
||
scale_y = self._eval(args[2], env) if len(args) > 2 else None
|
||
return jax_scale(frame, scale_x, scale_y)
|
||
|
||
if op == 'resize' or op == 'image:resize':
|
||
frame = self._eval(args[0], env)
|
||
new_w = self._eval(args[1], env)
|
||
new_h = self._eval(args[2], env)
|
||
return jax_resize(frame, new_w, new_h)
|
||
|
||
# =====================================================================
|
||
# Geometry distortion effects
|
||
# =====================================================================
|
||
if op == 'geometry:fisheye-coords' or op == 'fisheye':
|
||
# Signature: (w h strength cx cy zoom_correct) or (frame strength)
|
||
first_arg = self._eval(args[0], env)
|
||
if not hasattr(first_arg, 'shape'):
|
||
# (w h strength cx cy zoom_correct) signature
|
||
w = int(first_arg)
|
||
h = int(self._eval(args[1], env))
|
||
strength = self._eval(args[2], env) if len(args) > 2 else 0.5
|
||
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
|
||
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
|
||
frame = None
|
||
else:
|
||
frame = first_arg
|
||
strength = self._eval(args[1], env) if len(args) > 1 else 0.5
|
||
h, w = frame.shape[:2]
|
||
cx, cy = w / 2, h / 2
|
||
|
||
max_r = jnp.sqrt(float(cx*cx + cy*cy))
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
dx = x_coords - cx
|
||
dy = y_coords - cy
|
||
r = jnp.sqrt(dx*dx + dy*dy)
|
||
theta = jnp.arctan2(dy, dx)
|
||
|
||
# Fisheye distortion
|
||
r_new = r + strength * r * (1 - r / max_r)
|
||
|
||
src_x = r_new * jnp.cos(theta) + cx
|
||
src_y = r_new * jnp.sin(theta) + cy
|
||
|
||
if frame is None:
|
||
return {'x': src_x, 'y': src_y}
|
||
|
||
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'geometry:swirl-coords' or op == 'swirl':
|
||
first_arg = self._eval(args[0], env)
|
||
if not hasattr(first_arg, 'shape'):
|
||
w = int(first_arg)
|
||
h = int(self._eval(args[1], env))
|
||
amount = self._eval(args[2], env) if len(args) > 2 else 1.0
|
||
frame = None
|
||
else:
|
||
frame = first_arg
|
||
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
|
||
h, w = frame.shape[:2]
|
||
|
||
cx, cy = w / 2, h / 2
|
||
max_r = jnp.sqrt(float(cx*cx + cy*cy))
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
dx = x_coords - cx
|
||
dy = y_coords - cy
|
||
r = jnp.sqrt(dx*dx + dy*dy)
|
||
theta = jnp.arctan2(dy, dx)
|
||
|
||
swirl_angle = amount * (1 - r / max_r)
|
||
new_theta = theta + swirl_angle
|
||
|
||
src_x = r * jnp.cos(new_theta) + cx
|
||
src_y = r * jnp.sin(new_theta) + cy
|
||
|
||
if frame is None:
|
||
return {'x': src_x, 'y': src_y}
|
||
|
||
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
# Wave effect (frame-first signature for simple usage)
|
||
if op == 'wave-distort':
|
||
first_arg = self._eval(args[0], env)
|
||
frame = first_arg
|
||
amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0
|
||
amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0
|
||
freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1
|
||
freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1
|
||
h, w = frame.shape[:2]
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y)
|
||
src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x)
|
||
|
||
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'geometry:ripple-displace' or op == 'ripple':
|
||
# Match Python prim_ripple_displace signature:
|
||
# (w h freq amp cx cy decay phase) or (frame ...)
|
||
first_arg = self._eval(args[0], env)
|
||
if not hasattr(first_arg, 'shape'):
|
||
# Coordinate-only mode: (w h freq amp cx cy decay phase)
|
||
w = int(first_arg)
|
||
h = int(self._eval(args[1], env))
|
||
freq = self._eval(args[2], env) if len(args) > 2 else 5.0
|
||
amp = self._eval(args[3], env) if len(args) > 3 else 10.0
|
||
cx = self._eval(args[4], env) if len(args) > 4 else w / 2
|
||
cy = self._eval(args[5], env) if len(args) > 5 else h / 2
|
||
decay = self._eval(args[6], env) if len(args) > 6 else 0.0
|
||
phase = self._eval(args[7], env) if len(args) > 7 else 0.0
|
||
frame = None
|
||
else:
|
||
# Frame mode: (frame :amplitude A :frequency F :center_x CX ...)
|
||
frame = first_arg
|
||
h, w = frame.shape[:2]
|
||
# Parse keyword args
|
||
amp = 10.0
|
||
freq = 5.0
|
||
cx = w / 2
|
||
cy = h / 2
|
||
decay = 0.0
|
||
phase = 0.0
|
||
i = 1
|
||
while i < len(args):
|
||
if isinstance(args[i], Keyword):
|
||
kw = args[i].name
|
||
val = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
||
if kw == 'amplitude':
|
||
amp = val
|
||
elif kw == 'frequency':
|
||
freq = val
|
||
elif kw == 'center_x':
|
||
cx = val * w if val <= 1 else val # normalized or absolute
|
||
elif kw == 'center_y':
|
||
cy = val * h if val <= 1 else val
|
||
elif kw == 'decay':
|
||
decay = val
|
||
elif kw == 'speed':
|
||
# speed affects phase via time
|
||
t = env.get('t', 0)
|
||
phase = t * val * 2 * jnp.pi
|
||
elif kw == 'phase':
|
||
phase = val
|
||
i += 2
|
||
else:
|
||
i += 1
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
dx = x_coords - cx
|
||
dy = y_coords - cy
|
||
dist = jnp.sqrt(dx*dx + dy*dy)
|
||
|
||
# Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp
|
||
max_dim = jnp.maximum(w, h)
|
||
ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp
|
||
|
||
# Apply decay (when decay=0, exp(0)=1 so no effect)
|
||
decay_factor = jnp.exp(-decay * dist / max_dim)
|
||
ripple = ripple * decay_factor
|
||
|
||
# Radial displacement - use ADDITION to match Python prim_ripple_displace
|
||
# Python (primitives.py line 2890-2891):
|
||
# map_x = x_coords + ripple * norm_dx
|
||
# map_y = y_coords + ripple * norm_dy
|
||
# where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle)
|
||
angle = jnp.arctan2(dy, dx)
|
||
src_x = x_coords + ripple * jnp.cos(angle)
|
||
src_y = y_coords + ripple * jnp.sin(angle)
|
||
|
||
if frame is None:
|
||
return {'x': src_x, 'y': src_y}
|
||
|
||
# Sample using bilinear interpolation (jax_sample clamps coords,
|
||
# matching OpenCV's default remap behavior)
|
||
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'geometry:coords-x' or op == 'coords-x':
|
||
# Extract x coordinates from coord dict
|
||
coords = self._eval(args[0], env)
|
||
if isinstance(coords, dict):
|
||
return coords.get('x', coords.get('map_x'))
|
||
return coords[0] if isinstance(coords, (list, tuple)) else coords
|
||
|
||
if op == 'geometry:coords-y' or op == 'coords-y':
|
||
# Extract y coordinates from coord dict
|
||
coords = self._eval(args[0], env)
|
||
if isinstance(coords, dict):
|
||
return coords.get('y', coords.get('map_y'))
|
||
return coords[1] if isinstance(coords, (list, tuple)) else coords
|
||
|
||
if op == 'geometry:remap' or op == 'remap':
|
||
# Remap image using coordinate maps: (frame map_x map_y)
|
||
# OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords
|
||
frame = self._eval(args[0], env)
|
||
map_x = self._eval(args[1], env)
|
||
map_y = self._eval(args[2], env)
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# Flatten coordinate maps
|
||
src_x = map_x.flatten()
|
||
src_y = map_y.flatten()
|
||
|
||
# Sample using bilinear interpolation (jax_sample clamps coords internally,
|
||
# matching OpenCV's default behavior)
|
||
r_out, g_out, b_out = jax_sample(frame, src_x, src_y)
|
||
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope':
|
||
# Two signatures: (frame segments) or (w h segments cx cy)
|
||
if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'):
|
||
# (w h segments cx cy) signature
|
||
w = int(self._eval(args[0], env))
|
||
h = int(self._eval(args[1], env))
|
||
segments = int(self._eval(args[2], env)) if len(args) > 2 else 6
|
||
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
|
||
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
|
||
frame = None
|
||
else:
|
||
frame = self._eval(args[0], env)
|
||
segments = int(self._eval(args[1], env)) if len(args) > 1 else 6
|
||
h, w = frame.shape[:2]
|
||
cx, cy = w / 2, h / 2
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
dx = x_coords - cx
|
||
dy = y_coords - cy
|
||
r = jnp.sqrt(dx*dx + dy*dy)
|
||
theta = jnp.arctan2(dy, dx)
|
||
|
||
# Mirror into segments
|
||
segment_angle = 2 * jnp.pi / segments
|
||
theta_mod = theta % segment_angle
|
||
theta_mirror = jnp.where(
|
||
(jnp.floor(theta / segment_angle) % 2) == 0,
|
||
theta_mod,
|
||
segment_angle - theta_mod
|
||
)
|
||
|
||
src_x = r * jnp.cos(theta_mirror) + cx
|
||
src_y = r * jnp.sin(theta_mirror) + cy
|
||
|
||
if frame is None:
|
||
# Return coordinate arrays
|
||
return {'x': src_x, 'y': src_y}
|
||
|
||
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
# Geometry coordinate extraction
|
||
if op == 'geometry:coords-x':
|
||
coords = self._eval(args[0], env)
|
||
if isinstance(coords, dict):
|
||
return coords['x']
|
||
return coords[0] if isinstance(coords, tuple) else coords
|
||
|
||
if op == 'geometry:coords-y':
|
||
coords = self._eval(args[0], env)
|
||
if isinstance(coords, dict):
|
||
return coords['y']
|
||
return coords[1] if isinstance(coords, tuple) else coords
|
||
|
||
if op == 'geometry:remap' or op == 'remap':
|
||
frame = self._eval(args[0], env)
|
||
x_coords = self._eval(args[1], env)
|
||
y_coords = self._eval(args[2], env)
|
||
h, w = frame.shape[:2]
|
||
r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
# =====================================================================
|
||
# Blending operations
|
||
# =====================================================================
|
||
if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images':
|
||
frame1 = self._eval(args[0], env)
|
||
frame2 = self._eval(args[1], env)
|
||
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
|
||
return jax_blend(frame1, frame2, alpha)
|
||
|
||
if op == 'blend-add':
|
||
frame1 = self._eval(args[0], env)
|
||
frame2 = self._eval(args[1], env)
|
||
return jax_blend_add(frame1, frame2)
|
||
|
||
if op == 'blend-multiply':
|
||
frame1 = self._eval(args[0], env)
|
||
frame2 = self._eval(args[1], env)
|
||
return jax_blend_multiply(frame1, frame2)
|
||
|
||
if op == 'blend-screen':
|
||
frame1 = self._eval(args[0], env)
|
||
frame2 = self._eval(args[1], env)
|
||
return jax_blend_screen(frame1, frame2)
|
||
|
||
if op == 'blend-overlay':
|
||
frame1 = self._eval(args[0], env)
|
||
frame2 = self._eval(args[1], env)
|
||
return jax_blend_overlay(frame1, frame2)
|
||
|
||
# =====================================================================
|
||
# Image dimension queries (namespaced aliases)
|
||
# =====================================================================
|
||
if op == 'image:width':
|
||
if args:
|
||
frame = self._eval(args[0], env)
|
||
return frame.shape[1] # width is second dimension (h, w, c)
|
||
return env['width']
|
||
|
||
if op == 'image:height':
|
||
if args:
|
||
frame = self._eval(args[0], env)
|
||
return frame.shape[0] # height is first dimension (h, w, c)
|
||
return env['height']
|
||
|
||
# =====================================================================
|
||
# Utility
|
||
# =====================================================================
|
||
if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?':
|
||
x = self._eval(args[0], env)
|
||
return jax_is_nil(x)
|
||
|
||
# =====================================================================
|
||
# Xector channel operations (shortcuts)
|
||
# =====================================================================
|
||
if op == 'red':
|
||
val = self._eval(args[0], env)
|
||
# Works on frames or pixel tuples
|
||
if isinstance(val, tuple):
|
||
return val[0]
|
||
elif hasattr(val, 'shape') and val.ndim == 3:
|
||
return jax_channel(val, 0)
|
||
else:
|
||
return val # Assume it's already a channel
|
||
|
||
if op == 'green':
|
||
val = self._eval(args[0], env)
|
||
if isinstance(val, tuple):
|
||
return val[1]
|
||
elif hasattr(val, 'shape') and val.ndim == 3:
|
||
return jax_channel(val, 1)
|
||
else:
|
||
return val
|
||
|
||
if op == 'blue':
|
||
val = self._eval(args[0], env)
|
||
if isinstance(val, tuple):
|
||
return val[2]
|
||
elif hasattr(val, 'shape') and val.ndim == 3:
|
||
return jax_channel(val, 2)
|
||
else:
|
||
return val
|
||
|
||
if op == 'gray' or op == 'luminance':
|
||
val = self._eval(args[0], env)
|
||
# Handle tuple (r, g, b) from map-pixels
|
||
if isinstance(val, tuple) and len(val) == 3:
|
||
r, g, b = val
|
||
return r * 0.299 + g * 0.587 + b * 0.114
|
||
# Handle frame
|
||
frame = val
|
||
r = frame[:, :, 0].flatten().astype(jnp.float32)
|
||
g = frame[:, :, 1].flatten().astype(jnp.float32)
|
||
b = frame[:, :, 2].flatten().astype(jnp.float32)
|
||
return r * 0.299 + g * 0.587 + b * 0.114
|
||
|
||
if op == 'rgb':
|
||
r = self._eval(args[0], env)
|
||
g = self._eval(args[1], env)
|
||
b = self._eval(args[2], env)
|
||
# For scalars (e.g., in map-pixels), return tuple
|
||
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
|
||
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
|
||
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
|
||
if r_is_scalar and g_is_scalar and b_is_scalar:
|
||
return (r, g, b)
|
||
return jax_merge_channels(r, g, b, env['_shape'])
|
||
|
||
# =====================================================================
|
||
# Coordinate operations
|
||
# =====================================================================
|
||
if op == 'x-coords':
|
||
frame = self._eval(args[0], env)
|
||
h, w = frame.shape[:2]
|
||
return jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
|
||
|
||
if op == 'y-coords':
|
||
frame = self._eval(args[0], env)
|
||
h, w = frame.shape[:2]
|
||
return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
|
||
|
||
if op == 'dist-from-center':
|
||
frame = self._eval(args[0], env)
|
||
h, w = frame.shape[:2]
|
||
cx, cy = w / 2, h / 2
|
||
x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx
|
||
y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy
|
||
return jnp.sqrt(x*x + y*y)
|
||
|
||
# =====================================================================
|
||
# Alpha operations (element-wise on xectors)
|
||
# =====================================================================
|
||
if op == 'α/' or op == 'alpha/':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a / b
|
||
|
||
if op == 'α+' or op == 'alpha+':
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0]
|
||
for v in vals[1:]:
|
||
result = result + v
|
||
return result
|
||
|
||
if op == 'α*' or op == 'alpha*':
|
||
vals = [self._eval(a, env) for a in args]
|
||
result = vals[0]
|
||
for v in vals[1:]:
|
||
result = result * v
|
||
return result
|
||
|
||
if op == 'α-' or op == 'alpha-':
|
||
if len(args) == 1:
|
||
return -self._eval(args[0], env)
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a - b
|
||
|
||
if op == 'αclamp' or op == 'alpha-clamp':
|
||
x = self._eval(args[0], env)
|
||
lo = self._eval(args[1], env)
|
||
hi = self._eval(args[2], env)
|
||
return jnp.clip(x, lo, hi)
|
||
|
||
if op == 'αmin' or op == 'alpha-min':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return jnp.minimum(a, b)
|
||
|
||
if op == 'αmax' or op == 'alpha-max':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return jnp.maximum(a, b)
|
||
|
||
if op == 'αsqrt' or op == 'alpha-sqrt':
|
||
return jnp.sqrt(self._eval(args[0], env))
|
||
|
||
if op == 'αsin' or op == 'alpha-sin':
|
||
return jnp.sin(self._eval(args[0], env))
|
||
|
||
if op == 'αcos' or op == 'alpha-cos':
|
||
return jnp.cos(self._eval(args[0], env))
|
||
|
||
if op == 'αmod' or op == 'alpha-mod':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a % b
|
||
|
||
if op == 'α²' or op == 'αsq' or op == 'alpha-sq':
|
||
x = self._eval(args[0], env)
|
||
return x * x
|
||
|
||
if op == 'α<' or op == 'alpha<':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a < b
|
||
|
||
if op == 'α>' or op == 'alpha>':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a > b
|
||
|
||
if op == 'α<=' or op == 'alpha<=':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a <= b
|
||
|
||
if op == 'α>=' or op == 'alpha>=':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a >= b
|
||
|
||
if op == 'α=' or op == 'alpha=':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return a == b
|
||
|
||
if op == 'αfloor' or op == 'alpha-floor':
|
||
return jnp.floor(self._eval(args[0], env))
|
||
|
||
if op == 'αceil' or op == 'alpha-ceil':
|
||
return jnp.ceil(self._eval(args[0], env))
|
||
|
||
if op == 'αround' or op == 'alpha-round':
|
||
return jnp.round(self._eval(args[0], env))
|
||
|
||
if op == 'αabs' or op == 'alpha-abs':
|
||
return jnp.abs(self._eval(args[0], env))
|
||
|
||
if op == 'αexp' or op == 'alpha-exp':
|
||
return jnp.exp(self._eval(args[0], env))
|
||
|
||
if op == 'αlog' or op == 'alpha-log':
|
||
return jnp.log(self._eval(args[0], env))
|
||
|
||
if op == 'αor' or op == 'alpha-or':
|
||
# Element-wise logical OR
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return jnp.logical_or(a, b)
|
||
|
||
if op == 'αand' or op == 'alpha-and':
|
||
# Element-wise logical AND
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return jnp.logical_and(a, b)
|
||
|
||
if op == 'αnot' or op == 'alpha-not':
|
||
# Element-wise logical NOT
|
||
return jnp.logical_not(self._eval(args[0], env))
|
||
|
||
if op == 'αxor' or op == 'alpha-xor':
|
||
# Element-wise logical XOR
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
return jnp.logical_xor(a, b)
|
||
|
||
# =====================================================================
|
||
# Threading/arrow operations
|
||
# =====================================================================
|
||
if op == '->':
|
||
# Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b)
|
||
val = self._eval(args[0], env)
|
||
for form in args[1:]:
|
||
if isinstance(form, list):
|
||
# Insert val as first argument
|
||
fn_name = form[0].name if isinstance(form[0], Symbol) else form[0]
|
||
new_args = [val] + [self._eval(a, env) for a in form[1:]]
|
||
val = self._eval_call(fn_name, [val] + form[1:], env)
|
||
else:
|
||
# Simple function call
|
||
fn_name = form.name if isinstance(form, Symbol) else form
|
||
val = self._eval_call(fn_name, [args[0]], env)
|
||
return val
|
||
|
||
# =====================================================================
|
||
# Range and iteration
|
||
# =====================================================================
|
||
if op == 'range':
|
||
if len(args) == 1:
|
||
end = int(self._eval(args[0], env))
|
||
return list(range(end))
|
||
elif len(args) == 2:
|
||
start = int(self._eval(args[0], env))
|
||
end = int(self._eval(args[1], env))
|
||
return list(range(start, end))
|
||
else:
|
||
start = int(self._eval(args[0], env))
|
||
end = int(self._eval(args[1], env))
|
||
step = int(self._eval(args[2], env))
|
||
return list(range(start, end, step))
|
||
|
||
if op == 'reduce' or op == 'fold':
|
||
# (reduce seq init fn) - left fold
|
||
seq = self._eval(args[0], env)
|
||
acc = self._eval(args[1], env)
|
||
fn = args[2] # Lambda S-expression
|
||
|
||
# Handle lambda
|
||
if isinstance(fn, list) and len(fn) >= 3:
|
||
head = fn[0]
|
||
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
|
||
params = fn[1]
|
||
body = fn[2]
|
||
for item in seq:
|
||
fn_env = env.copy()
|
||
if len(params) >= 1:
|
||
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
|
||
fn_env[param_name] = acc
|
||
if len(params) >= 2:
|
||
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
|
||
fn_env[param_name] = item
|
||
acc = self._eval(body, fn_env)
|
||
return acc
|
||
|
||
# Fallback - try evaluating fn and calling it
|
||
fn_eval = self._eval(fn, env)
|
||
if callable(fn_eval):
|
||
for item in seq:
|
||
acc = fn_eval(acc, item)
|
||
return acc
|
||
|
||
if op == 'fold-indexed':
|
||
# (fold-indexed seq init fn) - fold with index
|
||
# fn takes (acc item index) or (acc item index cursor) for typography
|
||
seq = self._eval(args[0], env)
|
||
acc = self._eval(args[1], env)
|
||
fn = args[2] # Lambda S-expression
|
||
|
||
# Handle lambda
|
||
if isinstance(fn, list) and len(fn) >= 3:
|
||
head = fn[0]
|
||
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
|
||
params = fn[1]
|
||
body = fn[2]
|
||
for idx, item in enumerate(seq):
|
||
fn_env = env.copy()
|
||
if len(params) >= 1:
|
||
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
|
||
fn_env[param_name] = acc
|
||
if len(params) >= 2:
|
||
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
|
||
fn_env[param_name] = item
|
||
if len(params) >= 3:
|
||
param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2])
|
||
fn_env[param_name] = idx
|
||
acc = self._eval(body, fn_env)
|
||
return acc
|
||
|
||
# Fallback
|
||
fn_eval = self._eval(fn, env)
|
||
if callable(fn_eval):
|
||
for idx, item in enumerate(seq):
|
||
acc = fn_eval(acc, item, idx)
|
||
return acc
|
||
|
||
# =====================================================================
|
||
# Map-pixels (apply function to each pixel)
|
||
# =====================================================================
|
||
if op == 'map-pixels':
|
||
frame = self._eval(args[0], env)
|
||
fn = args[1] # Lambda or S-expression
|
||
h, w = frame.shape[:2]
|
||
|
||
# Extract channels
|
||
r = frame[:, :, 0].flatten().astype(jnp.float32)
|
||
g = frame[:, :, 1].flatten().astype(jnp.float32)
|
||
b = frame[:, :, 2].flatten().astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
|
||
y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
|
||
|
||
# Set up pixel environment
|
||
pixel_env = env.copy()
|
||
pixel_env['r'] = r
|
||
pixel_env['g'] = g
|
||
pixel_env['b'] = b
|
||
pixel_env['x'] = x_coords
|
||
pixel_env['y'] = y_coords
|
||
# Also provide c (color) as a tuple for lambda (x y c) style
|
||
pixel_env['c'] = (r, g, b)
|
||
|
||
# If fn is a lambda, we need to handle it specially
|
||
if isinstance(fn, list) and len(fn) >= 2:
|
||
head = fn[0]
|
||
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
|
||
# Lambda: (lambda (x y c) body)
|
||
params = fn[1]
|
||
body = fn[2]
|
||
# Bind parameters
|
||
if len(params) >= 1:
|
||
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
|
||
pixel_env[param_name] = x_coords
|
||
if len(params) >= 2:
|
||
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
|
||
pixel_env[param_name] = y_coords
|
||
if len(params) >= 3:
|
||
param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2])
|
||
pixel_env[param_name] = (r, g, b)
|
||
result = self._eval(body, pixel_env)
|
||
else:
|
||
result = self._eval(fn, pixel_env)
|
||
else:
|
||
result = self._eval(fn, pixel_env)
|
||
|
||
if isinstance(result, tuple) and len(result) == 3:
|
||
nr, ng, nb = result
|
||
return jax_merge_channels(nr, ng, nb, (h, w))
|
||
elif hasattr(result, 'shape') and result.ndim == 3:
|
||
return result
|
||
else:
|
||
# Single channel result
|
||
if hasattr(result, 'flatten'):
|
||
result = result.flatten()
|
||
gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
return jnp.stack([gray, gray, gray], axis=2)
|
||
|
||
# =====================================================================
|
||
# State operations (return unchanged for stateless JIT)
|
||
# =====================================================================
|
||
if op == 'state-get':
|
||
key = self._eval(args[0], env)
|
||
default = self._eval(args[1], env) if len(args) > 1 else None
|
||
return default # State not supported in JIT, return default
|
||
|
||
if op == 'state-set':
|
||
return None # No-op in JIT
|
||
|
||
# =====================================================================
|
||
# Cell/grid operations
|
||
# =====================================================================
|
||
if op == 'local-x-norm':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
x = jnp.tile(jnp.arange(w), h)
|
||
return (x % cell_size) / max(1, cell_size - 1)
|
||
|
||
if op == 'local-y-norm':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
y = jnp.repeat(jnp.arange(h), w)
|
||
return (y % cell_size) / max(1, cell_size - 1)
|
||
|
||
if op == 'local-x':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
x = jnp.tile(jnp.arange(w), h)
|
||
return (x % cell_size).astype(jnp.float32)
|
||
|
||
if op == 'local-y':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
y = jnp.repeat(jnp.arange(h), w)
|
||
return (y % cell_size).astype(jnp.float32)
|
||
|
||
if op == 'cell-row':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
y = jnp.repeat(jnp.arange(h), w)
|
||
return jnp.floor(y / cell_size)
|
||
|
||
if op == 'cell-col':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
h, w = frame.shape[:2]
|
||
x = jnp.tile(jnp.arange(w), h)
|
||
return jnp.floor(x / cell_size)
|
||
|
||
if op == 'num-rows':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
return frame.shape[0] // cell_size
|
||
|
||
if op == 'num-cols':
|
||
frame = self._eval(args[0], env)
|
||
cell_size = int(self._eval(args[1], env))
|
||
return frame.shape[1] // cell_size
|
||
|
||
# =====================================================================
|
||
# Control flow
|
||
# =====================================================================
|
||
if op == 'cond':
|
||
# (cond (test1 expr1) (test2 expr2) ... (else exprN))
|
||
# For JAX compatibility, build a nested jnp.where structure
|
||
# Start from the else clause and work backwards
|
||
|
||
# Collect clauses
|
||
clauses = []
|
||
else_expr = None
|
||
for clause in args:
|
||
if isinstance(clause, list) and len(clause) >= 2:
|
||
test = clause[0]
|
||
if isinstance(test, Symbol) and test.name == 'else':
|
||
else_expr = clause[1]
|
||
else:
|
||
clauses.append((test, clause[1]))
|
||
|
||
# If no else, default to None/0
|
||
if else_expr is not None:
|
||
result = self._eval(else_expr, env)
|
||
else:
|
||
result = 0
|
||
|
||
# Build nested where from last to first
|
||
for test_expr, val_expr in reversed(clauses):
|
||
cond_val = self._eval(test_expr, env)
|
||
then_val = self._eval(val_expr, env)
|
||
|
||
# Check if condition is array or scalar
|
||
if hasattr(cond_val, 'shape') and cond_val.shape != ():
|
||
# Array condition - use jnp.where
|
||
result = jnp.where(cond_val, then_val, result)
|
||
else:
|
||
# Scalar - can use Python if
|
||
if cond_val:
|
||
result = then_val
|
||
|
||
return result
|
||
|
||
if op == 'set!' or op == 'set':
|
||
# Mutation - not really supported in JAX, but we can update env
|
||
var = args[0].name if isinstance(args[0], Symbol) else str(args[0])
|
||
val = self._eval(args[1], env)
|
||
env[var] = val
|
||
return val
|
||
|
||
if op == 'begin' or op == 'do':
|
||
# Evaluate all expressions, return last
|
||
result = None
|
||
for expr in args:
|
||
result = self._eval(expr, env)
|
||
return result
|
||
|
||
# =====================================================================
|
||
# Additional math
|
||
# =====================================================================
|
||
if op == 'sq' or op == 'square':
|
||
x = self._eval(args[0], env)
|
||
return x * x
|
||
|
||
if op == 'lerp':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
t = self._eval(args[2], env)
|
||
return a * (1 - t) + b * t
|
||
|
||
if op == 'smoothstep':
|
||
edge0 = self._eval(args[0], env)
|
||
edge1 = self._eval(args[1], env)
|
||
x = self._eval(args[2], env)
|
||
t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1)
|
||
return t * t * (3 - 2 * t)
|
||
|
||
if op == 'atan2':
|
||
y = self._eval(args[0], env)
|
||
x = self._eval(args[1], env)
|
||
return jnp.arctan2(y, x)
|
||
|
||
if op == 'fract' or op == 'frac':
|
||
x = self._eval(args[0], env)
|
||
return x - jnp.floor(x)
|
||
|
||
# =====================================================================
|
||
# Frame copy and construction operations
|
||
# =====================================================================
|
||
if op == 'pixel':
|
||
# Get pixel at (x, y) from frame
|
||
frame = self._eval(args[0], env)
|
||
x = self._eval(args[1], env)
|
||
y = self._eval(args[2], env)
|
||
h, w = frame.shape[:2]
|
||
# Convert to int and clip to bounds
|
||
if isinstance(x, (int, float)):
|
||
x = max(0, min(int(x), w - 1))
|
||
else:
|
||
x = jnp.clip(x, 0, w - 1).astype(jnp.int32)
|
||
if isinstance(y, (int, float)):
|
||
y = max(0, min(int(y), h - 1))
|
||
else:
|
||
y = jnp.clip(y, 0, h - 1).astype(jnp.int32)
|
||
r = frame[y, x, 0]
|
||
g = frame[y, x, 1]
|
||
b = frame[y, x, 2]
|
||
return (r, g, b)
|
||
|
||
if op == 'copy':
|
||
frame = self._eval(args[0], env)
|
||
return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame)
|
||
|
||
if op == 'make-image':
|
||
w = int(self._eval(args[0], env))
|
||
h = int(self._eval(args[1], env))
|
||
if len(args) > 2:
|
||
color = self._eval(args[2], env)
|
||
if isinstance(color, (list, tuple)):
|
||
r, g, b = int(color[0]), int(color[1]), int(color[2])
|
||
else:
|
||
r = g = b = int(color)
|
||
else:
|
||
r = g = b = 0
|
||
img = jnp.zeros((h, w, 3), dtype=jnp.uint8)
|
||
img = img.at[:, :, 0].set(r)
|
||
img = img.at[:, :, 1].set(g)
|
||
img = img.at[:, :, 2].set(b)
|
||
return img
|
||
|
||
if op == 'paste':
|
||
dest = self._eval(args[0], env)
|
||
src = self._eval(args[1], env)
|
||
x = int(self._eval(args[2], env))
|
||
y = int(self._eval(args[3], env))
|
||
sh, sw = src.shape[:2]
|
||
dh, dw = dest.shape[:2]
|
||
# Clip to dest bounds
|
||
x1, y1 = max(0, x), max(0, y)
|
||
x2, y2 = min(dw, x + sw), min(dh, y + sh)
|
||
sx1, sy1 = x1 - x, y1 - y
|
||
sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1)
|
||
result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest)
|
||
result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :])
|
||
return result
|
||
|
||
# =====================================================================
|
||
# Blending operations
|
||
# =====================================================================
|
||
if op == 'blending:blend-images' or op == 'blend-images':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
|
||
return jax_blend(a, b, alpha)
|
||
|
||
if op == 'blending:blend-mode' or op == 'blend-mode':
|
||
a = self._eval(args[0], env)
|
||
b = self._eval(args[1], env)
|
||
mode = self._eval(args[2], env) if len(args) > 2 else 'add'
|
||
if mode == 'add':
|
||
return jax_blend_add(a, b)
|
||
elif mode == 'multiply':
|
||
return jax_blend_multiply(a, b)
|
||
elif mode == 'screen':
|
||
return jax_blend_screen(a, b)
|
||
elif mode == 'overlay':
|
||
return jax_blend_overlay(a, b)
|
||
elif mode == 'lighten':
|
||
return jnp.maximum(a, b)
|
||
elif mode == 'darken':
|
||
return jnp.minimum(a, b)
|
||
elif mode == 'difference':
|
||
return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8)
|
||
else:
|
||
return jax_blend(a, b, 0.5)
|
||
|
||
# =====================================================================
|
||
# Geometry coordinate operations
|
||
# =====================================================================
|
||
if op == 'geometry:wave-coords' or op == 'wave-coords':
|
||
w = int(self._eval(args[0], env))
|
||
h = int(self._eval(args[1], env))
|
||
axis = self._eval(args[2], env) if len(args) > 2 else 'x'
|
||
freq = self._eval(args[3], env) if len(args) > 3 else 1.0
|
||
amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0
|
||
phase = self._eval(args[5], env) if len(args) > 5 else 0.0
|
||
|
||
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||
|
||
if axis == 'x' or axis == 'horizontal':
|
||
# Wave displaces X based on Y
|
||
offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
|
||
src_x = x_coords + offset
|
||
src_y = y_coords
|
||
elif axis == 'y' or axis == 'vertical':
|
||
# Wave displaces Y based on X
|
||
offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
|
||
src_x = x_coords
|
||
src_y = y_coords + offset
|
||
else: # both
|
||
offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
|
||
offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
|
||
src_x = x_coords + offset_x
|
||
src_y = y_coords + offset_y
|
||
|
||
return {'x': src_x, 'y': src_y}
|
||
|
||
if op == 'geometry:coords-x' or op == 'coords-x':
|
||
coords = self._eval(args[0], env)
|
||
return coords['x']
|
||
|
||
if op == 'geometry:coords-y' or op == 'coords-y':
|
||
coords = self._eval(args[0], env)
|
||
return coords['y']
|
||
|
||
if op == 'geometry:remap' or op == 'remap':
|
||
frame = self._eval(args[0], env)
|
||
x = self._eval(args[1], env)
|
||
y = self._eval(args[2], env)
|
||
h, w = frame.shape[:2]
|
||
r, g, b = jax_sample(frame, x.flatten(), y.flatten())
|
||
return jnp.stack([
|
||
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||
], axis=2)
|
||
|
||
# =====================================================================
|
||
# Glitch effects
|
||
# =====================================================================
|
||
if op == 'pixelsort':
|
||
frame = self._eval(args[0], env)
|
||
sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness'
|
||
thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50
|
||
thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200
|
||
angle = int(self._eval(args[4], env)) if len(args) > 4 else 0
|
||
reverse = self._eval(args[5], env) if len(args) > 5 else False
|
||
|
||
h, w = frame.shape[:2]
|
||
result = frame.copy()
|
||
|
||
# Get luminance for thresholding
|
||
lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
|
||
frame[:, :, 1].astype(jnp.float32) * 0.587 +
|
||
frame[:, :, 2].astype(jnp.float32) * 0.114)
|
||
|
||
# Sort each row
|
||
for y in range(h):
|
||
row_lum = lum[y, :]
|
||
row = frame[y, :, :]
|
||
|
||
# Find mask of pixels to sort
|
||
mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi)
|
||
|
||
# Get indices where we should sort
|
||
sort_indices = jnp.where(mask, jnp.arange(w), -1)
|
||
|
||
# Simple sort by luminance for the row
|
||
if sort_by == 'lightness':
|
||
sort_key = row_lum
|
||
elif sort_by == 'hue':
|
||
# Approximate hue from RGB
|
||
sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32),
|
||
row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32)))
|
||
else:
|
||
sort_key = row_lum
|
||
|
||
# Sort pixels in masked region
|
||
sorted_indices = jnp.argsort(sort_key)
|
||
if reverse:
|
||
sorted_indices = sorted_indices[::-1]
|
||
|
||
# Apply partial sort (only where mask is true)
|
||
# This is a simplified version - full pixelsort is more complex
|
||
result = result.at[y, :, :].set(row[sorted_indices])
|
||
|
||
return result
|
||
|
||
if op == 'datamosh':
|
||
frame = self._eval(args[0], env)
|
||
prev = self._eval(args[1], env)
|
||
block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32
|
||
corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3
|
||
max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50
|
||
color_corrupt = self._eval(args[5], env) if len(args) > 5 else True
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# Use deterministic random for JIT with frame variation
|
||
seed = env.get('_seed', 42)
|
||
op_counter = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_counter + 1
|
||
key = make_jax_key(seed, frame_num, op_counter)
|
||
|
||
num_blocks_y = h // block_size
|
||
num_blocks_x = w // block_size
|
||
total_blocks = num_blocks_y * num_blocks_x
|
||
|
||
# Pre-generate all random values at once (vectorized)
|
||
key, k1, k2, k3, k4, k5 = jax.random.split(key, 6)
|
||
corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption
|
||
offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1)
|
||
offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1)
|
||
channels = jax.random.randint(k4, (total_blocks,), 0, 3)
|
||
color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51)
|
||
|
||
# Create coordinate grids for blocks
|
||
by_grid = jnp.arange(num_blocks_y)
|
||
bx_grid = jnp.arange(num_blocks_x)
|
||
|
||
# Create block coordinate arrays
|
||
by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...]
|
||
bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...]
|
||
|
||
# Create pixel coordinate grids
|
||
y_coords, x_coords = jnp.mgrid[:h, :w]
|
||
|
||
# Determine which block each pixel belongs to
|
||
pixel_block_y = y_coords // block_size
|
||
pixel_block_x = x_coords // block_size
|
||
pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x
|
||
|
||
# Clamp to valid block indices (for pixels outside the block grid)
|
||
pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1)
|
||
|
||
# Get the corrupt mask, offsets for each pixel's block
|
||
pixel_corrupt = corrupt_mask[pixel_block_idx]
|
||
pixel_offset_y = offsets_y[pixel_block_idx]
|
||
pixel_offset_x = offsets_x[pixel_block_idx]
|
||
|
||
# Calculate source coordinates with offset (clamped)
|
||
src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1)
|
||
src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1)
|
||
|
||
# Sample from previous frame at offset positions
|
||
prev_sampled = prev[src_y, src_x, :]
|
||
|
||
# Where corrupt mask is true, use prev_sampled; else use frame
|
||
result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame)
|
||
|
||
# Apply color corruption to corrupted blocks
|
||
if color_corrupt:
|
||
pixel_channel = channels[pixel_block_idx]
|
||
pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16)
|
||
|
||
# Create per-channel shift arrays (only shift the selected channel)
|
||
shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0)
|
||
shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0)
|
||
shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0)
|
||
|
||
result_int = result.astype(jnp.int16)
|
||
result_int = result_int.at[:, :, 0].add(shift_r)
|
||
result_int = result_int.at[:, :, 1].add(shift_g)
|
||
result_int = result_int.at[:, :, 2].add(shift_b)
|
||
result = jnp.clip(result_int, 0, 255).astype(jnp.uint8)
|
||
|
||
return result
|
||
|
||
# =====================================================================
|
||
# ASCII Art Operations (using pre-rendered font atlas)
|
||
# =====================================================================
|
||
|
||
if op == 'cell-sample':
|
||
# (cell-sample frame char_size) -> (colors, luminances)
|
||
# Downsample frame into cells, return average colors and luminances
|
||
frame = self._eval(args[0], env)
|
||
char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8
|
||
|
||
h, w = frame.shape[:2]
|
||
num_rows = h // char_size
|
||
num_cols = w // char_size
|
||
|
||
# Crop to exact multiple of char_size
|
||
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
|
||
|
||
# Reshape to (num_rows, char_size, num_cols, char_size, 3)
|
||
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
|
||
|
||
# Average over char_size dimensions -> (num_rows, num_cols, 3)
|
||
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
|
||
|
||
# Compute luminance per cell
|
||
colors_float = colors.astype(jnp.float32)
|
||
luminances = (0.299 * colors_float[:, :, 0] +
|
||
0.587 * colors_float[:, :, 1] +
|
||
0.114 * colors_float[:, :, 2]) / 255.0
|
||
|
||
return (colors, luminances)
|
||
|
||
if op == 'luminance-to-chars':
|
||
# (luminance-to-chars luminances alphabet contrast) -> char_indices
|
||
# Map luminance values to character indices
|
||
luminances = self._eval(args[0], env)
|
||
alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard'
|
||
contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5
|
||
|
||
# Get alphabet string
|
||
alpha_str = _get_alphabet_string(alphabet)
|
||
num_chars = len(alpha_str)
|
||
|
||
# Apply contrast
|
||
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
|
||
|
||
# Map to character indices (0 = darkest, num_chars-1 = brightest)
|
||
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
|
||
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
|
||
|
||
return char_indices
|
||
|
||
if op == 'render-char-grid':
|
||
# (render-char-grid frame chars colors char_size color_mode background_color invert_colors)
|
||
# Render character grid using font atlas
|
||
frame = self._eval(args[0], env)
|
||
char_indices = self._eval(args[1], env)
|
||
colors = self._eval(args[2], env)
|
||
char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8
|
||
color_mode = self._eval(args[4], env) if len(args) > 4 else 'color'
|
||
background_color = self._eval(args[5], env) if len(args) > 5 else 'black'
|
||
invert_colors = self._eval(args[6], env) if len(args) > 6 else 0
|
||
|
||
h, w = frame.shape[:2]
|
||
num_rows, num_cols = char_indices.shape
|
||
|
||
# Get the alphabet used (stored in env or default)
|
||
alphabet = env.get('_ascii_alphabet', 'standard')
|
||
alpha_str = _get_alphabet_string(alphabet)
|
||
|
||
# Get or create font atlas
|
||
font_atlas = _create_font_atlas(alpha_str, char_size)
|
||
|
||
# Parse background color
|
||
if background_color == 'black':
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
elif background_color == 'white':
|
||
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
||
else:
|
||
# Try to parse hex color
|
||
try:
|
||
if background_color.startswith('#'):
|
||
bg_hex = background_color[1:]
|
||
bg = jnp.array([int(bg_hex[0:2], 16),
|
||
int(bg_hex[2:4], 16),
|
||
int(bg_hex[4:6], 16)], dtype=jnp.uint8)
|
||
else:
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
except:
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
|
||
# Create output image starting with background
|
||
output_h = num_rows * char_size
|
||
output_w = num_cols * char_size
|
||
result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy()
|
||
|
||
# Gather characters from atlas based on indices
|
||
# char_indices shape: (num_rows, num_cols)
|
||
# font_atlas shape: (num_chars, char_size, char_size, 3)
|
||
# Convert numpy atlas to JAX for indexing with traced indices
|
||
font_atlas_jax = jnp.asarray(font_atlas)
|
||
flat_indices = char_indices.flatten()
|
||
char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3)
|
||
|
||
# Reshape to grid
|
||
char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3)
|
||
|
||
# Create coordinate grids for output pixels
|
||
y_out, x_out = jnp.mgrid[:output_h, :output_w]
|
||
cell_row = y_out // char_size
|
||
cell_col = x_out // char_size
|
||
local_y = y_out % char_size
|
||
local_x = x_out % char_size
|
||
|
||
# Clamp to valid ranges
|
||
cell_row = jnp.clip(cell_row, 0, num_rows - 1)
|
||
cell_col = jnp.clip(cell_col, 0, num_cols - 1)
|
||
|
||
# Get character pixel values
|
||
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
|
||
|
||
# Get character brightness (for masking)
|
||
char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0
|
||
|
||
# Handle color modes
|
||
if color_mode == 'mono':
|
||
# White characters on background
|
||
fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
||
fg = jnp.broadcast_to(fg_color, char_pixels.shape)
|
||
elif color_mode == 'invert':
|
||
# Inverted cell colors
|
||
cell_colors = colors[cell_row, cell_col]
|
||
fg = 255 - cell_colors
|
||
else:
|
||
# 'color' mode - use cell colors
|
||
fg = colors[cell_row, cell_col]
|
||
|
||
# Blend foreground onto background based on character brightness
|
||
if invert_colors:
|
||
# Swap fg and bg
|
||
fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg
|
||
result = (fg.astype(jnp.float32) * (1 - char_brightness) +
|
||
bg_broadcast.astype(jnp.float32) * char_brightness)
|
||
else:
|
||
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
|
||
result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) +
|
||
fg.astype(jnp.float32) * char_brightness)
|
||
|
||
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
# Resize back to original frame size if needed
|
||
if result.shape[0] != h or result.shape[1] != w:
|
||
# Simple nearest-neighbor resize
|
||
y_scale = result.shape[0] / h
|
||
x_scale = result.shape[1] / w
|
||
y_src = (jnp.arange(h) * y_scale).astype(jnp.int32)
|
||
x_src = (jnp.arange(w) * x_scale).astype(jnp.int32)
|
||
y_src = jnp.clip(y_src, 0, result.shape[0] - 1)
|
||
x_src = jnp.clip(x_src, 0, result.shape[1] - 1)
|
||
result = result[y_src[:, None], x_src[None, :], :]
|
||
|
||
return result
|
||
|
||
if op == 'ascii-fx-zone':
|
||
# Complex ASCII effect with per-zone expressions
|
||
# (ascii-fx-zone frame :cols cols :alphabet alphabet ...)
|
||
frame = self._eval(args[0], env)
|
||
|
||
# Parse keyword arguments
|
||
kwargs = {}
|
||
i = 1
|
||
while i < len(args):
|
||
if isinstance(args[i], Keyword):
|
||
key = args[i].name
|
||
if i + 1 < len(args):
|
||
kwargs[key] = args[i + 1]
|
||
i += 2
|
||
else:
|
||
i += 1
|
||
|
||
# Get parameters
|
||
cols = int(self._eval(kwargs.get('cols', 80), env))
|
||
char_size_param = kwargs.get('char_size')
|
||
alphabet = self._eval(kwargs.get('alphabet', 'standard'), env)
|
||
color_mode = self._eval(kwargs.get('color_mode', 'color'), env)
|
||
background = self._eval(kwargs.get('background', 'black'), env)
|
||
contrast = float(self._eval(kwargs.get('contrast', 1.5), env))
|
||
|
||
h, w = frame.shape[:2]
|
||
|
||
# Calculate char_size from cols if not specified
|
||
if char_size_param is not None:
|
||
char_size_val = self._eval(char_size_param, env)
|
||
if char_size_val is not None:
|
||
char_size = int(char_size_val)
|
||
else:
|
||
char_size = w // cols
|
||
else:
|
||
char_size = w // cols
|
||
char_size = max(4, min(char_size, 64))
|
||
|
||
# Store alphabet for render-char-grid to use
|
||
env['_ascii_alphabet'] = alphabet
|
||
|
||
# Cell sampling
|
||
num_rows = h // char_size
|
||
num_cols = w // char_size
|
||
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
|
||
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
|
||
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
|
||
|
||
# Compute luminances
|
||
colors_float = colors.astype(jnp.float32)
|
||
luminances = (0.299 * colors_float[:, :, 0] +
|
||
0.587 * colors_float[:, :, 1] +
|
||
0.114 * colors_float[:, :, 2]) / 255.0
|
||
|
||
# Get alphabet and map luminances to chars
|
||
alpha_str = _get_alphabet_string(alphabet)
|
||
num_chars = len(alpha_str)
|
||
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
|
||
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
|
||
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
|
||
|
||
# Handle optional per-zone effects (char_hue, char_saturation, etc.)
|
||
# These would modify colors based on zone position
|
||
char_hue = kwargs.get('char_hue')
|
||
char_saturation = kwargs.get('char_saturation')
|
||
char_brightness = kwargs.get('char_brightness')
|
||
|
||
if char_hue is not None or char_saturation is not None or char_brightness is not None:
|
||
# Create zone coordinate arrays for expression evaluation
|
||
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
|
||
row_norm = row_coords / max(num_rows - 1, 1)
|
||
col_norm = col_coords / max(num_cols - 1, 1)
|
||
|
||
# Bind zone variables
|
||
zone_env = env.copy()
|
||
zone_env['zone-row'] = row_coords
|
||
zone_env['zone-col'] = col_coords
|
||
zone_env['zone-row-norm'] = row_norm
|
||
zone_env['zone-col-norm'] = col_norm
|
||
zone_env['zone-lum'] = luminances
|
||
|
||
# Apply color modifications (simplified - full version would use HSV)
|
||
if char_brightness is not None:
|
||
brightness_mult = self._eval(char_brightness, zone_env)
|
||
if brightness_mult is not None:
|
||
colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None],
|
||
0, 255).astype(jnp.uint8)
|
||
|
||
# Render using font atlas
|
||
font_atlas = _create_font_atlas(alpha_str, char_size)
|
||
|
||
# Parse background color
|
||
if background == 'black':
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
elif background == 'white':
|
||
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
||
else:
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
|
||
# Gather characters - convert numpy atlas to JAX for traced indexing
|
||
font_atlas_jax = jnp.asarray(font_atlas)
|
||
flat_indices = char_indices.flatten()
|
||
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
|
||
|
||
# Create output
|
||
output_h = num_rows * char_size
|
||
output_w = num_cols * char_size
|
||
|
||
y_out, x_out = jnp.mgrid[:output_h, :output_w]
|
||
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
|
||
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
|
||
local_y = y_out % char_size
|
||
local_x = x_out % char_size
|
||
|
||
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
|
||
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
|
||
|
||
if color_mode == 'mono':
|
||
fg = jnp.full_like(char_pixels, 255)
|
||
else:
|
||
fg = colors[cell_row, cell_col]
|
||
|
||
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
|
||
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
|
||
fg.astype(jnp.float32) * char_bright)
|
||
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
# Resize to original dimensions
|
||
if result.shape[0] != h or result.shape[1] != w:
|
||
y_scale = result.shape[0] / h
|
||
x_scale = result.shape[1] / w
|
||
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
|
||
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
|
||
result = result[y_src[:, None], x_src[None, :], :]
|
||
|
||
return result
|
||
|
||
if op == 'render-char-grid-fx':
|
||
# Enhanced render with per-character effects
|
||
# (render-char-grid-fx frame chars colors luminances char_size
|
||
# color_mode bg_color invert_colors
|
||
# char_jitter char_scale char_rotation char_hue_shift
|
||
# jitter_source scale_source rotation_source hue_source)
|
||
frame = self._eval(args[0], env)
|
||
char_indices = self._eval(args[1], env)
|
||
colors = self._eval(args[2], env)
|
||
luminances = self._eval(args[3], env)
|
||
char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8
|
||
color_mode = self._eval(args[5], env) if len(args) > 5 else 'color'
|
||
background_color = self._eval(args[6], env) if len(args) > 6 else 'black'
|
||
invert_colors = self._eval(args[7], env) if len(args) > 7 else 0
|
||
|
||
# Per-char effect amounts
|
||
char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0
|
||
char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0
|
||
char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0
|
||
char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0
|
||
|
||
# Modulation sources
|
||
jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none'
|
||
scale_source = self._eval(args[13], env) if len(args) > 13 else 'none'
|
||
rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none'
|
||
hue_source = self._eval(args[15], env) if len(args) > 15 else 'none'
|
||
|
||
h, w = frame.shape[:2]
|
||
num_rows, num_cols = char_indices.shape
|
||
|
||
# Get alphabet
|
||
alphabet = env.get('_ascii_alphabet', 'standard')
|
||
alpha_str = _get_alphabet_string(alphabet)
|
||
font_atlas = _create_font_atlas(alpha_str, char_size)
|
||
|
||
# Parse background
|
||
if background_color == 'black':
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
elif background_color == 'white':
|
||
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
||
else:
|
||
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
|
||
|
||
# Create modulation values based on source
|
||
def get_modulation(source, lums, num_rows, num_cols):
|
||
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
|
||
row_norm = row_coords / max(num_rows - 1, 1)
|
||
col_norm = col_coords / max(num_cols - 1, 1)
|
||
|
||
if source == 'luminance':
|
||
return lums
|
||
elif source == 'inv_luminance':
|
||
return 1.0 - lums
|
||
elif source == 'position_x':
|
||
return col_norm
|
||
elif source == 'position_y':
|
||
return row_norm
|
||
elif source == 'position_diag':
|
||
return (row_norm + col_norm) / 2
|
||
elif source == 'center_dist':
|
||
cy, cx = 0.5, 0.5
|
||
dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2)
|
||
return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal
|
||
elif source == 'random':
|
||
# Use frame-varying key for random source
|
||
seed = env.get('_seed', 42)
|
||
op_ctr = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_ctr + 1
|
||
key = make_jax_key(seed, frame_num, op_ctr)
|
||
return jax.random.uniform(key, (num_rows, num_cols))
|
||
else:
|
||
return jnp.zeros((num_rows, num_cols))
|
||
|
||
# Get modulation values
|
||
jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols)
|
||
scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols)
|
||
rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols)
|
||
hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols)
|
||
|
||
# Gather characters - convert numpy atlas to JAX for traced indexing
|
||
font_atlas_jax = jnp.asarray(font_atlas)
|
||
flat_indices = char_indices.flatten()
|
||
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
|
||
|
||
# Create output
|
||
output_h = num_rows * char_size
|
||
output_w = num_cols * char_size
|
||
|
||
y_out, x_out = jnp.mgrid[:output_h, :output_w]
|
||
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
|
||
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
|
||
local_y = y_out % char_size
|
||
local_x = x_out % char_size
|
||
|
||
# Apply jitter if enabled
|
||
if char_jitter > 0:
|
||
jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter
|
||
# Use frame-varying key for jitter
|
||
seed = env.get('_seed', 42)
|
||
op_ctr = env.get('_rand_op_counter', 0)
|
||
env['_rand_op_counter'] = op_ctr + 1
|
||
key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2)
|
||
# Generate deterministic jitter per cell
|
||
jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1)
|
||
jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1)
|
||
offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
|
||
offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
|
||
local_y = jnp.clip(local_y + offset_y, 0, char_size - 1)
|
||
local_x = jnp.clip(local_x + offset_x, 0, char_size - 1)
|
||
|
||
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
|
||
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
|
||
|
||
# Get foreground colors
|
||
if color_mode == 'mono':
|
||
fg = jnp.full_like(char_pixels, 255)
|
||
else:
|
||
fg = colors[cell_row, cell_col]
|
||
|
||
# Apply hue shift if enabled
|
||
if char_hue_shift > 0 and color_mode == 'color':
|
||
hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift
|
||
# Simple hue rotation via channel cycling
|
||
fg_float = fg.astype(jnp.float32)
|
||
shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB
|
||
# Simplified: blend channels based on shift
|
||
r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2]
|
||
shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac
|
||
# Just do a simple tint for now
|
||
fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8)
|
||
|
||
# Blend
|
||
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
|
||
if invert_colors:
|
||
result = (fg.astype(jnp.float32) * (1 - char_bright) +
|
||
bg_broadcast.astype(jnp.float32) * char_bright)
|
||
else:
|
||
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
|
||
fg.astype(jnp.float32) * char_bright)
|
||
|
||
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||
|
||
# Resize to original
|
||
if result.shape[0] != h or result.shape[1] != w:
|
||
y_scale = result.shape[0] / h
|
||
x_scale = result.shape[1] / w
|
||
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
|
||
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
|
||
result = result[y_src[:, None], x_src[None, :], :]
|
||
|
||
return result
|
||
|
||
if op == 'alphabet-char':
|
||
# (alphabet-char alphabet-name index) -> char_index in that alphabet
|
||
alphabet = self._eval(args[0], env)
|
||
index = self._eval(args[1], env)
|
||
|
||
alpha_str = _get_alphabet_string(alphabet)
|
||
num_chars = len(alpha_str)
|
||
|
||
# Handle both scalar and array indices
|
||
if hasattr(index, 'shape'):
|
||
index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1)
|
||
else:
|
||
index = max(0, min(int(index), num_chars - 1))
|
||
|
||
return index
|
||
|
||
if op == 'map-char-grid':
|
||
# (map-char-grid base-chars luminances (lambda (r c ch lum) ...))
|
||
# Map over character grid, allowing per-cell character selection
|
||
base_chars = self._eval(args[0], env)
|
||
luminances = self._eval(args[1], env)
|
||
fn = args[2] # Lambda expression
|
||
|
||
num_rows, num_cols = base_chars.shape
|
||
|
||
# For JAX compatibility, we can't use Python loops with traced values
|
||
# Instead, we'll evaluate the lambda for the whole grid at once
|
||
if isinstance(fn, list) and len(fn) >= 3:
|
||
head = fn[0]
|
||
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
|
||
params = fn[1]
|
||
body = fn[2]
|
||
|
||
# Create grid coordinates
|
||
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
|
||
|
||
# Bind parameters for whole-grid evaluation
|
||
fn_env = env.copy()
|
||
|
||
# Params: (r c ch lum)
|
||
if len(params) >= 1:
|
||
fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords
|
||
if len(params) >= 2:
|
||
fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords
|
||
if len(params) >= 3:
|
||
fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars
|
||
if len(params) >= 4:
|
||
# Luminances scaled to 0-255 range
|
||
fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32)
|
||
|
||
# Evaluate body - should return new character indices
|
||
result = self._eval(body, fn_env)
|
||
if hasattr(result, 'shape'):
|
||
return result.astype(jnp.int32)
|
||
return base_chars
|
||
|
||
return base_chars
|
||
|
||
# =====================================================================
|
||
# List operations
|
||
# =====================================================================
|
||
if op == 'take':
|
||
seq = self._eval(args[0], env)
|
||
n = int(self._eval(args[1], env))
|
||
if isinstance(seq, (list, tuple)):
|
||
return seq[:n]
|
||
return seq[:n] # Works for arrays too
|
||
|
||
if op == 'cons':
|
||
item = self._eval(args[0], env)
|
||
seq = self._eval(args[1], env)
|
||
if isinstance(seq, list):
|
||
return [item] + seq
|
||
elif isinstance(seq, tuple):
|
||
return (item,) + seq
|
||
return jnp.concatenate([jnp.array([item]), seq])
|
||
|
||
if op == 'roll':
|
||
arr = self._eval(args[0], env)
|
||
shift = self._eval(args[1], env)
|
||
axis = self._eval(args[2], env) if len(args) > 2 else 0
|
||
# Convert to int for concrete values, keep as-is for JAX traced values
|
||
if isinstance(shift, (int, float)):
|
||
shift = int(shift)
|
||
elif hasattr(shift, 'astype'):
|
||
shift = shift.astype(jnp.int32)
|
||
if isinstance(axis, (int, float)):
|
||
axis = int(axis)
|
||
return jnp.roll(arr, shift, axis=axis)
|
||
|
||
# =====================================================================
|
||
# Pi constant
|
||
# =====================================================================
|
||
if op == 'pi':
|
||
return jnp.pi
|
||
|
||
raise ValueError(f"Unknown operation: {op}")
|
||
|
||
|
||
# =============================================================================
|
||
# Public API
|
||
# =============================================================================
|
||
|
||
def compile_effect(code: str) -> Callable:
|
||
"""
|
||
Compile an S-expression effect to a JAX function.
|
||
|
||
Args:
|
||
code: S-expression effect code
|
||
|
||
Returns:
|
||
JIT-compiled function: (frame, **params) -> frame
|
||
"""
|
||
# Check cache
|
||
cache_key = hashlib.md5(code.encode()).hexdigest()
|
||
if cache_key in _COMPILED_EFFECTS:
|
||
return _COMPILED_EFFECTS[cache_key]
|
||
|
||
# Parse and compile
|
||
sexp = parse(code)
|
||
compiler = JaxCompiler()
|
||
fn = compiler.compile_effect(sexp)
|
||
|
||
_COMPILED_EFFECTS[cache_key] = fn
|
||
return fn
|
||
|
||
|
||
def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable:
|
||
"""
|
||
Compile an effect from a .sexp file.
|
||
|
||
Args:
|
||
path: Path to the .sexp effect file
|
||
derived_paths: Optional list of paths to derived.sexp files to load
|
||
|
||
Returns:
|
||
JIT-compiled function: (frame, **params) -> frame
|
||
"""
|
||
with open(path, 'r') as f:
|
||
code = f.read()
|
||
|
||
# Parse all expressions in file
|
||
exprs = parse_all(code)
|
||
|
||
# Create compiler
|
||
compiler = JaxCompiler()
|
||
|
||
# Load derived files if specified
|
||
if derived_paths:
|
||
for dp in derived_paths:
|
||
compiler.load_derived(dp)
|
||
|
||
# Process expressions - find require statements and the effect
|
||
effect_sexp = None
|
||
effect_dir = Path(path).parent
|
||
|
||
for expr in exprs:
|
||
if not isinstance(expr, list) or len(expr) < 2:
|
||
continue
|
||
|
||
head = expr[0]
|
||
if not isinstance(head, Symbol):
|
||
continue
|
||
|
||
if head.name == 'require':
|
||
# (require "derived") or (require "path/to/file")
|
||
req_path = expr[1]
|
||
if isinstance(req_path, str):
|
||
# Resolve relative to effect file
|
||
if not req_path.endswith('.sexp'):
|
||
req_path = req_path + '.sexp'
|
||
full_path = effect_dir / req_path
|
||
if not full_path.exists():
|
||
# Try sexp_effects directory
|
||
full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path
|
||
if full_path.exists():
|
||
compiler.load_derived(str(full_path))
|
||
|
||
elif head.name == 'require-primitives':
|
||
# (require-primitives "lib") - currently ignored for JAX
|
||
# JAX has all primitives built-in
|
||
pass
|
||
|
||
elif head.name in ('effect', 'define-effect'):
|
||
effect_sexp = expr
|
||
|
||
if effect_sexp is None:
|
||
raise ValueError(f"No effect definition found in {path}")
|
||
|
||
return compiler.compile_effect(effect_sexp)
|
||
|
||
|
||
def load_derived(derived_path: str = None) -> Dict[str, Callable]:
|
||
"""
|
||
Load derived operations from derived.sexp.
|
||
|
||
Returns dict of compiled functions that can be used in effects.
|
||
"""
|
||
if derived_path is None:
|
||
derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp'
|
||
|
||
with open(derived_path, 'r') as f:
|
||
code = f.read()
|
||
|
||
exprs = parse_all(code)
|
||
compiler = JaxCompiler()
|
||
env = {}
|
||
|
||
for expr in exprs:
|
||
if isinstance(expr, list) and len(expr) >= 3:
|
||
head = expr[0]
|
||
if isinstance(head, Symbol) and head.name == 'define':
|
||
compiler._eval_define(expr[1:], env)
|
||
|
||
return env
|
||
|
||
|
||
# =============================================================================
|
||
# Test / Demo
|
||
# =============================================================================
|
||
|
||
if __name__ == '__main__':
|
||
import numpy as np
|
||
|
||
# Test effect
|
||
test_effect = '''
|
||
(effect "threshold-test"
|
||
:params ((threshold :default 128))
|
||
:body (let* ((r (channel frame 0))
|
||
(g (channel frame 1))
|
||
(b (channel frame 2))
|
||
(gray (+ (* r 0.299) (* g 0.587) (* b 0.114)))
|
||
(mask (where (> gray threshold) 255 0)))
|
||
(merge-channels mask mask mask)))
|
||
'''
|
||
|
||
print("Compiling effect...")
|
||
run_effect = compile_effect(test_effect)
|
||
|
||
# Create test frame
|
||
print("Creating test frame...")
|
||
frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8)
|
||
|
||
# Run effect
|
||
print("Running effect (first run includes JIT compilation)...")
|
||
import time
|
||
|
||
t0 = time.time()
|
||
result = run_effect(frame, threshold=128)
|
||
t1 = time.time()
|
||
print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms")
|
||
|
||
# Second run should be faster
|
||
t0 = time.time()
|
||
result = run_effect(frame, threshold=128)
|
||
t1 = time.time()
|
||
print(f"Second run (cached): {(t1-t0)*1000:.2f}ms")
|
||
|
||
# Multiple runs
|
||
t0 = time.time()
|
||
for _ in range(100):
|
||
result = run_effect(frame, threshold=128)
|
||
t1 = time.time()
|
||
print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg")
|
||
|
||
print(f"\nResult shape: {result.shape}, dtype: {result.dtype}")
|
||
print("Done!")
|