Make JAX the primary fused-pipeline path for CPU/GPU parity
JAX via XLA produces identical output on CPU and GPU. Previously CUDA hand-written kernels were preferred on GPU, causing visual differences vs the JAX CPU fallback. Now JAX is always used first, with legacy CuPy/GPUFrame as fallback only when JAX is unavailable. Also adds comprehensive CLAUDE.md for the monorepo. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -979,51 +979,52 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
|
||||
# Update effects list to exclude resize ops
|
||||
effects_list = other_effects
|
||||
|
||||
if not _FUSED_KERNELS_AVAILABLE:
|
||||
if _FUSED_JAX_AVAILABLE:
|
||||
# JAX path: convert to JAX array, apply effects, convert back to numpy
|
||||
if _FUSED_CALL_COUNT <= 3:
|
||||
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
|
||||
# Extract numpy array from GPUFrame if needed
|
||||
if isinstance(img, GPUFrame):
|
||||
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
|
||||
elif hasattr(img, 'get'):
|
||||
arr = img.get() # CuPy to numpy
|
||||
# JAX is the primary path — same code on CPU and GPU, XLA handles dispatch
|
||||
if _FUSED_JAX_AVAILABLE:
|
||||
if _FUSED_CALL_COUNT <= 3:
|
||||
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
|
||||
# Extract numpy array from GPUFrame if needed
|
||||
if isinstance(img, GPUFrame):
|
||||
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
|
||||
elif hasattr(img, 'get'):
|
||||
arr = img.get() # CuPy to numpy
|
||||
else:
|
||||
arr = np.asarray(img)
|
||||
result = jnp.array(arr)
|
||||
for effect in effects_list:
|
||||
op = effect['op']
|
||||
if op == 'rotate':
|
||||
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
|
||||
result = _jax_fused_fns['rotate'](result, angle=angle)
|
||||
elif op == 'zoom':
|
||||
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
|
||||
result = _jax_fused_fns['zoom'](result, amount=amount)
|
||||
elif op == 'hue_shift':
|
||||
degrees = effect.get('degrees', 0)
|
||||
if abs(degrees) > 0.1:
|
||||
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
|
||||
elif op == 'ripple':
|
||||
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||
if amplitude > 0.1:
|
||||
result = _jax_fused_fns['ripple'](result,
|
||||
amplitude=amplitude,
|
||||
frequency=effect.get('frequency', 8),
|
||||
decay=effect.get('decay', 2),
|
||||
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
|
||||
cx=effect.get('center_x'),
|
||||
cy=effect.get('center_y'))
|
||||
elif op == 'brightness':
|
||||
factor = effect.get('factor', 1.0)
|
||||
result = _jax_fused_fns['brightness'](result, factor=factor)
|
||||
elif op == 'invert':
|
||||
amount = effect.get('amount', 0)
|
||||
if amount > 0.5:
|
||||
result = _jax_fused_fns['invert'](result)
|
||||
else:
|
||||
arr = np.asarray(img)
|
||||
result = jnp.array(arr)
|
||||
for effect in effects_list:
|
||||
op = effect['op']
|
||||
if op == 'rotate':
|
||||
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
|
||||
result = _jax_fused_fns['rotate'](result, angle=angle)
|
||||
elif op == 'zoom':
|
||||
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
|
||||
result = _jax_fused_fns['zoom'](result, amount=amount)
|
||||
elif op == 'hue_shift':
|
||||
degrees = effect.get('degrees', 0)
|
||||
if abs(degrees) > 0.1:
|
||||
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
|
||||
elif op == 'ripple':
|
||||
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||
if amplitude > 0.1:
|
||||
result = _jax_fused_fns['ripple'](result,
|
||||
amplitude=amplitude,
|
||||
frequency=effect.get('frequency', 8),
|
||||
decay=effect.get('decay', 2),
|
||||
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
|
||||
cx=effect.get('center_x'),
|
||||
cy=effect.get('center_y'))
|
||||
elif op == 'brightness':
|
||||
factor = effect.get('factor', 1.0)
|
||||
result = _jax_fused_fns['brightness'](result, factor=factor)
|
||||
elif op == 'invert':
|
||||
amount = effect.get('amount', 0)
|
||||
if amount > 0.5:
|
||||
result = _jax_fused_fns['invert'](result)
|
||||
else:
|
||||
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||
return np.asarray(result)
|
||||
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||
return np.asarray(result)
|
||||
|
||||
if not _FUSED_KERNELS_AVAILABLE:
|
||||
|
||||
# Legacy CuPy/GPUFrame fallback
|
||||
print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user