Make JAX the primary fused-pipeline path for CPU/GPU parity
JAX via XLA produces identical output on CPU and GPU. Previously CUDA hand-written kernels were preferred on GPU, causing visual differences vs the JAX CPU fallback. Now JAX is always used first, with legacy CuPy/GPUFrame as fallback only when JAX is unavailable. Also adds comprehensive CLAUDE.md for the monorepo. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
74
CLAUDE.md
Normal file
74
CLAUDE.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Art DAG Monorepo
|
||||
|
||||
Federated content-addressed DAG execution engine for distributed media processing with ActivityPub ownership and provenance tracking.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
core/ # DAG engine (artdag package) - nodes, effects, analysis, planning
|
||||
l1/ # L1 Celery rendering server (FastAPI + Celery + Redis + PostgreSQL)
|
||||
l2/ # L2 ActivityPub registry (FastAPI + PostgreSQL)
|
||||
common/ # Shared templates, middleware, models (artdag_common package)
|
||||
client/ # CLI client
|
||||
test/ # Integration & e2e tests
|
||||
```
|
||||
|
||||
## Tech Stack
|
||||
|
||||
Python 3.11+, FastAPI, Celery, Redis, PostgreSQL (asyncpg for L1), SQLAlchemy, Pydantic, JAX (CPU/GPU), IPFS/Kubo, Docker Swarm, HTMX + Jinja2 for web UI.
|
||||
|
||||
## Key Commands
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
cd l1 && pytest tests/ # L1 unit tests
|
||||
cd core && pytest tests/ # Core unit tests
|
||||
cd test && python run.py # Full integration pipeline
|
||||
```
|
||||
- pytest uses `asyncio_mode = "auto"` for async tests
|
||||
- Test files: `test_*.py`, fixtures in `conftest.py`
|
||||
|
||||
### Linting & Type Checking (L1)
|
||||
```bash
|
||||
cd l1 && ruff check . # Lint (E, F, I, UP rules)
|
||||
cd l1 && mypy app/types.py app/routers/recipes.py tests/
|
||||
```
|
||||
- Line length: 100 chars (E501 ignored)
|
||||
- Mypy: strict on `app/types.py`, `app/routers/recipes.py`, `tests/`; gradual elsewhere
|
||||
- Mypy ignores imports for: celery, redis, artdag, artdag_common, ipfs_client
|
||||
|
||||
### Docker
|
||||
```bash
|
||||
docker build -f l1/Dockerfile -t celery-l1-server:latest .
|
||||
docker build -f l1/Dockerfile.gpu -t celery-l1-gpu:latest .
|
||||
docker build -f l2/Dockerfile -t l2-server:latest .
|
||||
./deploy.sh # Build, push, deploy stacks
|
||||
```
|
||||
|
||||
## Architecture Patterns
|
||||
|
||||
- **3-Phase Execution**: Analyze -> Plan -> Execute (tasks in `l1/tasks/`)
|
||||
- **Content-Addressed**: All data identified by SHA3-256 hashes or IPFS CIDs
|
||||
- **Services Pattern**: Business logic in `app/services/`, API endpoints in `app/routers/`
|
||||
- **Types Module**: Pydantic models and TypedDicts in `app/types.py`
|
||||
- **Celery Tasks**: In `l1/tasks/`, decorated with `@app.task`
|
||||
- **S-Expression Effects**: Composable effect language in `l1/sexp_effects/`
|
||||
- **Storage**: Local filesystem, S3, or IPFS backends (`storage_providers.py`)
|
||||
|
||||
## Auth
|
||||
|
||||
- L1 <-> L2: scoped JWT tokens (no shared secrets)
|
||||
- L2: password + OAuth SSO, token revocation in Redis (30-day expiry)
|
||||
- Federation: ActivityPub RSA signatures (`core/artdag/activitypub/`)
|
||||
|
||||
## Key Config Files
|
||||
|
||||
- `l1/pyproject.toml` - mypy, pytest, ruff config for L1
|
||||
- `l1/celery_app.py` - Celery initialization
|
||||
- `l1/database.py` / `l2/db.py` - SQLAlchemy models
|
||||
- `l1/docker-compose.yml` / `l2/docker-compose.yml` - Swarm stacks
|
||||
|
||||
## Tools
|
||||
|
||||
- Use Context7 MCP for up-to-date library documentation
|
||||
- Playwright MCP is available for browser automation/testing
|
||||
@@ -979,51 +979,52 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params):
|
||||
# Update effects list to exclude resize ops
|
||||
effects_list = other_effects
|
||||
|
||||
if not _FUSED_KERNELS_AVAILABLE:
|
||||
if _FUSED_JAX_AVAILABLE:
|
||||
# JAX path: convert to JAX array, apply effects, convert back to numpy
|
||||
if _FUSED_CALL_COUNT <= 3:
|
||||
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
|
||||
# Extract numpy array from GPUFrame if needed
|
||||
if isinstance(img, GPUFrame):
|
||||
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
|
||||
elif hasattr(img, 'get'):
|
||||
arr = img.get() # CuPy to numpy
|
||||
# JAX is the primary path — same code on CPU and GPU, XLA handles dispatch
|
||||
if _FUSED_JAX_AVAILABLE:
|
||||
if _FUSED_CALL_COUNT <= 3:
|
||||
print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr)
|
||||
# Extract numpy array from GPUFrame if needed
|
||||
if isinstance(img, GPUFrame):
|
||||
arr = img.cpu if not img.is_on_gpu else img.gpu.get()
|
||||
elif hasattr(img, 'get'):
|
||||
arr = img.get() # CuPy to numpy
|
||||
else:
|
||||
arr = np.asarray(img)
|
||||
result = jnp.array(arr)
|
||||
for effect in effects_list:
|
||||
op = effect['op']
|
||||
if op == 'rotate':
|
||||
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
|
||||
result = _jax_fused_fns['rotate'](result, angle=angle)
|
||||
elif op == 'zoom':
|
||||
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
|
||||
result = _jax_fused_fns['zoom'](result, amount=amount)
|
||||
elif op == 'hue_shift':
|
||||
degrees = effect.get('degrees', 0)
|
||||
if abs(degrees) > 0.1:
|
||||
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
|
||||
elif op == 'ripple':
|
||||
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||
if amplitude > 0.1:
|
||||
result = _jax_fused_fns['ripple'](result,
|
||||
amplitude=amplitude,
|
||||
frequency=effect.get('frequency', 8),
|
||||
decay=effect.get('decay', 2),
|
||||
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
|
||||
cx=effect.get('center_x'),
|
||||
cy=effect.get('center_y'))
|
||||
elif op == 'brightness':
|
||||
factor = effect.get('factor', 1.0)
|
||||
result = _jax_fused_fns['brightness'](result, factor=factor)
|
||||
elif op == 'invert':
|
||||
amount = effect.get('amount', 0)
|
||||
if amount > 0.5:
|
||||
result = _jax_fused_fns['invert'](result)
|
||||
else:
|
||||
arr = np.asarray(img)
|
||||
result = jnp.array(arr)
|
||||
for effect in effects_list:
|
||||
op = effect['op']
|
||||
if op == 'rotate':
|
||||
angle = dynamic_params.get('rotate_angle', effect.get('angle', 0))
|
||||
result = _jax_fused_fns['rotate'](result, angle=angle)
|
||||
elif op == 'zoom':
|
||||
amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0))
|
||||
result = _jax_fused_fns['zoom'](result, amount=amount)
|
||||
elif op == 'hue_shift':
|
||||
degrees = effect.get('degrees', 0)
|
||||
if abs(degrees) > 0.1:
|
||||
result = _jax_fused_fns['hue_shift'](result, degrees=degrees)
|
||||
elif op == 'ripple':
|
||||
amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))
|
||||
if amplitude > 0.1:
|
||||
result = _jax_fused_fns['ripple'](result,
|
||||
amplitude=amplitude,
|
||||
frequency=effect.get('frequency', 8),
|
||||
decay=effect.get('decay', 2),
|
||||
phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)),
|
||||
cx=effect.get('center_x'),
|
||||
cy=effect.get('center_y'))
|
||||
elif op == 'brightness':
|
||||
factor = effect.get('factor', 1.0)
|
||||
result = _jax_fused_fns['brightness'](result, factor=factor)
|
||||
elif op == 'invert':
|
||||
amount = effect.get('amount', 0)
|
||||
if amount > 0.5:
|
||||
result = _jax_fused_fns['invert'](result)
|
||||
else:
|
||||
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||
return np.asarray(result)
|
||||
raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize")
|
||||
return np.asarray(result)
|
||||
|
||||
if not _FUSED_KERNELS_AVAILABLE:
|
||||
|
||||
# Legacy CuPy/GPUFrame fallback
|
||||
print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user