Add IPFS HLS streaming and GPU optimizations
- Add IPFSHLSOutput class that uploads segments to IPFS as they're created - Update streaming task to use IPFS HLS output for distributed streaming - Add /ipfs-stream endpoint to get IPFS playlist URL - Update /stream endpoint to redirect to IPFS when available - Add GPU persistence mode (STREAMING_GPU_PERSIST=1) to keep frames on GPU - Add hardware video decoding (NVDEC) support for faster video processing - Add GPU-accelerated primitive libraries: blending_gpu, color_ops_gpu, geometry_gpu - Add streaming_gpu module with GPUFrame class for tracking CPU/GPU data location - Add Dockerfile.gpu for building GPU-enabled worker image Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -385,9 +385,9 @@ def _serialize_pretty(expr: List, indent: int) -> str:
|
||||
|
||||
|
||||
def parse_file(path: str) -> Any:
|
||||
"""Parse an S-expression file."""
|
||||
"""Parse an S-expression file (supports multiple top-level expressions)."""
|
||||
with open(path, 'r') as f:
|
||||
return parse(f.read())
|
||||
return parse_all(f.read())
|
||||
|
||||
|
||||
def to_sexp(obj: Any) -> str:
|
||||
|
||||
220
sexp_effects/primitive_libs/blending_gpu.py
Normal file
220
sexp_effects/primitive_libs/blending_gpu.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
GPU-Accelerated Blending Primitives Library
|
||||
|
||||
Uses CuPy for CUDA-accelerated image blending and compositing.
|
||||
Keeps frames on GPU when STREAMING_GPU_PERSIST=1 for maximum performance.
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
try:
|
||||
import cupy as cp
|
||||
GPU_AVAILABLE = True
|
||||
print("[blending_gpu] CuPy GPU acceleration enabled")
|
||||
except ImportError:
|
||||
cp = np
|
||||
GPU_AVAILABLE = False
|
||||
print("[blending_gpu] CuPy not available, using CPU fallback")
|
||||
|
||||
# GPU persistence mode - keep frames on GPU between operations
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "1") == "1"
|
||||
if GPU_AVAILABLE and GPU_PERSIST:
|
||||
print("[blending_gpu] GPU persistence enabled - frames stay on GPU")
|
||||
|
||||
|
||||
def _to_gpu(img):
|
||||
"""Move image to GPU if available."""
|
||||
if GPU_AVAILABLE and not isinstance(img, cp.ndarray):
|
||||
return cp.asarray(img)
|
||||
return img
|
||||
|
||||
|
||||
def _to_cpu(img):
|
||||
"""Move image back to CPU (only if GPU_PERSIST is disabled)."""
|
||||
if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray):
|
||||
return cp.asnumpy(img)
|
||||
return img
|
||||
|
||||
|
||||
def _get_xp(img):
|
||||
"""Get the array module (numpy or cupy) for the given image."""
|
||||
if GPU_AVAILABLE and isinstance(img, cp.ndarray):
|
||||
return cp
|
||||
return np
|
||||
|
||||
|
||||
def prim_blend_images(a, b, alpha):
|
||||
"""Blend two images: a * (1-alpha) + b * alpha."""
|
||||
alpha = max(0.0, min(1.0, float(alpha)))
|
||||
|
||||
if GPU_AVAILABLE:
|
||||
a_gpu = _to_gpu(a)
|
||||
b_gpu = _to_gpu(b)
|
||||
result = (a_gpu.astype(cp.float32) * (1 - alpha) + b_gpu.astype(cp.float32) * alpha).astype(cp.uint8)
|
||||
return _to_cpu(result)
|
||||
|
||||
return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_blend_mode(a, b, mode):
|
||||
"""Blend using Photoshop-style blend modes."""
|
||||
if GPU_AVAILABLE:
|
||||
a_gpu = _to_gpu(a).astype(cp.float32) / 255
|
||||
b_gpu = _to_gpu(b).astype(cp.float32) / 255
|
||||
xp = cp
|
||||
else:
|
||||
a_gpu = a.astype(float) / 255
|
||||
b_gpu = b.astype(float) / 255
|
||||
xp = np
|
||||
|
||||
if mode == "multiply":
|
||||
result = a_gpu * b_gpu
|
||||
elif mode == "screen":
|
||||
result = 1 - (1 - a_gpu) * (1 - b_gpu)
|
||||
elif mode == "overlay":
|
||||
mask = a_gpu < 0.5
|
||||
result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu))
|
||||
elif mode == "soft-light":
|
||||
mask = b_gpu < 0.5
|
||||
result = xp.where(mask,
|
||||
a_gpu - (1 - 2 * b_gpu) * a_gpu * (1 - a_gpu),
|
||||
a_gpu + (2 * b_gpu - 1) * (xp.sqrt(a_gpu) - a_gpu))
|
||||
elif mode == "hard-light":
|
||||
mask = b_gpu < 0.5
|
||||
result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu))
|
||||
elif mode == "color-dodge":
|
||||
result = xp.clip(a_gpu / (1 - b_gpu + 0.001), 0, 1)
|
||||
elif mode == "color-burn":
|
||||
result = 1 - xp.clip((1 - a_gpu) / (b_gpu + 0.001), 0, 1)
|
||||
elif mode == "difference":
|
||||
result = xp.abs(a_gpu - b_gpu)
|
||||
elif mode == "exclusion":
|
||||
result = a_gpu + b_gpu - 2 * a_gpu * b_gpu
|
||||
elif mode == "add":
|
||||
result = xp.clip(a_gpu + b_gpu, 0, 1)
|
||||
elif mode == "subtract":
|
||||
result = xp.clip(a_gpu - b_gpu, 0, 1)
|
||||
elif mode == "darken":
|
||||
result = xp.minimum(a_gpu, b_gpu)
|
||||
elif mode == "lighten":
|
||||
result = xp.maximum(a_gpu, b_gpu)
|
||||
else:
|
||||
# Default to normal (just return b)
|
||||
result = b_gpu
|
||||
|
||||
result = (result * 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
|
||||
|
||||
def prim_mask(img, mask_img):
|
||||
"""Apply grayscale mask to image (white=opaque, black=transparent)."""
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img)
|
||||
mask_gpu = _to_gpu(mask_img)
|
||||
|
||||
if len(mask_gpu.shape) == 3:
|
||||
mask = mask_gpu[:, :, 0].astype(cp.float32) / 255
|
||||
else:
|
||||
mask = mask_gpu.astype(cp.float32) / 255
|
||||
|
||||
mask = mask[:, :, cp.newaxis]
|
||||
result = (img_gpu.astype(cp.float32) * mask).astype(cp.uint8)
|
||||
return _to_cpu(result)
|
||||
|
||||
if len(mask_img.shape) == 3:
|
||||
mask = mask_img[:, :, 0].astype(float) / 255
|
||||
else:
|
||||
mask = mask_img.astype(float) / 255
|
||||
|
||||
mask = mask[:, :, np.newaxis]
|
||||
return (img.astype(float) * mask).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_alpha_composite(base, overlay, alpha_channel):
|
||||
"""Composite overlay onto base using alpha channel."""
|
||||
if GPU_AVAILABLE:
|
||||
base_gpu = _to_gpu(base)
|
||||
overlay_gpu = _to_gpu(overlay)
|
||||
alpha_gpu = _to_gpu(alpha_channel)
|
||||
|
||||
if len(alpha_gpu.shape) == 3:
|
||||
alpha = alpha_gpu[:, :, 0].astype(cp.float32) / 255
|
||||
else:
|
||||
alpha = alpha_gpu.astype(cp.float32) / 255
|
||||
|
||||
alpha = alpha[:, :, cp.newaxis]
|
||||
result = base_gpu.astype(cp.float32) * (1 - alpha) + overlay_gpu.astype(cp.float32) * alpha
|
||||
return _to_cpu(result.astype(cp.uint8))
|
||||
|
||||
if len(alpha_channel.shape) == 3:
|
||||
alpha = alpha_channel[:, :, 0].astype(float) / 255
|
||||
else:
|
||||
alpha = alpha_channel.astype(float) / 255
|
||||
|
||||
alpha = alpha[:, :, np.newaxis]
|
||||
result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha
|
||||
return result.astype(np.uint8)
|
||||
|
||||
|
||||
def prim_overlay(base, overlay, x, y, alpha=1.0):
|
||||
"""Overlay image at position (x, y) with optional alpha."""
|
||||
if GPU_AVAILABLE:
|
||||
base_gpu = _to_gpu(base)
|
||||
overlay_gpu = _to_gpu(overlay)
|
||||
result = base_gpu.copy()
|
||||
|
||||
x, y = int(x), int(y)
|
||||
oh, ow = overlay_gpu.shape[:2]
|
||||
bh, bw = base_gpu.shape[:2]
|
||||
|
||||
# Clip to bounds
|
||||
sx1 = max(0, -x)
|
||||
sy1 = max(0, -y)
|
||||
dx1 = max(0, x)
|
||||
dy1 = max(0, y)
|
||||
sx2 = min(ow, bw - x)
|
||||
sy2 = min(oh, bh - y)
|
||||
|
||||
if sx2 > sx1 and sy2 > sy1:
|
||||
src = overlay_gpu[sy1:sy2, sx1:sx2]
|
||||
dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)]
|
||||
blended = (dst.astype(cp.float32) * (1 - alpha) + src.astype(cp.float32) * alpha)
|
||||
result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(cp.uint8)
|
||||
|
||||
return _to_cpu(result)
|
||||
|
||||
result = base.copy()
|
||||
x, y = int(x), int(y)
|
||||
oh, ow = overlay.shape[:2]
|
||||
bh, bw = base.shape[:2]
|
||||
|
||||
# Clip to bounds
|
||||
sx1 = max(0, -x)
|
||||
sy1 = max(0, -y)
|
||||
dx1 = max(0, x)
|
||||
dy1 = max(0, y)
|
||||
sx2 = min(ow, bw - x)
|
||||
sy2 = min(oh, bh - y)
|
||||
|
||||
if sx2 > sx1 and sy2 > sy1:
|
||||
src = overlay[sy1:sy2, sx1:sx2]
|
||||
dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)]
|
||||
blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha)
|
||||
result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
PRIMITIVES = {
|
||||
# Basic blending
|
||||
'blend-images': prim_blend_images,
|
||||
'blend-mode': prim_blend_mode,
|
||||
|
||||
# Masking
|
||||
'mask': prim_mask,
|
||||
'alpha-composite': prim_alpha_composite,
|
||||
|
||||
# Overlay
|
||||
'overlay': prim_overlay,
|
||||
}
|
||||
280
sexp_effects/primitive_libs/color_ops_gpu.py
Normal file
280
sexp_effects/primitive_libs/color_ops_gpu.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
GPU-Accelerated Color Operations Library
|
||||
|
||||
Uses CuPy for CUDA-accelerated color transforms.
|
||||
|
||||
Performance Mode:
|
||||
- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations
|
||||
- This dramatically improves performance by avoiding CPU<->GPU transfers
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
try:
|
||||
import cupy as cp
|
||||
GPU_AVAILABLE = True
|
||||
print("[color_ops_gpu] CuPy GPU acceleration enabled")
|
||||
except ImportError:
|
||||
cp = np
|
||||
GPU_AVAILABLE = False
|
||||
print("[color_ops_gpu] CuPy not available, using CPU fallback")
|
||||
|
||||
# GPU persistence mode - keep frames on GPU between operations
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "1") == "1"
|
||||
if GPU_AVAILABLE and GPU_PERSIST:
|
||||
print("[color_ops_gpu] GPU persistence enabled - frames stay on GPU")
|
||||
|
||||
|
||||
def _to_gpu(img):
|
||||
"""Move image to GPU if available."""
|
||||
if GPU_AVAILABLE and not isinstance(img, cp.ndarray):
|
||||
return cp.asarray(img)
|
||||
return img
|
||||
|
||||
|
||||
def _to_cpu(img):
|
||||
"""Move image back to CPU (only if GPU_PERSIST is disabled)."""
|
||||
if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray):
|
||||
return cp.asnumpy(img)
|
||||
return img
|
||||
|
||||
|
||||
def prim_invert(img):
|
||||
"""Invert image colors."""
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img)
|
||||
return _to_cpu(255 - img_gpu)
|
||||
return 255 - img
|
||||
|
||||
|
||||
def prim_grayscale(img):
|
||||
"""Convert to grayscale."""
|
||||
if img.ndim != 3:
|
||||
return img
|
||||
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img.astype(np.float32))
|
||||
# Standard luminance weights
|
||||
gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2]
|
||||
gray = cp.clip(gray, 0, 255).astype(cp.uint8)
|
||||
# Stack to 3 channels
|
||||
result = cp.stack([gray, gray, gray], axis=2)
|
||||
return _to_cpu(result)
|
||||
|
||||
gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2]
|
||||
gray = np.clip(gray, 0, 255).astype(np.uint8)
|
||||
return np.stack([gray, gray, gray], axis=2)
|
||||
|
||||
|
||||
def prim_brightness(img, factor=1.0):
|
||||
"""Adjust brightness by factor."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img.astype(np.float32))
|
||||
result = xp.clip(img_gpu * factor, 0, 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
return np.clip(img.astype(np.float32) * factor, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_contrast(img, factor=1.0):
|
||||
"""Adjust contrast around midpoint."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img.astype(np.float32))
|
||||
result = xp.clip((img_gpu - 128) * factor + 128, 0, 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
return np.clip((img.astype(np.float32) - 128) * factor + 128, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
# CUDA kernel for HSV hue shift
|
||||
if GPU_AVAILABLE:
|
||||
_hue_shift_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void hue_shift(unsigned char* img, int width, int height, float shift) {
|
||||
int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
|
||||
if (x >= width || y >= height) return;
|
||||
|
||||
int idx = (y * width + x) * 3;
|
||||
|
||||
// Get RGB
|
||||
float r = img[idx] / 255.0f;
|
||||
float g = img[idx + 1] / 255.0f;
|
||||
float b = img[idx + 2] / 255.0f;
|
||||
|
||||
// RGB to HSV
|
||||
float max_c = fmaxf(r, fmaxf(g, b));
|
||||
float min_c = fminf(r, fminf(g, b));
|
||||
float delta = max_c - min_c;
|
||||
|
||||
float h = 0.0f, s = 0.0f, v = max_c;
|
||||
|
||||
if (delta > 0.00001f) {
|
||||
s = delta / max_c;
|
||||
|
||||
if (max_c == r) {
|
||||
h = 60.0f * fmodf((g - b) / delta, 6.0f);
|
||||
} else if (max_c == g) {
|
||||
h = 60.0f * ((b - r) / delta + 2.0f);
|
||||
} else {
|
||||
h = 60.0f * ((r - g) / delta + 4.0f);
|
||||
}
|
||||
|
||||
if (h < 0) h += 360.0f;
|
||||
}
|
||||
|
||||
// Shift hue
|
||||
h = fmodf(h + shift, 360.0f);
|
||||
if (h < 0) h += 360.0f;
|
||||
|
||||
// HSV to RGB
|
||||
float c = v * s;
|
||||
float x_val = c * (1.0f - fabsf(fmodf(h / 60.0f, 2.0f) - 1.0f));
|
||||
float m = v - c;
|
||||
|
||||
float r_out, g_out, b_out;
|
||||
if (h < 60) {
|
||||
r_out = c; g_out = x_val; b_out = 0;
|
||||
} else if (h < 120) {
|
||||
r_out = x_val; g_out = c; b_out = 0;
|
||||
} else if (h < 180) {
|
||||
r_out = 0; g_out = c; b_out = x_val;
|
||||
} else if (h < 240) {
|
||||
r_out = 0; g_out = x_val; b_out = c;
|
||||
} else if (h < 300) {
|
||||
r_out = x_val; g_out = 0; b_out = c;
|
||||
} else {
|
||||
r_out = c; g_out = 0; b_out = x_val;
|
||||
}
|
||||
|
||||
img[idx] = (unsigned char)fminf(255.0f, (r_out + m) * 255.0f);
|
||||
img[idx + 1] = (unsigned char)fminf(255.0f, (g_out + m) * 255.0f);
|
||||
img[idx + 2] = (unsigned char)fminf(255.0f, (b_out + m) * 255.0f);
|
||||
}
|
||||
''', 'hue_shift')
|
||||
|
||||
|
||||
def prim_hue_shift(img, shift=0.0):
|
||||
"""Shift hue by degrees."""
|
||||
if img.ndim != 3 or img.shape[2] != 3:
|
||||
return img
|
||||
|
||||
if not GPU_AVAILABLE:
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
||||
hsv[:, :, 0] = (hsv[:, :, 0].astype(np.float32) + shift / 2) % 180
|
||||
return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
img_gpu = _to_gpu(img.astype(np.uint8)).copy()
|
||||
|
||||
block = (16, 16)
|
||||
grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1])
|
||||
|
||||
_hue_shift_kernel(grid, block, (img_gpu, np.int32(w), np.int32(h), np.float32(shift)))
|
||||
|
||||
return _to_cpu(img_gpu)
|
||||
|
||||
|
||||
def prim_saturate(img, factor=1.0):
|
||||
"""Adjust saturation by factor."""
|
||||
if img.ndim != 3:
|
||||
return img
|
||||
|
||||
if not GPU_AVAILABLE:
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
|
||||
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
|
||||
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
||||
|
||||
# GPU version - simple desaturation blend
|
||||
img_gpu = _to_gpu(img.astype(np.float32))
|
||||
gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2]
|
||||
gray = gray[:, :, cp.newaxis]
|
||||
|
||||
if factor < 1.0:
|
||||
# Desaturate: blend toward gray
|
||||
result = img_gpu * factor + gray * (1 - factor)
|
||||
else:
|
||||
# Oversaturate: extrapolate away from gray
|
||||
result = gray + (img_gpu - gray) * factor
|
||||
|
||||
result = cp.clip(result, 0, 255).astype(cp.uint8)
|
||||
return _to_cpu(result)
|
||||
|
||||
|
||||
def prim_blend(img1, img2, alpha=0.5):
|
||||
"""Blend two images with alpha."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
|
||||
if GPU_AVAILABLE:
|
||||
img1_gpu = _to_gpu(img1.astype(np.float32))
|
||||
img2_gpu = _to_gpu(img2.astype(np.float32))
|
||||
result = img1_gpu * (1 - alpha) + img2_gpu * alpha
|
||||
result = xp.clip(result, 0, 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
|
||||
result = img1.astype(np.float32) * (1 - alpha) + img2.astype(np.float32) * alpha
|
||||
return np.clip(result, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_add(img1, img2):
|
||||
"""Add two images (clamped)."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
if GPU_AVAILABLE:
|
||||
result = xp.clip(_to_gpu(img1).astype(np.int16) + _to_gpu(img2).astype(np.int16), 0, 255)
|
||||
return _to_cpu(result.astype(xp.uint8))
|
||||
return np.clip(img1.astype(np.int16) + img2.astype(np.int16), 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_multiply(img1, img2):
|
||||
"""Multiply two images (normalized)."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
if GPU_AVAILABLE:
|
||||
result = (_to_gpu(img1).astype(np.float32) * _to_gpu(img2).astype(np.float32)) / 255.0
|
||||
result = xp.clip(result, 0, 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
result = (img1.astype(np.float32) * img2.astype(np.float32)) / 255.0
|
||||
return np.clip(result, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def prim_screen(img1, img2):
|
||||
"""Screen blend mode."""
|
||||
xp = cp if GPU_AVAILABLE else np
|
||||
if GPU_AVAILABLE:
|
||||
i1 = _to_gpu(img1).astype(np.float32) / 255.0
|
||||
i2 = _to_gpu(img2).astype(np.float32) / 255.0
|
||||
result = 1.0 - (1.0 - i1) * (1.0 - i2)
|
||||
result = xp.clip(result * 255, 0, 255).astype(xp.uint8)
|
||||
return _to_cpu(result)
|
||||
i1 = img1.astype(np.float32) / 255.0
|
||||
i2 = img2.astype(np.float32) / 255.0
|
||||
result = 1.0 - (1.0 - i1) * (1.0 - i2)
|
||||
return np.clip(result * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
# Import CPU primitives as fallbacks
|
||||
def _get_cpu_primitives():
|
||||
"""Get all primitives from CPU color_ops module as fallbacks."""
|
||||
from sexp_effects.primitive_libs import color_ops
|
||||
return color_ops.PRIMITIVES
|
||||
|
||||
|
||||
# Export functions - start with CPU primitives, then override with GPU versions
|
||||
PRIMITIVES = _get_cpu_primitives().copy()
|
||||
|
||||
# Override specific primitives with GPU-accelerated versions
|
||||
PRIMITIVES.update({
|
||||
'invert': prim_invert,
|
||||
'grayscale': prim_grayscale,
|
||||
'brightness': prim_brightness,
|
||||
'contrast': prim_contrast,
|
||||
'hue-shift': prim_hue_shift,
|
||||
'saturate': prim_saturate,
|
||||
'blend': prim_blend,
|
||||
'add': prim_add,
|
||||
'multiply': prim_multiply,
|
||||
'screen': prim_screen,
|
||||
})
|
||||
409
sexp_effects/primitive_libs/geometry_gpu.py
Normal file
409
sexp_effects/primitive_libs/geometry_gpu.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
GPU-Accelerated Geometry Primitives Library
|
||||
|
||||
Uses CuPy for CUDA-accelerated image transforms.
|
||||
Falls back to CPU if GPU unavailable.
|
||||
|
||||
Performance Mode:
|
||||
- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations
|
||||
- This dramatically improves performance by avoiding CPU<->GPU transfers
|
||||
- Frames only transfer to CPU at final output
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
try:
|
||||
import cupy as cp
|
||||
from cupyx.scipy import ndimage as cpndimage
|
||||
GPU_AVAILABLE = True
|
||||
print("[geometry_gpu] CuPy GPU acceleration enabled")
|
||||
except ImportError:
|
||||
cp = np
|
||||
GPU_AVAILABLE = False
|
||||
print("[geometry_gpu] CuPy not available, using CPU fallback")
|
||||
|
||||
# GPU persistence mode - keep frames on GPU between operations
|
||||
# Set STREAMING_GPU_PERSIST=1 for maximum performance
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "1") == "1"
|
||||
if GPU_AVAILABLE and GPU_PERSIST:
|
||||
print("[geometry_gpu] GPU persistence enabled - frames stay on GPU")
|
||||
|
||||
|
||||
def _to_gpu(img):
|
||||
"""Move image to GPU if available."""
|
||||
if GPU_AVAILABLE and not isinstance(img, cp.ndarray):
|
||||
return cp.asarray(img)
|
||||
return img
|
||||
|
||||
|
||||
def _to_cpu(img):
|
||||
"""Move image back to CPU (only if GPU_PERSIST is disabled)."""
|
||||
if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray):
|
||||
return cp.asnumpy(img)
|
||||
return img
|
||||
|
||||
|
||||
def _ensure_output_format(img):
|
||||
"""Ensure output is in correct format based on GPU_PERSIST setting."""
|
||||
return _to_cpu(img)
|
||||
|
||||
|
||||
def prim_rotate(img, angle, cx=None, cy=None):
|
||||
"""Rotate image by angle degrees around center (cx, cy)."""
|
||||
if not GPU_AVAILABLE:
|
||||
# Fallback to OpenCV
|
||||
import cv2
|
||||
h, w = img.shape[:2]
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
||||
return cv2.warpAffine(img, M, (w, h))
|
||||
|
||||
img_gpu = _to_gpu(img)
|
||||
h, w = img_gpu.shape[:2]
|
||||
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
|
||||
# Use cupyx.scipy.ndimage.rotate
|
||||
# Note: scipy uses different angle convention
|
||||
rotated = cpndimage.rotate(img_gpu, angle, reshape=False, order=1)
|
||||
|
||||
return _to_cpu(rotated)
|
||||
|
||||
|
||||
def prim_scale(img, sx, sy, cx=None, cy=None):
|
||||
"""Scale image by (sx, sy) around center (cx, cy)."""
|
||||
if not GPU_AVAILABLE:
|
||||
import cv2
|
||||
h, w = img.shape[:2]
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
M = np.float32([
|
||||
[sx, 0, cx * (1 - sx)],
|
||||
[0, sy, cy * (1 - sy)]
|
||||
])
|
||||
return cv2.warpAffine(img, M, (w, h))
|
||||
|
||||
img_gpu = _to_gpu(img)
|
||||
h, w = img_gpu.shape[:2]
|
||||
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
|
||||
# Use cupyx.scipy.ndimage.zoom
|
||||
if img_gpu.ndim == 3:
|
||||
zoom_factors = (sy, sx, 1) # Don't zoom color channels
|
||||
else:
|
||||
zoom_factors = (sy, sx)
|
||||
|
||||
zoomed = cpndimage.zoom(img_gpu, zoom_factors, order=1)
|
||||
|
||||
# Crop/pad to original size
|
||||
zh, zw = zoomed.shape[:2]
|
||||
result = cp.zeros_like(img_gpu)
|
||||
|
||||
# Calculate offsets
|
||||
src_y = max(0, (zh - h) // 2)
|
||||
src_x = max(0, (zw - w) // 2)
|
||||
dst_y = max(0, (h - zh) // 2)
|
||||
dst_x = max(0, (w - zw) // 2)
|
||||
|
||||
copy_h = min(h - dst_y, zh - src_y)
|
||||
copy_w = min(w - dst_x, zw - src_x)
|
||||
|
||||
result[dst_y:dst_y+copy_h, dst_x:dst_x+copy_w] = zoomed[src_y:src_y+copy_h, src_x:src_x+copy_w]
|
||||
|
||||
return _to_cpu(result)
|
||||
|
||||
|
||||
def prim_translate(img, dx, dy):
|
||||
"""Translate image by (dx, dy) pixels."""
|
||||
if not GPU_AVAILABLE:
|
||||
import cv2
|
||||
h, w = img.shape[:2]
|
||||
M = np.float32([[1, 0, dx], [0, 1, dy]])
|
||||
return cv2.warpAffine(img, M, (w, h))
|
||||
|
||||
img_gpu = _to_gpu(img)
|
||||
# Use cupyx.scipy.ndimage.shift
|
||||
if img_gpu.ndim == 3:
|
||||
shift = (dy, dx, 0) # Don't shift color channels
|
||||
else:
|
||||
shift = (dy, dx)
|
||||
|
||||
shifted = cpndimage.shift(img_gpu, shift, order=1)
|
||||
return _to_cpu(shifted)
|
||||
|
||||
|
||||
def prim_flip_h(img):
|
||||
"""Flip image horizontally."""
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img)
|
||||
return _to_cpu(cp.flip(img_gpu, axis=1))
|
||||
return np.flip(img, axis=1)
|
||||
|
||||
|
||||
def prim_flip_v(img):
|
||||
"""Flip image vertically."""
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img)
|
||||
return _to_cpu(cp.flip(img_gpu, axis=0))
|
||||
return np.flip(img, axis=0)
|
||||
|
||||
|
||||
def prim_flip(img, direction="horizontal"):
|
||||
"""Flip image in given direction."""
|
||||
if direction in ("horizontal", "h"):
|
||||
return prim_flip_h(img)
|
||||
elif direction in ("vertical", "v"):
|
||||
return prim_flip_v(img)
|
||||
elif direction in ("both", "hv", "vh"):
|
||||
if GPU_AVAILABLE:
|
||||
img_gpu = _to_gpu(img)
|
||||
return _to_cpu(cp.flip(cp.flip(img_gpu, axis=0), axis=1))
|
||||
return np.flip(np.flip(img, axis=0), axis=1)
|
||||
return img
|
||||
|
||||
|
||||
# CUDA kernel for ripple effect
|
||||
if GPU_AVAILABLE:
|
||||
_ripple_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void ripple(const unsigned char* src, unsigned char* dst,
|
||||
int width, int height, int channels,
|
||||
float amplitude, float frequency, float decay,
|
||||
float speed, float time, float cx, float cy) {
|
||||
int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
|
||||
if (x >= width || y >= height) return;
|
||||
|
||||
// Distance from center
|
||||
float dx = x - cx;
|
||||
float dy = y - cy;
|
||||
float dist = sqrtf(dx * dx + dy * dy);
|
||||
|
||||
// Ripple displacement
|
||||
float wave = sinf(dist * frequency * 0.1f - time * speed) * amplitude;
|
||||
float falloff = expf(-dist * decay * 0.01f);
|
||||
float displacement = wave * falloff;
|
||||
|
||||
// Direction from center
|
||||
float len = dist + 0.0001f; // Avoid division by zero
|
||||
float dir_x = dx / len;
|
||||
float dir_y = dy / len;
|
||||
|
||||
// Source coordinates
|
||||
float src_x = x - dir_x * displacement;
|
||||
float src_y = y - dir_y * displacement;
|
||||
|
||||
// Clamp to bounds
|
||||
src_x = fmaxf(0.0f, fminf(width - 1.0f, src_x));
|
||||
src_y = fmaxf(0.0f, fminf(height - 1.0f, src_y));
|
||||
|
||||
// Bilinear interpolation
|
||||
int x0 = (int)src_x;
|
||||
int y0 = (int)src_y;
|
||||
int x1 = min(x0 + 1, width - 1);
|
||||
int y1 = min(y0 + 1, height - 1);
|
||||
|
||||
float fx = src_x - x0;
|
||||
float fy = src_y - y0;
|
||||
|
||||
for (int c = 0; c < channels; c++) {
|
||||
float v00 = src[(y0 * width + x0) * channels + c];
|
||||
float v10 = src[(y0 * width + x1) * channels + c];
|
||||
float v01 = src[(y1 * width + x0) * channels + c];
|
||||
float v11 = src[(y1 * width + x1) * channels + c];
|
||||
|
||||
float v0 = v00 * (1 - fx) + v10 * fx;
|
||||
float v1 = v01 * (1 - fx) + v11 * fx;
|
||||
float val = v0 * (1 - fy) + v1 * fy;
|
||||
|
||||
dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val));
|
||||
}
|
||||
}
|
||||
''', 'ripple')
|
||||
|
||||
|
||||
def prim_ripple(img, amplitude=10.0, frequency=8.0, decay=2.0, speed=5.0,
|
||||
time=0.0, center_x=None, center_y=None):
|
||||
"""Apply ripple distortion effect."""
|
||||
h, w = img.shape[:2]
|
||||
channels = img.shape[2] if img.ndim == 3 else 1
|
||||
|
||||
if center_x is None:
|
||||
center_x = w / 2
|
||||
if center_y is None:
|
||||
center_y = h / 2
|
||||
|
||||
if not GPU_AVAILABLE:
|
||||
# CPU fallback using coordinate mapping
|
||||
import cv2
|
||||
y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32)
|
||||
|
||||
dx = x_coords - center_x
|
||||
dy = y_coords - center_y
|
||||
dist = np.sqrt(dx**2 + dy**2)
|
||||
|
||||
wave = np.sin(dist * frequency * 0.1 - time * speed) * amplitude
|
||||
falloff = np.exp(-dist * decay * 0.01)
|
||||
displacement = wave * falloff
|
||||
|
||||
length = dist + 0.0001
|
||||
dir_x = dx / length
|
||||
dir_y = dy / length
|
||||
|
||||
map_x = (x_coords - dir_x * displacement).astype(np.float32)
|
||||
map_y = (y_coords - dir_y * displacement).astype(np.float32)
|
||||
|
||||
return cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR)
|
||||
|
||||
# GPU implementation
|
||||
img_gpu = _to_gpu(img.astype(np.uint8))
|
||||
if img_gpu.ndim == 2:
|
||||
img_gpu = img_gpu[:, :, cp.newaxis]
|
||||
channels = 1
|
||||
|
||||
dst = cp.zeros_like(img_gpu)
|
||||
|
||||
block = (16, 16)
|
||||
grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1])
|
||||
|
||||
_ripple_kernel(grid, block, (
|
||||
img_gpu, dst,
|
||||
np.int32(w), np.int32(h), np.int32(channels),
|
||||
np.float32(amplitude), np.float32(frequency), np.float32(decay),
|
||||
np.float32(speed), np.float32(time),
|
||||
np.float32(center_x), np.float32(center_y)
|
||||
))
|
||||
|
||||
result = _to_cpu(dst)
|
||||
if channels == 1:
|
||||
result = result[:, :, 0]
|
||||
return result
|
||||
|
||||
|
||||
# CUDA kernel for fast rotation with bilinear interpolation
|
||||
if GPU_AVAILABLE:
|
||||
_rotate_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void rotate_img(const unsigned char* src, unsigned char* dst,
|
||||
int width, int height, int channels,
|
||||
float cos_a, float sin_a, float cx, float cy) {
|
||||
int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
|
||||
if (x >= width || y >= height) return;
|
||||
|
||||
// Translate to center, rotate, translate back
|
||||
float dx = x - cx;
|
||||
float dy = y - cy;
|
||||
|
||||
float src_x = cos_a * dx + sin_a * dy + cx;
|
||||
float src_y = -sin_a * dx + cos_a * dy + cy;
|
||||
|
||||
// Check bounds
|
||||
if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) {
|
||||
for (int c = 0; c < channels; c++) {
|
||||
dst[(y * width + x) * channels + c] = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Bilinear interpolation
|
||||
int x0 = (int)src_x;
|
||||
int y0 = (int)src_y;
|
||||
int x1 = x0 + 1;
|
||||
int y1 = y0 + 1;
|
||||
|
||||
float fx = src_x - x0;
|
||||
float fy = src_y - y0;
|
||||
|
||||
for (int c = 0; c < channels; c++) {
|
||||
float v00 = src[(y0 * width + x0) * channels + c];
|
||||
float v10 = src[(y0 * width + x1) * channels + c];
|
||||
float v01 = src[(y1 * width + x0) * channels + c];
|
||||
float v11 = src[(y1 * width + x1) * channels + c];
|
||||
|
||||
float v0 = v00 * (1 - fx) + v10 * fx;
|
||||
float v1 = v01 * (1 - fx) + v11 * fx;
|
||||
float val = v0 * (1 - fy) + v1 * fy;
|
||||
|
||||
dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val));
|
||||
}
|
||||
}
|
||||
''', 'rotate_img')
|
||||
|
||||
|
||||
def prim_rotate_gpu(img, angle, cx=None, cy=None):
|
||||
"""Fast GPU rotation using custom CUDA kernel."""
|
||||
if not GPU_AVAILABLE:
|
||||
return prim_rotate(img, angle, cx, cy)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
channels = img.shape[2] if img.ndim == 3 else 1
|
||||
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
|
||||
img_gpu = _to_gpu(img.astype(np.uint8))
|
||||
if img_gpu.ndim == 2:
|
||||
img_gpu = img_gpu[:, :, cp.newaxis]
|
||||
channels = 1
|
||||
|
||||
dst = cp.zeros_like(img_gpu)
|
||||
|
||||
# Convert angle to radians
|
||||
rad = np.radians(angle)
|
||||
cos_a = np.cos(rad)
|
||||
sin_a = np.sin(rad)
|
||||
|
||||
block = (16, 16)
|
||||
grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1])
|
||||
|
||||
_rotate_kernel(grid, block, (
|
||||
img_gpu, dst,
|
||||
np.int32(w), np.int32(h), np.int32(channels),
|
||||
np.float32(cos_a), np.float32(sin_a),
|
||||
np.float32(cx), np.float32(cy)
|
||||
))
|
||||
|
||||
result = _to_cpu(dst)
|
||||
if channels == 1:
|
||||
result = result[:, :, 0]
|
||||
return result
|
||||
|
||||
|
||||
# Import CPU primitives as fallbacks for functions we don't GPU-accelerate
|
||||
def _get_cpu_primitives():
|
||||
"""Get all primitives from CPU geometry module as fallbacks."""
|
||||
from sexp_effects.primitive_libs import geometry
|
||||
return geometry.PRIMITIVES
|
||||
|
||||
|
||||
# Export functions - start with CPU primitives, then override with GPU versions
|
||||
PRIMITIVES = _get_cpu_primitives().copy()
|
||||
|
||||
# Override specific primitives with GPU-accelerated versions
|
||||
PRIMITIVES.update({
|
||||
'translate': prim_translate,
|
||||
'rotate-img': prim_rotate_gpu if GPU_AVAILABLE else prim_rotate,
|
||||
'scale-img': prim_scale,
|
||||
'flip-h': prim_flip_h,
|
||||
'flip-v': prim_flip_v,
|
||||
'flip': prim_flip,
|
||||
# Note: ripple-displace uses CPU version (different API - returns coords, not image)
|
||||
})
|
||||
@@ -8,10 +8,16 @@ import cv2
|
||||
|
||||
|
||||
def prim_width(img):
|
||||
if isinstance(img, (list, tuple)):
|
||||
raise TypeError(f"image:width expects an image array, got {type(img).__name__} with {len(img)} elements")
|
||||
return img.shape[1]
|
||||
|
||||
|
||||
def prim_height(img):
|
||||
if isinstance(img, (list, tuple)):
|
||||
import sys
|
||||
print(f"DEBUG image:height got list: {img[:3]}... (types: {[type(x).__name__ for x in img[:3]]})", file=sys.stderr)
|
||||
raise TypeError(f"image:height expects an image array, got {type(img).__name__} with {len(img)} elements: {img}")
|
||||
return img.shape[0]
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,52 @@ Streaming primitives for video/audio processing.
|
||||
|
||||
These primitives handle video source reading and audio analysis,
|
||||
keeping the interpreter completely generic.
|
||||
|
||||
GPU Acceleration:
|
||||
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
||||
- Hardware video decoding (NVDEC) is used when available
|
||||
- Dramatically improves performance on GPU nodes
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
try:
|
||||
import cupy as cp
|
||||
CUPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
cp = None
|
||||
CUPY_AVAILABLE = False
|
||||
|
||||
# GPU persistence mode - output CuPy arrays instead of numpy
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "1") == "1" and CUPY_AVAILABLE
|
||||
|
||||
# Check for hardware decode support (cached)
|
||||
_HWDEC_AVAILABLE = None
|
||||
|
||||
|
||||
def _check_hwdec():
|
||||
"""Check if NVIDIA hardware decode is available."""
|
||||
global _HWDEC_AVAILABLE
|
||||
if _HWDEC_AVAILABLE is not None:
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
try:
|
||||
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
|
||||
if result.returncode != 0:
|
||||
_HWDEC_AVAILABLE = False
|
||||
return False
|
||||
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
|
||||
_HWDEC_AVAILABLE = "cuda" in result.stdout
|
||||
except Exception:
|
||||
_HWDEC_AVAILABLE = False
|
||||
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
|
||||
class VideoSource:
|
||||
"""Video source with persistent streaming pipe for fast sequential reads."""
|
||||
@@ -57,7 +96,10 @@ class VideoSource:
|
||||
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
|
||||
|
||||
def _start_stream(self, seek_time: float = 0):
|
||||
"""Start or restart the ffmpeg streaming process."""
|
||||
"""Start or restart the ffmpeg streaming process.
|
||||
|
||||
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
|
||||
"""
|
||||
if self._proc:
|
||||
self._proc.kill()
|
||||
self._proc = None
|
||||
@@ -67,15 +109,23 @@ class VideoSource:
|
||||
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||
|
||||
w, h = self._frame_size
|
||||
cmd = [
|
||||
"ffmpeg", "-v", "error", # Show errors instead of quiet
|
||||
|
||||
# Build ffmpeg command with optional hardware decode
|
||||
cmd = ["ffmpeg", "-v", "error"]
|
||||
|
||||
# Use hardware decode if available (significantly faster)
|
||||
if _check_hwdec():
|
||||
cmd.extend(["-hwaccel", "cuda"])
|
||||
|
||||
cmd.extend([
|
||||
"-ss", f"{seek_time:.3f}",
|
||||
"-i", str(self.path),
|
||||
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(self.fps), # Output at specified fps
|
||||
"-"
|
||||
]
|
||||
])
|
||||
|
||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
self._stream_time = seek_time
|
||||
|
||||
@@ -88,8 +138,11 @@ class VideoSource:
|
||||
if err:
|
||||
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
|
||||
|
||||
def _read_frame_from_stream(self) -> np.ndarray:
|
||||
"""Read one frame from the stream."""
|
||||
def _read_frame_from_stream(self):
|
||||
"""Read one frame from the stream.
|
||||
|
||||
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
|
||||
"""
|
||||
w, h = self._frame_size
|
||||
frame_size = w * h * 3
|
||||
|
||||
@@ -100,7 +153,12 @@ class VideoSource:
|
||||
if len(data) < frame_size:
|
||||
return None
|
||||
|
||||
return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||
frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||
|
||||
# Transfer to GPU if persistence mode enabled
|
||||
if GPU_PERSIST:
|
||||
return cp.asarray(frame)
|
||||
return frame
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Read frame (uses last cached or t=0)."""
|
||||
@@ -120,6 +178,9 @@ class VideoSource:
|
||||
seek_time = t
|
||||
if self._duration and self._duration > 0:
|
||||
seek_time = t % self._duration
|
||||
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
|
||||
if seek_time > self._duration - 0.1:
|
||||
seek_time = 0.0
|
||||
|
||||
# Decide whether to seek or continue streaming
|
||||
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
|
||||
@@ -138,24 +199,59 @@ class VideoSource:
|
||||
self._start_stream(seek_time)
|
||||
|
||||
# Skip frames to reach target time
|
||||
skip_retries = 0
|
||||
while self._stream_time + self._frame_time <= seek_time:
|
||||
frame = self._read_frame_from_stream()
|
||||
if frame is None:
|
||||
# Stream ended, restart from seek point
|
||||
# Stream ended or failed - restart from seek point
|
||||
import time
|
||||
skip_retries += 1
|
||||
if skip_retries > 3:
|
||||
# Give up skipping, just start fresh at seek_time
|
||||
self._start_stream(seek_time)
|
||||
time.sleep(0.1)
|
||||
break
|
||||
self._start_stream(seek_time)
|
||||
break
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
self._stream_time += self._frame_time
|
||||
skip_retries = 0 # Reset on successful read
|
||||
|
||||
# Read the target frame
|
||||
frame = self._read_frame_from_stream()
|
||||
if frame is None:
|
||||
# Read the target frame with retry logic
|
||||
frame = None
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
frame = self._read_frame_from_stream()
|
||||
if frame is not None:
|
||||
break
|
||||
|
||||
# Stream failed - try restarting
|
||||
import sys
|
||||
import time
|
||||
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
|
||||
|
||||
# Check for ffmpeg errors
|
||||
if self._proc and self._proc.stderr:
|
||||
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||
if err:
|
||||
raise RuntimeError(f"Failed to read video frame from {self.path.name}: {err}")
|
||||
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} - file may be corrupted or inaccessible")
|
||||
try:
|
||||
import select
|
||||
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
|
||||
if readable:
|
||||
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||
if err:
|
||||
print(f"ffmpeg error: {err}", file=sys.stderr)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Wait a bit and restart
|
||||
time.sleep(0.1)
|
||||
self._start_stream(seek_time)
|
||||
|
||||
# Give ffmpeg time to start
|
||||
time.sleep(0.1)
|
||||
|
||||
if frame is None:
|
||||
import sys
|
||||
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
|
||||
else:
|
||||
self._stream_time += self._frame_time
|
||||
|
||||
|
||||
502
sexp_effects/primitive_libs/streaming_gpu.py
Normal file
502
sexp_effects/primitive_libs/streaming_gpu.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""
|
||||
GPU-Accelerated Streaming Primitives
|
||||
|
||||
Provides GPU-native video source and frame processing.
|
||||
Frames stay on GPU memory throughout the pipeline for maximum performance.
|
||||
|
||||
Architecture:
|
||||
- GPUFrame: Wrapper that tracks whether data is on CPU or GPU
|
||||
- GPUVideoSource: Hardware-accelerated decode to GPU memory
|
||||
- GPU primitives operate directly on GPU frames
|
||||
- Transfer to CPU only at final output
|
||||
|
||||
Requirements:
|
||||
- CuPy for CUDA support
|
||||
- FFmpeg with NVDEC support (for hardware decode)
|
||||
- NVIDIA GPU with CUDA capability
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
# Try to import CuPy
|
||||
try:
|
||||
import cupy as cp
|
||||
GPU_AVAILABLE = True
|
||||
except ImportError:
|
||||
cp = None
|
||||
GPU_AVAILABLE = False
|
||||
|
||||
# Check for hardware decode support
|
||||
_HWDEC_AVAILABLE: Optional[bool] = None
|
||||
|
||||
|
||||
def check_hwdec_available() -> bool:
|
||||
"""Check if NVIDIA hardware decode is available."""
|
||||
global _HWDEC_AVAILABLE
|
||||
if _HWDEC_AVAILABLE is not None:
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
try:
|
||||
# Check for nvidia-smi (GPU present)
|
||||
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
|
||||
if result.returncode != 0:
|
||||
_HWDEC_AVAILABLE = False
|
||||
return False
|
||||
|
||||
# Check for nvdec in ffmpeg
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-hwaccels"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
_HWDEC_AVAILABLE = "cuda" in result.stdout
|
||||
except Exception:
|
||||
_HWDEC_AVAILABLE = False
|
||||
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
|
||||
class GPUFrame:
|
||||
"""
|
||||
Frame container that tracks data location (CPU/GPU).
|
||||
|
||||
Enables zero-copy operations when data is already on the right device.
|
||||
Lazy transfer - only moves data when actually needed.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Union[np.ndarray, 'cp.ndarray'], on_gpu: bool = None):
|
||||
self._cpu_data: Optional[np.ndarray] = None
|
||||
self._gpu_data = None # Optional[cp.ndarray]
|
||||
|
||||
if on_gpu is None:
|
||||
# Auto-detect based on type
|
||||
if GPU_AVAILABLE and isinstance(data, cp.ndarray):
|
||||
self._gpu_data = data
|
||||
else:
|
||||
self._cpu_data = np.asarray(data)
|
||||
elif on_gpu and GPU_AVAILABLE:
|
||||
self._gpu_data = cp.asarray(data) if not isinstance(data, cp.ndarray) else data
|
||||
else:
|
||||
self._cpu_data = np.asarray(data) if isinstance(data, np.ndarray) else cp.asnumpy(data)
|
||||
|
||||
@property
|
||||
def cpu(self) -> np.ndarray:
|
||||
"""Get frame as numpy array (transfers from GPU if needed)."""
|
||||
if self._cpu_data is None:
|
||||
if self._gpu_data is not None and GPU_AVAILABLE:
|
||||
self._cpu_data = cp.asnumpy(self._gpu_data)
|
||||
else:
|
||||
raise ValueError("No frame data available")
|
||||
return self._cpu_data
|
||||
|
||||
@property
|
||||
def gpu(self):
|
||||
"""Get frame as CuPy array (transfers to GPU if needed)."""
|
||||
if not GPU_AVAILABLE:
|
||||
raise RuntimeError("GPU not available")
|
||||
if self._gpu_data is None:
|
||||
if self._cpu_data is not None:
|
||||
self._gpu_data = cp.asarray(self._cpu_data)
|
||||
else:
|
||||
raise ValueError("No frame data available")
|
||||
return self._gpu_data
|
||||
|
||||
@property
|
||||
def is_on_gpu(self) -> bool:
|
||||
"""Check if data is currently on GPU."""
|
||||
return self._gpu_data is not None
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Get frame shape."""
|
||||
if self._gpu_data is not None:
|
||||
return self._gpu_data.shape
|
||||
return self._cpu_data.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get frame dtype."""
|
||||
if self._gpu_data is not None:
|
||||
return self._gpu_data.dtype
|
||||
return self._cpu_data.dtype
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
"""Alias for cpu property."""
|
||||
return self.cpu
|
||||
|
||||
def cupy(self):
|
||||
"""Alias for gpu property."""
|
||||
return self.gpu
|
||||
|
||||
def free_cpu(self):
|
||||
"""Free CPU memory (keep GPU only)."""
|
||||
if self._gpu_data is not None:
|
||||
self._cpu_data = None
|
||||
|
||||
def free_gpu(self):
|
||||
"""Free GPU memory (keep CPU only)."""
|
||||
if self._cpu_data is not None:
|
||||
self._gpu_data = None
|
||||
|
||||
|
||||
class GPUVideoSource:
|
||||
"""
|
||||
GPU-accelerated video source using hardware decode.
|
||||
|
||||
Uses NVDEC for hardware video decoding when available,
|
||||
keeping decoded frames in GPU memory for zero-copy processing.
|
||||
|
||||
Falls back to CPU decode if hardware decode unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, fps: float = 30, prefer_gpu: bool = True):
|
||||
self.path = Path(path)
|
||||
self.fps = fps
|
||||
self.prefer_gpu = prefer_gpu and GPU_AVAILABLE and check_hwdec_available()
|
||||
|
||||
self._frame_size: Optional[Tuple[int, int]] = None
|
||||
self._duration: Optional[float] = None
|
||||
self._proc = None
|
||||
self._stream_time = 0.0
|
||||
self._frame_time = 1.0 / fps
|
||||
self._last_read_time = -1
|
||||
self._cached_frame: Optional[GPUFrame] = None
|
||||
|
||||
# Get video info
|
||||
self._probe_video()
|
||||
|
||||
print(f"[GPUVideoSource] {self.path.name}: {self._frame_size}, "
|
||||
f"duration={self._duration:.1f}s, gpu={self.prefer_gpu}", file=sys.stderr)
|
||||
|
||||
def _probe_video(self):
|
||||
"""Probe video file for metadata."""
|
||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_streams", "-show_format", str(self.path)]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
info = json.loads(result.stdout)
|
||||
|
||||
for stream in info.get("streams", []):
|
||||
if stream.get("codec_type") == "video":
|
||||
self._frame_size = (stream.get("width", 720), stream.get("height", 720))
|
||||
if "duration" in stream:
|
||||
self._duration = float(stream["duration"])
|
||||
elif "tags" in stream and "DURATION" in stream["tags"]:
|
||||
dur_str = stream["tags"]["DURATION"]
|
||||
parts = dur_str.split(":")
|
||||
if len(parts) == 3:
|
||||
h, m, s = parts
|
||||
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
|
||||
break
|
||||
|
||||
if self._duration is None and "format" in info:
|
||||
if "duration" in info["format"]:
|
||||
self._duration = float(info["format"]["duration"])
|
||||
|
||||
if not self._frame_size:
|
||||
self._frame_size = (720, 720)
|
||||
if not self._duration:
|
||||
self._duration = 60.0
|
||||
|
||||
def _start_stream(self, seek_time: float = 0):
|
||||
"""Start ffmpeg decode process."""
|
||||
if self._proc:
|
||||
self._proc.kill()
|
||||
self._proc = None
|
||||
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||
|
||||
w, h = self._frame_size
|
||||
|
||||
# Build ffmpeg command
|
||||
cmd = ["ffmpeg", "-v", "error"]
|
||||
|
||||
# Hardware decode if available
|
||||
if self.prefer_gpu:
|
||||
cmd.extend(["-hwaccel", "cuda"])
|
||||
|
||||
cmd.extend([
|
||||
"-ss", f"{seek_time:.3f}",
|
||||
"-i", str(self.path),
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "rgb24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(self.fps),
|
||||
"-"
|
||||
])
|
||||
|
||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
self._stream_time = seek_time
|
||||
|
||||
def _read_frame_raw(self) -> Optional[np.ndarray]:
|
||||
"""Read one frame from ffmpeg pipe."""
|
||||
w, h = self._frame_size
|
||||
frame_size = w * h * 3
|
||||
|
||||
if not self._proc or self._proc.poll() is not None:
|
||||
return None
|
||||
|
||||
data = self._proc.stdout.read(frame_size)
|
||||
if len(data) < frame_size:
|
||||
return None
|
||||
|
||||
return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||
|
||||
def read_at(self, t: float) -> Optional[GPUFrame]:
|
||||
"""
|
||||
Read frame at specific time.
|
||||
|
||||
Returns GPUFrame with data on GPU if GPU mode enabled.
|
||||
"""
|
||||
# Cache check
|
||||
if t == self._last_read_time and self._cached_frame is not None:
|
||||
return self._cached_frame
|
||||
|
||||
# Loop time for shorter videos
|
||||
seek_time = t
|
||||
if self._duration and self._duration > 0:
|
||||
seek_time = t % self._duration
|
||||
if seek_time > self._duration - 0.1:
|
||||
seek_time = 0.0
|
||||
|
||||
# Determine if we need to seek
|
||||
need_seek = (
|
||||
self._proc is None or
|
||||
self._proc.poll() is not None or
|
||||
seek_time < self._stream_time - self._frame_time or
|
||||
seek_time > self._stream_time + 2.0
|
||||
)
|
||||
|
||||
if need_seek:
|
||||
self._start_stream(seek_time)
|
||||
|
||||
# Skip frames to reach target
|
||||
while self._stream_time + self._frame_time <= seek_time:
|
||||
frame = self._read_frame_raw()
|
||||
if frame is None:
|
||||
self._start_stream(seek_time)
|
||||
break
|
||||
self._stream_time += self._frame_time
|
||||
|
||||
# Read target frame
|
||||
frame_np = self._read_frame_raw()
|
||||
if frame_np is None:
|
||||
return self._cached_frame
|
||||
|
||||
self._stream_time += self._frame_time
|
||||
self._last_read_time = t
|
||||
|
||||
# Create GPUFrame - transfer to GPU if in GPU mode
|
||||
self._cached_frame = GPUFrame(frame_np, on_gpu=self.prefer_gpu)
|
||||
|
||||
# Free CPU copy if on GPU (saves memory)
|
||||
if self.prefer_gpu and self._cached_frame.is_on_gpu:
|
||||
self._cached_frame.free_cpu()
|
||||
|
||||
return self._cached_frame
|
||||
|
||||
def read(self) -> Optional[GPUFrame]:
|
||||
"""Read current frame."""
|
||||
if self._cached_frame is not None:
|
||||
return self._cached_frame
|
||||
return self.read_at(0)
|
||||
|
||||
@property
|
||||
def size(self) -> Tuple[int, int]:
|
||||
return self._frame_size
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self._duration
|
||||
|
||||
def close(self):
|
||||
"""Close the video source."""
|
||||
if self._proc:
|
||||
self._proc.kill()
|
||||
self._proc = None
|
||||
|
||||
|
||||
# GPU-aware primitive functions
|
||||
|
||||
def gpu_blend(frame_a: GPUFrame, frame_b: GPUFrame, alpha: float = 0.5) -> GPUFrame:
|
||||
"""
|
||||
Blend two frames on GPU.
|
||||
|
||||
Both frames stay on GPU throughout - no CPU transfer.
|
||||
"""
|
||||
if not GPU_AVAILABLE:
|
||||
a = frame_a.cpu.astype(np.float32)
|
||||
b = frame_b.cpu.astype(np.float32)
|
||||
result = (a * alpha + b * (1 - alpha)).astype(np.uint8)
|
||||
return GPUFrame(result, on_gpu=False)
|
||||
|
||||
a = frame_a.gpu.astype(cp.float32)
|
||||
b = frame_b.gpu.astype(cp.float32)
|
||||
result = (a * alpha + b * (1 - alpha)).astype(cp.uint8)
|
||||
return GPUFrame(result, on_gpu=True)
|
||||
|
||||
|
||||
def gpu_resize(frame: GPUFrame, size: Tuple[int, int]) -> GPUFrame:
|
||||
"""Resize frame on GPU."""
|
||||
import cv2
|
||||
|
||||
if not GPU_AVAILABLE or not frame.is_on_gpu:
|
||||
resized = cv2.resize(frame.cpu, size)
|
||||
return GPUFrame(resized, on_gpu=False)
|
||||
|
||||
# CuPy doesn't have built-in resize, use scipy zoom
|
||||
from cupyx.scipy import ndimage as cpndimage
|
||||
|
||||
gpu_data = frame.gpu
|
||||
h, w = gpu_data.shape[:2]
|
||||
target_w, target_h = size
|
||||
|
||||
zoom_y = target_h / h
|
||||
zoom_x = target_w / w
|
||||
|
||||
if gpu_data.ndim == 3:
|
||||
resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x, 1), order=1)
|
||||
else:
|
||||
resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x), order=1)
|
||||
|
||||
return GPUFrame(resized, on_gpu=True)
|
||||
|
||||
|
||||
def gpu_rotate(frame: GPUFrame, angle: float) -> GPUFrame:
|
||||
"""Rotate frame on GPU."""
|
||||
if not GPU_AVAILABLE or not frame.is_on_gpu:
|
||||
import cv2
|
||||
h, w = frame.cpu.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||
rotated = cv2.warpAffine(frame.cpu, M, (w, h))
|
||||
return GPUFrame(rotated, on_gpu=False)
|
||||
|
||||
from cupyx.scipy import ndimage as cpndimage
|
||||
rotated = cpndimage.rotate(frame.gpu, angle, reshape=False, order=1)
|
||||
return GPUFrame(rotated, on_gpu=True)
|
||||
|
||||
|
||||
def gpu_brightness(frame: GPUFrame, factor: float) -> GPUFrame:
|
||||
"""Adjust brightness on GPU."""
|
||||
if not GPU_AVAILABLE or not frame.is_on_gpu:
|
||||
result = np.clip(frame.cpu.astype(np.float32) * factor, 0, 255).astype(np.uint8)
|
||||
return GPUFrame(result, on_gpu=False)
|
||||
|
||||
result = cp.clip(frame.gpu.astype(cp.float32) * factor, 0, 255).astype(cp.uint8)
|
||||
return GPUFrame(result, on_gpu=True)
|
||||
|
||||
|
||||
def gpu_composite(frames: list, weights: list = None) -> GPUFrame:
|
||||
"""
|
||||
Composite multiple frames with weights.
|
||||
|
||||
All frames processed on GPU for efficiency.
|
||||
"""
|
||||
if not frames:
|
||||
raise ValueError("No frames to composite")
|
||||
|
||||
if len(frames) == 1:
|
||||
return frames[0]
|
||||
|
||||
if weights is None:
|
||||
weights = [1.0 / len(frames)] * len(frames)
|
||||
|
||||
# Normalize weights
|
||||
total = sum(weights)
|
||||
if total > 0:
|
||||
weights = [w / total for w in weights]
|
||||
|
||||
use_gpu = GPU_AVAILABLE and any(f.is_on_gpu for f in frames)
|
||||
|
||||
if use_gpu:
|
||||
# All on GPU
|
||||
target_shape = frames[0].gpu.shape
|
||||
result = cp.zeros(target_shape, dtype=cp.float32)
|
||||
|
||||
for frame, weight in zip(frames, weights):
|
||||
gpu_data = frame.gpu.astype(cp.float32)
|
||||
if gpu_data.shape != target_shape:
|
||||
# Resize to match
|
||||
from cupyx.scipy import ndimage as cpndimage
|
||||
h, w = target_shape[:2]
|
||||
fh, fw = gpu_data.shape[:2]
|
||||
zoom_factors = (h/fh, w/fw, 1) if gpu_data.ndim == 3 else (h/fh, w/fw)
|
||||
gpu_data = cpndimage.zoom(gpu_data, zoom_factors, order=1)
|
||||
result += gpu_data * weight
|
||||
|
||||
return GPUFrame(cp.clip(result, 0, 255).astype(cp.uint8), on_gpu=True)
|
||||
else:
|
||||
# All on CPU
|
||||
import cv2
|
||||
target_shape = frames[0].cpu.shape
|
||||
result = np.zeros(target_shape, dtype=np.float32)
|
||||
|
||||
for frame, weight in zip(frames, weights):
|
||||
cpu_data = frame.cpu.astype(np.float32)
|
||||
if cpu_data.shape != target_shape:
|
||||
cpu_data = cv2.resize(cpu_data, (target_shape[1], target_shape[0]))
|
||||
result += cpu_data * weight
|
||||
|
||||
return GPUFrame(np.clip(result, 0, 255).astype(np.uint8), on_gpu=False)
|
||||
|
||||
|
||||
# Primitive registration for streaming interpreter
|
||||
|
||||
def get_primitives():
|
||||
"""
|
||||
Get GPU-aware primitives for registration with interpreter.
|
||||
|
||||
These wrap the GPU functions to work with the sexp interpreter.
|
||||
"""
|
||||
def prim_make_video_source_gpu(path: str, fps: float = 30):
|
||||
"""Create GPU-accelerated video source."""
|
||||
return GPUVideoSource(path, fps, prefer_gpu=True)
|
||||
|
||||
def prim_gpu_blend(a, b, alpha=0.5):
|
||||
"""Blend two frames."""
|
||||
fa = a if isinstance(a, GPUFrame) else GPUFrame(a)
|
||||
fb = b if isinstance(b, GPUFrame) else GPUFrame(b)
|
||||
result = gpu_blend(fa, fb, alpha)
|
||||
return result.cpu # Return numpy for compatibility
|
||||
|
||||
def prim_gpu_rotate(img, angle):
|
||||
"""Rotate image."""
|
||||
f = img if isinstance(img, GPUFrame) else GPUFrame(img)
|
||||
result = gpu_rotate(f, angle)
|
||||
return result.cpu
|
||||
|
||||
def prim_gpu_brightness(img, factor):
|
||||
"""Adjust brightness."""
|
||||
f = img if isinstance(img, GPUFrame) else GPUFrame(img)
|
||||
result = gpu_brightness(f, factor)
|
||||
return result.cpu
|
||||
|
||||
return {
|
||||
'streaming-gpu:make-video-source': prim_make_video_source_gpu,
|
||||
'gpu:blend': prim_gpu_blend,
|
||||
'gpu:rotate': prim_gpu_rotate,
|
||||
'gpu:brightness': prim_gpu_brightness,
|
||||
}
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = [
|
||||
'GPU_AVAILABLE',
|
||||
'GPUFrame',
|
||||
'GPUVideoSource',
|
||||
'gpu_blend',
|
||||
'gpu_resize',
|
||||
'gpu_rotate',
|
||||
'gpu_brightness',
|
||||
'gpu_composite',
|
||||
'get_primitives',
|
||||
'check_hwdec_available',
|
||||
]
|
||||
715
sexp_effects/wgsl_compiler.py
Normal file
715
sexp_effects/wgsl_compiler.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""
|
||||
S-Expression to WGSL Compiler
|
||||
|
||||
Compiles sexp effect definitions to WGSL compute shaders for GPU execution.
|
||||
The compilation happens at effect upload time (AOT), not at runtime.
|
||||
|
||||
Architecture:
|
||||
- Parse sexp AST
|
||||
- Analyze primitives used
|
||||
- Generate WGSL compute shader
|
||||
|
||||
Shader Categories:
|
||||
1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel)
|
||||
2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample)
|
||||
3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
import math
|
||||
|
||||
from .parser import parse, parse_all, Symbol, Keyword
|
||||
|
||||
|
||||
@dataclass
|
||||
class WGSLParam:
|
||||
"""A shader parameter (uniform)."""
|
||||
name: str
|
||||
wgsl_type: str # f32, i32, u32, vec2f, etc.
|
||||
default: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledEffect:
|
||||
"""Result of compiling an sexp effect to WGSL."""
|
||||
name: str
|
||||
wgsl_code: str
|
||||
params: List[WGSLParam]
|
||||
workgroup_size: Tuple[int, int, int] = (16, 16, 1)
|
||||
# Metadata for runtime
|
||||
uses_time: bool = False
|
||||
uses_sampling: bool = False # Needs texture sampler
|
||||
category: str = "per_pixel" # per_pixel, geometric, neighborhood
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompilerContext:
|
||||
"""Context during compilation."""
|
||||
effect_name: str = ""
|
||||
params: Dict[str, WGSLParam] = field(default_factory=dict)
|
||||
locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr
|
||||
required_libs: Set[str] = field(default_factory=set)
|
||||
uses_time: bool = False
|
||||
uses_sampling: bool = False
|
||||
temp_counter: int = 0
|
||||
|
||||
def fresh_temp(self) -> str:
|
||||
"""Generate a fresh temporary variable name."""
|
||||
self.temp_counter += 1
|
||||
return f"_t{self.temp_counter}"
|
||||
|
||||
|
||||
class SexpToWGSLCompiler:
|
||||
"""
|
||||
Compiles S-expression effect definitions to WGSL compute shaders.
|
||||
"""
|
||||
|
||||
# Map sexp types to WGSL types
|
||||
TYPE_MAP = {
|
||||
'int': 'i32',
|
||||
'float': 'f32',
|
||||
'bool': 'u32', # WGSL doesn't have bool in storage
|
||||
'string': None, # Strings handled specially
|
||||
}
|
||||
|
||||
# Per-pixel primitives that can be compiled directly
|
||||
PER_PIXEL_PRIMITIVES = {
|
||||
'color_ops:invert-img',
|
||||
'color_ops:grayscale',
|
||||
'color_ops:sepia',
|
||||
'color_ops:adjust',
|
||||
'color_ops:adjust-brightness',
|
||||
'color_ops:shift-hsv',
|
||||
'color_ops:quantize',
|
||||
}
|
||||
|
||||
# Geometric primitives (coordinate remapping)
|
||||
GEOMETRIC_PRIMITIVES = {
|
||||
'geometry:scale-img',
|
||||
'geometry:rotate-img',
|
||||
'geometry:translate',
|
||||
'geometry:flip-h',
|
||||
'geometry:flip-v',
|
||||
'geometry:remap',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.ctx: Optional[CompilerContext] = None
|
||||
|
||||
def compile_file(self, path: str) -> CompiledEffect:
|
||||
"""Compile an effect from a .sexp file."""
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
exprs = parse_all(content)
|
||||
return self.compile(exprs)
|
||||
|
||||
def compile_string(self, sexp_code: str) -> CompiledEffect:
|
||||
"""Compile an effect from an sexp string."""
|
||||
exprs = parse_all(sexp_code)
|
||||
return self.compile(exprs)
|
||||
|
||||
def compile(self, expr: Any) -> CompiledEffect:
|
||||
"""Compile a parsed sexp expression."""
|
||||
self.ctx = CompilerContext()
|
||||
|
||||
# Handle multiple top-level expressions (require-primitives, define-effect)
|
||||
if isinstance(expr, list) and expr and isinstance(expr[0], list):
|
||||
for e in expr:
|
||||
self._process_toplevel(e)
|
||||
else:
|
||||
self._process_toplevel(expr)
|
||||
|
||||
# Generate the WGSL shader
|
||||
wgsl = self._generate_wgsl()
|
||||
|
||||
# Determine category based on primitives used
|
||||
category = self._determine_category()
|
||||
|
||||
return CompiledEffect(
|
||||
name=self.ctx.effect_name,
|
||||
wgsl_code=wgsl,
|
||||
params=list(self.ctx.params.values()),
|
||||
uses_time=self.ctx.uses_time,
|
||||
uses_sampling=self.ctx.uses_sampling,
|
||||
category=category,
|
||||
)
|
||||
|
||||
def _process_toplevel(self, expr: Any):
|
||||
"""Process a top-level expression."""
|
||||
if not isinstance(expr, list) or not expr:
|
||||
return
|
||||
|
||||
head = expr[0]
|
||||
if isinstance(head, Symbol):
|
||||
if head.name == 'require-primitives':
|
||||
# Track required primitive libraries
|
||||
for lib in expr[1:]:
|
||||
lib_name = lib.name if isinstance(lib, Symbol) else str(lib)
|
||||
self.ctx.required_libs.add(lib_name)
|
||||
|
||||
elif head.name == 'define-effect':
|
||||
self._compile_effect_def(expr)
|
||||
|
||||
def _compile_effect_def(self, expr: list):
|
||||
"""Compile a define-effect form."""
|
||||
# (define-effect name :params (...) body)
|
||||
self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1])
|
||||
|
||||
# Parse :params and body
|
||||
i = 2
|
||||
body = None
|
||||
while i < len(expr):
|
||||
item = expr[i]
|
||||
if isinstance(item, Keyword) and item.name == 'params':
|
||||
self._parse_params(expr[i + 1])
|
||||
i += 2
|
||||
elif isinstance(item, Keyword):
|
||||
i += 2 # Skip other keywords
|
||||
else:
|
||||
body = item
|
||||
i += 1
|
||||
|
||||
if body:
|
||||
self.ctx.body_expr = body
|
||||
|
||||
def _parse_params(self, params_list: list):
|
||||
"""Parse the :params block."""
|
||||
for param_def in params_list:
|
||||
if not isinstance(param_def, list):
|
||||
continue
|
||||
|
||||
name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0])
|
||||
|
||||
# Parse keyword args
|
||||
param_type = 'float'
|
||||
default = 0
|
||||
|
||||
i = 1
|
||||
while i < len(param_def):
|
||||
item = param_def[i]
|
||||
if isinstance(item, Keyword):
|
||||
if i + 1 < len(param_def):
|
||||
val = param_def[i + 1]
|
||||
if item.name == 'type':
|
||||
param_type = val.name if isinstance(val, Symbol) else str(val)
|
||||
elif item.name == 'default':
|
||||
default = val
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
wgsl_type = self.TYPE_MAP.get(param_type, 'f32')
|
||||
if wgsl_type:
|
||||
self.ctx.params[name] = WGSLParam(name, wgsl_type, default)
|
||||
|
||||
def _determine_category(self) -> str:
|
||||
"""Determine shader category based on primitives used."""
|
||||
for lib in self.ctx.required_libs:
|
||||
if lib == 'geometry':
|
||||
return 'geometric'
|
||||
if lib == 'filters':
|
||||
return 'neighborhood'
|
||||
return 'per_pixel'
|
||||
|
||||
def _generate_wgsl(self) -> str:
|
||||
"""Generate the complete WGSL shader code."""
|
||||
lines = []
|
||||
|
||||
# Header comment
|
||||
lines.append(f"// WGSL Shader: {self.ctx.effect_name}")
|
||||
lines.append(f"// Auto-generated from sexp effect definition")
|
||||
lines.append("")
|
||||
|
||||
# Bindings
|
||||
lines.append("@group(0) @binding(0) var<storage, read> input: array<u32>;")
|
||||
lines.append("@group(0) @binding(1) var<storage, read_write> output: array<u32>;")
|
||||
lines.append("")
|
||||
|
||||
# Params struct
|
||||
if self.ctx.params:
|
||||
lines.append("struct Params {")
|
||||
lines.append(" width: u32,")
|
||||
lines.append(" height: u32,")
|
||||
lines.append(" time: f32,")
|
||||
for param in self.ctx.params.values():
|
||||
lines.append(f" {param.name}: {param.wgsl_type},")
|
||||
lines.append("}")
|
||||
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
|
||||
else:
|
||||
lines.append("struct Params {")
|
||||
lines.append(" width: u32,")
|
||||
lines.append(" height: u32,")
|
||||
lines.append(" time: f32,")
|
||||
lines.append("}")
|
||||
lines.append("@group(0) @binding(2) var<uniform> params: Params;")
|
||||
lines.append("")
|
||||
|
||||
# Helper functions
|
||||
lines.extend(self._generate_helpers())
|
||||
lines.append("")
|
||||
|
||||
# Main compute shader
|
||||
lines.append("@compute @workgroup_size(16, 16, 1)")
|
||||
lines.append("fn main(@builtin(global_invocation_id) gid: vec3<u32>) {")
|
||||
lines.append(" let x = gid.x;")
|
||||
lines.append(" let y = gid.y;")
|
||||
lines.append(" if (x >= params.width || y >= params.height) { return; }")
|
||||
lines.append(" let idx = y * params.width + x;")
|
||||
lines.append("")
|
||||
|
||||
# Compile the effect body
|
||||
body_code = self._compile_expr(self.ctx.body_expr)
|
||||
lines.append(f" // Effect: {self.ctx.effect_name}")
|
||||
lines.append(body_code)
|
||||
lines.append("}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_helpers(self) -> List[str]:
|
||||
"""Generate WGSL helper functions."""
|
||||
helpers = []
|
||||
|
||||
# Pack/unpack RGB from u32
|
||||
helpers.append("fn unpack_rgb(packed: u32) -> vec3<f32> {")
|
||||
helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;")
|
||||
helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;")
|
||||
helpers.append(" let b = f32(packed & 0xFFu) / 255.0;")
|
||||
helpers.append(" return vec3<f32>(r, g, b);")
|
||||
helpers.append("}")
|
||||
helpers.append("")
|
||||
|
||||
helpers.append("fn pack_rgb(rgb: vec3<f32>) -> u32 {")
|
||||
helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);")
|
||||
helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);")
|
||||
helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);")
|
||||
helpers.append(" return (r << 16u) | (g << 8u) | b;")
|
||||
helpers.append("}")
|
||||
helpers.append("")
|
||||
|
||||
# Bilinear sampling for geometric transforms
|
||||
if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs:
|
||||
helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3<f32> {")
|
||||
helpers.append(" let w = f32(params.width);")
|
||||
helpers.append(" let h = f32(params.height);")
|
||||
helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);")
|
||||
helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);")
|
||||
helpers.append(" let x0 = u32(cx);")
|
||||
helpers.append(" let y0 = u32(cy);")
|
||||
helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);")
|
||||
helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);")
|
||||
helpers.append(" let fx = cx - f32(x0);")
|
||||
helpers.append(" let fy = cy - f32(y0);")
|
||||
helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);")
|
||||
helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);")
|
||||
helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);")
|
||||
helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);")
|
||||
helpers.append(" let top = mix(c00, c10, fx);")
|
||||
helpers.append(" let bot = mix(c01, c11, fx);")
|
||||
helpers.append(" return mix(top, bot, fy);")
|
||||
helpers.append("}")
|
||||
helpers.append("")
|
||||
|
||||
# HSV conversion for color effects
|
||||
if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs:
|
||||
helpers.append("fn rgb_to_hsv(rgb: vec3<f32>) -> vec3<f32> {")
|
||||
helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);")
|
||||
helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);")
|
||||
helpers.append(" let d = mx - mn;")
|
||||
helpers.append(" var h = 0.0;")
|
||||
helpers.append(" if (d > 0.0) {")
|
||||
helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }")
|
||||
helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }")
|
||||
helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }")
|
||||
helpers.append(" h = h / 6.0;")
|
||||
helpers.append(" if (h < 0.0) { h = h + 1.0; }")
|
||||
helpers.append(" }")
|
||||
helpers.append(" let s = select(0.0, d / mx, mx > 0.0);")
|
||||
helpers.append(" return vec3<f32>(h, s, mx);")
|
||||
helpers.append("}")
|
||||
helpers.append("")
|
||||
|
||||
helpers.append("fn hsv_to_rgb(hsv: vec3<f32>) -> vec3<f32> {")
|
||||
helpers.append(" let h = hsv.x * 6.0;")
|
||||
helpers.append(" let s = hsv.y;")
|
||||
helpers.append(" let v = hsv.z;")
|
||||
helpers.append(" let c = v * s;")
|
||||
helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));")
|
||||
helpers.append(" let m = v - c;")
|
||||
helpers.append(" var rgb: vec3<f32>;")
|
||||
helpers.append(" if (h < 1.0) { rgb = vec3<f32>(c, x, 0.0); }")
|
||||
helpers.append(" else if (h < 2.0) { rgb = vec3<f32>(x, c, 0.0); }")
|
||||
helpers.append(" else if (h < 3.0) { rgb = vec3<f32>(0.0, c, x); }")
|
||||
helpers.append(" else if (h < 4.0) { rgb = vec3<f32>(0.0, x, c); }")
|
||||
helpers.append(" else if (h < 5.0) { rgb = vec3<f32>(x, 0.0, c); }")
|
||||
helpers.append(" else { rgb = vec3<f32>(c, 0.0, x); }")
|
||||
helpers.append(" return rgb + vec3<f32>(m, m, m);")
|
||||
helpers.append("}")
|
||||
helpers.append("")
|
||||
|
||||
return helpers
|
||||
|
||||
def _compile_expr(self, expr: Any, indent: int = 4) -> str:
|
||||
"""Compile an sexp expression to WGSL code."""
|
||||
ind = " " * indent
|
||||
|
||||
# Literals
|
||||
if isinstance(expr, (int, float)):
|
||||
return f"{ind}// literal: {expr}"
|
||||
|
||||
if isinstance(expr, str):
|
||||
return f'{ind}// string: "{expr}"'
|
||||
|
||||
# Symbol reference
|
||||
if isinstance(expr, Symbol):
|
||||
name = expr.name
|
||||
if name == 'frame':
|
||||
return f"{ind}let rgb = unpack_rgb(input[idx]);"
|
||||
if name == 't' or name == '_time':
|
||||
self.ctx.uses_time = True
|
||||
return f"{ind}let t = params.time;"
|
||||
if name in self.ctx.params:
|
||||
return f"{ind}let {name} = params.{name};"
|
||||
if name in self.ctx.locals:
|
||||
return f"{ind}// local: {name}"
|
||||
return f"{ind}// unknown symbol: {name}"
|
||||
|
||||
# List (function call or special form)
|
||||
if isinstance(expr, list) and expr:
|
||||
head = expr[0]
|
||||
|
||||
if isinstance(head, Symbol):
|
||||
form = head.name
|
||||
|
||||
# Special forms
|
||||
if form == 'let' or form == 'let*':
|
||||
return self._compile_let(expr, indent)
|
||||
|
||||
if form == 'if':
|
||||
return self._compile_if(expr, indent)
|
||||
|
||||
if form == 'or':
|
||||
# (or a b) - return a if truthy, else b
|
||||
return self._compile_or(expr, indent)
|
||||
|
||||
# Primitive calls
|
||||
if ':' in form:
|
||||
return self._compile_primitive_call(expr, indent)
|
||||
|
||||
# Arithmetic
|
||||
if form in ('+', '-', '*', '/'):
|
||||
return self._compile_arithmetic(expr, indent)
|
||||
|
||||
if form in ('>', '<', '>=', '<=', '='):
|
||||
return self._compile_comparison(expr, indent)
|
||||
|
||||
if form == 'max':
|
||||
return self._compile_builtin('max', expr[1:], indent)
|
||||
|
||||
if form == 'min':
|
||||
return self._compile_builtin('min', expr[1:], indent)
|
||||
|
||||
return f"{ind}// unhandled: {expr}"
|
||||
|
||||
def _compile_let(self, expr: list, indent: int) -> str:
|
||||
"""Compile let/let* binding form."""
|
||||
ind = " " * indent
|
||||
lines = []
|
||||
|
||||
bindings = expr[1]
|
||||
body = expr[2]
|
||||
|
||||
# Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2)))
|
||||
pairs = []
|
||||
if bindings and isinstance(bindings[0], Symbol):
|
||||
# Clojure style
|
||||
i = 0
|
||||
while i < len(bindings) - 1:
|
||||
name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i])
|
||||
value = bindings[i + 1]
|
||||
pairs.append((name, value))
|
||||
i += 2
|
||||
else:
|
||||
# Scheme style
|
||||
for binding in bindings:
|
||||
name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0])
|
||||
value = binding[1]
|
||||
pairs.append((name, value))
|
||||
|
||||
# Compile bindings
|
||||
for name, value in pairs:
|
||||
val_code = self._expr_to_wgsl(value)
|
||||
lines.append(f"{ind}let {name} = {val_code};")
|
||||
self.ctx.locals[name] = val_code
|
||||
|
||||
# Compile body
|
||||
body_lines = self._compile_body(body, indent)
|
||||
lines.append(body_lines)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _compile_body(self, body: Any, indent: int) -> str:
|
||||
"""Compile the body of an effect (the final image expression)."""
|
||||
ind = " " * indent
|
||||
|
||||
# Most effects end with a primitive call that produces the output
|
||||
if isinstance(body, list) and body:
|
||||
head = body[0]
|
||||
if isinstance(head, Symbol) and ':' in head.name:
|
||||
return self._compile_primitive_call(body, indent)
|
||||
|
||||
# If body is just 'frame', pass through
|
||||
if isinstance(body, Symbol) and body.name == 'frame':
|
||||
return f"{ind}output[idx] = input[idx];"
|
||||
|
||||
return f"{ind}// body: {body}"
|
||||
|
||||
def _compile_primitive_call(self, expr: list, indent: int) -> str:
|
||||
"""Compile a primitive function call."""
|
||||
ind = " " * indent
|
||||
head = expr[0]
|
||||
prim_name = head.name if isinstance(head, Symbol) else str(head)
|
||||
args = expr[1:]
|
||||
|
||||
# Per-pixel color operations
|
||||
if prim_name == 'color_ops:invert-img':
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let result = vec3<f32>(1.0, 1.0, 1.0) - rgb;
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'color_ops:grayscale':
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b;
|
||||
{ind}let result = vec3<f32>(gray, gray, gray);
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'color_ops:adjust-brightness':
|
||||
amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let adj = f32({amount}) / 255.0;
|
||||
{ind}let result = clamp(rgb + vec3<f32>(adj, adj, adj), vec3<f32>(0.0, 0.0, 0.0), vec3<f32>(1.0, 1.0, 1.0));
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'color_ops:adjust':
|
||||
# (adjust img brightness contrast)
|
||||
brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
||||
contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let centered = rgb - vec3<f32>(0.5, 0.5, 0.5);
|
||||
{ind}let contrasted = centered * {contrast};
|
||||
{ind}let brightened = contrasted + vec3<f32>(0.5, 0.5, 0.5) + vec3<f32>({brightness}/255.0);
|
||||
{ind}let result = clamp(brightened, vec3<f32>(0.0), vec3<f32>(1.0));
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'color_ops:sepia':
|
||||
intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b;
|
||||
{ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b;
|
||||
{ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b;
|
||||
{ind}let sepia = vec3<f32>(sepia_r, sepia_g, sepia_b);
|
||||
{ind}let result = mix(rgb, sepia, {intensity});
|
||||
{ind}output[idx] = pack_rgb(clamp(result, vec3<f32>(0.0), vec3<f32>(1.0)));"""
|
||||
|
||||
if prim_name == 'color_ops:shift-hsv':
|
||||
h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
||||
s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0"
|
||||
v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0"
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}var hsv = rgb_to_hsv(rgb);
|
||||
{ind}hsv.x = fract(hsv.x + {h_shift} / 360.0);
|
||||
{ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0);
|
||||
{ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0);
|
||||
{ind}let result = hsv_to_rgb(hsv);
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'color_ops:quantize':
|
||||
levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0"
|
||||
return f"""{ind}let rgb = unpack_rgb(input[idx]);
|
||||
{ind}let lvl = max(2.0, {levels});
|
||||
{ind}let result = floor(rgb * lvl) / lvl;
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
# Geometric transforms
|
||||
if prim_name == 'geometry:scale-img':
|
||||
sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0"
|
||||
sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx
|
||||
self.ctx.uses_sampling = True
|
||||
return f"""{ind}let w = f32(params.width);
|
||||
{ind}let h = f32(params.height);
|
||||
{ind}let cx = w / 2.0;
|
||||
{ind}let cy = h / 2.0;
|
||||
{ind}let sx = f32(x) - cx;
|
||||
{ind}let sy = f32(y) - cy;
|
||||
{ind}let src_x = sx / {sx} + cx;
|
||||
{ind}let src_y = sy / {sy} + cy;
|
||||
{ind}let result = sample_bilinear(src_x, src_y);
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'geometry:rotate-img':
|
||||
angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0"
|
||||
self.ctx.uses_sampling = True
|
||||
return f"""{ind}let w = f32(params.width);
|
||||
{ind}let h = f32(params.height);
|
||||
{ind}let cx = w / 2.0;
|
||||
{ind}let cy = h / 2.0;
|
||||
{ind}let angle_rad = {angle} * 3.14159265 / 180.0;
|
||||
{ind}let cos_a = cos(-angle_rad);
|
||||
{ind}let sin_a = sin(-angle_rad);
|
||||
{ind}let dx = f32(x) - cx;
|
||||
{ind}let dy = f32(y) - cy;
|
||||
{ind}let src_x = dx * cos_a - dy * sin_a + cx;
|
||||
{ind}let src_y = dx * sin_a + dy * cos_a + cy;
|
||||
{ind}let result = sample_bilinear(src_x, src_y);
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
if prim_name == 'geometry:flip-h':
|
||||
return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x);
|
||||
{ind}output[idx] = input[src_idx];"""
|
||||
|
||||
if prim_name == 'geometry:flip-v':
|
||||
return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x;
|
||||
{ind}output[idx] = input[src_idx];"""
|
||||
|
||||
# Image library
|
||||
if prim_name == 'image:blur':
|
||||
radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5"
|
||||
# Box blur approximation (separable would be better)
|
||||
return f"""{ind}let radius = i32({radius});
|
||||
{ind}var sum = vec3<f32>(0.0, 0.0, 0.0);
|
||||
{ind}var count = 0.0;
|
||||
{ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{
|
||||
{ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{
|
||||
{ind} let sx = i32(x) + dx;
|
||||
{ind} let sy = i32(y) + dy;
|
||||
{ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{
|
||||
{ind} let sidx = u32(sy) * params.width + u32(sx);
|
||||
{ind} sum = sum + unpack_rgb(input[sidx]);
|
||||
{ind} count = count + 1.0;
|
||||
{ind} }}
|
||||
{ind} }}
|
||||
{ind}}}
|
||||
{ind}let result = sum / count;
|
||||
{ind}output[idx] = pack_rgb(result);"""
|
||||
|
||||
# Fallback - passthrough
|
||||
return f"""{ind}// Unimplemented primitive: {prim_name}
|
||||
{ind}output[idx] = input[idx];"""
|
||||
|
||||
def _compile_if(self, expr: list, indent: int) -> str:
|
||||
"""Compile if expression."""
|
||||
ind = " " * indent
|
||||
cond = self._expr_to_wgsl(expr[1])
|
||||
then_expr = expr[2]
|
||||
else_expr = expr[3] if len(expr) > 3 else None
|
||||
|
||||
lines = []
|
||||
lines.append(f"{ind}if ({cond}) {{")
|
||||
lines.append(self._compile_body(then_expr, indent + 4))
|
||||
if else_expr:
|
||||
lines.append(f"{ind}}} else {{")
|
||||
lines.append(self._compile_body(else_expr, indent + 4))
|
||||
lines.append(f"{ind}}}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _compile_or(self, expr: list, indent: int) -> str:
|
||||
"""Compile or expression - returns first truthy value."""
|
||||
# For numeric context, (or a b) means "a if a != 0 else b"
|
||||
a = self._expr_to_wgsl(expr[1])
|
||||
b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0"
|
||||
return f"select({b}, {a}, {a} != 0.0)"
|
||||
|
||||
def _compile_arithmetic(self, expr: list, indent: int) -> str:
|
||||
"""Compile arithmetic expression to inline WGSL."""
|
||||
op = expr[0].name
|
||||
operands = [self._expr_to_wgsl(arg) for arg in expr[1:]]
|
||||
|
||||
if len(operands) == 1:
|
||||
if op == '-':
|
||||
return f"(-{operands[0]})"
|
||||
return operands[0]
|
||||
|
||||
return f"({f' {op} '.join(operands)})"
|
||||
|
||||
def _compile_comparison(self, expr: list, indent: int) -> str:
|
||||
"""Compile comparison expression."""
|
||||
op = expr[0].name
|
||||
if op == '=':
|
||||
op = '=='
|
||||
a = self._expr_to_wgsl(expr[1])
|
||||
b = self._expr_to_wgsl(expr[2])
|
||||
return f"({a} {op} {b})"
|
||||
|
||||
def _compile_builtin(self, fn: str, args: list, indent: int) -> str:
|
||||
"""Compile builtin function call."""
|
||||
compiled_args = [self._expr_to_wgsl(arg) for arg in args]
|
||||
return f"{fn}({', '.join(compiled_args)})"
|
||||
|
||||
def _expr_to_wgsl(self, expr: Any) -> str:
|
||||
"""Convert an expression to inline WGSL code."""
|
||||
if isinstance(expr, (int, float)):
|
||||
# Ensure floats have decimal point
|
||||
if isinstance(expr, float) or '.' not in str(expr):
|
||||
return f"{float(expr)}"
|
||||
return str(expr)
|
||||
|
||||
if isinstance(expr, str):
|
||||
return f'"{expr}"'
|
||||
|
||||
if isinstance(expr, Symbol):
|
||||
name = expr.name
|
||||
if name == 'frame':
|
||||
return "rgb" # Assume rgb is already loaded
|
||||
if name == 't' or name == '_time':
|
||||
self.ctx.uses_time = True
|
||||
return "params.time"
|
||||
if name == 'pi':
|
||||
return "3.14159265"
|
||||
if name in self.ctx.params:
|
||||
return f"params.{name}"
|
||||
if name in self.ctx.locals:
|
||||
return name
|
||||
return name
|
||||
|
||||
if isinstance(expr, list) and expr:
|
||||
head = expr[0]
|
||||
if isinstance(head, Symbol):
|
||||
form = head.name
|
||||
|
||||
# Arithmetic
|
||||
if form in ('+', '-', '*', '/'):
|
||||
return self._compile_arithmetic(expr, 0)
|
||||
|
||||
# Comparison
|
||||
if form in ('>', '<', '>=', '<=', '='):
|
||||
return self._compile_comparison(expr, 0)
|
||||
|
||||
# Builtins
|
||||
if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'):
|
||||
args = [self._expr_to_wgsl(a) for a in expr[1:]]
|
||||
return f"{form}({', '.join(args)})"
|
||||
|
||||
if form == 'or':
|
||||
return self._compile_or(expr, 0)
|
||||
|
||||
# Image dimension queries
|
||||
if form == 'image:width':
|
||||
return "f32(params.width)"
|
||||
if form == 'image:height':
|
||||
return "f32(params.height)"
|
||||
|
||||
return f"/* unknown: {expr} */"
|
||||
|
||||
|
||||
def compile_effect(sexp_code: str) -> CompiledEffect:
|
||||
"""Convenience function to compile an sexp effect string."""
|
||||
compiler = SexpToWGSLCompiler()
|
||||
return compiler.compile_string(sexp_code)
|
||||
|
||||
|
||||
def compile_effect_file(path: str) -> CompiledEffect:
|
||||
"""Convenience function to compile an sexp effect file."""
|
||||
compiler = SexpToWGSLCompiler()
|
||||
return compiler.compile_file(path)
|
||||
Reference in New Issue
Block a user