- Add IPFSHLSOutput class that uploads segments to IPFS as they're created - Update streaming task to use IPFS HLS output for distributed streaming - Add /ipfs-stream endpoint to get IPFS playlist URL - Update /stream endpoint to redirect to IPFS when available - Add GPU persistence mode (STREAMING_GPU_PERSIST=1) to keep frames on GPU - Add hardware video decoding (NVDEC) support for faster video processing - Add GPU-accelerated primitive libraries: blending_gpu, color_ops_gpu, geometry_gpu - Add streaming_gpu module with GPUFrame class for tracking CPU/GPU data location - Add Dockerfile.gpu for building GPU-enabled worker image Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
716 lines
27 KiB
Python
716 lines
27 KiB
Python
"""
|
|
S-Expression to WGSL Compiler
|
|
|
|
Compiles sexp effect definitions to WGSL compute shaders for GPU execution.
|
|
The compilation happens at effect upload time (AOT), not at runtime.
|
|
|
|
Architecture:
|
|
- Parse sexp AST
|
|
- Analyze primitives used
|
|
- Generate WGSL compute shader
|
|
|
|
Shader Categories:
|
|
1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel)
|
|
2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample)
|
|
3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors)
|
|
"""
|
|
|
|
from typing import Dict, List, Any, Optional, Tuple, Set
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
import math
|
|
|
|
from .parser import parse, parse_all, Symbol, Keyword
|
|
|
|
|
|
@dataclass
|
|
class WGSLParam:
|
|
"""A shader parameter (uniform)."""
|
|
name: str
|
|
wgsl_type: str # f32, i32, u32, vec2f, etc.
|
|
default: Any
|
|
|
|
|
|
@dataclass
|
|
class CompiledEffect:
|
|
"""Result of compiling an sexp effect to WGSL."""
|
|
name: str
|
|
wgsl_code: str
|
|
params: List[WGSLParam]
|
|
workgroup_size: Tuple[int, int, int] = (16, 16, 1)
|
|
# Metadata for runtime
|
|
uses_time: bool = False
|
|
uses_sampling: bool = False # Needs texture sampler
|
|
category: str = "per_pixel" # per_pixel, geometric, neighborhood
|
|
|
|
|
|
@dataclass
|
|
class CompilerContext:
|
|
"""Context during compilation."""
|
|
effect_name: str = ""
|
|
params: Dict[str, WGSLParam] = field(default_factory=dict)
|
|
locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr
|
|
required_libs: Set[str] = field(default_factory=set)
|
|
uses_time: bool = False
|
|
uses_sampling: bool = False
|
|
temp_counter: int = 0
|
|
|
|
def fresh_temp(self) -> str:
|
|
"""Generate a fresh temporary variable name."""
|
|
self.temp_counter += 1
|
|
return f"_t{self.temp_counter}"
|
|
|
|
|
|
class SexpToWGSLCompiler:
|
|
"""
|
|
Compiles S-expression effect definitions to WGSL compute shaders.
|
|
"""
|
|
|
|
# Map sexp types to WGSL types
|
|
TYPE_MAP = {
|
|
'int': 'i32',
|
|
'float': 'f32',
|
|
'bool': 'u32', # WGSL doesn't have bool in storage
|
|
'string': None, # Strings handled specially
|
|
}
|
|
|
|
# Per-pixel primitives that can be compiled directly
|
|
PER_PIXEL_PRIMITIVES = {
|
|
'color_ops:invert-img',
|
|
'color_ops:grayscale',
|
|
'color_ops:sepia',
|
|
'color_ops:adjust',
|
|
'color_ops:adjust-brightness',
|
|
'color_ops:shift-hsv',
|
|
'color_ops:quantize',
|
|
}
|
|
|
|
# Geometric primitives (coordinate remapping)
|
|
GEOMETRIC_PRIMITIVES = {
|
|
'geometry:scale-img',
|
|
'geometry:rotate-img',
|
|
'geometry:translate',
|
|
'geometry:flip-h',
|
|
'geometry:flip-v',
|
|
'geometry:remap',
|
|
}
|
|
|
|
def __init__(self):
|
|
self.ctx: Optional[CompilerContext] = None
|
|
|
|
def compile_file(self, path: str) -> CompiledEffect:
|
|
"""Compile an effect from a .sexp file."""
|
|
with open(path, 'r') as f:
|
|
content = f.read()
|
|
exprs = parse_all(content)
|
|
return self.compile(exprs)
|
|
|
|
def compile_string(self, sexp_code: str) -> CompiledEffect:
|
|
"""Compile an effect from an sexp string."""
|
|
exprs = parse_all(sexp_code)
|
|
return self.compile(exprs)
|
|
|
|
def compile(self, expr: Any) -> CompiledEffect:
|
|
"""Compile a parsed sexp expression."""
|
|
self.ctx = CompilerContext()
|
|
|
|
# Handle multiple top-level expressions (require-primitives, define-effect)
|
|
if isinstance(expr, list) and expr and isinstance(expr[0], list):
|
|
for e in expr:
|
|
self._process_toplevel(e)
|
|
else:
|
|
self._process_toplevel(expr)
|
|
|
|
# Generate the WGSL shader
|
|
wgsl = self._generate_wgsl()
|
|
|
|
# Determine category based on primitives used
|
|
category = self._determine_category()
|
|
|
|
return CompiledEffect(
|
|
name=self.ctx.effect_name,
|
|
wgsl_code=wgsl,
|
|
params=list(self.ctx.params.values()),
|
|
uses_time=self.ctx.uses_time,
|
|
uses_sampling=self.ctx.uses_sampling,
|
|
category=category,
|
|
)
|
|
|
|
def _process_toplevel(self, expr: Any):
|
|
"""Process a top-level expression."""
|
|
if not isinstance(expr, list) or not expr:
|
|
return
|
|
|
|
head = expr[0]
|
|
if isinstance(head, Symbol):
|
|
if head.name == 'require-primitives':
|
|
# Track required primitive libraries
|
|
for lib in expr[1:]:
|
|
lib_name = lib.name if isinstance(lib, Symbol) else str(lib)
|
|
self.ctx.required_libs.add(lib_name)
|
|
|
|
elif head.name == 'define-effect':
|
|
self._compile_effect_def(expr)
|
|
|
|
def _compile_effect_def(self, expr: list):
|
|
"""Compile a define-effect form."""
|
|
# (define-effect name :params (...) body)
|
|
self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1])
|
|
|
|
# Parse :params and body
|
|
i = 2
|
|
body = None
|
|
while i < len(expr):
|
|
item = expr[i]
|
|
if isinstance(item, Keyword) and item.name == 'params':
|
|
self._parse_params(expr[i + 1])
|
|
i += 2
|
|
elif isinstance(item, Keyword):
|
|
i += 2 # Skip other keywords
|
|
else:
|
|
body = item
|
|
i += 1
|
|
|
|
if body:
|
|
self.ctx.body_expr = body
|
|
|
|
def _parse_params(self, params_list: list):
|
|
"""Parse the :params block."""
|
|
for param_def in params_list:
|
|
if not isinstance(param_def, list):
|
|
continue
|
|
|
|
name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0])
|
|
|
|
# Parse keyword args
|
|
param_type = 'float'
|
|
default = 0
|
|
|
|
i = 1
|
|
while i < len(param_def):
|
|
item = param_def[i]
|
|
if isinstance(item, Keyword):
|
|
if i + 1 < len(param_def):
|
|
val = param_def[i + 1]
|
|
if item.name == 'type':
|
|
param_type = val.name if isinstance(val, Symbol) else str(val)
|
|
elif item.name == 'default':
|
|
default = val
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
|
|
wgsl_type = self.TYPE_MAP.get(param_type, 'f32')
|
|
if wgsl_type:
|
|
self.ctx.params[name] = WGSLParam(name, wgsl_type, default)
|
|
|
|
def _determine_category(self) -> str:
|
|
"""Determine shader category based on primitives used."""
|
|
for lib in self.ctx.required_libs:
|
|
if lib == 'geometry':
|
|
return 'geometric'
|
|
if lib == 'filters':
|
|
return 'neighborhood'
|
|
return 'per_pixel'
|
|
|
|
def _generate_wgsl(self) -> str:
|
|
"""Generate the complete WGSL shader code."""
|
|
lines = []
|
|
|
|
# Header comment
|
|
lines.append(f"// WGSL Shader: {self.ctx.effect_name}")
|
|
lines.append(f"// Auto-generated from sexp effect definition")
|
|
lines.append("")
|
|
|
|
# Bindings
|
|
lines.append("@group(0) @binding(0) var<storage, read> input: array<u32>;")
|
|
lines.append("@group(0) @binding(1) var<storage, read_write> output: array<u32>;")
|
|
lines.append("")
|
|
|
|
# Params struct
|
|
if self.ctx.params:
|
|
lines.append("struct Params {")
|
|
lines.append(" width: u32,")
|
|
lines.append(" height: u32,")
|
|
lines.append(" time: f32,")
|
|
for param in self.ctx.params.values():
|
|
lines.append(f" {param.name}: {param.wgsl_type},")
|
|
lines.append("}")
|
|
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
|
|
else:
|
|
lines.append("struct Params {")
|
|
lines.append(" width: u32,")
|
|
lines.append(" height: u32,")
|
|
lines.append(" time: f32,")
|
|
lines.append("}")
|
|
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
|
|
lines.append("")
|
|
|
|
# Helper functions
|
|
lines.extend(self._generate_helpers())
|
|
lines.append("")
|
|
|
|
# Main compute shader
|
|
lines.append("@compute @workgroup_size(16, 16, 1)")
|
|
lines.append("fn main(@builtin(global_invocation_id) gid: vec3<u32>) {")
|
|
lines.append(" let x = gid.x;")
|
|
lines.append(" let y = gid.y;")
|
|
lines.append(" if (x >= params.width || y >= params.height) { return; }")
|
|
lines.append(" let idx = y * params.width + x;")
|
|
lines.append("")
|
|
|
|
# Compile the effect body
|
|
body_code = self._compile_expr(self.ctx.body_expr)
|
|
lines.append(f" // Effect: {self.ctx.effect_name}")
|
|
lines.append(body_code)
|
|
lines.append("}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_helpers(self) -> List[str]:
|
|
"""Generate WGSL helper functions."""
|
|
helpers = []
|
|
|
|
# Pack/unpack RGB from u32
|
|
helpers.append("fn unpack_rgb(packed: u32) -> vec3<f32> {")
|
|
helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;")
|
|
helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;")
|
|
helpers.append(" let b = f32(packed & 0xFFu) / 255.0;")
|
|
helpers.append(" return vec3<f32>(r, g, b);")
|
|
helpers.append("}")
|
|
helpers.append("")
|
|
|
|
helpers.append("fn pack_rgb(rgb: vec3<f32>) -> u32 {")
|
|
helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);")
|
|
helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);")
|
|
helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);")
|
|
helpers.append(" return (r << 16u) | (g << 8u) | b;")
|
|
helpers.append("}")
|
|
helpers.append("")
|
|
|
|
# Bilinear sampling for geometric transforms
|
|
if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs:
|
|
helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3<f32> {")
|
|
helpers.append(" let w = f32(params.width);")
|
|
helpers.append(" let h = f32(params.height);")
|
|
helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);")
|
|
helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);")
|
|
helpers.append(" let x0 = u32(cx);")
|
|
helpers.append(" let y0 = u32(cy);")
|
|
helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);")
|
|
helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);")
|
|
helpers.append(" let fx = cx - f32(x0);")
|
|
helpers.append(" let fy = cy - f32(y0);")
|
|
helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);")
|
|
helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);")
|
|
helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);")
|
|
helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);")
|
|
helpers.append(" let top = mix(c00, c10, fx);")
|
|
helpers.append(" let bot = mix(c01, c11, fx);")
|
|
helpers.append(" return mix(top, bot, fy);")
|
|
helpers.append("}")
|
|
helpers.append("")
|
|
|
|
# HSV conversion for color effects
|
|
if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs:
|
|
helpers.append("fn rgb_to_hsv(rgb: vec3<f32>) -> vec3<f32> {")
|
|
helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);")
|
|
helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);")
|
|
helpers.append(" let d = mx - mn;")
|
|
helpers.append(" var h = 0.0;")
|
|
helpers.append(" if (d > 0.0) {")
|
|
helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }")
|
|
helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }")
|
|
helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }")
|
|
helpers.append(" h = h / 6.0;")
|
|
helpers.append(" if (h < 0.0) { h = h + 1.0; }")
|
|
helpers.append(" }")
|
|
helpers.append(" let s = select(0.0, d / mx, mx > 0.0);")
|
|
helpers.append(" return vec3<f32>(h, s, mx);")
|
|
helpers.append("}")
|
|
helpers.append("")
|
|
|
|
helpers.append("fn hsv_to_rgb(hsv: vec3<f32>) -> vec3<f32> {")
|
|
helpers.append(" let h = hsv.x * 6.0;")
|
|
helpers.append(" let s = hsv.y;")
|
|
helpers.append(" let v = hsv.z;")
|
|
helpers.append(" let c = v * s;")
|
|
helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));")
|
|
helpers.append(" let m = v - c;")
|
|
helpers.append(" var rgb: vec3<f32>;")
|
|
helpers.append(" if (h < 1.0) { rgb = vec3<f32>(c, x, 0.0); }")
|
|
helpers.append(" else if (h < 2.0) { rgb = vec3<f32>(x, c, 0.0); }")
|
|
helpers.append(" else if (h < 3.0) { rgb = vec3<f32>(0.0, c, x); }")
|
|
helpers.append(" else if (h < 4.0) { rgb = vec3<f32>(0.0, x, c); }")
|
|
helpers.append(" else if (h < 5.0) { rgb = vec3<f32>(x, 0.0, c); }")
|
|
helpers.append(" else { rgb = vec3<f32>(c, 0.0, x); }")
|
|
helpers.append(" return rgb + vec3<f32>(m, m, m);")
|
|
helpers.append("}")
|
|
helpers.append("")
|
|
|
|
return helpers
|
|
|
|
def _compile_expr(self, expr: Any, indent: int = 4) -> str:
|
|
"""Compile an sexp expression to WGSL code."""
|
|
ind = " " * indent
|
|
|
|
# Literals
|
|
if isinstance(expr, (int, float)):
|
|
return f"{ind}// literal: {expr}"
|
|
|
|
if isinstance(expr, str):
|
|
return f'{ind}// string: "{expr}"'
|
|
|
|
# Symbol reference
|
|
if isinstance(expr, Symbol):
|
|
name = expr.name
|
|
if name == 'frame':
|
|
return f"{ind}let rgb = unpack_rgb(input[idx]);"
|
|
if name == 't' or name == '_time':
|
|
self.ctx.uses_time = True
|
|
return f"{ind}let t = params.time;"
|
|
if name in self.ctx.params:
|
|
return f"{ind}let {name} = params.{name};"
|
|
if name in self.ctx.locals:
|
|
return f"{ind}// local: {name}"
|
|
return f"{ind}// unknown symbol: {name}"
|
|
|
|
# List (function call or special form)
|
|
if isinstance(expr, list) and expr:
|
|
head = expr[0]
|
|
|
|
if isinstance(head, Symbol):
|
|
form = head.name
|
|
|
|
# Special forms
|
|
if form == 'let' or form == 'let*':
|
|
return self._compile_let(expr, indent)
|
|
|
|
if form == 'if':
|
|
return self._compile_if(expr, indent)
|
|
|
|
if form == 'or':
|
|
# (or a b) - return a if truthy, else b
|
|
return self._compile_or(expr, indent)
|
|
|
|
# Primitive calls
|
|
if ':' in form:
|
|
return self._compile_primitive_call(expr, indent)
|
|
|
|
# Arithmetic
|
|
if form in ('+', '-', '*', '/'):
|
|
return self._compile_arithmetic(expr, indent)
|
|
|
|
if form in ('>', '<', '>=', '<=', '='):
|
|
return self._compile_comparison(expr, indent)
|
|
|
|
if form == 'max':
|
|
return self._compile_builtin('max', expr[1:], indent)
|
|
|
|
if form == 'min':
|
|
return self._compile_builtin('min', expr[1:], indent)
|
|
|
|
return f"{ind}// unhandled: {expr}"
|
|
|
|
def _compile_let(self, expr: list, indent: int) -> str:
|
|
"""Compile let/let* binding form."""
|
|
ind = " " * indent
|
|
lines = []
|
|
|
|
bindings = expr[1]
|
|
body = expr[2]
|
|
|
|
# Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2)))
|
|
pairs = []
|
|
if bindings and isinstance(bindings[0], Symbol):
|
|
# Clojure style
|
|
i = 0
|
|
while i < len(bindings) - 1:
|
|
name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
|
|
value = bindings[i + 1]
|
|
pairs.append((name, value))
|
|
i += 2
|
|
else:
|
|
# Scheme style
|
|
for binding in bindings:
|
|
name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
|
|
value = binding[1]
|
|
pairs.append((name, value))
|
|
|
|
# Compile bindings
|
|
for name, value in pairs:
|
|
val_code = self._expr_to_wgsl(value)
|
|
lines.append(f"{ind}let {name} = {val_code};")
|
|
self.ctx.locals[name] = val_code
|
|
|
|
# Compile body
|
|
body_lines = self._compile_body(body, indent)
|
|
lines.append(body_lines)
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _compile_body(self, body: Any, indent: int) -> str:
|
|
"""Compile the body of an effect (the final image expression)."""
|
|
ind = " " * indent
|
|
|
|
# Most effects end with a primitive call that produces the output
|
|
if isinstance(body, list) and body:
|
|
head = body[0]
|
|
if isinstance(head, Symbol) and ':' in head.name:
|
|
return self._compile_primitive_call(body, indent)
|
|
|
|
# If body is just 'frame', pass through
|
|
if isinstance(body, Symbol) and body.name == 'frame':
|
|
return f"{ind}output[idx] = input[idx];"
|
|
|
|
return f"{ind}// body: {body}"
|
|
|
|
def _compile_primitive_call(self, expr: list, indent: int) -> str:
|
|
"""Compile a primitive function call."""
|
|
ind = " " * indent
|
|
head = expr[0]
|
|
prim_name = head.name if isinstance(head, Symbol) else str(head)
|
|
args = expr[1:]
|
|
|
|
# Per-pixel color operations
|
|
if prim_name == 'color_ops:invert-img':
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let result = vec3<f32>(1.0, 1.0, 1.0) - rgb;
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'color_ops:grayscale':
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b;
|
|
{ind}let result = vec3<f32>(gray, gray, gray);
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'color_ops:adjust-brightness':
|
|
amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let adj = f32({amount}) / 255.0;
|
|
{ind}let result = clamp(rgb + vec3<f32>(adj, adj, adj), vec3<f32>(0.0, 0.0, 0.0), vec3<f32>(1.0, 1.0, 1.0));
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'color_ops:adjust':
|
|
# (adjust img brightness contrast)
|
|
brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
|
contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let centered = rgb - vec3<f32>(0.5, 0.5, 0.5);
|
|
{ind}let contrasted = centered * {contrast};
|
|
{ind}let brightened = contrasted + vec3<f32>(0.5, 0.5, 0.5) + vec3<f32>({brightness}/255.0);
|
|
{ind}let result = clamp(brightened, vec3<f32>(0.0), vec3<f32>(1.0));
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'color_ops:sepia':
|
|
intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b;
|
|
{ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b;
|
|
{ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b;
|
|
{ind}let sepia = vec3<f32>(sepia_r, sepia_g, sepia_b);
|
|
{ind}let result = mix(rgb, sepia, {intensity});
|
|
{ind}output[idx] = pack_rgb(clamp(result, vec3<f32>(0.0), vec3<f32>(1.0)));"""
|
|
|
|
if prim_name == 'color_ops:shift-hsv':
|
|
h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
|
s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
|
|
v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0"
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}var hsv = rgb_to_hsv(rgb);
|
|
{ind}hsv.x = fract(hsv.x + {h_shift} / 360.0);
|
|
{ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0);
|
|
{ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0);
|
|
{ind}let result = hsv_to_rgb(hsv);
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'color_ops:quantize':
|
|
levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0"
|
|
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
|
{ind}let lvl = max(2.0, {levels});
|
|
{ind}let result = floor(rgb * lvl) / lvl;
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
# Geometric transforms
|
|
if prim_name == 'geometry:scale-img':
|
|
sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
|
|
sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx
|
|
self.ctx.uses_sampling = True
|
|
return f"""{ind}let w = f32(params.width);
|
|
{ind}let h = f32(params.height);
|
|
{ind}let cx = w / 2.0;
|
|
{ind}let cy = h / 2.0;
|
|
{ind}let sx = f32(x) - cx;
|
|
{ind}let sy = f32(y) - cy;
|
|
{ind}let src_x = sx / {sx} + cx;
|
|
{ind}let src_y = sy / {sy} + cy;
|
|
{ind}let result = sample_bilinear(src_x, src_y);
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'geometry:rotate-img':
|
|
angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
|
self.ctx.uses_sampling = True
|
|
return f"""{ind}let w = f32(params.width);
|
|
{ind}let h = f32(params.height);
|
|
{ind}let cx = w / 2.0;
|
|
{ind}let cy = h / 2.0;
|
|
{ind}let angle_rad = {angle} * 3.14159265 / 180.0;
|
|
{ind}let cos_a = cos(-angle_rad);
|
|
{ind}let sin_a = sin(-angle_rad);
|
|
{ind}let dx = f32(x) - cx;
|
|
{ind}let dy = f32(y) - cy;
|
|
{ind}let src_x = dx * cos_a - dy * sin_a + cx;
|
|
{ind}let src_y = dx * sin_a + dy * cos_a + cy;
|
|
{ind}let result = sample_bilinear(src_x, src_y);
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
if prim_name == 'geometry:flip-h':
|
|
return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x);
|
|
{ind}output[idx] = input[src_idx];"""
|
|
|
|
if prim_name == 'geometry:flip-v':
|
|
return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x;
|
|
{ind}output[idx] = input[src_idx];"""
|
|
|
|
# Image library
|
|
if prim_name == 'image:blur':
|
|
radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5"
|
|
# Box blur approximation (separable would be better)
|
|
return f"""{ind}let radius = i32({radius});
|
|
{ind}var sum = vec3<f32>(0.0, 0.0, 0.0);
|
|
{ind}var count = 0.0;
|
|
{ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{
|
|
{ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{
|
|
{ind} let sx = i32(x) + dx;
|
|
{ind} let sy = i32(y) + dy;
|
|
{ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{
|
|
{ind} let sidx = u32(sy) * params.width + u32(sx);
|
|
{ind} sum = sum + unpack_rgb(input[sidx]);
|
|
{ind} count = count + 1.0;
|
|
{ind} }}
|
|
{ind} }}
|
|
{ind}}}
|
|
{ind}let result = sum / count;
|
|
{ind}output[idx] = pack_rgb(result);"""
|
|
|
|
# Fallback - passthrough
|
|
return f"""{ind}// Unimplemented primitive: {prim_name}
|
|
{ind}output[idx] = input[idx];"""
|
|
|
|
def _compile_if(self, expr: list, indent: int) -> str:
|
|
"""Compile if expression."""
|
|
ind = " " * indent
|
|
cond = self._expr_to_wgsl(expr[1])
|
|
then_expr = expr[2]
|
|
else_expr = expr[3] if len(expr) > 3 else None
|
|
|
|
lines = []
|
|
lines.append(f"{ind}if ({cond}) {{")
|
|
lines.append(self._compile_body(then_expr, indent + 4))
|
|
if else_expr:
|
|
lines.append(f"{ind}}} else {{")
|
|
lines.append(self._compile_body(else_expr, indent + 4))
|
|
lines.append(f"{ind}}}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _compile_or(self, expr: list, indent: int) -> str:
|
|
"""Compile or expression - returns first truthy value."""
|
|
# For numeric context, (or a b) means "a if a != 0 else b"
|
|
a = self._expr_to_wgsl(expr[1])
|
|
b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0"
|
|
return f"select({b}, {a}, {a} != 0.0)"
|
|
|
|
def _compile_arithmetic(self, expr: list, indent: int) -> str:
|
|
"""Compile arithmetic expression to inline WGSL."""
|
|
op = expr[0].name
|
|
operands = [self._expr_to_wgsl(arg) for arg in expr[1:]]
|
|
|
|
if len(operands) == 1:
|
|
if op == '-':
|
|
return f"(-{operands[0]})"
|
|
return operands[0]
|
|
|
|
return f"({f' {op} '.join(operands)})"
|
|
|
|
def _compile_comparison(self, expr: list, indent: int) -> str:
|
|
"""Compile comparison expression."""
|
|
op = expr[0].name
|
|
if op == '=':
|
|
op = '=='
|
|
a = self._expr_to_wgsl(expr[1])
|
|
b = self._expr_to_wgsl(expr[2])
|
|
return f"({a} {op} {b})"
|
|
|
|
def _compile_builtin(self, fn: str, args: list, indent: int) -> str:
|
|
"""Compile builtin function call."""
|
|
compiled_args = [self._expr_to_wgsl(arg) for arg in args]
|
|
return f"{fn}({', '.join(compiled_args)})"
|
|
|
|
def _expr_to_wgsl(self, expr: Any) -> str:
|
|
"""Convert an expression to inline WGSL code."""
|
|
if isinstance(expr, (int, float)):
|
|
# Ensure floats have decimal point
|
|
if isinstance(expr, float) or '.' not in str(expr):
|
|
return f"{float(expr)}"
|
|
return str(expr)
|
|
|
|
if isinstance(expr, str):
|
|
return f'"{expr}"'
|
|
|
|
if isinstance(expr, Symbol):
|
|
name = expr.name
|
|
if name == 'frame':
|
|
return "rgb" # Assume rgb is already loaded
|
|
if name == 't' or name == '_time':
|
|
self.ctx.uses_time = True
|
|
return "params.time"
|
|
if name == 'pi':
|
|
return "3.14159265"
|
|
if name in self.ctx.params:
|
|
return f"params.{name}"
|
|
if name in self.ctx.locals:
|
|
return name
|
|
return name
|
|
|
|
if isinstance(expr, list) and expr:
|
|
head = expr[0]
|
|
if isinstance(head, Symbol):
|
|
form = head.name
|
|
|
|
# Arithmetic
|
|
if form in ('+', '-', '*', '/'):
|
|
return self._compile_arithmetic(expr, 0)
|
|
|
|
# Comparison
|
|
if form in ('>', '<', '>=', '<=', '='):
|
|
return self._compile_comparison(expr, 0)
|
|
|
|
# Builtins
|
|
if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'):
|
|
args = [self._expr_to_wgsl(a) for a in expr[1:]]
|
|
return f"{form}({', '.join(args)})"
|
|
|
|
if form == 'or':
|
|
return self._compile_or(expr, 0)
|
|
|
|
# Image dimension queries
|
|
if form == 'image:width':
|
|
return "f32(params.width)"
|
|
if form == 'image:height':
|
|
return "f32(params.height)"
|
|
|
|
return f"/* unknown: {expr} */"
|
|
|
|
|
|
def compile_effect(sexp_code: str) -> CompiledEffect:
|
|
"""Convenience function to compile an sexp effect string."""
|
|
compiler = SexpToWGSLCompiler()
|
|
return compiler.compile_string(sexp_code)
|
|
|
|
|
|
def compile_effect_file(path: str) -> CompiledEffect:
|
|
"""Convenience function to compile an sexp effect file."""
|
|
compiler = SexpToWGSLCompiler()
|
|
return compiler.compile_file(path)
|