""" S-Expression to WGSL Compiler Compiles sexp effect definitions to WGSL compute shaders for GPU execution. The compilation happens at effect upload time (AOT), not at runtime. Architecture: - Parse sexp AST - Analyze primitives used - Generate WGSL compute shader Shader Categories: 1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel) 2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample) 3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors) """ from typing import Dict, List, Any, Optional, Tuple, Set from dataclasses import dataclass, field from pathlib import Path import math from .parser import parse, parse_all, Symbol, Keyword @dataclass class WGSLParam: """A shader parameter (uniform).""" name: str wgsl_type: str # f32, i32, u32, vec2f, etc. default: Any @dataclass class CompiledEffect: """Result of compiling an sexp effect to WGSL.""" name: str wgsl_code: str params: List[WGSLParam] workgroup_size: Tuple[int, int, int] = (16, 16, 1) # Metadata for runtime uses_time: bool = False uses_sampling: bool = False # Needs texture sampler category: str = "per_pixel" # per_pixel, geometric, neighborhood @dataclass class CompilerContext: """Context during compilation.""" effect_name: str = "" params: Dict[str, WGSLParam] = field(default_factory=dict) locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr required_libs: Set[str] = field(default_factory=set) uses_time: bool = False uses_sampling: bool = False temp_counter: int = 0 def fresh_temp(self) -> str: """Generate a fresh temporary variable name.""" self.temp_counter += 1 return f"_t{self.temp_counter}" class SexpToWGSLCompiler: """ Compiles S-expression effect definitions to WGSL compute shaders. """ # Map sexp types to WGSL types TYPE_MAP = { 'int': 'i32', 'float': 'f32', 'bool': 'u32', # WGSL doesn't have bool in storage 'string': None, # Strings handled specially } # Per-pixel primitives that can be compiled directly PER_PIXEL_PRIMITIVES = { 'color_ops:invert-img', 'color_ops:grayscale', 'color_ops:sepia', 'color_ops:adjust', 'color_ops:adjust-brightness', 'color_ops:shift-hsv', 'color_ops:quantize', } # Geometric primitives (coordinate remapping) GEOMETRIC_PRIMITIVES = { 'geometry:scale-img', 'geometry:rotate-img', 'geometry:translate', 'geometry:flip-h', 'geometry:flip-v', 'geometry:remap', } def __init__(self): self.ctx: Optional[CompilerContext] = None def compile_file(self, path: str) -> CompiledEffect: """Compile an effect from a .sexp file.""" with open(path, 'r') as f: content = f.read() exprs = parse_all(content) return self.compile(exprs) def compile_string(self, sexp_code: str) -> CompiledEffect: """Compile an effect from an sexp string.""" exprs = parse_all(sexp_code) return self.compile(exprs) def compile(self, expr: Any) -> CompiledEffect: """Compile a parsed sexp expression.""" self.ctx = CompilerContext() # Handle multiple top-level expressions (require-primitives, define-effect) if isinstance(expr, list) and expr and isinstance(expr[0], list): for e in expr: self._process_toplevel(e) else: self._process_toplevel(expr) # Generate the WGSL shader wgsl = self._generate_wgsl() # Determine category based on primitives used category = self._determine_category() return CompiledEffect( name=self.ctx.effect_name, wgsl_code=wgsl, params=list(self.ctx.params.values()), uses_time=self.ctx.uses_time, uses_sampling=self.ctx.uses_sampling, category=category, ) def _process_toplevel(self, expr: Any): """Process a top-level expression.""" if not isinstance(expr, list) or not expr: return head = expr[0] if isinstance(head, Symbol): if head.name == 'require-primitives': # Track required primitive libraries for lib in expr[1:]: lib_name = lib.name if isinstance(lib, Symbol) else str(lib) self.ctx.required_libs.add(lib_name) elif head.name == 'define-effect': self._compile_effect_def(expr) def _compile_effect_def(self, expr: list): """Compile a define-effect form.""" # (define-effect name :params (...) body) self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1]) # Parse :params and body i = 2 body = None while i < len(expr): item = expr[i] if isinstance(item, Keyword) and item.name == 'params': self._parse_params(expr[i + 1]) i += 2 elif isinstance(item, Keyword): i += 2 # Skip other keywords else: body = item i += 1 if body: self.ctx.body_expr = body def _parse_params(self, params_list: list): """Parse the :params block.""" for param_def in params_list: if not isinstance(param_def, list): continue name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0]) # Parse keyword args param_type = 'float' default = 0 i = 1 while i < len(param_def): item = param_def[i] if isinstance(item, Keyword): if i + 1 < len(param_def): val = param_def[i + 1] if item.name == 'type': param_type = val.name if isinstance(val, Symbol) else str(val) elif item.name == 'default': default = val i += 2 else: i += 1 wgsl_type = self.TYPE_MAP.get(param_type, 'f32') if wgsl_type: self.ctx.params[name] = WGSLParam(name, wgsl_type, default) def _determine_category(self) -> str: """Determine shader category based on primitives used.""" for lib in self.ctx.required_libs: if lib == 'geometry': return 'geometric' if lib == 'filters': return 'neighborhood' return 'per_pixel' def _generate_wgsl(self) -> str: """Generate the complete WGSL shader code.""" lines = [] # Header comment lines.append(f"// WGSL Shader: {self.ctx.effect_name}") lines.append(f"// Auto-generated from sexp effect definition") lines.append("") # Bindings lines.append("@group(0) @binding(0) var input: array;") lines.append("@group(0) @binding(1) var output: array;") lines.append("") # Params struct if self.ctx.params: lines.append("struct Params {") lines.append(" width: u32,") lines.append(" height: u32,") lines.append(" time: f32,") for param in self.ctx.params.values(): lines.append(f" {param.name}: {param.wgsl_type},") lines.append("}") lines.append("@group(0) @binding(2) var params: Params;") else: lines.append("struct Params {") lines.append(" width: u32,") lines.append(" height: u32,") lines.append(" time: f32,") lines.append("}") lines.append("@group(0) @binding(2) var params: Params;") lines.append("") # Helper functions lines.extend(self._generate_helpers()) lines.append("") # Main compute shader lines.append("@compute @workgroup_size(16, 16, 1)") lines.append("fn main(@builtin(global_invocation_id) gid: vec3) {") lines.append(" let x = gid.x;") lines.append(" let y = gid.y;") lines.append(" if (x >= params.width || y >= params.height) { return; }") lines.append(" let idx = y * params.width + x;") lines.append("") # Compile the effect body body_code = self._compile_expr(self.ctx.body_expr) lines.append(f" // Effect: {self.ctx.effect_name}") lines.append(body_code) lines.append("}") return "\n".join(lines) def _generate_helpers(self) -> List[str]: """Generate WGSL helper functions.""" helpers = [] # Pack/unpack RGB from u32 helpers.append("fn unpack_rgb(packed: u32) -> vec3 {") helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;") helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;") helpers.append(" let b = f32(packed & 0xFFu) / 255.0;") helpers.append(" return vec3(r, g, b);") helpers.append("}") helpers.append("") helpers.append("fn pack_rgb(rgb: vec3) -> u32 {") helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);") helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);") helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);") helpers.append(" return (r << 16u) | (g << 8u) | b;") helpers.append("}") helpers.append("") # Bilinear sampling for geometric transforms if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs: helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3 {") helpers.append(" let w = f32(params.width);") helpers.append(" let h = f32(params.height);") helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);") helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);") helpers.append(" let x0 = u32(cx);") helpers.append(" let y0 = u32(cy);") helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);") helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);") helpers.append(" let fx = cx - f32(x0);") helpers.append(" let fy = cy - f32(y0);") helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);") helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);") helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);") helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);") helpers.append(" let top = mix(c00, c10, fx);") helpers.append(" let bot = mix(c01, c11, fx);") helpers.append(" return mix(top, bot, fy);") helpers.append("}") helpers.append("") # HSV conversion for color effects if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs: helpers.append("fn rgb_to_hsv(rgb: vec3) -> vec3 {") helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);") helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);") helpers.append(" let d = mx - mn;") helpers.append(" var h = 0.0;") helpers.append(" if (d > 0.0) {") helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }") helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }") helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }") helpers.append(" h = h / 6.0;") helpers.append(" if (h < 0.0) { h = h + 1.0; }") helpers.append(" }") helpers.append(" let s = select(0.0, d / mx, mx > 0.0);") helpers.append(" return vec3(h, s, mx);") helpers.append("}") helpers.append("") helpers.append("fn hsv_to_rgb(hsv: vec3) -> vec3 {") helpers.append(" let h = hsv.x * 6.0;") helpers.append(" let s = hsv.y;") helpers.append(" let v = hsv.z;") helpers.append(" let c = v * s;") helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));") helpers.append(" let m = v - c;") helpers.append(" var rgb: vec3;") helpers.append(" if (h < 1.0) { rgb = vec3(c, x, 0.0); }") helpers.append(" else if (h < 2.0) { rgb = vec3(x, c, 0.0); }") helpers.append(" else if (h < 3.0) { rgb = vec3(0.0, c, x); }") helpers.append(" else if (h < 4.0) { rgb = vec3(0.0, x, c); }") helpers.append(" else if (h < 5.0) { rgb = vec3(x, 0.0, c); }") helpers.append(" else { rgb = vec3(c, 0.0, x); }") helpers.append(" return rgb + vec3(m, m, m);") helpers.append("}") helpers.append("") return helpers def _compile_expr(self, expr: Any, indent: int = 4) -> str: """Compile an sexp expression to WGSL code.""" ind = " " * indent # Literals if isinstance(expr, (int, float)): return f"{ind}// literal: {expr}" if isinstance(expr, str): return f'{ind}// string: "{expr}"' # Symbol reference if isinstance(expr, Symbol): name = expr.name if name == 'frame': return f"{ind}let rgb = unpack_rgb(input[idx]);" if name == 't' or name == '_time': self.ctx.uses_time = True return f"{ind}let t = params.time;" if name in self.ctx.params: return f"{ind}let {name} = params.{name};" if name in self.ctx.locals: return f"{ind}// local: {name}" return f"{ind}// unknown symbol: {name}" # List (function call or special form) if isinstance(expr, list) and expr: head = expr[0] if isinstance(head, Symbol): form = head.name # Special forms if form == 'let' or form == 'let*': return self._compile_let(expr, indent) if form == 'if': return self._compile_if(expr, indent) if form == 'or': # (or a b) - return a if truthy, else b return self._compile_or(expr, indent) # Primitive calls if ':' in form: return self._compile_primitive_call(expr, indent) # Arithmetic if form in ('+', '-', '*', '/'): return self._compile_arithmetic(expr, indent) if form in ('>', '<', '>=', '<=', '='): return self._compile_comparison(expr, indent) if form == 'max': return self._compile_builtin('max', expr[1:], indent) if form == 'min': return self._compile_builtin('min', expr[1:], indent) return f"{ind}// unhandled: {expr}" def _compile_let(self, expr: list, indent: int) -> str: """Compile let/let* binding form.""" ind = " " * indent lines = [] bindings = expr[1] body = expr[2] # Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2))) pairs = [] if bindings and isinstance(bindings[0], Symbol): # Clojure style i = 0 while i < len(bindings) - 1: name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) value = bindings[i + 1] pairs.append((name, value)) i += 2 else: # Scheme style for binding in bindings: name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) value = binding[1] pairs.append((name, value)) # Compile bindings for name, value in pairs: val_code = self._expr_to_wgsl(value) lines.append(f"{ind}let {name} = {val_code};") self.ctx.locals[name] = val_code # Compile body body_lines = self._compile_body(body, indent) lines.append(body_lines) return "\n".join(lines) def _compile_body(self, body: Any, indent: int) -> str: """Compile the body of an effect (the final image expression).""" ind = " " * indent # Most effects end with a primitive call that produces the output if isinstance(body, list) and body: head = body[0] if isinstance(head, Symbol) and ':' in head.name: return self._compile_primitive_call(body, indent) # If body is just 'frame', pass through if isinstance(body, Symbol) and body.name == 'frame': return f"{ind}output[idx] = input[idx];" return f"{ind}// body: {body}" def _compile_primitive_call(self, expr: list, indent: int) -> str: """Compile a primitive function call.""" ind = " " * indent head = expr[0] prim_name = head.name if isinstance(head, Symbol) else str(head) args = expr[1:] # Per-pixel color operations if prim_name == 'color_ops:invert-img': return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let result = vec3(1.0, 1.0, 1.0) - rgb; {ind}output[idx] = pack_rgb(result);""" if prim_name == 'color_ops:grayscale': return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b; {ind}let result = vec3(gray, gray, gray); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'color_ops:adjust-brightness': amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let adj = f32({amount}) / 255.0; {ind}let result = clamp(rgb + vec3(adj, adj, adj), vec3(0.0, 0.0, 0.0), vec3(1.0, 1.0, 1.0)); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'color_ops:adjust': # (adjust img brightness contrast) brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let centered = rgb - vec3(0.5, 0.5, 0.5); {ind}let contrasted = centered * {contrast}; {ind}let brightened = contrasted + vec3(0.5, 0.5, 0.5) + vec3({brightness}/255.0); {ind}let result = clamp(brightened, vec3(0.0), vec3(1.0)); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'color_ops:sepia': intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b; {ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b; {ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b; {ind}let sepia = vec3(sepia_r, sepia_g, sepia_b); {ind}let result = mix(rgb, sepia, {intensity}); {ind}output[idx] = pack_rgb(clamp(result, vec3(0.0), vec3(1.0)));""" if prim_name == 'color_ops:shift-hsv': h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0" return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}var hsv = rgb_to_hsv(rgb); {ind}hsv.x = fract(hsv.x + {h_shift} / 360.0); {ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0); {ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0); {ind}let result = hsv_to_rgb(hsv); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'color_ops:quantize': levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0" return f"""{ind}let rgb = unpack_rgb(input[idx]); {ind}let lvl = max(2.0, {levels}); {ind}let result = floor(rgb * lvl) / lvl; {ind}output[idx] = pack_rgb(result);""" # Geometric transforms if prim_name == 'geometry:scale-img': sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx self.ctx.uses_sampling = True return f"""{ind}let w = f32(params.width); {ind}let h = f32(params.height); {ind}let cx = w / 2.0; {ind}let cy = h / 2.0; {ind}let sx = f32(x) - cx; {ind}let sy = f32(y) - cy; {ind}let src_x = sx / {sx} + cx; {ind}let src_y = sy / {sy} + cy; {ind}let result = sample_bilinear(src_x, src_y); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'geometry:rotate-img': angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" self.ctx.uses_sampling = True return f"""{ind}let w = f32(params.width); {ind}let h = f32(params.height); {ind}let cx = w / 2.0; {ind}let cy = h / 2.0; {ind}let angle_rad = {angle} * 3.14159265 / 180.0; {ind}let cos_a = cos(-angle_rad); {ind}let sin_a = sin(-angle_rad); {ind}let dx = f32(x) - cx; {ind}let dy = f32(y) - cy; {ind}let src_x = dx * cos_a - dy * sin_a + cx; {ind}let src_y = dx * sin_a + dy * cos_a + cy; {ind}let result = sample_bilinear(src_x, src_y); {ind}output[idx] = pack_rgb(result);""" if prim_name == 'geometry:flip-h': return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x); {ind}output[idx] = input[src_idx];""" if prim_name == 'geometry:flip-v': return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x; {ind}output[idx] = input[src_idx];""" # Image library if prim_name == 'image:blur': radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5" # Box blur approximation (separable would be better) return f"""{ind}let radius = i32({radius}); {ind}var sum = vec3(0.0, 0.0, 0.0); {ind}var count = 0.0; {ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{ {ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{ {ind} let sx = i32(x) + dx; {ind} let sy = i32(y) + dy; {ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{ {ind} let sidx = u32(sy) * params.width + u32(sx); {ind} sum = sum + unpack_rgb(input[sidx]); {ind} count = count + 1.0; {ind} }} {ind} }} {ind}}} {ind}let result = sum / count; {ind}output[idx] = pack_rgb(result);""" # Fallback - passthrough return f"""{ind}// Unimplemented primitive: {prim_name} {ind}output[idx] = input[idx];""" def _compile_if(self, expr: list, indent: int) -> str: """Compile if expression.""" ind = " " * indent cond = self._expr_to_wgsl(expr[1]) then_expr = expr[2] else_expr = expr[3] if len(expr) > 3 else None lines = [] lines.append(f"{ind}if ({cond}) {{") lines.append(self._compile_body(then_expr, indent + 4)) if else_expr: lines.append(f"{ind}}} else {{") lines.append(self._compile_body(else_expr, indent + 4)) lines.append(f"{ind}}}") return "\n".join(lines) def _compile_or(self, expr: list, indent: int) -> str: """Compile or expression - returns first truthy value.""" # For numeric context, (or a b) means "a if a != 0 else b" a = self._expr_to_wgsl(expr[1]) b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0" return f"select({b}, {a}, {a} != 0.0)" def _compile_arithmetic(self, expr: list, indent: int) -> str: """Compile arithmetic expression to inline WGSL.""" op = expr[0].name operands = [self._expr_to_wgsl(arg) for arg in expr[1:]] if len(operands) == 1: if op == '-': return f"(-{operands[0]})" return operands[0] return f"({f' {op} '.join(operands)})" def _compile_comparison(self, expr: list, indent: int) -> str: """Compile comparison expression.""" op = expr[0].name if op == '=': op = '==' a = self._expr_to_wgsl(expr[1]) b = self._expr_to_wgsl(expr[2]) return f"({a} {op} {b})" def _compile_builtin(self, fn: str, args: list, indent: int) -> str: """Compile builtin function call.""" compiled_args = [self._expr_to_wgsl(arg) for arg in args] return f"{fn}({', '.join(compiled_args)})" def _expr_to_wgsl(self, expr: Any) -> str: """Convert an expression to inline WGSL code.""" if isinstance(expr, (int, float)): # Ensure floats have decimal point if isinstance(expr, float) or '.' not in str(expr): return f"{float(expr)}" return str(expr) if isinstance(expr, str): return f'"{expr}"' if isinstance(expr, Symbol): name = expr.name if name == 'frame': return "rgb" # Assume rgb is already loaded if name == 't' or name == '_time': self.ctx.uses_time = True return "params.time" if name == 'pi': return "3.14159265" if name in self.ctx.params: return f"params.{name}" if name in self.ctx.locals: return name return name if isinstance(expr, list) and expr: head = expr[0] if isinstance(head, Symbol): form = head.name # Arithmetic if form in ('+', '-', '*', '/'): return self._compile_arithmetic(expr, 0) # Comparison if form in ('>', '<', '>=', '<=', '='): return self._compile_comparison(expr, 0) # Builtins if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'): args = [self._expr_to_wgsl(a) for a in expr[1:]] return f"{form}({', '.join(args)})" if form == 'or': return self._compile_or(expr, 0) # Image dimension queries if form == 'image:width': return "f32(params.width)" if form == 'image:height': return "f32(params.height)" return f"/* unknown: {expr} */" def compile_effect(sexp_code: str) -> CompiledEffect: """Convenience function to compile an sexp effect string.""" compiler = SexpToWGSLCompiler() return compiler.compile_string(sexp_code) def compile_effect_file(path: str) -> CompiledEffect: """Convenience function to compile an sexp effect file.""" compiler = SexpToWGSLCompiler() return compiler.compile_file(path)