From 07cae101ad1af511fe063e7441ca64495a2a271d Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:35:13 +0000 Subject: [PATCH] Use JAX for fused pipeline fallback on CPU instead of GPUFrame path When CUDA fused kernels aren't available, the fused-pipeline primitive now uses JAX ops (jax_rotate, jax_scale, jax_shift_hue, etc.) instead of falling back to one-by-one CuPy/GPUFrame operations. Legacy GPUFrame path retained as last resort when JAX is also unavailable. Co-Authored-By: Claude Opus 4.6 --- .../primitive_libs/streaming_gpu.py | 111 ++++++++++++++++-- 1 file changed, 102 insertions(+), 9 deletions(-) diff --git a/l1/sexp_effects/primitive_libs/streaming_gpu.py b/l1/sexp_effects/primitive_libs/streaming_gpu.py index f2aa7ea..a2374f5 100644 --- a/l1/sexp_effects/primitive_libs/streaming_gpu.py +++ b/l1/sexp_effects/primitive_libs/streaming_gpu.py @@ -842,8 +842,9 @@ def _get_cpu_primitives(): PRIMITIVES = _get_cpu_primitives().copy() -# Try to import fused kernel compiler +# Try to import fused kernel compiler (CUDA first, then JAX fallback) _FUSED_KERNELS_AVAILABLE = False +_FUSED_JAX_AVAILABLE = False _compile_frame_pipeline = None _compile_autonomous_pipeline = None try: @@ -853,7 +854,56 @@ try: _FUSED_KERNELS_AVAILABLE = True print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr) except ImportError as e: - print(f"[streaming_gpu] Fused kernels not available: {e}", file=sys.stderr) + print(f"[streaming_gpu] Fused CUDA kernels not available: {e}", file=sys.stderr) + +# JAX fallback for fused pipeline on CPU +_jax_fused_fns = {} +try: + from streaming.sexp_to_jax import ( + jax_rotate, jax_scale, jax_shift_hue, jax_invert, + jax_adjust_brightness, jax_adjust_contrast, jax_resize, + ) + import jax.numpy as jnp + _FUSED_JAX_AVAILABLE = True + from streaming.sexp_to_jax import jax_sample + + def _jax_ripple(img, amplitude=10, frequency=8, decay=2, phase=0, cx=None, cy=None): + """JAX ripple displacement matching the CUDA fused pipeline.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = jnp.sqrt(dx*dx + dy*dy) + max_dim = jnp.maximum(w, h).astype(jnp.float32) + ripple = jnp.sin(2 * jnp.pi * frequency * dist / max_dim + phase) * amplitude + decay_factor = jnp.exp(-decay * dist / max_dim) + ripple = ripple * decay_factor + angle = jnp.arctan2(dy, dx) + src_x = x_coords + ripple * jnp.cos(angle) + src_y = y_coords + ripple * jnp.sin(angle) + r, g, b = jax_sample(img, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + _jax_fused_fns = { + 'rotate': lambda img, **kw: jax_rotate(img, kw.get('angle', 0)), + 'zoom': lambda img, **kw: jax_scale(img, kw.get('amount', 1.0)), + 'hue_shift': lambda img, **kw: jax_shift_hue(img, kw.get('degrees', 0)), + 'invert': lambda img, **kw: jax_invert(img), + 'brightness': lambda img, **kw: jax_adjust_contrast(img, kw.get('factor', 1.0)), + 'ripple': lambda img, **kw: _jax_ripple(img, **kw), + } + print("[streaming_gpu] JAX fused fallback loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] JAX fallback not available: {e}", file=sys.stderr) # Fused pipeline cache @@ -930,9 +980,53 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): effects_list = other_effects if not _FUSED_KERNELS_AVAILABLE: - # Fallback: apply effects one by one - print(f"[FUSED FALLBACK] Using fallback path for {len(effects_list)} effects", file=sys.stderr) - # Wrap in GPUFrame if needed (GPU functions expect GPUFrame objects) + if _FUSED_JAX_AVAILABLE: + # JAX path: convert to JAX array, apply effects, convert back to numpy + if _FUSED_CALL_COUNT <= 3: + print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr) + # Extract numpy array from GPUFrame if needed + if isinstance(img, GPUFrame): + arr = img.cpu if not img.is_on_gpu else img.gpu.get() + elif hasattr(img, 'get'): + arr = img.get() # CuPy to numpy + else: + arr = np.asarray(img) + result = jnp.array(arr) + for effect in effects_list: + op = effect['op'] + if op == 'rotate': + angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) + result = _jax_fused_fns['rotate'](result, angle=angle) + elif op == 'zoom': + amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) + result = _jax_fused_fns['zoom'](result, amount=amount) + elif op == 'hue_shift': + degrees = effect.get('degrees', 0) + if abs(degrees) > 0.1: + result = _jax_fused_fns['hue_shift'](result, degrees=degrees) + elif op == 'ripple': + amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) + if amplitude > 0.1: + result = _jax_fused_fns['ripple'](result, + amplitude=amplitude, + frequency=effect.get('frequency', 8), + decay=effect.get('decay', 2), + phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), + cx=effect.get('center_x'), + cy=effect.get('center_y')) + elif op == 'brightness': + factor = effect.get('factor', 1.0) + result = _jax_fused_fns['brightness'](result, factor=factor) + elif op == 'invert': + amount = effect.get('amount', 0) + if amount > 0.5: + result = _jax_fused_fns['invert'](result) + else: + raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") + return np.asarray(result) + + # Legacy CuPy/GPUFrame fallback + print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr) if isinstance(img, GPUFrame): result = img else: @@ -948,11 +1042,11 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): result = gpu_zoom(result, amount) elif op == 'hue_shift': degrees = effect.get('degrees', 0) - if abs(degrees) > 0.1: # Only apply if significant shift + if abs(degrees) > 0.1: result = gpu_hue_shift(result, degrees) elif op == 'ripple': amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) - if amplitude > 0.1: # Only apply if amplitude is significant + if amplitude > 0.1: result = gpu_ripple(result, amplitude=amplitude, frequency=effect.get('frequency', 8), @@ -965,11 +1059,10 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): result = gpu_contrast(result, factor, 0) elif op == 'invert': amount = effect.get('amount', 0) - if amount > 0.5: # Only invert if amount > 0.5 + if amount > 0.5: result = gpu_invert(result) else: raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") - # Return raw array, not GPUFrame (downstream expects arrays with .flags attribute) if isinstance(result, GPUFrame): return result.gpu if result.is_on_gpu else result.cpu return result