diff --git a/sexp_effects/primitive_libs/core.py b/sexp_effects/primitive_libs/core.py index 352cbd3..34b580a 100644 --- a/sexp_effects/primitive_libs/core.py +++ b/sexp_effects/primitive_libs/core.py @@ -52,15 +52,33 @@ def prim_max(*args): def prim_round(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) return round(x) def prim_floor(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) import math return math.floor(x) def prim_ceil(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) import math return math.ceil(x) @@ -193,6 +211,11 @@ def prim_range(*args): import random _rng = random.Random() +def set_random_seed(seed): + """Set the random seed for deterministic output.""" + global _rng + _rng = random.Random(seed) + def prim_rand(): """Return random float in [0, 1).""" return _rng.random() diff --git a/streaming/output.py b/streaming/output.py index 86439da..b2a4e85 100644 --- a/streaming/output.py +++ b/streaming/output.py @@ -37,19 +37,39 @@ _nvenc_available: Optional[bool] = None def check_nvenc_available() -> bool: - """Check if NVENC hardware encoding is available.""" + """Check if NVENC hardware encoding is available and working. + + Does a real encode test to catch cases where nvenc is listed + but CUDA libraries aren't loaded. + """ global _nvenc_available if _nvenc_available is not None: return _nvenc_available try: + # First check if encoder is listed result = subprocess.run( ["ffmpeg", "-encoders"], capture_output=True, text=True, timeout=5 ) - _nvenc_available = "h264_nvenc" in result.stdout + if "h264_nvenc" not in result.stdout: + _nvenc_available = False + return _nvenc_available + + # Actually try to encode a small test frame + result = subprocess.run( + ["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1", + "-c:v", "h264_nvenc", "-f", "null", "-"], + capture_output=True, + text=True, + timeout=10 + ) + _nvenc_available = result.returncode == 0 + if not _nvenc_available: + import sys + print("NVENC listed but not working, falling back to libx264", file=sys.stderr) except Exception: _nvenc_available = False diff --git a/streaming/sexp_to_jax.py b/streaming/sexp_to_jax.py new file mode 100644 index 0000000..a268586 --- /dev/null +++ b/streaming/sexp_to_jax.py @@ -0,0 +1,3638 @@ +""" +Sexp to JAX Compiler. + +Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU. +Uses XLA compilation via @jax.jit for automatic kernel fusion. + +Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles +S-expressions directly to JAX operations which XLA then optimizes. + +Usage: + from streaming.sexp_to_jax import compile_effect + + effect_code = ''' + (effect "threshold" + :params ((threshold :default 128)) + :body (let ((g (gray frame))) + (rgb (where (> g threshold) 255 0) + (where (> g threshold) 255 0) + (where (> g threshold) 255 0)))) + ''' + + run_effect = compile_effect(effect_code) + output = run_effect(frame, threshold=128) +""" + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial +from typing import Any, Dict, List, Callable, Optional, Tuple +import hashlib +import numpy as np + +# Import parser +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) +from sexp_effects.parser import parse, parse_all, Symbol, Keyword + + +# ============================================================================= +# Compilation Cache +# ============================================================================= + +_COMPILED_EFFECTS: Dict[str, Callable] = {} + + +# ============================================================================= +# Font Atlas for ASCII Effects +# ============================================================================= + +# Character sets for ASCII rendering +ASCII_ALPHABETS = { + 'standard': ' .:-=+*#%@', + 'blocks': ' ░▒▓█', + 'simple': ' .:oO@', + 'digits': ' 0123456789', + 'binary': ' 01', + 'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$', +} + +# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array +_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {} + + +def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray: + """ + Create a font atlas with all characters pre-rendered. + + Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time. + + Args: + alphabet: String of characters to render (ordered by brightness, dark to light) + char_size: Size of each character cell in pixels + font_name: Optional font name/path (uses default monospace if None) + + Returns: + NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters + Each character is white on black background. + """ + cache_key = (alphabet, char_size, font_name) + if cache_key in _FONT_ATLAS_CACHE: + return _FONT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + # Fallback: create simple block-based atlas without PIL + return _create_block_atlas(alphabet, char_size) + + num_chars = len(alphabet) + atlas = [] + + # Try to load a monospace font + font = None + font_size = int(char_size * 0.9) # Slightly smaller than cell + + # Try various monospace fonts + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf', + '/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf', + '/System/Library/Fonts/Menlo.ttc', # macOS + '/System/Library/Fonts/Monaco.dfont', # macOS + 'C:\\Windows\\Fonts\\consola.ttf', # Windows + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + # Use default font + try: + font = ImageFont.load_default() + except: + # Ultimate fallback to blocks + return _create_block_atlas(alphabet, char_size) + + for char in alphabet: + # Create image for this character + img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Get text bounding box for centering + try: + bbox = draw.textbbox((0, 0), char, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + except AttributeError: + # Older PIL versions + text_width, text_height = draw.textsize(char, font=font) + + # Center the character + x = (char_size - text_width) // 2 + y = (char_size - text_height) // 2 + + # Draw white character on black background + draw.text((x, y), char, fill=(255, 255, 255), font=font) + + # Convert to numpy array (NOT jax array - avoids tracer issues) + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) + _FONT_ATLAS_CACHE[cache_key] = atlas + return atlas + + +def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray: + """ + Create a simple block-based atlas without fonts. + Uses numpy to avoid tracer issues. + """ + num_chars = len(alphabet) + atlas = [] + + for i, char in enumerate(alphabet): + # Brightness proportional to position in alphabet + brightness = int(255 * i / max(num_chars - 1, 1)) + + # Create a simple pattern based on character + img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8) + + # Add some texture/pattern for visual interest + # Checkerboard pattern for mid-range characters + if 0.2 < i / num_chars < 0.8: + y_coords, x_coords = np.mgrid[:char_size, :char_size] + checker = ((x_coords + y_coords) % 2 == 0) + variation = int(brightness * 0.2) + img = np.where(checker[:, :, None], + np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8), + np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8)) + + atlas.append(img) + + return np.stack(atlas, axis=0) + + +def _get_alphabet_string(alphabet_name: str) -> str: + """Get the character string for a named alphabet or return as-is if custom.""" + if alphabet_name in ASCII_ALPHABETS: + return ASCII_ALPHABETS[alphabet_name] + return alphabet_name # Assume it's a custom character string + + +# ============================================================================= +# JAX Primitives - True primitives that can't be derived +# ============================================================================= + +def jax_width(frame): + """Frame width.""" + return frame.shape[1] + + +def jax_height(frame): + """Frame height.""" + return frame.shape[0] + + +def jax_channel(frame, idx): + """Extract channel by index as flat array.""" + # idx must be a static int for indexing + return frame[:, :, int(idx)].flatten().astype(jnp.float32) + + +def jax_merge_channels(r, g, b, shape): + """Merge RGB channels back to frame.""" + h, w = shape + r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8) + g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8) + b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([r_img, g_img, b_img], axis=2) + + +def jax_iota(n): + """Generate [0, 1, 2, ..., n-1].""" + return jnp.arange(n, dtype=jnp.float32) + + +def jax_repeat(x, n): + """Repeat each element n times: [a,b] -> [a,a,b,b].""" + return jnp.repeat(x, n) + + +def jax_tile(x, n): + """Tile array n times: [a,b] -> [a,b,a,b].""" + return jnp.tile(x, n) + + +def jax_gather(data, indices): + """Parallel index lookup.""" + flat_data = data.flatten() + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1) + return flat_data[idx_clipped] + + +def jax_scatter(indices, values, size): + """Parallel index write (last write wins).""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].set(values) + + +def jax_scatter_add(indices, values, size): + """Parallel index accumulate.""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].add(values) + + +def jax_group_reduce(values, group_indices, num_groups, op='mean'): + """Reduce values by group.""" + grp = group_indices.astype(jnp.int32) + + if op == 'sum': + result = jnp.zeros(num_groups, dtype=jnp.float32) + return result.at[grp].add(values) + elif op == 'mean': + sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values) + counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + elif op == 'max': + result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32) + result = result.at[grp].max(values) + return jnp.where(result == -jnp.inf, 0.0, result) + elif op == 'min': + result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32) + result = result.at[grp].min(values) + return jnp.where(result == jnp.inf, 0.0, result) + else: + raise ValueError(f"Unknown reduce op: {op}") + + +def jax_where(cond, true_val, false_val): + """Conditional select.""" + return jnp.where(cond, true_val, false_val) + + +def jax_cell_indices(frame, cell_size): + """Compute cell index for each pixel.""" + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + + cell_row = y_coords // cell_size + cell_col = x_coords // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32) + + +def jax_pool_frame(frame, cell_size): + """ + Pool frame to cell values. + Returns tuple: (cell_r, cell_g, cell_b, cell_lum) + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + cell_row = jnp.clip(y_coords // cs, 0, rows - 1) + cell_col = jnp.clip(x_coords // cs, 0, cols - 1) + cell_idx = (cell_row * cols + cell_col).astype(jnp.int32) + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(jnp.float32) + g_flat = frame[:, :, 1].flatten().astype(jnp.float32) + b_flat = frame[:, :, 2].flatten().astype(jnp.float32) + + # Pool each channel (mean) + def pool_channel(data): + sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data) + counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + return (r_pooled, g_pooled, b_pooled, lum) + + +def jax_sample(frame, x, y): + """Bilinear sample at (x, y) coordinates.""" + h, w = frame.shape[:2] + + # Clamp coordinates + x = jnp.clip(x, 0, w - 1) + y = jnp.clip(y, 0, h - 1) + + # Get integer and fractional parts + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = jnp.clip(x0 + 1, 0, w - 1) + y1 = jnp.clip(y0 + 1, 0, h - 1) + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + # Bilinear interpolation for each channel + def interp_channel(c): + c00 = frame[y0, x0, c].astype(jnp.float32) + c10 = frame[y0, x1, c].astype(jnp.float32) + c01 = frame[y1, x0, c].astype(jnp.float32) + c11 = frame[y1, x1, c].astype(jnp.float32) + + return (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + r = interp_channel(0) + g = interp_channel(1) + b = interp_channel(2) + + return r, g, b + + +# ============================================================================= +# Convolution Operations +# ============================================================================= + +def jax_convolve2d(data, kernel): + """2D convolution on a single channel.""" + # data shape: (H, W), kernel shape: (kH, kW) + # Use JAX's conv with appropriate padding + h, w = data.shape + kh, kw = kernel.shape + + # Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c) + data_4d = data.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(kh, kw, 1, 1) + + # Convolve with 'SAME' padding + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + + return result.reshape(h, w) + + +def jax_blur(frame, radius=1): + """Gaussian blur.""" + # Create gaussian kernel + size = int(radius) * 2 + 1 + x = jnp.arange(size) - radius + gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_sharpen(frame, amount=1.0): + """Sharpen using unsharp mask.""" + kernel = jnp.array([ + [0, -1, 0], + [-1, 5, -1], + [0, -1, 0] + ], dtype=jnp.float32) + + # Adjust kernel based on amount + center = 4 * amount + 1 + kernel = kernel.at[1, 1].set(center) + kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_edge_detect(frame): + """Sobel edge detection.""" + # Sobel kernels + sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) + sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) + + # Convert to grayscale first + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + gx = jax_convolve2d(gray, sobel_x) + gy = jax_convolve2d(gray, sobel_y) + + edges = jnp.sqrt(gx**2 + gy**2) + edges = jnp.clip(edges, 0, 255).astype(jnp.uint8) + + return jnp.stack([edges, edges, edges], axis=2) + + +def jax_emboss(frame): + """Emboss effect.""" + kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32) + + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + embossed = jax_convolve2d(gray, kernel) + 128 + embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8) + + return jnp.stack([embossed, embossed, embossed], axis=2) + + +# ============================================================================= +# Color Space Conversion +# ============================================================================= + +def jax_rgb_to_hsv(r, g, b): + """Convert RGB to HSV. All inputs/outputs are 0-255 range.""" + r, g, b = r / 255.0, g / 255.0, b / 255.0 + + max_c = jnp.maximum(jnp.maximum(r, g), b) + min_c = jnp.minimum(jnp.minimum(r, g), b) + diff = max_c - min_c + + # Value + v = max_c + + # Saturation + s = jnp.where(max_c > 0, diff / max_c, 0.0) + + # Hue + h = jnp.where(diff == 0, 0.0, + jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360, + jnp.where(max_c == g, 60 * ((b - r) / diff) + 120, + 60 * ((r - g) / diff) + 240))) + + return h, s * 255, v * 255 + + +def jax_hsv_to_rgb(h, s, v): + """Convert HSV to RGB. H is 0-360, S and V are 0-255.""" + h = h % 360 + s, v = s / 255.0, v / 255.0 + + c = v * s + x = c * (1 - jnp.abs((h / 60) % 2 - 1)) + m = v - c + + h_sector = (h / 60).astype(jnp.int32) % 6 + + r = jnp.where(h_sector == 0, c, + jnp.where(h_sector == 1, x, + jnp.where(h_sector == 2, 0, + jnp.where(h_sector == 3, 0, + jnp.where(h_sector == 4, x, c))))) + + g = jnp.where(h_sector == 0, x, + jnp.where(h_sector == 1, c, + jnp.where(h_sector == 2, c, + jnp.where(h_sector == 3, x, + jnp.where(h_sector == 4, 0, 0))))) + + b = jnp.where(h_sector == 0, 0, + jnp.where(h_sector == 1, 0, + jnp.where(h_sector == 2, x, + jnp.where(h_sector == 3, c, + jnp.where(h_sector == 4, c, x))))) + + return (r + m) * 255, (g + m) * 255, (b + m) * 255 + + +def jax_adjust_saturation(frame, factor): + """Adjust saturation by factor (1.0 = unchanged).""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + s = jnp.clip(s * factor, 0, 255) + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +def jax_shift_hue(frame, degrees): + """Shift hue by degrees.""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + h = (h + degrees) % 360 + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Color Adjustment Operations +# ============================================================================= + +def jax_adjust_brightness(frame, amount): + """Adjust brightness by amount (-255 to 255).""" + result = frame.astype(jnp.float32) + amount + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_adjust_contrast(frame, factor): + """Adjust contrast by factor (1.0 = unchanged).""" + result = (frame.astype(jnp.float32) - 128) * factor + 128 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_invert(frame): + """Invert colors.""" + return 255 - frame + + +def jax_posterize(frame, levels): + """Reduce to N color levels per channel.""" + levels = int(levels) + if levels < 2: + levels = 2 + step = 255.0 / (levels - 1) + result = jnp.round(frame.astype(jnp.float32) / step) * step + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_threshold(frame, level, invert=False): + """Binary threshold.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + if invert: + binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8) + else: + binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8) + + return jnp.stack([binary, binary, binary], axis=2) + + +def jax_sepia(frame): + """Apply sepia tone.""" + r = frame[:, :, 0].astype(jnp.float32) + g = frame[:, :, 1].astype(jnp.float32) + b = frame[:, :, 2].astype(jnp.float32) + + new_r = r * 0.393 + g * 0.769 + b * 0.189 + new_g = r * 0.349 + g * 0.686 + b * 0.168 + new_b = r * 0.272 + g * 0.534 + b * 0.131 + + return jnp.stack([ + jnp.clip(new_r, 0, 255).astype(jnp.uint8), + jnp.clip(new_g, 0, 255).astype(jnp.uint8), + jnp.clip(new_b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_grayscale(frame): + """Convert to grayscale.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + gray = gray.astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + +# ============================================================================= +# Geometry Operations +# ============================================================================= + +def jax_flip_horizontal(frame): + """Flip horizontally.""" + return frame[:, ::-1, :] + + +def jax_flip_vertical(frame): + """Flip vertically.""" + return frame[::-1, :, :] + + +def jax_rotate(frame, angle, center_x=None, center_y=None): + """Rotate frame by angle (degrees), matching OpenCV convention. + + Positive angle = counter-clockwise rotation. + """ + h, w = frame.shape[:2] + if center_x is None: + center_x = w / 2 + if center_y is None: + center_y = h / 2 + + # Convert to radians + theta = angle * jnp.pi / 180 + cos_t, sin_t = jnp.cos(theta), jnp.sin(theta) + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]] + # For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]] + # So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx + # src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy + x_centered = x_coords - center_x + y_centered = y_coords - center_y + + src_x = cos_t * x_centered - sin_t * y_centered + center_x + src_y = sin_t * x_centered + cos_t * y_centered + center_y + + # Mask for valid coordinates (out-of-bounds -> black, matching OpenCV) + valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1) + valid_flat = valid.flatten() + + # Sample using bilinear interpolation + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + # Zero out-of-bounds pixels (matching OpenCV warpAffine behavior) + r = jnp.where(valid_flat, r, 0) + g = jnp.where(valid_flat, g, 0) + b = jnp.where(valid_flat, b, 0) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_scale(frame, scale_x, scale_y=None): + """Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds.""" + if scale_y is None: + scale_y = scale_x + + h, w = frame.shape[:2] + center_x, center_y = w / 2, h / 2 + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # Scale from center (inverse mapping: dst -> src) + src_x = (x_coords - center_x) / scale_x + center_x + src_y = (y_coords - center_y) / scale_y + center_y + + # Mask for valid coordinates (out-of-bounds -> black) + valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1) + valid_flat = valid.flatten() + + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + # Zero out-of-bounds pixels + r = jnp.where(valid_flat, r, 0) + g = jnp.where(valid_flat, g, 0) + b = jnp.where(valid_flat, b, 0) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_resize(frame, new_width, new_height): + """Resize frame to new dimensions.""" + h, w = frame.shape[:2] + new_h, new_w = int(new_height), int(new_width) + + # Create coordinate grids for new size + y_coords = jnp.repeat(jnp.arange(new_h), new_w) + x_coords = jnp.tile(jnp.arange(new_w), new_h) + + # Map to source coordinates + src_x = x_coords * (w - 1) / (new_w - 1) + src_y = y_coords * (h - 1) / (new_h - 1) + + r, g, b = jax_sample(frame, src_x, src_y) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Blending Operations +# ============================================================================= + +def _resize_to_match(frame1, frame2): + """Resize frame2 to match frame1's dimensions if they differ. + + Uses jax.image.resize for bilinear interpolation. + Returns frame2 resized to frame1's shape. + """ + h1, w1 = frame1.shape[:2] + h2, w2 = frame2.shape[:2] + + # If same size, return as-is + if h1 == h2 and w1 == w2: + return frame2 + + # Resize frame2 to match frame1 + # jax.image.resize expects (height, width, channels) and target shape + return jax.image.resize( + frame2.astype(jnp.float32), + (h1, w1, frame2.shape[2]), + method='bilinear' + ).astype(jnp.uint8) + + +def jax_blend(frame1, frame2, alpha): + """Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2. + + Auto-resizes frame2 to match frame1 if dimensions differ. + """ + frame2 = _resize_to_match(frame1, frame2) + return (frame1.astype(jnp.float32) * (1 - alpha) + + frame2.astype(jnp.float32) * alpha).astype(jnp.uint8) + + +def jax_blend_add(frame1, frame2): + """Additive blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32) + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_multiply(frame1, frame2): + """Multiply blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_screen(frame1, frame2): + """Screen blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = 1 - (1 - f1) * (1 - f2) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +def jax_blend_overlay(frame1, frame2): + """Overlay blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = jnp.where(f1 < 0.5, + 2 * f1 * f2, + 1 - 2 * (1 - f1) * (1 - f2)) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +# ============================================================================= +# Utility +# ============================================================================= + +def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0): + """Create a JAX random key that varies with frame and operation. + + Uses jax.random.fold_in to mix frame_num (which may be traced) into the key. + This allows JIT compilation without recompiling for each frame. + + Args: + seed: Base seed for determinism (must be concrete) + frame_num: Frame number for variation (can be traced) + op_id: Operation ID for variation (must be concrete) + + Returns: + JAX PRNGKey + """ + # Create base key from seed and op_id (both concrete) + base_key = jax.random.PRNGKey(seed + op_id * 1000003) + # Fold in frame_num (can be traced value) + return jax.random.fold_in(base_key, frame_num) + + +def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42): + """Random float in [lo, hi), varies with frame.""" + key = make_jax_key(seed, frame_num, op_id) + return lo + jax.random.uniform(key) * (hi - lo) + + +def jax_is_nil(x): + """Check if value is None/nil.""" + return x is None + + +# ============================================================================= +# S-expression to JAX Compiler +# ============================================================================= + +class JaxCompiler: + """Compiles S-expressions to JAX functions.""" + + def __init__(self): + self.env = {} # Variable bindings during compilation + self.params = {} # Effect parameters + self.primitives = {} # Loaded primitive libraries + self.derived = {} # Loaded derived functions + + def load_derived(self, path: str): + """Load derived operations from a .sexp file.""" + with open(path, 'r') as f: + code = f.read() + exprs = parse_all(code) + + # Evaluate all define expressions to populate derived functions + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + self._eval_define(expr[1:], self.derived) + + def compile_effect(self, sexp) -> Callable: + """ + Compile an effect S-expression to a JAX function. + + Supports both formats: + (effect "name" :params (...) :body ...) + (define-effect name :params (...) body) + + Args: + sexp: Parsed S-expression + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + if not isinstance(sexp, list) or len(sexp) < 2: + raise ValueError("Effect must be a list") + + head = sexp[0] + if not isinstance(head, Symbol): + raise ValueError("Effect must start with a symbol") + + form = head.name + + # Handle both 'effect' and 'define-effect' formats + if form == 'effect': + # (effect "name" :params (...) :body ...) + name = sexp[1] if len(sexp) > 1 else "unnamed" + if isinstance(name, Symbol): + name = name.name + start_idx = 2 + elif form == 'define-effect': + # (define-effect name :params (...) body) + name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1]) + start_idx = 2 + else: + raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'") + + params_spec = [] + body = None + + i = start_idx + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword): + if item.name == 'params' and i + 1 < len(sexp): + params_spec = sexp[i + 1] + i += 2 + elif item.name == 'body' and i + 1 < len(sexp): + body = sexp[i + 1] + i += 2 + elif item.name in ('desc', 'type', 'range'): + # Skip metadata keywords + i += 2 + else: + i += 2 # Skip unknown keywords with their values + else: + # Assume it's the body if we haven't seen one + if body is None: + body = item + i += 1 + + if body is None: + raise ValueError(f"Effect '{name}' must have a body") + + # Extract parameter names, defaults, and static params (strings, bools) + param_info, static_params = self._parse_params(params_spec) + + # Capture derived functions for the closure + derived_fns = self.derived.copy() + + # Create the JAX function + def effect_fn(frame, **kwargs): + # Set up environment + h, w = frame.shape[:2] + # Get frame_num for deterministic random variation + frame_num = kwargs.get('frame_num', 0) + # Get seed from recipe config (passed via kwargs) + seed = kwargs.get('seed', 42) + env = { + 'frame': frame, + 'width': w, + 'height': h, + '_shape': (h, w), + # Time variables (default to 0, can be overridden via kwargs) + 't': kwargs.get('t', kwargs.get('_time', 0.0)), + '_time': kwargs.get('_time', kwargs.get('t', 0.0)), + 'time': kwargs.get('time', kwargs.get('t', 0.0)), + # Frame number for random key generation + 'frame_num': frame_num, + 'frame-num': frame_num, + '_frame_num': frame_num, + # Seed from recipe for deterministic random + '_seed': seed, + # Counter for unique random keys within same frame + '_rand_op_counter': 0, + # Common constants + 'pi': jnp.pi, + 'PI': jnp.pi, + } + + # Add derived functions + env.update(derived_fns) + + # Add parameters with defaults + for pname, pdefault in param_info.items(): + if pname in kwargs: + env[pname] = kwargs[pname] + elif isinstance(pdefault, list): + # Unevaluated S-expression default - evaluate it + env[pname] = self._eval(pdefault, env) + else: + env[pname] = pdefault + + # Evaluate body + result = self._eval(body, env) + + # Ensure result is a frame + if isinstance(result, tuple) and len(result) == 3: + # RGB tuple - merge to frame + r, g, b = result + return jax_merge_channels(r, g, b, (h, w)) + elif result.ndim == 3: + return result + else: + # Single channel - replicate to RGB + h, w = env['_shape'] + gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # JIT compile with static args for string/bool parameters and seed + # seed must be static for PRNGKey, but frame_num can be traced via fold_in + all_static = set(static_params) | {'seed'} + return jax.jit(effect_fn, static_argnames=list(all_static)) + + def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]: + """Parse parameter specifications. + + Returns: + Tuple of (param_defaults, static_params) + - param_defaults: Dict mapping param names to default values + - static_params: Set of param names that should be static (strings, bools) + """ + result = {} + static_params = set() + if not isinstance(params_spec, list): + return result, static_params + + for param in params_spec: + if isinstance(param, Symbol): + result[param.name] = 0.0 + elif isinstance(param, list) and len(param) >= 1: + pname = param[0].name if isinstance(param[0], Symbol) else str(param[0]) + pdefault = 0.0 + ptype = None + + # Look for :default and :type keywords + i = 1 + while i < len(param): + if isinstance(param[i], Keyword): + kw = param[i].name + if kw == 'default' and i + 1 < len(param): + pdefault = param[i + 1] + if isinstance(pdefault, Symbol): + if pdefault.name == 'nil': + pdefault = None + elif pdefault.name == 'true': + pdefault = True + elif pdefault.name == 'false': + pdefault = False + i += 2 + elif kw == 'type' and i + 1 < len(param): + ptype = param[i + 1] + if isinstance(ptype, Symbol): + ptype = ptype.name + i += 2 + else: + i += 1 + else: + i += 1 + + result[pname] = pdefault + + # Mark string and bool parameters as static (can't be traced by JAX) + if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)): + static_params.add(pname) + + return result, static_params + + def _eval(self, expr, env: Dict[str, Any]) -> Any: + """Evaluate an S-expression in the given environment.""" + + # Already-evaluated values (e.g., from threading macros) + # JAX arrays, NumPy arrays, tuples, etc. + if hasattr(expr, 'shape'): # JAX/NumPy array + return expr + if isinstance(expr, tuple): # e.g., (r, g, b) from rgb + return expr + + # Literals - keep as Python numbers for static operations + if isinstance(expr, (int, float)): + return expr + + if isinstance(expr, str): + return expr + + # Symbols - variable lookup + if isinstance(expr, Symbol): + name = expr.name + if name in env: + return env[name] + if name == 'nil': + return None + if name == 'true': + return True + if name == 'false': + return False + raise NameError(f"Unknown symbol: {name}") + + # Lists - function calls + if isinstance(expr, list) and len(expr) > 0: + head = expr[0] + + if isinstance(head, Symbol): + op = head.name + args = expr[1:] + + # Special forms + if op == 'let' or op == 'let*': + return self._eval_let(args, env) + if op == 'if': + return self._eval_if(args, env) + if op == 'lambda' or op == 'λ': + return self._eval_lambda(args, env) + if op == 'define': + return self._eval_define(args, env) + + # Built-in operations + return self._eval_call(op, args, env) + + # Empty list + if isinstance(expr, list) and len(expr) == 0: + return [] + + raise ValueError(f"Cannot evaluate: {expr}") + + def _eval_let(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body).""" + if len(args) < 2: + raise ValueError("let requires bindings and body") + + bindings = args[0] + body = args[1] + + new_env = env.copy() + + # Handle both ((var val) ...) and [var val var2 val2 ...] syntax + if isinstance(bindings, list): + # Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...) + if bindings and isinstance(bindings[0], Symbol): + # Flat list: [var val var2 val2 ...] + i = 0 + while i < len(bindings) - 1: + var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[var] = val + i += 2 + else: + # Nested list: ((var val) (var2 val2) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[var] = val + + return self._eval(body, new_env) + + def _eval_if(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + if len(args) < 2: + raise ValueError("if requires condition and then-branch") + + cond = self._eval(args[0], env) + + # Handle None as falsy (important for optional params like overlay) + if cond is None: + return self._eval(args[2], env) if len(args) > 2 else None + + # For Python scalar bools, use normal Python if + # This allows side effects and None values + if isinstance(cond, bool): + if cond: + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + + # For NumPy/JAX scalar bools with concrete values + if hasattr(cond, 'item') and cond.shape == (): + try: + if bool(cond.item()): + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + except: + pass # Fall through to jnp.where for traced values + + # For traced values, evaluate both branches and use jnp.where + then_val = self._eval(args[1], env) + else_val = self._eval(args[2], env) if len(args) > 2 else 0.0 + + # Handle None by converting to zeros + if then_val is None: + then_val = 0.0 + if else_val is None: + else_val = 0.0 + + # Convert lists to tuples + if isinstance(then_val, list): + then_val = tuple(then_val) + if isinstance(else_val, list): + else_val = tuple(else_val) + + # Handle tuple results (e.g., from rgb in map-pixels) + if isinstance(then_val, tuple) and isinstance(else_val, tuple): + return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val)) + + return jnp.where(cond, then_val, else_val) + + def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable: + """Evaluate (lambda (params) body).""" + if len(args) < 2: + raise ValueError("lambda requires parameters and body") + + params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + return fn + + def _eval_define(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (define name value) or (define (name params) body).""" + if len(args) < 2: + raise ValueError("define requires name and value") + + name_part = args[0] + + if isinstance(name_part, list): + # Function definition: (define (name params) body) + fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + env[fn_name] = fn + return fn + else: + # Variable definition + var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part) + val = self._eval(args[1], env) + env[var_name] = val + return val + + def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any: + """Evaluate a function call.""" + + # Check if it's a user-defined function + if op in env and callable(env[op]): + fn = env[op] + eval_args = [self._eval(a, env) for a in args] + return fn(*eval_args) + + # Arithmetic + if op == '+': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 0.0 + for v in vals[1:]: + result = result + v + return result + + if op == '-': + if len(args) == 1: + return -self._eval(args[0], env) + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result - v + return result + + if op == '*': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 1.0 + for v in vals[1:]: + result = result * v + return result + + if op == '/': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result / v + return result + + if op == 'mod': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a % b + + if op == 'pow' or op == '**': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return jnp.power(a, b) + + # Comparison + if op == '<': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a < b + if op == '>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a > b + if op == '<=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a <= b + if op == '>=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a >= b + if op == '=' or op == '==': + a, b = self._eval(args[0], env), self._eval(args[1], env) + # For scalar Python types, return Python bool to enable trace-time if + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a == b) + return a == b + if op == '!=' or op == '<>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a != b) + return a != b + + # Logic + if op == 'and': + vals = [self._eval(a, env) for a in args] + # Use Python and for concrete Python bools (e.g., shape comparisons) + if all(isinstance(v, (bool, np.bool_)) for v in vals): + result = True + for v in vals: + result = result and bool(v) + return result + # Otherwise use JAX logical_and + result = vals[0] + for v in vals[1:]: + result = jnp.logical_and(result, v) + return result + + if op == 'or': + vals = [self._eval(a, env) for a in args] + # Use Python or for concrete Python bools + if all(isinstance(v, (bool, np.bool_)) for v in vals): + result = False + for v in vals: + result = result or bool(v) + return result + # Otherwise use JAX logical_or + result = vals[0] + for v in vals[1:]: + result = jnp.logical_or(result, v) + return result + + if op == 'not': + val = self._eval(args[0], env) + if isinstance(val, (bool, np.bool_)): + return not bool(val) + return jnp.logical_not(val) + + # Math functions + if op == 'sqrt': + return jnp.sqrt(self._eval(args[0], env)) + if op == 'sin': + return jnp.sin(self._eval(args[0], env)) + if op == 'cos': + return jnp.cos(self._eval(args[0], env)) + if op == 'tan': + return jnp.tan(self._eval(args[0], env)) + if op == 'exp': + return jnp.exp(self._eval(args[0], env)) + if op == 'log': + return jnp.log(self._eval(args[0], env)) + if op == 'abs': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return abs(x) + return jnp.abs(x) + if op == 'floor': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.floor(x) + return jnp.floor(x) + if op == 'ceil': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.ceil(x) + return jnp.ceil(x) + if op == 'round': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return round(x) + return jnp.round(x) + + # Frame primitives + if op == 'width': + return env['width'] + if op == 'height': + return env['height'] + + if op == 'channel': + frame = self._eval(args[0], env) + idx = self._eval(args[1], env) + # idx should be a Python int (literal from S-expression) + return jax_channel(frame, idx) + + if op == 'merge-channels' or op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + if op == 'sample': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + return jax_sample(frame, x, y) + + if op == 'cell-indices': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_cell_indices(frame, cell_size) + + if op == 'pool-frame': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_pool_frame(frame, cell_size) + + # Xector primitives + if op == 'iota': + n = self._eval(args[0], env) + return jax_iota(int(n)) + + if op == 'repeat': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_repeat(x, int(n)) + + if op == 'tile': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_tile(x, int(n)) + + if op == 'gather': + data = self._eval(args[0], env) + indices = self._eval(args[1], env) + return jax_gather(data, indices) + + if op == 'scatter': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter(indices, values, size) + + if op == 'scatter-add': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter_add(indices, values, size) + + if op == 'group-reduce': + values = self._eval(args[0], env) + groups = self._eval(args[1], env) + num_groups = int(self._eval(args[2], env)) + reduce_op = args[3] if len(args) > 3 else 'mean' + if isinstance(reduce_op, Symbol): + reduce_op = reduce_op.name + return jax_group_reduce(values, groups, num_groups, reduce_op) + + if op == 'where': + cond = self._eval(args[0], env) + true_val = self._eval(args[1], env) + false_val = self._eval(args[2], env) + # Handle None values + if true_val is None: + true_val = 0.0 + if false_val is None: + false_val = 0.0 + return jax_where(cond, true_val, false_val) + + if op == 'len' or op == 'length': + x = self._eval(args[0], env) + if isinstance(x, (list, tuple)): + return len(x) + return x.size + + # Beta reductions + if op in ('β+', 'beta+', 'sum'): + return jnp.sum(self._eval(args[0], env)) + if op in ('β*', 'beta*', 'product'): + return jnp.prod(self._eval(args[0], env)) + if op in ('βmin', 'beta-min'): + return jnp.min(self._eval(args[0], env)) + if op in ('βmax', 'beta-max'): + return jnp.max(self._eval(args[0], env)) + if op in ('βmean', 'beta-mean', 'mean'): + return jnp.mean(self._eval(args[0], env)) + if op in ('βstd', 'beta-std'): + return jnp.std(self._eval(args[0], env)) + if op in ('βany', 'beta-any'): + return jnp.any(self._eval(args[0], env)) + if op in ('βall', 'beta-all'): + return jnp.all(self._eval(args[0], env)) + + # Convenience - min/max of two values (handle both scalars and arrays) + if op == 'min' or op == 'min2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return min(a, b) + return jnp.minimum(jnp.asarray(a), jnp.asarray(b)) + if op == 'max' or op == 'max2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return max(a, b) + return jnp.maximum(jnp.asarray(a), jnp.asarray(b)) + if op == 'clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + # List operations + if op == 'list': + return tuple(self._eval(a, env) for a in args) + + if op == 'nth': + seq = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[idx] if 0 <= idx < len(seq) else None + return seq[idx] # For arrays + + if op == 'first': + seq = self._eval(args[0], env) + return seq[0] if len(seq) > 0 else None + + if op == 'second': + seq = self._eval(args[0], env) + return seq[1] if len(seq) > 1 else None + + # Random (JAX-compatible) + # Get frame_num for deterministic variation - can be traced, fold_in handles it + frame_num = env.get('_frame_num', env.get('frame_num', 0)) + # Convert to int32 for fold_in if needed (but keep as JAX array if traced) + if frame_num is None: + frame_num = 0 + elif isinstance(frame_num, (int, float)): + frame_num = int(frame_num) + # If it's a JAX array, leave it as-is for tracing + + # Increment operation counter for unique keys within same frame + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + + if op == 'rand' or op == 'rand-x': + # For size-based random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + shape = (n,) + else: + n = size.size + shape = size.shape + elif hasattr(size, 'size'): + n = size.size + shape = (n,) + else: + n = int(size) + shape = (n,) + # Use deterministic key that varies with frame + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, shape).flatten() + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, ()) + + if op == 'randn' or op == 'randn-x': + # Normal random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + else: + n = size.size + elif hasattr(size, 'size'): + n = size.size + else: + n = int(size) + mean = self._eval(args[1], env) if len(args) > 1 else 0.0 + std = self._eval(args[2], env) if len(args) > 2 else 1.0 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, (n,)) * std + mean + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) + + if op == 'rand-range' or op == 'core:rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + seed = env.get('_seed', 42) + return jax_rand_range(lo, hi, frame_num, op_counter, seed) + + # ===================================================================== + # Convolution operations + # ===================================================================== + if op == 'blur' or op == 'image:blur': + frame = self._eval(args[0], env) + radius = self._eval(args[1], env) if len(args) > 1 else 1 + # Convert traced value to concrete for kernel size + if hasattr(radius, 'item'): + radius = int(radius.item()) + elif hasattr(radius, '__float__'): + radius = int(float(radius)) + else: + radius = int(radius) + return jax_blur(frame, max(1, radius)) + + if op == 'gaussian': + first_arg = self._eval(args[0], env) + # Check if first arg is a frame (blur) or scalar (random) + if hasattr(first_arg, 'shape') and first_arg.ndim == 3: + # Frame - apply gaussian blur + sigma = self._eval(args[1], env) if len(args) > 1 else 1.0 + radius = max(1, int(sigma * 3)) + return jax_blur(first_arg, radius) + else: + # Scalar args - generate gaussian random value + mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg + std = self._eval(args[1], env) if len(args) > 1 else 1.0 + # Return a single random value + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) * std + mean + + if op == 'sharpen' or op == 'image:sharpen': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + return jax_sharpen(frame, amount) + + if op == 'edge-detect' or op == 'image:edge-detect': + frame = self._eval(args[0], env) + return jax_edge_detect(frame) + + if op == 'emboss': + frame = self._eval(args[0], env) + return jax_emboss(frame) + + if op == 'convolve': + frame = self._eval(args[0], env) + kernel = self._eval(args[1], env) + # Convert kernel to array if it's a list + if isinstance(kernel, (list, tuple)): + kernel = jnp.array(kernel, dtype=jnp.float32) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'add-noise': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 0.1 + h, w = frame.shape[:2] + # Use frame-varying key for noise + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1] + result = frame.astype(jnp.float32) + noise * amount * 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + if op == 'translate': + frame = self._eval(args[0], env) + dx = self._eval(args[1], env) + dy = self._eval(args[2], env) if len(args) > 2 else 0 + h, w = frame.shape[:2] + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + src_x = (x_coords - dx).flatten() + src_y = (y_coords - dy).flatten() + r, g, b = jax_sample(frame, src_x, src_y) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'image:crop' or op == 'crop': + frame = self._eval(args[0], env) + x = int(self._eval(args[1], env)) + y = int(self._eval(args[2], env)) + w = int(self._eval(args[3], env)) + h = int(self._eval(args[4], env)) + return frame[y:y+h, x:x+w, :] + + if op == 'dilate': + frame = self._eval(args[0], env) + size = int(self._eval(args[1], env)) if len(args) > 1 else 3 + # Simple dilation using max pooling approximation + kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'map-rows': + frame = self._eval(args[0], env) + fn = args[1] # S-expression function + h, w = frame.shape[:2] + # For each row, apply the function + results = [] + for row_idx in range(h): + row_env = env.copy() + row_env['row'] = frame[row_idx, :, :] + row_env['row-idx'] = row_idx + + # Check if fn is a lambda + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + # Bind lambda params to y and row + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + row_env[param_name] = row_idx # y + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + row_env[param_name] = frame[row_idx, :, :] # row + result_row = self._eval(body, row_env) + results.append(result_row) + continue + + result_row = self._eval(fn, row_env) + # If result is a function, call it + if callable(result_row): + result_row = result_row(row_idx, frame[row_idx, :, :]) + results.append(result_row) + return jnp.stack(results, axis=0) + + # ===================================================================== + # Color operations + # ===================================================================== + if op == 'rgb->hsv' or op == 'rgb-to-hsv': + # Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple + if len(args) == 1: + c = self._eval(args[0], env) + if isinstance(c, tuple) and len(c) == 3: + r, g, b = c + else: + # Assume it's a list-like + r, g, b = c[0], c[1], c[2] + else: + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + return jax_rgb_to_hsv(r, g, b) + + if op == 'hsv->rgb' or op == 'hsv-to-rgb': + # Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list) + if len(args) == 1: + hsv = self._eval(args[0], env) + if isinstance(hsv, (tuple, list)) and len(hsv) >= 3: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h = self._eval(args[0], env) + s = self._eval(args[1], env) + v = self._eval(args[2], env) + return jax_hsv_to_rgb(h, s, v) + + if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) + return jax_adjust_brightness(frame, amount) + + if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_contrast(frame, factor) + + if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_saturation(frame, factor) + + if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift': + frame = self._eval(args[0], env) + degrees = self._eval(args[1], env) + return jax_shift_hue(frame, degrees) + + if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img': + frame = self._eval(args[0], env) + return jax_invert(frame) + + if op == 'posterize' or op == 'color_ops:posterize': + frame = self._eval(args[0], env) + levels = self._eval(args[1], env) + return jax_posterize(frame, levels) + + if op == 'threshold' or op == 'color_ops:threshold': + frame = self._eval(args[0], env) + level = self._eval(args[1], env) + invert = self._eval(args[2], env) if len(args) > 2 else False + return jax_threshold(frame, level, invert) + + if op == 'sepia' or op == 'color_ops:sepia': + frame = self._eval(args[0], env) + return jax_sepia(frame) + + if op == 'grayscale' or op == 'image:grayscale': + frame = self._eval(args[0], env) + return jax_grayscale(frame) + + # ===================================================================== + # Geometry operations + # ===================================================================== + if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-img': + frame = self._eval(args[0], env) + direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal' + if direction == 'vertical' or direction == 'v': + return jax_flip_vertical(frame) + return jax_flip_horizontal(frame) + + if op == 'flip-vertical' or op == 'flip-v': + frame = self._eval(args[0], env) + return jax_flip_vertical(frame) + + if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img': + frame = self._eval(args[0], env) + angle = self._eval(args[1], env) + return jax_rotate(frame, angle) + + if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img': + frame = self._eval(args[0], env) + scale_x = self._eval(args[1], env) + scale_y = self._eval(args[2], env) if len(args) > 2 else None + return jax_scale(frame, scale_x, scale_y) + + if op == 'resize' or op == 'image:resize': + frame = self._eval(args[0], env) + new_w = self._eval(args[1], env) + new_h = self._eval(args[2], env) + return jax_resize(frame, new_w, new_h) + + # ===================================================================== + # Geometry distortion effects + # ===================================================================== + if op == 'geometry:fisheye-coords' or op == 'fisheye': + # Signature: (w h strength cx cy zoom_correct) or (frame strength) + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + # (w h strength cx cy zoom_correct) signature + w = int(first_arg) + h = int(self._eval(args[1], env)) + strength = self._eval(args[2], env) if len(args) > 2 else 0.5 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = first_arg + strength = self._eval(args[1], env) if len(args) > 1 else 0.5 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Fisheye distortion + r_new = r + strength * r * (1 - r / max_r) + + src_x = r_new * jnp.cos(theta) + cx + src_y = r_new * jnp.sin(theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:swirl-coords' or op == 'swirl': + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + w = int(first_arg) + h = int(self._eval(args[1], env)) + amount = self._eval(args[2], env) if len(args) > 2 else 1.0 + frame = None + else: + frame = first_arg + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + h, w = frame.shape[:2] + + cx, cy = w / 2, h / 2 + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + swirl_angle = amount * (1 - r / max_r) + new_theta = theta + swirl_angle + + src_x = r * jnp.cos(new_theta) + cx + src_y = r * jnp.sin(new_theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Wave effect (frame-first signature for simple usage) + if op == 'wave-distort': + first_arg = self._eval(args[0], env) + frame = first_arg + amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0 + amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0 + freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1 + freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1 + h, w = frame.shape[:2] + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y) + src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x) + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:ripple-displace' or op == 'ripple': + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + w = int(first_arg) + h = int(self._eval(args[1], env)) + amplitude = self._eval(args[2], env) if len(args) > 2 else 10.0 + frequency = self._eval(args[3], env) if len(args) > 3 else 0.05 + frame = None + else: + frame = first_arg + amplitude = self._eval(args[1], env) if len(args) > 1 else 10.0 + frequency = self._eval(args[2], env) if len(args) > 2 else 0.05 + h, w = frame.shape[:2] + + cx, cy = w / 2, h / 2 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + dist = jnp.sqrt(dx*dx + dy*dy) + + displacement = amplitude * jnp.sin(dist * frequency) + angle = jnp.arctan2(dy, dx) + + src_x = x_coords + displacement * jnp.cos(angle) + src_y = y_coords + displacement * jnp.sin(angle) + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope': + # Two signatures: (frame segments) or (w h segments cx cy) + if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'): + # (w h segments cx cy) signature + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + segments = int(self._eval(args[2], env)) if len(args) > 2 else 6 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = self._eval(args[0], env) + segments = int(self._eval(args[1], env)) if len(args) > 1 else 6 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Mirror into segments + segment_angle = 2 * jnp.pi / segments + theta_mod = theta % segment_angle + theta_mirror = jnp.where( + (jnp.floor(theta / segment_angle) % 2) == 0, + theta_mod, + segment_angle - theta_mod + ) + + src_x = r * jnp.cos(theta_mirror) + cx + src_y = r * jnp.sin(theta_mirror) + cy + + if frame is None: + # Return coordinate arrays + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Geometry coordinate extraction + if op == 'geometry:coords-x': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['x'] + return coords[0] if isinstance(coords, tuple) else coords + + if op == 'geometry:coords-y': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['y'] + return coords[1] if isinstance(coords, tuple) else coords + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x_coords = self._eval(args[1], env) + y_coords = self._eval(args[2], env) + h, w = frame.shape[:2] + r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(frame1, frame2, alpha) + + if op == 'blend-add': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_add(frame1, frame2) + + if op == 'blend-multiply': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_multiply(frame1, frame2) + + if op == 'blend-screen': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_screen(frame1, frame2) + + if op == 'blend-overlay': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_overlay(frame1, frame2) + + # ===================================================================== + # Image dimension queries (namespaced aliases) + # ===================================================================== + if op == 'image:width': + if args: + frame = self._eval(args[0], env) + return frame.shape[1] # width is second dimension (h, w, c) + return env['width'] + + if op == 'image:height': + if args: + frame = self._eval(args[0], env) + return frame.shape[0] # height is first dimension (h, w, c) + return env['height'] + + # ===================================================================== + # Utility + # ===================================================================== + if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?': + x = self._eval(args[0], env) + return jax_is_nil(x) + + # ===================================================================== + # Xector channel operations (shortcuts) + # ===================================================================== + if op == 'red': + val = self._eval(args[0], env) + # Works on frames or pixel tuples + if isinstance(val, tuple): + return val[0] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 0) + else: + return val # Assume it's already a channel + + if op == 'green': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[1] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 1) + else: + return val + + if op == 'blue': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[2] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 2) + else: + return val + + if op == 'gray' or op == 'luminance': + val = self._eval(args[0], env) + # Handle tuple (r, g, b) from map-pixels + if isinstance(val, tuple) and len(val) == 3: + r, g, b = val + return r * 0.299 + g * 0.587 + b * 0.114 + # Handle frame + frame = val + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + return r * 0.299 + g * 0.587 + b * 0.114 + + if op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + # ===================================================================== + # Coordinate operations + # ===================================================================== + if op == 'x-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + + if op == 'y-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + if op == 'dist-from-center': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx + y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy + return jnp.sqrt(x*x + y*y) + + # ===================================================================== + # Alpha operations (element-wise on xectors) + # ===================================================================== + if op == 'α/' or op == 'alpha/': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a / b + + if op == 'α+' or op == 'alpha+': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result + v + return result + + if op == 'α*' or op == 'alpha*': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result * v + return result + + if op == 'α-' or op == 'alpha-': + if len(args) == 1: + return -self._eval(args[0], env) + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a - b + + if op == 'αclamp' or op == 'alpha-clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + if op == 'αmin' or op == 'alpha-min': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.minimum(a, b) + + if op == 'αmax' or op == 'alpha-max': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.maximum(a, b) + + if op == 'αsqrt' or op == 'alpha-sqrt': + return jnp.sqrt(self._eval(args[0], env)) + + if op == 'αsin' or op == 'alpha-sin': + return jnp.sin(self._eval(args[0], env)) + + if op == 'αcos' or op == 'alpha-cos': + return jnp.cos(self._eval(args[0], env)) + + if op == 'αmod' or op == 'alpha-mod': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a % b + + if op == 'α²' or op == 'αsq' or op == 'alpha-sq': + x = self._eval(args[0], env) + return x * x + + if op == 'α<' or op == 'alpha<': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a < b + + if op == 'α>' or op == 'alpha>': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a > b + + if op == 'α<=' or op == 'alpha<=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a <= b + + if op == 'α>=' or op == 'alpha>=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a >= b + + if op == 'α=' or op == 'alpha=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a == b + + if op == 'αfloor' or op == 'alpha-floor': + return jnp.floor(self._eval(args[0], env)) + + if op == 'αceil' or op == 'alpha-ceil': + return jnp.ceil(self._eval(args[0], env)) + + if op == 'αround' or op == 'alpha-round': + return jnp.round(self._eval(args[0], env)) + + if op == 'αabs' or op == 'alpha-abs': + return jnp.abs(self._eval(args[0], env)) + + if op == 'αexp' or op == 'alpha-exp': + return jnp.exp(self._eval(args[0], env)) + + if op == 'αlog' or op == 'alpha-log': + return jnp.log(self._eval(args[0], env)) + + if op == 'αor' or op == 'alpha-or': + # Element-wise logical OR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_or(a, b) + + if op == 'αand' or op == 'alpha-and': + # Element-wise logical AND + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_and(a, b) + + if op == 'αnot' or op == 'alpha-not': + # Element-wise logical NOT + return jnp.logical_not(self._eval(args[0], env)) + + if op == 'αxor' or op == 'alpha-xor': + # Element-wise logical XOR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_xor(a, b) + + # ===================================================================== + # Threading/arrow operations + # ===================================================================== + if op == '->': + # Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b) + val = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list): + # Insert val as first argument + fn_name = form[0].name if isinstance(form[0], Symbol) else form[0] + new_args = [val] + [self._eval(a, env) for a in form[1:]] + val = self._eval_call(fn_name, [val] + form[1:], env) + else: + # Simple function call + fn_name = form.name if isinstance(form, Symbol) else form + val = self._eval_call(fn_name, [args[0]], env) + return val + + # ===================================================================== + # Range and iteration + # ===================================================================== + if op == 'range': + if len(args) == 1: + end = int(self._eval(args[0], env)) + return list(range(end)) + elif len(args) == 2: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + return list(range(start, end)) + else: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + step = int(self._eval(args[2], env)) + return list(range(start, end, step)) + + if op == 'reduce' or op == 'fold': + # (reduce seq init fn) - left fold + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for item in seq: + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + acc = self._eval(body, fn_env) + return acc + + # Fallback - try evaluating fn and calling it + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for item in seq: + acc = fn_eval(acc, item) + return acc + + # ===================================================================== + # Map-pixels (apply function to each pixel) + # ===================================================================== + if op == 'map-pixels': + frame = self._eval(args[0], env) + fn = args[1] # Lambda or S-expression + h, w = frame.shape[:2] + + # Extract channels + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + # Set up pixel environment + pixel_env = env.copy() + pixel_env['r'] = r + pixel_env['g'] = g + pixel_env['b'] = b + pixel_env['x'] = x_coords + pixel_env['y'] = y_coords + # Also provide c (color) as a tuple for lambda (x y c) style + pixel_env['c'] = (r, g, b) + + # If fn is a lambda, we need to handle it specially + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + # Lambda: (lambda (x y c) body) + params = fn[1] + body = fn[2] + # Bind parameters + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + pixel_env[param_name] = x_coords + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + pixel_env[param_name] = y_coords + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + pixel_env[param_name] = (r, g, b) + result = self._eval(body, pixel_env) + else: + result = self._eval(fn, pixel_env) + else: + result = self._eval(fn, pixel_env) + + if isinstance(result, tuple) and len(result) == 3: + nr, ng, nb = result + return jax_merge_channels(nr, ng, nb, (h, w)) + elif hasattr(result, 'shape') and result.ndim == 3: + return result + else: + # Single channel result + if hasattr(result, 'flatten'): + result = result.flatten() + gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # ===================================================================== + # State operations (return unchanged for stateless JIT) + # ===================================================================== + if op == 'state-get': + key = self._eval(args[0], env) + default = self._eval(args[1], env) if len(args) > 1 else None + return default # State not supported in JIT, return default + + if op == 'state-set': + return None # No-op in JIT + + # ===================================================================== + # Cell/grid operations + # ===================================================================== + if op == 'local-x-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size) / max(1, cell_size - 1) + + if op == 'local-y-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size) / max(1, cell_size - 1) + + if op == 'local-x': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size).astype(jnp.float32) + + if op == 'local-y': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size).astype(jnp.float32) + + if op == 'cell-row': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return jnp.floor(y / cell_size) + + if op == 'cell-col': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return jnp.floor(x / cell_size) + + if op == 'num-rows': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[0] // cell_size + + if op == 'num-cols': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[1] // cell_size + + # ===================================================================== + # Control flow + # ===================================================================== + if op == 'cond': + # (cond (test1 expr1) (test2 expr2) ... (else exprN)) + # For JAX compatibility, build a nested jnp.where structure + # Start from the else clause and work backwards + + # Collect clauses + clauses = [] + else_expr = None + for clause in args: + if isinstance(clause, list) and len(clause) >= 2: + test = clause[0] + if isinstance(test, Symbol) and test.name == 'else': + else_expr = clause[1] + else: + clauses.append((test, clause[1])) + + # If no else, default to None/0 + if else_expr is not None: + result = self._eval(else_expr, env) + else: + result = 0 + + # Build nested where from last to first + for test_expr, val_expr in reversed(clauses): + cond_val = self._eval(test_expr, env) + then_val = self._eval(val_expr, env) + + # Check if condition is array or scalar + if hasattr(cond_val, 'shape') and cond_val.shape != (): + # Array condition - use jnp.where + result = jnp.where(cond_val, then_val, result) + else: + # Scalar - can use Python if + if cond_val: + result = then_val + + return result + + if op == 'set!' or op == 'set': + # Mutation - not really supported in JAX, but we can update env + var = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + val = self._eval(args[1], env) + env[var] = val + return val + + if op == 'begin' or op == 'do': + # Evaluate all expressions, return last + result = None + for expr in args: + result = self._eval(expr, env) + return result + + # ===================================================================== + # Additional math + # ===================================================================== + if op == 'sq' or op == 'square': + x = self._eval(args[0], env) + return x * x + + if op == 'lerp': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + t = self._eval(args[2], env) + return a * (1 - t) + b * t + + if op == 'smoothstep': + edge0 = self._eval(args[0], env) + edge1 = self._eval(args[1], env) + x = self._eval(args[2], env) + t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1) + return t * t * (3 - 2 * t) + + if op == 'atan2': + y = self._eval(args[0], env) + x = self._eval(args[1], env) + return jnp.arctan2(y, x) + + if op == 'fract' or op == 'frac': + x = self._eval(args[0], env) + return x - jnp.floor(x) + + # ===================================================================== + # Frame copy and construction operations + # ===================================================================== + if op == 'pixel': + # Get pixel at (x, y) from frame + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + # Convert to int and clip to bounds + if isinstance(x, (int, float)): + x = max(0, min(int(x), w - 1)) + else: + x = jnp.clip(x, 0, w - 1).astype(jnp.int32) + if isinstance(y, (int, float)): + y = max(0, min(int(y), h - 1)) + else: + y = jnp.clip(y, 0, h - 1).astype(jnp.int32) + r = frame[y, x, 0] + g = frame[y, x, 1] + b = frame[y, x, 2] + return (r, g, b) + + if op == 'copy': + frame = self._eval(args[0], env) + return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame) + + if op == 'make-image': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + if len(args) > 2: + color = self._eval(args[2], env) + if isinstance(color, (list, tuple)): + r, g, b = int(color[0]), int(color[1]), int(color[2]) + else: + r = g = b = int(color) + else: + r = g = b = 0 + img = jnp.zeros((h, w, 3), dtype=jnp.uint8) + img = img.at[:, :, 0].set(r) + img = img.at[:, :, 1].set(g) + img = img.at[:, :, 2].set(b) + return img + + if op == 'paste': + dest = self._eval(args[0], env) + src = self._eval(args[1], env) + x = int(self._eval(args[2], env)) + y = int(self._eval(args[3], env)) + sh, sw = src.shape[:2] + dh, dw = dest.shape[:2] + # Clip to dest bounds + x1, y1 = max(0, x), max(0, y) + x2, y2 = min(dw, x + sw), min(dh, y + sh) + sx1, sy1 = x1 - x, y1 - y + sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1) + result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest) + result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :]) + return result + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blending:blend-images' or op == 'blend-images': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(a, b, alpha) + + if op == 'blending:blend-mode' or op == 'blend-mode': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + mode = self._eval(args[2], env) if len(args) > 2 else 'add' + if mode == 'add': + return jax_blend_add(a, b) + elif mode == 'multiply': + return jax_blend_multiply(a, b) + elif mode == 'screen': + return jax_blend_screen(a, b) + elif mode == 'overlay': + return jax_blend_overlay(a, b) + elif mode == 'lighten': + return jnp.maximum(a, b) + elif mode == 'darken': + return jnp.minimum(a, b) + elif mode == 'difference': + return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8) + else: + return jax_blend(a, b, 0.5) + + # ===================================================================== + # Geometry coordinate operations + # ===================================================================== + if op == 'geometry:wave-coords' or op == 'wave-coords': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + axis = self._eval(args[2], env) if len(args) > 2 else 'x' + freq = self._eval(args[3], env) if len(args) > 3 else 1.0 + amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0 + phase = self._eval(args[5], env) if len(args) > 5 else 0.0 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + if axis == 'x' or axis == 'horizontal': + # Wave displaces X based on Y + offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + src_x = x_coords + offset + src_y = y_coords + elif axis == 'y' or axis == 'vertical': + # Wave displaces Y based on X + offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + src_y = y_coords + offset + else: # both + offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + offset_x + src_y = y_coords + offset_y + + return {'x': src_x, 'y': src_y} + + if op == 'geometry:coords-x' or op == 'coords-x': + coords = self._eval(args[0], env) + return coords['x'] + + if op == 'geometry:coords-y' or op == 'coords-y': + coords = self._eval(args[0], env) + return coords['y'] + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + r, g, b = jax_sample(frame, x.flatten(), y.flatten()) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Glitch effects + # ===================================================================== + if op == 'pixelsort': + frame = self._eval(args[0], env) + sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness' + thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50 + thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200 + angle = int(self._eval(args[4], env)) if len(args) > 4 else 0 + reverse = self._eval(args[5], env) if len(args) > 5 else False + + h, w = frame.shape[:2] + result = frame.copy() + + # Get luminance for thresholding + lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + # Sort each row + for y in range(h): + row_lum = lum[y, :] + row = frame[y, :, :] + + # Find mask of pixels to sort + mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi) + + # Get indices where we should sort + sort_indices = jnp.where(mask, jnp.arange(w), -1) + + # Simple sort by luminance for the row + if sort_by == 'lightness': + sort_key = row_lum + elif sort_by == 'hue': + # Approximate hue from RGB + sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32), + row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32))) + else: + sort_key = row_lum + + # Sort pixels in masked region + sorted_indices = jnp.argsort(sort_key) + if reverse: + sorted_indices = sorted_indices[::-1] + + # Apply partial sort (only where mask is true) + # This is a simplified version - full pixelsort is more complex + result = result.at[y, :, :].set(row[sorted_indices]) + + return result + + if op == 'datamosh': + frame = self._eval(args[0], env) + prev = self._eval(args[1], env) + block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32 + corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3 + max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50 + color_corrupt = self._eval(args[5], env) if len(args) > 5 else True + + h, w = frame.shape[:2] + + # Use deterministic random for JIT with frame variation + seed = env.get('_seed', 42) + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + key = make_jax_key(seed, frame_num, op_counter) + + num_blocks_y = h // block_size + num_blocks_x = w // block_size + total_blocks = num_blocks_y * num_blocks_x + + # Pre-generate all random values at once (vectorized) + key, k1, k2, k3, k4, k5 = jax.random.split(key, 6) + corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption + offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1) + offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1) + channels = jax.random.randint(k4, (total_blocks,), 0, 3) + color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51) + + # Create coordinate grids for blocks + by_grid = jnp.arange(num_blocks_y) + bx_grid = jnp.arange(num_blocks_x) + + # Create block coordinate arrays + by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...] + bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...] + + # Create pixel coordinate grids + y_coords, x_coords = jnp.mgrid[:h, :w] + + # Determine which block each pixel belongs to + pixel_block_y = y_coords // block_size + pixel_block_x = x_coords // block_size + pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x + + # Clamp to valid block indices (for pixels outside the block grid) + pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1) + + # Get the corrupt mask, offsets for each pixel's block + pixel_corrupt = corrupt_mask[pixel_block_idx] + pixel_offset_y = offsets_y[pixel_block_idx] + pixel_offset_x = offsets_x[pixel_block_idx] + + # Calculate source coordinates with offset (clamped) + src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1) + src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1) + + # Sample from previous frame at offset positions + prev_sampled = prev[src_y, src_x, :] + + # Where corrupt mask is true, use prev_sampled; else use frame + result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame) + + # Apply color corruption to corrupted blocks + if color_corrupt: + pixel_channel = channels[pixel_block_idx] + pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16) + + # Create per-channel shift arrays (only shift the selected channel) + shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0) + shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0) + shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0) + + result_int = result.astype(jnp.int16) + result_int = result_int.at[:, :, 0].add(shift_r) + result_int = result_int.at[:, :, 1].add(shift_g) + result_int = result_int.at[:, :, 2].add(shift_b) + result = jnp.clip(result_int, 0, 255).astype(jnp.uint8) + + return result + + # ===================================================================== + # ASCII Art Operations (using pre-rendered font atlas) + # ===================================================================== + + if op == 'cell-sample': + # (cell-sample frame char_size) -> (colors, luminances) + # Downsample frame into cells, return average colors and luminances + frame = self._eval(args[0], env) + char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8 + + h, w = frame.shape[:2] + num_rows = h // char_size + num_cols = w // char_size + + # Crop to exact multiple of char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + + # Reshape to (num_rows, char_size, num_cols, char_size, 3) + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + + # Average over char_size dimensions -> (num_rows, num_cols, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminance per cell + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + return (colors, luminances) + + if op == 'luminance-to-chars': + # (luminance-to-chars luminances alphabet contrast) -> char_indices + # Map luminance values to character indices + luminances = self._eval(args[0], env) + alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard' + contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5 + + # Get alphabet string + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Apply contrast + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + + # Map to character indices (0 = darkest, num_chars-1 = brightest) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + return char_indices + + if op == 'render-char-grid': + # (render-char-grid frame chars colors char_size color_mode background_color invert_colors) + # Render character grid using font atlas + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8 + color_mode = self._eval(args[4], env) if len(args) > 4 else 'color' + background_color = self._eval(args[5], env) if len(args) > 5 else 'black' + invert_colors = self._eval(args[6], env) if len(args) > 6 else 0 + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get the alphabet used (stored in env or default) + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + + # Get or create font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + # Try to parse hex color + try: + if background_color.startswith('#'): + bg_hex = background_color[1:] + bg = jnp.array([int(bg_hex[0:2], 16), + int(bg_hex[2:4], 16), + int(bg_hex[4:6], 16)], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + except: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create output image starting with background + output_h = num_rows * char_size + output_w = num_cols * char_size + result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy() + + # Gather characters from atlas based on indices + # char_indices shape: (num_rows, num_cols) + # font_atlas shape: (num_chars, char_size, char_size, 3) + # Convert numpy atlas to JAX for indexing with traced indices + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3) + + # Reshape to grid + char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create coordinate grids for output pixels + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = y_out // char_size + cell_col = x_out // char_size + local_y = y_out % char_size + local_x = x_out % char_size + + # Clamp to valid ranges + cell_row = jnp.clip(cell_row, 0, num_rows - 1) + cell_col = jnp.clip(cell_col, 0, num_cols - 1) + + # Get character pixel values + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + + # Get character brightness (for masking) + char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Handle color modes + if color_mode == 'mono': + # White characters on background + fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8) + fg = jnp.broadcast_to(fg_color, char_pixels.shape) + elif color_mode == 'invert': + # Inverted cell colors + cell_colors = colors[cell_row, cell_col] + fg = 255 - cell_colors + else: + # 'color' mode - use cell colors + fg = colors[cell_row, cell_col] + + # Blend foreground onto background based on character brightness + if invert_colors: + # Swap fg and bg + fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg + result = (fg.astype(jnp.float32) * (1 - char_brightness) + + bg_broadcast.astype(jnp.float32) * char_brightness) + else: + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) + + fg.astype(jnp.float32) * char_brightness) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize back to original frame size if needed + if result.shape[0] != h or result.shape[1] != w: + # Simple nearest-neighbor resize + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = (jnp.arange(h) * y_scale).astype(jnp.int32) + x_src = (jnp.arange(w) * x_scale).astype(jnp.int32) + y_src = jnp.clip(y_src, 0, result.shape[0] - 1) + x_src = jnp.clip(x_src, 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'ascii-fx-zone': + # Complex ASCII effect with per-zone expressions + # (ascii-fx-zone frame :cols cols :alphabet alphabet ...) + frame = self._eval(args[0], env) + + # Parse keyword arguments + kwargs = {} + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + if i + 1 < len(args): + kwargs[key] = args[i + 1] + i += 2 + else: + i += 1 + + # Get parameters + cols = int(self._eval(kwargs.get('cols', 80), env)) + char_size_param = kwargs.get('char_size') + alphabet = self._eval(kwargs.get('alphabet', 'standard'), env) + color_mode = self._eval(kwargs.get('color_mode', 'color'), env) + background = self._eval(kwargs.get('background', 'black'), env) + contrast = float(self._eval(kwargs.get('contrast', 1.5), env)) + + h, w = frame.shape[:2] + + # Calculate char_size from cols if not specified + if char_size_param is not None: + char_size_val = self._eval(char_size_param, env) + if char_size_val is not None: + char_size = int(char_size_val) + else: + char_size = w // cols + else: + char_size = w // cols + char_size = max(4, min(char_size, 64)) + + # Store alphabet for render-char-grid to use + env['_ascii_alphabet'] = alphabet + + # Cell sampling + num_rows = h // char_size + num_cols = w // char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminances + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + # Get alphabet and map luminances to chars + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + # Handle optional per-zone effects (char_hue, char_saturation, etc.) + # These would modify colors based on zone position + char_hue = kwargs.get('char_hue') + char_saturation = kwargs.get('char_saturation') + char_brightness = kwargs.get('char_brightness') + + if char_hue is not None or char_saturation is not None or char_brightness is not None: + # Create zone coordinate arrays for expression evaluation + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + # Bind zone variables + zone_env = env.copy() + zone_env['zone-row'] = row_coords + zone_env['zone-col'] = col_coords + zone_env['zone-row-norm'] = row_norm + zone_env['zone-col-norm'] = col_norm + zone_env['zone-lum'] = luminances + + # Apply color modifications (simplified - full version would use HSV) + if char_brightness is not None: + brightness_mult = self._eval(char_brightness, zone_env) + if brightness_mult is not None: + colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None], + 0, 255).astype(jnp.uint8) + + # Render using font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original dimensions + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'render-char-grid-fx': + # Enhanced render with per-character effects + # (render-char-grid-fx frame chars colors luminances char_size + # color_mode bg_color invert_colors + # char_jitter char_scale char_rotation char_hue_shift + # jitter_source scale_source rotation_source hue_source) + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + luminances = self._eval(args[3], env) + char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8 + color_mode = self._eval(args[5], env) if len(args) > 5 else 'color' + background_color = self._eval(args[6], env) if len(args) > 6 else 'black' + invert_colors = self._eval(args[7], env) if len(args) > 7 else 0 + + # Per-char effect amounts + char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0 + char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0 + char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0 + char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0 + + # Modulation sources + jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none' + scale_source = self._eval(args[13], env) if len(args) > 13 else 'none' + rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none' + hue_source = self._eval(args[15], env) if len(args) > 15 else 'none' + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get alphabet + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create modulation values based on source + def get_modulation(source, lums, num_rows, num_cols): + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + if source == 'luminance': + return lums + elif source == 'inv_luminance': + return 1.0 - lums + elif source == 'position_x': + return col_norm + elif source == 'position_y': + return row_norm + elif source == 'position_diag': + return (row_norm + col_norm) / 2 + elif source == 'center_dist': + cy, cx = 0.5, 0.5 + dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2) + return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal + elif source == 'random': + # Use frame-varying key for random source + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key = make_jax_key(seed, frame_num, op_ctr) + return jax.random.uniform(key, (num_rows, num_cols)) + else: + return jnp.zeros((num_rows, num_cols)) + + # Get modulation values + jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols) + scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols) + rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols) + hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + # Apply jitter if enabled + if char_jitter > 0: + jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter + # Use frame-varying key for jitter + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2) + # Generate deterministic jitter per cell + jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1) + jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1) + offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + local_y = jnp.clip(local_y + offset_y, 0, char_size - 1) + local_x = jnp.clip(local_x + offset_x, 0, char_size - 1) + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Get foreground colors + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + # Apply hue shift if enabled + if char_hue_shift > 0 and color_mode == 'color': + hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift + # Simple hue rotation via channel cycling + fg_float = fg.astype(jnp.float32) + shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB + # Simplified: blend channels based on shift + r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2] + shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac + # Just do a simple tint for now + fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8) + + # Blend + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + if invert_colors: + result = (fg.astype(jnp.float32) * (1 - char_bright) + + bg_broadcast.astype(jnp.float32) * char_bright) + else: + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'alphabet-char': + # (alphabet-char alphabet-name index) -> char_index in that alphabet + alphabet = self._eval(args[0], env) + index = self._eval(args[1], env) + + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Handle both scalar and array indices + if hasattr(index, 'shape'): + index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1) + else: + index = max(0, min(int(index), num_chars - 1)) + + return index + + if op == 'map-char-grid': + # (map-char-grid base-chars luminances (lambda (r c ch lum) ...)) + # Map over character grid, allowing per-cell character selection + base_chars = self._eval(args[0], env) + luminances = self._eval(args[1], env) + fn = args[2] # Lambda expression + + num_rows, num_cols = base_chars.shape + + # For JAX compatibility, we can't use Python loops with traced values + # Instead, we'll evaluate the lambda for the whole grid at once + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + + # Create grid coordinates + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + + # Bind parameters for whole-grid evaluation + fn_env = env.copy() + + # Params: (r c ch lum) + if len(params) >= 1: + fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords + if len(params) >= 2: + fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords + if len(params) >= 3: + fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars + if len(params) >= 4: + # Luminances scaled to 0-255 range + fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32) + + # Evaluate body - should return new character indices + result = self._eval(body, fn_env) + if hasattr(result, 'shape'): + return result.astype(jnp.int32) + return base_chars + + return base_chars + + # ===================================================================== + # List operations + # ===================================================================== + if op == 'take': + seq = self._eval(args[0], env) + n = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[:n] + return seq[:n] # Works for arrays too + + if op == 'cons': + item = self._eval(args[0], env) + seq = self._eval(args[1], env) + if isinstance(seq, list): + return [item] + seq + elif isinstance(seq, tuple): + return (item,) + seq + return jnp.concatenate([jnp.array([item]), seq]) + + if op == 'roll': + arr = self._eval(args[0], env) + shift = self._eval(args[1], env) + axis = self._eval(args[2], env) if len(args) > 2 else 0 + # Convert to int for concrete values, keep as-is for JAX traced values + if isinstance(shift, (int, float)): + shift = int(shift) + elif hasattr(shift, 'astype'): + shift = shift.astype(jnp.int32) + if isinstance(axis, (int, float)): + axis = int(axis) + return jnp.roll(arr, shift, axis=axis) + + # ===================================================================== + # Pi constant + # ===================================================================== + if op == 'pi': + return jnp.pi + + raise ValueError(f"Unknown operation: {op}") + + +# ============================================================================= +# Public API +# ============================================================================= + +def compile_effect(code: str) -> Callable: + """ + Compile an S-expression effect to a JAX function. + + Args: + code: S-expression effect code + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + # Check cache + cache_key = hashlib.md5(code.encode()).hexdigest() + if cache_key in _COMPILED_EFFECTS: + return _COMPILED_EFFECTS[cache_key] + + # Parse and compile + sexp = parse(code) + compiler = JaxCompiler() + fn = compiler.compile_effect(sexp) + + _COMPILED_EFFECTS[cache_key] = fn + return fn + + +def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable: + """ + Compile an effect from a .sexp file. + + Args: + path: Path to the .sexp effect file + derived_paths: Optional list of paths to derived.sexp files to load + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + with open(path, 'r') as f: + code = f.read() + + # Parse all expressions in file + exprs = parse_all(code) + + # Create compiler + compiler = JaxCompiler() + + # Load derived files if specified + if derived_paths: + for dp in derived_paths: + compiler.load_derived(dp) + + # Process expressions - find require statements and the effect + effect_sexp = None + effect_dir = Path(path).parent + + for expr in exprs: + if not isinstance(expr, list) or len(expr) < 2: + continue + + head = expr[0] + if not isinstance(head, Symbol): + continue + + if head.name == 'require': + # (require "derived") or (require "path/to/file") + req_path = expr[1] + if isinstance(req_path, str): + # Resolve relative to effect file + if not req_path.endswith('.sexp'): + req_path = req_path + '.sexp' + full_path = effect_dir / req_path + if not full_path.exists(): + # Try sexp_effects directory + full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path + if full_path.exists(): + compiler.load_derived(str(full_path)) + + elif head.name == 'require-primitives': + # (require-primitives "lib") - currently ignored for JAX + # JAX has all primitives built-in + pass + + elif head.name in ('effect', 'define-effect'): + effect_sexp = expr + + if effect_sexp is None: + raise ValueError(f"No effect definition found in {path}") + + return compiler.compile_effect(effect_sexp) + + +def load_derived(derived_path: str = None) -> Dict[str, Callable]: + """ + Load derived operations from derived.sexp. + + Returns dict of compiled functions that can be used in effects. + """ + if derived_path is None: + derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp' + + with open(derived_path, 'r') as f: + code = f.read() + + exprs = parse_all(code) + compiler = JaxCompiler() + env = {} + + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + compiler._eval_define(expr[1:], env) + + return env + + +# ============================================================================= +# Test / Demo +# ============================================================================= + +if __name__ == '__main__': + import numpy as np + + # Test effect + test_effect = ''' + (effect "threshold-test" + :params ((threshold :default 128)) + :body (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (gray (+ (* r 0.299) (* g 0.587) (* b 0.114))) + (mask (where (> gray threshold) 255 0))) + (merge-channels mask mask mask))) + ''' + + print("Compiling effect...") + run_effect = compile_effect(test_effect) + + # Create test frame + print("Creating test frame...") + frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) + + # Run effect + print("Running effect (first run includes JIT compilation)...") + import time + + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms") + + # Second run should be faster + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"Second run (cached): {(t1-t0)*1000:.2f}ms") + + # Multiple runs + t0 = time.time() + for _ in range(100): + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg") + + print(f"\nResult shape: {result.shape}, dtype: {result.dtype}") + print("Done!") diff --git a/streaming/stream_sexp_generic.py b/streaming/stream_sexp_generic.py index 8cf0ae9..bc0160e 100644 --- a/streaming/stream_sexp_generic.py +++ b/streaming/stream_sexp_generic.py @@ -28,12 +28,31 @@ import math import numpy as np from pathlib import Path from dataclasses import dataclass -from typing import Dict, List, Any, Optional, Tuple +from typing import Dict, List, Any, Optional, Tuple, Callable # Use local sexp_effects parser (supports namespaced symbols like math:sin) sys.path.insert(0, str(Path(__file__).parent.parent)) from sexp_effects.parser import parse, parse_all, Symbol, Keyword +# JAX backend (optional - loaded on demand) +_JAX_AVAILABLE = False +_jax_compiler = None + +def _init_jax(): + """Lazily initialize JAX compiler.""" + global _JAX_AVAILABLE, _jax_compiler + if _jax_compiler is not None: + return _JAX_AVAILABLE + try: + from streaming.sexp_to_jax import JaxCompiler, compile_effect_file + _jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file} + _JAX_AVAILABLE = True + print("JAX backend initialized", file=sys.stderr) + except ImportError as e: + print(f"JAX backend not available: {e}", file=sys.stderr) + _JAX_AVAILABLE = False + return _JAX_AVAILABLE + @dataclass class Context: @@ -51,7 +70,7 @@ class StreamInterpreter: and calls primitives. """ - def __init__(self, sexp_path: str, actor_id: Optional[str] = None): + def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False): self.sexp_path = Path(sexp_path) self.sexp_dir = self.sexp_path.parent self.actor_id = actor_id # For friendly name resolution @@ -74,6 +93,17 @@ class StreamInterpreter: self.primitives: Dict[str, Any] = {} self.effects: Dict[str, dict] = {} self.macros: Dict[str, dict] = {} + + # JAX backend for accelerated effect evaluation + self.use_jax = use_jax + self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects + self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects + if use_jax: + if _init_jax(): + print("JAX acceleration enabled", file=sys.stderr) + else: + print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr) + self.use_jax = False # Try multiple locations for primitive_libs possible_paths = [ self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir @@ -307,8 +337,13 @@ class StreamInterpreter: i += 1 self.effects[name] = {'params': params, 'body': body} + self.jax_effect_paths[name] = effect_path # Track source for JAX compilation print(f"Effect: {name}", file=sys.stderr) + # Try to compile with JAX if enabled + if self.use_jax and _JAX_AVAILABLE: + self._compile_jax_effect(name, effect_path) + elif cmd == 'defmacro': name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] @@ -387,8 +422,62 @@ class StreamInterpreter: } print(f"Scan: {name}", file=sys.stderr) + def _compile_jax_effect(self, name: str, effect_path: Path): + """Compile an effect with JAX for accelerated execution.""" + if not _JAX_AVAILABLE or name in self.jax_effects: + return + + try: + compile_effect_file = _jax_compiler['compile_effect_file'] + jax_fn = compile_effect_file(str(effect_path)) + self.jax_effects[name] = jax_fn + print(f" [JAX compiled: {name}]", file=sys.stderr) + except Exception as e: + # Silently fall back to interpreter for unsupported effects + if 'Unknown operation' not in str(e): + print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr) + + def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]: + """Apply a JAX-compiled effect to a frame.""" + if name not in self.jax_effects: + return None + + try: + jax_fn = self.jax_effects[name] + # Ensure frame is numpy array + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get'): + frame = frame.get() + + # Get seed from config for deterministic random + seed = self.config.get('seed', 42) + + # Call JAX function with parameters + result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) + + # Convert result back to numpy if needed + if hasattr(result, 'block_until_ready'): + result.block_until_ready() # Ensure computation is complete + if hasattr(result, '__array__'): + result = np.asarray(result) + + return result + except Exception as e: + # Fall back to interpreter on error + print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) + return None + def _init(self): """Initialize from sexp - load primitives, effects, defs, scans.""" + # Set random seed for deterministic output + seed = self.config.get('seed', 42) + try: + from sexp_effects.primitive_libs.core import set_random_seed + set_random_seed(seed) + except ImportError: + pass + # Load external config files first (they can override recipe definitions) if self.sources_config: self._load_config_file(self.sources_config) @@ -780,6 +869,7 @@ class StreamInterpreter: effect_env[pname] = pdef.get('default', 0) positional_idx = 0 + frame_val = None i = 0 while i < len(args): if isinstance(args[i], Keyword): @@ -791,11 +881,24 @@ class StreamInterpreter: val = self._eval(args[i], env) if positional_idx == 0: effect_env['frame'] = val + frame_val = val elif positional_idx - 1 < len(param_names): effect_env[param_names[positional_idx - 1]] = val positional_idx += 1 i += 1 + # Try JAX-accelerated execution first + if self.use_jax and op in self.jax_effects and frame_val is not None: + # Build params dict for JAX (exclude 'frame') + jax_params = {k: v for k, v in effect_env.items() + if k != 'frame' and k in effect['params']} + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num) + if result is not None: + return result + # Fall through to interpreter if JAX fails + return self._eval(effect['body'], effect_env) # === Primitives === @@ -1049,9 +1152,9 @@ class StreamInterpreter: def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None, - sources_config: str = None, audio_config: str = None): + sources_config: str = None, audio_config: str = None, use_jax: bool = False): """Run a streaming sexp.""" - interp = StreamInterpreter(sexp_path) + interp = StreamInterpreter(sexp_path, use_jax=use_jax) if fps: interp.config['fps'] = fps if sources_config: @@ -1070,7 +1173,9 @@ if __name__ == "__main__": parser.add_argument("--fps", type=float, default=None) parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file") parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file") + parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects") args = parser.parse_args() run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps, - sources_config=args.sources_config, audio_config=args.audio_config) + sources_config=args.sources_config, audio_config=args.audio_config, + use_jax=args.jax)