Files
celery/streaming/sexp_to_jax.py
gilesb 7411aa74c4
Some checks failed
GPU Worker CI/CD / test (push) Has been cancelled
GPU Worker CI/CD / deploy (push) Has been cancelled
Add JAX backend with frame-varying random keys
- Add sexp_to_jax.py: JAX compiler for S-expression effects
- Use jax.random.fold_in for deterministic but varying random per frame
- Pass seed from recipe config through to JAX effects
- Fix NVENC detection to do actual encode test
- Add set_random_seed for deterministic Python random

The fold_in approach allows frame_num to be traced (not static)
while still producing different random patterns per frame,
fixing the interference pattern issue.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 11:07:02 +00:00

3639 lines
141 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Sexp to JAX Compiler.
Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU.
Uses XLA compilation via @jax.jit for automatic kernel fusion.
Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles
S-expressions directly to JAX operations which XLA then optimizes.
Usage:
from streaming.sexp_to_jax import compile_effect
effect_code = '''
(effect "threshold"
:params ((threshold :default 128))
:body (let ((g (gray frame)))
(rgb (where (> g threshold) 255 0)
(where (> g threshold) 255 0)
(where (> g threshold) 255 0))))
'''
run_effect = compile_effect(effect_code)
output = run_effect(frame, threshold=128)
"""
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Any, Dict, List, Callable, Optional, Tuple
import hashlib
import numpy as np
# Import parser
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sexp_effects.parser import parse, parse_all, Symbol, Keyword
# =============================================================================
# Compilation Cache
# =============================================================================
_COMPILED_EFFECTS: Dict[str, Callable] = {}
# =============================================================================
# Font Atlas for ASCII Effects
# =============================================================================
# Character sets for ASCII rendering
ASCII_ALPHABETS = {
'standard': ' .:-=+*#%@',
'blocks': ' ░▒▓█',
'simple': ' .:oO@',
'digits': ' 0123456789',
'binary': ' 01',
'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$',
}
# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array
_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {}
def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray:
"""
Create a font atlas with all characters pre-rendered.
Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time.
Args:
alphabet: String of characters to render (ordered by brightness, dark to light)
char_size: Size of each character cell in pixels
font_name: Optional font name/path (uses default monospace if None)
Returns:
NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters
Each character is white on black background.
"""
cache_key = (alphabet, char_size, font_name)
if cache_key in _FONT_ATLAS_CACHE:
return _FONT_ATLAS_CACHE[cache_key]
try:
from PIL import Image, ImageDraw, ImageFont
except ImportError:
# Fallback: create simple block-based atlas without PIL
return _create_block_atlas(alphabet, char_size)
num_chars = len(alphabet)
atlas = []
# Try to load a monospace font
font = None
font_size = int(char_size * 0.9) # Slightly smaller than cell
# Try various monospace fonts
font_candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
'/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf',
'/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf',
'/System/Library/Fonts/Menlo.ttc', # macOS
'/System/Library/Fonts/Monaco.dfont', # macOS
'C:\\Windows\\Fonts\\consola.ttf', # Windows
]
for font_path in font_candidates:
if font_path is None:
continue
try:
font = ImageFont.truetype(font_path, font_size)
break
except (IOError, OSError):
continue
if font is None:
# Use default font
try:
font = ImageFont.load_default()
except:
# Ultimate fallback to blocks
return _create_block_atlas(alphabet, char_size)
for char in alphabet:
# Create image for this character
img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0))
draw = ImageDraw.Draw(img)
# Get text bounding box for centering
try:
bbox = draw.textbbox((0, 0), char, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
except AttributeError:
# Older PIL versions
text_width, text_height = draw.textsize(char, font=font)
# Center the character
x = (char_size - text_width) // 2
y = (char_size - text_height) // 2
# Draw white character on black background
draw.text((x, y), char, fill=(255, 255, 255), font=font)
# Convert to numpy array (NOT jax array - avoids tracer issues)
char_array = np.array(img, dtype=np.uint8)
atlas.append(char_array)
atlas = np.stack(atlas, axis=0)
_FONT_ATLAS_CACHE[cache_key] = atlas
return atlas
def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray:
"""
Create a simple block-based atlas without fonts.
Uses numpy to avoid tracer issues.
"""
num_chars = len(alphabet)
atlas = []
for i, char in enumerate(alphabet):
# Brightness proportional to position in alphabet
brightness = int(255 * i / max(num_chars - 1, 1))
# Create a simple pattern based on character
img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8)
# Add some texture/pattern for visual interest
# Checkerboard pattern for mid-range characters
if 0.2 < i / num_chars < 0.8:
y_coords, x_coords = np.mgrid[:char_size, :char_size]
checker = ((x_coords + y_coords) % 2 == 0)
variation = int(brightness * 0.2)
img = np.where(checker[:, :, None],
np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8),
np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8))
atlas.append(img)
return np.stack(atlas, axis=0)
def _get_alphabet_string(alphabet_name: str) -> str:
"""Get the character string for a named alphabet or return as-is if custom."""
if alphabet_name in ASCII_ALPHABETS:
return ASCII_ALPHABETS[alphabet_name]
return alphabet_name # Assume it's a custom character string
# =============================================================================
# JAX Primitives - True primitives that can't be derived
# =============================================================================
def jax_width(frame):
"""Frame width."""
return frame.shape[1]
def jax_height(frame):
"""Frame height."""
return frame.shape[0]
def jax_channel(frame, idx):
"""Extract channel by index as flat array."""
# idx must be a static int for indexing
return frame[:, :, int(idx)].flatten().astype(jnp.float32)
def jax_merge_channels(r, g, b, shape):
"""Merge RGB channels back to frame."""
h, w = shape
r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8)
g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8)
b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
return jnp.stack([r_img, g_img, b_img], axis=2)
def jax_iota(n):
"""Generate [0, 1, 2, ..., n-1]."""
return jnp.arange(n, dtype=jnp.float32)
def jax_repeat(x, n):
"""Repeat each element n times: [a,b] -> [a,a,b,b]."""
return jnp.repeat(x, n)
def jax_tile(x, n):
"""Tile array n times: [a,b] -> [a,b,a,b]."""
return jnp.tile(x, n)
def jax_gather(data, indices):
"""Parallel index lookup."""
flat_data = data.flatten()
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1)
return flat_data[idx_clipped]
def jax_scatter(indices, values, size):
"""Parallel index write (last write wins)."""
result = jnp.zeros(size, dtype=jnp.float32)
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
return result.at[idx_clipped].set(values)
def jax_scatter_add(indices, values, size):
"""Parallel index accumulate."""
result = jnp.zeros(size, dtype=jnp.float32)
idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1)
return result.at[idx_clipped].add(values)
def jax_group_reduce(values, group_indices, num_groups, op='mean'):
"""Reduce values by group."""
grp = group_indices.astype(jnp.int32)
if op == 'sum':
result = jnp.zeros(num_groups, dtype=jnp.float32)
return result.at[grp].add(values)
elif op == 'mean':
sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values)
counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0)
return jnp.where(counts > 0, sums / counts, 0.0)
elif op == 'max':
result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32)
result = result.at[grp].max(values)
return jnp.where(result == -jnp.inf, 0.0, result)
elif op == 'min':
result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32)
result = result.at[grp].min(values)
return jnp.where(result == jnp.inf, 0.0, result)
else:
raise ValueError(f"Unknown reduce op: {op}")
def jax_where(cond, true_val, false_val):
"""Conditional select."""
return jnp.where(cond, true_val, false_val)
def jax_cell_indices(frame, cell_size):
"""Compute cell index for each pixel."""
h, w = frame.shape[:2]
cell_size = int(cell_size)
rows = h // cell_size
cols = w // cell_size
# For each pixel, compute its cell index
y_coords = jnp.repeat(jnp.arange(h), w)
x_coords = jnp.tile(jnp.arange(w), h)
cell_row = y_coords // cell_size
cell_col = x_coords // cell_size
cell_idx = cell_row * cols + cell_col
# Clip to valid range
return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32)
def jax_pool_frame(frame, cell_size):
"""
Pool frame to cell values.
Returns tuple: (cell_r, cell_g, cell_b, cell_lum)
"""
h, w = frame.shape[:2]
cs = int(cell_size)
rows = h // cs
cols = w // cs
num_cells = rows * cols
# Compute cell indices for each pixel
y_coords = jnp.repeat(jnp.arange(h), w)
x_coords = jnp.tile(jnp.arange(w), h)
cell_row = jnp.clip(y_coords // cs, 0, rows - 1)
cell_col = jnp.clip(x_coords // cs, 0, cols - 1)
cell_idx = (cell_row * cols + cell_col).astype(jnp.int32)
# Extract channels
r_flat = frame[:, :, 0].flatten().astype(jnp.float32)
g_flat = frame[:, :, 1].flatten().astype(jnp.float32)
b_flat = frame[:, :, 2].flatten().astype(jnp.float32)
# Pool each channel (mean)
def pool_channel(data):
sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data)
counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0)
return jnp.where(counts > 0, sums / counts, 0.0)
r_pooled = pool_channel(r_flat)
g_pooled = pool_channel(g_flat)
b_pooled = pool_channel(b_flat)
lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled
return (r_pooled, g_pooled, b_pooled, lum)
def jax_sample(frame, x, y):
"""Bilinear sample at (x, y) coordinates."""
h, w = frame.shape[:2]
# Clamp coordinates
x = jnp.clip(x, 0, w - 1)
y = jnp.clip(y, 0, h - 1)
# Get integer and fractional parts
x0 = jnp.floor(x).astype(jnp.int32)
y0 = jnp.floor(y).astype(jnp.int32)
x1 = jnp.clip(x0 + 1, 0, w - 1)
y1 = jnp.clip(y0 + 1, 0, h - 1)
fx = x - x0.astype(jnp.float32)
fy = y - y0.astype(jnp.float32)
# Bilinear interpolation for each channel
def interp_channel(c):
c00 = frame[y0, x0, c].astype(jnp.float32)
c10 = frame[y0, x1, c].astype(jnp.float32)
c01 = frame[y1, x0, c].astype(jnp.float32)
c11 = frame[y1, x1, c].astype(jnp.float32)
return (c00 * (1 - fx) * (1 - fy) +
c10 * fx * (1 - fy) +
c01 * (1 - fx) * fy +
c11 * fx * fy)
r = interp_channel(0)
g = interp_channel(1)
b = interp_channel(2)
return r, g, b
# =============================================================================
# Convolution Operations
# =============================================================================
def jax_convolve2d(data, kernel):
"""2D convolution on a single channel."""
# data shape: (H, W), kernel shape: (kH, kW)
# Use JAX's conv with appropriate padding
h, w = data.shape
kh, kw = kernel.shape
# Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c)
data_4d = data.reshape(1, h, w, 1)
kernel_4d = kernel.reshape(kh, kw, 1, 1)
# Convolve with 'SAME' padding
result = lax.conv_general_dilated(
data_4d, kernel_4d,
window_strides=(1, 1),
padding='SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)
return result.reshape(h, w)
def jax_blur(frame, radius=1):
"""Gaussian blur."""
# Create gaussian kernel
size = int(radius) * 2 + 1
x = jnp.arange(size) - radius
gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
kernel = jnp.outer(gaussian_1d, gaussian_1d)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_sharpen(frame, amount=1.0):
"""Sharpen using unsharp mask."""
kernel = jnp.array([
[0, -1, 0],
[-1, 5, -1],
[0, -1, 0]
], dtype=jnp.float32)
# Adjust kernel based on amount
center = 4 * amount + 1
kernel = kernel.at[1, 1].set(center)
kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_edge_detect(frame):
"""Sobel edge detection."""
# Sobel kernels
sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)
sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)
# Convert to grayscale first
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
gx = jax_convolve2d(gray, sobel_x)
gy = jax_convolve2d(gray, sobel_y)
edges = jnp.sqrt(gx**2 + gy**2)
edges = jnp.clip(edges, 0, 255).astype(jnp.uint8)
return jnp.stack([edges, edges, edges], axis=2)
def jax_emboss(frame):
"""Emboss effect."""
kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32)
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
embossed = jax_convolve2d(gray, kernel) + 128
embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8)
return jnp.stack([embossed, embossed, embossed], axis=2)
# =============================================================================
# Color Space Conversion
# =============================================================================
def jax_rgb_to_hsv(r, g, b):
"""Convert RGB to HSV. All inputs/outputs are 0-255 range."""
r, g, b = r / 255.0, g / 255.0, b / 255.0
max_c = jnp.maximum(jnp.maximum(r, g), b)
min_c = jnp.minimum(jnp.minimum(r, g), b)
diff = max_c - min_c
# Value
v = max_c
# Saturation
s = jnp.where(max_c > 0, diff / max_c, 0.0)
# Hue
h = jnp.where(diff == 0, 0.0,
jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360,
jnp.where(max_c == g, 60 * ((b - r) / diff) + 120,
60 * ((r - g) / diff) + 240)))
return h, s * 255, v * 255
def jax_hsv_to_rgb(h, s, v):
"""Convert HSV to RGB. H is 0-360, S and V are 0-255."""
h = h % 360
s, v = s / 255.0, v / 255.0
c = v * s
x = c * (1 - jnp.abs((h / 60) % 2 - 1))
m = v - c
h_sector = (h / 60).astype(jnp.int32) % 6
r = jnp.where(h_sector == 0, c,
jnp.where(h_sector == 1, x,
jnp.where(h_sector == 2, 0,
jnp.where(h_sector == 3, 0,
jnp.where(h_sector == 4, x, c)))))
g = jnp.where(h_sector == 0, x,
jnp.where(h_sector == 1, c,
jnp.where(h_sector == 2, c,
jnp.where(h_sector == 3, x,
jnp.where(h_sector == 4, 0, 0)))))
b = jnp.where(h_sector == 0, 0,
jnp.where(h_sector == 1, 0,
jnp.where(h_sector == 2, x,
jnp.where(h_sector == 3, c,
jnp.where(h_sector == 4, c, x)))))
return (r + m) * 255, (g + m) * 255, (b + m) * 255
def jax_adjust_saturation(frame, factor):
"""Adjust saturation by factor (1.0 = unchanged)."""
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
h, s, v = jax_rgb_to_hsv(r, g, b)
s = jnp.clip(s * factor, 0, 255)
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
h_dim, w_dim = frame.shape[:2]
return jnp.stack([
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
], axis=2)
def jax_shift_hue(frame, degrees):
"""Shift hue by degrees."""
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
h, s, v = jax_rgb_to_hsv(r, g, b)
h = (h + degrees) % 360
r2, g2, b2 = jax_hsv_to_rgb(h, s, v)
h_dim, w_dim = frame.shape[:2]
return jnp.stack([
jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8),
jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8)
], axis=2)
# =============================================================================
# Color Adjustment Operations
# =============================================================================
def jax_adjust_brightness(frame, amount):
"""Adjust brightness by amount (-255 to 255)."""
result = frame.astype(jnp.float32) + amount
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_adjust_contrast(frame, factor):
"""Adjust contrast by factor (1.0 = unchanged)."""
result = (frame.astype(jnp.float32) - 128) * factor + 128
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_invert(frame):
"""Invert colors."""
return 255 - frame
def jax_posterize(frame, levels):
"""Reduce to N color levels per channel."""
levels = int(levels)
if levels < 2:
levels = 2
step = 255.0 / (levels - 1)
result = jnp.round(frame.astype(jnp.float32) / step) * step
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_threshold(frame, level, invert=False):
"""Binary threshold."""
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
if invert:
binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8)
else:
binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8)
return jnp.stack([binary, binary, binary], axis=2)
def jax_sepia(frame):
"""Apply sepia tone."""
r = frame[:, :, 0].astype(jnp.float32)
g = frame[:, :, 1].astype(jnp.float32)
b = frame[:, :, 2].astype(jnp.float32)
new_r = r * 0.393 + g * 0.769 + b * 0.189
new_g = r * 0.349 + g * 0.686 + b * 0.168
new_b = r * 0.272 + g * 0.534 + b * 0.131
return jnp.stack([
jnp.clip(new_r, 0, 255).astype(jnp.uint8),
jnp.clip(new_g, 0, 255).astype(jnp.uint8),
jnp.clip(new_b, 0, 255).astype(jnp.uint8)
], axis=2)
def jax_grayscale(frame):
"""Convert to grayscale."""
gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
gray = gray.astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# =============================================================================
# Geometry Operations
# =============================================================================
def jax_flip_horizontal(frame):
"""Flip horizontally."""
return frame[:, ::-1, :]
def jax_flip_vertical(frame):
"""Flip vertically."""
return frame[::-1, :, :]
def jax_rotate(frame, angle, center_x=None, center_y=None):
"""Rotate frame by angle (degrees), matching OpenCV convention.
Positive angle = counter-clockwise rotation.
"""
h, w = frame.shape[:2]
if center_x is None:
center_x = w / 2
if center_y is None:
center_y = h / 2
# Convert to radians
theta = angle * jnp.pi / 180
cos_t, sin_t = jnp.cos(theta), jnp.sin(theta)
# Create coordinate grids
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
# OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]]
# For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]]
# So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx
# src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy
x_centered = x_coords - center_x
y_centered = y_coords - center_y
src_x = cos_t * x_centered - sin_t * y_centered + center_x
src_y = sin_t * x_centered + cos_t * y_centered + center_y
# Mask for valid coordinates (out-of-bounds -> black, matching OpenCV)
valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1)
valid_flat = valid.flatten()
# Sample using bilinear interpolation
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
# Zero out-of-bounds pixels (matching OpenCV warpAffine behavior)
r = jnp.where(valid_flat, r, 0)
g = jnp.where(valid_flat, g, 0)
b = jnp.where(valid_flat, b, 0)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
def jax_scale(frame, scale_x, scale_y=None):
"""Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds."""
if scale_y is None:
scale_y = scale_x
h, w = frame.shape[:2]
center_x, center_y = w / 2, h / 2
# Create coordinate grids
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
# Scale from center (inverse mapping: dst -> src)
src_x = (x_coords - center_x) / scale_x + center_x
src_y = (y_coords - center_y) / scale_y + center_y
# Mask for valid coordinates (out-of-bounds -> black)
valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1)
valid_flat = valid.flatten()
r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten())
# Zero out-of-bounds pixels
r = jnp.where(valid_flat, r, 0)
g = jnp.where(valid_flat, g, 0)
b = jnp.where(valid_flat, b, 0)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
def jax_resize(frame, new_width, new_height):
"""Resize frame to new dimensions."""
h, w = frame.shape[:2]
new_h, new_w = int(new_height), int(new_width)
# Create coordinate grids for new size
y_coords = jnp.repeat(jnp.arange(new_h), new_w)
x_coords = jnp.tile(jnp.arange(new_w), new_h)
# Map to source coordinates
src_x = x_coords * (w - 1) / (new_w - 1)
src_y = y_coords * (h - 1) / (new_h - 1)
r, g, b = jax_sample(frame, src_x, src_y)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8)
], axis=2)
# =============================================================================
# Blending Operations
# =============================================================================
def _resize_to_match(frame1, frame2):
"""Resize frame2 to match frame1's dimensions if they differ.
Uses jax.image.resize for bilinear interpolation.
Returns frame2 resized to frame1's shape.
"""
h1, w1 = frame1.shape[:2]
h2, w2 = frame2.shape[:2]
# If same size, return as-is
if h1 == h2 and w1 == w2:
return frame2
# Resize frame2 to match frame1
# jax.image.resize expects (height, width, channels) and target shape
return jax.image.resize(
frame2.astype(jnp.float32),
(h1, w1, frame2.shape[2]),
method='bilinear'
).astype(jnp.uint8)
def jax_blend(frame1, frame2, alpha):
"""Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2.
Auto-resizes frame2 to match frame1 if dimensions differ.
"""
frame2 = _resize_to_match(frame1, frame2)
return (frame1.astype(jnp.float32) * (1 - alpha) +
frame2.astype(jnp.float32) * alpha).astype(jnp.uint8)
def jax_blend_add(frame1, frame2):
"""Additive blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32)
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_blend_multiply(frame1, frame2):
"""Multiply blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def jax_blend_screen(frame1, frame2):
"""Screen blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
f1 = frame1.astype(jnp.float32) / 255
f2 = frame2.astype(jnp.float32) / 255
result = 1 - (1 - f1) * (1 - f2)
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
def jax_blend_overlay(frame1, frame2):
"""Overlay blend. Auto-resizes frame2 to match frame1."""
frame2 = _resize_to_match(frame1, frame2)
f1 = frame1.astype(jnp.float32) / 255
f2 = frame2.astype(jnp.float32) / 255
result = jnp.where(f1 < 0.5,
2 * f1 * f2,
1 - 2 * (1 - f1) * (1 - f2))
return jnp.clip(result * 255, 0, 255).astype(jnp.uint8)
# =============================================================================
# Utility
# =============================================================================
def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0):
"""Create a JAX random key that varies with frame and operation.
Uses jax.random.fold_in to mix frame_num (which may be traced) into the key.
This allows JIT compilation without recompiling for each frame.
Args:
seed: Base seed for determinism (must be concrete)
frame_num: Frame number for variation (can be traced)
op_id: Operation ID for variation (must be concrete)
Returns:
JAX PRNGKey
"""
# Create base key from seed and op_id (both concrete)
base_key = jax.random.PRNGKey(seed + op_id * 1000003)
# Fold in frame_num (can be traced value)
return jax.random.fold_in(base_key, frame_num)
def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42):
"""Random float in [lo, hi), varies with frame."""
key = make_jax_key(seed, frame_num, op_id)
return lo + jax.random.uniform(key) * (hi - lo)
def jax_is_nil(x):
"""Check if value is None/nil."""
return x is None
# =============================================================================
# S-expression to JAX Compiler
# =============================================================================
class JaxCompiler:
"""Compiles S-expressions to JAX functions."""
def __init__(self):
self.env = {} # Variable bindings during compilation
self.params = {} # Effect parameters
self.primitives = {} # Loaded primitive libraries
self.derived = {} # Loaded derived functions
def load_derived(self, path: str):
"""Load derived operations from a .sexp file."""
with open(path, 'r') as f:
code = f.read()
exprs = parse_all(code)
# Evaluate all define expressions to populate derived functions
for expr in exprs:
if isinstance(expr, list) and len(expr) >= 3:
head = expr[0]
if isinstance(head, Symbol) and head.name == 'define':
self._eval_define(expr[1:], self.derived)
def compile_effect(self, sexp) -> Callable:
"""
Compile an effect S-expression to a JAX function.
Supports both formats:
(effect "name" :params (...) :body ...)
(define-effect name :params (...) body)
Args:
sexp: Parsed S-expression
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
if not isinstance(sexp, list) or len(sexp) < 2:
raise ValueError("Effect must be a list")
head = sexp[0]
if not isinstance(head, Symbol):
raise ValueError("Effect must start with a symbol")
form = head.name
# Handle both 'effect' and 'define-effect' formats
if form == 'effect':
# (effect "name" :params (...) :body ...)
name = sexp[1] if len(sexp) > 1 else "unnamed"
if isinstance(name, Symbol):
name = name.name
start_idx = 2
elif form == 'define-effect':
# (define-effect name :params (...) body)
name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1])
start_idx = 2
else:
raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'")
params_spec = []
body = None
i = start_idx
while i < len(sexp):
item = sexp[i]
if isinstance(item, Keyword):
if item.name == 'params' and i + 1 < len(sexp):
params_spec = sexp[i + 1]
i += 2
elif item.name == 'body' and i + 1 < len(sexp):
body = sexp[i + 1]
i += 2
elif item.name in ('desc', 'type', 'range'):
# Skip metadata keywords
i += 2
else:
i += 2 # Skip unknown keywords with their values
else:
# Assume it's the body if we haven't seen one
if body is None:
body = item
i += 1
if body is None:
raise ValueError(f"Effect '{name}' must have a body")
# Extract parameter names, defaults, and static params (strings, bools)
param_info, static_params = self._parse_params(params_spec)
# Capture derived functions for the closure
derived_fns = self.derived.copy()
# Create the JAX function
def effect_fn(frame, **kwargs):
# Set up environment
h, w = frame.shape[:2]
# Get frame_num for deterministic random variation
frame_num = kwargs.get('frame_num', 0)
# Get seed from recipe config (passed via kwargs)
seed = kwargs.get('seed', 42)
env = {
'frame': frame,
'width': w,
'height': h,
'_shape': (h, w),
# Time variables (default to 0, can be overridden via kwargs)
't': kwargs.get('t', kwargs.get('_time', 0.0)),
'_time': kwargs.get('_time', kwargs.get('t', 0.0)),
'time': kwargs.get('time', kwargs.get('t', 0.0)),
# Frame number for random key generation
'frame_num': frame_num,
'frame-num': frame_num,
'_frame_num': frame_num,
# Seed from recipe for deterministic random
'_seed': seed,
# Counter for unique random keys within same frame
'_rand_op_counter': 0,
# Common constants
'pi': jnp.pi,
'PI': jnp.pi,
}
# Add derived functions
env.update(derived_fns)
# Add parameters with defaults
for pname, pdefault in param_info.items():
if pname in kwargs:
env[pname] = kwargs[pname]
elif isinstance(pdefault, list):
# Unevaluated S-expression default - evaluate it
env[pname] = self._eval(pdefault, env)
else:
env[pname] = pdefault
# Evaluate body
result = self._eval(body, env)
# Ensure result is a frame
if isinstance(result, tuple) and len(result) == 3:
# RGB tuple - merge to frame
r, g, b = result
return jax_merge_channels(r, g, b, (h, w))
elif result.ndim == 3:
return result
else:
# Single channel - replicate to RGB
h, w = env['_shape']
gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# JIT compile with static args for string/bool parameters and seed
# seed must be static for PRNGKey, but frame_num can be traced via fold_in
all_static = set(static_params) | {'seed'}
return jax.jit(effect_fn, static_argnames=list(all_static))
def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]:
"""Parse parameter specifications.
Returns:
Tuple of (param_defaults, static_params)
- param_defaults: Dict mapping param names to default values
- static_params: Set of param names that should be static (strings, bools)
"""
result = {}
static_params = set()
if not isinstance(params_spec, list):
return result, static_params
for param in params_spec:
if isinstance(param, Symbol):
result[param.name] = 0.0
elif isinstance(param, list) and len(param) >= 1:
pname = param[0].name if isinstance(param[0], Symbol) else str(param[0])
pdefault = 0.0
ptype = None
# Look for :default and :type keywords
i = 1
while i < len(param):
if isinstance(param[i], Keyword):
kw = param[i].name
if kw == 'default' and i + 1 < len(param):
pdefault = param[i + 1]
if isinstance(pdefault, Symbol):
if pdefault.name == 'nil':
pdefault = None
elif pdefault.name == 'true':
pdefault = True
elif pdefault.name == 'false':
pdefault = False
i += 2
elif kw == 'type' and i + 1 < len(param):
ptype = param[i + 1]
if isinstance(ptype, Symbol):
ptype = ptype.name
i += 2
else:
i += 1
else:
i += 1
result[pname] = pdefault
# Mark string and bool parameters as static (can't be traced by JAX)
if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)):
static_params.add(pname)
return result, static_params
def _eval(self, expr, env: Dict[str, Any]) -> Any:
"""Evaluate an S-expression in the given environment."""
# Already-evaluated values (e.g., from threading macros)
# JAX arrays, NumPy arrays, tuples, etc.
if hasattr(expr, 'shape'): # JAX/NumPy array
return expr
if isinstance(expr, tuple): # e.g., (r, g, b) from rgb
return expr
# Literals - keep as Python numbers for static operations
if isinstance(expr, (int, float)):
return expr
if isinstance(expr, str):
return expr
# Symbols - variable lookup
if isinstance(expr, Symbol):
name = expr.name
if name in env:
return env[name]
if name == 'nil':
return None
if name == 'true':
return True
if name == 'false':
return False
raise NameError(f"Unknown symbol: {name}")
# Lists - function calls
if isinstance(expr, list) and len(expr) > 0:
head = expr[0]
if isinstance(head, Symbol):
op = head.name
args = expr[1:]
# Special forms
if op == 'let' or op == 'let*':
return self._eval_let(args, env)
if op == 'if':
return self._eval_if(args, env)
if op == 'lambda' or op == 'λ':
return self._eval_lambda(args, env)
if op == 'define':
return self._eval_define(args, env)
# Built-in operations
return self._eval_call(op, args, env)
# Empty list
if isinstance(expr, list) and len(expr) == 0:
return []
raise ValueError(f"Cannot evaluate: {expr}")
def _eval_let(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body)."""
if len(args) < 2:
raise ValueError("let requires bindings and body")
bindings = args[0]
body = args[1]
new_env = env.copy()
# Handle both ((var val) ...) and [var val var2 val2 ...] syntax
if isinstance(bindings, list):
# Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...)
if bindings and isinstance(bindings[0], Symbol):
# Flat list: [var val var2 val2 ...]
i = 0
while i < len(bindings) - 1:
var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
val = self._eval(bindings[i + 1], new_env)
new_env[var] = val
i += 2
else:
# Nested list: ((var val) (var2 val2) ...)
for binding in bindings:
if isinstance(binding, list) and len(binding) >= 2:
var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
val = self._eval(binding[1], new_env)
new_env[var] = val
return self._eval(body, new_env)
def _eval_if(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (if cond then else)."""
if len(args) < 2:
raise ValueError("if requires condition and then-branch")
cond = self._eval(args[0], env)
# Handle None as falsy (important for optional params like overlay)
if cond is None:
return self._eval(args[2], env) if len(args) > 2 else None
# For Python scalar bools, use normal Python if
# This allows side effects and None values
if isinstance(cond, bool):
if cond:
return self._eval(args[1], env)
else:
return self._eval(args[2], env) if len(args) > 2 else None
# For NumPy/JAX scalar bools with concrete values
if hasattr(cond, 'item') and cond.shape == ():
try:
if bool(cond.item()):
return self._eval(args[1], env)
else:
return self._eval(args[2], env) if len(args) > 2 else None
except:
pass # Fall through to jnp.where for traced values
# For traced values, evaluate both branches and use jnp.where
then_val = self._eval(args[1], env)
else_val = self._eval(args[2], env) if len(args) > 2 else 0.0
# Handle None by converting to zeros
if then_val is None:
then_val = 0.0
if else_val is None:
else_val = 0.0
# Convert lists to tuples
if isinstance(then_val, list):
then_val = tuple(then_val)
if isinstance(else_val, list):
else_val = tuple(else_val)
# Handle tuple results (e.g., from rgb in map-pixels)
if isinstance(then_val, tuple) and isinstance(else_val, tuple):
return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val))
return jnp.where(cond, then_val, else_val)
def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable:
"""Evaluate (lambda (params) body)."""
if len(args) < 2:
raise ValueError("lambda requires parameters and body")
params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]]
body = args[1]
captured_env = env.copy()
def fn(*fn_args):
local_env = captured_env.copy()
for pname, pval in zip(params, fn_args):
local_env[pname] = pval
return self._eval(body, local_env)
return fn
def _eval_define(self, args, env: Dict[str, Any]) -> Any:
"""Evaluate (define name value) or (define (name params) body)."""
if len(args) < 2:
raise ValueError("define requires name and value")
name_part = args[0]
if isinstance(name_part, list):
# Function definition: (define (name params) body)
fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0])
params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]]
body = args[1]
captured_env = env.copy()
def fn(*fn_args):
local_env = captured_env.copy()
for pname, pval in zip(params, fn_args):
local_env[pname] = pval
return self._eval(body, local_env)
env[fn_name] = fn
return fn
else:
# Variable definition
var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part)
val = self._eval(args[1], env)
env[var_name] = val
return val
def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any:
"""Evaluate a function call."""
# Check if it's a user-defined function
if op in env and callable(env[op]):
fn = env[op]
eval_args = [self._eval(a, env) for a in args]
return fn(*eval_args)
# Arithmetic
if op == '+':
vals = [self._eval(a, env) for a in args]
result = vals[0] if vals else 0.0
for v in vals[1:]:
result = result + v
return result
if op == '-':
if len(args) == 1:
return -self._eval(args[0], env)
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result - v
return result
if op == '*':
vals = [self._eval(a, env) for a in args]
result = vals[0] if vals else 1.0
for v in vals[1:]:
result = result * v
return result
if op == '/':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result / v
return result
if op == 'mod':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a % b
if op == 'pow' or op == '**':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return jnp.power(a, b)
# Comparison
if op == '<':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a < b
if op == '>':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a > b
if op == '<=':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a <= b
if op == '>=':
a, b = self._eval(args[0], env), self._eval(args[1], env)
return a >= b
if op == '=' or op == '==':
a, b = self._eval(args[0], env), self._eval(args[1], env)
# For scalar Python types, return Python bool to enable trace-time if
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return bool(a == b)
return a == b
if op == '!=' or op == '<>':
a, b = self._eval(args[0], env), self._eval(args[1], env)
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return bool(a != b)
return a != b
# Logic
if op == 'and':
vals = [self._eval(a, env) for a in args]
# Use Python and for concrete Python bools (e.g., shape comparisons)
if all(isinstance(v, (bool, np.bool_)) for v in vals):
result = True
for v in vals:
result = result and bool(v)
return result
# Otherwise use JAX logical_and
result = vals[0]
for v in vals[1:]:
result = jnp.logical_and(result, v)
return result
if op == 'or':
vals = [self._eval(a, env) for a in args]
# Use Python or for concrete Python bools
if all(isinstance(v, (bool, np.bool_)) for v in vals):
result = False
for v in vals:
result = result or bool(v)
return result
# Otherwise use JAX logical_or
result = vals[0]
for v in vals[1:]:
result = jnp.logical_or(result, v)
return result
if op == 'not':
val = self._eval(args[0], env)
if isinstance(val, (bool, np.bool_)):
return not bool(val)
return jnp.logical_not(val)
# Math functions
if op == 'sqrt':
return jnp.sqrt(self._eval(args[0], env))
if op == 'sin':
return jnp.sin(self._eval(args[0], env))
if op == 'cos':
return jnp.cos(self._eval(args[0], env))
if op == 'tan':
return jnp.tan(self._eval(args[0], env))
if op == 'exp':
return jnp.exp(self._eval(args[0], env))
if op == 'log':
return jnp.log(self._eval(args[0], env))
if op == 'abs':
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return abs(x)
return jnp.abs(x)
if op == 'floor':
import math
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return math.floor(x)
return jnp.floor(x)
if op == 'ceil':
import math
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return math.ceil(x)
return jnp.ceil(x)
if op == 'round':
x = self._eval(args[0], env)
if isinstance(x, (int, float)):
return round(x)
return jnp.round(x)
# Frame primitives
if op == 'width':
return env['width']
if op == 'height':
return env['height']
if op == 'channel':
frame = self._eval(args[0], env)
idx = self._eval(args[1], env)
# idx should be a Python int (literal from S-expression)
return jax_channel(frame, idx)
if op == 'merge-channels' or op == 'rgb':
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
# For scalars (e.g., in map-pixels), return tuple
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
if r_is_scalar and g_is_scalar and b_is_scalar:
return (r, g, b)
return jax_merge_channels(r, g, b, env['_shape'])
if op == 'sample':
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
return jax_sample(frame, x, y)
if op == 'cell-indices':
frame = self._eval(args[0], env)
cell_size = self._eval(args[1], env)
return jax_cell_indices(frame, cell_size)
if op == 'pool-frame':
frame = self._eval(args[0], env)
cell_size = self._eval(args[1], env)
return jax_pool_frame(frame, cell_size)
# Xector primitives
if op == 'iota':
n = self._eval(args[0], env)
return jax_iota(int(n))
if op == 'repeat':
x = self._eval(args[0], env)
n = self._eval(args[1], env)
return jax_repeat(x, int(n))
if op == 'tile':
x = self._eval(args[0], env)
n = self._eval(args[1], env)
return jax_tile(x, int(n))
if op == 'gather':
data = self._eval(args[0], env)
indices = self._eval(args[1], env)
return jax_gather(data, indices)
if op == 'scatter':
indices = self._eval(args[0], env)
values = self._eval(args[1], env)
size = int(self._eval(args[2], env))
return jax_scatter(indices, values, size)
if op == 'scatter-add':
indices = self._eval(args[0], env)
values = self._eval(args[1], env)
size = int(self._eval(args[2], env))
return jax_scatter_add(indices, values, size)
if op == 'group-reduce':
values = self._eval(args[0], env)
groups = self._eval(args[1], env)
num_groups = int(self._eval(args[2], env))
reduce_op = args[3] if len(args) > 3 else 'mean'
if isinstance(reduce_op, Symbol):
reduce_op = reduce_op.name
return jax_group_reduce(values, groups, num_groups, reduce_op)
if op == 'where':
cond = self._eval(args[0], env)
true_val = self._eval(args[1], env)
false_val = self._eval(args[2], env)
# Handle None values
if true_val is None:
true_val = 0.0
if false_val is None:
false_val = 0.0
return jax_where(cond, true_val, false_val)
if op == 'len' or op == 'length':
x = self._eval(args[0], env)
if isinstance(x, (list, tuple)):
return len(x)
return x.size
# Beta reductions
if op in ('β+', 'beta+', 'sum'):
return jnp.sum(self._eval(args[0], env))
if op in ('β*', 'beta*', 'product'):
return jnp.prod(self._eval(args[0], env))
if op in ('βmin', 'beta-min'):
return jnp.min(self._eval(args[0], env))
if op in ('βmax', 'beta-max'):
return jnp.max(self._eval(args[0], env))
if op in ('βmean', 'beta-mean', 'mean'):
return jnp.mean(self._eval(args[0], env))
if op in ('βstd', 'beta-std'):
return jnp.std(self._eval(args[0], env))
if op in ('βany', 'beta-any'):
return jnp.any(self._eval(args[0], env))
if op in ('βall', 'beta-all'):
return jnp.all(self._eval(args[0], env))
# Convenience - min/max of two values (handle both scalars and arrays)
if op == 'min' or op == 'min2':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
# Use Python min/max for scalar Python values to preserve type
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return min(a, b)
return jnp.minimum(jnp.asarray(a), jnp.asarray(b))
if op == 'max' or op == 'max2':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
# Use Python min/max for scalar Python values to preserve type
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return max(a, b)
return jnp.maximum(jnp.asarray(a), jnp.asarray(b))
if op == 'clamp':
x = self._eval(args[0], env)
lo = self._eval(args[1], env)
hi = self._eval(args[2], env)
return jnp.clip(x, lo, hi)
# List operations
if op == 'list':
return tuple(self._eval(a, env) for a in args)
if op == 'nth':
seq = self._eval(args[0], env)
idx = int(self._eval(args[1], env))
if isinstance(seq, (list, tuple)):
return seq[idx] if 0 <= idx < len(seq) else None
return seq[idx] # For arrays
if op == 'first':
seq = self._eval(args[0], env)
return seq[0] if len(seq) > 0 else None
if op == 'second':
seq = self._eval(args[0], env)
return seq[1] if len(seq) > 1 else None
# Random (JAX-compatible)
# Get frame_num for deterministic variation - can be traced, fold_in handles it
frame_num = env.get('_frame_num', env.get('frame_num', 0))
# Convert to int32 for fold_in if needed (but keep as JAX array if traced)
if frame_num is None:
frame_num = 0
elif isinstance(frame_num, (int, float)):
frame_num = int(frame_num)
# If it's a JAX array, leave it as-is for tracing
# Increment operation counter for unique keys within same frame
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
if op == 'rand' or op == 'rand-x':
# For size-based random
if args:
size = self._eval(args[0], env)
if hasattr(size, 'shape'):
# For frames (3D), use h*w (channel size), not h*w*c
if size.ndim == 3:
n = size.shape[0] * size.shape[1] # h * w
shape = (n,)
else:
n = size.size
shape = size.shape
elif hasattr(size, 'size'):
n = size.size
shape = (n,)
else:
n = int(size)
shape = (n,)
# Use deterministic key that varies with frame
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.uniform(key, shape).flatten()
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.uniform(key, ())
if op == 'randn' or op == 'randn-x':
# Normal random
if args:
size = self._eval(args[0], env)
if hasattr(size, 'shape'):
# For frames (3D), use h*w (channel size), not h*w*c
if size.ndim == 3:
n = size.shape[0] * size.shape[1] # h * w
else:
n = size.size
elif hasattr(size, 'size'):
n = size.size
else:
n = int(size)
mean = self._eval(args[1], env) if len(args) > 1 else 0.0
std = self._eval(args[2], env) if len(args) > 2 else 1.0
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, (n,)) * std + mean
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, ())
if op == 'rand-range' or op == 'core:rand-range':
lo = self._eval(args[0], env)
hi = self._eval(args[1], env)
seed = env.get('_seed', 42)
return jax_rand_range(lo, hi, frame_num, op_counter, seed)
# =====================================================================
# Convolution operations
# =====================================================================
if op == 'blur' or op == 'image:blur':
frame = self._eval(args[0], env)
radius = self._eval(args[1], env) if len(args) > 1 else 1
# Convert traced value to concrete for kernel size
if hasattr(radius, 'item'):
radius = int(radius.item())
elif hasattr(radius, '__float__'):
radius = int(float(radius))
else:
radius = int(radius)
return jax_blur(frame, max(1, radius))
if op == 'gaussian':
first_arg = self._eval(args[0], env)
# Check if first arg is a frame (blur) or scalar (random)
if hasattr(first_arg, 'shape') and first_arg.ndim == 3:
# Frame - apply gaussian blur
sigma = self._eval(args[1], env) if len(args) > 1 else 1.0
radius = max(1, int(sigma * 3))
return jax_blur(first_arg, radius)
else:
# Scalar args - generate gaussian random value
mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg
std = self._eval(args[1], env) if len(args) > 1 else 1.0
# Return a single random value
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
return jax.random.normal(key, ()) * std + mean
if op == 'sharpen' or op == 'image:sharpen':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
return jax_sharpen(frame, amount)
if op == 'edge-detect' or op == 'image:edge-detect':
frame = self._eval(args[0], env)
return jax_edge_detect(frame)
if op == 'emboss':
frame = self._eval(args[0], env)
return jax_emboss(frame)
if op == 'convolve':
frame = self._eval(args[0], env)
kernel = self._eval(args[1], env)
# Convert kernel to array if it's a list
if isinstance(kernel, (list, tuple)):
kernel = jnp.array(kernel, dtype=jnp.float32)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
if op == 'add-noise':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env) if len(args) > 1 else 0.1
h, w = frame.shape[:2]
# Use frame-varying key for noise
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
seed = env.get('_seed', 42)
key = make_jax_key(seed, frame_num, op_counter)
noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1]
result = frame.astype(jnp.float32) + noise * amount * 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
if op == 'translate':
frame = self._eval(args[0], env)
dx = self._eval(args[1], env)
dy = self._eval(args[2], env) if len(args) > 2 else 0
h, w = frame.shape[:2]
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w)
src_x = (x_coords - dx).flatten()
src_y = (y_coords - dy).flatten()
r, g, b = jax_sample(frame, src_x, src_y)
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'image:crop' or op == 'crop':
frame = self._eval(args[0], env)
x = int(self._eval(args[1], env))
y = int(self._eval(args[2], env))
w = int(self._eval(args[3], env))
h = int(self._eval(args[4], env))
return frame[y:y+h, x:x+w, :]
if op == 'dilate':
frame = self._eval(args[0], env)
size = int(self._eval(args[1], env)) if len(args) > 1 else 3
# Simple dilation using max pooling approximation
kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size)
h, w = frame.shape[:2]
r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size)
g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size)
b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size)
return jnp.stack([
jnp.clip(r, 0, 255).astype(jnp.uint8),
jnp.clip(g, 0, 255).astype(jnp.uint8),
jnp.clip(b, 0, 255).astype(jnp.uint8)
], axis=2)
if op == 'map-rows':
frame = self._eval(args[0], env)
fn = args[1] # S-expression function
h, w = frame.shape[:2]
# For each row, apply the function
results = []
for row_idx in range(h):
row_env = env.copy()
row_env['row'] = frame[row_idx, :, :]
row_env['row-idx'] = row_idx
# Check if fn is a lambda
if isinstance(fn, list) and len(fn) >= 2:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
# Bind lambda params to y and row
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
row_env[param_name] = row_idx # y
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
row_env[param_name] = frame[row_idx, :, :] # row
result_row = self._eval(body, row_env)
results.append(result_row)
continue
result_row = self._eval(fn, row_env)
# If result is a function, call it
if callable(result_row):
result_row = result_row(row_idx, frame[row_idx, :, :])
results.append(result_row)
return jnp.stack(results, axis=0)
# =====================================================================
# Color operations
# =====================================================================
if op == 'rgb->hsv' or op == 'rgb-to-hsv':
# Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple
if len(args) == 1:
c = self._eval(args[0], env)
if isinstance(c, tuple) and len(c) == 3:
r, g, b = c
else:
# Assume it's a list-like
r, g, b = c[0], c[1], c[2]
else:
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
return jax_rgb_to_hsv(r, g, b)
if op == 'hsv->rgb' or op == 'hsv-to-rgb':
# Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list)
if len(args) == 1:
hsv = self._eval(args[0], env)
if isinstance(hsv, (tuple, list)) and len(hsv) >= 3:
h, s, v = hsv[0], hsv[1], hsv[2]
else:
h, s, v = hsv[0], hsv[1], hsv[2]
else:
h = self._eval(args[0], env)
s = self._eval(args[1], env)
v = self._eval(args[2], env)
return jax_hsv_to_rgb(h, s, v)
if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness':
frame = self._eval(args[0], env)
amount = self._eval(args[1], env)
return jax_adjust_brightness(frame, amount)
if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast':
frame = self._eval(args[0], env)
factor = self._eval(args[1], env)
return jax_adjust_contrast(frame, factor)
if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation':
frame = self._eval(args[0], env)
factor = self._eval(args[1], env)
return jax_adjust_saturation(frame, factor)
if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift':
frame = self._eval(args[0], env)
degrees = self._eval(args[1], env)
return jax_shift_hue(frame, degrees)
if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img':
frame = self._eval(args[0], env)
return jax_invert(frame)
if op == 'posterize' or op == 'color_ops:posterize':
frame = self._eval(args[0], env)
levels = self._eval(args[1], env)
return jax_posterize(frame, levels)
if op == 'threshold' or op == 'color_ops:threshold':
frame = self._eval(args[0], env)
level = self._eval(args[1], env)
invert = self._eval(args[2], env) if len(args) > 2 else False
return jax_threshold(frame, level, invert)
if op == 'sepia' or op == 'color_ops:sepia':
frame = self._eval(args[0], env)
return jax_sepia(frame)
if op == 'grayscale' or op == 'image:grayscale':
frame = self._eval(args[0], env)
return jax_grayscale(frame)
# =====================================================================
# Geometry operations
# =====================================================================
if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-img':
frame = self._eval(args[0], env)
direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal'
if direction == 'vertical' or direction == 'v':
return jax_flip_vertical(frame)
return jax_flip_horizontal(frame)
if op == 'flip-vertical' or op == 'flip-v':
frame = self._eval(args[0], env)
return jax_flip_vertical(frame)
if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img':
frame = self._eval(args[0], env)
angle = self._eval(args[1], env)
return jax_rotate(frame, angle)
if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img':
frame = self._eval(args[0], env)
scale_x = self._eval(args[1], env)
scale_y = self._eval(args[2], env) if len(args) > 2 else None
return jax_scale(frame, scale_x, scale_y)
if op == 'resize' or op == 'image:resize':
frame = self._eval(args[0], env)
new_w = self._eval(args[1], env)
new_h = self._eval(args[2], env)
return jax_resize(frame, new_w, new_h)
# =====================================================================
# Geometry distortion effects
# =====================================================================
if op == 'geometry:fisheye-coords' or op == 'fisheye':
# Signature: (w h strength cx cy zoom_correct) or (frame strength)
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
# (w h strength cx cy zoom_correct) signature
w = int(first_arg)
h = int(self._eval(args[1], env))
strength = self._eval(args[2], env) if len(args) > 2 else 0.5
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
frame = None
else:
frame = first_arg
strength = self._eval(args[1], env) if len(args) > 1 else 0.5
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
max_r = jnp.sqrt(float(cx*cx + cy*cy))
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
# Fisheye distortion
r_new = r + strength * r * (1 - r / max_r)
src_x = r_new * jnp.cos(theta) + cx
src_y = r_new * jnp.sin(theta) + cy
if frame is None:
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:swirl-coords' or op == 'swirl':
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
w = int(first_arg)
h = int(self._eval(args[1], env))
amount = self._eval(args[2], env) if len(args) > 2 else 1.0
frame = None
else:
frame = first_arg
amount = self._eval(args[1], env) if len(args) > 1 else 1.0
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
max_r = jnp.sqrt(float(cx*cx + cy*cy))
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
swirl_angle = amount * (1 - r / max_r)
new_theta = theta + swirl_angle
src_x = r * jnp.cos(new_theta) + cx
src_y = r * jnp.sin(new_theta) + cy
if frame is None:
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# Wave effect (frame-first signature for simple usage)
if op == 'wave-distort':
first_arg = self._eval(args[0], env)
frame = first_arg
amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0
amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0
freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1
freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1
h, w = frame.shape[:2]
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y)
src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x)
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:ripple-displace' or op == 'ripple':
first_arg = self._eval(args[0], env)
if not hasattr(first_arg, 'shape'):
w = int(first_arg)
h = int(self._eval(args[1], env))
amplitude = self._eval(args[2], env) if len(args) > 2 else 10.0
frequency = self._eval(args[3], env) if len(args) > 3 else 0.05
frame = None
else:
frame = first_arg
amplitude = self._eval(args[1], env) if len(args) > 1 else 10.0
frequency = self._eval(args[2], env) if len(args) > 2 else 0.05
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
dist = jnp.sqrt(dx*dx + dy*dy)
displacement = amplitude * jnp.sin(dist * frequency)
angle = jnp.arctan2(dy, dx)
src_x = x_coords + displacement * jnp.cos(angle)
src_y = y_coords + displacement * jnp.sin(angle)
if frame is None:
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope':
# Two signatures: (frame segments) or (w h segments cx cy)
if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'):
# (w h segments cx cy) signature
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
segments = int(self._eval(args[2], env)) if len(args) > 2 else 6
cx = self._eval(args[3], env) if len(args) > 3 else w / 2
cy = self._eval(args[4], env) if len(args) > 4 else h / 2
frame = None
else:
frame = self._eval(args[0], env)
segments = int(self._eval(args[1], env)) if len(args) > 1 else 6
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
dx = x_coords - cx
dy = y_coords - cy
r = jnp.sqrt(dx*dx + dy*dy)
theta = jnp.arctan2(dy, dx)
# Mirror into segments
segment_angle = 2 * jnp.pi / segments
theta_mod = theta % segment_angle
theta_mirror = jnp.where(
(jnp.floor(theta / segment_angle) % 2) == 0,
theta_mod,
segment_angle - theta_mod
)
src_x = r * jnp.cos(theta_mirror) + cx
src_y = r * jnp.sin(theta_mirror) + cy
if frame is None:
# Return coordinate arrays
return {'x': src_x, 'y': src_y}
r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# Geometry coordinate extraction
if op == 'geometry:coords-x':
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords['x']
return coords[0] if isinstance(coords, tuple) else coords
if op == 'geometry:coords-y':
coords = self._eval(args[0], env)
if isinstance(coords, dict):
return coords['y']
return coords[1] if isinstance(coords, tuple) else coords
if op == 'geometry:remap' or op == 'remap':
frame = self._eval(args[0], env)
x_coords = self._eval(args[1], env)
y_coords = self._eval(args[2], env)
h, w = frame.shape[:2]
r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten())
return jnp.stack([
jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# =====================================================================
# Blending operations
# =====================================================================
if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
return jax_blend(frame1, frame2, alpha)
if op == 'blend-add':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_add(frame1, frame2)
if op == 'blend-multiply':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_multiply(frame1, frame2)
if op == 'blend-screen':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_screen(frame1, frame2)
if op == 'blend-overlay':
frame1 = self._eval(args[0], env)
frame2 = self._eval(args[1], env)
return jax_blend_overlay(frame1, frame2)
# =====================================================================
# Image dimension queries (namespaced aliases)
# =====================================================================
if op == 'image:width':
if args:
frame = self._eval(args[0], env)
return frame.shape[1] # width is second dimension (h, w, c)
return env['width']
if op == 'image:height':
if args:
frame = self._eval(args[0], env)
return frame.shape[0] # height is first dimension (h, w, c)
return env['height']
# =====================================================================
# Utility
# =====================================================================
if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?':
x = self._eval(args[0], env)
return jax_is_nil(x)
# =====================================================================
# Xector channel operations (shortcuts)
# =====================================================================
if op == 'red':
val = self._eval(args[0], env)
# Works on frames or pixel tuples
if isinstance(val, tuple):
return val[0]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 0)
else:
return val # Assume it's already a channel
if op == 'green':
val = self._eval(args[0], env)
if isinstance(val, tuple):
return val[1]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 1)
else:
return val
if op == 'blue':
val = self._eval(args[0], env)
if isinstance(val, tuple):
return val[2]
elif hasattr(val, 'shape') and val.ndim == 3:
return jax_channel(val, 2)
else:
return val
if op == 'gray' or op == 'luminance':
val = self._eval(args[0], env)
# Handle tuple (r, g, b) from map-pixels
if isinstance(val, tuple) and len(val) == 3:
r, g, b = val
return r * 0.299 + g * 0.587 + b * 0.114
# Handle frame
frame = val
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
return r * 0.299 + g * 0.587 + b * 0.114
if op == 'rgb':
r = self._eval(args[0], env)
g = self._eval(args[1], env)
b = self._eval(args[2], env)
# For scalars (e.g., in map-pixels), return tuple
r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ())
g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ())
b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ())
if r_is_scalar and g_is_scalar and b_is_scalar:
return (r, g, b)
return jax_merge_channels(r, g, b, env['_shape'])
# =====================================================================
# Coordinate operations
# =====================================================================
if op == 'x-coords':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
return jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
if op == 'y-coords':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
if op == 'dist-from-center':
frame = self._eval(args[0], env)
h, w = frame.shape[:2]
cx, cy = w / 2, h / 2
x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx
y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy
return jnp.sqrt(x*x + y*y)
# =====================================================================
# Alpha operations (element-wise on xectors)
# =====================================================================
if op == 'α/' or op == 'alpha/':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a / b
if op == 'α+' or op == 'alpha+':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result + v
return result
if op == 'α*' or op == 'alpha*':
vals = [self._eval(a, env) for a in args]
result = vals[0]
for v in vals[1:]:
result = result * v
return result
if op == 'α-' or op == 'alpha-':
if len(args) == 1:
return -self._eval(args[0], env)
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a - b
if op == 'αclamp' or op == 'alpha-clamp':
x = self._eval(args[0], env)
lo = self._eval(args[1], env)
hi = self._eval(args[2], env)
return jnp.clip(x, lo, hi)
if op == 'αmin' or op == 'alpha-min':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.minimum(a, b)
if op == 'αmax' or op == 'alpha-max':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.maximum(a, b)
if op == 'αsqrt' or op == 'alpha-sqrt':
return jnp.sqrt(self._eval(args[0], env))
if op == 'αsin' or op == 'alpha-sin':
return jnp.sin(self._eval(args[0], env))
if op == 'αcos' or op == 'alpha-cos':
return jnp.cos(self._eval(args[0], env))
if op == 'αmod' or op == 'alpha-mod':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a % b
if op == 'α²' or op == 'αsq' or op == 'alpha-sq':
x = self._eval(args[0], env)
return x * x
if op == 'α<' or op == 'alpha<':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a < b
if op == 'α>' or op == 'alpha>':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a > b
if op == 'α<=' or op == 'alpha<=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a <= b
if op == 'α>=' or op == 'alpha>=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a >= b
if op == 'α=' or op == 'alpha=':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return a == b
if op == 'αfloor' or op == 'alpha-floor':
return jnp.floor(self._eval(args[0], env))
if op == 'αceil' or op == 'alpha-ceil':
return jnp.ceil(self._eval(args[0], env))
if op == 'αround' or op == 'alpha-round':
return jnp.round(self._eval(args[0], env))
if op == 'αabs' or op == 'alpha-abs':
return jnp.abs(self._eval(args[0], env))
if op == 'αexp' or op == 'alpha-exp':
return jnp.exp(self._eval(args[0], env))
if op == 'αlog' or op == 'alpha-log':
return jnp.log(self._eval(args[0], env))
if op == 'αor' or op == 'alpha-or':
# Element-wise logical OR
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_or(a, b)
if op == 'αand' or op == 'alpha-and':
# Element-wise logical AND
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_and(a, b)
if op == 'αnot' or op == 'alpha-not':
# Element-wise logical NOT
return jnp.logical_not(self._eval(args[0], env))
if op == 'αxor' or op == 'alpha-xor':
# Element-wise logical XOR
a = self._eval(args[0], env)
b = self._eval(args[1], env)
return jnp.logical_xor(a, b)
# =====================================================================
# Threading/arrow operations
# =====================================================================
if op == '->':
# Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b)
val = self._eval(args[0], env)
for form in args[1:]:
if isinstance(form, list):
# Insert val as first argument
fn_name = form[0].name if isinstance(form[0], Symbol) else form[0]
new_args = [val] + [self._eval(a, env) for a in form[1:]]
val = self._eval_call(fn_name, [val] + form[1:], env)
else:
# Simple function call
fn_name = form.name if isinstance(form, Symbol) else form
val = self._eval_call(fn_name, [args[0]], env)
return val
# =====================================================================
# Range and iteration
# =====================================================================
if op == 'range':
if len(args) == 1:
end = int(self._eval(args[0], env))
return list(range(end))
elif len(args) == 2:
start = int(self._eval(args[0], env))
end = int(self._eval(args[1], env))
return list(range(start, end))
else:
start = int(self._eval(args[0], env))
end = int(self._eval(args[1], env))
step = int(self._eval(args[2], env))
return list(range(start, end, step))
if op == 'reduce' or op == 'fold':
# (reduce seq init fn) - left fold
seq = self._eval(args[0], env)
acc = self._eval(args[1], env)
fn = args[2] # Lambda S-expression
# Handle lambda
if isinstance(fn, list) and len(fn) >= 3:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
for item in seq:
fn_env = env.copy()
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
fn_env[param_name] = acc
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
fn_env[param_name] = item
acc = self._eval(body, fn_env)
return acc
# Fallback - try evaluating fn and calling it
fn_eval = self._eval(fn, env)
if callable(fn_eval):
for item in seq:
acc = fn_eval(acc, item)
return acc
# =====================================================================
# Map-pixels (apply function to each pixel)
# =====================================================================
if op == 'map-pixels':
frame = self._eval(args[0], env)
fn = args[1] # Lambda or S-expression
h, w = frame.shape[:2]
# Extract channels
r = frame[:, :, 0].flatten().astype(jnp.float32)
g = frame[:, :, 1].flatten().astype(jnp.float32)
b = frame[:, :, 2].flatten().astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h)
y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w)
# Set up pixel environment
pixel_env = env.copy()
pixel_env['r'] = r
pixel_env['g'] = g
pixel_env['b'] = b
pixel_env['x'] = x_coords
pixel_env['y'] = y_coords
# Also provide c (color) as a tuple for lambda (x y c) style
pixel_env['c'] = (r, g, b)
# If fn is a lambda, we need to handle it specially
if isinstance(fn, list) and len(fn) >= 2:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
# Lambda: (lambda (x y c) body)
params = fn[1]
body = fn[2]
# Bind parameters
if len(params) >= 1:
param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0])
pixel_env[param_name] = x_coords
if len(params) >= 2:
param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1])
pixel_env[param_name] = y_coords
if len(params) >= 3:
param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2])
pixel_env[param_name] = (r, g, b)
result = self._eval(body, pixel_env)
else:
result = self._eval(fn, pixel_env)
else:
result = self._eval(fn, pixel_env)
if isinstance(result, tuple) and len(result) == 3:
nr, ng, nb = result
return jax_merge_channels(nr, ng, nb, (h, w))
elif hasattr(result, 'shape') and result.ndim == 3:
return result
else:
# Single channel result
if hasattr(result, 'flatten'):
result = result.flatten()
gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8)
return jnp.stack([gray, gray, gray], axis=2)
# =====================================================================
# State operations (return unchanged for stateless JIT)
# =====================================================================
if op == 'state-get':
key = self._eval(args[0], env)
default = self._eval(args[1], env) if len(args) > 1 else None
return default # State not supported in JIT, return default
if op == 'state-set':
return None # No-op in JIT
# =====================================================================
# Cell/grid operations
# =====================================================================
if op == 'local-x-norm':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return (x % cell_size) / max(1, cell_size - 1)
if op == 'local-y-norm':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return (y % cell_size) / max(1, cell_size - 1)
if op == 'local-x':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return (x % cell_size).astype(jnp.float32)
if op == 'local-y':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return (y % cell_size).astype(jnp.float32)
if op == 'cell-row':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
y = jnp.repeat(jnp.arange(h), w)
return jnp.floor(y / cell_size)
if op == 'cell-col':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
h, w = frame.shape[:2]
x = jnp.tile(jnp.arange(w), h)
return jnp.floor(x / cell_size)
if op == 'num-rows':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
return frame.shape[0] // cell_size
if op == 'num-cols':
frame = self._eval(args[0], env)
cell_size = int(self._eval(args[1], env))
return frame.shape[1] // cell_size
# =====================================================================
# Control flow
# =====================================================================
if op == 'cond':
# (cond (test1 expr1) (test2 expr2) ... (else exprN))
# For JAX compatibility, build a nested jnp.where structure
# Start from the else clause and work backwards
# Collect clauses
clauses = []
else_expr = None
for clause in args:
if isinstance(clause, list) and len(clause) >= 2:
test = clause[0]
if isinstance(test, Symbol) and test.name == 'else':
else_expr = clause[1]
else:
clauses.append((test, clause[1]))
# If no else, default to None/0
if else_expr is not None:
result = self._eval(else_expr, env)
else:
result = 0
# Build nested where from last to first
for test_expr, val_expr in reversed(clauses):
cond_val = self._eval(test_expr, env)
then_val = self._eval(val_expr, env)
# Check if condition is array or scalar
if hasattr(cond_val, 'shape') and cond_val.shape != ():
# Array condition - use jnp.where
result = jnp.where(cond_val, then_val, result)
else:
# Scalar - can use Python if
if cond_val:
result = then_val
return result
if op == 'set!' or op == 'set':
# Mutation - not really supported in JAX, but we can update env
var = args[0].name if isinstance(args[0], Symbol) else str(args[0])
val = self._eval(args[1], env)
env[var] = val
return val
if op == 'begin' or op == 'do':
# Evaluate all expressions, return last
result = None
for expr in args:
result = self._eval(expr, env)
return result
# =====================================================================
# Additional math
# =====================================================================
if op == 'sq' or op == 'square':
x = self._eval(args[0], env)
return x * x
if op == 'lerp':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
t = self._eval(args[2], env)
return a * (1 - t) + b * t
if op == 'smoothstep':
edge0 = self._eval(args[0], env)
edge1 = self._eval(args[1], env)
x = self._eval(args[2], env)
t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1)
return t * t * (3 - 2 * t)
if op == 'atan2':
y = self._eval(args[0], env)
x = self._eval(args[1], env)
return jnp.arctan2(y, x)
if op == 'fract' or op == 'frac':
x = self._eval(args[0], env)
return x - jnp.floor(x)
# =====================================================================
# Frame copy and construction operations
# =====================================================================
if op == 'pixel':
# Get pixel at (x, y) from frame
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
h, w = frame.shape[:2]
# Convert to int and clip to bounds
if isinstance(x, (int, float)):
x = max(0, min(int(x), w - 1))
else:
x = jnp.clip(x, 0, w - 1).astype(jnp.int32)
if isinstance(y, (int, float)):
y = max(0, min(int(y), h - 1))
else:
y = jnp.clip(y, 0, h - 1).astype(jnp.int32)
r = frame[y, x, 0]
g = frame[y, x, 1]
b = frame[y, x, 2]
return (r, g, b)
if op == 'copy':
frame = self._eval(args[0], env)
return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame)
if op == 'make-image':
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
if len(args) > 2:
color = self._eval(args[2], env)
if isinstance(color, (list, tuple)):
r, g, b = int(color[0]), int(color[1]), int(color[2])
else:
r = g = b = int(color)
else:
r = g = b = 0
img = jnp.zeros((h, w, 3), dtype=jnp.uint8)
img = img.at[:, :, 0].set(r)
img = img.at[:, :, 1].set(g)
img = img.at[:, :, 2].set(b)
return img
if op == 'paste':
dest = self._eval(args[0], env)
src = self._eval(args[1], env)
x = int(self._eval(args[2], env))
y = int(self._eval(args[3], env))
sh, sw = src.shape[:2]
dh, dw = dest.shape[:2]
# Clip to dest bounds
x1, y1 = max(0, x), max(0, y)
x2, y2 = min(dw, x + sw), min(dh, y + sh)
sx1, sy1 = x1 - x, y1 - y
sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1)
result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest)
result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :])
return result
# =====================================================================
# Blending operations
# =====================================================================
if op == 'blending:blend-images' or op == 'blend-images':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
alpha = self._eval(args[2], env) if len(args) > 2 else 0.5
return jax_blend(a, b, alpha)
if op == 'blending:blend-mode' or op == 'blend-mode':
a = self._eval(args[0], env)
b = self._eval(args[1], env)
mode = self._eval(args[2], env) if len(args) > 2 else 'add'
if mode == 'add':
return jax_blend_add(a, b)
elif mode == 'multiply':
return jax_blend_multiply(a, b)
elif mode == 'screen':
return jax_blend_screen(a, b)
elif mode == 'overlay':
return jax_blend_overlay(a, b)
elif mode == 'lighten':
return jnp.maximum(a, b)
elif mode == 'darken':
return jnp.minimum(a, b)
elif mode == 'difference':
return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8)
else:
return jax_blend(a, b, 0.5)
# =====================================================================
# Geometry coordinate operations
# =====================================================================
if op == 'geometry:wave-coords' or op == 'wave-coords':
w = int(self._eval(args[0], env))
h = int(self._eval(args[1], env))
axis = self._eval(args[2], env) if len(args) > 2 else 'x'
freq = self._eval(args[3], env) if len(args) > 3 else 1.0
amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0
phase = self._eval(args[5], env) if len(args) > 5 else 0.0
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
if axis == 'x' or axis == 'horizontal':
# Wave displaces X based on Y
offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
src_x = x_coords + offset
src_y = y_coords
elif axis == 'y' or axis == 'vertical':
# Wave displaces Y based on X
offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
src_x = x_coords
src_y = y_coords + offset
else: # both
offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase)
offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase)
src_x = x_coords + offset_x
src_y = y_coords + offset_y
return {'x': src_x, 'y': src_y}
if op == 'geometry:coords-x' or op == 'coords-x':
coords = self._eval(args[0], env)
return coords['x']
if op == 'geometry:coords-y' or op == 'coords-y':
coords = self._eval(args[0], env)
return coords['y']
if op == 'geometry:remap' or op == 'remap':
frame = self._eval(args[0], env)
x = self._eval(args[1], env)
y = self._eval(args[2], env)
h, w = frame.shape[:2]
r, g, b = jax_sample(frame, x.flatten(), y.flatten())
return jnp.stack([
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
], axis=2)
# =====================================================================
# Glitch effects
# =====================================================================
if op == 'pixelsort':
frame = self._eval(args[0], env)
sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness'
thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50
thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200
angle = int(self._eval(args[4], env)) if len(args) > 4 else 0
reverse = self._eval(args[5], env) if len(args) > 5 else False
h, w = frame.shape[:2]
result = frame.copy()
# Get luminance for thresholding
lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 +
frame[:, :, 1].astype(jnp.float32) * 0.587 +
frame[:, :, 2].astype(jnp.float32) * 0.114)
# Sort each row
for y in range(h):
row_lum = lum[y, :]
row = frame[y, :, :]
# Find mask of pixels to sort
mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi)
# Get indices where we should sort
sort_indices = jnp.where(mask, jnp.arange(w), -1)
# Simple sort by luminance for the row
if sort_by == 'lightness':
sort_key = row_lum
elif sort_by == 'hue':
# Approximate hue from RGB
sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32),
row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32)))
else:
sort_key = row_lum
# Sort pixels in masked region
sorted_indices = jnp.argsort(sort_key)
if reverse:
sorted_indices = sorted_indices[::-1]
# Apply partial sort (only where mask is true)
# This is a simplified version - full pixelsort is more complex
result = result.at[y, :, :].set(row[sorted_indices])
return result
if op == 'datamosh':
frame = self._eval(args[0], env)
prev = self._eval(args[1], env)
block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32
corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3
max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50
color_corrupt = self._eval(args[5], env) if len(args) > 5 else True
h, w = frame.shape[:2]
# Use deterministic random for JIT with frame variation
seed = env.get('_seed', 42)
op_counter = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_counter + 1
key = make_jax_key(seed, frame_num, op_counter)
num_blocks_y = h // block_size
num_blocks_x = w // block_size
total_blocks = num_blocks_y * num_blocks_x
# Pre-generate all random values at once (vectorized)
key, k1, k2, k3, k4, k5 = jax.random.split(key, 6)
corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption
offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1)
offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1)
channels = jax.random.randint(k4, (total_blocks,), 0, 3)
color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51)
# Create coordinate grids for blocks
by_grid = jnp.arange(num_blocks_y)
bx_grid = jnp.arange(num_blocks_x)
# Create block coordinate arrays
by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...]
bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...]
# Create pixel coordinate grids
y_coords, x_coords = jnp.mgrid[:h, :w]
# Determine which block each pixel belongs to
pixel_block_y = y_coords // block_size
pixel_block_x = x_coords // block_size
pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x
# Clamp to valid block indices (for pixels outside the block grid)
pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1)
# Get the corrupt mask, offsets for each pixel's block
pixel_corrupt = corrupt_mask[pixel_block_idx]
pixel_offset_y = offsets_y[pixel_block_idx]
pixel_offset_x = offsets_x[pixel_block_idx]
# Calculate source coordinates with offset (clamped)
src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1)
src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1)
# Sample from previous frame at offset positions
prev_sampled = prev[src_y, src_x, :]
# Where corrupt mask is true, use prev_sampled; else use frame
result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame)
# Apply color corruption to corrupted blocks
if color_corrupt:
pixel_channel = channels[pixel_block_idx]
pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16)
# Create per-channel shift arrays (only shift the selected channel)
shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0)
shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0)
shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0)
result_int = result.astype(jnp.int16)
result_int = result_int.at[:, :, 0].add(shift_r)
result_int = result_int.at[:, :, 1].add(shift_g)
result_int = result_int.at[:, :, 2].add(shift_b)
result = jnp.clip(result_int, 0, 255).astype(jnp.uint8)
return result
# =====================================================================
# ASCII Art Operations (using pre-rendered font atlas)
# =====================================================================
if op == 'cell-sample':
# (cell-sample frame char_size) -> (colors, luminances)
# Downsample frame into cells, return average colors and luminances
frame = self._eval(args[0], env)
char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8
h, w = frame.shape[:2]
num_rows = h // char_size
num_cols = w // char_size
# Crop to exact multiple of char_size
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
# Reshape to (num_rows, char_size, num_cols, char_size, 3)
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
# Average over char_size dimensions -> (num_rows, num_cols, 3)
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
# Compute luminance per cell
colors_float = colors.astype(jnp.float32)
luminances = (0.299 * colors_float[:, :, 0] +
0.587 * colors_float[:, :, 1] +
0.114 * colors_float[:, :, 2]) / 255.0
return (colors, luminances)
if op == 'luminance-to-chars':
# (luminance-to-chars luminances alphabet contrast) -> char_indices
# Map luminance values to character indices
luminances = self._eval(args[0], env)
alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard'
contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5
# Get alphabet string
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
# Apply contrast
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
# Map to character indices (0 = darkest, num_chars-1 = brightest)
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
return char_indices
if op == 'render-char-grid':
# (render-char-grid frame chars colors char_size color_mode background_color invert_colors)
# Render character grid using font atlas
frame = self._eval(args[0], env)
char_indices = self._eval(args[1], env)
colors = self._eval(args[2], env)
char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8
color_mode = self._eval(args[4], env) if len(args) > 4 else 'color'
background_color = self._eval(args[5], env) if len(args) > 5 else 'black'
invert_colors = self._eval(args[6], env) if len(args) > 6 else 0
h, w = frame.shape[:2]
num_rows, num_cols = char_indices.shape
# Get the alphabet used (stored in env or default)
alphabet = env.get('_ascii_alphabet', 'standard')
alpha_str = _get_alphabet_string(alphabet)
# Get or create font atlas
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background color
if background_color == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background_color == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
# Try to parse hex color
try:
if background_color.startswith('#'):
bg_hex = background_color[1:]
bg = jnp.array([int(bg_hex[0:2], 16),
int(bg_hex[2:4], 16),
int(bg_hex[4:6], 16)], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
except:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Create output image starting with background
output_h = num_rows * char_size
output_w = num_cols * char_size
result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy()
# Gather characters from atlas based on indices
# char_indices shape: (num_rows, num_cols)
# font_atlas shape: (num_chars, char_size, char_size, 3)
# Convert numpy atlas to JAX for indexing with traced indices
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3)
# Reshape to grid
char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3)
# Create coordinate grids for output pixels
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = y_out // char_size
cell_col = x_out // char_size
local_y = y_out % char_size
local_x = x_out % char_size
# Clamp to valid ranges
cell_row = jnp.clip(cell_row, 0, num_rows - 1)
cell_col = jnp.clip(cell_col, 0, num_cols - 1)
# Get character pixel values
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
# Get character brightness (for masking)
char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0
# Handle color modes
if color_mode == 'mono':
# White characters on background
fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
fg = jnp.broadcast_to(fg_color, char_pixels.shape)
elif color_mode == 'invert':
# Inverted cell colors
cell_colors = colors[cell_row, cell_col]
fg = 255 - cell_colors
else:
# 'color' mode - use cell colors
fg = colors[cell_row, cell_col]
# Blend foreground onto background based on character brightness
if invert_colors:
# Swap fg and bg
fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg
result = (fg.astype(jnp.float32) * (1 - char_brightness) +
bg_broadcast.astype(jnp.float32) * char_brightness)
else:
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) +
fg.astype(jnp.float32) * char_brightness)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize back to original frame size if needed
if result.shape[0] != h or result.shape[1] != w:
# Simple nearest-neighbor resize
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = (jnp.arange(h) * y_scale).astype(jnp.int32)
x_src = (jnp.arange(w) * x_scale).astype(jnp.int32)
y_src = jnp.clip(y_src, 0, result.shape[0] - 1)
x_src = jnp.clip(x_src, 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'ascii-fx-zone':
# Complex ASCII effect with per-zone expressions
# (ascii-fx-zone frame :cols cols :alphabet alphabet ...)
frame = self._eval(args[0], env)
# Parse keyword arguments
kwargs = {}
i = 1
while i < len(args):
if isinstance(args[i], Keyword):
key = args[i].name
if i + 1 < len(args):
kwargs[key] = args[i + 1]
i += 2
else:
i += 1
# Get parameters
cols = int(self._eval(kwargs.get('cols', 80), env))
char_size_param = kwargs.get('char_size')
alphabet = self._eval(kwargs.get('alphabet', 'standard'), env)
color_mode = self._eval(kwargs.get('color_mode', 'color'), env)
background = self._eval(kwargs.get('background', 'black'), env)
contrast = float(self._eval(kwargs.get('contrast', 1.5), env))
h, w = frame.shape[:2]
# Calculate char_size from cols if not specified
if char_size_param is not None:
char_size_val = self._eval(char_size_param, env)
if char_size_val is not None:
char_size = int(char_size_val)
else:
char_size = w // cols
else:
char_size = w // cols
char_size = max(4, min(char_size, 64))
# Store alphabet for render-char-grid to use
env['_ascii_alphabet'] = alphabet
# Cell sampling
num_rows = h // char_size
num_cols = w // char_size
cropped = frame[:num_rows * char_size, :num_cols * char_size, :]
reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3)
colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8)
# Compute luminances
colors_float = colors.astype(jnp.float32)
luminances = (0.299 * colors_float[:, :, 0] +
0.587 * colors_float[:, :, 1] +
0.114 * colors_float[:, :, 2]) / 255.0
# Get alphabet and map luminances to chars
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1)
char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32)
char_indices = jnp.clip(char_indices, 0, num_chars - 1)
# Handle optional per-zone effects (char_hue, char_saturation, etc.)
# These would modify colors based on zone position
char_hue = kwargs.get('char_hue')
char_saturation = kwargs.get('char_saturation')
char_brightness = kwargs.get('char_brightness')
if char_hue is not None or char_saturation is not None or char_brightness is not None:
# Create zone coordinate arrays for expression evaluation
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
row_norm = row_coords / max(num_rows - 1, 1)
col_norm = col_coords / max(num_cols - 1, 1)
# Bind zone variables
zone_env = env.copy()
zone_env['zone-row'] = row_coords
zone_env['zone-col'] = col_coords
zone_env['zone-row-norm'] = row_norm
zone_env['zone-col-norm'] = col_norm
zone_env['zone-lum'] = luminances
# Apply color modifications (simplified - full version would use HSV)
if char_brightness is not None:
brightness_mult = self._eval(char_brightness, zone_env)
if brightness_mult is not None:
colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None],
0, 255).astype(jnp.uint8)
# Render using font atlas
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background color
if background == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Gather characters - convert numpy atlas to JAX for traced indexing
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
# Create output
output_h = num_rows * char_size
output_w = num_cols * char_size
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
local_y = y_out % char_size
local_x = x_out % char_size
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
if color_mode == 'mono':
fg = jnp.full_like(char_pixels, 255)
else:
fg = colors[cell_row, cell_col]
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
fg.astype(jnp.float32) * char_bright)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize to original dimensions
if result.shape[0] != h or result.shape[1] != w:
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'render-char-grid-fx':
# Enhanced render with per-character effects
# (render-char-grid-fx frame chars colors luminances char_size
# color_mode bg_color invert_colors
# char_jitter char_scale char_rotation char_hue_shift
# jitter_source scale_source rotation_source hue_source)
frame = self._eval(args[0], env)
char_indices = self._eval(args[1], env)
colors = self._eval(args[2], env)
luminances = self._eval(args[3], env)
char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8
color_mode = self._eval(args[5], env) if len(args) > 5 else 'color'
background_color = self._eval(args[6], env) if len(args) > 6 else 'black'
invert_colors = self._eval(args[7], env) if len(args) > 7 else 0
# Per-char effect amounts
char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0
char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0
char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0
char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0
# Modulation sources
jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none'
scale_source = self._eval(args[13], env) if len(args) > 13 else 'none'
rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none'
hue_source = self._eval(args[15], env) if len(args) > 15 else 'none'
h, w = frame.shape[:2]
num_rows, num_cols = char_indices.shape
# Get alphabet
alphabet = env.get('_ascii_alphabet', 'standard')
alpha_str = _get_alphabet_string(alphabet)
font_atlas = _create_font_atlas(alpha_str, char_size)
# Parse background
if background_color == 'black':
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
elif background_color == 'white':
bg = jnp.array([255, 255, 255], dtype=jnp.uint8)
else:
bg = jnp.array([0, 0, 0], dtype=jnp.uint8)
# Create modulation values based on source
def get_modulation(source, lums, num_rows, num_cols):
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
row_norm = row_coords / max(num_rows - 1, 1)
col_norm = col_coords / max(num_cols - 1, 1)
if source == 'luminance':
return lums
elif source == 'inv_luminance':
return 1.0 - lums
elif source == 'position_x':
return col_norm
elif source == 'position_y':
return row_norm
elif source == 'position_diag':
return (row_norm + col_norm) / 2
elif source == 'center_dist':
cy, cx = 0.5, 0.5
dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2)
return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal
elif source == 'random':
# Use frame-varying key for random source
seed = env.get('_seed', 42)
op_ctr = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_ctr + 1
key = make_jax_key(seed, frame_num, op_ctr)
return jax.random.uniform(key, (num_rows, num_cols))
else:
return jnp.zeros((num_rows, num_cols))
# Get modulation values
jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols)
scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols)
rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols)
hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols)
# Gather characters - convert numpy atlas to JAX for traced indexing
font_atlas_jax = jnp.asarray(font_atlas)
flat_indices = char_indices.flatten()
char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3)
# Create output
output_h = num_rows * char_size
output_w = num_cols * char_size
y_out, x_out = jnp.mgrid[:output_h, :output_w]
cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1)
cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1)
local_y = y_out % char_size
local_x = x_out % char_size
# Apply jitter if enabled
if char_jitter > 0:
jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter
# Use frame-varying key for jitter
seed = env.get('_seed', 42)
op_ctr = env.get('_rand_op_counter', 0)
env['_rand_op_counter'] = op_ctr + 1
key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2)
# Generate deterministic jitter per cell
jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1)
jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1)
offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32)
local_y = jnp.clip(local_y + offset_y, 0, char_size - 1)
local_x = jnp.clip(local_x + offset_x, 0, char_size - 1)
char_pixels = char_tiles[cell_row, cell_col, local_y, local_x]
char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0
# Get foreground colors
if color_mode == 'mono':
fg = jnp.full_like(char_pixels, 255)
else:
fg = colors[cell_row, cell_col]
# Apply hue shift if enabled
if char_hue_shift > 0 and color_mode == 'color':
hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift
# Simple hue rotation via channel cycling
fg_float = fg.astype(jnp.float32)
shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB
# Simplified: blend channels based on shift
r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2]
shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac
# Just do a simple tint for now
fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8)
# Blend
bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape)
if invert_colors:
result = (fg.astype(jnp.float32) * (1 - char_bright) +
bg_broadcast.astype(jnp.float32) * char_bright)
else:
result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) +
fg.astype(jnp.float32) * char_bright)
result = jnp.clip(result, 0, 255).astype(jnp.uint8)
# Resize to original
if result.shape[0] != h or result.shape[1] != w:
y_scale = result.shape[0] / h
x_scale = result.shape[1] / w
y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1)
x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1)
result = result[y_src[:, None], x_src[None, :], :]
return result
if op == 'alphabet-char':
# (alphabet-char alphabet-name index) -> char_index in that alphabet
alphabet = self._eval(args[0], env)
index = self._eval(args[1], env)
alpha_str = _get_alphabet_string(alphabet)
num_chars = len(alpha_str)
# Handle both scalar and array indices
if hasattr(index, 'shape'):
index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1)
else:
index = max(0, min(int(index), num_chars - 1))
return index
if op == 'map-char-grid':
# (map-char-grid base-chars luminances (lambda (r c ch lum) ...))
# Map over character grid, allowing per-cell character selection
base_chars = self._eval(args[0], env)
luminances = self._eval(args[1], env)
fn = args[2] # Lambda expression
num_rows, num_cols = base_chars.shape
# For JAX compatibility, we can't use Python loops with traced values
# Instead, we'll evaluate the lambda for the whole grid at once
if isinstance(fn, list) and len(fn) >= 3:
head = fn[0]
if isinstance(head, Symbol) and head.name in ('lambda', 'λ'):
params = fn[1]
body = fn[2]
# Create grid coordinates
row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols]
# Bind parameters for whole-grid evaluation
fn_env = env.copy()
# Params: (r c ch lum)
if len(params) >= 1:
fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords
if len(params) >= 2:
fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords
if len(params) >= 3:
fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars
if len(params) >= 4:
# Luminances scaled to 0-255 range
fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32)
# Evaluate body - should return new character indices
result = self._eval(body, fn_env)
if hasattr(result, 'shape'):
return result.astype(jnp.int32)
return base_chars
return base_chars
# =====================================================================
# List operations
# =====================================================================
if op == 'take':
seq = self._eval(args[0], env)
n = int(self._eval(args[1], env))
if isinstance(seq, (list, tuple)):
return seq[:n]
return seq[:n] # Works for arrays too
if op == 'cons':
item = self._eval(args[0], env)
seq = self._eval(args[1], env)
if isinstance(seq, list):
return [item] + seq
elif isinstance(seq, tuple):
return (item,) + seq
return jnp.concatenate([jnp.array([item]), seq])
if op == 'roll':
arr = self._eval(args[0], env)
shift = self._eval(args[1], env)
axis = self._eval(args[2], env) if len(args) > 2 else 0
# Convert to int for concrete values, keep as-is for JAX traced values
if isinstance(shift, (int, float)):
shift = int(shift)
elif hasattr(shift, 'astype'):
shift = shift.astype(jnp.int32)
if isinstance(axis, (int, float)):
axis = int(axis)
return jnp.roll(arr, shift, axis=axis)
# =====================================================================
# Pi constant
# =====================================================================
if op == 'pi':
return jnp.pi
raise ValueError(f"Unknown operation: {op}")
# =============================================================================
# Public API
# =============================================================================
def compile_effect(code: str) -> Callable:
"""
Compile an S-expression effect to a JAX function.
Args:
code: S-expression effect code
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
# Check cache
cache_key = hashlib.md5(code.encode()).hexdigest()
if cache_key in _COMPILED_EFFECTS:
return _COMPILED_EFFECTS[cache_key]
# Parse and compile
sexp = parse(code)
compiler = JaxCompiler()
fn = compiler.compile_effect(sexp)
_COMPILED_EFFECTS[cache_key] = fn
return fn
def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable:
"""
Compile an effect from a .sexp file.
Args:
path: Path to the .sexp effect file
derived_paths: Optional list of paths to derived.sexp files to load
Returns:
JIT-compiled function: (frame, **params) -> frame
"""
with open(path, 'r') as f:
code = f.read()
# Parse all expressions in file
exprs = parse_all(code)
# Create compiler
compiler = JaxCompiler()
# Load derived files if specified
if derived_paths:
for dp in derived_paths:
compiler.load_derived(dp)
# Process expressions - find require statements and the effect
effect_sexp = None
effect_dir = Path(path).parent
for expr in exprs:
if not isinstance(expr, list) or len(expr) < 2:
continue
head = expr[0]
if not isinstance(head, Symbol):
continue
if head.name == 'require':
# (require "derived") or (require "path/to/file")
req_path = expr[1]
if isinstance(req_path, str):
# Resolve relative to effect file
if not req_path.endswith('.sexp'):
req_path = req_path + '.sexp'
full_path = effect_dir / req_path
if not full_path.exists():
# Try sexp_effects directory
full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path
if full_path.exists():
compiler.load_derived(str(full_path))
elif head.name == 'require-primitives':
# (require-primitives "lib") - currently ignored for JAX
# JAX has all primitives built-in
pass
elif head.name in ('effect', 'define-effect'):
effect_sexp = expr
if effect_sexp is None:
raise ValueError(f"No effect definition found in {path}")
return compiler.compile_effect(effect_sexp)
def load_derived(derived_path: str = None) -> Dict[str, Callable]:
"""
Load derived operations from derived.sexp.
Returns dict of compiled functions that can be used in effects.
"""
if derived_path is None:
derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp'
with open(derived_path, 'r') as f:
code = f.read()
exprs = parse_all(code)
compiler = JaxCompiler()
env = {}
for expr in exprs:
if isinstance(expr, list) and len(expr) >= 3:
head = expr[0]
if isinstance(head, Symbol) and head.name == 'define':
compiler._eval_define(expr[1:], env)
return env
# =============================================================================
# Test / Demo
# =============================================================================
if __name__ == '__main__':
import numpy as np
# Test effect
test_effect = '''
(effect "threshold-test"
:params ((threshold :default 128))
:body (let* ((r (channel frame 0))
(g (channel frame 1))
(b (channel frame 2))
(gray (+ (* r 0.299) (* g 0.587) (* b 0.114)))
(mask (where (> gray threshold) 255 0)))
(merge-channels mask mask mask)))
'''
print("Compiling effect...")
run_effect = compile_effect(test_effect)
# Create test frame
print("Creating test frame...")
frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8)
# Run effect
print("Running effect (first run includes JIT compilation)...")
import time
t0 = time.time()
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms")
# Second run should be faster
t0 = time.time()
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"Second run (cached): {(t1-t0)*1000:.2f}ms")
# Multiple runs
t0 = time.time()
for _ in range(100):
result = run_effect(frame, threshold=128)
t1 = time.time()
print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg")
print(f"\nResult shape: {result.shape}, dtype: {result.dtype}")
print("Done!")