Make JAX the primary fused-pipeline path for CPU/GPU parity

JAX via XLA produces identical output on CPU and GPU. Previously
CUDA hand-written kernels were preferred on GPU, causing visual
differences vs the JAX CPU fallback. Now JAX is always used first,
with legacy CuPy/GPUFrame as fallback only when JAX is unavailable.

Also adds comprehensive CLAUDE.md for the monorepo.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
giles
2026-02-25 19:31:53 +00:00
parent b788f1f778
commit 4c2e716558
2 changed files with 119 additions and 44 deletions

74
CLAUDE.md Normal file
View File

@@ -0,0 +1,74 @@
# Art DAG Monorepo
Federated content-addressed DAG execution engine for distributed media processing with ActivityPub ownership and provenance tracking.
## Project Structure
```
core/ # DAG engine (artdag package) - nodes, effects, analysis, planning
l1/ # L1 Celery rendering server (FastAPI + Celery + Redis + PostgreSQL)
l2/ # L2 ActivityPub registry (FastAPI + PostgreSQL)
common/ # Shared templates, middleware, models (artdag_common package)
client/ # CLI client
test/ # Integration & e2e tests
```
## Tech Stack
Python 3.11+, FastAPI, Celery, Redis, PostgreSQL (asyncpg for L1), SQLAlchemy, Pydantic, JAX (CPU/GPU), IPFS/Kubo, Docker Swarm, HTMX + Jinja2 for web UI.
## Key Commands
### Testing
```bash
cd l1 && pytest tests/ # L1 unit tests
cd core && pytest tests/ # Core unit tests
cd test && python run.py # Full integration pipeline
```
- pytest uses `asyncio_mode = "auto"` for async tests
- Test files: `test_*.py`, fixtures in `conftest.py`
### Linting & Type Checking (L1)
```bash
cd l1 && ruff check . # Lint (E, F, I, UP rules)
cd l1 && mypy app/types.py app/routers/recipes.py tests/
```
- Line length: 100 chars (E501 ignored)
- Mypy: strict on `app/types.py`, `app/routers/recipes.py`, `tests/`; gradual elsewhere
- Mypy ignores imports for: celery, redis, artdag, artdag_common, ipfs_client
### Docker
```bash
docker build -f l1/Dockerfile -t celery-l1-server:latest .
docker build -f l1/Dockerfile.gpu -t celery-l1-gpu:latest .
docker build -f l2/Dockerfile -t l2-server:latest .
./deploy.sh # Build, push, deploy stacks
```
## Architecture Patterns
- **3-Phase Execution**: Analyze -> Plan -> Execute (tasks in `l1/tasks/`)
- **Content-Addressed**: All data identified by SHA3-256 hashes or IPFS CIDs
- **Services Pattern**: Business logic in `app/services/`, API endpoints in `app/routers/`
- **Types Module**: Pydantic models and TypedDicts in `app/types.py`
- **Celery Tasks**: In `l1/tasks/`, decorated with `@app.task`
- **S-Expression Effects**: Composable effect language in `l1/sexp_effects/`
- **Storage**: Local filesystem, S3, or IPFS backends (`storage_providers.py`)
## Auth
- L1 <-> L2: scoped JWT tokens (no shared secrets)
- L2: password + OAuth SSO, token revocation in Redis (30-day expiry)
- Federation: ActivityPub RSA signatures (`core/artdag/activitypub/`)
## Key Config Files
- `l1/pyproject.toml` - mypy, pytest, ruff config for L1
- `l1/celery_app.py` - Celery initialization
- `l1/database.py` / `l2/db.py` - SQLAlchemy models
- `l1/docker-compose.yml` / `l2/docker-compose.yml` - Swarm stacks
## Tools
- Use Context7 MCP for up-to-date library documentation
- Playwright MCP is available for browser automation/testing

View File

@@ -979,51 +979,52 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
# Update effects list to exclude resize ops
effects_list = other_effects
if not _FUSED_KERNELS_AVAILABLE:
if _FUSED_JAX_AVAILABLE:
# JAX path: convert to JAX array, apply effects, convert back to numpy
if _FUSED_CALL_COUNT <= 3:
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
# Extract numpy array from GPUFrame if needed
if isinstance(img, GPUFrame):
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
elif hasattr(img, 'get'):
arr = img.get() # CuPy to numpy
# JAX is the primary path — same code on CPU and GPU, XLA handles dispatch
if _FUSED_JAX_AVAILABLE:
if _FUSED_CALL_COUNT <= 3:
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
# Extract numpy array from GPUFrame if needed
if isinstance(img, GPUFrame):
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
elif hasattr(img, 'get'):
arr = img.get() # CuPy to numpy
else:
arr = np.asarray(img)
result = jnp.array(arr)
for effect in effects_list:
op = effect['op']
if op == 'rotate':
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
result = _jax_fused_fns['rotate'](result, angle=angle)
elif op == 'zoom':
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
result = _jax_fused_fns['zoom'](result, amount=amount)
elif op == 'hue_shift':
degrees = effect.get('degrees', 0)
if abs(degrees) > 0.1:
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
elif op == 'ripple':
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
if amplitude > 0.1:
result = _jax_fused_fns['ripple'](result,
amplitude=amplitude,
frequency=effect.get('frequency', 8),
decay=effect.get('decay', 2),
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
cx=effect.get('center_x'),
cy=effect.get('center_y'))
elif op == 'brightness':
factor = effect.get('factor', 1.0)
result = _jax_fused_fns['brightness'](result, factor=factor)
elif op == 'invert':
amount = effect.get('amount', 0)
if amount > 0.5:
result = _jax_fused_fns['invert'](result)
else:
arr = np.asarray(img)
result = jnp.array(arr)
for effect in effects_list:
op = effect['op']
if op == 'rotate':
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
result = _jax_fused_fns['rotate'](result, angle=angle)
elif op == 'zoom':
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
result = _jax_fused_fns['zoom'](result, amount=amount)
elif op == 'hue_shift':
degrees = effect.get('degrees', 0)
if abs(degrees) > 0.1:
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
elif op == 'ripple':
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
if amplitude > 0.1:
result = _jax_fused_fns['ripple'](result,
amplitude=amplitude,
frequency=effect.get('frequency', 8),
decay=effect.get('decay', 2),
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
cx=effect.get('center_x'),
cy=effect.get('center_y'))
elif op == 'brightness':
factor = effect.get('factor', 1.0)
result = _jax_fused_fns['brightness'](result, factor=factor)
elif op == 'invert':
amount = effect.get('amount', 0)
if amount > 0.5:
result = _jax_fused_fns['invert'](result)
else:
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
return np.asarray(result)
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
return np.asarray(result)
if not _FUSED_KERNELS_AVAILABLE:
# Legacy CuPy/GPUFrame fallback
print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)