""" Sexp to JAX Compiler. Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU. Uses XLA compilation via @jax.jit for automatic kernel fusion. Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles S-expressions directly to JAX operations which XLA then optimizes. Usage: from streaming.sexp_to_jax import compile_effect effect_code = ''' (effect "threshold" :params ((threshold :default 128)) :body (let ((g (gray frame))) (rgb (where (> g threshold) 255 0) (where (> g threshold) 255 0) (where (> g threshold) 255 0)))) ''' run_effect = compile_effect(effect_code) output = run_effect(frame, threshold=128) """ import jax import jax.numpy as jnp from jax import lax from functools import partial from typing import Any, Dict, List, Callable, Optional, Tuple import hashlib import numpy as np # Import parser import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from sexp_effects.parser import parse, parse_all, Symbol, Keyword # Import typography primitives from streaming.jax_typography import bind_typography_primitives # ============================================================================= # Compilation Cache # ============================================================================= _COMPILED_EFFECTS: Dict[str, Callable] = {} # ============================================================================= # Font Atlas for ASCII Effects # ============================================================================= # Character sets for ASCII rendering ASCII_ALPHABETS = { 'standard': ' .:-=+*#%@', 'blocks': ' ░▒▓█', 'simple': ' .:oO@', 'digits': ' 0123456789', 'binary': ' 01', 'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$', } # Cache for font atlases: (alphabet, char_size, font_name) -> atlas array _FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {} def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray: """ Create a font atlas with all characters pre-rendered. Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time. Args: alphabet: String of characters to render (ordered by brightness, dark to light) char_size: Size of each character cell in pixels font_name: Optional font name/path (uses default monospace if None) Returns: NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters Each character is white on black background. """ cache_key = (alphabet, char_size, font_name) if cache_key in _FONT_ATLAS_CACHE: return _FONT_ATLAS_CACHE[cache_key] try: from PIL import Image, ImageDraw, ImageFont except ImportError: # Fallback: create simple block-based atlas without PIL return _create_block_atlas(alphabet, char_size) num_chars = len(alphabet) atlas = [] # Try to load a monospace font font = None font_size = int(char_size * 0.9) # Slightly smaller than cell # Try various monospace fonts font_candidates = [ font_name, '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', '/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf', '/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf', '/System/Library/Fonts/Menlo.ttc', # macOS '/System/Library/Fonts/Monaco.dfont', # macOS 'C:\\Windows\\Fonts\\consola.ttf', # Windows ] for font_path in font_candidates: if font_path is None: continue try: font = ImageFont.truetype(font_path, font_size) break except (IOError, OSError): continue if font is None: # Use default font try: font = ImageFont.load_default() except: # Ultimate fallback to blocks return _create_block_atlas(alphabet, char_size) for char in alphabet: # Create image for this character img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0)) draw = ImageDraw.Draw(img) # Get text bounding box for centering try: bbox = draw.textbbox((0, 0), char, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] except AttributeError: # Older PIL versions text_width, text_height = draw.textsize(char, font=font) # Center the character x = (char_size - text_width) // 2 y = (char_size - text_height) // 2 # Draw white character on black background draw.text((x, y), char, fill=(255, 255, 255), font=font) # Convert to numpy array (NOT jax array - avoids tracer issues) char_array = np.array(img, dtype=np.uint8) atlas.append(char_array) atlas = np.stack(atlas, axis=0) _FONT_ATLAS_CACHE[cache_key] = atlas return atlas def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray: """ Create a simple block-based atlas without fonts. Uses numpy to avoid tracer issues. """ num_chars = len(alphabet) atlas = [] for i, char in enumerate(alphabet): # Brightness proportional to position in alphabet brightness = int(255 * i / max(num_chars - 1, 1)) # Create a simple pattern based on character img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8) # Add some texture/pattern for visual interest # Checkerboard pattern for mid-range characters if 0.2 < i / num_chars < 0.8: y_coords, x_coords = np.mgrid[:char_size, :char_size] checker = ((x_coords + y_coords) % 2 == 0) variation = int(brightness * 0.2) img = np.where(checker[:, :, None], np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8), np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8)) atlas.append(img) return np.stack(atlas, axis=0) def _get_alphabet_string(alphabet_name: str) -> str: """Get the character string for a named alphabet or return as-is if custom.""" if alphabet_name in ASCII_ALPHABETS: return ASCII_ALPHABETS[alphabet_name] return alphabet_name # Assume it's a custom character string # ============================================================================= # Text Rendering with Font Atlas (JAX-compatible) # ============================================================================= # Default character set for text rendering (printable ASCII) TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~ # Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height) _TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {} def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple: """ Create a font atlas for general text rendering with proper baseline alignment. Font Metrics (from typography): - Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l) - Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y) - Baseline: the line text "sits" on - all characters align to this - Em-square: the design space, typically = ascender + descender Returns: (atlas, char_to_idx, char_widths, char_height, baseline_offset) - atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA - char_to_idx: dict mapping character to atlas index - char_widths: numpy array of actual width for each character - char_height: height of character cells (ascent + descent) - baseline_offset: pixels from top of cell to baseline (= ascent) """ cache_key = (font_name, font_size) if cache_key in _TEXT_ATLAS_CACHE: return _TEXT_ATLAS_CACHE[cache_key] try: from PIL import Image, ImageDraw, ImageFont except ImportError: raise ImportError("PIL/Pillow required for text rendering") # Load font - match drawing.py's font order for consistency font = None font_candidates = [ font_name, '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', '/usr/share/fonts/truetype/freefont/FreeSans.ttf', '/System/Library/Fonts/Helvetica.ttc', '/System/Library/Fonts/Arial.ttf', 'C:\\Windows\\Fonts\\arial.ttf', ] for font_path in font_candidates: if font_path is None: continue try: font = ImageFont.truetype(font_path, font_size) break except (IOError, OSError): continue if font is None: font = ImageFont.load_default() # Get font metrics - this is the key to proper text layout # getmetrics() returns (ascent, descent) where: # ascent = pixels from baseline to top of tallest character # descent = pixels from baseline to bottom of lowest character ascent, descent = font.getmetrics() # Cell dimensions based on font metrics (not per-character bounding boxes) cell_height = ascent + descent + 2 # +2 for padding baseline_y = ascent + 1 # Baseline position within cell (1px padding from top) # Find max character width temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) temp_draw = ImageDraw.Draw(temp_img) max_width = 0 char_widths_dict = {} for char in TEXT_CHARSET: try: # Use getlength for horizontal advance (proper character spacing) advance = font.getlength(char) char_widths_dict[char] = int(advance) max_width = max(max_width, int(advance)) except: char_widths_dict[char] = font_size // 2 max_width = max(max_width, font_size // 2) cell_width = max_width + 2 # +2 for padding # Create atlas with all characters - draw same way as prim_text for pixel-perfect match char_to_idx = {} char_widths = [] # Advance widths char_left_bearings = [] # Left bearing (x offset from origin to first pixel) atlas = [] # Position to draw at within each tile (with margin for negative bearings) draw_x = 5 # Margin for chars with negative left bearing draw_y = 0 # Top of cell (PIL default without anchor) for i, char in enumerate(TEXT_CHARSET): char_to_idx[char] = i char_widths.append(char_widths_dict.get(char, cell_width // 2)) # Create RGBA image for this character img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0)) draw = ImageDraw.Draw(img) # Draw same way as prim_text - at (draw_x, draw_y), no anchor # This positions the text origin, and glyphs may extend left/right from there draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) # Get bbox to find left bearing bbox = draw.textbbox((draw_x, draw_y), char, font=font) left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends char_left_bearings.append(left_bearing) # Convert to numpy char_array = np.array(img, dtype=np.uint8) atlas.append(char_array) atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4) char_widths = np.array(char_widths, dtype=np.int32) char_left_bearings = np.array(char_left_bearings, dtype=np.int32) # Return draw_x (origin offset within tile) so rendering knows where origin is result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings) _TEXT_ATLAS_CACHE[cache_key] = result return result def jax_text_render(frame, text: str, x: int, y: int, font_name: str = None, font_size: int = 32, color=(255, 255, 255), opacity: float = 1.0, align: str = "left", valign: str = "baseline", shadow: bool = False, shadow_color=(0, 0, 0), shadow_offset: int = 2): """ Render text onto frame using font atlas (JAX-compatible). This is designed to be called from within a JIT-compiled function. The font atlas is created at compile time (using numpy/PIL), then converted to JAX array for the actual rendering. Typography notes: - Baseline: The line text "sits" on. Most characters rest on this line. - Ascender: Top of tall letters (b, d, h, k, l) - above baseline - Descender: Bottom of letters like g, j, p, q, y - below baseline - For normal text, use valign="baseline" and y = the baseline position Args: frame: Input frame (H, W, 3) text: Text string to render x, y: Position reference point (affected by align/valign) font_name: Font to use (None = default) font_size: Font size in pixels color: RGB tuple (0-255) opacity: 0.0 to 1.0 align: Horizontal alignment relative to x: "left" - text starts at x "center" - text centered on x "right" - text ends at x valign: Vertical alignment relative to y: "baseline" - text baseline at y (default, like normal text) "top" - top of ascenders at y "middle" - text vertically centered on y "bottom" - bottom of descenders at y shadow: Whether to draw drop shadow shadow_color: Shadow RGB color shadow_offset: Shadow offset in pixels Returns: Frame with text rendered """ if not text: return frame h, w = frame.shape[:2] # Get or create font atlas (this happens at trace time, uses numpy) atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size) # Convert atlas to JAX array atlas = jnp.asarray(atlas_np) # Atlas dimensions cell_width = atlas.shape[2] # Convert text to character indices and compute character widths # (at trace time, text is static so we can pre-compute) indices_list = [] char_x_offsets = [0] # Starting x position for each character total_width = 0 for char in text: if char in char_to_idx: idx = char_to_idx[char] indices_list.append(idx) char_w = int(char_widths_np[idx]) else: indices_list.append(char_to_idx.get(' ', 0)) char_w = int(char_widths_np[char_to_idx.get(' ', 0)]) total_width += char_w char_x_offsets.append(total_width) indices = jnp.array(indices_list, dtype=jnp.int32) num_chars = len(indices_list) # Actual text dimensions using proportional widths text_width = total_width text_height = char_height # Adjust position for horizontal alignment if align == "center": x = x - text_width // 2 elif align == "right": x = x - text_width # Adjust position for vertical alignment # baseline_offset = pixels from top of cell to baseline if valign == "baseline": # y specifies baseline position, so top of text cell is above it y = y - baseline_offset elif valign == "middle": y = y - text_height // 2 elif valign == "bottom": y = y - text_height # valign == "top" needs no adjustment (default) # Ensure position is integer x, y = int(x), int(y) # Create text strip with proper character spacing at trace time (using numpy) # This ensures proportional fonts render correctly # # The atlas stores each character drawn at (origin_x, 0) in its tile. # To place a character at cursor position 'cx': # - The tile's origin_x should align with cx in the strip # - So we blit tile to strip starting at (cx - origin_x) # # Add padding for characters with negative left bearings strip_padding = origin_x # Extra space at start for negative bearings text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8) for i, char in enumerate(text): if char in char_to_idx: idx = char_to_idx[char] char_tile = atlas_np[idx] # (char_height, cell_width, 4) cx = char_x_offsets[i] # Position tile so its origin aligns with cursor position strip_x = strip_padding + cx - origin_x if strip_x >= 0: end_x = min(strip_x + cell_width, text_strip_np.shape[1]) tile_end = end_x - strip_x text_strip_np[:, strip_x:end_x] = np.maximum( text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end]) # Trim the strip: # - Left side: trim to first visible pixel (handles negative left bearing) # - Right side: use computed text_width (preserve advance width spacing) alpha = text_strip_np[:, :, 3] cols_with_content = np.any(alpha > 0, axis=0) if cols_with_content.any(): first_col = np.argmax(cols_with_content) # Right edge: use the computed text width from the strip's logical end right_col = strip_padding + text_width # Adjust x to account for the left trim offset x = x + first_col - strip_padding text_strip_np = text_strip_np[:, first_col:right_col] else: # No visible content, return original frame return frame # Convert to JAX text_strip = jnp.asarray(text_strip_np) # Convert color to array color = jnp.array(color, dtype=jnp.float32) shadow_color = jnp.array(shadow_color, dtype=jnp.float32) # Apply color tint to text strip (white text * color) text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity # Start with frame as float result = frame.astype(jnp.float32) # Draw shadow first if enabled if shadow: sx, sy = x + shadow_offset, y + shadow_offset shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5 result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy) # Draw main text result = _composite_text_strip(result, text_rgb, text_alpha, x, y) return jnp.clip(result, 0, 255).astype(jnp.uint8) def _composite_text_strip(frame, text_rgb, text_alpha, x, y): """ Composite text strip onto frame at position (x, y). Uses alpha blending: result = text * alpha + frame * (1 - alpha) This is designed to work within JAX tracing. """ h, w = frame.shape[:2] th, tw = text_rgb.shape[:2] # Clamp to frame bounds # Source region (in text strip) src_x1 = jnp.maximum(0, -x) src_y1 = jnp.maximum(0, -y) src_x2 = jnp.minimum(tw, w - x) src_y2 = jnp.minimum(th, h - y) # Destination region (in frame) dst_x1 = jnp.maximum(0, x) dst_y1 = jnp.maximum(0, y) dst_x2 = jnp.minimum(w, x + tw) dst_y2 = jnp.minimum(h, y + th) # Check if there's anything to draw # (We need to handle this carefully for JAX - can't use Python if with traced values) # Instead, we'll do the full operation but the slicing will handle bounds # Create coordinate grids for the destination region # We'll use dynamic_slice for JAX-compatible slicing # For simplicity and JAX compatibility, we'll create a full-frame text layer # and composite it - this is less efficient but works with JIT # Create full-frame RGBA layer text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32) # Place text strip in the layer using dynamic_update_slice # First pad the text strip to handle out-of-bounds padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) padded_alpha = jnp.zeros((h, w), dtype=jnp.float32) # Calculate valid region y_start = int(max(0, y)) y_end = int(min(h, y + th)) x_start = int(max(0, x)) x_end = int(min(w, x + tw)) src_y_start = int(max(0, -y)) src_y_end = src_y_start + (y_end - y_start) src_x_start = int(max(0, -x)) src_x_end = src_x_start + (x_end - x_start) # Only proceed if there's a valid region if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start: # Extract the valid portion of text valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end] valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end] # Use lax.dynamic_update_slice for JAX compatibility padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0)) padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start)) # Alpha composite: result = text * alpha + frame * (1 - alpha) alpha_3d = padded_alpha[:, :, jnp.newaxis] result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d) return result def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple: """ Measure text dimensions (width, height). This can be called at compile time to get text dimensions for layout. Returns: (width, height) tuple in pixels """ _, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size) # Sum actual character widths total_width = 0 for c in text: if c in char_to_idx: total_width += int(char_widths[char_to_idx[c]]) else: total_width += int(char_widths[char_to_idx.get(' ', 0)]) return (total_width, char_height) def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict: """ Get font metrics for layout calculations. Typography terms: - ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.) - descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.) - height: total height = ascent + descent (plus padding) - baseline: position of baseline from top of text cell Returns: dict with keys: ascent, descent, height, baseline """ _, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size) # baseline_offset is pixels from top to baseline (= ascent + padding) # descent = height - baseline (approximately) ascent = baseline_offset - 1 # remove padding descent = char_height - baseline_offset - 1 # remove padding return { 'ascent': ascent, 'descent': descent, 'height': char_height, 'baseline': baseline_offset, } def jax_fit_text_size(text: str, max_width: int, max_height: int, font_name: str = None, min_size: int = 8, max_size: int = 200) -> int: """ Calculate font size to fit text within bounds. Binary search for largest size that fits. """ best_size = min_size low, high = min_size, max_size while low <= high: mid = (low + high) // 2 w, h = jax_text_size(text, font_name, mid) if w <= max_width and h <= max_height: best_size = mid low = mid + 1 else: high = mid - 1 return best_size # ============================================================================= # JAX Primitives - True primitives that can't be derived # ============================================================================= def jax_width(frame): """Frame width.""" return frame.shape[1] def jax_height(frame): """Frame height.""" return frame.shape[0] def jax_channel(frame, idx): """Extract channel by index as flat array.""" # idx must be a static int for indexing return frame[:, :, int(idx)].flatten().astype(jnp.float32) def jax_merge_channels(r, g, b, shape): """Merge RGB channels back to frame.""" h, w = shape r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8) g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8) b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) return jnp.stack([r_img, g_img, b_img], axis=2) def jax_iota(n): """Generate [0, 1, 2, ..., n-1].""" return jnp.arange(n, dtype=jnp.float32) def jax_repeat(x, n): """Repeat each element n times: [a,b] -> [a,a,b,b].""" return jnp.repeat(x, n) def jax_tile(x, n): """Tile array n times: [a,b] -> [a,b,a,b].""" return jnp.tile(x, n) def jax_gather(data, indices): """Parallel index lookup.""" flat_data = data.flatten() idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1) return flat_data[idx_clipped] def jax_scatter(indices, values, size): """Parallel index write (last write wins).""" result = jnp.zeros(size, dtype=jnp.float32) idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) return result.at[idx_clipped].set(values) def jax_scatter_add(indices, values, size): """Parallel index accumulate.""" result = jnp.zeros(size, dtype=jnp.float32) idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) return result.at[idx_clipped].add(values) def jax_group_reduce(values, group_indices, num_groups, op='mean'): """Reduce values by group.""" grp = group_indices.astype(jnp.int32) if op == 'sum': result = jnp.zeros(num_groups, dtype=jnp.float32) return result.at[grp].add(values) elif op == 'mean': sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values) counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0) return jnp.where(counts > 0, sums / counts, 0.0) elif op == 'max': result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32) result = result.at[grp].max(values) return jnp.where(result == -jnp.inf, 0.0, result) elif op == 'min': result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32) result = result.at[grp].min(values) return jnp.where(result == jnp.inf, 0.0, result) else: raise ValueError(f"Unknown reduce op: {op}") def jax_where(cond, true_val, false_val): """Conditional select.""" return jnp.where(cond, true_val, false_val) def jax_cell_indices(frame, cell_size): """Compute cell index for each pixel.""" h, w = frame.shape[:2] cell_size = int(cell_size) rows = h // cell_size cols = w // cell_size # For each pixel, compute its cell index y_coords = jnp.repeat(jnp.arange(h), w) x_coords = jnp.tile(jnp.arange(w), h) cell_row = y_coords // cell_size cell_col = x_coords // cell_size cell_idx = cell_row * cols + cell_col # Clip to valid range return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32) def jax_pool_frame(frame, cell_size): """ Pool frame to cell values. Returns tuple: (cell_r, cell_g, cell_b, cell_lum) """ h, w = frame.shape[:2] cs = int(cell_size) rows = h // cs cols = w // cs num_cells = rows * cols # Compute cell indices for each pixel y_coords = jnp.repeat(jnp.arange(h), w) x_coords = jnp.tile(jnp.arange(w), h) cell_row = jnp.clip(y_coords // cs, 0, rows - 1) cell_col = jnp.clip(x_coords // cs, 0, cols - 1) cell_idx = (cell_row * cols + cell_col).astype(jnp.int32) # Extract channels r_flat = frame[:, :, 0].flatten().astype(jnp.float32) g_flat = frame[:, :, 1].flatten().astype(jnp.float32) b_flat = frame[:, :, 2].flatten().astype(jnp.float32) # Pool each channel (mean) def pool_channel(data): sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data) counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0) return jnp.where(counts > 0, sums / counts, 0.0) r_pooled = pool_channel(r_flat) g_pooled = pool_channel(g_flat) b_pooled = pool_channel(b_flat) lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled return (r_pooled, g_pooled, b_pooled, lum) # ============================================================================= # Scan (Prefix Operations) - JAX implementations # ============================================================================= def jax_scan_add(x, axis=None): """Cumulative sum (prefix sum).""" if axis is not None: return jnp.cumsum(x, axis=int(axis)) return jnp.cumsum(x.flatten()) def jax_scan_mul(x, axis=None): """Cumulative product.""" if axis is not None: return jnp.cumprod(x, axis=int(axis)) return jnp.cumprod(x.flatten()) def jax_scan_max(x, axis=None): """Cumulative maximum.""" if axis is not None: return lax.cummax(x, axis=int(axis)) return lax.cummax(x.flatten(), axis=0) def jax_scan_min(x, axis=None): """Cumulative minimum.""" if axis is not None: return lax.cummin(x, axis=int(axis)) return lax.cummin(x.flatten(), axis=0) # ============================================================================= # Outer Product - JAX implementations # ============================================================================= def jax_outer(x, y, op='*'): """Outer product with configurable operation.""" x_flat = x.flatten() y_flat = y.flatten() ops = { '*': lambda a, b: jnp.outer(a, b), '+': lambda a, b: a[:, None] + b[None, :], '-': lambda a, b: a[:, None] - b[None, :], '/': lambda a, b: a[:, None] / b[None, :], 'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]), 'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]), } op_fn = ops.get(op, ops['*']) return op_fn(x_flat, y_flat) def jax_outer_add(x, y): """Outer sum.""" return jax_outer(x, y, '+') def jax_outer_mul(x, y): """Outer product.""" return jax_outer(x, y, '*') def jax_outer_max(x, y): """Outer max.""" return jax_outer(x, y, 'max') def jax_outer_min(x, y): """Outer min.""" return jax_outer(x, y, 'min') # ============================================================================= # Reduce with Axis - JAX implementations # ============================================================================= def jax_reduce_axis(x, op='sum', axis=0): """Reduce along an axis.""" axis = int(axis) ops = { 'sum': lambda d: jnp.sum(d, axis=axis), '+': lambda d: jnp.sum(d, axis=axis), 'mean': lambda d: jnp.mean(d, axis=axis), 'max': lambda d: jnp.max(d, axis=axis), 'min': lambda d: jnp.min(d, axis=axis), 'prod': lambda d: jnp.prod(d, axis=axis), '*': lambda d: jnp.prod(d, axis=axis), 'std': lambda d: jnp.std(d, axis=axis), } op_fn = ops.get(op, ops['sum']) return op_fn(x) def jax_sum_axis(x, axis=0): """Sum along axis.""" return jnp.sum(x, axis=int(axis)) def jax_mean_axis(x, axis=0): """Mean along axis.""" return jnp.mean(x, axis=int(axis)) def jax_max_axis(x, axis=0): """Max along axis.""" return jnp.max(x, axis=int(axis)) def jax_min_axis(x, axis=0): """Min along axis.""" return jnp.min(x, axis=int(axis)) # ============================================================================= # Windowed Operations - JAX implementations # ============================================================================= def jax_window(x, size, op='mean', stride=1): """ Sliding window operation. For 1D arrays: standard sliding window For 2D arrays: 2D sliding window (size x size) """ size = int(size) stride = int(stride) if x.ndim == 1: # 1D sliding window using convolution trick n = len(x) if op == 'sum': kernel = jnp.ones(size) return jnp.convolve(x, kernel, mode='valid')[::stride] elif op == 'mean': kernel = jnp.ones(size) / size return jnp.convolve(x, kernel, mode='valid')[::stride] else: # For max/min, use manual approach out_n = (n - size) // stride + 1 indices = jnp.arange(out_n) * stride windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices) if op == 'max': return jnp.max(windows, axis=1) elif op == 'min': return jnp.min(windows, axis=1) else: return jnp.mean(windows, axis=1) else: # 2D sliding window h, w = x.shape[:2] out_h = (h - size) // stride + 1 out_w = (w - size) // stride + 1 # Extract all windows using vmap def extract_window(ij): i, j = ij // out_w, ij % out_w return lax.dynamic_slice(x, (i * stride, j * stride), (size, size)) indices = jnp.arange(out_h * out_w) windows = jax.vmap(extract_window)(indices) if op == 'sum': result = jnp.sum(windows, axis=(1, 2)) elif op == 'mean': result = jnp.mean(windows, axis=(1, 2)) elif op == 'max': result = jnp.max(windows, axis=(1, 2)) elif op == 'min': result = jnp.min(windows, axis=(1, 2)) else: result = jnp.mean(windows, axis=(1, 2)) return result.reshape(out_h, out_w) def jax_window_sum(x, size, stride=1): """Sliding window sum.""" return jax_window(x, size, 'sum', stride) def jax_window_mean(x, size, stride=1): """Sliding window mean.""" return jax_window(x, size, 'mean', stride) def jax_window_max(x, size, stride=1): """Sliding window max.""" return jax_window(x, size, 'max', stride) def jax_window_min(x, size, stride=1): """Sliding window min.""" return jax_window(x, size, 'min', stride) def jax_integral_image(frame): """ Compute integral image (summed area table). Enables O(1) box blur at any radius. """ if frame.ndim == 3: # Convert to grayscale gray = jnp.mean(frame.astype(jnp.float32), axis=2) else: gray = frame.astype(jnp.float32) # Cumsum along both axes return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1) def jax_sample(frame, x, y): """Bilinear sample at (x, y) coordinates. Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default): out-of-bounds samples return 0, then bilinear blend includes those zeros. """ h, w = frame.shape[:2] # Get integer coords for the 4 sample points x0 = jnp.floor(x).astype(jnp.int32) y0 = jnp.floor(y).astype(jnp.int32) x1 = x0 + 1 y1 = y0 + 1 fx = x - x0.astype(jnp.float32) fy = y - y0.astype(jnp.float32) # Check which sample points are in bounds valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) # Clamp indices for safe array access (values will be masked anyway) x0_safe = jnp.clip(x0, 0, w - 1) x1_safe = jnp.clip(x1, 0, w - 1) y0_safe = jnp.clip(y0, 0, h - 1) y1_safe = jnp.clip(y1, 0, h - 1) # Bilinear interpolation for each channel def interp_channel(c): # Sample with 0 for out-of-bounds (BORDER_CONSTANT) c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0) c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0) c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0) c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0) return (c00 * (1 - fx) * (1 - fy) + c10 * fx * (1 - fy) + c01 * (1 - fx) * fy + c11 * fx * fy) r = interp_channel(0) g = interp_channel(1) b = interp_channel(2) return r, g, b # ============================================================================= # Convolution Operations # ============================================================================= def jax_convolve2d(data, kernel): """2D convolution on a single channel.""" # data shape: (H, W), kernel shape: (kH, kW) # Use JAX's conv with appropriate padding h, w = data.shape kh, kw = kernel.shape # Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c) data_4d = data.reshape(1, h, w, 1) kernel_4d = kernel.reshape(kh, kw, 1, 1) # Convolve with 'SAME' padding result = lax.conv_general_dilated( data_4d, kernel_4d, window_strides=(1, 1), padding='SAME', dimension_numbers=('NHWC', 'HWIO', 'NHWC') ) return result.reshape(h, w) def jax_blur(frame, radius=1): """Gaussian blur.""" # Create gaussian kernel size = int(radius) * 2 + 1 x = jnp.arange(size) - radius gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2)) gaussian_1d = gaussian_1d / gaussian_1d.sum() kernel = jnp.outer(gaussian_1d, gaussian_1d) h, w = frame.shape[:2] r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) return jnp.stack([ jnp.clip(r, 0, 255).astype(jnp.uint8), jnp.clip(g, 0, 255).astype(jnp.uint8), jnp.clip(b, 0, 255).astype(jnp.uint8) ], axis=2) def jax_sharpen(frame, amount=1.0): """Sharpen using unsharp mask.""" kernel = jnp.array([ [0, -1, 0], [-1, 5, -1], [0, -1, 0] ], dtype=jnp.float32) # Adjust kernel based on amount center = 4 * amount + 1 kernel = kernel.at[1, 1].set(center) kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount) h, w = frame.shape[:2] r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) return jnp.stack([ jnp.clip(r, 0, 255).astype(jnp.uint8), jnp.clip(g, 0, 255).astype(jnp.uint8), jnp.clip(b, 0, 255).astype(jnp.uint8) ], axis=2) def jax_edge_detect(frame): """Sobel edge detection.""" # Sobel kernels sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) # Convert to grayscale first gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + frame[:, :, 1].astype(jnp.float32) * 0.587 + frame[:, :, 2].astype(jnp.float32) * 0.114) gx = jax_convolve2d(gray, sobel_x) gy = jax_convolve2d(gray, sobel_y) edges = jnp.sqrt(gx**2 + gy**2) edges = jnp.clip(edges, 0, 255).astype(jnp.uint8) return jnp.stack([edges, edges, edges], axis=2) def jax_emboss(frame): """Emboss effect.""" kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32) gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + frame[:, :, 1].astype(jnp.float32) * 0.587 + frame[:, :, 2].astype(jnp.float32) * 0.114) embossed = jax_convolve2d(gray, kernel) + 128 embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8) return jnp.stack([embossed, embossed, embossed], axis=2) # ============================================================================= # Color Space Conversion # ============================================================================= def jax_rgb_to_hsv(r, g, b): """Convert RGB to HSV. All inputs/outputs are 0-255 range.""" r, g, b = r / 255.0, g / 255.0, b / 255.0 max_c = jnp.maximum(jnp.maximum(r, g), b) min_c = jnp.minimum(jnp.minimum(r, g), b) diff = max_c - min_c # Value v = max_c # Saturation s = jnp.where(max_c > 0, diff / max_c, 0.0) # Hue h = jnp.where(diff == 0, 0.0, jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360, jnp.where(max_c == g, 60 * ((b - r) / diff) + 120, 60 * ((r - g) / diff) + 240))) return h, s * 255, v * 255 def jax_hsv_to_rgb(h, s, v): """Convert HSV to RGB. H is 0-360, S and V are 0-255.""" h = h % 360 s, v = s / 255.0, v / 255.0 c = v * s x = c * (1 - jnp.abs((h / 60) % 2 - 1)) m = v - c h_sector = (h / 60).astype(jnp.int32) % 6 r = jnp.where(h_sector == 0, c, jnp.where(h_sector == 1, x, jnp.where(h_sector == 2, 0, jnp.where(h_sector == 3, 0, jnp.where(h_sector == 4, x, c))))) g = jnp.where(h_sector == 0, x, jnp.where(h_sector == 1, c, jnp.where(h_sector == 2, c, jnp.where(h_sector == 3, x, jnp.where(h_sector == 4, 0, 0))))) b = jnp.where(h_sector == 0, 0, jnp.where(h_sector == 1, 0, jnp.where(h_sector == 2, x, jnp.where(h_sector == 3, c, jnp.where(h_sector == 4, c, x))))) return (r + m) * 255, (g + m) * 255, (b + m) * 255 def jax_adjust_saturation(frame, factor): """Adjust saturation by factor (1.0 = unchanged).""" r = frame[:, :, 0].flatten().astype(jnp.float32) g = frame[:, :, 1].flatten().astype(jnp.float32) b = frame[:, :, 2].flatten().astype(jnp.float32) h, s, v = jax_rgb_to_hsv(r, g, b) s = jnp.clip(s * factor, 0, 255) r2, g2, b2 = jax_hsv_to_rgb(h, s, v) h_dim, w_dim = frame.shape[:2] return jnp.stack([ jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) ], axis=2) def jax_shift_hue(frame, degrees): """Shift hue by degrees.""" r = frame[:, :, 0].flatten().astype(jnp.float32) g = frame[:, :, 1].flatten().astype(jnp.float32) b = frame[:, :, 2].flatten().astype(jnp.float32) h, s, v = jax_rgb_to_hsv(r, g, b) h = (h + degrees) % 360 r2, g2, b2 = jax_hsv_to_rgb(h, s, v) h_dim, w_dim = frame.shape[:2] return jnp.stack([ jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) ], axis=2) # ============================================================================= # Color Adjustment Operations # ============================================================================= def jax_adjust_brightness(frame, amount): """Adjust brightness by amount (-255 to 255).""" result = frame.astype(jnp.float32) + amount return jnp.clip(result, 0, 255).astype(jnp.uint8) def jax_adjust_contrast(frame, factor): """Adjust contrast by factor (1.0 = unchanged).""" result = (frame.astype(jnp.float32) - 128) * factor + 128 return jnp.clip(result, 0, 255).astype(jnp.uint8) def jax_invert(frame): """Invert colors.""" return 255 - frame def jax_posterize(frame, levels): """Reduce to N color levels per channel.""" levels = int(levels) if levels < 2: levels = 2 step = 255.0 / (levels - 1) result = jnp.round(frame.astype(jnp.float32) / step) * step return jnp.clip(result, 0, 255).astype(jnp.uint8) def jax_threshold(frame, level, invert=False): """Binary threshold.""" gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + frame[:, :, 1].astype(jnp.float32) * 0.587 + frame[:, :, 2].astype(jnp.float32) * 0.114) if invert: binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8) else: binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8) return jnp.stack([binary, binary, binary], axis=2) def jax_sepia(frame): """Apply sepia tone.""" r = frame[:, :, 0].astype(jnp.float32) g = frame[:, :, 1].astype(jnp.float32) b = frame[:, :, 2].astype(jnp.float32) new_r = r * 0.393 + g * 0.769 + b * 0.189 new_g = r * 0.349 + g * 0.686 + b * 0.168 new_b = r * 0.272 + g * 0.534 + b * 0.131 return jnp.stack([ jnp.clip(new_r, 0, 255).astype(jnp.uint8), jnp.clip(new_g, 0, 255).astype(jnp.uint8), jnp.clip(new_b, 0, 255).astype(jnp.uint8) ], axis=2) def jax_grayscale(frame): """Convert to grayscale.""" gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + frame[:, :, 1].astype(jnp.float32) * 0.587 + frame[:, :, 2].astype(jnp.float32) * 0.114) gray = gray.astype(jnp.uint8) return jnp.stack([gray, gray, gray], axis=2) # ============================================================================= # Geometry Operations # ============================================================================= def jax_flip_horizontal(frame): """Flip horizontally.""" return frame[:, ::-1, :] def jax_flip_vertical(frame): """Flip vertically.""" return frame[::-1, :, :] def jax_rotate(frame, angle, center_x=None, center_y=None): """Rotate frame by angle (degrees), matching OpenCV convention. Positive angle = counter-clockwise rotation. """ h, w = frame.shape[:2] if center_x is None: center_x = w / 2 if center_y is None: center_y = h / 2 # Convert to radians theta = angle * jnp.pi / 180 cos_t, sin_t = jnp.cos(theta), jnp.sin(theta) # Create coordinate grids y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) # OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]] # For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]] # So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx # src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy x_centered = x_coords - center_x y_centered = y_coords - center_y src_x = cos_t * x_centered - sin_t * y_centered + center_x src_y = sin_t * x_centered + cos_t * y_centered + center_y # Sample using bilinear interpolation # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) def jax_scale(frame, scale_x, scale_y=None): """Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds.""" if scale_y is None: scale_y = scale_x h, w = frame.shape[:2] center_x, center_y = w / 2, h / 2 # Create coordinate grids y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) # Scale from center (inverse mapping: dst -> src) src_x = (x_coords - center_x) / scale_x + center_x src_y = (y_coords - center_y) / scale_y + center_y # Sample using bilinear interpolation # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) def jax_resize(frame, new_width, new_height): """Resize frame to new dimensions.""" h, w = frame.shape[:2] new_h, new_w = int(new_height), int(new_width) # Create coordinate grids for new size y_coords = jnp.repeat(jnp.arange(new_h), new_w) x_coords = jnp.tile(jnp.arange(new_w), new_h) # Map to source coordinates src_x = x_coords * (w - 1) / (new_w - 1) src_y = y_coords * (h - 1) / (new_h - 1) r, g, b = jax_sample(frame, src_x, src_y) return jnp.stack([ jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8) ], axis=2) # ============================================================================= # Blending Operations # ============================================================================= def _resize_to_match(frame1, frame2): """Resize frame2 to match frame1's dimensions if they differ. Uses jax.image.resize for bilinear interpolation. Returns frame2 resized to frame1's shape. """ h1, w1 = frame1.shape[:2] h2, w2 = frame2.shape[:2] # If same size, return as-is if h1 == h2 and w1 == w2: return frame2 # Resize frame2 to match frame1 # jax.image.resize expects (height, width, channels) and target shape return jax.image.resize( frame2.astype(jnp.float32), (h1, w1, frame2.shape[2]), method='bilinear' ).astype(jnp.uint8) def jax_blend(frame1, frame2, alpha): """Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2. Auto-resizes frame2 to match frame1 if dimensions differ. """ frame2 = _resize_to_match(frame1, frame2) return (frame1.astype(jnp.float32) * (1 - alpha) + frame2.astype(jnp.float32) * alpha).astype(jnp.uint8) def jax_blend_add(frame1, frame2): """Additive blend. Auto-resizes frame2 to match frame1.""" frame2 = _resize_to_match(frame1, frame2) result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32) return jnp.clip(result, 0, 255).astype(jnp.uint8) def jax_blend_multiply(frame1, frame2): """Multiply blend. Auto-resizes frame2 to match frame1.""" frame2 = _resize_to_match(frame1, frame2) result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255 return jnp.clip(result, 0, 255).astype(jnp.uint8) def jax_blend_screen(frame1, frame2): """Screen blend. Auto-resizes frame2 to match frame1.""" frame2 = _resize_to_match(frame1, frame2) f1 = frame1.astype(jnp.float32) / 255 f2 = frame2.astype(jnp.float32) / 255 result = 1 - (1 - f1) * (1 - f2) return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) def jax_blend_overlay(frame1, frame2): """Overlay blend. Auto-resizes frame2 to match frame1.""" frame2 = _resize_to_match(frame1, frame2) f1 = frame1.astype(jnp.float32) / 255 f2 = frame2.astype(jnp.float32) / 255 result = jnp.where(f1 < 0.5, 2 * f1 * f2, 1 - 2 * (1 - f1) * (1 - f2)) return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) # ============================================================================= # Utility # ============================================================================= def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0): """Create a JAX random key that varies with frame and operation. Uses jax.random.fold_in to mix frame_num (which may be traced) into the key. This allows JIT compilation without recompiling for each frame. Args: seed: Base seed for determinism (must be concrete) frame_num: Frame number for variation (can be traced) op_id: Operation ID for variation (must be concrete) Returns: JAX PRNGKey """ # Create base key from seed and op_id (both concrete) base_key = jax.random.PRNGKey(seed + op_id * 1000003) # Fold in frame_num (can be traced value) return jax.random.fold_in(base_key, frame_num) def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42): """Random float in [lo, hi), varies with frame.""" key = make_jax_key(seed, frame_num, op_id) return lo + jax.random.uniform(key) * (hi - lo) def jax_is_nil(x): """Check if value is None/nil.""" return x is None # ============================================================================= # S-expression to JAX Compiler # ============================================================================= class JaxCompiler: """Compiles S-expressions to JAX functions.""" def __init__(self): self.env = {} # Variable bindings during compilation self.params = {} # Effect parameters self.primitives = {} # Loaded primitive libraries self.derived = {} # Loaded derived functions def load_derived(self, path: str): """Load derived operations from a .sexp file.""" with open(path, 'r') as f: code = f.read() exprs = parse_all(code) # Evaluate all define expressions to populate derived functions for expr in exprs: if isinstance(expr, list) and len(expr) >= 3: head = expr[0] if isinstance(head, Symbol) and head.name == 'define': self._eval_define(expr[1:], self.derived) def compile_effect(self, sexp) -> Callable: """ Compile an effect S-expression to a JAX function. Supports both formats: (effect "name" :params (...) :body ...) (define-effect name :params (...) body) Args: sexp: Parsed S-expression Returns: JIT-compiled function: (frame, **params) -> frame """ if not isinstance(sexp, list) or len(sexp) < 2: raise ValueError("Effect must be a list") head = sexp[0] if not isinstance(head, Symbol): raise ValueError("Effect must start with a symbol") form = head.name # Handle both 'effect' and 'define-effect' formats if form == 'effect': # (effect "name" :params (...) :body ...) name = sexp[1] if len(sexp) > 1 else "unnamed" if isinstance(name, Symbol): name = name.name start_idx = 2 elif form == 'define-effect': # (define-effect name :params (...) body) name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1]) start_idx = 2 else: raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'") params_spec = [] body = None i = start_idx while i < len(sexp): item = sexp[i] if isinstance(item, Keyword): if item.name == 'params' and i + 1 < len(sexp): params_spec = sexp[i + 1] i += 2 elif item.name == 'body' and i + 1 < len(sexp): body = sexp[i + 1] i += 2 elif item.name in ('desc', 'type', 'range'): # Skip metadata keywords i += 2 else: i += 2 # Skip unknown keywords with their values else: # Assume it's the body if we haven't seen one if body is None: body = item i += 1 if body is None: raise ValueError(f"Effect '{name}' must have a body") # Extract parameter names, defaults, and static params (strings, bools) param_info, static_params = self._parse_params(params_spec) # Capture derived functions for the closure derived_fns = self.derived.copy() # Create the JAX function def effect_fn(frame, **kwargs): # Set up environment h, w = frame.shape[:2] # Get frame_num for deterministic random variation frame_num = kwargs.get('frame_num', 0) # Get seed from recipe config (passed via kwargs) seed = kwargs.get('seed', 42) env = { 'frame': frame, 'width': w, 'height': h, '_shape': (h, w), # Time variables (default to 0, can be overridden via kwargs) 't': kwargs.get('t', kwargs.get('_time', 0.0)), '_time': kwargs.get('_time', kwargs.get('t', 0.0)), 'time': kwargs.get('time', kwargs.get('t', 0.0)), # Frame number for random key generation 'frame_num': frame_num, 'frame-num': frame_num, '_frame_num': frame_num, # Seed from recipe for deterministic random '_seed': seed, # Counter for unique random keys within same frame '_rand_op_counter': 0, # Common constants 'pi': jnp.pi, 'PI': jnp.pi, } # Add derived functions env.update(derived_fns) # Add typography primitives bind_typography_primitives(env) # Add parameters with defaults for pname, pdefault in param_info.items(): if pname in kwargs: env[pname] = kwargs[pname] elif isinstance(pdefault, list): # Unevaluated S-expression default - evaluate it env[pname] = self._eval(pdefault, env) else: env[pname] = pdefault # Evaluate body result = self._eval(body, env) # Ensure result is a frame if isinstance(result, tuple) and len(result) == 3: # RGB tuple - merge to frame r, g, b = result return jax_merge_channels(r, g, b, (h, w)) elif result.ndim == 3: return result else: # Single channel - replicate to RGB h, w = env['_shape'] gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8) return jnp.stack([gray, gray, gray], axis=2) # JIT compile with static args for string/bool parameters and seed # seed must be static for PRNGKey, but frame_num can be traced via fold_in all_static = set(static_params) | {'seed'} return jax.jit(effect_fn, static_argnames=list(all_static)) def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]: """Parse parameter specifications. Returns: Tuple of (param_defaults, static_params) - param_defaults: Dict mapping param names to default values - static_params: Set of param names that should be static (strings, bools) """ result = {} static_params = set() if not isinstance(params_spec, list): return result, static_params for param in params_spec: if isinstance(param, Symbol): result[param.name] = 0.0 elif isinstance(param, list) and len(param) >= 1: pname = param[0].name if isinstance(param[0], Symbol) else str(param[0]) pdefault = 0.0 ptype = None # Look for :default and :type keywords i = 1 while i < len(param): if isinstance(param[i], Keyword): kw = param[i].name if kw == 'default' and i + 1 < len(param): pdefault = param[i + 1] if isinstance(pdefault, Symbol): if pdefault.name == 'nil': pdefault = None elif pdefault.name == 'true': pdefault = True elif pdefault.name == 'false': pdefault = False i += 2 elif kw == 'type' and i + 1 < len(param): ptype = param[i + 1] if isinstance(ptype, Symbol): ptype = ptype.name i += 2 else: i += 1 else: i += 1 result[pname] = pdefault # Mark string and bool parameters as static (can't be traced by JAX) if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)): static_params.add(pname) return result, static_params def _eval(self, expr, env: Dict[str, Any]) -> Any: """Evaluate an S-expression in the given environment.""" # Already-evaluated values (e.g., from threading macros) # JAX arrays, NumPy arrays, tuples, etc. if hasattr(expr, 'shape'): # JAX/NumPy array return expr if isinstance(expr, tuple): # e.g., (r, g, b) from rgb return expr # Literals - keep as Python numbers for static operations if isinstance(expr, (int, float)): return expr if isinstance(expr, str): return expr # Symbols - variable lookup if isinstance(expr, Symbol): name = expr.name if name in env: return env[name] if name == 'nil': return None if name == 'true': return True if name == 'false': return False raise NameError(f"Unknown symbol: {name}") # Lists - function calls if isinstance(expr, list) and len(expr) > 0: head = expr[0] if isinstance(head, Symbol): op = head.name args = expr[1:] # Special forms if op == 'let' or op == 'let*': return self._eval_let(args, env) if op == 'if': return self._eval_if(args, env) if op == 'lambda' or op == 'λ': return self._eval_lambda(args, env) if op == 'define': return self._eval_define(args, env) # Built-in operations return self._eval_call(op, args, env) # Empty list if isinstance(expr, list) and len(expr) == 0: return [] raise ValueError(f"Cannot evaluate: {expr}") def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]): """Extract a keyword argument from args list. Looks for :key value pattern in args and evaluates the value. Returns default if not found. """ i = 0 while i < len(args): if isinstance(args[i], Keyword) and args[i].name == key: if i + 1 < len(args): val = self._eval(args[i + 1], env) # Handle Symbol values (e.g., :op 'sum -> 'sum') if isinstance(val, Symbol): return val.name return val return default i += 1 return default def _eval_let(self, args, env: Dict[str, Any]) -> Any: """Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body).""" if len(args) < 2: raise ValueError("let requires bindings and body") bindings = args[0] body = args[1] new_env = env.copy() # Handle both ((var val) ...) and [var val var2 val2 ...] syntax if isinstance(bindings, list): # Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...) if bindings and isinstance(bindings[0], Symbol): # Flat list: [var val var2 val2 ...] i = 0 while i < len(bindings) - 1: var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) val = self._eval(bindings[i + 1], new_env) new_env[var] = val i += 2 else: # Nested list: ((var val) (var2 val2) ...) for binding in bindings: if isinstance(binding, list) and len(binding) >= 2: var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) val = self._eval(binding[1], new_env) new_env[var] = val return self._eval(body, new_env) def _eval_if(self, args, env: Dict[str, Any]) -> Any: """Evaluate (if cond then else).""" if len(args) < 2: raise ValueError("if requires condition and then-branch") cond = self._eval(args[0], env) # Handle None as falsy (important for optional params like overlay) if cond is None: return self._eval(args[2], env) if len(args) > 2 else None # For Python scalar bools, use normal Python if # This allows side effects and None values if isinstance(cond, bool): if cond: return self._eval(args[1], env) else: return self._eval(args[2], env) if len(args) > 2 else None # For NumPy/JAX scalar bools with concrete values if hasattr(cond, 'item') and cond.shape == (): try: if bool(cond.item()): return self._eval(args[1], env) else: return self._eval(args[2], env) if len(args) > 2 else None except: pass # Fall through to jnp.where for traced values # For traced values, evaluate both branches and use jnp.where then_val = self._eval(args[1], env) else_val = self._eval(args[2], env) if len(args) > 2 else 0.0 # Handle None by converting to zeros if then_val is None: then_val = 0.0 if else_val is None: else_val = 0.0 # Convert lists to tuples if isinstance(then_val, list): then_val = tuple(then_val) if isinstance(else_val, list): else_val = tuple(else_val) # Handle tuple results (e.g., from rgb in map-pixels) if isinstance(then_val, tuple) and isinstance(else_val, tuple): return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val)) return jnp.where(cond, then_val, else_val) def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable: """Evaluate (lambda (params) body).""" if len(args) < 2: raise ValueError("lambda requires parameters and body") params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]] body = args[1] captured_env = env.copy() def fn(*fn_args): local_env = captured_env.copy() for pname, pval in zip(params, fn_args): local_env[pname] = pval return self._eval(body, local_env) return fn def _eval_define(self, args, env: Dict[str, Any]) -> Any: """Evaluate (define name value) or (define (name params) body).""" if len(args) < 2: raise ValueError("define requires name and value") name_part = args[0] if isinstance(name_part, list): # Function definition: (define (name params) body) fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0]) params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]] body = args[1] captured_env = env.copy() def fn(*fn_args): local_env = captured_env.copy() for pname, pval in zip(params, fn_args): local_env[pname] = pval return self._eval(body, local_env) env[fn_name] = fn return fn else: # Variable definition var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part) val = self._eval(args[1], env) env[var_name] = val return val def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any: """Evaluate a function call.""" # Check if it's a user-defined function if op in env and callable(env[op]): fn = env[op] eval_args = [self._eval(a, env) for a in args] return fn(*eval_args) # Arithmetic if op == '+': vals = [self._eval(a, env) for a in args] result = vals[0] if vals else 0.0 for v in vals[1:]: result = result + v return result if op == '-': if len(args) == 1: return -self._eval(args[0], env) vals = [self._eval(a, env) for a in args] result = vals[0] for v in vals[1:]: result = result - v return result if op == '*': vals = [self._eval(a, env) for a in args] result = vals[0] if vals else 1.0 for v in vals[1:]: result = result * v return result if op == '/': vals = [self._eval(a, env) for a in args] result = vals[0] for v in vals[1:]: result = result / v return result if op == 'mod': a, b = self._eval(args[0], env), self._eval(args[1], env) return a % b if op == 'pow' or op == '**': a, b = self._eval(args[0], env), self._eval(args[1], env) return jnp.power(a, b) # Comparison if op == '<': a, b = self._eval(args[0], env), self._eval(args[1], env) return a < b if op == '>': a, b = self._eval(args[0], env), self._eval(args[1], env) return a > b if op == '<=': a, b = self._eval(args[0], env), self._eval(args[1], env) return a <= b if op == '>=': a, b = self._eval(args[0], env), self._eval(args[1], env) return a >= b if op == '=' or op == '==': a, b = self._eval(args[0], env), self._eval(args[1], env) # For scalar Python types, return Python bool to enable trace-time if if isinstance(a, (int, float)) and isinstance(b, (int, float)): return bool(a == b) return a == b if op == '!=' or op == '<>': a, b = self._eval(args[0], env), self._eval(args[1], env) if isinstance(a, (int, float)) and isinstance(b, (int, float)): return bool(a != b) return a != b # Logic if op == 'and': vals = [self._eval(a, env) for a in args] # Use Python and for concrete Python bools (e.g., shape comparisons) if all(isinstance(v, (bool, np.bool_)) for v in vals): result = True for v in vals: result = result and bool(v) return result # Otherwise use JAX logical_and result = vals[0] for v in vals[1:]: result = jnp.logical_and(result, v) return result if op == 'or': # Lisp-style or: returns first truthy value, not boolean # (or a b c) returns a if a is truthy, else b if b is truthy, else c for arg in args: val = self._eval(arg, env) # Check if value is truthy if val is None: continue if isinstance(val, (bool, np.bool_)): if val: return val continue if isinstance(val, (int, float)): if val: return val continue if hasattr(val, 'shape'): # JAX/numpy array - return it (considered truthy) return val # For other types, check truthiness if val: return val # All values were falsy, return the last one return self._eval(args[-1], env) if args else None if op == 'not': val = self._eval(args[0], env) if isinstance(val, (bool, np.bool_)): return not bool(val) return jnp.logical_not(val) # Math functions if op == 'sqrt': return jnp.sqrt(self._eval(args[0], env)) if op == 'sin': return jnp.sin(self._eval(args[0], env)) if op == 'cos': return jnp.cos(self._eval(args[0], env)) if op == 'tan': return jnp.tan(self._eval(args[0], env)) if op == 'exp': return jnp.exp(self._eval(args[0], env)) if op == 'log': return jnp.log(self._eval(args[0], env)) if op == 'abs': x = self._eval(args[0], env) if isinstance(x, (int, float)): return abs(x) return jnp.abs(x) if op == 'floor': import math x = self._eval(args[0], env) if isinstance(x, (int, float)): return math.floor(x) return jnp.floor(x) if op == 'ceil': import math x = self._eval(args[0], env) if isinstance(x, (int, float)): return math.ceil(x) return jnp.ceil(x) if op == 'round': x = self._eval(args[0], env) if isinstance(x, (int, float)): return round(x) return jnp.round(x) # Frame primitives if op == 'width': return env['width'] if op == 'height': return env['height'] if op == 'channel': frame = self._eval(args[0], env) idx = self._eval(args[1], env) # idx should be a Python int (literal from S-expression) return jax_channel(frame, idx) if op == 'merge-channels' or op == 'rgb': r = self._eval(args[0], env) g = self._eval(args[1], env) b = self._eval(args[2], env) # For scalars (e.g., in map-pixels), return tuple r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) if r_is_scalar and g_is_scalar and b_is_scalar: return (r, g, b) return jax_merge_channels(r, g, b, env['_shape']) if op == 'sample': frame = self._eval(args[0], env) x = self._eval(args[1], env) y = self._eval(args[2], env) return jax_sample(frame, x, y) if op == 'cell-indices': frame = self._eval(args[0], env) cell_size = self._eval(args[1], env) return jax_cell_indices(frame, cell_size) if op == 'pool-frame': frame = self._eval(args[0], env) cell_size = self._eval(args[1], env) return jax_pool_frame(frame, cell_size) # Xector primitives if op == 'iota': n = self._eval(args[0], env) return jax_iota(int(n)) if op == 'repeat': x = self._eval(args[0], env) n = self._eval(args[1], env) return jax_repeat(x, int(n)) if op == 'tile': x = self._eval(args[0], env) n = self._eval(args[1], env) return jax_tile(x, int(n)) if op == 'gather': data = self._eval(args[0], env) indices = self._eval(args[1], env) return jax_gather(data, indices) if op == 'scatter': indices = self._eval(args[0], env) values = self._eval(args[1], env) size = int(self._eval(args[2], env)) return jax_scatter(indices, values, size) if op == 'scatter-add': indices = self._eval(args[0], env) values = self._eval(args[1], env) size = int(self._eval(args[2], env)) return jax_scatter_add(indices, values, size) if op == 'group-reduce': values = self._eval(args[0], env) groups = self._eval(args[1], env) num_groups = int(self._eval(args[2], env)) reduce_op = args[3] if len(args) > 3 else 'mean' if isinstance(reduce_op, Symbol): reduce_op = reduce_op.name return jax_group_reduce(values, groups, num_groups, reduce_op) if op == 'where': cond = self._eval(args[0], env) true_val = self._eval(args[1], env) false_val = self._eval(args[2], env) # Handle None values if true_val is None: true_val = 0.0 if false_val is None: false_val = 0.0 return jax_where(cond, true_val, false_val) if op == 'len' or op == 'length': x = self._eval(args[0], env) if isinstance(x, (list, tuple)): return len(x) return x.size # Beta reductions if op in ('β+', 'beta+', 'sum'): return jnp.sum(self._eval(args[0], env)) if op in ('β*', 'beta*', 'product'): return jnp.prod(self._eval(args[0], env)) if op in ('βmin', 'beta-min'): return jnp.min(self._eval(args[0], env)) if op in ('βmax', 'beta-max'): return jnp.max(self._eval(args[0], env)) if op in ('βmean', 'beta-mean', 'mean'): return jnp.mean(self._eval(args[0], env)) if op in ('βstd', 'beta-std'): return jnp.std(self._eval(args[0], env)) if op in ('βany', 'beta-any'): return jnp.any(self._eval(args[0], env)) if op in ('βall', 'beta-all'): return jnp.all(self._eval(args[0], env)) # Scan (prefix) operations if op in ('scan+', 'scan-add'): x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', None, env) return jax_scan_add(x, axis) if op in ('scan*', 'scan-mul'): x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', None, env) return jax_scan_mul(x, axis) if op == 'scan-max': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', None, env) return jax_scan_max(x, axis) if op == 'scan-min': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', None, env) return jax_scan_min(x, axis) # Outer product operations if op == 'outer': x = self._eval(args[0], env) y = self._eval(args[1], env) op_type = self._eval_kwarg(args, 'op', '*', env) return jax_outer(x, y, op_type) if op in ('outer+', 'outer-add'): x = self._eval(args[0], env) y = self._eval(args[1], env) return jax_outer_add(x, y) if op in ('outer*', 'outer-mul'): x = self._eval(args[0], env) y = self._eval(args[1], env) return jax_outer_mul(x, y) if op == 'outer-max': x = self._eval(args[0], env) y = self._eval(args[1], env) return jax_outer_max(x, y) if op == 'outer-min': x = self._eval(args[0], env) y = self._eval(args[1], env) return jax_outer_min(x, y) # Reduce with axis operations if op == 'reduce-axis': x = self._eval(args[0], env) reduce_op = self._eval_kwarg(args, 'op', 'sum', env) axis = self._eval_kwarg(args, 'axis', 0, env) return jax_reduce_axis(x, reduce_op, axis) if op == 'sum-axis': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', 0, env) return jax_sum_axis(x, axis) if op == 'mean-axis': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', 0, env) return jax_mean_axis(x, axis) if op == 'max-axis': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', 0, env) return jax_max_axis(x, axis) if op == 'min-axis': x = self._eval(args[0], env) axis = self._eval_kwarg(args, 'axis', 0, env) return jax_min_axis(x, axis) # Windowed operations if op == 'window': x = self._eval(args[0], env) size = int(self._eval(args[1], env)) win_op = self._eval_kwarg(args, 'op', 'mean', env) stride = int(self._eval_kwarg(args, 'stride', 1, env)) return jax_window(x, size, win_op, stride) if op == 'window-sum': x = self._eval(args[0], env) size = int(self._eval(args[1], env)) stride = int(self._eval_kwarg(args, 'stride', 1, env)) return jax_window_sum(x, size, stride) if op == 'window-mean': x = self._eval(args[0], env) size = int(self._eval(args[1], env)) stride = int(self._eval_kwarg(args, 'stride', 1, env)) return jax_window_mean(x, size, stride) if op == 'window-max': x = self._eval(args[0], env) size = int(self._eval(args[1], env)) stride = int(self._eval_kwarg(args, 'stride', 1, env)) return jax_window_max(x, size, stride) if op == 'window-min': x = self._eval(args[0], env) size = int(self._eval(args[1], env)) stride = int(self._eval_kwarg(args, 'stride', 1, env)) return jax_window_min(x, size, stride) # Integral image if op == 'integral-image': frame = self._eval(args[0], env) return jax_integral_image(frame) # Convenience - min/max of two values (handle both scalars and arrays) if op == 'min' or op == 'min2': a = self._eval(args[0], env) b = self._eval(args[1], env) # Use Python min/max for scalar Python values to preserve type if isinstance(a, (int, float)) and isinstance(b, (int, float)): return min(a, b) return jnp.minimum(jnp.asarray(a), jnp.asarray(b)) if op == 'max' or op == 'max2': a = self._eval(args[0], env) b = self._eval(args[1], env) # Use Python min/max for scalar Python values to preserve type if isinstance(a, (int, float)) and isinstance(b, (int, float)): return max(a, b) return jnp.maximum(jnp.asarray(a), jnp.asarray(b)) if op == 'clamp': x = self._eval(args[0], env) lo = self._eval(args[1], env) hi = self._eval(args[2], env) return jnp.clip(x, lo, hi) # List operations if op == 'list': return tuple(self._eval(a, env) for a in args) if op == 'nth': seq = self._eval(args[0], env) idx = int(self._eval(args[1], env)) if isinstance(seq, (list, tuple)): return seq[idx] if 0 <= idx < len(seq) else None return seq[idx] # For arrays if op == 'first': seq = self._eval(args[0], env) return seq[0] if len(seq) > 0 else None if op == 'second': seq = self._eval(args[0], env) return seq[1] if len(seq) > 1 else None # Random (JAX-compatible) # Get frame_num for deterministic variation - can be traced, fold_in handles it frame_num = env.get('_frame_num', env.get('frame_num', 0)) # Convert to int32 for fold_in if needed (but keep as JAX array if traced) if frame_num is None: frame_num = 0 elif isinstance(frame_num, (int, float)): frame_num = int(frame_num) # If it's a JAX array, leave it as-is for tracing # Increment operation counter for unique keys within same frame op_counter = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_counter + 1 if op == 'rand' or op == 'rand-x': # For size-based random if args: size = self._eval(args[0], env) if hasattr(size, 'shape'): # For frames (3D), use h*w (channel size), not h*w*c if size.ndim == 3: n = size.shape[0] * size.shape[1] # h * w shape = (n,) else: n = size.size shape = size.shape elif hasattr(size, 'size'): n = size.size shape = (n,) else: n = int(size) shape = (n,) # Use deterministic key that varies with frame seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) return jax.random.uniform(key, shape).flatten() seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) return jax.random.uniform(key, ()) if op == 'randn' or op == 'randn-x': # Normal random if args: size = self._eval(args[0], env) if hasattr(size, 'shape'): # For frames (3D), use h*w (channel size), not h*w*c if size.ndim == 3: n = size.shape[0] * size.shape[1] # h * w else: n = size.size elif hasattr(size, 'size'): n = size.size else: n = int(size) mean = self._eval(args[1], env) if len(args) > 1 else 0.0 std = self._eval(args[2], env) if len(args) > 2 else 1.0 seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) return jax.random.normal(key, (n,)) * std + mean seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) return jax.random.normal(key, ()) if op == 'rand-range' or op == 'core:rand-range': lo = self._eval(args[0], env) hi = self._eval(args[1], env) seed = env.get('_seed', 42) return jax_rand_range(lo, hi, frame_num, op_counter, seed) # ===================================================================== # Convolution operations # ===================================================================== if op == 'blur' or op == 'image:blur': frame = self._eval(args[0], env) radius = self._eval(args[1], env) if len(args) > 1 else 1 # Convert traced value to concrete for kernel size if hasattr(radius, 'item'): radius = int(radius.item()) elif hasattr(radius, '__float__'): radius = int(float(radius)) else: radius = int(radius) return jax_blur(frame, max(1, radius)) if op == 'gaussian': first_arg = self._eval(args[0], env) # Check if first arg is a frame (blur) or scalar (random) if hasattr(first_arg, 'shape') and first_arg.ndim == 3: # Frame - apply gaussian blur sigma = self._eval(args[1], env) if len(args) > 1 else 1.0 radius = max(1, int(sigma * 3)) return jax_blur(first_arg, radius) else: # Scalar args - generate gaussian random value mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg std = self._eval(args[1], env) if len(args) > 1 else 1.0 # Return a single random value op_counter = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_counter + 1 seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) return jax.random.normal(key, ()) * std + mean if op == 'sharpen' or op == 'image:sharpen': frame = self._eval(args[0], env) amount = self._eval(args[1], env) if len(args) > 1 else 1.0 return jax_sharpen(frame, amount) if op == 'edge-detect' or op == 'image:edge-detect': frame = self._eval(args[0], env) return jax_edge_detect(frame) if op == 'emboss': frame = self._eval(args[0], env) return jax_emboss(frame) if op == 'convolve': frame = self._eval(args[0], env) kernel = self._eval(args[1], env) # Convert kernel to array if it's a list if isinstance(kernel, (list, tuple)): kernel = jnp.array(kernel, dtype=jnp.float32) h, w = frame.shape[:2] r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) return jnp.stack([ jnp.clip(r, 0, 255).astype(jnp.uint8), jnp.clip(g, 0, 255).astype(jnp.uint8), jnp.clip(b, 0, 255).astype(jnp.uint8) ], axis=2) if op == 'add-noise': frame = self._eval(args[0], env) amount = self._eval(args[1], env) if len(args) > 1 else 0.1 h, w = frame.shape[:2] # Use frame-varying key for noise op_counter = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_counter + 1 seed = env.get('_seed', 42) key = make_jax_key(seed, frame_num, op_counter) noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1] result = frame.astype(jnp.float32) + noise * amount * 255 return jnp.clip(result, 0, 255).astype(jnp.uint8) if op == 'translate': frame = self._eval(args[0], env) dx = self._eval(args[1], env) dy = self._eval(args[2], env) if len(args) > 2 else 0 h, w = frame.shape[:2] y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) src_x = (x_coords - dx).flatten() src_y = (y_coords - dy).flatten() r, g, b = jax_sample(frame, src_x, src_y) return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) if op == 'image:crop' or op == 'crop': frame = self._eval(args[0], env) x = int(self._eval(args[1], env)) y = int(self._eval(args[2], env)) w = int(self._eval(args[3], env)) h = int(self._eval(args[4], env)) return frame[y:y+h, x:x+w, :] if op == 'dilate': frame = self._eval(args[0], env) size = int(self._eval(args[1], env)) if len(args) > 1 else 3 # Simple dilation using max pooling approximation kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size) h, w = frame.shape[:2] r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size) g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size) b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size) return jnp.stack([ jnp.clip(r, 0, 255).astype(jnp.uint8), jnp.clip(g, 0, 255).astype(jnp.uint8), jnp.clip(b, 0, 255).astype(jnp.uint8) ], axis=2) if op == 'map-rows': frame = self._eval(args[0], env) fn = args[1] # S-expression function h, w = frame.shape[:2] # For each row, apply the function results = [] for row_idx in range(h): row_env = env.copy() row_env['row'] = frame[row_idx, :, :] row_env['row-idx'] = row_idx # Check if fn is a lambda if isinstance(fn, list) and len(fn) >= 2: head = fn[0] if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): params = fn[1] body = fn[2] # Bind lambda params to y and row if len(params) >= 1: param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) row_env[param_name] = row_idx # y if len(params) >= 2: param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) row_env[param_name] = frame[row_idx, :, :] # row result_row = self._eval(body, row_env) results.append(result_row) continue result_row = self._eval(fn, row_env) # If result is a function, call it if callable(result_row): result_row = result_row(row_idx, frame[row_idx, :, :]) results.append(result_row) return jnp.stack(results, axis=0) # ===================================================================== # Text rendering operations # ===================================================================== if op == 'text': frame = self._eval(args[0], env) text_str = self._eval(args[1], env) if isinstance(text_str, Symbol): text_str = text_str.name text_str = str(text_str) # Extract keyword arguments x = self._eval_kwarg(args, 'x', None, env) y = self._eval_kwarg(args, 'y', None, env) font_size = self._eval_kwarg(args, 'font-size', 32, env) font_name = self._eval_kwarg(args, 'font-name', None, env) color = self._eval_kwarg(args, 'color', (255, 255, 255), env) opacity = self._eval_kwarg(args, 'opacity', 1.0, env) align = self._eval_kwarg(args, 'align', 'left', env) valign = self._eval_kwarg(args, 'valign', 'top', env) shadow = self._eval_kwarg(args, 'shadow', False, env) shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env) shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env) fit = self._eval_kwarg(args, 'fit', False, env) width = self._eval_kwarg(args, 'width', None, env) height = self._eval_kwarg(args, 'height', None, env) # Handle color as list/tuple if isinstance(color, (list, tuple)): color = tuple(int(c) for c in color[:3]) if isinstance(shadow_color, (list, tuple)): shadow_color = tuple(int(c) for c in shadow_color[:3]) h, w_frame = frame.shape[:2] # Default position to 0,0 or center based on alignment if x is None: if align == 'center': x = w_frame // 2 elif align == 'right': x = w_frame else: x = 0 if y is None: if valign == 'middle': y = h // 2 elif valign == 'bottom': y = h else: y = 0 # Auto-fit text to bounds if fit and width is not None and height is not None: font_size = jax_fit_text_size(text_str, int(width), int(height), font_name, min_size=8, max_size=200) return jax_text_render(frame, text_str, int(x), int(y), font_name=font_name, font_size=int(font_size), color=color, opacity=float(opacity), align=str(align), valign=str(valign), shadow=bool(shadow), shadow_color=shadow_color, shadow_offset=int(shadow_offset)) if op == 'text-size': text_str = self._eval(args[0], env) if isinstance(text_str, Symbol): text_str = text_str.name text_str = str(text_str) font_size = self._eval_kwarg(args, 'font-size', 32, env) font_name = self._eval_kwarg(args, 'font-name', None, env) return jax_text_size(text_str, font_name, int(font_size)) if op == 'fit-text-size': text_str = self._eval(args[0], env) if isinstance(text_str, Symbol): text_str = text_str.name text_str = str(text_str) max_width = int(self._eval(args[1], env)) max_height = int(self._eval(args[2], env)) font_name = self._eval_kwarg(args, 'font-name', None, env) return jax_fit_text_size(text_str, max_width, max_height, font_name) # ===================================================================== # Color operations # ===================================================================== if op == 'rgb->hsv' or op == 'rgb-to-hsv': # Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple if len(args) == 1: c = self._eval(args[0], env) if isinstance(c, tuple) and len(c) == 3: r, g, b = c else: # Assume it's a list-like r, g, b = c[0], c[1], c[2] else: r = self._eval(args[0], env) g = self._eval(args[1], env) b = self._eval(args[2], env) return jax_rgb_to_hsv(r, g, b) if op == 'hsv->rgb' or op == 'hsv-to-rgb': # Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list) if len(args) == 1: hsv = self._eval(args[0], env) if isinstance(hsv, (tuple, list)) and len(hsv) >= 3: h, s, v = hsv[0], hsv[1], hsv[2] else: h, s, v = hsv[0], hsv[1], hsv[2] else: h = self._eval(args[0], env) s = self._eval(args[1], env) v = self._eval(args[2], env) return jax_hsv_to_rgb(h, s, v) if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness': frame = self._eval(args[0], env) amount = self._eval(args[1], env) return jax_adjust_brightness(frame, amount) if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast': frame = self._eval(args[0], env) factor = self._eval(args[1], env) return jax_adjust_contrast(frame, factor) if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation': frame = self._eval(args[0], env) factor = self._eval(args[1], env) return jax_adjust_saturation(frame, factor) if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift': frame = self._eval(args[0], env) degrees = self._eval(args[1], env) return jax_shift_hue(frame, degrees) if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img': frame = self._eval(args[0], env) return jax_invert(frame) if op == 'posterize' or op == 'color_ops:posterize': frame = self._eval(args[0], env) levels = self._eval(args[1], env) return jax_posterize(frame, levels) if op == 'threshold' or op == 'color_ops:threshold': frame = self._eval(args[0], env) level = self._eval(args[1], env) invert = self._eval(args[2], env) if len(args) > 2 else False return jax_threshold(frame, level, invert) if op == 'sepia' or op == 'color_ops:sepia': frame = self._eval(args[0], env) return jax_sepia(frame) if op == 'grayscale' or op == 'image:grayscale': frame = self._eval(args[0], env) return jax_grayscale(frame) # ===================================================================== # Geometry operations # ===================================================================== if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img': frame = self._eval(args[0], env) direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal' if direction == 'vertical' or direction == 'v': return jax_flip_vertical(frame) return jax_flip_horizontal(frame) if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v': frame = self._eval(args[0], env) return jax_flip_vertical(frame) if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img': frame = self._eval(args[0], env) angle = self._eval(args[1], env) return jax_rotate(frame, angle) if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img': frame = self._eval(args[0], env) scale_x = self._eval(args[1], env) scale_y = self._eval(args[2], env) if len(args) > 2 else None return jax_scale(frame, scale_x, scale_y) if op == 'resize' or op == 'image:resize': frame = self._eval(args[0], env) new_w = self._eval(args[1], env) new_h = self._eval(args[2], env) return jax_resize(frame, new_w, new_h) # ===================================================================== # Geometry distortion effects # ===================================================================== if op == 'geometry:fisheye-coords' or op == 'fisheye': # Signature: (w h strength cx cy zoom_correct) or (frame strength) first_arg = self._eval(args[0], env) if not hasattr(first_arg, 'shape'): # (w h strength cx cy zoom_correct) signature w = int(first_arg) h = int(self._eval(args[1], env)) strength = self._eval(args[2], env) if len(args) > 2 else 0.5 cx = self._eval(args[3], env) if len(args) > 3 else w / 2 cy = self._eval(args[4], env) if len(args) > 4 else h / 2 frame = None else: frame = first_arg strength = self._eval(args[1], env) if len(args) > 1 else 0.5 h, w = frame.shape[:2] cx, cy = w / 2, h / 2 max_r = jnp.sqrt(float(cx*cx + cy*cy)) y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) dx = x_coords - cx dy = y_coords - cy r = jnp.sqrt(dx*dx + dy*dy) theta = jnp.arctan2(dy, dx) # Fisheye distortion r_new = r + strength * r * (1 - r / max_r) src_x = r_new * jnp.cos(theta) + cx src_y = r_new * jnp.sin(theta) + cy if frame is None: return {'x': src_x, 'y': src_y} r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) if op == 'geometry:swirl-coords' or op == 'swirl': first_arg = self._eval(args[0], env) if not hasattr(first_arg, 'shape'): w = int(first_arg) h = int(self._eval(args[1], env)) amount = self._eval(args[2], env) if len(args) > 2 else 1.0 frame = None else: frame = first_arg amount = self._eval(args[1], env) if len(args) > 1 else 1.0 h, w = frame.shape[:2] cx, cy = w / 2, h / 2 max_r = jnp.sqrt(float(cx*cx + cy*cy)) y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) dx = x_coords - cx dy = y_coords - cy r = jnp.sqrt(dx*dx + dy*dy) theta = jnp.arctan2(dy, dx) swirl_angle = amount * (1 - r / max_r) new_theta = theta + swirl_angle src_x = r * jnp.cos(new_theta) + cx src_y = r * jnp.sin(new_theta) + cy if frame is None: return {'x': src_x, 'y': src_y} r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) # Wave effect (frame-first signature for simple usage) if op == 'wave-distort': first_arg = self._eval(args[0], env) frame = first_arg amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0 amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0 freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1 freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1 h, w = frame.shape[:2] y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y) src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x) r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) if op == 'geometry:ripple-displace' or op == 'ripple': # Match Python prim_ripple_displace signature: # (w h freq amp cx cy decay phase) or (frame ...) first_arg = self._eval(args[0], env) if not hasattr(first_arg, 'shape'): # Coordinate-only mode: (w h freq amp cx cy decay phase) w = int(first_arg) h = int(self._eval(args[1], env)) freq = self._eval(args[2], env) if len(args) > 2 else 5.0 amp = self._eval(args[3], env) if len(args) > 3 else 10.0 cx = self._eval(args[4], env) if len(args) > 4 else w / 2 cy = self._eval(args[5], env) if len(args) > 5 else h / 2 decay = self._eval(args[6], env) if len(args) > 6 else 0.0 phase = self._eval(args[7], env) if len(args) > 7 else 0.0 frame = None else: # Frame mode: (frame :amplitude A :frequency F :center_x CX ...) frame = first_arg h, w = frame.shape[:2] # Parse keyword args amp = 10.0 freq = 5.0 cx = w / 2 cy = h / 2 decay = 0.0 phase = 0.0 i = 1 while i < len(args): if isinstance(args[i], Keyword): kw = args[i].name val = self._eval(args[i + 1], env) if i + 1 < len(args) else None if kw == 'amplitude': amp = val elif kw == 'frequency': freq = val elif kw == 'center_x': cx = val * w if val <= 1 else val # normalized or absolute elif kw == 'center_y': cy = val * h if val <= 1 else val elif kw == 'decay': decay = val elif kw == 'speed': # speed affects phase via time t = env.get('t', 0) phase = t * val * 2 * jnp.pi elif kw == 'phase': phase = val i += 2 else: i += 1 y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) dx = x_coords - cx dy = y_coords - cy dist = jnp.sqrt(dx*dx + dy*dy) # Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp max_dim = jnp.maximum(w, h) ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp # Apply decay (when decay=0, exp(0)=1 so no effect) decay_factor = jnp.exp(-decay * dist / max_dim) ripple = ripple * decay_factor # Radial displacement - use ADDITION to match Python prim_ripple_displace # Python (primitives.py line 2890-2891): # map_x = x_coords + ripple * norm_dx # map_y = y_coords + ripple * norm_dy # where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle) angle = jnp.arctan2(dy, dx) src_x = x_coords + ripple * jnp.cos(angle) src_y = y_coords + ripple * jnp.sin(angle) if frame is None: return {'x': src_x, 'y': src_y} # Sample using bilinear interpolation (jax_sample clamps coords, # matching OpenCV's default remap behavior) r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) if op == 'geometry:coords-x' or op == 'coords-x': # Extract x coordinates from coord dict coords = self._eval(args[0], env) if isinstance(coords, dict): return coords.get('x', coords.get('map_x')) return coords[0] if isinstance(coords, (list, tuple)) else coords if op == 'geometry:coords-y' or op == 'coords-y': # Extract y coordinates from coord dict coords = self._eval(args[0], env) if isinstance(coords, dict): return coords.get('y', coords.get('map_y')) return coords[1] if isinstance(coords, (list, tuple)) else coords if op == 'geometry:remap' or op == 'remap': # Remap image using coordinate maps: (frame map_x map_y) # OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords frame = self._eval(args[0], env) map_x = self._eval(args[1], env) map_y = self._eval(args[2], env) h, w = frame.shape[:2] # Flatten coordinate maps src_x = map_x.flatten() src_y = map_y.flatten() # Sample using bilinear interpolation (jax_sample clamps coords internally, # matching OpenCV's default behavior) r_out, g_out, b_out = jax_sample(frame, src_x, src_y) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope': # Two signatures: (frame segments) or (w h segments cx cy) if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'): # (w h segments cx cy) signature w = int(self._eval(args[0], env)) h = int(self._eval(args[1], env)) segments = int(self._eval(args[2], env)) if len(args) > 2 else 6 cx = self._eval(args[3], env) if len(args) > 3 else w / 2 cy = self._eval(args[4], env) if len(args) > 4 else h / 2 frame = None else: frame = self._eval(args[0], env) segments = int(self._eval(args[1], env)) if len(args) > 1 else 6 h, w = frame.shape[:2] cx, cy = w / 2, h / 2 y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) dx = x_coords - cx dy = y_coords - cy r = jnp.sqrt(dx*dx + dy*dy) theta = jnp.arctan2(dy, dx) # Mirror into segments segment_angle = 2 * jnp.pi / segments theta_mod = theta % segment_angle theta_mirror = jnp.where( (jnp.floor(theta / segment_angle) % 2) == 0, theta_mod, segment_angle - theta_mod ) src_x = r * jnp.cos(theta_mirror) + cx src_y = r * jnp.sin(theta_mirror) + cy if frame is None: # Return coordinate arrays return {'x': src_x, 'y': src_y} r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) # Geometry coordinate extraction if op == 'geometry:coords-x': coords = self._eval(args[0], env) if isinstance(coords, dict): return coords['x'] return coords[0] if isinstance(coords, tuple) else coords if op == 'geometry:coords-y': coords = self._eval(args[0], env) if isinstance(coords, dict): return coords['y'] return coords[1] if isinstance(coords, tuple) else coords if op == 'geometry:remap' or op == 'remap': frame = self._eval(args[0], env) x_coords = self._eval(args[1], env) y_coords = self._eval(args[2], env) h, w = frame.shape[:2] r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten()) return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) # ===================================================================== # Blending operations # ===================================================================== if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images': frame1 = self._eval(args[0], env) frame2 = self._eval(args[1], env) alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 return jax_blend(frame1, frame2, alpha) if op == 'blend-add': frame1 = self._eval(args[0], env) frame2 = self._eval(args[1], env) return jax_blend_add(frame1, frame2) if op == 'blend-multiply': frame1 = self._eval(args[0], env) frame2 = self._eval(args[1], env) return jax_blend_multiply(frame1, frame2) if op == 'blend-screen': frame1 = self._eval(args[0], env) frame2 = self._eval(args[1], env) return jax_blend_screen(frame1, frame2) if op == 'blend-overlay': frame1 = self._eval(args[0], env) frame2 = self._eval(args[1], env) return jax_blend_overlay(frame1, frame2) # ===================================================================== # Image dimension queries (namespaced aliases) # ===================================================================== if op == 'image:width': if args: frame = self._eval(args[0], env) return frame.shape[1] # width is second dimension (h, w, c) return env['width'] if op == 'image:height': if args: frame = self._eval(args[0], env) return frame.shape[0] # height is first dimension (h, w, c) return env['height'] # ===================================================================== # Utility # ===================================================================== if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?': x = self._eval(args[0], env) return jax_is_nil(x) # ===================================================================== # Xector channel operations (shortcuts) # ===================================================================== if op == 'red': val = self._eval(args[0], env) # Works on frames or pixel tuples if isinstance(val, tuple): return val[0] elif hasattr(val, 'shape') and val.ndim == 3: return jax_channel(val, 0) else: return val # Assume it's already a channel if op == 'green': val = self._eval(args[0], env) if isinstance(val, tuple): return val[1] elif hasattr(val, 'shape') and val.ndim == 3: return jax_channel(val, 1) else: return val if op == 'blue': val = self._eval(args[0], env) if isinstance(val, tuple): return val[2] elif hasattr(val, 'shape') and val.ndim == 3: return jax_channel(val, 2) else: return val if op == 'gray' or op == 'luminance': val = self._eval(args[0], env) # Handle tuple (r, g, b) from map-pixels if isinstance(val, tuple) and len(val) == 3: r, g, b = val return r * 0.299 + g * 0.587 + b * 0.114 # Handle frame frame = val r = frame[:, :, 0].flatten().astype(jnp.float32) g = frame[:, :, 1].flatten().astype(jnp.float32) b = frame[:, :, 2].flatten().astype(jnp.float32) return r * 0.299 + g * 0.587 + b * 0.114 if op == 'rgb': r = self._eval(args[0], env) g = self._eval(args[1], env) b = self._eval(args[2], env) # For scalars (e.g., in map-pixels), return tuple r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) if r_is_scalar and g_is_scalar and b_is_scalar: return (r, g, b) return jax_merge_channels(r, g, b, env['_shape']) # ===================================================================== # Coordinate operations # ===================================================================== if op == 'x-coords': frame = self._eval(args[0], env) h, w = frame.shape[:2] return jnp.tile(jnp.arange(w, dtype=jnp.float32), h) if op == 'y-coords': frame = self._eval(args[0], env) h, w = frame.shape[:2] return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) if op == 'dist-from-center': frame = self._eval(args[0], env) h, w = frame.shape[:2] cx, cy = w / 2, h / 2 x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy return jnp.sqrt(x*x + y*y) # ===================================================================== # Alpha operations (element-wise on xectors) # ===================================================================== if op == 'α/' or op == 'alpha/': a = self._eval(args[0], env) b = self._eval(args[1], env) return a / b if op == 'α+' or op == 'alpha+': vals = [self._eval(a, env) for a in args] result = vals[0] for v in vals[1:]: result = result + v return result if op == 'α*' or op == 'alpha*': vals = [self._eval(a, env) for a in args] result = vals[0] for v in vals[1:]: result = result * v return result if op == 'α-' or op == 'alpha-': if len(args) == 1: return -self._eval(args[0], env) a = self._eval(args[0], env) b = self._eval(args[1], env) return a - b if op == 'αclamp' or op == 'alpha-clamp': x = self._eval(args[0], env) lo = self._eval(args[1], env) hi = self._eval(args[2], env) return jnp.clip(x, lo, hi) if op == 'αmin' or op == 'alpha-min': a = self._eval(args[0], env) b = self._eval(args[1], env) return jnp.minimum(a, b) if op == 'αmax' or op == 'alpha-max': a = self._eval(args[0], env) b = self._eval(args[1], env) return jnp.maximum(a, b) if op == 'αsqrt' or op == 'alpha-sqrt': return jnp.sqrt(self._eval(args[0], env)) if op == 'αsin' or op == 'alpha-sin': return jnp.sin(self._eval(args[0], env)) if op == 'αcos' or op == 'alpha-cos': return jnp.cos(self._eval(args[0], env)) if op == 'αmod' or op == 'alpha-mod': a = self._eval(args[0], env) b = self._eval(args[1], env) return a % b if op == 'α²' or op == 'αsq' or op == 'alpha-sq': x = self._eval(args[0], env) return x * x if op == 'α<' or op == 'alpha<': a = self._eval(args[0], env) b = self._eval(args[1], env) return a < b if op == 'α>' or op == 'alpha>': a = self._eval(args[0], env) b = self._eval(args[1], env) return a > b if op == 'α<=' or op == 'alpha<=': a = self._eval(args[0], env) b = self._eval(args[1], env) return a <= b if op == 'α>=' or op == 'alpha>=': a = self._eval(args[0], env) b = self._eval(args[1], env) return a >= b if op == 'α=' or op == 'alpha=': a = self._eval(args[0], env) b = self._eval(args[1], env) return a == b if op == 'αfloor' or op == 'alpha-floor': return jnp.floor(self._eval(args[0], env)) if op == 'αceil' or op == 'alpha-ceil': return jnp.ceil(self._eval(args[0], env)) if op == 'αround' or op == 'alpha-round': return jnp.round(self._eval(args[0], env)) if op == 'αabs' or op == 'alpha-abs': return jnp.abs(self._eval(args[0], env)) if op == 'αexp' or op == 'alpha-exp': return jnp.exp(self._eval(args[0], env)) if op == 'αlog' or op == 'alpha-log': return jnp.log(self._eval(args[0], env)) if op == 'αor' or op == 'alpha-or': # Element-wise logical OR a = self._eval(args[0], env) b = self._eval(args[1], env) return jnp.logical_or(a, b) if op == 'αand' or op == 'alpha-and': # Element-wise logical AND a = self._eval(args[0], env) b = self._eval(args[1], env) return jnp.logical_and(a, b) if op == 'αnot' or op == 'alpha-not': # Element-wise logical NOT return jnp.logical_not(self._eval(args[0], env)) if op == 'αxor' or op == 'alpha-xor': # Element-wise logical XOR a = self._eval(args[0], env) b = self._eval(args[1], env) return jnp.logical_xor(a, b) # ===================================================================== # Threading/arrow operations # ===================================================================== if op == '->': # Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b) val = self._eval(args[0], env) for form in args[1:]: if isinstance(form, list): # Insert val as first argument fn_name = form[0].name if isinstance(form[0], Symbol) else form[0] new_args = [val] + [self._eval(a, env) for a in form[1:]] val = self._eval_call(fn_name, [val] + form[1:], env) else: # Simple function call fn_name = form.name if isinstance(form, Symbol) else form val = self._eval_call(fn_name, [args[0]], env) return val # ===================================================================== # Range and iteration # ===================================================================== if op == 'range': if len(args) == 1: end = int(self._eval(args[0], env)) return list(range(end)) elif len(args) == 2: start = int(self._eval(args[0], env)) end = int(self._eval(args[1], env)) return list(range(start, end)) else: start = int(self._eval(args[0], env)) end = int(self._eval(args[1], env)) step = int(self._eval(args[2], env)) return list(range(start, end, step)) if op == 'reduce' or op == 'fold': # (reduce seq init fn) - left fold seq = self._eval(args[0], env) acc = self._eval(args[1], env) fn = args[2] # Lambda S-expression # Handle lambda if isinstance(fn, list) and len(fn) >= 3: head = fn[0] if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): params = fn[1] body = fn[2] for item in seq: fn_env = env.copy() if len(params) >= 1: param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) fn_env[param_name] = acc if len(params) >= 2: param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) fn_env[param_name] = item acc = self._eval(body, fn_env) return acc # Fallback - try evaluating fn and calling it fn_eval = self._eval(fn, env) if callable(fn_eval): for item in seq: acc = fn_eval(acc, item) return acc if op == 'fold-indexed': # (fold-indexed seq init fn) - fold with index # fn takes (acc item index) or (acc item index cursor) for typography seq = self._eval(args[0], env) acc = self._eval(args[1], env) fn = args[2] # Lambda S-expression # Handle lambda if isinstance(fn, list) and len(fn) >= 3: head = fn[0] if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): params = fn[1] body = fn[2] for idx, item in enumerate(seq): fn_env = env.copy() if len(params) >= 1: param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) fn_env[param_name] = acc if len(params) >= 2: param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) fn_env[param_name] = item if len(params) >= 3: param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) fn_env[param_name] = idx acc = self._eval(body, fn_env) return acc # Fallback fn_eval = self._eval(fn, env) if callable(fn_eval): for idx, item in enumerate(seq): acc = fn_eval(acc, item, idx) return acc # ===================================================================== # Map-pixels (apply function to each pixel) # ===================================================================== if op == 'map-pixels': frame = self._eval(args[0], env) fn = args[1] # Lambda or S-expression h, w = frame.shape[:2] # Extract channels r = frame[:, :, 0].flatten().astype(jnp.float32) g = frame[:, :, 1].flatten().astype(jnp.float32) b = frame[:, :, 2].flatten().astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) # Set up pixel environment pixel_env = env.copy() pixel_env['r'] = r pixel_env['g'] = g pixel_env['b'] = b pixel_env['x'] = x_coords pixel_env['y'] = y_coords # Also provide c (color) as a tuple for lambda (x y c) style pixel_env['c'] = (r, g, b) # If fn is a lambda, we need to handle it specially if isinstance(fn, list) and len(fn) >= 2: head = fn[0] if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): # Lambda: (lambda (x y c) body) params = fn[1] body = fn[2] # Bind parameters if len(params) >= 1: param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) pixel_env[param_name] = x_coords if len(params) >= 2: param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) pixel_env[param_name] = y_coords if len(params) >= 3: param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) pixel_env[param_name] = (r, g, b) result = self._eval(body, pixel_env) else: result = self._eval(fn, pixel_env) else: result = self._eval(fn, pixel_env) if isinstance(result, tuple) and len(result) == 3: nr, ng, nb = result return jax_merge_channels(nr, ng, nb, (h, w)) elif hasattr(result, 'shape') and result.ndim == 3: return result else: # Single channel result if hasattr(result, 'flatten'): result = result.flatten() gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8) return jnp.stack([gray, gray, gray], axis=2) # ===================================================================== # State operations (return unchanged for stateless JIT) # ===================================================================== if op == 'state-get': key = self._eval(args[0], env) default = self._eval(args[1], env) if len(args) > 1 else None return default # State not supported in JIT, return default if op == 'state-set': return None # No-op in JIT # ===================================================================== # Cell/grid operations # ===================================================================== if op == 'local-x-norm': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] x = jnp.tile(jnp.arange(w), h) return (x % cell_size) / max(1, cell_size - 1) if op == 'local-y-norm': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] y = jnp.repeat(jnp.arange(h), w) return (y % cell_size) / max(1, cell_size - 1) if op == 'local-x': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] x = jnp.tile(jnp.arange(w), h) return (x % cell_size).astype(jnp.float32) if op == 'local-y': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] y = jnp.repeat(jnp.arange(h), w) return (y % cell_size).astype(jnp.float32) if op == 'cell-row': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] y = jnp.repeat(jnp.arange(h), w) return jnp.floor(y / cell_size) if op == 'cell-col': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) h, w = frame.shape[:2] x = jnp.tile(jnp.arange(w), h) return jnp.floor(x / cell_size) if op == 'num-rows': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) return frame.shape[0] // cell_size if op == 'num-cols': frame = self._eval(args[0], env) cell_size = int(self._eval(args[1], env)) return frame.shape[1] // cell_size # ===================================================================== # Control flow # ===================================================================== if op == 'cond': # (cond (test1 expr1) (test2 expr2) ... (else exprN)) # For JAX compatibility, build a nested jnp.where structure # Start from the else clause and work backwards # Collect clauses clauses = [] else_expr = None for clause in args: if isinstance(clause, list) and len(clause) >= 2: test = clause[0] if isinstance(test, Symbol) and test.name == 'else': else_expr = clause[1] else: clauses.append((test, clause[1])) # If no else, default to None/0 if else_expr is not None: result = self._eval(else_expr, env) else: result = 0 # Build nested where from last to first for test_expr, val_expr in reversed(clauses): cond_val = self._eval(test_expr, env) then_val = self._eval(val_expr, env) # Check if condition is array or scalar if hasattr(cond_val, 'shape') and cond_val.shape != (): # Array condition - use jnp.where result = jnp.where(cond_val, then_val, result) else: # Scalar - can use Python if if cond_val: result = then_val return result if op == 'set!' or op == 'set': # Mutation - not really supported in JAX, but we can update env var = args[0].name if isinstance(args[0], Symbol) else str(args[0]) val = self._eval(args[1], env) env[var] = val return val if op == 'begin' or op == 'do': # Evaluate all expressions, return last result = None for expr in args: result = self._eval(expr, env) return result # ===================================================================== # Additional math # ===================================================================== if op == 'sq' or op == 'square': x = self._eval(args[0], env) return x * x if op == 'lerp': a = self._eval(args[0], env) b = self._eval(args[1], env) t = self._eval(args[2], env) return a * (1 - t) + b * t if op == 'smoothstep': edge0 = self._eval(args[0], env) edge1 = self._eval(args[1], env) x = self._eval(args[2], env) t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1) return t * t * (3 - 2 * t) if op == 'atan2': y = self._eval(args[0], env) x = self._eval(args[1], env) return jnp.arctan2(y, x) if op == 'fract' or op == 'frac': x = self._eval(args[0], env) return x - jnp.floor(x) # ===================================================================== # Frame copy and construction operations # ===================================================================== if op == 'pixel': # Get pixel at (x, y) from frame frame = self._eval(args[0], env) x = self._eval(args[1], env) y = self._eval(args[2], env) h, w = frame.shape[:2] # Convert to int and clip to bounds if isinstance(x, (int, float)): x = max(0, min(int(x), w - 1)) else: x = jnp.clip(x, 0, w - 1).astype(jnp.int32) if isinstance(y, (int, float)): y = max(0, min(int(y), h - 1)) else: y = jnp.clip(y, 0, h - 1).astype(jnp.int32) r = frame[y, x, 0] g = frame[y, x, 1] b = frame[y, x, 2] return (r, g, b) if op == 'copy': frame = self._eval(args[0], env) return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame) if op == 'make-image': w = int(self._eval(args[0], env)) h = int(self._eval(args[1], env)) if len(args) > 2: color = self._eval(args[2], env) if isinstance(color, (list, tuple)): r, g, b = int(color[0]), int(color[1]), int(color[2]) else: r = g = b = int(color) else: r = g = b = 0 img = jnp.zeros((h, w, 3), dtype=jnp.uint8) img = img.at[:, :, 0].set(r) img = img.at[:, :, 1].set(g) img = img.at[:, :, 2].set(b) return img if op == 'paste': dest = self._eval(args[0], env) src = self._eval(args[1], env) x = int(self._eval(args[2], env)) y = int(self._eval(args[3], env)) sh, sw = src.shape[:2] dh, dw = dest.shape[:2] # Clip to dest bounds x1, y1 = max(0, x), max(0, y) x2, y2 = min(dw, x + sw), min(dh, y + sh) sx1, sy1 = x1 - x, y1 - y sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1) result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest) result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :]) return result # ===================================================================== # Blending operations # ===================================================================== if op == 'blending:blend-images' or op == 'blend-images': a = self._eval(args[0], env) b = self._eval(args[1], env) alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 return jax_blend(a, b, alpha) if op == 'blending:blend-mode' or op == 'blend-mode': a = self._eval(args[0], env) b = self._eval(args[1], env) mode = self._eval(args[2], env) if len(args) > 2 else 'add' if mode == 'add': return jax_blend_add(a, b) elif mode == 'multiply': return jax_blend_multiply(a, b) elif mode == 'screen': return jax_blend_screen(a, b) elif mode == 'overlay': return jax_blend_overlay(a, b) elif mode == 'lighten': return jnp.maximum(a, b) elif mode == 'darken': return jnp.minimum(a, b) elif mode == 'difference': return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8) else: return jax_blend(a, b, 0.5) # ===================================================================== # Geometry coordinate operations # ===================================================================== if op == 'geometry:wave-coords' or op == 'wave-coords': w = int(self._eval(args[0], env)) h = int(self._eval(args[1], env)) axis = self._eval(args[2], env) if len(args) > 2 else 'x' freq = self._eval(args[3], env) if len(args) > 3 else 1.0 amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0 phase = self._eval(args[5], env) if len(args) > 5 else 0.0 y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) if axis == 'x' or axis == 'horizontal': # Wave displaces X based on Y offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) src_x = x_coords + offset src_y = y_coords elif axis == 'y' or axis == 'vertical': # Wave displaces Y based on X offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) src_x = x_coords src_y = y_coords + offset else: # both offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) src_x = x_coords + offset_x src_y = y_coords + offset_y return {'x': src_x, 'y': src_y} if op == 'geometry:coords-x' or op == 'coords-x': coords = self._eval(args[0], env) return coords['x'] if op == 'geometry:coords-y' or op == 'coords-y': coords = self._eval(args[0], env) return coords['y'] if op == 'geometry:remap' or op == 'remap': frame = self._eval(args[0], env) x = self._eval(args[1], env) y = self._eval(args[2], env) h, w = frame.shape[:2] r, g, b = jax_sample(frame, x.flatten(), y.flatten()) return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) ], axis=2) # ===================================================================== # Glitch effects # ===================================================================== if op == 'pixelsort': frame = self._eval(args[0], env) sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness' thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50 thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200 angle = int(self._eval(args[4], env)) if len(args) > 4 else 0 reverse = self._eval(args[5], env) if len(args) > 5 else False h, w = frame.shape[:2] result = frame.copy() # Get luminance for thresholding lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 + frame[:, :, 1].astype(jnp.float32) * 0.587 + frame[:, :, 2].astype(jnp.float32) * 0.114) # Sort each row for y in range(h): row_lum = lum[y, :] row = frame[y, :, :] # Find mask of pixels to sort mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi) # Get indices where we should sort sort_indices = jnp.where(mask, jnp.arange(w), -1) # Simple sort by luminance for the row if sort_by == 'lightness': sort_key = row_lum elif sort_by == 'hue': # Approximate hue from RGB sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32), row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32))) else: sort_key = row_lum # Sort pixels in masked region sorted_indices = jnp.argsort(sort_key) if reverse: sorted_indices = sorted_indices[::-1] # Apply partial sort (only where mask is true) # This is a simplified version - full pixelsort is more complex result = result.at[y, :, :].set(row[sorted_indices]) return result if op == 'datamosh': frame = self._eval(args[0], env) prev = self._eval(args[1], env) block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32 corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3 max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50 color_corrupt = self._eval(args[5], env) if len(args) > 5 else True h, w = frame.shape[:2] # Use deterministic random for JIT with frame variation seed = env.get('_seed', 42) op_counter = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_counter + 1 key = make_jax_key(seed, frame_num, op_counter) num_blocks_y = h // block_size num_blocks_x = w // block_size total_blocks = num_blocks_y * num_blocks_x # Pre-generate all random values at once (vectorized) key, k1, k2, k3, k4, k5 = jax.random.split(key, 6) corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1) offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1) channels = jax.random.randint(k4, (total_blocks,), 0, 3) color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51) # Create coordinate grids for blocks by_grid = jnp.arange(num_blocks_y) bx_grid = jnp.arange(num_blocks_x) # Create block coordinate arrays by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...] bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...] # Create pixel coordinate grids y_coords, x_coords = jnp.mgrid[:h, :w] # Determine which block each pixel belongs to pixel_block_y = y_coords // block_size pixel_block_x = x_coords // block_size pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x # Clamp to valid block indices (for pixels outside the block grid) pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1) # Get the corrupt mask, offsets for each pixel's block pixel_corrupt = corrupt_mask[pixel_block_idx] pixel_offset_y = offsets_y[pixel_block_idx] pixel_offset_x = offsets_x[pixel_block_idx] # Calculate source coordinates with offset (clamped) src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1) src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1) # Sample from previous frame at offset positions prev_sampled = prev[src_y, src_x, :] # Where corrupt mask is true, use prev_sampled; else use frame result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame) # Apply color corruption to corrupted blocks if color_corrupt: pixel_channel = channels[pixel_block_idx] pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16) # Create per-channel shift arrays (only shift the selected channel) shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0) shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0) shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0) result_int = result.astype(jnp.int16) result_int = result_int.at[:, :, 0].add(shift_r) result_int = result_int.at[:, :, 1].add(shift_g) result_int = result_int.at[:, :, 2].add(shift_b) result = jnp.clip(result_int, 0, 255).astype(jnp.uint8) return result # ===================================================================== # ASCII Art Operations (using pre-rendered font atlas) # ===================================================================== if op == 'cell-sample': # (cell-sample frame char_size) -> (colors, luminances) # Downsample frame into cells, return average colors and luminances frame = self._eval(args[0], env) char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8 h, w = frame.shape[:2] num_rows = h // char_size num_cols = w // char_size # Crop to exact multiple of char_size cropped = frame[:num_rows * char_size, :num_cols * char_size, :] # Reshape to (num_rows, char_size, num_cols, char_size, 3) reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) # Average over char_size dimensions -> (num_rows, num_cols, 3) colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) # Compute luminance per cell colors_float = colors.astype(jnp.float32) luminances = (0.299 * colors_float[:, :, 0] + 0.587 * colors_float[:, :, 1] + 0.114 * colors_float[:, :, 2]) / 255.0 return (colors, luminances) if op == 'luminance-to-chars': # (luminance-to-chars luminances alphabet contrast) -> char_indices # Map luminance values to character indices luminances = self._eval(args[0], env) alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard' contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5 # Get alphabet string alpha_str = _get_alphabet_string(alphabet) num_chars = len(alpha_str) # Apply contrast lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) # Map to character indices (0 = darkest, num_chars-1 = brightest) char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) char_indices = jnp.clip(char_indices, 0, num_chars - 1) return char_indices if op == 'render-char-grid': # (render-char-grid frame chars colors char_size color_mode background_color invert_colors) # Render character grid using font atlas frame = self._eval(args[0], env) char_indices = self._eval(args[1], env) colors = self._eval(args[2], env) char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8 color_mode = self._eval(args[4], env) if len(args) > 4 else 'color' background_color = self._eval(args[5], env) if len(args) > 5 else 'black' invert_colors = self._eval(args[6], env) if len(args) > 6 else 0 h, w = frame.shape[:2] num_rows, num_cols = char_indices.shape # Get the alphabet used (stored in env or default) alphabet = env.get('_ascii_alphabet', 'standard') alpha_str = _get_alphabet_string(alphabet) # Get or create font atlas font_atlas = _create_font_atlas(alpha_str, char_size) # Parse background color if background_color == 'black': bg = jnp.array([0, 0, 0], dtype=jnp.uint8) elif background_color == 'white': bg = jnp.array([255, 255, 255], dtype=jnp.uint8) else: # Try to parse hex color try: if background_color.startswith('#'): bg_hex = background_color[1:] bg = jnp.array([int(bg_hex[0:2], 16), int(bg_hex[2:4], 16), int(bg_hex[4:6], 16)], dtype=jnp.uint8) else: bg = jnp.array([0, 0, 0], dtype=jnp.uint8) except: bg = jnp.array([0, 0, 0], dtype=jnp.uint8) # Create output image starting with background output_h = num_rows * char_size output_w = num_cols * char_size result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy() # Gather characters from atlas based on indices # char_indices shape: (num_rows, num_cols) # font_atlas shape: (num_chars, char_size, char_size, 3) # Convert numpy atlas to JAX for indexing with traced indices font_atlas_jax = jnp.asarray(font_atlas) flat_indices = char_indices.flatten() char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3) # Reshape to grid char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3) # Create coordinate grids for output pixels y_out, x_out = jnp.mgrid[:output_h, :output_w] cell_row = y_out // char_size cell_col = x_out // char_size local_y = y_out % char_size local_x = x_out % char_size # Clamp to valid ranges cell_row = jnp.clip(cell_row, 0, num_rows - 1) cell_col = jnp.clip(cell_col, 0, num_cols - 1) # Get character pixel values char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] # Get character brightness (for masking) char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0 # Handle color modes if color_mode == 'mono': # White characters on background fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8) fg = jnp.broadcast_to(fg_color, char_pixels.shape) elif color_mode == 'invert': # Inverted cell colors cell_colors = colors[cell_row, cell_col] fg = 255 - cell_colors else: # 'color' mode - use cell colors fg = colors[cell_row, cell_col] # Blend foreground onto background based on character brightness if invert_colors: # Swap fg and bg fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg result = (fg.astype(jnp.float32) * (1 - char_brightness) + bg_broadcast.astype(jnp.float32) * char_brightness) else: bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) + fg.astype(jnp.float32) * char_brightness) result = jnp.clip(result, 0, 255).astype(jnp.uint8) # Resize back to original frame size if needed if result.shape[0] != h or result.shape[1] != w: # Simple nearest-neighbor resize y_scale = result.shape[0] / h x_scale = result.shape[1] / w y_src = (jnp.arange(h) * y_scale).astype(jnp.int32) x_src = (jnp.arange(w) * x_scale).astype(jnp.int32) y_src = jnp.clip(y_src, 0, result.shape[0] - 1) x_src = jnp.clip(x_src, 0, result.shape[1] - 1) result = result[y_src[:, None], x_src[None, :], :] return result if op == 'ascii-fx-zone': # Complex ASCII effect with per-zone expressions # (ascii-fx-zone frame :cols cols :alphabet alphabet ...) frame = self._eval(args[0], env) # Parse keyword arguments kwargs = {} i = 1 while i < len(args): if isinstance(args[i], Keyword): key = args[i].name if i + 1 < len(args): kwargs[key] = args[i + 1] i += 2 else: i += 1 # Get parameters cols = int(self._eval(kwargs.get('cols', 80), env)) char_size_param = kwargs.get('char_size') alphabet = self._eval(kwargs.get('alphabet', 'standard'), env) color_mode = self._eval(kwargs.get('color_mode', 'color'), env) background = self._eval(kwargs.get('background', 'black'), env) contrast = float(self._eval(kwargs.get('contrast', 1.5), env)) h, w = frame.shape[:2] # Calculate char_size from cols if not specified if char_size_param is not None: char_size_val = self._eval(char_size_param, env) if char_size_val is not None: char_size = int(char_size_val) else: char_size = w // cols else: char_size = w // cols char_size = max(4, min(char_size, 64)) # Store alphabet for render-char-grid to use env['_ascii_alphabet'] = alphabet # Cell sampling num_rows = h // char_size num_cols = w // char_size cropped = frame[:num_rows * char_size, :num_cols * char_size, :] reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) # Compute luminances colors_float = colors.astype(jnp.float32) luminances = (0.299 * colors_float[:, :, 0] + 0.587 * colors_float[:, :, 1] + 0.114 * colors_float[:, :, 2]) / 255.0 # Get alphabet and map luminances to chars alpha_str = _get_alphabet_string(alphabet) num_chars = len(alpha_str) lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) char_indices = jnp.clip(char_indices, 0, num_chars - 1) # Handle optional per-zone effects (char_hue, char_saturation, etc.) # These would modify colors based on zone position char_hue = kwargs.get('char_hue') char_saturation = kwargs.get('char_saturation') char_brightness = kwargs.get('char_brightness') if char_hue is not None or char_saturation is not None or char_brightness is not None: # Create zone coordinate arrays for expression evaluation row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] row_norm = row_coords / max(num_rows - 1, 1) col_norm = col_coords / max(num_cols - 1, 1) # Bind zone variables zone_env = env.copy() zone_env['zone-row'] = row_coords zone_env['zone-col'] = col_coords zone_env['zone-row-norm'] = row_norm zone_env['zone-col-norm'] = col_norm zone_env['zone-lum'] = luminances # Apply color modifications (simplified - full version would use HSV) if char_brightness is not None: brightness_mult = self._eval(char_brightness, zone_env) if brightness_mult is not None: colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None], 0, 255).astype(jnp.uint8) # Render using font atlas font_atlas = _create_font_atlas(alpha_str, char_size) # Parse background color if background == 'black': bg = jnp.array([0, 0, 0], dtype=jnp.uint8) elif background == 'white': bg = jnp.array([255, 255, 255], dtype=jnp.uint8) else: bg = jnp.array([0, 0, 0], dtype=jnp.uint8) # Gather characters - convert numpy atlas to JAX for traced indexing font_atlas_jax = jnp.asarray(font_atlas) flat_indices = char_indices.flatten() char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) # Create output output_h = num_rows * char_size output_w = num_cols * char_size y_out, x_out = jnp.mgrid[:output_h, :output_w] cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) local_y = y_out % char_size local_x = x_out % char_size char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 if color_mode == 'mono': fg = jnp.full_like(char_pixels, 255) else: fg = colors[cell_row, cell_col] bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + fg.astype(jnp.float32) * char_bright) result = jnp.clip(result, 0, 255).astype(jnp.uint8) # Resize to original dimensions if result.shape[0] != h or result.shape[1] != w: y_scale = result.shape[0] / h x_scale = result.shape[1] / w y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) result = result[y_src[:, None], x_src[None, :], :] return result if op == 'render-char-grid-fx': # Enhanced render with per-character effects # (render-char-grid-fx frame chars colors luminances char_size # color_mode bg_color invert_colors # char_jitter char_scale char_rotation char_hue_shift # jitter_source scale_source rotation_source hue_source) frame = self._eval(args[0], env) char_indices = self._eval(args[1], env) colors = self._eval(args[2], env) luminances = self._eval(args[3], env) char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8 color_mode = self._eval(args[5], env) if len(args) > 5 else 'color' background_color = self._eval(args[6], env) if len(args) > 6 else 'black' invert_colors = self._eval(args[7], env) if len(args) > 7 else 0 # Per-char effect amounts char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0 char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0 char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0 char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0 # Modulation sources jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none' scale_source = self._eval(args[13], env) if len(args) > 13 else 'none' rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none' hue_source = self._eval(args[15], env) if len(args) > 15 else 'none' h, w = frame.shape[:2] num_rows, num_cols = char_indices.shape # Get alphabet alphabet = env.get('_ascii_alphabet', 'standard') alpha_str = _get_alphabet_string(alphabet) font_atlas = _create_font_atlas(alpha_str, char_size) # Parse background if background_color == 'black': bg = jnp.array([0, 0, 0], dtype=jnp.uint8) elif background_color == 'white': bg = jnp.array([255, 255, 255], dtype=jnp.uint8) else: bg = jnp.array([0, 0, 0], dtype=jnp.uint8) # Create modulation values based on source def get_modulation(source, lums, num_rows, num_cols): row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] row_norm = row_coords / max(num_rows - 1, 1) col_norm = col_coords / max(num_cols - 1, 1) if source == 'luminance': return lums elif source == 'inv_luminance': return 1.0 - lums elif source == 'position_x': return col_norm elif source == 'position_y': return row_norm elif source == 'position_diag': return (row_norm + col_norm) / 2 elif source == 'center_dist': cy, cx = 0.5, 0.5 dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2) return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal elif source == 'random': # Use frame-varying key for random source seed = env.get('_seed', 42) op_ctr = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_ctr + 1 key = make_jax_key(seed, frame_num, op_ctr) return jax.random.uniform(key, (num_rows, num_cols)) else: return jnp.zeros((num_rows, num_cols)) # Get modulation values jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols) scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols) rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols) hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols) # Gather characters - convert numpy atlas to JAX for traced indexing font_atlas_jax = jnp.asarray(font_atlas) flat_indices = char_indices.flatten() char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) # Create output output_h = num_rows * char_size output_w = num_cols * char_size y_out, x_out = jnp.mgrid[:output_h, :output_w] cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) local_y = y_out % char_size local_x = x_out % char_size # Apply jitter if enabled if char_jitter > 0: jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter # Use frame-varying key for jitter seed = env.get('_seed', 42) op_ctr = env.get('_rand_op_counter', 0) env['_rand_op_counter'] = op_ctr + 1 key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2) # Generate deterministic jitter per cell jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1) jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1) offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32) offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32) local_y = jnp.clip(local_y + offset_y, 0, char_size - 1) local_x = jnp.clip(local_x + offset_x, 0, char_size - 1) char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 # Get foreground colors if color_mode == 'mono': fg = jnp.full_like(char_pixels, 255) else: fg = colors[cell_row, cell_col] # Apply hue shift if enabled if char_hue_shift > 0 and color_mode == 'color': hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift # Simple hue rotation via channel cycling fg_float = fg.astype(jnp.float32) shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB # Simplified: blend channels based on shift r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2] shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac # Just do a simple tint for now fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8) # Blend bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) if invert_colors: result = (fg.astype(jnp.float32) * (1 - char_bright) + bg_broadcast.astype(jnp.float32) * char_bright) else: result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + fg.astype(jnp.float32) * char_bright) result = jnp.clip(result, 0, 255).astype(jnp.uint8) # Resize to original if result.shape[0] != h or result.shape[1] != w: y_scale = result.shape[0] / h x_scale = result.shape[1] / w y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) result = result[y_src[:, None], x_src[None, :], :] return result if op == 'alphabet-char': # (alphabet-char alphabet-name index) -> char_index in that alphabet alphabet = self._eval(args[0], env) index = self._eval(args[1], env) alpha_str = _get_alphabet_string(alphabet) num_chars = len(alpha_str) # Handle both scalar and array indices if hasattr(index, 'shape'): index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1) else: index = max(0, min(int(index), num_chars - 1)) return index if op == 'map-char-grid': # (map-char-grid base-chars luminances (lambda (r c ch lum) ...)) # Map over character grid, allowing per-cell character selection base_chars = self._eval(args[0], env) luminances = self._eval(args[1], env) fn = args[2] # Lambda expression num_rows, num_cols = base_chars.shape # For JAX compatibility, we can't use Python loops with traced values # Instead, we'll evaluate the lambda for the whole grid at once if isinstance(fn, list) and len(fn) >= 3: head = fn[0] if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): params = fn[1] body = fn[2] # Create grid coordinates row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] # Bind parameters for whole-grid evaluation fn_env = env.copy() # Params: (r c ch lum) if len(params) >= 1: fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords if len(params) >= 2: fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords if len(params) >= 3: fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars if len(params) >= 4: # Luminances scaled to 0-255 range fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32) # Evaluate body - should return new character indices result = self._eval(body, fn_env) if hasattr(result, 'shape'): return result.astype(jnp.int32) return base_chars return base_chars # ===================================================================== # List operations # ===================================================================== if op == 'take': seq = self._eval(args[0], env) n = int(self._eval(args[1], env)) if isinstance(seq, (list, tuple)): return seq[:n] return seq[:n] # Works for arrays too if op == 'cons': item = self._eval(args[0], env) seq = self._eval(args[1], env) if isinstance(seq, list): return [item] + seq elif isinstance(seq, tuple): return (item,) + seq return jnp.concatenate([jnp.array([item]), seq]) if op == 'roll': arr = self._eval(args[0], env) shift = self._eval(args[1], env) axis = self._eval(args[2], env) if len(args) > 2 else 0 # Convert to int for concrete values, keep as-is for JAX traced values if isinstance(shift, (int, float)): shift = int(shift) elif hasattr(shift, 'astype'): shift = shift.astype(jnp.int32) if isinstance(axis, (int, float)): axis = int(axis) return jnp.roll(arr, shift, axis=axis) # ===================================================================== # Pi constant # ===================================================================== if op == 'pi': return jnp.pi raise ValueError(f"Unknown operation: {op}") # ============================================================================= # Public API # ============================================================================= def compile_effect(code: str) -> Callable: """ Compile an S-expression effect to a JAX function. Args: code: S-expression effect code Returns: JIT-compiled function: (frame, **params) -> frame """ # Check cache cache_key = hashlib.md5(code.encode()).hexdigest() if cache_key in _COMPILED_EFFECTS: return _COMPILED_EFFECTS[cache_key] # Parse and compile sexp = parse(code) compiler = JaxCompiler() fn = compiler.compile_effect(sexp) _COMPILED_EFFECTS[cache_key] = fn return fn def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable: """ Compile an effect from a .sexp file. Args: path: Path to the .sexp effect file derived_paths: Optional list of paths to derived.sexp files to load Returns: JIT-compiled function: (frame, **params) -> frame """ with open(path, 'r') as f: code = f.read() # Parse all expressions in file exprs = parse_all(code) # Create compiler compiler = JaxCompiler() # Load derived files if specified if derived_paths: for dp in derived_paths: compiler.load_derived(dp) # Process expressions - find require statements and the effect effect_sexp = None effect_dir = Path(path).parent for expr in exprs: if not isinstance(expr, list) or len(expr) < 2: continue head = expr[0] if not isinstance(head, Symbol): continue if head.name == 'require': # (require "derived") or (require "path/to/file") req_path = expr[1] if isinstance(req_path, str): # Resolve relative to effect file if not req_path.endswith('.sexp'): req_path = req_path + '.sexp' full_path = effect_dir / req_path if not full_path.exists(): # Try sexp_effects directory full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path if full_path.exists(): compiler.load_derived(str(full_path)) elif head.name == 'require-primitives': # (require-primitives "lib") - currently ignored for JAX # JAX has all primitives built-in pass elif head.name in ('effect', 'define-effect'): effect_sexp = expr if effect_sexp is None: raise ValueError(f"No effect definition found in {path}") return compiler.compile_effect(effect_sexp) def load_derived(derived_path: str = None) -> Dict[str, Callable]: """ Load derived operations from derived.sexp. Returns dict of compiled functions that can be used in effects. """ if derived_path is None: derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp' with open(derived_path, 'r') as f: code = f.read() exprs = parse_all(code) compiler = JaxCompiler() env = {} for expr in exprs: if isinstance(expr, list) and len(expr) >= 3: head = expr[0] if isinstance(head, Symbol) and head.name == 'define': compiler._eval_define(expr[1:], env) return env # ============================================================================= # Test / Demo # ============================================================================= if __name__ == '__main__': import numpy as np # Test effect test_effect = ''' (effect "threshold-test" :params ((threshold :default 128)) :body (let* ((r (channel frame 0)) (g (channel frame 1)) (b (channel frame 2)) (gray (+ (* r 0.299) (* g 0.587) (* b 0.114))) (mask (where (> gray threshold) 255 0))) (merge-channels mask mask mask))) ''' print("Compiling effect...") run_effect = compile_effect(test_effect) # Create test frame print("Creating test frame...") frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) # Run effect print("Running effect (first run includes JIT compilation)...") import time t0 = time.time() result = run_effect(frame, threshold=128) t1 = time.time() print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms") # Second run should be faster t0 = time.time() result = run_effect(frame, threshold=128) t1 = time.time() print(f"Second run (cached): {(t1-t0)*1000:.2f}ms") # Multiple runs t0 = time.time() for _ in range(100): result = run_effect(frame, threshold=128) t1 = time.time() print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg") print(f"\nResult shape: {result.shape}, dtype: {result.dtype}") print("Done!")