Use JAX for fused pipeline fallback on CPU instead of GPUFrame path
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 2m41s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 2m41s
When CUDA fused kernels aren't available, the fused-pipeline primitive now uses JAX ops (jax_rotate, jax_scale, jax_shift_hue, etc.) instead of falling back to one-by-one CuPy/GPUFrame operations. Legacy GPUFrame path retained as last resort when JAX is also unavailable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -842,8 +842,9 @@ def _get_cpu_primitives():
|
|||||||
|
|
||||||
PRIMITIVES = _get_cpu_primitives().copy()
|
PRIMITIVES = _get_cpu_primitives().copy()
|
||||||
|
|
||||||
# Try to import fused kernel compiler
|
# Try to import fused kernel compiler (CUDA first, then JAX fallback)
|
||||||
_FUSED_KERNELS_AVAILABLE = False
|
_FUSED_KERNELS_AVAILABLE = False
|
||||||
|
_FUSED_JAX_AVAILABLE = False
|
||||||
_compile_frame_pipeline = None
|
_compile_frame_pipeline = None
|
||||||
_compile_autonomous_pipeline = None
|
_compile_autonomous_pipeline = None
|
||||||
try:
|
try:
|
||||||
@@ -853,7 +854,56 @@ try:
|
|||||||
_FUSED_KERNELS_AVAILABLE = True
|
_FUSED_KERNELS_AVAILABLE = True
|
||||||
print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr)
|
print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"[streaming_gpu] Fused kernels not available: {e}", file=sys.stderr)
|
print(f"[streaming_gpu] Fused CUDA kernels not available: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# JAX fallback for fused pipeline on CPU
|
||||||
|
_jax_fused_fns = {}
|
||||||
|
try:
|
||||||
|
from streaming.sexp_to_jax import (
|
||||||
|
jax_rotate, jax_scale, jax_shift_hue, jax_invert,
|
||||||
|
jax_adjust_brightness, jax_adjust_contrast, jax_resize,
|
||||||
|
)
|
||||||
|
import jax.numpy as jnp
|
||||||
|
_FUSED_JAX_AVAILABLE = True
|
||||||
|
from streaming.sexp_to_jax import jax_sample
|
||||||
|
|
||||||
|
def _jax_ripple(img, amplitude=10, frequency=8, decay=2, phase=0, cx=None, cy=None):
|
||||||
|
"""JAX ripple displacement matching the CUDA fused pipeline."""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
if cx is None:
|
||||||
|
cx = w / 2
|
||||||
|
if cy is None:
|
||||||
|
cy = h / 2
|
||||||
|
y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32)
|
||||||
|
x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32)
|
||||||
|
dx = x_coords - cx
|
||||||
|
dy = y_coords - cy
|
||||||
|
dist = jnp.sqrt(dx*dx + dy*dy)
|
||||||
|
max_dim = jnp.maximum(w, h).astype(jnp.float32)
|
||||||
|
ripple = jnp.sin(2 * jnp.pi * frequency * dist / max_dim + phase) * amplitude
|
||||||
|
decay_factor = jnp.exp(-decay * dist / max_dim)
|
||||||
|
ripple = ripple * decay_factor
|
||||||
|
angle = jnp.arctan2(dy, dx)
|
||||||
|
src_x = x_coords + ripple * jnp.cos(angle)
|
||||||
|
src_y = y_coords + ripple * jnp.sin(angle)
|
||||||
|
r, g, b = jax_sample(img, src_x.flatten(), src_y.flatten())
|
||||||
|
return jnp.stack([
|
||||||
|
jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||||||
|
jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8),
|
||||||
|
jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8)
|
||||||
|
], axis=2)
|
||||||
|
|
||||||
|
_jax_fused_fns = {
|
||||||
|
'rotate': lambda img, **kw: jax_rotate(img, kw.get('angle', 0)),
|
||||||
|
'zoom': lambda img, **kw: jax_scale(img, kw.get('amount', 1.0)),
|
||||||
|
'hue_shift': lambda img, **kw: jax_shift_hue(img, kw.get('degrees', 0)),
|
||||||
|
'invert': lambda img, **kw: jax_invert(img),
|
||||||
|
'brightness': lambda img, **kw: jax_adjust_contrast(img, kw.get('factor', 1.0)),
|
||||||
|
'ripple': lambda img, **kw: _jax_ripple(img, **kw),
|
||||||
|
}
|
||||||
|
print("[streaming_gpu] JAX fused fallback loaded", file=sys.stderr)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"[streaming_gpu] JAX fallback not available: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
# Fused pipeline cache
|
# Fused pipeline cache
|
||||||
@@ -930,9 +980,53 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
|
|||||||
effects_list = other_effects
|
effects_list = other_effects
|
||||||
|
|
||||||
if not _FUSED_KERNELS_AVAILABLE:
|
if not _FUSED_KERNELS_AVAILABLE:
|
||||||
# Fallback: apply effects one by one
|
if _FUSED_JAX_AVAILABLE:
|
||||||
print(f"[FUSED FALLBACK] Using fallback path for {len(effects_list)} effects", file=sys.stderr)
|
# JAX path: convert to JAX array, apply effects, convert back to numpy
|
||||||
# Wrap in GPUFrame if needed (GPU functions expect GPUFrame objects)
|
if _FUSED_CALL_COUNT <= 3:
|
||||||
|
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
|
||||||
|
# Extract numpy array from GPUFrame if needed
|
||||||
|
if isinstance(img, GPUFrame):
|
||||||
|
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
|
||||||
|
elif hasattr(img, 'get'):
|
||||||
|
arr = img.get() # CuPy to numpy
|
||||||
|
else:
|
||||||
|
arr = np.asarray(img)
|
||||||
|
result = jnp.array(arr)
|
||||||
|
for effect in effects_list:
|
||||||
|
op = effect['op']
|
||||||
|
if op == 'rotate':
|
||||||
|
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
|
||||||
|
result = _jax_fused_fns['rotate'](result, angle=angle)
|
||||||
|
elif op == 'zoom':
|
||||||
|
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
|
||||||
|
result = _jax_fused_fns['zoom'](result, amount=amount)
|
||||||
|
elif op == 'hue_shift':
|
||||||
|
degrees = effect.get('degrees', 0)
|
||||||
|
if abs(degrees) > 0.1:
|
||||||
|
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
|
||||||
|
elif op == 'ripple':
|
||||||
|
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||||
|
if amplitude > 0.1:
|
||||||
|
result = _jax_fused_fns['ripple'](result,
|
||||||
|
amplitude=amplitude,
|
||||||
|
frequency=effect.get('frequency', 8),
|
||||||
|
decay=effect.get('decay', 2),
|
||||||
|
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
|
||||||
|
cx=effect.get('center_x'),
|
||||||
|
cy=effect.get('center_y'))
|
||||||
|
elif op == 'brightness':
|
||||||
|
factor = effect.get('factor', 1.0)
|
||||||
|
result = _jax_fused_fns['brightness'](result, factor=factor)
|
||||||
|
elif op == 'invert':
|
||||||
|
amount = effect.get('amount', 0)
|
||||||
|
if amount > 0.5:
|
||||||
|
result = _jax_fused_fns['invert'](result)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||||
|
return np.asarray(result)
|
||||||
|
|
||||||
|
# Legacy CuPy/GPUFrame fallback
|
||||||
|
print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)
|
||||||
if isinstance(img, GPUFrame):
|
if isinstance(img, GPUFrame):
|
||||||
result = img
|
result = img
|
||||||
else:
|
else:
|
||||||
@@ -948,11 +1042,11 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
|
|||||||
result = gpu_zoom(result, amount)
|
result = gpu_zoom(result, amount)
|
||||||
elif op == 'hue_shift':
|
elif op == 'hue_shift':
|
||||||
degrees = effect.get('degrees', 0)
|
degrees = effect.get('degrees', 0)
|
||||||
if abs(degrees) > 0.1: # Only apply if significant shift
|
if abs(degrees) > 0.1:
|
||||||
result = gpu_hue_shift(result, degrees)
|
result = gpu_hue_shift(result, degrees)
|
||||||
elif op == 'ripple':
|
elif op == 'ripple':
|
||||||
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||||
if amplitude > 0.1: # Only apply if amplitude is significant
|
if amplitude > 0.1:
|
||||||
result = gpu_ripple(result,
|
result = gpu_ripple(result,
|
||||||
amplitude=amplitude,
|
amplitude=amplitude,
|
||||||
frequency=effect.get('frequency', 8),
|
frequency=effect.get('frequency', 8),
|
||||||
@@ -965,11 +1059,10 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
|
|||||||
result = gpu_contrast(result, factor, 0)
|
result = gpu_contrast(result, factor, 0)
|
||||||
elif op == 'invert':
|
elif op == 'invert':
|
||||||
amount = effect.get('amount', 0)
|
amount = effect.get('amount', 0)
|
||||||
if amount > 0.5: # Only invert if amount > 0.5
|
if amount > 0.5:
|
||||||
result = gpu_invert(result)
|
result = gpu_invert(result)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||||
# Return raw array, not GPUFrame (downstream expects arrays with .flags attribute)
|
|
||||||
if isinstance(result, GPUFrame):
|
if isinstance(result, GPUFrame):
|
||||||
return result.gpu if result.is_on_gpu else result.cpu
|
return result.gpu if result.is_on_gpu else result.cpu
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user