Make JAX the primary fused-pipeline path for CPU/GPU parity

JAX via XLA produces identical output on CPU and GPU. Previously
CUDA hand-written kernels were preferred on GPU, causing visual
differences vs the JAX CPU fallback. Now JAX is always used first,
with legacy CuPy/GPUFrame as fallback only when JAX is unavailable.

Also adds comprehensive CLAUDE.md for the monorepo.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
giles
2026-02-25 19:31:53 +00:00
parent b788f1f778
commit 4c2e716558
2 changed files with 119 additions and 44 deletions

View File

@@ -979,51 +979,52 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
# Update effects list to exclude resize ops
effects_list = other_effects
if not _FUSED_KERNELS_AVAILABLE:
if _FUSED_JAX_AVAILABLE:
# JAX path: convert to JAX array, apply effects, convert back to numpy
if _FUSED_CALL_COUNT <= 3:
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
# Extract numpy array from GPUFrame if needed
if isinstance(img, GPUFrame):
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
elif hasattr(img, 'get'):
arr = img.get() # CuPy to numpy
# JAX is the primary path — same code on CPU and GPU, XLA handles dispatch
if _FUSED_JAX_AVAILABLE:
if _FUSED_CALL_COUNT <= 3:
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
# Extract numpy array from GPUFrame if needed
if isinstance(img, GPUFrame):
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
elif hasattr(img, 'get'):
arr = img.get() # CuPy to numpy
else:
arr = np.asarray(img)
result = jnp.array(arr)
for effect in effects_list:
op = effect['op']
if op == 'rotate':
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
result = _jax_fused_fns['rotate'](result, angle=angle)
elif op == 'zoom':
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
result = _jax_fused_fns['zoom'](result, amount=amount)
elif op == 'hue_shift':
degrees = effect.get('degrees', 0)
if abs(degrees) > 0.1:
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
elif op == 'ripple':
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
if amplitude > 0.1:
result = _jax_fused_fns['ripple'](result,
amplitude=amplitude,
frequency=effect.get('frequency', 8),
decay=effect.get('decay', 2),
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
cx=effect.get('center_x'),
cy=effect.get('center_y'))
elif op == 'brightness':
factor = effect.get('factor', 1.0)
result = _jax_fused_fns['brightness'](result, factor=factor)
elif op == 'invert':
amount = effect.get('amount', 0)
if amount > 0.5:
result = _jax_fused_fns['invert'](result)
else:
arr = np.asarray(img)
result = jnp.array(arr)
for effect in effects_list:
op = effect['op']
if op == 'rotate':
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
result = _jax_fused_fns['rotate'](result, angle=angle)
elif op == 'zoom':
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
result = _jax_fused_fns['zoom'](result, amount=amount)
elif op == 'hue_shift':
degrees = effect.get('degrees', 0)
if abs(degrees) > 0.1:
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
elif op == 'ripple':
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
if amplitude > 0.1:
result = _jax_fused_fns['ripple'](result,
amplitude=amplitude,
frequency=effect.get('frequency', 8),
decay=effect.get('decay', 2),
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
cx=effect.get('center_x'),
cy=effect.get('center_y'))
elif op == 'brightness':
factor = effect.get('factor', 1.0)
result = _jax_fused_fns['brightness'](result, factor=factor)
elif op == 'invert':
amount = effect.get('amount', 0)
if amount > 0.5:
result = _jax_fused_fns['invert'](result)
else:
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
return np.asarray(result)
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
return np.asarray(result)
if not _FUSED_KERNELS_AVAILABLE:
# Legacy CuPy/GPUFrame fallback
print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)