Files
celery/sexp_effects/wgsl_compiler.py
giles 86830019ad Add IPFS HLS streaming and GPU optimizations
- Add IPFSHLSOutput class that uploads segments to IPFS as they're created
- Update streaming task to use IPFS HLS output for distributed streaming
- Add /ipfs-stream endpoint to get IPFS playlist URL
- Update /stream endpoint to redirect to IPFS when available
- Add GPU persistence mode (STREAMING_GPU_PERSIST=1) to keep frames on GPU
- Add hardware video decoding (NVDEC) support for faster video processing
- Add GPU-accelerated primitive libraries: blending_gpu, color_ops_gpu, geometry_gpu
- Add streaming_gpu module with GPUFrame class for tracking CPU/GPU data location
- Add Dockerfile.gpu for building GPU-enabled worker image

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:23:16 +00:00

716 lines
27 KiB
Python

"""
S-Expression to WGSL Compiler
Compiles sexp effect definitions to WGSL compute shaders for GPU execution.
The compilation happens at effect upload time (AOT), not at runtime.
Architecture:
- Parse sexp AST
- Analyze primitives used
- Generate WGSL compute shader
Shader Categories:
1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel)
2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample)
3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors)
"""
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from pathlib import Path
import math
from .parser import parse, parse_all, Symbol, Keyword
@dataclass
class WGSLParam:
"""A shader parameter (uniform)."""
name: str
wgsl_type: str # f32, i32, u32, vec2f, etc.
default: Any
@dataclass
class CompiledEffect:
"""Result of compiling an sexp effect to WGSL."""
name: str
wgsl_code: str
params: List[WGSLParam]
workgroup_size: Tuple[int, int, int] = (16, 16, 1)
# Metadata for runtime
uses_time: bool = False
uses_sampling: bool = False # Needs texture sampler
category: str = "per_pixel" # per_pixel, geometric, neighborhood
@dataclass
class CompilerContext:
"""Context during compilation."""
effect_name: str = ""
params: Dict[str, WGSLParam] = field(default_factory=dict)
locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr
required_libs: Set[str] = field(default_factory=set)
uses_time: bool = False
uses_sampling: bool = False
temp_counter: int = 0
def fresh_temp(self) -> str:
"""Generate a fresh temporary variable name."""
self.temp_counter += 1
return f"_t{self.temp_counter}"
class SexpToWGSLCompiler:
"""
Compiles S-expression effect definitions to WGSL compute shaders.
"""
# Map sexp types to WGSL types
TYPE_MAP = {
'int': 'i32',
'float': 'f32',
'bool': 'u32', # WGSL doesn't have bool in storage
'string': None, # Strings handled specially
}
# Per-pixel primitives that can be compiled directly
PER_PIXEL_PRIMITIVES = {
'color_ops:invert-img',
'color_ops:grayscale',
'color_ops:sepia',
'color_ops:adjust',
'color_ops:adjust-brightness',
'color_ops:shift-hsv',
'color_ops:quantize',
}
# Geometric primitives (coordinate remapping)
GEOMETRIC_PRIMITIVES = {
'geometry:scale-img',
'geometry:rotate-img',
'geometry:translate',
'geometry:flip-h',
'geometry:flip-v',
'geometry:remap',
}
def __init__(self):
self.ctx: Optional[CompilerContext] = None
def compile_file(self, path: str) -> CompiledEffect:
"""Compile an effect from a .sexp file."""
with open(path, 'r') as f:
content = f.read()
exprs = parse_all(content)
return self.compile(exprs)
def compile_string(self, sexp_code: str) -> CompiledEffect:
"""Compile an effect from an sexp string."""
exprs = parse_all(sexp_code)
return self.compile(exprs)
def compile(self, expr: Any) -> CompiledEffect:
"""Compile a parsed sexp expression."""
self.ctx = CompilerContext()
# Handle multiple top-level expressions (require-primitives, define-effect)
if isinstance(expr, list) and expr and isinstance(expr[0], list):
for e in expr:
self._process_toplevel(e)
else:
self._process_toplevel(expr)
# Generate the WGSL shader
wgsl = self._generate_wgsl()
# Determine category based on primitives used
category = self._determine_category()
return CompiledEffect(
name=self.ctx.effect_name,
wgsl_code=wgsl,
params=list(self.ctx.params.values()),
uses_time=self.ctx.uses_time,
uses_sampling=self.ctx.uses_sampling,
category=category,
)
def _process_toplevel(self, expr: Any):
"""Process a top-level expression."""
if not isinstance(expr, list) or not expr:
return
head = expr[0]
if isinstance(head, Symbol):
if head.name == 'require-primitives':
# Track required primitive libraries
for lib in expr[1:]:
lib_name = lib.name if isinstance(lib, Symbol) else str(lib)
self.ctx.required_libs.add(lib_name)
elif head.name == 'define-effect':
self._compile_effect_def(expr)
def _compile_effect_def(self, expr: list):
"""Compile a define-effect form."""
# (define-effect name :params (...) body)
self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1])
# Parse :params and body
i = 2
body = None
while i < len(expr):
item = expr[i]
if isinstance(item, Keyword) and item.name == 'params':
self._parse_params(expr[i + 1])
i += 2
elif isinstance(item, Keyword):
i += 2 # Skip other keywords
else:
body = item
i += 1
if body:
self.ctx.body_expr = body
def _parse_params(self, params_list: list):
"""Parse the :params block."""
for param_def in params_list:
if not isinstance(param_def, list):
continue
name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0])
# Parse keyword args
param_type = 'float'
default = 0
i = 1
while i < len(param_def):
item = param_def[i]
if isinstance(item, Keyword):
if i + 1 < len(param_def):
val = param_def[i + 1]
if item.name == 'type':
param_type = val.name if isinstance(val, Symbol) else str(val)
elif item.name == 'default':
default = val
i += 2
else:
i += 1
wgsl_type = self.TYPE_MAP.get(param_type, 'f32')
if wgsl_type:
self.ctx.params[name] = WGSLParam(name, wgsl_type, default)
def _determine_category(self) -> str:
"""Determine shader category based on primitives used."""
for lib in self.ctx.required_libs:
if lib == 'geometry':
return 'geometric'
if lib == 'filters':
return 'neighborhood'
return 'per_pixel'
def _generate_wgsl(self) -> str:
"""Generate the complete WGSL shader code."""
lines = []
# Header comment
lines.append(f"// WGSL Shader: {self.ctx.effect_name}")
lines.append(f"// Auto-generated from sexp effect definition")
lines.append("")
# Bindings
lines.append("@group(0) @binding(0) var<storage, read> input: array<u32>;")
lines.append("@group(0) @binding(1) var<storage, read_write> output: array<u32>;")
lines.append("")
# Params struct
if self.ctx.params:
lines.append("struct Params {")
lines.append(" width: u32,")
lines.append(" height: u32,")
lines.append(" time: f32,")
for param in self.ctx.params.values():
lines.append(f" {param.name}: {param.wgsl_type},")
lines.append("}")
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
else:
lines.append("struct Params {")
lines.append(" width: u32,")
lines.append(" height: u32,")
lines.append(" time: f32,")
lines.append("}")
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
lines.append("")
# Helper functions
lines.extend(self._generate_helpers())
lines.append("")
# Main compute shader
lines.append("@compute @workgroup_size(16, 16, 1)")
lines.append("fn main(@builtin(global_invocation_id) gid: vec3<u32>) {")
lines.append(" let x = gid.x;")
lines.append(" let y = gid.y;")
lines.append(" if (x >= params.width || y >= params.height) { return; }")
lines.append(" let idx = y * params.width + x;")
lines.append("")
# Compile the effect body
body_code = self._compile_expr(self.ctx.body_expr)
lines.append(f" // Effect: {self.ctx.effect_name}")
lines.append(body_code)
lines.append("}")
return "\n".join(lines)
def _generate_helpers(self) -> List[str]:
"""Generate WGSL helper functions."""
helpers = []
# Pack/unpack RGB from u32
helpers.append("fn unpack_rgb(packed: u32) -> vec3<f32> {")
helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;")
helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;")
helpers.append(" let b = f32(packed & 0xFFu) / 255.0;")
helpers.append(" return vec3<f32>(r, g, b);")
helpers.append("}")
helpers.append("")
helpers.append("fn pack_rgb(rgb: vec3<f32>) -> u32 {")
helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);")
helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);")
helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);")
helpers.append(" return (r << 16u) | (g << 8u) | b;")
helpers.append("}")
helpers.append("")
# Bilinear sampling for geometric transforms
if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs:
helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3<f32> {")
helpers.append(" let w = f32(params.width);")
helpers.append(" let h = f32(params.height);")
helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);")
helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);")
helpers.append(" let x0 = u32(cx);")
helpers.append(" let y0 = u32(cy);")
helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);")
helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);")
helpers.append(" let fx = cx - f32(x0);")
helpers.append(" let fy = cy - f32(y0);")
helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);")
helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);")
helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);")
helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);")
helpers.append(" let top = mix(c00, c10, fx);")
helpers.append(" let bot = mix(c01, c11, fx);")
helpers.append(" return mix(top, bot, fy);")
helpers.append("}")
helpers.append("")
# HSV conversion for color effects
if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs:
helpers.append("fn rgb_to_hsv(rgb: vec3<f32>) -> vec3<f32> {")
helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);")
helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);")
helpers.append(" let d = mx - mn;")
helpers.append(" var h = 0.0;")
helpers.append(" if (d > 0.0) {")
helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }")
helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }")
helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }")
helpers.append(" h = h / 6.0;")
helpers.append(" if (h < 0.0) { h = h + 1.0; }")
helpers.append(" }")
helpers.append(" let s = select(0.0, d / mx, mx > 0.0);")
helpers.append(" return vec3<f32>(h, s, mx);")
helpers.append("}")
helpers.append("")
helpers.append("fn hsv_to_rgb(hsv: vec3<f32>) -> vec3<f32> {")
helpers.append(" let h = hsv.x * 6.0;")
helpers.append(" let s = hsv.y;")
helpers.append(" let v = hsv.z;")
helpers.append(" let c = v * s;")
helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));")
helpers.append(" let m = v - c;")
helpers.append(" var rgb: vec3<f32>;")
helpers.append(" if (h < 1.0) { rgb = vec3<f32>(c, x, 0.0); }")
helpers.append(" else if (h < 2.0) { rgb = vec3<f32>(x, c, 0.0); }")
helpers.append(" else if (h < 3.0) { rgb = vec3<f32>(0.0, c, x); }")
helpers.append(" else if (h < 4.0) { rgb = vec3<f32>(0.0, x, c); }")
helpers.append(" else if (h < 5.0) { rgb = vec3<f32>(x, 0.0, c); }")
helpers.append(" else { rgb = vec3<f32>(c, 0.0, x); }")
helpers.append(" return rgb + vec3<f32>(m, m, m);")
helpers.append("}")
helpers.append("")
return helpers
def _compile_expr(self, expr: Any, indent: int = 4) -> str:
"""Compile an sexp expression to WGSL code."""
ind = " " * indent
# Literals
if isinstance(expr, (int, float)):
return f"{ind}// literal: {expr}"
if isinstance(expr, str):
return f'{ind}// string: "{expr}"'
# Symbol reference
if isinstance(expr, Symbol):
name = expr.name
if name == 'frame':
return f"{ind}let rgb = unpack_rgb(input[idx]);"
if name == 't' or name == '_time':
self.ctx.uses_time = True
return f"{ind}let t = params.time;"
if name in self.ctx.params:
return f"{ind}let {name} = params.{name};"
if name in self.ctx.locals:
return f"{ind}// local: {name}"
return f"{ind}// unknown symbol: {name}"
# List (function call or special form)
if isinstance(expr, list) and expr:
head = expr[0]
if isinstance(head, Symbol):
form = head.name
# Special forms
if form == 'let' or form == 'let*':
return self._compile_let(expr, indent)
if form == 'if':
return self._compile_if(expr, indent)
if form == 'or':
# (or a b) - return a if truthy, else b
return self._compile_or(expr, indent)
# Primitive calls
if ':' in form:
return self._compile_primitive_call(expr, indent)
# Arithmetic
if form in ('+', '-', '*', '/'):
return self._compile_arithmetic(expr, indent)
if form in ('>', '<', '>=', '<=', '='):
return self._compile_comparison(expr, indent)
if form == 'max':
return self._compile_builtin('max', expr[1:], indent)
if form == 'min':
return self._compile_builtin('min', expr[1:], indent)
return f"{ind}// unhandled: {expr}"
def _compile_let(self, expr: list, indent: int) -> str:
"""Compile let/let* binding form."""
ind = " " * indent
lines = []
bindings = expr[1]
body = expr[2]
# Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2)))
pairs = []
if bindings and isinstance(bindings[0], Symbol):
# Clojure style
i = 0
while i < len(bindings) - 1:
name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
value = bindings[i + 1]
pairs.append((name, value))
i += 2
else:
# Scheme style
for binding in bindings:
name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
value = binding[1]
pairs.append((name, value))
# Compile bindings
for name, value in pairs:
val_code = self._expr_to_wgsl(value)
lines.append(f"{ind}let {name} = {val_code};")
self.ctx.locals[name] = val_code
# Compile body
body_lines = self._compile_body(body, indent)
lines.append(body_lines)
return "\n".join(lines)
def _compile_body(self, body: Any, indent: int) -> str:
"""Compile the body of an effect (the final image expression)."""
ind = " " * indent
# Most effects end with a primitive call that produces the output
if isinstance(body, list) and body:
head = body[0]
if isinstance(head, Symbol) and ':' in head.name:
return self._compile_primitive_call(body, indent)
# If body is just 'frame', pass through
if isinstance(body, Symbol) and body.name == 'frame':
return f"{ind}output[idx] = input[idx];"
return f"{ind}// body: {body}"
def _compile_primitive_call(self, expr: list, indent: int) -> str:
"""Compile a primitive function call."""
ind = " " * indent
head = expr[0]
prim_name = head.name if isinstance(head, Symbol) else str(head)
args = expr[1:]
# Per-pixel color operations
if prim_name == 'color_ops:invert-img':
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let result = vec3<f32>(1.0, 1.0, 1.0) - rgb;
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'color_ops:grayscale':
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b;
{ind}let result = vec3<f32>(gray, gray, gray);
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'color_ops:adjust-brightness':
amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let adj = f32({amount}) / 255.0;
{ind}let result = clamp(rgb + vec3<f32>(adj, adj, adj), vec3<f32>(0.0, 0.0, 0.0), vec3<f32>(1.0, 1.0, 1.0));
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'color_ops:adjust':
# (adjust img brightness contrast)
brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let centered = rgb - vec3<f32>(0.5, 0.5, 0.5);
{ind}let contrasted = centered * {contrast};
{ind}let brightened = contrasted + vec3<f32>(0.5, 0.5, 0.5) + vec3<f32>({brightness}/255.0);
{ind}let result = clamp(brightened, vec3<f32>(0.0), vec3<f32>(1.0));
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'color_ops:sepia':
intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b;
{ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b;
{ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b;
{ind}let sepia = vec3<f32>(sepia_r, sepia_g, sepia_b);
{ind}let result = mix(rgb, sepia, {intensity});
{ind}output[idx] = pack_rgb(clamp(result, vec3<f32>(0.0), vec3<f32>(1.0)));"""
if prim_name == 'color_ops:shift-hsv':
h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0"
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}var hsv = rgb_to_hsv(rgb);
{ind}hsv.x = fract(hsv.x + {h_shift} / 360.0);
{ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0);
{ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0);
{ind}let result = hsv_to_rgb(hsv);
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'color_ops:quantize':
levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0"
return f"""{ind}let rgb = unpack_rgb(input[idx]);
{ind}let lvl = max(2.0, {levels});
{ind}let result = floor(rgb * lvl) / lvl;
{ind}output[idx] = pack_rgb(result);"""
# Geometric transforms
if prim_name == 'geometry:scale-img':
sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx
self.ctx.uses_sampling = True
return f"""{ind}let w = f32(params.width);
{ind}let h = f32(params.height);
{ind}let cx = w / 2.0;
{ind}let cy = h / 2.0;
{ind}let sx = f32(x) - cx;
{ind}let sy = f32(y) - cy;
{ind}let src_x = sx / {sx} + cx;
{ind}let src_y = sy / {sy} + cy;
{ind}let result = sample_bilinear(src_x, src_y);
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'geometry:rotate-img':
angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
self.ctx.uses_sampling = True
return f"""{ind}let w = f32(params.width);
{ind}let h = f32(params.height);
{ind}let cx = w / 2.0;
{ind}let cy = h / 2.0;
{ind}let angle_rad = {angle} * 3.14159265 / 180.0;
{ind}let cos_a = cos(-angle_rad);
{ind}let sin_a = sin(-angle_rad);
{ind}let dx = f32(x) - cx;
{ind}let dy = f32(y) - cy;
{ind}let src_x = dx * cos_a - dy * sin_a + cx;
{ind}let src_y = dx * sin_a + dy * cos_a + cy;
{ind}let result = sample_bilinear(src_x, src_y);
{ind}output[idx] = pack_rgb(result);"""
if prim_name == 'geometry:flip-h':
return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x);
{ind}output[idx] = input[src_idx];"""
if prim_name == 'geometry:flip-v':
return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x;
{ind}output[idx] = input[src_idx];"""
# Image library
if prim_name == 'image:blur':
radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5"
# Box blur approximation (separable would be better)
return f"""{ind}let radius = i32({radius});
{ind}var sum = vec3<f32>(0.0, 0.0, 0.0);
{ind}var count = 0.0;
{ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{
{ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{
{ind} let sx = i32(x) + dx;
{ind} let sy = i32(y) + dy;
{ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{
{ind} let sidx = u32(sy) * params.width + u32(sx);
{ind} sum = sum + unpack_rgb(input[sidx]);
{ind} count = count + 1.0;
{ind} }}
{ind} }}
{ind}}}
{ind}let result = sum / count;
{ind}output[idx] = pack_rgb(result);"""
# Fallback - passthrough
return f"""{ind}// Unimplemented primitive: {prim_name}
{ind}output[idx] = input[idx];"""
def _compile_if(self, expr: list, indent: int) -> str:
"""Compile if expression."""
ind = " " * indent
cond = self._expr_to_wgsl(expr[1])
then_expr = expr[2]
else_expr = expr[3] if len(expr) > 3 else None
lines = []
lines.append(f"{ind}if ({cond}) {{")
lines.append(self._compile_body(then_expr, indent + 4))
if else_expr:
lines.append(f"{ind}}} else {{")
lines.append(self._compile_body(else_expr, indent + 4))
lines.append(f"{ind}}}")
return "\n".join(lines)
def _compile_or(self, expr: list, indent: int) -> str:
"""Compile or expression - returns first truthy value."""
# For numeric context, (or a b) means "a if a != 0 else b"
a = self._expr_to_wgsl(expr[1])
b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0"
return f"select({b}, {a}, {a} != 0.0)"
def _compile_arithmetic(self, expr: list, indent: int) -> str:
"""Compile arithmetic expression to inline WGSL."""
op = expr[0].name
operands = [self._expr_to_wgsl(arg) for arg in expr[1:]]
if len(operands) == 1:
if op == '-':
return f"(-{operands[0]})"
return operands[0]
return f"({f' {op} '.join(operands)})"
def _compile_comparison(self, expr: list, indent: int) -> str:
"""Compile comparison expression."""
op = expr[0].name
if op == '=':
op = '=='
a = self._expr_to_wgsl(expr[1])
b = self._expr_to_wgsl(expr[2])
return f"({a} {op} {b})"
def _compile_builtin(self, fn: str, args: list, indent: int) -> str:
"""Compile builtin function call."""
compiled_args = [self._expr_to_wgsl(arg) for arg in args]
return f"{fn}({', '.join(compiled_args)})"
def _expr_to_wgsl(self, expr: Any) -> str:
"""Convert an expression to inline WGSL code."""
if isinstance(expr, (int, float)):
# Ensure floats have decimal point
if isinstance(expr, float) or '.' not in str(expr):
return f"{float(expr)}"
return str(expr)
if isinstance(expr, str):
return f'"{expr}"'
if isinstance(expr, Symbol):
name = expr.name
if name == 'frame':
return "rgb" # Assume rgb is already loaded
if name == 't' or name == '_time':
self.ctx.uses_time = True
return "params.time"
if name == 'pi':
return "3.14159265"
if name in self.ctx.params:
return f"params.{name}"
if name in self.ctx.locals:
return name
return name
if isinstance(expr, list) and expr:
head = expr[0]
if isinstance(head, Symbol):
form = head.name
# Arithmetic
if form in ('+', '-', '*', '/'):
return self._compile_arithmetic(expr, 0)
# Comparison
if form in ('>', '<', '>=', '<=', '='):
return self._compile_comparison(expr, 0)
# Builtins
if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'):
args = [self._expr_to_wgsl(a) for a in expr[1:]]
return f"{form}({', '.join(args)})"
if form == 'or':
return self._compile_or(expr, 0)
# Image dimension queries
if form == 'image:width':
return "f32(params.width)"
if form == 'image:height':
return "f32(params.height)"
return f"/* unknown: {expr} */"
def compile_effect(sexp_code: str) -> CompiledEffect:
"""Convenience function to compile an sexp effect string."""
compiler = SexpToWGSLCompiler()
return compiler.compile_string(sexp_code)
def compile_effect_file(path: str) -> CompiledEffect:
"""Convenience function to compile an sexp effect file."""
compiler = SexpToWGSLCompiler()
return compiler.compile_file(path)