diff --git a/l1/.dockerignore b/l1/.dockerignore new file mode 100644 index 0000000..f48a442 --- /dev/null +++ b/l1/.dockerignore @@ -0,0 +1,22 @@ +# Don't copy local clones - Dockerfile will clone fresh +artdag-effects/ + +# Python cache +__pycache__/ +*.py[cod] +*.egg-info/ +.pytest_cache/ + +# Virtual environments +.venv/ +venv/ + +# Local env +.env + +# Git +.git/ + +# IDE +.vscode/ +.idea/ diff --git a/l1/.env.example b/l1/.env.example new file mode 100644 index 0000000..0b0e063 --- /dev/null +++ b/l1/.env.example @@ -0,0 +1,20 @@ +# L1 Server Configuration + +# PostgreSQL password (REQUIRED - no default) +POSTGRES_PASSWORD=changeme-generate-with-openssl-rand-hex-16 + +# Admin token for purge operations (REQUIRED - no default) +# Generate with: openssl rand -hex 32 +ADMIN_TOKEN=changeme-generate-with-openssl-rand-hex-32 + +# L1 host IP/hostname for GPU worker cross-VPC access +L1_HOST=your-l1-server-ip + +# This L1 server's public URL (sent to L2 when publishing) +L1_PUBLIC_URL=https://l1.artdag.rose-ash.com + +# L2 server URL (for authentication and publishing) +L2_SERVER=https://artdag.rose-ash.com + +# L2 domain for ActivityPub actor IDs (e.g., @user@domain) +L2_DOMAIN=artdag.rose-ash.com diff --git a/l1/.env.gpu b/l1/.env.gpu new file mode 100644 index 0000000..9253dcd --- /dev/null +++ b/l1/.env.gpu @@ -0,0 +1,11 @@ +# GPU worker env - connects to L1 host via public IP (cross-VPC) +REDIS_URL=redis://138.68.142.139:16379/5 +DATABASE_URL=postgresql://artdag:f960bcc61d8b2155a1d57f7dd72c1c58@138.68.142.139:15432/artdag +IPFS_API=/ip4/138.68.142.139/tcp/15001 +IPFS_GATEWAYS=https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link +IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs +CACHE_DIR=/data/cache +C_FORCE_ROOT=true +ARTDAG_CLUSTER_KEY= +NVIDIA_VISIBLE_DEVICES=all +STREAMING_GPU_PERSIST=0 diff --git a/l1/.gitea/workflows/ci.yml b/l1/.gitea/workflows/ci.yml new file mode 100644 index 0000000..a79f66e --- /dev/null +++ b/l1/.gitea/workflows/ci.yml @@ -0,0 +1,63 @@ +name: Build and Deploy + +on: + push: + +env: + REGISTRY: registry.rose-ash.com:5000 + IMAGE_CPU: celery-l1-server + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install tools + run: | + apt-get update && apt-get install -y --no-install-recommends openssh-client + + - name: Set up SSH + env: + SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + mkdir -p ~/.ssh + echo "$SSH_KEY" > ~/.ssh/id_rsa + chmod 600 ~/.ssh/id_rsa + ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true + + - name: Pull latest code on server + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + BRANCH: ${{ github.ref_name }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + git fetch origin $BRANCH + git checkout $BRANCH + git reset --hard origin/$BRANCH + " + + - name: Build and push image + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + docker build --build-arg CACHEBUST=\$(date +%s) -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} . + docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest + docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} + " + + - name: Deploy stack + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + docker stack deploy -c docker-compose.yml celery + echo 'Waiting for services to update...' + sleep 10 + docker stack services celery + " diff --git a/l1/.gitignore b/l1/.gitignore new file mode 100644 index 0000000..3ca2eb4 --- /dev/null +++ b/l1/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.py[cod] +.pytest_cache/ +*.egg-info/ +.venv/ +venv/ +.env +artdag-effects/ diff --git a/l1/Dockerfile b/l1/Dockerfile new file mode 100644 index 0000000..90a770d --- /dev/null +++ b/l1/Dockerfile @@ -0,0 +1,31 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install git and ffmpeg (for video transcoding) +RUN apt-get update && apt-get install -y --no-install-recommends git ffmpeg && rm -rf /var/lib/apt/lists/* + +# Install dependencies +COPY requirements.txt . +ARG CACHEBUST=1 +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY . . + +# Clone effects repo +RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects + +# Build client tarball for download +RUN ./build-client.sh + +# Create cache directory +RUN mkdir -p /data/cache + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV EFFECTS_PATH=/app/artdag-effects +ENV PYTHONPATH=/app + +# Default command runs the server +CMD ["python", "server.py"] diff --git a/l1/Dockerfile.gpu b/l1/Dockerfile.gpu new file mode 100644 index 0000000..967f788 --- /dev/null +++ b/l1/Dockerfile.gpu @@ -0,0 +1,98 @@ +# GPU-enabled worker image +# Multi-stage build: use devel image for compiling, runtime for final image + +# Stage 1: Build decord with CUDA +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 AS builder + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 \ + python3.11-venv \ + python3.11-dev \ + python3-pip \ + git \ + cmake \ + build-essential \ + pkg-config \ + libavcodec-dev \ + libavformat-dev \ + libavutil-dev \ + libavdevice-dev \ + libavfilter-dev \ + libswresample-dev \ + libswscale-dev \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3.11 /usr/bin/python3 \ + && ln -sf /usr/bin/python3 /usr/bin/python + +# Download Video Codec SDK headers for NVDEC/NVCUVID +RUN git clone https://github.com/FFmpeg/nv-codec-headers.git /tmp/nv-codec-headers && \ + cd /tmp/nv-codec-headers && make install && rm -rf /tmp/nv-codec-headers + +# Create stub for libnvcuvid (real library comes from driver at runtime) +RUN echo 'void* __nvcuvid_stub__;' | gcc -shared -x c - -o /usr/local/cuda/lib64/libnvcuvid.so + +# Build decord with CUDA support +RUN git clone --recursive https://github.com/dmlc/decord /tmp/decord && \ + cd /tmp/decord && \ + mkdir build && cd build && \ + cmake .. -DUSE_CUDA=ON -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES="70;75;80;86;89;90" && \ + make -j$(nproc) && \ + cd ../python && pip install --target=/decord-install . + +# Stage 2: Runtime image +FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 + +WORKDIR /app + +# Install Python 3.11 and system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 \ + python3.11-venv \ + python3-pip \ + git \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3.11 /usr/bin/python3 \ + && ln -sf /usr/bin/python3 /usr/bin/python + +# Upgrade pip +RUN python3 -m pip install --upgrade pip + +# Install CPU dependencies first +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Install GPU-specific dependencies (CuPy for CUDA 12.x) +RUN pip install --no-cache-dir cupy-cuda12x + +# Install PyNvVideoCodec for zero-copy GPU encoding +RUN pip install --no-cache-dir PyNvVideoCodec + +# Copy decord from builder stage +COPY --from=builder /decord-install /usr/local/lib/python3.11/dist-packages/ +COPY --from=builder /tmp/decord/build/libdecord.so /usr/local/lib/ +RUN ldconfig + +# Clone effects repo (before COPY so it gets cached) +RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects + +# Copy application (this invalidates cache for any code change) +COPY . . + +# Create cache directory +RUN mkdir -p /data/cache + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV EFFECTS_PATH=/app/artdag-effects +ENV PYTHONPATH=/app +# GPU persistence enabled - frames stay on GPU throughout pipeline +ENV STREAMING_GPU_PERSIST=1 +# Preload libnvcuvid for decord NVDEC GPU decode +ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnvcuvid.so +# Use cluster's public IPFS gateway for HLS segment URLs +ENV IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs + +# Default command runs celery worker +CMD ["celery", "-A", "celery_app", "worker", "--loglevel=info", "-E", "-Q", "gpu,celery"] diff --git a/l1/README.md b/l1/README.md new file mode 100644 index 0000000..d387437 --- /dev/null +++ b/l1/README.md @@ -0,0 +1,329 @@ +# Art DAG L1 Server + +L1 rendering server for the Art DAG system. Manages distributed rendering jobs via Celery workers with content-addressable caching and optional IPFS integration. + +## Features + +- **3-Phase Execution**: Analyze → Plan → Execute pipeline for recipe-based rendering +- **Content-Addressable Caching**: IPFS CIDs with deduplication +- **IPFS Integration**: Optional IPFS-primary mode for distributed storage +- **Storage Providers**: S3, IPFS, and local storage backends +- **DAG Visualization**: Interactive graph visualization of execution plans +- **SPA-Style Navigation**: Smooth URL-based navigation without full page reloads +- **L2 Federation**: Publish outputs to ActivityPub registry + +## Dependencies + +- **artdag** (GitHub): Core DAG execution engine +- **artdag-effects** (rose-ash): Effect implementations +- **artdag-common**: Shared templates and middleware +- **Redis**: Message broker, result backend, and run persistence +- **PostgreSQL**: Metadata storage +- **IPFS** (optional): Distributed content storage + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Start Redis +redis-server + +# Start a worker +celery -A celery_app worker --loglevel=info -E + +# Start the L1 server +python server.py +``` + +## Docker Swarm Deployment + +```bash +docker stack deploy -c docker-compose.yml artdag +``` + +The stack includes: +- **redis**: Message broker (Redis 7) +- **postgres**: Metadata database (PostgreSQL 16) +- **ipfs**: IPFS node (Kubo) +- **l1-server**: FastAPI web server +- **l1-worker**: Celery workers (2 replicas) +- **flower**: Celery task monitoring + +## Configuration + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `HOST` | `0.0.0.0` | Server bind address | +| `PORT` | `8000` | Server port | +| `REDIS_URL` | `redis://localhost:6379/5` | Redis connection | +| `DATABASE_URL` | **(required)** | PostgreSQL connection | +| `CACHE_DIR` | `~/.artdag/cache` | Local cache directory | +| `IPFS_API` | `/dns/localhost/tcp/5001` | IPFS API multiaddr | +| `IPFS_GATEWAY_URL` | `https://ipfs.io/ipfs` | Public IPFS gateway | +| `IPFS_PRIMARY` | `false` | Enable IPFS-primary mode | +| `L1_PUBLIC_URL` | `http://localhost:8100` | Public URL for redirects | +| `L2_SERVER` | - | L2 ActivityPub server URL | +| `L2_DOMAIN` | - | L2 domain for federation | +| `ARTDAG_CLUSTER_KEY` | - | Cluster key for trust domains | + +### IPFS-Primary Mode + +When `IPFS_PRIMARY=true`, all content is stored on IPFS: +- Input files are added to IPFS on upload +- Analysis results stored as JSON on IPFS +- Execution plans stored on IPFS +- Step outputs pinned to IPFS +- Local cache becomes a read-through cache + +This enables distributed execution across multiple L1 nodes sharing the same IPFS network. + +## Web UI + +| Path | Description | +|------|-------------| +| `/` | Home page with server info | +| `/runs` | View and manage rendering runs | +| `/run/{id}` | Run detail with tabs: Plan, Analysis, Artifacts | +| `/run/{id}/plan` | Interactive DAG visualization | +| `/run/{id}/analysis` | Audio/video analysis data | +| `/run/{id}/artifacts` | Cached step outputs | +| `/recipes` | Browse and run available recipes | +| `/recipe/{id}` | Recipe detail page | +| `/recipe/{id}/dag` | Recipe DAG visualization | +| `/media` | Browse cached media files | +| `/storage` | Manage storage providers | +| `/auth` | Receive auth token from L2 | +| `/logout` | Log out | +| `/download/client` | Download CLI client | + +## API Reference + +Interactive docs: http://localhost:8100/docs + +### Runs + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/runs` | Start a rendering run | +| GET | `/runs` | List all runs (paginated) | +| GET | `/runs/{run_id}` | Get run status | +| DELETE | `/runs/{run_id}` | Delete a run | +| GET | `/api/run/{run_id}` | Get run as JSON | +| GET | `/api/run/{run_id}/plan` | Get execution plan JSON | +| GET | `/api/run/{run_id}/analysis` | Get analysis data JSON | + +### Recipes + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/recipes/upload` | Upload recipe YAML | +| GET | `/recipes` | List recipes (paginated) | +| GET | `/recipes/{recipe_id}` | Get recipe details | +| DELETE | `/recipes/{recipe_id}` | Delete recipe | +| POST | `/recipes/{recipe_id}/run` | Execute recipe | + +### Cache + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/cache/{cid}` | Get cached content (with preview) | +| GET | `/cache/{cid}/raw` | Download raw content | +| GET | `/cache/{cid}/mp4` | Get MP4 video | +| GET | `/cache/{cid}/meta` | Get content metadata | +| PATCH | `/cache/{cid}/meta` | Update metadata | +| POST | `/cache/{cid}/publish` | Publish to L2 | +| DELETE | `/cache/{cid}` | Delete from cache | +| POST | `/cache/import?path=` | Import local file | +| POST | `/cache/upload` | Upload file | +| GET | `/media` | Browse media gallery | + +### IPFS + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/ipfs/{cid}` | Redirect to IPFS gateway | +| GET | `/ipfs/{cid}/raw` | Fetch raw content from IPFS | + +### Storage Providers + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/storage` | List storage providers | +| POST | `/storage` | Add provider (form) | +| POST | `/storage/add` | Add provider (JSON) | +| GET | `/storage/{id}` | Get provider details | +| PATCH | `/storage/{id}` | Update provider | +| DELETE | `/storage/{id}` | Delete provider | +| POST | `/storage/{id}/test` | Test connection | +| GET | `/storage/type/{type}` | Get form for provider type | + +### 3-Phase API + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/api/plan` | Generate execution plan | +| POST | `/api/execute` | Execute a plan | +| POST | `/api/run-recipe` | Full pipeline (analyze+plan+execute) | + +### Authentication + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/auth` | Receive auth token from L2 | +| GET | `/logout` | Log out | +| POST | `/auth/revoke` | Revoke a specific token | +| POST | `/auth/revoke-user` | Revoke all user tokens | + +## 3-Phase Execution + +Recipes are executed in three phases: + +### Phase 1: Analyze +Extract features from input files: +- **Audio/Video**: Tempo, beat times, energy levels +- Results cached by CID + +### Phase 2: Plan +Generate an execution plan: +- Parse recipe YAML +- Resolve dependencies between steps +- Compute cache IDs for each step +- Skip already-cached steps + +### Phase 3: Execute +Run the plan level by level: +- Steps at each level run in parallel +- Results cached with content-addressable hashes +- Progress tracked in Redis + +## Recipe Format + +Recipes define reusable DAG pipelines: + +```yaml +name: beat-sync +version: "1.0" +description: "Synchronize video to audio beats" + +inputs: + video: + type: video + description: "Source video" + audio: + type: audio + description: "Audio track" + +steps: + - id: analyze_audio + type: ANALYZE + inputs: [audio] + config: + features: [beats, energy] + + - id: sync_video + type: BEAT_SYNC + inputs: [video, analyze_audio] + config: + mode: stretch + +output: sync_video +``` + +## Storage + +### Local Cache +- Location: `~/.artdag/cache/` (or `CACHE_DIR`) +- Content-addressed by IPFS CID +- Subdirectories: `plans/`, `analysis/` + +### Redis +- Database 5 (configurable via `REDIS_URL`) +- Keys: + - `artdag:run:*` - Run state + - `artdag:recipe:*` - Recipe definitions + - `artdag:revoked:*` - Token revocation + - `artdag:user_tokens:*` - User token tracking + +### PostgreSQL +- Content metadata +- Storage provider configurations +- Provenance records + +## Authentication + +L1 servers authenticate via L2 (ActivityPub registry). No shared secrets required. + +### Flow +1. User clicks "Attach" on L2's Renderers page +2. L2 creates a scoped token bound to this L1 +3. User redirected to L1's `/auth?auth_token=...` +4. L1 calls L2's `/auth/verify` to validate +5. L1 sets local cookie and records token + +### Token Revocation +- Tokens tracked per-user in Redis +- L2 calls `/auth/revoke-user` on logout +- Revoked hashes stored with 30-day expiry +- Every request checks revocation list + +## CLI Usage + +```bash +# Quick render (effect mode) +python render.py dog cat --sync + +# Submit async +python render.py dog cat + +# Run a recipe +curl -X POST http://localhost:8100/recipes/beat-sync/run \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{"inputs": {"video": "abc123...", "audio": "def456..."}}' +``` + +## Architecture + +``` +L1 Server (FastAPI) + │ + ├── Web UI (Jinja2 + HTMX + Tailwind) + │ + ├── POST /runs → Celery tasks + │ │ + │ └── celery_app.py + │ ├── tasks/analyze.py (Phase 1) + │ ├── tasks/execute.py (Phase 3 steps) + │ └── tasks/orchestrate.py (Full pipeline) + │ + ├── cache_manager.py + │ │ + │ ├── Local filesystem (CACHE_DIR) + │ ├── IPFS (ipfs_client.py) + │ └── S3/Storage providers + │ + └── database.py (PostgreSQL metadata) +``` + +## Provenance + +Every render produces a provenance record: + +```json +{ + "task_id": "celery-task-uuid", + "rendered_at": "2026-01-07T...", + "rendered_by": "@giles@artdag.rose-ash.com", + "output": {"name": "...", "cid": "Qm..."}, + "inputs": [...], + "effects": [...], + "infrastructure": { + "software": {"name": "infra:artdag", "cid": "Qm..."}, + "hardware": {"name": "infra:giles-hp", "cid": "Qm..."} + } +} +``` diff --git a/l1/app/__init__.py b/l1/app/__init__.py new file mode 100644 index 0000000..408983b --- /dev/null +++ b/l1/app/__init__.py @@ -0,0 +1,237 @@ +""" +Art-DAG L1 Server Application Factory. + +Creates and configures the FastAPI application with all routers and middleware. +""" + +import secrets +import time +from pathlib import Path +from urllib.parse import quote + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles + +from artdag_common import create_jinja_env +from artdag_common.middleware.auth import get_user_from_cookie + +from .config import settings + +# Paths that should never trigger a silent auth check +_SKIP_PREFIXES = ("/auth/", "/static/", "/api/", "/ipfs/", "/download/", "/inbox", "/health", "/internal/", "/oembed") +_SILENT_CHECK_COOLDOWN = 300 # 5 minutes +_DEVICE_COOKIE = "artdag_did" +_DEVICE_COOKIE_MAX_AGE = 30 * 24 * 3600 # 30 days + +# Derive external base URL from oauth_redirect_uri (e.g. https://celery-artdag.rose-ash.com) +_EXTERNAL_BASE = settings.oauth_redirect_uri.rsplit("/auth/callback", 1)[0] + + +def _external_url(request: Request) -> str: + """Build external URL from request path + query, using configured base domain.""" + url = f"{_EXTERNAL_BASE}{request.url.path}" + if request.url.query: + url += f"?{request.url.query}" + return url + + +def create_app() -> FastAPI: + """ + Create and configure the L1 FastAPI application. + + Returns: + Configured FastAPI instance + """ + app = FastAPI( + title="Art-DAG L1 Server", + description="Content-addressed media processing with distributed execution", + version="1.0.0", + ) + + # Database lifecycle events + from database import init_db, close_db + + @app.on_event("startup") + async def startup(): + await init_db() + + @app.on_event("shutdown") + async def shutdown(): + await close_db() + + # Silent auth check — auto-login via prompt=none OAuth + # NOTE: registered BEFORE device_id so device_id is outermost (runs first) + @app.middleware("http") + async def silent_auth_check(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _SKIP_PREFIXES) + or request.headers.get("hx-request") # skip HTMX + ): + return await call_next(request) + + # Already logged in — but verify account hasn't logged out + if get_user_from_cookie(request): + device_id = getattr(request.state, "device_id", None) + if device_id: + try: + from .dependencies import get_redis_client + r = get_redis_client() + if not r.get(f"did_auth:{device_id}"): + # Account logged out — clear our cookie + response = await call_next(request) + response.delete_cookie("artdag_session") + response.delete_cookie("pnone_at") + return response + except Exception: + pass + return await call_next(request) + + # Check cooldown — don't re-check within 5 minutes + pnone_at = request.cookies.get("pnone_at") + if pnone_at: + try: + pnone_ts = float(pnone_at) + if (time.time() - pnone_ts) < _SILENT_CHECK_COOLDOWN: + # But first check if account signalled a login via inbox delivery + device_id = getattr(request.state, "device_id", None) + if device_id: + try: + from .dependencies import get_redis_client + r = get_redis_client() + auth_ts = r.get(f"did_auth:{device_id}") + if auth_ts and float(auth_ts) > pnone_ts: + # Login happened since our last check — retry + current_url = _external_url(request) + return RedirectResponse( + url=f"/auth/login?prompt=none&next={quote(current_url, safe='')}", + status_code=302, + ) + except Exception: + pass + return await call_next(request) + except (ValueError, TypeError): + pass + + # Redirect to silent OAuth check + current_url = _external_url(request) + return RedirectResponse( + url=f"/auth/login?prompt=none&next={quote(current_url, safe='')}", + status_code=302, + ) + + # Device ID middleware — track browser identity across domains + # Registered AFTER silent_auth_check so it's outermost (always runs) + @app.middleware("http") + async def device_id_middleware(request: Request, call_next): + did = request.cookies.get(_DEVICE_COOKIE) + if did: + request.state.device_id = did + request.state._new_device_id = False + else: + request.state.device_id = secrets.token_urlsafe(32) + request.state._new_device_id = True + + response = await call_next(request) + + if getattr(request.state, "_new_device_id", False): + response.set_cookie( + key=_DEVICE_COOKIE, + value=request.state.device_id, + max_age=_DEVICE_COOKIE_MAX_AGE, + httponly=True, + samesite="lax", + secure=True, + ) + return response + + # Coop fragment pre-fetch — inject nav-tree, auth-menu, cart-mini into + # request.state for full-page HTML renders. Skips HTMX, API, and + # internal paths. Failures are silent (fragments default to ""). + _FRAG_SKIP = ("/auth/", "/api/", "/internal/", "/health", "/oembed", + "/ipfs/", "/download/", "/inbox", "/static/") + + @app.middleware("http") + async def coop_fragments_middleware(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _FRAG_SKIP) + or request.headers.get("hx-request") + or request.headers.get(fragments.FRAGMENT_HEADER) + ): + request.state.nav_tree_html = "" + request.state.auth_menu_html = "" + request.state.cart_mini_html = "" + return await call_next(request) + + from artdag_common.fragments import fetch_fragments as _fetch_frags + + user = get_user_from_cookie(request) + auth_params = {"email": user.email} if user and user.email else {} + nav_params = {"app_name": "artdag", "path": path} + + try: + nav_tree_html, auth_menu_html, cart_mini_html = await _fetch_frags([ + ("blog", "nav-tree", nav_params), + ("account", "auth-menu", auth_params or None), + ("cart", "cart-mini", None), + ]) + except Exception: + nav_tree_html = auth_menu_html = cart_mini_html = "" + + request.state.nav_tree_html = nav_tree_html + request.state.auth_menu_html = auth_menu_html + request.state.cart_mini_html = cart_mini_html + + return await call_next(request) + + # Initialize Jinja2 templates + template_dir = Path(__file__).parent / "templates" + app.state.templates = create_jinja_env(template_dir) + + # Custom 404 handler + @app.exception_handler(404) + async def not_found_handler(request: Request, exc): + from artdag_common.middleware import wants_html + if wants_html(request): + from artdag_common import render + return render(app.state.templates, "404.html", request, + user=None, + status_code=404, + ) + return JSONResponse({"detail": "Not found"}, status_code=404) + + # Include routers + from .routers import auth, storage, api, recipes, cache, runs, home, effects, inbox, fragments, oembed + + # Home and auth routers (root level) + app.include_router(home.router, tags=["home"]) + app.include_router(auth.router, prefix="/auth", tags=["auth"]) + app.include_router(inbox.router, tags=["inbox"]) + app.include_router(fragments.router, tags=["fragments"]) + app.include_router(oembed.router, tags=["oembed"]) + + # Feature routers + app.include_router(storage.router, prefix="/storage", tags=["storage"]) + app.include_router(api.router, prefix="/api", tags=["api"]) + + # Runs and recipes routers + app.include_router(runs.router, prefix="/runs", tags=["runs"]) + app.include_router(recipes.router, prefix="/recipes", tags=["recipes"]) + + # Cache router - handles /cache and /media + app.include_router(cache.router, prefix="/cache", tags=["cache"]) + # Also mount cache router at /media for convenience + app.include_router(cache.router, prefix="/media", tags=["media"]) + + # Effects router + app.include_router(effects.router, prefix="/effects", tags=["effects"]) + + return app + + +# Create the default app instance +app = create_app() diff --git a/l1/app/config.py b/l1/app/config.py new file mode 100644 index 0000000..8aa94d7 --- /dev/null +++ b/l1/app/config.py @@ -0,0 +1,116 @@ +""" +L1 Server Configuration. + +Environment-based configuration with sensible defaults. +All config should go through this module - no direct os.environ calls elsewhere. +""" + +import os +import sys +from pathlib import Path +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Settings: + """Application settings loaded from environment.""" + + # Server + host: str = field(default_factory=lambda: os.environ.get("HOST", "0.0.0.0")) + port: int = field(default_factory=lambda: int(os.environ.get("PORT", "8000"))) + debug: bool = field(default_factory=lambda: os.environ.get("DEBUG", "").lower() == "true") + + # Cache (use /data/cache in Docker via env var, ~/.artdag/cache locally) + cache_dir: Path = field( + default_factory=lambda: Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache"))) + ) + + # Redis + redis_url: str = field( + default_factory=lambda: os.environ.get("REDIS_URL", "redis://localhost:6379/5") + ) + + # Database + database_url: str = field( + default_factory=lambda: os.environ.get("DATABASE_URL", "") + ) + + # IPFS + ipfs_api: str = field( + default_factory=lambda: os.environ.get("IPFS_API", "/dns/localhost/tcp/5001") + ) + ipfs_gateway_url: str = field( + default_factory=lambda: os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ) + + # OAuth SSO (replaces L2 auth) + oauth_authorize_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_AUTHORIZE_URL", "https://account.rose-ash.com/auth/oauth/authorize") + ) + oauth_token_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_TOKEN_URL", "https://account.rose-ash.com/auth/oauth/token") + ) + oauth_client_id: str = field( + default_factory=lambda: os.environ.get("OAUTH_CLIENT_ID", "artdag") + ) + oauth_redirect_uri: str = field( + default_factory=lambda: os.environ.get("OAUTH_REDIRECT_URI", "https://celery-artdag.rose-ash.com/auth/callback") + ) + oauth_logout_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_LOGOUT_URL", "https://account.rose-ash.com/auth/sso-logout/") + ) + secret_key: str = field( + default_factory=lambda: os.environ.get("SECRET_KEY", "change-me-in-production") + ) + + # GPU/Streaming settings + streaming_gpu_persist: bool = field( + default_factory=lambda: os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" + ) + ipfs_gateways: str = field( + default_factory=lambda: os.environ.get( + "IPFS_GATEWAYS", "https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link" + ) + ) + + # Derived paths + @property + def plan_cache_dir(self) -> Path: + return self.cache_dir / "plans" + + @property + def analysis_cache_dir(self) -> Path: + return self.cache_dir / "analysis" + + def ensure_dirs(self) -> None: + """Create required directories.""" + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.plan_cache_dir.mkdir(parents=True, exist_ok=True) + self.analysis_cache_dir.mkdir(parents=True, exist_ok=True) + + + def log_config(self, logger=None) -> None: + """Log all configuration values for debugging.""" + output = logger.info if logger else lambda x: print(x, file=sys.stderr) + output("=" * 60) + output("CONFIGURATION") + output("=" * 60) + output(f" cache_dir: {self.cache_dir}") + output(f" redis_url: {self.redis_url}") + output(f" database_url: {self.database_url[:50]}...") + output(f" ipfs_api: {self.ipfs_api}") + output(f" ipfs_gateway_url: {self.ipfs_gateway_url}") + output(f" ipfs_gateways: {self.ipfs_gateways[:50]}...") + output(f" streaming_gpu_persist: {self.streaming_gpu_persist}") + output(f" oauth_client_id: {self.oauth_client_id}") + output(f" oauth_authorize_url: {self.oauth_authorize_url}") + output("=" * 60) + + +# Singleton settings instance +settings = Settings() + +# Log config on import if DEBUG or SHOW_CONFIG is set +if os.environ.get("DEBUG") or os.environ.get("SHOW_CONFIG"): + settings.log_config() diff --git a/l1/app/dependencies.py b/l1/app/dependencies.py new file mode 100644 index 0000000..fc59947 --- /dev/null +++ b/l1/app/dependencies.py @@ -0,0 +1,186 @@ +""" +FastAPI dependency injection container. + +Provides shared resources and services to route handlers. +""" + +from functools import lru_cache +from typing import Optional +import asyncio + +from fastapi import Request, Depends, HTTPException +from jinja2 import Environment + +from artdag_common.middleware.auth import UserContext, get_user_from_cookie, get_user_from_header + +from .config import settings + + +# Lazy imports to avoid circular dependencies +_redis_client = None +_cache_manager = None +_database = None + + +def get_redis_client(): + """Get the Redis client singleton.""" + global _redis_client + if _redis_client is None: + import redis + _redis_client = redis.from_url(settings.redis_url, decode_responses=True) + return _redis_client + + +def get_cache_manager(): + """Get the cache manager singleton.""" + global _cache_manager + if _cache_manager is None: + from cache_manager import get_cache_manager as _get_cache_manager + _cache_manager = _get_cache_manager() + return _cache_manager + + +def get_database(): + """Get the database singleton.""" + global _database + if _database is None: + import database + _database = database + return _database + + +def get_templates(request: Request) -> Environment: + """Get the Jinja2 environment from app state.""" + return request.app.state.templates + + +async def get_current_user(request: Request) -> Optional[UserContext]: + """ + Get the current user from request (cookie or header). + + This is a permissive dependency - returns None if not authenticated. + Use require_auth for routes that require authentication. + """ + # Try header first (API clients) + ctx = get_user_from_header(request) + if ctx: + return ctx + + # Fall back to cookie (browser) + return get_user_from_cookie(request) + + +async def require_auth(request: Request) -> UserContext: + """ + Require authentication for a route. + + Raises: + HTTPException 401 if not authenticated + HTTPException 302 redirect to login for HTML requests + """ + ctx = await get_current_user(request) + if ctx is None: + # Check if HTML request for redirect + accept = request.headers.get("accept", "") + if "text/html" in accept: + raise HTTPException( + status_code=302, + headers={"Location": "/auth/login"} + ) + raise HTTPException(status_code=401, detail="Authentication required") + return ctx + + +async def get_user_context_from_cookie(request: Request) -> Optional[UserContext]: + """ + Legacy compatibility: get user from cookie. + + Validates token with L2 server if configured. + """ + ctx = get_user_from_cookie(request) + if ctx is None: + return None + + # If L2 server configured, could validate token here + # For now, trust the cookie + return ctx + + +# Service dependencies (lazy loading) + +def get_run_service(): + """Get the run service.""" + from .services.run_service import RunService + return RunService( + database=get_database(), + redis=get_redis_client(), + cache=get_cache_manager(), + ) + + +def get_recipe_service(): + """Get the recipe service.""" + from .services.recipe_service import RecipeService + return RecipeService( + redis=get_redis_client(), # Kept for API compatibility, not used + cache=get_cache_manager(), + ) + + +def get_cache_service(): + """Get the cache service.""" + from .services.cache_service import CacheService + return CacheService( + cache_manager=get_cache_manager(), + database=get_database(), + ) + + +async def get_nav_counts(actor_id: Optional[str] = None) -> dict: + """ + Get counts for navigation bar display. + + Returns dict with: runs, recipes, effects, media, storage + """ + counts = {} + + try: + import database + counts["media"] = await database.count_user_items(actor_id) if actor_id else 0 + except Exception: + pass + + try: + recipe_service = get_recipe_service() + recipes = await recipe_service.list_recipes(actor_id) + counts["recipes"] = len(recipes) + except Exception: + pass + + try: + run_service = get_run_service() + runs = await run_service.list_runs(actor_id) + counts["runs"] = len(runs) + except Exception: + pass + + try: + # Effects are stored in _effects/ directory, not in cache + from pathlib import Path + cache_mgr = get_cache_manager() + effects_dir = Path(cache_mgr.cache_dir) / "_effects" + if effects_dir.exists(): + counts["effects"] = len([d for d in effects_dir.iterdir() if d.is_dir()]) + else: + counts["effects"] = 0 + except Exception: + pass + + try: + import database + storage_providers = await database.get_user_storage_providers(actor_id) if actor_id else [] + counts["storage"] = len(storage_providers) if storage_providers else 0 + except Exception: + pass + + return counts diff --git a/l1/app/repositories/__init__.py b/l1/app/repositories/__init__.py new file mode 100644 index 0000000..7985294 --- /dev/null +++ b/l1/app/repositories/__init__.py @@ -0,0 +1,10 @@ +""" +L1 Server Repositories. + +Data access layer for persistence operations. +""" + +# TODO: Implement repositories +# - RunRepository - Redis-backed run storage +# - RecipeRepository - Redis-backed recipe storage +# - CacheRepository - Filesystem + PostgreSQL cache metadata diff --git a/l1/app/routers/__init__.py b/l1/app/routers/__init__.py new file mode 100644 index 0000000..f0a9d54 --- /dev/null +++ b/l1/app/routers/__init__.py @@ -0,0 +1,23 @@ +""" +L1 Server Routers. + +Each router handles a specific domain of functionality. +""" + +from . import auth +from . import storage +from . import api +from . import recipes +from . import cache +from . import runs +from . import home + +__all__ = [ + "auth", + "storage", + "api", + "recipes", + "cache", + "runs", + "home", +] diff --git a/l1/app/routers/api.py b/l1/app/routers/api.py new file mode 100644 index 0000000..5288342 --- /dev/null +++ b/l1/app/routers/api.py @@ -0,0 +1,257 @@ +""" +3-phase API routes for L1 server. + +Provides the plan/execute/run-recipe endpoints for programmatic access. +""" + +import hashlib +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from artdag_common.middleware.auth import UserContext +from ..dependencies import require_auth, get_redis_client, get_cache_manager + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Redis key prefix +RUNS_KEY_PREFIX = "artdag:run:" + + +class PlanRequest(BaseModel): + recipe_sexp: str + input_hashes: Dict[str, str] + + +class ExecutePlanRequest(BaseModel): + plan_json: str + run_id: Optional[str] = None + + +class RecipeRunRequest(BaseModel): + recipe_sexp: str + input_hashes: Dict[str, str] + + +def compute_run_id(input_hashes: List[str], recipe: str, recipe_hash: str = None) -> str: + """Compute deterministic run_id from inputs and recipe.""" + data = { + "inputs": sorted(input_hashes), + "recipe": recipe_hash or f"effect:{recipe}", + "version": "1", + } + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha3_256(json_str.encode()).hexdigest() + + +@router.post("/plan") +async def generate_plan_endpoint( + request: PlanRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Generate an execution plan without executing it. + + Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model. + Returns the plan with cache status for each step. + """ + from tasks.orchestrate import generate_plan + + try: + task = generate_plan.delay( + recipe_sexp=request.recipe_sexp, + input_hashes=request.input_hashes, + ) + + # Wait for result (plan generation is usually fast) + result = task.get(timeout=60) + + return { + "status": result.get("status"), + "recipe": result.get("recipe"), + "plan_id": result.get("plan_id"), + "total_steps": result.get("total_steps"), + "cached_steps": result.get("cached_steps"), + "pending_steps": result.get("pending_steps"), + "steps": result.get("steps"), + } + except Exception as e: + logger.error(f"Plan generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/execute") +async def execute_plan_endpoint( + request: ExecutePlanRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Execute a pre-generated execution plan. + + Phase 3 (Execute) of the 3-phase model. + Submits the plan to Celery for parallel execution. + """ + from tasks.orchestrate import run_plan + + run_id = request.run_id or str(uuid.uuid4()) + + try: + task = run_plan.delay( + plan_json=request.plan_json, + run_id=run_id, + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + } + except Exception as e: + logger.error(f"Plan execution failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/run-recipe") +async def run_recipe_endpoint( + request: RecipeRunRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Run a complete recipe through all 3 phases. + + 1. Analyze: Extract features from inputs + 2. Plan: Generate execution plan with cache IDs + 3. Execute: Run steps with parallel execution + + Returns immediately with run_id. Poll /api/run/{run_id} for status. + """ + from tasks.orchestrate import run_recipe + from artdag.sexp import compile_string + import database + + redis = get_redis_client() + cache = get_cache_manager() + + # Parse recipe name from S-expression + try: + compiled = compile_string(request.recipe_sexp) + recipe_name = compiled.name or "unknown" + except Exception: + recipe_name = "unknown" + + # Compute deterministic run_id + run_id = compute_run_id( + list(request.input_hashes.values()), + recipe_name, + hashlib.sha3_256(request.recipe_sexp.encode()).hexdigest() + ) + + # Check if already completed + cached = await database.get_run_cache(run_id) + if cached: + output_cid = cached.get("output_cid") + if cache.has_content(output_cid): + return { + "status": "completed", + "run_id": run_id, + "output_cid": output_cid, + "output_ipfs_cid": cache.get_ipfs_cid(output_cid), + "cached": True, + } + + # Submit to Celery + try: + task = run_recipe.delay( + recipe_sexp=request.recipe_sexp, + input_hashes=request.input_hashes, + run_id=run_id, + ) + + # Store run status in Redis + run_data = { + "run_id": run_id, + "status": "pending", + "recipe": recipe_name, + "inputs": list(request.input_hashes.values()), + "celery_task_id": task.id, + "created_at": datetime.now(timezone.utc).isoformat(), + "username": ctx.actor_id, + } + redis.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, + json.dumps(run_data) + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + "recipe": recipe_name, + } + except Exception as e: + logger.error(f"Recipe run failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/run/{run_id}") +async def get_run_status( + run_id: str, + ctx: UserContext = Depends(require_auth), +): + """Get status of a recipe execution run.""" + import database + from celery.result import AsyncResult + + redis = get_redis_client() + + # Check Redis for run status + run_data = redis.get(f"{RUNS_KEY_PREFIX}{run_id}") + if run_data: + data = json.loads(run_data) + + # If pending, check Celery task status + if data.get("status") == "pending" and data.get("celery_task_id"): + result = AsyncResult(data["celery_task_id"]) + + if result.ready(): + if result.successful(): + task_result = result.get() + data["status"] = task_result.get("status", "completed") + data["output_cid"] = task_result.get("output_cache_id") + data["output_ipfs_cid"] = task_result.get("output_ipfs_cid") + data["total_steps"] = task_result.get("total_steps") + data["cached"] = task_result.get("cached") + data["executed"] = task_result.get("executed") + + # Update Redis + redis.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, + json.dumps(data) + ) + else: + data["status"] = "failed" + data["error"] = str(result.result) + else: + data["celery_status"] = result.status + + return data + + # Check database cache + cached = await database.get_run_cache(run_id) + if cached: + return { + "run_id": run_id, + "status": "completed", + "output_cid": cached.get("output_cid"), + "cached": True, + } + + raise HTTPException(status_code=404, detail="Run not found") diff --git a/l1/app/routers/auth.py b/l1/app/routers/auth.py new file mode 100644 index 0000000..c447f3d --- /dev/null +++ b/l1/app/routers/auth.py @@ -0,0 +1,165 @@ +""" +Authentication routes — OAuth2 authorization code flow via account.rose-ash.com. + +GET /auth/login — redirect to account OAuth authorize +GET /auth/callback — exchange code for user info, set session cookie +GET /auth/logout — clear cookie, redirect through account SSO logout +""" + +import secrets +import time + +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import RedirectResponse +from itsdangerous import URLSafeSerializer + +from artdag_common.middleware.auth import UserContext, set_auth_cookie, clear_auth_cookie + +from ..config import settings + +router = APIRouter() + +_signer = None + + +def _get_signer() -> URLSafeSerializer: + global _signer + if _signer is None: + _signer = URLSafeSerializer(settings.secret_key, salt="oauth-state") + return _signer + + +@router.get("/login") +async def login(request: Request): + """Store state + next in signed cookie, redirect to account OAuth authorize.""" + next_url = request.query_params.get("next", "/") + prompt = request.query_params.get("prompt", "") + state = secrets.token_urlsafe(32) + + signer = _get_signer() + state_payload = signer.dumps({"state": state, "next": next_url, "prompt": prompt}) + + device_id = getattr(request.state, "device_id", "") + authorize_url = ( + f"{settings.oauth_authorize_url}" + f"?client_id={settings.oauth_client_id}" + f"&redirect_uri={settings.oauth_redirect_uri}" + f"&device_id={device_id}" + f"&state={state}" + ) + if prompt: + authorize_url += f"&prompt={prompt}" + + response = RedirectResponse(url=authorize_url, status_code=302) + response.set_cookie( + key="oauth_state", + value=state_payload, + max_age=600, # 10 minutes + httponly=True, + samesite="lax", + secure=True, + ) + return response + + +@router.get("/callback") +async def callback(request: Request): + """Validate state, exchange code via token endpoint, set session cookie.""" + code = request.query_params.get("code", "") + state = request.query_params.get("state", "") + error = request.query_params.get("error", "") + account_did = request.query_params.get("account_did", "") + + # Adopt account's device ID as our own (one identity across all apps) + if account_did: + request.state.device_id = account_did + request.state._new_device_id = True # device_id middleware will set cookie + + # Recover state from signed cookie + state_cookie = request.cookies.get("oauth_state", "") + signer = _get_signer() + try: + payload = signer.loads(state_cookie) if state_cookie else {} + except Exception: + payload = {} + + next_url = payload.get("next", "/") + + # Handle prompt=none rejection (user not logged in on account) + if error == "login_required": + response = RedirectResponse(url=next_url, status_code=302) + response.delete_cookie("oauth_state") + # Set cooldown cookie — don't re-check for 5 minutes + response.set_cookie( + key="pnone_at", + value=str(time.time()), + max_age=300, + httponly=True, + samesite="lax", + secure=True, + ) + # Set device cookie if adopted + if account_did: + response.set_cookie( + key="artdag_did", + value=account_did, + max_age=30 * 24 * 3600, + httponly=True, + samesite="lax", + secure=True, + ) + return response + + # Normal callback — validate state + code + if not state_cookie or not code or not state: + return RedirectResponse(url="/", status_code=302) + + if payload.get("state") != state: + return RedirectResponse(url="/", status_code=302) + + # Exchange code for user info via account's token endpoint + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + settings.oauth_token_url, + json={ + "code": code, + "client_id": settings.oauth_client_id, + "redirect_uri": settings.oauth_redirect_uri, + }, + ) + except httpx.HTTPError: + return RedirectResponse(url="/", status_code=302) + + if resp.status_code != 200: + return RedirectResponse(url="/", status_code=302) + + data = resp.json() + if "error" in data: + return RedirectResponse(url="/", status_code=302) + + # Map OAuth response to artdag UserContext + # Note: account token endpoint returns user.email as "username" + display_name = data.get("display_name", "") + username = data.get("username", "") + email = username # OAuth response "username" is the user's email + actor_id = f"@{username}" + + user = UserContext(username=username, actor_id=actor_id, email=email) + + response = RedirectResponse(url=next_url, status_code=302) + set_auth_cookie(response, user) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") + return response + + +@router.get("/logout") +async def logout(): + """Clear session cookie, redirect through account SSO logout.""" + response = RedirectResponse(url=settings.oauth_logout_url, status_code=302) + clear_auth_cookie(response) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") + return response diff --git a/l1/app/routers/cache.py b/l1/app/routers/cache.py new file mode 100644 index 0000000..dc03d44 --- /dev/null +++ b/l1/app/routers/cache.py @@ -0,0 +1,515 @@ +""" +Cache and media routes for L1 server. + +Handles content retrieval, metadata, media preview, and publishing. +""" + +import logging +from pathlib import Path +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File, Form +from fastapi.responses import HTMLResponse, FileResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_redis_client, + get_cache_manager, get_current_user +) +from ..services.auth_service import AuthService +from ..services.cache_service import CacheService + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class UpdateMetadataRequest(BaseModel): + title: Optional[str] = None + description: Optional[str] = None + tags: Optional[list] = None + custom: Optional[Dict[str, Any]] = None + + +def get_cache_service(): + """Get cache service instance.""" + import database + return CacheService(database, get_cache_manager()) + + +@router.get("/{cid}") +async def get_cached( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get cached content by hash. Content negotiation: HTML for browsers, JSON for APIs.""" + ctx = await get_current_user(request) + + # Pass actor_id to get friendly name and user-specific metadata + actor_id = ctx.actor_id if ctx else None + cache_item = await cache_service.get_cache_item(cid, actor_id=actor_id) + if not cache_item: + if wants_html(request): + templates = get_templates(request) + return render(templates, "cache/not_found.html", request, + cid=cid, + user=ctx, + active_tab="media", + ) + raise HTTPException(404, f"Content {cid} not in cache") + + # JSON response + if wants_json(request): + return cache_item + + # HTML response + if not ctx: + from fastapi.responses import RedirectResponse + return RedirectResponse(url="/auth", status_code=302) + + # Check access + has_access = await cache_service.check_access(cid, ctx.actor_id, ctx.username) + if not has_access: + raise HTTPException(403, "Access denied") + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "cache/detail.html", request, + cache=cache_item, + user=ctx, + nav_counts=nav_counts, + active_tab="media", + ) + + +@router.get("/{cid}/raw") +async def get_cached_raw( + cid: str, + cache_service: CacheService = Depends(get_cache_service), +): + """Get raw cached content (file download).""" + file_path, media_type, filename = await cache_service.get_raw_file(cid) + + if not file_path: + raise HTTPException(404, f"Content {cid} not in cache") + + return FileResponse(file_path, media_type=media_type, filename=filename) + + +@router.get("/{cid}/mp4") +async def get_cached_mp4( + cid: str, + cache_service: CacheService = Depends(get_cache_service), +): + """Get cached content as MP4 (transcodes MKV on first request).""" + mp4_path, error = await cache_service.get_as_mp4(cid) + + if error: + raise HTTPException(400 if "not a video" in error else 404, error) + + return FileResponse(mp4_path, media_type="video/mp4") + + +@router.get("/{cid}/meta") +async def get_metadata( + cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Get content metadata.""" + meta = await cache_service.get_metadata(cid, ctx.actor_id) + if meta is None: + raise HTTPException(404, "Content not found") + return meta + + +@router.patch("/{cid}/meta") +async def update_metadata( + cid: str, + req: UpdateMetadataRequest, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Update content metadata.""" + success, error = await cache_service.update_metadata( + cid=cid, + actor_id=ctx.actor_id, + title=req.title, + description=req.description, + tags=req.tags, + custom=req.custom, + ) + + if error: + raise HTTPException(400, error) + + return {"updated": True} + + +@router.post("/{cid}/publish") +async def publish_content( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Publish content to L2 and IPFS.""" + ipfs_cid, error = await cache_service.publish_to_l2( + cid=cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Published: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "published": True} + + +@router.delete("/{cid}") +async def delete_content( + cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Delete content from cache.""" + success, error = await cache_service.delete_content(cid, ctx.actor_id) + + if error: + raise HTTPException(400 if "Cannot" in error or "pinned" in error else 404, error) + + return {"deleted": True} + + +@router.post("/import") +async def import_from_ipfs( + ipfs_cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Import content from IPFS.""" + cid, error = await cache_service.import_from_ipfs(ipfs_cid, ctx.actor_id) + + if error: + raise HTTPException(400, error) + + return {"cid": cid, "imported": True} + + +@router.post("/upload/chunk") +async def upload_chunk( + request: Request, + chunk: UploadFile = File(...), + upload_id: str = Form(...), + chunk_index: int = Form(...), + total_chunks: int = Form(...), + filename: str = Form(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Upload a file chunk. Assembles file when all chunks received.""" + import tempfile + import os + + # Create temp dir for this upload + chunk_dir = Path(tempfile.gettempdir()) / "uploads" / upload_id + chunk_dir.mkdir(parents=True, exist_ok=True) + + # Save this chunk + chunk_path = chunk_dir / f"chunk_{chunk_index:05d}" + chunk_data = await chunk.read() + chunk_path.write_bytes(chunk_data) + + # Check if all chunks received + received = len(list(chunk_dir.glob("chunk_*"))) + + if received < total_chunks: + return {"status": "partial", "received": received, "total": total_chunks} + + # All chunks received - assemble file + final_path = chunk_dir / filename + with open(final_path, 'wb') as f: + for i in range(total_chunks): + cp = chunk_dir / f"chunk_{i:05d}" + f.write(cp.read_bytes()) + cp.unlink() # Clean up chunk + + # Read assembled file + content = final_path.read_bytes() + final_path.unlink() + chunk_dir.rmdir() + + # Now do the normal upload flow + cid, ipfs_cid, error = await cache_service.upload_content( + content=content, + filename=filename, + actor_id=ctx.actor_id, + ) + + if error: + raise HTTPException(400, error) + + # Assign friendly name + final_cid = ipfs_cid or cid + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=final_cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, + filename=filename, + ) + + return { + "status": "complete", + "cid": final_cid, + "friendly_name": friendly_entry["friendly_name"], + "filename": filename, + "size": len(content), + "uploaded": True, + } + + +@router.post("/upload") +async def upload_content( + file: UploadFile = File(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Upload content to cache and IPFS. + + Args: + file: The file to upload + display_name: Optional custom name for the media (used as friendly name) + """ + content = await file.read() + cid, ipfs_cid, error = await cache_service.upload_content( + content=content, + filename=file.filename, + actor_id=ctx.actor_id, + ) + + if error: + raise HTTPException(400, error) + + # Assign friendly name (use IPFS CID if available, otherwise local hash) + final_cid = ipfs_cid or cid + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=final_cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, # Use custom name if provided + filename=file.filename, + ) + + return { + "cid": final_cid, + "content_hash": cid, # Legacy, for backwards compatibility + "friendly_name": friendly_entry["friendly_name"], + "filename": file.filename, + "size": len(content), + "uploaded": True, + } + + +# Media listing endpoint +@router.get("") +async def list_media( + request: Request, + offset: int = 0, + limit: int = 24, + media_type: Optional[str] = None, + cache_service: CacheService = Depends(get_cache_service), + ctx: UserContext = Depends(require_auth), +): + """List all media in cache.""" + items = await cache_service.list_media( + actor_id=ctx.actor_id, + username=ctx.username, + offset=offset, + limit=limit, + media_type=media_type, + ) + has_more = len(items) >= limit + + if wants_json(request): + return {"items": items, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "cache/media_list.html", request, + items=items, + user=ctx, + nav_counts=nav_counts, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="media", + ) + + +# HTMX metadata form +@router.get("/{cid}/meta-form", response_class=HTMLResponse) +async def get_metadata_form( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get metadata editing form (HTMX).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + meta = await cache_service.get_metadata(cid, ctx.actor_id) + + return HTMLResponse(f''' +

Metadata

+
+
+ + +
+
+ + +
+ +
+ ''') + + +@router.patch("/{cid}/meta", response_class=HTMLResponse) +async def update_metadata_htmx( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Update metadata (HTMX form handler).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + form_data = await request.form() + + success, error = await cache_service.update_metadata( + cid=cid, + actor_id=ctx.actor_id, + title=form_data.get("title"), + description=form_data.get("description"), + ) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse(''' +
Metadata saved!
+ + ''') + + +# Friendly name editing +@router.get("/{cid}/name-form", response_class=HTMLResponse) +async def get_name_form( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get friendly name editing form (HTMX).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + # Get current friendly name + from ..services.naming_service import get_naming_service + naming = get_naming_service() + entry = await naming.get_by_cid(ctx.actor_id, cid) + current_name = entry.get("base_name", "") if entry else "" + + return HTMLResponse(f''' +
+
+ + +

A name to reference this media in recipes

+
+
+ + +
+
+ ''') + + +@router.post("/{cid}/name", response_class=HTMLResponse) +async def update_friendly_name( + cid: str, + request: Request, +): + """Update friendly name (HTMX form handler).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + form_data = await request.form() + display_name = form_data.get("display_name", "").strip() + + if not display_name: + return HTMLResponse('
Name cannot be empty
') + + from ..services.naming_service import get_naming_service + naming = get_naming_service() + + try: + entry = await naming.assign_name( + cid=cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, + ) + + return HTMLResponse(f''' +
Name updated!
+ + ''') + except Exception as e: + return HTMLResponse(f'
Error: {e}
') diff --git a/l1/app/routers/effects.py b/l1/app/routers/effects.py new file mode 100644 index 0000000..994a925 --- /dev/null +++ b/l1/app/routers/effects.py @@ -0,0 +1,415 @@ +""" +Effects routes for L1 server. + +Handles effect upload, listing, and metadata. +Effects are S-expression files stored in IPFS like all other content-addressed data. +""" + +import json +import logging +import re +import time +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File, Form +from fastapi.responses import HTMLResponse, PlainTextResponse + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_redis_client, + get_cache_manager, +) +from ..services.auth_service import AuthService +import ipfs_client + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def get_effects_dir() -> Path: + """Get effects storage directory.""" + cache_mgr = get_cache_manager() + effects_dir = Path(cache_mgr.cache_dir) / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + return effects_dir + + +def parse_effect_metadata(source: str) -> dict: + """ + Parse effect metadata from S-expression source code. + + Extracts metadata from comment headers (;; @key value format) + or from (defeffect name ...) form. + """ + metadata = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + } + + # Parse comment-based metadata (;; @key value) + for line in source.split("\n"): + stripped = line.strip() + if not stripped.startswith(";"): + # Stop parsing metadata at first non-comment line + if stripped and not stripped.startswith("("): + continue + if stripped.startswith("("): + break + + # Remove comment prefix + comment = stripped.lstrip(";").strip() + + if comment.startswith("@effect "): + metadata["name"] = comment[8:].strip() + elif comment.startswith("@name "): + metadata["name"] = comment[6:].strip() + elif comment.startswith("@version "): + metadata["version"] = comment[9:].strip() + elif comment.startswith("@author "): + metadata["author"] = comment[8:].strip() + elif comment.startswith("@temporal"): + val = comment[9:].strip().lower() if len(comment) > 9 else "true" + metadata["temporal"] = val in ("true", "yes", "1", "") + elif comment.startswith("@description "): + metadata["description"] = comment[13:].strip() + elif comment.startswith("@param "): + # Format: @param name type [description] + parts = comment[7:].split(None, 2) + if len(parts) >= 2: + param = {"name": parts[0], "type": parts[1]} + if len(parts) > 2: + param["description"] = parts[2] + metadata["params"].append(param) + + # Also try to extract name from (defeffect "name" ...) or (effect "name" ...) + if not metadata["name"]: + name_match = re.search(r'\((defeffect|effect)\s+"([^"]+)"', source) + if name_match: + metadata["name"] = name_match.group(2) + + # Try to extract name from first (define ...) form + if not metadata["name"]: + define_match = re.search(r'\(define\s+(\w+)', source) + if define_match: + metadata["name"] = define_match.group(1) + + return metadata + + +@router.post("/upload") +async def upload_effect( + file: UploadFile = File(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), +): + """ + Upload an S-expression effect to IPFS. + + Parses metadata from comment headers. + Returns IPFS CID for use in recipes. + + Args: + file: The .sexp effect file + display_name: Optional custom friendly name for the effect + """ + content = await file.read() + + try: + source = content.decode("utf-8") + except UnicodeDecodeError: + raise HTTPException(400, "Effect must be valid UTF-8 text") + + # Parse metadata from sexp source + try: + meta = parse_effect_metadata(source) + except Exception as e: + logger.warning(f"Failed to parse effect metadata: {e}") + meta = {"name": file.filename or "unknown"} + + if not meta.get("name"): + meta["name"] = Path(file.filename).stem if file.filename else "unknown" + + # Store effect source in IPFS + cid = ipfs_client.add_bytes(content) + if not cid: + raise HTTPException(500, "Failed to store effect in IPFS") + + # Also keep local cache for fast worker access + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + effect_dir.mkdir(parents=True, exist_ok=True) + (effect_dir / "effect.sexp").write_text(source, encoding="utf-8") + + # Store metadata (locally and in IPFS) + full_meta = { + "cid": cid, + "meta": meta, + "uploader": ctx.actor_id, + "uploaded_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "filename": file.filename, + } + (effect_dir / "metadata.json").write_text(json.dumps(full_meta, indent=2)) + + # Also store metadata in IPFS for discoverability + meta_cid = ipfs_client.add_json(full_meta) + + # Track ownership in item_types + import database + await database.save_item_metadata( + cid=cid, + actor_id=ctx.actor_id, + item_type="effect", + filename=file.filename, + ) + + # Assign friendly name (use custom display_name if provided, else from metadata) + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=cid, + actor_id=ctx.actor_id, + item_type="effect", + display_name=display_name or meta.get("name"), + filename=file.filename, + ) + + logger.info(f"Uploaded effect '{meta.get('name')}' cid={cid} friendly_name='{friendly_entry['friendly_name']}' by {ctx.actor_id}") + + return { + "cid": cid, + "metadata_cid": meta_cid, + "name": meta.get("name"), + "friendly_name": friendly_entry["friendly_name"], + "version": meta.get("version"), + "temporal": meta.get("temporal", False), + "params": meta.get("params", []), + "uploaded": True, + } + + +@router.get("/{cid}") +async def get_effect( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Get effect metadata by CID.""" + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + metadata_path = effect_dir / "metadata.json" + + # Try local cache first + if metadata_path.exists(): + meta = json.loads(metadata_path.read_text()) + else: + # Fetch from IPFS + source_bytes = ipfs_client.get_bytes(cid) + if not source_bytes: + raise HTTPException(404, f"Effect {cid[:16]}... not found") + + # Cache locally + effect_dir.mkdir(parents=True, exist_ok=True) + source = source_bytes.decode("utf-8") + (effect_dir / "effect.sexp").write_text(source) + + # Parse metadata from source + parsed_meta = parse_effect_metadata(source) + meta = {"cid": cid, "meta": parsed_meta} + (effect_dir / "metadata.json").write_text(json.dumps(meta, indent=2)) + + # Add friendly name if available + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(ctx.actor_id, cid) + if friendly: + meta["friendly_name"] = friendly["friendly_name"] + meta["base_name"] = friendly["base_name"] + meta["version_id"] = friendly["version_id"] + + if wants_json(request): + return meta + + # HTML response + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "effects/detail.html", request, + effect=meta, + user=ctx, + nav_counts=nav_counts, + active_tab="effects", + ) + + +@router.get("/{cid}/source") +async def get_effect_source( + cid: str, + ctx: UserContext = Depends(require_auth), +): + """Get effect source code.""" + effects_dir = get_effects_dir() + source_path = effects_dir / cid / "effect.sexp" + + # Try local cache first (check both .sexp and legacy .py) + if source_path.exists(): + return PlainTextResponse(source_path.read_text()) + + legacy_path = effects_dir / cid / "effect.py" + if legacy_path.exists(): + return PlainTextResponse(legacy_path.read_text()) + + # Fetch from IPFS + source_bytes = ipfs_client.get_bytes(cid) + if not source_bytes: + raise HTTPException(404, f"Effect {cid[:16]}... not found") + + # Cache locally + source_path.parent.mkdir(parents=True, exist_ok=True) + source = source_bytes.decode("utf-8") + source_path.write_text(source) + + return PlainTextResponse(source) + + +@router.get("") +async def list_effects( + request: Request, + offset: int = 0, + limit: int = 20, + ctx: UserContext = Depends(require_auth), +): + """List user's effects with pagination.""" + import database + effects_dir = get_effects_dir() + effects = [] + + # Get user's effect CIDs from item_types + user_items = await database.get_user_items(ctx.actor_id, item_type="effect", limit=1000) + effect_cids = [item["cid"] for item in user_items] + + # Get naming service for friendly name lookup + from ..services.naming_service import get_naming_service + naming = get_naming_service() + + for cid in effect_cids: + effect_dir = effects_dir / cid + metadata_path = effect_dir / "metadata.json" + if metadata_path.exists(): + try: + meta = json.loads(metadata_path.read_text()) + # Add friendly name if available + friendly = await naming.get_by_cid(ctx.actor_id, cid) + if friendly: + meta["friendly_name"] = friendly["friendly_name"] + meta["base_name"] = friendly["base_name"] + effects.append(meta) + except json.JSONDecodeError: + pass + + # Sort by upload time (newest first) + effects.sort(key=lambda e: e.get("uploaded_at", ""), reverse=True) + + # Apply pagination + total = len(effects) + paginated_effects = effects[offset:offset + limit] + has_more = offset + limit < total + + if wants_json(request): + return {"effects": paginated_effects, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "effects/list.html", request, + effects=paginated_effects, + user=ctx, + nav_counts=nav_counts, + active_tab="effects", + offset=offset, + limit=limit, + has_more=has_more, + ) + + +@router.post("/{cid}/publish") +async def publish_effect( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Publish effect to L2 ActivityPub server.""" + from ..services.cache_service import CacheService + import database + + # Verify effect exists + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + if not effect_dir.exists(): + error = "Effect not found" + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(404, error) + + # Use cache service to publish + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + logger.info(f"Published effect {cid[:16]}... to L2 by {ctx.actor_id}") + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "cid": cid, "published": True} + + +@router.delete("/{cid}") +async def delete_effect( + cid: str, + ctx: UserContext = Depends(require_auth), +): + """Remove user's ownership link to an effect.""" + import database + + # Remove user's ownership link from item_types + await database.delete_item_type(cid, ctx.actor_id, "effect") + + # Remove friendly name + await database.delete_friendly_name(ctx.actor_id, cid) + + # Check if anyone still owns this effect + remaining_owners = await database.get_item_types(cid) + + # Only delete local files if no one owns it anymore + if not remaining_owners: + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + if effect_dir.exists(): + import shutil + shutil.rmtree(effect_dir) + + # Unpin from IPFS + ipfs_client.unpin(cid) + logger.info(f"Garbage collected effect {cid[:16]}... (no remaining owners)") + + logger.info(f"Removed effect {cid[:16]}... ownership for {ctx.actor_id}") + return {"deleted": True} diff --git a/l1/app/routers/fragments.py b/l1/app/routers/fragments.py new file mode 100644 index 0000000..5d6d821 --- /dev/null +++ b/l1/app/routers/fragments.py @@ -0,0 +1,143 @@ +""" +Art-DAG fragment endpoints. + +Exposes HTML fragments at ``/internal/fragments/{type}`` for consumption +by coop apps via the fragment client. +""" + +import os + +from fastapi import APIRouter, Request, Response + +router = APIRouter() + +# Registry of fragment handlers: type -> async callable(request) returning HTML str +_handlers: dict[str, object] = {} + +FRAGMENT_HEADER = "X-Fragment-Request" + + +@router.get("/internal/fragments/{fragment_type}") +async def get_fragment(fragment_type: str, request: Request): + if not request.headers.get(FRAGMENT_HEADER): + return Response(content="", status_code=403) + + handler = _handlers.get(fragment_type) + if handler is None: + return Response(content="", media_type="text/html", status_code=200) + html = await handler(request) + return Response(content=html, media_type="text/html", status_code=200) + + +# --- nav-item fragment --- + +async def _nav_item_handler(request: Request) -> str: + from artdag_common import render_fragment + + templates = request.app.state.templates + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + return render_fragment(templates, "fragments/nav_item.html", artdag_url=artdag_url) + + +_handlers["nav-item"] = _nav_item_handler + + +# --- link-card fragment --- + +async def _link_card_handler(request: Request) -> str: + from artdag_common import render_fragment + import database + + templates = request.app.state.templates + cid = request.query_params.get("cid", "") + content_type = request.query_params.get("type", "media") + slug = request.query_params.get("slug", "") + keys_raw = request.query_params.get("keys", "") + + # Batch mode: return multiple cards separated by markers + if keys_raw: + keys = [k.strip() for k in keys_raw.split(",") if k.strip()] + parts = [] + for key in keys: + parts.append(f"") + card_html = await _render_single_link_card( + templates, key, content_type, + ) + parts.append(card_html) + return "\n".join(parts) + + # Single mode: use cid or slug + lookup_cid = cid or slug + if not lookup_cid: + return "" + return await _render_single_link_card(templates, lookup_cid, content_type) + + +async def _render_single_link_card(templates, cid: str, content_type: str) -> str: + import database + from artdag_common import render_fragment + + if not cid: + return "" + + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + + # Try item_types first (has metadata) + item = await database.get_item_types(cid) + # get_item_types returns a list; pick best match for content_type + meta = None + if item: + for it in item: + if it.get("type") == content_type: + meta = it + break + if not meta: + meta = item[0] + + # Try friendly name for display + friendly = None + if meta and meta.get("actor_id"): + friendly = await database.get_friendly_name_by_cid(meta["actor_id"], cid) + + # Try run cache if type is "run" + run = None + if content_type == "run": + run = await database.get_run_cache(cid) + + title = "" + description = "" + link = "" + + if friendly: + title = friendly.get("display_name") or friendly.get("base_name", cid[:12]) + elif meta: + title = meta.get("filename") or meta.get("description", cid[:12]) + elif run: + title = f"Run {cid[:12]}" + else: + title = cid[:16] + + if meta: + description = meta.get("description", "") + + if content_type == "run": + link = f"{artdag_url}/runs/{cid}" + elif content_type == "recipe": + link = f"{artdag_url}/recipes/{cid}" + elif content_type == "effect": + link = f"{artdag_url}/effects/{cid}" + else: + link = f"{artdag_url}/cache/{cid}" + + return render_fragment( + templates, "fragments/link_card.html", + title=title, + description=description, + link=link, + cid=cid, + content_type=content_type, + artdag_url=artdag_url, + ) + + +_handlers["link-card"] = _link_card_handler diff --git a/l1/app/routers/home.py b/l1/app/routers/home.py new file mode 100644 index 0000000..4b89b94 --- /dev/null +++ b/l1/app/routers/home.py @@ -0,0 +1,253 @@ +""" +Home and root routes for L1 server. +""" + +from pathlib import Path + +import markdown +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse + +from artdag_common import render +from artdag_common.middleware import wants_html + +from ..dependencies import get_templates, get_current_user + +router = APIRouter() + + +@router.get("/health") +async def health(): + """Health check endpoint — always returns 200.""" + return {"status": "ok"} + + +async def get_user_stats(actor_id: str) -> dict: + """Get stats for a user.""" + import database + from ..services.run_service import RunService + from ..dependencies import get_redis_client, get_cache_manager + + stats = {} + + try: + # Count only actual media types (video, image, audio), not effects/recipes + media_count = 0 + for media_type in ["video", "image", "audio", "unknown"]: + media_count += await database.count_user_items(actor_id, item_type=media_type) + stats["media"] = media_count + except Exception: + stats["media"] = 0 + + try: + # Count user's recipes from database (ownership-based) + stats["recipes"] = await database.count_user_items(actor_id, item_type="recipe") + except Exception: + stats["recipes"] = 0 + + try: + run_service = RunService(database, get_redis_client(), get_cache_manager()) + runs = await run_service.list_runs(actor_id) + stats["runs"] = len(runs) + except Exception: + stats["runs"] = 0 + + try: + storage_providers = await database.get_user_storage_providers(actor_id) + stats["storage"] = len(storage_providers) if storage_providers else 0 + except Exception: + stats["storage"] = 0 + + try: + # Count user's effects from database (ownership-based) + stats["effects"] = await database.count_user_items(actor_id, item_type="effect") + except Exception: + stats["effects"] = 0 + + return stats + + +@router.get("/api/stats") +async def api_stats(request: Request): + """Get user stats as JSON for CLI and API clients.""" + user = await get_current_user(request) + if not user: + raise HTTPException(401, "Authentication required") + + stats = await get_user_stats(user.actor_id) + return stats + + +@router.delete("/api/clear-data") +async def clear_user_data(request: Request): + """ + Clear all user L1 data except storage configuration. + + Deletes: runs, recipes, effects, media/cache items. + Preserves: storage provider configurations. + """ + import logging + logger = logging.getLogger(__name__) + + user = await get_current_user(request) + if not user: + raise HTTPException(401, "Authentication required") + + import database + from ..services.recipe_service import RecipeService + from ..services.run_service import RunService + from ..dependencies import get_redis_client, get_cache_manager + + actor_id = user.actor_id + username = user.username + deleted = { + "runs": 0, + "recipes": 0, + "effects": 0, + "media": 0, + } + errors = [] + + # Delete all runs + try: + run_service = RunService(database, get_redis_client(), get_cache_manager()) + runs = await run_service.list_runs(actor_id, offset=0, limit=10000) + for run in runs: + try: + await run_service.discard_run(run["run_id"], actor_id, username) + deleted["runs"] += 1 + except Exception as e: + errors.append(f"Run {run['run_id']}: {e}") + except Exception as e: + errors.append(f"Failed to list runs: {e}") + + # Delete all recipes + try: + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipes = await recipe_service.list_recipes(actor_id, offset=0, limit=10000) + for recipe in recipes: + try: + success, error = await recipe_service.delete_recipe(recipe["recipe_id"], actor_id) + if success: + deleted["recipes"] += 1 + else: + errors.append(f"Recipe {recipe['recipe_id']}: {error}") + except Exception as e: + errors.append(f"Recipe {recipe['recipe_id']}: {e}") + except Exception as e: + errors.append(f"Failed to list recipes: {e}") + + # Delete all effects (uses ownership model) + cache_manager = get_cache_manager() + try: + # Get user's effects from item_types + effect_items = await database.get_user_items(actor_id, item_type="effect", limit=10000) + for item in effect_items: + cid = item.get("cid") + if cid: + try: + # Remove ownership link + await database.delete_item_type(cid, actor_id, "effect") + await database.delete_friendly_name(actor_id, cid) + + # Check if orphaned + remaining = await database.get_item_types(cid) + if not remaining: + # Garbage collect + effects_dir = Path(cache_manager.cache_dir) / "_effects" / cid + if effects_dir.exists(): + import shutil + shutil.rmtree(effects_dir) + import ipfs_client + ipfs_client.unpin(cid) + deleted["effects"] += 1 + except Exception as e: + errors.append(f"Effect {cid[:16]}...: {e}") + except Exception as e: + errors.append(f"Failed to delete effects: {e}") + + # Delete all media/cache items for user (uses ownership model) + try: + from ..services.cache_service import CacheService + cache_service = CacheService(database, cache_manager) + + # Get user's media items (video, image, audio) + for media_type in ["video", "image", "audio", "unknown"]: + items = await database.get_user_items(actor_id, item_type=media_type, limit=10000) + for item in items: + cid = item.get("cid") + if cid: + try: + success, error = await cache_service.delete_content(cid, actor_id) + if success: + deleted["media"] += 1 + elif error: + errors.append(f"Media {cid[:16]}...: {error}") + except Exception as e: + errors.append(f"Media {cid[:16]}...: {e}") + except Exception as e: + errors.append(f"Failed to delete media: {e}") + + logger.info(f"Cleared data for {actor_id}: {deleted}") + if errors: + logger.warning(f"Errors during clear: {errors[:10]}") # Log first 10 errors + + return { + "message": "User data cleared", + "deleted": deleted, + "errors": errors[:10] if errors else [], # Return first 10 errors + "storage_preserved": True, + } + + +@router.get("/") +async def home(request: Request): + """ + Home page - show README and stats. + """ + user = await get_current_user(request) + + # Load README + readme_html = "" + try: + readme_path = Path(__file__).parent.parent.parent / "README.md" + if readme_path.exists(): + readme_html = markdown.markdown(readme_path.read_text(), extensions=['tables', 'fenced_code']) + except Exception: + pass + + # Get stats for current user + stats = {} + if user: + stats = await get_user_stats(user.actor_id) + + templates = get_templates(request) + return render(templates, "home.html", request, + user=user, + readme_html=readme_html, + stats=stats, + nav_counts=stats, # Reuse stats for nav counts + active_tab="home", + ) + + +@router.get("/login") +async def login_redirect(request: Request): + """Redirect to OAuth login flow.""" + return RedirectResponse(url="/auth/login", status_code=302) + + +# Client tarball path +CLIENT_TARBALL = Path(__file__).parent.parent.parent / "artdag-client.tar.gz" + + +@router.get("/download/client") +async def download_client(): + """Download the Art DAG CLI client.""" + if not CLIENT_TARBALL.exists(): + raise HTTPException(404, "Client package not found. Run build-client.sh to create it.") + return FileResponse( + CLIENT_TARBALL, + media_type="application/gzip", + filename="artdag-client.tar.gz" + ) diff --git a/l1/app/routers/inbox.py b/l1/app/routers/inbox.py new file mode 100644 index 0000000..d6fa37c --- /dev/null +++ b/l1/app/routers/inbox.py @@ -0,0 +1,125 @@ +"""AP-style inbox endpoint for receiving signed activities from the coop. + +POST /inbox — verify HTTP Signature, dispatch by activity type. +""" +from __future__ import annotations + +import logging +import time + +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from ..dependencies import get_redis_client +from ..utils.http_signatures import verify_request_signature, parse_key_id + +log = logging.getLogger(__name__) +router = APIRouter() + +# Cache fetched public keys in Redis for 24 hours +_KEY_CACHE_TTL = 86400 + + +async def _fetch_actor_public_key(actor_url: str) -> str | None: + """Fetch an actor's public key, with Redis caching.""" + redis = get_redis_client() + cache_key = f"actor_pubkey:{actor_url}" + + # Check cache + cached = redis.get(cache_key) + if cached: + return cached + + # Fetch actor JSON + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + actor_url, + headers={"Accept": "application/activity+json, application/ld+json"}, + ) + if resp.status_code != 200: + log.warning("Failed to fetch actor %s: %d", actor_url, resp.status_code) + return None + data = resp.json() + except Exception: + log.warning("Error fetching actor %s", actor_url, exc_info=True) + return None + + pub_key_pem = (data.get("publicKey") or {}).get("publicKeyPem") + if not pub_key_pem: + log.warning("No publicKey in actor %s", actor_url) + return None + + # Cache it + redis.set(cache_key, pub_key_pem, ex=_KEY_CACHE_TTL) + return pub_key_pem + + +@router.post("/inbox") +async def inbox(request: Request): + """Receive signed AP activities from the coop platform.""" + sig_header = request.headers.get("signature", "") + if not sig_header: + return JSONResponse({"error": "missing signature"}, status_code=401) + + # Read body + body = await request.body() + + # Verify HTTP Signature + actor_url = parse_key_id(sig_header) + if not actor_url: + return JSONResponse({"error": "invalid keyId"}, status_code=401) + + pub_key = await _fetch_actor_public_key(actor_url) + if not pub_key: + return JSONResponse({"error": "could not fetch public key"}, status_code=401) + + req_headers = dict(request.headers) + path = request.url.path + valid = verify_request_signature( + public_key_pem=pub_key, + signature_header=sig_header, + method="POST", + path=path, + headers=req_headers, + ) + if not valid: + log.warning("Invalid signature from %s", actor_url) + return JSONResponse({"error": "invalid signature"}, status_code=401) + + # Parse and dispatch + try: + activity = await request.json() + except Exception: + return JSONResponse({"error": "invalid json"}, status_code=400) + + activity_type = activity.get("type", "") + log.info("Inbox received: %s from %s", activity_type, actor_url) + + if activity_type == "rose:DeviceAuth": + _handle_device_auth(activity) + + # Always 202 — AP convention + return JSONResponse({"status": "accepted"}, status_code=202) + + +def _handle_device_auth(activity: dict) -> None: + """Set or delete did_auth:{device_id} in local Redis.""" + obj = activity.get("object", {}) + device_id = obj.get("device_id", "") + action = obj.get("action", "") + + if not device_id: + log.warning("rose:DeviceAuth missing device_id") + return + + redis = get_redis_client() + if action == "login": + redis.set(f"did_auth:{device_id}", str(time.time()), ex=30 * 24 * 3600) + log.info("did_auth set for device %s...", device_id[:16]) + elif action == "logout": + redis.delete(f"did_auth:{device_id}") + log.info("did_auth cleared for device %s...", device_id[:16]) + else: + log.warning("rose:DeviceAuth unknown action: %s", action) diff --git a/l1/app/routers/oembed.py b/l1/app/routers/oembed.py new file mode 100644 index 0000000..615dfda --- /dev/null +++ b/l1/app/routers/oembed.py @@ -0,0 +1,74 @@ +"""Art-DAG oEmbed endpoint. + +Returns oEmbed JSON responses for Art-DAG content (media, recipes, effects, runs). +""" + +import os + +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +router = APIRouter() + + +@router.get("/oembed") +async def oembed(request: Request): + url = request.query_params.get("url", "") + if not url: + return JSONResponse({"error": "url parameter required"}, status_code=400) + + # Parse URL to extract content type and CID + # URL patterns: /cache/{cid}, /recipes/{cid}, /effects/{cid}, /runs/{cid} + from urllib.parse import urlparse + + parsed = urlparse(url) + parts = [p for p in parsed.path.strip("/").split("/") if p] + + if len(parts) < 2: + return JSONResponse({"error": "could not parse content URL"}, status_code=404) + + content_type = parts[0].rstrip("s") # recipes -> recipe, runs -> run + cid = parts[1] + + import database + + title = cid[:16] + thumbnail_url = None + + # Look up metadata + items = await database.get_item_types(cid) + if items: + meta = items[0] + title = meta.get("filename") or meta.get("description") or title + + # Try friendly name + actor_id = meta.get("actor_id") + if actor_id: + friendly = await database.get_friendly_name_by_cid(actor_id, cid) + if friendly: + title = friendly.get("display_name") or friendly.get("base_name", title) + + # Media items get a thumbnail + if meta.get("type") == "media": + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + thumbnail_url = f"{artdag_url}/cache/{cid}/raw" + + elif content_type == "run": + run = await database.get_run_cache(cid) + if run: + title = f"Run {cid[:12]}" + + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + + resp = { + "version": "1.0", + "type": "link", + "title": title, + "provider_name": "art-dag", + "provider_url": artdag_url, + "url": url, + } + if thumbnail_url: + resp["thumbnail_url"] = thumbnail_url + + return JSONResponse(resp) diff --git a/l1/app/routers/recipes.py b/l1/app/routers/recipes.py new file mode 100644 index 0000000..1a55397 --- /dev/null +++ b/l1/app/routers/recipes.py @@ -0,0 +1,686 @@ +""" +Recipe management routes for L1 server. + +Handles recipe upload, listing, viewing, and execution. +""" + +import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import require_auth, get_current_user, get_templates, get_redis_client, get_cache_manager +from ..services.auth_service import AuthService +from ..services.recipe_service import RecipeService +from ..types import ( + CompiledNode, TransformedNode, Registry, Recipe, + is_variable_input, get_effect_cid, +) + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class RecipeUploadRequest(BaseModel): + content: str # S-expression or YAML + name: Optional[str] = None + description: Optional[str] = None + + +class RecipeRunRequest(BaseModel): + """Request to run a recipe with variable inputs.""" + inputs: Dict[str, str] = {} # Map input names to CIDs + + +def get_recipe_service() -> RecipeService: + """Get recipe service instance.""" + return RecipeService(get_redis_client(), get_cache_manager()) + + +def transform_node( + node: CompiledNode, + assets: Dict[str, Dict[str, Any]], + effects: Dict[str, Dict[str, Any]], +) -> TransformedNode: + """ + Transform a compiled node to artdag execution format. + + - Resolves asset references to CIDs for SOURCE nodes + - Resolves effect references to CIDs for EFFECT nodes + - Renames 'type' to 'node_type', 'id' to 'node_id' + """ + node_id = node.get("id", "") + config = dict(node.get("config", {})) # Copy to avoid mutation + + # Resolve asset references for SOURCE nodes + if node.get("type") == "SOURCE" and "asset" in config: + asset_name = config["asset"] + if asset_name in assets: + config["cid"] = assets[asset_name].get("cid") + + # Resolve effect references for EFFECT nodes + if node.get("type") == "EFFECT" and "effect" in config: + effect_name = config["effect"] + if effect_name in effects: + config["cid"] = effects[effect_name].get("cid") + + return { + "node_id": node_id, + "node_type": node.get("type", "EFFECT"), + "config": config, + "inputs": node.get("inputs", []), + "name": node.get("name"), + } + + +def build_input_name_mapping( + nodes: Dict[str, TransformedNode], +) -> Dict[str, str]: + """ + Build a mapping from input names to node IDs for variable inputs. + + Variable inputs can be referenced by: + - node_id directly + - config.name (e.g., "Second Video") + - snake_case version (e.g., "second_video") + - kebab-case version (e.g., "second-video") + - node.name (def binding name) + """ + input_name_to_node: Dict[str, str] = {} + + for node_id, node in nodes.items(): + if node.get("node_type") != "SOURCE": + continue + + config = node.get("config", {}) + if not is_variable_input(config): + continue + + # Map by node_id + input_name_to_node[node_id] = node_id + + # Map by config.name + name = config.get("name") + if name: + input_name_to_node[name] = node_id + input_name_to_node[name.lower().replace(" ", "_")] = node_id + input_name_to_node[name.lower().replace(" ", "-")] = node_id + + # Map by node.name (def binding) + node_name = node.get("name") + if node_name: + input_name_to_node[node_name] = node_id + input_name_to_node[node_name.replace("-", "_")] = node_id + + return input_name_to_node + + +def bind_inputs( + nodes: Dict[str, TransformedNode], + input_name_to_node: Dict[str, str], + user_inputs: Dict[str, str], +) -> List[str]: + """ + Bind user-provided input CIDs to source nodes. + + Returns list of warnings for inputs that couldn't be bound. + """ + warnings: List[str] = [] + + for input_name, cid in user_inputs.items(): + # Try direct node ID match first + if input_name in nodes: + node = nodes[input_name] + if node.get("node_type") == "SOURCE": + node["config"]["cid"] = cid + logger.info(f"Bound input {input_name} directly to node, cid={cid[:16]}...") + continue + + # Try input name lookup + if input_name in input_name_to_node: + node_id = input_name_to_node[input_name] + node = nodes[node_id] + node["config"]["cid"] = cid + logger.info(f"Bound input {input_name} via lookup to node {node_id}, cid={cid[:16]}...") + continue + + # Input not found + warnings.append(f"Input '{input_name}' not found in recipe") + logger.warning(f"Input {input_name} not found in nodes or input_name_to_node") + + return warnings + + +async def resolve_friendly_names_in_registry( + registry: dict, + actor_id: str, +) -> dict: + """ + Resolve friendly names to CIDs in the registry. + + Friendly names are identified by containing a space (e.g., "brightness 01hw3x9k") + or by not being a valid CID format. + """ + from ..services.naming_service import get_naming_service + import re + + naming = get_naming_service() + resolved = {"assets": {}, "effects": {}} + + # CID patterns: IPFS CID (Qm..., bafy...) or SHA256 hash (64 hex chars) + cid_pattern = re.compile(r'^(Qm[a-zA-Z0-9]{44}|bafy[a-zA-Z0-9]+|[a-f0-9]{64})$') + + for asset_name, asset_info in registry.get("assets", {}).items(): + cid = asset_info.get("cid", "") + if cid and not cid_pattern.match(cid): + # Looks like a friendly name, resolve it + resolved_cid = await naming.resolve(actor_id, cid, item_type="media") + if resolved_cid: + asset_info = dict(asset_info) + asset_info["cid"] = resolved_cid + asset_info["_resolved_from"] = cid + resolved["assets"][asset_name] = asset_info + + for effect_name, effect_info in registry.get("effects", {}).items(): + cid = effect_info.get("cid", "") + if cid and not cid_pattern.match(cid): + # Looks like a friendly name, resolve it + resolved_cid = await naming.resolve(actor_id, cid, item_type="effect") + if resolved_cid: + effect_info = dict(effect_info) + effect_info["cid"] = resolved_cid + effect_info["_resolved_from"] = cid + resolved["effects"][effect_name] = effect_info + + return resolved + + +async def prepare_dag_for_execution( + recipe: Recipe, + user_inputs: Dict[str, str], + actor_id: str = None, +) -> Tuple[str, List[str]]: + """ + Prepare a recipe DAG for execution by transforming nodes and binding inputs. + + Resolves friendly names to CIDs if actor_id is provided. + Returns (dag_json, warnings). + """ + recipe_dag = recipe.get("dag") + if not recipe_dag or not isinstance(recipe_dag, dict): + raise ValueError("Recipe has no DAG definition") + + # Deep copy to avoid mutating original + dag_copy = json.loads(json.dumps(recipe_dag)) + nodes = dag_copy.get("nodes", {}) + + # Get registry for resolving references + registry = recipe.get("registry", {}) + + # Resolve friendly names to CIDs + if actor_id and registry: + registry = await resolve_friendly_names_in_registry(registry, actor_id) + + assets = registry.get("assets", {}) if registry else {} + effects = registry.get("effects", {}) if registry else {} + + # Transform nodes from list to dict if needed + if isinstance(nodes, list): + nodes_dict: Dict[str, TransformedNode] = {} + for node in nodes: + node_id = node.get("id") + if node_id: + nodes_dict[node_id] = transform_node(node, assets, effects) + nodes = nodes_dict + dag_copy["nodes"] = nodes + + # Build input name mapping and bind user inputs + input_name_to_node = build_input_name_mapping(nodes) + logger.info(f"Input name to node mapping: {input_name_to_node}") + logger.info(f"User-provided inputs: {user_inputs}") + + warnings = bind_inputs(nodes, input_name_to_node, user_inputs) + + # Log final SOURCE node configs for debugging + for nid, n in nodes.items(): + if n.get("node_type") == "SOURCE": + logger.info(f"Final SOURCE node {nid}: config={n.get('config')}") + + # Transform output to output_id + if "output" in dag_copy: + dag_copy["output_id"] = dag_copy.pop("output") + + # Add metadata if not present + if "metadata" not in dag_copy: + dag_copy["metadata"] = {} + + return json.dumps(dag_copy), warnings + + +@router.post("/upload") +async def upload_recipe( + file: UploadFile = File(...), + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Upload a new recipe from S-expression or YAML file.""" + import yaml + + # Read content from the uploaded file + content = (await file.read()).decode("utf-8") + + # Detect format (skip comments starting with ;) + def is_sexp_format(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(') + return False + + is_sexp = is_sexp_format(content) + + try: + from artdag.sexp import compile_string, ParseError, CompileError + SEXP_AVAILABLE = True + except ImportError: + SEXP_AVAILABLE = False + + recipe_name = None + recipe_version = "1.0" + recipe_description = None + variable_inputs = [] + fixed_inputs = [] + + if is_sexp: + if not SEXP_AVAILABLE: + raise HTTPException(500, "S-expression recipes require artdag.sexp module (not installed on server)") + # Parse S-expression + try: + compiled = compile_string(content) + recipe_name = compiled.name + recipe_version = compiled.version + recipe_description = compiled.description + + for node in compiled.nodes: + if node.get("type") == "SOURCE": + config = node.get("config", {}) + if config.get("input"): + variable_inputs.append(config.get("name", node.get("id"))) + elif config.get("asset"): + fixed_inputs.append(config.get("asset")) + except Exception as e: + raise HTTPException(400, f"Parse error: {e}") + else: + # Parse YAML + try: + recipe_data = yaml.safe_load(content) + recipe_name = recipe_data.get("name") + recipe_version = recipe_data.get("version", "1.0") + recipe_description = recipe_data.get("description") + + inputs = recipe_data.get("inputs", {}) + for input_name, input_def in inputs.items(): + if isinstance(input_def, dict) and input_def.get("fixed"): + fixed_inputs.append(input_name) + else: + variable_inputs.append(input_name) + except yaml.YAMLError as e: + raise HTTPException(400, f"Invalid YAML: {e}") + + # Use filename as recipe name if not specified + if not recipe_name and file.filename: + recipe_name = file.filename.rsplit(".", 1)[0] + + recipe_id, error = await recipe_service.upload_recipe( + content=content, + uploader=ctx.actor_id, + name=recipe_name, + description=recipe_description, + ) + + if error: + raise HTTPException(400, error) + + return { + "recipe_id": recipe_id, + "name": recipe_name or "unnamed", + "version": recipe_version, + "variable_inputs": variable_inputs, + "fixed_inputs": fixed_inputs, + "message": "Recipe uploaded successfully", + } + + +@router.get("") +async def list_recipes( + request: Request, + offset: int = 0, + limit: int = 20, + recipe_service: RecipeService = Depends(get_recipe_service), + ctx: UserContext = Depends(require_auth), +): + """List available recipes.""" + recipes = await recipe_service.list_recipes(ctx.actor_id, offset=offset, limit=limit) + has_more = len(recipes) >= limit + + if wants_json(request): + return {"recipes": recipes, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "recipes/list.html", request, + recipes=recipes, + user=ctx, + nav_counts=nav_counts, + active_tab="recipes", + offset=offset, + limit=limit, + has_more=has_more, + ) + + +@router.get("/{recipe_id}") +async def get_recipe( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), + ctx: UserContext = Depends(require_auth), +): + """Get recipe details.""" + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + # Add friendly name if available + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(ctx.actor_id, recipe_id) + if friendly: + recipe["friendly_name"] = friendly["friendly_name"] + recipe["base_name"] = friendly["base_name"] + recipe["version_id"] = friendly["version_id"] + + if wants_json(request): + return recipe + + # Build DAG elements for visualization and convert nodes to steps format + dag_elements = [] + steps = [] + node_colors = { + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + } + + # Debug: log recipe structure + logger.info(f"Recipe keys: {list(recipe.keys())}") + + # Get nodes from dag - can be list or dict, can be under "dag" or directly on recipe + dag = recipe.get("dag", {}) + logger.info(f"DAG type: {type(dag)}, keys: {list(dag.keys()) if isinstance(dag, dict) else 'not dict'}") + nodes = dag.get("nodes", []) if isinstance(dag, dict) else [] + logger.info(f"Nodes from dag.nodes: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + # Also check for nodes directly on recipe (alternative formats) + if not nodes: + nodes = recipe.get("nodes", []) + logger.info(f"Nodes from recipe.nodes: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + if not nodes: + nodes = recipe.get("pipeline", []) + logger.info(f"Nodes from recipe.pipeline: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + if not nodes: + nodes = recipe.get("steps", []) + logger.info(f"Nodes from recipe.steps: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + logger.info(f"Final nodes count: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + # Convert list of nodes to steps format + if isinstance(nodes, list): + for node in nodes: + node_id = node.get("id", "") + node_type = node.get("type", "EFFECT") + inputs = node.get("inputs", []) + config = node.get("config", {}) + + steps.append({ + "id": node_id, + "name": node_id, + "type": node_type, + "inputs": inputs, + "params": config, + }) + + dag_elements.append({ + "data": { + "id": node_id, + "label": node_id, + "color": node_colors.get(node_type, "#6b7280"), + } + }) + for inp in inputs: + if isinstance(inp, str): + dag_elements.append({ + "data": {"source": inp, "target": node_id} + }) + elif isinstance(nodes, dict): + for node_id, node in nodes.items(): + node_type = node.get("type", "EFFECT") + inputs = node.get("inputs", []) + config = node.get("config", {}) + + steps.append({ + "id": node_id, + "name": node_id, + "type": node_type, + "inputs": inputs, + "params": config, + }) + + dag_elements.append({ + "data": { + "id": node_id, + "label": node_id, + "color": node_colors.get(node_type, "#6b7280"), + } + }) + for inp in inputs: + if isinstance(inp, str): + dag_elements.append({ + "data": {"source": inp, "target": node_id} + }) + + # Add steps to recipe for template + recipe["steps"] = steps + + # Use S-expression source if available + if "sexp" not in recipe: + recipe["sexp"] = "; No S-expression source available" + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "recipes/detail.html", request, + recipe=recipe, + dag_elements=dag_elements, + user=ctx, + nav_counts=nav_counts, + active_tab="recipes", + ) + + +@router.delete("/{recipe_id}") +async def delete_recipe( + recipe_id: str, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Delete a recipe.""" + success, error = await recipe_service.delete_recipe(recipe_id, ctx.actor_id) + if error: + raise HTTPException(400 if "Cannot" in error else 404, error) + return {"deleted": True, "recipe_id": recipe_id} + + +@router.post("/{recipe_id}/run") +async def run_recipe( + recipe_id: str, + req: RecipeRunRequest, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Run a recipe with given inputs.""" + from ..services.run_service import RunService + from ..dependencies import get_cache_manager + import database + + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + try: + # Create run using run service + run_service = RunService(database, get_redis_client(), get_cache_manager()) + + # Prepare DAG for execution (transform nodes, bind inputs, resolve friendly names) + dag_json = None + if recipe.get("dag"): + dag_json, warnings = await prepare_dag_for_execution(recipe, req.inputs, actor_id=ctx.actor_id) + for warning in warnings: + logger.warning(warning) + + run, error = await run_service.create_run( + recipe=recipe_id, # Use recipe hash as primary identifier + inputs=req.inputs, + use_dag=True, + dag_json=dag_json, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + recipe_name=recipe.get("name"), # Store name for display + recipe_sexp=recipe.get("sexp"), # S-expression for code-addressed execution + ) + + if error: + raise HTTPException(400, error) + + if not run: + raise HTTPException(500, "Run creation returned no result") + + return { + "run_id": run["run_id"] if isinstance(run, dict) else run.run_id, + "status": run.get("status", "pending") if isinstance(run, dict) else run.status, + "message": "Recipe execution started", + } + except HTTPException: + raise + except Exception as e: + logger.exception(f"Error running recipe {recipe_id}") + raise HTTPException(500, f"Run failed: {e}") + + +@router.get("/{recipe_id}/dag") +async def recipe_dag( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Get recipe DAG visualization data.""" + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + dag_elements = [] + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + } + + for i, step in enumerate(recipe.get("steps", [])): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + dag_elements.append({ + "data": {"source": inp, "target": step_id} + }) + + return {"elements": dag_elements} + + +@router.delete("/{recipe_id}/ui", response_class=HTMLResponse) +async def ui_discard_recipe( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), +): + """HTMX handler: discard a recipe.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
', status_code=401) + + success, error = await recipe_service.delete_recipe(recipe_id, ctx.actor_id) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse( + '
Recipe deleted
' + '' + ) + + +@router.post("/{recipe_id}/publish") +async def publish_recipe( + recipe_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Publish recipe to L2 and IPFS.""" + from ..services.cache_service import CacheService + from ..dependencies import get_cache_manager + import database + + # Verify recipe exists + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + # Use cache service to publish (recipes are stored in cache) + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=recipe_id, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "published": True} diff --git a/l1/app/routers/runs.py b/l1/app/routers/runs.py new file mode 100644 index 0000000..29c7d25 --- /dev/null +++ b/l1/app/routers/runs.py @@ -0,0 +1,1704 @@ +""" +Run management routes for L1 server. + +Handles run creation, status, listing, and detail views. +""" + +import asyncio +import json +import logging +from datetime import datetime, timezone +from typing import List, Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_current_user, + get_redis_client, get_cache_manager +) +from ..services.run_service import RunService + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def plan_to_sexp(plan: dict, recipe_name: str = None) -> str: + """Convert a plan to S-expression format for display.""" + if not plan or not plan.get("steps"): + return ";; No plan available" + + lines = [] + lines.append(f'(plan "{recipe_name or "unknown"}"') + + # Group nodes by type for cleaner output + steps = plan.get("steps", []) + + for step in steps: + step_id = step.get("id", "?") + step_type = step.get("type", "EFFECT") + inputs = step.get("inputs", []) + config = step.get("config", {}) + + # Build the step S-expression + if step_type == "SOURCE": + if config.get("input"): + # Variable input + input_name = config.get("name", config.get("input", "input")) + lines.append(f' (source :input "{input_name}")') + elif config.get("asset"): + # Fixed asset + lines.append(f' (source {config.get("asset", step_id)})') + else: + lines.append(f' (source {step_id})') + elif step_type == "EFFECT": + effect_name = config.get("effect", step_id) + if inputs: + inp_str = " ".join(inputs) + lines.append(f' (-> {inp_str} (effect {effect_name}))') + else: + lines.append(f' (effect {effect_name})') + elif step_type == "SEQUENCE": + if inputs: + inp_str = " ".join(inputs) + lines.append(f' (sequence {inp_str})') + else: + lines.append(f' (sequence)') + else: + # Generic node + if inputs: + inp_str = " ".join(inputs) + lines.append(f' ({step_type.lower()} {inp_str})') + else: + lines.append(f' ({step_type.lower()} {step_id})') + + lines.append(')') + return "\n".join(lines) + +RUNS_KEY_PREFIX = "artdag:run:" + + +class RunRequest(BaseModel): + recipe: str + inputs: List[str] + output_name: Optional[str] = None + use_dag: bool = True + dag_json: Optional[str] = None + + +class RunStatus(BaseModel): + run_id: str + status: str + recipe: Optional[str] = None + inputs: Optional[List[str]] = None + output_name: Optional[str] = None + created_at: Optional[str] = None + completed_at: Optional[str] = None + output_cid: Optional[str] = None + username: Optional[str] = None + provenance_cid: Optional[str] = None + celery_task_id: Optional[str] = None + error: Optional[str] = None + plan_id: Optional[str] = None + plan_name: Optional[str] = None + step_results: Optional[Dict[str, Any]] = None + all_outputs: Optional[List[str]] = None + effects_commit: Optional[str] = None + effect_url: Optional[str] = None + infrastructure: Optional[Dict[str, Any]] = None + + +class StreamRequest(BaseModel): + """Request to run a streaming recipe.""" + recipe_sexp: str # The recipe S-expression content + output_name: str = "output.mp4" + duration: Optional[float] = None # Duration in seconds + fps: Optional[float] = None # FPS override + sources_sexp: Optional[str] = None # Sources config S-expression + audio_sexp: Optional[str] = None # Audio config S-expression + + +def get_run_service(): + """Get run service instance.""" + import database + return RunService(database, get_redis_client(), get_cache_manager()) + + +@router.post("", response_model=RunStatus) +async def create_run( + request: RunRequest, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Start a new rendering run. Checks cache before executing.""" + run, error = await run_service.create_run( + recipe=request.recipe, + inputs=request.inputs, + output_name=request.output_name, + use_dag=request.use_dag, + dag_json=request.dag_json, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + ) + + if error: + raise HTTPException(400, error) + + return run + + +@router.post("/stream", response_model=RunStatus) +async def create_stream_run( + request: StreamRequest, + req: Request, + ctx: UserContext = Depends(get_current_user), +): + """Start a streaming video render. + + The recipe_sexp should be a complete streaming recipe with + (stream ...) form defining the pipeline. + + Assets can be referenced by CID or friendly name in the recipe. + Requires authentication OR admin token in X-Admin-Token header. + """ + import uuid + import tempfile + import os + from pathlib import Path + import database + from tasks.streaming import run_stream + + # Check for admin token if no user auth + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = req.headers.get("X-Admin-Token") + admin_actor_id = req.headers.get("X-Actor-Id", "admin@local") + + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Use context actor_id or admin actor_id + actor_id = ctx.actor_id if ctx else admin_actor_id + + # Generate run ID + run_id = str(uuid.uuid4()) + + # Store recipe in cache so it appears on /recipes page + recipe_id = None + try: + cache_manager = get_cache_manager() + with tempfile.NamedTemporaryFile(delete=False, suffix=".sexp", mode="w") as tmp: + tmp.write(request.recipe_sexp) + tmp_path = Path(tmp.name) + + cached, ipfs_cid = cache_manager.put(tmp_path, node_type="recipe", move=True) + recipe_id = cached.cid + + # Extract recipe name from S-expression (look for (stream "name" ...) pattern) + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', request.recipe_sexp) + recipe_name = name_match.group(1) if name_match else f"stream-{run_id[:8]}" + + # Track ownership in item_types + await database.save_item_metadata( + cid=recipe_id, + actor_id=actor_id, + item_type="recipe", + description=f"Streaming recipe: {recipe_name}", + filename=f"{recipe_name}.sexp", + ) + + # Assign friendly name + from ..services.naming_service import get_naming_service + naming = get_naming_service() + await naming.assign_name( + cid=recipe_id, + actor_id=actor_id, + item_type="recipe", + display_name=recipe_name, + ) + + logger.info(f"Stored streaming recipe {recipe_id[:16]}... as '{recipe_name}'") + except Exception as e: + logger.warning(f"Failed to store recipe in cache: {e}") + # Continue anyway - run will still work, just won't appear in /recipes + + # Submit Celery task to GPU queue for hardware-accelerated rendering + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=request.recipe_sexp, + output_name=request.output_name, + duration=request.duration, + fps=request.fps, + actor_id=actor_id, + sources_sexp=request.sources_sexp, + audio_sexp=request.audio_sexp, + ), + queue='gpu', + ) + + # Store in database for durability + pending = await database.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe_id or "streaming", # Use recipe CID if available + inputs=[], # Streaming recipes don't have traditional inputs + actor_id=actor_id, + dag_json=request.recipe_sexp, # Store recipe content for viewing + output_name=request.output_name, + ) + + logger.info(f"Started stream run {run_id} with task {task.id}") + + return RunStatus( + run_id=run_id, + status="pending", + recipe=recipe_id or "streaming", + created_at=pending.get("created_at"), + celery_task_id=task.id, + ) + + +@router.get("/{run_id}") +async def get_run( + request: Request, + run_id: str, + run_service: RunService = Depends(get_run_service), +): + """Get status of a run.""" + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, f"Run {run_id} not found") + + # Only render HTML if browser explicitly requests it + if wants_html(request): + # Extract username from actor_id (format: @user@server) + actor_id = run.get("actor_id", "") + if actor_id and actor_id.startswith("@"): + parts = actor_id[1:].split("@") + run["username"] = parts[0] if parts else "Unknown" + else: + run["username"] = actor_id or "Unknown" + + # Helper to normalize input refs to just node IDs + def normalize_inputs(inputs): + """Convert input refs (may be dicts or strings) to list of node IDs.""" + result = [] + for inp in inputs: + if isinstance(inp, dict): + node_id = inp.get("node") or inp.get("input") or inp.get("id") + else: + node_id = inp + if node_id: + result.append(node_id) + return result + + # Try to load the recipe to show the plan + plan = None + plan_sexp = None # Native S-expression if available + recipe_ipfs_cid = None + recipe_id = run.get("recipe") + # Check for valid recipe ID (64-char hash, IPFS CIDv0 "Qm...", or CIDv1 "bafy...") + is_valid_recipe_id = recipe_id and ( + len(recipe_id) == 64 or + recipe_id.startswith("Qm") or + recipe_id.startswith("bafy") + ) + if is_valid_recipe_id: + try: + from ..services.recipe_service import RecipeService + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipe = await recipe_service.get_recipe(recipe_id) + if recipe: + # Use native S-expression if available (code is data!) + if recipe.get("sexp"): + plan_sexp = recipe["sexp"] + # Get IPFS CID for the recipe + recipe_ipfs_cid = recipe.get("ipfs_cid") + + # Build steps for DAG visualization + dag = recipe.get("dag", {}) + nodes = dag.get("nodes", []) + + steps = [] + if isinstance(nodes, list): + for node in nodes: + node_id = node.get("id", "") + steps.append({ + "id": node_id, + "name": node_id, + "type": node.get("type", "EFFECT"), + "status": "completed", # Run completed + "inputs": normalize_inputs(node.get("inputs", [])), + "config": node.get("config", {}), + }) + elif isinstance(nodes, dict): + for node_id, node in nodes.items(): + steps.append({ + "id": node_id, + "name": node_id, + "type": node.get("type", "EFFECT"), + "status": "completed", + "inputs": normalize_inputs(node.get("inputs", [])), + "config": node.get("config", {}), + }) + + if steps: + plan = {"steps": steps} + run["total_steps"] = len(steps) + run["executed"] = len(steps) if run.get("status") == "completed" else 0 + + # Use recipe name instead of hash for display (if not already set) + if recipe.get("name") and not run.get("recipe_name"): + run["recipe_name"] = recipe["name"] + except Exception as e: + logger.warning(f"Failed to load recipe for plan: {e}") + + # Handle streaming runs - detect by recipe_sexp content or legacy "streaming" marker + recipe_sexp_content = run.get("recipe_sexp") + is_streaming = run.get("recipe") == "streaming" # Legacy marker + if not is_streaming and recipe_sexp_content: + # Check if content starts with (stream after skipping comments + for line in recipe_sexp_content.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + is_streaming = stripped.startswith('(stream') + break + if is_streaming and recipe_sexp_content and not plan: + plan_sexp = recipe_sexp_content + plan = { + "steps": [{ + "id": "stream", + "type": "STREAM", + "name": "Streaming Recipe", + "inputs": [], + "config": {}, + "status": "completed" if run.get("status") == "completed" else "pending", + }] + } + run["total_steps"] = 1 + run["executed"] = 1 if run.get("status") == "completed" else 0 + + # Helper to convert simple type to MIME type prefix for template + def type_to_mime(simple_type: str) -> str: + if simple_type == "video": + return "video/mp4" + elif simple_type == "image": + return "image/jpeg" + elif simple_type == "audio": + return "audio/mpeg" + return None + + # Build artifacts list from output and inputs + artifacts = [] + output_media_type = None + if run.get("output_cid"): + # Detect media type using magic bytes, fall back to database item_type + output_cid = run["output_cid"] + media_type = None + + # Streaming runs (with ipfs_cid) are always video/mp4 + if run.get("ipfs_cid"): + media_type = "video/mp4" + output_media_type = media_type + else: + try: + from ..services.run_service import detect_media_type + cache_path = get_cache_manager().get_by_cid(output_cid) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + media_type = type_to_mime(simple_type) + output_media_type = media_type + except Exception: + pass + # Fall back to database item_type if local detection failed + if not media_type: + try: + import database + item_types = await database.get_item_types(output_cid, run.get("actor_id")) + if item_types: + media_type = type_to_mime(item_types[0].get("type")) + output_media_type = media_type + except Exception: + pass + artifacts.append({ + "cid": output_cid, + "step_name": "Output", + "media_type": media_type or "application/octet-stream", + }) + + # Build inputs list with media types + run_inputs = [] + if run.get("inputs"): + from ..services.run_service import detect_media_type + cache_manager = get_cache_manager() + for i, input_hash in enumerate(run["inputs"]): + media_type = None + try: + cache_path = cache_manager.get_by_cid(input_hash) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + media_type = type_to_mime(simple_type) + except Exception: + pass + run_inputs.append({ + "cid": input_hash, + "name": f"Input {i + 1}", + "media_type": media_type, + }) + + # Build DAG elements for visualization + dag_elements = [] + if plan and plan.get("steps"): + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + for i, step in enumerate(plan["steps"]): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + # Handle both string and dict inputs + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": { + "source": source, + "target": step_id, + } + }) + + # Use native S-expression if available, otherwise generate from plan + if not plan_sexp and plan: + plan_sexp = plan_to_sexp(plan, run.get("recipe_name")) + + from ..dependencies import get_nav_counts + user = await get_current_user(request) + nav_counts = await get_nav_counts(user.actor_id if user else None) + + templates = get_templates(request) + return render(templates, "runs/detail.html", request, + run=run, + plan=plan, + artifacts=artifacts, + run_inputs=run_inputs, + dag_elements=dag_elements, + output_media_type=output_media_type, + plan_sexp=plan_sexp, + recipe_ipfs_cid=recipe_ipfs_cid, + nav_counts=nav_counts, + active_tab="runs", + ) + + # Default to JSON for API clients + return run + + +@router.delete("/{run_id}") +async def discard_run( + run_id: str, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Discard (delete) a run and its outputs.""" + success, error = await run_service.discard_run(run_id, ctx.actor_id, ctx.username) + if error: + raise HTTPException(400 if "Cannot" in error else 404, error) + return {"discarded": True, "run_id": run_id} + + +@router.get("") +async def list_runs( + request: Request, + offset: int = 0, + limit: int = 20, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(get_current_user), +): + """List all runs for the current user.""" + import os + + # Check for admin token if no user auth + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = request.headers.get("X-Admin-Token") + admin_actor_id = request.headers.get("X-Actor-Id") + + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Use context actor_id or admin actor_id + actor_id = ctx.actor_id if ctx else admin_actor_id + if not actor_id: + raise HTTPException(400, "X-Actor-Id header required with admin token") + + runs = await run_service.list_runs(actor_id, offset=offset, limit=limit) + has_more = len(runs) >= limit + + if wants_json(request): + return {"runs": runs, "offset": offset, "limit": limit, "has_more": has_more} + + # Add media info for inline previews (only for HTML) + cache_manager = get_cache_manager() + from ..services.run_service import detect_media_type + + def type_to_mime(simple_type: str) -> str: + if simple_type == "video": + return "video/mp4" + elif simple_type == "image": + return "image/jpeg" + elif simple_type == "audio": + return "audio/mpeg" + return None + + for run in runs: + # Add output media info + if run.get("output_cid"): + # Streaming runs (with ipfs_cid) are always video/mp4 + if run.get("ipfs_cid"): + run["output_media_type"] = "video/mp4" + else: + try: + cache_path = cache_manager.get_by_cid(run["output_cid"]) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + run["output_media_type"] = type_to_mime(simple_type) + except Exception: + pass + + # Add input media info (first 3 inputs) + input_previews = [] + inputs = run.get("inputs", []) + if isinstance(inputs, list): + for input_hash in inputs[:3]: + preview = {"cid": input_hash, "media_type": None} + try: + cache_path = cache_manager.get_by_cid(input_hash) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + preview["media_type"] = type_to_mime(simple_type) + except Exception: + pass + input_previews.append(preview) + run["input_previews"] = input_previews + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(actor_id) + + templates = get_templates(request) + return render(templates, "runs/list.html", request, + runs=runs, + user=ctx or {"actor_id": actor_id}, + nav_counts=nav_counts, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="runs", + ) + + +@router.get("/{run_id}/detail") +async def run_detail( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Run detail page with tabs for plan/analysis/artifacts.""" + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, f"Run {run_id} not found") + + # Get plan, artifacts, and analysis + plan = await run_service.get_run_plan(run_id) + artifacts = await run_service.get_run_artifacts(run_id) + analysis = await run_service.get_run_analysis(run_id) + + # Build DAG elements for visualization + dag_elements = [] + if plan and plan.get("steps"): + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + for i, step in enumerate(plan["steps"]): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + # Add edges from inputs (handle both string and dict formats) + for inp in step.get("inputs", []): + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": { + "source": source, + "target": step_id, + } + }) + + if wants_json(request): + return { + "run": run, + "plan": plan, + "artifacts": artifacts, + "analysis": analysis, + } + + # Extract plan_sexp for streaming runs + plan_sexp = plan.get("sexp") if plan else None + + templates = get_templates(request) + return render(templates, "runs/detail.html", request, + run=run, + plan=plan, + plan_sexp=plan_sexp, + artifacts=artifacts, + analysis=analysis, + dag_elements=dag_elements, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/plan") +async def run_plan( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Plan visualization as interactive DAG.""" + plan = await run_service.get_run_plan(run_id) + if not plan: + raise HTTPException(404, "Plan not found for this run") + + if wants_json(request): + return plan + + # Build DAG elements + dag_elements = [] + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + + for i, step in enumerate(plan.get("steps", [])): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + # Handle both string and dict formats + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": {"source": source, "target": step_id} + }) + + templates = get_templates(request) + return render(templates, "runs/plan.html", request, + run_id=run_id, + plan=plan, + dag_elements=dag_elements, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/artifacts") +async def run_artifacts( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Get artifacts list for a run.""" + artifacts = await run_service.get_run_artifacts(run_id) + + if wants_json(request): + return {"artifacts": artifacts} + + templates = get_templates(request) + return render(templates, "runs/artifacts.html", request, + run_id=run_id, + artifacts=artifacts, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/plan/node/{cache_id}", response_class=HTMLResponse) +async def plan_node_detail( + run_id: str, + cache_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), +): + """HTMX partial: Get plan node detail by cache_id.""" + from artdag_common import render_fragment + + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('

Login required

', status_code=401) + + run = await run_service.get_run(run_id) + if not run: + return HTMLResponse('

Run not found

', status_code=404) + + plan = await run_service.get_run_plan(run_id) + if not plan: + return HTMLResponse('

Plan not found

') + + # Build lookups + steps_by_cache_id = {} + steps_by_step_id = {} + for s in plan.get("steps", []): + if s.get("cache_id"): + steps_by_cache_id[s["cache_id"]] = s + if s.get("step_id"): + steps_by_step_id[s["step_id"]] = s + + step = steps_by_cache_id.get(cache_id) + if not step: + return HTMLResponse(f'

Step not found

') + + cache_manager = get_cache_manager() + + # Node colors + node_colors = { + "SOURCE": "#3b82f6", "EFFECT": "#22c55e", "OUTPUT": "#a855f7", + "ANALYSIS": "#f59e0b", "_LIST": "#6366f1", "default": "#6b7280" + } + node_color = node_colors.get(step.get("node_type", "EFFECT"), node_colors["default"]) + + # Check cache status + has_cached = cache_manager.has_content(cache_id) if cache_id else False + + # Determine output media type + output_media_type = None + output_preview = False + if has_cached: + cache_path = cache_manager.get_content_path(cache_id) + if cache_path: + output_media_type = run_service.detect_media_type(cache_path) + output_preview = output_media_type in ('video', 'image', 'audio') + + # Check for IPFS CID + ipfs_cid = None + if run.step_results: + res = run.step_results.get(step.get("step_id")) + if isinstance(res, dict) and res.get("cid"): + ipfs_cid = res["cid"] + + # Build input previews + inputs = [] + for inp_step_id in step.get("input_steps", []): + inp_step = steps_by_step_id.get(inp_step_id) + if inp_step: + inp_cache_id = inp_step.get("cache_id", "") + inp_has_cached = cache_manager.has_content(inp_cache_id) if inp_cache_id else False + inp_media_type = None + if inp_has_cached: + inp_path = cache_manager.get_content_path(inp_cache_id) + if inp_path: + inp_media_type = run_service.detect_media_type(inp_path) + + inputs.append({ + "name": inp_step.get("name", inp_step_id[:12]), + "cache_id": inp_cache_id, + "media_type": inp_media_type, + "has_cached": inp_has_cached, + }) + + status = "cached" if (has_cached or ipfs_cid) else ("completed" if run.status == "completed" else "pending") + + templates = get_templates(request) + return HTMLResponse(render_fragment(templates, "runs/plan_node.html", + step=step, + cache_id=cache_id, + node_color=node_color, + status=status, + has_cached=has_cached, + output_preview=output_preview, + output_media_type=output_media_type, + ipfs_cid=ipfs_cid, + ipfs_gateway="https://ipfs.io/ipfs", + inputs=inputs, + config=step.get("config", {}), + )) + + +@router.delete("/{run_id}/ui", response_class=HTMLResponse) +async def ui_discard_run( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), +): + """HTMX handler: discard a run.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse( + '
Login required
', + status_code=401 + ) + + success, error = await run_service.discard_run(run_id, ctx.actor_id, ctx.username) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse( + '
Run discarded
' + '' + ) + + +@router.post("/{run_id}/publish") +async def publish_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Publish run output to L2 and IPFS.""" + from ..services.cache_service import CacheService + from ..dependencies import get_cache_manager + import database + + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, "Run not found") + + # Check if run has output + output_cid = run.get("output_cid") + if not output_cid: + error = "Run has no output to publish" + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + # Use cache service to publish the output + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=output_cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "output_cid": output_cid, "published": True} + + +@router.post("/rerun/{recipe_id}", response_class=HTMLResponse) +async def rerun_recipe( + recipe_id: str, + request: Request, +): + """HTMX handler: run a recipe again. + + Fetches the recipe by CID and starts a new streaming run. + """ + import uuid + import database + from tasks.streaming import run_stream + + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse( + '
Login required
', + status_code=401 + ) + + # Fetch the recipe + try: + from ..services.recipe_service import RecipeService + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + return HTMLResponse(f'
Recipe not found: {recipe_id[:16]}...
') + + # Get the S-expression content + recipe_sexp = recipe.get("sexp") + if not recipe_sexp: + return HTMLResponse('
Recipe has no S-expression content
') + + # Extract recipe name for output + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', recipe_sexp) + recipe_name = name_match.group(1) if name_match else f"stream" + + # Generate new run ID + run_id = str(uuid.uuid4()) + + # Extract duration from recipe if present (look for :duration pattern) + duration = None + duration_match = re.search(r':duration\s+(\d+(?:\.\d+)?)', recipe_sexp) + if duration_match: + duration = float(duration_match.group(1)) + + # Extract fps from recipe if present + fps = None + fps_match = re.search(r':fps\s+(\d+(?:\.\d+)?)', recipe_sexp) + if fps_match: + fps = float(fps_match.group(1)) + + # Submit Celery task to GPU queue + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=recipe_sexp, + output_name=f"{recipe_name}.mp4", + duration=duration, + fps=fps, + actor_id=ctx.actor_id, + sources_sexp=None, + audio_sexp=None, + ), + queue='gpu', + ) + + # Store in database + await database.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe_id, + inputs=[], + actor_id=ctx.actor_id, + dag_json=recipe_sexp, + output_name=f"{recipe_name}.mp4", + ) + + logger.info(f"Started rerun {run_id} for recipe {recipe_id[:16]}...") + + return HTMLResponse( + f'
Started new run
' + f'' + ) + + except Exception as e: + logger.error(f"Failed to rerun recipe {recipe_id}: {e}") + return HTMLResponse(f'
Error: {str(e)}
') + + +@router.delete("/admin/purge-failed") +async def purge_failed_runs( + request: Request, + ctx: UserContext = Depends(get_current_user), +): + """Delete all failed runs from pending_runs table. + + Requires authentication OR admin token in X-Admin-Token header. + """ + import database + import os + + # Check for admin token + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = request.headers.get("X-Admin-Token") + + # Require either valid auth or admin token + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Get all failed runs + failed_runs = await database.list_pending_runs(status="failed") + + deleted = [] + for run in failed_runs: + run_id = run.get("run_id") + try: + await database.delete_pending_run(run_id) + deleted.append(run_id) + except Exception as e: + logger.warning(f"Failed to delete run {run_id}: {e}") + + logger.info(f"Purged {len(deleted)} failed runs") + return {"purged": len(deleted), "run_ids": deleted} + + +@router.post("/{run_id}/pause") +async def pause_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Pause a running render. Waits for current segment to complete. + + The render will checkpoint at the next segment boundary and stop. + """ + import database + from celery_app import app as celery_app + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] != 'running': + raise HTTPException(400, f"Can only pause running renders (current status: {pending['status']})") + + # Revoke the Celery task (soft termination via SIGTERM - allows cleanup) + celery_task_id = pending.get('celery_task_id') + if celery_task_id: + celery_app.control.revoke(celery_task_id, terminate=True, signal='SIGTERM') + logger.info(f"Sent SIGTERM to task {celery_task_id} for run {run_id}") + + # Update status to 'paused' + await database.update_pending_run_status(run_id, 'paused') + + return { + "run_id": run_id, + "status": "paused", + "checkpoint_frame": pending.get('checkpoint_frame'), + } + + +@router.post("/{run_id}/resume") +async def resume_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Resume a paused or failed run from its last checkpoint. + + The render will continue from the checkpoint frame. + """ + import database + from tasks.streaming import run_stream + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] not in ('failed', 'paused'): + raise HTTPException(400, f"Can only resume failed/paused runs (current status: {pending['status']})") + + if not pending.get('checkpoint_frame'): + raise HTTPException(400, "No checkpoint available - use restart instead") + + if not pending.get('resumable', True): + raise HTTPException(400, "Run checkpoint is corrupted - use restart instead") + + # Submit new Celery task with resume=True + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=pending.get('dag_json', ''), # Recipe is stored in dag_json + output_name=pending.get('output_name', 'output.mp4'), + actor_id=pending.get('actor_id'), + resume=True, + ), + queue='gpu', + ) + + # Update status and celery_task_id + await database.update_pending_run_status(run_id, 'running') + + # Update the celery_task_id manually since create_pending_run isn't called + async with database.pool.acquire() as conn: + await conn.execute( + "UPDATE pending_runs SET celery_task_id = $2, updated_at = NOW() WHERE run_id = $1", + run_id, task.id + ) + + logger.info(f"Resumed run {run_id} from frame {pending.get('checkpoint_frame')} with task {task.id}") + + return { + "run_id": run_id, + "status": "running", + "celery_task_id": task.id, + "resumed_from_frame": pending.get('checkpoint_frame'), + } + + +@router.post("/{run_id}/restart") +async def restart_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Restart a failed/paused run from the beginning (discard checkpoint). + + All progress will be lost. Use resume instead to continue from checkpoint. + """ + import database + from tasks.streaming import run_stream + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] not in ('failed', 'paused'): + raise HTTPException(400, f"Can only restart failed/paused runs (current status: {pending['status']})") + + # Clear checkpoint data + await database.clear_run_checkpoint(run_id) + + # Submit new Celery task (without resume) + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=pending.get('dag_json', ''), # Recipe is stored in dag_json + output_name=pending.get('output_name', 'output.mp4'), + actor_id=pending.get('actor_id'), + resume=False, + ), + queue='gpu', + ) + + # Update status and celery_task_id + await database.update_pending_run_status(run_id, 'running') + + async with database.pool.acquire() as conn: + await conn.execute( + "UPDATE pending_runs SET celery_task_id = $2, updated_at = NOW() WHERE run_id = $1", + run_id, task.id + ) + + logger.info(f"Restarted run {run_id} from beginning with task {task.id}") + + return { + "run_id": run_id, + "status": "running", + "celery_task_id": task.id, + } + + +@router.get("/{run_id}/stream") +async def stream_run_output( + run_id: str, + request: Request, +): + """Stream the video output of a running render. + + For IPFS HLS streams, redirects to the IPFS gateway playlist. + For local HLS streams, redirects to the m3u8 playlist. + For legacy MP4 streams, returns the file directly. + """ + from fastapi.responses import StreamingResponse, FileResponse, RedirectResponse + from pathlib import Path + import os + import database + from celery_app import app as celery_app + + await database.init_db() + + # Check for IPFS HLS streaming first (distributed P2P streaming) + pending = await database.get_pending_run(run_id) + if pending and pending.get("celery_task_id"): + task_id = pending["celery_task_id"] + result = celery_app.AsyncResult(task_id) + if result.ready() and isinstance(result.result, dict): + ipfs_playlist_url = result.result.get("ipfs_playlist_url") + if ipfs_playlist_url: + logger.info(f"Redirecting to IPFS stream: {ipfs_playlist_url}") + return RedirectResponse(url=ipfs_playlist_url, status_code=302) + + cache_dir = os.environ.get("CACHE_DIR", "/data/cache") + stream_dir = Path(cache_dir) / "streaming" / run_id + + # Check for local HLS output + hls_playlist = stream_dir / "stream.m3u8" + if hls_playlist.exists(): + # Redirect to the HLS playlist endpoint + return RedirectResponse( + url=f"/runs/{run_id}/hls/stream.m3u8", + status_code=302 + ) + + # Fall back to legacy MP4 streaming + stream_path = stream_dir / "output.mp4" + if not stream_path.exists(): + raise HTTPException(404, "Stream not available yet") + + file_size = stream_path.stat().st_size + if file_size == 0: + raise HTTPException(404, "Stream not ready") + + return FileResponse( + path=str(stream_path), + media_type="video/mp4", + headers={ + "Accept-Ranges": "bytes", + "Cache-Control": "no-cache, no-store, must-revalidate", + "X-Content-Size": str(file_size), + } + ) + + +@router.get("/{run_id}/hls/{filename:path}") +async def serve_hls_content( + run_id: str, + filename: str, + request: Request, +): + """Serve HLS playlist and segments for live streaming. + + Serves stream.m3u8 (playlist) and segment_*.ts files. + The playlist updates as new segments are rendered. + + If files aren't found locally, proxies to the GPU worker (if configured). + """ + from fastapi.responses import FileResponse, StreamingResponse + from pathlib import Path + import os + import httpx + + cache_dir = os.environ.get("CACHE_DIR", "/data/cache") + stream_dir = Path(cache_dir) / "streaming" / run_id + file_path = stream_dir / filename + + # Security: ensure we're only serving files within stream_dir + try: + file_path_resolved = file_path.resolve() + stream_dir_resolved = stream_dir.resolve() + if stream_dir.exists() and not str(file_path_resolved).startswith(str(stream_dir_resolved)): + raise HTTPException(403, "Invalid path") + except Exception: + pass # Allow proxy fallback + + # Determine content type + if filename.endswith(".m3u8"): + media_type = "application/vnd.apple.mpegurl" + headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Access-Control-Allow-Origin": "*", + } + elif filename.endswith(".ts"): + media_type = "video/mp2t" + headers = { + "Cache-Control": "public, max-age=3600", + "Access-Control-Allow-Origin": "*", + } + else: + raise HTTPException(400, "Invalid file type") + + # For playlist requests, check IPFS first (redirect to IPFS gateway) + if filename == "stream.m3u8" and not file_path.exists(): + import database + from fastapi.responses import RedirectResponse + await database.init_db() + + # Check pending run for IPFS playlist + pending = await database.get_pending_run(run_id) + if pending and pending.get("ipfs_playlist_cid"): + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ipfs_url = f"{gateway}/{pending['ipfs_playlist_cid']}" + return RedirectResponse(url=ipfs_url, status_code=302) + + # Check completed run cache + run = await database.get_run_cache(run_id) + if run and run.get("ipfs_playlist_cid"): + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ipfs_url = f"{gateway}/{run['ipfs_playlist_cid']}" + return RedirectResponse(url=ipfs_url, status_code=302) + + # Try local file first + if file_path.exists(): + return FileResponse( + path=str(file_path), + media_type=media_type, + headers=headers, + ) + + # Fallback: proxy to GPU worker if configured + gpu_worker_url = os.environ.get("GPU_WORKER_STREAM_URL") + if gpu_worker_url: + # Proxy request to GPU worker + proxy_url = f"{gpu_worker_url}/{run_id}/{filename}" + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get(proxy_url) + if resp.status_code == 200: + return StreamingResponse( + content=iter([resp.content]), + media_type=media_type, + headers=headers, + ) + except Exception as e: + logger.warning(f"GPU worker proxy failed: {e}") + + raise HTTPException(404, f"File not found: {filename}") + + +@router.get("/{run_id}/playlist.m3u8") +async def get_playlist(run_id: str, request: Request): + """Get live HLS master playlist for a streaming run. + + For multi-resolution streams: generates a master playlist with DYNAMIC quality URLs. + For single-resolution streams: returns the playlist directly from IPFS. + """ + import database + import os + import httpx + from fastapi.responses import Response + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + quality_playlists = pending.get("quality_playlists") + + # Multi-resolution stream: generate master playlist with dynamic quality URLs + if quality_playlists: + lines = ["#EXTM3U", "#EXT-X-VERSION:3"] + + for name, info in quality_playlists.items(): + if not info.get("cid"): + continue + + lines.append( + f"#EXT-X-STREAM-INF:BANDWIDTH={info['bitrate'] * 1000}," + f"RESOLUTION={info['width']}x{info['height']}," + f"NAME=\"{name}\"" + ) + # Use dynamic URL that fetches latest CID from database + lines.append(f"/runs/{run_id}/quality/{name}/playlist.m3u8") + + if len(lines) <= 2: + raise HTTPException(404, "No quality playlists available") + + playlist_content = "\n".join(lines) + "\n" + + else: + # Single-resolution stream: fetch directly from IPFS + ipfs_playlist_cid = pending.get("ipfs_playlist_cid") + if not ipfs_playlist_cid: + raise HTTPException(404, "HLS playlist not created - rendering likely failed") + + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={ipfs_playlist_cid}") + if resp.status_code != 200: + raise HTTPException(502, "Failed to fetch playlist from IPFS") + playlist_content = resp.text + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Rewrite IPFS URLs to use our proxy endpoint + import re + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + + playlist_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + playlist_content = re.sub( + r'/ipfs(?:-ts)?/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + + return Response( + content=playlist_content, + media_type="application/vnd.apple.mpegurl", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/quality/{quality}/playlist.m3u8") +async def get_quality_playlist(run_id: str, quality: str, request: Request): + """Get quality-level HLS playlist for a streaming run. + + Fetches the LATEST CID for this quality from the database, + so HLS.js always gets updated content. + """ + import database + import os + import httpx + from fastapi.responses import Response + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + quality_playlists = pending.get("quality_playlists") + if not quality_playlists or quality not in quality_playlists: + raise HTTPException(404, f"Quality '{quality}' not found") + + quality_cid = quality_playlists[quality].get("cid") + if not quality_cid: + raise HTTPException(404, f"Quality '{quality}' playlist not ready") + + # Fetch playlist from local IPFS node + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={quality_cid}") + if resp.status_code != 200: + raise HTTPException(502, f"Failed to fetch quality playlist from IPFS: {quality_cid}") + playlist_content = resp.text + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Rewrite segment URLs to use our proxy (segments are still static IPFS content) + import re + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + + # Replace absolute gateway URLs with our proxy + playlist_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + # Also handle /ipfs/ paths and /ipfs-ts/ paths + playlist_content = re.sub( + r'/ipfs(?:-ts)?/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + + return Response( + content=playlist_content, + media_type="application/vnd.apple.mpegurl", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/ipfs-proxy/{cid}") +async def proxy_ipfs_content(run_id: str, cid: str, request: Request): + """Proxy IPFS content with no-cache headers for live streaming. + + This allows HLS.js to poll for updated playlists through us rather than + hitting static IPFS URLs directly. + """ + import os + import httpx + from fastapi.responses import Response + + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={cid}") + if resp.status_code != 200: + raise HTTPException(502, f"Failed to fetch from IPFS: {cid}") + content = resp.content + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Determine content type + if cid.endswith('.m3u8') or b'#EXTM3U' in content[:20]: + media_type = "application/vnd.apple.mpegurl" + # Rewrite any IPFS URLs in sub-playlists too + import re + text_content = content.decode('utf-8', errors='replace') + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + text_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + text_content + ) + text_content = re.sub( + r'/ipfs/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + text_content + ) + content = text_content.encode('utf-8') + elif b'\x47' in content[:1]: # MPEG-TS sync byte + media_type = "video/mp2t" + else: + media_type = "application/octet-stream" + + return Response( + content=content, + media_type=media_type, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/ipfs-stream") +async def get_ipfs_stream_info(run_id: str, request: Request): + """Get IPFS streaming info for a run. + + Returns the IPFS playlist URL and segment info if available. + This allows clients to stream directly from IPFS gateways. + """ + from celery_app import app as celery_app + import database + import os + + await database.init_db() + + # Try to get pending run to find the Celery task ID + pending = await database.get_pending_run(run_id) + from fastapi.responses import JSONResponse + no_cache_headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + + if not pending: + # Try completed runs + run = await database.get_run_cache(run_id) + if not run: + raise HTTPException(404, "Run not found") + # For completed runs, check if we have IPFS info stored + ipfs_cid = run.get("ipfs_cid") + if ipfs_cid: + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + return JSONResponse( + content={ + "run_id": run_id, + "status": "completed", + "ipfs_video_url": f"{gateway}/{ipfs_cid}", + }, + headers=no_cache_headers + ) + raise HTTPException(404, "No IPFS stream info available") + + task_id = pending.get("celery_task_id") + if not task_id: + raise HTTPException(404, "No task ID for this run") + + # Get the Celery task result + result = celery_app.AsyncResult(task_id) + + from fastapi.responses import JSONResponse + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + no_cache_headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + + if result.ready(): + # Task is complete - check the result for IPFS playlist info + task_result = result.result + if isinstance(task_result, dict): + ipfs_playlist_cid = task_result.get("ipfs_playlist_cid") + ipfs_playlist_url = task_result.get("ipfs_playlist_url") + if ipfs_playlist_url: + return JSONResponse( + content={ + "run_id": run_id, + "status": "completed", + "ipfs_playlist_cid": ipfs_playlist_cid, + "ipfs_playlist_url": ipfs_playlist_url, + "segment_count": task_result.get("ipfs_segment_count", 0), + }, + headers=no_cache_headers + ) + + # Task is still running - check database for live playlist updates + ipfs_playlist_cid = pending.get("ipfs_playlist_cid") + + # Get task state and progress metadata + task_state = result.state + task_info = result.info if isinstance(result.info, dict) else {} + + response_data = { + "run_id": run_id, + "status": task_state.lower() if task_state else pending.get("status", "pending"), + } + + # Add progress metadata if available + if task_info: + if "progress" in task_info: + response_data["progress"] = task_info["progress"] + if "frame" in task_info: + response_data["frame"] = task_info["frame"] + if "total_frames" in task_info: + response_data["total_frames"] = task_info["total_frames"] + if "percent" in task_info: + response_data["percent"] = task_info["percent"] + + if ipfs_playlist_cid: + response_data["ipfs_playlist_cid"] = ipfs_playlist_cid + response_data["ipfs_playlist_url"] = f"{gateway}/{ipfs_playlist_cid}" + else: + response_data["message"] = "IPFS streaming info not yet available" + + # No caching for live streaming data + return JSONResponse( + content=response_data, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + ) diff --git a/l1/app/routers/storage.py b/l1/app/routers/storage.py new file mode 100644 index 0000000..b8f2fc8 --- /dev/null +++ b/l1/app/routers/storage.py @@ -0,0 +1,264 @@ +""" +Storage provider routes for L1 server. + +Manages user storage backends (Pinata, web3.storage, local, etc.) +""" + +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException, Form +from fastapi.responses import HTMLResponse, RedirectResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import get_database, get_current_user, require_auth, get_templates +from ..services.storage_service import StorageService, STORAGE_PROVIDERS_INFO, VALID_PROVIDER_TYPES + +router = APIRouter() + + +# Import storage_providers module +import storage_providers as sp_module + + +def get_storage_service(): + """Get storage service instance.""" + import database + return StorageService(database, sp_module) + + +class AddStorageRequest(BaseModel): + provider_type: str + config: Dict[str, Any] + capacity_gb: int = 5 + provider_name: Optional[str] = None + + +class UpdateStorageRequest(BaseModel): + config: Optional[Dict[str, Any]] = None + capacity_gb: Optional[int] = None + is_active: Optional[bool] = None + + +@router.get("") +async def list_storage( + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """List user's storage providers. HTML for browsers, JSON for API.""" + storages = await storage_service.list_storages(ctx.actor_id) + + if wants_json(request): + return {"storages": storages} + + # Render HTML template + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "storage/list.html", request, + storages=storages, + user=ctx, + nav_counts=nav_counts, + providers_info=STORAGE_PROVIDERS_INFO, + active_tab="storage", + ) + + +@router.post("") +async def add_storage( + req: AddStorageRequest, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Add a storage provider via API.""" + ctx = await require_auth(request) + + storage_id, error = await storage_service.add_storage( + actor_id=ctx.actor_id, + provider_type=req.provider_type, + config=req.config, + capacity_gb=req.capacity_gb, + provider_name=req.provider_name, + ) + + if error: + raise HTTPException(400, error) + + return {"id": storage_id, "message": "Storage provider added"} + + +@router.post("/add") +async def add_storage_form( + request: Request, + provider_type: str = Form(...), + provider_name: Optional[str] = Form(None), + description: Optional[str] = Form(None), + capacity_gb: int = Form(5), + api_key: Optional[str] = Form(None), + secret_key: Optional[str] = Form(None), + api_token: Optional[str] = Form(None), + project_id: Optional[str] = Form(None), + project_secret: Optional[str] = Form(None), + access_key: Optional[str] = Form(None), + bucket: Optional[str] = Form(None), + path: Optional[str] = Form(None), + storage_service: StorageService = Depends(get_storage_service), +): + """Add a storage provider via HTML form.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Not authenticated
', status_code=401) + + # Build config from form + form_data = { + "api_key": api_key, + "secret_key": secret_key, + "api_token": api_token, + "project_id": project_id, + "project_secret": project_secret, + "access_key": access_key, + "bucket": bucket, + "path": path, + } + config, error = storage_service.build_config_from_form(provider_type, form_data) + + if error: + return HTMLResponse(f'
{error}
') + + storage_id, error = await storage_service.add_storage( + actor_id=ctx.actor_id, + provider_type=provider_type, + config=config, + capacity_gb=capacity_gb, + provider_name=provider_name, + description=description, + ) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse(f''' +
Storage provider added successfully!
+ + ''') + + +@router.get("/{storage_id}") +async def get_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Get a specific storage provider.""" + ctx = await require_auth(request) + + storage = await storage_service.get_storage(storage_id, ctx.actor_id) + if not storage: + raise HTTPException(404, "Storage provider not found") + + return storage + + +@router.patch("/{storage_id}") +async def update_storage( + storage_id: int, + req: UpdateStorageRequest, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Update a storage provider.""" + ctx = await require_auth(request) + + success, error = await storage_service.update_storage( + storage_id=storage_id, + actor_id=ctx.actor_id, + config=req.config, + capacity_gb=req.capacity_gb, + is_active=req.is_active, + ) + + if error: + raise HTTPException(400, error) + + return {"message": "Storage provider updated"} + + +@router.delete("/{storage_id}") +async def delete_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """Remove a storage provider.""" + success, error = await storage_service.delete_storage(storage_id, ctx.actor_id) + + if error: + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse("") + + return {"message": "Storage provider removed"} + + +@router.post("/{storage_id}/test") +async def test_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Test storage provider connectivity.""" + ctx = await get_current_user(request) + if not ctx: + if wants_html(request): + return HTMLResponse('Not authenticated', status_code=401) + raise HTTPException(401, "Not authenticated") + + success, message = await storage_service.test_storage(storage_id, ctx.actor_id) + + if wants_html(request): + color = "green" if success else "red" + return HTMLResponse(f'{message}') + + return {"success": success, "message": message} + + +@router.get("/type/{provider_type}") +async def storage_type_page( + provider_type: str, + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """Page for managing storage configs of a specific type.""" + if provider_type not in STORAGE_PROVIDERS_INFO: + raise HTTPException(404, "Invalid provider type") + + storages = await storage_service.list_by_type(ctx.actor_id, provider_type) + provider_info = STORAGE_PROVIDERS_INFO[provider_type] + + if wants_json(request): + return { + "provider_type": provider_type, + "provider_info": provider_info, + "storages": storages, + } + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "storage/type.html", request, + provider_type=provider_type, + provider_info=provider_info, + storages=storages, + user=ctx, + nav_counts=nav_counts, + active_tab="storage", + ) diff --git a/l1/app/services/__init__.py b/l1/app/services/__init__.py new file mode 100644 index 0000000..76eba24 --- /dev/null +++ b/l1/app/services/__init__.py @@ -0,0 +1,15 @@ +""" +L1 Server Services. + +Business logic layer between routers and repositories. +""" + +from .run_service import RunService +from .recipe_service import RecipeService +from .cache_service import CacheService + +__all__ = [ + "RunService", + "RecipeService", + "CacheService", +] diff --git a/l1/app/services/auth_service.py b/l1/app/services/auth_service.py new file mode 100644 index 0000000..3f3ce26 --- /dev/null +++ b/l1/app/services/auth_service.py @@ -0,0 +1,138 @@ +""" +Auth Service - token management and user verification. +""" + +import hashlib +import base64 +import json +from typing import Optional, Dict, Any, TYPE_CHECKING + +import httpx + +from artdag_common.middleware.auth import UserContext +from ..config import settings + +if TYPE_CHECKING: + import redis + from starlette.requests import Request + + +# Token expiry (30 days to match token lifetime) +TOKEN_EXPIRY_SECONDS = 60 * 60 * 24 * 30 + +# Redis key prefixes +REVOKED_KEY_PREFIX = "artdag:revoked:" +USER_TOKENS_PREFIX = "artdag:user_tokens:" + + +class AuthService: + """Service for authentication and token management.""" + + def __init__(self, redis_client: "redis.Redis[bytes]") -> None: + self.redis = redis_client + + def register_user_token(self, username: str, token: str) -> None: + """Track a token for a user (for later revocation by username).""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{USER_TOKENS_PREFIX}{username}" + self.redis.sadd(key, token_hash) + self.redis.expire(key, TOKEN_EXPIRY_SECONDS) + + def revoke_token(self, token: str) -> bool: + """Add token to revocation set. Returns True if newly revoked.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) + return result is not None + + def revoke_token_hash(self, token_hash: str) -> bool: + """Add token hash to revocation set. Returns True if newly revoked.""" + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) + return result is not None + + def revoke_all_user_tokens(self, username: str) -> int: + """Revoke all tokens for a user. Returns count revoked.""" + key = f"{USER_TOKENS_PREFIX}{username}" + token_hashes = self.redis.smembers(key) + count = 0 + for token_hash in token_hashes: + if self.revoke_token_hash( + token_hash.decode() if isinstance(token_hash, bytes) else token_hash + ): + count += 1 + self.redis.delete(key) + return count + + def is_token_revoked(self, token: str) -> bool: + """Check if token has been revoked.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + return self.redis.exists(key) > 0 + + def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]: + """Decode JWT claims without verification.""" + try: + parts = token.split(".") + if len(parts) != 3: + return None + payload = parts[1] + # Add padding + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + return json.loads(base64.urlsafe_b64decode(payload)) + except (json.JSONDecodeError, ValueError): + return None + + def get_user_context_from_token(self, token: str) -> Optional[UserContext]: + """Extract user context from a token.""" + if self.is_token_revoked(token): + return None + + claims = self.decode_token_claims(token) + if not claims: + return None + + username = claims.get("username") or claims.get("sub") + actor_id = claims.get("actor_id") or claims.get("actor") + + if not username: + return None + + return UserContext( + username=username, + actor_id=actor_id or f"@{username}", + token=token, + l2_server=settings.l2_server, + ) + + async def verify_token_with_l2(self, token: str) -> Optional[UserContext]: + """Verify token with L2 server.""" + ctx = self.get_user_context_from_token(token) + if not ctx: + return None + + # If L2 server configured, verify token + if settings.l2_server: + try: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{settings.l2_server}/auth/verify", + headers={"Authorization": f"Bearer {token}"}, + timeout=5.0, + ) + if resp.status_code != 200: + return None + except httpx.RequestError: + # L2 unavailable, trust the token + pass + + return ctx + + def get_user_from_cookie(self, request: "Request") -> Optional[UserContext]: + """Extract user context from auth cookie.""" + token = request.cookies.get("auth_token") + if not token: + return None + return self.get_user_context_from_token(token) diff --git a/l1/app/services/cache_service.py b/l1/app/services/cache_service.py new file mode 100644 index 0000000..9b7bcd8 --- /dev/null +++ b/l1/app/services/cache_service.py @@ -0,0 +1,618 @@ +""" +Cache Service - business logic for cache and media management. +""" + +import asyncio +import json +import logging +import os +import subprocess +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +import httpx + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from database import Database + from cache_manager import L1CacheManager + + +def detect_media_type(cache_path: Path) -> str: + """Detect if file is image, video, or audio based on magic bytes.""" + try: + with open(cache_path, "rb") as f: + header = f.read(32) + except Exception: + return "unknown" + + # Video signatures + if header[:4] == b'\x1a\x45\xdf\xa3': # WebM/MKV + return "video" + if len(header) > 8 and header[4:8] == b'ftyp': # MP4/MOV + return "video" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'AVI ': # AVI + return "video" + + # Image signatures + if header[:8] == b'\x89PNG\r\n\x1a\n': # PNG + return "image" + if header[:2] == b'\xff\xd8': # JPEG + return "image" + if header[:6] in (b'GIF87a', b'GIF89a'): # GIF + return "image" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WEBP': # WebP + return "image" + + # Audio signatures + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WAVE': # WAV + return "audio" + if header[:3] == b'ID3' or header[:2] == b'\xff\xfb': # MP3 + return "audio" + if header[:4] == b'fLaC': # FLAC + return "audio" + + return "unknown" + + +def get_mime_type(path: Path) -> str: + """Get MIME type based on file magic bytes.""" + media_type = detect_media_type(path) + if media_type == "video": + try: + with open(path, "rb") as f: + header = f.read(12) + if header[:4] == b'\x1a\x45\xdf\xa3': + return "video/x-matroska" + return "video/mp4" + except Exception: + return "video/mp4" + elif media_type == "image": + try: + with open(path, "rb") as f: + header = f.read(8) + if header[:8] == b'\x89PNG\r\n\x1a\n': + return "image/png" + if header[:2] == b'\xff\xd8': + return "image/jpeg" + if header[:6] in (b'GIF87a', b'GIF89a'): + return "image/gif" + return "image/jpeg" + except Exception: + return "image/jpeg" + elif media_type == "audio": + return "audio/mpeg" + return "application/octet-stream" + + +class CacheService: + """ + Service for managing cached content. + + Handles content retrieval, metadata, and media type detection. + """ + + def __init__(self, database: "Database", cache_manager: "L1CacheManager") -> None: + self.db = database + self.cache = cache_manager + self.cache_dir = Path(os.environ.get("CACHE_DIR", "/tmp/artdag-cache")) + + async def get_cache_item(self, cid: str, actor_id: str = None) -> Optional[Dict[str, Any]]: + """Get cached item with full metadata for display.""" + # Get metadata from database first + meta = await self.db.load_item_metadata(cid, actor_id) + cache_item = await self.db.get_cache_item(cid) + + # Check if content exists locally + path = self.cache.get_by_cid(cid) if self.cache.has_content(cid) else None + + if path and path.exists(): + # Local file exists - detect type from file + media_type = detect_media_type(path) + mime_type = get_mime_type(path) + size = path.stat().st_size + else: + # File not local - check database for type info + # Try to get type from item_types table + media_type = "unknown" + mime_type = "application/octet-stream" + size = 0 + + if actor_id: + try: + item_types = await self.db.get_item_types(cid, actor_id) + if item_types: + media_type = item_types[0].get("type", "unknown") + if media_type == "video": + mime_type = "video/mp4" + elif media_type == "image": + mime_type = "image/png" + elif media_type == "audio": + mime_type = "audio/mpeg" + except Exception: + pass + + # If no local path but we have IPFS CID, content is available remotely + if not cache_item: + return None + + result = { + "cid": cid, + "path": str(path) if path else None, + "media_type": media_type, + "mime_type": mime_type, + "size": size, + "ipfs_cid": cache_item.get("ipfs_cid") if cache_item else None, + "meta": meta, + "remote_only": path is None or not path.exists(), + } + + # Unpack meta fields to top level for template convenience + if meta: + result["title"] = meta.get("title") + result["description"] = meta.get("description") + result["tags"] = meta.get("tags", []) + result["source_type"] = meta.get("source_type") + result["source_note"] = meta.get("source_note") + result["created_at"] = meta.get("created_at") + result["filename"] = meta.get("filename") + + # Get friendly name if actor_id provided + if actor_id: + from .naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(actor_id, cid) + if friendly: + result["friendly_name"] = friendly["friendly_name"] + result["base_name"] = friendly["base_name"] + result["version_id"] = friendly["version_id"] + + return result + + async def check_access(self, cid: str, actor_id: str, username: str) -> bool: + """Check if user has access to content.""" + user_hashes = await self._get_user_cache_hashes(username, actor_id) + return cid in user_hashes + + async def _get_user_cache_hashes(self, username: str, actor_id: Optional[str] = None) -> set: + """Get all cache hashes owned by or associated with a user.""" + match_values = [username] + if actor_id: + match_values.append(actor_id) + + hashes = set() + + # Query database for items owned by user + if actor_id: + try: + db_items = await self.db.get_user_items(actor_id) + for item in db_items: + hashes.add(item["cid"]) + except Exception: + pass + + # Legacy: Files uploaded by user (JSON metadata) + if self.cache_dir.exists(): + for f in self.cache_dir.iterdir(): + if f.name.endswith('.meta.json'): + try: + with open(f, 'r') as mf: + meta = json.load(mf) + if meta.get("uploader") in match_values: + hashes.add(f.name.replace('.meta.json', '')) + except Exception: + pass + + # Files from user's runs (inputs and outputs) + runs = await self._list_user_runs(username, actor_id) + for run in runs: + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + hashes.update(inputs) + if run.get("output_cid"): + hashes.add(run["output_cid"]) + + return hashes + + async def _list_user_runs(self, username: str, actor_id: Optional[str]) -> List[Dict]: + """List runs for a user (helper for access check).""" + from ..dependencies import get_redis_client + import json + + redis = get_redis_client() + runs = [] + cursor = 0 + prefix = "artdag:run:" + + while True: + cursor, keys = redis.scan(cursor=cursor, match=f"{prefix}*", count=100) + for key in keys: + data = redis.get(key) + if data: + run = json.loads(data) + if run.get("actor_id") in (username, actor_id) or run.get("username") in (username, actor_id): + runs.append(run) + if cursor == 0: + break + + return runs + + async def get_raw_file(self, cid: str) -> Tuple[Optional[Path], Optional[str], Optional[str]]: + """Get raw file path, media type, and filename for download.""" + if not self.cache.has_content(cid): + return None, None, None + + path = self.cache.get_by_cid(cid) + if not path or not path.exists(): + return None, None, None + + media_type = detect_media_type(path) + mime = get_mime_type(path) + + # Determine extension + ext = "bin" + if media_type == "video": + try: + with open(path, "rb") as f: + header = f.read(12) + if header[:4] == b'\x1a\x45\xdf\xa3': + ext = "mkv" + else: + ext = "mp4" + except Exception: + ext = "mp4" + elif media_type == "image": + try: + with open(path, "rb") as f: + header = f.read(8) + if header[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + else: + ext = "jpg" + except Exception: + ext = "jpg" + + filename = f"{cid}.{ext}" + return path, mime, filename + + async def get_as_mp4(self, cid: str) -> Tuple[Optional[Path], Optional[str]]: + """Get content as MP4, transcoding if necessary. Returns (path, error).""" + if not self.cache.has_content(cid): + return None, f"Content {cid} not in cache" + + path = self.cache.get_by_cid(cid) + if not path or not path.exists(): + return None, f"Content {cid} not in cache" + + # Check if video + media_type = detect_media_type(path) + if media_type != "video": + return None, "Content is not a video" + + # Check for cached MP4 + mp4_path = self.cache_dir / f"{cid}.mp4" + if mp4_path.exists(): + return mp4_path, None + + # Check if already MP4 format + try: + result = subprocess.run( + ["ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "format=format_name", "-of", "csv=p=0", str(path)], + capture_output=True, text=True, timeout=10 + ) + if "mp4" in result.stdout.lower() or "mov" in result.stdout.lower(): + return path, None + except Exception: + pass + + # Transcode to MP4 + transcode_path = self.cache_dir / f"{cid}.transcoding.mp4" + try: + result = subprocess.run( + ["ffmpeg", "-y", "-i", str(path), + "-c:v", "libx264", "-preset", "fast", "-crf", "23", + "-c:a", "aac", "-b:a", "128k", + "-movflags", "+faststart", + str(transcode_path)], + capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + return None, f"Transcoding failed: {result.stderr[:200]}" + + transcode_path.rename(mp4_path) + return mp4_path, None + + except subprocess.TimeoutExpired: + if transcode_path.exists(): + transcode_path.unlink() + return None, "Transcoding timed out" + except Exception as e: + if transcode_path.exists(): + transcode_path.unlink() + return None, f"Transcoding failed: {e}" + + async def get_metadata(self, cid: str, actor_id: str) -> Optional[Dict[str, Any]]: + """Get content metadata.""" + if not self.cache.has_content(cid): + return None + return await self.db.load_item_metadata(cid, actor_id) + + async def update_metadata( + self, + cid: str, + actor_id: str, + title: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + custom: Optional[Dict[str, Any]] = None, + ) -> Tuple[bool, Optional[str]]: + """Update content metadata. Returns (success, error).""" + if not self.cache.has_content(cid): + return False, "Content not found" + + # Build update dict + updates = {} + if title is not None: + updates["title"] = title + if description is not None: + updates["description"] = description + if tags is not None: + updates["tags"] = tags + if custom is not None: + updates["custom"] = custom + + try: + await self.db.update_item_metadata(cid, actor_id, **updates) + return True, None + except Exception as e: + return False, str(e) + + async def publish_to_l2( + self, + cid: str, + actor_id: str, + l2_server: str, + auth_token: str, + ) -> Tuple[Optional[str], Optional[str]]: + """Publish content to L2 and IPFS. Returns (ipfs_cid, error).""" + if not self.cache.has_content(cid): + return None, "Content not found" + + # Get IPFS CID + cache_item = await self.db.get_cache_item(cid) + ipfs_cid = cache_item.get("ipfs_cid") if cache_item else None + + # Get metadata for origin info + meta = await self.db.load_item_metadata(cid, actor_id) + origin = meta.get("origin") if meta else None + + if not origin or "type" not in origin: + return None, "Origin must be set before publishing" + + if not auth_token: + return None, "Authentication token required" + + # Call L2 publish-cache endpoint + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{l2_server}/assets/publish-cache", + headers={"Authorization": f"Bearer {auth_token}"}, + json={ + "cid": cid, + "ipfs_cid": ipfs_cid, + "asset_name": meta.get("title") or cid[:16], + "asset_type": detect_media_type(self.cache.get_by_cid(cid)), + "origin": origin, + "description": meta.get("description"), + "tags": meta.get("tags", []), + } + ) + resp.raise_for_status() + l2_result = resp.json() + except httpx.HTTPStatusError as e: + error_detail = str(e) + try: + error_detail = e.response.json().get("detail", str(e)) + except Exception: + pass + return None, f"L2 publish failed: {error_detail}" + except Exception as e: + return None, f"L2 publish failed: {e}" + + # Update local metadata with publish status + await self.db.save_l2_share( + cid=cid, + actor_id=actor_id, + l2_server=l2_server, + asset_name=meta.get("title") or cid[:16], + content_type=detect_media_type(self.cache.get_by_cid(cid)) + ) + await self.db.update_item_metadata( + cid=cid, + actor_id=actor_id, + pinned=True, + pin_reason="published" + ) + + return l2_result.get("ipfs_cid") or ipfs_cid, None + + async def delete_content(self, cid: str, actor_id: str) -> Tuple[bool, Optional[str]]: + """ + Remove user's ownership link to cached content. + + This removes the item_types entry linking the user to the content. + The cached file is only deleted if no other users own it. + Returns (success, error). + """ + import logging + logger = logging.getLogger(__name__) + + # Check if pinned for this user + meta = await self.db.load_item_metadata(cid, actor_id) + if meta and meta.get("pinned"): + pin_reason = meta.get("pin_reason", "unknown") + return False, f"Cannot discard pinned item (reason: {pin_reason})" + + # Get the item type to delete the right ownership entry + item_types = await self.db.get_item_types(cid, actor_id) + if not item_types: + return False, "You don't own this content" + + # Remove user's ownership links (all types for this user) + for item in item_types: + item_type = item.get("type", "media") + await self.db.delete_item_type(cid, actor_id, item_type) + + # Remove friendly name + await self.db.delete_friendly_name(actor_id, cid) + + # Check if anyone else still owns this content + remaining_owners = await self.db.get_item_types(cid) + + # Only delete the actual file if no one owns it anymore + if not remaining_owners: + # Check deletion rules via cache_manager + can_delete, reason = self.cache.can_delete(cid) + if can_delete: + # Delete via cache_manager + self.cache.delete_by_cid(cid) + + # Clean up legacy metadata files + meta_path = self.cache_dir / f"{cid}.meta.json" + if meta_path.exists(): + meta_path.unlink() + mp4_path = self.cache_dir / f"{cid}.mp4" + if mp4_path.exists(): + mp4_path.unlink() + + # Delete from database + await self.db.delete_cache_item(cid) + + logger.info(f"Garbage collected content {cid[:16]}... (no remaining owners)") + else: + logger.info(f"Content {cid[:16]}... orphaned but cannot delete: {reason}") + + logger.info(f"Removed content {cid[:16]}... ownership for {actor_id}") + return True, None + + async def import_from_ipfs(self, ipfs_cid: str, actor_id: str) -> Tuple[Optional[str], Optional[str]]: + """Import content from IPFS. Returns (cid, error).""" + try: + import ipfs_client + + # Download from IPFS + legacy_dir = self.cache_dir / "legacy" + legacy_dir.mkdir(parents=True, exist_ok=True) + tmp_path = legacy_dir / f"import-{ipfs_cid[:16]}" + + if not ipfs_client.get_file(ipfs_cid, str(tmp_path)): + return None, f"Could not fetch CID {ipfs_cid} from IPFS" + + # Detect media type before storing + media_type = detect_media_type(tmp_path) + + # Store in cache + cached, new_ipfs_cid = self.cache.put(tmp_path, node_type="import", move=True) + cid = new_ipfs_cid or cached.cid # Prefer IPFS CID + + # Save to database with detected media type + await self.db.create_cache_item(cid, new_ipfs_cid) + await self.db.save_item_metadata( + cid=cid, + actor_id=actor_id, + item_type=media_type, # Use detected type for filtering + filename=f"ipfs-{ipfs_cid[:16]}" + ) + + return cid, None + except Exception as e: + return None, f"Import failed: {e}" + + async def upload_content( + self, + content: bytes, + filename: str, + actor_id: str, + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Upload content to cache. Returns (cid, ipfs_cid, error). + + Files are stored locally first for fast response, then uploaded + to IPFS in the background. + """ + import tempfile + + try: + # Write to temp file + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + + # Detect media type (video/image/audio) before moving file + media_type = detect_media_type(tmp_path) + + # Store locally AND upload to IPFS synchronously + # This ensures the IPFS CID is available immediately for distributed access + cached, ipfs_cid = self.cache.put(tmp_path, node_type="upload", move=True, skip_ipfs=False) + cid = ipfs_cid or cached.cid # Prefer IPFS CID, fall back to local hash + + # Save to database with media category type + await self.db.create_cache_item(cached.cid, ipfs_cid) + await self.db.save_item_metadata( + cid=cid, + actor_id=actor_id, + item_type=media_type, + filename=filename + ) + + if ipfs_cid: + logger.info(f"Uploaded to IPFS: {ipfs_cid[:16]}...") + else: + logger.warning(f"IPFS upload failed, using local hash: {cid[:16]}...") + + return cid, ipfs_cid, None + except Exception as e: + return None, None, f"Upload failed: {e}" + + async def list_media( + self, + actor_id: Optional[str] = None, + username: Optional[str] = None, + offset: int = 0, + limit: int = 24, + media_type: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """List media items in cache.""" + # Get items from database (uses item_types table) + items = await self.db.get_user_items( + actor_id=actor_id or username, + item_type=media_type, # "video", "image", "audio", or None for all + limit=limit, + offset=offset, + ) + + # Add friendly names to items + if actor_id: + from .naming_service import get_naming_service + naming = get_naming_service() + for item in items: + cid = item.get("cid") + if cid: + friendly = await naming.get_by_cid(actor_id, cid) + if friendly: + item["friendly_name"] = friendly["friendly_name"] + item["base_name"] = friendly["base_name"] + + return items + + # Legacy compatibility methods + def has_content(self, cid: str) -> bool: + """Check if content exists in cache.""" + return self.cache.has_content(cid) + + def get_ipfs_cid(self, cid: str) -> Optional[str]: + """Get IPFS CID for cached content.""" + return self.cache.get_ipfs_cid(cid) diff --git a/l1/app/services/naming_service.py b/l1/app/services/naming_service.py new file mode 100644 index 0000000..5678ab2 --- /dev/null +++ b/l1/app/services/naming_service.py @@ -0,0 +1,234 @@ +""" +Naming service for friendly names. + +Handles: +- Name normalization (My Cool Effect -> my-cool-effect) +- Version ID generation (server-signed timestamps) +- Friendly name assignment and resolution +""" + +import hmac +import os +import re +import time +from typing import Optional, Tuple + +import database + + +# Base32 Crockford alphabet (excludes I, L, O, U to avoid confusion) +CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" + + +def _get_server_secret() -> bytes: + """Get server secret for signing version IDs.""" + secret = os.environ.get("SERVER_SECRET", "") + if not secret: + # Fall back to a derived secret from other env vars + # In production, SERVER_SECRET should be set explicitly + secret = os.environ.get("SECRET_KEY", "default-dev-secret") + return secret.encode("utf-8") + + +def _base32_crockford_encode(data: bytes) -> str: + """Encode bytes as base32-crockford (lowercase).""" + # Convert bytes to integer + num = int.from_bytes(data, "big") + if num == 0: + return CROCKFORD_ALPHABET[0] + + result = [] + while num > 0: + result.append(CROCKFORD_ALPHABET[num % 32]) + num //= 32 + + return "".join(reversed(result)) + + +def generate_version_id() -> str: + """ + Generate a version ID that is: + - Always increasing (timestamp-based prefix) + - Verifiable as originating from this server (HMAC suffix) + - Short and URL-safe (13 chars) + + Format: 6 bytes timestamp (ms) + 2 bytes HMAC = 8 bytes = 13 base32 chars + """ + timestamp_ms = int(time.time() * 1000) + timestamp_bytes = timestamp_ms.to_bytes(6, "big") + + # HMAC the timestamp with server secret + secret = _get_server_secret() + sig = hmac.new(secret, timestamp_bytes, "sha256").digest() + + # Combine: 6 bytes timestamp + 2 bytes HMAC signature + combined = timestamp_bytes + sig[:2] + + # Encode as base32-crockford + return _base32_crockford_encode(combined) + + +def normalize_name(name: str) -> str: + """ + Normalize a display name to a base name. + + - Lowercase + - Replace spaces and underscores with dashes + - Remove special characters (keep alphanumeric and dashes) + - Collapse multiple dashes + - Strip leading/trailing dashes + + Examples: + "My Cool Effect" -> "my-cool-effect" + "Brightness_V2" -> "brightness-v2" + "Test!!!Effect" -> "test-effect" + """ + # Lowercase + name = name.lower() + + # Replace spaces and underscores with dashes + name = re.sub(r"[\s_]+", "-", name) + + # Remove anything that's not alphanumeric or dash + name = re.sub(r"[^a-z0-9-]", "", name) + + # Collapse multiple dashes + name = re.sub(r"-+", "-", name) + + # Strip leading/trailing dashes + name = name.strip("-") + + return name or "unnamed" + + +def parse_friendly_name(friendly_name: str) -> Tuple[str, Optional[str]]: + """ + Parse a friendly name into base name and optional version. + + Args: + friendly_name: Name like "my-effect" or "my-effect 01hw3x9k" + + Returns: + Tuple of (base_name, version_id or None) + """ + parts = friendly_name.strip().split(" ", 1) + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + return base_name, version_id + + +def format_friendly_name(base_name: str, version_id: str) -> str: + """Format a base name and version into a full friendly name.""" + return f"{base_name} {version_id}" + + +def format_l2_name(actor_id: str, base_name: str, version_id: str) -> str: + """ + Format a friendly name for L2 sharing. + + Format: @user@domain base-name version-id + """ + return f"{actor_id} {base_name} {version_id}" + + +class NamingService: + """Service for managing friendly names.""" + + async def assign_name( + self, + cid: str, + actor_id: str, + item_type: str, + display_name: Optional[str] = None, + filename: Optional[str] = None, + ) -> dict: + """ + Assign a friendly name to content. + + Args: + cid: Content ID + actor_id: User ID + item_type: Type (recipe, effect, media) + display_name: Human-readable name (optional) + filename: Original filename (used as fallback for media) + + Returns: + Friendly name entry dict + """ + # Determine display name + if not display_name: + if filename: + # Use filename without extension + display_name = os.path.splitext(filename)[0] + else: + display_name = f"unnamed-{item_type}" + + # Normalize to base name + base_name = normalize_name(display_name) + + # Generate version ID + version_id = generate_version_id() + + # Create database entry + entry = await database.create_friendly_name( + actor_id=actor_id, + base_name=base_name, + version_id=version_id, + cid=cid, + item_type=item_type, + display_name=display_name, + ) + + return entry + + async def get_by_cid(self, actor_id: str, cid: str) -> Optional[dict]: + """Get friendly name entry by CID.""" + return await database.get_friendly_name_by_cid(actor_id, cid) + + async def resolve( + self, + actor_id: str, + name: str, + item_type: Optional[str] = None, + ) -> Optional[str]: + """ + Resolve a friendly name to a CID. + + Args: + actor_id: User ID + name: Friendly name ("base-name" or "base-name version") + item_type: Optional type filter + + Returns: + CID or None if not found + """ + return await database.resolve_friendly_name(actor_id, name, item_type) + + async def list_names( + self, + actor_id: str, + item_type: Optional[str] = None, + latest_only: bool = False, + ) -> list: + """List friendly names for a user.""" + return await database.list_friendly_names( + actor_id=actor_id, + item_type=item_type, + latest_only=latest_only, + ) + + async def delete(self, actor_id: str, cid: str) -> bool: + """Delete a friendly name entry.""" + return await database.delete_friendly_name(actor_id, cid) + + +# Module-level instance +_naming_service: Optional[NamingService] = None + + +def get_naming_service() -> NamingService: + """Get the naming service singleton.""" + global _naming_service + if _naming_service is None: + _naming_service = NamingService() + return _naming_service diff --git a/l1/app/services/recipe_service.py b/l1/app/services/recipe_service.py new file mode 100644 index 0000000..6b0a70d --- /dev/null +++ b/l1/app/services/recipe_service.py @@ -0,0 +1,337 @@ +""" +Recipe Service - business logic for recipe management. + +Recipes are S-expressions stored in the content-addressed cache (and IPFS). +The recipe ID is the content hash of the file. +""" + +import tempfile +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +from artdag.sexp import compile_string, parse, serialize, CompileError, ParseError + +if TYPE_CHECKING: + import redis + from cache_manager import L1CacheManager + +from ..types import Recipe, CompiledDAG, VisualizationDAG, VisNode, VisEdge + + +class RecipeService: + """ + Service for managing recipes. + + Recipes are S-expressions stored in the content-addressed cache. + """ + + def __init__(self, redis: "redis.Redis", cache: "L1CacheManager") -> None: + # Redis kept for compatibility but not used for recipe storage + self.redis = redis + self.cache = cache + + async def get_recipe(self, recipe_id: str) -> Optional[Recipe]: + """Get a recipe by ID (content hash).""" + import yaml + import logging + logger = logging.getLogger(__name__) + + # Get from cache (content-addressed storage) + logger.info(f"get_recipe: Looking up recipe_id={recipe_id[:16]}...") + path = self.cache.get_by_cid(recipe_id) + logger.info(f"get_recipe: cache.get_by_cid returned path={path}") + if not path or not path.exists(): + logger.warning(f"get_recipe: Recipe {recipe_id[:16]}... not found in cache") + return None + + with open(path) as f: + content = f.read() + + # Detect format - check if it starts with ( after skipping comments + def is_sexp_format(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(') + return False + + import logging + logger = logging.getLogger(__name__) + + if is_sexp_format(content): + # Detect if this is a streaming recipe (starts with (stream ...)) + def is_streaming_recipe(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(stream') + return False + + if is_streaming_recipe(content): + # Streaming recipes have different format - parse manually + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', content) + recipe_name = name_match.group(1) if name_match else "streaming" + + recipe_data = { + "name": recipe_name, + "sexp": content, + "format": "sexp", + "type": "streaming", + "dag": {"nodes": []}, # Streaming recipes don't have traditional DAG + } + logger.info(f"Parsed streaming recipe {recipe_id[:16]}..., name: {recipe_name}") + else: + # Parse traditional (recipe ...) S-expression + try: + compiled = compile_string(content) + recipe_data = compiled.to_dict() + recipe_data["sexp"] = content + recipe_data["format"] = "sexp" + logger.info(f"Parsed sexp recipe {recipe_id[:16]}..., keys: {list(recipe_data.keys())}") + except (ParseError, CompileError) as e: + logger.warning(f"Failed to parse sexp recipe {recipe_id[:16]}...: {e}") + return {"error": str(e), "recipe_id": recipe_id} + else: + # Parse YAML + try: + recipe_data = yaml.safe_load(content) + if not isinstance(recipe_data, dict): + return {"error": "Invalid YAML: expected dictionary", "recipe_id": recipe_id} + recipe_data["yaml"] = content + recipe_data["format"] = "yaml" + except yaml.YAMLError as e: + return {"error": f"YAML parse error: {e}", "recipe_id": recipe_id} + + # Add the recipe_id to the data for convenience + recipe_data["recipe_id"] = recipe_id + + # Get IPFS CID if available + ipfs_cid = self.cache.get_ipfs_cid(recipe_id) + if ipfs_cid: + recipe_data["ipfs_cid"] = ipfs_cid + + # Compute step_count from nodes (handle both formats) + if recipe_data.get("format") == "sexp": + nodes = recipe_data.get("dag", {}).get("nodes", []) + else: + # YAML format: nodes might be at top level or under dag + nodes = recipe_data.get("nodes", recipe_data.get("dag", {}).get("nodes", [])) + recipe_data["step_count"] = len(nodes) if isinstance(nodes, (list, dict)) else 0 + + return recipe_data + + async def list_recipes(self, actor_id: Optional[str] = None, offset: int = 0, limit: int = 20) -> List[Recipe]: + """ + List recipes owned by a user. + + Queries item_types table for user's recipe links. + """ + import logging + import database + logger = logging.getLogger(__name__) + + recipes = [] + + if not actor_id: + logger.warning("list_recipes called without actor_id") + return [] + + # Get user's recipe CIDs from item_types + user_items = await database.get_user_items(actor_id, item_type="recipe", limit=1000) + recipe_cids = [item["cid"] for item in user_items] + logger.info(f"Found {len(recipe_cids)} recipe CIDs for user {actor_id}") + + for cid in recipe_cids: + recipe = await self.get_recipe(cid) + if recipe and not recipe.get("error"): + recipes.append(recipe) + elif recipe and recipe.get("error"): + logger.warning(f"Recipe {cid[:16]}... has error: {recipe.get('error')}") + + # Add friendly names + from .naming_service import get_naming_service + naming = get_naming_service() + for recipe in recipes: + recipe_id = recipe.get("recipe_id") + if recipe_id: + friendly = await naming.get_by_cid(actor_id, recipe_id) + if friendly: + recipe["friendly_name"] = friendly["friendly_name"] + recipe["base_name"] = friendly["base_name"] + + # Sort by name + recipes.sort(key=lambda r: r.get("name", "")) + + return recipes[offset:offset + limit] + + async def upload_recipe( + self, + content: str, + uploader: str, + name: str = None, + description: str = None, + ) -> Tuple[Optional[str], Optional[str]]: + """ + Upload a recipe from S-expression content. + + The recipe is stored in the cache and pinned to IPFS. + Returns (recipe_id, error_message). + """ + # Validate S-expression + try: + compiled = compile_string(content) + except ParseError as e: + return None, f"Parse error: {e}" + except CompileError as e: + return None, f"Compile error: {e}" + + # Write to temp file for caching + import logging + logger = logging.getLogger(__name__) + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".sexp", mode="w") as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + + # Store in cache (content-addressed, auto-pins to IPFS) + logger.info(f"upload_recipe: Storing recipe in cache from {tmp_path}") + cached, ipfs_cid = self.cache.put(tmp_path, node_type="recipe", move=True) + recipe_id = ipfs_cid or cached.cid # Prefer IPFS CID + logger.info(f"upload_recipe: Stored recipe, cached.cid={cached.cid[:16]}..., ipfs_cid={ipfs_cid[:16] if ipfs_cid else None}, recipe_id={recipe_id[:16]}...") + + # Track ownership in item_types and assign friendly name + if uploader: + import database + display_name = name or compiled.name or "unnamed-recipe" + + # Create item_types entry (ownership link) + await database.save_item_metadata( + cid=recipe_id, + actor_id=uploader, + item_type="recipe", + description=description, + filename=f"{display_name}.sexp", + ) + + # Assign friendly name + from .naming_service import get_naming_service + naming = get_naming_service() + await naming.assign_name( + cid=recipe_id, + actor_id=uploader, + item_type="recipe", + display_name=display_name, + ) + + return recipe_id, None + + except Exception as e: + return None, f"Failed to cache recipe: {e}" + + async def delete_recipe(self, recipe_id: str, actor_id: str = None) -> Tuple[bool, Optional[str]]: + """ + Remove user's ownership link to a recipe. + + This removes the item_types entry linking the user to the recipe. + The cached file is only deleted if no other users own it. + Returns (success, error_message). + """ + import database + + if not actor_id: + return False, "actor_id required" + + # Remove user's ownership link + try: + await database.delete_item_type(recipe_id, actor_id, "recipe") + + # Also remove friendly name + await database.delete_friendly_name(actor_id, recipe_id) + + # Try to garbage collect if no one owns it anymore + # (delete_cache_item only deletes if no item_types remain) + await database.delete_cache_item(recipe_id) + + return True, None + except Exception as e: + return False, f"Failed to delete: {e}" + + def parse_recipe(self, content: str) -> CompiledDAG: + """Parse recipe S-expression content.""" + compiled = compile_string(content) + return compiled.to_dict() + + def build_dag(self, recipe: Recipe) -> VisualizationDAG: + """ + Build DAG visualization data from recipe. + + Returns nodes and edges for Cytoscape.js. + """ + vis_nodes: List[VisNode] = [] + edges: List[VisEdge] = [] + + dag = recipe.get("dag", {}) + dag_nodes = dag.get("nodes", []) + output_node = dag.get("output") + + # Handle list format (compiled S-expression) + if isinstance(dag_nodes, list): + for node_def in dag_nodes: + node_id = node_def.get("id") + node_type = node_def.get("type", "EFFECT") + + vis_nodes.append({ + "data": { + "id": node_id, + "label": node_id, + "nodeType": node_type, + "isOutput": node_id == output_node, + } + }) + + for input_ref in node_def.get("inputs", []): + if isinstance(input_ref, dict): + source = input_ref.get("node") or input_ref.get("input") + else: + source = input_ref + + if source: + edges.append({ + "data": { + "source": source, + "target": node_id, + } + }) + + # Handle dict format + elif isinstance(dag_nodes, dict): + for node_id, node_def in dag_nodes.items(): + node_type = node_def.get("type", "EFFECT") + + vis_nodes.append({ + "data": { + "id": node_id, + "label": node_id, + "nodeType": node_type, + "isOutput": node_id == output_node, + } + }) + + for input_ref in node_def.get("inputs", []): + if isinstance(input_ref, dict): + source = input_ref.get("node") or input_ref.get("input") + else: + source = input_ref + + if source: + edges.append({ + "data": { + "source": source, + "target": node_id, + } + }) + + return {"nodes": vis_nodes, "edges": edges} diff --git a/l1/app/services/run_service.py b/l1/app/services/run_service.py new file mode 100644 index 0000000..5bfe19d --- /dev/null +++ b/l1/app/services/run_service.py @@ -0,0 +1,1001 @@ +""" +Run Service - business logic for run management. + +Runs are content-addressed (run_id computed from inputs + recipe). +Completed runs are stored in PostgreSQL, not Redis. +In-progress runs are tracked via Celery task state. +""" + +import hashlib +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, Union, TYPE_CHECKING + +if TYPE_CHECKING: + import redis + from cache_manager import L1CacheManager + from database import Database + +from ..types import RunResult + + +def compute_run_id(input_hashes: Union[List[str], Dict[str, str]], recipe: str, recipe_hash: Optional[str] = None) -> str: + """ + Compute a deterministic run_id from inputs and recipe. + + The run_id is a SHA3-256 hash of: + - Sorted input content hashes + - Recipe identifier (recipe_hash if provided, else "effect:{recipe}") + + This makes runs content-addressable: same inputs + recipe = same run_id. + """ + # Handle both list and dict inputs + if isinstance(input_hashes, dict): + sorted_inputs = sorted(input_hashes.values()) + else: + sorted_inputs = sorted(input_hashes) + + data = { + "inputs": sorted_inputs, + "recipe": recipe_hash or f"effect:{recipe}", + "version": "1", + } + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha3_256(json_str.encode()).hexdigest() + + +def detect_media_type(cache_path: Path) -> str: + """Detect if file is image, video, or audio based on magic bytes.""" + try: + with open(cache_path, "rb") as f: + header = f.read(32) + except Exception: + return "unknown" + + # Video signatures + if header[:4] == b'\x1a\x45\xdf\xa3': # WebM/MKV + return "video" + if len(header) > 8 and header[4:8] == b'ftyp': # MP4/MOV/M4A + # Check for audio-only M4A + if len(header) > 11 and header[8:12] in (b'M4A ', b'm4a '): + return "audio" + return "video" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'AVI ': # AVI + return "video" + + # Image signatures + if header[:8] == b'\x89PNG\r\n\x1a\n': # PNG + return "image" + if header[:2] == b'\xff\xd8': # JPEG + return "image" + if header[:6] in (b'GIF87a', b'GIF89a'): # GIF + return "image" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WEBP': # WebP + return "image" + + # Audio signatures + if header[:3] == b'ID3' or header[:2] == b'\xff\xfb': # MP3 + return "audio" + if header[:4] == b'fLaC': # FLAC + return "audio" + if header[:4] == b'OggS': # Ogg (could be audio or video, assume audio) + return "audio" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WAVE': # WAV + return "audio" + + return "unknown" + + +class RunService: + """ + Service for managing recipe runs. + + Uses PostgreSQL for completed runs, Celery for task state. + Redis is only used for task_id mapping (ephemeral). + """ + + def __init__(self, database: "Database", redis: "redis.Redis[bytes]", cache: "L1CacheManager") -> None: + self.db = database + self.redis = redis # Only for task_id mapping + self.cache = cache + self.task_key_prefix = "artdag:task:" # run_id -> task_id mapping only + self.cache_dir = Path(os.environ.get("CACHE_DIR", "/tmp/artdag-cache")) + + def _ensure_inputs_list(self, inputs: Any) -> List[str]: + """Ensure inputs is a list, parsing JSON string if needed.""" + if inputs is None: + return [] + if isinstance(inputs, list): + return inputs + if isinstance(inputs, str): + try: + parsed = json.loads(inputs) + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + pass + return [] + return [] + + async def get_run(self, run_id: str) -> Optional[RunResult]: + """Get a run by ID. Checks database first, then Celery task state.""" + # Check database for completed run + cached = await self.db.get_run_cache(run_id) + if cached: + output_cid = cached.get("output_cid") + # Only return as completed if we have an output + # (runs with no output should be re-executed) + if output_cid: + # Also fetch recipe content from pending_runs for streaming runs + recipe_sexp = None + recipe_name = None + pending = await self.db.get_pending_run(run_id) + if pending: + recipe_sexp = pending.get("dag_json") + + # Extract recipe name from streaming recipe content + if recipe_sexp: + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', recipe_sexp) + if name_match: + recipe_name = name_match.group(1) + + return { + "run_id": run_id, + "status": "completed", + "recipe": cached.get("recipe"), + "recipe_name": recipe_name, + "inputs": self._ensure_inputs_list(cached.get("inputs")), + "output_cid": output_cid, + "ipfs_cid": cached.get("ipfs_cid"), + "ipfs_playlist_cid": cached.get("ipfs_playlist_cid") or (pending.get("ipfs_playlist_cid") if pending else None), + "provenance_cid": cached.get("provenance_cid"), + "plan_cid": cached.get("plan_cid"), + "actor_id": cached.get("actor_id"), + "created_at": cached.get("created_at"), + "completed_at": cached.get("created_at"), + "recipe_sexp": recipe_sexp, + } + + # Check database for pending run + pending = await self.db.get_pending_run(run_id) + if pending: + task_id = pending.get("celery_task_id") + if task_id: + # Check actual Celery task state + from celery.result import AsyncResult + from celery_app import app as celery_app + + result = AsyncResult(task_id, app=celery_app) + status = result.status.lower() + + # Normalize status + status_map = { + "pending": "pending", + "started": "running", + "rendering": "running", # Custom status from streaming task + "success": "completed", + "failure": "failed", + "retry": "running", + "revoked": "failed", + } + normalized_status = status_map.get(status, status) + + run_data = { + "run_id": run_id, + "status": normalized_status, + "celery_task_id": task_id, + "actor_id": pending.get("actor_id"), + "recipe": pending.get("recipe"), + "inputs": self._ensure_inputs_list(pending.get("inputs")), + "output_name": pending.get("output_name"), + "created_at": pending.get("created_at"), + "error": pending.get("error"), + "recipe_sexp": pending.get("dag_json"), # Recipe content for streaming runs + # Checkpoint fields for resumable renders + "checkpoint_frame": pending.get("checkpoint_frame"), + "checkpoint_t": pending.get("checkpoint_t"), + "total_frames": pending.get("total_frames"), + "resumable": pending.get("resumable", True), + # IPFS streaming info + "ipfs_playlist_cid": pending.get("ipfs_playlist_cid"), + "quality_playlists": pending.get("quality_playlists"), + } + + # If task completed, get result + if result.ready(): + if result.successful(): + task_result = result.result + if isinstance(task_result, dict): + # Check task's own success flag and output_cid + task_success = task_result.get("success", True) + output_cid = task_result.get("output_cid") + if task_success and output_cid: + run_data["status"] = "completed" + run_data["output_cid"] = output_cid + else: + run_data["status"] = "failed" + run_data["error"] = task_result.get("error", "No output produced") + else: + run_data["status"] = "completed" + else: + run_data["status"] = "failed" + run_data["error"] = str(result.result) + + return run_data + + # No task_id but have pending record - return from DB + return { + "run_id": run_id, + "status": pending.get("status", "pending"), + "recipe": pending.get("recipe"), + "inputs": self._ensure_inputs_list(pending.get("inputs")), + "output_name": pending.get("output_name"), + "actor_id": pending.get("actor_id"), + "created_at": pending.get("created_at"), + "error": pending.get("error"), + "recipe_sexp": pending.get("dag_json"), # Recipe content for streaming runs + # Checkpoint fields for resumable renders + "checkpoint_frame": pending.get("checkpoint_frame"), + "checkpoint_t": pending.get("checkpoint_t"), + "total_frames": pending.get("total_frames"), + "resumable": pending.get("resumable", True), + # IPFS streaming info + "ipfs_playlist_cid": pending.get("ipfs_playlist_cid"), + "quality_playlists": pending.get("quality_playlists"), + } + + # Fallback: Check Redis for backwards compatibility + task_data = self.redis.get(f"{self.task_key_prefix}{run_id}") + if task_data: + if isinstance(task_data, bytes): + task_data = task_data.decode() + + # Parse task data (supports both old format string and new JSON format) + try: + parsed = json.loads(task_data) + task_id = parsed.get("task_id") + task_actor_id = parsed.get("actor_id") + task_recipe = parsed.get("recipe") + task_recipe_name = parsed.get("recipe_name") + task_inputs = parsed.get("inputs") + # Ensure inputs is a list (might be JSON string) + if isinstance(task_inputs, str): + try: + task_inputs = json.loads(task_inputs) + except json.JSONDecodeError: + task_inputs = None + task_output_name = parsed.get("output_name") + task_created_at = parsed.get("created_at") + except json.JSONDecodeError: + # Old format: just the task_id string + task_id = task_data + task_actor_id = None + task_recipe = None + task_recipe_name = None + task_inputs = None + task_output_name = None + task_created_at = None + + # Get task state from Celery + from celery.result import AsyncResult + from celery_app import app as celery_app + + result = AsyncResult(task_id, app=celery_app) + status = result.status.lower() + + # Normalize Celery status names + status_map = { + "pending": "pending", + "started": "running", + "rendering": "running", # Custom status from streaming task + "success": "completed", + "failure": "failed", + "retry": "running", + "revoked": "failed", + } + normalized_status = status_map.get(status, status) + + run_data = { + "run_id": run_id, + "status": normalized_status, + "celery_task_id": task_id, + "actor_id": task_actor_id, + "recipe": task_recipe, + "recipe_name": task_recipe_name, + "inputs": self._ensure_inputs_list(task_inputs), + "output_name": task_output_name, + "created_at": task_created_at, + } + + # If task completed, get result + if result.ready(): + if result.successful(): + task_result = result.result + if isinstance(task_result, dict): + # Check task's own success flag and output_cid + task_success = task_result.get("success", True) + output_cid = task_result.get("output_cid") + if task_success and output_cid: + run_data["status"] = "completed" + run_data["output_cid"] = output_cid + else: + run_data["status"] = "failed" + run_data["error"] = task_result.get("error", "No output produced") + else: + run_data["status"] = "completed" + else: + run_data["status"] = "failed" + run_data["error"] = str(result.result) + + return run_data + + return None + + async def list_runs(self, actor_id: str, offset: int = 0, limit: int = 20) -> List[RunResult]: + """List runs for a user. Returns completed and pending runs from database.""" + # Get completed runs from database + completed_runs = await self.db.list_runs_by_actor(actor_id, offset=0, limit=limit + 50) + + # Get pending runs from database + pending_db = await self.db.list_pending_runs(actor_id=actor_id) + + # Convert pending runs to run format with live status check + pending = [] + for pr in pending_db: + run_id = pr.get("run_id") + # Skip if already in completed + if any(r.get("run_id") == run_id for r in completed_runs): + continue + + # Get live status - include pending, running, rendering, and failed runs + run = await self.get_run(run_id) + if run and run.get("status") in ("pending", "running", "rendering", "failed"): + pending.append(run) + + # Combine and sort + all_runs = pending + completed_runs + all_runs.sort(key=lambda r: r.get("created_at", ""), reverse=True) + + return all_runs[offset:offset + limit] + + async def create_run( + self, + recipe: str, + inputs: Union[List[str], Dict[str, str]], + output_name: Optional[str] = None, + use_dag: bool = True, + dag_json: Optional[str] = None, + actor_id: Optional[str] = None, + l2_server: Optional[str] = None, + recipe_name: Optional[str] = None, + recipe_sexp: Optional[str] = None, + ) -> Tuple[Optional[RunResult], Optional[str]]: + """ + Create a new rendering run. Checks cache before executing. + + If recipe_sexp is provided, uses the new S-expression execution path + which generates code-addressed cache IDs before execution. + + Returns (run_dict, error_message). + """ + import httpx + try: + from legacy_tasks import render_effect, execute_dag, build_effect_dag, execute_recipe + except ImportError as e: + return None, f"Celery tasks not available: {e}" + + # Handle both list and dict inputs + if isinstance(inputs, dict): + input_list = list(inputs.values()) + else: + input_list = inputs + + # Compute content-addressable run_id + run_id = compute_run_id(input_list, recipe) + + # Generate output name if not provided + if not output_name: + output_name = f"{recipe}-{run_id[:8]}" + + # Check database cache first (completed runs) + cached_run = await self.db.get_run_cache(run_id) + if cached_run: + output_cid = cached_run.get("output_cid") + if output_cid and self.cache.has_content(output_cid): + return { + "run_id": run_id, + "status": "completed", + "recipe": recipe, + "inputs": input_list, + "output_name": output_name, + "output_cid": output_cid, + "ipfs_cid": cached_run.get("ipfs_cid"), + "provenance_cid": cached_run.get("provenance_cid"), + "created_at": cached_run.get("created_at"), + "completed_at": cached_run.get("created_at"), + "actor_id": actor_id, + }, None + + # Check L2 if not in local cache + if l2_server: + try: + async with httpx.AsyncClient(timeout=10) as client: + l2_resp = await client.get(f"{l2_server}/assets/by-run-id/{run_id}") + if l2_resp.status_code == 200: + l2_data = l2_resp.json() + output_cid = l2_data.get("output_cid") + ipfs_cid = l2_data.get("ipfs_cid") + if output_cid and ipfs_cid: + # Pull from IPFS to local cache + try: + import ipfs_client + legacy_dir = self.cache_dir / "legacy" + legacy_dir.mkdir(parents=True, exist_ok=True) + recovery_path = legacy_dir / output_cid + if ipfs_client.get_file(ipfs_cid, str(recovery_path)): + # Save to database cache + await self.db.save_run_cache( + run_id=run_id, + output_cid=output_cid, + recipe=recipe, + inputs=input_list, + ipfs_cid=ipfs_cid, + provenance_cid=l2_data.get("provenance_cid"), + actor_id=actor_id, + ) + return { + "run_id": run_id, + "status": "completed", + "recipe": recipe, + "inputs": input_list, + "output_cid": output_cid, + "ipfs_cid": ipfs_cid, + "provenance_cid": l2_data.get("provenance_cid"), + "created_at": datetime.now(timezone.utc).isoformat(), + "actor_id": actor_id, + }, None + except Exception: + pass # IPFS recovery failed, continue to run + except Exception: + pass # L2 lookup failed, continue to run + + # Not cached - submit to Celery + try: + # Prefer S-expression execution path (code-addressed cache IDs) + if recipe_sexp: + # Convert inputs to dict if needed + if isinstance(inputs, dict): + input_hashes = inputs + else: + # Legacy list format - use positional names + input_hashes = {f"input_{i}": cid for i, cid in enumerate(input_list)} + + task = execute_recipe.delay(recipe_sexp, input_hashes, run_id) + elif use_dag or recipe == "dag": + if dag_json: + dag_data = dag_json + else: + dag = build_effect_dag(input_list, recipe) + dag_data = dag.to_json() + + task = execute_dag.delay(dag_data, run_id) + else: + if len(input_list) != 1: + return None, "Legacy mode only supports single-input recipes. Use use_dag=true for multi-input." + task = render_effect.delay(input_list[0], recipe, output_name) + + # Store pending run in database for durability + try: + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe, + inputs=input_list, + actor_id=actor_id, + dag_json=dag_json, + output_name=output_name, + ) + except Exception as e: + import logging + logging.getLogger(__name__).error(f"Failed to save pending run: {e}") + + # Also store in Redis for backwards compatibility (shorter TTL) + task_data = json.dumps({ + "task_id": task.id, + "actor_id": actor_id, + "recipe": recipe, + "recipe_name": recipe_name, + "inputs": input_list, + "output_name": output_name, + "created_at": datetime.now(timezone.utc).isoformat(), + }) + self.redis.setex( + f"{self.task_key_prefix}{run_id}", + 3600 * 4, # 4 hour TTL (database is primary now) + task_data + ) + + return { + "run_id": run_id, + "status": "running", + "recipe": recipe, + "recipe_name": recipe_name, + "inputs": input_list, + "output_name": output_name, + "celery_task_id": task.id, + "created_at": datetime.now(timezone.utc).isoformat(), + "actor_id": actor_id, + }, None + + except Exception as e: + return None, f"Failed to submit task: {e}" + + async def discard_run( + self, + run_id: str, + actor_id: str, + username: str, + ) -> Tuple[bool, Optional[str]]: + """ + Discard (delete) a run record and clean up outputs/intermediates. + + Outputs and intermediates are only deleted if not used by other runs. + """ + import logging + logger = logging.getLogger(__name__) + + run = await self.get_run(run_id) + if not run: + return False, f"Run {run_id} not found" + + # Check ownership + run_owner = run.get("actor_id") + if run_owner and run_owner not in (username, actor_id): + return False, "Access denied" + + # Clean up activity outputs/intermediates (only if orphaned) + # The activity_id is the same as run_id + try: + success, msg = self.cache.discard_activity_outputs_only(run_id) + if success: + logger.info(f"Cleaned up run {run_id}: {msg}") + else: + # Activity might not exist (old runs), that's OK + logger.debug(f"No activity cleanup for {run_id}: {msg}") + except Exception as e: + logger.warning(f"Failed to cleanup activity for {run_id}: {e}") + + # Remove task_id mapping from Redis + self.redis.delete(f"{self.task_key_prefix}{run_id}") + + # Remove from run_cache database table + try: + await self.db.delete_run_cache(run_id) + except Exception as e: + logger.warning(f"Failed to delete run_cache for {run_id}: {e}") + + # Remove pending run if exists + try: + await self.db.delete_pending_run(run_id) + except Exception: + pass + + return True, None + + def _dag_to_steps(self, dag: Dict[str, Any]) -> Dict[str, Any]: + """Convert DAG nodes dict format to steps list format. + + DAG format: {"nodes": {"id": {...}}, "output_id": "..."} + Steps format: {"steps": [{"id": "...", "type": "...", ...}], "output_id": "..."} + """ + if "steps" in dag: + # Already in steps format + return dag + + if "nodes" not in dag: + return dag + + nodes = dag.get("nodes", {}) + steps = [] + + # Sort by topological order (sources first, then by input dependencies) + def get_level(node_id: str, visited: set = None) -> int: + if visited is None: + visited = set() + if node_id in visited: + return 0 + visited.add(node_id) + node = nodes.get(node_id, {}) + inputs = node.get("inputs", []) + if not inputs: + return 0 + return 1 + max(get_level(inp, visited) for inp in inputs) + + sorted_ids = sorted(nodes.keys(), key=lambda nid: (get_level(nid), nid)) + + for node_id in sorted_ids: + node = nodes[node_id] + steps.append({ + "id": node_id, + "step_id": node_id, + "type": node.get("node_type", "EFFECT"), + "config": node.get("config", {}), + "inputs": node.get("inputs", []), + "name": node.get("name"), + "cache_id": node_id, # In code-addressed system, node_id IS the cache_id + }) + + return { + "steps": steps, + "output_id": dag.get("output_id"), + "metadata": dag.get("metadata", {}), + "format": "json", + } + + def _sexp_to_steps(self, sexp_content: str) -> Dict[str, Any]: + """Convert S-expression plan to steps list format for UI. + + Parses the S-expression plan format: + (plan :id :recipe :recipe-hash + (inputs (input_name hash) ...) + (step step_id :cache-id :level (node-type :key val ...)) + ... + :output ) + + Returns steps list compatible with UI visualization. + """ + try: + from artdag.sexp import parse, Symbol, Keyword + except ImportError: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + try: + parsed = parse(sexp_content) + except Exception: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + if not isinstance(parsed, list) or not parsed: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + steps = [] + output_step_id = None + plan_id = None + recipe_name = None + + # Parse plan structure + i = 0 + while i < len(parsed): + item = parsed[i] + + if isinstance(item, Keyword): + key = item.name + if i + 1 < len(parsed): + value = parsed[i + 1] + if key == "id": + plan_id = value + elif key == "recipe": + recipe_name = value + elif key == "output": + output_step_id = value + i += 2 + continue + + if isinstance(item, list) and item: + first = item[0] + if isinstance(first, Symbol) and first.name == "step": + # Parse step: (step step_id :cache-id :level (node-expr)) + step_id = item[1] if len(item) > 1 else None + cache_id = None + level = 0 + node_type = "EFFECT" + config = {} + inputs = [] + + j = 2 + while j < len(item): + part = item[j] + if isinstance(part, Keyword): + key = part.name + if j + 1 < len(item): + val = item[j + 1] + if key == "cache-id": + cache_id = val + elif key == "level": + level = val + j += 2 + continue + elif isinstance(part, list) and part: + # Node expression: (node-type :key val ...) + if isinstance(part[0], Symbol): + node_type = part[0].name.upper() + k = 1 + while k < len(part): + if isinstance(part[k], Keyword): + kname = part[k].name + if k + 1 < len(part): + kval = part[k + 1] + if kname == "inputs": + inputs = kval if isinstance(kval, list) else [kval] + else: + config[kname] = kval + k += 2 + continue + k += 1 + j += 1 + + steps.append({ + "id": step_id, + "step_id": step_id, + "type": node_type, + "config": config, + "inputs": inputs, + "cache_id": cache_id or step_id, + "level": level, + }) + + i += 1 + + return { + "sexp": sexp_content, + "steps": steps, + "output_id": output_step_id, + "plan_id": plan_id, + "recipe": recipe_name, + "format": "sexp", + } + + async def get_run_plan(self, run_id: str) -> Optional[Dict[str, Any]]: + """Get execution plan for a run. + + Plans are just node outputs - cached by content hash like everything else. + For streaming runs, returns the recipe content as the plan. + """ + # Get run to find plan_cache_id + run = await self.get_run(run_id) + if not run: + return None + + # For streaming runs, return the recipe as the plan + if run.get("recipe") == "streaming" and run.get("recipe_sexp"): + return { + "steps": [{"id": "stream", "type": "STREAM", "name": "Streaming Recipe"}], + "sexp": run.get("recipe_sexp"), + "format": "sexp", + } + + # Check plan_cid (stored in database) or plan_cache_id (legacy) + plan_cid = run.get("plan_cid") or run.get("plan_cache_id") + if plan_cid: + # Get plan from cache by content hash + plan_path = self.cache.get_by_cid(plan_cid) + if plan_path and plan_path.exists(): + with open(plan_path) as f: + content = f.read() + # Detect format + if content.strip().startswith("("): + # S-expression format - parse for UI + return self._sexp_to_steps(content) + else: + plan = json.loads(content) + return self._dag_to_steps(plan) + + # Fall back to legacy plans directory + sexp_path = self.cache_dir / "plans" / f"{run_id}.sexp" + if sexp_path.exists(): + with open(sexp_path) as f: + return self._sexp_to_steps(f.read()) + + json_path = self.cache_dir / "plans" / f"{run_id}.json" + if json_path.exists(): + with open(json_path) as f: + plan = json.load(f) + return self._dag_to_steps(plan) + + return None + + async def get_run_plan_sexp(self, run_id: str) -> Optional[str]: + """Get execution plan as S-expression string.""" + plan = await self.get_run_plan(run_id) + if plan and plan.get("format") == "sexp": + return plan.get("sexp") + return None + + async def get_run_artifacts(self, run_id: str) -> List[Dict[str, Any]]: + """Get all artifacts (inputs + outputs) for a run.""" + run = await self.get_run(run_id) + if not run: + return [] + + artifacts = [] + + def get_artifact_info(cid: str, role: str, name: str) -> Optional[Dict]: + if self.cache.has_content(cid): + path = self.cache.get_by_cid(cid) + if path and path.exists(): + return { + "cid": cid, + "size_bytes": path.stat().st_size, + "media_type": detect_media_type(path), + "role": role, + "step_name": name, + } + return None + + # Add inputs + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + for i, h in enumerate(inputs): + info = get_artifact_info(h, "input", f"Input {i + 1}") + if info: + artifacts.append(info) + + # Add output + if run.get("output_cid"): + info = get_artifact_info(run["output_cid"], "output", "Output") + if info: + artifacts.append(info) + + return artifacts + + async def get_run_analysis(self, run_id: str) -> List[Dict[str, Any]]: + """Get analysis data for each input in a run.""" + run = await self.get_run(run_id) + if not run: + return [] + + analysis_dir = self.cache_dir / "analysis" + results = [] + + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + + for i, input_hash in enumerate(inputs): + analysis_path = analysis_dir / f"{input_hash}.json" + analysis_data = None + + if analysis_path.exists(): + try: + with open(analysis_path) as f: + analysis_data = json.load(f) + except (json.JSONDecodeError, IOError): + pass + + results.append({ + "input_hash": input_hash, + "input_name": f"Input {i + 1}", + "has_analysis": analysis_data is not None, + "tempo": analysis_data.get("tempo") if analysis_data else None, + "beat_times": analysis_data.get("beat_times", []) if analysis_data else [], + "raw": analysis_data, + }) + + return results + + def detect_media_type(self, path: Path) -> str: + """Detect media type for a file path.""" + return detect_media_type(path) + + async def recover_pending_runs(self) -> Dict[str, Union[int, str]]: + """ + Recover pending runs after restart. + + Checks all pending runs in the database and: + - Updates status for completed tasks + - Re-queues orphaned tasks that can be retried + - Marks as failed if unrecoverable + + Returns counts of recovered, completed, failed runs. + """ + from celery.result import AsyncResult + from celery_app import app as celery_app + + try: + from legacy_tasks import execute_dag + except ImportError: + return {"error": "Celery tasks not available"} + + stats = {"recovered": 0, "completed": 0, "failed": 0, "still_running": 0} + + # Get all pending/running runs from database + pending_runs = await self.db.list_pending_runs() + + for run in pending_runs: + run_id = run.get("run_id") + task_id = run.get("celery_task_id") + status = run.get("status") + + if not task_id: + # No task ID - try to re-queue if we have dag_json + dag_json = run.get("dag_json") + if dag_json: + try: + new_task = execute_dag.delay(dag_json, run_id) + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=new_task.id, + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + actor_id=run.get("actor_id"), + dag_json=dag_json, + output_name=run.get("output_name"), + ) + stats["recovered"] += 1 + except Exception as e: + await self.db.update_pending_run_status( + run_id, "failed", f"Recovery failed: {e}" + ) + stats["failed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", "No DAG data for recovery" + ) + stats["failed"] += 1 + continue + + # Check Celery task state + result = AsyncResult(task_id, app=celery_app) + celery_status = result.status.lower() + + if result.ready(): + if result.successful(): + # Task completed - move to run_cache + task_result = result.result + if isinstance(task_result, dict) and task_result.get("output_cid"): + await self.db.save_run_cache( + run_id=run_id, + output_cid=task_result["output_cid"], + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + ipfs_cid=task_result.get("ipfs_cid"), + provenance_cid=task_result.get("provenance_cid"), + actor_id=run.get("actor_id"), + ) + await self.db.complete_pending_run(run_id) + stats["completed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", "Task completed but no output hash" + ) + stats["failed"] += 1 + else: + # Task failed + await self.db.update_pending_run_status( + run_id, "failed", str(result.result) + ) + stats["failed"] += 1 + elif celery_status in ("pending", "started", "retry"): + # Still running + stats["still_running"] += 1 + else: + # Unknown state - try to re-queue if we have dag_json + dag_json = run.get("dag_json") + if dag_json: + try: + new_task = execute_dag.delay(dag_json, run_id) + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=new_task.id, + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + actor_id=run.get("actor_id"), + dag_json=dag_json, + output_name=run.get("output_name"), + ) + stats["recovered"] += 1 + except Exception as e: + await self.db.update_pending_run_status( + run_id, "failed", f"Recovery failed: {e}" + ) + stats["failed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", f"Task in unknown state: {celery_status}" + ) + stats["failed"] += 1 + + return stats diff --git a/l1/app/services/storage_service.py b/l1/app/services/storage_service.py new file mode 100644 index 0000000..19d4e3c --- /dev/null +++ b/l1/app/services/storage_service.py @@ -0,0 +1,232 @@ +""" +Storage Service - business logic for storage provider management. +""" + +import json +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from database import Database + from storage_providers import StorageProvidersModule + + +STORAGE_PROVIDERS_INFO = { + "pinata": {"name": "Pinata", "desc": "1GB free, IPFS pinning", "color": "blue"}, + "web3storage": {"name": "web3.storage", "desc": "IPFS + Filecoin", "color": "green"}, + "nftstorage": {"name": "NFT.Storage", "desc": "Free for NFTs", "color": "pink"}, + "infura": {"name": "Infura IPFS", "desc": "5GB free", "color": "orange"}, + "filebase": {"name": "Filebase", "desc": "5GB free, S3+IPFS", "color": "cyan"}, + "storj": {"name": "Storj", "desc": "25GB free", "color": "indigo"}, + "local": {"name": "Local Storage", "desc": "Your own disk", "color": "purple"}, +} + +VALID_PROVIDER_TYPES = list(STORAGE_PROVIDERS_INFO.keys()) + + +class StorageService: + """Service for managing user storage providers.""" + + def __init__(self, database: "Database", storage_providers_module: "StorageProvidersModule") -> None: + self.db = database + self.providers = storage_providers_module + + async def list_storages(self, actor_id: str) -> List[Dict[str, Any]]: + """List all storage providers for a user with usage stats.""" + storages = await self.db.get_user_storage(actor_id) + + for storage in storages: + usage = await self.db.get_storage_usage(storage["id"]) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + + # Mask sensitive config keys for display + if storage.get("config"): + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + masked = {} + for k, v in config.items(): + if "key" in k.lower() or "token" in k.lower() or "secret" in k.lower(): + masked[k] = v[:4] + "..." + v[-4:] if len(str(v)) > 8 else "****" + else: + masked[k] = v + storage["config_display"] = masked + + return storages + + async def get_storage(self, storage_id: int, actor_id: str) -> Optional[Dict[str, Any]]: + """Get a specific storage provider.""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return None + if storage["actor_id"] != actor_id: + return None + + usage = await self.db.get_storage_usage(storage_id) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + + return storage + + async def add_storage( + self, + actor_id: str, + provider_type: str, + config: Dict[str, Any], + capacity_gb: int = 5, + provider_name: Optional[str] = None, + description: Optional[str] = None, + ) -> Tuple[Optional[int], Optional[str]]: + """Add a new storage provider. Returns (storage_id, error_message).""" + if provider_type not in VALID_PROVIDER_TYPES: + return None, f"Invalid provider type: {provider_type}" + + # Test connection before saving + provider = self.providers.create_provider(provider_type, { + **config, + "capacity_gb": capacity_gb + }) + if not provider: + return None, "Failed to create provider with given config" + + success, message = await provider.test_connection() + if not success: + return None, f"Provider connection failed: {message}" + + # Generate name if not provided + if not provider_name: + existing = await self.db.get_user_storage_by_type(actor_id, provider_type) + provider_name = f"{provider_type}-{len(existing) + 1}" + + storage_id = await self.db.add_user_storage( + actor_id=actor_id, + provider_type=provider_type, + provider_name=provider_name, + config=config, + capacity_gb=capacity_gb, + description=description + ) + + if not storage_id: + return None, "Failed to save storage provider" + + return storage_id, None + + async def update_storage( + self, + storage_id: int, + actor_id: str, + config: Optional[Dict[str, Any]] = None, + capacity_gb: Optional[int] = None, + is_active: Optional[bool] = None, + ) -> Tuple[bool, Optional[str]]: + """Update a storage provider. Returns (success, error_message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage provider not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + # Test new config if provided + if config: + existing_config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + new_config = {**existing_config, **config} + provider = self.providers.create_provider(storage["provider_type"], { + **new_config, + "capacity_gb": capacity_gb or storage["capacity_gb"] + }) + if provider: + success, message = await provider.test_connection() + if not success: + return False, f"Provider connection failed: {message}" + + success = await self.db.update_user_storage( + storage_id, + config=config, + capacity_gb=capacity_gb, + is_active=is_active + ) + + return success, None if success else "Failed to update storage provider" + + async def delete_storage(self, storage_id: int, actor_id: str) -> Tuple[bool, Optional[str]]: + """Delete a storage provider. Returns (success, error_message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage provider not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + success = await self.db.remove_user_storage(storage_id) + return success, None if success else "Failed to remove storage provider" + + async def test_storage(self, storage_id: int, actor_id: str) -> Tuple[bool, str]: + """Test storage provider connectivity. Returns (success, message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + provider = self.providers.create_provider(storage["provider_type"], { + **config, + "capacity_gb": storage["capacity_gb"] + }) + + if not provider: + return False, "Failed to create provider" + + return await provider.test_connection() + + async def list_by_type(self, actor_id: str, provider_type: str) -> List[Dict[str, Any]]: + """List storage providers of a specific type.""" + return await self.db.get_user_storage_by_type(actor_id, provider_type) + + def build_config_from_form(self, provider_type: str, form_data: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """Build provider config from form data. Returns (config, error).""" + api_key = form_data.get("api_key") + secret_key = form_data.get("secret_key") + api_token = form_data.get("api_token") + project_id = form_data.get("project_id") + project_secret = form_data.get("project_secret") + access_key = form_data.get("access_key") + bucket = form_data.get("bucket") + path = form_data.get("path") + + if provider_type == "pinata": + if not api_key or not secret_key: + return None, "Pinata requires API Key and Secret Key" + return {"api_key": api_key, "secret_key": secret_key}, None + + elif provider_type == "web3storage": + if not api_token: + return None, "web3.storage requires API Token" + return {"api_token": api_token}, None + + elif provider_type == "nftstorage": + if not api_token: + return None, "NFT.Storage requires API Token" + return {"api_token": api_token}, None + + elif provider_type == "infura": + if not project_id or not project_secret: + return None, "Infura requires Project ID and Project Secret" + return {"project_id": project_id, "project_secret": project_secret}, None + + elif provider_type == "filebase": + if not access_key or not secret_key or not bucket: + return None, "Filebase requires Access Key, Secret Key, and Bucket" + return {"access_key": access_key, "secret_key": secret_key, "bucket": bucket}, None + + elif provider_type == "storj": + if not access_key or not secret_key or not bucket: + return None, "Storj requires Access Key, Secret Key, and Bucket" + return {"access_key": access_key, "secret_key": secret_key, "bucket": bucket}, None + + elif provider_type == "local": + if not path: + return None, "Local storage requires a path" + return {"path": path}, None + + return None, f"Unknown provider type: {provider_type}" diff --git a/l1/app/templates/404.html b/l1/app/templates/404.html new file mode 100644 index 0000000..0cd9c70 --- /dev/null +++ b/l1/app/templates/404.html @@ -0,0 +1,14 @@ +{% extends "base.html" %} + +{% block title %}Not Found - Art-DAG L1{% endblock %} + +{% block content %} +
+

404

+

Page Not Found

+

The page you're looking for doesn't exist or has been moved.

+ + Go Home + +
+{% endblock %} diff --git a/l1/app/templates/base.html b/l1/app/templates/base.html new file mode 100644 index 0000000..9be32fb --- /dev/null +++ b/l1/app/templates/base.html @@ -0,0 +1,46 @@ +{% extends "_base.html" %} + +{% block brand %} +Rose Ash +| +Art-DAG +{% endblock %} + +{% block cart_mini %} +{% if request and request.state.cart_mini_html %} + {{ request.state.cart_mini_html | safe }} +{% endif %} +{% endblock %} + +{% block nav_tree %} +{% if request and request.state.nav_tree_html %} + {{ request.state.nav_tree_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu_mobile %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block sub_nav %} + +{% endblock %} diff --git a/l1/app/templates/cache/detail.html b/l1/app/templates/cache/detail.html new file mode 100644 index 0000000..da30119 --- /dev/null +++ b/l1/app/templates/cache/detail.html @@ -0,0 +1,182 @@ +{% extends "base.html" %} + +{% block title %}{{ cache.cid[:16] }} - Cache - Art-DAG L1{% endblock %} + +{% block content %} +
+ +
+ ← Media +

{{ cache.cid[:24] }}...

+
+ + +
+ {% if cache.mime_type and cache.mime_type.startswith('image/') %} + {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} + + {% elif cache.mime_type and cache.mime_type.startswith('video/') %} + {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} + + {% elif cache.mime_type and cache.mime_type.startswith('audio/') %} +
+ {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} +
+ + {% elif cache.mime_type == 'application/json' %} +
+
{{ cache.content_preview }}
+
+ + {% else %} +
+
{{ cache.mime_type or 'Unknown type' }}
+
{{ cache.size | filesizeformat if cache.size else 'Unknown size' }}
+
+ {% endif %} +
+ + +
+
+ Friendly Name + +
+ {% if cache.friendly_name %} +

{{ cache.friendly_name }}

+

Use in recipes: {{ cache.base_name }}

+ {% else %} +

No friendly name assigned. Click Edit to add one.

+ {% endif %} +
+ + +
+
+

Details

+ +
+ {% if cache.title or cache.description or cache.filename %} +
+ {% if cache.title %} +

{{ cache.title }}

+ {% elif cache.filename %} +

{{ cache.filename }}

+ {% endif %} + {% if cache.description %} +

{{ cache.description }}

+ {% endif %} +
+ {% else %} +

No title or description set. Click Edit to add metadata.

+ {% endif %} + {% if cache.tags %} +
+ {% for tag in cache.tags %} + {{ tag }} + {% endfor %} +
+ {% endif %} + {% if cache.source_type or cache.source_note %} +
+ {% if cache.source_type %}Source: {{ cache.source_type }}{% endif %} + {% if cache.source_note %} - {{ cache.source_note }}{% endif %} +
+ {% endif %} +
+ + +
+
+
CID
+
{{ cache.cid }}
+
+
+
Content Type
+
{{ cache.mime_type or 'Unknown' }}
+
+
+
Size
+
{{ cache.size | filesizeformat if cache.size else 'Unknown' }}
+
+
+
Created
+
{{ cache.created_at or 'Unknown' }}
+
+
+ + + {% if cache.ipfs_cid %} +
+
IPFS CID
+
+ {{ cache.ipfs_cid }} + + View on IPFS Gateway → + +
+
+ {% endif %} + + + {% if cache.runs %} +

Related Runs

+
+ {% for run in cache.runs %} + +
+ {{ run.run_id[:16] }}... + {{ run.created_at }} +
+
+ {% endfor %} +
+ {% endif %} + + +
+ + Download + + + +
+
+{% endblock %} diff --git a/l1/app/templates/cache/media_list.html b/l1/app/templates/cache/media_list.html new file mode 100644 index 0000000..0a436aa --- /dev/null +++ b/l1/app/templates/cache/media_list.html @@ -0,0 +1,325 @@ +{% extends "base.html" %} + +{% block title %}Media - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Media

+
+ + +
+
+ + + + + {% if items %} +
+ {% for item in items %} + {# Determine media category from type or filename #} + {% set is_image = item.type in ('image', 'image/jpeg', 'image/png', 'image/gif', 'image/webp') or (item.filename and item.filename.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))) %} + {% set is_video = item.type in ('video', 'video/mp4', 'video/webm', 'video/x-matroska') or (item.filename and item.filename.lower().endswith(('.mp4', '.mkv', '.webm', '.mov'))) %} + {% set is_audio = item.type in ('audio', 'audio/mpeg', 'audio/wav', 'audio/flac') or (item.filename and item.filename.lower().endswith(('.mp3', '.wav', '.flac', '.ogg'))) %} + + + + {% if is_image %} + + + {% elif is_video %} +
+ +
+
+ + + +
+
+
+ + {% elif is_audio %} +
+ + + + Audio +
+ + {% else %} +
+ {{ item.type or 'Media' }} +
+ {% endif %} + +
+ {% if item.friendly_name %} +
{{ item.friendly_name }}
+ {% else %} +
{{ item.cid[:16] }}...
+ {% endif %} + {% if item.filename %} +
{{ item.filename }}
+ {% endif %} +
+
+ {% endfor %} +
+ + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+ + + +

No media files yet

+

Run a recipe to generate media artifacts.

+
+ {% endif %} +
+ + +{% endblock %} diff --git a/l1/app/templates/cache/not_found.html b/l1/app/templates/cache/not_found.html new file mode 100644 index 0000000..600a77d --- /dev/null +++ b/l1/app/templates/cache/not_found.html @@ -0,0 +1,21 @@ +{% extends "base.html" %} + +{% block title %}Content Not Found - Art-DAG L1{% endblock %} + +{% block content %} +
+

404

+

Content Not Found

+

+ The content with hash {{ cid[:24] if cid else 'unknown' }}... was not found in the cache. +

+ +
+{% endblock %} diff --git a/l1/app/templates/effects/detail.html b/l1/app/templates/effects/detail.html new file mode 100644 index 0000000..572f586 --- /dev/null +++ b/l1/app/templates/effects/detail.html @@ -0,0 +1,203 @@ +{% extends "base.html" %} + +{% set meta = effect.meta or effect %} + +{% block title %}{{ meta.name or 'Effect' }} - Effects - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + + +{% endblock %} + +{% block content %} +
+ +
+ ← Effects +

{{ meta.name or 'Unnamed Effect' }}

+ v{{ meta.version or '1.0.0' }} + {% if meta.temporal %} + temporal + {% endif %} +
+ + {% if meta.author %} +

by {{ meta.author }}

+ {% endif %} + + {% if meta.description %} +

{{ meta.description }}

+ {% endif %} + + +
+ {% if effect.friendly_name %} +
+ Friendly Name +

{{ effect.friendly_name }}

+

Use in recipes: (effect {{ effect.base_name }})

+
+ {% endif %} +
+
+ Content ID (CID) +

{{ effect.cid }}

+
+ +
+ {% if effect.uploaded_at %} +
+ Uploaded: {{ effect.uploaded_at }} + {% if effect.uploader %} + by {{ effect.uploader }} + {% endif %} +
+ {% endif %} +
+ +
+ +
+ + {% if meta.params %} +
+
+ Parameters +
+
+ {% for param in meta.params %} +
+
+ {{ param.name }} + {{ param.type }} +
+ {% if param.description %} +

{{ param.description }}

+ {% endif %} +
+ {% if param.range %} + range: {{ param.range[0] }} - {{ param.range[1] }} + {% endif %} + {% if param.default is defined %} + default: {{ param.default }} + {% endif %} +
+
+ {% endfor %} +
+
+ {% endif %} + + +
+
+ Usage in Recipe +
+
+ {% if effect.base_name %} +
({{ effect.base_name }} ...)
+

+ Use the friendly name to reference this effect. +

+ {% else %} +
(effect :cid "{{ effect.cid }}")
+

+ Reference this effect by CID in your recipe. +

+ {% endif %} +
+
+
+ + +
+
+
+ Source Code (S-expression) + +
+
+
Loading...
+
+
+
+
+ + +
+ {% if effect.cid.startswith('Qm') or effect.cid.startswith('bafy') %} + + View on IPFS + + {% endif %} + + + +
+
+ + +{% endblock %} diff --git a/l1/app/templates/effects/list.html b/l1/app/templates/effects/list.html new file mode 100644 index 0000000..065d2bb --- /dev/null +++ b/l1/app/templates/effects/list.html @@ -0,0 +1,200 @@ +{% extends "base.html" %} + +{% block title %}Effects - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Effects

+ +
+ + + + +

+ Effects are S-expression files that define video processing operations. + Each effect is stored in IPFS and can be referenced by name in recipes. +

+ + {% if effects %} + + + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+ + + +

No effects uploaded yet.

+

+ Effects are S-expression files with metadata in comment headers. +

+ +
+ {% endif %} +
+ + +{% endblock %} diff --git a/l1/app/templates/fragments/link_card.html b/l1/app/templates/fragments/link_card.html new file mode 100644 index 0000000..ecc4450 --- /dev/null +++ b/l1/app/templates/fragments/link_card.html @@ -0,0 +1,22 @@ + +
+
+ {% if content_type == "recipe" %} + + {% elif content_type == "effect" %} + + {% elif content_type == "run" %} + + {% else %} + + {% endif %} +
+
+
{{ title }}
+ {% if description %} +
{{ description }}
+ {% endif %} +
{{ content_type }} · {{ cid[:12] }}…
+
+
+
diff --git a/l1/app/templates/fragments/nav_item.html b/l1/app/templates/fragments/nav_item.html new file mode 100644 index 0000000..e987cc5 --- /dev/null +++ b/l1/app/templates/fragments/nav_item.html @@ -0,0 +1,7 @@ + diff --git a/l1/app/templates/home.html b/l1/app/templates/home.html new file mode 100644 index 0000000..c6a23aa --- /dev/null +++ b/l1/app/templates/home.html @@ -0,0 +1,51 @@ +{% extends "base.html" %} + +{% block title %}Art-DAG L1{% endblock %} + +{% block content %} +
+

Art-DAG L1

+

Content-Addressable Media Processing

+ + + + {% if not user %} +
+

Sign in through your L2 server to access all features.

+ Sign In → +
+ {% endif %} + + {% if readme_html %} +
+ {{ readme_html | safe }} +
+ {% endif %} +
+{% endblock %} diff --git a/l1/app/templates/recipes/detail.html b/l1/app/templates/recipes/detail.html new file mode 100644 index 0000000..daf134a --- /dev/null +++ b/l1/app/templates/recipes/detail.html @@ -0,0 +1,265 @@ +{% extends "base.html" %} + +{% block title %}{{ recipe.name }} - Recipe - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + +{% endblock %} + +{% block content %} +
+ +
+ ← Recipes +

{{ recipe.name or 'Unnamed Recipe' }}

+ {% if recipe.version %} + v{{ recipe.version }} + {% endif %} +
+ + {% if recipe.description %} +

{{ recipe.description }}

+ {% endif %} + + +
+
+
+ Recipe ID +

{{ recipe.recipe_id[:16] }}...

+
+ {% if recipe.ipfs_cid %} +
+ IPFS CID +

{{ recipe.ipfs_cid[:16] }}...

+
+ {% endif %} +
+ Steps +

{{ recipe.step_count or recipe.steps|length }}

+
+ {% if recipe.author %} +
+ Author +

{{ recipe.author }}

+
+ {% endif %} +
+
+ + {% if recipe.type == 'streaming' %} + +
+
+ Streaming Recipe +
+

+ This recipe uses frame-by-frame streaming rendering. The pipeline is defined as an S-expression that generates frames dynamically. +

+
+ {% else %} + +
+
+ Pipeline DAG + {{ recipe.steps | length }} steps +
+
+
+ + +

Steps

+
+ {% for step in recipe.steps %} + {% set colors = { + 'effect': 'blue', + 'analyze': 'purple', + 'transform': 'green', + 'combine': 'orange', + 'output': 'cyan' + } %} + {% set color = colors.get(step.type, 'gray') %} + +
+
+
+ + {{ loop.index }} + + {{ step.name }} + + {{ step.type }} + +
+
+ + {% if step.inputs %} +
+ Inputs: {{ step.inputs | join(', ') }} +
+ {% endif %} + + {% if step.params %} +
+ {{ step.params | tojson }} +
+ {% endif %} +
+ {% endfor %} +
+ {% endif %} + + +

Recipe (S-expression)

+
+ {% if recipe.sexp %} +
{{ recipe.sexp }}
+ {% else %} +

No source available

+ {% endif %} +
+ + + + +
+ + {% if recipe.ipfs_cid %} + + View on IPFS + + {% elif recipe.recipe_id.startswith('Qm') or recipe.recipe_id.startswith('bafy') %} + + View on IPFS + + {% endif %} + + + +
+
+ + +{% endblock %} diff --git a/l1/app/templates/recipes/list.html b/l1/app/templates/recipes/list.html new file mode 100644 index 0000000..0cd484f --- /dev/null +++ b/l1/app/templates/recipes/list.html @@ -0,0 +1,136 @@ +{% extends "base.html" %} + +{% block title %}Recipes - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Recipes

+ +
+ +

+ Recipes define processing pipelines for audio and media. Each recipe is a DAG of effects. +

+ + {% if recipes %} + + + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+

No recipes available.

+

+ Recipes are S-expression files (.sexp) that define processing pipelines. +

+ +
+ {% endif %} +
+ +
+ + +{% endblock %} diff --git a/l1/app/templates/runs/_run_card.html b/l1/app/templates/runs/_run_card.html new file mode 100644 index 0000000..88a42a2 --- /dev/null +++ b/l1/app/templates/runs/_run_card.html @@ -0,0 +1,89 @@ +{# Run card partial - expects 'run' variable #} +{% set status_colors = { + 'completed': 'green', + 'running': 'blue', + 'pending': 'yellow', + 'failed': 'red', + 'cached': 'purple' +} %} +{% set color = status_colors.get(run.status, 'gray') %} + + +
+
+ {{ run.run_id[:12] }}... + + {{ run.status }} + + {% if run.cached %} + cached + {% endif %} +
+ {{ run.created_at }} +
+ +
+
+ + Recipe: {{ run.recipe_name or (run.recipe[:12] ~ '...' if run.recipe and run.recipe|length > 12 else run.recipe) or 'Unknown' }} + + {% if run.total_steps %} + + Steps: {{ run.executed or 0 }}/{{ run.total_steps }} + + {% endif %} +
+
+ + {# Media previews row #} +
+ {# Input previews #} + {% if run.input_previews %} +
+ In: + {% for inp in run.input_previews %} + {% if inp.media_type and inp.media_type.startswith('image/') %} + + {% elif inp.media_type and inp.media_type.startswith('video/') %} + + {% else %} +
?
+ {% endif %} + {% endfor %} + {% if run.inputs and run.inputs|length > 3 %} + +{{ run.inputs|length - 3 }} + {% endif %} +
+ {% elif run.inputs %} +
+ {{ run.inputs|length }} input(s) +
+ {% endif %} + + {# Arrow #} + -> + + {# Output preview - prefer IPFS URLs when available #} + {% if run.output_cid %} +
+ Out: + {% if run.output_media_type and run.output_media_type.startswith('image/') %} + + {% elif run.output_media_type and run.output_media_type.startswith('video/') %} + + {% else %} +
?
+ {% endif %} +
+ {% else %} + No output yet + {% endif %} + +
+ + {% if run.output_cid %} + {{ run.output_cid[:12] }}... + {% endif %} +
+
diff --git a/l1/app/templates/runs/artifacts.html b/l1/app/templates/runs/artifacts.html new file mode 100644 index 0000000..874188c --- /dev/null +++ b/l1/app/templates/runs/artifacts.html @@ -0,0 +1,62 @@ +{% extends "base.html" %} + +{% block title %}Run Artifacts{% endblock %} + +{% block content %} + + +

Run Artifacts

+ +{% if artifacts %} +
+ {% for artifact in artifacts %} +
+
+ + {{ artifact.role }} + + {{ artifact.step_name }} +
+ +
+

Content Hash

+

{{ artifact.hash }}

+
+ +
+ + {% if artifact.media_type == 'video' %}Video + {% elif artifact.media_type == 'image' %}Image + {% elif artifact.media_type == 'audio' %}Audio + {% else %}File{% endif %} + + {{ (artifact.size_bytes / 1024)|round(1) }} KB +
+ + +
+ {% endfor %} +
+{% else %} +
+

No artifacts found for this run.

+
+{% endif %} +{% endblock %} diff --git a/l1/app/templates/runs/detail.html b/l1/app/templates/runs/detail.html new file mode 100644 index 0000000..ae87dd3 --- /dev/null +++ b/l1/app/templates/runs/detail.html @@ -0,0 +1,1073 @@ +{% extends "base.html" %} + +{% block title %}Run {{ run.run_id[:12] }} - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + + +{% endblock %} + +{% block content %} +{% set status_colors = {'completed': 'green', 'running': 'blue', 'pending': 'yellow', 'failed': 'red', 'paused': 'yellow'} %} +{% set color = status_colors.get(run.status, 'gray') %} + +
+ +
+ ← Runs +

{{ run.run_id[:16] }}...

+ + {{ run.status }} + + {% if run.cached %} + Cached + {% endif %} + {% if run.error %} + {{ run.error }} + {% endif %} + {% if run.checkpoint_frame %} + + Checkpoint: {{ run.checkpoint_frame }}{% if run.total_frames %} / {{ run.total_frames }}{% endif %} frames + + {% endif %} +
+ + + {% if run.status == 'running' %} + + {% endif %} + + + {% if run.status in ['failed', 'paused'] %} + {% if run.checkpoint_frame %} + + {% endif %} + + {% endif %} + + {% if run.recipe %} + + {% endif %} + + + +
+ + +
+
+
Recipe
+
+ {% if run.recipe %} + + {{ run.recipe_name or (run.recipe[:16] ~ '...') }} + + {% else %} + Unknown + {% endif %} +
+
+
+
Steps
+
+ {% if run.recipe == 'streaming' %} + {% if run.status == 'completed' %}1 / 1{% else %}0 / 1{% endif %} + {% else %} + {{ run.executed or 0 }} / {{ run.total_steps or (plan.steps|length if plan and plan.steps else '?') }} + {% endif %} + {% if run.cached_steps %} + ({{ run.cached_steps }} cached) + {% endif %} +
+
+
+
Created
+
{{ run.created_at }}
+
+
+
User
+
{{ run.username or 'Unknown' }}
+
+
+ + + {% if run.status == 'rendering' or run.ipfs_playlist_cid or (run.status in ['paused', 'failed'] and run.checkpoint_frame) %} +
+
+

+ {% if run.status == 'rendering' %} + + Live Preview + {% elif run.status == 'paused' %} + + Partial Output (Paused) + {% elif run.status == 'failed' and run.checkpoint_frame %} + + Partial Output (Failed) + {% else %} + + Video + {% endif %} +

+
+ +
+ + +
+
Connecting...
+
+
+
+ +
+
+
+
Waiting for stream...
+
+
+
+
+ Stream: /runs/{{ run.run_id }}/playlist.m3u8 + +
+
+ + + {% endif %} + + +
+ +
+ + +
+ {% if plan %} +
+ +
+
+
+ +
+
+ Click a node to view details +
+ +
+
+ + +
+ {% for step in plan.steps %} + {% set step_color = 'green' if step.status == 'completed' or step.cache_id else ('purple' if step.cached else ('blue' if step.status == 'running' else 'gray')) %} +
+
+
+ + {{ loop.index }} + + {{ step.name }} + {{ step.type }} +
+
+ {% if step.cached %} + cached + {% elif step.status == 'completed' %} + completed + {% endif %} +
+
+ {% if step.cache_id %} + + {% endif %} +
+ {% endfor %} +
+ + + {% if plan_sexp %} +
+ + Recipe (S-expression) + {% if recipe_ipfs_cid %} + + ipfs://{{ recipe_ipfs_cid[:16] }}... + + {% endif %} + +
+
{{ plan_sexp }}
+
+
+ {% endif %} + + + + {% else %} +

No plan available for this run.

+ {% endif %} +
+ + + + + + + + + + + + {% if run.output_cid %} +
+

Output

+ + {# Inline media preview - prefer IPFS URLs when available #} +
+ {% if output_media_type and output_media_type.startswith('image/') %} + + Output + + {% elif output_media_type and output_media_type.startswith('video/') %} + {# HLS streams use the unified player above; show direct video for non-HLS #} + {% if run.ipfs_playlist_cid %} +
+ HLS stream available in player above. Use "From Start" to watch from beginning or "Live Edge" to follow rendering progress. +
+ {% else %} + {# Direct video file #} + + {% endif %} + {% elif output_media_type and output_media_type.startswith('audio/') %} + + {% else %} +
+
?
+
{{ output_media_type or 'Unknown media type' }}
+
+ {% endif %} +
+ +
+ + {% if run.ipfs_cid %}{{ run.ipfs_cid }}{% else %}{{ run.output_cid }}{% endif %} + +
+ {% if run.ipfs_playlist_cid %} + + HLS Playlist + + {% endif %} + {% if run.ipfs_cid %} + + View on IPFS Gateway + + {% endif %} +
+
+
+ {% endif %} +
+ + +{% endblock %} diff --git a/l1/app/templates/runs/list.html b/l1/app/templates/runs/list.html new file mode 100644 index 0000000..8d72415 --- /dev/null +++ b/l1/app/templates/runs/list.html @@ -0,0 +1,45 @@ +{% extends "base.html" %} + +{% block title %}Runs - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Execution Runs

+ Browse Recipes → +
+ + {% if runs %} +
+ {% for run in runs %} + {% include "runs/_run_card.html" %} + {% endfor %} +
+ + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+
+ + + +

No runs yet

+
+

Execute a recipe to see your runs here.

+ + Browse Recipes + +
+ {% endif %} +
+{% endblock %} diff --git a/l1/app/templates/runs/plan.html b/l1/app/templates/runs/plan.html new file mode 100644 index 0000000..f50090d --- /dev/null +++ b/l1/app/templates/runs/plan.html @@ -0,0 +1,99 @@ +{% extends "base.html" %} + +{% block title %}Run Plan - {{ run_id[:16] }}{% endblock %} + +{% block head %} + +{% endblock %} + +{% block content %} + + +

Execution Plan

+ +{% if plan %} +
+ +
+

DAG Visualization

+
+
+ + +
+

Steps ({{ plan.steps|length if plan.steps else 0 }})

+
+ {% for step in plan.get('steps', []) %} +
+
+ {{ step.name or step.id or 'Step ' ~ loop.index }} + + {{ step.status or ('cached' if step.cached else 'pending') }} + +
+ {% if step.cache_id %} +
+ {{ step.cache_id[:24] }}... +
+ {% endif %} +
+ {% else %} +

No steps defined

+ {% endfor %} +
+
+
+ + +{% else %} +
+

No execution plan available for this run.

+
+{% endif %} +{% endblock %} diff --git a/l1/app/templates/runs/plan_node.html b/l1/app/templates/runs/plan_node.html new file mode 100644 index 0000000..99e1658 --- /dev/null +++ b/l1/app/templates/runs/plan_node.html @@ -0,0 +1,99 @@ +{# Plan node detail panel - loaded via HTMX #} +{% set status_color = 'green' if status in ('cached', 'completed') else 'yellow' %} + +
+
+

{{ step.name or step.step_id[:20] }}

+
+ + {{ step.node_type or 'EFFECT' }} + + {{ status }} + Level {{ step.level or 0 }} +
+
+ +
+ +{# Output preview #} +{% if output_preview %} +
+
Output
+ {% if output_media_type == 'video' %} + + {% elif output_media_type == 'image' %} + + {% elif output_media_type == 'audio' %} + + {% endif %} +
+{% elif ipfs_cid %} +
+
Output (IPFS)
+ +
+{% endif %} + +{# Output link #} +{% if ipfs_cid %} + + {{ ipfs_cid[:24] }}... + View + +{% elif has_cached and cache_id %} + + {{ cache_id[:24] }}... + View + +{% endif %} + +{# Input media previews #} +{% if inputs %} + +{% endif %} + +{# Parameters/Config #} +{% if config %} +
+
Parameters
+
+ {% for key, value in config.items() %} +
+ {{ key }}: + {{ value if value is string else value|tojson }} +
+ {% endfor %} +
+
+{% endif %} + +{# Metadata #} +
+
Step ID: {{ step.step_id[:32] }}...
+
Cache ID: {{ cache_id[:32] }}...
+
diff --git a/l1/app/templates/storage/list.html b/l1/app/templates/storage/list.html new file mode 100644 index 0000000..a33f98a --- /dev/null +++ b/l1/app/templates/storage/list.html @@ -0,0 +1,90 @@ +{% extends "base.html" %} + +{% block title %}Storage Providers - Art-DAG L1{% endblock %} + +{% block content %} +
+

Storage Providers

+ +

+ Configure your IPFS pinning services. Data is pinned to your accounts, giving you full control. +

+ + + + + + {% if storages %} +

Your Storage Providers

+
+ {% for storage in storages %} + {% set info = providers_info.get(storage.provider_type, {'name': storage.provider_type, 'color': 'gray'}) %} +
+
+
+ {{ storage.provider_name or info.name }} + {% if storage.is_active %} + Active + {% else %} + Inactive + {% endif %} +
+
+ + +
+
+ +
+
+ Capacity: + {{ storage.capacity_gb }} GB +
+
+ Used: + {{ (storage.used_bytes / 1024 / 1024 / 1024) | round(2) }} GB +
+
+ Pins: + {{ storage.pin_count }} +
+
+ +
+
+ {% endfor %} +
+ {% else %} +
+

No storage providers configured yet.

+

Click on a provider above to add your first one.

+
+ {% endif %} +
+{% endblock %} diff --git a/l1/app/templates/storage/type.html b/l1/app/templates/storage/type.html new file mode 100644 index 0000000..851c633 --- /dev/null +++ b/l1/app/templates/storage/type.html @@ -0,0 +1,152 @@ +{% extends "base.html" %} + +{% block title %}{{ provider_info.name }} - Storage - Art-DAG L1{% endblock %} + +{% block content %} +
+
+ ← All Providers +

{{ provider_info.name }}

+
+ +

{{ provider_info.desc }}

+ + +
+

Add {{ provider_info.name }} Account

+ +
+ + +
+ + +
+ + {% if provider_type == 'pinata' %} +
+
+ + +
+
+ + +
+
+ + {% elif provider_type in ['web3storage', 'nftstorage'] %} +
+ + +
+ + {% elif provider_type == 'infura' %} +
+
+ + +
+
+ + +
+
+ + {% elif provider_type in ['filebase', 'storj'] %} +
+
+ + +
+
+ + +
+
+
+ + +
+ + {% elif provider_type == 'local' %} +
+ + +
+ {% endif %} + +
+ + +
+ +
+ +
+ +
+
+
+ + + {% if storages %} +

Configured Accounts

+
+ {% for storage in storages %} +
+
+
+ {{ storage.provider_name }} + {% if storage.is_active %} + Active + {% endif %} +
+
+ + +
+
+ + {% if storage.config_display %} +
+ {% for key, value in storage.config_display.items() %} + {{ key }}: {{ value }} + {% endfor %} +
+ {% endif %} + +
+
+ {% endfor %} +
+ {% endif %} +
+{% endblock %} diff --git a/l1/app/types.py b/l1/app/types.py new file mode 100644 index 0000000..15d81c6 --- /dev/null +++ b/l1/app/types.py @@ -0,0 +1,197 @@ +""" +Type definitions for Art DAG L1 server. + +Uses TypedDict for configuration structures to enable mypy checking. +""" + +from typing import Any, Dict, List, Optional, TypedDict, Union +from typing_extensions import NotRequired + + +# === Node Config Types === + +class SourceConfig(TypedDict, total=False): + """Config for SOURCE nodes.""" + cid: str # Content ID (IPFS CID or SHA3-256 hash) + asset: str # Asset name from registry + input: bool # True if this is a variable input + name: str # Human-readable name for variable inputs + description: str # Description for variable inputs + + +class EffectConfig(TypedDict, total=False): + """Config for EFFECT nodes.""" + effect: str # Effect name + cid: str # Effect CID (for cached/IPFS effects) + # Effect parameters are additional keys + intensity: float + level: float + + +class SequenceConfig(TypedDict, total=False): + """Config for SEQUENCE nodes.""" + transition: Dict[str, Any] # Transition config + + +class SegmentConfig(TypedDict, total=False): + """Config for SEGMENT nodes.""" + start: float + end: float + duration: float + + +# Union of all config types +NodeConfig = Union[SourceConfig, EffectConfig, SequenceConfig, SegmentConfig, Dict[str, Any]] + + +# === Node Types === + +class CompiledNode(TypedDict): + """Node as produced by the S-expression compiler.""" + id: str + type: str # "SOURCE", "EFFECT", "SEQUENCE", etc. + config: Dict[str, Any] + inputs: List[str] + name: NotRequired[str] + + +class TransformedNode(TypedDict): + """Node after transformation for artdag execution.""" + node_id: str + node_type: str + config: Dict[str, Any] + inputs: List[str] + name: NotRequired[str] + + +# === DAG Types === + +class CompiledDAG(TypedDict): + """DAG as produced by the S-expression compiler.""" + nodes: List[CompiledNode] + output: str + + +class TransformedDAG(TypedDict): + """DAG after transformation for artdag execution.""" + nodes: Dict[str, TransformedNode] + output_id: str + metadata: NotRequired[Dict[str, Any]] + + +# === Registry Types === + +class AssetEntry(TypedDict, total=False): + """Asset in the recipe registry.""" + cid: str + url: str + + +class EffectEntry(TypedDict, total=False): + """Effect in the recipe registry.""" + cid: str + url: str + temporal: bool + + +class Registry(TypedDict): + """Recipe registry containing assets and effects.""" + assets: Dict[str, AssetEntry] + effects: Dict[str, EffectEntry] + + +# === Visualization Types === + +class VisNodeData(TypedDict, total=False): + """Data for a visualization node (Cytoscape.js format).""" + id: str + label: str + nodeType: str + isOutput: bool + + +class VisNode(TypedDict): + """Visualization node wrapper.""" + data: VisNodeData + + +class VisEdgeData(TypedDict): + """Data for a visualization edge.""" + source: str + target: str + + +class VisEdge(TypedDict): + """Visualization edge wrapper.""" + data: VisEdgeData + + +class VisualizationDAG(TypedDict): + """DAG structure for Cytoscape.js visualization.""" + nodes: List[VisNode] + edges: List[VisEdge] + + +# === Recipe Types === + +class Recipe(TypedDict, total=False): + """Compiled recipe structure.""" + name: str + version: str + description: str + owner: str + registry: Registry + dag: CompiledDAG + recipe_id: str + ipfs_cid: str + sexp: str + step_count: int + error: str + + +# === API Request/Response Types === + +class RecipeRunInputs(TypedDict): + """Mapping of input names to CIDs for recipe execution.""" + # Keys are input names, values are CIDs + pass # Actually just Dict[str, str] + + +class RunResult(TypedDict, total=False): + """Result of a recipe run.""" + run_id: str + status: str # "pending", "running", "completed", "failed" + recipe: str + recipe_name: str + inputs: List[str] + output_cid: str + ipfs_cid: str + provenance_cid: str + error: str + created_at: str + completed_at: str + actor_id: str + celery_task_id: str + output_name: str + + +# === Helper functions for type narrowing === + +def is_source_node(node: TransformedNode) -> bool: + """Check if node is a SOURCE node.""" + return node.get("node_type") == "SOURCE" + + +def is_effect_node(node: TransformedNode) -> bool: + """Check if node is an EFFECT node.""" + return node.get("node_type") == "EFFECT" + + +def is_variable_input(config: Dict[str, Any]) -> bool: + """Check if a SOURCE node config represents a variable input.""" + return bool(config.get("input")) + + +def get_effect_cid(config: Dict[str, Any]) -> Optional[str]: + """Get effect CID from config, checking both 'cid' and 'hash' keys.""" + return config.get("cid") or config.get("hash") diff --git a/l1/app/utils/__init__.py b/l1/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/l1/app/utils/http_signatures.py b/l1/app/utils/http_signatures.py new file mode 100644 index 0000000..da1f105 --- /dev/null +++ b/l1/app/utils/http_signatures.py @@ -0,0 +1,84 @@ +"""HTTP Signature verification for incoming AP-style inbox requests. + +Implements the same RSA-SHA256 / PKCS1v15 scheme used by the coop's +shared/utils/http_signatures.py, but only the verification side. +""" +from __future__ import annotations + +import base64 +import re + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding + + +def verify_request_signature( + public_key_pem: str, + signature_header: str, + method: str, + path: str, + headers: dict[str, str], +) -> bool: + """Verify an incoming HTTP Signature. + + Args: + public_key_pem: PEM-encoded public key of the sender. + signature_header: Value of the ``Signature`` header. + method: HTTP method (GET, POST, etc.). + path: Request path (e.g. ``/inbox``). + headers: All request headers (case-insensitive keys). + + Returns: + True if the signature is valid. + """ + parts = _parse_signature_header(signature_header) + signed_headers = parts.get("headers", "date").split() + signature_b64 = parts.get("signature", "") + + # Reconstruct the signed string + lc_headers = {k.lower(): v for k, v in headers.items()} + lines: list[str] = [] + for h in signed_headers: + if h == "(request-target)": + lines.append(f"(request-target): {method.lower()} {path}") + else: + lines.append(f"{h}: {lc_headers.get(h, '')}") + + signed_string = "\n".join(lines) + + public_key = serialization.load_pem_public_key(public_key_pem.encode()) + try: + public_key.verify( + base64.b64decode(signature_b64), + signed_string.encode(), + padding.PKCS1v15(), + hashes.SHA256(), + ) + return True + except Exception: + return False + + +def parse_key_id(signature_header: str) -> str: + """Extract the keyId from a Signature header. + + keyId is typically ``https://domain/users/username#main-key``. + Returns the actor URL (strips ``#main-key``). + """ + parts = _parse_signature_header(signature_header) + key_id = parts.get("keyId", "") + return re.sub(r"#.*$", "", key_id) + + +def _parse_signature_header(header: str) -> dict[str, str]: + """Parse a Signature header into its component parts.""" + parts: dict[str, str] = {} + for part in header.split(","): + part = part.strip() + eq = part.find("=") + if eq < 0: + continue + key = part[:eq] + val = part[eq + 1:].strip('"') + parts[key] = val + return parts diff --git a/l1/artdag-client.tar.gz b/l1/artdag-client.tar.gz new file mode 100644 index 0000000..a4ec7f4 Binary files /dev/null and b/l1/artdag-client.tar.gz differ diff --git a/l1/build-client.sh b/l1/build-client.sh new file mode 100755 index 0000000..c9443b6 --- /dev/null +++ b/l1/build-client.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Build the artdag-client tarball +# This script is run during deployment to create the downloadable client package + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLIENT_REPO="https://git.rose-ash.com/art-dag/client.git" +TEMP_DIR=$(mktemp -d) +OUTPUT_FILE="$SCRIPT_DIR/artdag-client.tar.gz" + +echo "Building artdag-client.tar.gz..." + +# Clone the client repo +git clone --depth 1 "$CLIENT_REPO" "$TEMP_DIR/artdag-client" 2>/dev/null || { + echo "Failed to clone client repo, trying alternative..." + # Try GitHub if internal git fails + git clone --depth 1 "https://github.com/gilesbradshaw/art-client.git" "$TEMP_DIR/artdag-client" 2>/dev/null || { + echo "Error: Could not clone client repository" + rm -rf "$TEMP_DIR" + exit 1 + } +} + +# Remove .git directory +rm -rf "$TEMP_DIR/artdag-client/.git" +rm -rf "$TEMP_DIR/artdag-client/__pycache__" + +# Create tarball +cd "$TEMP_DIR" +tar -czf "$OUTPUT_FILE" artdag-client + +# Cleanup +rm -rf "$TEMP_DIR" + +echo "Created: $OUTPUT_FILE" +ls -lh "$OUTPUT_FILE" diff --git a/l1/cache_manager.py b/l1/cache_manager.py new file mode 100644 index 0000000..9474ca2 --- /dev/null +++ b/l1/cache_manager.py @@ -0,0 +1,872 @@ +# art-celery/cache_manager.py +""" +Cache management for Art DAG L1 server. + +Integrates artdag's Cache, ActivityStore, and ActivityManager to provide: +- Content-addressed caching with both node_id and cid +- Activity tracking for runs (input/output/intermediate relationships) +- Deletion rules enforcement (shared items protected) +- L2 ActivityPub integration for "shared" status checks +- IPFS as durable backing store (local cache as hot storage) +- Redis-backed indexes for multi-worker consistency +""" + +import hashlib +import json +import logging +import os +import shutil +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING + +import requests + +if TYPE_CHECKING: + import redis + +from artdag import Cache, CacheEntry, DAG, Node, NodeType +from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + +import ipfs_client + +logger = logging.getLogger(__name__) + + +def file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """Compute local content hash (fallback when IPFS unavailable).""" + hasher = hashlib.new(algorithm) + actual_path = path.resolve() if path.is_symlink() else path + with open(actual_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class CachedFile: + """ + A cached file with both identifiers. + + Provides a unified view combining: + - node_id: computation identity (for DAG caching) + - cid: file content identity (for external references) + """ + node_id: str + cid: str + path: Path + size_bytes: int + node_type: str + created_at: float + + @classmethod + def from_cache_entry(cls, entry: CacheEntry) -> "CachedFile": + return cls( + node_id=entry.node_id, + cid=entry.cid, + path=entry.output_path, + size_bytes=entry.size_bytes, + node_type=entry.node_type, + created_at=entry.created_at, + ) + + +class L2SharedChecker: + """ + Checks if content is shared (published) via L2 ActivityPub server. + + Caches results to avoid repeated API calls. + """ + + def __init__(self, l2_server: str, cache_ttl: int = 300): + self.l2_server = l2_server + self.cache_ttl = cache_ttl + self._cache: Dict[str, tuple[bool, float]] = {} + + def is_shared(self, cid: str) -> bool: + """Check if cid has been published to L2.""" + import time + now = time.time() + + # Check cache + if cid in self._cache: + is_shared, cached_at = self._cache[cid] + if now - cached_at < self.cache_ttl: + logger.debug(f"L2 check (cached): {cid[:16]}... = {is_shared}") + return is_shared + + # Query L2 + try: + url = f"{self.l2_server}/assets/by-hash/{cid}" + logger.info(f"L2 check: GET {url}") + resp = requests.get(url, timeout=5) + logger.info(f"L2 check response: {resp.status_code}") + is_shared = resp.status_code == 200 + except Exception as e: + logger.warning(f"Failed to check L2 for {cid}: {e}") + # On error, assume IS shared (safer - prevents accidental deletion) + is_shared = True + + self._cache[cid] = (is_shared, now) + return is_shared + + def invalidate(self, cid: str): + """Invalidate cache for a cid (call after publishing).""" + self._cache.pop(cid, None) + + def mark_shared(self, cid: str): + """Mark as shared without querying (call after successful publish).""" + import time + self._cache[cid] = (True, time.time()) + + +class L1CacheManager: + """ + Unified cache manager for Art DAG L1 server. + + Combines: + - artdag Cache for file storage + - ActivityStore for run tracking + - ActivityManager for deletion rules + - L2 integration for shared status + + Provides both node_id and cid based access. + """ + + def __init__( + self, + cache_dir: Path | str, + l2_server: str = "http://localhost:8200", + redis_client: Optional["redis.Redis"] = None, + ): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Redis for shared state between workers + self._redis = redis_client + self._redis_content_key = "artdag:content_index" + self._redis_ipfs_key = "artdag:ipfs_index" + + # artdag components + self.cache = Cache(self.cache_dir / "nodes") + self.activity_store = ActivityStore(self.cache_dir / "activities") + + # L2 shared checker + self.l2_checker = L2SharedChecker(l2_server) + + # Activity manager with L2-based is_shared + self.activity_manager = ActivityManager( + cache=self.cache, + activity_store=self.activity_store, + is_shared_fn=self._is_shared_by_node_id, + ) + + # Legacy files directory (for files uploaded directly by cid) + self.legacy_dir = self.cache_dir / "legacy" + self.legacy_dir.mkdir(parents=True, exist_ok=True) + + # ============ Redis Index (no JSON files) ============ + # + # Content index maps: CID (content hash or IPFS CID) -> node_id (code hash) + # IPFS index maps: node_id -> IPFS CID + # + # Database is the ONLY source of truth for cache_id -> ipfs_cid mapping. + # No fallbacks - failures raise exceptions. + + def _run_async(self, coro): + """Run async coroutine from sync context. + + Always creates a fresh event loop to avoid issues with Celery's + prefork workers where loops may be closed by previous tasks. + """ + import asyncio + + # Check if we're already in an async context + try: + asyncio.get_running_loop() + # We're in an async context - use a thread with its own loop + import threading + result = [None] + error = [None] + + def run_in_thread(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + result[0] = new_loop.run_until_complete(coro) + finally: + new_loop.close() + except Exception as e: + error[0] = e + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) + if error[0]: + raise error[0] + return result[0] + except RuntimeError: + # No running loop - create a fresh one (don't reuse potentially closed loops) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + def _set_content_index(self, cache_id: str, ipfs_cid: str): + """Set content index entry in database (cache_id -> ipfs_cid).""" + import database + + async def save_to_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + await conn.execute( + """ + INSERT INTO cache_items (cid, ipfs_cid) + VALUES ($1, $2) + ON CONFLICT (cid) DO UPDATE SET ipfs_cid = $2 + """, + cache_id, ipfs_cid + ) + finally: + await conn.close() + + self._run_async(save_to_db()) + logger.info(f"Indexed in database: {cache_id[:16]}... -> {ipfs_cid}") + + def _get_content_index(self, cache_id: str) -> Optional[str]: + """Get content index entry (cache_id -> ipfs_cid) from database.""" + import database + + async def get_from_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + row = await conn.fetchrow( + "SELECT ipfs_cid FROM cache_items WHERE cid = $1", + cache_id + ) + return {"ipfs_cid": row["ipfs_cid"]} if row else None + finally: + await conn.close() + + result = self._run_async(get_from_db()) + if result and result.get("ipfs_cid"): + return result["ipfs_cid"] + return None + + def _del_content_index(self, cache_id: str): + """Delete content index entry from database.""" + import database + + async def delete_from_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + await conn.execute("DELETE FROM cache_items WHERE cid = $1", cache_id) + finally: + await conn.close() + + self._run_async(delete_from_db()) + + def _set_ipfs_index(self, cid: str, ipfs_cid: str): + """Set IPFS index entry in Redis.""" + if self._redis: + try: + self._redis.hset(self._redis_ipfs_key, cid, ipfs_cid) + except Exception as e: + logger.warning(f"Failed to set IPFS index in Redis: {e}") + + def _get_ipfs_cid_from_index(self, cid: str) -> Optional[str]: + """Get IPFS CID from Redis.""" + if self._redis: + try: + val = self._redis.hget(self._redis_ipfs_key, cid) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Failed to get IPFS CID from Redis: {e}") + return None + + def get_ipfs_cid(self, cid: str) -> Optional[str]: + """Get IPFS CID for a content hash.""" + return self._get_ipfs_cid_from_index(cid) + + def _is_shared_by_node_id(self, cid: str) -> bool: + """Check if a cid is shared via L2.""" + return self.l2_checker.is_shared(cid) + + def _load_meta(self, cid: str) -> dict: + """Load metadata for a cached file.""" + meta_path = self.cache_dir / f"{cid}.meta.json" + if meta_path.exists(): + with open(meta_path) as f: + return json.load(f) + return {} + + def is_pinned(self, cid: str) -> tuple[bool, str]: + """ + Check if a cid is pinned (non-deletable). + + Returns: + (is_pinned, reason) tuple + """ + meta = self._load_meta(cid) + if meta.get("pinned"): + return True, meta.get("pin_reason", "published") + return False, "" + + def _save_meta(self, cid: str, **updates) -> dict: + """Save/update metadata for a cached file.""" + meta = self._load_meta(cid) + meta.update(updates) + meta_path = self.cache_dir / f"{cid}.meta.json" + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + return meta + + def pin(self, cid: str, reason: str = "published") -> None: + """Mark an item as pinned (non-deletable).""" + self._save_meta(cid, pinned=True, pin_reason=reason) + + # ============ File Storage ============ + + def put( + self, + source_path: Path, + node_type: str = "upload", + node_id: str = None, + cache_id: str = None, + execution_time: float = 0.0, + move: bool = False, + skip_ipfs: bool = False, + ) -> tuple[CachedFile, Optional[str]]: + """ + Store a file in the cache and optionally upload to IPFS. + + Files are stored by IPFS CID when skip_ipfs=False (default), or by + local content hash when skip_ipfs=True. The cache_id parameter creates + an index from cache_id -> CID for code-addressed lookups. + + Args: + source_path: Path to file to cache + node_type: Type of node (e.g., "upload", "source", "effect") + node_id: DEPRECATED - ignored, always uses CID + cache_id: Optional code-addressed cache ID to index + execution_time: How long the operation took + move: If True, move instead of copy + skip_ipfs: If True, skip IPFS upload and use local hash (faster for large files) + + Returns: + Tuple of (CachedFile with both node_id and cid, CID or None if skip_ipfs) + """ + if skip_ipfs: + # Use local content hash instead of IPFS CID (much faster) + cid = file_hash(source_path) + ipfs_cid = None + logger.info(f"put: Using local hash (skip_ipfs=True): {cid[:16]}...") + else: + # Upload to IPFS first to get the CID (primary identifier) + cid = ipfs_client.add_file(source_path) + if not cid: + raise RuntimeError(f"IPFS upload failed for {source_path}. IPFS is required.") + ipfs_cid = cid + + # Always store by IPFS CID (node_id parameter is deprecated) + node_id = cid + + # Check if already cached (by node_id) + existing = self.cache.get_entry(node_id) + if existing and existing.output_path.exists(): + return CachedFile.from_cache_entry(existing), ipfs_cid + + # Compute local hash BEFORE moving the file (for dual-indexing) + # Only needed if we uploaded to IPFS (to map local hash -> IPFS CID) + local_hash = None + if not skip_ipfs and self._is_ipfs_cid(cid): + local_hash = file_hash(source_path) + + # Store in local cache + logger.info(f"put: Storing in cache with node_id={node_id[:16]}...") + self.cache.put( + node_id=node_id, + source_path=source_path, + node_type=node_type, + execution_time=execution_time, + move=move, + ) + + entry = self.cache.get_entry(node_id) + logger.info(f"put: After cache.put, get_entry(node_id={node_id[:16]}...) returned entry={entry is not None}, path={entry.output_path if entry else None}") + + # Verify we can retrieve it + verify_path = self.cache.get(node_id) + logger.info(f"put: Verify cache.get(node_id={node_id[:16]}...) = {verify_path}") + + # Index by cache_id if provided (code-addressed cache lookup) + # This allows get_by_cid(cache_id) to find files stored by IPFS CID + if cache_id and cache_id != cid: + self._set_content_index(cache_id, cid) + logger.info(f"put: Indexed cache_id {cache_id[:16]}... -> IPFS {cid}") + + # Also index by local hash for content-based lookup + if local_hash and local_hash != cid: + self._set_content_index(local_hash, cid) + logger.debug(f"Indexed local hash {local_hash[:16]}... -> IPFS {cid}") + + logger.info(f"Cached: {cid[:16]}..." + (" (local only)" if skip_ipfs else " (IPFS)")) + + return CachedFile.from_cache_entry(entry), ipfs_cid if not skip_ipfs else None + + def get_by_node_id(self, node_id: str) -> Optional[Path]: + """Get cached file path by node_id.""" + return self.cache.get(node_id) + + def _is_ipfs_cid(self, identifier: str) -> bool: + """Check if identifier looks like an IPFS CID.""" + # CIDv0 starts with "Qm", CIDv1 starts with "bafy" or other multibase prefixes + return identifier.startswith("Qm") or identifier.startswith("bafy") or identifier.startswith("baf") + + def get_by_cid(self, cid: str) -> Optional[Path]: + """Get cached file path by cid or IPFS CID. Falls back to IPFS if not in local cache.""" + logger.info(f"get_by_cid: Looking for cid={cid[:16]}...") + + # Check index first (Redis then local) + node_id = self._get_content_index(cid) + logger.info(f"get_by_cid: Index lookup returned node_id={node_id[:16] if node_id else None}...") + if node_id: + path = self.cache.get(node_id) + logger.info(f"get_by_cid: cache.get(node_id={node_id[:16]}...) returned path={path}") + if path and path.exists(): + logger.info(f"get_by_cid: Found via index: {path}") + return path + + # artdag Cache doesn't know about entry - check filesystem directly + # Files are stored at {cache_dir}/nodes/{node_id}/output.* + nodes_dir = self.cache_dir / "nodes" / node_id + if nodes_dir.exists(): + for f in nodes_dir.iterdir(): + if f.name.startswith("output."): + logger.info(f"get_by_cid: Found on filesystem: {f}") + return f + + # For uploads, node_id == cid, so try direct lookup + # This works even if cache index hasn't been reloaded + path = self.cache.get(cid) + logger.info(f"get_by_cid: Direct cache.get({cid[:16]}...) returned: {path}") + if path and path.exists(): + self._set_content_index(cid, cid) + return path + + # Check filesystem directly for cid as node_id + nodes_dir = self.cache_dir / "nodes" / cid + if nodes_dir.exists(): + for f in nodes_dir.iterdir(): + if f.name.startswith("output."): + logger.info(f"get_by_cid: Found on filesystem (direct): {f}") + self._set_content_index(cid, cid) + return f + + # Scan cache entries (fallback for new structure) + entry = self.cache.find_by_cid(cid) + logger.info(f"get_by_cid: find_by_cid({cid[:16]}...) returned entry={entry}") + if entry and entry.output_path.exists(): + logger.info(f"get_by_cid: Found via scan: {entry.output_path}") + self._set_content_index(cid, entry.node_id) + return entry.output_path + + # Check legacy location (files stored directly as CACHE_DIR/{cid}) + legacy_path = self.cache_dir / cid + logger.info(f"get_by_cid: Checking legacy path: {legacy_path} exists={legacy_path.exists()}") + if legacy_path.exists() and legacy_path.is_file(): + logger.info(f"get_by_cid: Found at legacy path: {legacy_path}") + return legacy_path + + # Fetch from IPFS - this is the source of truth for all content + if self._is_ipfs_cid(cid): + logger.info(f"get_by_cid: Fetching from IPFS: {cid[:16]}...") + recovery_path = self.legacy_dir / cid + recovery_path.parent.mkdir(parents=True, exist_ok=True) + if ipfs_client.get_file(cid, str(recovery_path)): + logger.info(f"get_by_cid: Fetched from IPFS: {recovery_path}") + self._set_content_index(cid, cid) + return recovery_path + else: + logger.warning(f"get_by_cid: IPFS fetch failed for {cid[:16]}...") + + # Also try with a mapped IPFS CID if different from cid + ipfs_cid = self._get_ipfs_cid_from_index(cid) + if ipfs_cid and ipfs_cid != cid: + logger.info(f"get_by_cid: Fetching from IPFS via mapping: {ipfs_cid[:16]}...") + recovery_path = self.legacy_dir / cid + recovery_path.parent.mkdir(parents=True, exist_ok=True) + if ipfs_client.get_file(ipfs_cid, str(recovery_path)): + logger.info(f"get_by_cid: Fetched from IPFS: {recovery_path}") + return recovery_path + + return None + + def has_content(self, cid: str) -> bool: + """Check if content exists in cache.""" + return self.get_by_cid(cid) is not None + + def get_entry_by_cid(self, cid: str) -> Optional[CacheEntry]: + """Get cache entry by cid.""" + node_id = self._get_content_index(cid) + if node_id: + return self.cache.get_entry(node_id) + return self.cache.find_by_cid(cid) + + def list_all(self) -> List[CachedFile]: + """List all cached files.""" + files = [] + seen_hashes = set() + + # New cache structure entries + for entry in self.cache.list_entries(): + files.append(CachedFile.from_cache_entry(entry)) + if entry.cid: + seen_hashes.add(entry.cid) + + # Legacy files stored directly in cache_dir (old structure) + # These are files named by cid directly in CACHE_DIR + for f in self.cache_dir.iterdir(): + # Skip directories and special files + if not f.is_file(): + continue + # Skip metadata/auxiliary files + if f.suffix in ('.json', '.mp4'): + continue + # Skip if name doesn't look like a hash (64 hex chars) + if len(f.name) != 64 or not all(c in '0123456789abcdef' for c in f.name): + continue + # Skip if already seen via new cache + if f.name in seen_hashes: + continue + + files.append(CachedFile( + node_id=f.name, + cid=f.name, + path=f, + size_bytes=f.stat().st_size, + node_type="legacy", + created_at=f.stat().st_mtime, + )) + seen_hashes.add(f.name) + + return files + + def list_by_type(self, node_type: str) -> List[str]: + """ + List CIDs of all cached files of a specific type. + + Args: + node_type: Type to filter by (e.g., "recipe", "upload", "effect") + + Returns: + List of CIDs (IPFS CID if available, otherwise node_id) + """ + cids = [] + for entry in self.cache.list_entries(): + if entry.node_type == node_type: + # Return node_id which is the IPFS CID for uploaded content + cids.append(entry.node_id) + return cids + + # ============ Activity Tracking ============ + + def record_activity(self, dag: DAG, run_id: str = None) -> Activity: + """ + Record a DAG execution as an activity. + + Args: + dag: The executed DAG + run_id: Optional run ID to use as activity_id + + Returns: + The created Activity + """ + activity = Activity.from_dag(dag, activity_id=run_id) + self.activity_store.add(activity) + return activity + + def record_simple_activity( + self, + input_hashes: List[str], + output_cid: str, + run_id: str = None, + ) -> Activity: + """ + Record a simple (non-DAG) execution as an activity. + + For legacy single-effect runs that don't use full DAG execution. + Uses cid as node_id. + """ + activity = Activity( + activity_id=run_id or str(hash((tuple(input_hashes), output_cid))), + input_ids=sorted(input_hashes), + output_id=output_cid, + intermediate_ids=[], + created_at=datetime.now(timezone.utc).timestamp(), + status="completed", + ) + self.activity_store.add(activity) + return activity + + def get_activity(self, activity_id: str) -> Optional[Activity]: + """Get activity by ID.""" + return self.activity_store.get(activity_id) + + def list_activities(self) -> List[Activity]: + """List all activities.""" + return self.activity_store.list() + + def find_activities_by_inputs(self, input_hashes: List[str]) -> List[Activity]: + """Find activities with matching inputs (for UI grouping).""" + return self.activity_store.find_by_input_ids(input_hashes) + + # ============ Deletion Rules ============ + + def can_delete(self, cid: str) -> tuple[bool, str]: + """ + Check if a cached item can be deleted. + + Returns: + (can_delete, reason) tuple + """ + # Check if pinned (published or input to published) + pinned, reason = self.is_pinned(cid) + if pinned: + return False, f"Item is pinned ({reason})" + + # Find node_id for this content + node_id = self._get_content_index(cid) or cid + + # Check if it's an input or output of any activity + for activity in self.activity_store.list(): + if node_id in activity.input_ids: + return False, f"Item is input to activity {activity.activity_id}" + if node_id == activity.output_id: + return False, f"Item is output of activity {activity.activity_id}" + + return True, "OK" + + def can_discard_activity(self, activity_id: str) -> tuple[bool, str]: + """ + Check if an activity can be discarded. + + Returns: + (can_discard, reason) tuple + """ + activity = self.activity_store.get(activity_id) + if not activity: + return False, "Activity not found" + + # Check if any item is pinned + for node_id in activity.all_node_ids: + entry = self.cache.get_entry(node_id) + if entry: + pinned, reason = self.is_pinned(entry.cid) + if pinned: + return False, f"Item {node_id} is pinned ({reason})" + + return True, "OK" + + def delete_by_cid(self, cid: str) -> tuple[bool, str]: + """ + Delete a cached item by cid. + + Enforces deletion rules. + + Returns: + (success, message) tuple + """ + can_delete, reason = self.can_delete(cid) + if not can_delete: + return False, reason + + # Find and delete + node_id = self._get_content_index(cid) + if node_id: + self.cache.remove(node_id) + self._del_content_index(cid) + return True, "Deleted" + + # Try legacy + legacy_path = self.legacy_dir / cid + if legacy_path.exists(): + legacy_path.unlink() + return True, "Deleted (legacy)" + + return False, "Not found" + + def discard_activity(self, activity_id: str) -> tuple[bool, str]: + """ + Discard an activity and clean up its cache entries. + + Enforces deletion rules. + + Returns: + (success, message) tuple + """ + can_discard, reason = self.can_discard_activity(activity_id) + if not can_discard: + return False, reason + + success = self.activity_manager.discard_activity(activity_id) + if success: + return True, "Activity discarded" + return False, "Failed to discard" + + def _is_used_by_other_activities(self, node_id: str, exclude_activity_id: str) -> bool: + """Check if a node is used by any activity other than the excluded one.""" + for other_activity in self.activity_store.list(): + if other_activity.activity_id == exclude_activity_id: + continue + # Check if used as input, output, or intermediate + if node_id in other_activity.input_ids: + return True + if node_id == other_activity.output_id: + return True + if node_id in other_activity.intermediate_ids: + return True + return False + + def discard_activity_outputs_only(self, activity_id: str) -> tuple[bool, str]: + """ + Discard an activity, deleting only outputs and intermediates. + + Inputs (cache items, configs) are preserved. + Outputs/intermediates used by other activities are preserved. + + Returns: + (success, message) tuple + """ + activity = self.activity_store.get(activity_id) + if not activity: + return False, "Activity not found" + + # Check if output is pinned + if activity.output_id: + entry = self.cache.get_entry(activity.output_id) + if entry: + pinned, reason = self.is_pinned(entry.cid) + if pinned: + return False, f"Output is pinned ({reason})" + + deleted_outputs = 0 + preserved_shared = 0 + + # Delete output (only if not used by other activities) + if activity.output_id: + if self._is_used_by_other_activities(activity.output_id, activity_id): + preserved_shared += 1 + else: + entry = self.cache.get_entry(activity.output_id) + if entry: + # Remove from cache + self.cache.remove(activity.output_id) + # Remove from content index (Redis + local) + self._del_content_index(entry.cid) + # Delete from legacy dir if exists + legacy_path = self.legacy_dir / entry.cid + if legacy_path.exists(): + legacy_path.unlink() + deleted_outputs += 1 + + # Delete intermediates (only if not used by other activities) + for node_id in activity.intermediate_ids: + if self._is_used_by_other_activities(node_id, activity_id): + preserved_shared += 1 + continue + entry = self.cache.get_entry(node_id) + if entry: + self.cache.remove(node_id) + self._del_content_index(entry.cid) + legacy_path = self.legacy_dir / entry.cid + if legacy_path.exists(): + legacy_path.unlink() + deleted_outputs += 1 + + # Remove activity record (inputs remain in cache) + self.activity_store.remove(activity_id) + + msg = f"Activity discarded (deleted {deleted_outputs} outputs" + if preserved_shared > 0: + msg += f", preserved {preserved_shared} shared items" + msg += ")" + return True, msg + + def cleanup_intermediates(self) -> int: + """Delete all intermediate cache entries (reconstructible).""" + return self.activity_manager.cleanup_intermediates() + + def get_deletable_items(self) -> List[CachedFile]: + """Get all items that can be deleted.""" + deletable = [] + for entry in self.activity_manager.get_deletable_entries(): + deletable.append(CachedFile.from_cache_entry(entry)) + return deletable + + # ============ L2 Integration ============ + + def mark_published(self, cid: str): + """Mark a cid as published to L2.""" + self.l2_checker.mark_shared(cid) + + def invalidate_shared_cache(self, cid: str): + """Invalidate shared status cache (call if item might be unpublished).""" + self.l2_checker.invalidate(cid) + + # ============ Stats ============ + + def get_stats(self) -> dict: + """Get cache statistics.""" + stats = self.cache.get_stats() + return { + "total_entries": stats.total_entries, + "total_size_bytes": stats.total_size_bytes, + "hits": stats.hits, + "misses": stats.misses, + "hit_rate": stats.hit_rate, + "activities": len(self.activity_store), + } + + +# Singleton instance (initialized on first import with env vars) +_manager: Optional[L1CacheManager] = None + + +def get_cache_manager() -> L1CacheManager: + """Get the singleton cache manager instance.""" + global _manager + if _manager is None: + import redis + from urllib.parse import urlparse + + cache_dir = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache"))) + l2_server = os.environ.get("L2_SERVER", "http://localhost:8200") + + # Initialize Redis client for shared cache index + redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + parsed = urlparse(redis_url) + redis_client = redis.Redis( + host=parsed.hostname or 'localhost', + port=parsed.port or 6379, + db=int(parsed.path.lstrip('/') or 0), + socket_timeout=5, + socket_connect_timeout=5 + ) + + _manager = L1CacheManager(cache_dir=cache_dir, l2_server=l2_server, redis_client=redis_client) + return _manager + + +def reset_cache_manager(): + """Reset the singleton (for testing).""" + global _manager + _manager = None diff --git a/l1/celery_app.py b/l1/celery_app.py new file mode 100644 index 0000000..4a843d2 --- /dev/null +++ b/l1/celery_app.py @@ -0,0 +1,51 @@ +""" +Art DAG Celery Application + +Streaming video rendering for the Art DAG system. +Uses S-expression recipes with frame-by-frame processing. +""" + +import os +import sys +from celery import Celery +from celery.signals import worker_ready + +# Use central config +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from app.config import settings + +app = Celery( + 'art_celery', + broker=settings.redis_url, + backend=settings.redis_url, + include=['tasks', 'tasks.streaming', 'tasks.ipfs_upload'] +) + + +@worker_ready.connect +def log_config_on_startup(sender, **kwargs): + """Log configuration when worker starts.""" + print("=" * 60, file=sys.stderr) + print("WORKER STARTED - CONFIGURATION", file=sys.stderr) + print("=" * 60, file=sys.stderr) + settings.log_config() + print(f"Worker: {sender}", file=sys.stderr) + print("=" * 60, file=sys.stderr) + +app.conf.update( + result_expires=86400 * 7, # 7 days - allow time for recovery after restarts + task_serializer='json', + accept_content=['json', 'pickle'], # pickle needed for internal Celery messages + result_serializer='json', + event_serializer='json', + timezone='UTC', + enable_utc=True, + task_track_started=True, + task_acks_late=True, # Don't ack until task completes - survives worker restart + worker_prefetch_multiplier=1, + task_reject_on_worker_lost=True, # Re-queue if worker dies + task_acks_on_failure_or_timeout=True, # Ack failed tasks so they don't retry forever +) + +if __name__ == '__main__': + app.start() diff --git a/l1/check_redis.py b/l1/check_redis.py new file mode 100644 index 0000000..44f70aa --- /dev/null +++ b/l1/check_redis.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +"""Check Redis connectivity.""" + +import redis + +try: + r = redis.Redis(host='localhost', port=6379, db=0) + r.ping() + print("Redis: OK") +except redis.ConnectionError: + print("Redis: Not running") + print("Start with: sudo systemctl start redis-server") diff --git a/l1/claiming.py b/l1/claiming.py new file mode 100644 index 0000000..77fa1a0 --- /dev/null +++ b/l1/claiming.py @@ -0,0 +1,421 @@ +""" +Hash-based task claiming for distributed execution. + +Prevents duplicate work when multiple workers process the same plan. +Uses Redis Lua scripts for atomic claim operations. +""" + +import json +import logging +import os +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +import redis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + +# Key prefix for task claims +CLAIM_PREFIX = "artdag:claim:" + +# Default TTL for claims (5 minutes) +DEFAULT_CLAIM_TTL = 300 + +# TTL for completed results (1 hour) +COMPLETED_TTL = 3600 + + +class ClaimStatus(Enum): + """Status of a task claim.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + CACHED = "cached" + FAILED = "failed" + + +@dataclass +class ClaimInfo: + """Information about a task claim.""" + cache_id: str + status: ClaimStatus + worker_id: Optional[str] = None + task_id: Optional[str] = None + claimed_at: Optional[str] = None + completed_at: Optional[str] = None + output_path: Optional[str] = None + error: Optional[str] = None + + def to_dict(self) -> dict: + return { + "cache_id": self.cache_id, + "status": self.status.value, + "worker_id": self.worker_id, + "task_id": self.task_id, + "claimed_at": self.claimed_at, + "completed_at": self.completed_at, + "output_path": self.output_path, + "error": self.error, + } + + @classmethod + def from_dict(cls, data: dict) -> "ClaimInfo": + return cls( + cache_id=data["cache_id"], + status=ClaimStatus(data["status"]), + worker_id=data.get("worker_id"), + task_id=data.get("task_id"), + claimed_at=data.get("claimed_at"), + completed_at=data.get("completed_at"), + output_path=data.get("output_path"), + error=data.get("error"), + ) + + +# Lua script for atomic task claiming +# Returns 1 if claim successful, 0 if already claimed/completed +CLAIM_TASK_SCRIPT = """ +local key = KEYS[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + local s = status['status'] + -- Already claimed, running, completed, or cached - don't claim + if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then + return 0 + end +end + +-- Claim the task +local claim_data = ARGV[1] +local ttl = tonumber(ARGV[2]) +redis.call('SETEX', key, ttl, claim_data) +return 1 +""" + +# Lua script for releasing a claim (e.g., on failure) +RELEASE_CLAIM_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + -- Only release if we own the claim + if status['worker_id'] == worker_id then + redis.call('DEL', key) + return 1 + end +end +return 0 +""" + +# Lua script for updating claim status (claimed -> running -> completed) +UPDATE_STATUS_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local new_status = ARGV[2] +local new_data = ARGV[3] +local ttl = tonumber(ARGV[4]) + +local data = redis.call('GET', key) +if not data then + return 0 +end + +local status = cjson.decode(data) + +-- Only update if we own the claim +if status['worker_id'] ~= worker_id then + return 0 +end + +redis.call('SETEX', key, ttl, new_data) +return 1 +""" + + +class TaskClaimer: + """ + Manages hash-based task claiming for distributed execution. + + Uses Redis for coordination between workers. + Each task is identified by its cache_id (content-addressed). + """ + + def __init__(self, redis_url: str = None): + """ + Initialize the claimer. + + Args: + redis_url: Redis connection URL + """ + self.redis_url = redis_url or REDIS_URL + self._redis: Optional[redis.Redis] = None + self._claim_script = None + self._release_script = None + self._update_script = None + + @property + def redis(self) -> redis.Redis: + """Get Redis connection (lazy initialization).""" + if self._redis is None: + self._redis = redis.from_url(self.redis_url, decode_responses=True) + # Register Lua scripts + self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT) + self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT) + self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT) + return self._redis + + def _key(self, cache_id: str) -> str: + """Get Redis key for a cache_id.""" + return f"{CLAIM_PREFIX}{cache_id}" + + def claim( + self, + cache_id: str, + worker_id: str, + task_id: Optional[str] = None, + ttl: int = DEFAULT_CLAIM_TTL, + ) -> bool: + """ + Attempt to claim a task. + + Args: + cache_id: The cache ID of the task to claim + worker_id: Identifier for the claiming worker + task_id: Optional Celery task ID + ttl: Time-to-live for the claim in seconds + + Returns: + True if claim successful, False if already claimed + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CLAIMED, + worker_id=worker_id, + task_id=task_id, + claimed_at=datetime.now(timezone.utc).isoformat(), + ) + + result = self._claim_script( + keys=[self._key(cache_id)], + args=[json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}") + return True + else: + logger.debug(f"Task {cache_id[:16]}... already claimed") + return False + + def update_status( + self, + cache_id: str, + worker_id: str, + status: ClaimStatus, + output_path: Optional[str] = None, + error: Optional[str] = None, + ttl: Optional[int] = None, + ) -> bool: + """ + Update the status of a claimed task. + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + status: New status + output_path: Path to output (for completed) + error: Error message (for failed) + ttl: New TTL (defaults based on status) + + Returns: + True if update successful + """ + if ttl is None: + if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED): + ttl = COMPLETED_TTL + else: + ttl = DEFAULT_CLAIM_TTL + + # Get existing claim info + existing = self.get_status(cache_id) + if not existing: + logger.warning(f"No claim found for {cache_id[:16]}...") + return False + + claim_info = ClaimInfo( + cache_id=cache_id, + status=status, + worker_id=worker_id, + task_id=existing.task_id, + claimed_at=existing.claimed_at, + completed_at=datetime.now(timezone.utc).isoformat() if status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ) else None, + output_path=output_path, + error=error, + ) + + result = self._update_script( + keys=[self._key(cache_id)], + args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Updated task {cache_id[:16]}... to {status.value}") + return True + else: + logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)") + return False + + def release(self, cache_id: str, worker_id: str) -> bool: + """ + Release a claim (e.g., on task failure before completion). + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + + Returns: + True if release successful + """ + result = self._release_script( + keys=[self._key(cache_id)], + args=[worker_id], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Released claim on {cache_id[:16]}...") + return True + return False + + def get_status(self, cache_id: str) -> Optional[ClaimInfo]: + """ + Get the current status of a task. + + Args: + cache_id: The cache ID of the task + + Returns: + ClaimInfo if task has been claimed, None otherwise + """ + data = self.redis.get(self._key(cache_id)) + if data: + return ClaimInfo.from_dict(json.loads(data)) + return None + + def is_completed(self, cache_id: str) -> bool: + """Check if a task is completed or cached.""" + info = self.get_status(cache_id) + return info is not None and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED + ) + + def wait_for_completion( + self, + cache_id: str, + timeout: float = 300, + poll_interval: float = 0.5, + ) -> Optional[ClaimInfo]: + """ + Wait for a task to complete. + + Args: + cache_id: The cache ID of the task + timeout: Maximum time to wait in seconds + poll_interval: How often to check status + + Returns: + ClaimInfo if completed, None if timeout + """ + start_time = time.time() + while time.time() - start_time < timeout: + info = self.get_status(cache_id) + if info and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ): + return info + time.sleep(poll_interval) + + logger.warning(f"Timeout waiting for {cache_id[:16]}...") + return None + + def mark_cached(self, cache_id: str, output_path: str) -> None: + """ + Mark a task as already cached (no processing needed). + + This is used when we discover the result already exists + before attempting to claim. + + Args: + cache_id: The cache ID of the task + output_path: Path to the cached output + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CACHED, + output_path=output_path, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + self.redis.setex( + self._key(cache_id), + COMPLETED_TTL, + json.dumps(claim_info.to_dict()), + ) + + def clear_all(self) -> int: + """ + Clear all claims (for testing/reset). + + Returns: + Number of claims cleared + """ + pattern = f"{CLAIM_PREFIX}*" + keys = list(self.redis.scan_iter(match=pattern)) + if keys: + return self.redis.delete(*keys) + return 0 + + +# Global claimer instance +_claimer: Optional[TaskClaimer] = None + + +def get_claimer() -> TaskClaimer: + """Get the global TaskClaimer instance.""" + global _claimer + if _claimer is None: + _claimer = TaskClaimer() + return _claimer + + +def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool: + """Convenience function to claim a task.""" + return get_claimer().claim(cache_id, worker_id, task_id) + + +def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool: + """Convenience function to mark a task as completed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path + ) + + +def fail_task(cache_id: str, worker_id: str, error: str) -> bool: + """Convenience function to mark a task as failed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.FAILED, error=error + ) diff --git a/l1/configs/audio-dizzy.sexp b/l1/configs/audio-dizzy.sexp new file mode 100644 index 0000000..dc16087 --- /dev/null +++ b/l1/configs/audio-dizzy.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Paths relative to working directory (project root) +(def music (streaming:make-audio-analyzer "dizzy.mp3")) + +;; Audio playback path (for sync with video output) +(audio-playback "dizzy.mp3") diff --git a/l1/configs/audio-halleluwah.sexp b/l1/configs/audio-halleluwah.sexp new file mode 100644 index 0000000..7d7bfae --- /dev/null +++ b/l1/configs/audio-halleluwah.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Using friendly name for asset resolution +(def music (streaming:make-audio-analyzer "woods-audio")) + +;; Audio playback path (for sync with video output) +(audio-playback "woods-audio") diff --git a/l1/configs/sources-default.sexp b/l1/configs/sources-default.sexp new file mode 100644 index 0000000..754bd92 --- /dev/null +++ b/l1/configs/sources-default.sexp @@ -0,0 +1,38 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Paths relative to working directory (project root) +(def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 +]) diff --git a/l1/configs/sources-woods-half.sexp b/l1/configs/sources-woods-half.sexp new file mode 100644 index 0000000..d2feff8 --- /dev/null +++ b/l1/configs/sources-woods-half.sexp @@ -0,0 +1,19 @@ +;; Half-resolution Woods Sources (960x540) +;; +;; Pass to recipe with: --sources configs/sources-woods-half.sexp + +(require-primitives "streaming") + +(def sources [ + (streaming:make-video-source "woods_half/1.webm" 30) + (streaming:make-video-source "woods_half/2.webm" 30) + (streaming:make-video-source "woods_half/3.webm" 30) + (streaming:make-video-source "woods_half/4.webm" 30) +]) + +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} +]) diff --git a/l1/configs/sources-woods.sexp b/l1/configs/sources-woods.sexp new file mode 100644 index 0000000..ab8dff4 --- /dev/null +++ b/l1/configs/sources-woods.sexp @@ -0,0 +1,39 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Using friendly names for asset resolution +(def sources [ + (streaming:make-video-source "woods-1" 10) + (streaming:make-video-source "woods-2" 10) + (streaming:make-video-source "woods-3" 10) + (streaming:make-video-source "woods-4" 10) + (streaming:make-video-source "woods-5" 10) + (streaming:make-video-source "woods-6" 10) + (streaming:make-video-source "woods-7" 10) + (streaming:make-video-source "woods-8" 10) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + +]) diff --git a/l1/database.py b/l1/database.py new file mode 100644 index 0000000..70187db --- /dev/null +++ b/l1/database.py @@ -0,0 +1,2144 @@ +# art-celery/database.py +""" +PostgreSQL database module for Art DAG L1 server. + +Provides connection pooling and CRUD operations for cache metadata. +""" + +import os +from datetime import datetime, timezone +from typing import List, Optional + +import asyncpg + +DATABASE_URL = os.getenv("DATABASE_URL") +if not DATABASE_URL: + raise RuntimeError("DATABASE_URL environment variable is required") + +pool: Optional[asyncpg.Pool] = None + +SCHEMA_SQL = """ +-- Core cache: just content hash and IPFS CID +-- Physical file storage - shared by all users +CREATE TABLE IF NOT EXISTS cache_items ( + cid VARCHAR(64) PRIMARY KEY, + ipfs_cid VARCHAR(128), + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Item types: per-user metadata (same item can be recipe AND media, per user) +-- actor_id format: @username@server (ActivityPub style) +CREATE TABLE IF NOT EXISTS item_types ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) REFERENCES cache_items(cid) ON DELETE CASCADE, + actor_id VARCHAR(255) NOT NULL, + type VARCHAR(50) NOT NULL, + path VARCHAR(255), + description TEXT, + source_type VARCHAR(20), + source_url TEXT, + source_note TEXT, + pinned BOOLEAN DEFAULT FALSE, + filename VARCHAR(255), + metadata JSONB DEFAULT '{}', + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(cid, actor_id, type, path) +); + + +-- Pin reasons: one-to-many from item_types +CREATE TABLE IF NOT EXISTS pin_reasons ( + id SERIAL PRIMARY KEY, + item_type_id INTEGER REFERENCES item_types(id) ON DELETE CASCADE, + reason VARCHAR(100) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- L2 shares: per-user shares (includes content_type for role when shared) +CREATE TABLE IF NOT EXISTS l2_shares ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) REFERENCES cache_items(cid) ON DELETE CASCADE, + actor_id VARCHAR(255) NOT NULL, + l2_server VARCHAR(255) NOT NULL, + asset_name VARCHAR(255) NOT NULL, + activity_id VARCHAR(128), + content_type VARCHAR(50) NOT NULL, + published_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + last_synced_at TIMESTAMP WITH TIME ZONE, + UNIQUE(cid, actor_id, l2_server, content_type) +); + + +-- Run cache: maps content-addressable run_id to output +-- run_id is a hash of (sorted inputs + recipe), making runs deterministic +CREATE TABLE IF NOT EXISTS run_cache ( + run_id VARCHAR(64) PRIMARY KEY, + output_cid VARCHAR(64) NOT NULL, + ipfs_cid VARCHAR(128), + provenance_cid VARCHAR(128), + plan_cid VARCHAR(128), + recipe VARCHAR(255) NOT NULL, + inputs JSONB NOT NULL, + actor_id VARCHAR(255), + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Pending/running runs: tracks in-progress work for durability +-- Allows runs to survive restarts and be recovered +CREATE TABLE IF NOT EXISTS pending_runs ( + run_id VARCHAR(64) PRIMARY KEY, + celery_task_id VARCHAR(128), + status VARCHAR(20) NOT NULL DEFAULT 'pending', -- pending, running, failed + recipe VARCHAR(255) NOT NULL, + inputs JSONB NOT NULL, + dag_json TEXT, + plan_cid VARCHAR(128), + output_name VARCHAR(255), + actor_id VARCHAR(255), + error TEXT, + ipfs_playlist_cid VARCHAR(128), -- For streaming: IPFS CID of HLS playlist + quality_playlists JSONB, -- For streaming: quality-level playlist CIDs {quality_name: {cid, width, height, bitrate}} + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Add ipfs_playlist_cid if table exists but column doesn't (migration) +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'ipfs_playlist_cid') THEN + ALTER TABLE pending_runs ADD COLUMN ipfs_playlist_cid VARCHAR(128); + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'quality_playlists') THEN + ALTER TABLE pending_runs ADD COLUMN quality_playlists JSONB; + END IF; + -- Checkpoint columns for resumable renders + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_frame') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_frame INTEGER; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_t') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_t FLOAT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_scans') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_scans JSONB; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'total_frames') THEN + ALTER TABLE pending_runs ADD COLUMN total_frames INTEGER; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'resumable') THEN + ALTER TABLE pending_runs ADD COLUMN resumable BOOLEAN DEFAULT TRUE; + END IF; +END $$; + +CREATE INDEX IF NOT EXISTS idx_pending_runs_status ON pending_runs(status); +CREATE INDEX IF NOT EXISTS idx_pending_runs_actor ON pending_runs(actor_id); + +-- User storage backends (synced from L2 or configured locally) +CREATE TABLE IF NOT EXISTS storage_backends ( + id SERIAL PRIMARY KEY, + actor_id VARCHAR(255) NOT NULL, + provider_type VARCHAR(50) NOT NULL, -- 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + provider_name VARCHAR(255), + description TEXT, + config JSONB NOT NULL DEFAULT '{}', + capacity_gb INTEGER NOT NULL, + used_bytes BIGINT DEFAULT 0, + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + synced_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Storage pins tracking (what's pinned where) +CREATE TABLE IF NOT EXISTS storage_pins ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) NOT NULL, + storage_id INTEGER NOT NULL REFERENCES storage_backends(id) ON DELETE CASCADE, + ipfs_cid VARCHAR(128), + pin_type VARCHAR(20) NOT NULL, -- 'user_content', 'donated', 'system' + size_bytes BIGINT, + pinned_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(cid, storage_id) +); + +-- Friendly names: human-readable versioned names for content +-- Version IDs are server-signed timestamps (always increasing, verifiable origin) +-- Names are per-user within L1; when shared to L2: @user@domain name version +CREATE TABLE IF NOT EXISTS friendly_names ( + id SERIAL PRIMARY KEY, + actor_id VARCHAR(255) NOT NULL, + base_name VARCHAR(255) NOT NULL, -- normalized: my-cool-effect + version_id VARCHAR(20) NOT NULL, -- server-signed timestamp: 01hw3x9kab2cd + cid VARCHAR(64) NOT NULL, -- content address + item_type VARCHAR(20) NOT NULL, -- recipe | effect | media + display_name VARCHAR(255), -- original "My Cool Effect" + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + + UNIQUE(actor_id, base_name, version_id), + UNIQUE(actor_id, cid) -- each CID has exactly one friendly name per user +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_item_types_cid ON item_types(cid); +CREATE INDEX IF NOT EXISTS idx_item_types_actor_id ON item_types(actor_id); +CREATE INDEX IF NOT EXISTS idx_item_types_type ON item_types(type); +CREATE INDEX IF NOT EXISTS idx_item_types_path ON item_types(path); +CREATE INDEX IF NOT EXISTS idx_pin_reasons_item_type ON pin_reasons(item_type_id); +CREATE INDEX IF NOT EXISTS idx_l2_shares_cid ON l2_shares(cid); +CREATE INDEX IF NOT EXISTS idx_l2_shares_actor_id ON l2_shares(actor_id); +CREATE INDEX IF NOT EXISTS idx_run_cache_output ON run_cache(output_cid); +CREATE INDEX IF NOT EXISTS idx_storage_backends_actor ON storage_backends(actor_id); +CREATE INDEX IF NOT EXISTS idx_storage_backends_type ON storage_backends(provider_type); +CREATE INDEX IF NOT EXISTS idx_storage_pins_hash ON storage_pins(cid); +CREATE INDEX IF NOT EXISTS idx_storage_pins_storage ON storage_pins(storage_id); +CREATE INDEX IF NOT EXISTS idx_friendly_names_actor ON friendly_names(actor_id); +CREATE INDEX IF NOT EXISTS idx_friendly_names_type ON friendly_names(item_type); +CREATE INDEX IF NOT EXISTS idx_friendly_names_base ON friendly_names(actor_id, base_name); +CREATE INDEX IF NOT EXISTS idx_friendly_names_latest ON friendly_names(actor_id, item_type, base_name, created_at DESC); +""" + + +async def init_db(): + """Initialize database connection pool and create schema. + + Raises: + asyncpg.PostgresError: If database connection fails + RuntimeError: If pool creation fails + """ + global pool + if pool is not None: + return # Already initialized + try: + pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + if pool is None: + raise RuntimeError(f"Failed to create database pool for {DATABASE_URL}") + async with pool.acquire() as conn: + await conn.execute(SCHEMA_SQL) + except asyncpg.PostgresError as e: + pool = None + raise RuntimeError(f"Database connection failed: {e}") from e + except Exception as e: + pool = None + raise RuntimeError(f"Database initialization failed: {e}") from e + + +async def close_db(): + """Close database connection pool.""" + global pool + if pool: + await pool.close() + pool = None + + +# ============ Cache Items ============ + +async def create_cache_item(cid: str, ipfs_cid: Optional[str] = None) -> dict: + """Create a cache item. Returns the created item.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO cache_items (cid, ipfs_cid) + VALUES ($1, $2) + ON CONFLICT (cid) DO UPDATE SET ipfs_cid = COALESCE($2, cache_items.ipfs_cid) + RETURNING cid, ipfs_cid, created_at + """, + cid, ipfs_cid + ) + return dict(row) + + +async def get_cache_item(cid: str) -> Optional[dict]: + """Get a cache item by content hash.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT cid, ipfs_cid, created_at FROM cache_items WHERE cid = $1", + cid + ) + return dict(row) if row else None + + +async def update_cache_item_ipfs_cid(cid: str, ipfs_cid: str) -> bool: + """Update the IPFS CID for a cache item.""" + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE cache_items SET ipfs_cid = $2 WHERE cid = $1", + cid, ipfs_cid + ) + return result == "UPDATE 1" + + +async def get_ipfs_cid(cid: str) -> Optional[str]: + """Get the IPFS CID for a cache item by its internal CID.""" + async with pool.acquire() as conn: + return await conn.fetchval( + "SELECT ipfs_cid FROM cache_items WHERE cid = $1", + cid + ) + + +async def delete_cache_item(cid: str) -> bool: + """Delete a cache item and all associated data (cascades).""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM cache_items WHERE cid = $1", + cid + ) + return result == "DELETE 1" + + +async def list_cache_items(limit: int = 100, offset: int = 0) -> List[dict]: + """List cache items with pagination.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT cid, ipfs_cid, created_at + FROM cache_items + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + """, + limit, offset + ) + return [dict(row) for row in rows] + + +# ============ Item Types ============ + +async def add_item_type( + cid: str, + actor_id: str, + item_type: str, + path: Optional[str] = None, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, +) -> dict: + """Add a type to a cache item for a user. Creates cache_item if needed.""" + async with pool.acquire() as conn: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + cid + ) + # Insert or update item_type + row = await conn.fetchrow( + """ + INSERT INTO item_types (cid, actor_id, type, path, description, source_type, source_url, source_note) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE($5, item_types.description), + source_type = COALESCE($6, item_types.source_type), + source_url = COALESCE($7, item_types.source_url), + source_note = COALESCE($8, item_types.source_note) + RETURNING id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + """, + cid, actor_id, item_type, path, description, source_type, source_url, source_note + ) + return dict(row) + + +async def get_item_types(cid: str, actor_id: Optional[str] = None) -> List[dict]: + """Get types for a cache item, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 + ORDER BY created_at + """, + cid, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 + ORDER BY created_at + """, + cid + ) + return [dict(row) for row in rows] + + +async def get_item_type(cid: str, actor_id: str, item_type: str, path: Optional[str] = None) -> Optional[dict]: + """Get a specific type for a cache item and user.""" + async with pool.acquire() as conn: + if path is None: + row = await conn.fetchrow( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """, + cid, actor_id, item_type + ) + else: + row = await conn.fetchrow( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path = $4 + """, + cid, actor_id, item_type, path + ) + return dict(row) if row else None + + +async def update_item_type( + item_type_id: int, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, +) -> bool: + """Update an item type's metadata.""" + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE item_types SET + description = COALESCE($2, description), + source_type = COALESCE($3, source_type), + source_url = COALESCE($4, source_url), + source_note = COALESCE($5, source_note) + WHERE id = $1 + """, + item_type_id, description, source_type, source_url, source_note + ) + return result == "UPDATE 1" + + +async def delete_item_type(cid: str, actor_id: str, item_type: str, path: Optional[str] = None) -> bool: + """Delete a specific type from a cache item for a user.""" + async with pool.acquire() as conn: + if path is None: + result = await conn.execute( + "DELETE FROM item_types WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL", + cid, actor_id, item_type + ) + else: + result = await conn.execute( + "DELETE FROM item_types WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path = $4", + cid, actor_id, item_type, path + ) + return result == "DELETE 1" + + +async def list_items_by_type(item_type: str, actor_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[dict]: + """List items of a specific type, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.actor_id = $2 + ORDER BY it.created_at DESC + LIMIT $3 OFFSET $4 + """, + item_type, actor_id, limit, offset + ) + else: + rows = await conn.fetch( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 + ORDER BY it.created_at DESC + LIMIT $2 OFFSET $3 + """, + item_type, limit, offset + ) + return [dict(row) for row in rows] + + +async def get_item_by_path(item_type: str, path: str, actor_id: Optional[str] = None) -> Optional[dict]: + """Get an item by its type and path (e.g., recipe:/effects/dog), optionally for a specific user.""" + async with pool.acquire() as conn: + if actor_id: + row = await conn.fetchrow( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.path = $2 AND it.actor_id = $3 + """, + item_type, path, actor_id + ) + else: + row = await conn.fetchrow( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.path = $2 + """, + item_type, path + ) + return dict(row) if row else None + + +# ============ Pinning ============ + +async def pin_item_type(item_type_id: int, reason: str) -> bool: + """Pin an item type with a reason.""" + async with pool.acquire() as conn: + async with conn.transaction(): + # Set pinned flag + await conn.execute( + "UPDATE item_types SET pinned = TRUE WHERE id = $1", + item_type_id + ) + # Add pin reason + await conn.execute( + "INSERT INTO pin_reasons (item_type_id, reason) VALUES ($1, $2)", + item_type_id, reason + ) + return True + + +async def unpin_item_type(item_type_id: int, reason: Optional[str] = None) -> bool: + """Remove a pin reason from an item type. If no reasons left, unpins the item.""" + async with pool.acquire() as conn: + async with conn.transaction(): + if reason: + # Remove specific reason + await conn.execute( + "DELETE FROM pin_reasons WHERE item_type_id = $1 AND reason = $2", + item_type_id, reason + ) + else: + # Remove all reasons + await conn.execute( + "DELETE FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + + # Check if any reasons remain + count = await conn.fetchval( + "SELECT COUNT(*) FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + + if count == 0: + await conn.execute( + "UPDATE item_types SET pinned = FALSE WHERE id = $1", + item_type_id + ) + return True + + +async def get_pin_reasons(item_type_id: int) -> List[dict]: + """Get all pin reasons for an item type.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT id, reason, created_at FROM pin_reasons WHERE item_type_id = $1 ORDER BY created_at", + item_type_id + ) + return [dict(row) for row in rows] + + +async def is_item_pinned(cid: str, item_type: Optional[str] = None) -> tuple[bool, List[str]]: + """Check if any type of a cache item is pinned. Returns (is_pinned, reasons).""" + async with pool.acquire() as conn: + if item_type: + rows = await conn.fetch( + """ + SELECT pr.reason + FROM pin_reasons pr + JOIN item_types it ON pr.item_type_id = it.id + WHERE it.cid = $1 AND it.type = $2 AND it.pinned = TRUE + """, + cid, item_type + ) + else: + rows = await conn.fetch( + """ + SELECT pr.reason + FROM pin_reasons pr + JOIN item_types it ON pr.item_type_id = it.id + WHERE it.cid = $1 AND it.pinned = TRUE + """, + cid + ) + reasons = [row["reason"] for row in rows] + return len(reasons) > 0, reasons + + +# ============ L2 Shares ============ + +async def add_l2_share( + cid: str, + actor_id: str, + l2_server: str, + asset_name: str, + content_type: str, +) -> dict: + """Add or update an L2 share for a user.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO l2_shares (cid, actor_id, l2_server, asset_name, content_type, last_synced_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (cid, actor_id, l2_server, content_type) DO UPDATE SET + asset_name = $4, + last_synced_at = NOW() + RETURNING id, cid, actor_id, l2_server, asset_name, content_type, published_at, last_synced_at + """, + cid, actor_id, l2_server, asset_name, content_type + ) + return dict(row) + + +async def get_l2_shares(cid: str, actor_id: Optional[str] = None) -> List[dict]: + """Get L2 shares for a cache item, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares + WHERE cid = $1 AND actor_id = $2 + ORDER BY published_at + """, + cid, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares + WHERE cid = $1 + ORDER BY published_at + """, + cid + ) + return [dict(row) for row in rows] + + +async def delete_l2_share(cid: str, actor_id: str, l2_server: str, content_type: str) -> bool: + """Delete an L2 share for a user.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM l2_shares WHERE cid = $1 AND actor_id = $2 AND l2_server = $3 AND content_type = $4", + cid, actor_id, l2_server, content_type + ) + return result == "DELETE 1" + + +# ============ Cache Item Cleanup ============ + +async def has_remaining_references(cid: str) -> bool: + """Check if a cache item has any remaining item_types or l2_shares.""" + async with pool.acquire() as conn: + item_types_count = await conn.fetchval( + "SELECT COUNT(*) FROM item_types WHERE cid = $1", + cid + ) + if item_types_count > 0: + return True + + l2_shares_count = await conn.fetchval( + "SELECT COUNT(*) FROM l2_shares WHERE cid = $1", + cid + ) + return l2_shares_count > 0 + + +async def cleanup_orphaned_cache_item(cid: str) -> bool: + """Delete a cache item if it has no remaining references. Returns True if deleted.""" + async with pool.acquire() as conn: + # Only delete if no item_types or l2_shares reference it + result = await conn.execute( + """ + DELETE FROM cache_items + WHERE cid = $1 + AND NOT EXISTS (SELECT 1 FROM item_types WHERE cid = $1) + AND NOT EXISTS (SELECT 1 FROM l2_shares WHERE cid = $1) + """, + cid + ) + return result == "DELETE 1" + + +# ============ High-Level Metadata Functions ============ +# These provide a compatible interface to the old JSON-based save_cache_meta/load_cache_meta + +import json as _json + + +async def save_item_metadata( + cid: str, + actor_id: str, + item_type: str = "media", + filename: Optional[str] = None, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, + pinned: bool = False, + pin_reason: Optional[str] = None, + tags: Optional[List[str]] = None, + folder: Optional[str] = None, + collections: Optional[List[str]] = None, + **extra_metadata +) -> dict: + """ + Save or update item metadata in the database. + + Returns a dict with the item metadata (compatible with old JSON format). + """ + import logging + logger = logging.getLogger(__name__) + logger.info(f"save_item_metadata: cid={cid[:16] if cid else None}..., actor_id={actor_id}, item_type={item_type}") + # Build metadata JSONB for extra fields + metadata = {} + if tags: + metadata["tags"] = tags + if folder: + metadata["folder"] = folder + if collections: + metadata["collections"] = collections + metadata.update(extra_metadata) + + async with pool.acquire() as conn: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + cid + ) + + # Upsert item_type + row = await conn.fetchrow( + """ + INSERT INTO item_types (cid, actor_id, type, description, source_type, source_url, source_note, pinned, filename, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE(EXCLUDED.description, item_types.description), + source_type = COALESCE(EXCLUDED.source_type, item_types.source_type), + source_url = COALESCE(EXCLUDED.source_url, item_types.source_url), + source_note = COALESCE(EXCLUDED.source_note, item_types.source_note), + pinned = EXCLUDED.pinned, + filename = COALESCE(EXCLUDED.filename, item_types.filename), + metadata = item_types.metadata || EXCLUDED.metadata + RETURNING id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + """, + cid, actor_id, item_type, description, source_type, source_url, source_note, pinned, filename, _json.dumps(metadata) + ) + + item_type_id = row["id"] + logger.info(f"save_item_metadata: Created/updated item_type id={item_type_id} for cid={cid[:16]}...") + + # Handle pinning + if pinned and pin_reason: + # Add pin reason if not exists + await conn.execute( + """ + INSERT INTO pin_reasons (item_type_id, reason) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, + item_type_id, pin_reason + ) + + # Build response dict (compatible with old format) + result = { + "uploader": actor_id, + "uploaded_at": row["created_at"].isoformat() if row["created_at"] else None, + "filename": row["filename"], + "type": row["type"], + "description": row["description"], + "pinned": row["pinned"], + } + + # Add origin if present + if row["source_type"] or row["source_url"] or row["source_note"]: + result["origin"] = { + "type": row["source_type"], + "url": row["source_url"], + "note": row["source_note"] + } + + # Add metadata fields + if row["metadata"]: + meta = row["metadata"] if isinstance(row["metadata"], dict) else _json.loads(row["metadata"]) + if meta.get("tags"): + result["tags"] = meta["tags"] + if meta.get("folder"): + result["folder"] = meta["folder"] + if meta.get("collections"): + result["collections"] = meta["collections"] + + # Get pin reasons + if row["pinned"]: + reasons = await conn.fetch( + "SELECT reason FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + if reasons: + result["pin_reason"] = reasons[0]["reason"] + + return result + + +async def load_item_metadata(cid: str, actor_id: Optional[str] = None) -> dict: + """ + Load item metadata from the database. + + If actor_id is provided, returns metadata for that user's view of the item. + Otherwise, returns combined metadata from all users (for backwards compat). + + Returns a dict compatible with old JSON format. + """ + async with pool.acquire() as conn: + # Get cache item + cache_item = await conn.fetchrow( + "SELECT cid, ipfs_cid, created_at FROM cache_items WHERE cid = $1", + cid + ) + + if not cache_item: + return {} + + # Get item types + if actor_id: + item_types = await conn.fetch( + """ + SELECT id, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + FROM item_types WHERE cid = $1 AND actor_id = $2 + ORDER BY created_at + """, + cid, actor_id + ) + else: + item_types = await conn.fetch( + """ + SELECT id, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + FROM item_types WHERE cid = $1 + ORDER BY created_at + """, + cid + ) + + if not item_types: + return {"uploaded_at": cache_item["created_at"].isoformat() if cache_item["created_at"] else None} + + # Use first item type as primary (for backwards compat) + primary = item_types[0] + + result = { + "uploader": primary["actor_id"], + "uploaded_at": primary["created_at"].isoformat() if primary["created_at"] else None, + "filename": primary["filename"], + "type": primary["type"], + "description": primary["description"], + "pinned": any(it["pinned"] for it in item_types), + } + + # Add origin if present + if primary["source_type"] or primary["source_url"] or primary["source_note"]: + result["origin"] = { + "type": primary["source_type"], + "url": primary["source_url"], + "note": primary["source_note"] + } + + # Add metadata fields + if primary["metadata"]: + meta = primary["metadata"] if isinstance(primary["metadata"], dict) else _json.loads(primary["metadata"]) + if meta.get("tags"): + result["tags"] = meta["tags"] + if meta.get("folder"): + result["folder"] = meta["folder"] + if meta.get("collections"): + result["collections"] = meta["collections"] + + # Get pin reasons for pinned items + for it in item_types: + if it["pinned"]: + reasons = await conn.fetch( + "SELECT reason FROM pin_reasons WHERE item_type_id = $1", + it["id"] + ) + if reasons: + result["pin_reason"] = reasons[0]["reason"] + break + + # Get L2 shares + if actor_id: + shares = await conn.fetch( + """ + SELECT l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares WHERE cid = $1 AND actor_id = $2 + """, + cid, actor_id + ) + else: + shares = await conn.fetch( + """ + SELECT l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares WHERE cid = $1 + """, + cid + ) + + if shares: + result["l2_shares"] = [ + { + "l2_server": s["l2_server"], + "asset_name": s["asset_name"], + "activity_id": s["activity_id"], + "content_type": s["content_type"], + "published_at": s["published_at"].isoformat() if s["published_at"] else None, + "last_synced_at": s["last_synced_at"].isoformat() if s["last_synced_at"] else None, + } + for s in shares + ] + + # For backwards compat, also set "published" if shared + result["published"] = { + "to_l2": True, + "asset_name": shares[0]["asset_name"], + "activity_id": shares[0]["activity_id"], + "l2_server": shares[0]["l2_server"], + } + + return result + + +async def update_item_metadata( + cid: str, + actor_id: str, + item_type: str = "media", + **updates +) -> dict: + """ + Update specific fields of item metadata. + + Returns updated metadata dict. + """ + # Extract known fields from updates + description = updates.pop("description", None) + source_type = updates.pop("source_type", None) + source_url = updates.pop("source_url", None) + source_note = updates.pop("source_note", None) + + # Handle origin dict format + origin = updates.pop("origin", None) + if origin: + source_type = origin.get("type", source_type) + source_url = origin.get("url", source_url) + source_note = origin.get("note", source_note) + + pinned = updates.pop("pinned", None) + pin_reason = updates.pop("pin_reason", None) + filename = updates.pop("filename", None) + tags = updates.pop("tags", None) + folder = updates.pop("folder", None) + collections = updates.pop("collections", None) + + async with pool.acquire() as conn: + # Get existing item_type + existing = await conn.fetchrow( + """ + SELECT id, metadata FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """, + cid, actor_id, item_type + ) + + if not existing: + # Create new entry + return await save_item_metadata( + cid, actor_id, item_type, + filename=filename, description=description, + source_type=source_type, source_url=source_url, source_note=source_note, + pinned=pinned or False, pin_reason=pin_reason, + tags=tags, folder=folder, collections=collections, + **updates + ) + + # Build update query dynamically + set_parts = [] + params = [cid, actor_id, item_type] + param_idx = 4 + + if description is not None: + set_parts.append(f"description = ${param_idx}") + params.append(description) + param_idx += 1 + + if source_type is not None: + set_parts.append(f"source_type = ${param_idx}") + params.append(source_type) + param_idx += 1 + + if source_url is not None: + set_parts.append(f"source_url = ${param_idx}") + params.append(source_url) + param_idx += 1 + + if source_note is not None: + set_parts.append(f"source_note = ${param_idx}") + params.append(source_note) + param_idx += 1 + + if pinned is not None: + set_parts.append(f"pinned = ${param_idx}") + params.append(pinned) + param_idx += 1 + + if filename is not None: + set_parts.append(f"filename = ${param_idx}") + params.append(filename) + param_idx += 1 + + # Handle metadata updates + current_metadata = existing["metadata"] if isinstance(existing["metadata"], dict) else (_json.loads(existing["metadata"]) if existing["metadata"] else {}) + if tags is not None: + current_metadata["tags"] = tags + if folder is not None: + current_metadata["folder"] = folder + if collections is not None: + current_metadata["collections"] = collections + current_metadata.update(updates) + + if current_metadata: + set_parts.append(f"metadata = ${param_idx}") + params.append(_json.dumps(current_metadata)) + param_idx += 1 + + if set_parts: + query = f""" + UPDATE item_types SET {', '.join(set_parts)} + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """ + await conn.execute(query, *params) + + # Handle pin reason + if pinned and pin_reason: + await conn.execute( + """ + INSERT INTO pin_reasons (item_type_id, reason) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, + existing["id"], pin_reason + ) + + return await load_item_metadata(cid, actor_id) + + +async def save_l2_share( + cid: str, + actor_id: str, + l2_server: str, + asset_name: str, + content_type: str = "media", + activity_id: Optional[str] = None +) -> dict: + """Save an L2 share and return share info.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO l2_shares (cid, actor_id, l2_server, asset_name, activity_id, content_type, last_synced_at) + VALUES ($1, $2, $3, $4, $5, $6, NOW()) + ON CONFLICT (cid, actor_id, l2_server, content_type) DO UPDATE SET + asset_name = EXCLUDED.asset_name, + activity_id = COALESCE(EXCLUDED.activity_id, l2_shares.activity_id), + last_synced_at = NOW() + RETURNING l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + """, + cid, actor_id, l2_server, asset_name, activity_id, content_type + ) + return { + "l2_server": row["l2_server"], + "asset_name": row["asset_name"], + "activity_id": row["activity_id"], + "content_type": row["content_type"], + "published_at": row["published_at"].isoformat() if row["published_at"] else None, + "last_synced_at": row["last_synced_at"].isoformat() if row["last_synced_at"] else None, + } + + +async def get_user_items(actor_id: str, item_type: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[dict]: + """Get all items for a user, optionally filtered by type. Deduplicates by cid.""" + async with pool.acquire() as conn: + if item_type: + rows = await conn.fetch( + """ + SELECT * FROM ( + SELECT DISTINCT ON (it.cid) + it.cid, it.type, it.description, it.filename, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.actor_id = $1 AND it.type = $2 + ORDER BY it.cid, it.created_at DESC + ) deduped + ORDER BY created_at DESC + LIMIT $3 OFFSET $4 + """, + actor_id, item_type, limit, offset + ) + else: + rows = await conn.fetch( + """ + SELECT * FROM ( + SELECT DISTINCT ON (it.cid) + it.cid, it.type, it.description, it.filename, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.actor_id = $1 + ORDER BY it.cid, it.created_at DESC + ) deduped + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + """, + actor_id, limit, offset + ) + + return [ + { + "cid": r["cid"], + "type": r["type"], + "description": r["description"], + "filename": r["filename"], + "pinned": r["pinned"], + "created_at": r["created_at"].isoformat() if r["created_at"] else None, + "ipfs_cid": r["ipfs_cid"], + } + for r in rows + ] + + +async def count_user_items(actor_id: str, item_type: Optional[str] = None) -> int: + """Count unique items (by cid) for a user.""" + async with pool.acquire() as conn: + if item_type: + return await conn.fetchval( + "SELECT COUNT(DISTINCT cid) FROM item_types WHERE actor_id = $1 AND type = $2", + actor_id, item_type + ) + else: + return await conn.fetchval( + "SELECT COUNT(DISTINCT cid) FROM item_types WHERE actor_id = $1", + actor_id + ) + + +# ============ Run Cache ============ + +async def get_run_cache(run_id: str) -> Optional[dict]: + """Get cached run result by content-addressable run_id.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + FROM run_cache WHERE run_id = $1 + """, + run_id + ) + if row: + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +async def save_run_cache( + run_id: str, + output_cid: str, + recipe: str, + inputs: List[str], + ipfs_cid: Optional[str] = None, + provenance_cid: Optional[str] = None, + plan_cid: Optional[str] = None, + actor_id: Optional[str] = None, +) -> dict: + """Save run result to cache. Updates if run_id already exists.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO run_cache (run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (run_id) DO UPDATE SET + output_cid = EXCLUDED.output_cid, + ipfs_cid = COALESCE(EXCLUDED.ipfs_cid, run_cache.ipfs_cid), + provenance_cid = COALESCE(EXCLUDED.provenance_cid, run_cache.provenance_cid), + plan_cid = COALESCE(EXCLUDED.plan_cid, run_cache.plan_cid), + actor_id = COALESCE(EXCLUDED.actor_id, run_cache.actor_id), + recipe = COALESCE(EXCLUDED.recipe, run_cache.recipe), + inputs = COALESCE(EXCLUDED.inputs, run_cache.inputs) + RETURNING run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + """, + run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, _json.dumps(inputs), actor_id + ) + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + + +async def get_run_by_output(output_cid: str) -> Optional[dict]: + """Get run cache entry by output hash.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, recipe, inputs, actor_id, created_at + FROM run_cache WHERE output_cid = $1 + """, + output_cid + ) + if row: + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +def _parse_inputs(inputs_value): + """Parse inputs from database - may be JSON string, list, or None.""" + if inputs_value is None: + return [] + if isinstance(inputs_value, list): + return inputs_value + if isinstance(inputs_value, str): + try: + parsed = _json.loads(inputs_value) + if isinstance(parsed, list): + return parsed + return [] + except (_json.JSONDecodeError, TypeError): + return [] + return [] + + +async def list_runs_by_actor(actor_id: str, offset: int = 0, limit: int = 20) -> List[dict]: + """List completed runs for a user, ordered by creation time (newest first).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + FROM run_cache + WHERE actor_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + """, + actor_id, limit, offset + ) + return [ + { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": _parse_inputs(row["inputs"]), + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "status": "completed", + } + for row in rows + ] + + +async def delete_run_cache(run_id: str) -> bool: + """Delete a run from the cache.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM run_cache WHERE run_id = $1", + run_id + ) + return result == "DELETE 1" + + +async def delete_pending_run(run_id: str) -> bool: + """Delete a pending run.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM pending_runs WHERE run_id = $1", + run_id + ) + return result == "DELETE 1" + + +# ============ Storage Backends ============ + +async def get_user_storage(actor_id: str) -> List[dict]: + """Get all storage backends for a user.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE actor_id = $1 + ORDER BY provider_type, created_at""", + actor_id + ) + return [dict(row) for row in rows] + + +async def get_user_storage_by_type(actor_id: str, provider_type: str) -> List[dict]: + """Get storage backends of a specific type for a user.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE actor_id = $1 AND provider_type = $2 + ORDER BY created_at""", + actor_id, provider_type + ) + return [dict(row) for row in rows] + + +async def get_storage_by_id(storage_id: int) -> Optional[dict]: + """Get a storage backend by ID.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE id = $1""", + storage_id + ) + return dict(row) if row else None + + +async def add_user_storage( + actor_id: str, + provider_type: str, + provider_name: str, + config: dict, + capacity_gb: int, + description: Optional[str] = None +) -> Optional[int]: + """Add a storage backend for a user. Returns storage ID.""" + async with pool.acquire() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO storage_backends (actor_id, provider_type, provider_name, description, config, capacity_gb) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id""", + actor_id, provider_type, provider_name, description, _json.dumps(config), capacity_gb + ) + return row["id"] if row else None + except Exception: + return None + + +async def update_user_storage( + storage_id: int, + provider_name: Optional[str] = None, + description: Optional[str] = None, + config: Optional[dict] = None, + capacity_gb: Optional[int] = None, + is_active: Optional[bool] = None +) -> bool: + """Update a storage backend.""" + updates = [] + params = [] + param_num = 1 + + if provider_name is not None: + updates.append(f"provider_name = ${param_num}") + params.append(provider_name) + param_num += 1 + if description is not None: + updates.append(f"description = ${param_num}") + params.append(description) + param_num += 1 + if config is not None: + updates.append(f"config = ${param_num}") + params.append(_json.dumps(config)) + param_num += 1 + if capacity_gb is not None: + updates.append(f"capacity_gb = ${param_num}") + params.append(capacity_gb) + param_num += 1 + if is_active is not None: + updates.append(f"is_active = ${param_num}") + params.append(is_active) + param_num += 1 + + if not updates: + return False + + updates.append("synced_at = NOW()") + params.append(storage_id) + + async with pool.acquire() as conn: + result = await conn.execute( + f"UPDATE storage_backends SET {', '.join(updates)} WHERE id = ${param_num}", + *params + ) + return "UPDATE 1" in result + + +async def remove_user_storage(storage_id: int) -> bool: + """Remove a storage backend. Cascades to storage_pins.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM storage_backends WHERE id = $1", + storage_id + ) + return "DELETE 1" in result + + +async def get_storage_usage(storage_id: int) -> dict: + """Get storage usage stats.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT + COUNT(*) as pin_count, + COALESCE(SUM(size_bytes), 0) as used_bytes + FROM storage_pins WHERE storage_id = $1""", + storage_id + ) + return {"pin_count": row["pin_count"], "used_bytes": row["used_bytes"]} + + +async def get_all_active_storage() -> List[dict]: + """Get all active storage backends (for distributed pinning).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT sb.id, sb.actor_id, sb.provider_type, sb.provider_name, sb.description, + sb.config, sb.capacity_gb, sb.is_active, sb.created_at, sb.synced_at, + COALESCE(SUM(sp.size_bytes), 0) as used_bytes, + COUNT(sp.id) as pin_count + FROM storage_backends sb + LEFT JOIN storage_pins sp ON sb.id = sp.storage_id + WHERE sb.is_active = true + GROUP BY sb.id + ORDER BY sb.provider_type, sb.created_at""" + ) + return [dict(row) for row in rows] + + +async def add_storage_pin( + cid: str, + storage_id: int, + ipfs_cid: Optional[str], + pin_type: str, + size_bytes: int +) -> Optional[int]: + """Add a pin record. Returns pin ID.""" + async with pool.acquire() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO storage_pins (cid, storage_id, ipfs_cid, pin_type, size_bytes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (cid, storage_id) DO UPDATE SET + ipfs_cid = EXCLUDED.ipfs_cid, + pin_type = EXCLUDED.pin_type, + size_bytes = EXCLUDED.size_bytes, + pinned_at = NOW() + RETURNING id""", + cid, storage_id, ipfs_cid, pin_type, size_bytes + ) + return row["id"] if row else None + except Exception: + return None + + +async def remove_storage_pin(cid: str, storage_id: int) -> bool: + """Remove a pin record.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM storage_pins WHERE cid = $1 AND storage_id = $2", + cid, storage_id + ) + return "DELETE 1" in result + + +async def get_pins_for_content(cid: str) -> List[dict]: + """Get all storage locations where content is pinned.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT sp.*, sb.provider_type, sb.provider_name, sb.actor_id + FROM storage_pins sp + JOIN storage_backends sb ON sp.storage_id = sb.id + WHERE sp.cid = $1""", + cid + ) + return [dict(row) for row in rows] + + +# ============ Pending Runs ============ + +async def create_pending_run( + run_id: str, + celery_task_id: str, + recipe: str, + inputs: List[str], + actor_id: str, + dag_json: Optional[str] = None, + output_name: Optional[str] = None, +) -> dict: + """Create a pending run record for durability.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO pending_runs (run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id) + VALUES ($1, $2, 'running', $3, $4, $5, $6, $7) + ON CONFLICT (run_id) DO UPDATE SET + celery_task_id = EXCLUDED.celery_task_id, + status = 'running', + updated_at = NOW() + RETURNING run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id, created_at, updated_at + """, + run_id, celery_task_id, recipe, _json.dumps(inputs), dag_json, output_name, actor_id + ) + return { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "dag_json": row["dag_json"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + + +async def get_pending_run(run_id: str) -> Optional[dict]: + """Get a pending run by ID.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, celery_task_id, status, recipe, inputs, dag_json, plan_cid, output_name, actor_id, error, + ipfs_playlist_cid, quality_playlists, checkpoint_frame, checkpoint_t, checkpoint_scans, + total_frames, resumable, created_at, updated_at + FROM pending_runs WHERE run_id = $1 + """, + run_id + ) + if row: + # Parse inputs if it's a string (JSONB should auto-parse but be safe) + inputs = row["inputs"] + if isinstance(inputs, str): + inputs = _json.loads(inputs) + # Parse quality_playlists if it's a string + quality_playlists = row.get("quality_playlists") + if isinstance(quality_playlists, str): + quality_playlists = _json.loads(quality_playlists) + # Parse checkpoint_scans if it's a string + checkpoint_scans = row.get("checkpoint_scans") + if isinstance(checkpoint_scans, str): + checkpoint_scans = _json.loads(checkpoint_scans) + return { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": inputs, + "dag_json": row["dag_json"], + "plan_cid": row["plan_cid"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "error": row["error"], + "ipfs_playlist_cid": row["ipfs_playlist_cid"], + "quality_playlists": quality_playlists, + "checkpoint_frame": row.get("checkpoint_frame"), + "checkpoint_t": row.get("checkpoint_t"), + "checkpoint_scans": checkpoint_scans, + "total_frames": row.get("total_frames"), + "resumable": row.get("resumable", True), + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + return None + + +async def list_pending_runs(actor_id: Optional[str] = None, status: Optional[str] = None) -> List[dict]: + """List pending runs, optionally filtered by actor and/or status.""" + async with pool.acquire() as conn: + conditions = [] + params = [] + param_idx = 1 + + if actor_id: + conditions.append(f"actor_id = ${param_idx}") + params.append(actor_id) + param_idx += 1 + + if status: + conditions.append(f"status = ${param_idx}") + params.append(status) + param_idx += 1 + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + rows = await conn.fetch( + f""" + SELECT run_id, celery_task_id, status, recipe, inputs, output_name, actor_id, error, created_at, updated_at + FROM pending_runs + WHERE {where_clause} + ORDER BY created_at DESC + """, + *params + ) + results = [] + for row in rows: + # Parse inputs if it's a string + inputs = row["inputs"] + if isinstance(inputs, str): + inputs = _json.loads(inputs) + results.append({ + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": inputs, + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "error": row["error"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + }) + return results + + +async def update_pending_run_status(run_id: str, status: str, error: Optional[str] = None) -> bool: + """Update the status of a pending run.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + if error: + result = await conn.execute( + "UPDATE pending_runs SET status = $2, error = $3, updated_at = NOW() WHERE run_id = $1", + run_id, status, error + ) + else: + result = await conn.execute( + "UPDATE pending_runs SET status = $2, updated_at = NOW() WHERE run_id = $1", + run_id, status + ) + return "UPDATE 1" in result + + +async def update_pending_run_plan(run_id: str, plan_cid: str) -> bool: + """Update the plan_cid of a pending run (called when plan is generated).""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE pending_runs SET plan_cid = $2, updated_at = NOW() WHERE run_id = $1", + run_id, plan_cid + ) + return "UPDATE 1" in result + + +async def update_pending_run_playlist(run_id: str, ipfs_playlist_cid: str, quality_playlists: Optional[dict] = None) -> bool: + """Update the IPFS playlist CID of a streaming run. + + Args: + run_id: The run ID + ipfs_playlist_cid: Master playlist CID + quality_playlists: Dict of quality name -> {cid, width, height, bitrate} + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + if quality_playlists: + result = await conn.execute( + "UPDATE pending_runs SET ipfs_playlist_cid = $2, quality_playlists = $3, updated_at = NOW() WHERE run_id = $1", + run_id, ipfs_playlist_cid, _json.dumps(quality_playlists) + ) + else: + result = await conn.execute( + "UPDATE pending_runs SET ipfs_playlist_cid = $2, updated_at = NOW() WHERE run_id = $1", + run_id, ipfs_playlist_cid + ) + return "UPDATE 1" in result + + +async def update_pending_run_checkpoint( + run_id: str, + checkpoint_frame: int, + checkpoint_t: float, + checkpoint_scans: Optional[dict] = None, + total_frames: Optional[int] = None, +) -> bool: + """Update checkpoint state for a streaming run. + + Called at segment boundaries to enable resume after failures. + + Args: + run_id: The run ID + checkpoint_frame: Last completed frame at segment boundary + checkpoint_t: Time value for checkpoint frame + checkpoint_scans: Accumulated scan state {scan_name: state_dict} + total_frames: Total expected frames (for progress %) + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE pending_runs SET + checkpoint_frame = $2, + checkpoint_t = $3, + checkpoint_scans = $4, + total_frames = COALESCE($5, total_frames), + updated_at = NOW() + WHERE run_id = $1 + """, + run_id, + checkpoint_frame, + checkpoint_t, + _json.dumps(checkpoint_scans) if checkpoint_scans else None, + total_frames, + ) + return "UPDATE 1" in result + + +async def get_run_checkpoint(run_id: str) -> Optional[dict]: + """Get checkpoint data for resuming a run. + + Returns: + Dict with checkpoint_frame, checkpoint_t, checkpoint_scans, quality_playlists, etc. + or None if no checkpoint exists + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT checkpoint_frame, checkpoint_t, checkpoint_scans, total_frames, + quality_playlists, ipfs_playlist_cid, resumable + FROM pending_runs WHERE run_id = $1 + """, + run_id + ) + if row and row.get("checkpoint_frame") is not None: + # Parse JSONB fields + checkpoint_scans = row.get("checkpoint_scans") + if isinstance(checkpoint_scans, str): + checkpoint_scans = _json.loads(checkpoint_scans) + quality_playlists = row.get("quality_playlists") + if isinstance(quality_playlists, str): + quality_playlists = _json.loads(quality_playlists) + return { + "frame_num": row["checkpoint_frame"], + "t": row["checkpoint_t"], + "scans": checkpoint_scans or {}, + "total_frames": row.get("total_frames"), + "quality_playlists": quality_playlists, + "ipfs_playlist_cid": row.get("ipfs_playlist_cid"), + "resumable": row.get("resumable", True), + } + return None + + +async def clear_run_checkpoint(run_id: str) -> bool: + """Clear checkpoint data for a run (used on restart). + + Args: + run_id: The run ID + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE pending_runs SET + checkpoint_frame = NULL, + checkpoint_t = NULL, + checkpoint_scans = NULL, + quality_playlists = NULL, + ipfs_playlist_cid = NULL, + updated_at = NOW() + WHERE run_id = $1 + """, + run_id, + ) + return "UPDATE 1" in result + + +async def complete_pending_run(run_id: str) -> bool: + """Remove a pending run after it completes (moves to run_cache).""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM pending_runs WHERE run_id = $1", + run_id + ) + return "DELETE 1" in result + + +async def get_stale_pending_runs(older_than_hours: int = 24) -> List[dict]: + """Get pending runs that haven't been updated recently (for recovery).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id, created_at, updated_at + FROM pending_runs + WHERE status IN ('pending', 'running') + AND updated_at < NOW() - INTERVAL '%s hours' + ORDER BY created_at + """, + older_than_hours + ) + return [ + { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "dag_json": row["dag_json"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + for row in rows + ] + + +# ============ Friendly Names ============ + +async def create_friendly_name( + actor_id: str, + base_name: str, + version_id: str, + cid: str, + item_type: str, + display_name: Optional[str] = None, +) -> dict: + """ + Create a friendly name entry. + + Returns the created entry. + """ + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO friendly_names (actor_id, base_name, version_id, cid, item_type, display_name) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (actor_id, cid) DO UPDATE SET + base_name = EXCLUDED.base_name, + version_id = EXCLUDED.version_id, + display_name = EXCLUDED.display_name + RETURNING id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + """, + actor_id, base_name, version_id, cid, item_type, display_name + ) + return { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + + +async def get_friendly_name_by_cid(actor_id: str, cid: str) -> Optional[dict]: + """Get friendly name entry by CID.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 AND cid = $2 + """, + actor_id, cid + ) + if row: + return { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +async def resolve_friendly_name( + actor_id: str, + name: str, + item_type: Optional[str] = None, +) -> Optional[str]: + """ + Resolve a friendly name to a CID. + + Name can be: + - "base-name" -> resolves to latest version + - "base-name version-id" -> resolves to exact version + + Returns CID or None if not found. + """ + parts = name.strip().split(' ') + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + + async with pool.acquire() as conn: + if version_id: + # Exact version lookup + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = $1 AND base_name = $2 AND version_id = $3 + """ + params = [actor_id, base_name, version_id] + if item_type: + query += " AND item_type = $4" + params.append(item_type) + + return await conn.fetchval(query, *params) + else: + # Latest version lookup + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = $1 AND base_name = $2 + """ + params = [actor_id, base_name] + if item_type: + query += " AND item_type = $3" + params.append(item_type) + query += " ORDER BY created_at DESC LIMIT 1" + + return await conn.fetchval(query, *params) + + +async def list_friendly_names( + actor_id: str, + item_type: Optional[str] = None, + base_name: Optional[str] = None, + latest_only: bool = False, +) -> List[dict]: + """ + List friendly names for a user. + + Args: + actor_id: User ID + item_type: Filter by type (recipe, effect, media) + base_name: Filter by base name + latest_only: If True, only return latest version of each base name + """ + async with pool.acquire() as conn: + if latest_only: + # Use DISTINCT ON to get latest version of each base name + query = """ + SELECT DISTINCT ON (base_name) + id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 + """ + params = [actor_id] + if item_type: + query += " AND item_type = $2" + params.append(item_type) + if base_name: + query += f" AND base_name = ${len(params) + 1}" + params.append(base_name) + query += " ORDER BY base_name, created_at DESC" + else: + query = """ + SELECT id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 + """ + params = [actor_id] + if item_type: + query += " AND item_type = $2" + params.append(item_type) + if base_name: + query += f" AND base_name = ${len(params) + 1}" + params.append(base_name) + query += " ORDER BY base_name, created_at DESC" + + rows = await conn.fetch(query, *params) + return [ + { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + for row in rows + ] + + +async def delete_friendly_name(actor_id: str, cid: str) -> bool: + """Delete a friendly name entry by CID.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM friendly_names WHERE actor_id = $1 AND cid = $2", + actor_id, cid + ) + return "DELETE 1" in result + + +async def update_friendly_name_cid(actor_id: str, old_cid: str, new_cid: str) -> bool: + """ + Update a friendly name's CID (used when IPFS upload completes). + + This updates the CID from a local SHA256 hash to an IPFS CID, + ensuring assets can be fetched by remote workers via IPFS. + """ + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE friendly_names SET cid = $3 WHERE actor_id = $1 AND cid = $2", + actor_id, old_cid, new_cid + ) + return "UPDATE 1" in result + + +# ============================================================================= +# SYNCHRONOUS DATABASE FUNCTIONS (for use from non-async contexts like video streaming) +# ============================================================================= + +def resolve_friendly_name_sync( + actor_id: str, + name: str, + item_type: Optional[str] = None, +) -> Optional[str]: + """ + Synchronous version of resolve_friendly_name using psycopg2. + + Useful when calling from synchronous code (e.g., video streaming callbacks) + where async/await is not possible. + + Returns CID or None if not found. + """ + import psycopg2 + + parts = name.strip().split(' ') + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + + try: + conn = psycopg2.connect(DATABASE_URL) + cursor = conn.cursor() + + if version_id: + # Exact version lookup + if item_type: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND version_id = %s AND item_type = %s + """ + cursor.execute(query, (actor_id, base_name, version_id, item_type)) + else: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND version_id = %s + """ + cursor.execute(query, (actor_id, base_name, version_id)) + else: + # Latest version lookup + if item_type: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND item_type = %s + ORDER BY created_at DESC LIMIT 1 + """ + cursor.execute(query, (actor_id, base_name, item_type)) + else: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s + ORDER BY created_at DESC LIMIT 1 + """ + cursor.execute(query, (actor_id, base_name)) + + result = cursor.fetchone() + cursor.close() + conn.close() + + return result[0] if result else None + + except Exception as e: + import sys + print(f"resolve_friendly_name_sync ERROR: {e}", file=sys.stderr) + return None + + +def get_ipfs_cid_sync(cid: str) -> Optional[str]: + """ + Synchronous version of get_ipfs_cid using psycopg2. + + Returns the IPFS CID for a given internal CID, or None if not found. + """ + import psycopg2 + + try: + conn = psycopg2.connect(DATABASE_URL) + cursor = conn.cursor() + + cursor.execute( + "SELECT ipfs_cid FROM cache_items WHERE cid = %s", + (cid,) + ) + + result = cursor.fetchone() + cursor.close() + conn.close() + + return result[0] if result else None + + except Exception as e: + import sys + print(f"get_ipfs_cid_sync ERROR: {e}", file=sys.stderr) + return None diff --git a/l1/deploy.sh b/l1/deploy.sh new file mode 100755 index 0000000..a2d6e69 --- /dev/null +++ b/l1/deploy.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +cd "$(dirname "$0")" + +echo "=== Pulling latest code ===" +git pull + +echo "=== Building Docker image ===" +docker build --build-arg CACHEBUST=$(date +%s) -t registry.rose-ash.com:5000/celery-l1-server:latest . + +echo "=== Pushing to registry ===" +docker push registry.rose-ash.com:5000/celery-l1-server:latest + +echo "=== Redeploying celery stack ===" +docker stack deploy -c docker-compose.yml celery + +echo "=== Done ===" +docker stack services celery diff --git a/l1/diagnose_gpu.py b/l1/diagnose_gpu.py new file mode 100755 index 0000000..5136139 --- /dev/null +++ b/l1/diagnose_gpu.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +GPU Rendering Diagnostic Script + +Checks for common issues that cause GPU rendering slowdowns in art-dag. +Run this script to identify potential performance bottlenecks. +""" + +import sys +import subprocess +import os + +def print_section(title): + print(f"\n{'='*60}") + print(f" {title}") + print(f"{'='*60}") + +def check_pass(msg): + print(f" [PASS] {msg}") + +def check_fail(msg): + print(f" [FAIL] {msg}") + +def check_warn(msg): + print(f" [WARN] {msg}") + +def check_info(msg): + print(f" [INFO] {msg}") + +# ============================================================ +# 1. Check GPU Availability +# ============================================================ +print_section("1. GPU AVAILABILITY") + +# Check nvidia-smi +try: + result = subprocess.run(["nvidia-smi", "--query-gpu=name,memory.total,memory.free,utilization.gpu", + "--format=csv,noheader"], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + for line in result.stdout.strip().split('\n'): + check_pass(f"GPU found: {line}") + else: + check_fail("nvidia-smi failed - no GPU detected") +except FileNotFoundError: + check_fail("nvidia-smi not found - NVIDIA drivers not installed") +except Exception as e: + check_fail(f"nvidia-smi error: {e}") + +# ============================================================ +# 2. Check CuPy +# ============================================================ +print_section("2. CUPY (GPU ARRAY LIBRARY)") + +try: + import cupy as cp + check_pass(f"CuPy available, version {cp.__version__}") + + # Test basic GPU operation + try: + a = cp.zeros((100, 100), dtype=cp.uint8) + cp.cuda.Stream.null.synchronize() + check_pass("CuPy GPU operations working") + + # Check memory + mempool = cp.get_default_memory_pool() + check_info(f"GPU memory pool: {mempool.used_bytes() / 1024**2:.1f} MB used, " + f"{mempool.total_bytes() / 1024**2:.1f} MB total") + except Exception as e: + check_fail(f"CuPy GPU test failed: {e}") +except ImportError: + check_fail("CuPy not installed - GPU rendering disabled") + +# ============================================================ +# 3. Check PyNvVideoCodec (GPU Encoding) +# ============================================================ +print_section("3. PYNVVIDEOCODEC (GPU ENCODING)") + +try: + import PyNvVideoCodec as nvc + check_pass("PyNvVideoCodec available - zero-copy GPU encoding enabled") +except ImportError: + check_warn("PyNvVideoCodec not available - using FFmpeg NVENC (slower)") + +# ============================================================ +# 4. Check Decord GPU (Hardware Decode) +# ============================================================ +print_section("4. DECORD GPU (HARDWARE DECODE)") + +try: + import decord + from decord import gpu + ctx = gpu(0) + check_pass(f"Decord GPU (NVDEC) available - hardware video decode enabled") +except ImportError: + check_warn("Decord not installed - using FFmpeg decode") +except Exception as e: + check_warn(f"Decord GPU not available ({e}) - using FFmpeg decode") + +# ============================================================ +# 5. Check DLPack Support +# ============================================================ +print_section("5. DLPACK (ZERO-COPY TRANSFER)") + +try: + import decord + from decord import VideoReader, gpu + import cupy as cp + + # Need a test video file + test_video = None + for path in ["/data/cache", "/tmp"]: + if os.path.exists(path): + for f in os.listdir(path): + if f.endswith(('.mp4', '.webm', '.mkv')): + test_video = os.path.join(path, f) + break + if test_video: + break + + if test_video: + try: + vr = VideoReader(test_video, ctx=gpu(0)) + frame = vr[0] + dlpack = frame.to_dlpack() + gpu_frame = cp.from_dlpack(dlpack) + check_pass(f"DLPack zero-copy working (tested with {os.path.basename(test_video)})") + except Exception as e: + check_fail(f"DLPack FAILED: {e}") + check_info("This means every frame does GPU->CPU->GPU copy (SLOW)") + else: + check_warn("No test video found - cannot verify DLPack") +except ImportError: + check_warn("Cannot test DLPack - decord or cupy not available") + +# ============================================================ +# 6. Check Fast CUDA Kernels +# ============================================================ +print_section("6. FAST CUDA KERNELS (JIT COMPILED)") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from streaming.jit_compiler import ( + fast_rotate, fast_zoom, fast_blend, fast_hue_shift, + fast_invert, fast_ripple, get_fast_ops + ) + check_pass("Fast CUDA kernels loaded successfully") + + # Test one kernel + try: + import cupy as cp + test_img = cp.zeros((720, 1280, 3), dtype=cp.uint8) + result = fast_rotate(test_img, 45.0) + cp.cuda.Stream.null.synchronize() + check_pass("Fast rotate kernel working") + except Exception as e: + check_fail(f"Fast kernel execution failed: {e}") +except ImportError as e: + check_warn(f"Fast CUDA kernels not available: {e}") + check_info("Fallback to slower CuPy operations") + +# ============================================================ +# 7. Check Fused Pipeline Compiler +# ============================================================ +print_section("7. FUSED PIPELINE COMPILER") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from streaming.sexp_to_cuda import compile_frame_pipeline, compile_autonomous_pipeline + check_pass("Fused CUDA pipeline compiler available") +except ImportError as e: + check_warn(f"Fused pipeline compiler not available: {e}") + check_info("Using per-operation fallback (slower for multi-effect pipelines)") + +# ============================================================ +# 8. Check FFmpeg NVENC +# ============================================================ +print_section("8. FFMPEG NVENC (HARDWARE ENCODE)") + +try: + result = subprocess.run(["ffmpeg", "-encoders"], capture_output=True, text=True, timeout=5) + if "h264_nvenc" in result.stdout: + check_pass("FFmpeg h264_nvenc encoder available") + else: + check_warn("FFmpeg h264_nvenc not available - using libx264 (CPU)") + + if "hevc_nvenc" in result.stdout: + check_pass("FFmpeg hevc_nvenc encoder available") +except Exception as e: + check_fail(f"FFmpeg check failed: {e}") + +# ============================================================ +# 9. Check FFmpeg NVDEC +# ============================================================ +print_section("9. FFMPEG NVDEC (HARDWARE DECODE)") + +try: + result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5) + if "cuda" in result.stdout: + check_pass("FFmpeg CUDA hwaccel available") + else: + check_warn("FFmpeg CUDA hwaccel not available - using CPU decode") +except Exception as e: + check_fail(f"FFmpeg hwaccel check failed: {e}") + +# ============================================================ +# 10. Check Pipeline Cache Status +# ============================================================ +print_section("10. PIPELINE CACHE STATUS") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from sexp_effects.primitive_libs.streaming_gpu import ( + _FUSED_PIPELINE_CACHE, _AUTONOMOUS_PIPELINE_CACHE + ) + fused_count = len(_FUSED_PIPELINE_CACHE) + auto_count = len(_AUTONOMOUS_PIPELINE_CACHE) + + if fused_count > 0 or auto_count > 0: + check_info(f"Fused pipeline cache: {fused_count} entries") + check_info(f"Autonomous pipeline cache: {auto_count} entries") + if fused_count > 100 or auto_count > 100: + check_warn("Large pipeline cache - may cause memory pressure") + else: + check_info("Pipeline caches empty (no rendering done yet)") +except Exception as e: + check_info(f"Could not check pipeline cache: {e}") + +# ============================================================ +# Summary +# ============================================================ +print_section("SUMMARY") +print(""" +Optimal GPU rendering requires: + 1. [CRITICAL] CuPy with working GPU operations + 2. [CRITICAL] DLPack zero-copy transfer (decord -> CuPy) + 3. [HIGH] Fast CUDA kernels from jit_compiler + 4. [MEDIUM] Fused pipeline compiler for multi-effect recipes + 5. [MEDIUM] PyNvVideoCodec for zero-copy encoding + 6. [LOW] FFmpeg NVENC/NVDEC as fallback + +If DLPack is failing, check: + - decord version (needs 0.6.0+ with DLPack support) + - CuPy version compatibility + - CUDA toolkit version match + +If fast kernels are not loading: + - Check if streaming/jit_compiler.py exists + - Verify CUDA compiler (nvcc) is available +""") diff --git a/l1/docker-compose.gpu-dev.yml b/l1/docker-compose.gpu-dev.yml new file mode 100644 index 0000000..1facb3b --- /dev/null +++ b/l1/docker-compose.gpu-dev.yml @@ -0,0 +1,36 @@ +# GPU Worker Development Override +# +# Usage: docker stack deploy -c docker-compose.yml -c docker-compose.gpu-dev.yml celery +# Or for quick testing: docker-compose -f docker-compose.yml -f docker-compose.gpu-dev.yml up l1-gpu-worker +# +# Features: +# - Mounts source code for instant changes (no rebuild needed) +# - Uses watchmedo for auto-reload on file changes +# - Shows config on startup + +version: '3.8' + +services: + l1-gpu-worker: + # Override command to use watchmedo for auto-reload + command: > + sh -c " + pip install -q watchdog[watchmedo] 2>/dev/null || true; + echo '=== GPU WORKER DEV MODE ==='; + echo 'Source mounted - changes take effect on restart'; + echo 'Auto-reload enabled via watchmedo'; + env | grep -E 'STREAMING_GPU|IPFS_GATEWAY|REDIS|DATABASE' | sort; + echo '==========================='; + watchmedo auto-restart --directory=/app --pattern='*.py' --recursive -- \ + celery -A celery_app worker --loglevel=info -E -Q gpu,celery + " + environment: + # Development defaults (can override with .env) + - STREAMING_GPU_PERSIST=0 + - IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs + - SHOW_CONFIG=1 + volumes: + # Mount source code for hot reload + - ./:/app:ro + # Keep cache local + - gpu_cache:/data/cache diff --git a/l1/docker-compose.yml b/l1/docker-compose.yml new file mode 100644 index 0000000..301e439 --- /dev/null +++ b/l1/docker-compose.yml @@ -0,0 +1,191 @@ +version: "3.8" + +services: + redis: + image: redis:7-alpine + ports: + - target: 6379 + published: 16379 + mode: host # Bypass swarm routing mesh + volumes: + - redis_data:/data + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + postgres: + image: postgres:16-alpine + env_file: + - .env + environment: + - POSTGRES_USER=artdag + - POSTGRES_DB=artdag + ports: + - target: 5432 + published: 15432 + mode: host # Expose for GPU worker on different VPC + volumes: + - postgres_data:/var/lib/postgresql/data + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + ipfs: + image: ipfs/kubo:latest + ports: + - "4001:4001" # Swarm TCP + - "4001:4001/udp" # Swarm UDP + - target: 5001 + published: 15001 + mode: host # API port for GPU worker on different VPC + volumes: + - ipfs_data:/data/ipfs + - l1_cache:/data/cache:ro # Read-only access to cache for adding files + networks: + - celery + - externalnet # For gateway access + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + l1-server: + image: registry.rose-ash.com:5000/celery-l1-server:latest + env_file: + - .env + environment: + - REDIS_URL=redis://redis:6379/5 + # IPFS_API multiaddr - used for all IPFS operations (add, cat, pin) + - IPFS_API=/dns/ipfs/tcp/5001 + - CACHE_DIR=/data/cache + # Coop app internal URLs for fragment composition + - INTERNAL_URL_BLOG=http://blog:8000 + - INTERNAL_URL_CART=http://cart:8000 + - INTERNAL_URL_ACCOUNT=http://account:8000 + # DATABASE_URL, ADMIN_TOKEN, ARTDAG_CLUSTER_KEY, + # L2_SERVER, L2_DOMAIN, IPFS_GATEWAY_URL from .env file + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8100/health')"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 15s + volumes: + - l1_cache:/data/cache + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + - externalnet + deploy: + replicas: 1 + update_config: + order: start-first + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + l1-worker: + image: registry.rose-ash.com:5000/celery-l1-server:latest + command: sh -c "find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; celery -A celery_app worker --loglevel=info -E" + env_file: + - .env + environment: + - REDIS_URL=redis://redis:6379/5 + # IPFS_API multiaddr - used for all IPFS operations (add, cat, pin) + - IPFS_API=/dns/ipfs/tcp/5001 + - CACHE_DIR=/data/cache + - C_FORCE_ROOT=true + # DATABASE_URL, ARTDAG_CLUSTER_KEY from .env file + volumes: + - l1_cache:/data/cache + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + deploy: + replicas: 2 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + flower: + image: mher/flower:2.0 + command: celery --broker=redis://redis:6379/5 flower --port=5555 + environment: + - CELERY_BROKER_URL=redis://redis:6379/5 + - FLOWER_PORT=5555 + depends_on: + - redis + networks: + - celery + - externalnet + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + # GPU worker for streaming/rendering tasks + # Build: docker build -f Dockerfile.gpu -t registry.rose-ash.com:5000/celery-l1-gpu-server:latest . + # Requires: docker node update --label-add gpu=true + l1-gpu-worker: + image: registry.rose-ash.com:5000/celery-l1-gpu-server:latest + command: sh -c "cd /app && celery -A celery_app worker --loglevel=info -E -Q gpu,celery" + env_file: + - .env.gpu + volumes: + # Local cache - ephemeral, just for working files + - gpu_cache:/data/cache + # Note: No source mount - GPU worker uses code from image + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu == true + +volumes: + redis_data: + postgres_data: + ipfs_data: + l1_cache: + gpu_cache: # Ephemeral cache for GPU workers + +networks: + celery: + driver: overlay + externalnet: + external: true diff --git a/l1/effects/quick_test_explicit.sexp b/l1/effects/quick_test_explicit.sexp new file mode 100644 index 0000000..0a3698b --- /dev/null +++ b/l1/effects/quick_test_explicit.sexp @@ -0,0 +1,150 @@ +;; Quick Test - Fully Explicit Streaming Version +;; +;; The interpreter is completely generic - knows nothing about video/audio. +;; All domain logic is explicit via primitives. +;; +;; Run with built-in sources/audio: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp --fps 30 +;; +;; Run with external config files: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp \ +;; --sources configs/sources-default.sexp \ +;; --audio configs/audio-dizzy.sexp \ +;; --fps 30 + +(stream "quick_test_explicit" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives and effects + (include :path "../templates/standard-primitives.sexp") + (include :path "../templates/standard-effects.sexp") + + ;; Load reusable templates + (include :path "../templates/stream-process-pair.sexp") + (include :path "../templates/crossfade-zoom.sexp") + + ;; === SOURCES AS ARRAY === + (def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) + ]) + + ;; Per-pair config: [rot-dir, rot-a-max, rot-b-max, zoom-a-max, zoom-b-max] + ;; Pairs 3,6: reversed (negative rot-a, positive rot-b, shrink zoom-a, grow zoom-b) + ;; Pair 5: smaller ranges + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 + ]) + + ;; Audio analyzer + (def music (streaming:make-audio-analyzer "dizzy.mp3")) + + ;; Audio playback + (audio-playback "../dizzy.mp3") + + ;; === GLOBAL SCANS === + + ;; Cycle state: which source is active (recipe-specific) + ;; clen = beats per source (8-24 beats = ~4-12 seconds) + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Reusable scans from templates (require 'music' to be defined) + (include :path "../templates/scan-oscillating-spin.sexp") + (include :path "../templates/scan-ripple-drops.sexp") + + ;; === PER-PAIR STATE (dynamically sized based on sources) === + ;; Each pair has: inv-a, inv-b, hue-a, hue-b, mix, rot-angle + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [;; Invert toggles (10% chance, lasts 1-4 beats) + new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + ;; Hue shifts (10% chance, lasts 1-4 beats) - use countdown like invert + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + ;; Pick random hue value when triggering (stored separately) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + ;; Mix (holds for 1-10 beats, then picks 0, 0.5, or 1) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + ;; Rotation (accumulates, reverses direction when cycle completes) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + ;; Note: dir comes from pair-configs, but we store rotation state here + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic: last third of cycle crossfades to next + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Get pair states array (required by process-pair macro) + pair-states (bind pairs :states) + + ;; Process active pair using macro from template + active-frame (process-pair active) + + ;; Crossfade with zoom during transition (using macro) + result (if fading + (crossfade-zoom active-frame (process-pair next-idx) fade-amt) + active-frame) + + ;; Final: global spin + ripple + spun (rotate result :angle (bind spin :angle)) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50))] + + (ripple spun + :amplitude rip-amp + :center_x (bind ripple-state :cx) + :center_y (bind ripple-state :cy) + :frequency 8 + :decay 2 + :speed 5)))) diff --git a/l1/ipfs_client.py b/l1/ipfs_client.py new file mode 100644 index 0000000..3edf5b1 --- /dev/null +++ b/l1/ipfs_client.py @@ -0,0 +1,345 @@ +# art-celery/ipfs_client.py +""" +IPFS client for Art DAG L1 server. + +Provides functions to add, retrieve, and pin files on IPFS. +Uses direct HTTP API calls for compatibility with all Kubo versions. +""" + +import logging +import os +import re +from pathlib import Path +from typing import Optional, Union + +import requests + +logger = logging.getLogger(__name__) + +# IPFS API multiaddr - default to local, docker uses /dns/ipfs/tcp/5001 +IPFS_API = os.getenv("IPFS_API", "/ip4/127.0.0.1/tcp/5001") + +# Connection timeout in seconds (increased for large files) +IPFS_TIMEOUT = int(os.getenv("IPFS_TIMEOUT", "120")) + +# IPFS gateway URLs for fallback when local node doesn't have content +# Comma-separated list of gateway URLs (without /ipfs/ suffix) +IPFS_GATEWAYS = [g.strip() for g in os.getenv( + "IPFS_GATEWAYS", + "https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link" +).split(",") if g.strip()] + +# Gateway timeout (shorter than API timeout for faster fallback) +GATEWAY_TIMEOUT = int(os.getenv("GATEWAY_TIMEOUT", "30")) + + +def _multiaddr_to_url(multiaddr: str) -> str: + """Convert IPFS multiaddr to HTTP URL.""" + # Handle /dns/hostname/tcp/port format + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +# Base URL for IPFS API +IPFS_BASE_URL = _multiaddr_to_url(IPFS_API) + + +def add_file(file_path: Union[Path, str], pin: bool = True) -> Optional[str]: + """ + Add a file to IPFS and optionally pin it. + + Args: + file_path: Path to the file to add (Path object or string) + pin: Whether to pin the file (default: True) + + Returns: + IPFS CID (content identifier) or None on failure + """ + try: + # Ensure file_path is a Path object + if isinstance(file_path, str): + file_path = Path(file_path) + + url = f"{IPFS_BASE_URL}/api/v0/add" + params = {"pin": str(pin).lower()} + + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f)} + response = requests.post(url, params=params, files=files, timeout=IPFS_TIMEOUT) + + response.raise_for_status() + result = response.json() + cid = result["Hash"] + logger.info(f"Added to IPFS: {file_path.name} -> {cid}") + return cid + except Exception as e: + logger.error(f"Failed to add to IPFS: {e}") + return None + + +def add_bytes(data: bytes, pin: bool = True) -> Optional[str]: + """ + Add bytes data to IPFS and optionally pin it. + + Args: + data: Bytes to add + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/add" + params = {"pin": str(pin).lower()} + files = {"file": ("data", data)} + + response = requests.post(url, params=params, files=files, timeout=IPFS_TIMEOUT) + response.raise_for_status() + result = response.json() + cid = result["Hash"] + + logger.info(f"Added bytes to IPFS: {len(data)} bytes -> {cid}") + return cid + except Exception as e: + logger.error(f"Failed to add bytes to IPFS: {e}") + return None + + +def add_json(data: dict, pin: bool = True) -> Optional[str]: + """ + Serialize dict to JSON and add to IPFS. + + Args: + data: Dictionary to serialize and store + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + import json + json_bytes = json.dumps(data, indent=2, sort_keys=True).encode('utf-8') + return add_bytes(json_bytes, pin=pin) + + +def add_string(content: str, pin: bool = True) -> Optional[str]: + """ + Add a string to IPFS and optionally pin it. + + Args: + content: String content to add (e.g., S-expression) + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + return add_bytes(content.encode('utf-8'), pin=pin) + + +def get_file(cid: str, dest_path: Union[Path, str]) -> bool: + """ + Retrieve a file from IPFS and save to destination. + + Args: + cid: IPFS CID to retrieve + dest_path: Path to save the file (Path object or string) + + Returns: + True on success, False on failure + """ + try: + data = get_bytes(cid) + if data is None: + return False + + # Ensure dest_path is a Path object + if isinstance(dest_path, str): + dest_path = Path(dest_path) + + dest_path.parent.mkdir(parents=True, exist_ok=True) + dest_path.write_bytes(data) + logger.info(f"Retrieved from IPFS: {cid} -> {dest_path}") + return True + except Exception as e: + logger.error(f"Failed to get from IPFS: {e}") + return False + + +def get_bytes_from_gateway(cid: str) -> Optional[bytes]: + """ + Retrieve bytes from IPFS via public gateways (fallback). + + Tries each configured gateway in order until one succeeds. + + Args: + cid: IPFS CID to retrieve + + Returns: + File content as bytes or None if all gateways fail + """ + for gateway in IPFS_GATEWAYS: + try: + url = f"{gateway}/ipfs/{cid}" + logger.info(f"Trying gateway: {url}") + response = requests.get(url, timeout=GATEWAY_TIMEOUT) + response.raise_for_status() + data = response.content + logger.info(f"Retrieved from gateway {gateway}: {cid} ({len(data)} bytes)") + return data + except Exception as e: + logger.warning(f"Gateway {gateway} failed for {cid}: {e}") + continue + + logger.error(f"All gateways failed for {cid}") + return None + + +def get_bytes(cid: str, use_gateway_fallback: bool = True) -> Optional[bytes]: + """ + Retrieve bytes data from IPFS. + + Tries local IPFS node first, then falls back to public gateways + if configured and use_gateway_fallback is True. + + Args: + cid: IPFS CID to retrieve + use_gateway_fallback: If True, try public gateways on local failure + + Returns: + File content as bytes or None on failure + """ + # Try local IPFS node first + try: + url = f"{IPFS_BASE_URL}/api/v0/cat" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + data = response.content + + logger.info(f"Retrieved from IPFS: {cid} ({len(data)} bytes)") + return data + except Exception as e: + logger.warning(f"Local IPFS failed for {cid}: {e}") + + # Try gateway fallback + if use_gateway_fallback and IPFS_GATEWAYS: + logger.info(f"Trying gateway fallback for {cid}") + return get_bytes_from_gateway(cid) + + logger.error(f"Failed to get bytes from IPFS: {e}") + return None + + +def pin(cid: str) -> bool: + """ + Pin a CID on IPFS. + + Args: + cid: IPFS CID to pin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/add" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Pinned on IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to pin on IPFS: {e}") + return False + + +def unpin(cid: str) -> bool: + """ + Unpin a CID on IPFS. + + Args: + cid: IPFS CID to unpin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/rm" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Unpinned on IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to unpin on IPFS: {e}") + return False + + +def is_pinned(cid: str) -> bool: + """ + Check if a CID is pinned on IPFS. + + Args: + cid: IPFS CID to check + + Returns: + True if pinned, False otherwise + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/ls" + params = {"arg": cid, "type": "recursive"} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + if response.status_code == 200: + result = response.json() + return cid in result.get("Keys", {}) + return False + except Exception as e: + logger.error(f"Failed to check pin status: {e}") + return False + + +def is_available() -> bool: + """ + Check if IPFS daemon is available. + + Returns: + True if IPFS is available, False otherwise + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def get_node_id() -> Optional[str]: + """ + Get this IPFS node's peer ID. + + Returns: + Peer ID string or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=IPFS_TIMEOUT) + response.raise_for_status() + return response.json().get("ID") + except Exception as e: + logger.error(f"Failed to get node ID: {e}") + return None diff --git a/l1/path_registry.py b/l1/path_registry.py new file mode 100644 index 0000000..985be18 --- /dev/null +++ b/l1/path_registry.py @@ -0,0 +1,477 @@ +""" +Path Registry - Maps human-friendly paths to content-addressed IDs. + +This module provides a bidirectional mapping between: +- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp") +- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes) + +The registry is useful for: +- Looking up effects by their friendly path name +- Resolving cids back to the original path for debugging +- Maintaining a stable naming scheme across cache updates + +Storage: +- Uses the existing item_types table in the database (path column) +- Caches in Redis for fast lookups across distributed workers + +The registry uses a system actor (@system@local) for global path mappings, +allowing effects to be resolved by path without requiring user context. +""" + +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# System actor for global path mappings (effects, recipes, analyzers) +SYSTEM_ACTOR = "@system@local" + + +@dataclass +class PathEntry: + """A registered path with its content-addressed ID.""" + path: str # Human-friendly path (relative or normalized) + cid: str # Content-addressed ID (IPFS CID or hash) + content_type: str # Type: "effect", "recipe", "analyzer", etc. + actor_id: str = SYSTEM_ACTOR # Owner (system for global) + description: Optional[str] = None + created_at: float = 0.0 + + +class PathRegistry: + """ + Registry for mapping paths to content-addressed IDs. + + Uses the existing item_types table for persistence and Redis + for fast lookups in distributed Celery workers. + """ + + def __init__(self, redis_client=None): + self._redis = redis_client + self._redis_path_to_cid_key = "artdag:path_to_cid" + self._redis_cid_to_path_key = "artdag:cid_to_path" + + def _run_async(self, coro): + """Run async coroutine from sync context.""" + import asyncio + + try: + loop = asyncio.get_running_loop() + import threading + result = [None] + error = [None] + + def run_in_thread(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + result[0] = new_loop.run_until_complete(coro) + finally: + new_loop.close() + except Exception as e: + error[0] = e + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) + if error[0]: + raise error[0] + return result[0] + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + def _normalize_path(self, path: str) -> str: + """Normalize a path for consistent storage.""" + # Remove leading ./ or / + path = path.lstrip('./') + # Normalize separators + path = path.replace('\\', '/') + # Remove duplicate slashes + while '//' in path: + path = path.replace('//', '/') + return path + + def register( + self, + path: str, + cid: str, + content_type: str = "effect", + actor_id: str = SYSTEM_ACTOR, + description: Optional[str] = None, + ) -> PathEntry: + """ + Register a path -> cid mapping. + + Args: + path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp") + cid: Content-addressed ID (IPFS CID or hash) + content_type: Type of content ("effect", "recipe", "analyzer") + actor_id: Owner (default: system for global mappings) + description: Optional description + + Returns: + The created PathEntry + """ + norm_path = self._normalize_path(path) + now = datetime.now(timezone.utc).timestamp() + + entry = PathEntry( + path=norm_path, + cid=cid, + content_type=content_type, + actor_id=actor_id, + description=description, + created_at=now, + ) + + # Store in database (item_types table) + self._save_to_db(entry) + + # Update Redis cache + self._update_redis_cache(norm_path, cid) + + logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...") + return entry + + def _save_to_db(self, entry: PathEntry): + """Save entry to database using item_types table.""" + import database + + async def save(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + entry.cid + ) + # Insert or update item_type with path + await conn.execute( + """ + INSERT INTO item_types (cid, actor_id, type, path, description) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE(EXCLUDED.description, item_types.description) + """, + entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description + ) + finally: + await conn.close() + + try: + self._run_async(save()) + except Exception as e: + logger.warning(f"Failed to save path registry to DB: {e}") + + def _update_redis_cache(self, path: str, cid: str): + """Update Redis cache with mapping.""" + if self._redis: + try: + self._redis.hset(self._redis_path_to_cid_key, path, cid) + self._redis.hset(self._redis_cid_to_path_key, cid, path) + except Exception as e: + logger.warning(f"Failed to update Redis cache: {e}") + + def get_cid(self, path: str, content_type: str = None) -> Optional[str]: + """ + Get the cid for a path. + + Args: + path: Human-friendly path + content_type: Optional type filter + + Returns: + The cid, or None if not found + """ + norm_path = self._normalize_path(path) + + # Try Redis first (fast path) + if self._redis: + try: + val = self._redis.hget(self._redis_path_to_cid_key, norm_path) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_cid_from_db(norm_path, content_type) + + def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]: + """Get cid from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1", + path + ) + return row["cid"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(path, result) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def get_path(self, cid: str) -> Optional[str]: + """ + Get the path for a cid. + + Args: + cid: Content-addressed ID + + Returns: + The path, or None if not found + """ + # Try Redis first + if self._redis: + try: + val = self._redis.hget(self._redis_cid_to_path_key, cid) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_path_from_db(cid) + + def _get_path_from_db(self, cid: str) -> Optional[str]: + """Get path from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + row = await conn.fetchrow( + "SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1", + cid + ) + return row["path"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(result, cid) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]: + """ + List all entries of a given type. + + Args: + content_type: Type to filter by ("effect", "recipe", etc.) + actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global) + + Returns: + List of PathEntry objects + """ + import database + + async def list_entries(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if actor_id: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL + ORDER BY path + """, + content_type, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND path IS NOT NULL + ORDER BY path + """, + content_type + ) + return [ + PathEntry( + path=row["path"], + cid=row["cid"], + content_type=row["type"], + actor_id=row["actor_id"], + description=row["description"], + created_at=row["created_at"] or 0, + ) + for row in rows + ] + finally: + await conn.close() + + try: + return self._run_async(list_entries()) + except Exception as e: + logger.warning(f"Failed to list from DB: {e}") + return [] + + def delete(self, path: str, content_type: str = None) -> bool: + """ + Delete a path registration. + + Args: + path: The path to delete + content_type: Optional type filter + + Returns: + True if deleted, False if not found + """ + norm_path = self._normalize_path(path) + + # Get cid for Redis cleanup + cid = self.get_cid(norm_path, content_type) + + # Delete from database + deleted = self._delete_from_db(norm_path, content_type) + + # Clean up Redis + if deleted and cid and self._redis: + try: + self._redis.hdel(self._redis_path_to_cid_key, norm_path) + self._redis.hdel(self._redis_cid_to_path_key, cid) + except Exception as e: + logger.warning(f"Failed to clean up Redis: {e}") + + return deleted + + def _delete_from_db(self, path: str, content_type: str = None) -> bool: + """Delete from database.""" + import database + + async def delete(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1", + path + ) + return "DELETE" in result + finally: + await conn.close() + + try: + return self._run_async(delete()) + except Exception as e: + logger.warning(f"Failed to delete from DB: {e}") + return False + + def register_effect( + self, + path: str, + cid: str, + description: Optional[str] = None, + ) -> PathEntry: + """ + Convenience method to register an effect. + + Args: + path: Effect path (e.g., "effects/ascii_fx_zone.sexp") + cid: IPFS CID of the effect file + description: Optional description + + Returns: + The created PathEntry + """ + return self.register( + path=path, + cid=cid, + content_type="effect", + actor_id=SYSTEM_ACTOR, + description=description, + ) + + def get_effect_cid(self, path: str) -> Optional[str]: + """ + Get CID for an effect by path. + + Args: + path: Effect path + + Returns: + IPFS CID or None + """ + return self.get_cid(path, content_type="effect") + + def list_effects(self) -> List[PathEntry]: + """List all registered effects.""" + return self.list_by_type("effect", actor_id=SYSTEM_ACTOR) + + +# Singleton instance +_registry: Optional[PathRegistry] = None + + +def get_path_registry() -> PathRegistry: + """Get the singleton path registry instance.""" + global _registry + if _registry is None: + import redis + from urllib.parse import urlparse + + redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + parsed = urlparse(redis_url) + redis_client = redis.Redis( + host=parsed.hostname or 'localhost', + port=parsed.port or 6379, + db=int(parsed.path.lstrip('/') or 0), + socket_timeout=5, + socket_connect_timeout=5 + ) + + _registry = PathRegistry(redis_client=redis_client) + return _registry + + +def reset_path_registry(): + """Reset the singleton (for testing).""" + global _registry + _registry = None diff --git a/l1/pyproject.toml b/l1/pyproject.toml new file mode 100644 index 0000000..b358312 --- /dev/null +++ b/l1/pyproject.toml @@ -0,0 +1,51 @@ +[project] +name = "art-celery" +version = "0.1.0" +description = "Art DAG L1 Server and Celery Workers" +requires-python = ">=3.11" + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +strict_optional = true +no_implicit_optional = true + +# Start strict on new code, gradually enable for existing +files = [ + "app/types.py", + "app/routers/recipes.py", + "tests/", +] + +# Ignore missing imports for third-party packages without stubs +[[tool.mypy.overrides]] +module = [ + "celery.*", + "redis.*", + "artdag.*", + "artdag_common.*", + "ipfs_client.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +asyncio_mode = "auto" +addopts = "-v --tb=short" +filterwarnings = [ + "ignore::DeprecationWarning", +] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "UP"] +ignore = ["E501"] # Line length handled separately diff --git a/l1/recipes/woods-lowres.sexp b/l1/recipes/woods-lowres.sexp new file mode 100644 index 0000000..55a1a6a --- /dev/null +++ b/l1/recipes/woods-lowres.sexp @@ -0,0 +1,223 @@ +;; Woods Recipe - OPTIMIZED VERSION +;; +;; Uses fused-pipeline for GPU acceleration when available, +;; falls back to individual primitives on CPU. +;; +;; Key optimizations: +;; 1. Uses streaming_gpu primitives with fast CUDA kernels +;; 2. Uses fused-pipeline to batch effects into single kernel passes +;; 3. GPU persistence - frames stay on GPU throughout pipeline + +(stream "woods-lowres" + :fps 30 + :width 640 + :height 360 + :seed 42 + + ;; Load standard primitives (includes proper asset resolution) + ;; Auto-selects GPU versions when available, falls back to CPU + (include :name "tpl-standard-primitives") + + ;; === SOURCES (using streaming: which has proper asset resolution) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio + (def music (streaming:make-audio-analyzer "woods-audio")) + (audio-playback "woods-audio") + + ;; === SCANS === + + ;; Cycle state + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Spin scan + (scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :speed 2} + :step (let [new-dir (if (< (core:rand) 0.05) (* dir -1) dir) + new-speed (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) speed)] + (dict :angle (+ angle (* new-dir new-speed)) + :dir new-dir + :speed new-speed))) + + ;; Ripple scan - raindrop style, all params randomized + ;; Higher freq = bigger gaps between waves (formula is dist/freq) + (scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 320 :cy 180 :freq 20 :decay 6 :amp-mult 1.0} + :step (let [new-gate (if (< (core:rand) 0.2) (+ 2 (core:rand-int 0 4)) (core:max 0 (- gate 1))) + triggered (> new-gate gate) + new-cx (if triggered (core:rand-int 50 590) cx) + new-cy (if triggered (core:rand-int 50 310) cy) + new-freq (if triggered (+ 15 (core:rand-int 0 20)) freq) + new-decay (if triggered (+ 5 (core:rand-int 0 4)) decay) + new-amp-mult (if triggered (+ 0.8 (* (core:rand) 1.2)) amp-mult)] + (dict :gate new-gate :cx new-cx :cy new-cy :freq new-freq :decay new-decay :amp-mult new-amp-mult))) + + ;; Pair states + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === OPTIMIZED PROCESS-PAIR MACRO === + ;; Uses fused-pipeline to batch rotate+hue+invert into single kernel + (defmacro process-pair-fast (idx) + (let [;; Get sources for this pair (with safe modulo indexing) + num-sources (len sources) + src-a (nth sources (mod (* idx 2) num-sources)) + src-b (nth sources (mod (+ (* idx 2) 1) num-sources)) + cfg (nth pair-configs idx) + pstate (nth (bind pairs :states) idx) + + ;; Read frames (GPU decode, stays on GPU) + frame-a (streaming:source-read src-a t) + frame-b (streaming:source-read src-b t) + + ;; Get state values + dir (get cfg :dir) + rot-max-a (get cfg :rot-a) + rot-max-b (get cfg :rot-b) + zoom-max-a (get cfg :zoom-a) + zoom-max-b (get cfg :zoom-b) + pair-angle (get pstate :angle) + inv-a-on (> (get pstate :inv-a) 0) + inv-b-on (> (get pstate :inv-b) 0) + hue-a-on (> (get pstate :hue-a) 0) + hue-b-on (> (get pstate :hue-b) 0) + hue-a-val (get pstate :hue-a-val) + hue-b-val (get pstate :hue-b-val) + mix-ratio (get pstate :mix) + + ;; Calculate rotation angles + angle-a (* dir pair-angle rot-max-a 0.01) + angle-b (* dir pair-angle rot-max-b 0.01) + + ;; Energy-driven zoom (maps audio energy 0-1 to 1-max) + zoom-a (core:map-range e 0 1 1 zoom-max-a) + zoom-b (core:map-range e 0 1 1 zoom-max-b) + + ;; Define effect pipelines for each source + ;; These get compiled to single CUDA kernels! + ;; First resize to target resolution, then apply effects + effects-a [{:op "resize" :width 640 :height 360} + {:op "zoom" :amount zoom-a} + {:op "rotate" :angle angle-a} + {:op "hue_shift" :degrees (if hue-a-on hue-a-val 0)} + {:op "invert" :amount (if inv-a-on 1 0)}] + effects-b [{:op "resize" :width 640 :height 360} + {:op "zoom" :amount zoom-b} + {:op "rotate" :angle angle-b} + {:op "hue_shift" :degrees (if hue-b-on hue-b-val 0)} + {:op "invert" :amount (if inv-b-on 1 0)}] + + ;; Apply fused pipelines (single kernel per source!) + processed-a (streaming:fused-pipeline frame-a effects-a) + processed-b (streaming:fused-pipeline frame-b effects-b)] + + ;; Blend the two processed frames + (blending:blend-images processed-a processed-b mix-ratio))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Process active pair with fused pipeline + active-frame (process-pair-fast active) + + ;; Crossfade with zoom during transition + ;; Old pair: zooms out (1.0 -> 2.0) and fades out + ;; New pair: starts small (0.1), zooms in (-> 1.0) and fades in + result (if fading + (let [next-frame (process-pair-fast next-idx) + ;; Active zooms out as it fades + active-zoom (+ 1.0 fade-amt) + active-zoomed (streaming:fused-pipeline active-frame + [{:op "zoom" :amount active-zoom}]) + ;; Next starts small and zooms in + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (streaming:fused-pipeline next-frame + [{:op "zoom" :amount next-zoom}])] + (blending:blend-images active-zoomed next-zoomed fade-amt)) + active-frame) + + ;; Final effects pipeline (fused!) + spin-angle (bind spin :angle) + ;; Ripple params - all randomized per ripple trigger + rip-gate (bind ripple-state :gate) + rip-amp-mult (bind ripple-state :amp-mult) + rip-amp (* rip-gate rip-amp-mult (core:map-range e 0 1 50 200)) + rip-cx (bind ripple-state :cx) + rip-cy (bind ripple-state :cy) + rip-freq (bind ripple-state :freq) + rip-decay (bind ripple-state :decay) + + ;; Fused final effects + final-effects [{:op "rotate" :angle spin-angle} + {:op "ripple" :amplitude rip-amp :frequency rip-freq :decay rip-decay + :phase (* now 5) :center_x rip-cx :center_y rip-cy}]] + + ;; Apply final fused pipeline + (streaming:fused-pipeline result final-effects + :rotate_angle spin-angle + :ripple_phase (* now 5) + :ripple_amplitude rip-amp)))) diff --git a/l1/recipes/woods-recipe-optimized.sexp b/l1/recipes/woods-recipe-optimized.sexp new file mode 100644 index 0000000..bec96b8 --- /dev/null +++ b/l1/recipes/woods-recipe-optimized.sexp @@ -0,0 +1,211 @@ +;; Woods Recipe - OPTIMIZED VERSION +;; +;; Uses fused-pipeline for GPU acceleration when available, +;; falls back to individual primitives on CPU. +;; +;; Key optimizations: +;; 1. Uses streaming_gpu primitives with fast CUDA kernels +;; 2. Uses fused-pipeline to batch effects into single kernel passes +;; 3. GPU persistence - frames stay on GPU throughout pipeline + +(stream "woods-recipe-optimized" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives (includes proper asset resolution) + ;; Auto-selects GPU versions when available, falls back to CPU + (include :name "tpl-standard-primitives") + + ;; === SOURCES (using streaming: which has proper asset resolution) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio + (def music (streaming:make-audio-analyzer "woods-audio")) + (audio-playback "woods-audio") + + ;; === SCANS === + + ;; Cycle state + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Spin scan + (scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :speed 2} + :step (let [new-dir (if (< (core:rand) 0.05) (* dir -1) dir) + new-speed (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) speed)] + (dict :angle (+ angle (* new-dir new-speed)) + :dir new-dir + :speed new-speed))) + + ;; Ripple scan + (scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 960 :cy 540} + :step (let [new-gate (if (< (core:rand) 0.15) (+ 3 (core:rand-int 0 5)) (core:max 0 (- gate 1))) + new-cx (if (> new-gate gate) (+ 200 (core:rand-int 0 1520)) cx) + new-cy (if (> new-gate gate) (+ 200 (core:rand-int 0 680)) cy)] + (dict :gate new-gate :cx new-cx :cy new-cy))) + + ;; Pair states + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === OPTIMIZED PROCESS-PAIR MACRO === + ;; Uses fused-pipeline to batch rotate+hue+invert into single kernel + (defmacro process-pair-fast (idx) + (let [;; Get sources for this pair (with safe modulo indexing) + num-sources (len sources) + src-a (nth sources (mod (* idx 2) num-sources)) + src-b (nth sources (mod (+ (* idx 2) 1) num-sources)) + cfg (nth pair-configs idx) + pstate (nth (bind pairs :states) idx) + + ;; Read frames (GPU decode, stays on GPU) + frame-a (streaming:source-read src-a t) + frame-b (streaming:source-read src-b t) + + ;; Get state values + dir (get cfg :dir) + rot-max-a (get cfg :rot-a) + rot-max-b (get cfg :rot-b) + zoom-max-a (get cfg :zoom-a) + zoom-max-b (get cfg :zoom-b) + pair-angle (get pstate :angle) + inv-a-on (> (get pstate :inv-a) 0) + inv-b-on (> (get pstate :inv-b) 0) + hue-a-on (> (get pstate :hue-a) 0) + hue-b-on (> (get pstate :hue-b) 0) + hue-a-val (get pstate :hue-a-val) + hue-b-val (get pstate :hue-b-val) + mix-ratio (get pstate :mix) + + ;; Calculate rotation angles + angle-a (* dir pair-angle rot-max-a 0.01) + angle-b (* dir pair-angle rot-max-b 0.01) + + ;; Energy-driven zoom (maps audio energy 0-1 to 1-max) + zoom-a (core:map-range e 0 1 1 zoom-max-a) + zoom-b (core:map-range e 0 1 1 zoom-max-b) + + ;; Define effect pipelines for each source + ;; These get compiled to single CUDA kernels! + effects-a [{:op "zoom" :amount zoom-a} + {:op "rotate" :angle angle-a} + {:op "hue_shift" :degrees (if hue-a-on hue-a-val 0)} + {:op "invert" :amount (if inv-a-on 1 0)}] + effects-b [{:op "zoom" :amount zoom-b} + {:op "rotate" :angle angle-b} + {:op "hue_shift" :degrees (if hue-b-on hue-b-val 0)} + {:op "invert" :amount (if inv-b-on 1 0)}] + + ;; Apply fused pipelines (single kernel per source!) + processed-a (streaming:fused-pipeline frame-a effects-a) + processed-b (streaming:fused-pipeline frame-b effects-b)] + + ;; Blend the two processed frames + (blending:blend-images processed-a processed-b mix-ratio))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Process active pair with fused pipeline + active-frame (process-pair-fast active) + + ;; Crossfade with zoom during transition + ;; Old pair: zooms out (1.0 -> 2.0) and fades out + ;; New pair: starts small (0.1), zooms in (-> 1.0) and fades in + result (if fading + (let [next-frame (process-pair-fast next-idx) + ;; Active zooms out as it fades + active-zoom (+ 1.0 fade-amt) + active-zoomed (streaming:fused-pipeline active-frame + [{:op "zoom" :amount active-zoom}]) + ;; Next starts small and zooms in + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (streaming:fused-pipeline next-frame + [{:op "zoom" :amount next-zoom}])] + (blending:blend-images active-zoomed next-zoomed fade-amt)) + active-frame) + + ;; Final effects pipeline (fused!) + spin-angle (bind spin :angle) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50)) + rip-cx (bind ripple-state :cx) + rip-cy (bind ripple-state :cy) + + ;; Fused final effects + final-effects [{:op "rotate" :angle spin-angle} + {:op "ripple" :amplitude rip-amp :frequency 8 :decay 2 + :phase (* now 5) :center_x rip-cx :center_y rip-cy}]] + + ;; Apply final fused pipeline + (streaming:fused-pipeline result final-effects + :rotate_angle spin-angle + :ripple_phase (* now 5) + :ripple_amplitude rip-amp)))) diff --git a/l1/recipes/woods-recipe.sexp b/l1/recipes/woods-recipe.sexp new file mode 100644 index 0000000..4c5f4ec --- /dev/null +++ b/l1/recipes/woods-recipe.sexp @@ -0,0 +1,134 @@ +;; Woods Recipe - Using friendly names for all assets +;; +;; Requires uploaded: +;; - Media: woods-1 through woods-8 (videos), woods-audio (audio) +;; - Effects: fx-rotate, fx-zoom, fx-blend, fx-ripple, fx-invert, fx-hue-shift +;; - Templates: tpl-standard-primitives, tpl-standard-effects, tpl-process-pair, +;; tpl-crossfade-zoom, tpl-scan-spin, tpl-scan-ripple + +(stream "woods-recipe" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives and effects via friendly names + (include :name "tpl-standard-primitives") + (include :name "tpl-standard-effects") + + ;; Load reusable templates + (include :name "tpl-process-pair") + (include :name "tpl-crossfade-zoom") + + ;; === SOURCES AS ARRAY (using friendly names) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config: [rot-dir, rot-a-max, rot-b-max, zoom-a-max, zoom-b-max] + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio analyzer (using friendly name) + (def music (streaming:make-audio-analyzer "woods-audio")) + + ;; Audio playback (friendly name resolved by streaming primitives) + (audio-playback "woods-audio") + + ;; === GLOBAL SCANS === + + ;; Cycle state: which source is active + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Reusable scans from templates + (include :name "tpl-scan-spin") + (include :name "tpl-scan-ripple") + + ;; === PER-PAIR STATE === + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Get pair states array + pair-states (bind pairs :states) + + ;; Process active pair using macro from template + active-frame (process-pair active) + + ;; Crossfade with zoom during transition + result (if fading + (crossfade-zoom active-frame (process-pair next-idx) fade-amt) + active-frame) + + ;; Final: global spin + ripple + spun (rotate result :angle (bind spin :angle)) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50))] + + (ripple spun + :amplitude rip-amp + :center_x (bind ripple-state :cx) + :center_y (bind ripple-state :cy) + :frequency 8 + :decay 2 + :speed 5)))) diff --git a/l1/requirements-dev.txt b/l1/requirements-dev.txt new file mode 100644 index 0000000..b7e7438 --- /dev/null +++ b/l1/requirements-dev.txt @@ -0,0 +1,16 @@ +# Development dependencies +-r requirements.txt + +# Type checking +mypy>=1.8.0 +types-requests>=2.31.0 +types-PyYAML>=6.0.0 +typing_extensions>=4.9.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 + +# Linting +ruff>=0.2.0 diff --git a/l1/requirements.txt b/l1/requirements.txt new file mode 100644 index 0000000..deab545 --- /dev/null +++ b/l1/requirements.txt @@ -0,0 +1,21 @@ +celery[redis]>=5.3.0 +redis>=5.0.0 +requests>=2.31.0 +httpx>=0.27.0 +itsdangerous>=2.0 +cryptography>=41.0 +fastapi>=0.109.0 +uvicorn>=0.27.0 +python-multipart>=0.0.6 +PyYAML>=6.0 +asyncpg>=0.29.0 +markdown>=3.5.0 +# Common effect dependencies (used by uploaded effects) +numpy>=1.24.0 +opencv-python-headless>=4.8.0 +# Core artdag from GitHub (tracks main branch) +git+https://github.com/gilesbradshaw/art-dag.git@main +# Shared components (tracks master branch) +git+https://git.rose-ash.com/art-dag/common.git@master +psycopg2-binary +nest_asyncio diff --git a/l1/scripts/cloud-init-gpu.sh b/l1/scripts/cloud-init-gpu.sh new file mode 100644 index 0000000..fe8cc27 --- /dev/null +++ b/l1/scripts/cloud-init-gpu.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Cloud-init startup script for GPU droplet (RTX 6000 Ada, etc.) +# Paste this into DigitalOcean "User data" field when creating droplet + +set -e +export DEBIAN_FRONTEND=noninteractive +exec > /var/log/artdag-setup.log 2>&1 + +echo "=== ArtDAG GPU Setup Started $(date) ===" + +# Update system (non-interactive, keep existing configs) +apt-get update +apt-get -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" upgrade + +# Install essentials +apt-get install -y \ + python3 python3-venv python3-pip \ + git curl wget \ + ffmpeg \ + vulkan-tools \ + build-essential + +# Create venv +VENV_DIR="/opt/artdag-gpu" +python3 -m venv "$VENV_DIR" +source "$VENV_DIR/bin/activate" + +# Install Python packages +pip install --upgrade pip +pip install \ + numpy \ + opencv-python-headless \ + wgpu \ + httpx \ + pyyaml \ + celery[redis] \ + fastapi \ + uvicorn \ + asyncpg + +# Create code directory +mkdir -p "$VENV_DIR/celery/sexp_effects/effects" +mkdir -p "$VENV_DIR/celery/sexp_effects/primitive_libs" +mkdir -p "$VENV_DIR/celery/streaming" + +# Add SSH key for easier access (optional - add your key here) +# echo "ssh-ed25519 AAAA... your-key" >> /root/.ssh/authorized_keys + +# Test GPU +echo "=== GPU Info ===" +nvidia-smi || echo "nvidia-smi not available yet" + +echo "=== NVENC Check ===" +ffmpeg -encoders 2>/dev/null | grep -E "nvenc|cuda" || echo "NVENC not detected" + +echo "=== wgpu Check ===" +"$VENV_DIR/bin/python3" -c " +import wgpu +try: + adapter = wgpu.gpu.request_adapter_sync(power_preference='high-performance') + print(f'GPU: {adapter.info}') +except Exception as e: + print(f'wgpu error: {e}') +" || echo "wgpu test failed" + +# Add environment setup +cat >> /etc/profile.d/artdag-gpu.sh << 'ENVEOF' +export WGPU_BACKEND_TYPE=Vulkan +export PATH="/opt/artdag-gpu/bin:$PATH" +ENVEOF + +# Mark setup complete +touch /opt/artdag-gpu/.setup-complete +echo "=== Setup Complete $(date) ===" +echo "Venv: /opt/artdag-gpu" +echo "Activate: source /opt/artdag-gpu/bin/activate" +echo "Vulkan: export WGPU_BACKEND_TYPE=Vulkan" diff --git a/l1/scripts/deploy-to-gpu.sh b/l1/scripts/deploy-to-gpu.sh new file mode 100755 index 0000000..e41c802 --- /dev/null +++ b/l1/scripts/deploy-to-gpu.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Deploy art-dag GPU code to a remote droplet +# Usage: ./deploy-to-gpu.sh + +set -e + +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "Example: $0 159.223.7.100" + exit 1 +fi + +DROPLET_IP="$1" +REMOTE_DIR="/opt/artdag-gpu/celery" +LOCAL_DIR="$(dirname "$0")/.." + +echo "=== Deploying to $DROPLET_IP ===" + +# Create remote directory +echo "[1/4] Creating remote directory..." +ssh "root@$DROPLET_IP" "mkdir -p $REMOTE_DIR/sexp_effects $REMOTE_DIR/streaming $REMOTE_DIR/scripts" + +# Copy core files +echo "[2/4] Copying core files..." +scp "$LOCAL_DIR/sexp_effects/wgsl_compiler.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/parser.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/interpreter.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/__init__.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/streaming/backends.py" "root@$DROPLET_IP:$REMOTE_DIR/streaming/" + +# Copy effects +echo "[3/4] Copying effects..." +ssh "root@$DROPLET_IP" "mkdir -p $REMOTE_DIR/sexp_effects/effects $REMOTE_DIR/sexp_effects/primitive_libs" +scp -r "$LOCAL_DIR/sexp_effects/effects/"*.sexp "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/effects/" 2>/dev/null || true +scp -r "$LOCAL_DIR/sexp_effects/primitive_libs/"*.py "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/primitive_libs/" 2>/dev/null || true + +# Test +echo "[4/4] Testing deployment..." +ssh "root@$DROPLET_IP" "cd $REMOTE_DIR && /opt/artdag-gpu/bin/python3 -c ' +import sys +sys.path.insert(0, \".\") +from sexp_effects.wgsl_compiler import compile_effect_file +result = compile_effect_file(\"sexp_effects/effects/invert.sexp\") +print(f\"Compiled effect: {result.name}\") +print(\"Deployment OK\") +'" || echo "Test failed - may need to run setup script first" + +echo "" +echo "=== Deployment complete ===" +echo "SSH: ssh root@$DROPLET_IP" +echo "Test: ssh root@$DROPLET_IP 'cd $REMOTE_DIR && /opt/artdag-gpu/bin/python3 -c \"from streaming.backends import get_backend; b=get_backend(\\\"wgpu\\\"); print(b)\"'" diff --git a/l1/scripts/gpu-dev-deploy.sh b/l1/scripts/gpu-dev-deploy.sh new file mode 100755 index 0000000..f1be595 --- /dev/null +++ b/l1/scripts/gpu-dev-deploy.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Quick deploy to GPU node with hot reload +# Usage: ./scripts/gpu-dev-deploy.sh + +set -e + +GPU_HOST="${GPU_HOST:-root@138.197.163.123}" +REMOTE_DIR="/root/art-dag/celery" + +echo "=== GPU Dev Deploy ===" +echo "Syncing code to $GPU_HOST..." + +# Sync code (excluding cache, git, __pycache__) +rsync -avz --delete \ + --exclude '.git' \ + --exclude '__pycache__' \ + --exclude '*.pyc' \ + --exclude '.pytest_cache' \ + --exclude 'node_modules' \ + --exclude '.env' \ + ./ "$GPU_HOST:$REMOTE_DIR/" + +echo "Restarting GPU worker..." +ssh "$GPU_HOST" "docker kill \$(docker ps -q -f name=l1-gpu-worker) 2>/dev/null || true" + +echo "Waiting for new container..." +sleep 10 + +# Show new container logs +ssh "$GPU_HOST" "docker logs --tail 30 \$(docker ps -q -f name=l1-gpu-worker)" + +echo "" +echo "=== Deploy Complete ===" +echo "Use 'ssh $GPU_HOST docker logs -f \$(docker ps -q -f name=l1-gpu-worker)' to follow logs" diff --git a/l1/scripts/setup-gpu-droplet.sh b/l1/scripts/setup-gpu-droplet.sh new file mode 100755 index 0000000..e731ef8 --- /dev/null +++ b/l1/scripts/setup-gpu-droplet.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Setup script for GPU droplet with NVENC support +# Run as root on a fresh Ubuntu droplet with NVIDIA GPU + +set -e + +echo "=== ArtDAG GPU Droplet Setup ===" + +# 1. System updates +echo "[1/7] Updating system..." +apt-get update +apt-get upgrade -y + +# 2. Install NVIDIA drivers (if not already installed) +echo "[2/7] Checking NVIDIA drivers..." +if ! command -v nvidia-smi &> /dev/null; then + echo "Installing NVIDIA drivers..." + apt-get install -y nvidia-driver-535 nvidia-utils-535 + echo "NVIDIA drivers installed. Reboot required." + echo "After reboot, run this script again." + exit 0 +fi + +nvidia-smi +echo "NVIDIA drivers OK" + +# 3. Install FFmpeg with NVENC support +echo "[3/7] Installing FFmpeg with NVENC..." +apt-get install -y ffmpeg + +# Verify NVENC +if ffmpeg -encoders 2>/dev/null | grep -q nvenc; then + echo "NVENC available:" + ffmpeg -encoders 2>/dev/null | grep nvenc +else + echo "WARNING: NVENC not available. GPU may not support hardware encoding." +fi + +# 4. Install Python and create venv +echo "[4/7] Setting up Python environment..." +apt-get install -y python3 python3-venv python3-pip git + +VENV_DIR="/opt/artdag-gpu" +python3 -m venv "$VENV_DIR" +source "$VENV_DIR/bin/activate" + +# 5. Install Python dependencies +echo "[5/7] Installing Python packages..." +pip install --upgrade pip +pip install \ + numpy \ + opencv-python-headless \ + wgpu \ + httpx \ + pyyaml \ + celery[redis] \ + fastapi \ + uvicorn + +# 6. Clone/update art-dag code +echo "[6/7] Setting up art-dag code..." +ARTDAG_DIR="$VENV_DIR/celery" +if [ -d "$ARTDAG_DIR" ]; then + echo "Updating existing code..." + cd "$ARTDAG_DIR" + git pull || true +else + echo "Cloning art-dag..." + git clone https://git.rose-ash.com/art-dag/celery.git "$ARTDAG_DIR" || { + echo "Git clone failed. You may need to copy code manually." + } +fi + +# 7. Test GPU compute +echo "[7/7] Testing GPU compute..." +"$VENV_DIR/bin/python3" << 'PYTEST' +import sys +try: + import wgpu + adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + print(f"GPU Adapter: {adapter.info.get('device', 'unknown')}") + device = adapter.request_device_sync() + print("wgpu device created successfully") + + # Check for NVENC via FFmpeg + import subprocess + result = subprocess.run(['ffmpeg', '-encoders'], capture_output=True, text=True) + if 'h264_nvenc' in result.stdout: + print("NVENC H.264 encoder: AVAILABLE") + else: + print("NVENC H.264 encoder: NOT AVAILABLE") + if 'hevc_nvenc' in result.stdout: + print("NVENC HEVC encoder: AVAILABLE") + else: + print("NVENC HEVC encoder: NOT AVAILABLE") + +except Exception as e: + print(f"Error: {e}") + sys.exit(1) +PYTEST + +echo "" +echo "=== Setup Complete ===" +echo "Venv: $VENV_DIR" +echo "Code: $ARTDAG_DIR" +echo "" +echo "To activate: source $VENV_DIR/bin/activate" +echo "To test: cd $ARTDAG_DIR && python -c 'from streaming.backends import get_backend; print(get_backend(\"wgpu\"))'" diff --git a/l1/server.py b/l1/server.py new file mode 100644 index 0000000..f7c1e1e --- /dev/null +++ b/l1/server.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +Art DAG L1 Server + +Minimal entry point that uses the modular app factory. +All routes are defined in app/routers/. +All templates are in app/templates/. +""" + +import logging +import os + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s %(name)s: %(message)s' +) + +# Import the app from the factory +from app import app + +if __name__ == "__main__": + import uvicorn + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", "8100")) + uvicorn.run("server:app", host=host, port=port, workers=4) diff --git a/l1/sexp_effects/__init__.py b/l1/sexp_effects/__init__.py new file mode 100644 index 0000000..b001c71 --- /dev/null +++ b/l1/sexp_effects/__init__.py @@ -0,0 +1,32 @@ +""" +S-Expression Effects System + +Safe, shareable effects defined in S-expressions. +""" + +from .parser import parse, parse_file, Symbol, Keyword +from .interpreter import ( + Interpreter, + get_interpreter, + load_effect, + load_effects_dir, + run_effect, + list_effects, + make_process_frame, +) +from .primitives import PRIMITIVES + +__all__ = [ + 'parse', + 'parse_file', + 'Symbol', + 'Keyword', + 'Interpreter', + 'get_interpreter', + 'load_effect', + 'load_effects_dir', + 'run_effect', + 'list_effects', + 'make_process_frame', + 'PRIMITIVES', +] diff --git a/l1/sexp_effects/derived.sexp b/l1/sexp_effects/derived.sexp new file mode 100644 index 0000000..7e1aae3 --- /dev/null +++ b/l1/sexp_effects/derived.sexp @@ -0,0 +1,206 @@ +;; Derived Operations +;; +;; These are built from true primitives using S-expressions. +;; Load with: (require "derived") + +;; ============================================================================= +;; Math Helpers (derivable from where + basic ops) +;; ============================================================================= + +;; Absolute value +(define (abs x) (where (< x 0) (- x) x)) + +;; Minimum of two values +(define (min2 a b) (where (< a b) a b)) + +;; Maximum of two values +(define (max2 a b) (where (> a b) a b)) + +;; Clamp x to range [lo, hi] +(define (clamp x lo hi) (max2 lo (min2 hi x))) + +;; Square of x +(define (sq x) (* x x)) + +;; Linear interpolation: a*(1-t) + b*t +(define (lerp a b t) (+ (* a (- 1 t)) (* b t))) + +;; Smooth interpolation between edges +(define (smoothstep edge0 edge1 x) + (let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1))) + (* t (* t (- 3 (* 2 t)))))) + +;; ============================================================================= +;; Channel Shortcuts (derivable from channel primitive) +;; ============================================================================= + +;; Extract red channel as xector +(define (red frame) (channel frame 0)) + +;; Extract green channel as xector +(define (green frame) (channel frame 1)) + +;; Extract blue channel as xector +(define (blue frame) (channel frame 2)) + +;; Convert to grayscale xector (ITU-R BT.601) +(define (gray frame) + (+ (* (red frame) 0.299) + (* (green frame) 0.587) + (* (blue frame) 0.114))) + +;; Alias for gray +(define (luminance frame) (gray frame)) + +;; ============================================================================= +;; Coordinate Generators (derivable from iota + repeat/tile) +;; ============================================================================= + +;; X coordinate for each pixel [0, width) +(define (x-coords frame) (tile (iota (width frame)) (height frame))) + +;; Y coordinate for each pixel [0, height) +(define (y-coords frame) (repeat (iota (height frame)) (width frame))) + +;; Normalized X coordinate [0, 1] +(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1)))) + +;; Normalized Y coordinate [0, 1] +(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1)))) + +;; Distance from frame center for each pixel +(define (dist-from-center frame) + (let* ((cx (/ (width frame) 2)) + (cy (/ (height frame) 2)) + (dx (- (x-coords frame) cx)) + (dy (- (y-coords frame) cy))) + (sqrt (+ (sq dx) (sq dy))))) + +;; Normalized distance from center [0, ~1] +(define (dist-norm frame) + (let ((d (dist-from-center frame))) + (/ d (max2 1 (βmax d))))) + +;; ============================================================================= +;; Cell/Grid Operations (derivable from floor + basic math) +;; ============================================================================= + +;; Cell row index for each pixel +(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size))) + +;; Cell column index for each pixel +(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size))) + +;; Number of cell rows +(define (num-rows frame cell-size) (floor (/ (height frame) cell-size))) + +;; Number of cell columns +(define (num-cols frame cell-size) (floor (/ (width frame) cell-size))) + +;; Flat cell index for each pixel +(define (cell-indices frame cell-size) + (+ (* (cell-row frame cell-size) (num-cols frame cell-size)) + (cell-col frame cell-size))) + +;; Total number of cells +(define (num-cells frame cell-size) + (* (num-rows frame cell-size) (num-cols frame cell-size))) + +;; X position within cell [0, cell-size) +(define (local-x frame cell-size) (mod (x-coords frame) cell-size)) + +;; Y position within cell [0, cell-size) +(define (local-y frame cell-size) (mod (y-coords frame) cell-size)) + +;; Normalized X within cell [0, 1] +(define (local-x-norm frame cell-size) + (/ (local-x frame cell-size) (max2 1 (- cell-size 1)))) + +;; Normalized Y within cell [0, 1] +(define (local-y-norm frame cell-size) + (/ (local-y frame cell-size) (max2 1 (- cell-size 1)))) + +;; ============================================================================= +;; Fill Operations (derivable from iota) +;; ============================================================================= + +;; Xector of n zeros +(define (zeros n) (* (iota n) 0)) + +;; Xector of n ones +(define (ones n) (+ (zeros n) 1)) + +;; Xector of n copies of val +(define (fill val n) (+ (zeros n) val)) + +;; Xector of zeros matching x's length +(define (zeros-like x) (* x 0)) + +;; Xector of ones matching x's length +(define (ones-like x) (+ (zeros-like x) 1)) + +;; ============================================================================= +;; Pooling (derivable from group-reduce) +;; ============================================================================= + +;; Pool a channel by cell index +(define (pool-channel chan cell-idx num-cells) + (group-reduce chan cell-idx num-cells "mean")) + +;; Pool red channel to cells +(define (pool-red frame cell-size) + (pool-channel (red frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool green channel to cells +(define (pool-green frame cell-size) + (pool-channel (green frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool blue channel to cells +(define (pool-blue frame cell-size) + (pool-channel (blue frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool grayscale to cells +(define (pool-gray frame cell-size) + (pool-channel (gray frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; ============================================================================= +;; Blending (derivable from math) +;; ============================================================================= + +;; Additive blend +(define (blend-add a b) (clamp (+ a b) 0 255)) + +;; Multiply blend (normalized) +(define (blend-multiply a b) (* (/ a 255) b)) + +;; Screen blend +(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b)))) + +;; Overlay blend +(define (blend-overlay a b) + (where (< a 128) + (* 2 (/ (* a b) 255)) + (- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255))))) + +;; ============================================================================= +;; Simple Effects (derivable from primitives) +;; ============================================================================= + +;; Invert a channel (255 - c) +(define (invert-channel c) (- 255 c)) + +;; Binary threshold +(define (threshold-channel c thresh) (where (> c thresh) 255 0)) + +;; Reduce to n levels +(define (posterize-channel c levels) + (let ((step (/ 255 (- levels 1)))) + (* (round (/ c step)) step))) diff --git a/l1/sexp_effects/effects/ascii_art.sexp b/l1/sexp_effects/effects/ascii_art.sexp new file mode 100644 index 0000000..0504768 --- /dev/null +++ b/l1/sexp_effects/effects/ascii_art.sexp @@ -0,0 +1,17 @@ +;; ASCII Art effect - converts image to ASCII characters +(require-primitives "ascii") + +(define-effect ascii_art + :params ( + (char_size :type int :default 8 :range [4 32]) + (alphabet :type string :default "standard") + (color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex") + (background_color :type string :default "black" :desc "background color name/hex") + (invert_colors :type int :default 0 :desc "swap foreground and background colors") + (contrast :type float :default 1.5 :range [1 3]) + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid frame chars colors char_size color_mode background_color invert_colors))) diff --git a/l1/sexp_effects/effects/ascii_art_fx.sexp b/l1/sexp_effects/effects/ascii_art_fx.sexp new file mode 100644 index 0000000..2bb14be --- /dev/null +++ b/l1/sexp_effects/effects/ascii_art_fx.sexp @@ -0,0 +1,52 @@ +;; ASCII Art FX - converts image to ASCII characters with per-character effects +(require-primitives "ascii") + +(define-effect ascii_art_fx + :params ( + ;; Basic parameters + (char_size :type int :default 8 :range [4 32] + :desc "Size of each character cell in pixels") + (alphabet :type string :default "standard" + :desc "Character set to use") + (color_mode :type string :default "color" + :choices [color mono invert] + :desc "Color mode: color, mono, invert, or any color name/hex") + (background_color :type string :default "black" + :desc "Background color name or hex value") + (invert_colors :type int :default 0 :range [0 1] + :desc "Swap foreground and background colors (0/1)") + (contrast :type float :default 1.5 :range [1 3] + :desc "Character selection contrast") + + ;; Per-character effects + (char_jitter :type float :default 0 :range [0 20] + :desc "Position jitter amount in pixels") + (char_scale :type float :default 1.0 :range [0.5 2.0] + :desc "Character scale factor") + (char_rotation :type float :default 0 :range [0 180] + :desc "Rotation amount in degrees") + (char_hue_shift :type float :default 0 :range [0 360] + :desc "Hue shift in degrees") + + ;; Modulation sources + (jitter_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives jitter modulation") + (scale_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives scale modulation") + (rotation_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives rotation modulation") + (hue_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives hue shift modulation") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid-fx frame chars colors luminances char_size + color_mode background_color invert_colors + char_jitter char_scale char_rotation char_hue_shift + jitter_source scale_source rotation_source hue_source))) diff --git a/l1/sexp_effects/effects/ascii_fx_zone.sexp b/l1/sexp_effects/effects/ascii_fx_zone.sexp new file mode 100644 index 0000000..69e5340 --- /dev/null +++ b/l1/sexp_effects/effects/ascii_fx_zone.sexp @@ -0,0 +1,102 @@ +;; Composable ASCII Art with Per-Zone Expression-Driven Effects +;; Requires ascii primitive library for the ascii-fx-zone primitive + +(require-primitives "ascii") + +;; Two modes of operation: +;; +;; 1. EXPRESSION MODE: Use zone-* variables in expression parameters +;; Zone variables available: +;; zone-row, zone-col: Grid position (integers) +;; zone-row-norm, zone-col-norm: Normalized position (0-1) +;; zone-lum: Cell luminance (0-1) +;; zone-sat: Cell saturation (0-1) +;; zone-hue: Cell hue (0-360) +;; zone-r, zone-g, zone-b: RGB components (0-1) +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 80 +;; :char_hue (* zone-lum 180) +;; :char_rotation (* zone-col-norm 30)) +;; +;; 2. CELL EFFECT MODE: Pass a lambda to apply arbitrary effects per-cell +;; The lambda receives (cell-image zone-dict) and returns modified cell. +;; Zone dict contains: row, col, row-norm, col-norm, lum, sat, hue, r, g, b, +;; char, color, cell_size, plus any bound analysis values. +;; +;; Any loaded sexp effect can be called on cells - each cell is just a small frame: +;; (blur cell radius) - Gaussian blur +;; (rotate cell angle) - Rotate by angle degrees +;; (brightness cell factor) - Adjust brightness +;; (contrast cell factor) - Adjust contrast +;; (saturation cell factor) - Adjust saturation +;; (hue_shift cell degrees) - Shift hue +;; (rgb_split cell offset_x offset_y) - RGB channel split +;; (invert cell) - Invert colors +;; (pixelate cell block_size) - Pixelate +;; (wave cell amplitude freq) - Wave distortion +;; ... and any other loaded effect +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 60 +;; :cell_effect (lambda [cell zone] +;; (blur (rotate cell (* (get zone "energy") 45)) +;; (if (> (get zone "lum") 0.5) 3 0)))) + +(define-effect ascii_fx_zone + :params ( + (cols :type int :default 80 :range [20 200] + :desc "Number of character columns") + (char_size :type int :default nil :range [4 32] + :desc "Character cell size in pixels (overrides cols if set)") + (alphabet :type string :default "standard" + :desc "Character set: standard, blocks, simple, digits, or custom string") + (color_mode :type string :default "color" + :desc "Color mode: color, mono, invert, or any color name/hex") + (background :type string :default "black" + :desc "Background color name or hex value") + (contrast :type float :default 1.5 :range [0.5 3.0] + :desc "Contrast for character selection") + (char_hue :type any :default nil + :desc "Hue shift expression (evaluated per-zone with zone-* vars)") + (char_saturation :type any :default nil + :desc "Saturation multiplier expression (1.0 = unchanged)") + (char_brightness :type any :default nil + :desc "Brightness multiplier expression (1.0 = unchanged)") + (char_scale :type any :default nil + :desc "Character scale expression (1.0 = normal size)") + (char_rotation :type any :default nil + :desc "Character rotation expression (degrees)") + (char_jitter :type any :default nil + :desc "Position jitter expression (pixels)") + (cell_effect :type any :default nil + :desc "Lambda (cell zone) -> cell for arbitrary per-cell effects") + ;; Convenience params for staged recipes (avoids compile-time expression issues) + (energy :type float :default nil + :desc "Energy multiplier (0-1) from audio analysis bind") + (rotation_scale :type float :default 0 + :desc "Max rotation at top-right when energy=1 (degrees)") + ) + ;; The ascii-fx-zone special form handles expression params + ;; If energy + rotation_scale provided, it builds: energy * scale * position_factor + ;; where position_factor = 0 at bottom-left, 3 at top-right + ;; If cell_effect provided, each character is rendered to a cell image, + ;; passed to the lambda, and the result composited back + (ascii-fx-zone frame + :cols cols + :char_size char_size + :alphabet alphabet + :color_mode color_mode + :background background + :contrast contrast + :char_hue char_hue + :char_saturation char_saturation + :char_brightness char_brightness + :char_scale char_scale + :char_rotation char_rotation + :char_jitter char_jitter + :cell_effect cell_effect + :energy energy + :rotation_scale rotation_scale)) diff --git a/l1/sexp_effects/effects/ascii_zones.sexp b/l1/sexp_effects/effects/ascii_zones.sexp new file mode 100644 index 0000000..6bc441c --- /dev/null +++ b/l1/sexp_effects/effects/ascii_zones.sexp @@ -0,0 +1,30 @@ +;; ASCII Zones effect - different character sets for different brightness zones +;; Dark areas use simple chars, mid uses standard, bright uses blocks +(require-primitives "ascii") + +(define-effect ascii_zones + :params ( + (char_size :type int :default 8 :range [4 32]) + (dark_threshold :type int :default 80 :range [0 128]) + (bright_threshold :type int :default 180 :range [128 255]) + (color_mode :type string :default "color") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + ;; Start with simple chars as base + (base-chars (luminance-to-chars luminances "simple" 1.2)) + ;; Map each cell to appropriate alphabet based on brightness zone + (zoned-chars (map-char-grid base-chars luminances + (lambda (r c ch lum) + (cond + ;; Bright zones: use block characters + ((> lum bright_threshold) + (alphabet-char "blocks" (floor (/ (- lum bright_threshold) 15)))) + ;; Dark zones: use simple sparse chars + ((< lum dark_threshold) + (alphabet-char " .-" (floor (/ lum 30)))) + ;; Mid zones: use standard ASCII + (else + (alphabet-char "standard" (floor (/ lum 4))))))))) + (render-char-grid frame zoned-chars colors char_size color_mode (list 0 0 0)))) diff --git a/l1/sexp_effects/effects/blend.sexp b/l1/sexp_effects/effects/blend.sexp new file mode 100644 index 0000000..bf7fefd --- /dev/null +++ b/l1/sexp_effects/effects/blend.sexp @@ -0,0 +1,31 @@ +;; Blend effect - combines two video frames +;; Streaming-compatible: frame is background, overlay is second frame +;; Usage: (blend background overlay :opacity 0.5 :mode "alpha") +;; +;; Params: +;; mode - blend mode (add, multiply, screen, overlay, difference, lighten, darken, alpha) +;; opacity - blend amount (0-1) + +(require-primitives "image" "blending" "core") + +(define-effect blend + :params ( + (overlay :type frame :default nil) + (mode :type string :default "alpha") + (opacity :type float :default 0.5) + ) + (if (core:is-nil overlay) + frame + (let [a frame + b overlay + a-h (image:height a) + a-w (image:width a) + b-h (image:height b) + b-w (image:width b) + ;; Resize b to match a if needed + b-sized (if (and (= a-w b-w) (= a-h b-h)) + b + (image:resize b a-w a-h "linear"))] + (if (= mode "alpha") + (blending:blend-images a b-sized opacity) + (blending:blend-images a (blending:blend-mode a b-sized mode) opacity))))) diff --git a/l1/sexp_effects/effects/blend_multi.sexp b/l1/sexp_effects/effects/blend_multi.sexp new file mode 100644 index 0000000..1ee160f --- /dev/null +++ b/l1/sexp_effects/effects/blend_multi.sexp @@ -0,0 +1,58 @@ +;; N-way weighted blend effect +;; Streaming-compatible: pass inputs as a list of frames +;; Usage: (blend_multi :inputs [(read a) (read b) (read c)] :weights [0.3 0.4 0.3]) +;; +;; Parameters: +;; inputs - list of N frames to blend +;; weights - list of N floats, one per input (resolved per-frame) +;; mode - blend mode applied when folding each frame in: +;; "alpha" — pure weighted average (default) +;; "multiply" — darken by multiplication +;; "screen" — lighten (inverse multiply) +;; "overlay" — contrast-boosting midtone blend +;; "soft-light" — gentle dodge/burn +;; "hard-light" — strong dodge/burn +;; "color-dodge" — brightens towards white +;; "color-burn" — darkens towards black +;; "difference" — absolute pixel difference +;; "exclusion" — softer difference +;; "add" — additive (clamped) +;; "subtract" — subtractive (clamped) +;; "darken" — per-pixel minimum +;; "lighten" — per-pixel maximum +;; resize_mode - how to match frame dimensions (fit, crop, stretch) +;; +;; Uses a left-fold over inputs[1..N-1]. At each step the running +;; opacity is: w[i] / (w[0] + w[1] + ... + w[i]) +;; which produces the correct normalised weighted result. + +(require-primitives "image" "blending") + +(define-effect blend_multi + :params ( + (inputs :type list :default []) + (weights :type list :default []) + (mode :type string :default "alpha") + (resize_mode :type string :default "fit") + ) + (let [n (len inputs) + ;; Target dimensions from first frame + target-w (image:width (nth inputs 0)) + target-h (image:height (nth inputs 0)) + ;; Fold over indices 1..n-1 + ;; Accumulator is (list blended-frame running-weight-sum) + seed (list (nth inputs 0) (nth weights 0)) + result (reduce (range 1 n) seed + (lambda (pair i) + (let [acc (nth pair 0) + running (nth pair 1) + w (nth weights i) + new-running (+ running w) + opacity (/ w (max new-running 0.001)) + f (image:resize (nth inputs i) target-w target-h "linear") + ;; Apply blend mode then mix with opacity + blended (if (= mode "alpha") + (blending:blend-images acc f opacity) + (blending:blend-images acc (blending:blend-mode acc f mode) opacity))] + (list blended new-running))))] + (nth result 0))) diff --git a/l1/sexp_effects/effects/bloom.sexp b/l1/sexp_effects/effects/bloom.sexp new file mode 100644 index 0000000..3524d01 --- /dev/null +++ b/l1/sexp_effects/effects/bloom.sexp @@ -0,0 +1,16 @@ +;; Bloom effect - glow on bright areas +(require-primitives "image" "blending") + +(define-effect bloom + :params ( + (intensity :type float :default 0.5 :range [0 2]) + (threshold :type int :default 200 :range [0 255]) + (radius :type int :default 15 :range [1 50]) + ) + (let* ((bright (map-pixels frame + (lambda (x y c) + (if (> (luminance c) threshold) + c + (rgb 0 0 0))))) + (blurred (image:blur bright radius))) + (blending:blend-mode frame blurred "add"))) diff --git a/l1/sexp_effects/effects/blur.sexp b/l1/sexp_effects/effects/blur.sexp new file mode 100644 index 0000000..b71a55a --- /dev/null +++ b/l1/sexp_effects/effects/blur.sexp @@ -0,0 +1,8 @@ +;; Blur effect - gaussian blur +(require-primitives "image") + +(define-effect blur + :params ( + (radius :type int :default 5 :range [1 50]) + ) + (image:blur frame (max 1 radius))) diff --git a/l1/sexp_effects/effects/brightness.sexp b/l1/sexp_effects/effects/brightness.sexp new file mode 100644 index 0000000..4af53a7 --- /dev/null +++ b/l1/sexp_effects/effects/brightness.sexp @@ -0,0 +1,9 @@ +;; Brightness effect - adjusts overall brightness +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect brightness + :params ( + (amount :type int :default 0 :range [-255 255]) + ) + (color_ops:adjust-brightness frame amount)) diff --git a/l1/sexp_effects/effects/cell_pattern.sexp b/l1/sexp_effects/effects/cell_pattern.sexp new file mode 100644 index 0000000..bc503bb --- /dev/null +++ b/l1/sexp_effects/effects/cell_pattern.sexp @@ -0,0 +1,65 @@ +;; Cell Pattern effect - custom patterns within cells +;; +;; Demonstrates building arbitrary per-cell visuals from primitives. +;; Uses local coordinates within cells to draw patterns scaled by luminance. + +(require-primitives "xector") + +(define-effect cell_pattern + :params ( + (cell-size :type int :default 16 :range [8 48] :desc "Cell size") + (pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring") + ) + (let* ( + ;; Pool to get cell colors + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + (cell-lum (α/ (nth pooled 3) 255)) + + ;; Cell indices for each pixel + (cell-idx (cell-indices frame cell-size)) + + ;; Look up cell values for each pixel + (pix-r (gather cell-r cell-idx)) + (pix-g (gather cell-g cell-idx)) + (pix-b (gather cell-b cell-idx)) + (pix-lum (gather cell-lum cell-idx)) + + ;; Local position within cell [0, 1] + (lx (local-x-norm frame cell-size)) + (ly (local-y-norm frame cell-size)) + + ;; Pattern mask based on pattern type + (mask + (cond + ;; Diagonal lines - thickness based on luminance + ((= pattern "diagonal") + (let* ((diag (αmod (α+ lx ly) 0.25)) + (thickness (α* pix-lum 0.125))) + (α< diag thickness))) + + ;; Cross pattern + ((= pattern "cross") + (let* ((cx (αabs (α- lx 0.5))) + (cy (αabs (α- ly 0.5))) + (thickness (α* pix-lum 0.25))) + (αor (α< cx thickness) (α< cy thickness)))) + + ;; Ring pattern + ((= pattern "ring") + (let* ((dx (α- lx 0.5)) + (dy (α- ly 0.5)) + (dist (αsqrt (α+ (α² dx) (α² dy)))) + (target (α* pix-lum 0.4)) + (thickness 0.05)) + (α< (αabs (α- dist target)) thickness))) + + ;; Default: solid + (else (α> pix-lum 0))))) + + ;; Apply mask: show cell color where mask is true, black elsewhere + (rgb (where mask pix-r 0) + (where mask pix-g 0) + (where mask pix-b 0)))) diff --git a/l1/sexp_effects/effects/color-adjust.sexp b/l1/sexp_effects/effects/color-adjust.sexp new file mode 100644 index 0000000..5318bdd --- /dev/null +++ b/l1/sexp_effects/effects/color-adjust.sexp @@ -0,0 +1,13 @@ +;; Color adjustment effect - replaces TRANSFORM node +(require-primitives "color_ops") + +(define-effect color-adjust + :params ( + (brightness :type int :default 0 :range [-255 255] :desc "Brightness adjustment") + (contrast :type float :default 1 :range [0 3] :desc "Contrast multiplier") + (saturation :type float :default 1 :range [0 2] :desc "Saturation multiplier") + ) + (-> frame + (color_ops:adjust-brightness brightness) + (color_ops:adjust-contrast contrast) + (color_ops:adjust-saturation saturation))) diff --git a/l1/sexp_effects/effects/color_cycle.sexp b/l1/sexp_effects/effects/color_cycle.sexp new file mode 100644 index 0000000..e08dbb6 --- /dev/null +++ b/l1/sexp_effects/effects/color_cycle.sexp @@ -0,0 +1,13 @@ +;; Color Cycle effect - animated hue rotation +(require-primitives "color_ops") + +(define-effect color_cycle + :params ( + (speed :type int :default 1 :range [0 10]) + ) + (let ((shift (* t speed 360))) + (map-pixels frame + (lambda (x y c) + (let* ((hsv (rgb->hsv c)) + (new-h (mod (+ (first hsv) shift) 360))) + (hsv->rgb (list new-h (nth hsv 1) (nth hsv 2)))))))) diff --git a/l1/sexp_effects/effects/contrast.sexp b/l1/sexp_effects/effects/contrast.sexp new file mode 100644 index 0000000..660661d --- /dev/null +++ b/l1/sexp_effects/effects/contrast.sexp @@ -0,0 +1,9 @@ +;; Contrast effect - adjusts image contrast +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect contrast + :params ( + (amount :type int :default 1 :range [0.5 3]) + ) + (color_ops:adjust-contrast frame amount)) diff --git a/l1/sexp_effects/effects/crt.sexp b/l1/sexp_effects/effects/crt.sexp new file mode 100644 index 0000000..097eaf9 --- /dev/null +++ b/l1/sexp_effects/effects/crt.sexp @@ -0,0 +1,30 @@ +;; CRT effect - old monitor simulation +(require-primitives "image") + +(define-effect crt + :params ( + (line_spacing :type int :default 2 :range [1 10]) + (line_opacity :type float :default 0.3 :range [0 1]) + (vignette_amount :type float :default 0.2) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (sqrt (+ (* cx cx) (* cy cy))))) + (map-pixels frame + (lambda (x y c) + (let* (;; Scanline darkening + (scanline-factor (if (= 0 (mod y line_spacing)) + (- 1 line_opacity) + 1)) + ;; Vignette + (dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (vignette-factor (- 1 (* (/ dist max-dist) vignette_amount))) + ;; Combined + (factor (* scanline-factor vignette-factor))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/l1/sexp_effects/effects/datamosh.sexp b/l1/sexp_effects/effects/datamosh.sexp new file mode 100644 index 0000000..60cec66 --- /dev/null +++ b/l1/sexp_effects/effects/datamosh.sexp @@ -0,0 +1,14 @@ +;; Datamosh effect - glitch block corruption + +(define-effect datamosh + :params ( + (block_size :type int :default 32 :range [8 128]) + (corruption :type float :default 0.3 :range [0 1]) + (max_offset :type int :default 50 :range [0 200]) + (color_corrupt :type bool :default true) + ) + ;; Get previous frame from state, or use current frame if none + (let ((prev (state-get "prev_frame" frame))) + (begin + (state-set "prev_frame" (copy frame)) + (datamosh frame prev block_size corruption max_offset color_corrupt)))) diff --git a/l1/sexp_effects/effects/echo.sexp b/l1/sexp_effects/effects/echo.sexp new file mode 100644 index 0000000..599a1d6 --- /dev/null +++ b/l1/sexp_effects/effects/echo.sexp @@ -0,0 +1,19 @@ +;; Echo effect - motion trails using frame buffer +(require-primitives "blending") + +(define-effect echo + :params ( + (num_echoes :type int :default 4 :range [1 20]) + (decay :type float :default 0.5 :range [0 1]) + ) + (let* ((buffer (state-get "buffer" (list))) + (new-buffer (take (cons frame buffer) (+ num_echoes 1)))) + (begin + (state-set "buffer" new-buffer) + ;; Blend frames with decay + (if (< (length new-buffer) 2) + frame + (let ((result (copy frame))) + ;; Simple blend of first two frames for now + ;; Full version would fold over all frames + (blending:blend-images frame (nth new-buffer 1) (* decay 0.5))))))) diff --git a/l1/sexp_effects/effects/edge_detect.sexp b/l1/sexp_effects/effects/edge_detect.sexp new file mode 100644 index 0000000..170befb --- /dev/null +++ b/l1/sexp_effects/effects/edge_detect.sexp @@ -0,0 +1,9 @@ +;; Edge detection effect - highlights edges +(require-primitives "image") + +(define-effect edge_detect + :params ( + (low :type int :default 50 :range [10 100]) + (high :type int :default 150 :range [50 300]) + ) + (image:edge-detect frame low high)) diff --git a/l1/sexp_effects/effects/emboss.sexp b/l1/sexp_effects/effects/emboss.sexp new file mode 100644 index 0000000..1eac3ce --- /dev/null +++ b/l1/sexp_effects/effects/emboss.sexp @@ -0,0 +1,13 @@ +;; Emboss effect - creates raised/3D appearance +(require-primitives "blending") + +(define-effect emboss + :params ( + (strength :type int :default 1 :range [0.5 3]) + (blend :type float :default 0.3 :range [0 1]) + ) + (let* ((kernel (list (list (- strength) (- strength) 0) + (list (- strength) 1 strength) + (list 0 strength strength))) + (embossed (convolve frame kernel))) + (blending:blend-images embossed frame blend))) diff --git a/l1/sexp_effects/effects/film_grain.sexp b/l1/sexp_effects/effects/film_grain.sexp new file mode 100644 index 0000000..29bdd75 --- /dev/null +++ b/l1/sexp_effects/effects/film_grain.sexp @@ -0,0 +1,19 @@ +;; Film Grain effect - adds film grain texture +(require-primitives "core") + +(define-effect film_grain + :params ( + (intensity :type float :default 0.2 :range [0 1]) + (colored :type bool :default false) + ) + (let ((grain-amount (* intensity 50))) + (map-pixels frame + (lambda (x y c) + (if colored + (rgb (clamp (+ (red c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (green c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (blue c) (gaussian 0 grain-amount)) 0 255)) + (let ((n (gaussian 0 grain-amount))) + (rgb (clamp (+ (red c) n) 0 255) + (clamp (+ (green c) n) 0 255) + (clamp (+ (blue c) n) 0 255)))))))) diff --git a/l1/sexp_effects/effects/fisheye.sexp b/l1/sexp_effects/effects/fisheye.sexp new file mode 100644 index 0000000..37750a7 --- /dev/null +++ b/l1/sexp_effects/effects/fisheye.sexp @@ -0,0 +1,16 @@ +;; Fisheye effect - barrel/pincushion lens distortion +(require-primitives "geometry" "image") + +(define-effect fisheye + :params ( + (strength :type float :default 0.3 :range [-1 1]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom_correct :type bool :default true) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:fisheye-coords w h strength cx cy zoom_correct))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/l1/sexp_effects/effects/flip.sexp b/l1/sexp_effects/effects/flip.sexp new file mode 100644 index 0000000..977e1e1 --- /dev/null +++ b/l1/sexp_effects/effects/flip.sexp @@ -0,0 +1,16 @@ +;; Flip effect - flips image horizontally or vertically +(require-primitives "geometry") + +(define-effect flip + :params ( + (horizontal :type bool :default true) + (vertical :type bool :default false) + ) + (let ((result frame)) + (if horizontal + (set! result (geometry:flip-img result "horizontal")) + nil) + (if vertical + (set! result (geometry:flip-img result "vertical")) + nil) + result)) diff --git a/l1/sexp_effects/effects/grayscale.sexp b/l1/sexp_effects/effects/grayscale.sexp new file mode 100644 index 0000000..848f8a7 --- /dev/null +++ b/l1/sexp_effects/effects/grayscale.sexp @@ -0,0 +1,7 @@ +;; Grayscale effect - converts to grayscale +;; Uses vectorized mix-gray primitive for fast processing +(require-primitives "image") + +(define-effect grayscale + :params () + (image:grayscale frame)) diff --git a/l1/sexp_effects/effects/halftone.sexp b/l1/sexp_effects/effects/halftone.sexp new file mode 100644 index 0000000..2190a4a --- /dev/null +++ b/l1/sexp_effects/effects/halftone.sexp @@ -0,0 +1,49 @@ +;; Halftone/dot effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell luminances +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel +;; local-x/y-norm - position within cell [0,1] +;; where - conditional per-pixel + +(require-primitives "xector") + +(define-effect halftone + :params ( + (cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells") + (dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius") + (invert :type bool :default false :desc "Invert (white dots on black)") + ) + (let* ( + ;; Pool frame to get luminance per cell + (pooled (pool-frame frame cell-size)) + (cell-lum (nth pooled 3)) ; luminance is 4th element + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Get cell luminance for each pixel + (pixel-lum (α/ (gather cell-lum cell-idx) 255)) + + ;; Position within cell, normalized to [-0.5, 0.5] + (lx (α- (local-x-norm frame cell-size) 0.5)) + (ly (α- (local-y-norm frame cell-size) 0.5)) + + ;; Distance from cell center (0 at center, ~0.7 at corners) + (dist (αsqrt (α+ (α² lx) (α² ly)))) + + ;; Radius based on luminance (brighter = bigger dot) + (radius (α* (if invert (α- 1 pixel-lum) pixel-lum) + (α* dot-scale 0.5))) + + ;; Is this pixel inside the dot? + (inside (α< dist radius)) + + ;; Output color + (fg (if invert 255 0)) + (bg (if invert 0 255)) + (out (where inside fg bg))) + + ;; Grayscale output + (rgb out out out))) diff --git a/l1/sexp_effects/effects/hue_shift.sexp b/l1/sexp_effects/effects/hue_shift.sexp new file mode 100644 index 0000000..ab61bd6 --- /dev/null +++ b/l1/sexp_effects/effects/hue_shift.sexp @@ -0,0 +1,12 @@ +;; Hue shift effect - rotates hue values +;; Uses vectorized shift-hsv primitive for fast processing + +(require-primitives "color_ops") + +(define-effect hue_shift + :params ( + (degrees :type int :default 0 :range [0 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((shift (+ degrees (* speed t)))) + (color_ops:shift-hsv frame shift 1 1))) diff --git a/l1/sexp_effects/effects/invert.sexp b/l1/sexp_effects/effects/invert.sexp new file mode 100644 index 0000000..34936da --- /dev/null +++ b/l1/sexp_effects/effects/invert.sexp @@ -0,0 +1,9 @@ +;; Invert effect - inverts all colors +;; Uses vectorized invert-img primitive for fast processing +;; amount param: 0 = no invert, 1 = full invert (threshold at 0.5) + +(require-primitives "color_ops") + +(define-effect invert + :params ((amount :type float :default 1 :range [0 1])) + (if (> amount 0.5) (color_ops:invert-img frame) frame)) diff --git a/l1/sexp_effects/effects/kaleidoscope.sexp b/l1/sexp_effects/effects/kaleidoscope.sexp new file mode 100644 index 0000000..9487ae2 --- /dev/null +++ b/l1/sexp_effects/effects/kaleidoscope.sexp @@ -0,0 +1,20 @@ +;; Kaleidoscope effect - mandala-like symmetry patterns +(require-primitives "geometry" "image") + +(define-effect kaleidoscope + :params ( + (segments :type int :default 6 :range [3 16]) + (rotation :type int :default 0 :range [0 360]) + (rotation_speed :type int :default 0 :range [-180 180]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom :type int :default 1 :range [0.5 3]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + ;; Total rotation including time-based animation + (total_rot (+ rotation (* rotation_speed (or _time 0)))) + (coords (geometry:kaleidoscope-coords w h segments total_rot cx cy zoom))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/l1/sexp_effects/effects/layer.sexp b/l1/sexp_effects/effects/layer.sexp new file mode 100644 index 0000000..e57d627 --- /dev/null +++ b/l1/sexp_effects/effects/layer.sexp @@ -0,0 +1,36 @@ +;; Layer effect - composite overlay over background at position +;; Streaming-compatible: frame is background, overlay is foreground +;; Usage: (layer background overlay :x 10 :y 20 :opacity 0.8) +;; +;; Params: +;; overlay - frame to composite on top +;; x, y - position to place overlay +;; opacity - blend amount (0-1) +;; mode - blend mode (alpha, multiply, screen, etc.) + +(require-primitives "image" "blending" "core") + +(define-effect layer + :params ( + (overlay :type frame :default nil) + (x :type int :default 0) + (y :type int :default 0) + (opacity :type float :default 1.0) + (mode :type string :default "alpha") + ) + (if (core:is-nil overlay) + frame + (let [bg (copy frame) + fg overlay + fg-w (image:width fg) + fg-h (image:height fg)] + (if (= opacity 1.0) + ;; Simple paste + (paste bg fg x y) + ;; Blend with opacity + (let [blended (if (= mode "alpha") + (blending:blend-images (image:crop bg x y fg-w fg-h) fg opacity) + (blending:blend-images (image:crop bg x y fg-w fg-h) + (blending:blend-mode (image:crop bg x y fg-w fg-h) fg mode) + opacity))] + (paste bg blended x y)))))) diff --git a/l1/sexp_effects/effects/mirror.sexp b/l1/sexp_effects/effects/mirror.sexp new file mode 100644 index 0000000..a450cb6 --- /dev/null +++ b/l1/sexp_effects/effects/mirror.sexp @@ -0,0 +1,33 @@ +;; Mirror effect - mirrors half of image +(require-primitives "geometry" "image") + +(define-effect mirror + :params ( + (mode :type string :default "left_right") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (hw (floor (/ w 2))) + (hh (floor (/ h 2)))) + (cond + ((= mode "left_right") + (let ((left (image:crop frame 0 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img left "horizontal") hw 0))) + + ((= mode "right_left") + (let ((right (image:crop frame hw 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img right "horizontal") 0 0))) + + ((= mode "top_bottom") + (let ((top (image:crop frame 0 0 w hh)) + (result (copy frame))) + (paste result (geometry:flip-img top "vertical") 0 hh))) + + ((= mode "bottom_top") + (let ((bottom (image:crop frame 0 hh w hh)) + (result (copy frame))) + (paste result (geometry:flip-img bottom "vertical") 0 0))) + + (else frame)))) diff --git a/l1/sexp_effects/effects/mosaic.sexp b/l1/sexp_effects/effects/mosaic.sexp new file mode 100644 index 0000000..5de07de --- /dev/null +++ b/l1/sexp_effects/effects/mosaic.sexp @@ -0,0 +1,30 @@ +;; Mosaic effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell averages +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel + +(require-primitives "xector") + +(define-effect mosaic + :params ( + (cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells") + ) + (let* ( + ;; Pool frame to get average color per cell (returns r,g,b,lum xectors) + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Gather: look up cell color for each pixel + (out-r (gather cell-r cell-idx)) + (out-g (gather cell-g cell-idx)) + (out-b (gather cell-b cell-idx))) + + ;; Reconstruct frame + (rgb out-r out-g out-b))) diff --git a/l1/sexp_effects/effects/neon_glow.sexp b/l1/sexp_effects/effects/neon_glow.sexp new file mode 100644 index 0000000..39245ab --- /dev/null +++ b/l1/sexp_effects/effects/neon_glow.sexp @@ -0,0 +1,23 @@ +;; Neon Glow effect - glowing edge effect +(require-primitives "image" "blending") + +(define-effect neon_glow + :params ( + (edge_low :type int :default 50 :range [10 200]) + (edge_high :type int :default 150 :range [50 300]) + (glow_radius :type int :default 15 :range [1 50]) + (glow_intensity :type int :default 2 :range [0.5 5]) + (background :type float :default 0.3 :range [0 1]) + ) + (let* ((edge-img (image:edge-detect frame edge_low edge_high)) + (glow (image:blur edge-img glow_radius)) + ;; Intensify the glow + (bright-glow (map-pixels glow + (lambda (x y c) + (rgb (clamp (* (red c) glow_intensity) 0 255) + (clamp (* (green c) glow_intensity) 0 255) + (clamp (* (blue c) glow_intensity) 0 255)))))) + (blending:blend-mode (blending:blend-images frame (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 background)) + bright-glow + "screen"))) diff --git a/l1/sexp_effects/effects/noise.sexp b/l1/sexp_effects/effects/noise.sexp new file mode 100644 index 0000000..4da8298 --- /dev/null +++ b/l1/sexp_effects/effects/noise.sexp @@ -0,0 +1,8 @@ +;; Noise effect - adds random noise +;; Uses vectorized add-noise primitive for fast processing + +(define-effect noise + :params ( + (amount :type int :default 20 :range [0 100]) + ) + (add-noise frame amount)) diff --git a/l1/sexp_effects/effects/outline.sexp b/l1/sexp_effects/effects/outline.sexp new file mode 100644 index 0000000..921a0b8 --- /dev/null +++ b/l1/sexp_effects/effects/outline.sexp @@ -0,0 +1,24 @@ +;; Outline effect - shows only edges +(require-primitives "image") + +(define-effect outline + :params ( + (thickness :type int :default 2 :range [1 10]) + (threshold :type int :default 100 :range [20 300]) + (color :type list :default (list 0 0 0)) + (fill_mode :type string :default "original") + ) + (let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold)) + (dilated (if (> thickness 1) + (dilate edge-img thickness) + edge-img)) + (base (cond + ((= fill_mode "original") (copy frame)) + ((= fill_mode "white") (make-image (image:width frame) (image:height frame) (list 255 255 255))) + (else (make-image (image:width frame) (image:height frame) (list 0 0 0)))))) + (map-pixels base + (lambda (x y c) + (let ((edge-val (luminance (pixel dilated x y)))) + (if (> edge-val 128) + color + c)))))) diff --git a/l1/sexp_effects/effects/pixelate.sexp b/l1/sexp_effects/effects/pixelate.sexp new file mode 100644 index 0000000..3d28ce1 --- /dev/null +++ b/l1/sexp_effects/effects/pixelate.sexp @@ -0,0 +1,13 @@ +;; Pixelate effect - creates blocky pixels +(require-primitives "image") + +(define-effect pixelate + :params ( + (block_size :type int :default 8 :range [2 64]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (small-w (max 1 (floor (/ w block_size)))) + (small-h (max 1 (floor (/ h block_size)))) + (small (image:resize frame small-w small-h "area"))) + (image:resize small w h "nearest"))) diff --git a/l1/sexp_effects/effects/pixelsort.sexp b/l1/sexp_effects/effects/pixelsort.sexp new file mode 100644 index 0000000..155ac13 --- /dev/null +++ b/l1/sexp_effects/effects/pixelsort.sexp @@ -0,0 +1,11 @@ +;; Pixelsort effect - glitch art pixel sorting + +(define-effect pixelsort + :params ( + (sort_by :type string :default "lightness") + (threshold_low :type int :default 50 :range [0 255]) + (threshold_high :type int :default 200 :range [0 255]) + (angle :type int :default 0 :range [0 180]) + (reverse :type bool :default false) + ) + (pixelsort frame sort_by threshold_low threshold_high angle reverse)) diff --git a/l1/sexp_effects/effects/posterize.sexp b/l1/sexp_effects/effects/posterize.sexp new file mode 100644 index 0000000..7052ed3 --- /dev/null +++ b/l1/sexp_effects/effects/posterize.sexp @@ -0,0 +1,8 @@ +;; Posterize effect - reduces color levels +(require-primitives "color_ops") + +(define-effect posterize + :params ( + (levels :type int :default 8 :range [2 32]) + ) + (color_ops:posterize frame levels)) diff --git a/l1/sexp_effects/effects/resize-frame.sexp b/l1/sexp_effects/effects/resize-frame.sexp new file mode 100644 index 0000000..a1cce27 --- /dev/null +++ b/l1/sexp_effects/effects/resize-frame.sexp @@ -0,0 +1,11 @@ +;; Resize effect - replaces RESIZE node +;; Note: uses target-w/target-h to avoid conflict with width/height primitives +(require-primitives "image") + +(define-effect resize-frame + :params ( + (target-w :type int :default 640 :desc "Target width in pixels") + (target-h :type int :default 480 :desc "Target height in pixels") + (mode :type string :default "linear" :choices [linear nearest area] :desc "Interpolation mode") + ) + (image:resize frame target-w target-h mode)) diff --git a/l1/sexp_effects/effects/rgb_split.sexp b/l1/sexp_effects/effects/rgb_split.sexp new file mode 100644 index 0000000..4582701 --- /dev/null +++ b/l1/sexp_effects/effects/rgb_split.sexp @@ -0,0 +1,13 @@ +;; RGB Split effect - chromatic aberration + +(define-effect rgb_split + :params ( + (offset_x :type int :default 10 :range [-50 50]) + (offset_y :type int :default 0 :range [-50 50]) + ) + (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (r-shifted (translate (merge-channels r r r) offset_x offset_y)) + (b-shifted (translate (merge-channels b b b) (- offset_x) (- offset_y)))) + (merge-channels (channel r-shifted 0) g (channel b-shifted 0)))) diff --git a/l1/sexp_effects/effects/ripple.sexp b/l1/sexp_effects/effects/ripple.sexp new file mode 100644 index 0000000..0bb7a8d --- /dev/null +++ b/l1/sexp_effects/effects/ripple.sexp @@ -0,0 +1,19 @@ +;; Ripple effect - radial wave distortion from center +(require-primitives "geometry" "image" "math") + +(define-effect ripple + :params ( + (frequency :type int :default 5 :range [1 20]) + (amplitude :type int :default 10 :range [0 50]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (decay :type int :default 1 :range [0 5]) + (speed :type int :default 1 :range [0 10]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (phase (* (or t 0) speed 2 pi)) + (coords (geometry:ripple-displace w h frequency amplitude cx cy decay phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/l1/sexp_effects/effects/rotate.sexp b/l1/sexp_effects/effects/rotate.sexp new file mode 100644 index 0000000..d06c2f7 --- /dev/null +++ b/l1/sexp_effects/effects/rotate.sexp @@ -0,0 +1,11 @@ +;; Rotate effect - rotates image + +(require-primitives "geometry") + +(define-effect rotate + :params ( + (angle :type int :default 0 :range [-360 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((total-angle (+ angle (* speed t)))) + (geometry:rotate-img frame total-angle))) diff --git a/l1/sexp_effects/effects/saturation.sexp b/l1/sexp_effects/effects/saturation.sexp new file mode 100644 index 0000000..9852dc7 --- /dev/null +++ b/l1/sexp_effects/effects/saturation.sexp @@ -0,0 +1,9 @@ +;; Saturation effect - adjusts color saturation +;; Uses vectorized shift-hsv primitive for fast processing +(require-primitives "color_ops") + +(define-effect saturation + :params ( + (amount :type int :default 1 :range [0 3]) + ) + (color_ops:adjust-saturation frame amount)) diff --git a/l1/sexp_effects/effects/scanlines.sexp b/l1/sexp_effects/effects/scanlines.sexp new file mode 100644 index 0000000..ddfcf44 --- /dev/null +++ b/l1/sexp_effects/effects/scanlines.sexp @@ -0,0 +1,15 @@ +;; Scanlines effect - VHS-style horizontal line shifting +(require-primitives "core") + +(define-effect scanlines + :params ( + (amplitude :type int :default 10 :range [0 100]) + (frequency :type int :default 10 :range [1 100]) + (randomness :type float :default 0.5 :range [0 1]) + ) + (map-rows frame + (lambda (y row) + (let* ((sine-shift (* amplitude (sin (/ (* y 6.28) (max 1 frequency))))) + (rand-shift (core:rand-range (- amplitude) amplitude)) + (shift (floor (lerp sine-shift rand-shift randomness)))) + (roll row shift 0))))) diff --git a/l1/sexp_effects/effects/sepia.sexp b/l1/sexp_effects/effects/sepia.sexp new file mode 100644 index 0000000..e3a5875 --- /dev/null +++ b/l1/sexp_effects/effects/sepia.sexp @@ -0,0 +1,7 @@ +;; Sepia effect - applies sepia tone +;; Classic warm vintage look +(require-primitives "color_ops") + +(define-effect sepia + :params () + (color_ops:sepia frame)) diff --git a/l1/sexp_effects/effects/sharpen.sexp b/l1/sexp_effects/effects/sharpen.sexp new file mode 100644 index 0000000..538bd7f --- /dev/null +++ b/l1/sexp_effects/effects/sharpen.sexp @@ -0,0 +1,8 @@ +;; Sharpen effect - sharpens edges +(require-primitives "image") + +(define-effect sharpen + :params ( + (amount :type int :default 1 :range [0 5]) + ) + (image:sharpen frame amount)) diff --git a/l1/sexp_effects/effects/strobe.sexp b/l1/sexp_effects/effects/strobe.sexp new file mode 100644 index 0000000..2bf80b4 --- /dev/null +++ b/l1/sexp_effects/effects/strobe.sexp @@ -0,0 +1,16 @@ +;; Strobe effect - holds frames for choppy look +(require-primitives "core") + +(define-effect strobe + :params ( + (frame_rate :type int :default 12 :range [1 60]) + ) + (let* ((held (state-get "held" nil)) + (held-until (state-get "held-until" 0)) + (frame-duration (/ 1 frame_rate))) + (if (or (core:is-nil held) (>= t held-until)) + (begin + (state-set "held" (copy frame)) + (state-set "held-until" (+ t frame-duration)) + frame) + held))) diff --git a/l1/sexp_effects/effects/swirl.sexp b/l1/sexp_effects/effects/swirl.sexp new file mode 100644 index 0000000..ba9cf57 --- /dev/null +++ b/l1/sexp_effects/effects/swirl.sexp @@ -0,0 +1,17 @@ +;; Swirl effect - spiral vortex distortion +(require-primitives "geometry" "image") + +(define-effect swirl + :params ( + (strength :type int :default 1 :range [-10 10]) + (radius :type float :default 0.5 :range [0.1 2]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (falloff :type string :default "quadratic") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:swirl-coords w h strength radius cx cy falloff))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/l1/sexp_effects/effects/threshold.sexp b/l1/sexp_effects/effects/threshold.sexp new file mode 100644 index 0000000..50d3bc5 --- /dev/null +++ b/l1/sexp_effects/effects/threshold.sexp @@ -0,0 +1,9 @@ +;; Threshold effect - converts to black and white +(require-primitives "color_ops") + +(define-effect threshold + :params ( + (level :type int :default 128 :range [0 255]) + (invert :type bool :default false) + ) + (color_ops:threshold frame level invert)) diff --git a/l1/sexp_effects/effects/tile_grid.sexp b/l1/sexp_effects/effects/tile_grid.sexp new file mode 100644 index 0000000..44487a9 --- /dev/null +++ b/l1/sexp_effects/effects/tile_grid.sexp @@ -0,0 +1,29 @@ +;; Tile Grid effect - tiles image in grid +(require-primitives "geometry" "image") + +(define-effect tile_grid + :params ( + (rows :type int :default 2 :range [1 10]) + (cols :type int :default 2 :range [1 10]) + (gap :type int :default 0 :range [0 50]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (tile-w (floor (/ (- w (* gap (- cols 1))) cols))) + (tile-h (floor (/ (- h (* gap (- rows 1))) rows))) + (tile (image:resize frame tile-w tile-h "area")) + (result (make-image w h (list 0 0 0)))) + (begin + ;; Manually place tiles using nested iteration + ;; This is a simplified version - full version would loop + (paste result tile 0 0) + (if (> cols 1) + (paste result tile (+ tile-w gap) 0) + nil) + (if (> rows 1) + (paste result tile 0 (+ tile-h gap)) + nil) + (if (and (> cols 1) (> rows 1)) + (paste result tile (+ tile-w gap) (+ tile-h gap)) + nil) + result))) diff --git a/l1/sexp_effects/effects/trails.sexp b/l1/sexp_effects/effects/trails.sexp new file mode 100644 index 0000000..5c0fc7c --- /dev/null +++ b/l1/sexp_effects/effects/trails.sexp @@ -0,0 +1,20 @@ +;; Trails effect - persistent motion trails +(require-primitives "image" "blending") + +(define-effect trails + :params ( + (persistence :type float :default 0.8 :range [0 0.99]) + ) + (let* ((buffer (state-get "buffer" nil)) + (current frame)) + (if (= buffer nil) + (begin + (state-set "buffer" (copy frame)) + frame) + (let* ((faded (blending:blend-images buffer + (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 persistence))) + (result (blending:blend-mode faded current "lighten"))) + (begin + (state-set "buffer" result) + result))))) diff --git a/l1/sexp_effects/effects/vignette.sexp b/l1/sexp_effects/effects/vignette.sexp new file mode 100644 index 0000000..46e63ee --- /dev/null +++ b/l1/sexp_effects/effects/vignette.sexp @@ -0,0 +1,23 @@ +;; Vignette effect - darkens corners +(require-primitives "image") + +(define-effect vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type int :default 1 :range [0.5 2]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (* (sqrt (+ (* cx cx) (* cy cy))) radius))) + (map-pixels frame + (lambda (x y c) + (let* ((dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (factor (- 1 (* (/ dist max-dist) strength))) + (factor (clamp factor 0 1))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/l1/sexp_effects/effects/wave.sexp b/l1/sexp_effects/effects/wave.sexp new file mode 100644 index 0000000..98b03c2 --- /dev/null +++ b/l1/sexp_effects/effects/wave.sexp @@ -0,0 +1,22 @@ +;; Wave effect - sine wave displacement distortion +(require-primitives "geometry" "image") + +(define-effect wave + :params ( + (amplitude :type int :default 10 :range [0 100]) + (wavelength :type int :default 50 :range [10 500]) + (speed :type int :default 1 :range [0 10]) + (direction :type string :default "horizontal") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + ;; Use _time for animation phase + (phase (* (or _time 0) speed 2 pi)) + ;; Calculate frequency: waves per dimension + (freq (/ (if (= direction "vertical") w h) wavelength)) + (axis (cond + ((= direction "horizontal") "x") + ((= direction "vertical") "y") + (else "both"))) + (coords (geometry:wave-coords w h axis freq amplitude phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/l1/sexp_effects/effects/xector_feathered_blend.sexp b/l1/sexp_effects/effects/xector_feathered_blend.sexp new file mode 100644 index 0000000..96224fb --- /dev/null +++ b/l1/sexp_effects/effects/xector_feathered_blend.sexp @@ -0,0 +1,44 @@ +;; Feathered blend - blend two same-size frames with distance-based falloff +;; Center shows overlay, edges show background, with smooth transition + +(require-primitives "xector") + +(define-effect xector_feathered_blend + :params ( + (inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)") + (fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)") + (overlay :type frame :default nil :desc "Frame to blend in center") + ) + (let* ( + ;; Get normalized distance from center (0 at center, ~1 at corners) + (dist (dist-from-center frame)) + (max-dist (βmax dist)) + (dist-norm (α/ dist max-dist)) + + ;; Calculate blend factor: + ;; - 1.0 when dist-norm < inner-radius (fully overlay) + ;; - 0.0 when dist-norm > inner-radius + fade-width (fully background) + ;; - linear ramp between + (t (α/ (α- dist-norm inner-radius) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + (inv-blend (α- 1 blend)) + + ;; Background channels + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + (if (nil? overlay) + ;; No overlay - visualize the blend mask + (let ((vis (α* blend 255))) + (rgb vis vis vis)) + + ;; Blend overlay with background using the mask + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; lerp: bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r inv-blend) (α* ov-r blend))) + (g-out (α+ (α* bg-g inv-blend) (α* ov-g blend))) + (b-out (α+ (α* bg-b inv-blend) (α* ov-b blend)))) + (rgb r-out g-out b-out))))) diff --git a/l1/sexp_effects/effects/xector_grain.sexp b/l1/sexp_effects/effects/xector_grain.sexp new file mode 100644 index 0000000..64ebfa6 --- /dev/null +++ b/l1/sexp_effects/effects/xector_grain.sexp @@ -0,0 +1,34 @@ +;; Film grain effect using xector operations +;; Demonstrates random xectors and mixing scalar/xector math + +(require-primitives "xector") + +(define-effect xector_grain + :params ( + (intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity") + (colored :type bool :default false :desc "Use colored grain") + ) + (let* ( + ;; Extract channels + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Generate noise xector(s) + ;; randn-x generates normal distribution noise + (grain-amount (* intensity 50))) + + (if colored + ;; Colored grain: different noise per channel + (let* ((nr (randn-x frame 0 grain-amount)) + (ng (randn-x frame 0 grain-amount)) + (nb (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r nr) 0 255) + (αclamp (α+ g ng) 0 255) + (αclamp (α+ b nb) 0 255))) + + ;; Monochrome grain: same noise for all channels + (let ((n (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r n) 0 255) + (αclamp (α+ g n) 0 255) + (αclamp (α+ b n) 0 255)))))) diff --git a/l1/sexp_effects/effects/xector_inset_blend.sexp b/l1/sexp_effects/effects/xector_inset_blend.sexp new file mode 100644 index 0000000..597e23a --- /dev/null +++ b/l1/sexp_effects/effects/xector_inset_blend.sexp @@ -0,0 +1,57 @@ +;; Inset blend - fade a smaller frame into a larger background +;; Uses distance-based alpha for smooth transition (no hard edges) + +(require-primitives "xector") + +(define-effect xector_inset_blend + :params ( + (x :type int :default 0 :desc "X position of inset") + (y :type int :default 0 :desc "Y position of inset") + (fade-width :type int :default 50 :desc "Width of fade region in pixels") + (overlay :type frame :default nil :desc "The smaller frame to inset") + ) + (let* ( + ;; Get dimensions + (bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height + (bg-w bg-h) ;; placeholder + + ;; For now, create a simple centered circular blend + ;; Distance from center of overlay position + (cx (+ x (/ (- bg-w (* 2 x)) 2))) + (cy (+ y (/ (- bg-h (* 2 y)) 2))) + + ;; Get coordinates as xectors + (px (x-coords frame)) + (py (y-coords frame)) + + ;; Distance from center + (dx (α- px cx)) + (dy (α- py cy)) + (dist (αsqrt (α+ (α* dx dx) (α* dy dy)))) + + ;; Inner radius (fully overlay) and outer radius (fully background) + (inner-r (- (/ bg-w 2) x fade-width)) + (outer-r (- (/ bg-w 2) x)) + + ;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between + (t (α/ (α- dist inner-r) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + + ;; Extract channels from both frames + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + ;; If overlay provided, blend it + (if overlay + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; Linear blend: result = bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend))) + (g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend))) + (b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend)))) + (rgb r-out g-out b-out)) + ;; No overlay - just show the blend mask for debugging + (let ((mask-vis (α* blend 255))) + (rgb mask-vis mask-vis mask-vis))))) diff --git a/l1/sexp_effects/effects/xector_threshold.sexp b/l1/sexp_effects/effects/xector_threshold.sexp new file mode 100644 index 0000000..c571468 --- /dev/null +++ b/l1/sexp_effects/effects/xector_threshold.sexp @@ -0,0 +1,27 @@ +;; Threshold effect using xector operations +;; Demonstrates where (conditional select) and β (reduction) for normalization + +(require-primitives "xector") + +(define-effect xector_threshold + :params ( + (threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)") + (invert :type bool :default false :desc "Invert the threshold") + ) + (let* ( + ;; Get grayscale luminance as xector + (luma (gray frame)) + + ;; Normalize to 0-1 range + (luma-norm (α/ luma 255)) + + ;; Create boolean mask: pixels above threshold + (mask (if invert + (α< luma-norm threshold) + (α>= luma-norm threshold))) + + ;; Use where to select: white (255) if above threshold, black (0) if below + (out (where mask 255 0))) + + ;; Output as grayscale (same value for R, G, B) + (rgb out out out))) diff --git a/l1/sexp_effects/effects/xector_vignette.sexp b/l1/sexp_effects/effects/xector_vignette.sexp new file mode 100644 index 0000000..d654ca7 --- /dev/null +++ b/l1/sexp_effects/effects/xector_vignette.sexp @@ -0,0 +1,36 @@ +;; Vignette effect using xector operations +;; Demonstrates α (element-wise) and β (reduction) patterns + +(require-primitives "xector") + +(define-effect xector_vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type float :default 1.0 :range [0.5 2]) + ) + (let* ( + ;; Get normalized distance from center for each pixel + (dist (dist-from-center frame)) + + ;; Calculate max distance (corner distance) + (max-dist (* (βmax dist) radius)) + + ;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength) + ;; Using explicit α operators + (factor (α- 1 (α* (α/ dist max-dist) strength))) + + ;; Clamp factor to [0, 1] + (factor (αclamp factor 0 1)) + + ;; Extract channels as xectors + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Apply factor to each channel (implicit element-wise via Xector operators) + (r-out (* r factor)) + (g-out (* g factor)) + (b-out (* b factor))) + + ;; Combine back to frame + (rgb r-out g-out b-out))) diff --git a/l1/sexp_effects/effects/zoom.sexp b/l1/sexp_effects/effects/zoom.sexp new file mode 100644 index 0000000..6e4b9ff --- /dev/null +++ b/l1/sexp_effects/effects/zoom.sexp @@ -0,0 +1,8 @@ +;; Zoom effect - zooms in/out from center +(require-primitives "geometry") + +(define-effect zoom + :params ( + (amount :type int :default 1 :range [0.1 5]) + ) + (geometry:scale-img frame amount amount)) diff --git a/l1/sexp_effects/interpreter.py b/l1/sexp_effects/interpreter.py new file mode 100644 index 0000000..406f6da --- /dev/null +++ b/l1/sexp_effects/interpreter.py @@ -0,0 +1,1085 @@ +""" +S-Expression Effect Interpreter + +Interprets effect definitions written in S-expressions. +Only allows safe primitives - no arbitrary code execution. +""" + +import numpy as np +from typing import Any, Dict, List, Optional, Callable +from pathlib import Path + +from .parser import Symbol, Keyword, parse, parse_file +from .primitives import PRIMITIVES, reset_rng + + +def _is_symbol(x) -> bool: + """Check if x is a Symbol (duck typing to support multiple Symbol classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Symbol' + + +def _is_keyword(x) -> bool: + """Check if x is a Keyword (duck typing to support multiple Keyword classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Keyword' + + +def _symbol_name(x) -> str: + """Get the name from a Symbol.""" + return x.name if hasattr(x, 'name') else str(x) + + +class Environment: + """Lexical environment for variable bindings.""" + + def __init__(self, parent: 'Environment' = None): + self.bindings: Dict[str, Any] = {} + self.parent = parent + + def get(self, name: str) -> Any: + if name in self.bindings: + return self.bindings[name] + if self.parent: + return self.parent.get(name) + raise NameError(f"Undefined variable: {name}") + + def set(self, name: str, value: Any): + self.bindings[name] = value + + def has(self, name: str) -> bool: + if name in self.bindings: + return True + if self.parent: + return self.parent.has(name) + return False + + +class Lambda: + """A user-defined function (lambda).""" + + def __init__(self, params: List[str], body: Any, env: Environment): + self.params = params + self.body = body + self.env = env # Closure environment + + def __repr__(self): + return f"" + + +class EffectDefinition: + """A parsed effect definition.""" + + def __init__(self, name: str, params: Dict[str, Any], body: Any): + self.name = name + self.params = params # {name: (type, default)} + self.body = body + + def __repr__(self): + return f"" + + +class Interpreter: + """ + S-Expression interpreter for effects. + + Provides a safe execution environment where only + whitelisted primitives can be called. + + Args: + minimal_primitives: If True, only load core primitives (arithmetic, comparison, + basic data access). Additional primitives must be loaded with + (require-primitives) or (with-primitives). + If False (default), load all legacy primitives for backward compatibility. + """ + + def __init__(self, minimal_primitives: bool = False): + # Base environment with primitives + self.global_env = Environment() + self.minimal_primitives = minimal_primitives + + if minimal_primitives: + # Load only core primitives + from .primitive_libs.core import PRIMITIVES as CORE_PRIMITIVES + for name, fn in CORE_PRIMITIVES.items(): + self.global_env.set(name, fn) + else: + # Load all legacy primitives for backward compatibility + for name, fn in PRIMITIVES.items(): + self.global_env.set(name, fn) + + # Special values + self.global_env.set('true', True) + self.global_env.set('false', False) + self.global_env.set('nil', None) + + # Loaded effect definitions + self.effects: Dict[str, EffectDefinition] = {} + + def eval(self, expr: Any, env: Environment = None) -> Any: + """Evaluate an S-expression.""" + if env is None: + env = self.global_env + + # Atoms + if isinstance(expr, (int, float, str, bool)): + return expr + + if expr is None: + return None + + # Handle Symbol (duck typing to support both sexp_effects.parser.Symbol and artdag.sexp.parser.Symbol) + if _is_symbol(expr): + return env.get(expr.name) + + # Handle Keyword (duck typing) + if _is_keyword(expr): + return expr # Keywords evaluate to themselves + + if isinstance(expr, np.ndarray): + return expr # Images pass through + + # Lists (function calls / special forms) + if isinstance(expr, list): + if not expr: + return [] + + head = expr[0] + + # Special forms + if _is_symbol(head): + form = head.name + + # Quote + if form == 'quote': + return expr[1] + + # Define + if form == 'define': + name = expr[1] + if _is_symbol(name): + # Simple define: (define name value) + value = self.eval(expr[2], env) + self.global_env.set(name.name, value) + return value + elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]): + # Function define: (define (fn-name args...) body) + # Desugars to: (define fn-name (lambda (args...) body)) + fn_name = name[0].name + params = [p.name if _is_symbol(p) else p for p in name[1:]] + body = expr[2] + fn = Lambda(params, body, env) + self.global_env.set(fn_name, fn) + return fn + else: + raise SyntaxError(f"define requires symbol or (name args...), got {name}") + + # Define-effect + if form == 'define-effect': + return self._define_effect(expr, env) + + # Lambda + if form == 'lambda' or form == 'λ': + params = [p.name if _is_symbol(p) else p for p in expr[1]] + body = expr[2] + return Lambda(params, body, env) + + # Let + if form == 'let': + return self._eval_let(expr, env) + + # Let* + if form == 'let*': + return self._eval_let_star(expr, env) + + # If + if form == 'if': + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + # Cond + if form == 'cond': + return self._eval_cond(expr, env) + + # And + if form == 'and': + result = True + for e in expr[1:]: + result = self.eval(e, env) + if not result: + return False + return result + + # Or + if form == 'or': + for e in expr[1:]: + result = self.eval(e, env) + if result: + return result + return False + + # Not + if form == 'not': + return not self.eval(expr[1], env) + + # Begin (sequence) + if form == 'begin': + result = None + for e in expr[1:]: + result = self.eval(e, env) + return result + + # Thread-first macro: (-> x (f a) (g b)) => (g (f x a) b) + if form == '->': + result = self.eval(expr[1], env) + for form_expr in expr[2:]: + if isinstance(form_expr, list): + # Insert result as first arg: (f a b) => (f result a b) + result = self.eval([form_expr[0], result] + form_expr[1:], env) + else: + # Just a symbol: f => (f result) + result = self.eval([form_expr, result], env) + return result + + # Set! (mutation) + if form == 'set!': + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + value = self.eval(expr[2], env) + # Find and update in appropriate scope + scope = env + while scope: + if name in scope.bindings: + scope.bindings[name] = value + return value + scope = scope.parent + raise NameError(f"Cannot set undefined variable: {name}") + + # State-get / state-set (for effect state) + if form == 'state-get': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + default = self.eval(expr[2], env) if len(expr) > 2 else None + return state.get(key, default) + + if form == 'state-set': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + value = self.eval(expr[2], env) + state[key] = value + return value + + # ascii-fx-zone special form - delays evaluation of expression parameters + if form == 'ascii-fx-zone': + return self._eval_ascii_fx_zone(expr, env) + + # with-primitives - load primitive library and scope to body + if form == 'with-primitives': + return self._eval_with_primitives(expr, env) + + # require-primitives - load primitive library into current scope + if form == 'require-primitives': + return self._eval_require_primitives(expr, env) + + # require - load .sexp file into current scope + if form == 'require': + return self._eval_require(expr, env) + + # Function call + fn = self.eval(head, env) + args = [self.eval(arg, env) for arg in expr[1:]] + + # Handle keyword arguments + pos_args = [] + kw_args = {} + i = 0 + while i < len(args): + if _is_keyword(args[i]): + kw_args[args[i].name] = args[i + 1] if i + 1 < len(args) else None + i += 2 + else: + pos_args.append(args[i]) + i += 1 + + return self._apply(fn, pos_args, kw_args, env) + + raise TypeError(f"Cannot evaluate: {expr}") + + def _wrap_lambda(self, lam: 'Lambda') -> Callable: + """Wrap a Lambda in a Python callable for use by primitives.""" + def wrapper(*args): + new_env = Environment(lam.env) + for i, param in enumerate(lam.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(lam.body, new_env) + return wrapper + + def _apply(self, fn: Any, args: List[Any], kwargs: Dict[str, Any], env: Environment) -> Any: + """Apply a function to arguments.""" + if isinstance(fn, Lambda): + # User-defined function + new_env = Environment(fn.env) + for i, param in enumerate(fn.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(fn.body, new_env) + + elif callable(fn): + # Wrap any Lambda arguments so primitives can call them + wrapped_args = [] + for arg in args: + if isinstance(arg, Lambda): + wrapped_args.append(self._wrap_lambda(arg)) + else: + wrapped_args.append(arg) + + # Inject _interp and _env for primitives that need them + import inspect + try: + sig = inspect.signature(fn) + params = sig.parameters + if '_interp' in params and '_interp' not in kwargs: + kwargs['_interp'] = self + if '_env' in params and '_env' not in kwargs: + kwargs['_env'] = env + except (ValueError, TypeError): + # Some built-in functions don't have inspectable signatures + pass + + # Primitive function + if kwargs: + return fn(*wrapped_args, **kwargs) + return fn(*wrapped_args) + + else: + raise TypeError(f"Cannot call: {fn}") + + def _parse_bindings(self, bindings: list) -> list: + """Parse bindings in either Scheme or Clojure style. + + Scheme: ((x 1) (y 2)) -> [(x, 1), (y, 2)] + Clojure: [x 1 y 2] -> [(x, 1), (y, 2)] + """ + if not bindings: + return [] + + # Check if Clojure style (flat list with symbols and values alternating) + if _is_symbol(bindings[0]): + # Clojure style: [x 1 y 2] + pairs = [] + i = 0 + while i < len(bindings) - 1: + name = bindings[i].name if _is_symbol(bindings[i]) else bindings[i] + value = bindings[i + 1] + pairs.append((name, value)) + i += 2 + return pairs + else: + # Scheme style: ((x 1) (y 2)) + pairs = [] + for binding in bindings: + name = binding[0].name if _is_symbol(binding[0]) else binding[0] + value = binding[1] + pairs.append((name, value)) + return pairs + + def _eval_let(self, expr: Any, env: Environment) -> Any: + """Evaluate let expression: (let ((x 1) (y 2)) body) or (let [x 1 y 2] body) + + Note: Uses sequential binding (like Clojure let / Scheme let*) so each + binding can reference previous bindings. + """ + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Sequential: can see previous bindings + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_let_star(self, expr: Any, env: Environment) -> Any: + """Evaluate let* expression: sequential bindings.""" + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Evaluate in current env + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_cond(self, expr: Any, env: Environment) -> Any: + """Evaluate cond expression.""" + for clause in expr[1:]: + test = clause[0] + if _is_symbol(test) and test.name == 'else': + return self.eval(clause[1], env) + if self.eval(test, env): + return self.eval(clause[1], env) + return None + + def _eval_with_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate with-primitives: scoped primitive library loading. + + Syntax: + (with-primitives "math" + (sin (* x pi))) + + (with-primitives "math" :path "custom/math.py" + body) + + The primitives from the library are only available within the body. + """ + # Parse library name and optional path + lib_name = expr[1] + if _is_symbol(lib_name): + lib_name = lib_name.name + + path = None + body_start = 2 + + # Check for :path keyword + if len(expr) > 2 and _is_keyword(expr[2]) and expr[2].name == 'path': + path = expr[3] + body_start = 4 + + # Load the primitive library + primitives = self.load_primitive_library(lib_name, path) + + # Create new environment with primitives + new_env = Environment(env) + for name, fn in primitives.items(): + new_env.set(name, fn) + + # Evaluate body in new environment + result = None + for e in expr[body_start:]: + result = self.eval(e, new_env) + return result + + def _eval_require_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require-primitives: load primitives into current scope. + + Syntax: + (require-primitives "math" "color" "filters") + + Unlike with-primitives, this loads into the current environment + (typically used at top-level to set up an effect's dependencies). + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + primitives = self.load_primitive_library(lib_name) + for name, fn in primitives.items(): + env.set(name, fn) + + return None + + def load_primitive_library(self, name: str, path: str = None) -> dict: + """ + Load a primitive library by name or path. + + Returns dict of {name: function}. + """ + from .primitive_libs import load_primitive_library + return load_primitive_library(name, path) + + def _eval_require(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require: load a .sexp file and evaluate its definitions. + + Syntax: + (require "derived") ; loads derived.sexp from sexp_effects/ + (require "path/to/file.sexp") ; loads from explicit path + + Definitions from the file are added to the current environment. + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + # Find the .sexp file + sexp_path = self._find_sexp_file(lib_name) + if sexp_path is None: + raise ValueError(f"Cannot find sexp file: {lib_name}") + + # Parse and evaluate the file + content = parse_file(sexp_path) + + # Evaluate all top-level expressions + if isinstance(content, list) and content and isinstance(content[0], list): + for e in content: + self.eval(e, env) + else: + self.eval(content, env) + + return None + + def _find_sexp_file(self, name: str) -> Optional[str]: + """Find a .sexp file by name.""" + # Try various locations + candidates = [ + # Explicit path + name, + name + '.sexp', + # In sexp_effects directory + Path(__file__).parent / f'{name}.sexp', + Path(__file__).parent / name, + # In effects directory + Path(__file__).parent / 'effects' / f'{name}.sexp', + Path(__file__).parent / 'effects' / name, + ] + + for path in candidates: + p = Path(path) if not isinstance(path, Path) else path + if p.exists() and p.is_file(): + return str(p) + + return None + + def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any: + """ + Evaluate ascii-fx-zone special form. + + Syntax: + (ascii-fx-zone frame + :cols 80 + :alphabet "standard" + :color_mode "color" + :background "black" + :contrast 1.5 + :char_hue ;; NOT evaluated - passed to primitive + :char_saturation + :char_brightness + :char_scale + :char_rotation + :char_jitter ) + + The expression parameters (:char_hue, etc.) are NOT pre-evaluated. + They are passed as raw S-expressions to the primitive which + evaluates them per-zone with zone context variables injected. + + Requires: (require-primitives "ascii") + """ + # Look up ascii-fx-zone primitive from environment + # It must be loaded via (require-primitives "ascii") + try: + prim_ascii_fx_zone = env.get('ascii-fx-zone') + except NameError: + raise NameError( + "ascii-fx-zone primitive not found. " + "Add (require-primitives \"ascii\") to your effect file." + ) + + # Expression parameter names that should NOT be evaluated + expr_params = {'char_hue', 'char_saturation', 'char_brightness', + 'char_scale', 'char_rotation', 'char_jitter', 'cell_effect'} + + # Parse arguments + frame = self.eval(expr[1], env) # First arg is always the frame + + # Defaults + cols = 80 + char_size = None # If set, overrides cols + alphabet = "standard" + color_mode = "color" + background = "black" + contrast = 1.5 + char_hue = None + char_saturation = None + char_brightness = None + char_scale = None + char_rotation = None + char_jitter = None + cell_effect = None # Lambda for arbitrary per-cell effects + # Convenience params for staged recipes + energy = None + rotation_scale = 0 + # Extra params to pass to zone dict for lambdas + extra_params = {} + + # Parse keyword arguments + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item): + if i + 1 >= len(expr): + break + value_expr = expr[i + 1] + kw_name = item.name + + if kw_name in expr_params: + # Resolve symbol references but don't evaluate expressions + # This handles the case where effect definition passes a param like :char_hue char_hue + resolved = value_expr + if _is_symbol(value_expr): + try: + resolved = env.get(value_expr.name) + except NameError: + resolved = value_expr # Keep as symbol if not found + + if kw_name == 'char_hue': + char_hue = resolved + elif kw_name == 'char_saturation': + char_saturation = resolved + elif kw_name == 'char_brightness': + char_brightness = resolved + elif kw_name == 'char_scale': + char_scale = resolved + elif kw_name == 'char_rotation': + char_rotation = resolved + elif kw_name == 'char_jitter': + char_jitter = resolved + elif kw_name == 'cell_effect': + cell_effect = resolved + else: + # Evaluate normally + value = self.eval(value_expr, env) + if kw_name == 'cols': + cols = int(value) + elif kw_name == 'char_size': + # Handle nil/None values + if value is None or (_is_symbol(value) and value.name == 'nil'): + char_size = None + else: + char_size = int(value) + elif kw_name == 'alphabet': + alphabet = str(value) + elif kw_name == 'color_mode': + color_mode = str(value) + elif kw_name == 'background': + background = str(value) + elif kw_name == 'contrast': + contrast = float(value) + elif kw_name == 'energy': + if value is None or (_is_symbol(value) and value.name == 'nil'): + energy = None + else: + energy = float(value) + elif kw_name == 'rotation_scale': + rotation_scale = float(value) + else: + # Store any other params for lambdas to access + extra_params[kw_name] = value + i += 2 + else: + i += 1 + + # If energy and rotation_scale provided, build rotation expression + # rotation = energy * rotation_scale * position_factor + # position_factor: bottom-left=0, top-right=3 + # Formula: 1.5 * (zone-col-norm + (1 - zone-row-norm)) + if energy is not None and rotation_scale > 0: + # Build expression as S-expression list that will be evaluated per-zone + # (* (* energy rotation_scale) (* 1.5 (+ zone-col-norm (- 1 zone-row-norm)))) + energy_times_scale = energy * rotation_scale + # The position part uses zone variables, so we build it as an expression + char_rotation = [ + Symbol('*'), + energy_times_scale, + [Symbol('*'), 1.5, + [Symbol('+'), Symbol('zone-col-norm'), + [Symbol('-'), 1, Symbol('zone-row-norm')]]] + ] + + # Pull any extra params from environment that aren't standard params + # These are typically passed from recipes for use in cell_effect lambdas + standard_params = { + 'cols', 'char_size', 'alphabet', 'color_mode', 'background', 'contrast', + 'char_hue', 'char_saturation', 'char_brightness', 'char_scale', + 'char_rotation', 'char_jitter', 'cell_effect', 'energy', 'rotation_scale', + 'frame', 't', '_time', '__state__', '__interp__', 'true', 'false', 'nil' + } + # Check environment for extra bindings + current_env = env + while current_env is not None: + for k, v in current_env.bindings.items(): + if k not in standard_params and k not in extra_params and not callable(v): + # Add non-standard, non-callable bindings to extra_params + if isinstance(v, (int, float, str, bool)) or v is None: + extra_params[k] = v + current_env = current_env.parent + + # Call the primitive with interpreter and env for expression evaluation + return prim_ascii_fx_zone( + frame, + cols=cols, + char_size=char_size, + alphabet=alphabet, + color_mode=color_mode, + background=background, + contrast=contrast, + char_hue=char_hue, + char_saturation=char_saturation, + char_brightness=char_brightness, + char_scale=char_scale, + char_rotation=char_rotation, + char_jitter=char_jitter, + cell_effect=cell_effect, + energy=energy, + rotation_scale=rotation_scale, + _interp=self, + _env=env, + **extra_params + ) + + def _define_effect(self, expr: Any, env: Environment) -> EffectDefinition: + """ + Parse effect definition. + + Required syntax: + (define-effect name + :params ( + (param1 :type int :default 8 :desc "description") + ) + body) + + Effects MUST use :params syntax. Legacy ((param default) ...) is not supported. + """ + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + + params = {} + body = None + found_params = False + + # Parse :params and body + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item) and item.name == "params": + # :params syntax + if i + 1 >= len(expr): + raise SyntaxError(f"Effect '{name}': Missing params list after :params keyword") + params_list = expr[i + 1] + params = self._parse_params_block(params_list) + found_params = True + i += 2 + elif _is_keyword(item): + # Skip other keywords (like :desc) + i += 2 + elif body is None: + # First non-keyword item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy syntax and reject it + if isinstance(first_elem, list) and len(first_elem) >= 2: + raise SyntaxError( + f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. " + f"Use :params block instead." + ) + body = item + i += 1 + else: + i += 1 + + if body is None: + raise SyntaxError(f"Effect '{name}': No body found") + + if not found_params: + raise SyntaxError( + f"Effect '{name}': Missing :params block. " + f"For effects with no parameters, use empty :params ()" + ) + + effect = EffectDefinition(name, params, body) + self.effects[name] = effect + return effect + + def _parse_params_block(self, params_list: list) -> Dict[str, Any]: + """ + Parse :params block syntax: + ( + (param_name :type int :default 8 :range [4 32] :desc "description") + ) + """ + params = {} + for param_def in params_list: + if not isinstance(param_def, list) or len(param_def) < 1: + continue + + # First element is the parameter name + first = param_def[0] + if _is_symbol(first): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + continue + + # Parse keyword arguments + default = None + i = 1 + while i < len(param_def): + item = param_def[i] + if _is_keyword(item): + if i + 1 >= len(param_def): + break + kw_value = param_def[i + 1] + + if item.name == "default": + default = kw_value + i += 2 + else: + i += 1 + + params[param_name] = default + + return params + + def load_effect(self, path: str) -> EffectDefinition: + """Load an effect definition from a .sexp file.""" + expr = parse_file(path) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the last defined effect + if self.effects: + return list(self.effects.values())[-1] + return None + + def load_effect_from_string(self, sexp_content: str, effect_name: str = None) -> EffectDefinition: + """Load an effect definition from an S-expression string. + + Args: + sexp_content: The S-expression content as a string + effect_name: Optional name hint (used if effect doesn't define its own name) + + Returns: + The loaded EffectDefinition + """ + expr = parse(sexp_content) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the effect if we can find it by name + if effect_name and effect_name in self.effects: + return self.effects[effect_name] + + # Return the most recently loaded effect + if self.effects: + return list(self.effects.values())[-1] + + return None + + def run_effect(self, name: str, frame, params: Dict[str, Any], + state: Dict[str, Any]) -> tuple: + """ + Run an effect on frame(s). + + Args: + name: Effect name + frame: Input frame (H, W, 3) RGB uint8, or list of frames for multi-input + params: Effect parameters (overrides defaults) + state: Persistent state dict + + Returns: + (output_frame, new_state) + """ + if name not in self.effects: + raise ValueError(f"Unknown effect: {name}") + + effect = self.effects[name] + + # Create environment for this run + env = Environment(self.global_env) + + # Bind frame(s) - support both single frame and list of frames + if isinstance(frame, list): + # Multi-input effect + frames = frame + env.set('frame', frames[0] if frames else None) # Backwards compat + env.set('inputs', frames) + # Named frame bindings + for i, f in enumerate(frames): + env.set(f'frame-{chr(ord("a") + i)}', f) # frame-a, frame-b, etc. + else: + # Single-input effect + env.set('frame', frame) + + # Bind state + if state is None: + state = {} + env.set('__state__', state) + + # Validate that all provided params are known (except internal params) + # Extra params are allowed and will be passed through to cell_effect lambdas + known_params = set(effect.params.keys()) + internal_params = {'_time', 'seed', '_binding', 'effect', 'cid', 'hash', 'effect_path'} + extra_effect_params = {} # Unknown params passed through for cell_effect lambdas + for k in params.keys(): + if k not in known_params and k not in internal_params: + # Allow unknown params - they'll be passed to cell_effect lambdas via zone dict + extra_effect_params[k] = params[k] + + # Bind parameters (defaults + overrides) + for pname, pdefault in effect.params.items(): + value = params.get(pname) + if value is None: + # Evaluate default if it's an expression (list) or a symbol (like 'nil') + if isinstance(pdefault, list) or _is_symbol(pdefault): + value = self.eval(pdefault, env) + else: + value = pdefault + env.set(pname, value) + + # Bind extra params (unknown params passed through for cell_effect lambdas) + for k, v in extra_effect_params.items(): + env.set(k, v) + + # Reset RNG with seed if provided + seed = params.get('seed', 42) + reset_rng(int(seed)) + + # Bind time if provided + time_val = params.get('_time', 0) + env.set('t', time_val) + env.set('_time', time_val) + + # Evaluate body + result = self.eval(effect.body, env) + + # Ensure result is an image + if not isinstance(result, np.ndarray): + result = frame + + return result, state + + def eval_with_zone(self, expr, env: Environment, zone) -> Any: + """ + Evaluate expression with zone-* variables injected. + + Args: + expr: Expression to evaluate (S-expression) + env: Parent environment with bound values + zone: ZoneContext object with cell data + + Zone variables injected: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: + Evaluated result (typically a number) + """ + # Create child environment with zone variables + zone_env = Environment(env) + zone_env.set('zone-row', zone.row) + zone_env.set('zone-col', zone.col) + zone_env.set('zone-row-norm', zone.row_norm) + zone_env.set('zone-col-norm', zone.col_norm) + zone_env.set('zone-lum', zone.luminance) + zone_env.set('zone-sat', zone.saturation) + zone_env.set('zone-hue', zone.hue) + zone_env.set('zone-r', zone.r) + zone_env.set('zone-g', zone.g) + zone_env.set('zone-b', zone.b) + + return self.eval(expr, zone_env) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +_interpreter = None +_interpreter_minimal = None + + +def get_interpreter(minimal_primitives: bool = False) -> Interpreter: + """Get or create the global interpreter. + + Args: + minimal_primitives: If True, return interpreter with only core primitives. + Additional primitives must be loaded with require-primitives or with-primitives. + """ + global _interpreter, _interpreter_minimal + + if minimal_primitives: + if _interpreter_minimal is None: + _interpreter_minimal = Interpreter(minimal_primitives=True) + return _interpreter_minimal + else: + if _interpreter is None: + _interpreter = Interpreter(minimal_primitives=False) + return _interpreter + + +def load_effect(path: str) -> EffectDefinition: + """Load an effect from a .sexp file.""" + return get_interpreter().load_effect(path) + + +def load_effects_dir(directory: str): + """Load all .sexp effects from a directory.""" + interp = get_interpreter() + dir_path = Path(directory) + for path in dir_path.glob('*.sexp'): + try: + interp.load_effect(str(path)) + except Exception as e: + print(f"Warning: Failed to load {path}: {e}") + + +def run_effect(name: str, frame: np.ndarray, params: Dict[str, Any], + state: Dict[str, Any] = None) -> tuple: + """Run an effect.""" + return get_interpreter().run_effect(name, frame, params, state or {}) + + +def list_effects() -> List[str]: + """List loaded effect names.""" + return list(get_interpreter().effects.keys()) + + +# ============================================================================= +# Adapter for existing effect system +# ============================================================================= + +def make_process_frame(effect_path: str) -> Callable: + """ + Create a process_frame function from a .sexp effect. + + This allows S-expression effects to be used with the existing + effect system. + """ + interp = get_interpreter() + interp.load_effect(effect_path) + effect_name = Path(effect_path).stem + + def process_frame(frame: np.ndarray, params: dict, state: dict) -> tuple: + return interp.run_effect(effect_name, frame, params, state) + + return process_frame diff --git a/l1/sexp_effects/parser.py b/l1/sexp_effects/parser.py new file mode 100644 index 0000000..5e17565 --- /dev/null +++ b/l1/sexp_effects/parser.py @@ -0,0 +1,396 @@ +""" +S-expression parser for ArtDAG recipes and plans. + +Supports: +- Lists: (a b c) +- Symbols: foo, bar-baz, -> +- Keywords: :key +- Strings: "hello world" +- Numbers: 42, 3.14, -1.5 +- Comments: ; to end of line +- Vectors: [a b c] (syntactic sugar for lists) +- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts) +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Union +import re + + +@dataclass +class Symbol: + """An unquoted symbol/identifier.""" + name: str + + def __repr__(self): + return f"Symbol({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Symbol): + return self.name == other.name + if isinstance(other, str): + return self.name == other + return False + + def __hash__(self): + return hash(self.name) + + +@dataclass +class Keyword: + """A keyword starting with colon.""" + name: str + + def __repr__(self): + return f"Keyword({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Keyword): + return self.name == other.name + return False + + def __hash__(self): + return hash((':' , self.name)) + + +class ParseError(Exception): + """Error during S-expression parsing.""" + def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1): + self.position = position + self.line = line + self.col = col + super().__init__(f"{message} at line {line}, column {col}") + + +class Tokenizer: + """Tokenize S-expression text into tokens.""" + + # Token patterns + WHITESPACE = re.compile(r'\s+') + COMMENT = re.compile(r';[^\n]*') + STRING = re.compile(r'"(?:[^"\\]|\\.)*"') + NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') + KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') + # Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations + SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*') + + def __init__(self, text: str): + self.text = text + self.pos = 0 + self.line = 1 + self.col = 1 + + def _advance(self, count: int = 1): + """Advance position, tracking line/column.""" + for _ in range(count): + if self.pos < len(self.text): + if self.text[self.pos] == '\n': + self.line += 1 + self.col = 1 + else: + self.col += 1 + self.pos += 1 + + def _skip_whitespace_and_comments(self): + """Skip whitespace and comments.""" + while self.pos < len(self.text): + # Whitespace + match = self.WHITESPACE.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + # Comments + match = self.COMMENT.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + break + + def peek(self) -> str | None: + """Peek at current character.""" + self._skip_whitespace_and_comments() + if self.pos >= len(self.text): + return None + return self.text[self.pos] + + def next_token(self) -> Any: + """Get the next token.""" + self._skip_whitespace_and_comments() + + if self.pos >= len(self.text): + return None + + char = self.text[self.pos] + start_line, start_col = self.line, self.col + + # Single-character tokens (parens, brackets, braces) + if char in '()[]{}': + self._advance() + return char + + # String + if char == '"': + match = self.STRING.match(self.text, self.pos) + if not match: + raise ParseError("Unterminated string", self.pos, self.line, self.col) + self._advance(match.end() - self.pos) + # Parse escape sequences + content = match.group()[1:-1] + content = content.replace('\\n', '\n') + content = content.replace('\\t', '\t') + content = content.replace('\\"', '"') + content = content.replace('\\\\', '\\') + return content + + # Keyword + if char == ':': + match = self.KEYWORD.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Keyword(match.group()[1:]) # Strip leading colon + raise ParseError(f"Invalid keyword", self.pos, self.line, self.col) + + # Number (must check before symbol due to - prefix) + if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and + (self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')): + match = self.NUMBER.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + num_str = match.group() + if '.' in num_str or 'e' in num_str or 'E' in num_str: + return float(num_str) + return int(num_str) + + # Symbol + match = self.SYMBOL.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Symbol(match.group()) + + raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col) + + +def parse(text: str) -> Any: + """ + Parse an S-expression string into Python data structures. + + Returns: + Parsed S-expression as nested Python structures: + - Lists become Python lists + - Symbols become Symbol objects + - Keywords become Keyword objects + - Strings become Python strings + - Numbers become int/float + + Example: + >>> parse('(recipe "test" :version "1.0")') + [Symbol('recipe'), 'test', Keyword('version'), '1.0'] + """ + tokenizer = Tokenizer(text) + result = _parse_expr(tokenizer) + + # Check for trailing content + if tokenizer.peek() is not None: + raise ParseError("Unexpected content after expression", + tokenizer.pos, tokenizer.line, tokenizer.col) + + return result + + +def parse_all(text: str) -> List[Any]: + """ + Parse multiple S-expressions from a string. + + Returns list of parsed expressions. + """ + tokenizer = Tokenizer(text) + results = [] + + while tokenizer.peek() is not None: + results.append(_parse_expr(tokenizer)) + + return results + + +def _parse_expr(tokenizer: Tokenizer) -> Any: + """Parse a single expression.""" + token = tokenizer.next_token() + + if token is None: + raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col) + + # List + if token == '(': + return _parse_list(tokenizer, ')') + + # Vector (sugar for list) + if token == '[': + return _parse_list(tokenizer, ']') + + # Map/dict: {:key1 val1 :key2 val2} + if token == '{': + return _parse_map(tokenizer) + + # Unexpected closers + if isinstance(token, str) and token in ')]}': + raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col) + + # Atom + return token + + +def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]: + """Parse a list until the closing delimiter.""" + items = [] + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError(f"Unterminated list, expected {closer!r}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == closer: + tokenizer.next_token() # Consume closer + return items + + items.append(_parse_expr(tokenizer)) + + +def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]: + """Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}.""" + result = {} + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError("Unterminated map, expected '}'", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == '}': + tokenizer.next_token() # Consume closer + return result + + # Parse key (should be a keyword like :key) + key_token = _parse_expr(tokenizer) + if isinstance(key_token, Keyword): + key = key_token.name + elif isinstance(key_token, str): + key = key_token + else: + raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + # Parse value + value = _parse_expr(tokenizer) + result[key] = value + + +def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str: + """ + Serialize a Python data structure back to S-expression format. + + Args: + expr: The expression to serialize + indent: Current indentation level (for pretty printing) + pretty: Whether to use pretty printing with newlines + + Returns: + S-expression string + """ + if isinstance(expr, list): + if not expr: + return "()" + + if pretty: + return _serialize_pretty(expr, indent) + else: + items = [serialize(item, indent, False) for item in expr] + return "(" + " ".join(items) + ")" + + if isinstance(expr, Symbol): + return expr.name + + if isinstance(expr, Keyword): + return f":{expr.name}" + + if isinstance(expr, str): + # Escape special characters + escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t') + return f'"{escaped}"' + + if isinstance(expr, bool): + return "true" if expr else "false" + + if isinstance(expr, (int, float)): + return str(expr) + + if expr is None: + return "nil" + + if isinstance(expr, dict): + # Serialize dict as property list: {:key1 val1 :key2 val2} + items = [] + for k, v in expr.items(): + items.append(f":{k}") + items.append(serialize(v, indent, pretty)) + return "{" + " ".join(items) + "}" + + raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}") + + +def _serialize_pretty(expr: List, indent: int) -> str: + """Pretty-print a list expression with smart formatting.""" + if not expr: + return "()" + + prefix = " " * indent + inner_prefix = " " * (indent + 1) + + # Check if this is a simple list that fits on one line + simple = serialize(expr, indent, False) + if len(simple) < 60 and '\n' not in simple: + return simple + + # Start building multiline output + head = serialize(expr[0], indent + 1, False) + parts = [f"({head}"] + + i = 1 + while i < len(expr): + item = expr[i] + + # Group keyword-value pairs on same line + if isinstance(item, Keyword) and i + 1 < len(expr): + key = serialize(item, 0, False) + val = serialize(expr[i + 1], indent + 1, False) + + # If value is short, put on same line + if len(val) < 50 and '\n' not in val: + parts.append(f"{inner_prefix}{key} {val}") + else: + # Value is complex, serialize it pretty + val_pretty = serialize(expr[i + 1], indent + 1, True) + parts.append(f"{inner_prefix}{key} {val_pretty}") + i += 2 + else: + # Regular item + item_str = serialize(item, indent + 1, True) + parts.append(f"{inner_prefix}{item_str}") + i += 1 + + return "\n".join(parts) + ")" + + +def parse_file(path: str) -> Any: + """Parse an S-expression file (supports multiple top-level expressions).""" + with open(path, 'r') as f: + return parse_all(f.read()) + + +def to_sexp(obj: Any) -> str: + """Convert Python object back to S-expression string (alias for serialize).""" + return serialize(obj) diff --git a/l1/sexp_effects/primitive_libs/__init__.py b/l1/sexp_effects/primitive_libs/__init__.py new file mode 100644 index 0000000..47ee174 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/__init__.py @@ -0,0 +1,102 @@ +""" +Primitive Libraries System + +Provides modular loading of primitives. Core primitives are always available, +additional primitive libraries can be loaded on-demand with scoped availability. + +Usage in sexp: + ;; Load at recipe level - available throughout + (primitives math :path "primitive_libs/math.py") + + ;; Or use with-primitives for scoped access + (with-primitives "image" + (blur frame 3)) ;; blur only available inside + + ;; Nested scopes work + (with-primitives "math" + (with-primitives "color" + (hue-shift frame (* (sin t) 30)))) + +Library file format (primitive_libs/math.py): + import math + + def prim_sin(x): return math.sin(x) + def prim_cos(x): return math.cos(x) + + PRIMITIVES = { + 'sin': prim_sin, + 'cos': prim_cos, + } +""" + +import importlib.util +from pathlib import Path +from typing import Dict, Callable, Any, Optional + +# Cache of loaded primitive libraries +_library_cache: Dict[str, Dict[str, Any]] = {} + +# Core primitives - always available, cannot be overridden +CORE_PRIMITIVES: Dict[str, Any] = {} + + +def register_core_primitive(name: str, fn: Callable): + """Register a core primitive that's always available.""" + CORE_PRIMITIVES[name] = fn + + +def load_primitive_library(name: str, path: Optional[str] = None) -> Dict[str, Any]: + """ + Load a primitive library by name or path. + + Args: + name: Library name (e.g., "math", "image", "color") + path: Optional explicit path to library file + + Returns: + Dict of primitive name -> function + """ + # Check cache first + cache_key = path or name + if cache_key in _library_cache: + return _library_cache[cache_key] + + # Find library file + if path: + lib_path = Path(path) + else: + # Look in standard locations + lib_dir = Path(__file__).parent + lib_path = lib_dir / f"{name}.py" + + if not lib_path.exists(): + raise ValueError(f"Primitive library '{name}' not found at {lib_path}") + + if not lib_path.exists(): + raise ValueError(f"Primitive library file not found: {lib_path}") + + # Load the module + spec = importlib.util.spec_from_file_location(f"prim_lib_{name}", lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get PRIMITIVES dict from module + if not hasattr(module, 'PRIMITIVES'): + raise ValueError(f"Primitive library '{name}' missing PRIMITIVES dict") + + primitives = module.PRIMITIVES + + # Cache and return + _library_cache[cache_key] = primitives + return primitives + + +def get_library_names() -> list: + """Get names of available primitive libraries.""" + lib_dir = Path(__file__).parent + return [p.stem for p in lib_dir.glob("*.py") if p.stem != "__init__"] + + +def clear_cache(): + """Clear the library cache (useful for testing).""" + _library_cache.clear() diff --git a/l1/sexp_effects/primitive_libs/arrays.py b/l1/sexp_effects/primitive_libs/arrays.py new file mode 100644 index 0000000..61da196 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/arrays.py @@ -0,0 +1,196 @@ +""" +Array Primitives Library + +Vectorized operations on numpy arrays for coordinate transformations. +""" +import numpy as np + + +# Arithmetic +def prim_arr_add(a, b): + return np.add(a, b) + + +def prim_arr_sub(a, b): + return np.subtract(a, b) + + +def prim_arr_mul(a, b): + return np.multiply(a, b) + + +def prim_arr_div(a, b): + return np.divide(a, b) + + +def prim_arr_mod(a, b): + return np.mod(a, b) + + +def prim_arr_neg(a): + return np.negative(a) + + +# Math functions +def prim_arr_sin(a): + return np.sin(a) + + +def prim_arr_cos(a): + return np.cos(a) + + +def prim_arr_tan(a): + return np.tan(a) + + +def prim_arr_sqrt(a): + return np.sqrt(np.maximum(a, 0)) + + +def prim_arr_pow(a, b): + return np.power(a, b) + + +def prim_arr_abs(a): + return np.abs(a) + + +def prim_arr_exp(a): + return np.exp(a) + + +def prim_arr_log(a): + return np.log(np.maximum(a, 1e-10)) + + +def prim_arr_atan2(y, x): + return np.arctan2(y, x) + + +# Comparison / selection +def prim_arr_min(a, b): + return np.minimum(a, b) + + +def prim_arr_max(a, b): + return np.maximum(a, b) + + +def prim_arr_clip(a, lo, hi): + return np.clip(a, lo, hi) + + +def prim_arr_where(cond, a, b): + return np.where(cond, a, b) + + +def prim_arr_floor(a): + return np.floor(a) + + +def prim_arr_ceil(a): + return np.ceil(a) + + +def prim_arr_round(a): + return np.round(a) + + +# Interpolation +def prim_arr_lerp(a, b, t): + return a + (b - a) * t + + +def prim_arr_smoothstep(edge0, edge1, x): + t = prim_arr_clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +# Creation +def prim_arr_zeros(shape): + return np.zeros(shape, dtype=np.float32) + + +def prim_arr_ones(shape): + return np.ones(shape, dtype=np.float32) + + +def prim_arr_full(shape, value): + return np.full(shape, value, dtype=np.float32) + + +def prim_arr_arange(start, stop, step=1): + return np.arange(start, stop, step, dtype=np.float32) + + +def prim_arr_linspace(start, stop, num): + return np.linspace(start, stop, num, dtype=np.float32) + + +def prim_arr_meshgrid(x, y): + return np.meshgrid(x, y) + + +# Coordinate transforms +def prim_polar_from_center(map_x, map_y, cx, cy): + """Convert Cartesian to polar coordinates centered at (cx, cy).""" + dx = map_x - cx + dy = map_y - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + return (r, theta) + + +def prim_cart_from_polar(r, theta, cx, cy): + """Convert polar to Cartesian, adding center offset.""" + x = r * np.cos(theta) + cx + y = r * np.sin(theta) + cy + return (x, y) + + +PRIMITIVES = { + # Arithmetic + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-neg': prim_arr_neg, + + # Math + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-exp': prim_arr_exp, + 'arr-log': prim_arr_log, + 'arr-atan2': prim_arr_atan2, + + # Selection + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-ceil': prim_arr_ceil, + 'arr-round': prim_arr_round, + + # Interpolation + 'arr-lerp': prim_arr_lerp, + 'arr-smoothstep': prim_arr_smoothstep, + + # Creation + 'arr-zeros': prim_arr_zeros, + 'arr-ones': prim_arr_ones, + 'arr-full': prim_arr_full, + 'arr-arange': prim_arr_arange, + 'arr-linspace': prim_arr_linspace, + 'arr-meshgrid': prim_arr_meshgrid, + + # Coordinates + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, +} diff --git a/l1/sexp_effects/primitive_libs/ascii.py b/l1/sexp_effects/primitive_libs/ascii.py new file mode 100644 index 0000000..858f010 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/ascii.py @@ -0,0 +1,388 @@ +""" +ASCII Art Primitives Library + +ASCII art rendering with per-zone expression evaluation and cell effects. +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont +from typing import Any, Dict, List, Optional, Callable +import colorsys + + +# Character sets +CHAR_SETS = { + "standard": " .:-=+*#%@", + "blocks": " ░▒▓█", + "simple": " .:oO@", + "digits": "0123456789", + "binary": "01", + "ascii": " `.-':_,^=;><+!rc*/z?sLTv)J7(|Fi{C}fI31tlu[neoZ5Yxjya]2ESwqkP6h9d4VpOGbUAKXHm8RD#$Bg0MNWQ%&@", +} + +# Default font +_default_font = None + + +def _get_font(size: int): + """Get monospace font at given size.""" + global _default_font + try: + return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) + except: + return ImageFont.load_default() + + +def _parse_color(color_str: str) -> tuple: + """Parse color string to RGB tuple.""" + if color_str.startswith('#'): + hex_color = color_str[1:] + if len(hex_color) == 3: + hex_color = ''.join(c*2 for c in hex_color) + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + + colors = { + 'black': (0, 0, 0), 'white': (255, 255, 255), + 'red': (255, 0, 0), 'green': (0, 255, 0), 'blue': (0, 0, 255), + 'yellow': (255, 255, 0), 'cyan': (0, 255, 255), 'magenta': (255, 0, 255), + 'gray': (128, 128, 128), 'grey': (128, 128, 128), + } + return colors.get(color_str.lower(), (0, 0, 0)) + + +def _cell_sample(frame: np.ndarray, cell_size: int): + """Sample frame into cells, returning colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + """ + h, w = frame.shape[:2] + rows = h // cell_size + cols = w // cell_size + + # Crop to exact grid then block-average via cv2 area interpolation. + cropped = frame[:rows * cell_size, :cols * cell_size] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + luminances = ((0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]) / 255.0).astype(np.float32) + + return colors, luminances + + +def _luminance_to_char(lum: float, alphabet: str, contrast: float) -> str: + """Map luminance to character.""" + chars = CHAR_SETS.get(alphabet, alphabet) + lum = ((lum - 0.5) * contrast + 0.5) + lum = max(0, min(1, lum)) + idx = int(lum * (len(chars) - 1)) + return chars[idx] + + +def _render_char_cell(char: str, cell_size: int, color: tuple, bg_color: tuple) -> np.ndarray: + """Render a single character to a cell image.""" + img = Image.new('RGB', (cell_size, cell_size), bg_color) + draw = ImageDraw.Draw(img) + font = _get_font(cell_size) + + # Center the character + bbox = draw.textbbox((0, 0), char, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + x = (cell_size - text_w) // 2 + y = (cell_size - text_h) // 2 - bbox[1] + + draw.text((x, y), char, fill=color, font=font) + return np.array(img) + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int = 80, + char_size: int = None, + alphabet: str = "standard", + color_mode: str = "color", + background: str = "black", + contrast: float = 1.5, + char_hue = None, + char_saturation = None, + char_brightness = None, + char_scale = None, + char_rotation = None, + char_jitter = None, + cell_effect = None, + energy: float = None, + rotation_scale: float = 0, + _interp = None, + _env = None, + **extra_params +) -> np.ndarray: + """ + Render frame as ASCII art with per-zone effects. + + Args: + frame: Input image + cols: Number of character columns + char_size: Cell size in pixels (overrides cols if set) + alphabet: Character set name or custom string + color_mode: "color", "mono", "invert", or color name + background: Background color name or hex + contrast: Contrast for character selection + char_hue/saturation/brightness/scale/rotation/jitter: Per-zone expressions + cell_effect: Lambda (cell, zone) -> cell for per-cell effects + energy: Energy value from audio analysis + rotation_scale: Max rotation degrees + _interp: Interpreter (auto-injected) + _env: Environment (auto-injected) + **extra_params: Additional params passed to zone dict + """ + h, w = frame.shape[:2] + + # Calculate cell size + if char_size is None or char_size == 0: + cell_size = max(4, w // cols) + else: + cell_size = max(4, int(char_size)) + + # Sample cells + colors, luminances = _cell_sample(frame, cell_size) + rows, cols_actual = luminances.shape + + # Parse background color + bg_color = _parse_color(background) + + # Create output image + out_h = rows * cell_size + out_w = cols_actual * cell_size + output = np.full((out_h, out_w, 3), bg_color, dtype=np.uint8) + + # Check if we have cell_effect + has_cell_effect = cell_effect is not None + + # Process each cell + for r in range(rows): + for c in range(cols_actual): + lum = luminances[r, c] + cell_color = tuple(colors[r, c]) + + # Build zone context + zone = { + 'row': r, + 'col': c, + 'row-norm': r / max(1, rows - 1), + 'col-norm': c / max(1, cols_actual - 1), + 'lum': float(lum), + 'r': cell_color[0] / 255, + 'g': cell_color[1] / 255, + 'b': cell_color[2] / 255, + 'cell_size': cell_size, + } + + # Add HSV + r_f, g_f, b_f = cell_color[0]/255, cell_color[1]/255, cell_color[2]/255 + hsv = colorsys.rgb_to_hsv(r_f, g_f, b_f) + zone['hue'] = hsv[0] * 360 + zone['sat'] = hsv[1] + + # Add energy and rotation_scale + if energy is not None: + zone['energy'] = energy + zone['rotation_scale'] = rotation_scale + + # Add extra params + for k, v in extra_params.items(): + if isinstance(v, (int, float, str, bool)) or v is None: + zone[k] = v + + # Get character + char = _luminance_to_char(lum, alphabet, contrast) + zone['char'] = char + + # Determine cell color based on mode + if color_mode == "mono": + render_color = (255, 255, 255) + elif color_mode == "invert": + render_color = tuple(255 - c for c in cell_color) + elif color_mode == "color": + render_color = cell_color + else: + render_color = _parse_color(color_mode) + + zone['color'] = render_color + + # Render character to cell + cell_img = _render_char_cell(char, cell_size, render_color, bg_color) + + # Apply cell_effect if provided + if has_cell_effect and _interp is not None: + cell_img = _apply_cell_effect(cell_img, zone, cell_effect, _interp, _env, extra_params) + + # Paste cell to output + y1, y2 = r * cell_size, (r + 1) * cell_size + x1, x2 = c * cell_size, (c + 1) * cell_size + output[y1:y2, x1:x2] = cell_img + + # Resize to match input dimensions + if output.shape[:2] != frame.shape[:2]: + output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LINEAR) + + return output + + +def _apply_cell_effect(cell_img, zone, cell_effect, interp, env, extra_params): + """Apply cell_effect lambda to a cell image. + + cell_effect is a Lambda object with params and body. + We create a child environment with zone variables and cell, + then evaluate the lambda body. + """ + # Get Environment class from the interpreter's module + Environment = type(env) + + # Create child environment with zone variables + cell_env = Environment(env) + + # Bind zone variables + for k, v in zone.items(): + cell_env.set(k, v) + + # Also bind with zone- prefix for consistency + cell_env.set('zone-row', zone.get('row', 0)) + cell_env.set('zone-col', zone.get('col', 0)) + cell_env.set('zone-row-norm', zone.get('row-norm', 0)) + cell_env.set('zone-col-norm', zone.get('col-norm', 0)) + cell_env.set('zone-lum', zone.get('lum', 0)) + cell_env.set('zone-sat', zone.get('sat', 0)) + cell_env.set('zone-hue', zone.get('hue', 0)) + cell_env.set('zone-r', zone.get('r', 0)) + cell_env.set('zone-g', zone.get('g', 0)) + cell_env.set('zone-b', zone.get('b', 0)) + + # Inject loaded effects as callable functions + if hasattr(interp, 'effects'): + for effect_name in interp.effects: + def make_effect_fn(name): + def effect_fn(frame, *args): + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 2: + params['offset_x'] = args[0] + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['size'] = args[0] + elif name == 'invert': + pass + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + cell_env.set(effect_name, make_effect_fn(effect_name)) + + # Bind cell image and zone dict + cell_env.set('cell', cell_img) + cell_env.set('zone', zone) + + # Evaluate the cell_effect lambda + # Lambda has params and body - we need to bind the params then evaluate + if hasattr(cell_effect, 'params') and hasattr(cell_effect, 'body'): + # Bind lambda parameters: (lambda [cell zone] body) + if len(cell_effect.params) >= 1: + cell_env.set(cell_effect.params[0], cell_img) + if len(cell_effect.params) >= 2: + cell_env.set(cell_effect.params[1], zone) + + result = interp.eval(cell_effect.body, cell_env) + elif isinstance(cell_effect, list): + # Raw S-expression lambda like (lambda [cell zone] body) or (fn [cell zone] body) + # Check if it's a lambda expression + head = cell_effect[0] if cell_effect else None + head_name = head.name if head and hasattr(head, 'name') else str(head) if head else None + is_lambda = head_name in ('lambda', 'fn') + + if is_lambda: + # (lambda [params...] body) + params = cell_effect[1] if len(cell_effect) > 1 else [] + body = cell_effect[2] if len(cell_effect) > 2 else None + + # Bind lambda parameters + if isinstance(params, list) and len(params) >= 1: + param_name = params[0].name if hasattr(params[0], 'name') else str(params[0]) + cell_env.set(param_name, cell_img) + if isinstance(params, list) and len(params) >= 2: + param_name = params[1].name if hasattr(params[1], 'name') else str(params[1]) + cell_env.set(param_name, zone) + + result = interp.eval(body, cell_env) if body else cell_img + else: + # Some other expression - just evaluate it + result = interp.eval(cell_effect, cell_env) + elif callable(cell_effect): + # It's a callable + result = cell_effect(cell_img, zone) + else: + raise ValueError(f"cell_effect must be a Lambda, list, or callable, got {type(cell_effect)}") + + if isinstance(result, np.ndarray) and result.shape == cell_img.shape: + return result + elif isinstance(result, np.ndarray): + # Shape mismatch - resize to fit + result = cv2.resize(result, (cell_img.shape[1], cell_img.shape[0])) + return result + + raise ValueError(f"cell_effect must return an image array, got {type(result)}") + + +def _get_legacy_ascii_primitives(): + """Import ASCII primitives from legacy primitives module. + + These are loaded lazily to avoid import issues during module loading. + By the time a primitive library is loaded, sexp_effects.primitives + is already in sys.modules (imported by sexp_effects.__init__). + """ + from sexp_effects.primitives import ( + prim_cell_sample, + prim_luminance_to_chars, + prim_render_char_grid, + prim_render_char_grid_fx, + prim_alphabet_char, + prim_alphabet_length, + prim_map_char_grid, + prim_map_colors, + prim_make_char_grid, + prim_set_char, + prim_get_char, + prim_char_grid_dimensions, + cell_sample_extended, + ) + return { + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + } + + +PRIMITIVES = { + 'ascii-fx-zone': prim_ascii_fx_zone, + **_get_legacy_ascii_primitives(), +} diff --git a/l1/sexp_effects/primitive_libs/blending.py b/l1/sexp_effects/primitive_libs/blending.py new file mode 100644 index 0000000..0bf345d --- /dev/null +++ b/l1/sexp_effects/primitive_libs/blending.py @@ -0,0 +1,116 @@ +""" +Blending Primitives Library + +Image blending and compositing operations. +""" +import numpy as np + + +def prim_blend_images(a, b, alpha): + """Blend two images: a * (1-alpha) + b * alpha.""" + alpha = max(0.0, min(1.0, alpha)) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a, b, mode): + """Blend using Photoshop-style blend modes.""" + a = a.astype(float) / 255 + b = b.astype(float) / 255 + + if mode == "multiply": + result = a * b + elif mode == "screen": + result = 1 - (1 - a) * (1 - b) + elif mode == "overlay": + mask = a < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "soft-light": + mask = b < 0.5 + result = np.where(mask, + a - (1 - 2 * b) * a * (1 - a), + a + (2 * b - 1) * (np.sqrt(a) - a)) + elif mode == "hard-light": + mask = b < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "color-dodge": + result = np.clip(a / (1 - b + 0.001), 0, 1) + elif mode == "color-burn": + result = 1 - np.clip((1 - a) / (b + 0.001), 0, 1) + elif mode == "difference": + result = np.abs(a - b) + elif mode == "exclusion": + result = a + b - 2 * a * b + elif mode == "add": + result = np.clip(a + b, 0, 1) + elif mode == "subtract": + result = np.clip(a - b, 0, 1) + elif mode == "darken": + result = np.minimum(a, b) + elif mode == "lighten": + result = np.maximum(a, b) + else: + # Default to normal (just return b) + result = b + + return (result * 255).astype(np.uint8) + + +def prim_mask(img, mask_img): + """Apply grayscale mask to image (white=opaque, black=transparent).""" + if len(mask_img.shape) == 3: + mask = mask_img[:, :, 0].astype(float) / 255 + else: + mask = mask_img.astype(float) / 255 + + mask = mask[:, :, np.newaxis] + return (img.astype(float) * mask).astype(np.uint8) + + +def prim_alpha_composite(base, overlay, alpha_channel): + """Composite overlay onto base using alpha channel.""" + if len(alpha_channel.shape) == 3: + alpha = alpha_channel[:, :, 0].astype(float) / 255 + else: + alpha = alpha_channel.astype(float) / 255 + + alpha = alpha[:, :, np.newaxis] + result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha + return result.astype(np.uint8) + + +def prim_overlay(base, overlay, x, y, alpha=1.0): + """Overlay image at position (x, y) with optional alpha.""" + result = base.copy() + x, y = int(x), int(y) + oh, ow = overlay.shape[:2] + bh, bw = base.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8) + + return result + + +PRIMITIVES = { + # Basic blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + + # Masking + 'mask': prim_mask, + 'alpha-composite': prim_alpha_composite, + + # Overlay + 'overlay': prim_overlay, +} diff --git a/l1/sexp_effects/primitive_libs/blending_gpu.py b/l1/sexp_effects/primitive_libs/blending_gpu.py new file mode 100644 index 0000000..c768be3 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/blending_gpu.py @@ -0,0 +1,220 @@ +""" +GPU-Accelerated Blending Primitives Library + +Uses CuPy for CUDA-accelerated image blending and compositing. +Keeps frames on GPU when STREAMING_GPU_PERSIST=1 for maximum performance. +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + GPU_AVAILABLE = True + print("[blending_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[blending_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[blending_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def _get_xp(img): + """Get the array module (numpy or cupy) for the given image.""" + if GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp + return np + + +def prim_blend_images(a, b, alpha): + """Blend two images: a * (1-alpha) + b * alpha.""" + alpha = max(0.0, min(1.0, float(alpha))) + + if GPU_AVAILABLE: + a_gpu = _to_gpu(a) + b_gpu = _to_gpu(b) + result = (a_gpu.astype(cp.float32) * (1 - alpha) + b_gpu.astype(cp.float32) * alpha).astype(cp.uint8) + return _to_cpu(result) + + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a, b, mode): + """Blend using Photoshop-style blend modes.""" + if GPU_AVAILABLE: + a_gpu = _to_gpu(a).astype(cp.float32) / 255 + b_gpu = _to_gpu(b).astype(cp.float32) / 255 + xp = cp + else: + a_gpu = a.astype(float) / 255 + b_gpu = b.astype(float) / 255 + xp = np + + if mode == "multiply": + result = a_gpu * b_gpu + elif mode == "screen": + result = 1 - (1 - a_gpu) * (1 - b_gpu) + elif mode == "overlay": + mask = a_gpu < 0.5 + result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu)) + elif mode == "soft-light": + mask = b_gpu < 0.5 + result = xp.where(mask, + a_gpu - (1 - 2 * b_gpu) * a_gpu * (1 - a_gpu), + a_gpu + (2 * b_gpu - 1) * (xp.sqrt(a_gpu) - a_gpu)) + elif mode == "hard-light": + mask = b_gpu < 0.5 + result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu)) + elif mode == "color-dodge": + result = xp.clip(a_gpu / (1 - b_gpu + 0.001), 0, 1) + elif mode == "color-burn": + result = 1 - xp.clip((1 - a_gpu) / (b_gpu + 0.001), 0, 1) + elif mode == "difference": + result = xp.abs(a_gpu - b_gpu) + elif mode == "exclusion": + result = a_gpu + b_gpu - 2 * a_gpu * b_gpu + elif mode == "add": + result = xp.clip(a_gpu + b_gpu, 0, 1) + elif mode == "subtract": + result = xp.clip(a_gpu - b_gpu, 0, 1) + elif mode == "darken": + result = xp.minimum(a_gpu, b_gpu) + elif mode == "lighten": + result = xp.maximum(a_gpu, b_gpu) + else: + # Default to normal (just return b) + result = b_gpu + + result = (result * 255).astype(xp.uint8) + return _to_cpu(result) + + +def prim_mask(img, mask_img): + """Apply grayscale mask to image (white=opaque, black=transparent).""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + mask_gpu = _to_gpu(mask_img) + + if len(mask_gpu.shape) == 3: + mask = mask_gpu[:, :, 0].astype(cp.float32) / 255 + else: + mask = mask_gpu.astype(cp.float32) / 255 + + mask = mask[:, :, cp.newaxis] + result = (img_gpu.astype(cp.float32) * mask).astype(cp.uint8) + return _to_cpu(result) + + if len(mask_img.shape) == 3: + mask = mask_img[:, :, 0].astype(float) / 255 + else: + mask = mask_img.astype(float) / 255 + + mask = mask[:, :, np.newaxis] + return (img.astype(float) * mask).astype(np.uint8) + + +def prim_alpha_composite(base, overlay, alpha_channel): + """Composite overlay onto base using alpha channel.""" + if GPU_AVAILABLE: + base_gpu = _to_gpu(base) + overlay_gpu = _to_gpu(overlay) + alpha_gpu = _to_gpu(alpha_channel) + + if len(alpha_gpu.shape) == 3: + alpha = alpha_gpu[:, :, 0].astype(cp.float32) / 255 + else: + alpha = alpha_gpu.astype(cp.float32) / 255 + + alpha = alpha[:, :, cp.newaxis] + result = base_gpu.astype(cp.float32) * (1 - alpha) + overlay_gpu.astype(cp.float32) * alpha + return _to_cpu(result.astype(cp.uint8)) + + if len(alpha_channel.shape) == 3: + alpha = alpha_channel[:, :, 0].astype(float) / 255 + else: + alpha = alpha_channel.astype(float) / 255 + + alpha = alpha[:, :, np.newaxis] + result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha + return result.astype(np.uint8) + + +def prim_overlay(base, overlay, x, y, alpha=1.0): + """Overlay image at position (x, y) with optional alpha.""" + if GPU_AVAILABLE: + base_gpu = _to_gpu(base) + overlay_gpu = _to_gpu(overlay) + result = base_gpu.copy() + + x, y = int(x), int(y) + oh, ow = overlay_gpu.shape[:2] + bh, bw = base_gpu.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay_gpu[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(cp.float32) * (1 - alpha) + src.astype(cp.float32) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(cp.uint8) + + return _to_cpu(result) + + result = base.copy() + x, y = int(x), int(y) + oh, ow = overlay.shape[:2] + bh, bw = base.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8) + + return result + + +PRIMITIVES = { + # Basic blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + + # Masking + 'mask': prim_mask, + 'alpha-composite': prim_alpha_composite, + + # Overlay + 'overlay': prim_overlay, +} diff --git a/l1/sexp_effects/primitive_libs/color.py b/l1/sexp_effects/primitive_libs/color.py new file mode 100644 index 0000000..0b6854b --- /dev/null +++ b/l1/sexp_effects/primitive_libs/color.py @@ -0,0 +1,137 @@ +""" +Color Primitives Library + +Color manipulation: RGB, HSV, blending, luminance. +""" +import numpy as np +import colorsys + + +def prim_rgb(r, g, b): + """Create RGB color as [r, g, b] (0-255).""" + return [int(max(0, min(255, r))), + int(max(0, min(255, g))), + int(max(0, min(255, b)))] + + +def prim_red(c): + return c[0] + + +def prim_green(c): + return c[1] + + +def prim_blue(c): + return c[2] + + +def prim_luminance(c): + """Perceived luminance (0-1) using standard weights.""" + return (0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) / 255 + + +def prim_rgb_to_hsv(c): + """Convert RGB [0-255] to HSV [h:0-360, s:0-1, v:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, s, v = colorsys.rgb_to_hsv(r, g, b) + return [h * 360, s, v] + + +def prim_hsv_to_rgb(hsv): + """Convert HSV [h:0-360, s:0-1, v:0-1] to RGB [0-255].""" + h, s, v = hsv[0] / 360, hsv[1], hsv[2] + r, g, b = colorsys.hsv_to_rgb(h, s, v) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_rgb_to_hsl(c): + """Convert RGB [0-255] to HSL [h:0-360, s:0-1, l:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, l, s = colorsys.rgb_to_hls(r, g, b) + return [h * 360, s, l] + + +def prim_hsl_to_rgb(hsl): + """Convert HSL [h:0-360, s:0-1, l:0-1] to RGB [0-255].""" + h, s, l = hsl[0] / 360, hsl[1], hsl[2] + r, g, b = colorsys.hls_to_rgb(h, l, s) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_blend_color(c1, c2, alpha): + """Blend two colors: c1 * (1-alpha) + c2 * alpha.""" + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img): + """Get average color of an image.""" + mean = np.mean(img, axis=(0, 1)) + return [int(mean[0]), int(mean[1]), int(mean[2])] + + +def prim_dominant_color(img, k=1): + """Get dominant color using k-means (simplified: just average for now).""" + return prim_average_color(img) + + +def prim_invert_color(c): + """Invert a color.""" + return [255 - c[0], 255 - c[1], 255 - c[2]] + + +def prim_grayscale_color(c): + """Convert color to grayscale.""" + gray = int(0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) + return [gray, gray, gray] + + +def prim_saturate(c, amount): + """Adjust saturation of color. amount=0 is grayscale, 1 is unchanged, >1 is more saturated.""" + hsv = prim_rgb_to_hsv(c) + hsv[1] = max(0, min(1, hsv[1] * amount)) + return prim_hsv_to_rgb(hsv) + + +def prim_brighten(c, amount): + """Adjust brightness. amount=0 is black, 1 is unchanged, >1 is brighter.""" + return [int(max(0, min(255, c[i] * amount))) for i in range(3)] + + +def prim_shift_hue(c, degrees): + """Shift hue by degrees.""" + hsv = prim_rgb_to_hsv(c) + hsv[0] = (hsv[0] + degrees) % 360 + return prim_hsv_to_rgb(hsv) + + +PRIMITIVES = { + # Construction + 'rgb': prim_rgb, + + # Component access + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + + # Color space conversion + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'rgb->hsl': prim_rgb_to_hsl, + 'hsl->rgb': prim_hsl_to_rgb, + + # Blending + 'blend-color': prim_blend_color, + + # Analysis + 'average-color': prim_average_color, + 'dominant-color': prim_dominant_color, + + # Manipulation + 'invert-color': prim_invert_color, + 'grayscale-color': prim_grayscale_color, + 'saturate': prim_saturate, + 'brighten': prim_brighten, + 'shift-hue': prim_shift_hue, +} diff --git a/l1/sexp_effects/primitive_libs/color_ops.py b/l1/sexp_effects/primitive_libs/color_ops.py new file mode 100644 index 0000000..a0da497 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/color_ops.py @@ -0,0 +1,109 @@ +""" +Color Operations Primitives Library + +Vectorized color adjustments: brightness, contrast, saturation, invert, HSV. +These operate on entire images for fast processing. +""" +import numpy as np +import cv2 + + +def _to_numpy(img): + """Convert GPU frames or CuPy arrays to numpy for CPU processing.""" + # Handle GPUFrame objects + if hasattr(img, 'cpu'): + return img.cpu + # Handle CuPy arrays + if hasattr(img, 'get'): + return img.get() + return img + + +def prim_adjust(img, brightness=0, contrast=1): + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + img = _to_numpy(img) + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img_raw, amount): + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + img = _to_numpy(img_raw) + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img): + """Invert all pixel values.""" + img = _to_numpy(img) + return (255 - img).astype(np.uint8) + + +def prim_shift_hsv(img, h=0, s=1, v=1): + """Shift HSV: h=degrees offset, s/v=multipliers.""" + img = _to_numpy(img) + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +def prim_add_noise(img, amount): + """Add gaussian noise to image.""" + img = _to_numpy(img) + noise = np.random.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img, levels): + """Reduce to N color levels per channel.""" + img = _to_numpy(img) + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_sepia(img, intensity=1.0): + """Apply sepia tone effect.""" + img = _to_numpy(img) + sepia_matrix = np.array([ + [0.393, 0.769, 0.189], + [0.349, 0.686, 0.168], + [0.272, 0.534, 0.131] + ]) + sepia = np.dot(img, sepia_matrix.T) + result = img.astype(np.float32) * (1 - intensity) + sepia * intensity + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_grayscale(img): + """Convert to grayscale (still RGB output).""" + img = _to_numpy(img) + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + return np.stack([gray, gray, gray], axis=-1).astype(np.uint8) + + +PRIMITIVES = { + # Brightness/Contrast + 'adjust': prim_adjust, + + # Saturation + 'mix-gray': prim_mix_gray, + 'grayscale': prim_grayscale, + + # HSV manipulation + 'shift-hsv': prim_shift_hsv, + + # Inversion + 'invert-img': prim_invert_img, + + # Effects + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'sepia': prim_sepia, +} diff --git a/l1/sexp_effects/primitive_libs/color_ops_gpu.py b/l1/sexp_effects/primitive_libs/color_ops_gpu.py new file mode 100644 index 0000000..a4f5272 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/color_ops_gpu.py @@ -0,0 +1,280 @@ +""" +GPU-Accelerated Color Operations Library + +Uses CuPy for CUDA-accelerated color transforms. + +Performance Mode: +- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations +- This dramatically improves performance by avoiding CPU<->GPU transfers +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + GPU_AVAILABLE = True + print("[color_ops_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[color_ops_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[color_ops_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def prim_invert(img): + """Invert image colors.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(255 - img_gpu) + return 255 - img + + +def prim_grayscale(img): + """Convert to grayscale.""" + if img.ndim != 3: + return img + + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + # Standard luminance weights + gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2] + gray = cp.clip(gray, 0, 255).astype(cp.uint8) + # Stack to 3 channels + result = cp.stack([gray, gray, gray], axis=2) + return _to_cpu(result) + + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray = np.clip(gray, 0, 255).astype(np.uint8) + return np.stack([gray, gray, gray], axis=2) + + +def prim_brightness(img, factor=1.0): + """Adjust brightness by factor.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + result = xp.clip(img_gpu * factor, 0, 255).astype(xp.uint8) + return _to_cpu(result) + return np.clip(img.astype(np.float32) * factor, 0, 255).astype(np.uint8) + + +def prim_contrast(img, factor=1.0): + """Adjust contrast around midpoint.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + result = xp.clip((img_gpu - 128) * factor + 128, 0, 255).astype(xp.uint8) + return _to_cpu(result) + return np.clip((img.astype(np.float32) - 128) * factor + 128, 0, 255).astype(np.uint8) + + +# CUDA kernel for HSV hue shift +if GPU_AVAILABLE: + _hue_shift_kernel = cp.RawKernel(r''' + extern "C" __global__ + void hue_shift(unsigned char* img, int width, int height, float shift) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int idx = (y * width + x) * 3; + + // Get RGB + float r = img[idx] / 255.0f; + float g = img[idx + 1] / 255.0f; + float b = img[idx + 2] / 255.0f; + + // RGB to HSV + float max_c = fmaxf(r, fmaxf(g, b)); + float min_c = fminf(r, fminf(g, b)); + float delta = max_c - min_c; + + float h = 0.0f, s = 0.0f, v = max_c; + + if (delta > 0.00001f) { + s = delta / max_c; + + if (max_c == r) { + h = 60.0f * fmodf((g - b) / delta, 6.0f); + } else if (max_c == g) { + h = 60.0f * ((b - r) / delta + 2.0f); + } else { + h = 60.0f * ((r - g) / delta + 4.0f); + } + + if (h < 0) h += 360.0f; + } + + // Shift hue + h = fmodf(h + shift, 360.0f); + if (h < 0) h += 360.0f; + + // HSV to RGB + float c = v * s; + float x_val = c * (1.0f - fabsf(fmodf(h / 60.0f, 2.0f) - 1.0f)); + float m = v - c; + + float r_out, g_out, b_out; + if (h < 60) { + r_out = c; g_out = x_val; b_out = 0; + } else if (h < 120) { + r_out = x_val; g_out = c; b_out = 0; + } else if (h < 180) { + r_out = 0; g_out = c; b_out = x_val; + } else if (h < 240) { + r_out = 0; g_out = x_val; b_out = c; + } else if (h < 300) { + r_out = x_val; g_out = 0; b_out = c; + } else { + r_out = c; g_out = 0; b_out = x_val; + } + + img[idx] = (unsigned char)fminf(255.0f, (r_out + m) * 255.0f); + img[idx + 1] = (unsigned char)fminf(255.0f, (g_out + m) * 255.0f); + img[idx + 2] = (unsigned char)fminf(255.0f, (b_out + m) * 255.0f); + } + ''', 'hue_shift') + + +def prim_hue_shift(img, shift=0.0): + """Shift hue by degrees.""" + if img.ndim != 3 or img.shape[2] != 3: + return img + + if not GPU_AVAILABLE: + import cv2 + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.float32) + shift / 2) % 180 + return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + h, w = img.shape[:2] + img_gpu = _to_gpu(img.astype(np.uint8)).copy() + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _hue_shift_kernel(grid, block, (img_gpu, np.int32(w), np.int32(h), np.float32(shift))) + + return _to_cpu(img_gpu) + + +def prim_saturate(img, factor=1.0): + """Adjust saturation by factor.""" + if img.ndim != 3: + return img + + if not GPU_AVAILABLE: + import cv2 + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + # GPU version - simple desaturation blend + img_gpu = _to_gpu(img.astype(np.float32)) + gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2] + gray = gray[:, :, cp.newaxis] + + if factor < 1.0: + # Desaturate: blend toward gray + result = img_gpu * factor + gray * (1 - factor) + else: + # Oversaturate: extrapolate away from gray + result = gray + (img_gpu - gray) * factor + + result = cp.clip(result, 0, 255).astype(cp.uint8) + return _to_cpu(result) + + +def prim_blend(img1, img2, alpha=0.5): + """Blend two images with alpha.""" + xp = cp if GPU_AVAILABLE else np + + if GPU_AVAILABLE: + img1_gpu = _to_gpu(img1.astype(np.float32)) + img2_gpu = _to_gpu(img2.astype(np.float32)) + result = img1_gpu * (1 - alpha) + img2_gpu * alpha + result = xp.clip(result, 0, 255).astype(xp.uint8) + return _to_cpu(result) + + result = img1.astype(np.float32) * (1 - alpha) + img2.astype(np.float32) * alpha + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_add(img1, img2): + """Add two images (clamped).""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + result = xp.clip(_to_gpu(img1).astype(np.int16) + _to_gpu(img2).astype(np.int16), 0, 255) + return _to_cpu(result.astype(xp.uint8)) + return np.clip(img1.astype(np.int16) + img2.astype(np.int16), 0, 255).astype(np.uint8) + + +def prim_multiply(img1, img2): + """Multiply two images (normalized).""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + result = (_to_gpu(img1).astype(np.float32) * _to_gpu(img2).astype(np.float32)) / 255.0 + result = xp.clip(result, 0, 255).astype(xp.uint8) + return _to_cpu(result) + result = (img1.astype(np.float32) * img2.astype(np.float32)) / 255.0 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_screen(img1, img2): + """Screen blend mode.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + i1 = _to_gpu(img1).astype(np.float32) / 255.0 + i2 = _to_gpu(img2).astype(np.float32) / 255.0 + result = 1.0 - (1.0 - i1) * (1.0 - i2) + result = xp.clip(result * 255, 0, 255).astype(xp.uint8) + return _to_cpu(result) + i1 = img1.astype(np.float32) / 255.0 + i2 = img2.astype(np.float32) / 255.0 + result = 1.0 - (1.0 - i1) * (1.0 - i2) + return np.clip(result * 255, 0, 255).astype(np.uint8) + + +# Import CPU primitives as fallbacks +def _get_cpu_primitives(): + """Get all primitives from CPU color_ops module as fallbacks.""" + from sexp_effects.primitive_libs import color_ops + return color_ops.PRIMITIVES + + +# Export functions - start with CPU primitives, then override with GPU versions +PRIMITIVES = _get_cpu_primitives().copy() + +# Override specific primitives with GPU-accelerated versions +PRIMITIVES.update({ + 'invert': prim_invert, + 'grayscale': prim_grayscale, + 'brightness': prim_brightness, + 'contrast': prim_contrast, + 'hue-shift': prim_hue_shift, + 'saturate': prim_saturate, + 'blend': prim_blend, + 'add': prim_add, + 'multiply': prim_multiply, + 'screen': prim_screen, +}) diff --git a/l1/sexp_effects/primitive_libs/core.py b/l1/sexp_effects/primitive_libs/core.py new file mode 100644 index 0000000..34b580a --- /dev/null +++ b/l1/sexp_effects/primitive_libs/core.py @@ -0,0 +1,294 @@ +""" +Core Primitives - Always available, minimal essential set. + +These are the primitives that form the foundation of the language. +They cannot be overridden by libraries. +""" + + +# Arithmetic +def prim_add(*args): + if len(args) == 0: + return 0 + result = args[0] + for arg in args[1:]: + result = result + arg + return result + + +def prim_sub(a, b=None): + if b is None: + return -a + return a - b + + +def prim_mul(*args): + if len(args) == 0: + return 1 + result = args[0] + for arg in args[1:]: + result = result * arg + return result + + +def prim_div(a, b): + return a / b + + +def prim_mod(a, b): + return a % b + + +def prim_abs(x): + return abs(x) + + +def prim_min(*args): + return min(args) + + +def prim_max(*args): + return max(args) + + +def prim_round(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) + return round(x) + + +def prim_floor(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) + import math + return math.floor(x) + + +def prim_ceil(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) + import math + return math.ceil(x) + + +# Comparison +def prim_lt(a, b): + return a < b + + +def prim_gt(a, b): + return a > b + + +def prim_le(a, b): + return a <= b + + +def prim_ge(a, b): + return a >= b + + +def prim_eq(a, b): + if isinstance(a, float) or isinstance(b, float): + return abs(a - b) < 1e-9 + return a == b + + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# Logic +def prim_not(x): + return not x + + +def prim_and(*args): + for a in args: + if not a: + return False + return True + + +def prim_or(*args): + for a in args: + if a: + return True + return False + + +# Basic data access +def prim_get(obj, key, default=None): + """Get value from dict or list.""" + if isinstance(obj, dict): + return obj.get(key, default) + elif isinstance(obj, (list, tuple)): + try: + return obj[int(key)] + except (IndexError, ValueError): + return default + return default + + +def prim_nth(seq, i): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_length(seq): + return len(seq) + + +def prim_list(*args): + return list(args) + + +# Type checking +def prim_is_number(x): + return isinstance(x, (int, float)) + + +def prim_is_string(x): + return isinstance(x, str) + + +def prim_is_list(x): + return isinstance(x, (list, tuple)) + + +def prim_is_dict(x): + return isinstance(x, dict) + + +def prim_is_nil(x): + return x is None + + +# Higher-order / iteration +def prim_reduce(seq, init, fn): + """(reduce seq init fn) — fold left: fn(fn(fn(init, s0), s1), s2) ...""" + acc = init + for item in seq: + acc = fn(acc, item) + return acc + + +def prim_map(seq, fn): + """(map seq fn) — apply fn to each element, return new list.""" + return [fn(item) for item in seq] + + +def prim_range(*args): + """(range end), (range start end), or (range start end step) — integer range.""" + if len(args) == 1: + return list(range(int(args[0]))) + elif len(args) == 2: + return list(range(int(args[0]), int(args[1]))) + elif len(args) >= 3: + return list(range(int(args[0]), int(args[1]), int(args[2]))) + return [] + + +# Random +import random +_rng = random.Random() + +def set_random_seed(seed): + """Set the random seed for deterministic output.""" + global _rng + _rng = random.Random(seed) + +def prim_rand(): + """Return random float in [0, 1).""" + return _rng.random() + +def prim_rand_int(lo, hi): + """Return random integer in [lo, hi].""" + return _rng.randint(int(lo), int(hi)) + +def prim_rand_range(lo, hi): + """Return random float in [lo, hi).""" + return lo + _rng.random() * (hi - lo) + +def prim_map_range(val, from_lo, from_hi, to_lo, to_hi): + """Map value from one range to another.""" + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + +# Core primitives dict +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + 'mod': prim_mod, + 'abs': prim_abs, + 'min': prim_min, + 'max': prim_max, + 'round': prim_round, + 'floor': prim_floor, + 'ceil': prim_ceil, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Logic + 'not': prim_not, + 'and': prim_and, + 'or': prim_or, + + # Data access + 'get': prim_get, + 'nth': prim_nth, + 'first': prim_first, + 'length': prim_length, + 'len': prim_length, + 'list': prim_list, + + # Type predicates + 'number?': prim_is_number, + 'string?': prim_is_string, + 'list?': prim_is_list, + 'dict?': prim_is_dict, + 'nil?': prim_is_nil, + 'is-nil': prim_is_nil, + + # Higher-order / iteration + 'reduce': prim_reduce, + 'fold': prim_reduce, + 'map': prim_map, + 'range': prim_range, + + # Random + 'rand': prim_rand, + 'rand-int': prim_rand_int, + 'rand-range': prim_rand_range, + 'map-range': prim_map_range, +} diff --git a/l1/sexp_effects/primitive_libs/drawing.py b/l1/sexp_effects/primitive_libs/drawing.py new file mode 100644 index 0000000..50e0c45 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/drawing.py @@ -0,0 +1,690 @@ +""" +Drawing Primitives Library + +Draw shapes, text, and characters on images with sophisticated text handling. + +Text Features: +- Font loading from files or system fonts +- Text measurement and fitting +- Alignment (left/center/right, top/middle/bottom) +- Opacity for fade effects +- Multi-line text support +- Shadow and outline effects +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont +import os +import glob as glob_module +from typing import Optional, Tuple, List, Union + + +# ============================================================================= +# Font Management +# ============================================================================= + +# Font cache: (path, size) -> font object +_font_cache = {} + +# Common system font directories +FONT_DIRS = [ + "/usr/share/fonts", + "/usr/local/share/fonts", + "~/.fonts", + "~/.local/share/fonts", + "/System/Library/Fonts", # macOS + "/Library/Fonts", # macOS + "C:/Windows/Fonts", # Windows +] + +# Default fonts to try (in order of preference) +DEFAULT_FONTS = [ + "DejaVuSans.ttf", + "DejaVuSansMono.ttf", + "Arial.ttf", + "Helvetica.ttf", + "FreeSans.ttf", + "LiberationSans-Regular.ttf", +] + + +def _find_font_file(name: str) -> Optional[str]: + """Find a font file by name in system directories.""" + # If it's already a full path + if os.path.isfile(name): + return name + + # Expand user paths + expanded = os.path.expanduser(name) + if os.path.isfile(expanded): + return expanded + + # Search in font directories + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + # Direct match + direct = os.path.join(font_dir, name) + if os.path.isfile(direct): + return direct + + # Recursive search + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower() == name.lower(): + return os.path.join(root, f) + # Also match without extension + base = os.path.splitext(f)[0] + if base.lower() == name.lower(): + return os.path.join(root, f) + + return None + + +def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont: + """Get a default font at the given size.""" + for font_name in DEFAULT_FONTS: + path = _find_font_file(font_name) + if path: + try: + return ImageFont.truetype(path, size) + except: + continue + + # Last resort: PIL default + return ImageFont.load_default() + + +def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont: + """ + Load a font by name or path. + + (make-font "Arial" 32) ; system font by name + (make-font "/path/to/font.ttf" 24) ; font file path + (make-font "DejaVuSans" 48) ; searches common locations + + Returns a font object for use with text primitives. + """ + size = int(size) + + # Check cache + cache_key = (name_or_path, size) + if cache_key in _font_cache: + return _font_cache[cache_key] + + # Find the font file + path = _find_font_file(name_or_path) + if not path: + raise FileNotFoundError(f"Font not found: {name_or_path}") + + # Load and cache + font = ImageFont.truetype(path, size) + _font_cache[cache_key] = font + return font + + +def prim_list_fonts() -> List[str]: + """ + List available system fonts. + + (list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...) + + Returns list of font filenames found in system directories. + """ + fonts = set() + + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower().endswith(('.ttf', '.otf', '.ttc')): + fonts.add(f) + + return sorted(fonts) + + +def prim_font_size(font: ImageFont.FreeTypeFont) -> int: + """ + Get the size of a font. + + (font-size my-font) ; -> 24 + """ + return font.size + + +# ============================================================================= +# Text Measurement +# ============================================================================= + +def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]: + """ + Measure text dimensions. + + (text-size "Hello" my-font) ; -> (width height) + (text-size "Hello" :font-size 32) ; -> (width height) with default font + + For multi-line text, returns total bounding box. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + # Create temporary image for measurement + img = Image.new('RGB', (1, 1)) + draw = ImageDraw.Draw(img) + + bbox = draw.textbbox((0, 0), str(text), font=font) + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + + return (width, height) + + +def prim_text_metrics(font=None, font_size: int = 24) -> dict: + """ + Get font metrics. + + (text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25} + + Useful for precise text layout. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + ascent, descent = font.getmetrics() + return { + 'ascent': ascent, + 'descent': descent, + 'height': ascent + descent, + 'size': font.size, + } + + +def prim_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> int: + """ + Calculate font size to fit text within bounds. + + (fit-text-size "Hello World" 400 100) ; -> 48 + (fit-text-size "Title" 800 200 :font-name "Arial") + + Returns the largest font size that fits within max_width x max_height. + """ + max_width = int(max_width) + max_height = int(max_height) + min_size = int(min_size) + max_size = int(max_size) + text = str(text) + + # Binary search for optimal size + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + + if font_name: + try: + font = prim_make_font(font_name, mid) + except: + font = _get_default_font(mid) + else: + font = _get_default_font(mid) + + w, h = prim_text_size(text, font) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + +def prim_fit_font(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> ImageFont.FreeTypeFont: + """ + Create a font sized to fit text within bounds. + + (fit-font "Hello World" 400 100) ; -> font object + (fit-font "Title" 800 200 :font-name "Arial") + + Returns a font object at the optimal size. + """ + size = prim_fit_text_size(text, max_width, max_height, + font_name, min_size, max_size) + + if font_name: + try: + return prim_make_font(font_name, size) + except: + pass + + return _get_default_font(size) + + +# ============================================================================= +# Text Drawing +# ============================================================================= + +def prim_text(img: np.ndarray, text: str, + x: int = None, y: int = None, + width: int = None, height: int = None, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "left", valign: str = "top", + fit: bool = False, + shadow: bool = False, shadow_color=None, shadow_offset: int = 2, + outline: bool = False, outline_color=None, outline_width: int = 1, + line_spacing: float = 1.2) -> np.ndarray: + """ + Draw text with alignment, opacity, and effects. + + Basic usage: + (text frame "Hello" :x 100 :y 50) + + Centered in frame: + (text frame "Title" :align "center" :valign "middle") + + Fit to box: + (text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true) + + With fade (for animations): + (text frame "Fading" :x 100 :y 100 :opacity 0.5) + + With effects: + (text frame "Shadow" :x 100 :y 100 :shadow true) + (text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0)) + + Args: + img: Input frame + text: Text to draw + x, y: Position (if not specified, uses alignment in full frame) + width, height: Bounding box (for fit and alignment within box) + font: Font object from make-font + font_size: Size if no font specified + font_name: Font name to load + color: RGB tuple (default white) + opacity: 0.0 (invisible) to 1.0 (opaque) for fading + align: "left", "center", "right" + valign: "top", "middle", "bottom" + fit: If true, auto-size font to fit in box + shadow: Draw drop shadow + shadow_color: Shadow color (default black) + shadow_offset: Shadow offset in pixels + outline: Draw text outline + outline_color: Outline color (default black) + outline_width: Outline thickness + line_spacing: Multiplier for line height (for multi-line) + + Returns: + Frame with text drawn + """ + h, w = img.shape[:2] + text = str(text) + + # Default colors + if color is None: + color = (255, 255, 255) + else: + color = tuple(int(c) for c in color) + + if shadow_color is None: + shadow_color = (0, 0, 0) + else: + shadow_color = tuple(int(c) for c in shadow_color) + + if outline_color is None: + outline_color = (0, 0, 0) + else: + outline_color = tuple(int(c) for c in outline_color) + + # Determine bounding box + if x is None: + x = 0 + if width is None: + width = w + if y is None: + y = 0 + if height is None: + height = h + + x, y = int(x), int(y) + box_width = int(width) if width else w - x + box_height = int(height) if height else h - y + + # Get or create font + if font is None: + if fit: + font = prim_fit_font(text, box_width, box_height, font_name) + elif font_name: + try: + font = prim_make_font(font_name, int(font_size)) + except: + font = _get_default_font(int(font_size)) + else: + font = _get_default_font(int(font_size)) + + # Measure text + text_w, text_h = prim_text_size(text, font) + + # Calculate position based on alignment + if align == "center": + draw_x = x + (box_width - text_w) // 2 + elif align == "right": + draw_x = x + box_width - text_w + else: # left + draw_x = x + + if valign == "middle": + draw_y = y + (box_height - text_h) // 2 + elif valign == "bottom": + draw_y = y + box_height - text_h + else: # top + draw_y = y + + # Create RGBA image for compositing with opacity + pil_img = Image.fromarray(img).convert('RGBA') + + # Create text layer with transparency + text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(text_layer) + + # Draw shadow first (if enabled) + if shadow: + shadow_x = draw_x + shadow_offset + shadow_y = draw_y + shadow_offset + shadow_rgba = shadow_color + (int(255 * opacity * 0.5),) + draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font) + + # Draw outline (if enabled) + if outline: + outline_rgba = outline_color + (int(255 * opacity),) + ow = int(outline_width) + for dx in range(-ow, ow + 1): + for dy in range(-ow, ow + 1): + if dx != 0 or dy != 0: + draw.text((draw_x + dx, draw_y + dy), text, + fill=outline_rgba, font=font) + + # Draw main text + text_rgba = color + (int(255 * opacity),) + draw.text((draw_x, draw_y), text, fill=text_rgba, font=font) + + # Composite + result = Image.alpha_composite(pil_img, text_layer) + return np.array(result.convert('RGB')) + + +def prim_text_box(img: np.ndarray, text: str, + x: int, y: int, width: int, height: int, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "center", valign: str = "middle", + fit: bool = True, + padding: int = 0, + background=None, background_opacity: float = 0.5, + **kwargs) -> np.ndarray: + """ + Draw text fitted within a box, optionally with background. + + (text-box frame "Title" 50 50 400 100) + (text-box frame "Subtitle" 50 160 400 50 + :background (0 0 0) :background-opacity 0.7) + + Convenience wrapper around text() for common box-with-text pattern. + """ + x, y = int(x), int(y) + width, height = int(width), int(height) + padding = int(padding) + + result = img.copy() + + # Draw background if specified + if background is not None: + bg_color = tuple(int(c) for c in background) + + # Create background with opacity + pil_img = Image.fromarray(result).convert('RGBA') + bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + bg_draw = ImageDraw.Draw(bg_layer) + bg_rgba = bg_color + (int(255 * background_opacity),) + bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba) + result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB')) + + # Draw text within padded box + return prim_text(result, text, + x=x + padding, y=y + padding, + width=width - 2 * padding, height=height - 2 * padding, + font=font, font_size=font_size, font_name=font_name, + color=color, opacity=opacity, + align=align, valign=valign, fit=fit, + **kwargs) + + +# ============================================================================= +# Legacy text functions (keep for compatibility) +# ============================================================================= + +def prim_draw_char(img, char, x, y, font_size=16, color=None): + """Draw a single character at (x, y). Legacy function.""" + return prim_text(img, str(char), x=int(x), y=int(y), + font_size=int(font_size), color=color) + + +def prim_draw_text(img, text, x, y, font_size=16, color=None): + """Draw text string at (x, y). Legacy function.""" + return prim_text(img, str(text), x=int(x), y=int(y), + font_size=int(font_size), color=color) + + +# ============================================================================= +# Shape Drawing +# ============================================================================= + +def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0): + """ + Fill a rectangle with color. + + (fill-rect frame 10 10 100 50 (255 0 0)) + (fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5) + """ + if color is None: + color = [255, 255, 255] + + x, y, w, h = int(x), int(y), int(w), int(h) + + if opacity >= 1.0: + result = img.copy() + result[y:y+h, x:x+w] = color + return result + + # With opacity, use alpha compositing + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([x, y, x + w, y + h], fill=fill_rgba) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0): + """Draw rectangle outline.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([int(x), int(y), int(x+w), int(y+h)], + outline=outline_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0): + """Draw a line from (x1, y1) to (x2, y2).""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.line([(int(x1), int(y1)), (int(x2), int(y2))], + fill=line_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, + fill=False, opacity: float = 1.0): + """Draw a circle.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.circle(result, (int(cx), int(cy)), int(radius), + tuple(int(c) for c in color), t) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + cx, cy, r = int(cx), int(cy), int(radius) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + + if fill: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba) + else: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], + outline=rgba, width=int(thickness)) + + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, + thickness=1, fill=False, opacity: float = 1.0): + """Draw an ellipse.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, tuple(int(c) for c in color), t) + return result + + # With opacity (note: PIL doesn't support rotated ellipses easily) + # Fall back to cv2 on a separate layer + layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) + t = -1 if fill else int(thickness) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, rgba, t) + + pil_img = Image.fromarray(img).convert('RGBA') + pil_layer = Image.fromarray(layer) + result = Image.alpha_composite(pil_img, pil_layer) + return np.array(result.convert('RGB')) + + +def prim_draw_polygon(img, points, color=None, thickness=1, + fill=False, opacity: float = 1.0): + """Draw a polygon from list of [x, y] points.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + if fill: + cv2.fillPoly(result, [pts], tuple(int(c) for c in color)) + else: + cv2.polylines(result, [pts], True, + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + + pts_flat = [(int(p[0]), int(p[1])) for p in points] + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + + if fill: + draw.polygon(pts_flat, fill=rgba) + else: + draw.polygon(pts_flat, outline=rgba, width=int(thickness)) + + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + +PRIMITIVES = { + # Font management + 'make-font': prim_make_font, + 'list-fonts': prim_list_fonts, + 'font-size': prim_font_size, + + # Text measurement + 'text-size': prim_text_size, + 'text-metrics': prim_text_metrics, + 'fit-text-size': prim_fit_text_size, + 'fit-font': prim_fit_font, + + # Text drawing + 'text': prim_text, + 'text-box': prim_text_box, + + # Legacy text (compatibility) + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + + # Rectangles + 'fill-rect': prim_fill_rect, + 'draw-rect': prim_draw_rect, + + # Lines and shapes + 'draw-line': prim_draw_line, + 'draw-circle': prim_draw_circle, + 'draw-ellipse': prim_draw_ellipse, + 'draw-polygon': prim_draw_polygon, +} diff --git a/l1/sexp_effects/primitive_libs/filters.py b/l1/sexp_effects/primitive_libs/filters.py new file mode 100644 index 0000000..a66f107 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/filters.py @@ -0,0 +1,119 @@ +""" +Filters Primitives Library + +Image filters: blur, sharpen, edges, convolution. +""" +import numpy as np +import cv2 + + +def prim_blur(img, radius): + """Gaussian blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img, radius): + """Box blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_median_blur(img, radius): + """Median blur (good for noise removal).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.medianBlur(img, ksize) + + +def prim_bilateral(img, d=9, sigma_color=75, sigma_space=75): + """Bilateral filter (edge-preserving blur).""" + return cv2.bilateralFilter(img, d, sigma_color, sigma_space) + + +def prim_sharpen(img, amount=1.0): + """Sharpen image using unsharp mask.""" + blurred = cv2.GaussianBlur(img, (0, 0), 3) + return cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0) + + +def prim_edges(img, low=50, high=150): + """Canny edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, low, high) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img, ksize=3): + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=ksize) + sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=ksize) + mag = np.sqrt(sobelx**2 + sobely**2) + mag = np.clip(mag, 0, 255).astype(np.uint8) + return cv2.cvtColor(mag, cv2.COLOR_GRAY2RGB) + + +def prim_laplacian(img, ksize=3): + """Laplacian edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + lap = cv2.Laplacian(gray, cv2.CV_64F, ksize=ksize) + lap = np.abs(lap) + lap = np.clip(lap, 0, 255).astype(np.uint8) + return cv2.cvtColor(lap, cv2.COLOR_GRAY2RGB) + + +def prim_emboss(img): + """Emboss effect.""" + kernel = np.array([[-2, -1, 0], + [-1, 1, 1], + [ 0, 1, 2]]) + result = cv2.filter2D(img, -1, kernel) + return np.clip(result + 128, 0, 255).astype(np.uint8) + + +def prim_dilate(img, size=1): + """Morphological dilation.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.dilate(img, kernel) + + +def prim_erode(img, size=1): + """Morphological erosion.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.erode(img, kernel) + + +def prim_convolve(img, kernel): + """Apply custom convolution kernel.""" + kernel = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, kernel) + + +PRIMITIVES = { + # Blur + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'median-blur': prim_median_blur, + 'bilateral': prim_bilateral, + + # Sharpen + 'sharpen': prim_sharpen, + + # Edges + 'edges': prim_edges, + 'sobel': prim_sobel, + 'laplacian': prim_laplacian, + + # Effects + 'emboss': prim_emboss, + + # Morphology + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Custom + 'convolve': prim_convolve, +} diff --git a/l1/sexp_effects/primitive_libs/geometry.py b/l1/sexp_effects/primitive_libs/geometry.py new file mode 100644 index 0000000..5b385a4 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/geometry.py @@ -0,0 +1,143 @@ +""" +Geometry Primitives Library + +Geometric transforms: rotate, scale, flip, translate, remap. +""" +import numpy as np +import cv2 + + +def prim_translate(img, dx, dy): + """Translate image by (dx, dy) pixels.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_rotate(img, angle, cx=None, cy=None): + """Rotate image by angle degrees around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_scale(img, sx, sy, cx=None, cy=None): + """Scale image by (sx, sy) around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Build transform matrix + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_flip_h(img): + """Flip image horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img): + """Flip image vertically.""" + return cv2.flip(img, 0) + + +def prim_flip(img, direction="horizontal"): + """Flip image in given direction.""" + if direction in ("horizontal", "h"): + return prim_flip_h(img) + elif direction in ("vertical", "v"): + return prim_flip_v(img) + elif direction in ("both", "hv", "vh"): + return cv2.flip(img, -1) + return img + + +def prim_transpose(img): + """Transpose image (swap x and y).""" + return np.transpose(img, (1, 0, 2)) + + +def prim_remap(img, map_x, map_y): + """Remap image using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), + map_y.astype(np.float32), + cv2.INTER_LINEAR) + + +def prim_make_coords(w, h): + """Create coordinate grids for remapping.""" + x = np.arange(w, dtype=np.float32) + y = np.arange(h, dtype=np.float32) + map_x, map_y = np.meshgrid(x, y) + return (map_x, map_y) + + +def prim_perspective(img, src_pts, dst_pts): + """Apply perspective transform.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getPerspectiveTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpPerspective(img, M, (w, h)) + + +def prim_affine(img, src_pts, dst_pts): + """Apply affine transform using 3 point pairs.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getAffineTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpAffine(img, M, (w, h)) + + +def _get_legacy_geometry_primitives(): + """Import geometry primitives from legacy primitives module.""" + from sexp_effects.primitives import ( + prim_coords_x, + prim_coords_y, + prim_ripple_displace, + prim_fisheye_displace, + prim_kaleidoscope_displace, + ) + return { + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'ripple-displace': prim_ripple_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + } + + +PRIMITIVES = { + # Basic transforms + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + + # Flips + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'flip': prim_flip, + 'transpose': prim_transpose, + + # Remapping + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Advanced transforms + 'perspective': prim_perspective, + 'affine': prim_affine, + + # Displace / coordinate ops (from legacy primitives) + **_get_legacy_geometry_primitives(), +} diff --git a/l1/sexp_effects/primitive_libs/geometry_gpu.py b/l1/sexp_effects/primitive_libs/geometry_gpu.py new file mode 100644 index 0000000..d4e3193 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/geometry_gpu.py @@ -0,0 +1,403 @@ +""" +GPU-Accelerated Geometry Primitives Library + +Uses CuPy for CUDA-accelerated image transforms. +Falls back to CPU if GPU unavailable. + +Performance Mode: +- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations +- This dramatically improves performance by avoiding CPU<->GPU transfers +- Frames only transfer to CPU at final output +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + from cupyx.scipy import ndimage as cpndimage + GPU_AVAILABLE = True + print("[geometry_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[geometry_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +# Set STREAMING_GPU_PERSIST=1 for maximum performance +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[geometry_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def _ensure_output_format(img): + """Ensure output is in correct format based on GPU_PERSIST setting.""" + return _to_cpu(img) + + +def prim_rotate(img, angle, cx=None, cy=None): + """Rotate image by angle degrees around center (cx, cy). + + Uses fast CUDA kernel when available (< 1ms vs 20ms for scipy). + """ + if not GPU_AVAILABLE: + # Fallback to OpenCV + import cv2 + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + # Use fast CUDA kernel (prim_rotate_gpu defined below) + return prim_rotate_gpu(img, angle, cx, cy) + + +def prim_scale(img, sx, sy, cx=None, cy=None): + """Scale image by (sx, sy) around center (cx, cy).""" + if not GPU_AVAILABLE: + import cv2 + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h)) + + img_gpu = _to_gpu(img) + h, w = img_gpu.shape[:2] + + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Use cupyx.scipy.ndimage.zoom + if img_gpu.ndim == 3: + zoom_factors = (sy, sx, 1) # Don't zoom color channels + else: + zoom_factors = (sy, sx) + + zoomed = cpndimage.zoom(img_gpu, zoom_factors, order=1) + + # Crop/pad to original size + zh, zw = zoomed.shape[:2] + result = cp.zeros_like(img_gpu) + + # Calculate offsets + src_y = max(0, (zh - h) // 2) + src_x = max(0, (zw - w) // 2) + dst_y = max(0, (h - zh) // 2) + dst_x = max(0, (w - zw) // 2) + + copy_h = min(h - dst_y, zh - src_y) + copy_w = min(w - dst_x, zw - src_x) + + result[dst_y:dst_y+copy_h, dst_x:dst_x+copy_w] = zoomed[src_y:src_y+copy_h, src_x:src_x+copy_w] + + return _to_cpu(result) + + +def prim_translate(img, dx, dy): + """Translate image by (dx, dy) pixels.""" + if not GPU_AVAILABLE: + import cv2 + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h)) + + img_gpu = _to_gpu(img) + # Use cupyx.scipy.ndimage.shift + if img_gpu.ndim == 3: + shift = (dy, dx, 0) # Don't shift color channels + else: + shift = (dy, dx) + + shifted = cpndimage.shift(img_gpu, shift, order=1) + return _to_cpu(shifted) + + +def prim_flip_h(img): + """Flip image horizontally.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(img_gpu, axis=1)) + return np.flip(img, axis=1) + + +def prim_flip_v(img): + """Flip image vertically.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(img_gpu, axis=0)) + return np.flip(img, axis=0) + + +def prim_flip(img, direction="horizontal"): + """Flip image in given direction.""" + if direction in ("horizontal", "h"): + return prim_flip_h(img) + elif direction in ("vertical", "v"): + return prim_flip_v(img) + elif direction in ("both", "hv", "vh"): + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(cp.flip(img_gpu, axis=0), axis=1)) + return np.flip(np.flip(img, axis=0), axis=1) + return img + + +# CUDA kernel for ripple effect +if GPU_AVAILABLE: + _ripple_kernel = cp.RawKernel(r''' + extern "C" __global__ + void ripple(const unsigned char* src, unsigned char* dst, + int width, int height, int channels, + float amplitude, float frequency, float decay, + float speed, float time, float cx, float cy) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Distance from center + float dx = x - cx; + float dy = y - cy; + float dist = sqrtf(dx * dx + dy * dy); + + // Ripple displacement + float wave = sinf(dist * frequency * 0.1f - time * speed) * amplitude; + float falloff = expf(-dist * decay * 0.01f); + float displacement = wave * falloff; + + // Direction from center + float len = dist + 0.0001f; // Avoid division by zero + float dir_x = dx / len; + float dir_y = dy / len; + + // Source coordinates + float src_x = x - dir_x * displacement; + float src_y = y - dir_y * displacement; + + // Clamp to bounds + src_x = fmaxf(0.0f, fminf(width - 1.0f, src_x)); + src_y = fmaxf(0.0f, fminf(height - 1.0f, src_y)); + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = min(x0 + 1, width - 1); + int y1 = min(y0 + 1, height - 1); + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float val = v0 * (1 - fy) + v1 * fy; + + dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val)); + } + } + ''', 'ripple') + + +def prim_ripple(img, amplitude=10.0, frequency=8.0, decay=2.0, speed=5.0, + time=0.0, center_x=None, center_y=None): + """Apply ripple distortion effect.""" + h, w = img.shape[:2] + channels = img.shape[2] if img.ndim == 3 else 1 + + if center_x is None: + center_x = w / 2 + if center_y is None: + center_y = h / 2 + + if not GPU_AVAILABLE: + # CPU fallback using coordinate mapping + import cv2 + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx**2 + dy**2) + + wave = np.sin(dist * frequency * 0.1 - time * speed) * amplitude + falloff = np.exp(-dist * decay * 0.01) + displacement = wave * falloff + + length = dist + 0.0001 + dir_x = dx / length + dir_y = dy / length + + map_x = (x_coords - dir_x * displacement).astype(np.float32) + map_y = (y_coords - dir_y * displacement).astype(np.float32) + + return cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR) + + # GPU implementation + img_gpu = _to_gpu(img.astype(np.uint8)) + if img_gpu.ndim == 2: + img_gpu = img_gpu[:, :, cp.newaxis] + channels = 1 + + dst = cp.zeros_like(img_gpu) + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _ripple_kernel(grid, block, ( + img_gpu, dst, + np.int32(w), np.int32(h), np.int32(channels), + np.float32(amplitude), np.float32(frequency), np.float32(decay), + np.float32(speed), np.float32(time), + np.float32(center_x), np.float32(center_y) + )) + + result = _to_cpu(dst) + if channels == 1: + result = result[:, :, 0] + return result + + +# CUDA kernel for fast rotation with bilinear interpolation +if GPU_AVAILABLE: + _rotate_kernel = cp.RawKernel(r''' + extern "C" __global__ + void rotate_img(const unsigned char* src, unsigned char* dst, + int width, int height, int channels, + float cos_a, float sin_a, float cx, float cy) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Translate to center, rotate, translate back + float dx = x - cx; + float dy = y - cy; + + float src_x = cos_a * dx + sin_a * dy + cx; + float src_y = -sin_a * dx + cos_a * dy + cy; + + // Check bounds + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[(y * width + x) * channels + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = x0 + 1; + int y1 = y0 + 1; + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float val = v0 * (1 - fy) + v1 * fy; + + dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val)); + } + } + ''', 'rotate_img') + + +def prim_rotate_gpu(img, angle, cx=None, cy=None): + """Fast GPU rotation using custom CUDA kernel.""" + if not GPU_AVAILABLE: + return prim_rotate(img, angle, cx, cy) + + h, w = img.shape[:2] + channels = img.shape[2] if img.ndim == 3 else 1 + + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + img_gpu = _to_gpu(img.astype(np.uint8)) + if img_gpu.ndim == 2: + img_gpu = img_gpu[:, :, cp.newaxis] + channels = 1 + + dst = cp.zeros_like(img_gpu) + + # Convert angle to radians + rad = np.radians(angle) + cos_a = np.cos(rad) + sin_a = np.sin(rad) + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _rotate_kernel(grid, block, ( + img_gpu, dst, + np.int32(w), np.int32(h), np.int32(channels), + np.float32(cos_a), np.float32(sin_a), + np.float32(cx), np.float32(cy) + )) + + result = _to_cpu(dst) + if channels == 1: + result = result[:, :, 0] + return result + + +# Import CPU primitives as fallbacks for functions we don't GPU-accelerate +def _get_cpu_primitives(): + """Get all primitives from CPU geometry module as fallbacks.""" + from sexp_effects.primitive_libs import geometry + return geometry.PRIMITIVES + + +# Export functions - start with CPU primitives, then override with GPU versions +PRIMITIVES = _get_cpu_primitives().copy() + +# Override specific primitives with GPU-accelerated versions +PRIMITIVES.update({ + 'translate': prim_translate, + 'rotate': prim_rotate_gpu if GPU_AVAILABLE else prim_rotate, # Fast CUDA kernel + 'rotate-img': prim_rotate_gpu if GPU_AVAILABLE else prim_rotate, # Alias + 'scale-img': prim_scale, + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'flip': prim_flip, + 'ripple': prim_ripple, # Fast CUDA kernel + # Note: ripple-displace uses CPU version (different API - returns coords, not image) +}) diff --git a/l1/sexp_effects/primitive_libs/image.py b/l1/sexp_effects/primitive_libs/image.py new file mode 100644 index 0000000..2ab922c --- /dev/null +++ b/l1/sexp_effects/primitive_libs/image.py @@ -0,0 +1,150 @@ +""" +Image Primitives Library + +Basic image operations: dimensions, pixels, resize, crop, paste. +""" +import numpy as np +import cv2 + + +def prim_width(img): + if isinstance(img, (list, tuple)): + raise TypeError(f"image:width expects an image array, got {type(img).__name__} with {len(img)} elements") + return img.shape[1] + + +def prim_height(img): + if isinstance(img, (list, tuple)): + import sys + print(f"DEBUG image:height got list: {img[:3]}... (types: {[type(x).__name__ for x in img[:3]]})", file=sys.stderr) + raise TypeError(f"image:height expects an image array, got {type(img).__name__} with {len(img)} elements: {img}") + return img.shape[0] + + +def prim_make_image(w, h, color=None): + """Create a new image filled with color (default black).""" + if color is None: + color = [0, 0, 0] + img = np.zeros((h, w, 3), dtype=np.uint8) + img[:] = color + return img + + +def prim_copy(img): + return img.copy() + + +def prim_pixel(img, x, y): + """Get pixel color at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + if 0 <= x < w and 0 <= y < h: + return list(img[int(y), int(x)]) + return [0, 0, 0] + + +def prim_set_pixel(img, x, y, color): + """Set pixel at (x, y) to color, returns modified image.""" + result = img.copy() + h, w = result.shape[:2] + if 0 <= x < w and 0 <= y < h: + result[int(y), int(x)] = color + return result + + +def prim_sample(img, x, y): + """Bilinear sample at float coordinates, returns [r, g, b] as floats.""" + h, w = img.shape[:2] + x = max(0, min(w - 1.001, x)) + y = max(0, min(h - 1.001, y)) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + top = c00 * (1 - fx) + c10 * fx + bottom = c01 * (1 - fx) + c11 * fx + return list(top * (1 - fy) + bottom * fy) + + +def prim_channel(img, c): + """Extract single channel (0=R, 1=G, 2=B).""" + return img[:, :, c] + + +def prim_merge_channels(r, g, b): + """Merge three single-channel arrays into RGB image.""" + return np.stack([r, g, b], axis=2).astype(np.uint8) + + +def prim_resize(img, w, h, mode="linear"): + """Resize image to w x h.""" + interp = cv2.INTER_LINEAR + if mode == "nearest": + interp = cv2.INTER_NEAREST + elif mode == "cubic": + interp = cv2.INTER_CUBIC + elif mode == "area": + interp = cv2.INTER_AREA + return cv2.resize(img, (int(w), int(h)), interpolation=interp) + + +def prim_crop(img, x, y, w, h): + """Crop rectangle from image.""" + x, y, w, h = int(x), int(y), int(w), int(h) + ih, iw = img.shape[:2] + x = max(0, min(x, iw - 1)) + y = max(0, min(y, ih - 1)) + w = min(w, iw - x) + h = min(h, ih - y) + return img[y:y+h, x:x+w].copy() + + +def prim_paste(dst, src, x, y): + """Paste src onto dst at position (x, y).""" + result = dst.copy() + x, y = int(x), int(y) + sh, sw = src.shape[:2] + dh, dw = dst.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + if sx2 > sx1 and sy2 > sy1: + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = src[sy1:sy2, sx1:sx2] + + return result + + +PRIMITIVES = { + # Dimensions + 'width': prim_width, + 'height': prim_height, + + # Creation + 'make-image': prim_make_image, + 'copy': prim_copy, + + # Pixel access + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + + # Channels + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + + # Geometry + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, +} diff --git a/l1/sexp_effects/primitive_libs/math.py b/l1/sexp_effects/primitive_libs/math.py new file mode 100644 index 0000000..140ad3e --- /dev/null +++ b/l1/sexp_effects/primitive_libs/math.py @@ -0,0 +1,164 @@ +""" +Math Primitives Library + +Trigonometry, rounding, clamping, random numbers, etc. +""" +import math +import random as rand_module + + +def prim_sin(x): + return math.sin(x) + + +def prim_cos(x): + return math.cos(x) + + +def prim_tan(x): + return math.tan(x) + + +def prim_asin(x): + return math.asin(x) + + +def prim_acos(x): + return math.acos(x) + + +def prim_atan(x): + return math.atan(x) + + +def prim_atan2(y, x): + return math.atan2(y, x) + + +def prim_sqrt(x): + return math.sqrt(x) + + +def prim_pow(x, y): + return math.pow(x, y) + + +def prim_exp(x): + return math.exp(x) + + +def prim_log(x, base=None): + if base is None: + return math.log(x) + return math.log(x, base) + + +def prim_abs(x): + return abs(x) + + +def prim_floor(x): + return math.floor(x) + + +def prim_ceil(x): + return math.ceil(x) + + +def prim_round(x): + return round(x) + + +def prim_min(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return min(args[0]) + return min(args) + + +def prim_max(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return max(args[0]) + return max(args) + + +def prim_clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +def prim_lerp(a, b, t): + """Linear interpolation: a + (b - a) * t""" + return a + (b - a) * t + + +def prim_smoothstep(edge0, edge1, x): + """Smooth interpolation between 0 and 1.""" + t = prim_clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +def prim_random(lo=0.0, hi=1.0): + return rand_module.uniform(lo, hi) + + +def prim_randint(lo, hi): + return rand_module.randint(lo, hi) + + +def prim_gaussian(mean=0.0, std=1.0): + return rand_module.gauss(mean, std) + + +def prim_sign(x): + if x > 0: + return 1 + elif x < 0: + return -1 + return 0 + + +def prim_fract(x): + """Fractional part of x.""" + return x - math.floor(x) + + +PRIMITIVES = { + # Trigonometry + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'asin': prim_asin, + 'acos': prim_acos, + 'atan': prim_atan, + 'atan2': prim_atan2, + + # Powers and roots + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'exp': prim_exp, + 'log': prim_log, + + # Rounding + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'sign': prim_sign, + 'fract': prim_fract, + + # Min/max/clamp + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'smoothstep': prim_smoothstep, + + # Random + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + + # Constants + 'pi': math.pi, + 'tau': math.tau, + 'e': math.e, +} diff --git a/l1/sexp_effects/primitive_libs/streaming.py b/l1/sexp_effects/primitive_libs/streaming.py new file mode 100644 index 0000000..ccb6056 --- /dev/null +++ b/l1/sexp_effects/primitive_libs/streaming.py @@ -0,0 +1,593 @@ +""" +Streaming primitives for video/audio processing. + +These primitives handle video source reading and audio analysis, +keeping the interpreter completely generic. + +GPU Acceleration: +- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU) +- Hardware video decoding (NVDEC) is used when available +- Dramatically improves performance on GPU nodes + +Async Prefetching: +- Set STREAMING_PREFETCH=1 to enable background frame prefetching +- Decodes upcoming frames while current frame is being processed +""" + +import os +import numpy as np +import subprocess +import json +import threading +from collections import deque +from pathlib import Path + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + +# GPU persistence mode - output CuPy arrays instead of numpy +# Disabled by default until all primitives support GPU frames +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE + +# Async prefetch mode - decode frames in background thread +PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1" +PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10")) + +# Check for hardware decode support (cached) +_HWDEC_AVAILABLE = None + + +def _check_hwdec(): + """Check if NVIDIA hardware decode is available.""" + global _HWDEC_AVAILABLE + if _HWDEC_AVAILABLE is not None: + return _HWDEC_AVAILABLE + + try: + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2) + if result.returncode != 0: + _HWDEC_AVAILABLE = False + return False + result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5) + _HWDEC_AVAILABLE = "cuda" in result.stdout + except Exception: + _HWDEC_AVAILABLE = False + + return _HWDEC_AVAILABLE + + +class VideoSource: + """Video source with persistent streaming pipe for fast sequential reads.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + self.fps = fps # Output fps for the stream + self._frame_size = None + self._duration = None + self._proc = None # Persistent ffmpeg process + self._stream_time = 0.0 # Current position in stream + self._frame_time = 1.0 / fps # Time per frame at output fps + self._last_read_time = -1 + self._cached_frame = None + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + # Get video info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}") + try: + info = json.loads(result.stdout) + except json.JSONDecodeError: + raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}") + + for stream in info.get("streams", []): + if stream.get("codec_type") == "video": + self._frame_size = (stream.get("width", 720), stream.get("height", 720)) + # Try direct duration field first + if "duration" in stream: + self._duration = float(stream["duration"]) + # Fall back to tags.DURATION (webm format: "00:01:00.124000000") + elif "tags" in stream and "DURATION" in stream["tags"]: + dur_str = stream["tags"]["DURATION"] + parts = dur_str.split(":") + if len(parts) == 3: + h, m, s = parts + self._duration = int(h) * 3600 + int(m) * 60 + float(s) + break + + # Fallback: check format duration if stream duration not found + if self._duration is None and "format" in info and "duration" in info["format"]: + self._duration = float(info["format"]["duration"]) + + if not self._frame_size: + self._frame_size = (720, 720) + + import sys + print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr) + + def _start_stream(self, seek_time: float = 0): + """Start or restart the ffmpeg streaming process. + + Uses NVIDIA hardware decoding (NVDEC) when available for better performance. + """ + if self._proc: + self._proc.kill() + self._proc = None + + # Check file exists before trying to open + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + w, h = self._frame_size + + # Build ffmpeg command with optional hardware decode + cmd = ["ffmpeg", "-v", "error"] + + # Use hardware decode if available (significantly faster) + if _check_hwdec(): + cmd.extend(["-hwaccel", "cuda"]) + + cmd.extend([ + "-ss", f"{seek_time:.3f}", + "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(self.fps), # Output at specified fps + "-" + ]) + + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._stream_time = seek_time + + # Check if process started successfully by reading first bit of stderr + import select + import sys + readable, _, _ = select.select([self._proc.stderr], [], [], 0.5) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr) + + def _read_frame_from_stream(self): + """Read one frame from the stream. + + Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise. + """ + w, h = self._frame_size + frame_size = w * h * 3 + + if not self._proc or self._proc.poll() is not None: + return None + + data = self._proc.stdout.read(frame_size) + if len(data) < frame_size: + return None + + frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() + + # Transfer to GPU if persistence mode enabled + if GPU_PERSIST: + return cp.asarray(frame) + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + if self._cached_frame is not None: + return self._cached_frame + return self.read_at(0) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time using streaming with smart seeking.""" + # Cache check - return same frame for same time + if t == self._last_read_time and self._cached_frame is not None: + return self._cached_frame + + w, h = self._frame_size + + # Loop time if video is shorter + seek_time = t + if self._duration and self._duration > 0: + seek_time = t % self._duration + # If we're within 0.1s of the end, wrap to beginning to avoid EOF issues + if seek_time > self._duration - 0.1: + seek_time = 0.0 + + # Decide whether to seek or continue streaming + # Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead + # Allow small backward tolerance to handle floating point and timing jitter + need_seek = ( + self._proc is None or + self._proc.poll() is not None or + seek_time < self._stream_time - self._frame_time or # More than 1 frame backward + seek_time > self._stream_time + 2.0 + ) + + if need_seek: + import sys + reason = "no proc" if self._proc is None else "proc dead" if self._proc.poll() is not None else "backward" if seek_time < self._stream_time else "jump" + print(f"SEEK {self.path.name}: t={t:.4f} seek={seek_time:.4f} stream={self._stream_time:.4f} ({reason})", file=sys.stderr) + self._start_stream(seek_time) + + # Skip frames to reach target time + skip_retries = 0 + while self._stream_time + self._frame_time <= seek_time: + frame = self._read_frame_from_stream() + if frame is None: + # Stream ended or failed - restart from seek point + import time + skip_retries += 1 + if skip_retries > 3: + # Give up skipping, just start fresh at seek_time + self._start_stream(seek_time) + time.sleep(0.1) + break + self._start_stream(seek_time) + time.sleep(0.05) + continue + self._stream_time += self._frame_time + skip_retries = 0 # Reset on successful read + + # Read the target frame with retry logic + frame = None + max_retries = 3 + for attempt in range(max_retries): + frame = self._read_frame_from_stream() + if frame is not None: + break + + # Stream failed - try restarting + import sys + import time + print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr) + + # Check for ffmpeg errors + if self._proc and self._proc.stderr: + try: + import select + readable, _, _ = select.select([self._proc.stderr], [], [], 0.1) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error: {err}", file=sys.stderr) + except: + pass + + # Wait a bit and restart + time.sleep(0.1) + self._start_stream(seek_time) + + # Give ffmpeg time to start + time.sleep(0.1) + + if frame is None: + import sys + raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries") + else: + self._stream_time += self._frame_time + + self._last_read_time = t + self._cached_frame = frame + return frame + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._frame_size + + def close(self): + if self._proc: + self._proc.kill() + self._proc = None + + +class PrefetchingVideoSource: + """ + Video source with background prefetching for improved performance. + + Wraps VideoSource and adds a background thread that pre-decodes + upcoming frames while the main thread processes the current frame. + """ + + def __init__(self, path: str, fps: float = 30, buffer_size: int = None): + self._source = VideoSource(path, fps) + self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE + self._buffer = {} # time -> frame + self._buffer_lock = threading.Lock() + self._prefetch_time = 0.0 + self._frame_time = 1.0 / fps + self._stop_event = threading.Event() + self._request_event = threading.Event() + self._target_time = 0.0 + + # Start prefetch thread + self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) + self._thread.start() + + import sys + print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr) + + def _prefetch_loop(self): + """Background thread that pre-reads frames.""" + while not self._stop_event.is_set(): + # Wait for work or timeout + self._request_event.wait(timeout=0.01) + self._request_event.clear() + + if self._stop_event.is_set(): + break + + # Prefetch frames ahead of target time + target = self._target_time + with self._buffer_lock: + # Clean old frames (more than 1 second behind) + old_times = [t for t in self._buffer.keys() if t < target - 1.0] + for t in old_times: + del self._buffer[t] + + # Count how many frames we have buffered ahead + buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target) + + # Prefetch if buffer not full + if buffered_ahead < self._buffer_size: + # Find next time to prefetch + prefetch_t = target + with self._buffer_lock: + existing_times = set(self._buffer.keys()) + for _ in range(self._buffer_size): + if prefetch_t not in existing_times: + break + prefetch_t += self._frame_time + + # Read the frame (this is the slow part) + try: + frame = self._source.read_at(prefetch_t) + with self._buffer_lock: + self._buffer[prefetch_t] = frame + except Exception as e: + import sys + print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time, using prefetch buffer if available.""" + self._target_time = t + self._request_event.set() # Wake up prefetch thread + + # Round to frame time for buffer lookup + t_key = round(t / self._frame_time) * self._frame_time + + # Check buffer first + with self._buffer_lock: + if t_key in self._buffer: + return self._buffer[t_key] + # Also check for close matches (within half frame time) + for buf_t, frame in self._buffer.items(): + if abs(buf_t - t) < self._frame_time * 0.5: + return frame + + # Not in buffer - read directly (blocking) + frame = self._source.read_at(t) + + # Store in buffer + with self._buffer_lock: + self._buffer[t_key] = frame + + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + return self.read_at(0) + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._source.size + + @property + def path(self): + return self._source.path + + def close(self): + self._stop_event.set() + self._request_event.set() # Wake up thread to exit + self._thread.join(timeout=1.0) + self._source.close() + + +class AudioAnalyzer: + """Audio analyzer for energy and beat detection.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + self.sample_rate = sample_rate + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Audio file not found: {self.path}") + + # Load audio via ffmpeg + cmd = ["ffmpeg", "-v", "error", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}") + self._audio = np.frombuffer(result.stdout, dtype=np.float32) + if len(self._audio) == 0: + raise RuntimeError(f"Audio file is empty or invalid: {self.path}") + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}") + info = json.loads(result.stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + # Beat detection state + self._flux_history = [] + self._last_beat_time = -1 + self._beat_count = 0 + self._last_beat_check_time = -1 + # Cache beat result for current time (so multiple scans see same result) + self._beat_cache_time = -1 + self._beat_cache_result = False + + def get_energy(self, t: float) -> float: + """Get energy level at time t (0-1).""" + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + """Check if there's a beat at time t.""" + # Return cached result if same time (multiple scans query same frame) + if t == self._beat_cache_time: + return self._beat_cache_result + + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size/2: + self._beat_cache_time = t + self._beat_cache_result = False + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + self._beat_cache_time = t + self._beat_cache_result = False + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + if len(self._flux_history) > 50: + self._flux_history = self._flux_history[-50:] + + if len(self._flux_history) < 5: + self._beat_cache_time = t + self._beat_cache_result = False + return False + + recent = [f for _, f in self._flux_history[-20:]] + threshold = np.mean(recent) + 1.5 * np.std(recent) + + is_beat = flux > threshold and (t - self._last_beat_time) > 0.1 + if is_beat: + self._last_beat_time = t + if t > self._last_beat_check_time: + self._beat_count += 1 + self._last_beat_check_time = t + + # Cache result for this time + self._beat_cache_time = t + self._beat_cache_result = is_beat + return is_beat + + def get_beat_count(self, t: float) -> int: + """Get cumulative beat count up to time t.""" + # Ensure beat detection has run up to this time + self.get_beat(t) + return self._beat_count + + +# === Primitives === + +def prim_make_video_source(path: str, fps: float = 30): + """Create a video source from a file path. + + Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default). + """ + if PREFETCH_ENABLED: + return PrefetchingVideoSource(path, fps) + return VideoSource(path, fps) + + +def prim_source_read(source: VideoSource, t: float = None): + """Read a frame from a video source.""" + import sys + if t is not None: + frame = source.read_at(t) + # Debug: show source and time + if int(t * 10) % 10 == 0: # Every second + print(f"READ {source.path.name}: t={t:.2f} stream={source._stream_time:.2f}", file=sys.stderr) + return frame + return source.read() + + +def prim_source_skip(source: VideoSource): + """Skip a frame (keep pipe in sync).""" + source.skip() + + +def prim_source_size(source: VideoSource): + """Get (width, height) of source.""" + return source.size + + +def prim_make_audio_analyzer(path: str): + """Create an audio analyzer from a file path.""" + return AudioAnalyzer(path) + + +def prim_audio_energy(analyzer: AudioAnalyzer, t: float) -> float: + """Get energy level (0-1) at time t.""" + return analyzer.get_energy(t) + + +def prim_audio_beat(analyzer: AudioAnalyzer, t: float) -> bool: + """Check if there's a beat at time t.""" + return analyzer.get_beat(t) + + +def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int: + """Get cumulative beat count up to time t.""" + return analyzer.get_beat_count(t) + + +def prim_audio_duration(analyzer: AudioAnalyzer) -> float: + """Get audio duration in seconds.""" + return analyzer.duration + + +# Export primitives +PRIMITIVES = { + # Video source + 'make-video-source': prim_make_video_source, + 'source-read': prim_source_read, + 'source-skip': prim_source_skip, + 'source-size': prim_source_size, + + # Audio analyzer + 'make-audio-analyzer': prim_make_audio_analyzer, + 'audio-energy': prim_audio_energy, + 'audio-beat': prim_audio_beat, + 'audio-beat-count': prim_audio_beat_count, + 'audio-duration': prim_audio_duration, +} diff --git a/l1/sexp_effects/primitive_libs/streaming_gpu.py b/l1/sexp_effects/primitive_libs/streaming_gpu.py new file mode 100644 index 0000000..f2aa7ea --- /dev/null +++ b/l1/sexp_effects/primitive_libs/streaming_gpu.py @@ -0,0 +1,1165 @@ +""" +GPU-Accelerated Streaming Primitives + +Provides GPU-native video source and frame processing. +Frames stay on GPU memory throughout the pipeline for maximum performance. + +Architecture: +- GPUFrame: Wrapper that tracks whether data is on CPU or GPU +- GPUVideoSource: Hardware-accelerated decode to GPU memory +- GPU primitives operate directly on GPU frames using fast CUDA kernels +- Transfer to CPU only at final output + +Requirements: +- CuPy for CUDA support +- FFmpeg with NVDEC support (for hardware decode) +- NVIDIA GPU with CUDA capability +""" + +import os +import sys +import json +import subprocess +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, Union + +# Try to import CuPy +try: + import cupy as cp + GPU_AVAILABLE = True +except ImportError: + cp = None + GPU_AVAILABLE = False + +# Try to import fast CUDA kernels from JIT compiler +_FAST_KERNELS_AVAILABLE = False +try: + if GPU_AVAILABLE: + from streaming.jit_compiler import ( + fast_rotate, fast_zoom, fast_blend, fast_hue_shift, + fast_invert, fast_ripple, get_fast_ops + ) + _FAST_KERNELS_AVAILABLE = True + print("[streaming_gpu] Fast CUDA kernels loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] Fast kernels not available: {e}", file=sys.stderr) + +# Check for hardware decode support +_HWDEC_AVAILABLE: Optional[bool] = None +_DECORD_GPU_AVAILABLE: Optional[bool] = None + + +def check_hwdec_available() -> bool: + """Check if NVIDIA hardware decode is available.""" + global _HWDEC_AVAILABLE + if _HWDEC_AVAILABLE is not None: + return _HWDEC_AVAILABLE + + try: + # Check for nvidia-smi (GPU present) + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2) + if result.returncode != 0: + _HWDEC_AVAILABLE = False + return False + + # Check for nvdec in ffmpeg + result = subprocess.run( + ["ffmpeg", "-hwaccels"], + capture_output=True, + text=True, + timeout=5 + ) + _HWDEC_AVAILABLE = "cuda" in result.stdout + except Exception: + _HWDEC_AVAILABLE = False + + return _HWDEC_AVAILABLE + + +def check_decord_gpu_available() -> bool: + """Check if decord with CUDA GPU decode is available.""" + global _DECORD_GPU_AVAILABLE + if _DECORD_GPU_AVAILABLE is not None: + return _DECORD_GPU_AVAILABLE + + try: + import decord + from decord import gpu + # Try to create a GPU context to verify CUDA support + ctx = gpu(0) + _DECORD_GPU_AVAILABLE = True + print("[streaming_gpu] decord GPU (CUDA) decode available", file=sys.stderr) + except Exception as e: + _DECORD_GPU_AVAILABLE = False + print(f"[streaming_gpu] decord GPU not available: {e}", file=sys.stderr) + + return _DECORD_GPU_AVAILABLE + + +class GPUFrame: + """ + Frame container that tracks data location (CPU/GPU). + + Enables zero-copy operations when data is already on the right device. + Lazy transfer - only moves data when actually needed. + """ + + def __init__(self, data: Union[np.ndarray, 'cp.ndarray'], on_gpu: bool = None): + self._cpu_data: Optional[np.ndarray] = None + self._gpu_data = None # Optional[cp.ndarray] + + if on_gpu is None: + # Auto-detect based on type + if GPU_AVAILABLE and isinstance(data, cp.ndarray): + self._gpu_data = data + else: + self._cpu_data = np.asarray(data) + elif on_gpu and GPU_AVAILABLE: + self._gpu_data = cp.asarray(data) if not isinstance(data, cp.ndarray) else data + else: + self._cpu_data = np.asarray(data) if isinstance(data, np.ndarray) else cp.asnumpy(data) + + @property + def cpu(self) -> np.ndarray: + """Get frame as numpy array (transfers from GPU if needed).""" + if self._cpu_data is None: + if self._gpu_data is not None and GPU_AVAILABLE: + self._cpu_data = cp.asnumpy(self._gpu_data) + else: + raise ValueError("No frame data available") + return self._cpu_data + + @property + def gpu(self): + """Get frame as CuPy array (transfers to GPU if needed).""" + if not GPU_AVAILABLE: + raise RuntimeError("GPU not available") + if self._gpu_data is None: + if self._cpu_data is not None: + self._gpu_data = cp.asarray(self._cpu_data) + else: + raise ValueError("No frame data available") + return self._gpu_data + + @property + def is_on_gpu(self) -> bool: + """Check if data is currently on GPU.""" + return self._gpu_data is not None + + @property + def shape(self) -> Tuple[int, ...]: + """Get frame shape.""" + if self._gpu_data is not None: + return self._gpu_data.shape + return self._cpu_data.shape + + @property + def dtype(self): + """Get frame dtype.""" + if self._gpu_data is not None: + return self._gpu_data.dtype + return self._cpu_data.dtype + + def numpy(self) -> np.ndarray: + """Alias for cpu property.""" + return self.cpu + + def cupy(self): + """Alias for gpu property.""" + return self.gpu + + def free_cpu(self): + """Free CPU memory (keep GPU only).""" + if self._gpu_data is not None: + self._cpu_data = None + + def free_gpu(self): + """Free GPU memory (keep CPU only).""" + if self._cpu_data is not None: + self._gpu_data = None + + +class GPUVideoSource: + """ + GPU-accelerated video source using hardware decode. + + Uses decord with CUDA GPU context for true NVDEC decode - frames + decode directly to GPU memory via CUDA. + + Falls back to FFmpeg pipe if decord GPU unavailable (slower due to CPU copy). + """ + + def __init__(self, path: str, fps: float = 30, prefer_gpu: bool = True): + self.path = Path(path) + self.fps = fps + self.prefer_gpu = prefer_gpu and GPU_AVAILABLE + self._use_decord_gpu = self.prefer_gpu and check_decord_gpu_available() + + self._frame_size: Optional[Tuple[int, int]] = None + self._duration: Optional[float] = None + self._video_fps: float = 30.0 + self._total_frames: int = 0 + self._frame_time = 1.0 / fps + self._last_read_time = -1 + self._cached_frame: Optional[GPUFrame] = None + + # Decord VideoReader with GPU context + self._vr = None + self._decord_ctx = None + + # FFmpeg fallback state + self._proc = None + self._stream_time = 0.0 + + # Initialize video source + self._init_video() + + mode = "decord-GPU" if self._use_decord_gpu else ("ffmpeg-hwaccel" if check_hwdec_available() else "ffmpeg-CPU") + print(f"[GPUVideoSource] {self.path.name}: {self._frame_size}, " + f"duration={self._duration:.1f}s, mode={mode}", file=sys.stderr) + + def _init_video(self): + """Initialize video reader (decord GPU or probe for ffmpeg).""" + if self._use_decord_gpu: + try: + from decord import VideoReader, gpu + + # Use GPU context for NVDEC hardware decode + self._decord_ctx = gpu(0) + self._vr = VideoReader(str(self.path), ctx=self._decord_ctx, num_threads=1) + + self._total_frames = len(self._vr) + self._video_fps = self._vr.get_avg_fps() + self._duration = self._total_frames / self._video_fps + + # Get frame size from first frame + first_frame = self._vr[0] + self._frame_size = (first_frame.shape[1], first_frame.shape[0]) + + print(f"[GPUVideoSource] decord GPU initialized: {self._frame_size}, " + f"{self._total_frames} frames @ {self._video_fps:.1f}fps", file=sys.stderr) + return + except Exception as e: + print(f"[GPUVideoSource] decord GPU init failed, falling back to ffmpeg: {e}", file=sys.stderr) + self._use_decord_gpu = False + self._vr = None + self._decord_ctx = None + + # FFmpeg fallback - probe video for metadata + self._probe_video() + + def _probe_video(self): + """Probe video file for metadata (FFmpeg fallback).""" + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + info = json.loads(result.stdout) + + for stream in info.get("streams", []): + if stream.get("codec_type") == "video": + self._frame_size = (stream.get("width", 720), stream.get("height", 720)) + if "duration" in stream: + self._duration = float(stream["duration"]) + elif "tags" in stream and "DURATION" in stream["tags"]: + dur_str = stream["tags"]["DURATION"] + parts = dur_str.split(":") + if len(parts) == 3: + h, m, s = parts + self._duration = int(h) * 3600 + int(m) * 60 + float(s) + # Get fps + if "r_frame_rate" in stream: + fps_str = stream["r_frame_rate"] + if "/" in fps_str: + num, den = fps_str.split("/") + self._video_fps = float(num) / float(den) + break + + if self._duration is None and "format" in info: + if "duration" in info["format"]: + self._duration = float(info["format"]["duration"]) + + if not self._frame_size: + self._frame_size = (720, 720) + if not self._duration: + self._duration = 60.0 + + self._total_frames = int(self._duration * self._video_fps) + + def _start_stream(self, seek_time: float = 0): + """Start ffmpeg decode process (fallback mode).""" + if self._proc: + self._proc.kill() + self._proc = None + + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + w, h = self._frame_size + + # Build ffmpeg command + cmd = ["ffmpeg", "-v", "error"] + + # Hardware decode if available + if check_hwdec_available(): + cmd.extend(["-hwaccel", "cuda"]) + + cmd.extend([ + "-ss", f"{seek_time:.3f}", + "-i", str(self.path), + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(self.fps), + "-" + ]) + + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._stream_time = seek_time + + def _read_frame_raw(self) -> Optional[np.ndarray]: + """Read one frame from ffmpeg pipe (fallback mode).""" + w, h = self._frame_size + frame_size = w * h * 3 + + if not self._proc or self._proc.poll() is not None: + return None + + data = self._proc.stdout.read(frame_size) + if len(data) < frame_size: + return None + + return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() + + def _read_frame_decord_gpu(self, frame_idx: int) -> Optional[GPUFrame]: + """Read frame using decord with GPU context (NVDEC, zero-copy to CuPy).""" + if self._vr is None: + return None + + try: + # Handle looping + frame_idx = frame_idx % max(1, self._total_frames) + + # Decode frame - with GPU context, this uses NVDEC + frame_tensor = self._vr[frame_idx] + + # Convert to CuPy via DLPack (zero-copy GPU transfer) + if GPU_AVAILABLE: + # decord tensors have .to_dlpack() which returns a PyCapsule + # that CuPy can consume for zero-copy GPU transfer + try: + dlpack_capsule = frame_tensor.to_dlpack() + gpu_frame = cp.from_dlpack(dlpack_capsule) + # Log success once per source + if not getattr(self, '_dlpack_logged', False): + print(f"[GPUVideoSource] DLPack zero-copy SUCCESS - frames stay on GPU", file=sys.stderr) + self._dlpack_logged = True + return GPUFrame(gpu_frame, on_gpu=True) + except Exception as dlpack_err: + # Fallback: convert via numpy (involves CPU copy) + if not getattr(self, '_dlpack_fail_logged', False): + print(f"[GPUVideoSource] DLPack FAILED ({dlpack_err}), using CPU copy fallback", file=sys.stderr) + self._dlpack_fail_logged = True + frame_np = frame_tensor.asnumpy() + return GPUFrame(frame_np, on_gpu=True) + else: + return GPUFrame(frame_tensor.asnumpy(), on_gpu=False) + + except Exception as e: + print(f"[GPUVideoSource] decord GPU read error at frame {frame_idx}: {e}", file=sys.stderr) + return None + + def read_at(self, t: float) -> Optional[GPUFrame]: + """ + Read frame at specific time. + + Returns GPUFrame with data on GPU if GPU mode enabled. + """ + # Cache check + if t == self._last_read_time and self._cached_frame is not None: + return self._cached_frame + + # Loop time for shorter videos + seek_time = t + if self._duration and self._duration > 0: + seek_time = t % self._duration + if seek_time > self._duration - 0.1: + seek_time = 0.0 + + self._last_read_time = t + + # Use decord GPU if available (NVDEC decode, zero-copy via DLPack) + if self._use_decord_gpu: + frame_idx = int(seek_time * self._video_fps) + self._cached_frame = self._read_frame_decord_gpu(frame_idx) + if self._cached_frame is not None: + # Free CPU copy if on GPU (saves memory) + if self.prefer_gpu and self._cached_frame.is_on_gpu: + self._cached_frame.free_cpu() + return self._cached_frame + + # FFmpeg fallback + need_seek = ( + self._proc is None or + self._proc.poll() is not None or + seek_time < self._stream_time - self._frame_time or + seek_time > self._stream_time + 2.0 + ) + + if need_seek: + self._start_stream(seek_time) + + # Skip frames to reach target + while self._stream_time + self._frame_time <= seek_time: + frame = self._read_frame_raw() + if frame is None: + self._start_stream(seek_time) + break + self._stream_time += self._frame_time + + # Read target frame + frame_np = self._read_frame_raw() + if frame_np is None: + return self._cached_frame + + self._stream_time += self._frame_time + + # Create GPUFrame - transfer to GPU if in GPU mode + self._cached_frame = GPUFrame(frame_np, on_gpu=self.prefer_gpu) + + # Free CPU copy if on GPU (saves memory) + if self.prefer_gpu and self._cached_frame.is_on_gpu: + self._cached_frame.free_cpu() + + return self._cached_frame + + def read(self) -> Optional[GPUFrame]: + """Read current frame.""" + if self._cached_frame is not None: + return self._cached_frame + return self.read_at(0) + + @property + def size(self) -> Tuple[int, int]: + return self._frame_size + + @property + def duration(self) -> float: + return self._duration + + def close(self): + """Close the video source.""" + if self._proc: + self._proc.kill() + self._proc = None + # Release decord resources + self._vr = None + self._decord_ctx = None + + +# GPU-aware primitive functions + +def gpu_blend(frame_a: GPUFrame, frame_b: GPUFrame, alpha: float = 0.5) -> GPUFrame: + """ + Blend two frames on GPU using fast CUDA kernel. + + Both frames stay on GPU throughout - no CPU transfer. + """ + if not GPU_AVAILABLE: + a = frame_a.cpu.astype(np.float32) + b = frame_b.cpu.astype(np.float32) + result = (a * alpha + b * (1 - alpha)).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + # Use fast CUDA kernel + if _FAST_KERNELS_AVAILABLE: + a_gpu = frame_a.gpu + b_gpu = frame_b.gpu + if a_gpu.dtype != cp.uint8: + a_gpu = cp.clip(a_gpu, 0, 255).astype(cp.uint8) + if b_gpu.dtype != cp.uint8: + b_gpu = cp.clip(b_gpu, 0, 255).astype(cp.uint8) + result = fast_blend(a_gpu, b_gpu, alpha) + return GPUFrame(result, on_gpu=True) + + # Fallback + a = frame_a.gpu.astype(cp.float32) + b = frame_b.gpu.astype(cp.float32) + result = (a * alpha + b * (1 - alpha)).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_resize(frame: GPUFrame, size: Tuple[int, int]) -> GPUFrame: + """Resize frame on GPU using fast CUDA zoom kernel.""" + import cv2 + + if not GPU_AVAILABLE or not frame.is_on_gpu: + resized = cv2.resize(frame.cpu, size) + return GPUFrame(resized, on_gpu=False) + + gpu_data = frame.gpu + h, w = gpu_data.shape[:2] + target_w, target_h = size + + # Use fast zoom kernel if same aspect ratio (pure zoom) + if _FAST_KERNELS_AVAILABLE and target_w == target_h == w == h: + # For uniform zoom we can use the zoom kernel + pass # Fall through to scipy for now - full resize needs different approach + + # CuPy doesn't have built-in resize, use scipy zoom + from cupyx.scipy import ndimage as cpndimage + + zoom_y = target_h / h + zoom_x = target_w / w + + if gpu_data.ndim == 3: + resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x, 1), order=1) + else: + resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x), order=1) + + return GPUFrame(resized, on_gpu=True) + + +def gpu_zoom(frame: GPUFrame, factor: float, cx: float = None, cy: float = None) -> GPUFrame: + """Zoom frame on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + h, w = frame.cpu.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), 0, factor) + zoomed = cv2.warpAffine(frame.cpu, M, (w, h)) + return GPUFrame(zoomed, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + zoomed = fast_zoom(frame.gpu, factor, cx=cx, cy=cy) + return GPUFrame(zoomed, on_gpu=True) + + # Fallback - basic zoom via slice and resize + return frame + + +def gpu_hue_shift(frame: GPUFrame, degrees: float) -> GPUFrame: + """Shift hue on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + hsv = cv2.cvtColor(frame.cpu, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.float32) + degrees / 2) % 180 + result = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + shifted = fast_hue_shift(gpu_data, degrees) + return GPUFrame(shifted, on_gpu=True) + + # Fallback - no GPU hue shift without fast kernels + return frame + + +def gpu_invert(frame: GPUFrame) -> GPUFrame: + """Invert colors on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = 255 - frame.cpu + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + inverted = fast_invert(gpu_data) + return GPUFrame(inverted, on_gpu=True) + + # Fallback - basic CuPy invert + result = 255 - frame.gpu + return GPUFrame(result, on_gpu=True) + + +def gpu_ripple(frame: GPUFrame, amplitude: float, frequency: float = 8, + decay: float = 2, phase: float = 0, + cx: float = None, cy: float = None) -> GPUFrame: + """Apply ripple effect on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + return frame # No CPU fallback for ripple + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + rippled = fast_ripple( + gpu_data, amplitude, + center_x=cx if cx else w/2, + center_y=cy if cy else h/2, + frequency=frequency, + decay=decay, + speed=1.0, + t=phase + ) + return GPUFrame(rippled, on_gpu=True) + + return frame + + +def gpu_contrast(frame: GPUFrame, factor: float) -> GPUFrame: + """Adjust contrast on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = np.clip((frame.cpu.astype(np.float32) - 128) * factor + 128, 0, 255).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(gpu_data) + ops.contrast(factor) + return GPUFrame(ops.get_output().copy(), on_gpu=True) + + # Fallback + result = cp.clip((frame.gpu.astype(cp.float32) - 128) * factor + 128, 0, 255).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_rotate(frame: GPUFrame, angle: float) -> GPUFrame: + """Rotate frame on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + h, w = frame.cpu.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(frame.cpu, M, (w, h)) + return GPUFrame(rotated, on_gpu=False) + + # Use fast CUDA kernel (< 1ms vs 20ms for scipy) + if _FAST_KERNELS_AVAILABLE: + rotated = fast_rotate(frame.gpu, angle) + return GPUFrame(rotated, on_gpu=True) + + # Fallback to scipy (slow) + from cupyx.scipy import ndimage as cpndimage + rotated = cpndimage.rotate(frame.gpu, angle, reshape=False, order=1) + return GPUFrame(rotated, on_gpu=True) + + +def gpu_brightness(frame: GPUFrame, factor: float) -> GPUFrame: + """Adjust brightness on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = np.clip(frame.cpu.astype(np.float32) * factor, 0, 255).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + # Use fast CUDA kernel + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(gpu_data) + ops.brightness(factor) + return GPUFrame(ops.get_output().copy(), on_gpu=True) + + # Fallback + result = cp.clip(frame.gpu.astype(cp.float32) * factor, 0, 255).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_composite(frames: list, weights: list = None) -> GPUFrame: + """ + Composite multiple frames with weights. + + All frames processed on GPU for efficiency. + """ + if not frames: + raise ValueError("No frames to composite") + + if len(frames) == 1: + return frames[0] + + if weights is None: + weights = [1.0 / len(frames)] * len(frames) + + # Normalize weights + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + + use_gpu = GPU_AVAILABLE and any(f.is_on_gpu for f in frames) + + if use_gpu: + # All on GPU + target_shape = frames[0].gpu.shape + result = cp.zeros(target_shape, dtype=cp.float32) + + for frame, weight in zip(frames, weights): + gpu_data = frame.gpu.astype(cp.float32) + if gpu_data.shape != target_shape: + # Resize to match + from cupyx.scipy import ndimage as cpndimage + h, w = target_shape[:2] + fh, fw = gpu_data.shape[:2] + zoom_factors = (h/fh, w/fw, 1) if gpu_data.ndim == 3 else (h/fh, w/fw) + gpu_data = cpndimage.zoom(gpu_data, zoom_factors, order=1) + result += gpu_data * weight + + return GPUFrame(cp.clip(result, 0, 255).astype(cp.uint8), on_gpu=True) + else: + # All on CPU + import cv2 + target_shape = frames[0].cpu.shape + result = np.zeros(target_shape, dtype=np.float32) + + for frame, weight in zip(frames, weights): + cpu_data = frame.cpu.astype(np.float32) + if cpu_data.shape != target_shape: + cpu_data = cv2.resize(cpu_data, (target_shape[1], target_shape[0])) + result += cpu_data * weight + + return GPUFrame(np.clip(result, 0, 255).astype(np.uint8), on_gpu=False) + + +# Primitive registration for streaming interpreter + +def _to_gpu_frame(img): + """Convert any image type to GPUFrame, keeping data on GPU if possible.""" + if isinstance(img, GPUFrame): + return img + # Check for CuPy array (stays on GPU) + if GPU_AVAILABLE and hasattr(img, '__cuda_array_interface__'): + # Already a CuPy array - wrap directly + return GPUFrame(img, on_gpu=True) + # Numpy or other - will be uploaded to GPU + return GPUFrame(img, on_gpu=True) + + +def get_primitives(): + """ + Get GPU-aware primitives for registration with interpreter. + + These wrap the GPU functions to work with the sexp interpreter. + All use fast CUDA kernels when available for maximum performance. + + Primitives detect CuPy arrays and keep them on GPU (no CPU round-trips). + """ + def prim_make_video_source_gpu(path: str, fps: float = 30): + """Create GPU-accelerated video source.""" + return GPUVideoSource(path, fps, prefer_gpu=True) + + def prim_gpu_blend(a, b, alpha=0.5): + """Blend two frames using fast CUDA kernel.""" + fa = _to_gpu_frame(a) + fb = _to_gpu_frame(b) + result = gpu_blend(fa, fb, alpha) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_rotate(img, angle): + """Rotate image using fast CUDA kernel (< 1ms).""" + f = _to_gpu_frame(img) + result = gpu_rotate(f, angle) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_brightness(img, factor): + """Adjust brightness using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_brightness(f, factor) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_contrast(img, factor): + """Adjust contrast using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_contrast(f, factor) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_zoom(img, factor, cx=None, cy=None): + """Zoom image using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_zoom(f, factor, cx, cy) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_hue_shift(img, degrees): + """Shift hue using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_hue_shift(f, degrees) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_invert(img): + """Invert colors using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_invert(f) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_ripple(img, amplitude, frequency=8, decay=2, phase=0, cx=None, cy=None): + """Apply ripple effect using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_ripple(f, amplitude, frequency, decay, phase, cx, cy) + return result.gpu if result.is_on_gpu else result.cpu + + return { + 'streaming-gpu:make-video-source': prim_make_video_source_gpu, + 'gpu:blend': prim_gpu_blend, + 'gpu:rotate': prim_gpu_rotate, + 'gpu:brightness': prim_gpu_brightness, + 'gpu:contrast': prim_gpu_contrast, + 'gpu:zoom': prim_gpu_zoom, + 'gpu:hue-shift': prim_gpu_hue_shift, + 'gpu:invert': prim_gpu_invert, + 'gpu:ripple': prim_gpu_ripple, + } + + +# Export +__all__ = [ + 'GPU_AVAILABLE', + 'GPUFrame', + 'GPUVideoSource', + 'gpu_blend', + 'gpu_resize', + 'gpu_rotate', + 'gpu_brightness', + 'gpu_contrast', + 'gpu_zoom', + 'gpu_hue_shift', + 'gpu_invert', + 'gpu_ripple', + 'gpu_composite', + 'get_primitives', + 'check_hwdec_available', + 'PRIMITIVES', +] + + +# Import CPU primitives from streaming.py and include them in PRIMITIVES +# This ensures audio analysis primitives are available when streaming_gpu is loaded +def _get_cpu_primitives(): + from sexp_effects.primitive_libs import streaming + return streaming.PRIMITIVES + + +PRIMITIVES = _get_cpu_primitives().copy() + +# Try to import fused kernel compiler +_FUSED_KERNELS_AVAILABLE = False +_compile_frame_pipeline = None +_compile_autonomous_pipeline = None +try: + if GPU_AVAILABLE: + from streaming.sexp_to_cuda import compile_frame_pipeline as _compile_frame_pipeline + from streaming.sexp_to_cuda import compile_autonomous_pipeline as _compile_autonomous_pipeline + _FUSED_KERNELS_AVAILABLE = True + print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] Fused kernels not available: {e}", file=sys.stderr) + + +# Fused pipeline cache +_FUSED_PIPELINE_CACHE = {} + + +def _normalize_effect_dict(effect): + """Convert effect dict with Keyword keys to string keys.""" + result = {} + for k, v in effect.items(): + # Handle Keyword objects from sexp parser + if hasattr(k, 'name'): # Keyword object + key = k.name + else: + key = str(k) + result[key] = v + return result + + +_FUSED_CALL_COUNT = 0 + +def prim_fused_pipeline(img, effects_list, **dynamic_params): + """ + Apply a fused CUDA kernel pipeline to an image. + + This compiles multiple effects into a single CUDA kernel that processes + the entire pipeline in one GPU pass, eliminating Python interpreter overhead. + + Args: + img: Input image (GPU array or numpy array) + effects_list: List of effect dicts like: + [{'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 90.0}, + {'op': 'ripple', 'amplitude': 10, ...}] + **dynamic_params: Parameters that change per-frame like: + rotate_angle=45, ripple_phase=0.5 + + Returns: + Processed image as GPU array + + Supported ops: rotate, zoom, ripple, invert, hue_shift, brightness, resize + """ + global _FUSED_CALL_COUNT + _FUSED_CALL_COUNT += 1 + if _FUSED_CALL_COUNT <= 5 or _FUSED_CALL_COUNT % 100 == 0: + print(f"[FUSED] call #{_FUSED_CALL_COUNT}, effects={len(effects_list)}, params={list(dynamic_params.keys())}", file=sys.stderr) + + # Normalize effects list - convert Keyword keys to strings + effects_list = [_normalize_effect_dict(e) for e in effects_list] + + # Handle resize separately - it changes dimensions so must happen before fused kernel + resize_ops = [e for e in effects_list if e.get('op') == 'resize'] + other_effects = [e for e in effects_list if e.get('op') != 'resize'] + + # Apply resize first if needed + if resize_ops: + for resize_op in resize_ops: + target_w = int(resize_op.get('width', 640)) + target_h = int(resize_op.get('height', 360)) + # Wrap in GPUFrame if needed + if isinstance(img, GPUFrame): + img = gpu_resize(img, (target_w, target_h)) + img = img.gpu if img.is_on_gpu else img.cpu + else: + frame = GPUFrame(img, on_gpu=hasattr(img, '__cuda_array_interface__')) + img = gpu_resize(frame, (target_w, target_h)) + img = img.gpu if img.is_on_gpu else img.cpu + + # If no other effects, just return the resized image + if not other_effects: + return img + + # Update effects list to exclude resize ops + effects_list = other_effects + + if not _FUSED_KERNELS_AVAILABLE: + # Fallback: apply effects one by one + print(f"[FUSED FALLBACK] Using fallback path for {len(effects_list)} effects", file=sys.stderr) + # Wrap in GPUFrame if needed (GPU functions expect GPUFrame objects) + if isinstance(img, GPUFrame): + result = img + else: + on_gpu = hasattr(img, '__cuda_array_interface__') + result = GPUFrame(img, on_gpu=on_gpu) + for effect in effects_list: + op = effect['op'] + if op == 'rotate': + angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) + result = gpu_rotate(result, angle) + elif op == 'zoom': + amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) + result = gpu_zoom(result, amount) + elif op == 'hue_shift': + degrees = effect.get('degrees', 0) + if abs(degrees) > 0.1: # Only apply if significant shift + result = gpu_hue_shift(result, degrees) + elif op == 'ripple': + amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) + if amplitude > 0.1: # Only apply if amplitude is significant + result = gpu_ripple(result, + amplitude=amplitude, + frequency=effect.get('frequency', 8), + decay=effect.get('decay', 2), + phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), + cx=effect.get('center_x'), + cy=effect.get('center_y')) + elif op == 'brightness': + factor = effect.get('factor', 1.0) + result = gpu_contrast(result, factor, 0) + elif op == 'invert': + amount = effect.get('amount', 0) + if amount > 0.5: # Only invert if amount > 0.5 + result = gpu_invert(result) + else: + raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") + # Return raw array, not GPUFrame (downstream expects arrays with .flags attribute) + if isinstance(result, GPUFrame): + return result.gpu if result.is_on_gpu else result.cpu + return result + + # Get image dimensions + if hasattr(img, 'shape'): + h, w = img.shape[:2] + else: + raise ValueError("Image must have shape attribute") + + # Create cache key from effects + import hashlib + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects_list]) + cache_key = f"{w}x{h}_{hashlib.md5(ops_key.encode()).hexdigest()}" + + # Compile or get cached pipeline + if cache_key not in _FUSED_PIPELINE_CACHE: + _FUSED_PIPELINE_CACHE[cache_key] = _compile_frame_pipeline(effects_list, w, h) + + pipeline = _FUSED_PIPELINE_CACHE[cache_key] + + # Ensure image is on GPU and uint8 + if hasattr(img, '__cuda_array_interface__'): + gpu_img = img + elif GPU_AVAILABLE: + gpu_img = cp.asarray(img) + else: + gpu_img = img + + # Run the fused pipeline + # Debug: log dynamic params occasionally + import random + if random.random() < 0.01: # 1% of frames + print(f"[fused] dynamic_params: {dynamic_params}", file=sys.stderr) + print(f"[fused] effects: {[(e['op'], e.get('amount'), e.get('amplitude')) for e in effects_list]}", file=sys.stderr) + return pipeline(gpu_img, **dynamic_params) + + +# Autonomous pipeline cache (separate from fused) +_AUTONOMOUS_PIPELINE_CACHE = {} + + +def prim_autonomous_pipeline(img, effects_list, dynamic_expressions, frame_num, fps=30.0): + """ + Apply a fully autonomous CUDA kernel pipeline. + + This computes ALL parameters on GPU - including time-based expressions + like sin(t), t*30, etc. Zero Python in the hot path! + + Args: + img: Input image (GPU array or numpy array) + effects_list: List of effect dicts + dynamic_expressions: Dict mapping param names to CUDA expressions: + {'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)'} + frame_num: Current frame number + fps: Frames per second (default 30) + + Returns: + Processed image as GPU array + + Note: Expressions use CUDA syntax - use sinf() not sin(), etc. + """ + # Normalize effects and expressions + effects_list = [_normalize_effect_dict(e) for e in effects_list] + dynamic_expressions = { + (k.name if hasattr(k, 'name') else str(k)): v + for k, v in dynamic_expressions.items() + } + + if not _FUSED_KERNELS_AVAILABLE or _compile_autonomous_pipeline is None: + # Fallback to regular fused pipeline with Python-computed params + import math + t = float(frame_num) / float(fps) + # Evaluate expressions in Python as fallback + dynamic_params = {} + for key, expr in dynamic_expressions.items(): + try: + # Simple eval with t and math functions + result = eval(expr.replace('f', '').replace('sin', 'math.sin').replace('cos', 'math.cos'), + {'t': t, 'math': math, 'frame_num': frame_num}) + dynamic_params[key] = result + except: + dynamic_params[key] = 0 + return prim_fused_pipeline(img, effects_list, **dynamic_params) + + # Get image dimensions + if hasattr(img, 'shape'): + h, w = img.shape[:2] + else: + raise ValueError("Image must have shape attribute") + + # Create cache key + import hashlib + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects_list]) + expr_key = str(sorted(dynamic_expressions.items())) + cache_key = f"auto_{w}x{h}_{hashlib.md5((ops_key + expr_key).encode()).hexdigest()}" + + # Compile or get cached pipeline + if cache_key not in _AUTONOMOUS_PIPELINE_CACHE: + _AUTONOMOUS_PIPELINE_CACHE[cache_key] = _compile_autonomous_pipeline( + effects_list, w, h, dynamic_expressions) + + pipeline = _AUTONOMOUS_PIPELINE_CACHE[cache_key] + + # Ensure image is on GPU + if hasattr(img, '__cuda_array_interface__'): + gpu_img = img + elif GPU_AVAILABLE: + gpu_img = cp.asarray(img) + else: + gpu_img = img + + # Run - just pass frame_num and fps, kernel does the rest! + return pipeline(gpu_img, int(frame_num), float(fps)) + + +# ============================================================ +# GPU Image Primitives (keep images on GPU) +# ============================================================ + +def gpu_make_image(w, h, color=None): + """Create a new image on GPU filled with color (default black). + + Unlike image:make-image, this keeps the image on GPU memory, + avoiding CPU<->GPU transfers in the pipeline. + """ + if not GPU_AVAILABLE: + # Fallback to CPU + import numpy as np + if color is None: + color = [0, 0, 0] + img = np.zeros((int(h), int(w), 3), dtype=np.uint8) + img[:] = color + return img + + w, h = int(w), int(h) + if color is None: + color = [0, 0, 0] + + # Create on GPU directly + img = cp.zeros((h, w, 3), dtype=cp.uint8) + img[:, :, 0] = int(color[0]) if len(color) > 0 else 0 + img[:, :, 1] = int(color[1]) if len(color) > 1 else 0 + img[:, :, 2] = int(color[2]) if len(color) > 2 else 0 + + return img + + +def gpu_gradient_image(w, h, color1=None, color2=None, direction='horizontal'): + """Create a gradient image on GPU. + + Args: + w, h: Dimensions + color1, color2: Start and end colors [r, g, b] + direction: 'horizontal', 'vertical', 'diagonal' + """ + if not GPU_AVAILABLE: + return gpu_make_image(w, h, color1) + + w, h = int(w), int(h) + if color1 is None: + color1 = [0, 0, 0] + if color2 is None: + color2 = [255, 255, 255] + + img = cp.zeros((h, w, 3), dtype=cp.uint8) + + if direction == 'horizontal': + for c in range(3): + grad = cp.linspace(color1[c], color2[c], w, dtype=cp.float32) + img[:, :, c] = grad[cp.newaxis, :].astype(cp.uint8) + elif direction == 'vertical': + for c in range(3): + grad = cp.linspace(color1[c], color2[c], h, dtype=cp.float32) + img[:, :, c] = grad[:, cp.newaxis].astype(cp.uint8) + elif direction == 'diagonal': + for c in range(3): + x_grad = cp.linspace(0, 1, w, dtype=cp.float32)[cp.newaxis, :] + y_grad = cp.linspace(0, 1, h, dtype=cp.float32)[:, cp.newaxis] + combined = (x_grad + y_grad) / 2 + img[:, :, c] = (color1[c] + (color2[c] - color1[c]) * combined).astype(cp.uint8) + + return img + + +# Add GPU-specific primitives +PRIMITIVES['fused-pipeline'] = prim_fused_pipeline +PRIMITIVES['autonomous-pipeline'] = prim_autonomous_pipeline +PRIMITIVES['gpu-make-image'] = gpu_make_image +PRIMITIVES['gpu-gradient'] = gpu_gradient_image +# (The GPU video source will be added by create_cid_primitives in the task) diff --git a/l1/sexp_effects/primitive_libs/xector.py b/l1/sexp_effects/primitive_libs/xector.py new file mode 100644 index 0000000..fb95dfd --- /dev/null +++ b/l1/sexp_effects/primitive_libs/xector.py @@ -0,0 +1,1382 @@ +""" +Xector Primitives - Parallel array operations for GPU-style data parallelism. + +Inspired by Connection Machine Lisp and hillisp. Xectors are parallel arrays +where operations automatically apply element-wise. + +Usage in sexp: + (require-primitives "xector") + + ;; Extract channels as xectors + (let* ((r (red frame)) + (g (green frame)) + (b (blue frame)) + ;; Operations are element-wise on xectors + (brightness (α+ (α* r 0.299) (α* g 0.587) (α* b 0.114)))) + ;; Reduce to scalar + (βmax brightness)) + + ;; Explicit α for element-wise, implicit also works + (α+ r 10) ;; explicit: add 10 to every element + (+ r 10) ;; implicit: same thing when r is a xector + + ;; β for reductions + (β+ r) ;; sum all elements + (βmax r) ;; maximum element + (βmean r) ;; average + +Operators: + α (alpha) - element-wise: (α+ x y) adds corresponding elements + β (beta) - reduce: (β+ x) sums all elements +""" + +import numpy as np +from typing import Union, Callable, Any + +# Try to use CuPy for GPU acceleration if available +try: + import cupy as cp + HAS_CUPY = True +except ImportError: + cp = None + HAS_CUPY = False + + +class Xector: + """ + Parallel array type for element-wise operations. + + Wraps a numpy/cupy array and provides automatic broadcasting + and element-wise operation semantics. + """ + + def __init__(self, data, shape=None): + """ + Create a Xector from data. + + Args: + data: numpy array, cupy array, scalar, or list + shape: optional shape tuple (for coordinate xectors) + """ + if isinstance(data, Xector): + self._data = data._data + self._shape = data._shape + elif isinstance(data, np.ndarray): + self._data = data.astype(np.float32) + self._shape = shape or data.shape + elif HAS_CUPY and isinstance(data, cp.ndarray): + self._data = data.astype(cp.float32) + self._shape = shape or data.shape + elif isinstance(data, (list, tuple)): + self._data = np.array(data, dtype=np.float32) + self._shape = shape or self._data.shape + else: + # Scalar - will broadcast + self._data = np.float32(data) + self._shape = shape or () + + @property + def data(self): + return self._data + + @property + def shape(self): + return self._shape + + def __len__(self): + return self._data.size + + def __repr__(self): + if self._data.size <= 10: + return f"Xector({self._data})" + return f"Xector(shape={self._shape}, size={self._data.size})" + + def to_numpy(self): + """Convert to numpy array.""" + if HAS_CUPY and isinstance(self._data, cp.ndarray): + return cp.asnumpy(self._data) + return self._data + + def to_gpu(self): + """Move to GPU if available.""" + if HAS_CUPY and not isinstance(self._data, cp.ndarray): + self._data = cp.asarray(self._data) + return self + + # Arithmetic operators - enable implicit element-wise ops + def __add__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data + other_data, self._shape) + + def __radd__(self, other): + return Xector(other + self._data, self._shape) + + def __sub__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data - other_data, self._shape) + + def __rsub__(self, other): + return Xector(other - self._data, self._shape) + + def __mul__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data * other_data, self._shape) + + def __rmul__(self, other): + return Xector(other * self._data, self._shape) + + def __truediv__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data / other_data, self._shape) + + def __rtruediv__(self, other): + return Xector(other / self._data, self._shape) + + def __pow__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data ** other_data, self._shape) + + def __neg__(self): + return Xector(-self._data, self._shape) + + def __abs__(self): + return Xector(np.abs(self._data), self._shape) + + # Comparison operators - return boolean xectors + def __lt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data < other_data, self._shape) + + def __le__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data <= other_data, self._shape) + + def __gt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data > other_data, self._shape) + + def __ge__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data >= other_data, self._shape) + + def __eq__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data == other_data, self._shape) + + def __ne__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data != other_data, self._shape) + + +def _unwrap(x): + """Unwrap Xector to underlying data, or return as-is.""" + if isinstance(x, Xector): + return x._data + return x + + +def _wrap(data, shape=None): + """Wrap result in Xector if it's an array.""" + if isinstance(data, (np.ndarray,)) or (HAS_CUPY and isinstance(data, cp.ndarray)): + return Xector(data, shape) + return data + + +# ============================================================================= +# Frame/Xector Conversion +# ============================================================================= +# NOTE: red, green, blue, gray are derived in derived.sexp using (channel frame n) + +def xector_from_frame(frame): + """Convert entire frame to xector (flattened RGB). (xector frame) -> Xector""" + if isinstance(frame, np.ndarray): + return Xector(frame.flatten().astype(np.float32), frame.shape) + raise TypeError(f"Expected frame array, got {type(frame)}") + + +def xector_to_frame(x, shape=None): + """Convert xector back to frame. (to-frame x) or (to-frame x shape) -> frame""" + data = _unwrap(x) + if shape is None and isinstance(x, Xector): + shape = x._shape + if shape is None: + raise ValueError("Shape required to convert xector to frame") + return np.clip(data, 0, 255).reshape(shape).astype(np.uint8) + + +# ============================================================================= +# Coordinate Generators +# ============================================================================= +# NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived +# in derived.sexp using iota, tile, repeat primitives + + +# ============================================================================= +# Alpha (α) - Element-wise Operations +# ============================================================================= + +def alpha_lift(fn): + """Lift a scalar function to work element-wise on xectors.""" + def lifted(*args): + # Check if any arg is a Xector + has_xector = any(isinstance(a, Xector) for a in args) + if not has_xector: + return fn(*args) + + # Get shape from first xector + shape = None + for a in args: + if isinstance(a, Xector): + shape = a._shape + break + + # Unwrap all args + unwrapped = [_unwrap(a) for a in args] + + # Apply function + result = fn(*unwrapped) + + return _wrap(result, shape) + + return lifted + + +# Element-wise math operations +def alpha_add(*args): + """Element-wise addition. (α+ a b ...) -> Xector""" + if len(args) == 0: + return 0 + result = _unwrap(args[0]) + for a in args[1:]: + result = result + _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_sub(a, b=None): + """Element-wise subtraction. (α- a b) -> Xector""" + if b is None: + return Xector(-_unwrap(a)) if isinstance(a, Xector) else -a + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) - _unwrap(b), shape) + + +def alpha_mul(*args): + """Element-wise multiplication. (α* a b ...) -> Xector""" + if len(args) == 0: + return 1 + result = _unwrap(args[0]) + for a in args[1:]: + result = result * _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_div(a, b): + """Element-wise division. (α/ a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) / _unwrap(b), shape) + + +def alpha_pow(a, b): + """Element-wise power. (α** a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) ** _unwrap(b), shape) + + +def alpha_sqrt(x): + """Element-wise square root. (αsqrt x) -> Xector""" + return _wrap(np.sqrt(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_abs(x): + """Element-wise absolute value. (αabs x) -> Xector""" + return _wrap(np.abs(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_sin(x): + """Element-wise sine. (αsin x) -> Xector""" + return _wrap(np.sin(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_cos(x): + """Element-wise cosine. (αcos x) -> Xector""" + return _wrap(np.cos(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_exp(x): + """Element-wise exponential. (αexp x) -> Xector""" + return _wrap(np.exp(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_log(x): + """Element-wise natural log. (αlog x) -> Xector""" + return _wrap(np.log(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_clamp is derived in derived.sexp as (max2 lo (min2 hi x)) + +def alpha_min(a, b): + """Element-wise minimum. (αmin a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.minimum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_max(a, b): + """Element-wise maximum. (αmax a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.maximum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_mod(a, b): + """Element-wise modulo. (αmod a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) % _unwrap(b), shape) + + +def alpha_floor(x): + """Element-wise floor. (αfloor x) -> Xector""" + return _wrap(np.floor(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_ceil(x): + """Element-wise ceiling. (αceil x) -> Xector""" + return _wrap(np.ceil(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_round(x): + """Element-wise round. (αround x) -> Xector""" + return _wrap(np.round(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_sq is derived in derived.sexp as (* x x) + +# Comparison operators (return boolean xectors) +def alpha_lt(a, b): + """Element-wise less than. (α< a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) < _unwrap(b), shape) + + +def alpha_le(a, b): + """Element-wise less-or-equal. (α<= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) <= _unwrap(b), shape) + + +def alpha_gt(a, b): + """Element-wise greater than. (α> a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) > _unwrap(b), shape) + + +def alpha_ge(a, b): + """Element-wise greater-or-equal. (α>= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) >= _unwrap(b), shape) + + +def alpha_eq(a, b): + """Element-wise equality. (α= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) == _unwrap(b), shape) + + +# Logical operators +def alpha_and(a, b): + """Element-wise logical and. (αand a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_and(_unwrap(a), _unwrap(b)), shape) + + +def alpha_or(a, b): + """Element-wise logical or. (αor a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_or(_unwrap(a), _unwrap(b)), shape) + + +def alpha_not(x): + """Element-wise logical not. (αnot x) -> Xector[bool]""" + return _wrap(np.logical_not(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# ============================================================================= +# Beta (β) - Reduction Operations +# ============================================================================= + +def beta_add(x): + """Sum all elements. (β+ x) -> scalar""" + return float(np.sum(_unwrap(x))) + + +def beta_mul(x): + """Product of all elements. (β* x) -> scalar""" + return float(np.prod(_unwrap(x))) + + +def beta_min(x): + """Minimum element. (βmin x) -> scalar""" + return float(np.min(_unwrap(x))) + + +def beta_max(x): + """Maximum element. (βmax x) -> scalar""" + return float(np.max(_unwrap(x))) + + +def beta_mean(x): + """Mean of all elements. (βmean x) -> scalar""" + return float(np.mean(_unwrap(x))) + + +def beta_std(x): + """Standard deviation. (βstd x) -> scalar""" + return float(np.std(_unwrap(x))) + + +def beta_count(x): + """Count of elements. (βcount x) -> scalar""" + return int(np.size(_unwrap(x))) + + +def beta_any(x): + """True if any element is truthy. (βany x) -> bool""" + return bool(np.any(_unwrap(x))) + + +def beta_all(x): + """True if all elements are truthy. (βall x) -> bool""" + return bool(np.all(_unwrap(x))) + + +# ============================================================================= +# Conditional / Selection +# ============================================================================= + +def xector_where(cond, true_val, false_val): + """ + Conditional select. (where cond true-val false-val) -> Xector + + Like numpy.where - selects elements based on condition. + """ + cond_data = _unwrap(cond) + true_data = _unwrap(true_val) + false_data = _unwrap(false_val) + + # Get shape from condition or values + shape = None + for x in [cond, true_val, false_val]: + if isinstance(x, Xector): + shape = x._shape + break + + result = np.where(cond_data, true_data, false_data) + return _wrap(result, shape) + + +# NOTE: fill, zeros, ones are derived in derived.sexp using iota + +def xector_rand(size_or_frame): + """Create xector of random values [0,1). (rand-x frame) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector(np.random.random(size).astype(np.float32), shape) + + +def xector_randn(size_or_frame, mean=0, std=1): + """Create xector of normal random values. (randn-x frame) or (randn-x frame mean std) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector((np.random.randn(size) * std + mean).astype(np.float32), shape) + + +# ============================================================================= +# Type checking +# ============================================================================= + +def is_xector(x): + """Check if x is a Xector. (xector? x) -> bool""" + return isinstance(x, Xector) + + +# ============================================================================= +# CORE PRIMITIVES: gather, scatter, group-reduce, reshape +# These are the fundamental operations everything else builds on. +# ============================================================================= + +def xector_gather(data, indices): + """ + Parallel index lookup. (gather data indices) -> Xector + + For each index in indices, look up the corresponding value in data. + This is the fundamental operation for remapping/resampling. + + Example: + (gather [10 20 30 40] [2 0 1 2]) ; -> [30 10 20 30] + """ + data_arr = _unwrap(data) + idx_arr = _unwrap(indices).astype(np.int32) + + # Flatten data for 1D indexing + flat_data = data_arr.flatten() + + # Clip indices to valid range + idx_clipped = np.clip(idx_arr, 0, len(flat_data) - 1) + + result = flat_data[idx_clipped] + shape = indices._shape if isinstance(indices, Xector) else None + return Xector(result, shape) + + +def xector_gather_2d(data, row_indices, col_indices): + """ + 2D parallel index lookup. (gather-2d data rows cols) -> Xector + + For each (row, col) pair, look up the value in 2D data. + Essential for grid/cell operations. + + Example: + (gather-2d image-lum cell-rows cell-cols) + """ + data_arr = _unwrap(data) + row_arr = _unwrap(row_indices).astype(np.int32) + col_arr = _unwrap(col_indices).astype(np.int32) + + # Get data shape + if isinstance(data, Xector) and data._shape and len(data._shape) >= 2: + h, w = data._shape[:2] + data_2d = data_arr.reshape(h, w) + elif len(data_arr.shape) >= 2: + h, w = data_arr.shape[:2] + data_2d = data_arr.reshape(h, w) if data_arr.ndim == 1 else data_arr + else: + # Assume square + size = int(np.sqrt(len(data_arr))) + h, w = size, size + data_2d = data_arr.reshape(h, w) + + # Clip indices + row_clipped = np.clip(row_arr, 0, h - 1) + col_clipped = np.clip(col_arr, 0, w - 1) + + result = data_2d[row_clipped.flatten(), col_clipped.flatten()] + shape = row_indices._shape if isinstance(row_indices, Xector) else None + return Xector(result, shape) + + +def xector_scatter(indices, values, size): + """ + Parallel index write. (scatter indices values size) -> Xector + + Create a new xector of given size, writing values at indices. + Later writes overwrite earlier ones at same index. + + Example: + (scatter [0 2 4] [10 20 30] 5) ; -> [10 0 20 0 30] + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + idx_clipped = np.clip(idx_arr, 0, int(size) - 1) + result[idx_clipped] = val_arr + + return Xector(result, (int(size),)) + + +def xector_scatter_add(indices, values, size): + """ + Parallel index accumulate. (scatter-add indices values size) -> Xector + + Like scatter, but adds to existing values instead of overwriting. + Useful for histograms, pooling reductions. + + Example: + (scatter-add [0 0 1] [1 2 3] 3) ; -> [3 3 0] (1+2 at index 0) + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + np.add.at(result, np.clip(idx_arr, 0, int(size) - 1), val_arr) + + return Xector(result, (int(size),)) + + +def xector_group_reduce(values, group_indices, num_groups, op='mean'): + """ + Reduce values by group. (group-reduce values groups num-groups op) -> Xector + + Groups values by group_indices and reduces each group. + This is the primitive for pooling operations. + + Args: + values: Xector of values to reduce + group_indices: Xector of group assignments (integers) + num_groups: Number of groups (output size) + op: 'mean', 'sum', 'max', 'min' + + Example: + ; Pool 4 values into 2 groups + (group-reduce [1 2 3 4] [0 0 1 1] 2 "mean") ; -> [1.5 3.5] + """ + val_arr = _unwrap(values).flatten() + grp_arr = _unwrap(group_indices).astype(np.int32).flatten() + n = int(num_groups) + + if op == 'sum': + result = np.zeros(n, dtype=np.float32) + np.add.at(result, grp_arr, val_arr) + elif op == 'mean': + sums = np.zeros(n, dtype=np.float32) + counts = np.zeros(n, dtype=np.float32) + np.add.at(sums, grp_arr, val_arr) + np.add.at(counts, grp_arr, 1) + result = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + elif op == 'max': + result = np.full(n, -np.inf, dtype=np.float32) + np.maximum.at(result, grp_arr, val_arr) + result[result == -np.inf] = 0 + elif op == 'min': + result = np.full(n, np.inf, dtype=np.float32) + np.minimum.at(result, grp_arr, val_arr) + result[result == np.inf] = 0 + else: + raise ValueError(f"Unknown reduce op: {op}") + + return Xector(result, (n,)) + + +def xector_reshape(x, *dims): + """ + Reshape xector. (reshape x h w) or (reshape x n) -> Xector + + Changes the logical shape of the xector without changing data. + """ + data = _unwrap(x) + if len(dims) == 1: + new_shape = (int(dims[0]),) + else: + new_shape = tuple(int(d) for d in dims) + + return Xector(data.reshape(-1), new_shape) + + +def xector_shape(x): + """Get shape of xector. (shape x) -> list""" + if isinstance(x, Xector): + return list(x._shape) if x._shape else [len(x)] + if isinstance(x, np.ndarray): + return list(x.shape) + return [] + + +def xector_len(x): + """Get length of xector. (xlen x) -> int""" + return len(_unwrap(x).flatten()) + + +def xector_iota(n): + """ + Generate indices 0 to n-1. (iota n) -> Xector + + Fundamental for generating coordinate xectors. + + Example: + (iota 5) ; -> [0 1 2 3 4] + """ + return Xector(np.arange(int(n), dtype=np.float32), (int(n),)) + + +def xector_repeat(x, n): + """ + Repeat each element n times. (repeat x n) -> Xector + + Example: + (repeat [1 2 3] 2) ; -> [1 1 2 2 3 3] + """ + data = _unwrap(x) + result = np.repeat(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +def xector_tile(x, n): + """ + Tile entire xector n times. (tile x n) -> Xector + + Example: + (tile [1 2 3] 2) ; -> [1 2 3 1 2 3] + """ + data = _unwrap(x) + result = np.tile(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +# ============================================================================= +# 2D Grid Helpers (built on primitives above) +# ============================================================================= + +def xector_cell_indices(frame, cell_size): + """ + Compute cell index for each pixel. (cell-indices frame cell-size) -> Xector + + Returns flat index of which cell each pixel belongs to. + This is the bridge between pixel-space and cell-space. + """ + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y = np.repeat(np.arange(h), w) # [0,0,0..., 1,1,1..., ...] + x = np.tile(np.arange(w), h) # [0,1,2..., 0,1,2..., ...] + + cell_row = y // cell_size + cell_col = x // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + cell_idx = np.clip(cell_idx, 0, rows * cols - 1) + + return Xector(cell_idx.astype(np.float32), (h, w)) + + +def xector_local_x(frame, cell_size): + """ + X position within each cell [0, cell_size). (local-x frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + x = np.tile(np.arange(w), h) + local = (x % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y(frame, cell_size): + """ + Y position within each cell [0, cell_size). (local-y frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + y = np.repeat(np.arange(h), w) + local = (y % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_x_norm(frame, cell_size): + """ + Normalized X within cell [0, 1]. (local-x-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + x = np.tile(np.arange(w), h) + local = ((x % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y_norm(frame, cell_size): + """ + Normalized Y within cell [0, 1]. (local-y-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + y = np.repeat(np.arange(h), w) + local = ((y % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_pool_frame(frame, cell_size, op='mean'): + """ + Pool frame to cell values. (pool-frame frame cell-size) -> (r, g, b, lum) Xectors + + Returns tuple of xectors: (red, green, blue, luminance) for cells. + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y = np.repeat(np.arange(h), w) + x = np.tile(np.arange(w), h) + cell_row = np.clip(y // cs, 0, rows - 1) + cell_col = np.clip(x // cs, 0, cols - 1) + cell_idx = cell_row * cols + cell_col + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(np.float32) + g_flat = frame[:, :, 1].flatten().astype(np.float32) + b_flat = frame[:, :, 2].flatten().astype(np.float32) + + # Pool each channel + def pool_channel(data): + sums = np.zeros(num_cells, dtype=np.float32) + counts = np.zeros(num_cells, dtype=np.float32) + np.add.at(sums, cell_idx, data) + np.add.at(counts, cell_idx, 1) + return np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + shape = (rows, cols) + return (Xector(r_pooled, shape), + Xector(g_pooled, shape), + Xector(b_pooled, shape), + Xector(lum, shape)) + + +def xector_cell_row(frame, cell_size): + """ + Cell row index for each pixel. (cell-row frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + + y = np.repeat(np.arange(h), w) + cell_row = np.clip(y // cs, 0, rows - 1).astype(np.float32) + return Xector(cell_row, (h, w)) + + +def xector_cell_col(frame, cell_size): + """ + Cell column index for each pixel. (cell-col frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + cols = w // cs + + x = np.tile(np.arange(w), h) + cell_col = np.clip(x // cs, 0, cols - 1).astype(np.float32) + return Xector(cell_col, (h, w)) + + +def xector_num_cells(frame, cell_size): + """Number of cells. (num-cells frame cell-size) -> (rows, cols, total)""" + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + return (rows, cols, rows * cols) + + +# ============================================================================= +# Scan (Prefix Operations) - cumulative reductions +# ============================================================================= + +def xector_scan_add(x, axis=None): + """ + Cumulative sum (prefix sum). (scan+ x) or (scan+ x :axis 0) + + Returns array where each element is sum of all previous elements. + Useful for integral images, cumulative effects. + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None: + # Reshape to 2D for axis operation + if shape and len(shape) == 2: + result = np.cumsum(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumsum(data, axis=int(axis)) + else: + result = np.cumsum(data) + + return _wrap(result, shape) + + +def xector_scan_mul(x, axis=None): + """Cumulative product. (scan* x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.cumprod(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumprod(data) + + return _wrap(result, shape) + + +def xector_scan_max(x, axis=None): + """Cumulative maximum. (scan-max x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.maximum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.maximum.accumulate(data) + + return _wrap(result, shape) + + +def xector_scan_min(x, axis=None): + """Cumulative minimum. (scan-min x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.minimum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.minimum.accumulate(data) + + return _wrap(result, shape) + + +# ============================================================================= +# Outer Product - Cartesian operations +# ============================================================================= + +def xector_outer(x, y, op='*'): + """ + Outer product. (outer x y) or (outer x y :op '+') + + Creates 2D result where result[i,j] = op(x[i], y[j]). + Default is multiplication (*). + + Useful for generating 2D patterns from 1D vectors. + """ + x_data = _unwrap(x) + y_data = _unwrap(y) + + ops = { + '*': np.multiply, + '+': np.add, + '-': np.subtract, + '/': np.divide, + 'max': np.maximum, + 'min': np.minimum, + 'and': np.logical_and, + 'or': np.logical_or, + 'xor': np.logical_xor, + } + + op_fn = ops.get(op, np.multiply) + result = op_fn.outer(x_data.flatten(), y_data.flatten()) + + # Return as xector with 2D shape + h, w = len(x_data.flatten()), len(y_data.flatten()) + return _wrap(result.flatten(), (h, w)) + + +def xector_outer_add(x, y): + """Outer sum. (outer+ x y) -> result[i,j] = x[i] + y[j]""" + return xector_outer(x, y, '+') + + +def xector_outer_mul(x, y): + """Outer product. (outer* x y) -> result[i,j] = x[i] * y[j]""" + return xector_outer(x, y, '*') + + +def xector_outer_max(x, y): + """Outer max. (outer-max x y) -> result[i,j] = max(x[i], y[j])""" + return xector_outer(x, y, 'max') + + +def xector_outer_min(x, y): + """Outer min. (outer-min x y) -> result[i,j] = min(x[i], y[j])""" + return xector_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - dimensional reductions +# ============================================================================= + +def xector_reduce_axis(x, op='sum', axis=0): + """ + Reduce along an axis. (reduce-axis x :op 'sum' :axis 0) + + ops: 'sum', 'mean', 'max', 'min', 'prod', 'std' + axis: 0 (rows), 1 (columns) + + For a frame-sized xector (H*W): + axis=0: reduce across rows -> W values (one per column) + axis=1: reduce across columns -> H values (one per row) + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if shape is None or len(shape) != 2: + # Can't do axis reduction without 2D shape + raise ValueError("reduce-axis requires 2D xector (with shape)") + + h, w = shape + data_2d = data.reshape(h, w) + axis = int(axis) + + ops = { + 'sum': lambda d, a: np.sum(d, axis=a), + '+': lambda d, a: np.sum(d, axis=a), + 'mean': lambda d, a: np.mean(d, axis=a), + 'max': lambda d, a: np.max(d, axis=a), + 'min': lambda d, a: np.min(d, axis=a), + 'prod': lambda d, a: np.prod(d, axis=a), + '*': lambda d, a: np.prod(d, axis=a), + 'std': lambda d, a: np.std(d, axis=a), + } + + op_fn = ops.get(op, ops['sum']) + result = op_fn(data_2d, axis) + + # Result shape: if axis=0, shape is (w,); if axis=1, shape is (h,) + new_shape = (w,) if axis == 0 else (h,) + return _wrap(result.flatten(), new_shape) + + +def xector_sum_axis(x, axis=0): + """Sum along axis. (sum-axis x :axis 0)""" + return xector_reduce_axis(x, 'sum', axis) + + +def xector_mean_axis(x, axis=0): + """Mean along axis. (mean-axis x :axis 0)""" + return xector_reduce_axis(x, 'mean', axis) + + +def xector_max_axis(x, axis=0): + """Max along axis. (max-axis x :axis 0)""" + return xector_reduce_axis(x, 'max', axis) + + +def xector_min_axis(x, axis=0): + """Min along axis. (min-axis x :axis 0)""" + return xector_reduce_axis(x, 'min', axis) + + +# ============================================================================= +# Windowed Operations - sliding window computations +# ============================================================================= + +def xector_window(x, size, op='mean', stride=1): + """ + Sliding window operation. (window x size :op 'mean' :stride 1) + + Applies reduction over sliding windows of given size. + ops: 'sum', 'mean', 'max', 'min' + + For 1D: windows slide along the array + For 2D (with shape): windows are size x size squares + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + size = int(size) + stride = int(stride) + + ops = { + 'sum': np.sum, + 'mean': np.mean, + 'max': np.max, + 'min': np.min, + 'std': np.std, + } + op_fn = ops.get(op, np.mean) + + if shape and len(shape) == 2: + # 2D sliding window + h, w = shape + data_2d = data.reshape(h, w) + + # Use stride tricks for efficient windowing + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + result = np.zeros((out_h, out_w)) + for i in range(out_h): + for j in range(out_w): + window = data_2d[i*stride:i*stride+size, j*stride:j*stride+size] + result[i, j] = op_fn(window) + + return _wrap(result.flatten(), (out_h, out_w)) + else: + # 1D sliding window + n = len(data) + out_n = (n - size) // stride + 1 + result = np.array([op_fn(data[i*stride:i*stride+size]) for i in range(out_n)]) + return _wrap(result, (out_n,)) + + +def xector_window_sum(x, size, stride=1): + """Sliding window sum. (window-sum x size)""" + return xector_window(x, size, 'sum', stride) + + +def xector_window_mean(x, size, stride=1): + """Sliding window mean. (window-mean x size)""" + return xector_window(x, size, 'mean', stride) + + +def xector_window_max(x, size, stride=1): + """Sliding window max. (window-max x size)""" + return xector_window(x, size, 'max', stride) + + +def xector_window_min(x, size, stride=1): + """Sliding window min. (window-min x size)""" + return xector_window(x, size, 'min', stride) + + +def xector_integral_image(frame): + """ + Compute integral image (summed area table). (integral-image frame) + + Each pixel contains sum of all pixels above and to the left. + Enables O(1) box blur at any radius. + + Returns xector with same shape as frame's luminance. + """ + if hasattr(frame, 'shape') and len(frame.shape) == 3: + # Convert frame to grayscale + gray = np.mean(frame, axis=2) + else: + data = _unwrap(frame) + shape = frame._shape if isinstance(frame, Xector) else None + if shape and len(shape) == 2: + gray = data.reshape(shape) + else: + gray = data + + integral = np.cumsum(np.cumsum(gray, axis=0), axis=1) + h, w = integral.shape + return _wrap(integral.flatten(), (h, w)) + + +def xector_box_blur_fast(integral, x, y, radius, width, height): + """ + Fast box blur using integral image. (box-blur-fast integral x y radius w h) + + Given pre-computed integral image, compute average in box centered at (x,y). + O(1) regardless of radius. + """ + integral_data = _unwrap(integral) + shape = integral._shape if isinstance(integral, Xector) else None + + if shape is None or len(shape) != 2: + raise ValueError("box-blur-fast requires 2D integral image") + + h, w = shape + integral_2d = integral_data.reshape(h, w) + + radius = int(radius) + x, y = int(x), int(y) + + # Clamp coordinates + x1 = max(0, x - radius) + y1 = max(0, y - radius) + x2 = min(w - 1, x + radius) + y2 = min(h - 1, y + radius) + + # Sum in rectangle using integral image + total = integral_2d[y2, x2] + if x1 > 0: + total -= integral_2d[y2, x1 - 1] + if y1 > 0: + total -= integral_2d[y1 - 1, x2] + if x1 > 0 and y1 > 0: + total += integral_2d[y1 - 1, x1 - 1] + + count = (x2 - x1 + 1) * (y2 - y1 + 1) + return total / max(count, 1) + + +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + +PRIMITIVES = { + # Frame/Xector conversion + # NOTE: red, green, blue, gray, rgb are derived in derived.sexp using (channel frame n) + 'xector': xector_from_frame, + 'to-frame': xector_to_frame, + + # Coordinate generators + # NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived + # in derived.sexp using iota, tile, repeat primitives + + # Alpha (α) - element-wise operations + 'α+': alpha_add, + 'α-': alpha_sub, + 'α*': alpha_mul, + 'α/': alpha_div, + 'α**': alpha_pow, + 'αsqrt': alpha_sqrt, + 'αabs': alpha_abs, + 'αsin': alpha_sin, + 'αcos': alpha_cos, + 'αexp': alpha_exp, + 'αlog': alpha_log, + # NOTE: αclamp is derived in derived.sexp as (max2 lo (min2 hi x)) + 'αmin': alpha_min, + 'αmax': alpha_max, + 'αmod': alpha_mod, + 'αfloor': alpha_floor, + 'αceil': alpha_ceil, + 'αround': alpha_round, + # NOTE: α² / αsq is derived in derived.sexp as (* x x) + + # Alpha comparison + 'α<': alpha_lt, + 'α<=': alpha_le, + 'α>': alpha_gt, + 'α>=': alpha_ge, + 'α=': alpha_eq, + + # Alpha logical + 'αand': alpha_and, + 'αor': alpha_or, + 'αnot': alpha_not, + + # ASCII fallbacks for α + 'alpha+': alpha_add, + 'alpha-': alpha_sub, + 'alpha*': alpha_mul, + 'alpha/': alpha_div, + 'alpha**': alpha_pow, + 'alpha-sqrt': alpha_sqrt, + 'alpha-abs': alpha_abs, + 'alpha-sin': alpha_sin, + 'alpha-cos': alpha_cos, + 'alpha-exp': alpha_exp, + 'alpha-log': alpha_log, + 'alpha-min': alpha_min, + 'alpha-max': alpha_max, + 'alpha-mod': alpha_mod, + 'alpha-floor': alpha_floor, + 'alpha-ceil': alpha_ceil, + 'alpha-round': alpha_round, + 'alpha<': alpha_lt, + 'alpha<=': alpha_le, + 'alpha>': alpha_gt, + 'alpha>=': alpha_ge, + 'alpha=': alpha_eq, + 'alpha-and': alpha_and, + 'alpha-or': alpha_or, + 'alpha-not': alpha_not, + + # Beta (β) - reduction operations + 'β+': beta_add, + 'β*': beta_mul, + 'βmin': beta_min, + 'βmax': beta_max, + 'βmean': beta_mean, + 'βstd': beta_std, + 'βcount': beta_count, + 'βany': beta_any, + 'βall': beta_all, + + # ASCII fallbacks for β + 'beta+': beta_add, + 'beta*': beta_mul, + 'beta-min': beta_min, + 'beta-max': beta_max, + 'beta-mean': beta_mean, + 'beta-std': beta_std, + 'beta-count': beta_count, + 'beta-any': beta_any, + 'beta-all': beta_all, + + # Convenience aliases + 'sum': beta_add, + 'product': beta_mul, + 'mean': beta_mean, + + # Conditional / Selection + 'where': xector_where, + # NOTE: fill, zeros, ones are derived in derived.sexp using iota + 'rand-x': xector_rand, + 'randn-x': xector_randn, + + # Type checking + 'xector?': is_xector, + + # =========================================== + # CORE PRIMITIVES - fundamental operations + # =========================================== + + # Gather/Scatter - parallel indexing + 'gather': xector_gather, + 'gather-2d': xector_gather_2d, + 'scatter': xector_scatter, + 'scatter-add': xector_scatter_add, + + # Group reduce - pooling primitive + 'group-reduce': xector_group_reduce, + + # Shape operations + 'reshape': xector_reshape, + 'shape': xector_shape, + 'xlen': xector_len, + + # Index generation + 'iota': xector_iota, + 'repeat': xector_repeat, + 'tile': xector_tile, + + # Cell/Grid helpers (built on primitives) + 'cell-indices': xector_cell_indices, + 'cell-row': xector_cell_row, + 'cell-col': xector_cell_col, + 'local-x': xector_local_x, + 'local-y': xector_local_y, + 'local-x-norm': xector_local_x_norm, + 'local-y-norm': xector_local_y_norm, + 'pool-frame': xector_pool_frame, + 'num-cells': xector_num_cells, + + # Scan (prefix) operations - cumulative reductions + 'scan+': xector_scan_add, + 'scan*': xector_scan_mul, + 'scan-max': xector_scan_max, + 'scan-min': xector_scan_min, + 'scan-add': xector_scan_add, + 'scan-mul': xector_scan_mul, + + # Outer product - Cartesian operations + 'outer': xector_outer, + 'outer+': xector_outer_add, + 'outer*': xector_outer_mul, + 'outer-add': xector_outer_add, + 'outer-mul': xector_outer_mul, + 'outer-max': xector_outer_max, + 'outer-min': xector_outer_min, + + # Reduce with axis - dimensional reductions + 'reduce-axis': xector_reduce_axis, + 'sum-axis': xector_sum_axis, + 'mean-axis': xector_mean_axis, + 'max-axis': xector_max_axis, + 'min-axis': xector_min_axis, + + # Windowed operations - sliding window computations + 'window': xector_window, + 'window-sum': xector_window_sum, + 'window-mean': xector_window_mean, + 'window-max': xector_window_max, + 'window-min': xector_window_min, + + # Integral image - for fast box blur + 'integral-image': xector_integral_image, + 'box-blur-fast': xector_box_blur_fast, +} diff --git a/l1/sexp_effects/primitives.py b/l1/sexp_effects/primitives.py new file mode 100644 index 0000000..9a50356 --- /dev/null +++ b/l1/sexp_effects/primitives.py @@ -0,0 +1,3075 @@ +""" +Safe Primitives for S-Expression Effects + +These are the building blocks that user-defined effects can use. +All primitives operate only on image data - no filesystem, network, etc. +""" + +import numpy as np +import cv2 +from typing import Any, Callable, Dict, List, Tuple, Optional +from dataclasses import dataclass +import math + + +@dataclass +class ZoneContext: + """Context for a single cell/zone in ASCII art grid.""" + row: int + col: int + row_norm: float # Normalized row position 0-1 + col_norm: float # Normalized col position 0-1 + luminance: float # Cell luminance 0-1 + saturation: float # Cell saturation 0-1 + hue: float # Cell hue 0-360 + r: float # Red component 0-1 + g: float # Green component 0-1 + b: float # Blue component 0-1 + + +class DeterministicRNG: + """Seeded RNG for reproducible effects.""" + + def __init__(self, seed: int = 42): + self._rng = np.random.RandomState(seed) + + def random(self, low: float = 0, high: float = 1) -> float: + return self._rng.uniform(low, high) + + def randint(self, low: int, high: int) -> int: + return self._rng.randint(low, high + 1) + + def gaussian(self, mean: float = 0, std: float = 1) -> float: + return self._rng.normal(mean, std) + + +# Global RNG instance (reset per frame with seed param) +_rng = DeterministicRNG() + + +def reset_rng(seed: int): + """Reset the global RNG with a new seed.""" + global _rng + _rng = DeterministicRNG(seed) + + +# ============================================================================= +# Color Names (FFmpeg/X11 compatible) +# ============================================================================= + +NAMED_COLORS = { + # Basic colors + "black": (0, 0, 0), + "white": (255, 255, 255), + "red": (255, 0, 0), + "green": (0, 128, 0), + "blue": (0, 0, 255), + "yellow": (255, 255, 0), + "cyan": (0, 255, 255), + "magenta": (255, 0, 255), + + # Grays + "gray": (128, 128, 128), + "grey": (128, 128, 128), + "darkgray": (169, 169, 169), + "darkgrey": (169, 169, 169), + "lightgray": (211, 211, 211), + "lightgrey": (211, 211, 211), + "dimgray": (105, 105, 105), + "dimgrey": (105, 105, 105), + "silver": (192, 192, 192), + + # Reds + "darkred": (139, 0, 0), + "firebrick": (178, 34, 34), + "crimson": (220, 20, 60), + "indianred": (205, 92, 92), + "lightcoral": (240, 128, 128), + "salmon": (250, 128, 114), + "darksalmon": (233, 150, 122), + "lightsalmon": (255, 160, 122), + "tomato": (255, 99, 71), + "orangered": (255, 69, 0), + "coral": (255, 127, 80), + + # Oranges + "orange": (255, 165, 0), + "darkorange": (255, 140, 0), + + # Yellows + "gold": (255, 215, 0), + "lightyellow": (255, 255, 224), + "lemonchiffon": (255, 250, 205), + "papayawhip": (255, 239, 213), + "moccasin": (255, 228, 181), + "peachpuff": (255, 218, 185), + "palegoldenrod": (238, 232, 170), + "khaki": (240, 230, 140), + "darkkhaki": (189, 183, 107), + + # Greens + "lime": (0, 255, 0), + "limegreen": (50, 205, 50), + "forestgreen": (34, 139, 34), + "darkgreen": (0, 100, 0), + "seagreen": (46, 139, 87), + "mediumseagreen": (60, 179, 113), + "springgreen": (0, 255, 127), + "mediumspringgreen": (0, 250, 154), + "lightgreen": (144, 238, 144), + "palegreen": (152, 251, 152), + "darkseagreen": (143, 188, 143), + "greenyellow": (173, 255, 47), + "chartreuse": (127, 255, 0), + "lawngreen": (124, 252, 0), + "olivedrab": (107, 142, 35), + "olive": (128, 128, 0), + "darkolivegreen": (85, 107, 47), + "yellowgreen": (154, 205, 50), + + # Cyans/Teals + "aqua": (0, 255, 255), + "teal": (0, 128, 128), + "darkcyan": (0, 139, 139), + "lightcyan": (224, 255, 255), + "aquamarine": (127, 255, 212), + "mediumaquamarine": (102, 205, 170), + "paleturquoise": (175, 238, 238), + "turquoise": (64, 224, 208), + "mediumturquoise": (72, 209, 204), + "darkturquoise": (0, 206, 209), + "cadetblue": (95, 158, 160), + + # Blues + "navy": (0, 0, 128), + "darkblue": (0, 0, 139), + "mediumblue": (0, 0, 205), + "royalblue": (65, 105, 225), + "cornflowerblue": (100, 149, 237), + "steelblue": (70, 130, 180), + "dodgerblue": (30, 144, 255), + "deepskyblue": (0, 191, 255), + "lightskyblue": (135, 206, 250), + "skyblue": (135, 206, 235), + "lightsteelblue": (176, 196, 222), + "lightblue": (173, 216, 230), + "powderblue": (176, 224, 230), + "slateblue": (106, 90, 205), + "mediumslateblue": (123, 104, 238), + "darkslateblue": (72, 61, 139), + "midnightblue": (25, 25, 112), + + # Purples/Violets + "purple": (128, 0, 128), + "darkmagenta": (139, 0, 139), + "darkviolet": (148, 0, 211), + "blueviolet": (138, 43, 226), + "darkorchid": (153, 50, 204), + "mediumorchid": (186, 85, 211), + "orchid": (218, 112, 214), + "violet": (238, 130, 238), + "plum": (221, 160, 221), + "thistle": (216, 191, 216), + "lavender": (230, 230, 250), + "indigo": (75, 0, 130), + "mediumpurple": (147, 112, 219), + "fuchsia": (255, 0, 255), + "hotpink": (255, 105, 180), + "deeppink": (255, 20, 147), + "mediumvioletred": (199, 21, 133), + "palevioletred": (219, 112, 147), + + # Pinks + "pink": (255, 192, 203), + "lightpink": (255, 182, 193), + "mistyrose": (255, 228, 225), + + # Browns + "brown": (165, 42, 42), + "maroon": (128, 0, 0), + "saddlebrown": (139, 69, 19), + "sienna": (160, 82, 45), + "chocolate": (210, 105, 30), + "peru": (205, 133, 63), + "sandybrown": (244, 164, 96), + "burlywood": (222, 184, 135), + "tan": (210, 180, 140), + "rosybrown": (188, 143, 143), + "goldenrod": (218, 165, 32), + "darkgoldenrod": (184, 134, 11), + + # Whites + "snow": (255, 250, 250), + "honeydew": (240, 255, 240), + "mintcream": (245, 255, 250), + "azure": (240, 255, 255), + "aliceblue": (240, 248, 255), + "ghostwhite": (248, 248, 255), + "whitesmoke": (245, 245, 245), + "seashell": (255, 245, 238), + "beige": (245, 245, 220), + "oldlace": (253, 245, 230), + "floralwhite": (255, 250, 240), + "ivory": (255, 255, 240), + "antiquewhite": (250, 235, 215), + "linen": (250, 240, 230), + "lavenderblush": (255, 240, 245), + "wheat": (245, 222, 179), + "cornsilk": (255, 248, 220), + "blanchedalmond": (255, 235, 205), + "bisque": (255, 228, 196), + "navajowhite": (255, 222, 173), + + # Special + "transparent": (0, 0, 0), # Note: no alpha support, just black +} + + +def parse_color(color_spec: str) -> Optional[Tuple[int, int, int]]: + """ + Parse a color specification into RGB tuple. + + Supports: + - Named colors: "red", "green", "lime", "navy", etc. + - Hex colors: "#FF0000", "#f00", "0xFF0000" + - Special modes: "color", "mono", "invert" return None (handled separately) + + Returns: + RGB tuple (r, g, b) or None for special modes + """ + if color_spec is None: + return None + + color_spec = str(color_spec).strip().lower() + + # Special modes handled elsewhere + if color_spec in ("color", "mono", "invert"): + return None + + # Check named colors + if color_spec in NAMED_COLORS: + return NAMED_COLORS[color_spec] + + # Handle hex colors + hex_str = None + if color_spec.startswith("#"): + hex_str = color_spec[1:] + elif color_spec.startswith("0x"): + hex_str = color_spec[2:] + elif all(c in "0123456789abcdef" for c in color_spec) and len(color_spec) in (3, 6): + hex_str = color_spec + + if hex_str: + try: + if len(hex_str) == 3: + # Short form: #RGB -> #RRGGBB + r = int(hex_str[0] * 2, 16) + g = int(hex_str[1] * 2, 16) + b = int(hex_str[2] * 2, 16) + return (r, g, b) + elif len(hex_str) == 6: + r = int(hex_str[0:2], 16) + g = int(hex_str[2:4], 16) + b = int(hex_str[4:6], 16) + return (r, g, b) + except ValueError: + pass + + # Unknown color - default to None (will use original colors) + return None + + +# ============================================================================= +# Image Primitives +# ============================================================================= + +def prim_width(img: np.ndarray) -> int: + """Get image width.""" + return img.shape[1] + + +def prim_height(img: np.ndarray) -> int: + """Get image height.""" + return img.shape[0] + + +def prim_make_image(w: int, h: int, color: List[int]) -> np.ndarray: + """Create a new image filled with color.""" + img = np.zeros((int(h), int(w), 3), dtype=np.uint8) + if color: + img[:, :] = color[:3] + return img + + +def prim_copy(img: np.ndarray) -> np.ndarray: + """Copy an image.""" + return img.copy() + + +def prim_pixel(img: np.ndarray, x: int, y: int) -> List[int]: + """Get pixel at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + return list(img[y, x]) + return [0, 0, 0] + + +def prim_set_pixel(img: np.ndarray, x: int, y: int, color: List[int]) -> np.ndarray: + """Set pixel at (x, y). Returns modified image.""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + img[y, x] = color[:3] + return img + + +def prim_sample(img: np.ndarray, x: float, y: float) -> List[float]: + """Bilinear sample at float coordinates.""" + h, w = img.shape[:2] + x = np.clip(x, 0, w - 1) + y = np.clip(y, 0, h - 1) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + c = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + return list(c) + + +def prim_channel(img: np.ndarray, c: int) -> np.ndarray: + """Extract a single channel as 2D array.""" + return img[:, :, int(c)].copy() + + +def prim_merge_channels(r: np.ndarray, g: np.ndarray, b: np.ndarray) -> np.ndarray: + """Merge three channels into RGB image.""" + return np.stack([r, g, b], axis=-1).astype(np.uint8) + + +def prim_resize(img: np.ndarray, w: int, h: int, mode: str = "linear") -> np.ndarray: + """Resize image. Mode: linear, nearest, area.""" + w, h = int(w), int(h) + if w < 1 or h < 1: + return img + interp = { + "linear": cv2.INTER_LINEAR, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + }.get(mode, cv2.INTER_LINEAR) + return cv2.resize(img, (w, h), interpolation=interp) + + +def prim_crop(img: np.ndarray, x: int, y: int, w: int, h: int) -> np.ndarray: + """Crop a region from image.""" + ih, iw = img.shape[:2] + x, y, w, h = int(x), int(y), int(w), int(h) + x = max(0, min(x, iw)) + y = max(0, min(y, ih)) + w = max(0, min(w, iw - x)) + h = max(0, min(h, ih - y)) + return img[y:y + h, x:x + w].copy() + + +def prim_paste(dst: np.ndarray, src: np.ndarray, x: int, y: int) -> np.ndarray: + """Paste src onto dst at position (x, y).""" + dh, dw = dst.shape[:2] + sh, sw = src.shape[:2] + x, y = int(x), int(y) + + # Calculate valid regions + sx1 = max(0, -x) + sy1 = max(0, -y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + dx1 = max(0, x) + dy1 = max(0, y) + dx2 = dx1 + (sx2 - sx1) + dy2 = dy1 + (sy2 - sy1) + + if dx2 > dx1 and dy2 > dy1: + dst[dy1:dy2, dx1:dx2] = src[sy1:sy2, sx1:sx2] + + return dst + + +# ============================================================================= +# Color Primitives +# ============================================================================= + +def prim_rgb(r: float, g: float, b: float) -> List[int]: + """Create RGB color.""" + return [int(np.clip(r, 0, 255)), + int(np.clip(g, 0, 255)), + int(np.clip(b, 0, 255))] + + +def prim_red(c: List[int]) -> int: + return c[0] if c else 0 + + +def prim_green(c: List[int]) -> int: + return c[1] if len(c) > 1 else 0 + + +def prim_blue(c: List[int]) -> int: + return c[2] if len(c) > 2 else 0 + + +def prim_luminance(c: List[int]) -> float: + """Calculate luminance (grayscale value).""" + if not c: + return 0 + return 0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2] + + +def prim_rgb_to_hsv(c: List[int]) -> List[float]: + """Convert RGB to HSV.""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + mx, mn = max(r, g, b), min(r, g, b) + diff = mx - mn + + if diff == 0: + h = 0 + elif mx == r: + h = (60 * ((g - b) / diff) + 360) % 360 + elif mx == g: + h = (60 * ((b - r) / diff) + 120) % 360 + else: + h = (60 * ((r - g) / diff) + 240) % 360 + + s = 0 if mx == 0 else diff / mx + v = mx + + return [h, s * 100, v * 100] + + +def prim_hsv_to_rgb(hsv: List[float]) -> List[int]: + """Convert HSV to RGB.""" + h, s, v = hsv[0], hsv[1] / 100, hsv[2] / 100 + c = v * s + x = c * (1 - abs((h / 60) % 2 - 1)) + m = v - c + + if h < 60: + r, g, b = c, x, 0 + elif h < 120: + r, g, b = x, c, 0 + elif h < 180: + r, g, b = 0, c, x + elif h < 240: + r, g, b = 0, x, c + elif h < 300: + r, g, b = x, 0, c + else: + r, g, b = c, 0, x + + return [int((r + m) * 255), int((g + m) * 255), int((b + m) * 255)] + + +def prim_blend_color(c1: List[int], c2: List[int], alpha: float) -> List[int]: + """Blend two colors.""" + alpha = np.clip(alpha, 0, 1) + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img: np.ndarray) -> List[int]: + """Get average color of image/region.""" + return [int(x) for x in img.mean(axis=(0, 1))] + + +# ============================================================================= +# Image Operations (Bulk) +# ============================================================================= + +def prim_map_pixels(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each pixel: fn(x, y, [r,g,b]) -> [r,g,b].""" + result = img.copy() + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + new_color = fn(x, y, color) + if new_color is not None: + result[y, x] = new_color[:3] + return result + + +def prim_map_rows(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each row: fn(y, row) -> row.""" + result = img.copy() + h = img.shape[0] + for y in range(h): + row = img[y].copy() + new_row = fn(y, row) + if new_row is not None: + result[y] = new_row + return result + + +def prim_for_grid(img: np.ndarray, cell_size: int, fn: Callable) -> np.ndarray: + """Iterate over grid cells: fn(gx, gy, cell_img) for side effects.""" + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + for gy in range(rows): + for gx in range(cols): + y, x = gy * cell_size, gx * cell_size + cell = img[y:y + cell_size, x:x + cell_size] + fn(gx, gy, cell) + + return img + + +def prim_fold_pixels(img: np.ndarray, init: Any, fn: Callable) -> Any: + """Fold over pixels: fn(acc, x, y, color) -> acc.""" + acc = init + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + acc = fn(acc, x, y, color) + return acc + + +# ============================================================================= +# Convolution / Filters +# ============================================================================= + +def prim_convolve(img: np.ndarray, kernel: List[List[float]]) -> np.ndarray: + """Apply convolution kernel.""" + k = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, k) + + +def prim_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Gaussian blur.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Box blur (faster than Gaussian).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_edges(img: np.ndarray, low: int = 50, high: int = 150) -> np.ndarray: + """Canny edge detection, returns grayscale edges.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, int(low), int(high)) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img: np.ndarray) -> np.ndarray: + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.float32) + sx = cv2.Sobel(gray, cv2.CV_32F, 1, 0) + sy = cv2.Sobel(gray, cv2.CV_32F, 0, 1) + magnitude = np.sqrt(sx ** 2 + sy ** 2) + magnitude = np.clip(magnitude, 0, 255).astype(np.uint8) + return cv2.cvtColor(magnitude, cv2.COLOR_GRAY2RGB) + + +def prim_dilate(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological dilation.""" + kernel = np.ones((size, size), np.uint8) + return cv2.dilate(img, kernel, iterations=1) + + +def prim_erode(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological erosion.""" + kernel = np.ones((size, size), np.uint8) + return cv2.erode(img, kernel, iterations=1) + + +# ============================================================================= +# Geometric Transforms +# ============================================================================= + +def prim_translate(img: np.ndarray, dx: float, dy: float) -> np.ndarray: + """Translate image.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_rotate(img: np.ndarray, angle: float, cx: float = None, cy: float = None) -> np.ndarray: + """Rotate image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_scale(img: np.ndarray, sx: float, sy: float, cx: float = None, cy: float = None) -> np.ndarray: + """Scale image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_flip_h(img: np.ndarray) -> np.ndarray: + """Flip horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img: np.ndarray) -> np.ndarray: + """Flip vertically.""" + return cv2.flip(img, 0) + + +def prim_remap(img: np.ndarray, map_x: np.ndarray, map_y: np.ndarray) -> np.ndarray: + """Remap using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), map_y.astype(np.float32), + cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + +def prim_make_coords(w: int, h: int) -> Tuple[np.ndarray, np.ndarray]: + """Create coordinate grid (map_x, map_y).""" + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + return map_x, map_y + + +# ============================================================================= +# Blending +# ============================================================================= + +def prim_blend_images(a: np.ndarray, b: np.ndarray, alpha: float) -> np.ndarray: + """Blend two images. Auto-resizes b to match a if sizes differ.""" + alpha = np.clip(alpha, 0, 1) + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a: np.ndarray, b: np.ndarray, mode: str) -> np.ndarray: + """Blend with various modes: add, multiply, screen, overlay, difference. + Auto-resizes b to match a if sizes differ.""" + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + af = a.astype(float) / 255 + bf = b.astype(float) / 255 + + if mode == "add": + result = af + bf + elif mode == "multiply": + result = af * bf + elif mode == "screen": + result = 1 - (1 - af) * (1 - bf) + elif mode == "overlay": + mask = af < 0.5 + result = np.where(mask, 2 * af * bf, 1 - 2 * (1 - af) * (1 - bf)) + elif mode == "difference": + result = np.abs(af - bf) + elif mode == "lighten": + result = np.maximum(af, bf) + elif mode == "darken": + result = np.minimum(af, bf) + else: + result = af + + return (np.clip(result, 0, 1) * 255).astype(np.uint8) + + +def prim_mask(img: np.ndarray, mask_img: np.ndarray) -> np.ndarray: + """Apply grayscale mask to image.""" + if len(mask_img.shape) == 3: + mask = cv2.cvtColor(mask_img, cv2.COLOR_RGB2GRAY) + else: + mask = mask_img + mask_f = mask.astype(float) / 255 + result = img.astype(float) * mask_f[:, :, np.newaxis] + return result.astype(np.uint8) + + +# ============================================================================= +# Drawing +# ============================================================================= + +# Simple font (5x7 bitmap characters) +FONT_5X7 = { + ' ': [0, 0, 0, 0, 0, 0, 0], + '.': [0, 0, 0, 0, 0, 0, 4], + ':': [0, 0, 4, 0, 4, 0, 0], + '-': [0, 0, 0, 14, 0, 0, 0], + '=': [0, 0, 14, 0, 14, 0, 0], + '+': [0, 4, 4, 31, 4, 4, 0], + '*': [0, 4, 21, 14, 21, 4, 0], + '#': [10, 31, 10, 10, 31, 10, 0], + '%': [19, 19, 4, 8, 25, 25, 0], + '@': [14, 17, 23, 21, 23, 16, 14], + '0': [14, 17, 19, 21, 25, 17, 14], + '1': [4, 12, 4, 4, 4, 4, 14], + '2': [14, 17, 1, 2, 4, 8, 31], + '3': [31, 2, 4, 2, 1, 17, 14], + '4': [2, 6, 10, 18, 31, 2, 2], + '5': [31, 16, 30, 1, 1, 17, 14], + '6': [6, 8, 16, 30, 17, 17, 14], + '7': [31, 1, 2, 4, 8, 8, 8], + '8': [14, 17, 17, 14, 17, 17, 14], + '9': [14, 17, 17, 15, 1, 2, 12], +} + +# Add uppercase letters +for i, c in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ'): + FONT_5X7[c] = [0] * 7 # Placeholder + + +def prim_draw_char(img: np.ndarray, char: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw a character at position.""" + # Use OpenCV's built-in font for simplicity + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, char, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_draw_text(img: np.ndarray, text: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw text at position.""" + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, text, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_fill_rect(img: np.ndarray, x: int, y: int, w: int, h: int, + color: List[int]) -> np.ndarray: + """Fill rectangle.""" + x, y, w, h = int(x), int(y), int(w), int(h) + img[y:y + h, x:x + w] = color[:3] + return img + + +def prim_draw_line(img: np.ndarray, x1: int, y1: int, x2: int, y2: int, + color: List[int], thickness: int = 1) -> np.ndarray: + """Draw line.""" + cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), tuple(color[:3]), int(thickness)) + return img + + +# ============================================================================= +# Math Primitives +# ============================================================================= + +def prim_sin(x: float) -> float: + return math.sin(x) + + +def prim_cos(x: float) -> float: + return math.cos(x) + + +def prim_tan(x: float) -> float: + return math.tan(x) + + +def prim_atan2(y, x): + if hasattr(y, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape) + return math.atan2(y, x) + + +def prim_sqrt(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.sqrt(np.maximum(0, x._data)), x._shape) + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(0, x)) + return math.sqrt(max(0, x)) + + +def prim_pow(x, y): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + y_data = y._data if hasattr(y, '_data') else y + return Xector(np.power(x._data, y_data), x._shape) + return math.pow(x, y) + + +def prim_abs(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.abs(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.abs(x) + return abs(x) + + +def prim_floor(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) + return int(math.floor(x)) + + +def prim_ceil(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) + return int(math.ceil(x)) + + +def prim_round(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) + return int(round(x)) + + +def prim_min(*args) -> float: + return min(args) + + +def prim_max(*args) -> float: + return max(args) + + +def prim_clamp(x: float, lo: float, hi: float) -> float: + return max(lo, min(hi, x)) + + +def prim_lerp(a: float, b: float, t: float) -> float: + """Linear interpolation.""" + return a + (b - a) * t + + +def prim_mod(a: float, b: float) -> float: + return a % b + + +def prim_random(lo: float = 0, hi: float = 1) -> float: + """Random number from global RNG.""" + return _rng.random(lo, hi) + + +def prim_randint(lo: int, hi: int) -> int: + """Random integer from global RNG.""" + return _rng.randint(lo, hi) + + +def prim_gaussian(mean: float = 0, std: float = 1) -> float: + """Gaussian random from global RNG.""" + return _rng.gaussian(mean, std) + + +def prim_assert(condition, message: str = "Assertion failed"): + """Assert that condition is true, raise error with message if false.""" + if not condition: + raise RuntimeError(f"Assertion error: {message}") + return True + + +# ============================================================================= +# Array/List Primitives +# ============================================================================= + +def prim_length(seq) -> int: + return len(seq) + + +def prim_nth(seq, i: int): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_rest(seq): + return seq[1:] if seq else [] + + +def prim_take(seq, n: int): + return seq[:int(n)] + + +def prim_drop(seq, n: int): + return seq[int(n):] + + +def prim_cons(x, seq): + return [x] + list(seq) + + +def prim_append(*seqs): + result = [] + for s in seqs: + result.extend(s) + return result + + +def prim_reverse(seq): + return list(reversed(seq)) + + +def prim_range(start: int, end: int, step: int = 1) -> List[int]: + return list(range(int(start), int(end), int(step))) + + +def prim_roll(arr: np.ndarray, shift: int, axis: int = 0) -> np.ndarray: + """Circular roll of array.""" + return np.roll(arr, int(shift), axis=int(axis)) + + +def prim_list(*args) -> list: + """Create a list.""" + return list(args) + + +# ============================================================================= +# Primitive Registry +# ============================================================================= + +def prim_add(*args): + return sum(args) + +def prim_sub(a, b=None): + if b is None: + return -a # Unary negation + return a - b + +def prim_mul(*args): + result = 1 + for x in args: + result *= x + return result + +def prim_div(a, b): + return a / b if b != 0 else 0 + +def prim_lt(a, b): + return a < b + +def prim_gt(a, b): + return a > b + +def prim_le(a, b): + return a <= b + +def prim_ge(a, b): + return a >= b + +def prim_eq(a, b): + # Handle None/nil comparisons with numpy arrays + if a is None: + return b is None + if b is None: + return a is None + if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.array_equal(a, b) + return False # array vs non-array + return a == b + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# ============================================================================= +# Vectorized Bulk Operations (true primitives for composing effects) +# ============================================================================= + +def prim_color_matrix(img: np.ndarray, matrix: List[List[float]]) -> np.ndarray: + """Apply a 3x3 color transformation matrix to all pixels.""" + m = np.array(matrix, dtype=np.float32) + result = img.astype(np.float32) @ m.T + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_adjust(img: np.ndarray, brightness: float = 0, contrast: float = 1) -> np.ndarray: + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img: np.ndarray, amount: float) -> np.ndarray: + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img: np.ndarray) -> np.ndarray: + """Invert all pixel values.""" + return (255 - img).astype(np.uint8) + + +def prim_add_noise(img: np.ndarray, amount: float) -> np.ndarray: + """Add gaussian noise to image.""" + noise = _rng._rng.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img: np.ndarray, levels: int) -> np.ndarray: + """Reduce to N color levels per channel.""" + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_shift_hsv(img: np.ndarray, h: float = 0, s: float = 1, v: float = 1) -> np.ndarray: + """Shift HSV: h=degrees offset, s/v=multipliers.""" + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +# ============================================================================= +# Array Math Primitives (vectorized operations on coordinate arrays) +# ============================================================================= + +def prim_arr_add(a: np.ndarray, b) -> np.ndarray: + """Element-wise addition. b can be array or scalar.""" + return (np.asarray(a) + np.asarray(b)).astype(np.float32) + + +def prim_arr_sub(a: np.ndarray, b) -> np.ndarray: + """Element-wise subtraction. b can be array or scalar.""" + return (np.asarray(a) - np.asarray(b)).astype(np.float32) + + +def prim_arr_mul(a: np.ndarray, b) -> np.ndarray: + """Element-wise multiplication. b can be array or scalar.""" + return (np.asarray(a) * np.asarray(b)).astype(np.float32) + + +def prim_arr_div(a: np.ndarray, b) -> np.ndarray: + """Element-wise division. b can be array or scalar.""" + b = np.asarray(b) + # Avoid division by zero + with np.errstate(divide='ignore', invalid='ignore'): + result = np.asarray(a) / np.where(b == 0, 1e-10, b) + return result.astype(np.float32) + + +def prim_arr_mod(a: np.ndarray, b) -> np.ndarray: + """Element-wise modulo.""" + return (np.asarray(a) % np.asarray(b)).astype(np.float32) + + +def prim_arr_sin(a: np.ndarray) -> np.ndarray: + """Element-wise sine.""" + return np.sin(np.asarray(a)).astype(np.float32) + + +def prim_arr_cos(a: np.ndarray) -> np.ndarray: + """Element-wise cosine.""" + return np.cos(np.asarray(a)).astype(np.float32) + + +def prim_arr_tan(a: np.ndarray) -> np.ndarray: + """Element-wise tangent.""" + return np.tan(np.asarray(a)).astype(np.float32) + + +def prim_arr_sqrt(a: np.ndarray) -> np.ndarray: + """Element-wise square root.""" + return np.sqrt(np.maximum(0, np.asarray(a))).astype(np.float32) + + +def prim_arr_pow(a: np.ndarray, b) -> np.ndarray: + """Element-wise power.""" + return np.power(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_abs(a: np.ndarray) -> np.ndarray: + """Element-wise absolute value.""" + return np.abs(np.asarray(a)).astype(np.float32) + + +def prim_arr_neg(a: np.ndarray) -> np.ndarray: + """Element-wise negation.""" + return (-np.asarray(a)).astype(np.float32) + + +def prim_arr_exp(a: np.ndarray) -> np.ndarray: + """Element-wise exponential.""" + return np.exp(np.asarray(a)).astype(np.float32) + + +def prim_arr_atan2(y: np.ndarray, x: np.ndarray) -> np.ndarray: + """Element-wise atan2(y, x).""" + return np.arctan2(np.asarray(y), np.asarray(x)).astype(np.float32) + + +def prim_arr_min(a: np.ndarray, b) -> np.ndarray: + """Element-wise minimum.""" + return np.minimum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_max(a: np.ndarray, b) -> np.ndarray: + """Element-wise maximum.""" + return np.maximum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_clip(a: np.ndarray, lo, hi) -> np.ndarray: + """Element-wise clip to range.""" + return np.clip(np.asarray(a), lo, hi).astype(np.float32) + + +def prim_arr_where(cond: np.ndarray, a, b) -> np.ndarray: + """Element-wise conditional: where cond is true, use a, else b.""" + return np.where(np.asarray(cond), np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_floor(a: np.ndarray) -> np.ndarray: + """Element-wise floor.""" + return np.floor(np.asarray(a)).astype(np.float32) + + +def prim_arr_lerp(a: np.ndarray, b: np.ndarray, t) -> np.ndarray: + """Element-wise linear interpolation.""" + a, b = np.asarray(a), np.asarray(b) + return (a + (b - a) * t).astype(np.float32) + + +# ============================================================================= +# Coordinate Transformation Primitives +# ============================================================================= + +def prim_polar_from_center(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create polar coordinates (r, theta) from image center. + + Usage: + (polar-from-center img) ; center of image + (polar-from-center img cx cy) ; custom center + (polar-from-center w h cx cy) ; explicit dimensions + + Returns: (r, theta) tuple of arrays + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + + return (r, theta) + + +def prim_cart_from_polar(r: np.ndarray, theta: np.ndarray, cx: float, cy: float) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert polar coordinates back to Cartesian. + + Args: + r: radius array + theta: angle array + cx, cy: center point + + Returns: (x, y) tuple of coordinate arrays + """ + x = (cx + r * np.cos(theta)).astype(np.float32) + y = (cy + r * np.sin(theta)).astype(np.float32) + return (x, y) + + +def prim_normalize_coords(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create normalized coordinates (-1 to 1) from center. + + Returns: (x_norm, y_norm) tuple of arrays where center is (0,0) + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + + return (x_norm, y_norm) + + +def prim_coords_x(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get x/first component from coordinate tuple.""" + return coords[0] + + +def prim_coords_y(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get y/second component from coordinate tuple.""" + return coords[1] + + +def prim_make_coords_centered(w: int, h: int, cx: float = None, cy: float = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create coordinate grids centered at (cx, cy). + Like make-coords but returns coordinates relative to center. + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + return (x_coords - cx, y_coords - cy) + + +# ============================================================================= +# Specialized Distortion Primitives +# ============================================================================= + +def prim_wave_displace(w: int, h: int, axis: str, freq: float, amp: float, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create wave displacement maps. + + Args: + w, h: dimensions + axis: "x" (horizontal waves) or "y" (vertical waves) + freq: wave frequency (waves per image width/height) + amp: wave amplitude in pixels + phase: phase offset in radians + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + + if axis == "x" or axis == "horizontal": + # Horizontal waves: displace x based on y + wave = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + map_x = map_x + wave + elif axis == "y" or axis == "vertical": + # Vertical waves: displace y based on x + wave = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_y = map_y + wave + elif axis == "both": + wave_x = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + wave_y = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_x = map_x + wave_x + map_y = map_y + wave_y + + return (map_x, map_y) + + +def prim_swirl_displace(w: int, h: int, strength: float, radius: float = 0.5, + cx: float = None, cy: float = None, falloff: str = "quadratic") -> Tuple[np.ndarray, np.ndarray]: + """ + Create swirl displacement maps. + + Args: + w, h: dimensions + strength: swirl strength in radians + radius: effect radius as fraction of max dimension + cx, cy: center (defaults to image center) + falloff: "linear", "quadratic", or "gaussian" + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + radius_px = max(w, h) * radius + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + angle = np.arctan2(dy, dx) + + # Normalized distance for falloff + norm_dist = dist / radius_px + + # Calculate falloff factor + if falloff == "linear": + factor = np.maximum(0, 1 - norm_dist) + elif falloff == "gaussian": + factor = np.exp(-norm_dist**2 * 2) + else: # quadratic + factor = np.maximum(0, 1 - norm_dist**2) + + # Apply swirl rotation + new_angle = angle + strength * factor + + # Calculate new coordinates + map_x = (cx + dist * np.cos(new_angle)).astype(np.float32) + map_y = (cy + dist * np.sin(new_angle)).astype(np.float32) + + return (map_x, map_y) + + +def prim_fisheye_displace(w: int, h: int, strength: float, cx: float = None, cy: float = None, + zoom_correct: bool = True) -> Tuple[np.ndarray, np.ndarray]: + """ + Create fisheye/barrel distortion displacement maps. + + Args: + w, h: dimensions + strength: distortion strength (-1 to 1, positive=bulge, negative=pinch) + cx, cy: center (defaults to image center) + zoom_correct: auto-zoom to hide black edges + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize coordinates + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + r = np.sqrt(x_norm**2 + y_norm**2) + + # Apply barrel/pincushion distortion + if strength > 0: + r_distorted = r * (1 + strength * r**2) + else: + r_distorted = r / (1 - strength * r**2 + 0.001) + + # Calculate scale factor + with np.errstate(divide='ignore', invalid='ignore'): + scale = np.where(r > 0, r_distorted / r, 1) + + # Apply zoom correction + if zoom_correct and strength > 0: + zoom = 1 + strength * 0.5 + scale = scale / zoom + + # Calculate new coordinates + map_x = (x_norm * scale * (w / 2) + cx).astype(np.float32) + map_y = (y_norm * scale * (h / 2) + cy).astype(np.float32) + + return (map_x, map_y) + + +def prim_kaleidoscope_displace(w: int, h: int, segments: int, rotation: float = 0, + cx: float = None, cy: float = None, zoom: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create kaleidoscope displacement maps. + + Args: + w, h: dimensions + segments: number of symmetry segments (3-16) + rotation: rotation angle in degrees + cx, cy: center (defaults to image center) + zoom: zoom factor + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + segments = max(3, min(int(segments), 16)) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + segment_angle = 2 * np.pi / segments + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Translate to center + x_centered = x_coords - cx + y_centered = y_coords - cy + + # Convert to polar + r = np.sqrt(x_centered**2 + y_centered**2) + theta = np.arctan2(y_centered, x_centered) + + # Apply rotation + theta = theta - np.deg2rad(rotation) + + # Fold angle into first segment and mirror + theta_normalized = theta % (2 * np.pi) + segment_idx = (theta_normalized / segment_angle).astype(int) + theta_in_segment = theta_normalized - segment_idx * segment_angle + + # Mirror alternating segments + mirror_mask = (segment_idx % 2) == 1 + theta_in_segment = np.where(mirror_mask, segment_angle - theta_in_segment, theta_in_segment) + + # Apply zoom + r = r / zoom + + # Convert back to Cartesian + map_x = (r * np.cos(theta_in_segment) + cx).astype(np.float32) + map_y = (r * np.sin(theta_in_segment) + cy).astype(np.float32) + + return (map_x, map_y) + + +# ============================================================================= +# Character/ASCII Art Primitives +# ============================================================================= + +# Character sets ordered by visual density (light to dark) +CHAR_ALPHABETS = { + "standard": " .`'^\",:;Il!i><~+_-?][}{1)(|/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$", + "blocks": " ░▒▓█", + "simple": " .-:=+*#%@", + "digits": " 0123456789", +} + +# Global atlas cache: keyed on (frozenset(chars), cell_size) -> +# (atlas_array, char_to_idx) where atlas_array is (N, cell_size, cell_size) uint8. +_char_atlas_cache = {} +_CHAR_ATLAS_CACHE_MAX = 32 + + +def _get_char_atlas(alphabet: str, cell_size: int) -> dict: + """Get or create character atlas for alphabet (legacy dict version).""" + atlas_arr, char_to_idx = _get_render_atlas(alphabet, cell_size) + # Build legacy dict from array + idx_to_char = {v: k for k, v in char_to_idx.items()} + return {idx_to_char[i]: atlas_arr[i] for i in range(len(atlas_arr))} + + +def _get_render_atlas(unique_chars_or_alphabet, cell_size: int): + """Get or build a stacked numpy atlas for vectorised rendering. + + Args: + unique_chars_or_alphabet: Either an alphabet name (str looked up in + CHAR_ALPHABETS), a literal character string, or a set/frozenset + of characters. + cell_size: Pixel size of each cell. + + Returns: + (atlas_array, char_to_idx) where + atlas_array: (num_chars, cell_size, cell_size) uint8 masks + char_to_idx: dict mapping character -> index in atlas_array + """ + if isinstance(unique_chars_or_alphabet, (set, frozenset)): + chars_tuple = tuple(sorted(unique_chars_or_alphabet)) + else: + resolved = CHAR_ALPHABETS.get(unique_chars_or_alphabet, unique_chars_or_alphabet) + chars_tuple = tuple(resolved) + + cache_key = (chars_tuple, cell_size) + cached = _char_atlas_cache.get(cache_key) + if cached is not None: + return cached + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + n = len(chars_tuple) + atlas = np.zeros((n, cell_size, cell_size), dtype=np.uint8) + char_to_idx = {} + + for i, char in enumerate(chars_tuple): + char_to_idx[char] = i + if char and char != ' ': + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + cv2.putText(atlas[i], char, (text_x, text_y), + font, font_scale, 255, thickness, cv2.LINE_AA) + except Exception: + pass + + # Evict oldest entry if cache is full + if len(_char_atlas_cache) >= _CHAR_ATLAS_CACHE_MAX: + _char_atlas_cache.pop(next(iter(_char_atlas_cache))) + + _char_atlas_cache[cache_key] = (atlas, char_to_idx) + return atlas, char_to_idx + + +def prim_cell_sample(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample image into cell grid, returning average colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + + Args: + img: source image + cell_size: size of each cell in pixels + + Returns: (colors, luminances) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32)) + + # Crop to exact grid then block-average via cv2 area interpolation. + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + return (colors, luminances) + + +def cell_sample_extended(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray, List[List[ZoneContext]]]: + """ + Sample image into cell grid, returning colors, luminances, and full zone contexts. + + Args: + img: source image (RGB) + cell_size: size of each cell in pixels + + Returns: (colors, luminances, zone_contexts) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + - zone_contexts: 2D list of ZoneContext objects with full cell data + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32), + [[ZoneContext(0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]]) + + # Crop to grid + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + + # Reshape and average + reshaped = cropped.reshape(rows, cell_size, cols, cell_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(np.uint8) + + # Compute luminance (0-255) + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + # Normalize colors to 0-1 for HSV/saturation calculations + colors_float = colors.astype(np.float32) / 255.0 + + # Compute HSV values for each cell + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + diff = max_c - min_c + + # Saturation + saturation = np.where(max_c > 0, diff / max_c, 0) + + # Hue (0-360) + hue = np.zeros((rows, cols), dtype=np.float32) + # Avoid division by zero + mask = diff > 0 + r, g, b = colors_float[:, :, 0], colors_float[:, :, 1], colors_float[:, :, 2] + + # Red is max + red_max = mask & (max_c == r) + hue[red_max] = 60 * (((g[red_max] - b[red_max]) / diff[red_max]) % 6) + + # Green is max + green_max = mask & (max_c == g) + hue[green_max] = 60 * ((b[green_max] - r[green_max]) / diff[green_max] + 2) + + # Blue is max + blue_max = mask & (max_c == b) + hue[blue_max] = 60 * ((r[blue_max] - g[blue_max]) / diff[blue_max] + 4) + + # Ensure hue is in 0-360 range + hue = hue % 360 + + # Build zone contexts + zone_contexts = [] + for row in range(rows): + row_contexts = [] + for col in range(cols): + ctx = ZoneContext( + row=row, + col=col, + row_norm=row / max(1, rows - 1) if rows > 1 else 0.5, + col_norm=col / max(1, cols - 1) if cols > 1 else 0.5, + luminance=luminances[row, col] / 255.0, # Normalize to 0-1 + saturation=float(saturation[row, col]), + hue=float(hue[row, col]), + r=float(colors_float[row, col, 0]), + g=float(colors_float[row, col, 1]), + b=float(colors_float[row, col, 2]), + ) + row_contexts.append(ctx) + zone_contexts.append(row_contexts) + + return (colors, luminances, zone_contexts) + + +def prim_luminance_to_chars(luminances: np.ndarray, alphabet: str, contrast: float = 1.0) -> List[List[str]]: + """ + Map luminance values to characters from alphabet. + + Args: + luminances: (rows, cols) array of brightness values 0-255 + alphabet: character set name or literal string (light to dark) + contrast: contrast boost factor + + Returns: 2D list of single-character strings + """ + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + num_chars = len(chars) + + # Apply contrast + lum = luminances.astype(np.float32) + if contrast != 1.0: + lum = (lum - 128) * contrast + 128 + lum = np.clip(lum, 0, 255) + + # Map to indices + indices = ((lum / 255) * (num_chars - 1)).astype(np.int32) + indices = np.clip(indices, 0, num_chars - 1) + + # Vectorised conversion via numpy char array lookup + chars_arr = np.array(list(chars)) + char_grid = chars_arr[indices.ravel()].reshape(indices.shape) + + return char_grid.tolist() + + +def prim_render_char_grid(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + cell_size: int, color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False) -> np.ndarray: + """ + Render a grid of characters onto an image. + + Uses vectorised numpy operations instead of per-cell Python loops: + the character atlas is looked up via fancy indexing and the full + mask + colour image are assembled in bulk. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + cell_size: size of each cell + color_mode: "color" (original colors), "mono" (white), "invert", + or any color name/hex value ("green", "lime", "#00ff00") + background_color: background color name/hex ("black", "navy", "#001100") + invert_colors: if True, swap foreground and background colors + + Returns: rendered image + """ + # Parse color_mode - may be a named color or hex value + fg_color = parse_color(color_mode) + + # Parse background_color + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + # Handle invert_colors - swap fg and bg + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + + # --- Build atlas & index grid --- + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + atlas, char_to_idx = _get_render_atlas(unique_chars, cell_size) + + # Convert 2D char list to index array using ordinal lookup table + # (avoids per-cell Python dict lookup). + space_idx = char_to_idx.get(' ', 0) + max_ord = max(ord(ch) for ch in char_to_idx) + 1 + ord_lookup = np.full(max_ord, space_idx, dtype=np.int32) + for ch, idx in char_to_idx.items(): + if ch: + ord_lookup[ord(ch)] = idx + + flat = [ch for row in chars for ch in row] + ords = np.frombuffer(np.array(flat, dtype='U1'), dtype=np.uint32) + char_indices = ord_lookup[ords].reshape(rows, cols) + + # --- Vectorised mask assembly --- + # atlas[char_indices] -> (rows, cols, cell_size, cell_size) + # Transpose to (rows, cell_size, cols, cell_size) then reshape to full image. + all_masks = atlas[char_indices] + full_mask = all_masks.transpose(0, 2, 1, 3).reshape(h, w) + + # Expand per-cell colours to per-pixel (only when needed). + need_color_full = (color_mode in ("color", "invert") + or (fg_color is None and color_mode != "mono")) + + if need_color_full: + color_full = np.repeat( + np.repeat(colors[:rows, :cols], cell_size, axis=0), + cell_size, axis=1) + + # --- Vectorised colour composite --- + # Use element-wise multiply/np.where instead of boolean-indexed scatter + # for much better memory access patterns. + mask_u8 = (full_mask > 0).astype(np.uint8)[:, :, np.newaxis] + + if color_mode == "invert": + # Background is source colour; characters are black. + # result = color_full * (1 - mask) + result = color_full * (1 - mask_u8) + elif fg_color is not None: + # Fixed foreground colour on background. + fg = np.array(fg_color, dtype=np.uint8) + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, fg, bg_arr).astype(np.uint8) + elif color_mode == "mono": + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, np.uint8(255), bg_arr).astype(np.uint8) + else: + # "color" mode – each cell uses its source colour on bg. + if bg == [0, 0, 0]: + result = color_full * mask_u8 + else: + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, color_full, bg_arr).astype(np.uint8) + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_render_char_grid_fx(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + luminances: np.ndarray, cell_size: int, + color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False, + char_jitter: float = 0.0, + char_scale: float = 1.0, + char_rotation: float = 0.0, + char_hue_shift: float = 0.0, + jitter_source: str = "none", + scale_source: str = "none", + rotation_source: str = "none", + hue_source: str = "none") -> np.ndarray: + """ + Render a grid of characters with per-character effects. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + luminances: (rows, cols) array of luminance values (0-255) + cell_size: size of each cell + color_mode: "color", "mono", "invert", or any color name/hex + background_color: background color name/hex + invert_colors: if True, swap foreground and background colors + char_jitter: base jitter amount in pixels + char_scale: base scale factor (1.0 = normal) + char_rotation: base rotation in degrees + char_hue_shift: base hue shift in degrees (0-360) + jitter_source: source for jitter modulation ("none", "luminance", "position", "random") + scale_source: source for scale modulation + rotation_source: source for rotation modulation + hue_source: source for hue modulation + + Per-character effect sources: + "none" - use base value only + "luminance" - modulate by cell luminance (0-1) + "inv_luminance" - modulate by inverse luminance (dark = high) + "saturation" - modulate by cell color saturation + "position_x" - modulate by horizontal position (0-1) + "position_y" - modulate by vertical position (0-1) + "position_diag" - modulate by diagonal position + "random" - random per-cell value (deterministic from position) + "center_dist" - distance from center (0=center, 1=corner) + + Returns: rendered image + """ + # Parse colors + fg_color = parse_color(color_mode) + + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + result = np.full((h, w, 3), bg, dtype=np.uint8) + + # Normalize luminances to 0-1 + lum_normalized = luminances.astype(np.float32) / 255.0 + + # Compute saturation from colors + colors_float = colors.astype(np.float32) / 255.0 + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0) + + # Helper to get modulation value for a cell + def get_mod_value(source: str, r: int, c: int) -> float: + if source == "none": + return 1.0 + elif source == "luminance": + return lum_normalized[r, c] + elif source == "inv_luminance": + return 1.0 - lum_normalized[r, c] + elif source == "saturation": + return saturation[r, c] + elif source == "position_x": + return c / max(1, cols - 1) if cols > 1 else 0.5 + elif source == "position_y": + return r / max(1, rows - 1) if rows > 1 else 0.5 + elif source == "position_diag": + px = c / max(1, cols - 1) if cols > 1 else 0.5 + py = r / max(1, rows - 1) if rows > 1 else 0.5 + return (px + py) / 2.0 + elif source == "random": + # Deterministic random based on position + seed = (r * 1000 + c) % 10000 + return ((seed * 9301 + 49297) % 233280) / 233280.0 + elif source == "center_dist": + cx, cy = (cols - 1) / 2.0, (rows - 1) / 2.0 + dx = (c - cx) / max(1, cx) if cx > 0 else 0 + dy = (r - cy) / max(1, cy) if cy > 0 else 0 + return min(1.0, math.sqrt(dx*dx + dy*dy)) + else: + return 1.0 + + # Build character atlas at base size + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, we need to render characters larger then transform + max_scale = max(1.0, char_scale * 1.5) # Allow headroom for scaling + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with effects + for r in range(rows): + for c in range(cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell modulation values + jitter_mod = get_mod_value(jitter_source, r, c) + scale_mod = get_mod_value(scale_source, r, c) + rot_mod = get_mod_value(rotation_source, r, c) + hue_mod = get_mod_value(hue_source, r, c) + + # Compute effective values + eff_jitter = char_jitter * jitter_mod + eff_scale = char_scale * (0.5 + 0.5 * scale_mod) if scale_source != "none" else char_scale + eff_rotation = char_rotation * (rot_mod * 2 - 1) # -1 to 1 range + eff_hue_shift = char_hue_shift * hue_mod + + # Apply transformations + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + # Fill cell with source color first + cy1 = max(0, r * cell_size) + cy2 = min(h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + # Convert to HSV, shift hue, convert back + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + # Cast to int to avoid uint8 overflow, then back to uint8 + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < h and 0 <= px < w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def _render_with_cell_effect( + frame: np.ndarray, + chars: List[List[str]], + colors: np.ndarray, + luminances: np.ndarray, + zone_contexts: List[List['ZoneContext']], + cell_size: int, + bg_color: tuple, + fg_color: tuple, + color_mode: str, + cell_effect, # Lambda or callable: (cell_image, zone_dict) -> cell_image + extra_params: dict, + interp, + env, + result: np.ndarray, +) -> np.ndarray: + """ + Render ASCII art using a cell_effect lambda for arbitrary per-cell transforms. + + Each character is rendered to a cell image, the cell_effect is called with + (cell_image, zone_dict), and the returned cell is composited into result. + + This allows arbitrary effects (rotate, blur, etc.) to be applied per-character. + """ + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + out_h, out_w = result.shape[:2] + + # Build character atlas (cell-sized colored characters on transparent bg) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + # Helper to render a single character cell + def render_char_cell(char: str, color: np.ndarray) -> np.ndarray: + """Render a character onto a cell-sized RGB image.""" + cell = np.full((cell_size, cell_size, 3), bg_color, dtype=np.uint8) + if not char or char == ' ': + return cell + + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + + # Render character in white on mask, then apply color + mask = np.zeros((cell_size, cell_size), dtype=np.uint8) + cv2.putText(mask, char, (text_x, text_y), font, font_scale, 255, thickness, cv2.LINE_AA) + + # Apply color where mask is set + for ch in range(3): + cell[:, :, ch] = np.where(mask > 0, color[ch], bg_color[ch]) + except: + pass + + return cell + + # Helper to evaluate cell_effect (handles artdag Lambda objects) + def eval_cell_effect(cell_img: np.ndarray, zone_dict: dict) -> np.ndarray: + """Call cell_effect with (cell_image, zone_dict), handle Lambda objects.""" + if callable(cell_effect): + return cell_effect(cell_img, zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(cell_effect, ArtdagLambda): + # Build env with closure values + eval_env = dict(cell_effect.closure) if cell_effect.closure else {} + # Bind lambda parameters + if len(cell_effect.params) >= 2: + eval_env[cell_effect.params[0]] = cell_img + eval_env[cell_effect.params[1]] = zone_dict + elif len(cell_effect.params) == 1: + # Single param gets zone_dict with cell as 'cell' key + zone_dict['cell'] = cell_img + eval_env[cell_effect.params[0]] = zone_dict + + # Add primitives to eval env + eval_env.update(PRIMITIVES) + + # Add effect runner - allows calling any loaded sexp effect on a cell + # Usage: (apply-effect "effect_name" cell {"param" value ...}) + # Or: (apply-effect "effect_name" cell) for defaults + def apply_effect_fn(effect_name, frame, params=None): + """Run a loaded sexp effect on a frame (cell).""" + if interp and hasattr(interp, 'run_effect'): + if params is None: + params = {} + result, _ = interp.run_effect(effect_name, frame, params, {}) + return result + return frame + eval_env['apply-effect'] = apply_effect_fn + + # Also inject loaded effects directly as callable functions + # These wrappers take positional args in common order for each effect + # Usage: (blur cell 5) or (rotate cell 45) etc. + if interp and hasattr(interp, 'effects'): + for effect_name in interp.effects: + # Create a wrapper that calls run_effect with positional-to-named mapping + def make_effect_fn(name): + def effect_fn(frame, *args): + # Map common positional args to named params + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 1: + params['offset_x'] = args[0] + if len(args) >= 2: + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['block_size'] = args[0] + elif name == 'wave' and len(args) >= 1: + params['amplitude'] = args[0] + if len(args) >= 2: + params['frequency'] = args[1] + elif name == 'noise' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'posterize' and len(args) >= 1: + params['levels'] = args[0] + elif name == 'threshold' and len(args) >= 1: + params['level'] = args[0] + elif name == 'sharpen' and len(args) >= 1: + params['amount'] = args[0] + elif len(args) == 1 and isinstance(args[0], dict): + # Accept dict as single arg + params = args[0] + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + eval_env[effect_name] = make_effect_fn(effect_name) + + result = artdag_evaluate(cell_effect.body, eval_env) + if isinstance(result, np.ndarray): + return result + return cell_img + except ImportError: + pass + + # Fallback: return cell unchanged + return cell_img + + # Render each cell + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + zone = zone_contexts[r][c] + + # Determine character color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + color = np.array([0, 0, 0], dtype=np.uint8) + else: + color = colors[r, c].copy() + + # Render character to cell image + cell_img = render_char_cell(char, color) + + # Build zone dict + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + 'color': color.tolist(), + 'cell_size': cell_size, + } + # Add extra params (energy, rotation_scale, etc.) + if extra_params: + zone_dict.update(extra_params) + + # Call cell_effect + modified_cell = eval_cell_effect(cell_img, zone_dict) + + # Ensure result is valid + if modified_cell is None or not isinstance(modified_cell, np.ndarray): + modified_cell = cell_img + if modified_cell.shape[:2] != (cell_size, cell_size): + # Resize if cell size changed + modified_cell = cv2.resize(modified_cell, (cell_size, cell_size)) + if len(modified_cell.shape) == 2: + # Convert grayscale to RGB + modified_cell = cv2.cvtColor(modified_cell, cv2.COLOR_GRAY2RGB) + + # Composite into result + y1 = r * cell_size + x1 = c * cell_size + y2 = min(y1 + cell_size, out_h) + x2 = min(x1 + cell_size, out_w) + ch = y2 - y1 + cw = x2 - x1 + result[y1:y2, x1:x2] = modified_cell[:ch, :cw] + + # Resize to match original frame if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + bg = list(bg_color) + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int, + char_size_override: int, # If set, overrides cols-based calculation + alphabet: str, + color_mode: str, + background: str, + contrast: float, + char_hue_expr, # Expression, literal, or None + char_sat_expr, # Expression, literal, or None + char_bright_expr, # Expression, literal, or None + char_scale_expr, # Expression, literal, or None + char_rotation_expr, # Expression, literal, or None + char_jitter_expr, # Expression, literal, or None + interp, # Interpreter for expression evaluation + env, # Environment with bound values + extra_params=None, # Extra params to include in zone dict for lambdas + cell_effect=None, # Lambda (cell_image, zone_dict) -> cell_image for arbitrary cell effects +) -> np.ndarray: + """ + Render ASCII art with per-zone expression-driven transforms. + + Args: + frame: Source image (H, W, 3) RGB uint8 + cols: Number of character columns + char_size_override: If set, use this cell size instead of cols-based + alphabet: Character set name or literal string + color_mode: "color", "mono", "invert", or color name/hex + background: Background color name or hex + contrast: Contrast boost for character selection + char_hue_expr: Expression for hue shift (evaluated per zone) + char_sat_expr: Expression for saturation adjustment (evaluated per zone) + char_bright_expr: Expression for brightness adjustment (evaluated per zone) + char_scale_expr: Expression for scale factor (evaluated per zone) + char_rotation_expr: Expression for rotation degrees (evaluated per zone) + char_jitter_expr: Expression for position jitter (evaluated per zone) + interp: Interpreter instance for expression evaluation + env: Environment with bound variables + cell_effect: Optional lambda that receives (cell_image, zone_dict) and returns + a modified cell_image. When provided, each character is rendered + to a cell image, passed to this lambda, and the result composited. + This allows arbitrary effects to be applied per-character. + + Zone variables available in expressions: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: Rendered image + """ + h, w = frame.shape[:2] + # Use char_size if provided, otherwise calculate from cols + if char_size_override is not None: + cell_size = max(4, int(char_size_override)) + else: + cell_size = max(4, w // cols) + + # Get zone data using extended sampling + colors, luminances, zone_contexts = cell_sample_extended(frame, cell_size) + + # Convert luminances to characters + chars = prim_luminance_to_chars(luminances, alphabet, contrast) + + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + + # Parse colors + fg_color = parse_color(color_mode) + if isinstance(background, (list, tuple)): + bg_color = tuple(int(c) for c in background[:3]) + else: + bg_color = parse_color(background) + if bg_color is None: + bg_color = (0, 0, 0) + + # Arrays for per-zone transform values + hue_shifts = np.zeros((grid_rows, grid_cols), dtype=np.float32) + saturations = np.ones((grid_rows, grid_cols), dtype=np.float32) + brightness = np.ones((grid_rows, grid_cols), dtype=np.float32) + scales = np.ones((grid_rows, grid_cols), dtype=np.float32) + rotations = np.zeros((grid_rows, grid_cols), dtype=np.float32) + jitters = np.zeros((grid_rows, grid_cols), dtype=np.float32) + + # Helper to evaluate expression or return literal value + def eval_expr(expr, zone, char): + if expr is None: + return None + if isinstance(expr, (int, float)): + return expr + + # Build zone dict for lambda calls + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + } + # Add extra params (energy, rotation_scale, etc.) for lambdas to access + if extra_params: + zone_dict.update(extra_params) + + # Check if it's a Python callable + if callable(expr): + return expr(zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(expr, ArtdagLambda): + # Build env with zone dict and any closure values + eval_env = dict(expr.closure) if expr.closure else {} + # Bind the lambda parameter to zone_dict + if expr.params: + eval_env[expr.params[0]] = zone_dict + return artdag_evaluate(expr.body, eval_env) + except ImportError: + pass + + # It's an expression - evaluate with zone context (sexp_effects style) + return interp.eval_with_zone(expr, env, zone) + + # Evaluate expressions for each zone + for r in range(grid_rows): + for c in range(grid_cols): + zone = zone_contexts[r][c] + char = chars[r][c] + + val = eval_expr(char_hue_expr, zone, char) + if val is not None: + hue_shifts[r, c] = float(val) + + val = eval_expr(char_sat_expr, zone, char) + if val is not None: + saturations[r, c] = float(val) + + val = eval_expr(char_bright_expr, zone, char) + if val is not None: + brightness[r, c] = float(val) + + val = eval_expr(char_scale_expr, zone, char) + if val is not None: + scales[r, c] = float(val) + + val = eval_expr(char_rotation_expr, zone, char) + if val is not None: + rotations[r, c] = float(val) + + val = eval_expr(char_jitter_expr, zone, char) + if val is not None: + jitters[r, c] = float(val) + + # Now render with computed transform arrays + out_h, out_w = grid_rows * cell_size, grid_cols * cell_size + bg = list(bg_color) + result = np.full((out_h, out_w, 3), bg, dtype=np.uint8) + + # If cell_effect is provided, use the cell-mapper rendering path + if cell_effect is not None: + return _render_with_cell_effect( + frame, chars, colors, luminances, zone_contexts, + cell_size, bg_color, fg_color, color_mode, + cell_effect, extra_params, interp, env, result + ) + + # Build character atlas + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, render characters larger then transform + max_scale = max(1.0, np.max(scales) * 1.5) + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with per-zone effects + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell values + eff_scale = scales[r, c] + eff_rotation = rotations[r, c] + eff_jitter = jitters[r, c] + eff_hue_shift = hue_shifts[r, c] + eff_brightness = brightness[r, c] + eff_saturation = saturations[r, c] + + # Apply transformations to character + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + cy1 = max(0, r * cell_size) + cy2 = min(out_h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(out_w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode - use source colors + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply saturation adjustment + if abs(eff_saturation - 1.0) > 0.01 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_sat = np.clip(int(color_hsv[0, 0, 1] * eff_saturation), 0, 255) + color_hsv[0, 0, 1] = np.uint8(new_sat) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply brightness adjustment + if abs(eff_brightness - 1.0) > 0.01: + color = np.clip(color.astype(np.float32) * eff_brightness, 0, 255).astype(np.uint8) + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < out_h and 0 <= px < out_w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_make_char_grid(rows: int, cols: int, fill_char: str = " ") -> List[List[str]]: + """Create a character grid filled with a character.""" + return [[fill_char for _ in range(cols)] for _ in range(rows)] + + +def prim_set_char(chars: List[List[str]], row: int, col: int, char: str) -> List[List[str]]: + """Set a character at position (returns modified copy).""" + result = [r[:] for r in chars] # shallow copy rows + if 0 <= row < len(result) and 0 <= col < len(result[0]): + result[row][col] = char + return result + + +def prim_get_char(chars: List[List[str]], row: int, col: int) -> str: + """Get character at position.""" + if 0 <= row < len(chars) and 0 <= col < len(chars[0]): + return chars[row][col] + return " " + + +def prim_char_grid_dimensions(chars: List[List[str]]) -> Tuple[int, int]: + """Get (rows, cols) of character grid.""" + if not chars: + return (0, 0) + return (len(chars), len(chars[0]) if chars[0] else 0) + + +def prim_alphabet_char(alphabet: str, index: int) -> str: + """Get character at index from alphabet (wraps around).""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + if not chars: + return " " + return chars[int(index) % len(chars)] + + +def prim_alphabet_length(alphabet: str) -> int: + """Get length of alphabet.""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + return len(chars) + + +def prim_map_char_grid(chars: List[List[str]], luminances: np.ndarray, fn: Callable) -> List[List[str]]: + """ + Map a function over character grid. + + fn receives (row, col, char, luminance) and returns new character. + This allows per-cell character selection based on position, brightness, etc. + + Example: + (map-char-grid chars luminances + (lambda (r c ch lum) + (if (> lum 128) + (alphabet-char "blocks" (floor (/ lum 50))) + ch))) + """ + if not chars or not chars[0]: + return chars + + rows = len(chars) + cols = len(chars[0]) + result = [] + + for r in range(rows): + row = [] + for c in range(cols): + ch = chars[r][c] + lum = float(luminances[r, c]) if r < luminances.shape[0] and c < luminances.shape[1] else 0 + new_ch = fn(r, c, ch, lum) + row.append(str(new_ch) if new_ch else " ") + result.append(row) + + return result + + +def prim_map_colors(colors: np.ndarray, fn: Callable) -> np.ndarray: + """ + Map a function over color grid. + + fn receives (row, col, color) and returns new [r, g, b]. + Color is a list [r, g, b]. + """ + if colors.size == 0: + return colors + + rows, cols = colors.shape[:2] + result = colors.copy() + + for r in range(rows): + for c in range(cols): + color = list(colors[r, c]) + new_color = fn(r, c, color) + if new_color is not None: + result[r, c] = new_color[:3] + + return result + + +# ============================================================================= +# Glitch Art Primitives +# ============================================================================= + +def prim_pixelsort(img: np.ndarray, sort_by: str = "lightness", + threshold_low: float = 50, threshold_high: float = 200, + angle: float = 0, reverse: bool = False) -> np.ndarray: + """ + Pixel sorting glitch effect. + + Args: + img: source image + sort_by: "lightness", "hue", "saturation", "red", "green", "blue" + threshold_low: pixels below this aren't sorted + threshold_high: pixels above this aren't sorted + angle: 0 = horizontal, 90 = vertical + reverse: reverse sort order + """ + h, w = img.shape[:2] + + # Rotate for vertical sorting + if 45 <= (angle % 180) <= 135: + frame = np.transpose(img, (1, 0, 2)) + h, w = frame.shape[:2] + rotated = True + else: + frame = img + rotated = False + + result = frame.copy() + + # Get sort values + if sort_by == "lightness": + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + elif sort_by == "hue": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 0].astype(np.float32) + elif sort_by == "saturation": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 1].astype(np.float32) + elif sort_by == "red": + sort_values = frame[:, :, 0].astype(np.float32) + elif sort_by == "green": + sort_values = frame[:, :, 1].astype(np.float32) + elif sort_by == "blue": + sort_values = frame[:, :, 2].astype(np.float32) + else: + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + + # Create mask + mask = (sort_values >= threshold_low) & (sort_values <= threshold_high) + + # Sort each row + for y in range(h): + row = result[y].copy() + row_mask = mask[y] + row_values = sort_values[y] + + # Find contiguous segments + segments = [] + start = None + for i, val in enumerate(row_mask): + if val and start is None: + start = i + elif not val and start is not None: + segments.append((start, i)) + start = None + if start is not None: + segments.append((start, len(row_mask))) + + # Sort each segment + for seg_start, seg_end in segments: + if seg_end - seg_start > 1: + segment_values = row_values[seg_start:seg_end] + sort_indices = np.argsort(segment_values) + if reverse: + sort_indices = sort_indices[::-1] + row[seg_start:seg_end] = row[seg_start:seg_end][sort_indices] + + result[y] = row + + # Rotate back + if rotated: + result = np.transpose(result, (1, 0, 2)) + + return np.ascontiguousarray(result) + + +def prim_datamosh(img: np.ndarray, prev_frame: np.ndarray, + block_size: int = 32, corruption: float = 0.3, + max_offset: int = 50, color_corrupt: bool = True) -> np.ndarray: + """ + Datamosh/glitch block corruption effect. + + Args: + img: current frame + prev_frame: previous frame (or None) + block_size: size of corruption blocks + corruption: probability 0-1 of corrupting each block + max_offset: maximum pixel shift + color_corrupt: also apply color channel shifts + """ + if corruption <= 0: + return img.copy() + + block_size = max(8, min(int(block_size), 128)) + h, w = img.shape[:2] + result = img.copy() + + for by in range(0, h, block_size): + for bx in range(0, w, block_size): + bh = min(block_size, h - by) + bw = min(block_size, w - bx) + + if _rng.random() < corruption: + corruption_type = _rng.randint(0, 3) + + if corruption_type == 0 and max_offset > 0: + # Shift + ox = _rng.randint(-max_offset, max_offset) + oy = _rng.randint(-max_offset, max_offset) + src_x = max(0, min(bx + ox, w - bw)) + src_y = max(0, min(by + oy, h - bh)) + result[by:by+bh, bx:bx+bw] = img[src_y:src_y+bh, src_x:src_x+bw] + + elif corruption_type == 1 and prev_frame is not None: + # Duplicate from previous frame + if prev_frame.shape == img.shape: + result[by:by+bh, bx:bx+bw] = prev_frame[by:by+bh, bx:bx+bw] + + elif corruption_type == 2 and color_corrupt: + # Color channel shift + block = result[by:by+bh, bx:bx+bw].copy() + shift = _rng.randint(1, 3) + channel = _rng.randint(0, 2) + block[:, :, channel] = np.roll(block[:, :, channel], shift, axis=0) + result[by:by+bh, bx:bx+bw] = block + + else: + # Swap with another block + other_bx = _rng.randint(0, max(0, w - bw)) + other_by = _rng.randint(0, max(0, h - bh)) + temp = result[by:by+bh, bx:bx+bw].copy() + result[by:by+bh, bx:bx+bw] = img[other_by:other_by+bh, other_bx:other_bx+bw] + result[other_by:other_by+bh, other_bx:other_bx+bw] = temp + + return result + + +def prim_ripple_displace(w: int, h: int, freq: float, amp: float, cx: float = None, cy: float = None, + decay: float = 0, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create radial ripple displacement maps. + + Args: + w, h: dimensions + freq: ripple frequency + amp: ripple amplitude in pixels + cx, cy: center + decay: how fast ripples decay with distance (0 = no decay) + phase: phase offset + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + + # Calculate ripple displacement (radial) + ripple = np.sin(2 * np.pi * freq * dist / max(w, h) + phase) * amp + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max(w, h)) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return (map_x, map_y) + + +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Image + 'width': prim_width, + 'height': prim_height, + 'make-image': prim_make_image, + 'copy': prim_copy, + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, + + # Color + 'rgb': prim_rgb, + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'blend-color': prim_blend_color, + 'average-color': prim_average_color, + + # Vectorized bulk operations + 'color-matrix': prim_color_matrix, + 'adjust': prim_adjust, + 'mix-gray': prim_mix_gray, + 'invert-img': prim_invert_img, + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'shift-hsv': prim_shift_hsv, + + # Bulk operations + 'map-pixels': prim_map_pixels, + 'map-rows': prim_map_rows, + 'for-grid': prim_for_grid, + 'fold-pixels': prim_fold_pixels, + + # Filters + 'convolve': prim_convolve, + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'edges': prim_edges, + 'sobel': prim_sobel, + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Geometry + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + 'mask': prim_mask, + + # Drawing + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + 'fill-rect': prim_fill_rect, + 'draw-line': prim_draw_line, + + # Math + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'atan2': prim_atan2, + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'mod': prim_mod, + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + 'assert': prim_assert, + 'pi': math.pi, + 'tau': math.tau, + + # Array + 'length': prim_length, + 'len': prim_length, # alias + 'nth': prim_nth, + 'first': prim_first, + 'rest': prim_rest, + 'take': prim_take, + 'drop': prim_drop, + 'cons': prim_cons, + 'append': prim_append, + 'reverse': prim_reverse, + 'range': prim_range, + 'roll': prim_roll, + 'list': prim_list, + + # Array math (vectorized operations on coordinate arrays) + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-neg': prim_arr_neg, + 'arr-exp': prim_arr_exp, + 'arr-atan2': prim_arr_atan2, + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-lerp': prim_arr_lerp, + + # Coordinate transformations + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, + 'normalize-coords': prim_normalize_coords, + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'make-coords-centered': prim_make_coords_centered, + + # Specialized distortion maps + 'wave-displace': prim_wave_displace, + 'swirl-displace': prim_swirl_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + 'ripple-displace': prim_ripple_displace, + + # Character/ASCII art + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'ascii-fx-zone': prim_ascii_fx_zone, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + + # Glitch art + 'pixelsort': prim_pixelsort, + 'datamosh': prim_datamosh, + +} diff --git a/l1/sexp_effects/test_interpreter.py b/l1/sexp_effects/test_interpreter.py new file mode 100644 index 0000000..550b21a --- /dev/null +++ b/l1/sexp_effects/test_interpreter.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Test the S-expression effect interpreter. +""" + +import numpy as np +import sys +from pathlib import Path + +# Add parent to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sexp_effects import ( + get_interpreter, + load_effects_dir, + run_effect, + list_effects, + parse, +) + + +def test_parser(): + """Test S-expression parser.""" + print("Testing parser...") + + # Simple expressions + assert parse("42") == 42 + assert parse("3.14") == 3.14 + assert parse('"hello"') == "hello" + assert parse("true") == True + + # Lists + assert parse("(+ 1 2)")[0].name == "+" + assert parse("(+ 1 2)")[1] == 1 + + # Nested + expr = parse("(define x (+ 1 2))") + assert expr[0].name == "define" + + print(" Parser OK") + + +def test_interpreter_basics(): + """Test basic interpreter operations.""" + print("Testing interpreter basics...") + + interp = get_interpreter() + + # Math + assert interp.eval(parse("(+ 1 2)")) == 3 + assert interp.eval(parse("(* 3 4)")) == 12 + assert interp.eval(parse("(- 10 3)")) == 7 + + # Comparison + assert interp.eval(parse("(< 1 2)")) == True + assert interp.eval(parse("(> 1 2)")) == False + + # Let binding + assert interp.eval(parse("(let ((x 5)) x)")) == 5 + assert interp.eval(parse("(let ((x 5) (y 3)) (+ x y))")) == 8 + + # Lambda + result = interp.eval(parse("((lambda (x) (* x 2)) 5)")) + assert result == 10 + + # If + assert interp.eval(parse("(if true 1 2)")) == 1 + assert interp.eval(parse("(if false 1 2)")) == 2 + + print(" Interpreter basics OK") + + +def test_primitives(): + """Test image primitives.""" + print("Testing primitives...") + + interp = get_interpreter() + + # Create test image + img = np.zeros((100, 100, 3), dtype=np.uint8) + img[50, 50] = [255, 128, 64] + + interp.global_env.set('test_img', img) + + # Width/height + assert interp.eval(parse("(width test_img)")) == 100 + assert interp.eval(parse("(height test_img)")) == 100 + + # Pixel + pixel = interp.eval(parse("(pixel test_img 50 50)")) + assert pixel == [255, 128, 64] + + # RGB + color = interp.eval(parse("(rgb 100 150 200)")) + assert color == [100, 150, 200] + + # Luminance + lum = interp.eval(parse("(luminance (rgb 100 100 100))")) + assert abs(lum - 100) < 1 + + print(" Primitives OK") + + +def test_effect_loading(): + """Test loading effects from .sexp files.""" + print("Testing effect loading...") + + # Load all effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + effects = list_effects() + print(f" Loaded {len(effects)} effects: {', '.join(sorted(effects))}") + + assert len(effects) > 0 + print(" Effect loading OK") + + +def test_effect_execution(): + """Test running effects on images.""" + print("Testing effect execution...") + + # Create test image + img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + + # Load effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Test each effect + effects = list_effects() + passed = 0 + failed = [] + + for name in sorted(effects): + try: + result, state = run_effect(name, img.copy(), {'_time': 0.5}, {}) + assert isinstance(result, np.ndarray) + assert result.shape == img.shape + passed += 1 + print(f" {name}: OK") + except Exception as e: + failed.append((name, str(e))) + print(f" {name}: FAILED - {e}") + + print(f" Passed: {passed}/{len(effects)}") + if failed: + print(f" Failed: {[f[0] for f in failed]}") + + return passed, failed + + +def test_ascii_fx_zone(): + """Test ascii_fx_zone effect with zone expressions.""" + print("Testing ascii_fx_zone...") + + interp = get_interpreter() + + # Load the effect + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Create gradient test frame + frame = np.zeros((120, 160, 3), dtype=np.uint8) + for x in range(160): + frame[:, x] = int(x / 160 * 255) + frame = np.stack([frame[:,:,0]]*3, axis=2) + + # Test 1: Basic without expressions + result, _ = run_effect('ascii_fx_zone', frame, {'cols': 20}, {}) + assert result.shape == frame.shape + print(" Basic run: OK") + + # Test 2: With zone-lum expression + expr = parse('(* zone-lum 180)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': expr + }, {}) + assert result.shape == frame.shape + print(" Zone-lum expression: OK") + + # Test 3: With multiple expressions + scale_expr = parse('(+ 0.5 (* zone-lum 0.5))') + rot_expr = parse('(* zone-row-norm 30)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_scale': scale_expr, + 'char_rotation': rot_expr + }, {}) + assert result.shape == frame.shape + print(" Multiple expressions: OK") + + # Test 4: With numeric literals + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': 90, + 'char_scale': 1.2 + }, {}) + assert result.shape == frame.shape + print(" Numeric literals: OK") + + # Test 5: Zone position expressions + col_expr = parse('(* zone-col-norm 360)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': col_expr + }, {}) + assert result.shape == frame.shape + print(" Zone position expression: OK") + + print(" ascii_fx_zone OK") + + +def main(): + print("=" * 60) + print("S-Expression Effect Interpreter Tests") + print("=" * 60) + + test_parser() + test_interpreter_basics() + test_primitives() + test_effect_loading() + test_ascii_fx_zone() + passed, failed = test_effect_execution() + + print("=" * 60) + if not failed: + print("All tests passed!") + else: + print(f"Tests completed with {len(failed)} failures") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/l1/sexp_effects/wgsl_compiler.py b/l1/sexp_effects/wgsl_compiler.py new file mode 100644 index 0000000..0c8b024 --- /dev/null +++ b/l1/sexp_effects/wgsl_compiler.py @@ -0,0 +1,715 @@ +""" +S-Expression to WGSL Compiler + +Compiles sexp effect definitions to WGSL compute shaders for GPU execution. +The compilation happens at effect upload time (AOT), not at runtime. + +Architecture: +- Parse sexp AST +- Analyze primitives used +- Generate WGSL compute shader + +Shader Categories: +1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel) +2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample) +3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors) +""" + +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from pathlib import Path +import math + +from .parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class WGSLParam: + """A shader parameter (uniform).""" + name: str + wgsl_type: str # f32, i32, u32, vec2f, etc. + default: Any + + +@dataclass +class CompiledEffect: + """Result of compiling an sexp effect to WGSL.""" + name: str + wgsl_code: str + params: List[WGSLParam] + workgroup_size: Tuple[int, int, int] = (16, 16, 1) + # Metadata for runtime + uses_time: bool = False + uses_sampling: bool = False # Needs texture sampler + category: str = "per_pixel" # per_pixel, geometric, neighborhood + + +@dataclass +class CompilerContext: + """Context during compilation.""" + effect_name: str = "" + params: Dict[str, WGSLParam] = field(default_factory=dict) + locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr + required_libs: Set[str] = field(default_factory=set) + uses_time: bool = False + uses_sampling: bool = False + temp_counter: int = 0 + + def fresh_temp(self) -> str: + """Generate a fresh temporary variable name.""" + self.temp_counter += 1 + return f"_t{self.temp_counter}" + + +class SexpToWGSLCompiler: + """ + Compiles S-expression effect definitions to WGSL compute shaders. + """ + + # Map sexp types to WGSL types + TYPE_MAP = { + 'int': 'i32', + 'float': 'f32', + 'bool': 'u32', # WGSL doesn't have bool in storage + 'string': None, # Strings handled specially + } + + # Per-pixel primitives that can be compiled directly + PER_PIXEL_PRIMITIVES = { + 'color_ops:invert-img', + 'color_ops:grayscale', + 'color_ops:sepia', + 'color_ops:adjust', + 'color_ops:adjust-brightness', + 'color_ops:shift-hsv', + 'color_ops:quantize', + } + + # Geometric primitives (coordinate remapping) + GEOMETRIC_PRIMITIVES = { + 'geometry:scale-img', + 'geometry:rotate-img', + 'geometry:translate', + 'geometry:flip-h', + 'geometry:flip-v', + 'geometry:remap', + } + + def __init__(self): + self.ctx: Optional[CompilerContext] = None + + def compile_file(self, path: str) -> CompiledEffect: + """Compile an effect from a .sexp file.""" + with open(path, 'r') as f: + content = f.read() + exprs = parse_all(content) + return self.compile(exprs) + + def compile_string(self, sexp_code: str) -> CompiledEffect: + """Compile an effect from an sexp string.""" + exprs = parse_all(sexp_code) + return self.compile(exprs) + + def compile(self, expr: Any) -> CompiledEffect: + """Compile a parsed sexp expression.""" + self.ctx = CompilerContext() + + # Handle multiple top-level expressions (require-primitives, define-effect) + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self._process_toplevel(e) + else: + self._process_toplevel(expr) + + # Generate the WGSL shader + wgsl = self._generate_wgsl() + + # Determine category based on primitives used + category = self._determine_category() + + return CompiledEffect( + name=self.ctx.effect_name, + wgsl_code=wgsl, + params=list(self.ctx.params.values()), + uses_time=self.ctx.uses_time, + uses_sampling=self.ctx.uses_sampling, + category=category, + ) + + def _process_toplevel(self, expr: Any): + """Process a top-level expression.""" + if not isinstance(expr, list) or not expr: + return + + head = expr[0] + if isinstance(head, Symbol): + if head.name == 'require-primitives': + # Track required primitive libraries + for lib in expr[1:]: + lib_name = lib.name if isinstance(lib, Symbol) else str(lib) + self.ctx.required_libs.add(lib_name) + + elif head.name == 'define-effect': + self._compile_effect_def(expr) + + def _compile_effect_def(self, expr: list): + """Compile a define-effect form.""" + # (define-effect name :params (...) body) + self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1]) + + # Parse :params and body + i = 2 + body = None + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword) and item.name == 'params': + self._parse_params(expr[i + 1]) + i += 2 + elif isinstance(item, Keyword): + i += 2 # Skip other keywords + else: + body = item + i += 1 + + if body: + self.ctx.body_expr = body + + def _parse_params(self, params_list: list): + """Parse the :params block.""" + for param_def in params_list: + if not isinstance(param_def, list): + continue + + name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0]) + + # Parse keyword args + param_type = 'float' + default = 0 + + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 < len(param_def): + val = param_def[i + 1] + if item.name == 'type': + param_type = val.name if isinstance(val, Symbol) else str(val) + elif item.name == 'default': + default = val + i += 2 + else: + i += 1 + + wgsl_type = self.TYPE_MAP.get(param_type, 'f32') + if wgsl_type: + self.ctx.params[name] = WGSLParam(name, wgsl_type, default) + + def _determine_category(self) -> str: + """Determine shader category based on primitives used.""" + for lib in self.ctx.required_libs: + if lib == 'geometry': + return 'geometric' + if lib == 'filters': + return 'neighborhood' + return 'per_pixel' + + def _generate_wgsl(self) -> str: + """Generate the complete WGSL shader code.""" + lines = [] + + # Header comment + lines.append(f"// WGSL Shader: {self.ctx.effect_name}") + lines.append(f"// Auto-generated from sexp effect definition") + lines.append("") + + # Bindings + lines.append("@group(0) @binding(0) var input: array;") + lines.append("@group(0) @binding(1) var output: array;") + lines.append("") + + # Params struct + if self.ctx.params: + lines.append("struct Params {") + lines.append(" width: u32,") + lines.append(" height: u32,") + lines.append(" time: f32,") + for param in self.ctx.params.values(): + lines.append(f" {param.name}: {param.wgsl_type},") + lines.append("}") + lines.append("@group(0) @binding(2) var params: Params;") + else: + lines.append("struct Params {") + lines.append(" width: u32,") + lines.append(" height: u32,") + lines.append(" time: f32,") + lines.append("}") + lines.append("@group(0) @binding(2) var params: Params;") + lines.append("") + + # Helper functions + lines.extend(self._generate_helpers()) + lines.append("") + + # Main compute shader + lines.append("@compute @workgroup_size(16, 16, 1)") + lines.append("fn main(@builtin(global_invocation_id) gid: vec3) {") + lines.append(" let x = gid.x;") + lines.append(" let y = gid.y;") + lines.append(" if (x >= params.width || y >= params.height) { return; }") + lines.append(" let idx = y * params.width + x;") + lines.append("") + + # Compile the effect body + body_code = self._compile_expr(self.ctx.body_expr) + lines.append(f" // Effect: {self.ctx.effect_name}") + lines.append(body_code) + lines.append("}") + + return "\n".join(lines) + + def _generate_helpers(self) -> List[str]: + """Generate WGSL helper functions.""" + helpers = [] + + # Pack/unpack RGB from u32 + helpers.append("fn unpack_rgb(packed: u32) -> vec3 {") + helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;") + helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;") + helpers.append(" let b = f32(packed & 0xFFu) / 255.0;") + helpers.append(" return vec3(r, g, b);") + helpers.append("}") + helpers.append("") + + helpers.append("fn pack_rgb(rgb: vec3) -> u32 {") + helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);") + helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);") + helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);") + helpers.append(" return (r << 16u) | (g << 8u) | b;") + helpers.append("}") + helpers.append("") + + # Bilinear sampling for geometric transforms + if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs: + helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3 {") + helpers.append(" let w = f32(params.width);") + helpers.append(" let h = f32(params.height);") + helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);") + helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);") + helpers.append(" let x0 = u32(cx);") + helpers.append(" let y0 = u32(cy);") + helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);") + helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);") + helpers.append(" let fx = cx - f32(x0);") + helpers.append(" let fy = cy - f32(y0);") + helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);") + helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);") + helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);") + helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);") + helpers.append(" let top = mix(c00, c10, fx);") + helpers.append(" let bot = mix(c01, c11, fx);") + helpers.append(" return mix(top, bot, fy);") + helpers.append("}") + helpers.append("") + + # HSV conversion for color effects + if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs: + helpers.append("fn rgb_to_hsv(rgb: vec3) -> vec3 {") + helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);") + helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);") + helpers.append(" let d = mx - mn;") + helpers.append(" var h = 0.0;") + helpers.append(" if (d > 0.0) {") + helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }") + helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }") + helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }") + helpers.append(" h = h / 6.0;") + helpers.append(" if (h < 0.0) { h = h + 1.0; }") + helpers.append(" }") + helpers.append(" let s = select(0.0, d / mx, mx > 0.0);") + helpers.append(" return vec3(h, s, mx);") + helpers.append("}") + helpers.append("") + + helpers.append("fn hsv_to_rgb(hsv: vec3) -> vec3 {") + helpers.append(" let h = hsv.x * 6.0;") + helpers.append(" let s = hsv.y;") + helpers.append(" let v = hsv.z;") + helpers.append(" let c = v * s;") + helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));") + helpers.append(" let m = v - c;") + helpers.append(" var rgb: vec3;") + helpers.append(" if (h < 1.0) { rgb = vec3(c, x, 0.0); }") + helpers.append(" else if (h < 2.0) { rgb = vec3(x, c, 0.0); }") + helpers.append(" else if (h < 3.0) { rgb = vec3(0.0, c, x); }") + helpers.append(" else if (h < 4.0) { rgb = vec3(0.0, x, c); }") + helpers.append(" else if (h < 5.0) { rgb = vec3(x, 0.0, c); }") + helpers.append(" else { rgb = vec3(c, 0.0, x); }") + helpers.append(" return rgb + vec3(m, m, m);") + helpers.append("}") + helpers.append("") + + return helpers + + def _compile_expr(self, expr: Any, indent: int = 4) -> str: + """Compile an sexp expression to WGSL code.""" + ind = " " * indent + + # Literals + if isinstance(expr, (int, float)): + return f"{ind}// literal: {expr}" + + if isinstance(expr, str): + return f'{ind}// string: "{expr}"' + + # Symbol reference + if isinstance(expr, Symbol): + name = expr.name + if name == 'frame': + return f"{ind}let rgb = unpack_rgb(input[idx]);" + if name == 't' or name == '_time': + self.ctx.uses_time = True + return f"{ind}let t = params.time;" + if name in self.ctx.params: + return f"{ind}let {name} = params.{name};" + if name in self.ctx.locals: + return f"{ind}// local: {name}" + return f"{ind}// unknown symbol: {name}" + + # List (function call or special form) + if isinstance(expr, list) and expr: + head = expr[0] + + if isinstance(head, Symbol): + form = head.name + + # Special forms + if form == 'let' or form == 'let*': + return self._compile_let(expr, indent) + + if form == 'if': + return self._compile_if(expr, indent) + + if form == 'or': + # (or a b) - return a if truthy, else b + return self._compile_or(expr, indent) + + # Primitive calls + if ':' in form: + return self._compile_primitive_call(expr, indent) + + # Arithmetic + if form in ('+', '-', '*', '/'): + return self._compile_arithmetic(expr, indent) + + if form in ('>', '<', '>=', '<=', '='): + return self._compile_comparison(expr, indent) + + if form == 'max': + return self._compile_builtin('max', expr[1:], indent) + + if form == 'min': + return self._compile_builtin('min', expr[1:], indent) + + return f"{ind}// unhandled: {expr}" + + def _compile_let(self, expr: list, indent: int) -> str: + """Compile let/let* binding form.""" + ind = " " * indent + lines = [] + + bindings = expr[1] + body = expr[2] + + # Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2))) + pairs = [] + if bindings and isinstance(bindings[0], Symbol): + # Clojure style + i = 0 + while i < len(bindings) - 1: + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + value = bindings[i + 1] + pairs.append((name, value)) + i += 2 + else: + # Scheme style + for binding in bindings: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + value = binding[1] + pairs.append((name, value)) + + # Compile bindings + for name, value in pairs: + val_code = self._expr_to_wgsl(value) + lines.append(f"{ind}let {name} = {val_code};") + self.ctx.locals[name] = val_code + + # Compile body + body_lines = self._compile_body(body, indent) + lines.append(body_lines) + + return "\n".join(lines) + + def _compile_body(self, body: Any, indent: int) -> str: + """Compile the body of an effect (the final image expression).""" + ind = " " * indent + + # Most effects end with a primitive call that produces the output + if isinstance(body, list) and body: + head = body[0] + if isinstance(head, Symbol) and ':' in head.name: + return self._compile_primitive_call(body, indent) + + # If body is just 'frame', pass through + if isinstance(body, Symbol) and body.name == 'frame': + return f"{ind}output[idx] = input[idx];" + + return f"{ind}// body: {body}" + + def _compile_primitive_call(self, expr: list, indent: int) -> str: + """Compile a primitive function call.""" + ind = " " * indent + head = expr[0] + prim_name = head.name if isinstance(head, Symbol) else str(head) + args = expr[1:] + + # Per-pixel color operations + if prim_name == 'color_ops:invert-img': + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let result = vec3(1.0, 1.0, 1.0) - rgb; +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:grayscale': + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b; +{ind}let result = vec3(gray, gray, gray); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:adjust-brightness': + amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let adj = f32({amount}) / 255.0; +{ind}let result = clamp(rgb + vec3(adj, adj, adj), vec3(0.0, 0.0, 0.0), vec3(1.0, 1.0, 1.0)); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:adjust': + # (adjust img brightness contrast) + brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let centered = rgb - vec3(0.5, 0.5, 0.5); +{ind}let contrasted = centered * {contrast}; +{ind}let brightened = contrasted + vec3(0.5, 0.5, 0.5) + vec3({brightness}/255.0); +{ind}let result = clamp(brightened, vec3(0.0), vec3(1.0)); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:sepia': + intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b; +{ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b; +{ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b; +{ind}let sepia = vec3(sepia_r, sepia_g, sepia_b); +{ind}let result = mix(rgb, sepia, {intensity}); +{ind}output[idx] = pack_rgb(clamp(result, vec3(0.0), vec3(1.0)));""" + + if prim_name == 'color_ops:shift-hsv': + h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" + v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}var hsv = rgb_to_hsv(rgb); +{ind}hsv.x = fract(hsv.x + {h_shift} / 360.0); +{ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0); +{ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0); +{ind}let result = hsv_to_rgb(hsv); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:quantize': + levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let lvl = max(2.0, {levels}); +{ind}let result = floor(rgb * lvl) / lvl; +{ind}output[idx] = pack_rgb(result);""" + + # Geometric transforms + if prim_name == 'geometry:scale-img': + sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" + sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx + self.ctx.uses_sampling = True + return f"""{ind}let w = f32(params.width); +{ind}let h = f32(params.height); +{ind}let cx = w / 2.0; +{ind}let cy = h / 2.0; +{ind}let sx = f32(x) - cx; +{ind}let sy = f32(y) - cy; +{ind}let src_x = sx / {sx} + cx; +{ind}let src_y = sy / {sy} + cy; +{ind}let result = sample_bilinear(src_x, src_y); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'geometry:rotate-img': + angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + self.ctx.uses_sampling = True + return f"""{ind}let w = f32(params.width); +{ind}let h = f32(params.height); +{ind}let cx = w / 2.0; +{ind}let cy = h / 2.0; +{ind}let angle_rad = {angle} * 3.14159265 / 180.0; +{ind}let cos_a = cos(-angle_rad); +{ind}let sin_a = sin(-angle_rad); +{ind}let dx = f32(x) - cx; +{ind}let dy = f32(y) - cy; +{ind}let src_x = dx * cos_a - dy * sin_a + cx; +{ind}let src_y = dx * sin_a + dy * cos_a + cy; +{ind}let result = sample_bilinear(src_x, src_y); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'geometry:flip-h': + return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x); +{ind}output[idx] = input[src_idx];""" + + if prim_name == 'geometry:flip-v': + return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x; +{ind}output[idx] = input[src_idx];""" + + # Image library + if prim_name == 'image:blur': + radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5" + # Box blur approximation (separable would be better) + return f"""{ind}let radius = i32({radius}); +{ind}var sum = vec3(0.0, 0.0, 0.0); +{ind}var count = 0.0; +{ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{ +{ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{ +{ind} let sx = i32(x) + dx; +{ind} let sy = i32(y) + dy; +{ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{ +{ind} let sidx = u32(sy) * params.width + u32(sx); +{ind} sum = sum + unpack_rgb(input[sidx]); +{ind} count = count + 1.0; +{ind} }} +{ind} }} +{ind}}} +{ind}let result = sum / count; +{ind}output[idx] = pack_rgb(result);""" + + # Fallback - passthrough + return f"""{ind}// Unimplemented primitive: {prim_name} +{ind}output[idx] = input[idx];""" + + def _compile_if(self, expr: list, indent: int) -> str: + """Compile if expression.""" + ind = " " * indent + cond = self._expr_to_wgsl(expr[1]) + then_expr = expr[2] + else_expr = expr[3] if len(expr) > 3 else None + + lines = [] + lines.append(f"{ind}if ({cond}) {{") + lines.append(self._compile_body(then_expr, indent + 4)) + if else_expr: + lines.append(f"{ind}}} else {{") + lines.append(self._compile_body(else_expr, indent + 4)) + lines.append(f"{ind}}}") + + return "\n".join(lines) + + def _compile_or(self, expr: list, indent: int) -> str: + """Compile or expression - returns first truthy value.""" + # For numeric context, (or a b) means "a if a != 0 else b" + a = self._expr_to_wgsl(expr[1]) + b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0" + return f"select({b}, {a}, {a} != 0.0)" + + def _compile_arithmetic(self, expr: list, indent: int) -> str: + """Compile arithmetic expression to inline WGSL.""" + op = expr[0].name + operands = [self._expr_to_wgsl(arg) for arg in expr[1:]] + + if len(operands) == 1: + if op == '-': + return f"(-{operands[0]})" + return operands[0] + + return f"({f' {op} '.join(operands)})" + + def _compile_comparison(self, expr: list, indent: int) -> str: + """Compile comparison expression.""" + op = expr[0].name + if op == '=': + op = '==' + a = self._expr_to_wgsl(expr[1]) + b = self._expr_to_wgsl(expr[2]) + return f"({a} {op} {b})" + + def _compile_builtin(self, fn: str, args: list, indent: int) -> str: + """Compile builtin function call.""" + compiled_args = [self._expr_to_wgsl(arg) for arg in args] + return f"{fn}({', '.join(compiled_args)})" + + def _expr_to_wgsl(self, expr: Any) -> str: + """Convert an expression to inline WGSL code.""" + if isinstance(expr, (int, float)): + # Ensure floats have decimal point + if isinstance(expr, float) or '.' not in str(expr): + return f"{float(expr)}" + return str(expr) + + if isinstance(expr, str): + return f'"{expr}"' + + if isinstance(expr, Symbol): + name = expr.name + if name == 'frame': + return "rgb" # Assume rgb is already loaded + if name == 't' or name == '_time': + self.ctx.uses_time = True + return "params.time" + if name == 'pi': + return "3.14159265" + if name in self.ctx.params: + return f"params.{name}" + if name in self.ctx.locals: + return name + return name + + if isinstance(expr, list) and expr: + head = expr[0] + if isinstance(head, Symbol): + form = head.name + + # Arithmetic + if form in ('+', '-', '*', '/'): + return self._compile_arithmetic(expr, 0) + + # Comparison + if form in ('>', '<', '>=', '<=', '='): + return self._compile_comparison(expr, 0) + + # Builtins + if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'): + args = [self._expr_to_wgsl(a) for a in expr[1:]] + return f"{form}({', '.join(args)})" + + if form == 'or': + return self._compile_or(expr, 0) + + # Image dimension queries + if form == 'image:width': + return "f32(params.width)" + if form == 'image:height': + return "f32(params.height)" + + return f"/* unknown: {expr} */" + + +def compile_effect(sexp_code: str) -> CompiledEffect: + """Convenience function to compile an sexp effect string.""" + compiler = SexpToWGSLCompiler() + return compiler.compile_string(sexp_code) + + +def compile_effect_file(path: str) -> CompiledEffect: + """Convenience function to compile an sexp effect file.""" + compiler = SexpToWGSLCompiler() + return compiler.compile_file(path) diff --git a/l1/storage_providers.py b/l1/storage_providers.py new file mode 100644 index 0000000..1cee65d --- /dev/null +++ b/l1/storage_providers.py @@ -0,0 +1,1009 @@ +""" +Storage provider abstraction for user-attachable storage. + +Supports: +- Pinata (IPFS pinning service) +- web3.storage (IPFS pinning service) +- Local filesystem storage +""" + +import hashlib +import json +import logging +import os +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + + +class StorageProvider(ABC): + """Abstract base class for storage backends.""" + + provider_type: str = "unknown" + + @abstractmethod + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """ + Pin content to storage. + + Args: + cid: SHA3-256 hash of the content + data: Raw bytes to store + filename: Optional filename hint + + Returns: + IPFS CID or provider-specific ID, or None on failure + """ + pass + + @abstractmethod + async def unpin(self, cid: str) -> bool: + """ + Unpin content from storage. + + Args: + cid: SHA3-256 hash of the content + + Returns: + True if unpinned successfully + """ + pass + + @abstractmethod + async def get(self, cid: str) -> Optional[bytes]: + """ + Retrieve content from storage. + + Args: + cid: SHA3-256 hash of the content + + Returns: + Raw bytes or None if not found + """ + pass + + @abstractmethod + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned in this storage.""" + pass + + @abstractmethod + async def test_connection(self) -> tuple[bool, str]: + """ + Test connectivity to the storage provider. + + Returns: + (success, message) tuple + """ + pass + + @abstractmethod + def get_usage(self) -> dict: + """ + Get storage usage statistics. + + Returns: + {used_bytes, capacity_bytes, pin_count} + """ + pass + + +class PinataProvider(StorageProvider): + """Pinata IPFS pinning service provider.""" + + provider_type = "pinata" + + def __init__(self, api_key: str, secret_key: str, capacity_gb: int = 1): + self.api_key = api_key + self.secret_key = secret_key + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.pinata.cloud" + self._usage_cache = None + + def _headers(self) -> dict: + return { + "pinata_api_key": self.api_key, + "pinata_secret_api_key": self.secret_key, + } + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Pinata.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{cid[:16]}.bin", data)} + metadata = { + "name": filename or cid[:16], + "keyvalues": {"cid": cid} + } + response = requests.post( + f"{self.base_url}/pinning/pinFileToIPFS", + files=files, + data={"pinataMetadata": json.dumps(metadata)}, + headers=self._headers(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("IpfsHash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Pinata: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Pinata pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Unpin content from Pinata by finding its CID first.""" + try: + import asyncio + + def do_unpin(): + # First find the pin by cid metadata + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return False + + # Unpin each matching CID + for pin in pins: + cid = pin.get("ipfs_pin_hash") + if cid: + resp = requests.delete( + f"{self.base_url}/pinning/unpin/{cid}", + headers=self._headers(), + timeout=30 + ) + resp.raise_for_status() + return True + + result = await asyncio.to_thread(do_unpin) + logger.info(f"Pinata: Unpinned {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Pinata unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Pinata via IPFS gateway.""" + try: + import asyncio + + def do_get(): + # First find the CID + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return None + + cid = pins[0].get("ipfs_pin_hash") + if not cid: + return None + + # Fetch from gateway + gateway_response = requests.get( + f"https://gateway.pinata.cloud/ipfs/{cid}", + timeout=120 + ) + gateway_response.raise_for_status() + return gateway_response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Pinata get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned on Pinata.""" + try: + import asyncio + + def do_check(): + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + return len(response.json().get("rows", [])) > 0 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Pinata API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/data/testAuthentication", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Pinata successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Pinata usage stats.""" + try: + response = requests.get( + f"{self.base_url}/data/userPinnedDataTotal", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + data = response.json() + return { + "used_bytes": data.get("pin_size_total", 0), + "capacity_bytes": self.capacity_bytes, + "pin_count": data.get("pin_count", 0) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class Web3StorageProvider(StorageProvider): + """web3.storage pinning service provider.""" + + provider_type = "web3storage" + + def __init__(self, api_token: str, capacity_gb: int = 1): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.web3.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to web3.storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={ + **self._headers(), + "X-Name": filename or cid[:16] + }, + timeout=120 + ) + response.raise_for_status() + return response.json().get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"web3.storage: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"web3.storage pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """web3.storage doesn't support unpinning - data is stored permanently.""" + logger.warning("web3.storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from web3.storage - would need CID mapping.""" + # web3.storage requires knowing the CID to fetch + # For now, return None - we'd need to maintain a mapping + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test web3.storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1}, + timeout=10 + ) + response.raise_for_status() + return True, "Connected to web3.storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get web3.storage usage stats.""" + try: + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1000}, + timeout=30 + ) + response.raise_for_status() + uploads = response.json() + total_size = sum(u.get("dagSize", 0) for u in uploads) + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": len(uploads) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class NFTStorageProvider(StorageProvider): + """NFT.Storage pinning service provider (free for NFT data).""" + + provider_type = "nftstorage" + + def __init__(self, api_token: str, capacity_gb: int = 5): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.nft.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to NFT.Storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={**self._headers(), "Content-Type": "application/octet-stream"}, + timeout=120 + ) + response.raise_for_status() + return response.json().get("value", {}).get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"NFT.Storage: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"NFT.Storage pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """NFT.Storage doesn't support unpinning - data is stored permanently.""" + logger.warning("NFT.Storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from NFT.Storage - would need CID mapping.""" + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test NFT.Storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to NFT.Storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get NFT.Storage usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class InfuraIPFSProvider(StorageProvider): + """Infura IPFS pinning service provider.""" + + provider_type = "infura" + + def __init__(self, project_id: str, project_secret: str, capacity_gb: int = 5): + self.project_id = project_id + self.project_secret = project_secret + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://ipfs.infura.io:5001/api/v0" + + def _auth(self) -> tuple: + return (self.project_id, self.project_secret) + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Infura IPFS.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{cid[:16]}.bin", data)} + response = requests.post( + f"{self.base_url}/add", + files=files, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("Hash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Infura IPFS: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Infura IPFS pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Unpin content from Infura IPFS.""" + try: + import asyncio + + def do_unpin(): + response = requests.post( + f"{self.base_url}/pin/rm", + params={"arg": cid}, + auth=self._auth(), + timeout=30 + ) + response.raise_for_status() + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Infura IPFS unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Infura IPFS gateway.""" + try: + import asyncio + + def do_get(): + response = requests.post( + f"{self.base_url}/cat", + params={"arg": cid}, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Infura IPFS get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned on Infura IPFS.""" + try: + import asyncio + + def do_check(): + response = requests.post( + f"{self.base_url}/pin/ls", + params={"arg": cid}, + auth=self._auth(), + timeout=30 + ) + return response.status_code == 200 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Infura IPFS API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.post( + f"{self.base_url}/id", + auth=self._auth(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Infura IPFS successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid project credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Infura usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class FilebaseProvider(StorageProvider): + """Filebase S3-compatible IPFS pinning service.""" + + provider_type = "filebase" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 5): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://s3.filebase.com" + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or f"{cid[:16]}.bin" + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + # Get CID from response headers + head = s3.head_object(Bucket=self.bucket, Key=key) + return head.get('Metadata', {}).get('cid', cid) + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Filebase: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Filebase pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Filebase unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=cid) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Filebase get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists in Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Filebase connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Filebase bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Filebase usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class StorjProvider(StorageProvider): + """Storj decentralized cloud storage (S3-compatible).""" + + provider_type = "storj" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 25): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://gateway.storjshare.io" + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or cid + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + return cid + + result = await asyncio.to_thread(do_pin) + logger.info(f"Storj: Stored {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Storj pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Storj unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=cid) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Storj get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Storj connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Storj bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Storj usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class LocalStorageProvider(StorageProvider): + """Local filesystem storage provider.""" + + provider_type = "local" + + def __init__(self, base_path: str, capacity_gb: int = 10): + self.base_path = Path(base_path) + self.capacity_bytes = capacity_gb * 1024**3 + # Create directory if it doesn't exist + self.base_path.mkdir(parents=True, exist_ok=True) + + def _get_file_path(self, cid: str) -> Path: + """Get file path for a content hash (using subdirectories).""" + # Use first 2 chars as subdirectory for better filesystem performance + subdir = cid[:2] + return self.base_path / subdir / cid + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content locally.""" + try: + import asyncio + + def do_store(): + file_path = self._get_file_path(cid) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(data) + return cid # Use cid as ID for local storage + + result = await asyncio.to_thread(do_store) + logger.info(f"Local: Stored {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Local storage failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from local storage.""" + try: + import asyncio + + def do_remove(): + file_path = self._get_file_path(cid) + if file_path.exists(): + file_path.unlink() + return True + return False + + return await asyncio.to_thread(do_remove) + except Exception as e: + logger.error(f"Local unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from local storage.""" + try: + import asyncio + + def do_get(): + file_path = self._get_file_path(cid) + if file_path.exists(): + return file_path.read_bytes() + return None + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Local get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists in local storage.""" + return self._get_file_path(cid).exists() + + async def test_connection(self) -> tuple[bool, str]: + """Test local storage is writable.""" + try: + test_file = self.base_path / ".write_test" + test_file.write_text("test") + test_file.unlink() + return True, f"Local storage ready at {self.base_path}" + except Exception as e: + return False, f"Cannot write to {self.base_path}: {e}" + + def get_usage(self) -> dict: + """Get local storage usage stats.""" + try: + total_size = 0 + file_count = 0 + for subdir in self.base_path.iterdir(): + if subdir.is_dir() and len(subdir.name) == 2: + for f in subdir.iterdir(): + if f.is_file(): + total_size += f.stat().st_size + file_count += 1 + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": file_count + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +def create_provider(provider_type: str, config: dict) -> Optional[StorageProvider]: + """ + Factory function to create a storage provider from config. + + Args: + provider_type: One of 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + config: Provider-specific configuration dict + + Returns: + StorageProvider instance or None if invalid + """ + try: + if provider_type == "pinata": + return PinataProvider( + api_key=config["api_key"], + secret_key=config["secret_key"], + capacity_gb=config.get("capacity_gb", 1) + ) + elif provider_type == "web3storage": + return Web3StorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "nftstorage": + return NFTStorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "infura": + return InfuraIPFSProvider( + project_id=config["project_id"], + project_secret=config["project_secret"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "filebase": + return FilebaseProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "storj": + return StorjProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 25) + ) + elif provider_type == "local": + return LocalStorageProvider( + base_path=config["path"], + capacity_gb=config.get("capacity_gb", 10) + ) + else: + logger.error(f"Unknown provider type: {provider_type}") + return None + except KeyError as e: + logger.error(f"Missing config key for {provider_type}: {e}") + return None + except Exception as e: + logger.error(f"Failed to create provider {provider_type}: {e}") + return None diff --git a/l1/streaming/__init__.py b/l1/streaming/__init__.py new file mode 100644 index 0000000..2c007cc --- /dev/null +++ b/l1/streaming/__init__.py @@ -0,0 +1,44 @@ +""" +Streaming video compositor for real-time effect processing. + +This module provides a frame-by-frame streaming architecture that: +- Reads from multiple video sources with automatic looping +- Applies effects inline (no intermediate files) +- Composites layers with time-varying weights +- Outputs to display, file, or stream + +Usage: + from streaming import StreamingCompositor, VideoSource, AudioAnalyzer + + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[...], + compositor_config={...}, + ) + + # With live audio + audio = AudioAnalyzer(device=0) + compositor.run(output="output.mp4", duration=60, audio=audio) + + # With preview window + compositor.run(output="preview", duration=60) + +Backends: + - numpy: Works everywhere, ~3-5 fps (default) + - glsl: Requires GPU, 30+ fps real-time (future) +""" + +from .sources import VideoSource, ImageSource +from .compositor import StreamingCompositor +from .backends import NumpyBackend, get_backend +from .output import DisplayOutput, FileOutput + +__all__ = [ + "StreamingCompositor", + "VideoSource", + "ImageSource", + "NumpyBackend", + "get_backend", + "DisplayOutput", + "FileOutput", +] diff --git a/l1/streaming/audio.py b/l1/streaming/audio.py new file mode 100644 index 0000000..9d20937 --- /dev/null +++ b/l1/streaming/audio.py @@ -0,0 +1,486 @@ +""" +Live audio analysis for reactive effects. + +Provides real-time audio features: +- Energy (RMS amplitude) +- Beat detection +- Frequency bands (bass, mid, high) +""" + +import numpy as np +from typing import Optional +import threading +import time + + +class AudioAnalyzer: + """ + Real-time audio analyzer using sounddevice. + + Captures audio from microphone/line-in and computes + features in real-time for effect parameter bindings. + + Example: + analyzer = AudioAnalyzer(device=0) + analyzer.start() + + # In compositor loop: + energy = analyzer.get_energy() + beat = analyzer.get_beat() + + analyzer.stop() + """ + + def __init__( + self, + device: int = None, + sample_rate: int = 44100, + block_size: int = 1024, + buffer_seconds: float = 0.5, + ): + """ + Initialize audio analyzer. + + Args: + device: Audio input device index (None = default) + sample_rate: Audio sample rate + block_size: Samples per block + buffer_seconds: Ring buffer duration + """ + self.sample_rate = sample_rate + self.block_size = block_size + self.device = device + + # Ring buffer for recent audio + buffer_size = int(sample_rate * buffer_seconds) + self._buffer = np.zeros(buffer_size, dtype=np.float32) + self._buffer_pos = 0 + self._lock = threading.Lock() + + # Beat detection state + self._last_energy = 0 + self._energy_history = [] + self._last_beat_time = 0 + self._beat_threshold = 1.5 # Energy ratio for beat detection + self._min_beat_interval = 0.1 # Min seconds between beats + + # Stream state + self._stream = None + self._running = False + + def _audio_callback(self, indata, frames, time_info, status): + """Called by sounddevice for each audio block.""" + with self._lock: + # Add to ring buffer + data = indata[:, 0] if len(indata.shape) > 1 else indata + n = len(data) + if self._buffer_pos + n <= len(self._buffer): + self._buffer[self._buffer_pos:self._buffer_pos + n] = data + else: + # Wrap around + first = len(self._buffer) - self._buffer_pos + self._buffer[self._buffer_pos:] = data[:first] + self._buffer[:n - first] = data[first:] + self._buffer_pos = (self._buffer_pos + n) % len(self._buffer) + + def start(self): + """Start audio capture.""" + try: + import sounddevice as sd + except ImportError: + print("Warning: sounddevice not installed. Audio analysis disabled.") + print("Install with: pip install sounddevice") + return + + self._stream = sd.InputStream( + device=self.device, + channels=1, + samplerate=self.sample_rate, + blocksize=self.block_size, + callback=self._audio_callback, + ) + self._stream.start() + self._running = True + + def stop(self): + """Stop audio capture.""" + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + self._running = False + + def get_energy(self) -> float: + """ + Get current audio energy (RMS amplitude). + + Returns: + Energy value normalized to 0-1 range (approximately) + """ + with self._lock: + # Use recent samples + recent = 2048 + if self._buffer_pos >= recent: + data = self._buffer[self._buffer_pos - recent:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(recent - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # RMS energy + rms = np.sqrt(np.mean(data ** 2)) + + # Normalize (typical mic input is quite low) + normalized = min(1.0, rms * 10) + + return normalized + + def get_beat(self) -> bool: + """ + Detect if current moment is a beat. + + Simple onset detection based on energy spikes. + + Returns: + True if beat detected, False otherwise + """ + current_energy = self.get_energy() + now = time.time() + + # Update energy history + self._energy_history.append(current_energy) + if len(self._energy_history) > 20: + self._energy_history.pop(0) + + # Need enough history + if len(self._energy_history) < 5: + self._last_energy = current_energy + return False + + # Average recent energy + avg_energy = np.mean(self._energy_history[:-1]) + + # Beat if current energy is significantly above average + is_beat = ( + current_energy > avg_energy * self._beat_threshold and + now - self._last_beat_time > self._min_beat_interval and + current_energy > self._last_energy # Rising edge + ) + + if is_beat: + self._last_beat_time = now + + self._last_energy = current_energy + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """ + Get frequency spectrum divided into bands. + + Args: + bands: Number of frequency bands (default 3: bass, mid, high) + + Returns: + Array of band energies, normalized to 0-1 + """ + with self._lock: + # Use recent samples for FFT + n = 2048 + if self._buffer_pos >= n: + data = self._buffer[self._buffer_pos - n:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(n - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # FFT + fft = np.abs(np.fft.rfft(data * np.hanning(len(data)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + start = i * band_size + end = start + band_size + result[i] = np.mean(fft[start:end]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return self._running + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() + + +class FileAudioAnalyzer: + """ + Audio analyzer that reads from a file (for testing/development). + + Pre-computes analysis and plays back in sync with video. + """ + + def __init__(self, path: str, analysis_data: dict = None): + """ + Initialize from audio file. + + Args: + path: Path to audio file + analysis_data: Pre-computed analysis (times, values, etc.) + """ + self.path = path + self.analysis_data = analysis_data or {} + self._current_time = 0 + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Get energy at current time from pre-computed data.""" + track = self.analysis_data.get("energy", {}) + return self._interpolate(track, self._current_time) + + def get_beat(self) -> bool: + """Check if current time is near a beat.""" + track = self.analysis_data.get("beats", {}) + times = track.get("times", []) + + # Check if we're within 50ms of a beat + for beat_time in times: + if abs(beat_time - self._current_time) < 0.05: + return True + return False + + def _interpolate(self, track: dict, t: float) -> float: + """Interpolate value at time t.""" + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + if t <= times[0]: + return values[0] + if t >= times[-1]: + return values[-1] + + # Find bracket and interpolate + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + return values[i] * (1 - alpha) + values[i + 1] * alpha + + return values[-1] + + @property + def is_running(self) -> bool: + return True + + +class StreamingAudioAnalyzer: + """ + Real-time audio analyzer that streams from a file. + + Reads audio in sync with video time and computes features on-the-fly. + No pre-computation needed - analysis happens as frames are processed. + """ + + def __init__(self, path: str, sample_rate: int = 22050, hop_length: int = 512): + """ + Initialize streaming audio analyzer. + + Args: + path: Path to audio file + sample_rate: Sample rate for analysis + hop_length: Hop length for feature extraction + """ + import subprocess + import json + + self.path = path + self.sample_rate = sample_rate + self.hop_length = hop_length + self._current_time = 0.0 + + # Get audio duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(path)] + result = subprocess.run(cmd, capture_output=True, text=True) + info = json.loads(result.stdout) + self.duration = float(info["format"]["duration"]) + + # Audio buffer and state + self._audio_data = None + self._energy_history = [] + self._last_energy = 0 + self._last_beat_time = -1 + self._beat_threshold = 1.5 + self._min_beat_interval = 0.15 + + # Load audio lazily + self._loaded = False + + def _load_audio(self): + """Load audio data on first use.""" + if self._loaded: + return + + import subprocess + + # Use ffmpeg to decode audio to raw PCM + cmd = [ + "ffmpeg", "-v", "quiet", + "-i", str(self.path), + "-f", "f32le", # 32-bit float, little-endian + "-ac", "1", # mono + "-ar", str(self.sample_rate), + "-" + ] + result = subprocess.run(cmd, capture_output=True) + self._audio_data = np.frombuffer(result.stdout, dtype=np.float32) + self._loaded = True + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Compute energy at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return 0.0 + + # Get sample index for current time + sample_idx = int(self._current_time * self.sample_rate) + window_size = self.hop_length * 2 + + start = max(0, sample_idx - window_size // 2) + end = min(len(self._audio_data), sample_idx + window_size // 2) + + if start >= end: + return 0.0 + + # RMS energy + chunk = self._audio_data[start:end] + rms = np.sqrt(np.mean(chunk ** 2)) + + # Normalize to 0-1 range (approximate) + energy = min(1.0, rms * 3.0) + + self._last_energy = energy + return energy + + def get_beat(self) -> bool: + """Detect beat using spectral flux (change in frequency content).""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return False + + # Get audio chunks for current and previous frame + sample_idx = int(self._current_time * self.sample_rate) + chunk_size = self.hop_length * 2 + + # Current chunk + start = max(0, sample_idx - chunk_size // 2) + end = min(len(self._audio_data), sample_idx + chunk_size // 2) + if end - start < chunk_size // 2: + return False + current_chunk = self._audio_data[start:end] + + # Previous chunk (one hop back) + prev_start = max(0, start - self.hop_length) + prev_end = max(0, end - self.hop_length) + if prev_end <= prev_start: + return False + prev_chunk = self._audio_data[prev_start:prev_end] + + # Compute spectra + current_spec = np.abs(np.fft.rfft(current_chunk * np.hanning(len(current_chunk)))) + prev_spec = np.abs(np.fft.rfft(prev_chunk * np.hanning(len(prev_chunk)))) + + # Spectral flux: sum of positive differences (onset = new frequencies appearing) + min_len = min(len(current_spec), len(prev_spec)) + diff = current_spec[:min_len] - prev_spec[:min_len] + flux = np.sum(np.maximum(0, diff)) # Only count increases + + # Normalize by spectrum size + flux = flux / (min_len + 1) + + # Update flux history + self._energy_history.append((self._current_time, flux)) + while self._energy_history and self._energy_history[0][0] < self._current_time - 1.5: + self._energy_history.pop(0) + + if len(self._energy_history) < 3: + return False + + # Adaptive threshold based on recent flux values + flux_values = [f for t, f in self._energy_history] + mean_flux = np.mean(flux_values) + std_flux = np.std(flux_values) + 0.001 # Avoid division by zero + + # Beat if flux is above mean (more sensitive threshold) + threshold = mean_flux + std_flux * 0.3 # Lower = more sensitive + min_interval = 0.1 # Allow up to 600 BPM + time_ok = self._current_time - self._last_beat_time > min_interval + + is_beat = flux > threshold and time_ok + + if is_beat: + self._last_beat_time = self._current_time + + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """Get frequency spectrum at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return np.zeros(bands) + + sample_idx = int(self._current_time * self.sample_rate) + n = 2048 + + start = max(0, sample_idx - n // 2) + end = min(len(self._audio_data), sample_idx + n // 2) + + if end - start < n // 2: + return np.zeros(bands) + + chunk = self._audio_data[start:end] + + # FFT + fft = np.abs(np.fft.rfft(chunk * np.hanning(len(chunk)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + s, e = i * band_size, (i + 1) * band_size + result[i] = np.mean(fft[s:e]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return True diff --git a/l1/streaming/backends.py b/l1/streaming/backends.py new file mode 100644 index 0000000..80c558a --- /dev/null +++ b/l1/streaming/backends.py @@ -0,0 +1,572 @@ +""" +Effect processing backends. + +Provides abstraction over different rendering backends: +- numpy: CPU-based, works everywhere, ~3-5 fps +- glsl: GPU-based, requires OpenGL, 30+ fps (future) +""" + +import numpy as np +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from pathlib import Path + + +class Backend(ABC): + """Abstract base class for effect processing backends.""" + + @abstractmethod + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process multiple input frames through effects and composite. + + Args: + frames: List of input frames (one per source) + effects_per_frame: List of effect chains (one per source) + compositor_config: How to blend the layers + t: Current time in seconds + analysis_data: Analysis data for binding resolution + + Returns: + Composited output frame + """ + pass + + @abstractmethod + def load_effect(self, effect_path: Path) -> Any: + """Load an effect definition.""" + pass + + +class NumpyBackend(Backend): + """ + CPU-based effect processing using NumPy. + + Uses existing sexp_effects interpreter for effect execution. + Works on any system, but limited to ~3-5 fps for complex effects. + """ + + def __init__(self, recipe_dir: Path = None, minimal_primitives: bool = True): + self.recipe_dir = recipe_dir or Path(".") + self.minimal_primitives = minimal_primitives + self._interpreter = None + self._loaded_effects = {} + + def _get_interpreter(self): + """Lazy-load the sexp interpreter.""" + if self._interpreter is None: + from sexp_effects import get_interpreter + self._interpreter = get_interpreter(minimal_primitives=self.minimal_primitives) + return self._interpreter + + def load_effect(self, effect_path: Path) -> Any: + """Load an effect from sexp file.""" + if isinstance(effect_path, str): + effect_path = Path(effect_path) + effect_key = str(effect_path) + if effect_key not in self._loaded_effects: + interp = self._get_interpreter() + interp.load_effect(str(effect_path)) + self._loaded_effects[effect_key] = effect_path.stem + return self._loaded_effects[effect_key] + + def _resolve_binding(self, value: Any, t: float, analysis_data: Dict) -> Any: + """Resolve a parameter binding to its value at time t.""" + if not isinstance(value, dict): + return value + + if "_binding" in value or "_bind" in value: + source = value.get("source") or value.get("_bind") + feature = value.get("feature", "values") + range_map = value.get("range") + + track = analysis_data.get(source, {}) + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + # Find value at time t (linear interpolation) + if t <= times[0]: + val = values[0] + elif t >= times[-1]: + val = values[-1] + else: + # Binary search for bracket + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + val = values[i] * (1 - alpha) + values[i + 1] * alpha + break + else: + val = values[-1] + + # Apply range mapping + if range_map and len(range_map) == 2: + val = range_map[0] + val * (range_map[1] - range_map[0]) + + return val + + return value + + def _apply_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Apply a single effect to a frame.""" + # Resolve bindings in params + resolved_params = {"_time": t} + for key, value in params.items(): + if key in ("effect", "effect_path", "cid", "analysis_refs"): + continue + resolved_params[key] = self._resolve_binding(value, t, analysis_data) + + # Try fast native effects first + result = self._apply_native_effect(frame, effect_name, resolved_params) + if result is not None: + return result + + # Fall back to sexp interpreter for complex effects + interp = self._get_interpreter() + if effect_name in interp.effects: + result, _ = interp.run_effect(effect_name, frame, resolved_params, {}) + return result + + # Unknown effect - pass through + return frame + + def _apply_native_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + ) -> Optional[np.ndarray]: + """Fast native numpy effects for real-time streaming.""" + import cv2 + + if effect_name == "zoom": + amount = float(params.get("amount", 1.0)) + if abs(amount - 1.0) < 0.01: + return frame + h, w = frame.shape[:2] + # Crop center and resize + new_w, new_h = int(w / amount), int(h / amount) + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + + elif effect_name == "rotate": + angle = float(params.get("angle", 0)) + if abs(angle) < 0.5: + return frame + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + return cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == "brightness": + amount = float(params.get("amount", 1.0)) + return np.clip(frame * amount, 0, 255).astype(np.uint8) + + elif effect_name == "invert": + amount = float(params.get("amount", 1.0)) + if amount < 0.5: + return frame + return 255 - frame + + # Not a native effect + return None + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process frames through effects and composite. + """ + if not frames: + return np.zeros((720, 1280, 3), dtype=np.uint8) + + processed = [] + + # Apply effects to each input frame + for i, (frame, effects) in enumerate(zip(frames, effects_per_frame)): + result = frame.copy() + for effect_config in effects: + effect_name = effect_config.get("effect", "") + if effect_name: + result = self._apply_effect( + result, effect_name, effect_config, t, analysis_data + ) + processed.append(result) + + # Composite layers + if len(processed) == 1: + return processed[0] + + return self._composite(processed, compositor_config, t, analysis_data) + + def _composite( + self, + frames: List[np.ndarray], + config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Composite multiple frames into one.""" + mode = config.get("mode", "alpha") + weights = config.get("weights", [1.0 / len(frames)] * len(frames)) + + # Resolve weight bindings + resolved_weights = [] + for w in weights: + resolved_weights.append(self._resolve_binding(w, t, analysis_data)) + + # Normalize weights + total = sum(resolved_weights) + if total > 0: + resolved_weights = [w / total for w in resolved_weights] + else: + resolved_weights = [1.0 / len(frames)] * len(frames) + + # Resize frames to match first frame + target_h, target_w = frames[0].shape[:2] + resized = [] + for frame in frames: + if frame.shape[:2] != (target_h, target_w): + import cv2 + frame = cv2.resize(frame, (target_w, target_h)) + resized.append(frame.astype(np.float32)) + + # Weighted blend + result = np.zeros_like(resized[0]) + for frame, weight in zip(resized, resolved_weights): + result += frame * weight + + return np.clip(result, 0, 255).astype(np.uint8) + + +class WGPUBackend(Backend): + """ + GPU-based effect processing using wgpu/WebGPU compute shaders. + + Compiles sexp effects to WGSL at load time, executes on GPU. + Achieves 30+ fps real-time processing on supported hardware. + + Requirements: + - wgpu-py library + - Vulkan-capable GPU (or software renderer) + """ + + def __init__(self, recipe_dir: Path = None): + self.recipe_dir = recipe_dir or Path(".") + self._device = None + self._loaded_effects: Dict[str, Any] = {} # name -> compiled shader info + self._numpy_fallback = NumpyBackend(recipe_dir) + # Buffer pool for reuse - keyed by (width, height) + self._buffer_pool: Dict[tuple, Dict] = {} + + def _ensure_device(self): + """Lazy-initialize wgpu device.""" + if self._device is not None: + return + + try: + import wgpu + adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + self._device = adapter.request_device_sync() + print(f"[WGPUBackend] Using GPU: {adapter.info.get('device', 'unknown')}") + except Exception as e: + print(f"[WGPUBackend] GPU init failed: {e}, falling back to CPU") + self._device = None + + def load_effect(self, effect_path: Path) -> Any: + """Load and compile an effect from sexp file to WGSL.""" + effect_key = str(effect_path) + if effect_key in self._loaded_effects: + return self._loaded_effects[effect_key] + + try: + from sexp_effects.wgsl_compiler import compile_effect_file + compiled = compile_effect_file(str(effect_path)) + + self._ensure_device() + if self._device is None: + # Fall back to numpy + return self._numpy_fallback.load_effect(effect_path) + + # Create shader module + import wgpu + shader_module = self._device.create_shader_module(code=compiled.wgsl_code) + + # Create compute pipeline + pipeline = self._device.create_compute_pipeline( + layout="auto", + compute={"module": shader_module, "entry_point": "main"} + ) + + self._loaded_effects[effect_key] = { + 'compiled': compiled, + 'pipeline': pipeline, + 'name': compiled.name, + } + return compiled.name + + except Exception as e: + print(f"[WGPUBackend] Failed to compile {effect_path}: {e}") + # Fall back to numpy for this effect + return self._numpy_fallback.load_effect(effect_path) + + def _resolve_binding(self, value: Any, t: float, analysis_data: Dict) -> Any: + """Resolve a parameter binding to its value at time t.""" + # Delegate to numpy backend's implementation + return self._numpy_fallback._resolve_binding(value, t, analysis_data) + + def _get_or_create_buffers(self, w: int, h: int): + """Get or create reusable buffers for given dimensions.""" + import wgpu + + key = (w, h) + if key in self._buffer_pool: + return self._buffer_pool[key] + + size = w * h * 4 # u32 per pixel + + # Create staging buffer for uploads (MAP_WRITE) + staging_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.MAP_WRITE | wgpu.BufferUsage.COPY_SRC, + mapped_at_creation=False, + ) + + # Create input buffer (STORAGE, receives data from staging) + input_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST, + ) + + # Create output buffer (STORAGE + COPY_SRC for readback) + output_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC, + ) + + # Params buffer (uniform, 256 bytes should be enough) + params_buffer = self._device.create_buffer( + size=256, + usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST, + ) + + self._buffer_pool[key] = { + 'staging': staging_buffer, + 'input': input_buffer, + 'output': output_buffer, + 'params': params_buffer, + 'size': size, + } + return self._buffer_pool[key] + + def _apply_effect_gpu( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + ) -> Optional[np.ndarray]: + """Apply effect using GPU. Returns None if GPU not available.""" + import wgpu + + # Find the loaded effect + effect_info = None + for key, info in self._loaded_effects.items(): + if info.get('name') == effect_name: + effect_info = info + break + + if effect_info is None or self._device is None: + return None + + compiled = effect_info['compiled'] + pipeline = effect_info['pipeline'] + + h, w = frame.shape[:2] + + # Get reusable buffers + buffers = self._get_or_create_buffers(w, h) + + # Pack frame as u32 array (RGB -> packed u32) + r = frame[:, :, 0].astype(np.uint32) + g = frame[:, :, 1].astype(np.uint32) + b = frame[:, :, 2].astype(np.uint32) + packed = (r << 16) | (g << 8) | b + input_data = packed.flatten().astype(np.uint32) + + # Upload input data via queue.write_buffer (more efficient than recreation) + self._device.queue.write_buffer(buffers['input'], 0, input_data.tobytes()) + + # Build params struct + import struct + param_values = [w, h] # width, height as u32 + param_format = "II" # two u32 + + # Add time as f32 + param_values.append(t) + param_format += "f" + + # Add effect-specific params + for param in compiled.params: + val = params.get(param.name, param.default) + if val is None: + val = 0 + if param.wgsl_type == 'f32': + param_values.append(float(val)) + param_format += "f" + elif param.wgsl_type == 'i32': + param_values.append(int(val)) + param_format += "i" + elif param.wgsl_type == 'u32': + param_values.append(int(val)) + param_format += "I" + + # Pad to 16-byte alignment + param_bytes = struct.pack(param_format, *param_values) + while len(param_bytes) % 16 != 0: + param_bytes += b'\x00' + + self._device.queue.write_buffer(buffers['params'], 0, param_bytes) + + # Create bind group (unfortunately this can't be easily reused with different effects) + bind_group = self._device.create_bind_group( + layout=pipeline.get_bind_group_layout(0), + entries=[ + {"binding": 0, "resource": {"buffer": buffers['input']}}, + {"binding": 1, "resource": {"buffer": buffers['output']}}, + {"binding": 2, "resource": {"buffer": buffers['params']}}, + ] + ) + + # Dispatch compute + encoder = self._device.create_command_encoder() + compute_pass = encoder.begin_compute_pass() + compute_pass.set_pipeline(pipeline) + compute_pass.set_bind_group(0, bind_group) + + # Workgroups: ceil(w/16) x ceil(h/16) + wg_x = (w + 15) // 16 + wg_y = (h + 15) // 16 + compute_pass.dispatch_workgroups(wg_x, wg_y, 1) + compute_pass.end() + + self._device.queue.submit([encoder.finish()]) + + # Read back result + result_data = self._device.queue.read_buffer(buffers['output']) + result_packed = np.frombuffer(result_data, dtype=np.uint32).reshape(h, w) + + # Unpack u32 -> RGB + result = np.zeros((h, w, 3), dtype=np.uint8) + result[:, :, 0] = ((result_packed >> 16) & 0xFF).astype(np.uint8) + result[:, :, 1] = ((result_packed >> 8) & 0xFF).astype(np.uint8) + result[:, :, 2] = (result_packed & 0xFF).astype(np.uint8) + + return result + + def _apply_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Apply a single effect to a frame.""" + # Resolve bindings in params + resolved_params = {"_time": t} + for key, value in params.items(): + if key in ("effect", "effect_path", "cid", "analysis_refs"): + continue + resolved_params[key] = self._resolve_binding(value, t, analysis_data) + + # Try GPU first + self._ensure_device() + if self._device is not None: + result = self._apply_effect_gpu(frame, effect_name, resolved_params, t) + if result is not None: + return result + + # Fall back to numpy + return self._numpy_fallback._apply_effect( + frame, effect_name, params, t, analysis_data + ) + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Process frames through effects and composite.""" + if not frames: + return np.zeros((720, 1280, 3), dtype=np.uint8) + + processed = [] + + # Apply effects to each input frame + for i, (frame, effects) in enumerate(zip(frames, effects_per_frame)): + result = frame.copy() + for effect_config in effects: + effect_name = effect_config.get("effect", "") + if effect_name: + result = self._apply_effect( + result, effect_name, effect_config, t, analysis_data + ) + processed.append(result) + + # Composite layers (use numpy backend for now) + if len(processed) == 1: + return processed[0] + + return self._numpy_fallback._composite( + processed, compositor_config, t, analysis_data + ) + + +# Keep GLSLBackend as alias for backwards compatibility +GLSLBackend = WGPUBackend + + +def get_backend(name: str = "numpy", **kwargs) -> Backend: + """ + Get a backend by name. + + Args: + name: "numpy", "wgpu", or "glsl" (alias for wgpu) + **kwargs: Backend-specific options + + Returns: + Backend instance + """ + if name == "numpy": + return NumpyBackend(**kwargs) + elif name in ("wgpu", "glsl", "gpu"): + return WGPUBackend(**kwargs) + else: + raise ValueError(f"Unknown backend: {name}") diff --git a/l1/streaming/compositor.py b/l1/streaming/compositor.py new file mode 100644 index 0000000..477128f --- /dev/null +++ b/l1/streaming/compositor.py @@ -0,0 +1,595 @@ +""" +Streaming video compositor. + +Main entry point for the streaming pipeline. Combines: +- Multiple video sources (with looping) +- Per-source effect chains +- Layer compositing +- Optional live audio analysis +- Output to display/file/stream +""" + +import time +import sys +import numpy as np +from typing import List, Dict, Any, Optional, Union +from pathlib import Path + +from .sources import Source, VideoSource +from .backends import Backend, NumpyBackend, get_backend +from .output import Output, DisplayOutput, FileOutput, MultiOutput + + +class StreamingCompositor: + """ + Real-time streaming video compositor. + + Reads frames from multiple sources, applies effects, composites layers, + and outputs the result - all frame-by-frame without intermediate files. + + Example: + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[ + [{"effect": "rotate", "angle": 45}], + [{"effect": "zoom", "amount": 1.5}], + ], + compositor_config={"mode": "alpha", "weights": [0.5, 0.5]}, + ) + compositor.run(output="preview", duration=60) + """ + + def __init__( + self, + sources: List[Union[str, Source]], + effects_per_source: List[List[Dict]] = None, + compositor_config: Dict = None, + analysis_data: Dict = None, + backend: str = "numpy", + recipe_dir: Path = None, + fps: float = 30, + audio_source: str = None, + ): + """ + Initialize the streaming compositor. + + Args: + sources: List of video paths or Source objects + effects_per_source: List of effect chains, one per source + compositor_config: How to blend layers (mode, weights) + analysis_data: Pre-computed analysis data for bindings + backend: "numpy" or "glsl" + recipe_dir: Directory for resolving relative effect paths + fps: Output frame rate + audio_source: Path to audio file for streaming analysis + """ + self.fps = fps + self.recipe_dir = recipe_dir or Path(".") + self.analysis_data = analysis_data or {} + + # Initialize streaming audio analyzer if audio source provided + self._audio_analyzer = None + self._audio_source = audio_source + if audio_source: + from .audio import StreamingAudioAnalyzer + self._audio_analyzer = StreamingAudioAnalyzer(audio_source) + print(f"Streaming audio: {audio_source}", file=sys.stderr) + + # Initialize sources + self.sources: List[Source] = [] + for src in sources: + if isinstance(src, Source): + self.sources.append(src) + elif isinstance(src, (str, Path)): + self.sources.append(VideoSource(str(src), target_fps=fps)) + else: + raise ValueError(f"Unknown source type: {type(src)}") + + # Effect chains (default: no effects) + self.effects_per_source = effects_per_source or [[] for _ in self.sources] + if len(self.effects_per_source) != len(self.sources): + raise ValueError( + f"effects_per_source length ({len(self.effects_per_source)}) " + f"must match sources length ({len(self.sources)})" + ) + + # Compositor config (default: equal blend) + self.compositor_config = compositor_config or { + "mode": "alpha", + "weights": [1.0 / len(self.sources)] * len(self.sources), + } + + # Initialize backend + self.backend: Backend = get_backend( + backend, + recipe_dir=self.recipe_dir, + ) + + # Load effects + self._load_effects() + + def _load_effects(self): + """Pre-load all effect definitions.""" + for effects in self.effects_per_source: + for effect_config in effects: + effect_path = effect_config.get("effect_path") + if effect_path: + full_path = self.recipe_dir / effect_path + if full_path.exists(): + self.backend.load_effect(full_path) + + def _create_output( + self, + output: Union[str, Output], + size: tuple, + ) -> Output: + """Create output target from string or Output object.""" + if isinstance(output, Output): + return output + + if output == "preview": + return DisplayOutput("Streaming Preview", size, + audio_source=self._audio_source, fps=self.fps) + elif output == "null": + from .output import NullOutput + return NullOutput() + elif isinstance(output, str): + return FileOutput(output, size, fps=self.fps, audio_source=self._audio_source) + else: + raise ValueError(f"Unknown output type: {output}") + + def run( + self, + output: Union[str, Output] = "preview", + duration: float = None, + audio_analyzer=None, + show_fps: bool = True, + recipe_executor=None, + ): + """ + Run the streaming compositor. + + Args: + output: Output target - "preview", filename, or Output object + duration: Duration in seconds (None = run until quit) + audio_analyzer: Optional AudioAnalyzer for live audio reactivity + show_fps: Show FPS counter in console + recipe_executor: Optional StreamingRecipeExecutor for full recipe logic + """ + # Determine output size from first source + output_size = self.sources[0].size + + # Create output + out = self._create_output(output, output_size) + + # Determine duration + if duration is None: + # Run until stopped (or min source duration if not looping) + duration = min(s.duration for s in self.sources) + if duration == float('inf'): + duration = 3600 # 1 hour max for live sources + + total_frames = int(duration * self.fps) + frame_time = 1.0 / self.fps + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {total_frames} frames @ {self.fps}fps", file=sys.stderr) + print(f"Output size: {output_size[0]}x{output_size[1]}", file=sys.stderr) + print(f"Press 'q' to quit (if preview)", file=sys.stderr) + + # Frame loop + start_time = time.time() + frame_count = 0 + fps_update_interval = 30 # Update FPS display every N frames + last_fps_time = start_time + last_fps_count = 0 + + try: + for frame_num in range(total_frames): + if not out.is_open: + print(f"\nOutput closed at frame {frame_num}", file=sys.stderr) + break + + t = frame_num * frame_time + + try: + # Update analysis data from streaming audio (file-based) + energy = 0.0 + is_beat = False + if self._audio_analyzer: + self._update_from_audio(self._audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + elif audio_analyzer: + self._update_from_audio(audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + + # Read frames from all sources + frames = [src.read_frame(t) for src in self.sources] + + # Process through recipe executor if provided + if recipe_executor: + result = self._process_with_executor( + frames, recipe_executor, energy, is_beat, t + ) + else: + # Simple backend processing + result = self.backend.process_frame( + frames, + self.effects_per_source, + self.compositor_config, + t, + self.analysis_data, + ) + + # Output + out.write(result, t) + frame_count += 1 + + # FPS display + if show_fps and frame_count % fps_update_interval == 0: + now = time.time() + elapsed = now - last_fps_time + if elapsed > 0: + current_fps = (frame_count - last_fps_count) / elapsed + progress = frame_num / total_frames * 100 + print( + f"\r {progress:5.1f}% | {current_fps:5.1f} fps | " + f"frame {frame_num}/{total_frames}", + end="", file=sys.stderr + ) + last_fps_time = now + last_fps_count = frame_count + + except Exception as e: + print(f"\nError at frame {frame_num}, t={t:.1f}s: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + break + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources: + if hasattr(src, 'close'): + src.close() + + # Final stats + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", file=sys.stderr) + + def _process_with_executor( + self, + frames: List[np.ndarray], + executor, + energy: float, + is_beat: bool, + t: float, + ) -> np.ndarray: + """ + Process frames using the recipe executor for full pipeline. + + Implements: + 1. process-pair: two clips per source with effects, blended + 2. cycle-crossfade: dynamic composition with zoom and weights + 3. Final effects: whole-spin, ripple + """ + import cv2 + + # Target size from first source + target_h, target_w = frames[0].shape[:2] + + # Resize all frames to target size (letterbox to preserve aspect ratio) + resized_frames = [] + for frame in frames: + fh, fw = frame.shape[:2] + if (fh, fw) != (target_h, target_w): + # Calculate scale to fit while preserving aspect ratio + scale = min(target_w / fw, target_h / fh) + new_w, new_h = int(fw * scale), int(fh * scale) + resized = cv2.resize(frame, (new_w, new_h)) + # Center on black canvas + canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8) + x_off = (target_w - new_w) // 2 + y_off = (target_h - new_h) // 2 + canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized + resized_frames.append(canvas) + else: + resized_frames.append(frame) + frames = resized_frames + + # Update executor state + executor.on_frame(energy, is_beat, t) + + # Get weights to know which sources are active + weights = executor.get_cycle_weights() + + # Process each source as a "pair" (clip A and B with different effects) + processed_pairs = [] + + for i, frame in enumerate(frames): + # Skip sources with zero weight (but still need placeholder) + if i < len(weights) and weights[i] < 0.001: + processed_pairs.append(None) + continue + # Get effect params for clip A and B + params_a = executor.get_effect_params(i, "a", energy) + params_b = executor.get_effect_params(i, "b", energy) + pair_params = executor.get_pair_params(i) + + # Process clip A + clip_a = self._apply_clip_effects(frame.copy(), params_a, t) + + # Process clip B + clip_b = self._apply_clip_effects(frame.copy(), params_b, t) + + # Blend A and B using pair_mix opacity + opacity = pair_params["blend_opacity"] + blended = cv2.addWeighted( + clip_a, 1 - opacity, + clip_b, opacity, + 0 + ) + + # Apply pair rotation + h, w = blended.shape[:2] + center = (w // 2, h // 2) + angle = pair_params["pair_rotation"] + if abs(angle) > 0.5: + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + blended = cv2.warpAffine(blended, matrix, (w, h)) + + processed_pairs.append(blended) + + # Cycle-crossfade composition + weights = executor.get_cycle_weights() + zooms = executor.get_cycle_zooms() + + # Apply zoom per pair and composite + h, w = target_h, target_w + result = np.zeros((h, w, 3), dtype=np.float32) + + for idx, (pair, weight, zoom) in enumerate(zip(processed_pairs, weights, zooms)): + # Skip zero-weight sources + if pair is None or weight < 0.001: + continue + + orig_shape = pair.shape + + # Apply zoom + if zoom > 1.01: + # Zoom in: crop center and resize up + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = pair[y1:y1+new_h, x1:x1+new_w] + pair = cv2.resize(cropped, (w, h)) + elif zoom < 0.99: + # Zoom out: shrink video and center on black + scaled_w, scaled_h = int(w * zoom), int(h * zoom) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(pair, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + pair = canvas.copy() + + # Draw colored border - size indicates zoom level + border_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] + color = border_colors[idx % 4] + thickness = max(3, int(10 * weight)) # Thicker border = higher weight + pair = np.ascontiguousarray(pair) + pair[:thickness, :] = color + pair[-thickness:, :] = color + pair[:, :thickness] = color + pair[:, -thickness:] = color + + result += pair.astype(np.float32) * weight + + result = np.clip(result, 0, 255).astype(np.uint8) + + # Apply final effects (whole-spin, ripple) + final_params = executor.get_final_effects(energy) + + # Whole spin + spin_angle = final_params["whole_spin_angle"] + if abs(spin_angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, spin_angle, 1.0) + result = cv2.warpAffine(result, matrix, (w, h)) + + # Ripple effect + amp = final_params["ripple_amplitude"] + if amp > 1: + result = self._apply_ripple(result, amp, + final_params["ripple_cx"], + final_params["ripple_cy"], + t) + + return result + + def _apply_clip_effects(self, frame: np.ndarray, params: dict, t: float) -> np.ndarray: + """Apply per-clip effects: rotate, zoom, invert, hue_shift, ascii.""" + import cv2 + + h, w = frame.shape[:2] + + # Rotate + angle = params["rotate_angle"] + if abs(angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + # Zoom + zoom = params["zoom_amount"] + if abs(zoom - 1.0) > 0.01: + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x1 + new_w), min(h, y1 + new_h) + if x2 > x1 and y2 > y1: + cropped = frame[y1:y2, x1:x2] + frame = cv2.resize(cropped, (w, h)) + + # Invert + if params["invert_amount"] > 0.5: + frame = 255 - frame + + # Hue shift + hue_deg = params["hue_degrees"] + if abs(hue_deg) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.int32) + int(hue_deg / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + # ASCII art + if params["ascii_mix"] > 0.5: + char_size = max(4, int(params["ascii_char_size"])) + frame = self._apply_ascii(frame, char_size) + + return frame + + def _apply_ascii(self, frame: np.ndarray, char_size: int) -> np.ndarray: + """Apply ASCII art effect.""" + import cv2 + from PIL import Image, ImageDraw, ImageFont + + h, w = frame.shape[:2] + chars = " .:-=+*#%@" + + # Get font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size) + except: + font = ImageFont.load_default() + + # Sample cells using area interpolation (fast block average) + rows = h // char_size + cols = w // char_size + if rows < 1 or cols < 1: + return frame + + # Crop to exact grid and downsample + cropped = frame[:rows * char_size, :cols * char_size] + cell_colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * cell_colors[:, :, 0] + + 0.587 * cell_colors[:, :, 1] + + 0.114 * cell_colors[:, :, 2]) / 255.0 + + # Create output image + out_h = rows * char_size + out_w = cols * char_size + output = Image.new('RGB', (out_w, out_h), (0, 0, 0)) + draw = ImageDraw.Draw(output) + + # Draw characters + for r in range(rows): + for c in range(cols): + lum = luminances[r, c] + color = tuple(cell_colors[r, c]) + + # Map luminance to character + idx = int(lum * (len(chars) - 1)) + char = chars[idx] + + # Draw character + x = c * char_size + y = r * char_size + draw.text((x, y), char, fill=color, font=font) + + # Convert back to numpy and resize to original + result = np.array(output) + if result.shape[:2] != (h, w): + result = cv2.resize(result, (w, h), interpolation=cv2.INTER_LINEAR) + + return result + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, t: float = 0) -> np.ndarray: + """Apply ripple distortion effect.""" + import cv2 + + h, w = frame.shape[:2] + center_x, center_y = cx * w, cy * h + max_dim = max(w, h) + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Distance from center + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx*dx + dy*dy) + + # Ripple parameters (matching recipe: frequency=8, decay=2, speed=5) + freq = 8 + decay = 2 + speed = 5 + phase = t * speed * 2 * np.pi + + # Ripple displacement (matching original formula) + ripple = np.sin(2 * np.pi * freq * dist / max_dim + phase) * amplitude + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max_dim) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT) + + def _update_from_audio(self, analyzer, t: float): + """Update analysis data from audio analyzer (streaming or live).""" + # Set time for file-based streaming analyzers + if hasattr(analyzer, 'set_time'): + analyzer.set_time(t) + + # Get current audio features + energy = analyzer.get_energy() if hasattr(analyzer, 'get_energy') else 0 + beat = analyzer.get_beat() if hasattr(analyzer, 'get_beat') else False + + # Update analysis tracks - these can be referenced by effect bindings + self.analysis_data["live_energy"] = { + "times": [t], + "values": [energy], + "duration": float('inf'), + } + self.analysis_data["live_beat"] = { + "times": [t], + "values": [1.0 if beat else 0.0], + "duration": float('inf'), + } + + +def quick_preview( + sources: List[str], + effects: List[List[Dict]] = None, + duration: float = 10, + fps: float = 30, +): + """ + Quick preview helper - show sources with optional effects. + + Example: + quick_preview(["video1.mp4", "video2.mp4"], duration=30) + """ + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects, + fps=fps, + ) + compositor.run(output="preview", duration=duration) diff --git a/l1/streaming/demo.py b/l1/streaming/demo.py new file mode 100644 index 0000000..0b1899f --- /dev/null +++ b/l1/streaming/demo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Demo script for streaming compositor. + +Usage: + # Preview two videos blended + python -m streaming.demo preview video1.mp4 video2.mp4 + + # Record output to file + python -m streaming.demo record video1.mp4 video2.mp4 -o output.mp4 + + # Benchmark (no output) + python -m streaming.demo benchmark video1.mp4 --duration 10 +""" + +import argparse +import sys +from pathlib import Path + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from streaming import StreamingCompositor, VideoSource +from streaming.output import NullOutput + + +def demo_preview(sources: list, duration: float, effects: bool = False): + """Preview sources with optional simple effects.""" + effects_config = None + if effects: + effects_config = [ + [{"effect": "rotate", "angle": 15}], + [{"effect": "zoom", "amount": 1.2}], + ][:len(sources)] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="preview", duration=duration) + + +def demo_record(sources: list, output_path: str, duration: float): + """Record blended output to file.""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output=output_path, duration=duration) + + +def demo_benchmark(sources: list, duration: float): + """Benchmark processing speed (no output).""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="null", duration=duration) + + +def demo_audio_reactive(sources: list, duration: float): + """Preview with live audio reactivity.""" + from streaming.audio import AudioAnalyzer + + # Create compositor with energy-reactive effects + effects_config = [ + [{ + "effect": "zoom", + "amount": {"_binding": True, "source": "live_energy", "feature": "values", "range": [1.0, 1.5]}, + }] + for _ in sources + ] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + + # Start audio analyzer + try: + with AudioAnalyzer() as audio: + print("Audio analyzer started. Make some noise!", file=sys.stderr) + compositor.run(output="preview", duration=duration, audio_analyzer=audio) + except Exception as e: + print(f"Audio not available: {e}", file=sys.stderr) + print("Running without audio...", file=sys.stderr) + compositor.run(output="preview", duration=duration) + + +def main(): + parser = argparse.ArgumentParser(description="Streaming compositor demo") + parser.add_argument("mode", choices=["preview", "record", "benchmark", "audio"], + help="Demo mode") + parser.add_argument("sources", nargs="+", help="Video source files") + parser.add_argument("-o", "--output", help="Output file (for record mode)") + parser.add_argument("-d", "--duration", type=float, default=30, + help="Duration in seconds") + parser.add_argument("--effects", action="store_true", + help="Apply simple effects (for preview)") + + args = parser.parse_args() + + # Verify sources exist + for src in args.sources: + if not Path(src).exists(): + print(f"Error: Source not found: {src}", file=sys.stderr) + sys.exit(1) + + if args.mode == "preview": + demo_preview(args.sources, args.duration, args.effects) + elif args.mode == "record": + if not args.output: + print("Error: --output required for record mode", file=sys.stderr) + sys.exit(1) + demo_record(args.sources, args.output, args.duration) + elif args.mode == "benchmark": + demo_benchmark(args.sources, args.duration) + elif args.mode == "audio": + demo_audio_reactive(args.sources, args.duration) + + +if __name__ == "__main__": + main() diff --git a/l1/streaming/gpu_output.py b/l1/streaming/gpu_output.py new file mode 100644 index 0000000..3034310 --- /dev/null +++ b/l1/streaming/gpu_output.py @@ -0,0 +1,538 @@ +""" +Zero-copy GPU video encoding output. + +Uses PyNvVideoCodec for direct GPU-to-GPU encoding without CPU transfers. +Frames stay on GPU throughout: CuPy → NV12 conversion → NVENC encoding. +""" + +import numpy as np +import subprocess +import sys +import threading +import queue +from pathlib import Path +from typing import Tuple, Optional, Union +import time + +# Try to import GPU libraries +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + +try: + import PyNvVideoCodec as nvc + PYNVCODEC_AVAILABLE = True +except ImportError: + nvc = None + PYNVCODEC_AVAILABLE = False + + +def check_gpu_encode_available() -> bool: + """Check if zero-copy GPU encoding is available.""" + return CUPY_AVAILABLE and PYNVCODEC_AVAILABLE + + +# RGB to NV12 CUDA kernel +_RGB_TO_NV12_KERNEL = None + +def _get_rgb_to_nv12_kernel(): + """Get or create the RGB to NV12 conversion kernel.""" + global _RGB_TO_NV12_KERNEL + if _RGB_TO_NV12_KERNEL is None and CUPY_AVAILABLE: + _RGB_TO_NV12_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void rgb_to_nv12( + const unsigned char* rgb, + unsigned char* y_plane, + unsigned char* uv_plane, + int width, int height +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int rgb_idx = (y * width + x) * 3; + unsigned char r = rgb[rgb_idx]; + unsigned char g = rgb[rgb_idx + 1]; + unsigned char b = rgb[rgb_idx + 2]; + + // RGB to Y (BT.601) + int y_val = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16; + y_plane[y * width + x] = (unsigned char)(y_val > 255 ? 255 : (y_val < 0 ? 0 : y_val)); + + // UV (subsample 2x2) - only process even pixels + if ((x & 1) == 0 && (y & 1) == 0) { + int u_val = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + int v_val = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + + int uv_idx = (y / 2) * width + x; + uv_plane[uv_idx] = (unsigned char)(u_val > 255 ? 255 : (u_val < 0 ? 0 : u_val)); + uv_plane[uv_idx + 1] = (unsigned char)(v_val > 255 ? 255 : (v_val < 0 ? 0 : v_val)); + } +} +''', 'rgb_to_nv12') + return _RGB_TO_NV12_KERNEL + + +class GPUEncoder: + """ + Zero-copy GPU video encoder using PyNvVideoCodec. + + Frames are converted from RGB to NV12 on GPU and encoded directly + without any CPU memory transfers. + """ + + def __init__(self, width: int, height: int, fps: float = 30, crf: int = 23): + if not check_gpu_encode_available(): + raise RuntimeError("GPU encoding not available (need CuPy and PyNvVideoCodec)") + + self.width = width + self.height = height + self.fps = fps + self.crf = crf + + # Create dummy video to get frame buffer template + self._init_frame_buffer() + + # Create encoder with low-latency settings (no B-frames for immediate output) + # Use H264 codec explicitly, with SPS/PPS headers for browser compatibility + self.encoder = nvc.CreateEncoder( + width, height, "NV12", usecpuinputbuffer=False, + codec="h264", # Explicit H.264 (not HEVC) + bf=0, # No B-frames - immediate output + repeatSPSPPS=1, # Include SPS/PPS with each IDR frame + idrPeriod=30, # IDR frame every 30 frames (1 sec at 30fps) + ) + + # CUDA kernel grid/block config + self._block = (16, 16) + self._grid = ((width + 15) // 16, (height + 15) // 16) + + self._frame_count = 0 + self._encoded_data = [] + + print(f"[GPUEncoder] Initialized {width}x{height} @ {fps}fps, zero-copy GPU encoding", file=sys.stderr) + + def _init_frame_buffer(self): + """Initialize frame buffer from dummy decode.""" + # Create minimal dummy video + dummy_path = Path("/tmp/gpu_encoder_dummy.mp4") + subprocess.run([ + "ffmpeg", "-y", "-f", "lavfi", + "-i", f"color=black:size={self.width}x{self.height}:duration=0.1:rate=30", + "-c:v", "h264", "-pix_fmt", "yuv420p", + str(dummy_path) + ], capture_output=True) + + # Decode to get frame buffer + demuxer = nvc.CreateDemuxer(str(dummy_path)) + decoder = nvc.CreateDecoder(gpuid=0, usedevicememory=True) + + self._template_frame = None + for _ in range(30): + packet = demuxer.Demux() + if not packet: + break + frames = decoder.Decode(packet) + if frames: + self._template_frame = frames[0] + break + + if not self._template_frame: + raise RuntimeError("Failed to initialize GPU frame buffer") + + # Wrap frame planes with CuPy for zero-copy access + y_ptr = self._template_frame.GetPtrToPlane(0) + uv_ptr = self._template_frame.GetPtrToPlane(1) + + y_mem = cp.cuda.UnownedMemory(y_ptr, self.height * self.width, None) + self._y_plane = cp.ndarray( + (self.height, self.width), dtype=cp.uint8, + memptr=cp.cuda.MemoryPointer(y_mem, 0) + ) + + uv_mem = cp.cuda.UnownedMemory(uv_ptr, (self.height // 2) * self.width, None) + self._uv_plane = cp.ndarray( + (self.height // 2, self.width), dtype=cp.uint8, + memptr=cp.cuda.MemoryPointer(uv_mem, 0) + ) + + # Keep references to prevent GC + self._decoder = decoder + self._demuxer = demuxer + + # Cleanup dummy file + dummy_path.unlink(missing_ok=True) + + def encode_frame(self, frame: Union[np.ndarray, 'cp.ndarray']) -> bytes: + """ + Encode a frame (RGB format) to H.264. + + Args: + frame: RGB frame as numpy or CuPy array, shape (H, W, 3) + + Returns: + Encoded bytes (may be empty if frame is buffered) + """ + # Ensure frame is on GPU + if isinstance(frame, np.ndarray): + frame_gpu = cp.asarray(frame) + else: + frame_gpu = frame + + # Ensure uint8 + if frame_gpu.dtype != cp.uint8: + frame_gpu = cp.clip(frame_gpu, 0, 255).astype(cp.uint8) + + # Ensure contiguous + if not frame_gpu.flags['C_CONTIGUOUS']: + frame_gpu = cp.ascontiguousarray(frame_gpu) + + # Debug: check input frame has actual data (first few frames only) + if self._frame_count < 3: + frame_sum = float(cp.sum(frame_gpu)) + print(f"[GPUEncoder] Frame {self._frame_count}: shape={frame_gpu.shape}, dtype={frame_gpu.dtype}, sum={frame_sum:.0f}", file=sys.stderr) + if frame_sum < 1000: + print(f"[GPUEncoder] WARNING: Frame appears to be mostly black!", file=sys.stderr) + + # Convert RGB to NV12 on GPU + kernel = _get_rgb_to_nv12_kernel() + kernel(self._grid, self._block, (frame_gpu, self._y_plane, self._uv_plane, self.width, self.height)) + + # CRITICAL: Synchronize CUDA to ensure kernel completes before encoding + cp.cuda.Stream.null.synchronize() + + # Debug: check Y plane has data after conversion (first few frames only) + if self._frame_count < 3: + y_sum = float(cp.sum(self._y_plane)) + print(f"[GPUEncoder] Frame {self._frame_count}: Y plane sum={y_sum:.0f}", file=sys.stderr) + + # Encode (GPU to GPU) + result = self.encoder.Encode(self._template_frame) + self._frame_count += 1 + + return result if result else b'' + + def flush(self) -> bytes: + """Flush encoder and return remaining data.""" + return self.encoder.EndEncode() + + def close(self): + """Close encoder and cleanup.""" + pass + + +class GPUHLSOutput: + """ + GPU-accelerated HLS output with IPFS upload. + + Uses zero-copy GPU encoding and writes HLS segments. + Uploads happen asynchronously in a background thread to avoid stuttering. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + crf: int = 23, + audio_source: str = None, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._on_playlist_update = on_playlist_update + self._is_open = True + self.audio_source = audio_source + + # GPU encoder + self._gpu_encoder = GPUEncoder(size[0], size[1], fps, crf) + + # Segment management + self._current_segment = 0 + self._frames_in_segment = 0 + self._frames_per_segment = int(fps * segment_duration) + self._segment_data = [] + + # Track segment CIDs for IPFS + self.segment_cids = {} + self._playlist_cid = None + self._upload_lock = threading.Lock() + + # Import IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Background upload thread + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Setup ffmpeg for muxing (takes raw H.264, outputs .ts segments) + self._setup_muxer() + + print(f"[GPUHLSOutput] Initialized {size[0]}x{size[1]} @ {fps}fps, GPU encoding", file=sys.stderr) + + def _setup_muxer(self): + """Setup ffmpeg for muxing H.264 to MPEG-TS segments with optional audio.""" + self.local_playlist_path = self.output_dir / "stream.m3u8" + + cmd = [ + "ffmpeg", "-y", + "-f", "h264", # Input is raw H.264 + "-i", "-", + ] + + # Add audio input if provided + if self.audio_source: + cmd.extend(["-i", str(self.audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + cmd.extend([ + "-c:v", "copy", # Just copy video, no re-encoding + ]) + + # Add audio codec if we have audio + if self.audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + cmd.extend([ + "-f", "hls", + "-hls_time", str(self.segment_duration), + "-hls_list_size", "0", + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.local_playlist_path), + ]) + + print(f"[GPUHLSOutput] FFmpeg cmd: {' '.join(cmd)}", file=sys.stderr) + + self._muxer = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, # Capture stderr for debugging + ) + + # Start thread to drain stderr (prevents pipe buffer from filling and blocking FFmpeg) + self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) + self._stderr_thread.start() + + def _drain_stderr(self): + """Drain FFmpeg stderr to prevent blocking.""" + try: + for line in self._muxer.stderr: + line_str = line.decode('utf-8', errors='replace').strip() + if line_str: + print(f"[FFmpeg] {line_str}", file=sys.stderr) + except Exception as e: + print(f"[FFmpeg stderr] Error reading: {e}", file=sys.stderr) + + def write(self, frame: Union[np.ndarray, 'cp.ndarray'], t: float = 0): + """Write a frame using GPU encoding.""" + if not self._is_open: + return + + # Handle GPUFrame objects (from streaming_gpu primitives) + if hasattr(frame, 'gpu') and hasattr(frame, 'is_on_gpu'): + # It's a GPUFrame - extract the underlying array + frame = frame.gpu if frame.is_on_gpu else frame.cpu + + # GPU encode + encoded = self._gpu_encoder.encode_frame(frame) + + # Send to muxer + if encoded: + try: + self._muxer.stdin.write(encoded) + except BrokenPipeError as e: + print(f"[GPUHLSOutput] FFmpeg pipe broken after {self._frames_in_segment} frames in segment, total segments: {self._current_segment}", file=sys.stderr) + # Check if muxer is still running + if self._muxer.poll() is not None: + print(f"[GPUHLSOutput] FFmpeg exited with code {self._muxer.returncode}", file=sys.stderr) + self._is_open = False + return + except Exception as e: + print(f"[GPUHLSOutput] Error writing to FFmpeg: {e}", file=sys.stderr) + self._is_open = False + return + + self._frames_in_segment += 1 + + # Check for segment completion + if self._frames_in_segment >= self._frames_per_segment: + self._frames_in_segment = 0 + self._check_upload_segments() + + def _upload_worker(self): + """Background worker thread for async IPFS uploads.""" + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + seg_path, seg_num = item + self._do_upload(seg_path, seg_num) + except queue.Empty: + continue + except Exception as e: + print(f"Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, seg_path: Path, seg_num: int): + """Actually perform the upload (runs in background thread).""" + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + with self._upload_lock: + self.segment_cids[seg_num] = cid + print(f"Added to IPFS: {seg_path.name} -> {cid}", file=sys.stderr) + self._update_playlist() + except Exception as e: + print(f"Failed to add to IPFS: {e}", file=sys.stderr) + + def _check_upload_segments(self): + """Check for and queue new segments for async IPFS upload.""" + segments = sorted(self.output_dir.glob("segment_*.ts")) + + for seg_path in segments: + seg_num = int(seg_path.stem.split("_")[1]) + + with self._upload_lock: + if seg_num in self.segment_cids: + continue + + # Check if segment is complete (quick check, no blocking) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue + # Quick non-blocking check + time.sleep(0.01) + size2 = seg_path.stat().st_size + if size1 != size2: + continue + except FileNotFoundError: + continue + + # Queue for async upload (non-blocking!) + self._upload_queue.put((seg_path, seg_num)) + + def _update_playlist(self): + """Generate and upload IPFS-aware playlist.""" + with self._upload_lock: + if not self.segment_cids: + return + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + # Use /ipfs-ts/ path for segments to get correct MIME type (video/mp2t) + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + lines.append(f"{segment_gateway}/{cid}") + + playlist_content = "\n".join(lines) + "\n" + + # Upload playlist + self._playlist_cid = self._ipfs_add_bytes(playlist_content.encode(), pin=True) + if self._playlist_cid and self._on_playlist_update: + self._on_playlist_update(self._playlist_cid) + + def close(self): + """Close output and flush remaining data.""" + if not self._is_open: + return + + self._is_open = False + + # Flush GPU encoder + final_data = self._gpu_encoder.flush() + if final_data: + try: + self._muxer.stdin.write(final_data) + except: + pass + + # Close muxer + try: + self._muxer.stdin.close() + self._muxer.wait(timeout=10) + except: + self._muxer.kill() + + # Final segment upload + self._check_upload_segments() + + # Wait for pending uploads to complete + self._upload_queue.put(None) # Signal shutdown + self._upload_thread.join(timeout=30) + + # Generate final playlist with #EXT-X-ENDLIST for VOD playback + self._generate_final_playlist() + + self._gpu_encoder.close() + + def _generate_final_playlist(self): + """Generate final IPFS playlist with #EXT-X-ENDLIST for completed streams.""" + with self._upload_lock: + if not self.segment_cids: + return + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + "#EXT-X-PLAYLIST-TYPE:VOD", # Mark as VOD for completed streams + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + # Use /ipfs-ts/ path for segments to get correct MIME type (video/mp2t) + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + lines.append(f"{segment_gateway}/{cid}") + + # Mark stream as complete - critical for VOD playback + lines.append("#EXT-X-ENDLIST") + + playlist_content = "\n".join(lines) + "\n" + + # Upload final playlist + self._playlist_cid = self._ipfs_add_bytes(playlist_content.encode(), pin=True) + if self._playlist_cid: + print(f"[GPUHLSOutput] Final VOD playlist: {self._playlist_cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + if self._on_playlist_update: + self._on_playlist_update(self._playlist_cid) + + @property + def is_open(self) -> bool: + return self._is_open + + @property + def playlist_cid(self) -> Optional[str]: + return self._playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + """Get the full IPFS URL for the playlist.""" + if self._playlist_cid: + return f"{self.ipfs_gateway}/{self._playlist_cid}" + return None diff --git a/l1/streaming/jax_typography.py b/l1/streaming/jax_typography.py new file mode 100644 index 0000000..74c0b31 --- /dev/null +++ b/l1/streaming/jax_typography.py @@ -0,0 +1,1642 @@ +""" +JAX Typography Primitives + +Two approaches for text rendering, both compile to JAX/GPU: + +## 1. TextStrip - Pixel-perfect static text + Pre-render entire strings at compile time using PIL. + Perfect sub-pixel anti-aliasing, exact match with PIL. + Use for: static titles, labels, any text without per-character effects. + + S-expression: + (let ((strip (render-text-strip "Hello World" 48))) + (place-text-strip frame strip x y :color white)) + +## 2. Glyph-by-glyph - Dynamic text effects + Individual glyph placement for wave, arc, audio-reactive effects. + Each character can have independent position, color, opacity. + Note: slight anti-aliasing differences vs PIL due to integer positioning. + + S-expression: + ; Wave text - y oscillates with character index + (let ((glyphs (text-glyphs "Wavy" 48))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc)) + (i (length acc))) ; approximate index + (list + (place-glyph frm (glyph-image g) + (+ x cursor) + (+ y (* amplitude (sin (* i frequency)))) + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g)))))))) + + ; Audio-reactive spacing + (let ((glyphs (text-glyphs "Bass" 48)) + (bass (audio-band 0 200))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc))) + (list + (place-glyph frm (glyph-image g) + (+ x cursor) y + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g) (* bass 20)))))))) + +Kerning support: + ; With kerning adjustment + (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) +""" + +import math +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from typing import Tuple, Dict, Any, List, Optional +from dataclasses import dataclass + + +# ============================================================================= +# Glyph Data (computed at compile time) +# ============================================================================= + +@dataclass +class GlyphData: + """Glyph data computed at compile time. + + Attributes: + char: The character + image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime + advance: Horizontal advance (distance to next glyph origin) + bearing_x: Left side bearing (x offset from origin to first pixel) + bearing_y: Top bearing (y offset from baseline to top of glyph) + width: Image width + height: Image height + """ + char: str + image: np.ndarray # (H, W, 4) RGBA uint8 + advance: float + bearing_x: float + bearing_y: float + width: int + height: int + + +# Font cache: (font_name, font_size) -> {char: GlyphData} +_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {} + +# Font metrics cache: (font_name, font_size) -> (ascent, descent) +_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {} + +# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment} +# Kerning adjustment is added to advance: new_advance = advance + kerning +# Typically negative (characters move closer together) +_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {} + + +def _load_font(font_name: str = None, font_size: int = 32): + """Load a font. Called at compile time.""" + from PIL import ImageFont + + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + ] + + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + + return ImageFont.load_default() + + +def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]: + """Get or create glyph cache for a font. Called at compile time.""" + cache_key = (font_name, font_size) + + if cache_key in _GLYPH_CACHE: + return _GLYPH_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + _METRICS_CACHE[cache_key] = (ascent, descent) + + glyphs = {} + charset = ''.join(chr(i) for i in range(32, 127)) + + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + for char in charset: + # Get metrics + bbox = temp_draw.textbbox((0, 0), char, font=font) + advance = font.getlength(char) + + x_min, y_min, x_max, y_max = bbox + + # Create glyph image with padding + padding = 2 + img_w = max(int(x_max - x_min) + padding * 2, 1) + img_h = max(int(y_max - y_min) + padding * 2, 1) + + glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + glyph_draw = ImageDraw.Draw(glyph_img) + + # Draw at position accounting for bbox offset + draw_x = padding - x_min + draw_y = padding - y_min + glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + glyphs[char] = GlyphData( + char=char, + image=np.array(glyph_img, dtype=np.uint8), + advance=float(advance), + bearing_x=float(x_min), + bearing_y=float(-y_min), # Distance from baseline to top + width=img_w, + height=img_h, + ) + + _GLYPH_CACHE[cache_key] = glyphs + return glyphs + + +def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]: + """Get or create kerning cache for a font. Called at compile time. + + Kerning is computed as: + kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b) + + This gives the adjustment needed when placing 'b' after 'a'. + Typically negative (characters move closer together). + """ + cache_key = (font_name, font_size) + + if cache_key in _KERNING_CACHE: + return _KERNING_CACHE[cache_key] + + font = _load_font(font_name, font_size) + kerning = {} + + # Compute kerning for all printable ASCII pairs + charset = ''.join(chr(i) for i in range(32, 127)) + + # Pre-compute individual character lengths + char_lengths = {c: font.getlength(c) for c in charset} + + # Compute kerning for each pair + for c1 in charset: + for c2 in charset: + pair_length = font.getlength(c1 + c2) + individual_sum = char_lengths[c1] + char_lengths[c2] + kern = pair_length - individual_sum + + # Only store non-zero kerning to save memory + if abs(kern) > 0.01: + kerning[(c1, c2)] = kern + + _KERNING_CACHE[cache_key] = kerning + return kerning + + +def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float: + """Get kerning adjustment between two characters. Compile-time. + + Returns the adjustment to add to char1's advance when char2 follows. + Typically negative (characters move closer). + + Usage in S-expression: + (+ (glyph-advance g1) (kerning g1 g2)) + """ + kerning_cache = _get_kerning_cache(font_name, font_size) + return kerning_cache.get((char1, char2), 0.0) + + +@dataclass +class TextStrip: + """Pre-rendered text strip with proper sub-pixel anti-aliasing. + + Rendered at compile time using PIL for exact matching. + At runtime, just composite onto frame at integer positions. + + Attributes: + text: The original text + image: RGBA image as numpy array (H, W, 4) + width: Strip width + height: Strip height + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing of first character + anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right) + anchor_y: Y offset for anchor point (depends on anchor type) + stroke_width: Stroke width used when rendering + """ + text: str + image: np.ndarray + width: int + height: int + baseline_y: int + bearing_x: float + anchor_x: float = 0.0 + anchor_y: float = 0.0 + stroke_width: int = 0 + + +# Text strip cache: cache_key -> TextStrip +_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {} + + +def render_text_strip( + text: str, + font_name: str = None, + font_size: int = 32, + stroke_width: int = 0, + stroke_fill: tuple = None, + anchor: str = "la", # left-ascender (PIL default is "la") + multiline: bool = False, + line_spacing: int = 4, + align: str = "left", +) -> TextStrip: + """Render text to a strip at compile time. Perfect sub-pixel anti-aliasing. + + Args: + text: Text to render + font_name: Path to font file (None for default) + font_size: Font size in pixels + stroke_width: Outline width in pixels (0 for no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black + anchor: PIL anchor code - first char: h=left, m=middle, r=right + second char: a=ascender, t=top, m=middle, s=baseline, d=descender + multiline: If True, handle newlines in text + line_spacing: Extra pixels between lines (for multiline) + align: 'left', 'center', 'right' (for multiline) + + Returns: + TextStrip with pre-rendered text + """ + # Build cache key from all parameters + cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align) + if cache_key in _TEXT_STRIP_CACHE: + return _TEXT_STRIP_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + + # Default stroke fill to black + if stroke_fill is None: + stroke_fill = (0, 0, 0, 255) + elif len(stroke_fill) == 3: + stroke_fill = (*stroke_fill, 255) + + # Get text bbox (accounting for stroke) + temp = Image.new('RGBA', (1, 1)) + temp_draw = ImageDraw.Draw(temp) + + if multiline: + bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing, + stroke_width=stroke_width) + else: + bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width) + + # bbox is (left, top, right, bottom) relative to origin + x_min, y_min, x_max, y_max = bbox + + # Create image with padding (extra for stroke) + padding = 2 + stroke_width + img_width = max(int(x_max - x_min) + padding * 2, 1) + img_height = max(int(y_max - y_min) + padding * 2, 1) + + # Create RGBA image + img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw text at position that puts it in the image + draw_x = padding - x_min + draw_y = padding - y_min + + if multiline: + draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + spacing=line_spacing, align=align, + stroke_width=stroke_width, stroke_fill=stroke_fill) + else: + draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill) + + # Baseline is at y=0 in text coordinates, which is at draw_y in image + baseline_y = draw_y + + # Convert to numpy for pixel analysis + img_array = np.array(img, dtype=np.uint8) + + # Calculate anchor offsets + # For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching + h_anchor = anchor[0] if len(anchor) > 0 else 'l' + v_anchor = anchor[1] if len(anchor) > 1 else 'a' + + # Find actual pixel bounds (for middle anchor calculations) + alpha = img_array[:, :, 3] + nonzero_cols = np.where(alpha.max(axis=0) > 0)[0] + nonzero_rows = np.where(alpha.max(axis=1) > 0)[0] + + if len(nonzero_cols) > 0: + pixel_x_min = nonzero_cols.min() + pixel_x_max = nonzero_cols.max() + pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0 + else: + pixel_x_center = img_width / 2.0 + + if len(nonzero_rows) > 0: + pixel_y_min = nonzero_rows.min() + pixel_y_max = nonzero_rows.max() + pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0 + else: + pixel_y_center = img_height / 2.0 + + # Horizontal offset + text_width = x_max - x_min + if h_anchor == 'l': # left edge of text + anchor_x = float(draw_x) + elif h_anchor == 'm': # middle - use actual pixel center for perfect matching + anchor_x = pixel_x_center + elif h_anchor == 'r': # right edge of text + anchor_x = float(draw_x + text_width) + else: + anchor_x = float(draw_x) + + # Vertical offset + # PIL anchor positions are based on font metrics (ascent/descent): + # - 'a' (ascender): at the ascender line → draw_y in strip + # - 't' (top): at top of text bounding box → padding in strip + # - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender + # - 's' (baseline): at baseline = ascent below ascender + # - 'd' (descender): at descender line = ascent + descent below ascender + + if v_anchor == 'a': # ascender + anchor_y = float(draw_y) + elif v_anchor == 't': # top of bbox + anchor_y = float(padding) + elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation) + anchor_y = float(draw_y + (ascent + descent) / 2.0) + elif v_anchor == 's': # baseline + anchor_y = float(draw_y + ascent) + elif v_anchor == 'd': # descender + anchor_y = float(draw_y + ascent + descent) + else: + anchor_y = float(draw_y) # default to ascender + + strip = TextStrip( + text=text, + image=img_array, + width=img_width, + height=img_height, + baseline_y=baseline_y, + bearing_x=float(x_min), + anchor_x=anchor_x, + anchor_y=anchor_y, + stroke_width=stroke_width, + ) + + _TEXT_STRIP_CACHE[cache_key] = strip + return strip + + +# ============================================================================= +# Compile-time functions (called during S-expression compilation) +# ============================================================================= + +def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData: + """Get glyph data for a single character. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + return cache.get(char, cache.get(' ')) + + +def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list: + """Get glyph data for a string. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + space = cache.get(' ') + return [cache.get(c, space) for c in text] + + +def get_font_ascent(font_name: str = None, font_size: int = 32) -> float: + """Get font ascent. Compile-time.""" + _get_glyph_cache(font_name, font_size) # Ensure cache exists + return _METRICS_CACHE[(font_name, font_size)][0] + + +def get_font_descent(font_name: str = None, font_size: int = 32) -> float: + """Get font descent. Compile-time.""" + _get_glyph_cache(font_name, font_size) + return _METRICS_CACHE[(font_name, font_size)][1] + + +# ============================================================================= +# JAX Runtime Primitives +# ============================================================================= + +def place_glyph_jax( + frame: jnp.ndarray, + glyph_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + bearing_x: float, + bearing_y: float, + color: jnp.ndarray, # (3,) RGB 0-255 + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Place a glyph onto a frame. This is the core JAX primitive. + + All positioning math can use traced values (x, y from audio, time, etc.) + The glyph_image is static (determined at compile time). + + Args: + frame: (H, W, 3) RGB frame + glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array) + x: X position of glyph origin (baseline point) + y: Y position of baseline + bearing_x: Left side bearing + bearing_y: Top bearing (from baseline to top) + color: RGB color array + opacity: Opacity 0-1 + + Returns: + Frame with glyph composited + """ + h, w = frame.shape[:2] + gh, gw = glyph_image.shape[:2] + + # Calculate destination position + # bearing_x: how far right of origin the glyph starts (can be negative) + # bearing_y: how far up from baseline the glyph extends + padding = 2 # Must match padding used in glyph creation + dst_x = x + bearing_x - padding + dst_y = y - bearing_y - padding + + # Extract glyph RGB and alpha + glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + opacity_int = jnp.round(opacity * 255) + glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32) + glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint (glyph is white, multiply by color) + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = glyph_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use padded buffer to avoid XLA's dynamic_update_slice clamping + buf_h = h + 2 * gh + buf_w = w + 2 * gw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + dst_x_int = dst_x.astype(jnp.int32) + dst_y_int = dst_y.astype(jnp.int32) + place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :] + alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """ + Place a pre-rendered text strip onto a frame. + + The strip was rendered at compile time with proper sub-pixel anti-aliasing. + This just composites it at the specified position. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x: X position for anchor point + y: Y position for anchor point + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing + color: RGB color + opacity: Opacity 0-1 + anchor_x: X offset of anchor point within strip + anchor_y: Y offset of anchor point within strip + stroke_width: Stroke width used when rendering (affects padding) + + Returns: + Frame with text composited + """ + h, w = frame.shape[:2] + sh, sw = strip_image.shape[:2] + + # Calculate destination position + # Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame + # anchor_x/anchor_y already account for the anchor position within the strip + # Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding) + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract strip RGB and alpha + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + # Use jnp.round (banker's rounding) to match Python's round() used by PIL + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior. + # XLA clamps indices so the update fits, which silently shifts the strip. + # By placing into a buffer padded by strip dimensions, then extracting the + # frame-sized region, we get correct clipping for both overflow and underflow. + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + # Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + # Extract frame-sized region (sh, sw are compile-time constants from strip shape) + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic: + # result = (src * alpha + dst * (255 - alpha) + 127) // 255 + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_glyph_simple( + frame: jnp.ndarray, + glyph: GlyphData, + x: float, + y: float, + color: tuple = (255, 255, 255), + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Convenience wrapper that takes GlyphData directly. + Converts glyph image to JAX array. + + For S-expression use, prefer place_glyph_jax with pre-converted arrays. + """ + glyph_jax = jnp.asarray(glyph.image) + color_jax = jnp.array(color, dtype=jnp.float32) + + return place_glyph_jax( + frame, glyph_jax, x, y, + glyph.bearing_x, glyph.bearing_y, + color_jax, opacity + ) + + +# ============================================================================= +# Gradient Functions (compile-time: generate color maps from strip dimensions) +# ============================================================================= + +def make_linear_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + angle: float = 0.0, +) -> np.ndarray: + """Create a linear gradient color map. + + Args: + width, height: Dimensions of the gradient (match strip dimensions) + color1: Start color (R, G, B) 0-255 + color2: End color (R, G, B) 0-255 + angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + # Create coordinate grid + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Project onto gradient axis + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + + # Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1] + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + # Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|) + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + t = np.clip(t, 0.0, 1.0) + + # Interpolate + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_radial_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a radial gradient color map. + + Args: + width, height: Dimensions + color1: Inner color (R, G, B) + color2: Outer color (R, G, B) + center_x, center_y: Center position in [0, 1] (0.5 = center) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Distance from center, normalized so corners are ~1.0 + dx = nx - center_x + dy = ny - center_y + # Max possible distance from center to a corner + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + if max_dist > 0: + t = np.sqrt(dx**2 + dy**2) / max_dist + else: + t = np.zeros_like(dx) + t = np.clip(t, 0.0, 1.0) + + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_multi_stop_gradient( + width: int, + height: int, + stops: list, + angle: float = 0.0, + radial: bool = False, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a multi-stop gradient color map. + + Args: + width, height: Dimensions + stops: List of (position, (R, G, B)) tuples, position in [0, 1] + angle: Gradient angle in degrees (for linear mode) + radial: If True, use radial gradient + center_x, center_y: Center for radial gradient + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + if len(stops) < 2: + if len(stops) == 1: + c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0 + return np.broadcast_to(c, (height, width, 3)).copy() + return np.zeros((height, width, 3), dtype=np.float32) + + # Sort stops by position + stops = sorted(stops, key=lambda s: s[0]) + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + if radial: + dx = nx - center_x + dy = ny - center_y + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6) + else: + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + + t = np.clip(t, 0.0, 1.0) + + # Build gradient from stops using piecewise linear interpolation + colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops]) + positions = np.array([s[0] for s in stops], dtype=np.float32) + + # Start with first color + gradient = np.broadcast_to(colors[0], (height, width, 3)).copy() + + for i in range(len(stops) - 1): + p0, p1 = positions[i], positions[i + 1] + c0, c1 = colors[i], colors[i + 1] + + if p1 <= p0: + continue + + # Segment interpolation factor + seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0) + # Only apply where t >= p0 + mask = (t >= p0)[:, :, None] + seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None] + gradient = np.where(mask, seg_color, gradient) + + return gradient + + +def _composite_strip_onto_frame( + frame: jnp.ndarray, + strip_rgb: jnp.ndarray, + strip_alpha: jnp.ndarray, + dst_x: jnp.ndarray, + dst_y: jnp.ndarray, + sh: int, + sw: int, +) -> jnp.ndarray: + """Core compositing: place tinted+alpha strip onto frame using padded buffer. + + Args: + frame: (H, W, 3) RGB uint8 + strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB + strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha + dst_x, dst_y: int32 destination position + sh, sw: strip dimensions (compile-time constants) + + Returns: + Composited frame (H, W, 3) uint8 + """ + h, w = frame.shape[:2] + + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0)) + alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # PIL-compatible integer alpha blending + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_gradient_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + gradient_map: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """Place text strip with gradient coloring instead of solid color. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + gradient_map: (sh, sw, 3) float32 color map in [0, 1] + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + sh, sw = strip_image.shape[:2] + + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract alpha with opacity + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply gradient instead of solid color + tinted = strip_rgb * gradient_map + + return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw) + + +# ============================================================================= +# Strip Rotation (RGBA bilinear interpolation) +# ============================================================================= + +def _sample_rgba(strip, x, y): + """Bilinear sample all 4 RGBA channels from a strip. + + Args: + strip: (H, W, 4) RGBA float32 + x, y: coordinate arrays (flattened) + + Returns: + (r, g, b, a) each same shape as x + """ + h, w = strip.shape[:2] + + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = x0 + 1 + y1 = y0 + 1 + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + + channels = [] + for c in range(4): + c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0) + c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0) + c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0) + c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0) + + val = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + channels.append(val) + + return channels[0], channels[1], channels[2], channels[3] + + +def rotate_strip_jax( + strip_image: jnp.ndarray, + angle: float, +) -> jnp.ndarray: + """Rotate an RGBA strip by angle (degrees), counter-clockwise. + + Output buffer is sized to contain the full rotated strip. + The output size is ceil(sqrt(w^2 + h^2)), computed at trace time + from the strip's static shape. + + Args: + strip_image: (H, W, 4) RGBA uint8 + angle: Rotation angle in degrees + + Returns: + (out_h, out_w, 4) RGBA uint8 - rotated strip + """ + sh, sw = strip_image.shape[:2] + + # Output size: diagonal of original strip (compile-time constant). + # Ensure output dimensions have same parity as source so that the + # center offset (out - src) / 2 is always an integer. Otherwise + # identity rotations would place content at half-pixel offsets. + diag = int(math.ceil(math.sqrt(sw * sw + sh * sh))) + out_w = diag + ((diag % 2) != (sw % 2)) + out_h = diag + ((diag % 2) != (sh % 2)) + + # Center of input strip and output buffer (pixel-center convention). + # Using (dim-1)/2 ensures integer coords map to integer coords for + # identity rotation regardless of even/odd dimension parity. + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (out_w - 1) / 2.0 + dst_cy = (out_h - 1) / 2.0 + + # Convert to radians and snap trig values near 0/±1 to exact values. + # Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing + # bilinear blending at pixel edges and 1-value differences. + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Create output coordinate grid + y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w) + x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w) + + # Inverse rotation: map output coords to source coords + x_centered = x_coords.astype(jnp.float32) - dst_cx + y_centered = y_coords.astype(jnp.float32) - dst_cy + + src_x = cos_t * x_centered - sin_t * y_centered + src_cx + src_y = sin_t * x_centered + cos_t * y_centered + src_cy + + # Sample all 4 channels + strip_f = strip_image.astype(jnp.float32) + r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + ], axis=2) + + +# ============================================================================= +# Shadow Compositing +# ============================================================================= + +def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray: + """Blur a single-channel alpha array using Gaussian convolution. + + Args: + alpha: (H, W) float32 alpha channel + radius: Blur radius (compile-time constant) + + Returns: + (H, W) float32 blurred alpha + """ + size = radius * 2 + 1 + x = jnp.arange(size, dtype=jnp.float32) - radius + sigma = max(radius / 2.0, 0.5) + gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + # Use JAX conv with SAME padding + h, w = alpha.shape + data_4d = alpha.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(size, size, 1, 1) + + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + ) + return result.reshape(h, w) + + +def place_text_strip_shadow_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + shadow_offset_x: float = 3.0, + shadow_offset_y: float = 3.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.5, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Place text strip with a drop shadow. + + Composites the strip twice: first as shadow (offset, colored, optionally blurred), + then the text itself on top. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + shadow_offset_x/y: Shadow offset in pixels + shadow_color: (3,) RGB color for shadow (default black) + shadow_opacity: Shadow opacity 0-1 + shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time) + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Shadow pass --- + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + # Shadow alpha from strip alpha + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + # Blur the alpha channel for soft shadow + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1) + + # Shadow RGB: solid shadow color + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Text pass --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0 + + color_norm = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_norm + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + +# ============================================================================= +# Combined FX Pipeline +# ============================================================================= + +def place_text_strip_fx_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int = 0, + bearing_x: float = 0.0, + color: jnp.ndarray = None, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + gradient_map: jnp.ndarray = None, + angle: float = 0.0, + shadow_offset_x: float = 0.0, + shadow_offset_y: float = 0.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.0, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Combined text placement with gradient, rotation, and shadow. + + Pipeline order: + 1. Build color layer (solid color or gradient map) + 2. Rotate strip + color layer if angle != 0 + 3. Composite shadow if shadow_opacity > 0 + 4. Composite text + + Note: angle and shadow_blur_radius should be compile-time constants + for optimal JIT performance (they affect buffer shapes/kernel sizes). + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x, y: Anchor point position + color: (3,) RGB color (ignored if gradient_map provided) + opacity: Text opacity + gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color + angle: Rotation angle in degrees (0 = no rotation) + shadow_offset_x/y: Shadow offset + shadow_color: (3,) RGB shadow color + shadow_opacity: Shadow opacity (0 = no shadow) + shadow_blur_radius: Shadow blur radius + + Returns: + Composited frame + """ + if color is None: + color = jnp.array([255, 255, 255], dtype=jnp.float32) + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Step 1: Build color layer --- + if gradient_map is not None: + color_layer = gradient_map # (sh, sw, 3) float32 [0, 1] + else: + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3)) + + # --- Step 2: Rotate if needed --- + # angle is expected to be a compile-time constant or static value + # We check at Python level to avoid tracing issues with dynamic shapes + use_rotation = not isinstance(angle, (int, float)) or angle != 0.0 + + if use_rotation: + # Rotate the strip + rotated_strip = rotate_strip_jax(strip_image, angle) + rh, rw = rotated_strip.shape[:2] + + # Rotate the color layer by building a 4-channel color+dummy image + # Actually, just re-create color layer at rotated size + if gradient_map is not None: + # Rotate gradient map: pack into 3-channel "image", rotate via sampling + grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8) + # Create RGBA from gradient (alpha=255 everywhere) + grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2) + rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle) + color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0 + else: + # Solid color: just broadcast to rotated size + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3)) + + # Update anchor point for rotation (pixel-center convention) + # Rotate the anchor offset around the strip center + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Original anchor relative to strip pixel center + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (rw - 1) / 2.0 + dst_cy = (rh - 1) / 2.0 + + ax_rel = anchor_x - src_cx + ay_rel = anchor_y - src_cy + + # Rotate anchor point (forward rotation, not inverse) + new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx + new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy + + strip_image = rotated_strip + anchor_x = new_ax + anchor_y = new_ay + sh, sw = rh, rw + + # --- Step 3: Shadow --- + has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0 + if has_shadow: + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] + + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Step 4: Composite text --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + tinted = strip_rgb * color_layer + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + +# ============================================================================= +# S-Expression Primitive Bindings +# ============================================================================= + +def bind_typography_primitives(env: dict) -> dict: + """ + Add typography primitives to an S-expression environment. + + Primitives added: + (text-glyphs text font-size) -> list of glyph data + (glyph-image g) -> JAX array (H, W, 4) + (glyph-advance g) -> float + (glyph-bearing-x g) -> float + (glyph-bearing-y g) -> float + (glyph-width g) -> int + (glyph-height g) -> int + (font-ascent font-size) -> float + (font-descent font-size) -> float + (place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame + """ + + def prim_text_glyphs(text, font_size=32, font_name=None): + """Get list of glyph data for text. Compile-time.""" + return get_glyphs(str(text), font_name, int(font_size)) + + def prim_glyph_image(glyph): + """Get glyph image as JAX array.""" + return jnp.asarray(glyph.image) + + def prim_glyph_advance(glyph): + """Get glyph advance width.""" + return glyph.advance + + def prim_glyph_bearing_x(glyph): + """Get glyph left side bearing.""" + return glyph.bearing_x + + def prim_glyph_bearing_y(glyph): + """Get glyph top bearing.""" + return glyph.bearing_y + + def prim_glyph_width(glyph): + """Get glyph image width.""" + return glyph.width + + def prim_glyph_height(glyph): + """Get glyph image height.""" + return glyph.height + + def prim_font_ascent(font_size=32, font_name=None): + """Get font ascent.""" + return get_font_ascent(font_name, int(font_size)) + + def prim_font_descent(font_size=32, font_name=None): + """Get font descent.""" + return get_font_descent(font_name, int(font_size)) + + def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y, + color=(255, 255, 255), opacity=1.0): + """Place glyph on frame. Runtime JAX operation.""" + color_arr = jnp.array(color, dtype=jnp.float32) + return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y, + color_arr, opacity) + + def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None): + """Get kerning adjustment between two glyphs. Compile-time. + + Returns adjustment to add to glyph1's advance when glyph2 follows. + Typically negative (characters move closer). + + Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size)) + """ + return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size)) + + def prim_char_kerning(char1, char2, font_size=32, font_name=None): + """Get kerning adjustment between two characters. Compile-time.""" + return get_kerning(str(char1), str(char2), font_name, int(font_size)) + + # TextStrip primitives for pre-rendered text with proper anti-aliasing + def prim_render_text_strip(text, font_size=32, font_name=None): + """Render text to a strip at compile time. Perfect anti-aliasing.""" + return render_text_strip(str(text), font_name, int(font_size)) + + def prim_render_text_strip_styled( + text, font_size=32, font_name=None, + stroke_width=0, stroke_fill=None, + anchor="la", multiline=False, line_spacing=4, align="left" + ): + """Render styled text to a strip. Supports stroke, anchors, multiline. + + Args: + text: Text to render + font_size: Size in pixels + font_name: Path to font file + stroke_width: Outline width (0 = no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A) + anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender) + multiline: If True, handle newlines + line_spacing: Extra pixels between lines + align: "left", "center", "right" for multiline + """ + return render_text_strip( + str(text), font_name, int(font_size), + stroke_width=int(stroke_width), + stroke_fill=stroke_fill, + anchor=str(anchor), + multiline=bool(multiline), + line_spacing=int(line_spacing), + align=str(align), + ) + + def prim_text_strip_image(strip): + """Get text strip image as JAX array.""" + return jnp.asarray(strip.image) + + def prim_text_strip_width(strip): + """Get text strip width.""" + return strip.width + + def prim_text_strip_height(strip): + """Get text strip height.""" + return strip.height + + def prim_text_strip_baseline_y(strip): + """Get text strip baseline Y position.""" + return strip.baseline_y + + def prim_text_strip_bearing_x(strip): + """Get text strip left bearing.""" + return strip.bearing_x + + def prim_text_strip_anchor_x(strip): + """Get text strip anchor X offset.""" + return strip.anchor_x + + def prim_text_strip_anchor_y(strip): + """Get text strip anchor Y offset.""" + return strip.anchor_y + + def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0): + """Place pre-rendered text strip on frame. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # --- Gradient primitives --- + + def prim_linear_gradient(strip, color1, color2, angle=0.0): + """Create linear gradient color map for a text strip. Compile-time.""" + grad = make_linear_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(angle)) + return jnp.asarray(grad) + + def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5): + """Create radial gradient color map for a text strip. Compile-time.""" + grad = make_radial_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False, + center_x=0.5, center_y=0.5): + """Create multi-stop gradient for a text strip. Compile-time. + + stops: list of (position, (R, G, B)) tuples + """ + parsed_stops = [] + for s in stops: + pos = float(s[0]) + color_tuple = tuple(int(c) for c in s[1]) + parsed_stops.append((pos, color_tuple)) + grad = make_multi_stop_gradient(strip.width, strip.height, + parsed_stops, float(angle), + bool(radial), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0): + """Place text strip with gradient coloring. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + gradient_map, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # --- Rotation primitive --- + + def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255), + opacity=1.0, angle=0.0): + """Place text strip with rotation. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + angle=float(angle), + ) + + # --- Shadow primitive --- + + def prim_place_text_strip_shadow(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=(0, 0, 0), shadow_opacity=0.5, + shadow_blur_radius=0): + """Place text strip with shadow. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_shadow_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur_radius), + ) + + # --- Combined FX primitive --- + + def prim_place_text_strip_fx(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + gradient=None, angle=0.0, + shadow_offset_x=0.0, shadow_offset_y=0.0, + shadow_color=(0, 0, 0), shadow_opacity=0.0, + shadow_blur=0): + """Place text strip with all effects. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + gradient_map=gradient, + angle=float(angle), + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur), + ) + + # Add to environment + env.update({ + # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) + 'text-glyphs': prim_text_glyphs, + 'glyph-image': prim_glyph_image, + 'glyph-advance': prim_glyph_advance, + 'glyph-bearing-x': prim_glyph_bearing_x, + 'glyph-bearing-y': prim_glyph_bearing_y, + 'glyph-width': prim_glyph_width, + 'glyph-height': prim_glyph_height, + 'glyph-kerning': prim_glyph_kerning, + 'char-kerning': prim_char_kerning, + 'font-ascent': prim_font_ascent, + 'font-descent': prim_font_descent, + 'place-glyph': prim_place_glyph, + # TextStrip primitives (for pixel-perfect static text) + 'render-text-strip': prim_render_text_strip, + 'render-text-strip-styled': prim_render_text_strip_styled, + 'text-strip-image': prim_text_strip_image, + 'text-strip-width': prim_text_strip_width, + 'text-strip-height': prim_text_strip_height, + 'text-strip-baseline-y': prim_text_strip_baseline_y, + 'text-strip-bearing-x': prim_text_strip_bearing_x, + 'text-strip-anchor-x': prim_text_strip_anchor_x, + 'text-strip-anchor-y': prim_text_strip_anchor_y, + 'place-text-strip': prim_place_text_strip, + # Gradient primitives + 'linear-gradient': prim_linear_gradient, + 'radial-gradient': prim_radial_gradient, + 'multi-stop-gradient': prim_multi_stop_gradient, + 'place-text-strip-gradient': prim_place_text_strip_gradient, + # Rotation + 'place-text-strip-rotated': prim_place_text_strip_rotated, + # Shadow + 'place-text-strip-shadow': prim_place_text_strip_shadow, + # Combined FX + 'place-text-strip-fx': prim_place_text_strip_fx, + }) + + return env + + +# ============================================================================= +# Example: Render text using primitives (for testing) +# ============================================================================= + +def render_text_primitives( + frame: jnp.ndarray, + text: str, + x: float, + y: float, + font_size: int = 32, + color: tuple = (255, 255, 255), + opacity: float = 1.0, + use_kerning: bool = True, +) -> jnp.ndarray: + """ + Render text using the primitives. + This is what an S-expression would compile to. + + Args: + use_kerning: If True, apply kerning adjustments between characters + """ + glyphs = get_glyphs(text, None, font_size) + color_arr = jnp.array(color, dtype=jnp.float32) + + cursor = x + for i, g in enumerate(glyphs): + glyph_jax = jnp.asarray(g.image) + frame = place_glyph_jax( + frame, glyph_jax, cursor, y, + g.bearing_x, g.bearing_y, + color_arr, opacity + ) + # Advance cursor with optional kerning + advance = g.advance + if use_kerning and i + 1 < len(glyphs): + advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size) + cursor = cursor + advance + + return frame diff --git a/l1/streaming/jit_compiler.py b/l1/streaming/jit_compiler.py new file mode 100644 index 0000000..bb8c97c --- /dev/null +++ b/l1/streaming/jit_compiler.py @@ -0,0 +1,531 @@ +""" +JIT Compiler for sexp frame pipelines. + +Compiles sexp expressions to fused CUDA kernels for maximum performance. +""" + +import cupy as cp +import numpy as np +from typing import Dict, List, Any, Optional, Tuple, Callable +import hashlib +import sys + +# Cache for compiled kernels +_KERNEL_CACHE: Dict[str, Callable] = {} + + +def _generate_kernel_key(ops: List[Tuple]) -> str: + """Generate cache key for operation sequence.""" + return hashlib.md5(str(ops).encode()).hexdigest() + + +# ============================================================================= +# CUDA Kernel Templates +# ============================================================================= + +AFFINE_WARP_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void affine_warp( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float m00, float m01, float m02, + float m10, float m11, float m12 +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Apply inverse affine transform + float src_x = m00 * x + m01 * y + m02; + float src_y = m10 * x + m11 * y + m12; + + int dst_idx = (y * width + x) * channels; + + // Bounds check + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = x0 + 1; + int y1 = y0 + 1; + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float v = v0 * (1 - fy) + v1 * fy; + + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'affine_warp') + + +BLEND_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void blend( + const unsigned char* src1, + const unsigned char* src2, + unsigned char* dst, + int size, + float alpha +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + float v = src1[idx] * (1.0f - alpha) + src2[idx] * alpha; + dst[idx] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); +} +''', 'blend') + + +BRIGHTNESS_CONTRAST_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void brightness_contrast( + const unsigned char* src, + unsigned char* dst, + int size, + float brightness, + float contrast +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + float v = src[idx]; + v = (v - 128.0f) * contrast + 128.0f + brightness; + dst[idx] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); +} +''', 'brightness_contrast') + + +HUE_SHIFT_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void hue_shift( + const unsigned char* src, + unsigned char* dst, + int width, int height, + float hue_shift +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int idx = (y * width + x) * 3; + + float r = src[idx] / 255.0f; + float g = src[idx + 1] / 255.0f; + float b = src[idx + 2] / 255.0f; + + // RGB to HSV + float max_c = fmaxf(r, fmaxf(g, b)); + float min_c = fminf(r, fminf(g, b)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) { + s = delta / max_c; + if (r >= max_c) h = (g - b) / delta; + else if (g >= max_c) h = 2.0f + (b - r) / delta; + else h = 4.0f + (r - g) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + } + + // Apply hue shift + h = fmodf(h + hue_shift + 360.0f, 360.0f); + + // HSV to RGB + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) { r2 = c; g2 = x_val; b2 = 0; } + else if (h < 120) { r2 = x_val; g2 = c; b2 = 0; } + else if (h < 180) { r2 = 0; g2 = c; b2 = x_val; } + else if (h < 240) { r2 = 0; g2 = x_val; b2 = c; } + else if (h < 300) { r2 = x_val; g2 = 0; b2 = c; } + else { r2 = c; g2 = 0; b2 = x_val; } + + dst[idx] = (unsigned char)((r2 + m) * 255); + dst[idx + 1] = (unsigned char)((g2 + m) * 255); + dst[idx + 2] = (unsigned char)((b2 + m) * 255); +} +''', 'hue_shift') + + +INVERT_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void invert( + const unsigned char* src, + unsigned char* dst, + int size +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + dst[idx] = 255 - src[idx]; +} +''', 'invert') + + +ZOOM_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void zoom( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float zoom_factor, + float cx, float cy +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Map to source coordinates (zoom from center) + float src_x = (x - cx) / zoom_factor + cx; + float src_y = (y - cy) / zoom_factor + cy; + + int dst_idx = (y * width + x) * channels; + + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + (x0+1)) * channels + c]; + float v01 = src[((y0+1) * width + x0) * channels + c]; + float v11 = src[((y0+1) * width + (x0+1)) * channels + c]; + + float v = v00*(1-fx)*(1-fy) + v10*fx*(1-fy) + v01*(1-fx)*fy + v11*fx*fy; + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'zoom') + + +RIPPLE_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void ripple( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float cx, float cy, + float amplitude, float frequency, float decay, float phase +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + float dx = x - cx; + float dy = y - cy; + float dist = sqrtf(dx * dx + dy * dy); + + // Ripple displacement + float wave = sinf(dist * frequency * 0.1f + phase); + float amp = amplitude * expf(-dist * decay * 0.01f); + + float src_x = x + dx / (dist + 0.001f) * wave * amp; + float src_y = y + dy / (dist + 0.001f) * wave * amp; + + int dst_idx = (y * width + x) * channels; + + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = src[dst_idx + c]; // Keep original on boundary + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + (x0+1)) * channels + c]; + float v01 = src[((y0+1) * width + x0) * channels + c]; + float v11 = src[((y0+1) * width + (x0+1)) * channels + c]; + + float v = v00*(1-fx)*(1-fy) + v10*fx*(1-fy) + v01*(1-fx)*fy + v11*fx*fy; + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'ripple') + + +# ============================================================================= +# Fast GPU Operations +# ============================================================================= + +class FastGPUOps: + """Optimized GPU operations using CUDA kernels.""" + + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.channels = 3 + + # Pre-allocate work buffers + self._buf1 = cp.zeros((height, width, 3), dtype=cp.uint8) + self._buf2 = cp.zeros((height, width, 3), dtype=cp.uint8) + self._current_buf = 0 + + # Grid/block config + self._block_2d = (16, 16) + self._grid_2d = ((width + 15) // 16, (height + 15) // 16) + self._block_1d = 256 + self._grid_1d = (width * height * 3 + 255) // 256 + + def _get_buffers(self): + """Get source and destination buffers (ping-pong).""" + if self._current_buf == 0: + return self._buf1, self._buf2 + return self._buf2, self._buf1 + + def _swap_buffers(self): + """Swap ping-pong buffers.""" + self._current_buf = 1 - self._current_buf + + def set_input(self, frame: cp.ndarray): + """Set input frame.""" + if self._current_buf == 0: + cp.copyto(self._buf1, frame) + else: + cp.copyto(self._buf2, frame) + + def get_output(self) -> cp.ndarray: + """Get current output buffer.""" + if self._current_buf == 0: + return self._buf1 + return self._buf2 + + def rotate(self, angle: float, cx: float = None, cy: float = None): + """Fast GPU rotation.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + # Compute inverse rotation matrix + import math + rad = math.radians(-angle) # Negative for inverse + cos_a = math.cos(rad) + sin_a = math.sin(rad) + + # Inverse affine matrix (rotate around center) + m00 = cos_a + m01 = -sin_a + m02 = cx - cos_a * cx + sin_a * cy + m10 = sin_a + m11 = cos_a + m12 = cy - sin_a * cx - cos_a * cy + + AFFINE_WARP_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(m00), np.float32(m01), np.float32(m02), + np.float32(m10), np.float32(m11), np.float32(m12)) + ) + self._swap_buffers() + + def zoom(self, factor: float, cx: float = None, cy: float = None): + """Fast GPU zoom.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + ZOOM_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(factor), np.float32(cx), np.float32(cy)) + ) + self._swap_buffers() + + def blend(self, other: cp.ndarray, alpha: float): + """Fast GPU blend.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BLEND_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), other.ravel(), dst.ravel(), size, np.float32(alpha)) + ) + self._swap_buffers() + + def brightness(self, factor: float): + """Fast GPU brightness adjustment.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BRIGHTNESS_CONTRAST_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size, np.float32((factor - 1) * 128), np.float32(1.0)) + ) + self._swap_buffers() + + def contrast(self, factor: float): + """Fast GPU contrast adjustment.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BRIGHTNESS_CONTRAST_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size, np.float32(0), np.float32(factor)) + ) + self._swap_buffers() + + def hue_shift(self, degrees: float): + """Fast GPU hue shift.""" + src, dst = self._get_buffers() + + HUE_SHIFT_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, np.float32(degrees)) + ) + self._swap_buffers() + + def invert(self): + """Fast GPU invert.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + INVERT_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size) + ) + self._swap_buffers() + + def ripple(self, amplitude: float, cx: float = None, cy: float = None, + frequency: float = 8, decay: float = 2, phase: float = 0): + """Fast GPU ripple effect.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + RIPPLE_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(cx), np.float32(cy), + np.float32(amplitude), np.float32(frequency), + np.float32(decay), np.float32(phase)) + ) + self._swap_buffers() + + +# Global fast ops instance (created per resolution) +_FAST_OPS: Dict[Tuple[int, int], FastGPUOps] = {} + + +def get_fast_ops(width: int, height: int) -> FastGPUOps: + """Get or create FastGPUOps for given resolution.""" + key = (width, height) + if key not in _FAST_OPS: + _FAST_OPS[key] = FastGPUOps(width, height) + return _FAST_OPS[key] + + +# ============================================================================= +# Fast effect functions (drop-in replacements) +# ============================================================================= + +def fast_rotate(frame: cp.ndarray, angle: float, **kwargs) -> cp.ndarray: + """Fast GPU rotation.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.rotate(angle, kwargs.get('cx'), kwargs.get('cy')) + return ops.get_output().copy() + + +def fast_zoom(frame: cp.ndarray, factor: float, **kwargs) -> cp.ndarray: + """Fast GPU zoom.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.zoom(factor, kwargs.get('cx'), kwargs.get('cy')) + return ops.get_output().copy() + + +def fast_blend(frame1: cp.ndarray, frame2: cp.ndarray, alpha: float) -> cp.ndarray: + """Fast GPU blend.""" + h, w = frame1.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame1) + ops.blend(frame2, alpha) + return ops.get_output().copy() + + +def fast_hue_shift(frame: cp.ndarray, degrees: float) -> cp.ndarray: + """Fast GPU hue shift.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.hue_shift(degrees) + return ops.get_output().copy() + + +def fast_invert(frame: cp.ndarray) -> cp.ndarray: + """Fast GPU invert.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.invert() + return ops.get_output().copy() + + +def fast_ripple(frame: cp.ndarray, amplitude: float, **kwargs) -> cp.ndarray: + """Fast GPU ripple.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.ripple( + amplitude, + kwargs.get('center_x', w/2), + kwargs.get('center_y', h/2), + kwargs.get('frequency', 8), + kwargs.get('decay', 2), + kwargs.get('speed', 0) * kwargs.get('t', 0) # phase from speed*time + ) + return ops.get_output().copy() + + +print("[jit_compiler] CUDA kernels loaded", file=sys.stderr) diff --git a/l1/streaming/multi_res_output.py b/l1/streaming/multi_res_output.py new file mode 100644 index 0000000..40c661a --- /dev/null +++ b/l1/streaming/multi_res_output.py @@ -0,0 +1,509 @@ +""" +Multi-Resolution HLS Output with IPFS Storage. + +Renders video at multiple quality levels simultaneously: +- Original resolution (from recipe) +- 720p (streaming quality) +- 360p (mobile/low bandwidth) + +All segments stored on IPFS. Master playlist enables adaptive bitrate streaming. +""" + +import os +import sys +import subprocess +import threading +import queue +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union +from dataclasses import dataclass, field + +import numpy as np + +# Try GPU imports +try: + import cupy as cp + GPU_AVAILABLE = True +except ImportError: + cp = None + GPU_AVAILABLE = False + + +@dataclass +class QualityLevel: + """Configuration for a quality level.""" + name: str + width: int + height: int + bitrate: int # kbps + segment_cids: Dict[int, str] = field(default_factory=dict) + playlist_cid: Optional[str] = None + + +class MultiResolutionHLSOutput: + """ + GPU-accelerated multi-resolution HLS output with IPFS storage. + + Encodes video at multiple quality levels simultaneously using NVENC. + Segments are uploaded to IPFS as they're created. + Generates adaptive bitrate master playlist. + """ + + def __init__( + self, + output_dir: str, + source_size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + audio_source: str = None, + resume_from: Optional[Dict] = None, + ): + """Initialize multi-resolution HLS output. + + Args: + output_dir: Directory for HLS output files + source_size: (width, height) of source frames + fps: Frames per second + segment_duration: Duration of each HLS segment in seconds + ipfs_gateway: IPFS gateway URL for playlist URLs + on_playlist_update: Callback when playlists are updated + audio_source: Optional audio file to mux with video + resume_from: Optional dict to resume from checkpoint with keys: + - segment_cids: Dict of quality -> {seg_num: cid} + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.source_width, self.source_height = source_size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._on_playlist_update = on_playlist_update + self.audio_source = audio_source + self._is_open = True + self._frame_count = 0 + + # Define quality levels + self.qualities: Dict[str, QualityLevel] = {} + self._setup_quality_levels() + + # Restore segment CIDs if resuming (don't re-upload existing segments) + if resume_from and resume_from.get('segment_cids'): + for name, cids in resume_from['segment_cids'].items(): + if name in self.qualities: + self.qualities[name].segment_cids = dict(cids) + print(f"[MultiResHLS] Restored {len(cids)} segment CIDs for {name}", file=sys.stderr) + + # IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Upload queue and thread + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Track master playlist + self._master_playlist_cid = None + + # Setup encoders + self._setup_encoders() + + print(f"[MultiResHLS] Initialized {self.source_width}x{self.source_height} @ {fps}fps", file=sys.stderr) + print(f"[MultiResHLS] Quality levels: {list(self.qualities.keys())}", file=sys.stderr) + + def _setup_quality_levels(self): + """Configure quality levels based on source resolution.""" + # Always include original resolution + self.qualities['original'] = QualityLevel( + name='original', + width=self.source_width, + height=self.source_height, + bitrate=self._estimate_bitrate(self.source_width, self.source_height), + ) + + # Add 720p if source is larger + if self.source_height > 720: + aspect = self.source_width / self.source_height + w720 = int(720 * aspect) + w720 = w720 - (w720 % 2) # Ensure even width + self.qualities['720p'] = QualityLevel( + name='720p', + width=w720, + height=720, + bitrate=2500, + ) + + # Add 360p if source is larger + if self.source_height > 360: + aspect = self.source_width / self.source_height + w360 = int(360 * aspect) + w360 = w360 - (w360 % 2) # Ensure even width + self.qualities['360p'] = QualityLevel( + name='360p', + width=w360, + height=360, + bitrate=800, + ) + + def _estimate_bitrate(self, width: int, height: int) -> int: + """Estimate appropriate bitrate for resolution (in kbps).""" + pixels = width * height + if pixels >= 3840 * 2160: # 4K + return 15000 + elif pixels >= 1920 * 1080: # 1080p + return 5000 + elif pixels >= 1280 * 720: # 720p + return 2500 + elif pixels >= 854 * 480: # 480p + return 1500 + else: + return 800 + + def _setup_encoders(self): + """Setup FFmpeg encoder processes for each quality level.""" + self._encoders: Dict[str, subprocess.Popen] = {} + self._encoder_threads: Dict[str, threading.Thread] = {} + + for name, quality in self.qualities.items(): + # Create output directory for this quality + quality_dir = self.output_dir / name + quality_dir.mkdir(parents=True, exist_ok=True) + + # Build FFmpeg command + cmd = self._build_encoder_cmd(quality, quality_dir) + + print(f"[MultiResHLS] Starting encoder for {name}: {quality.width}x{quality.height}", file=sys.stderr) + + # Start encoder process + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=10**7, # Large buffer to prevent blocking + ) + self._encoders[name] = proc + + # Start stderr drain thread + stderr_thread = threading.Thread( + target=self._drain_stderr, + args=(name, proc), + daemon=True + ) + stderr_thread.start() + self._encoder_threads[name] = stderr_thread + + def _build_encoder_cmd(self, quality: QualityLevel, output_dir: Path) -> List[str]: + """Build FFmpeg command for a quality level.""" + playlist_path = output_dir / "playlist.m3u8" + segment_pattern = output_dir / "segment_%05d.ts" + + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-pixel_format", "rgb24", + "-video_size", f"{self.source_width}x{self.source_height}", + "-framerate", str(self.fps), + "-i", "-", + ] + + # Add audio input if provided + if self.audio_source: + cmd.extend(["-i", str(self.audio_source)]) + # Map video from input 0, audio from input 1 + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Scale if not original resolution + if quality.width != self.source_width or quality.height != self.source_height: + cmd.extend([ + "-vf", f"scale={quality.width}:{quality.height}:flags=lanczos", + ]) + + # NVENC encoding with quality settings + cmd.extend([ + "-c:v", "h264_nvenc", + "-preset", "p4", # Balanced speed/quality + "-tune", "hq", + "-b:v", f"{quality.bitrate}k", + "-maxrate", f"{int(quality.bitrate * 1.5)}k", + "-bufsize", f"{quality.bitrate * 2}k", + "-g", str(int(self.fps * self.segment_duration)), # Keyframe interval = segment duration + "-keyint_min", str(int(self.fps * self.segment_duration)), + "-sc_threshold", "0", # Disable scene change detection for consistent segments + ]) + + # Add audio encoding if audio source provided + if self.audio_source: + cmd.extend([ + "-c:a", "aac", + "-b:a", "128k", + "-shortest", # Stop when shortest stream ends + ]) + + # HLS output + cmd.extend([ + "-f", "hls", + "-hls_time", str(self.segment_duration), + "-hls_list_size", "0", # Keep all segments in playlist + "-hls_flags", "independent_segments+append_list", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(segment_pattern), + str(playlist_path), + ]) + + return cmd + + def _drain_stderr(self, name: str, proc: subprocess.Popen): + """Drain FFmpeg stderr to prevent blocking.""" + try: + for line in proc.stderr: + line_str = line.decode('utf-8', errors='replace').strip() + if line_str and ('error' in line_str.lower() or 'warning' in line_str.lower()): + print(f"[FFmpeg/{name}] {line_str}", file=sys.stderr) + except Exception as e: + print(f"[FFmpeg/{name}] stderr drain error: {e}", file=sys.stderr) + + def write(self, frame: Union[np.ndarray, 'cp.ndarray'], t: float = 0): + """Write a frame to all quality encoders.""" + if not self._is_open: + return + + # Convert GPU frame to CPU if needed + if GPU_AVAILABLE and hasattr(frame, 'get'): + frame = frame.get() # CuPy to NumPy + elif hasattr(frame, 'cpu'): + frame = frame.cpu # GPUFrame to NumPy + elif hasattr(frame, 'gpu') and hasattr(frame, 'is_on_gpu'): + frame = frame.gpu.get() if frame.is_on_gpu else frame.cpu + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + + # Ensure contiguous + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + frame_bytes = frame.tobytes() + + # Write to all encoders + for name, proc in self._encoders.items(): + if proc.poll() is not None: + print(f"[MultiResHLS] Encoder {name} died with code {proc.returncode}", file=sys.stderr) + self._is_open = False + return + + try: + proc.stdin.write(frame_bytes) + except BrokenPipeError: + print(f"[MultiResHLS] Encoder {name} pipe broken", file=sys.stderr) + self._is_open = False + return + + self._frame_count += 1 + + # Check for new segments periodically + if self._frame_count % int(self.fps * self.segment_duration) == 0: + self._check_and_upload_segments() + + def _check_and_upload_segments(self): + """Check for new segments and queue them for upload.""" + for name, quality in self.qualities.items(): + quality_dir = self.output_dir / name + segments = sorted(quality_dir.glob("segment_*.ts")) + + for seg_path in segments: + seg_num = int(seg_path.stem.split("_")[1]) + + if seg_num in quality.segment_cids: + continue # Already uploaded + + # Check if segment is complete (not still being written) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue + time.sleep(0.05) + size2 = seg_path.stat().st_size + if size1 != size2: + continue # Still being written + except FileNotFoundError: + continue + + # Queue for upload + self._upload_queue.put((name, seg_path, seg_num)) + + def _upload_worker(self): + """Background worker for IPFS uploads.""" + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + + quality_name, seg_path, seg_num = item + self._do_upload(quality_name, seg_path, seg_num) + + except queue.Empty: + continue + except Exception as e: + print(f"[MultiResHLS] Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, quality_name: str, seg_path: Path, seg_num: int): + """Upload a segment to IPFS.""" + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + self.qualities[quality_name].segment_cids[seg_num] = cid + print(f"[MultiResHLS] Uploaded {quality_name}/segment_{seg_num:05d}.ts -> {cid[:16]}...", file=sys.stderr) + + # Update playlists after each upload + self._update_playlists() + except Exception as e: + print(f"[MultiResHLS] Failed to upload {seg_path}: {e}", file=sys.stderr) + + def _update_playlists(self): + """Generate and upload IPFS playlists.""" + # Generate quality-specific playlists + for name, quality in self.qualities.items(): + if not quality.segment_cids: + continue + + playlist = self._generate_quality_playlist(quality) + cid = self._ipfs_add_bytes(playlist.encode(), pin=True) + if cid: + quality.playlist_cid = cid + + # Generate master playlist + self._generate_master_playlist() + + def _generate_quality_playlist(self, quality: QualityLevel, finalize: bool = False) -> str: + """Generate HLS playlist for a quality level.""" + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + if finalize: + lines.append("#EXT-X-PLAYLIST-TYPE:VOD") + + # Use /ipfs-ts/ for correct MIME type + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + + for seg_num in sorted(quality.segment_cids.keys()): + cid = quality.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{segment_gateway}/{cid}") + + if finalize: + lines.append("#EXT-X-ENDLIST") + + return "\n".join(lines) + "\n" + + def _generate_master_playlist(self, finalize: bool = False): + """Generate and upload master playlist.""" + lines = ["#EXTM3U", "#EXT-X-VERSION:3"] + + for name, quality in self.qualities.items(): + if not quality.playlist_cid: + continue + + lines.append( + f"#EXT-X-STREAM-INF:BANDWIDTH={quality.bitrate * 1000}," + f"RESOLUTION={quality.width}x{quality.height}," + f"NAME=\"{name}\"" + ) + lines.append(f"{self.ipfs_gateway}/{quality.playlist_cid}") + + if len(lines) <= 2: + return # No quality playlists yet + + master_content = "\n".join(lines) + "\n" + cid = self._ipfs_add_bytes(master_content.encode(), pin=True) + + if cid: + self._master_playlist_cid = cid + print(f"[MultiResHLS] Master playlist: {cid}", file=sys.stderr) + + if self._on_playlist_update: + # Pass both master CID and quality info for dynamic playlist generation + quality_info = { + name: { + "cid": q.playlist_cid, + "width": q.width, + "height": q.height, + "bitrate": q.bitrate, + } + for name, q in self.qualities.items() + if q.playlist_cid + } + self._on_playlist_update(cid, quality_info) + + def close(self): + """Close all encoders and finalize output.""" + if not self._is_open: + return + + self._is_open = False + print(f"[MultiResHLS] Closing after {self._frame_count} frames", file=sys.stderr) + + # Close encoder stdin pipes + for name, proc in self._encoders.items(): + try: + proc.stdin.close() + except: + pass + + # Wait for encoders to finish + for name, proc in self._encoders.items(): + try: + proc.wait(timeout=30) + print(f"[MultiResHLS] Encoder {name} finished with code {proc.returncode}", file=sys.stderr) + except subprocess.TimeoutExpired: + proc.kill() + print(f"[MultiResHLS] Encoder {name} killed (timeout)", file=sys.stderr) + + # Final segment check and upload + self._check_and_upload_segments() + + # Wait for uploads to complete + self._upload_queue.put(None) # Shutdown signal + self._upload_thread.join(timeout=60) + + # Generate final playlists with EXT-X-ENDLIST + for name, quality in self.qualities.items(): + if quality.segment_cids: + playlist = self._generate_quality_playlist(quality, finalize=True) + cid = self._ipfs_add_bytes(playlist.encode(), pin=True) + if cid: + quality.playlist_cid = cid + print(f"[MultiResHLS] Final {name} playlist: {cid} ({len(quality.segment_cids)} segments)", file=sys.stderr) + + # Final master playlist + self._generate_master_playlist(finalize=True) + + print(f"[MultiResHLS] Complete. Master playlist: {self._master_playlist_cid}", file=sys.stderr) + + @property + def is_open(self) -> bool: + return self._is_open + + @property + def playlist_cid(self) -> Optional[str]: + return self._master_playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + if self._master_playlist_cid: + return f"{self.ipfs_gateway}/{self._master_playlist_cid}" + return None + + @property + def segment_cids(self) -> Dict[str, Dict[int, str]]: + """Get all segment CIDs organized by quality.""" + return {name: dict(q.segment_cids) for name, q in self.qualities.items()} diff --git a/l1/streaming/output.py b/l1/streaming/output.py new file mode 100644 index 0000000..b2a4e85 --- /dev/null +++ b/l1/streaming/output.py @@ -0,0 +1,963 @@ +""" +Output targets for streaming compositor. + +Supports: +- Display window (preview) +- File output (recording) +- Stream output (RTMP, etc.) - future +- NVENC hardware encoding (auto-detected) +- CuPy GPU arrays (auto-converted to numpy for output) +""" + +import numpy as np +import subprocess +import threading +import queue +from abc import ABC, abstractmethod +from typing import Tuple, Optional, List, Union +from pathlib import Path + +# Try to import CuPy for GPU array support +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + + +def ensure_numpy(frame: Union[np.ndarray, 'cp.ndarray']) -> np.ndarray: + """Convert frame to numpy array if it's a CuPy array.""" + if CUPY_AVAILABLE and isinstance(frame, cp.ndarray): + return cp.asnumpy(frame) + return frame + +# Cache NVENC availability check +_nvenc_available: Optional[bool] = None + + +def check_nvenc_available() -> bool: + """Check if NVENC hardware encoding is available and working. + + Does a real encode test to catch cases where nvenc is listed + but CUDA libraries aren't loaded. + """ + global _nvenc_available + if _nvenc_available is not None: + return _nvenc_available + + try: + # First check if encoder is listed + result = subprocess.run( + ["ffmpeg", "-encoders"], + capture_output=True, + text=True, + timeout=5 + ) + if "h264_nvenc" not in result.stdout: + _nvenc_available = False + return _nvenc_available + + # Actually try to encode a small test frame + result = subprocess.run( + ["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1", + "-c:v", "h264_nvenc", "-f", "null", "-"], + capture_output=True, + text=True, + timeout=10 + ) + _nvenc_available = result.returncode == 0 + if not _nvenc_available: + import sys + print("NVENC listed but not working, falling back to libx264", file=sys.stderr) + except Exception: + _nvenc_available = False + + return _nvenc_available + + +def get_encoder_params(codec: str, preset: str, crf: int) -> List[str]: + """ + Get encoder-specific FFmpeg parameters. + + For NVENC (h264_nvenc, hevc_nvenc): + - Uses -cq for constant quality (similar to CRF) + - Presets: p1 (fastest) to p7 (slowest/best quality) + - Mapping: fast->p4, medium->p5, slow->p6 + + For libx264: + - Uses -crf for constant rate factor + - Presets: ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow + """ + if codec in ("h264_nvenc", "hevc_nvenc"): + # Map libx264 presets to NVENC presets + nvenc_preset_map = { + "ultrafast": "p1", + "superfast": "p2", + "veryfast": "p3", + "faster": "p3", + "fast": "p4", + "medium": "p5", + "slow": "p6", + "slower": "p6", + "veryslow": "p7", + } + nvenc_preset = nvenc_preset_map.get(preset, "p4") + + # NVENC quality: 0 (best) to 51 (worst), similar to CRF + # CRF 18 = high quality, CRF 23 = good quality + return [ + "-c:v", codec, + "-preset", nvenc_preset, + "-cq", str(crf), # Constant quality mode + "-rc", "vbr", # Variable bitrate with quality target + ] + else: + # Standard libx264 params + return [ + "-c:v", codec, + "-preset", preset, + "-crf", str(crf), + ] + + +class Output(ABC): + """Abstract base class for output targets.""" + + @abstractmethod + def write(self, frame: np.ndarray, t: float): + """Write a frame to the output.""" + pass + + @abstractmethod + def close(self): + """Close the output and clean up resources.""" + pass + + @property + @abstractmethod + def is_open(self) -> bool: + """Check if output is still open/valid.""" + pass + + +class DisplayOutput(Output): + """ + Display frames using mpv (handles Wayland properly). + + Useful for live preview. Press 'q' to quit. + """ + + def __init__(self, title: str = "Streaming Preview", size: Tuple[int, int] = None, + audio_source: str = None, fps: float = 30): + self.title = title + self.size = size + self.audio_source = audio_source + self.fps = fps + self._is_open = True + self._process = None + self._audio_process = None + + def _start_mpv(self, frame_size: Tuple[int, int]): + """Start mpv process for display.""" + import sys + w, h = frame_size + cmd = [ + "mpv", + "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + f"--title={self.title}", + "-", + ] + print(f"Starting mpv: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Start audio playback if we have an audio source + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Display frame.""" + if not self._is_open: + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Ensure frame is correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + # Start mpv on first frame + if self._process is None: + self._start_mpv((frame.shape[1], frame.shape[0])) + + # Check if mpv is still running + if self._process.poll() is not None: + self._is_open = False + return + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() # Prevent buffering + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the display and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open + + +class FileOutput(Output): + """ + Write frames to a video file using ffmpeg. + + Automatically uses NVENC hardware encoding when available, + falling back to libx264 CPU encoding otherwise. + """ + + def __init__( + self, + path: str, + size: Tuple[int, int], + fps: float = 30, + codec: str = "auto", # "auto", "h264_nvenc", "libx264" + crf: int = 18, + preset: str = "fast", + audio_source: str = None, + ): + self.path = Path(path) + self.size = size + self.fps = fps + self._is_open = True + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # Build ffmpeg command + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + # Explicitly map: video from input 0 (rawvideo), audio from input 1 + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend(["-pix_fmt", "yuv420p"]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "192k", "-shortest"]) + + # Use fragmented mp4 for streamable output while writing + if str(self.path).endswith('.mp4'): + cmd.extend(["-movflags", "frag_keyframe+empty_moov+default_base_moof"]) + + cmd.append(str(self.path)) + + import sys + print(f"FileOutput cmd: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, # Show errors for debugging + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to video file.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the video file.""" + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None + + +class MultiOutput(Output): + """ + Write to multiple outputs simultaneously. + + Useful for recording while showing preview. + """ + + def __init__(self, outputs: list): + self.outputs = outputs + + def write(self, frame: np.ndarray, t: float): + for output in self.outputs: + if output.is_open: + output.write(frame, t) + + def close(self): + for output in self.outputs: + output.close() + + @property + def is_open(self) -> bool: + return any(o.is_open for o in self.outputs) + + +class NullOutput(Output): + """ + Discard frames (for benchmarking). + """ + + def __init__(self): + self._is_open = True + self.frame_count = 0 + + def write(self, frame: np.ndarray, t: float): + self.frame_count += 1 + + def close(self): + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open + + +class PipeOutput(Output): + """ + Pipe frames directly to mpv. + + Launches mpv with rawvideo demuxer and writes frames to stdin. + """ + + def __init__(self, size: Tuple[int, int], fps: float = 30, audio_source: str = None): + self.size = size + self.fps = fps + self.audio_source = audio_source + self._is_open = True + self._process = None + self._audio_process = None + self._started = False + + def _start(self): + """Start mpv and audio on first frame.""" + if self._started: + return + self._started = True + + import sys + w, h = self.size + + # Start mpv + cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + "--title=Streaming", + "-" + ] + print(f"Starting mpv: {w}x{h} @ {self.fps}fps", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + # Start audio + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to mpv.""" + if not self._is_open: + return + + self._start() + + # Check mpv still running + if self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close mpv and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open + + +class HLSOutput(Output): + """ + Write frames as HLS stream (m3u8 playlist + .ts segments). + + This enables true live streaming where the browser can poll + for new segments as they become available. + + Automatically uses NVENC hardware encoding when available. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, # 4s segments for stability + codec: str = "auto", # "auto", "h264_nvenc", "libx264" + crf: int = 23, + preset: str = "fast", # Better quality than ultrafast + audio_source: str = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self._is_open = True + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # HLS playlist path + self.playlist_path = self.output_dir / "stream.m3u8" + + # Build ffmpeg command for HLS output + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Keyframe interval - must be exactly segment_duration for clean cuts + gop_size = int(fps * segment_duration) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend([ + "-pix_fmt", "yuv420p", + # Force keyframes at exact intervals for clean segment boundaries + "-g", str(gop_size), + "-keyint_min", str(gop_size), + "-sc_threshold", "0", # Disable scene change detection + "-force_key_frames", f"expr:gte(t,n_forced*{segment_duration})", + # Reduce buffering for faster segment availability + "-flush_packets", "1", + ]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + # HLS specific options for smooth live streaming + cmd.extend([ + "-f", "hls", + "-hls_time", str(segment_duration), + "-hls_list_size", "0", # Keep all segments in playlist + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.playlist_path), + ]) + + import sys + print(f"HLSOutput cmd: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, # Show errors for debugging + ) + + # Track segments for status reporting + self.segments_written = 0 + self._last_segment_check = 0 + + def write(self, frame: np.ndarray, t: float): + """Write frame to HLS stream.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + + # Periodically count segments + if t - self._last_segment_check > 1.0: + self._last_segment_check = t + self.segments_written = len(list(self.output_dir.glob("segment_*.ts"))) + + def close(self): + """Close the HLS stream.""" + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + # Final segment count + self.segments_written = len(list(self.output_dir.glob("segment_*.ts"))) + + # Mark playlist as ended (VOD mode) + if self.playlist_path.exists(): + with open(self.playlist_path, "a") as f: + f.write("#EXT-X-ENDLIST\n") + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None + + +class IPFSHLSOutput(Output): + """ + Write frames as HLS stream with segments uploaded to IPFS. + + Each segment is uploaded to IPFS as it's created, enabling distributed + streaming where clients can fetch segments from any IPFS gateway. + + The m3u8 playlist is continuously updated with IPFS URLs and can be + fetched via get_playlist() or the playlist_cid property. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + codec: str = "auto", + crf: int = 23, + preset: str = "fast", + audio_source: str = None, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._is_open = True + self._on_playlist_update = on_playlist_update # Callback when playlist CID changes + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # Track segment CIDs + self.segment_cids: dict = {} # segment_number -> cid + self._last_segment_checked = -1 + self._playlist_cid: Optional[str] = None + self._upload_lock = threading.Lock() + + # Import IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Background upload thread for async IPFS uploads + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Local HLS paths + self.local_playlist_path = self.output_dir / "stream.m3u8" + + # Build ffmpeg command for HLS output + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Keyframe interval + gop_size = int(fps * segment_duration) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend([ + "-pix_fmt", "yuv420p", + "-g", str(gop_size), + "-keyint_min", str(gop_size), + "-sc_threshold", "0", + "-force_key_frames", f"expr:gte(t,n_forced*{segment_duration})", + "-flush_packets", "1", + ]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + # HLS options + cmd.extend([ + "-f", "hls", + "-hls_time", str(segment_duration), + "-hls_list_size", "0", + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.local_playlist_path), + ]) + + import sys + print(f"IPFSHLSOutput: starting ffmpeg", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, + ) + + def _upload_worker(self): + """Background worker thread for async IPFS uploads.""" + import sys + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + seg_path, seg_num = item + self._do_upload(seg_path, seg_num) + except queue.Empty: + continue + except Exception as e: + print(f"Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, seg_path: Path, seg_num: int): + """Actually perform the upload (runs in background thread).""" + import sys + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + with self._upload_lock: + self.segment_cids[seg_num] = cid + print(f"IPFS: segment_{seg_num:05d}.ts -> {cid}", file=sys.stderr) + self._update_ipfs_playlist() + except Exception as e: + print(f"Failed to upload segment {seg_num}: {e}", file=sys.stderr) + + def _upload_new_segments(self): + """Check for new segments and queue them for async IPFS upload.""" + import sys + import time + + # Find all segments + segments = sorted(self.output_dir.glob("segment_*.ts")) + + for seg_path in segments: + # Extract segment number from filename + seg_name = seg_path.stem # segment_00000 + seg_num = int(seg_name.split("_")[1]) + + # Skip if already uploaded or queued + with self._upload_lock: + if seg_num in self.segment_cids: + continue + + # Skip if segment is still being written (quick non-blocking check) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue # Empty file, still being created + + time.sleep(0.01) # Very short check + size2 = seg_path.stat().st_size + if size1 != size2: + continue # File still being written + except FileNotFoundError: + continue + + # Queue for async upload (non-blocking!) + self._upload_queue.put((seg_path, seg_num)) + + def _update_ipfs_playlist(self): + """Generate and upload IPFS-aware m3u8 playlist.""" + import sys + + with self._upload_lock: + if not self.segment_cids: + return + + # Build m3u8 content with IPFS URLs + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + # Add segments in order + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + playlist_content = "\n".join(lines) + "\n" + + # Upload playlist to IPFS + cid = self._ipfs_add_bytes(playlist_content.encode("utf-8"), pin=True) + if cid: + self._playlist_cid = cid + print(f"IPFS: playlist updated -> {cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + # Notify callback (e.g., to update database for live HLS redirect) + if self._on_playlist_update: + try: + self._on_playlist_update(cid) + except Exception as e: + print(f"IPFS: playlist callback error: {e}", file=sys.stderr) + + def write(self, frame: np.ndarray, t: float): + """Write frame to HLS stream and upload segments to IPFS.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + return + + # Check for new segments periodically (every second) + current_segment = int(t / self.segment_duration) + if current_segment > self._last_segment_checked: + self._last_segment_checked = current_segment + self._upload_new_segments() + + def close(self): + """Close the HLS stream and finalize IPFS uploads.""" + import sys + + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + # Queue any remaining segments + self._upload_new_segments() + + # Wait for pending uploads to complete + self._upload_queue.put(None) # Signal shutdown + self._upload_thread.join(timeout=30) + + # Generate final playlist with #EXT-X-ENDLIST + if self.segment_cids: + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + "#EXT-X-PLAYLIST-TYPE:VOD", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + lines.append("#EXT-X-ENDLIST") + playlist_content = "\n".join(lines) + "\n" + + cid = self._ipfs_add_bytes(playlist_content.encode("utf-8"), pin=True) + if cid: + self._playlist_cid = cid + print(f"IPFS: final playlist -> {cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + + @property + def playlist_cid(self) -> Optional[str]: + """Get the current playlist CID.""" + return self._playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + """Get the full IPFS URL for the playlist.""" + if self._playlist_cid: + return f"{self.ipfs_gateway}/{self._playlist_cid}" + return None + + def get_playlist(self) -> str: + """Get the current m3u8 playlist content with IPFS URLs.""" + if not self.segment_cids: + return "#EXTM3U\n" + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + if not self._is_open: + lines.append("#EXT-X-ENDLIST") + + return "\n".join(lines) + "\n" + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None \ No newline at end of file diff --git a/l1/streaming/pipeline.py b/l1/streaming/pipeline.py new file mode 100644 index 0000000..29dd7e1 --- /dev/null +++ b/l1/streaming/pipeline.py @@ -0,0 +1,846 @@ +""" +Streaming pipeline executor. + +Directly executes compiled sexp recipes frame-by-frame. +No adapter layer - frames and analysis flow through the DAG. +""" + +import sys +import time +import numpy as np +from pathlib import Path +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sources import VideoSource +from .audio import StreamingAudioAnalyzer +from .output import DisplayOutput, FileOutput +from .sexp_interp import SexpInterpreter + + +@dataclass +class FrameContext: + """Context passed through the pipeline for each frame.""" + t: float # Current time + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + analysis: Dict[str, Any] = field(default_factory=dict) + + +class StreamingPipeline: + """ + Executes a compiled sexp recipe as a streaming pipeline. + + Frames flow through the DAG directly - no adapter needed. + Each node is evaluated lazily when its output is requested. + """ + + def __init__(self, compiled_recipe, recipe_dir: Path = None, fps: float = 30, seed: int = 42, + output_size: tuple = None): + self.recipe = compiled_recipe + self.recipe_dir = recipe_dir or Path(".") + self.fps = fps + self.seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # Runtime state + self.sources: Dict[str, VideoSource] = {} + self.audio_analyzer: Optional[StreamingAudioAnalyzer] = None + self.audio_source_path: Optional[str] = None + + # Sexp interpreter for expressions + self.interp = SexpInterpreter() + + # Scan state (node_id -> current value) + self.scan_state: Dict[str, Any] = {} + self.scan_emit: Dict[str, Any] = {} + + # SLICE_ON state + self.slice_on_acc: Dict[str, Any] = {} + self.slice_on_result: Dict[str, Any] = {} + + # Frame cache for current timestep (cleared each frame) + self._frame_cache: Dict[str, np.ndarray] = {} + + # Context for current frame + self.ctx = FrameContext(t=0.0) + + # Output size (w, h) - set after sources are initialized + self._output_size = output_size + + # Initialize + self._init_sources() + self._init_scans() + self._init_slice_on() + + # Set output size from first source if not specified + if self._output_size is None and self.sources: + first_source = next(iter(self.sources.values())) + self._output_size = first_source._size + + def _init_sources(self): + """Initialize video and audio sources.""" + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + config = node.get('config', {}) + path = config.get('path') + if path: + full_path = (self.recipe_dir / path).resolve() + suffix = full_path.suffix.lower() + + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + if not full_path.exists(): + print(f"Warning: video not found: {full_path}", file=sys.stderr) + continue + self.sources[node['id']] = VideoSource( + str(full_path), + target_fps=self.fps + ) + elif suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + if not full_path.exists(): + print(f"Warning: audio not found: {full_path}", file=sys.stderr) + continue + self.audio_source_path = str(full_path) + self.audio_analyzer = StreamingAudioAnalyzer(str(full_path)) + + def _init_scans(self): + """Initialize scan nodes with their initial state.""" + import random + seed_offset = 0 + + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + config = node.get('config', {}) + + # Create RNG for this scan + scan_seed = config.get('seed', self.seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + init_value = self.interp.eval(init_expr, {}) + + self.scan_state[node['id']] = { + 'value': init_value, + 'rng': rng, + 'config': config, + } + + # Compute initial emit + self._update_scan_emit(node['id']) + + def _update_scan_emit(self, node_id: str): + """Update the emit value for a scan.""" + state = self.scan_state[node_id] + config = state['config'] + emit_expr = config.get('emit_expr', config.get('emit', None)) + + if emit_expr is None: + # No emit expression - emit the value directly + self.scan_emit[node_id] = state['value'] + return + + # Build environment from state + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG for interpreter + self.interp.rng = state['rng'] + + self.scan_emit[node_id] = self.interp.eval(emit_expr, env) + + def _step_scan(self, node_id: str): + """Step a scan forward on beat.""" + state = self.scan_state[node_id] + config = state['config'] + step_expr = config.get('step_expr', config.get('step', None)) + + if step_expr is None: + return + + # Build environment + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG + self.interp.rng = state['rng'] + + # Evaluate step + new_value = self.interp.eval(step_expr, env) + state['value'] = new_value + + # Update emit + self._update_scan_emit(node_id) + + def _init_slice_on(self): + """Initialize SLICE_ON nodes.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + init = config.get('init', {}) + self.slice_on_acc[node['id']] = dict(init) + + # Evaluate initial state + self._eval_slice_on(node['id']) + + def _eval_slice_on(self, node_id: str): + """Evaluate a SLICE_ON node's Lambda.""" + node = self.nodes[node_id] + config = node.get('config', {}) + fn = config.get('fn') + videos = config.get('videos', []) + + if not fn: + return + + acc = self.slice_on_acc[node_id] + n_videos = len(videos) + + # Set up environment + self.interp.globals['videos'] = list(range(n_videos)) + + try: + from .sexp_interp import eval_slice_on_lambda + result = eval_slice_on_lambda( + fn, acc, self.ctx.beat_count, 0, 1, + list(range(n_videos)), self.interp + ) + self.slice_on_result[node_id] = result + + # Update accumulator + if 'acc' in result: + self.slice_on_acc[node_id] = result['acc'] + except Exception as e: + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _on_beat(self): + """Called when a beat is detected.""" + self.ctx.beat_count += 1 + + # Step all scans + for node_id in self.scan_state: + self._step_scan(node_id) + + # Step all SLICE_ON nodes + for node_id in self.slice_on_acc: + self._eval_slice_on(node_id) + + def _get_frame(self, node_id: str) -> Optional[np.ndarray]: + """ + Get the output frame for a node at current time. + + Recursively evaluates inputs as needed. + Results are cached for the current timestep. + """ + if node_id in self._frame_cache: + return self._frame_cache[node_id] + + node = self.nodes.get(node_id) + if not node: + return None + + node_type = node.get('type') + + if node_type == 'SOURCE': + frame = self._eval_source(node) + elif node_type == 'SEGMENT': + frame = self._eval_segment(node) + elif node_type == 'EFFECT': + frame = self._eval_effect(node) + elif node_type == 'SLICE_ON': + frame = self._eval_slice_on_frame(node) + else: + # Unknown node type - try to pass through input + inputs = node.get('inputs', []) + frame = self._get_frame(inputs[0]) if inputs else None + + self._frame_cache[node_id] = frame + return frame + + def _eval_source(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SOURCE node.""" + source = self.sources.get(node['id']) + if source: + return source.read_frame(self.ctx.t) + return None + + def _eval_segment(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SEGMENT node (time segment of source).""" + inputs = node.get('inputs', []) + if not inputs: + return None + + config = node.get('config', {}) + start = config.get('start', 0) + duration = config.get('duration') + + # Resolve any bindings + if isinstance(start, dict): + start = self._resolve_binding(start) if start.get('_binding') else 0 + if isinstance(duration, dict): + duration = self._resolve_binding(duration) if duration.get('_binding') else None + + # Adjust time for segment + t_local = self.ctx.t + (start if isinstance(start, (int, float)) else 0) + if duration and isinstance(duration, (int, float)): + t_local = t_local % duration # Loop within segment + + # Get source frame at adjusted time + source_id = inputs[0] + source = self.sources.get(source_id) + if source: + return source.read_frame(t_local) + + return self._get_frame(source_id) + + def _eval_effect(self, node: dict) -> Optional[np.ndarray]: + """Evaluate an EFFECT node.""" + import cv2 + + inputs = node.get('inputs', []) + config = node.get('config', {}) + effect_name = config.get('effect') + + # Get input frame(s) + input_frames = [self._get_frame(inp) for inp in inputs] + input_frames = [f for f in input_frames if f is not None] + + if not input_frames: + return None + + frame = input_frames[0] + + # Resolve bindings in config + params = self._resolve_config(config) + + # Apply effect based on name + if effect_name == 'rotate': + angle = params.get('angle', 0) + if abs(angle) > 0.5: + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == 'zoom': + amount = params.get('amount', 1.0) + if abs(amount - 1.0) > 0.01: + frame = self._apply_zoom(frame, amount) + + elif effect_name == 'invert': + amount = params.get('amount', 0) + if amount > 0.01: + inverted = 255 - frame + frame = cv2.addWeighted(frame, 1 - amount, inverted, amount, 0) + + elif effect_name == 'hue_shift': + degrees = params.get('degrees', 0) + if abs(degrees) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(int) + int(degrees / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + elif effect_name == 'blend': + if len(input_frames) >= 2: + opacity = params.get('opacity', 0.5) + frame = cv2.addWeighted(input_frames[0], 1 - opacity, + input_frames[1], opacity, 0) + + elif effect_name == 'blend_multi': + weights = params.get('weights', []) + if len(input_frames) > 1 and weights: + h, w = input_frames[0].shape[:2] + result = np.zeros((h, w, 3), dtype=np.float32) + for f, wt in zip(input_frames, weights): + if f is not None and wt > 0.001: + if f.shape[:2] != (h, w): + f = cv2.resize(f, (w, h)) + result += f.astype(np.float32) * wt + frame = np.clip(result, 0, 255).astype(np.uint8) + + elif effect_name == 'ripple': + amp = params.get('amplitude', 0) + if amp > 1: + frame = self._apply_ripple(frame, amp, + params.get('center_x', 0.5), + params.get('center_y', 0.5), + params.get('frequency', 8), + params.get('decay', 2), + params.get('speed', 5)) + + return frame + + def _eval_slice_on_frame(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SLICE_ON node - returns composited frame.""" + import cv2 + + config = node.get('config', {}) + video_ids = config.get('videos', []) + result = self.slice_on_result.get(node['id'], {}) + + if not result: + # No result yet - return first video + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get layers and compose info + layers = result.get('layers', []) + compose = result.get('compose', {}) + weights = compose.get('weights', []) + + if not layers or not weights: + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get frames for each layer + frames = [] + for i, layer in enumerate(layers): + video_idx = layer.get('video', i) + if video_idx < len(video_ids): + frame = self._get_frame(video_ids[video_idx]) + + # Apply layer effects (zoom) + effects = layer.get('effects', []) + for eff in effects: + eff_name = eff.get('effect') + if hasattr(eff_name, 'name'): + eff_name = eff_name.name + if eff_name == 'zoom': + zoom_amt = eff.get('amount', 1.0) + if frame is not None: + frame = self._apply_zoom(frame, zoom_amt) + + frames.append(frame) + else: + frames.append(None) + + # Composite with weights - use consistent output size + if self._output_size: + w, h = self._output_size + else: + # Fallback to first non-None frame size + for f in frames: + if f is not None: + h, w = f.shape[:2] + break + else: + return None + + output = np.zeros((h, w, 3), dtype=np.float32) + + for frame, weight in zip(frames, weights): + if frame is None or weight < 0.001: + continue + + # Resize to output size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + + output += frame.astype(np.float32) * weight + + # Normalize weights + total_weight = sum(wt for wt in weights if wt > 0.001) + if total_weight > 0 and abs(total_weight - 1.0) > 0.01: + output /= total_weight + + return np.clip(output, 0, 255).astype(np.uint8) + + def _resolve_config(self, config: dict) -> dict: + """Resolve bindings in effect config to actual values.""" + resolved = {} + + for key, value in config.items(): + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', + 'analysis_refs', 'inputs', 'cid'): + continue + + if isinstance(value, dict) and value.get('_binding'): + resolved[key] = self._resolve_binding(value) + elif isinstance(value, dict) and value.get('_expr'): + resolved[key] = self._resolve_expr(value) + else: + resolved[key] = value + + return resolved + + def _resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to its current value.""" + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get raw value from scan or analysis + if source_id in self.scan_emit: + value = self.scan_emit[source_id] + elif source_id in self.ctx.analysis: + data = self.ctx.analysis[source_id] + value = data.get(feature, data.get('values', [0]))[0] if isinstance(data, dict) else data + else: + # Fallback to energy + value = self.ctx.energy + + # Extract feature from dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def _resolve_expr(self, expr: dict) -> Any: + """Resolve a compiled expression.""" + env = { + 'energy': self.ctx.energy, + 'beat_count': self.ctx.beat_count, + 't': self.ctx.t, + } + + # Add scan values + for scan_id, value in self.scan_emit.items(): + # Use short form if available + env[scan_id] = value + + # Extract the actual expression from _expr wrapper + actual_expr = expr.get('_expr', expr) + return self.interp.eval(actual_expr, env) + + def _apply_zoom(self, frame: np.ndarray, amount: float) -> np.ndarray: + """Apply zoom to frame.""" + import cv2 + h, w = frame.shape[:2] + + if amount > 1.01: + # Zoom in: crop center + new_w, new_h = int(w / amount), int(h / amount) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + elif amount < 0.99: + # Zoom out: shrink and center + scaled_w, scaled_h = int(w * amount), int(h * amount) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(frame, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + return canvas + + return frame + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, frequency: float, + decay: float, speed: float) -> np.ndarray: + """Apply ripple effect.""" + import cv2 + h, w = frame.shape[:2] + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize to center + center_x, center_y = w * cx, h * cy + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx**2 + dy**2) + + # Ripple displacement + phase = self.ctx.t * speed + ripple = amplitude * np.sin(dist / frequency - phase) * np.exp(-dist * decay / max(w, h)) + + # Displace coordinates + angle = np.arctan2(dy, dx) + map_x = (x_coords + ripple * np.cos(angle)).astype(np.float32) + map_y = (y_coords + ripple * np.sin(angle)).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + def _find_output_node(self) -> Optional[str]: + """Find the final output node (MUX or last EFFECT).""" + # Look for MUX node + for node in self.recipe.nodes: + if node.get('type') == 'MUX': + return node['id'] + + # Otherwise find last EFFECT after SLICE_ON + last_effect = None + found_slice_on = False + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + found_slice_on = True + elif node.get('type') == 'EFFECT' and found_slice_on: + last_effect = node['id'] + + return last_effect + + def render_frame(self, t: float) -> Optional[np.ndarray]: + """Render a single frame at time t.""" + # Clear frame cache + self._frame_cache.clear() + + # Update context + self.ctx.t = t + + # Update audio analysis + if self.audio_analyzer: + self.audio_analyzer.set_time(t) + energy = self.audio_analyzer.get_energy() + is_beat = self.audio_analyzer.get_beat() + + # Beat edge detection + was_beat = self.ctx.is_beat + self.ctx.energy = energy + self.ctx.is_beat = is_beat + + if is_beat and not was_beat: + self._on_beat() + + # Store in analysis dict + self.ctx.analysis['live_energy'] = {'values': [energy]} + self.ctx.analysis['live_beat'] = {'values': [1.0 if is_beat else 0.0]} + + # Find output node and render + output_node = self._find_output_node() + if output_node: + frame = self._get_frame(output_node) + # Normalize to output size + if frame is not None and self._output_size: + w, h = self._output_size + if frame.shape[1] != w or frame.shape[0] != h: + import cv2 + frame = cv2.resize(frame, (w, h)) + return frame + + return None + + def run(self, output: str = "preview", duration: float = None): + """ + Run the pipeline. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration or 60s) + """ + # Determine duration + if duration is None: + if self.audio_analyzer: + duration = self.audio_analyzer.duration + else: + duration = 60.0 + + # Create output + if output == "preview": + # Get frame size from first source + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = DisplayOutput(size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + elif isinstance(output, str): + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = FileOutput(output, size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + else: + out = output + + frame_time = 1.0 / self.fps + n_frames = int(duration * self.fps) + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {n_frames} frames @ {self.fps}fps", file=sys.stderr) + + start_time = time.time() + frame_count = 0 + + try: + for frame_num in range(n_frames): + t = frame_num * frame_time + + frame = self.render_frame(t) + + if frame is not None: + out.write(frame, t) + frame_count += 1 + + # Progress + if frame_num % 50 == 0: + elapsed = time.time() - start_time + fps = frame_count / elapsed if elapsed > 0 else 0 + pct = 100 * frame_num / n_frames + print(f"\r{pct:5.1f}% | {fps:5.1f} fps | frame {frame_num}/{n_frames}", + end="", file=sys.stderr) + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources.values(): + src.close() + + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", + file=sys.stderr) + + +def run_pipeline(recipe_path: str, output: str = "preview", + duration: float = None, fps: float = None): + """ + Run a recipe through the streaming pipeline. + + No adapter layer - directly executes the compiled recipe. + """ + from pathlib import Path + + # Add artdag to path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + pipeline.run(output=output, duration=duration) + + +def run_pipeline_piped(recipe_path: str, duration: float = None, fps: float = None): + """ + Run pipeline and pipe directly to mpv with audio. + """ + import subprocess + from pathlib import Path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + # Get frame info + first_source = next(iter(pipeline.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + + # Determine duration + if duration is None: + if pipeline.audio_analyzer: + duration = pipeline.audio_analyzer.duration + else: + duration = 60.0 + + actual_fps = fps or compiled.encoding.get('fps', 30) + n_frames = int(duration * actual_fps) + frame_time = 1.0 / actual_fps + + print(f"Streaming {n_frames} frames @ {actual_fps}fps to mpv", file=sys.stderr) + + # Start mpv + mpv_cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={actual_fps}", + "--title=Streaming Pipeline", + "-" + ] + mpv = subprocess.Popen(mpv_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Start audio if available + audio_proc = None + if pipeline.audio_source_path: + audio_cmd = ["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + pipeline.audio_source_path] + audio_proc = subprocess.Popen(audio_cmd, stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + + try: + import cv2 + for frame_num in range(n_frames): + if mpv.poll() is not None: + break # mpv closed + + t = frame_num * frame_time + frame = pipeline.render_frame(t) + if frame is not None: + # Ensure consistent frame size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + try: + mpv.stdin.write(frame.tobytes()) + mpv.stdin.flush() + except BrokenPipeError: + break + except KeyboardInterrupt: + pass + finally: + if mpv.stdin: + mpv.stdin.close() + mpv.terminate() + if audio_proc: + audio_proc.terminate() + for src in pipeline.sources.values(): + src.close() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe through streaming pipeline") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_pipeline_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_pipeline(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/l1/streaming/recipe_adapter.py b/l1/streaming/recipe_adapter.py new file mode 100644 index 0000000..2133919 --- /dev/null +++ b/l1/streaming/recipe_adapter.py @@ -0,0 +1,470 @@ +""" +Adapter to run sexp recipes through the streaming compositor. + +Bridges the gap between: +- Existing recipe format (sexp files with stages, effects, analysis) +- Streaming compositor (sources, effect chains, compositor config) +""" + +import sys +from pathlib import Path +from typing import Dict, List, Any, Optional + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + +from .compositor import StreamingCompositor +from .sources import VideoSource +from .audio import FileAudioAnalyzer + + +class RecipeAdapter: + """ + Adapts a compiled sexp recipe to run through the streaming compositor. + + Example: + adapter = RecipeAdapter("effects/quick_test.sexp") + adapter.run(output="preview", duration=60) + """ + + def __init__( + self, + recipe_path: str, + params: Dict[str, Any] = None, + backend: str = "numpy", + ): + """ + Load and prepare a recipe for streaming. + + Args: + recipe_path: Path to .sexp recipe file + params: Parameter overrides + backend: "numpy" or "glsl" + """ + self.recipe_path = Path(recipe_path) + self.recipe_dir = self.recipe_path.parent + self.params = params or {} + self.backend = backend + + # Compile recipe + self._compile() + + def _compile(self): + """Compile the recipe and extract structure.""" + from artdag.sexp.compiler import compile_string + + recipe_text = self.recipe_path.read_text() + self.compiled = compile_string(recipe_text, self.params, recipe_dir=self.recipe_dir) + + # Extract key info + self.sources = {} # name -> path + self.effects_registry = {} # effect_name -> path + self.analyzers = {} # name -> analyzer info + + # Walk nodes to find sources and structure + # nodes is a list in CompiledRecipe + for node in self.compiled.nodes: + node_type = node.get("type", "") + + if node_type == "SOURCE": + config = node.get("config", {}) + path = config.get("path") + if path: + self.sources[node["id"]] = self.recipe_dir / path + + elif node_type == "ANALYZE": + config = node.get("config", {}) + self.analyzers[node["id"]] = { + "analyzer": config.get("analyzer"), + "path": config.get("analyzer_path"), + } + + # Get effects registry from compiled recipe + # registry has 'effects' sub-dict + effects_dict = self.compiled.registry.get("effects", {}) + for name, info in effects_dict.items(): + if info.get("path"): + self.effects_registry[name] = Path(info["path"]) + + def run_analysis(self) -> Dict[str, Any]: + """ + Run analysis phase (energy, beats, etc.). + + Returns: + Dict of analysis track name -> {times, values, duration} + """ + print(f"Running analysis...", file=sys.stderr) + + # Use existing planner's analysis execution + from artdag.sexp.planner import create_plan + + analysis_data = {} + + def on_analysis(node_id: str, results: dict): + analysis_data[node_id] = results + print(f" {node_id[:16]}...: {len(results.get('times', []))} samples", file=sys.stderr) + + # Create plan (runs analysis as side effect) + plan = create_plan( + self.compiled, + inputs={}, + recipe_dir=self.recipe_dir, + on_analysis=on_analysis, + ) + + # Also store named analysis tracks + for name, data in plan.analysis.items(): + analysis_data[name] = data + + return analysis_data + + def build_compositor( + self, + analysis_data: Dict[str, Any] = None, + fps: float = None, + ) -> StreamingCompositor: + """ + Build a streaming compositor from the recipe. + + This is a simplified version that handles common patterns. + Complex recipes may need manual configuration. + + Args: + analysis_data: Pre-computed analysis data + + Returns: + Configured StreamingCompositor + """ + # Extract video and audio sources in SLICE_ON input order + video_sources = [] + audio_source = None + + # Find audio source first + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + audio_source = str(path) + break + + # Find SLICE_ON node to get correct video order + slice_on_inputs = None + for node in self.compiled.nodes: + if node.get('type') == 'SLICE_ON': + # Use 'videos' config key which has the correct order + config = node.get('config', {}) + slice_on_inputs = config.get('videos', []) + break + + if slice_on_inputs: + # Trace each SLICE_ON input back to its SOURCE + node_lookup = {n['id']: n for n in self.compiled.nodes} + + def trace_to_source(node_id, visited=None): + """Trace a node back to its SOURCE, return source path.""" + if visited is None: + visited = set() + if node_id in visited: + return None + visited.add(node_id) + + node = node_lookup.get(node_id) + if not node: + return None + if node.get('type') == 'SOURCE': + return self.sources.get(node_id) + # Recurse through inputs + for inp in node.get('inputs', []): + result = trace_to_source(inp, visited) + if result: + return result + return None + + # Build video_sources in SLICE_ON input order + for inp_id in slice_on_inputs: + source_path = trace_to_source(inp_id) + if source_path: + suffix = source_path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(source_path)) + + # Fallback to definition order if no SLICE_ON + if not video_sources: + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(path)) + + if not video_sources: + raise ValueError("No video sources found in recipe") + + # Build effect chains - use live audio bindings (matching video_sources count) + effects_per_source = self._build_streaming_effects(n_sources=len(video_sources)) + + # Build compositor config from recipe + compositor_config = self._extract_compositor_config(analysis_data) + + return StreamingCompositor( + sources=video_sources, + effects_per_source=effects_per_source, + compositor_config=compositor_config, + analysis_data=analysis_data or {}, + backend=self.backend, + recipe_dir=self.recipe_dir, + fps=fps or self.compiled.encoding.get("fps", 30), + audio_source=audio_source, + ) + + def _build_streaming_effects(self, n_sources: int = None) -> List[List[Dict]]: + """ + Build effect chains for streaming with live audio bindings. + + Replicates the recipe's effect pipeline: + - Per source: rotate, zoom, invert, hue_shift, ascii_art + - All driven by live_energy and live_beat + """ + if n_sources is None: + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + + effects_per_source = [] + + for i in range(n_sources): + # Alternate rotation direction per source + rot_dir = 1 if i % 2 == 0 else -1 + + effects = [ + # Rotate - energy drives angle + { + "effect": "rotate", + "effect_path": str(self.effects_registry.get("rotate", "")), + "angle": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 45 * rot_dir], + }, + }, + # Zoom - energy drives amount + { + "effect": "zoom", + "effect_path": str(self.effects_registry.get("zoom", "")), + "amount": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [1.0, 1.5] if i % 2 == 0 else [1.0, 0.7], + }, + }, + # Invert - beat triggers + { + "effect": "invert", + "effect_path": str(self.effects_registry.get("invert", "")), + "amount": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + # Hue shift - energy drives hue + { + "effect": "hue_shift", + "effect_path": str(self.effects_registry.get("hue_shift", "")), + "degrees": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 180], + }, + }, + # ASCII art - energy drives char size, beat triggers mix + { + "effect": "ascii_art", + "effect_path": str(self.effects_registry.get("ascii_art", "")), + "char_size": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [4, 32], + }, + "mix": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + ] + effects_per_source.append(effects) + + return effects_per_source + + def _extract_effects(self) -> List[List[Dict]]: + """Extract effect chains for each source (legacy, pre-computed analysis).""" + # Simplified: find EFFECT nodes and their configs + effects_per_source = [] + + for node_id, path in self.sources.items(): + if path.suffix.lower() not in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + continue + + # Find effects that depend on this source + # This is simplified - real implementation would trace the DAG + effects = [] + + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + effect_name = config.get("effect") + if effect_name and effect_name in self.effects_registry: + effect_config = { + "effect": effect_name, + "effect_path": str(self.effects_registry[effect_name]), + } + # Copy only effect params (filter out internal fields) + internal_fields = ( + "effect", "effect_path", "cid", "effect_cid", + "effects_registry", "analysis_refs", "inputs", + ) + for k, v in config.items(): + if k not in internal_fields: + effect_config[k] = v + effects.append(effect_config) + break # One effect per source for now + + effects_per_source.append(effects) + + return effects_per_source + + def _extract_compositor_config(self, analysis_data: Dict) -> Dict: + """Extract compositor configuration.""" + # Look for blend_multi or similar composition nodes + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + if config.get("effect") == "blend_multi": + return { + "mode": config.get("mode", "alpha"), + "weights": config.get("weights", []), + } + + # Default: equal blend + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + return { + "mode": "alpha", + "weights": [1.0 / n_sources] * n_sources if n_sources > 0 else [1.0], + } + + def run( + self, + output: str = "preview", + duration: float = None, + fps: float = None, + ): + """ + Run the recipe through streaming compositor. + + Everything streams: video frames read on-demand, audio analyzed in real-time. + No pre-computation. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration) + fps: Frame rate (default from recipe, or 30) + """ + # Build compositor with recipe executor for full pipeline + from .recipe_executor import StreamingRecipeExecutor + + compositor = self.build_compositor(analysis_data={}, fps=fps) + + # Use audio duration if not specified + if duration is None: + if compositor._audio_analyzer: + duration = compositor._audio_analyzer.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + else: + # Live mode - run until quit + print("Live mode - press 'q' to quit", file=sys.stderr) + + # Create sexp executor that interprets the recipe + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(self.compiled, seed=42) + + compositor.run(output=output, duration=duration, recipe_executor=executor) + + +def run_recipe( + recipe_path: str, + output: str = "preview", + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run a recipe through streaming compositor. + + Everything streams in real-time: video frames, audio analysis. + No pre-computation - starts immediately. + + Example: + run_recipe("effects/quick_test.sexp", output="preview", duration=30) + run_recipe("effects/quick_test.sexp", output="preview", fps=5) # Lower fps for slow systems + """ + adapter = RecipeAdapter(recipe_path, params=params) + adapter.run(output=output, duration=duration, fps=fps) + + +def run_recipe_piped( + recipe_path: str, + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run recipe and pipe directly to mpv. + """ + from .output import PipeOutput + + adapter = RecipeAdapter(recipe_path, params=params) + compositor = adapter.build_compositor(analysis_data={}, fps=fps) + + # Get frame size + if compositor.sources: + first_source = compositor.sources[0] + w, h = first_source._size + else: + w, h = 720, 720 + + actual_fps = fps or adapter.compiled.encoding.get('fps', 30) + + # Create pipe output + pipe_out = PipeOutput( + size=(w, h), + fps=actual_fps, + audio_source=compositor._audio_source + ) + + # Create executor + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(adapter.compiled, seed=42) + + # Run with pipe output + compositor.run(output=pipe_out, duration=duration, recipe_executor=executor) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe with streaming compositor") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_recipe_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_recipe(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/l1/streaming/recipe_executor.py b/l1/streaming/recipe_executor.py new file mode 100644 index 0000000..678d9f6 --- /dev/null +++ b/l1/streaming/recipe_executor.py @@ -0,0 +1,415 @@ +""" +Streaming recipe executor. + +Implements the full recipe logic for real-time streaming: +- Scans (state machines that evolve on beats) +- Process-pair template (two clips with sporadic effects, blended) +- Cycle-crossfade (dynamic composition cycling through video pairs) +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + + +@dataclass +class ScanState: + """State for a scan (beat-driven state machine).""" + value: Any = 0 + rng: random.Random = field(default_factory=random.Random) + + +class StreamingScans: + """ + Real-time scan executor. + + Scans are state machines that evolve on each beat. + They drive effect parameters like invert triggers, hue shifts, etc. + """ + + def __init__(self, seed: int = 42, n_sources: int = 4): + self.master_seed = seed + self.n_sources = n_sources + self.scans: Dict[str, ScanState] = {} + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self._init_scans() + + def _init_scans(self): + """Initialize all scans with their own RNG seeds.""" + scan_names = [] + + # Per-pair scans (dynamic based on n_sources) + for i in range(self.n_sources): + scan_names.extend([ + f"inv_a_{i}", f"inv_b_{i}", f"hue_a_{i}", f"hue_b_{i}", + f"ascii_a_{i}", f"ascii_b_{i}", f"pair_mix_{i}", f"pair_rot_{i}", + ]) + + # Global scans + scan_names.extend(["whole_spin", "ripple_gate", "cycle"]) + + for i, name in enumerate(scan_names): + rng = random.Random(self.master_seed + i) + self.scans[name] = ScanState(value=self._init_value(name), rng=rng) + + def _init_value(self, name: str) -> Any: + """Get initial value for a scan.""" + if name.startswith("inv_") or name.startswith("ascii_"): + return 0 # Counter for remaining beats + elif name.startswith("hue_"): + return {"rem": 0, "hue": 0} + elif name.startswith("pair_mix"): + return {"rem": 0, "opacity": 0.5} + elif name.startswith("pair_rot"): + pair_idx = int(name.split("_")[-1]) + rot_dir = 1 if pair_idx % 2 == 0 else -1 + return {"beat": 0, "clen": 25, "dir": rot_dir, "angle": 0} + elif name == "whole_spin": + return { + "phase": 0, # 0 = waiting, 1 = spinning + "beat": 0, # beats into current phase + "plen": 20, # beats in this phase + "dir": 1, # spin direction + "total_angle": 0.0, # cumulative angle after all spins + "spin_start_angle": 0.0, # angle when current spin started + "spin_start_time": 0.0, # time when current spin started + "spin_end_time": 0.0, # estimated time when spin ends + } + elif name == "ripple_gate": + return {"rem": 0, "cx": 0.5, "cy": 0.5} + elif name == "cycle": + return {"cycle": 0, "beat": 0, "clen": 60} + return 0 + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + # Estimate beat interval from last two beats + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + for name, state in self.scans.items(): + state.value = self._step_scan(name, state.value, state.rng, beat_interval) + + def _step_scan(self, name: str, value: Any, rng: random.Random, beat_interval: float = 0.5) -> Any: + """Step a scan forward by one beat.""" + + # Invert scan: 10% chance, lasts 1-5 beats + if name.startswith("inv_"): + if value > 0: + return value - 1 + elif rng.random() < 0.1: + return rng.randint(1, 5) + return 0 + + # Hue scan: 10% chance, random hue 30-330, lasts 1-5 beats + elif name.startswith("hue_"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "hue": value["hue"]} + elif rng.random() < 0.1: + return {"rem": rng.randint(1, 5), "hue": rng.uniform(30, 330)} + return {"rem": 0, "hue": 0} + + # ASCII scan: 5% chance, lasts 1-3 beats + elif name.startswith("ascii_"): + if value > 0: + return value - 1 + elif rng.random() < 0.05: + return rng.randint(1, 3) + return 0 + + # Pair mix: changes every 1-11 beats + elif name.startswith("pair_mix"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "opacity": value["opacity"]} + return {"rem": rng.randint(1, 11), "opacity": rng.choice([0, 0.5, 1.0])} + + # Pair rotation: full rotation every 20-30 beats + elif name.startswith("pair_rot"): + beat = value["beat"] + clen = value["clen"] + dir_ = value["dir"] + angle = value["angle"] + + if beat + 1 < clen: + new_angle = angle + dir_ * (360 / clen) + return {"beat": beat + 1, "clen": clen, "dir": dir_, "angle": new_angle} + else: + return {"beat": 0, "clen": rng.randint(20, 30), "dir": -dir_, "angle": angle} + + # Whole spin: sporadic 720 degree spins (cumulative - stays rotated) + elif name == "whole_spin": + phase = value["phase"] + beat = value["beat"] + plen = value["plen"] + dir_ = value["dir"] + total_angle = value.get("total_angle", 0.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", 0.0) + + if phase == 1: + # Currently spinning + if beat + 1 < plen: + return { + "phase": 1, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Spin complete - update total_angle with final spin + new_total = spin_start_angle + dir_ * 720.0 + return { + "phase": 0, "beat": 0, "plen": rng.randint(20, 40), "dir": dir_, + "total_angle": new_total, + "spin_start_angle": new_total, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time, + } + else: + # Waiting phase + if beat + 1 < plen: + return { + "phase": 0, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Start new spin + new_dir = 1 if rng.random() < 0.5 else -1 + new_plen = rng.randint(10, 25) + spin_duration = new_plen * beat_interval + return { + "phase": 1, "beat": 0, "plen": new_plen, "dir": new_dir, + "total_angle": total_angle, + "spin_start_angle": total_angle, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time + spin_duration, + } + + # Ripple gate: 5% chance, lasts 1-20 beats + elif name == "ripple_gate": + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "cx": value["cx"], "cy": value["cy"]} + elif rng.random() < 0.05: + return {"rem": rng.randint(1, 20), + "cx": rng.uniform(0.1, 0.9), + "cy": rng.uniform(0.1, 0.9)} + return {"rem": 0, "cx": 0.5, "cy": 0.5} + + # Cycle: track which video pair is active + elif name == "cycle": + beat = value["beat"] + clen = value["clen"] + cycle = value["cycle"] + + if beat + 1 < clen: + return {"cycle": cycle, "beat": beat + 1, "clen": clen} + else: + # Move to next pair, vary cycle length + return {"cycle": (cycle + 1) % 4, "beat": 0, + "clen": 40 + (self.beat_count * 7) % 41} + + return value + + def get_emit(self, name: str) -> float: + """Get emitted value for a scan.""" + value = self.scans[name].value + + if name.startswith("inv_") or name.startswith("ascii_"): + return 1.0 if value > 0 else 0.0 + + elif name.startswith("hue_"): + return value["hue"] if value["rem"] > 0 else 0.0 + + elif name.startswith("pair_mix"): + return value["opacity"] + + elif name.startswith("pair_rot"): + return value["angle"] + + elif name == "whole_spin": + # Smooth time-based interpolation during spin + phase = value.get("phase", 0) + if phase == 1: + # Currently spinning - interpolate based on time + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", spin_start_time + 1.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + dir_ = value.get("dir", 1) + + duration = spin_end_time - spin_start_time + if duration > 0: + progress = (self.current_time - spin_start_time) / duration + progress = max(0.0, min(1.0, progress)) # clamp to 0-1 + else: + progress = 1.0 + + return spin_start_angle + progress * 720.0 * dir_ + else: + # Not spinning - return cumulative angle + return value.get("total_angle", 0.0) + + elif name == "ripple_gate": + return 1.0 if value["rem"] > 0 else 0.0 + + elif name == "cycle": + return value + + return 0.0 + + +class StreamingRecipeExecutor: + """ + Executes a recipe in streaming mode. + + Implements: + - process-pair: two video clips with opposite effects, blended + - cycle-crossfade: dynamic cycling through video pairs + - Final effects: whole-spin rotation, ripple + """ + + def __init__(self, n_sources: int = 4, seed: int = 42): + self.n_sources = n_sources + self.scans = StreamingScans(seed, n_sources=n_sources) + self.last_beat_detected = False + self.current_time = 0.0 + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with current audio analysis.""" + self.current_time = t + self.scans.current_time = t + # Update scans on beat + if is_beat and not self.last_beat_detected: + self.scans.on_beat() + self.last_beat_detected = is_beat + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """ + Get effect parameters for a source clip. + + Args: + source_idx: Which video source (0-3) + clip: "a" or "b" (each source has two clips) + energy: Current audio energy (0-1) + """ + suffix = f"_{source_idx}" + + # Rotation ranges alternate + if source_idx % 2 == 0: + rot_range = [0, 45] if clip == "a" else [0, -45] + zoom_range = [1, 1.5] if clip == "a" else [1, 0.5] + else: + rot_range = [0, -45] if clip == "a" else [0, 45] + zoom_range = [1, 0.5] if clip == "a" else [1, 1.5] + + return { + "rotate_angle": rot_range[0] + energy * (rot_range[1] - rot_range[0]), + "zoom_amount": zoom_range[0] + energy * (zoom_range[1] - zoom_range[0]), + "invert_amount": self.scans.get_emit(f"inv_{clip}{suffix}"), + "hue_degrees": self.scans.get_emit(f"hue_{clip}{suffix}"), + "ascii_mix": 0, # Disabled - too slow without GPU + "ascii_char_size": 4 + energy * 28, # 4-32 + } + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair.""" + suffix = f"_{source_idx}" + return { + "blend_opacity": self.scans.get_emit(f"pair_mix{suffix}"), + "pair_rotation": self.scans.get_emit(f"pair_rot{suffix}"), + } + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade composition.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + weights = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + w = 0.9 + elif phase3 < clen * 2: + w = 0.9 - ((phase3 - clen) / clen) * 0.85 + else: + w = 0.05 + elif active == prev: + if phase3 < clen: + w = 0.05 + elif phase3 < clen * 2: + w = 0.05 + ((phase3 - clen) / clen) * 0.85 + else: + w = 0.9 + else: + w = 0.05 + + weights.append(w) + + # Normalize + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + + return weights + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + zooms = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + z = 1.0 + elif phase3 < clen * 2: + z = 1.0 + ((phase3 - clen) / clen) * 1.0 + else: + z = 0.1 + elif active == prev: + if phase3 < clen: + z = 3.0 # Start big + elif phase3 < clen * 2: + z = 3.0 - ((phase3 - clen) / clen) * 2.0 # Shrink to 1.0 + else: + z = 1.0 + else: + z = 0.1 + + zooms.append(z) + + return zooms + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (whole-spin, ripple).""" + ripple_gate = self.scans.get_emit("ripple_gate") + ripple_state = self.scans.scans["ripple_gate"].value + + return { + "whole_spin_angle": self.scans.get_emit("whole_spin"), + "ripple_amplitude": ripple_gate * (5 + energy * 45), # 5-50 + "ripple_cx": ripple_state["cx"], + "ripple_cy": ripple_state["cy"], + } diff --git a/l1/streaming/sexp_executor.py b/l1/streaming/sexp_executor.py new file mode 100644 index 0000000..0151853 --- /dev/null +++ b/l1/streaming/sexp_executor.py @@ -0,0 +1,678 @@ +""" +Streaming S-expression executor. + +Executes compiled sexp recipes in real-time by: +- Evaluating scan expressions on each beat +- Resolving bindings to get effect parameter values +- Applying effects frame-by-frame +- Evaluating SLICE_ON Lambda for cycle crossfade +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sexp_interp import SexpInterpreter, eval_slice_on_lambda + + +@dataclass +class ScanState: + """Runtime state for a scan.""" + node_id: str + name: Optional[str] + value: Any + rng: random.Random + init_expr: dict + step_expr: dict + emit_expr: dict + + +class ExprEvaluator: + """ + Evaluates compiled expression ASTs. + + Expressions are dicts with: + - _expr: True (marks as expression) + - op: operation name + - args: list of arguments + - name: for 'var' ops + - keys: for 'dict' ops + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + + def eval(self, expr: Any, env: Dict[str, Any]) -> Any: + """Evaluate an expression in the given environment.""" + # Literal values + if not isinstance(expr, dict): + return expr + + # Check if it's an expression + if not expr.get('_expr'): + # It's a plain dict - return as-is + return expr + + op = expr.get('op') + args = expr.get('args', []) + + # Evaluate based on operation + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + raise KeyError(f"Unknown variable: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison ops + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic ops + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random ops + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.randint(lo, hi) + elif op == 'rand-range': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.uniform(lo, hi) + + # Logic ops + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown operation: {op}") + + +class SexpStreamingExecutor: + """ + Executes a compiled sexp recipe in streaming mode. + + Reads scan definitions, effect chains, and bindings from the + compiled recipe and executes them frame-by-frame. + """ + + def __init__(self, compiled_recipe, seed: int = 42): + self.recipe = compiled_recipe + self.master_seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # State (must be initialized before _init_scans) + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self.last_beat_detected = False + self.energy = 0.0 + + # Initialize scans + self.scans: Dict[str, ScanState] = {} + self.scan_outputs: Dict[str, Any] = {} # Current emit values by node_id + self._init_scans() + + # Initialize SLICE_ON interpreter + self.sexp_interp = SexpInterpreter(random.Random(seed)) + self._slice_on_lambda = None + self._slice_on_acc = None + self._slice_on_result = None # Last evaluation result {layers, compose, acc} + self._init_slice_on() + + def _init_slice_on(self): + """Initialize SLICE_ON Lambda for cycle crossfade.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + self._slice_on_lambda = config.get('fn') + init = config.get('init', {}) + self._slice_on_acc = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + # Evaluate initial state + self._eval_slice_on() + break + + def _eval_slice_on(self): + """Evaluate the SLICE_ON Lambda with current state.""" + if not self._slice_on_lambda: + return + + n = len(self._get_video_sources()) + videos = list(range(n)) # Placeholder video indices + + try: + result = eval_slice_on_lambda( + self._slice_on_lambda, + self._slice_on_acc, + self.beat_count, + 0.0, # start time (not used for weights) + 1.0, # end time (not used for weights) + videos, + self.sexp_interp, + ) + self._slice_on_result = result + # Update accumulator for next beat + if 'acc' in result: + self._slice_on_acc = result['acc'] + except Exception as e: + import sys + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _init_scans(self): + """Initialize all scan nodes from the recipe.""" + seed_offset = 0 + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + node_id = node['id'] + config = node.get('config', {}) + + # Create RNG with unique seed + scan_seed = config.get('seed', self.master_seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + evaluator = ExprEvaluator(rng) + init_value = evaluator.eval(init_expr, {}) + + self.scans[node_id] = ScanState( + node_id=node_id, + name=node.get('name'), + value=init_value, + rng=rng, + init_expr=init_expr, + step_expr=config.get('step_expr', {}), + emit_expr=config.get('emit_expr', {}), + ) + + # Compute initial emit + self._update_emit(node_id) + + def _update_emit(self, node_id: str): + """Update the emit value for a scan.""" + scan = self.scans[node_id] + evaluator = ExprEvaluator(scan.rng) + + # Build environment from current state + env = self._build_scan_env(scan) + + # Evaluate emit expression + emit_value = evaluator.eval(scan.emit_expr, env) + self.scan_outputs[node_id] = emit_value + + def _build_scan_env(self, scan: ScanState) -> Dict[str, Any]: + """Build environment for scan expression evaluation.""" + env = {} + + # Add state variables + if isinstance(scan.value, dict): + env.update(scan.value) + else: + env['acc'] = scan.value + + # Add beat count + env['beat_count'] = self.beat_count + env['time'] = self.current_time + + return env + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + + # Estimate beat interval + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + # Step each scan + for node_id, scan in self.scans.items(): + evaluator = ExprEvaluator(scan.rng) + env = self._build_scan_env(scan) + + # Evaluate step expression + new_value = evaluator.eval(scan.step_expr, env) + scan.value = new_value + + # Update emit + self._update_emit(node_id) + + # Step the cycle state + self._step_cycle() + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with audio analysis.""" + self.current_time = t + self.energy = energy + + # Update scans on beat (edge detection) + if is_beat and not self.last_beat_detected: + self.on_beat() + self.last_beat_detected = is_beat + + def resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to get the current value.""" + if not isinstance(binding, dict) or not binding.get('_binding'): + return binding + + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get the raw value + if source_id in self.scan_outputs: + value = self.scan_outputs[source_id] + else: + # Might be an analyzer reference - use energy as fallback + value = self.energy + + # Extract feature if value is a dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def get_effect_params(self, effect_node: dict) -> Dict[str, Any]: + """Get resolved parameters for an effect node.""" + config = effect_node.get('config', {}) + params = {} + + for key, value in config.items(): + # Skip internal fields + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', 'analysis_refs'): + continue + + # Resolve bindings + params[key] = self.resolve_binding(value) + + return params + + def get_scan_value(self, name: str) -> Any: + """Get scan output by name.""" + for node_id, scan in self.scans.items(): + if scan.name == name: + return self.scan_outputs.get(node_id) + return None + + def get_all_scan_values(self) -> Dict[str, Any]: + """Get all named scan outputs.""" + result = {} + for node_id, scan in self.scans.items(): + if scan.name: + result[scan.name] = self.scan_outputs.get(node_id) + return result + + # === Compositor interface methods === + + def _get_video_sources(self) -> List[str]: + """Get list of video source node IDs.""" + sources = [] + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + sources.append(node['id']) + # Filter to video only (exclude audio - last one is usually audio) + # Look at file extensions in the paths + return sources[:-1] if len(sources) > 1 else sources + + def _trace_effect_chain(self, start_id: str, stop_at_blend: bool = True) -> List[dict]: + """Trace effect chain from a node, returning effects in order.""" + chain = [] + current_id = start_id + + for _ in range(20): # Max depth + # Find node that uses current as input + next_node = None + for node in self.recipe.nodes: + if current_id in node.get('inputs', []): + if node.get('type') == 'EFFECT': + effect_type = node.get('config', {}).get('effect') + chain.append(node) + if stop_at_blend and effect_type == 'blend': + return chain + next_node = node + break + elif node.get('type') == 'SEGMENT': + next_node = node + break + + if next_node is None: + break + current_id = next_node['id'] + + return chain + + def _find_clip_chains(self, source_idx: int) -> tuple: + """Find effect chains for clip A and B from a source.""" + sources = self._get_video_sources() + if source_idx >= len(sources): + return [], [] + + source_id = sources[source_idx] + + # Find SEGMENT node + segment_id = None + for node in self.recipe.nodes: + if node.get('type') == 'SEGMENT' and source_id in node.get('inputs', []): + segment_id = node['id'] + break + + if not segment_id: + return [], [] + + # Find the two effect chains from segment (clip A and clip B) + chains = [] + for node in self.recipe.nodes: + if segment_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + chain = self._trace_effect_chain(segment_id) + # Get chain starting from this specific branch + branch_chain = [node] + current = node['id'] + for _ in range(10): + found = False + for n in self.recipe.nodes: + if current in n.get('inputs', []) and n.get('type') == 'EFFECT': + branch_chain.append(n) + if n.get('config', {}).get('effect') == 'blend': + break + current = n['id'] + found = True + break + if not found: + break + chains.append(branch_chain) + + # Return first two chains as A and B + chain_a = chains[0] if len(chains) > 0 else [] + chain_b = chains[1] if len(chains) > 1 else [] + return chain_a, chain_b + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """Get effect parameters for a source clip (compositor interface).""" + # Get the correct chain for this clip + chain_a, chain_b = self._find_clip_chains(source_idx) + chain = chain_a if clip == 'a' else chain_b + + # Default params + params = { + "rotate_angle": 0, + "zoom_amount": 1.0, + "invert_amount": 0, + "hue_degrees": 0, + "ascii_mix": 0, + "ascii_char_size": 8, + } + + # Resolve from effects in chain + for eff in chain: + config = eff.get('config', {}) + effect_type = config.get('effect') + + if effect_type == 'rotate': + angle_binding = config.get('angle') + if angle_binding: + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + # Bound to analyzer - use energy with range + range_map = angle_binding.get('range') + if range_map: + lo, hi = range_map + params["rotate_angle"] = lo + energy * (hi - lo) + else: + params["rotate_angle"] = self.resolve_binding(angle_binding) + else: + params["rotate_angle"] = angle_binding if isinstance(angle_binding, (int, float)) else 0 + + elif effect_type == 'zoom': + amount_binding = config.get('amount') + if amount_binding: + if isinstance(amount_binding, dict) and amount_binding.get('_binding'): + range_map = amount_binding.get('range') + if range_map: + lo, hi = range_map + params["zoom_amount"] = lo + energy * (hi - lo) + else: + params["zoom_amount"] = self.resolve_binding(amount_binding) + else: + params["zoom_amount"] = amount_binding if isinstance(amount_binding, (int, float)) else 1.0 + + elif effect_type == 'invert': + amount_binding = config.get('amount') + if amount_binding: + val = self.resolve_binding(amount_binding) + params["invert_amount"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'hue_shift': + deg_binding = config.get('degrees') + if deg_binding: + val = self.resolve_binding(deg_binding) + params["hue_degrees"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'ascii_art': + mix_binding = config.get('mix') + if mix_binding: + val = self.resolve_binding(mix_binding) + params["ascii_mix"] = val if isinstance(val, (int, float)) else 0 + size_binding = config.get('char_size') + if size_binding: + if isinstance(size_binding, dict) and size_binding.get('_binding'): + range_map = size_binding.get('range') + if range_map: + lo, hi = range_map + params["ascii_char_size"] = lo + energy * (hi - lo) + + return params + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair (compositor interface).""" + params = { + "blend_opacity": 0.5, + "pair_rotation": 0, + } + + # Find the blend node for this source + chain_a, _ = self._find_clip_chains(source_idx) + + # The last effect in chain_a should be the blend + blend_node = None + for eff in reversed(chain_a): + if eff.get('config', {}).get('effect') == 'blend': + blend_node = eff + break + + if blend_node: + config = blend_node.get('config', {}) + opacity_binding = config.get('opacity') + if opacity_binding: + val = self.resolve_binding(opacity_binding) + if isinstance(val, (int, float)): + params["blend_opacity"] = val + + # Find rotate after blend (pair rotation) + blend_id = blend_node['id'] + for node in self.recipe.nodes: + if blend_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + if node.get('config', {}).get('effect') == 'rotate': + angle_binding = node.get('config', {}).get('angle') + if angle_binding: + val = self.resolve_binding(angle_binding) + if isinstance(val, (int, float)): + params["pair_rotation"] = val + break + + return params + + def _get_cycle_state(self) -> dict: + """Get current cycle state from SLICE_ON or internal tracking.""" + if not hasattr(self, '_cycle_state'): + # Initialize from SLICE_ON node + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + init = node.get('config', {}).get('init', {}) + self._cycle_state = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + break + else: + self._cycle_state = {'cycle': 0, 'beat': 0, 'clen': 60} + + return self._cycle_state + + def _step_cycle(self): + """Step the cycle state forward on beat by evaluating SLICE_ON Lambda.""" + # Use interpreter to evaluate the Lambda + self._eval_slice_on() + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get weights from interpreted result + if self._slice_on_result: + compose = self._slice_on_result.get('compose', {}) + weights = compose.get('weights', []) + if weights and len(weights) == n: + # Normalize + total = sum(weights) + if total > 0: + return [w / total for w in weights] + + # Fallback: equal weights + return [1.0 / n] * n + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get zooms from interpreted result (layers -> effects -> zoom amount) + if self._slice_on_result: + layers = self._slice_on_result.get('layers', []) + if layers and len(layers) == n: + zooms = [] + for layer in layers: + effects = layer.get('effects', []) + zoom_amt = 1.0 + for eff in effects: + if eff.get('effect') == 'zoom' or (hasattr(eff.get('effect'), 'name') and eff.get('effect').name == 'zoom'): + zoom_amt = eff.get('amount', 1.0) + break + zooms.append(zoom_amt) + return zooms + + # Fallback + return [1.0] * n + + def _get_final_rotate_scan_id(self) -> str: + """Find the scan ID that drives the final rotation (after SLICE_ON).""" + if hasattr(self, '_final_rotate_scan_id'): + return self._final_rotate_scan_id + + # Find SLICE_ON node index + slice_on_idx = None + for i, node in enumerate(self.recipe.nodes): + if node.get('type') == 'SLICE_ON': + slice_on_idx = i + break + + # Find rotate effect after SLICE_ON + if slice_on_idx is not None: + for node in self.recipe.nodes[slice_on_idx + 1:]: + if node.get('type') == 'EFFECT': + config = node.get('config', {}) + if config.get('effect') == 'rotate': + angle_binding = config.get('angle', {}) + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + self._final_rotate_scan_id = angle_binding.get('source') + return self._final_rotate_scan_id + + self._final_rotate_scan_id = None + return None + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (compositor interface).""" + # Get named scans + scan_values = self.get_all_scan_values() + + # Whole spin - get from the specific scan bound to final rotate effect + whole_spin = 0 + final_rotate_scan_id = self._get_final_rotate_scan_id() + if final_rotate_scan_id and final_rotate_scan_id in self.scan_outputs: + val = self.scan_outputs[final_rotate_scan_id] + if isinstance(val, dict) and 'angle' in val: + whole_spin = val['angle'] + elif isinstance(val, (int, float)): + whole_spin = val + + # Ripple + ripple_gate = scan_values.get('ripple-gate', 0) + ripple_cx = scan_values.get('ripple-cx', 0.5) + ripple_cy = scan_values.get('ripple-cy', 0.5) + + if isinstance(ripple_gate, dict): + ripple_gate = ripple_gate.get('gate', 0) if 'gate' in ripple_gate else 1 + + return { + "whole_spin_angle": whole_spin, + "ripple_amplitude": ripple_gate * (5 + energy * 45), + "ripple_cx": ripple_cx if isinstance(ripple_cx, (int, float)) else 0.5, + "ripple_cy": ripple_cy if isinstance(ripple_cy, (int, float)) else 0.5, + } diff --git a/l1/streaming/sexp_interp.py b/l1/streaming/sexp_interp.py new file mode 100644 index 0000000..e3433b2 --- /dev/null +++ b/l1/streaming/sexp_interp.py @@ -0,0 +1,376 @@ +""" +S-expression interpreter for streaming execution. + +Evaluates sexp expressions including: +- let bindings +- lambda definitions and calls +- Arithmetic, comparison, logic operators +- dict/list operations +- Random number generation +""" + +import random +from typing import Any, Dict, List, Callable +from dataclasses import dataclass + + +@dataclass +class Lambda: + """Runtime lambda value.""" + params: List[str] + body: Any + closure: Dict[str, Any] + + +class Symbol: + """Symbol reference.""" + def __init__(self, name: str): + self.name = name + + def __repr__(self): + return f"Symbol({self.name})" + + +class SexpInterpreter: + """ + Interprets S-expressions in real-time. + + Handles the full sexp language used in recipes. + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + self.globals: Dict[str, Any] = {} + + def eval(self, expr: Any, env: Dict[str, Any] = None) -> Any: + """Evaluate an expression in the given environment.""" + if env is None: + env = {} + + # Literals + if isinstance(expr, (int, float, str, bool)) or expr is None: + return expr + + # Symbol lookup + if isinstance(expr, Symbol) or (hasattr(expr, 'name') and hasattr(expr, '__class__') and expr.__class__.__name__ == 'Symbol'): + name = expr.name if hasattr(expr, 'name') else str(expr) + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined symbol: {name}") + + # Compiled expression dict (from compiler) + if isinstance(expr, dict): + if expr.get('_expr'): + return self._eval_compiled_expr(expr, env) + # Plain dict - evaluate values that might be expressions + result = {} + for k, v in expr.items(): + # Some keys should keep Symbol values as strings (effect names, modes) + if k in ('effect', 'mode') and hasattr(v, 'name'): + result[k] = v.name + else: + result[k] = self.eval(v, env) + return result + + # List expression (sexp) + if isinstance(expr, (list, tuple)) and len(expr) > 0: + return self._eval_list(expr, env) + + # Empty list + if isinstance(expr, (list, tuple)): + return [] + + return expr + + def _eval_compiled_expr(self, expr: dict, env: Dict[str, Any]) -> Any: + """Evaluate a compiled expression dict.""" + op = expr.get('op') + args = expr.get('args', []) + + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + return self.rng.randint(self.eval(args[0], env), self.eval(args[1], env)) + elif op == 'rand-range': + return self.rng.uniform(self.eval(args[0], env), self.eval(args[1], env)) + + # Logic + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown op: {op}") + + def _eval_list(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate a list expression (sexp form).""" + if len(expr) == 0: + return [] + + head = expr[0] + + # Get head name + if isinstance(head, Symbol) or (hasattr(head, 'name') and hasattr(head, '__class__')): + head_name = head.name if hasattr(head, 'name') else str(head) + elif isinstance(head, str): + head_name = head + else: + # Not a symbol - check if it's a data list or function call + if isinstance(head, dict): + # List of dicts - evaluate each element as data + return [self.eval(item, env) for item in expr] + # Otherwise evaluate as function call + fn = self.eval(head, env) + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + # Special forms + if head_name == 'let': + return self._eval_let(expr, env) + elif head_name in ('lambda', 'fn'): + return self._eval_lambda(expr, env) + elif head_name == 'if': + return self._eval_if(expr, env) + elif head_name == 'dict': + return self._eval_dict(expr, env) + elif head_name == 'get': + obj = self.eval(expr[1], env) + key = self.eval(expr[2], env) if len(expr) > 2 else expr[2] + if isinstance(key, str): + return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None) + return obj[key] + elif head_name == 'len': + return len(self.eval(expr[1], env)) + elif head_name == 'range': + start = self.eval(expr[1], env) + end = self.eval(expr[2], env) if len(expr) > 2 else start + if len(expr) == 2: + return list(range(end)) + return list(range(start, end)) + elif head_name == 'map': + fn = self.eval(expr[1], env) + lst = self.eval(expr[2], env) + return [self._call(fn, [x], env) for x in lst] + elif head_name == 'mod': + return self.eval(expr[1], env) % self.eval(expr[2], env) + + # Arithmetic + elif head_name == '+': + return self.eval(expr[1], env) + self.eval(expr[2], env) + elif head_name == '-': + if len(expr) == 2: + return -self.eval(expr[1], env) + return self.eval(expr[1], env) - self.eval(expr[2], env) + elif head_name == '*': + return self.eval(expr[1], env) * self.eval(expr[2], env) + elif head_name == '/': + return self.eval(expr[1], env) / self.eval(expr[2], env) + + # Comparison + elif head_name == '<': + return self.eval(expr[1], env) < self.eval(expr[2], env) + elif head_name == '>': + return self.eval(expr[1], env) > self.eval(expr[2], env) + elif head_name == '<=': + return self.eval(expr[1], env) <= self.eval(expr[2], env) + elif head_name == '>=': + return self.eval(expr[1], env) >= self.eval(expr[2], env) + elif head_name == '=': + return self.eval(expr[1], env) == self.eval(expr[2], env) + + # Logic + elif head_name == 'and': + return all(self.eval(a, env) for a in expr[1:]) + elif head_name == 'or': + return any(self.eval(a, env) for a in expr[1:]) + elif head_name == 'not': + return not self.eval(expr[1], env) + + # Function call + else: + fn = env.get(head_name) or self.globals.get(head_name) + if fn is None: + raise NameError(f"Undefined function: {head_name}") + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + def _eval_let(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (let [bindings...] body).""" + bindings = expr[1] + body = expr[2] + + # Create new environment with bindings + new_env = dict(env) + + # Process bindings in pairs + i = 0 + while i < len(bindings): + name = bindings[i] + if isinstance(name, Symbol) or hasattr(name, 'name'): + name = name.name if hasattr(name, 'name') else str(name) + value = self.eval(bindings[i + 1], new_env) + new_env[name] = value + i += 2 + + return self.eval(body, new_env) + + def _eval_lambda(self, expr: list, env: Dict[str, Any]) -> Lambda: + """Evaluate (lambda [params] body).""" + params_expr = expr[1] + body = expr[2] + + # Extract parameter names + params = [] + for p in params_expr: + if isinstance(p, Symbol) or hasattr(p, 'name'): + params.append(p.name if hasattr(p, 'name') else str(p)) + else: + params.append(str(p)) + + return Lambda(params=params, body=body, closure=dict(env)) + + def _eval_if(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + def _eval_dict(self, expr: list, env: Dict[str, Any]) -> dict: + """Evaluate (dict :key val ...).""" + result = {} + i = 1 + while i < len(expr): + key = expr[i] + # Handle keyword syntax (:key) and Keyword objects + if hasattr(key, 'name'): + key = key.name + elif hasattr(key, '__class__') and key.__class__.__name__ == 'Keyword': + key = str(key).lstrip(':') + elif isinstance(key, str) and key.startswith(':'): + key = key[1:] + value = self.eval(expr[i + 1], env) + result[key] = value + i += 2 + return result + + def _call(self, fn: Any, args: List[Any], env: Dict[str, Any]) -> Any: + """Call a function with arguments.""" + if isinstance(fn, Lambda): + # Our own Lambda type + call_env = dict(fn.closure) + for param, arg in zip(fn.params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif hasattr(fn, 'params') and hasattr(fn, 'body'): + # Lambda from parser (artdag.sexp.parser.Lambda) + call_env = dict(env) + if hasattr(fn, 'closure') and fn.closure: + call_env.update(fn.closure) + # Get param names + params = [] + for p in fn.params: + if hasattr(p, 'name'): + params.append(p.name) + else: + params.append(str(p)) + for param, arg in zip(params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif callable(fn): + return fn(*args) + else: + raise TypeError(f"Not callable: {type(fn).__name__}") + + +def eval_slice_on_lambda(lambda_obj, acc: dict, i: int, start: float, end: float, + videos: list, interp: SexpInterpreter = None) -> dict: + """ + Evaluate a SLICE_ON lambda function. + + Args: + lambda_obj: The Lambda object from the compiled recipe + acc: Current accumulator state + i: Beat index + start: Slice start time + end: Slice end time + videos: List of video inputs + interp: Interpreter to use + + Returns: + Dict with 'layers', 'compose', 'acc' keys + """ + if interp is None: + interp = SexpInterpreter() + + # Set up global 'videos' for (len videos) to work + interp.globals['videos'] = videos + + # Build initial environment with lambda parameters + env = dict(lambda_obj.closure) if hasattr(lambda_obj, 'closure') and lambda_obj.closure else {} + env['videos'] = videos + + # Call the lambda + result = interp._call(lambda_obj, [acc, i, start, end], env) + + return result diff --git a/l1/streaming/sexp_to_cuda.py b/l1/streaming/sexp_to_cuda.py new file mode 100644 index 0000000..e4051bd --- /dev/null +++ b/l1/streaming/sexp_to_cuda.py @@ -0,0 +1,706 @@ +""" +Sexp to CUDA Kernel Compiler. + +Compiles sexp frame pipelines to fused CUDA kernels for maximum performance. +Instead of interpreting sexp and launching 10+ kernels per frame, +generates a single kernel that does everything in one pass. +""" + +import cupy as cp +import numpy as np +from typing import Dict, List, Any, Optional, Tuple +import hashlib +import sys +import logging + +logger = logging.getLogger(__name__) + +# Kernel cache +_COMPILED_KERNELS: Dict[str, Any] = {} + + +def compile_frame_pipeline(effects: List[dict], width: int, height: int) -> callable: + """ + Compile a list of effects to a fused CUDA kernel. + + Args: + effects: List of effect dicts like: + [{'op': 'rotate', 'angle': 45.0}, + {'op': 'blend', 'alpha': 0.5, 'src2': }, + {'op': 'hue_shift', 'degrees': 90.0}, + {'op': 'ripple', 'amplitude': 10.0, 'frequency': 8.0, ...}] + width, height: Frame dimensions + + Returns: + Callable that takes input frame and returns output frame + """ + + # Generate cache key + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects]) + cache_key = f"{width}x{height}_{hashlib.md5(ops_key.encode()).hexdigest()}" + + if cache_key in _COMPILED_KERNELS: + return _COMPILED_KERNELS[cache_key] + + # Generate fused kernel code + kernel_code = _generate_fused_kernel(effects, width, height) + + # Compile kernel + kernel = cp.RawKernel(kernel_code, 'fused_pipeline') + + # Create wrapper function + def run_pipeline(frame: cp.ndarray, **dynamic_params) -> cp.ndarray: + """Run the compiled pipeline on a frame.""" + if frame.dtype != cp.uint8: + frame = cp.clip(frame, 0, 255).astype(cp.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = cp.ascontiguousarray(frame) + + output = cp.zeros_like(frame) + + block = (16, 16) + grid = ((width + 15) // 16, (height + 15) // 16) + + # Build parameter array + params = _build_params(effects, dynamic_params) + + kernel(grid, block, (frame, output, width, height, params)) + + return output + + _COMPILED_KERNELS[cache_key] = run_pipeline + return run_pipeline + + +def _generate_fused_kernel(effects: List[dict], width: int, height: int) -> str: + """Generate CUDA kernel code for fused effects pipeline.""" + + # Validate all ops are supported + SUPPORTED_OPS = {'rotate', 'zoom', 'ripple', 'invert', 'hue_shift', 'brightness'} + for effect in effects: + op = effect.get('op') + if op not in SUPPORTED_OPS: + raise ValueError(f"Unsupported CUDA kernel operation: '{op}'. Supported ops: {', '.join(sorted(SUPPORTED_OPS))}. Note: 'resize' must be handled separately before the fused kernel.") + + # Build the kernel + code = r''' +extern "C" __global__ +void fused_pipeline( + const unsigned char* src, + unsigned char* dst, + int width, int height, + const float* params +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Start with source coordinates + float src_x = (float)x; + float src_y = (float)y; + float cx = width / 2.0f; + float cy = height / 2.0f; + + // Track accumulated transforms + float total_cos = 1.0f, total_sin = 0.0f; // rotation + float total_zoom = 1.0f; // zoom + float ripple_dx = 0.0f, ripple_dy = 0.0f; // ripple displacement + + int param_idx = 0; + +''' + + # Add effect-specific code + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'rotate': + code += f''' + // Rotate {i} + {{ + float angle = params[param_idx++] * 3.14159265f / 180.0f; + float c = cosf(angle); + float s = sinf(angle); + // Compose with existing rotation + float nc = total_cos * c - total_sin * s; + float ns = total_cos * s + total_sin * c; + total_cos = nc; + total_sin = ns; + }} +''' + elif op == 'zoom': + code += f''' + // Zoom {i} + {{ + float zoom = params[param_idx++]; + total_zoom *= zoom; + }} +''' + elif op == 'ripple': + code += f''' + // Ripple {i} - matching original formula: sin(dist/freq - phase) * exp(-dist*decay/maxdim) + {{ + float amplitude = params[param_idx++]; + float frequency = params[param_idx++]; + float decay = params[param_idx++]; + float phase = params[param_idx++]; + float rcx = params[param_idx++]; + float rcy = params[param_idx++]; + + float rdx = src_x - rcx; + float rdy = src_y - rcy; + float dist = sqrtf(rdx * rdx + rdy * rdy); + float max_dim = (float)(width > height ? width : height); + + // Original formula: sin(dist / frequency - phase) * exp(-dist * decay / max_dim) + float wave = sinf(dist / frequency - phase); + float amp = amplitude * expf(-dist * decay / max_dim); + + if (dist > 0.001f) {{ + ripple_dx += rdx / dist * wave * amp; + ripple_dy += rdy / dist * wave * amp; + }} + }} +''' + + # Apply all geometric transforms at once + code += ''' + // Apply accumulated geometric transforms + { + // Translate to center + float dx = src_x - cx; + float dy = src_y - cy; + + // Apply rotation + float rx = total_cos * dx + total_sin * dy; + float ry = -total_sin * dx + total_cos * dy; + + // Apply zoom (inverse for sampling) + rx /= total_zoom; + ry /= total_zoom; + + // Translate back and apply ripple + src_x = rx + cx - ripple_dx; + src_y = ry + cy - ripple_dy; + } + + // Sample source with bilinear interpolation + float r, g, b; + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + r = g = b = 0; + } else { + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + int idx00 = (y0 * width + x0) * 3; + int idx10 = (y0 * width + x0 + 1) * 3; + int idx01 = ((y0 + 1) * width + x0) * 3; + int idx11 = ((y0 + 1) * width + x0 + 1) * 3; + + #define BILERP(c) \\ + (src[idx00 + c] * (1-fx) * (1-fy) + \\ + src[idx10 + c] * fx * (1-fy) + \\ + src[idx01 + c] * (1-fx) * fy + \\ + src[idx11 + c] * fx * fy) + + r = BILERP(0); + g = BILERP(1); + b = BILERP(2); + } + +''' + + # Add color transforms + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'invert': + code += f''' + // Invert {i} + {{ + float amount = params[param_idx++]; + if (amount > 0.5f) {{ + r = 255.0f - r; + g = 255.0f - g; + b = 255.0f - b; + }} + }} +''' + elif op == 'hue_shift': + code += f''' + // Hue shift {i} + {{ + float shift = params[param_idx++]; + if (fabsf(shift) > 0.01f) {{ + // RGB to HSV + float rf = r / 255.0f; + float gf = g / 255.0f; + float bf = b / 255.0f; + + float max_c = fmaxf(rf, fmaxf(gf, bf)); + float min_c = fminf(rf, fminf(gf, bf)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) {{ + s = delta / max_c; + if (rf >= max_c) h = (gf - bf) / delta; + else if (gf >= max_c) h = 2.0f + (bf - rf) / delta; + else h = 4.0f + (rf - gf) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + }} + + h = fmodf(h + shift + 360.0f, 360.0f); + + // HSV to RGB + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) {{ r2 = c; g2 = x_val; b2 = 0; }} + else if (h < 120) {{ r2 = x_val; g2 = c; b2 = 0; }} + else if (h < 180) {{ r2 = 0; g2 = c; b2 = x_val; }} + else if (h < 240) {{ r2 = 0; g2 = x_val; b2 = c; }} + else if (h < 300) {{ r2 = x_val; g2 = 0; b2 = c; }} + else {{ r2 = c; g2 = 0; b2 = x_val; }} + + r = (r2 + m) * 255.0f; + g = (g2 + m) * 255.0f; + b = (b2 + m) * 255.0f; + }} + }} +''' + elif op == 'brightness': + code += f''' + // Brightness {i} + {{ + float factor = params[param_idx++]; + r *= factor; + g *= factor; + b *= factor; + }} +''' + + # Write output + code += ''' + // Write output + int dst_idx = (y * width + x) * 3; + dst[dst_idx] = (unsigned char)fminf(255.0f, fmaxf(0.0f, r)); + dst[dst_idx + 1] = (unsigned char)fminf(255.0f, fmaxf(0.0f, g)); + dst[dst_idx + 2] = (unsigned char)fminf(255.0f, fmaxf(0.0f, b)); +} +''' + + return code + + +_BUILD_PARAMS_COUNT = 0 + +def _build_params(effects: List[dict], dynamic_params: dict) -> cp.ndarray: + """Build parameter array for kernel. + + IMPORTANT: Parameters must be built in the same order the kernel consumes them: + 1. First all geometric transforms (rotate, zoom, ripple) in list order + 2. Then all color transforms (invert, hue_shift, brightness) in list order + """ + global _BUILD_PARAMS_COUNT + _BUILD_PARAMS_COUNT += 1 + + # ALWAYS log first few calls - use WARNING to ensure visibility in Celery logs + if _BUILD_PARAMS_COUNT <= 3: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] effects={[e['op'] for e in effects]}") + + params = [] + + # First pass: geometric transforms (matches kernel's first loop) + for effect in effects: + op = effect['op'] + + if op == 'rotate': + params.append(float(dynamic_params.get('rotate_angle', effect.get('angle', 0)))) + elif op == 'zoom': + params.append(float(dynamic_params.get('zoom_amount', effect.get('amount', 1.0)))) + elif op == 'ripple': + amp = float(dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))) + freq = float(effect.get('frequency', 8)) + decay = float(effect.get('decay', 2)) + phase = float(dynamic_params.get('ripple_phase', effect.get('phase', 0))) + cx = float(effect.get('center_x', 960)) + cy = float(effect.get('center_y', 540)) + params.extend([amp, freq, decay, phase, cx, cy]) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] ripple amp={amp} freq={freq} decay={decay} phase={phase:.2f} cx={cx} cy={cy}") + + # Second pass: color transforms (matches kernel's second loop) + for effect in effects: + op = effect['op'] + + if op == 'invert': + amt = float(effect.get('amount', 0)) + params.append(amt) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] invert amount={amt}") + elif op == 'hue_shift': + deg = float(effect.get('degrees', 0)) + params.append(deg) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] hue_shift degrees={deg}") + elif op == 'brightness': + params.append(float(effect.get('factor', 1.0))) + + return cp.array(params, dtype=cp.float32) + + +def compile_autonomous_pipeline(effects: List[dict], width: int, height: int, + dynamic_expressions: dict = None) -> callable: + """ + Compile a fully autonomous pipeline that computes ALL parameters on GPU. + + This eliminates Python from the hot path - the kernel computes time-based + parameters (sin, cos, etc.) directly on GPU. + + Args: + effects: List of effect dicts + width, height: Frame dimensions + dynamic_expressions: Dict mapping param names to expressions, e.g.: + {'rotate_angle': 't * 30', + 'ripple_phase': 't * 2', + 'brightness_factor': '0.8 + 0.4 * sin(t * 2)'} + + Returns: + Callable that takes (frame, frame_num, fps) and returns output frame + """ + if dynamic_expressions is None: + dynamic_expressions = {} + + # Generate cache key + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects]) + expr_key = str(sorted(dynamic_expressions.items())) + cache_key = f"auto_{width}x{height}_{hashlib.md5((ops_key + expr_key).encode()).hexdigest()}" + + if cache_key in _COMPILED_KERNELS: + return _COMPILED_KERNELS[cache_key] + + # Generate autonomous kernel code + kernel_code = _generate_autonomous_kernel(effects, width, height, dynamic_expressions) + + # Compile kernel + kernel = cp.RawKernel(kernel_code, 'autonomous_pipeline') + + # Create wrapper function + def run_autonomous(frame: cp.ndarray, frame_num: int, fps: float = 30.0) -> cp.ndarray: + """Run the autonomous pipeline - no Python in the hot path!""" + if frame.dtype != cp.uint8: + frame = cp.clip(frame, 0, 255).astype(cp.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = cp.ascontiguousarray(frame) + + output = cp.zeros_like(frame) + + block = (16, 16) + grid = ((width + 15) // 16, (height + 15) // 16) + + # Only pass frame_num and fps - kernel computes everything else! + t = float(frame_num) / float(fps) + kernel(grid, block, (frame, output, np.int32(width), np.int32(height), + np.float32(t), np.int32(frame_num))) + + return output + + _COMPILED_KERNELS[cache_key] = run_autonomous + return run_autonomous + + +def _generate_autonomous_kernel(effects: List[dict], width: int, height: int, + dynamic_expressions: dict) -> str: + """Generate CUDA kernel that computes everything autonomously.""" + + # Map simple expressions to CUDA code + def expr_to_cuda(expr: str) -> str: + """Convert simple expression to CUDA.""" + expr = expr.replace('sin(', 'sinf(') + expr = expr.replace('cos(', 'cosf(') + expr = expr.replace('abs(', 'fabsf(') + return expr + + code = r''' +extern "C" __global__ +void autonomous_pipeline( + const unsigned char* src, + unsigned char* dst, + int width, int height, + float t, int frame_num +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Compute dynamic parameters from time (ALL ON GPU!) +''' + + # Add dynamic parameter calculations + rotate_expr = dynamic_expressions.get('rotate_angle', '0.0f') + ripple_phase_expr = dynamic_expressions.get('ripple_phase', '0.0f') + brightness_expr = dynamic_expressions.get('brightness_factor', '1.0f') + zoom_expr = dynamic_expressions.get('zoom_amount', '1.0f') + + code += f''' + float rotate_angle = {expr_to_cuda(rotate_expr)}; + float ripple_phase = {expr_to_cuda(ripple_phase_expr)}; + float brightness_factor = {expr_to_cuda(brightness_expr)}; + float zoom_amount = {expr_to_cuda(zoom_expr)}; + + // Start with source coordinates + float src_x = (float)x; + float src_y = (float)y; + float cx = width / 2.0f; + float cy = height / 2.0f; + + // Accumulated transforms + float total_cos = 1.0f, total_sin = 0.0f; + float total_zoom = 1.0f; + float ripple_dx = 0.0f, ripple_dy = 0.0f; + +''' + + # Add effect-specific code + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'rotate': + code += f''' + // Rotate {i} + {{ + float angle = rotate_angle * 3.14159265f / 180.0f; + float c = cosf(angle); + float s = sinf(angle); + float nc = total_cos * c - total_sin * s; + float ns = total_cos * s + total_sin * c; + total_cos = nc; + total_sin = ns; + }} +''' + elif op == 'zoom': + code += f''' + // Zoom {i} + {{ + total_zoom *= zoom_amount; + }} +''' + elif op == 'ripple': + amp = float(effect.get('amplitude', 10)) + freq = float(effect.get('frequency', 8)) + decay = float(effect.get('decay', 2)) + rcx = float(effect.get('center_x', width/2)) + rcy = float(effect.get('center_y', height/2)) + code += f''' + // Ripple {i} + {{ + float amplitude = {amp:.1f}f; + float frequency = {freq:.1f}f; + float decay_val = {decay:.1f}f; + float rcx = {rcx:.1f}f; + float rcy = {rcy:.1f}f; + + float rdx = src_x - rcx; + float rdy = src_y - rcy; + float dist = sqrtf(rdx * rdx + rdy * rdy); + + float wave = sinf(dist * frequency * 0.1f + ripple_phase); + float amp = amplitude * expf(-dist * decay_val * 0.01f); + + if (dist > 0.001f) {{ + ripple_dx += rdx / dist * wave * amp; + ripple_dy += rdy / dist * wave * amp; + }} + }} +''' + + # Apply geometric transforms + code += ''' + // Apply accumulated transforms + { + float dx = src_x - cx; + float dy = src_y - cy; + float rx = total_cos * dx + total_sin * dy; + float ry = -total_sin * dx + total_cos * dy; + rx /= total_zoom; + ry /= total_zoom; + src_x = rx + cx - ripple_dx; + src_y = ry + cy - ripple_dy; + } + + // Bilinear sample + float r, g, b; + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + r = g = b = 0; + } else { + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + int idx00 = (y0 * width + x0) * 3; + int idx10 = (y0 * width + x0 + 1) * 3; + int idx01 = ((y0 + 1) * width + x0) * 3; + int idx11 = ((y0 + 1) * width + x0 + 1) * 3; + + #define BILERP(c) \\ + (src[idx00 + c] * (1-fx) * (1-fy) + \\ + src[idx10 + c] * fx * (1-fy) + \\ + src[idx01 + c] * (1-fx) * fy + \\ + src[idx11 + c] * fx * fy) + + r = BILERP(0); + g = BILERP(1); + b = BILERP(2); + } + +''' + + # Add color transforms + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'hue_shift': + degrees = float(effect.get('degrees', 0)) + code += f''' + // Hue shift {i} + {{ + float shift = {degrees:.1f}f; + float rf = r / 255.0f; + float gf = g / 255.0f; + float bf = b / 255.0f; + + float max_c = fmaxf(rf, fmaxf(gf, bf)); + float min_c = fminf(rf, fminf(gf, bf)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) {{ + s = delta / max_c; + if (rf >= max_c) h = (gf - bf) / delta; + else if (gf >= max_c) h = 2.0f + (bf - rf) / delta; + else h = 4.0f + (rf - gf) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + }} + + h = fmodf(h + shift + 360.0f, 360.0f); + + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) {{ r2 = c; g2 = x_val; b2 = 0; }} + else if (h < 120) {{ r2 = x_val; g2 = c; b2 = 0; }} + else if (h < 180) {{ r2 = 0; g2 = c; b2 = x_val; }} + else if (h < 240) {{ r2 = 0; g2 = x_val; b2 = c; }} + else if (h < 300) {{ r2 = x_val; g2 = 0; b2 = c; }} + else {{ r2 = c; g2 = 0; b2 = x_val; }} + + r = (r2 + m) * 255.0f; + g = (g2 + m) * 255.0f; + b = (b2 + m) * 255.0f; + }} +''' + elif op == 'brightness': + code += ''' + // Brightness + { + r *= brightness_factor; + g *= brightness_factor; + b *= brightness_factor; + } +''' + + # Write output + code += ''' + // Write output + int dst_idx = (y * width + x) * 3; + dst[dst_idx] = (unsigned char)fminf(255.0f, fmaxf(0.0f, r)); + dst[dst_idx + 1] = (unsigned char)fminf(255.0f, fmaxf(0.0f, g)); + dst[dst_idx + 2] = (unsigned char)fminf(255.0f, fmaxf(0.0f, b)); +} +''' + + return code + + +# Test the compiler +if __name__ == '__main__': + import time + + print("[sexp_to_cuda] Testing fused kernel compiler...") + print("=" * 60) + + # Define a test pipeline + effects = [ + {'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 30.0}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + frame = cp.random.randint(0, 255, (1080, 1920, 3), dtype=cp.uint8) + + # ===== Test 1: Standard fused kernel (params passed from Python) ===== + print("\n[Test 1] Standard fused kernel (Python computes params)") + pipeline = compile_frame_pipeline(effects, 1920, 1080) + + # Warmup + output = pipeline(frame) + cp.cuda.Stream.null.synchronize() + + # Benchmark with Python param computation + start = time.time() + for i in range(100): + # Simulate Python computing params (like sexp interpreter does) + import math + t = i / 30.0 + angle = t * 30 + phase = t * 2 + brightness = 0.8 + 0.4 * math.sin(t * 2) + output = pipeline(frame, rotate_angle=angle, ripple_phase=phase) + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + print(f" Time: {elapsed/100*1000:.2f}ms per frame") + print(f" FPS: {100/elapsed:.0f}") + + # ===== Test 2: Autonomous kernel (GPU computes everything) ===== + print("\n[Test 2] Autonomous kernel (GPU computes ALL params)") + + dynamic_expressions = { + 'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)', + } + + auto_pipeline = compile_autonomous_pipeline(effects, 1920, 1080, dynamic_expressions) + + # Warmup + output = auto_pipeline(frame, 0, 30.0) + cp.cuda.Stream.null.synchronize() + + # Benchmark - NO Python computation in loop! + start = time.time() + for i in range(100): + output = auto_pipeline(frame, i, 30.0) # Just pass frame_num! + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + print(f" Time: {elapsed/100*1000:.2f}ms per frame") + print(f" FPS: {100/elapsed:.0f}") + + print("\n" + "=" * 60) + print("Autonomous kernel eliminates Python from hot path!") diff --git a/l1/streaming/sexp_to_jax.py b/l1/streaming/sexp_to_jax.py new file mode 100644 index 0000000..db781f2 --- /dev/null +++ b/l1/streaming/sexp_to_jax.py @@ -0,0 +1,4628 @@ +""" +Sexp to JAX Compiler. + +Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU. +Uses XLA compilation via @jax.jit for automatic kernel fusion. + +Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles +S-expressions directly to JAX operations which XLA then optimizes. + +Usage: + from streaming.sexp_to_jax import compile_effect + + effect_code = ''' + (effect "threshold" + :params ((threshold :default 128)) + :body (let ((g (gray frame))) + (rgb (where (> g threshold) 255 0) + (where (> g threshold) 255 0) + (where (> g threshold) 255 0)))) + ''' + + run_effect = compile_effect(effect_code) + output = run_effect(frame, threshold=128) +""" + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial +from typing import Any, Dict, List, Callable, Optional, Tuple +import hashlib +import numpy as np + +# Import parser +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) +from sexp_effects.parser import parse, parse_all, Symbol, Keyword + +# Import typography primitives +from streaming.jax_typography import bind_typography_primitives + + +# ============================================================================= +# Compilation Cache +# ============================================================================= + +_COMPILED_EFFECTS: Dict[str, Callable] = {} + + +# ============================================================================= +# Font Atlas for ASCII Effects +# ============================================================================= + +# Character sets for ASCII rendering +ASCII_ALPHABETS = { + 'standard': ' .:-=+*#%@', + 'blocks': ' ░▒▓█', + 'simple': ' .:oO@', + 'digits': ' 0123456789', + 'binary': ' 01', + 'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$', +} + +# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array +_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {} + + +def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray: + """ + Create a font atlas with all characters pre-rendered. + + Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time. + + Args: + alphabet: String of characters to render (ordered by brightness, dark to light) + char_size: Size of each character cell in pixels + font_name: Optional font name/path (uses default monospace if None) + + Returns: + NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters + Each character is white on black background. + """ + cache_key = (alphabet, char_size, font_name) + if cache_key in _FONT_ATLAS_CACHE: + return _FONT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + # Fallback: create simple block-based atlas without PIL + return _create_block_atlas(alphabet, char_size) + + num_chars = len(alphabet) + atlas = [] + + # Try to load a monospace font + font = None + font_size = int(char_size * 0.9) # Slightly smaller than cell + + # Try various monospace fonts + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf', + '/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf', + '/System/Library/Fonts/Menlo.ttc', # macOS + '/System/Library/Fonts/Monaco.dfont', # macOS + 'C:\\Windows\\Fonts\\consola.ttf', # Windows + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + # Use default font + try: + font = ImageFont.load_default() + except: + # Ultimate fallback to blocks + return _create_block_atlas(alphabet, char_size) + + for char in alphabet: + # Create image for this character + img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Get text bounding box for centering + try: + bbox = draw.textbbox((0, 0), char, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + except AttributeError: + # Older PIL versions + text_width, text_height = draw.textsize(char, font=font) + + # Center the character + x = (char_size - text_width) // 2 + y = (char_size - text_height) // 2 + + # Draw white character on black background + draw.text((x, y), char, fill=(255, 255, 255), font=font) + + # Convert to numpy array (NOT jax array - avoids tracer issues) + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) + _FONT_ATLAS_CACHE[cache_key] = atlas + return atlas + + +def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray: + """ + Create a simple block-based atlas without fonts. + Uses numpy to avoid tracer issues. + """ + num_chars = len(alphabet) + atlas = [] + + for i, char in enumerate(alphabet): + # Brightness proportional to position in alphabet + brightness = int(255 * i / max(num_chars - 1, 1)) + + # Create a simple pattern based on character + img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8) + + # Add some texture/pattern for visual interest + # Checkerboard pattern for mid-range characters + if 0.2 < i / num_chars < 0.8: + y_coords, x_coords = np.mgrid[:char_size, :char_size] + checker = ((x_coords + y_coords) % 2 == 0) + variation = int(brightness * 0.2) + img = np.where(checker[:, :, None], + np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8), + np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8)) + + atlas.append(img) + + return np.stack(atlas, axis=0) + + +def _get_alphabet_string(alphabet_name: str) -> str: + """Get the character string for a named alphabet or return as-is if custom.""" + if alphabet_name in ASCII_ALPHABETS: + return ASCII_ALPHABETS[alphabet_name] + return alphabet_name # Assume it's a custom character string + + +# ============================================================================= +# Text Rendering with Font Atlas (JAX-compatible) +# ============================================================================= + +# Default character set for text rendering (printable ASCII) +TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~ + +# Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height) +_TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {} + + +def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple: + """ + Create a font atlas for general text rendering with proper baseline alignment. + + Font Metrics (from typography): + - Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l) + - Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y) + - Baseline: the line text "sits" on - all characters align to this + - Em-square: the design space, typically = ascender + descender + + Returns: + (atlas, char_to_idx, char_widths, char_height, baseline_offset) + - atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA + - char_to_idx: dict mapping character to atlas index + - char_widths: numpy array of actual width for each character + - char_height: height of character cells (ascent + descent) + - baseline_offset: pixels from top of cell to baseline (= ascent) + """ + cache_key = (font_name, font_size) + if cache_key in _TEXT_ATLAS_CACHE: + return _TEXT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + raise ImportError("PIL/Pillow required for text rendering") + + # Load font - match drawing.py's font order for consistency + font = None + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + '/System/Library/Fonts/Helvetica.ttc', + '/System/Library/Fonts/Arial.ttf', + 'C:\\Windows\\Fonts\\arial.ttf', + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + font = ImageFont.load_default() + + # Get font metrics - this is the key to proper text layout + # getmetrics() returns (ascent, descent) where: + # ascent = pixels from baseline to top of tallest character + # descent = pixels from baseline to bottom of lowest character + ascent, descent = font.getmetrics() + + # Cell dimensions based on font metrics (not per-character bounding boxes) + cell_height = ascent + descent + 2 # +2 for padding + baseline_y = ascent + 1 # Baseline position within cell (1px padding from top) + + # Find max character width + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + max_width = 0 + char_widths_dict = {} + + for char in TEXT_CHARSET: + try: + # Use getlength for horizontal advance (proper character spacing) + advance = font.getlength(char) + char_widths_dict[char] = int(advance) + max_width = max(max_width, int(advance)) + except: + char_widths_dict[char] = font_size // 2 + max_width = max(max_width, font_size // 2) + + cell_width = max_width + 2 # +2 for padding + + # Create atlas with all characters - draw same way as prim_text for pixel-perfect match + char_to_idx = {} + char_widths = [] # Advance widths + char_left_bearings = [] # Left bearing (x offset from origin to first pixel) + atlas = [] + + # Position to draw at within each tile (with margin for negative bearings) + draw_x = 5 # Margin for chars with negative left bearing + draw_y = 0 # Top of cell (PIL default without anchor) + + for i, char in enumerate(TEXT_CHARSET): + char_to_idx[char] = i + char_widths.append(char_widths_dict.get(char, cell_width // 2)) + + # Create RGBA image for this character + img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw same way as prim_text - at (draw_x, draw_y), no anchor + # This positions the text origin, and glyphs may extend left/right from there + draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + # Get bbox to find left bearing + bbox = draw.textbbox((draw_x, draw_y), char, font=font) + left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends + char_left_bearings.append(left_bearing) + + # Convert to numpy + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4) + char_widths = np.array(char_widths, dtype=np.int32) + char_left_bearings = np.array(char_left_bearings, dtype=np.int32) + + # Return draw_x (origin offset within tile) so rendering knows where origin is + result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings) + _TEXT_ATLAS_CACHE[cache_key] = result + return result + + +def jax_text_render(frame, text: str, x: int, y: int, + font_name: str = None, font_size: int = 32, + color=(255, 255, 255), opacity: float = 1.0, + align: str = "left", valign: str = "baseline", + shadow: bool = False, shadow_color=(0, 0, 0), + shadow_offset: int = 2): + """ + Render text onto frame using font atlas (JAX-compatible). + + This is designed to be called from within a JIT-compiled function. + The font atlas is created at compile time (using numpy/PIL), + then converted to JAX array for the actual rendering. + + Typography notes: + - Baseline: The line text "sits" on. Most characters rest on this line. + - Ascender: Top of tall letters (b, d, h, k, l) - above baseline + - Descender: Bottom of letters like g, j, p, q, y - below baseline + - For normal text, use valign="baseline" and y = the baseline position + + Args: + frame: Input frame (H, W, 3) + text: Text string to render + x, y: Position reference point (affected by align/valign) + font_name: Font to use (None = default) + font_size: Font size in pixels + color: RGB tuple (0-255) + opacity: 0.0 to 1.0 + align: Horizontal alignment relative to x: + "left" - text starts at x + "center" - text centered on x + "right" - text ends at x + valign: Vertical alignment relative to y: + "baseline" - text baseline at y (default, like normal text) + "top" - top of ascenders at y + "middle" - text vertically centered on y + "bottom" - bottom of descenders at y + shadow: Whether to draw drop shadow + shadow_color: Shadow RGB color + shadow_offset: Shadow offset in pixels + + Returns: + Frame with text rendered + """ + if not text: + return frame + + h, w = frame.shape[:2] + + # Get or create font atlas (this happens at trace time, uses numpy) + atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size) + + # Convert atlas to JAX array + atlas = jnp.asarray(atlas_np) + + # Atlas dimensions + cell_width = atlas.shape[2] + + # Convert text to character indices and compute character widths + # (at trace time, text is static so we can pre-compute) + indices_list = [] + char_x_offsets = [0] # Starting x position for each character + total_width = 0 + + for char in text: + if char in char_to_idx: + idx = char_to_idx[char] + indices_list.append(idx) + char_w = int(char_widths_np[idx]) + else: + indices_list.append(char_to_idx.get(' ', 0)) + char_w = int(char_widths_np[char_to_idx.get(' ', 0)]) + total_width += char_w + char_x_offsets.append(total_width) + + indices = jnp.array(indices_list, dtype=jnp.int32) + num_chars = len(indices_list) + + # Actual text dimensions using proportional widths + text_width = total_width + text_height = char_height + + # Adjust position for horizontal alignment + if align == "center": + x = x - text_width // 2 + elif align == "right": + x = x - text_width + + # Adjust position for vertical alignment + # baseline_offset = pixels from top of cell to baseline + if valign == "baseline": + # y specifies baseline position, so top of text cell is above it + y = y - baseline_offset + elif valign == "middle": + y = y - text_height // 2 + elif valign == "bottom": + y = y - text_height + # valign == "top" needs no adjustment (default) + + # Ensure position is integer + x, y = int(x), int(y) + + # Create text strip with proper character spacing at trace time (using numpy) + # This ensures proportional fonts render correctly + # + # The atlas stores each character drawn at (origin_x, 0) in its tile. + # To place a character at cursor position 'cx': + # - The tile's origin_x should align with cx in the strip + # - So we blit tile to strip starting at (cx - origin_x) + # + # Add padding for characters with negative left bearings + strip_padding = origin_x # Extra space at start for negative bearings + text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8) + + for i, char in enumerate(text): + if char in char_to_idx: + idx = char_to_idx[char] + char_tile = atlas_np[idx] # (char_height, cell_width, 4) + cx = char_x_offsets[i] + # Position tile so its origin aligns with cursor position + strip_x = strip_padding + cx - origin_x + if strip_x >= 0: + end_x = min(strip_x + cell_width, text_strip_np.shape[1]) + tile_end = end_x - strip_x + text_strip_np[:, strip_x:end_x] = np.maximum( + text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end]) + + # Trim the strip: + # - Left side: trim to first visible pixel (handles negative left bearing) + # - Right side: use computed text_width (preserve advance width spacing) + alpha = text_strip_np[:, :, 3] + cols_with_content = np.any(alpha > 0, axis=0) + if cols_with_content.any(): + first_col = np.argmax(cols_with_content) + # Right edge: use the computed text width from the strip's logical end + right_col = strip_padding + text_width + # Adjust x to account for the left trim offset + x = x + first_col - strip_padding + text_strip_np = text_strip_np[:, first_col:right_col] + else: + # No visible content, return original frame + return frame + + # Convert to JAX + text_strip = jnp.asarray(text_strip_np) + + # Convert color to array + color = jnp.array(color, dtype=jnp.float32) + shadow_color = jnp.array(shadow_color, dtype=jnp.float32) + + # Apply color tint to text strip (white text * color) + text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color + text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity + + # Start with frame as float + result = frame.astype(jnp.float32) + + # Draw shadow first if enabled + if shadow: + sx, sy = x + shadow_offset, y + shadow_offset + shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color + shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5 + result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy) + + # Draw main text + result = _composite_text_strip(result, text_rgb, text_alpha, x, y) + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def _composite_text_strip(frame, text_rgb, text_alpha, x, y): + """ + Composite text strip onto frame at position (x, y). + + Uses alpha blending: result = text * alpha + frame * (1 - alpha) + + This is designed to work within JAX tracing. + """ + h, w = frame.shape[:2] + th, tw = text_rgb.shape[:2] + + # Clamp to frame bounds + # Source region (in text strip) + src_x1 = jnp.maximum(0, -x) + src_y1 = jnp.maximum(0, -y) + src_x2 = jnp.minimum(tw, w - x) + src_y2 = jnp.minimum(th, h - y) + + # Destination region (in frame) + dst_x1 = jnp.maximum(0, x) + dst_y1 = jnp.maximum(0, y) + dst_x2 = jnp.minimum(w, x + tw) + dst_y2 = jnp.minimum(h, y + th) + + # Check if there's anything to draw + # (We need to handle this carefully for JAX - can't use Python if with traced values) + # Instead, we'll do the full operation but the slicing will handle bounds + + # Create coordinate grids for the destination region + # We'll use dynamic_slice for JAX-compatible slicing + + # For simplicity and JAX compatibility, we'll create a full-frame text layer + # and composite it - this is less efficient but works with JIT + + # Create full-frame RGBA layer + text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Place text strip in the layer using dynamic_update_slice + # First pad the text strip to handle out-of-bounds + padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + padded_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Calculate valid region + y_start = int(max(0, y)) + y_end = int(min(h, y + th)) + x_start = int(max(0, x)) + x_end = int(min(w, x + tw)) + + src_y_start = int(max(0, -y)) + src_y_end = src_y_start + (y_end - y_start) + src_x_start = int(max(0, -x)) + src_x_end = src_x_start + (x_end - x_start) + + # Only proceed if there's a valid region + if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start: + # Extract the valid portion of text + valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end] + valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end] + + # Use lax.dynamic_update_slice for JAX compatibility + padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0)) + padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start)) + + # Alpha composite: result = text * alpha + frame * (1 - alpha) + alpha_3d = padded_alpha[:, :, jnp.newaxis] + result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d) + + return result + + +def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple: + """ + Measure text dimensions (width, height). + + This can be called at compile time to get text dimensions for layout. + + Returns: + (width, height) tuple in pixels + """ + _, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size) + + # Sum actual character widths + total_width = 0 + for c in text: + if c in char_to_idx: + total_width += int(char_widths[char_to_idx[c]]) + else: + total_width += int(char_widths[char_to_idx.get(' ', 0)]) + + return (total_width, char_height) + + +def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict: + """ + Get font metrics for layout calculations. + + Typography terms: + - ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.) + - descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.) + - height: total height = ascent + descent (plus padding) + - baseline: position of baseline from top of text cell + + Returns: + dict with keys: ascent, descent, height, baseline + """ + _, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size) + + # baseline_offset is pixels from top to baseline (= ascent + padding) + # descent = height - baseline (approximately) + ascent = baseline_offset - 1 # remove padding + descent = char_height - baseline_offset - 1 # remove padding + + return { + 'ascent': ascent, + 'descent': descent, + 'height': char_height, + 'baseline': baseline_offset, + } + + +def jax_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, max_size: int = 200) -> int: + """ + Calculate font size to fit text within bounds. + + Binary search for largest size that fits. + """ + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + w, h = jax_text_size(text, font_name, mid) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + +# ============================================================================= +# JAX Primitives - True primitives that can't be derived +# ============================================================================= + +def jax_width(frame): + """Frame width.""" + return frame.shape[1] + + +def jax_height(frame): + """Frame height.""" + return frame.shape[0] + + +def jax_channel(frame, idx): + """Extract channel by index as flat array.""" + # idx must be a static int for indexing + return frame[:, :, int(idx)].flatten().astype(jnp.float32) + + +def jax_merge_channels(r, g, b, shape): + """Merge RGB channels back to frame.""" + h, w = shape + r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8) + g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8) + b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([r_img, g_img, b_img], axis=2) + + +def jax_iota(n): + """Generate [0, 1, 2, ..., n-1].""" + return jnp.arange(n, dtype=jnp.float32) + + +def jax_repeat(x, n): + """Repeat each element n times: [a,b] -> [a,a,b,b].""" + return jnp.repeat(x, n) + + +def jax_tile(x, n): + """Tile array n times: [a,b] -> [a,b,a,b].""" + return jnp.tile(x, n) + + +def jax_gather(data, indices): + """Parallel index lookup.""" + flat_data = data.flatten() + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1) + return flat_data[idx_clipped] + + +def jax_scatter(indices, values, size): + """Parallel index write (last write wins).""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].set(values) + + +def jax_scatter_add(indices, values, size): + """Parallel index accumulate.""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].add(values) + + +def jax_group_reduce(values, group_indices, num_groups, op='mean'): + """Reduce values by group.""" + grp = group_indices.astype(jnp.int32) + + if op == 'sum': + result = jnp.zeros(num_groups, dtype=jnp.float32) + return result.at[grp].add(values) + elif op == 'mean': + sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values) + counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + elif op == 'max': + result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32) + result = result.at[grp].max(values) + return jnp.where(result == -jnp.inf, 0.0, result) + elif op == 'min': + result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32) + result = result.at[grp].min(values) + return jnp.where(result == jnp.inf, 0.0, result) + else: + raise ValueError(f"Unknown reduce op: {op}") + + +def jax_where(cond, true_val, false_val): + """Conditional select.""" + return jnp.where(cond, true_val, false_val) + + +def jax_cell_indices(frame, cell_size): + """Compute cell index for each pixel.""" + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + + cell_row = y_coords // cell_size + cell_col = x_coords // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32) + + +def jax_pool_frame(frame, cell_size): + """ + Pool frame to cell values. + Returns tuple: (cell_r, cell_g, cell_b, cell_lum) + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + cell_row = jnp.clip(y_coords // cs, 0, rows - 1) + cell_col = jnp.clip(x_coords // cs, 0, cols - 1) + cell_idx = (cell_row * cols + cell_col).astype(jnp.int32) + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(jnp.float32) + g_flat = frame[:, :, 1].flatten().astype(jnp.float32) + b_flat = frame[:, :, 2].flatten().astype(jnp.float32) + + # Pool each channel (mean) + def pool_channel(data): + sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data) + counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + return (r_pooled, g_pooled, b_pooled, lum) + + +# ============================================================================= +# Scan (Prefix Operations) - JAX implementations +# ============================================================================= + +def jax_scan_add(x, axis=None): + """Cumulative sum (prefix sum).""" + if axis is not None: + return jnp.cumsum(x, axis=int(axis)) + return jnp.cumsum(x.flatten()) + + +def jax_scan_mul(x, axis=None): + """Cumulative product.""" + if axis is not None: + return jnp.cumprod(x, axis=int(axis)) + return jnp.cumprod(x.flatten()) + + +def jax_scan_max(x, axis=None): + """Cumulative maximum.""" + if axis is not None: + return lax.cummax(x, axis=int(axis)) + return lax.cummax(x.flatten(), axis=0) + + +def jax_scan_min(x, axis=None): + """Cumulative minimum.""" + if axis is not None: + return lax.cummin(x, axis=int(axis)) + return lax.cummin(x.flatten(), axis=0) + + +# ============================================================================= +# Outer Product - JAX implementations +# ============================================================================= + +def jax_outer(x, y, op='*'): + """Outer product with configurable operation.""" + x_flat = x.flatten() + y_flat = y.flatten() + + ops = { + '*': lambda a, b: jnp.outer(a, b), + '+': lambda a, b: a[:, None] + b[None, :], + '-': lambda a, b: a[:, None] - b[None, :], + '/': lambda a, b: a[:, None] / b[None, :], + 'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]), + 'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]), + } + + op_fn = ops.get(op, ops['*']) + return op_fn(x_flat, y_flat) + + +def jax_outer_add(x, y): + """Outer sum.""" + return jax_outer(x, y, '+') + + +def jax_outer_mul(x, y): + """Outer product.""" + return jax_outer(x, y, '*') + + +def jax_outer_max(x, y): + """Outer max.""" + return jax_outer(x, y, 'max') + + +def jax_outer_min(x, y): + """Outer min.""" + return jax_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - JAX implementations +# ============================================================================= + +def jax_reduce_axis(x, op='sum', axis=0): + """Reduce along an axis.""" + axis = int(axis) + ops = { + 'sum': lambda d: jnp.sum(d, axis=axis), + '+': lambda d: jnp.sum(d, axis=axis), + 'mean': lambda d: jnp.mean(d, axis=axis), + 'max': lambda d: jnp.max(d, axis=axis), + 'min': lambda d: jnp.min(d, axis=axis), + 'prod': lambda d: jnp.prod(d, axis=axis), + '*': lambda d: jnp.prod(d, axis=axis), + 'std': lambda d: jnp.std(d, axis=axis), + } + op_fn = ops.get(op, ops['sum']) + return op_fn(x) + + +def jax_sum_axis(x, axis=0): + """Sum along axis.""" + return jnp.sum(x, axis=int(axis)) + + +def jax_mean_axis(x, axis=0): + """Mean along axis.""" + return jnp.mean(x, axis=int(axis)) + + +def jax_max_axis(x, axis=0): + """Max along axis.""" + return jnp.max(x, axis=int(axis)) + + +def jax_min_axis(x, axis=0): + """Min along axis.""" + return jnp.min(x, axis=int(axis)) + + +# ============================================================================= +# Windowed Operations - JAX implementations +# ============================================================================= + +def jax_window(x, size, op='mean', stride=1): + """ + Sliding window operation. + + For 1D arrays: standard sliding window + For 2D arrays: 2D sliding window (size x size) + """ + size = int(size) + stride = int(stride) + + if x.ndim == 1: + # 1D sliding window using convolution trick + n = len(x) + if op == 'sum': + kernel = jnp.ones(size) + return jnp.convolve(x, kernel, mode='valid')[::stride] + elif op == 'mean': + kernel = jnp.ones(size) / size + return jnp.convolve(x, kernel, mode='valid')[::stride] + else: + # For max/min, use manual approach + out_n = (n - size) // stride + 1 + indices = jnp.arange(out_n) * stride + windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices) + if op == 'max': + return jnp.max(windows, axis=1) + elif op == 'min': + return jnp.min(windows, axis=1) + else: + return jnp.mean(windows, axis=1) + else: + # 2D sliding window + h, w = x.shape[:2] + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + # Extract all windows using vmap + def extract_window(ij): + i, j = ij // out_w, ij % out_w + return lax.dynamic_slice(x, (i * stride, j * stride), (size, size)) + + indices = jnp.arange(out_h * out_w) + windows = jax.vmap(extract_window)(indices) + + if op == 'sum': + result = jnp.sum(windows, axis=(1, 2)) + elif op == 'mean': + result = jnp.mean(windows, axis=(1, 2)) + elif op == 'max': + result = jnp.max(windows, axis=(1, 2)) + elif op == 'min': + result = jnp.min(windows, axis=(1, 2)) + else: + result = jnp.mean(windows, axis=(1, 2)) + + return result.reshape(out_h, out_w) + + +def jax_window_sum(x, size, stride=1): + """Sliding window sum.""" + return jax_window(x, size, 'sum', stride) + + +def jax_window_mean(x, size, stride=1): + """Sliding window mean.""" + return jax_window(x, size, 'mean', stride) + + +def jax_window_max(x, size, stride=1): + """Sliding window max.""" + return jax_window(x, size, 'max', stride) + + +def jax_window_min(x, size, stride=1): + """Sliding window min.""" + return jax_window(x, size, 'min', stride) + + +def jax_integral_image(frame): + """ + Compute integral image (summed area table). + Enables O(1) box blur at any radius. + """ + if frame.ndim == 3: + # Convert to grayscale + gray = jnp.mean(frame.astype(jnp.float32), axis=2) + else: + gray = frame.astype(jnp.float32) + + # Cumsum along both axes + return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1) + + +def jax_sample(frame, x, y): + """Bilinear sample at (x, y) coordinates. + + Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default): + out-of-bounds samples return 0, then bilinear blend includes those zeros. + """ + h, w = frame.shape[:2] + + # Get integer coords for the 4 sample points + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = x0 + 1 + y1 = y0 + 1 + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + # Check which sample points are in bounds + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + # Clamp indices for safe array access (values will be masked anyway) + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + + # Bilinear interpolation for each channel + def interp_channel(c): + # Sample with 0 for out-of-bounds (BORDER_CONSTANT) + c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0) + c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0) + c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0) + c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0) + + return (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + r = interp_channel(0) + g = interp_channel(1) + b = interp_channel(2) + + return r, g, b + + +# ============================================================================= +# Convolution Operations +# ============================================================================= + +def jax_convolve2d(data, kernel): + """2D convolution on a single channel.""" + # data shape: (H, W), kernel shape: (kH, kW) + # Use JAX's conv with appropriate padding + h, w = data.shape + kh, kw = kernel.shape + + # Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c) + data_4d = data.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(kh, kw, 1, 1) + + # Convolve with 'SAME' padding + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + + return result.reshape(h, w) + + +def jax_blur(frame, radius=1): + """Gaussian blur.""" + # Create gaussian kernel + size = int(radius) * 2 + 1 + x = jnp.arange(size) - radius + gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_sharpen(frame, amount=1.0): + """Sharpen using unsharp mask.""" + kernel = jnp.array([ + [0, -1, 0], + [-1, 5, -1], + [0, -1, 0] + ], dtype=jnp.float32) + + # Adjust kernel based on amount + center = 4 * amount + 1 + kernel = kernel.at[1, 1].set(center) + kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_edge_detect(frame): + """Sobel edge detection.""" + # Sobel kernels + sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) + sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) + + # Convert to grayscale first + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + gx = jax_convolve2d(gray, sobel_x) + gy = jax_convolve2d(gray, sobel_y) + + edges = jnp.sqrt(gx**2 + gy**2) + edges = jnp.clip(edges, 0, 255).astype(jnp.uint8) + + return jnp.stack([edges, edges, edges], axis=2) + + +def jax_emboss(frame): + """Emboss effect.""" + kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32) + + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + embossed = jax_convolve2d(gray, kernel) + 128 + embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8) + + return jnp.stack([embossed, embossed, embossed], axis=2) + + +# ============================================================================= +# Color Space Conversion +# ============================================================================= + +def jax_rgb_to_hsv(r, g, b): + """Convert RGB to HSV. All inputs/outputs are 0-255 range.""" + r, g, b = r / 255.0, g / 255.0, b / 255.0 + + max_c = jnp.maximum(jnp.maximum(r, g), b) + min_c = jnp.minimum(jnp.minimum(r, g), b) + diff = max_c - min_c + + # Value + v = max_c + + # Saturation + s = jnp.where(max_c > 0, diff / max_c, 0.0) + + # Hue + h = jnp.where(diff == 0, 0.0, + jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360, + jnp.where(max_c == g, 60 * ((b - r) / diff) + 120, + 60 * ((r - g) / diff) + 240))) + + return h, s * 255, v * 255 + + +def jax_hsv_to_rgb(h, s, v): + """Convert HSV to RGB. H is 0-360, S and V are 0-255.""" + h = h % 360 + s, v = s / 255.0, v / 255.0 + + c = v * s + x = c * (1 - jnp.abs((h / 60) % 2 - 1)) + m = v - c + + h_sector = (h / 60).astype(jnp.int32) % 6 + + r = jnp.where(h_sector == 0, c, + jnp.where(h_sector == 1, x, + jnp.where(h_sector == 2, 0, + jnp.where(h_sector == 3, 0, + jnp.where(h_sector == 4, x, c))))) + + g = jnp.where(h_sector == 0, x, + jnp.where(h_sector == 1, c, + jnp.where(h_sector == 2, c, + jnp.where(h_sector == 3, x, + jnp.where(h_sector == 4, 0, 0))))) + + b = jnp.where(h_sector == 0, 0, + jnp.where(h_sector == 1, 0, + jnp.where(h_sector == 2, x, + jnp.where(h_sector == 3, c, + jnp.where(h_sector == 4, c, x))))) + + return (r + m) * 255, (g + m) * 255, (b + m) * 255 + + +def jax_adjust_saturation(frame, factor): + """Adjust saturation by factor (1.0 = unchanged).""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + s = jnp.clip(s * factor, 0, 255) + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +def jax_shift_hue(frame, degrees): + """Shift hue by degrees.""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + h = (h + degrees) % 360 + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Color Adjustment Operations +# ============================================================================= + +def jax_adjust_brightness(frame, amount): + """Adjust brightness by amount (-255 to 255).""" + result = frame.astype(jnp.float32) + amount + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_adjust_contrast(frame, factor): + """Adjust contrast by factor (1.0 = unchanged).""" + result = (frame.astype(jnp.float32) - 128) * factor + 128 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_invert(frame): + """Invert colors.""" + return 255 - frame + + +def jax_posterize(frame, levels): + """Reduce to N color levels per channel.""" + levels = int(levels) + if levels < 2: + levels = 2 + step = 255.0 / (levels - 1) + result = jnp.round(frame.astype(jnp.float32) / step) * step + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_threshold(frame, level, invert=False): + """Binary threshold.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + if invert: + binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8) + else: + binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8) + + return jnp.stack([binary, binary, binary], axis=2) + + +def jax_sepia(frame): + """Apply sepia tone.""" + r = frame[:, :, 0].astype(jnp.float32) + g = frame[:, :, 1].astype(jnp.float32) + b = frame[:, :, 2].astype(jnp.float32) + + new_r = r * 0.393 + g * 0.769 + b * 0.189 + new_g = r * 0.349 + g * 0.686 + b * 0.168 + new_b = r * 0.272 + g * 0.534 + b * 0.131 + + return jnp.stack([ + jnp.clip(new_r, 0, 255).astype(jnp.uint8), + jnp.clip(new_g, 0, 255).astype(jnp.uint8), + jnp.clip(new_b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_grayscale(frame): + """Convert to grayscale.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + gray = gray.astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + +# ============================================================================= +# Geometry Operations +# ============================================================================= + +def jax_flip_horizontal(frame): + """Flip horizontally.""" + return frame[:, ::-1, :] + + +def jax_flip_vertical(frame): + """Flip vertically.""" + return frame[::-1, :, :] + + +def jax_rotate(frame, angle, center_x=None, center_y=None): + """Rotate frame by angle (degrees), matching OpenCV convention. + + Positive angle = counter-clockwise rotation. + """ + h, w = frame.shape[:2] + if center_x is None: + center_x = w / 2 + if center_y is None: + center_y = h / 2 + + # Convert to radians + theta = angle * jnp.pi / 180 + cos_t, sin_t = jnp.cos(theta), jnp.sin(theta) + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]] + # For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]] + # So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx + # src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy + x_centered = x_coords - center_x + y_centered = y_coords - center_y + + src_x = cos_t * x_centered - sin_t * y_centered + center_x + src_y = sin_t * x_centered + cos_t * y_centered + center_y + + # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_scale(frame, scale_x, scale_y=None): + """Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds.""" + if scale_y is None: + scale_y = scale_x + + h, w = frame.shape[:2] + center_x, center_y = w / 2, h / 2 + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # Scale from center (inverse mapping: dst -> src) + src_x = (x_coords - center_x) / scale_x + center_x + src_y = (y_coords - center_y) / scale_y + center_y + + # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_resize(frame, new_width, new_height): + """Resize frame to new dimensions.""" + h, w = frame.shape[:2] + new_h, new_w = int(new_height), int(new_width) + + # Create coordinate grids for new size + y_coords = jnp.repeat(jnp.arange(new_h), new_w) + x_coords = jnp.tile(jnp.arange(new_w), new_h) + + # Map to source coordinates + src_x = x_coords * (w - 1) / (new_w - 1) + src_y = y_coords * (h - 1) / (new_h - 1) + + r, g, b = jax_sample(frame, src_x, src_y) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Blending Operations +# ============================================================================= + +def _resize_to_match(frame1, frame2): + """Resize frame2 to match frame1's dimensions if they differ. + + Uses jax.image.resize for bilinear interpolation. + Returns frame2 resized to frame1's shape. + """ + h1, w1 = frame1.shape[:2] + h2, w2 = frame2.shape[:2] + + # If same size, return as-is + if h1 == h2 and w1 == w2: + return frame2 + + # Resize frame2 to match frame1 + # jax.image.resize expects (height, width, channels) and target shape + return jax.image.resize( + frame2.astype(jnp.float32), + (h1, w1, frame2.shape[2]), + method='bilinear' + ).astype(jnp.uint8) + + +def jax_blend(frame1, frame2, alpha): + """Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2. + + Auto-resizes frame2 to match frame1 if dimensions differ. + """ + frame2 = _resize_to_match(frame1, frame2) + return (frame1.astype(jnp.float32) * (1 - alpha) + + frame2.astype(jnp.float32) * alpha).astype(jnp.uint8) + + +def jax_blend_add(frame1, frame2): + """Additive blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32) + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_multiply(frame1, frame2): + """Multiply blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_screen(frame1, frame2): + """Screen blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = 1 - (1 - f1) * (1 - f2) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +def jax_blend_overlay(frame1, frame2): + """Overlay blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = jnp.where(f1 < 0.5, + 2 * f1 * f2, + 1 - 2 * (1 - f1) * (1 - f2)) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +# ============================================================================= +# Utility +# ============================================================================= + +def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0): + """Create a JAX random key that varies with frame and operation. + + Uses jax.random.fold_in to mix frame_num (which may be traced) into the key. + This allows JIT compilation without recompiling for each frame. + + Args: + seed: Base seed for determinism (must be concrete) + frame_num: Frame number for variation (can be traced) + op_id: Operation ID for variation (must be concrete) + + Returns: + JAX PRNGKey + """ + # Create base key from seed and op_id (both concrete) + base_key = jax.random.PRNGKey(seed + op_id * 1000003) + # Fold in frame_num (can be traced value) + return jax.random.fold_in(base_key, frame_num) + + +def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42): + """Random float in [lo, hi), varies with frame.""" + key = make_jax_key(seed, frame_num, op_id) + return lo + jax.random.uniform(key) * (hi - lo) + + +def jax_is_nil(x): + """Check if value is None/nil.""" + return x is None + + +# ============================================================================= +# S-expression to JAX Compiler +# ============================================================================= + +class JaxCompiler: + """Compiles S-expressions to JAX functions.""" + + def __init__(self): + self.env = {} # Variable bindings during compilation + self.params = {} # Effect parameters + self.primitives = {} # Loaded primitive libraries + self.derived = {} # Loaded derived functions + + def load_derived(self, path: str): + """Load derived operations from a .sexp file.""" + with open(path, 'r') as f: + code = f.read() + exprs = parse_all(code) + + # Evaluate all define expressions to populate derived functions + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + self._eval_define(expr[1:], self.derived) + + def compile_effect(self, sexp) -> Callable: + """ + Compile an effect S-expression to a JAX function. + + Supports both formats: + (effect "name" :params (...) :body ...) + (define-effect name :params (...) body) + + Args: + sexp: Parsed S-expression + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + if not isinstance(sexp, list) or len(sexp) < 2: + raise ValueError("Effect must be a list") + + head = sexp[0] + if not isinstance(head, Symbol): + raise ValueError("Effect must start with a symbol") + + form = head.name + + # Handle both 'effect' and 'define-effect' formats + if form == 'effect': + # (effect "name" :params (...) :body ...) + name = sexp[1] if len(sexp) > 1 else "unnamed" + if isinstance(name, Symbol): + name = name.name + start_idx = 2 + elif form == 'define-effect': + # (define-effect name :params (...) body) + name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1]) + start_idx = 2 + else: + raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'") + + params_spec = [] + body = None + + i = start_idx + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword): + if item.name == 'params' and i + 1 < len(sexp): + params_spec = sexp[i + 1] + i += 2 + elif item.name == 'body' and i + 1 < len(sexp): + body = sexp[i + 1] + i += 2 + elif item.name in ('desc', 'type', 'range'): + # Skip metadata keywords + i += 2 + else: + i += 2 # Skip unknown keywords with their values + else: + # Assume it's the body if we haven't seen one + if body is None: + body = item + i += 1 + + if body is None: + raise ValueError(f"Effect '{name}' must have a body") + + # Extract parameter names, defaults, and static params (strings, bools) + param_info, static_params = self._parse_params(params_spec) + + # Capture derived functions for the closure + derived_fns = self.derived.copy() + + # Create the JAX function + def effect_fn(frame, **kwargs): + # Set up environment + h, w = frame.shape[:2] + # Get frame_num for deterministic random variation + frame_num = kwargs.get('frame_num', 0) + # Get seed from recipe config (passed via kwargs) + seed = kwargs.get('seed', 42) + env = { + 'frame': frame, + 'width': w, + 'height': h, + '_shape': (h, w), + # Time variables (default to 0, can be overridden via kwargs) + 't': kwargs.get('t', kwargs.get('_time', 0.0)), + '_time': kwargs.get('_time', kwargs.get('t', 0.0)), + 'time': kwargs.get('time', kwargs.get('t', 0.0)), + # Frame number for random key generation + 'frame_num': frame_num, + 'frame-num': frame_num, + '_frame_num': frame_num, + # Seed from recipe for deterministic random + '_seed': seed, + # Counter for unique random keys within same frame + '_rand_op_counter': 0, + # Common constants + 'pi': jnp.pi, + 'PI': jnp.pi, + } + + # Add derived functions + env.update(derived_fns) + + # Add typography primitives + bind_typography_primitives(env) + + # Add parameters with defaults + for pname, pdefault in param_info.items(): + if pname in kwargs: + env[pname] = kwargs[pname] + elif isinstance(pdefault, list): + # Unevaluated S-expression default - evaluate it + env[pname] = self._eval(pdefault, env) + else: + env[pname] = pdefault + + # Evaluate body + result = self._eval(body, env) + + # Ensure result is a frame + if isinstance(result, tuple) and len(result) == 3: + # RGB tuple - merge to frame + r, g, b = result + return jax_merge_channels(r, g, b, (h, w)) + elif result.ndim == 3: + return result + else: + # Single channel - replicate to RGB + h, w = env['_shape'] + gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # JIT compile with static args for string/bool parameters and seed + # seed must be static for PRNGKey, but frame_num can be traced via fold_in + all_static = set(static_params) | {'seed'} + return jax.jit(effect_fn, static_argnames=list(all_static)) + + def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]: + """Parse parameter specifications. + + Returns: + Tuple of (param_defaults, static_params) + - param_defaults: Dict mapping param names to default values + - static_params: Set of param names that should be static (strings, bools) + """ + result = {} + static_params = set() + if not isinstance(params_spec, list): + return result, static_params + + for param in params_spec: + if isinstance(param, Symbol): + result[param.name] = 0.0 + elif isinstance(param, list) and len(param) >= 1: + pname = param[0].name if isinstance(param[0], Symbol) else str(param[0]) + pdefault = 0.0 + ptype = None + + # Look for :default and :type keywords + i = 1 + while i < len(param): + if isinstance(param[i], Keyword): + kw = param[i].name + if kw == 'default' and i + 1 < len(param): + pdefault = param[i + 1] + if isinstance(pdefault, Symbol): + if pdefault.name == 'nil': + pdefault = None + elif pdefault.name == 'true': + pdefault = True + elif pdefault.name == 'false': + pdefault = False + i += 2 + elif kw == 'type' and i + 1 < len(param): + ptype = param[i + 1] + if isinstance(ptype, Symbol): + ptype = ptype.name + i += 2 + else: + i += 1 + else: + i += 1 + + result[pname] = pdefault + + # Mark string and bool parameters as static (can't be traced by JAX) + if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)): + static_params.add(pname) + + return result, static_params + + def _eval(self, expr, env: Dict[str, Any]) -> Any: + """Evaluate an S-expression in the given environment.""" + + # Already-evaluated values (e.g., from threading macros) + # JAX arrays, NumPy arrays, tuples, etc. + if hasattr(expr, 'shape'): # JAX/NumPy array + return expr + if isinstance(expr, tuple): # e.g., (r, g, b) from rgb + return expr + + # Literals - keep as Python numbers for static operations + if isinstance(expr, (int, float)): + return expr + + if isinstance(expr, str): + return expr + + # Symbols - variable lookup + if isinstance(expr, Symbol): + name = expr.name + if name in env: + return env[name] + if name == 'nil': + return None + if name == 'true': + return True + if name == 'false': + return False + raise NameError(f"Unknown symbol: {name}") + + # Lists - function calls + if isinstance(expr, list) and len(expr) > 0: + head = expr[0] + + if isinstance(head, Symbol): + op = head.name + args = expr[1:] + + # Special forms + if op == 'let' or op == 'let*': + return self._eval_let(args, env) + if op == 'if': + return self._eval_if(args, env) + if op == 'lambda' or op == 'λ': + return self._eval_lambda(args, env) + if op == 'define': + return self._eval_define(args, env) + + # Built-in operations + return self._eval_call(op, args, env) + + # Empty list + if isinstance(expr, list) and len(expr) == 0: + return [] + + raise ValueError(f"Cannot evaluate: {expr}") + + def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]): + """Extract a keyword argument from args list. + + Looks for :key value pattern in args and evaluates the value. + Returns default if not found. + """ + i = 0 + while i < len(args): + if isinstance(args[i], Keyword) and args[i].name == key: + if i + 1 < len(args): + val = self._eval(args[i + 1], env) + # Handle Symbol values (e.g., :op 'sum -> 'sum') + if isinstance(val, Symbol): + return val.name + return val + return default + i += 1 + return default + + def _eval_let(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body).""" + if len(args) < 2: + raise ValueError("let requires bindings and body") + + bindings = args[0] + body = args[1] + + new_env = env.copy() + + # Handle both ((var val) ...) and [var val var2 val2 ...] syntax + if isinstance(bindings, list): + # Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...) + if bindings and isinstance(bindings[0], Symbol): + # Flat list: [var val var2 val2 ...] + i = 0 + while i < len(bindings) - 1: + var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[var] = val + i += 2 + else: + # Nested list: ((var val) (var2 val2) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[var] = val + + return self._eval(body, new_env) + + def _eval_if(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + if len(args) < 2: + raise ValueError("if requires condition and then-branch") + + cond = self._eval(args[0], env) + + # Handle None as falsy (important for optional params like overlay) + if cond is None: + return self._eval(args[2], env) if len(args) > 2 else None + + # For Python scalar bools, use normal Python if + # This allows side effects and None values + if isinstance(cond, bool): + if cond: + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + + # For NumPy/JAX scalar bools with concrete values + if hasattr(cond, 'item') and cond.shape == (): + try: + if bool(cond.item()): + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + except: + pass # Fall through to jnp.where for traced values + + # For traced values, evaluate both branches and use jnp.where + then_val = self._eval(args[1], env) + else_val = self._eval(args[2], env) if len(args) > 2 else 0.0 + + # Handle None by converting to zeros + if then_val is None: + then_val = 0.0 + if else_val is None: + else_val = 0.0 + + # Convert lists to tuples + if isinstance(then_val, list): + then_val = tuple(then_val) + if isinstance(else_val, list): + else_val = tuple(else_val) + + # Handle tuple results (e.g., from rgb in map-pixels) + if isinstance(then_val, tuple) and isinstance(else_val, tuple): + return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val)) + + return jnp.where(cond, then_val, else_val) + + def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable: + """Evaluate (lambda (params) body).""" + if len(args) < 2: + raise ValueError("lambda requires parameters and body") + + params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + return fn + + def _eval_define(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (define name value) or (define (name params) body).""" + if len(args) < 2: + raise ValueError("define requires name and value") + + name_part = args[0] + + if isinstance(name_part, list): + # Function definition: (define (name params) body) + fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + env[fn_name] = fn + return fn + else: + # Variable definition + var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part) + val = self._eval(args[1], env) + env[var_name] = val + return val + + def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any: + """Evaluate a function call.""" + + # Check if it's a user-defined function + if op in env and callable(env[op]): + fn = env[op] + eval_args = [self._eval(a, env) for a in args] + return fn(*eval_args) + + # Arithmetic + if op == '+': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 0.0 + for v in vals[1:]: + result = result + v + return result + + if op == '-': + if len(args) == 1: + return -self._eval(args[0], env) + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result - v + return result + + if op == '*': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 1.0 + for v in vals[1:]: + result = result * v + return result + + if op == '/': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result / v + return result + + if op == 'mod': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a % b + + if op == 'pow' or op == '**': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return jnp.power(a, b) + + # Comparison + if op == '<': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a < b + if op == '>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a > b + if op == '<=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a <= b + if op == '>=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a >= b + if op == '=' or op == '==': + a, b = self._eval(args[0], env), self._eval(args[1], env) + # For scalar Python types, return Python bool to enable trace-time if + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a == b) + return a == b + if op == '!=' or op == '<>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a != b) + return a != b + + # Logic + if op == 'and': + vals = [self._eval(a, env) for a in args] + # Use Python and for concrete Python bools (e.g., shape comparisons) + if all(isinstance(v, (bool, np.bool_)) for v in vals): + result = True + for v in vals: + result = result and bool(v) + return result + # Otherwise use JAX logical_and + result = vals[0] + for v in vals[1:]: + result = jnp.logical_and(result, v) + return result + + if op == 'or': + # Lisp-style or: returns first truthy value, not boolean + # (or a b c) returns a if a is truthy, else b if b is truthy, else c + for arg in args: + val = self._eval(arg, env) + # Check if value is truthy + if val is None: + continue + if isinstance(val, (bool, np.bool_)): + if val: + return val + continue + if isinstance(val, (int, float)): + if val: + return val + continue + if hasattr(val, 'shape'): + # JAX/numpy array - return it (considered truthy) + return val + # For other types, check truthiness + if val: + return val + # All values were falsy, return the last one + return self._eval(args[-1], env) if args else None + + if op == 'not': + val = self._eval(args[0], env) + if isinstance(val, (bool, np.bool_)): + return not bool(val) + return jnp.logical_not(val) + + # Math functions + if op == 'sqrt': + return jnp.sqrt(self._eval(args[0], env)) + if op == 'sin': + return jnp.sin(self._eval(args[0], env)) + if op == 'cos': + return jnp.cos(self._eval(args[0], env)) + if op == 'tan': + return jnp.tan(self._eval(args[0], env)) + if op == 'exp': + return jnp.exp(self._eval(args[0], env)) + if op == 'log': + return jnp.log(self._eval(args[0], env)) + if op == 'abs': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return abs(x) + return jnp.abs(x) + if op == 'floor': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.floor(x) + return jnp.floor(x) + if op == 'ceil': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.ceil(x) + return jnp.ceil(x) + if op == 'round': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return round(x) + return jnp.round(x) + + # Frame primitives + if op == 'width': + return env['width'] + if op == 'height': + return env['height'] + + if op == 'channel': + frame = self._eval(args[0], env) + idx = self._eval(args[1], env) + # idx should be a Python int (literal from S-expression) + return jax_channel(frame, idx) + + if op == 'merge-channels' or op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + if op == 'sample': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + return jax_sample(frame, x, y) + + if op == 'cell-indices': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_cell_indices(frame, cell_size) + + if op == 'pool-frame': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_pool_frame(frame, cell_size) + + # Xector primitives + if op == 'iota': + n = self._eval(args[0], env) + return jax_iota(int(n)) + + if op == 'repeat': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_repeat(x, int(n)) + + if op == 'tile': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_tile(x, int(n)) + + if op == 'gather': + data = self._eval(args[0], env) + indices = self._eval(args[1], env) + return jax_gather(data, indices) + + if op == 'scatter': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter(indices, values, size) + + if op == 'scatter-add': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter_add(indices, values, size) + + if op == 'group-reduce': + values = self._eval(args[0], env) + groups = self._eval(args[1], env) + num_groups = int(self._eval(args[2], env)) + reduce_op = args[3] if len(args) > 3 else 'mean' + if isinstance(reduce_op, Symbol): + reduce_op = reduce_op.name + return jax_group_reduce(values, groups, num_groups, reduce_op) + + if op == 'where': + cond = self._eval(args[0], env) + true_val = self._eval(args[1], env) + false_val = self._eval(args[2], env) + # Handle None values + if true_val is None: + true_val = 0.0 + if false_val is None: + false_val = 0.0 + return jax_where(cond, true_val, false_val) + + if op == 'len' or op == 'length': + x = self._eval(args[0], env) + if isinstance(x, (list, tuple)): + return len(x) + return x.size + + # Beta reductions + if op in ('β+', 'beta+', 'sum'): + return jnp.sum(self._eval(args[0], env)) + if op in ('β*', 'beta*', 'product'): + return jnp.prod(self._eval(args[0], env)) + if op in ('βmin', 'beta-min'): + return jnp.min(self._eval(args[0], env)) + if op in ('βmax', 'beta-max'): + return jnp.max(self._eval(args[0], env)) + if op in ('βmean', 'beta-mean', 'mean'): + return jnp.mean(self._eval(args[0], env)) + if op in ('βstd', 'beta-std'): + return jnp.std(self._eval(args[0], env)) + if op in ('βany', 'beta-any'): + return jnp.any(self._eval(args[0], env)) + if op in ('βall', 'beta-all'): + return jnp.all(self._eval(args[0], env)) + + # Scan (prefix) operations + if op in ('scan+', 'scan-add'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_add(x, axis) + if op in ('scan*', 'scan-mul'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_mul(x, axis) + if op == 'scan-max': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_max(x, axis) + if op == 'scan-min': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_min(x, axis) + + # Outer product operations + if op == 'outer': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + op_type = self._eval_kwarg(args, 'op', '*', env) + return jax_outer(x, y, op_type) + if op in ('outer+', 'outer-add'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_add(x, y) + if op in ('outer*', 'outer-mul'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_mul(x, y) + if op == 'outer-max': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_max(x, y) + if op == 'outer-min': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_min(x, y) + + # Reduce with axis operations + if op == 'reduce-axis': + x = self._eval(args[0], env) + reduce_op = self._eval_kwarg(args, 'op', 'sum', env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_reduce_axis(x, reduce_op, axis) + if op == 'sum-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_sum_axis(x, axis) + if op == 'mean-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_mean_axis(x, axis) + if op == 'max-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_max_axis(x, axis) + if op == 'min-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_min_axis(x, axis) + + # Windowed operations + if op == 'window': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + win_op = self._eval_kwarg(args, 'op', 'mean', env) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window(x, size, win_op, stride) + if op == 'window-sum': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_sum(x, size, stride) + if op == 'window-mean': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_mean(x, size, stride) + if op == 'window-max': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_max(x, size, stride) + if op == 'window-min': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_min(x, size, stride) + + # Integral image + if op == 'integral-image': + frame = self._eval(args[0], env) + return jax_integral_image(frame) + + # Convenience - min/max of two values (handle both scalars and arrays) + if op == 'min' or op == 'min2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return min(a, b) + return jnp.minimum(jnp.asarray(a), jnp.asarray(b)) + if op == 'max' or op == 'max2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return max(a, b) + return jnp.maximum(jnp.asarray(a), jnp.asarray(b)) + if op == 'clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + # List operations + if op == 'list': + return tuple(self._eval(a, env) for a in args) + + if op == 'nth': + seq = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[idx] if 0 <= idx < len(seq) else None + return seq[idx] # For arrays + + if op == 'first': + seq = self._eval(args[0], env) + return seq[0] if len(seq) > 0 else None + + if op == 'second': + seq = self._eval(args[0], env) + return seq[1] if len(seq) > 1 else None + + # Random (JAX-compatible) + # Get frame_num for deterministic variation - can be traced, fold_in handles it + frame_num = env.get('_frame_num', env.get('frame_num', 0)) + # Convert to int32 for fold_in if needed (but keep as JAX array if traced) + if frame_num is None: + frame_num = 0 + elif isinstance(frame_num, (int, float)): + frame_num = int(frame_num) + # If it's a JAX array, leave it as-is for tracing + + # Increment operation counter for unique keys within same frame + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + + if op == 'rand' or op == 'rand-x': + # For size-based random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + shape = (n,) + else: + n = size.size + shape = size.shape + elif hasattr(size, 'size'): + n = size.size + shape = (n,) + else: + n = int(size) + shape = (n,) + # Use deterministic key that varies with frame + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, shape).flatten() + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, ()) + + if op == 'randn' or op == 'randn-x': + # Normal random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + else: + n = size.size + elif hasattr(size, 'size'): + n = size.size + else: + n = int(size) + mean = self._eval(args[1], env) if len(args) > 1 else 0.0 + std = self._eval(args[2], env) if len(args) > 2 else 1.0 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, (n,)) * std + mean + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) + + if op == 'rand-range' or op == 'core:rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + seed = env.get('_seed', 42) + return jax_rand_range(lo, hi, frame_num, op_counter, seed) + + # ===================================================================== + # Convolution operations + # ===================================================================== + if op == 'blur' or op == 'image:blur': + frame = self._eval(args[0], env) + radius = self._eval(args[1], env) if len(args) > 1 else 1 + # Convert traced value to concrete for kernel size + if hasattr(radius, 'item'): + radius = int(radius.item()) + elif hasattr(radius, '__float__'): + radius = int(float(radius)) + else: + radius = int(radius) + return jax_blur(frame, max(1, radius)) + + if op == 'gaussian': + first_arg = self._eval(args[0], env) + # Check if first arg is a frame (blur) or scalar (random) + if hasattr(first_arg, 'shape') and first_arg.ndim == 3: + # Frame - apply gaussian blur + sigma = self._eval(args[1], env) if len(args) > 1 else 1.0 + radius = max(1, int(sigma * 3)) + return jax_blur(first_arg, radius) + else: + # Scalar args - generate gaussian random value + mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg + std = self._eval(args[1], env) if len(args) > 1 else 1.0 + # Return a single random value + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) * std + mean + + if op == 'sharpen' or op == 'image:sharpen': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + return jax_sharpen(frame, amount) + + if op == 'edge-detect' or op == 'image:edge-detect': + frame = self._eval(args[0], env) + return jax_edge_detect(frame) + + if op == 'emboss': + frame = self._eval(args[0], env) + return jax_emboss(frame) + + if op == 'convolve': + frame = self._eval(args[0], env) + kernel = self._eval(args[1], env) + # Convert kernel to array if it's a list + if isinstance(kernel, (list, tuple)): + kernel = jnp.array(kernel, dtype=jnp.float32) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'add-noise': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 0.1 + h, w = frame.shape[:2] + # Use frame-varying key for noise + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1] + result = frame.astype(jnp.float32) + noise * amount * 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + if op == 'translate': + frame = self._eval(args[0], env) + dx = self._eval(args[1], env) + dy = self._eval(args[2], env) if len(args) > 2 else 0 + h, w = frame.shape[:2] + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + src_x = (x_coords - dx).flatten() + src_y = (y_coords - dy).flatten() + r, g, b = jax_sample(frame, src_x, src_y) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'image:crop' or op == 'crop': + frame = self._eval(args[0], env) + x = int(self._eval(args[1], env)) + y = int(self._eval(args[2], env)) + w = int(self._eval(args[3], env)) + h = int(self._eval(args[4], env)) + return frame[y:y+h, x:x+w, :] + + if op == 'dilate': + frame = self._eval(args[0], env) + size = int(self._eval(args[1], env)) if len(args) > 1 else 3 + # Simple dilation using max pooling approximation + kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'map-rows': + frame = self._eval(args[0], env) + fn = args[1] # S-expression function + h, w = frame.shape[:2] + # For each row, apply the function + results = [] + for row_idx in range(h): + row_env = env.copy() + row_env['row'] = frame[row_idx, :, :] + row_env['row-idx'] = row_idx + + # Check if fn is a lambda + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + # Bind lambda params to y and row + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + row_env[param_name] = row_idx # y + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + row_env[param_name] = frame[row_idx, :, :] # row + result_row = self._eval(body, row_env) + results.append(result_row) + continue + + result_row = self._eval(fn, row_env) + # If result is a function, call it + if callable(result_row): + result_row = result_row(row_idx, frame[row_idx, :, :]) + results.append(result_row) + return jnp.stack(results, axis=0) + + # ===================================================================== + # Text rendering operations + # ===================================================================== + if op == 'text': + frame = self._eval(args[0], env) + text_str = self._eval(args[1], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + + # Extract keyword arguments + x = self._eval_kwarg(args, 'x', None, env) + y = self._eval_kwarg(args, 'y', None, env) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + color = self._eval_kwarg(args, 'color', (255, 255, 255), env) + opacity = self._eval_kwarg(args, 'opacity', 1.0, env) + align = self._eval_kwarg(args, 'align', 'left', env) + valign = self._eval_kwarg(args, 'valign', 'top', env) + shadow = self._eval_kwarg(args, 'shadow', False, env) + shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env) + shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env) + fit = self._eval_kwarg(args, 'fit', False, env) + width = self._eval_kwarg(args, 'width', None, env) + height = self._eval_kwarg(args, 'height', None, env) + + # Handle color as list/tuple + if isinstance(color, (list, tuple)): + color = tuple(int(c) for c in color[:3]) + if isinstance(shadow_color, (list, tuple)): + shadow_color = tuple(int(c) for c in shadow_color[:3]) + + h, w_frame = frame.shape[:2] + + # Default position to 0,0 or center based on alignment + if x is None: + if align == 'center': + x = w_frame // 2 + elif align == 'right': + x = w_frame + else: + x = 0 + if y is None: + if valign == 'middle': + y = h // 2 + elif valign == 'bottom': + y = h + else: + y = 0 + + # Auto-fit text to bounds + if fit and width is not None and height is not None: + font_size = jax_fit_text_size(text_str, int(width), int(height), + font_name, min_size=8, max_size=200) + + return jax_text_render(frame, text_str, int(x), int(y), + font_name=font_name, font_size=int(font_size), + color=color, opacity=float(opacity), + align=str(align), valign=str(valign), + shadow=bool(shadow), shadow_color=shadow_color, + shadow_offset=int(shadow_offset)) + + if op == 'text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_text_size(text_str, font_name, int(font_size)) + + if op == 'fit-text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + max_width = int(self._eval(args[1], env)) + max_height = int(self._eval(args[2], env)) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_fit_text_size(text_str, max_width, max_height, font_name) + + # ===================================================================== + # Color operations + # ===================================================================== + if op == 'rgb->hsv' or op == 'rgb-to-hsv': + # Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple + if len(args) == 1: + c = self._eval(args[0], env) + if isinstance(c, tuple) and len(c) == 3: + r, g, b = c + else: + # Assume it's a list-like + r, g, b = c[0], c[1], c[2] + else: + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + return jax_rgb_to_hsv(r, g, b) + + if op == 'hsv->rgb' or op == 'hsv-to-rgb': + # Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list) + if len(args) == 1: + hsv = self._eval(args[0], env) + if isinstance(hsv, (tuple, list)) and len(hsv) >= 3: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h = self._eval(args[0], env) + s = self._eval(args[1], env) + v = self._eval(args[2], env) + return jax_hsv_to_rgb(h, s, v) + + if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) + return jax_adjust_brightness(frame, amount) + + if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_contrast(frame, factor) + + if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_saturation(frame, factor) + + if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift': + frame = self._eval(args[0], env) + degrees = self._eval(args[1], env) + return jax_shift_hue(frame, degrees) + + if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img': + frame = self._eval(args[0], env) + return jax_invert(frame) + + if op == 'posterize' or op == 'color_ops:posterize': + frame = self._eval(args[0], env) + levels = self._eval(args[1], env) + return jax_posterize(frame, levels) + + if op == 'threshold' or op == 'color_ops:threshold': + frame = self._eval(args[0], env) + level = self._eval(args[1], env) + invert = self._eval(args[2], env) if len(args) > 2 else False + return jax_threshold(frame, level, invert) + + if op == 'sepia' or op == 'color_ops:sepia': + frame = self._eval(args[0], env) + return jax_sepia(frame) + + if op == 'grayscale' or op == 'image:grayscale': + frame = self._eval(args[0], env) + return jax_grayscale(frame) + + # ===================================================================== + # Geometry operations + # ===================================================================== + if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img': + frame = self._eval(args[0], env) + direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal' + if direction == 'vertical' or direction == 'v': + return jax_flip_vertical(frame) + return jax_flip_horizontal(frame) + + if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v': + frame = self._eval(args[0], env) + return jax_flip_vertical(frame) + + if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img': + frame = self._eval(args[0], env) + angle = self._eval(args[1], env) + return jax_rotate(frame, angle) + + if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img': + frame = self._eval(args[0], env) + scale_x = self._eval(args[1], env) + scale_y = self._eval(args[2], env) if len(args) > 2 else None + return jax_scale(frame, scale_x, scale_y) + + if op == 'resize' or op == 'image:resize': + frame = self._eval(args[0], env) + new_w = self._eval(args[1], env) + new_h = self._eval(args[2], env) + return jax_resize(frame, new_w, new_h) + + # ===================================================================== + # Geometry distortion effects + # ===================================================================== + if op == 'geometry:fisheye-coords' or op == 'fisheye': + # Signature: (w h strength cx cy zoom_correct) or (frame strength) + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + # (w h strength cx cy zoom_correct) signature + w = int(first_arg) + h = int(self._eval(args[1], env)) + strength = self._eval(args[2], env) if len(args) > 2 else 0.5 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = first_arg + strength = self._eval(args[1], env) if len(args) > 1 else 0.5 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Fisheye distortion + r_new = r + strength * r * (1 - r / max_r) + + src_x = r_new * jnp.cos(theta) + cx + src_y = r_new * jnp.sin(theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:swirl-coords' or op == 'swirl': + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + w = int(first_arg) + h = int(self._eval(args[1], env)) + amount = self._eval(args[2], env) if len(args) > 2 else 1.0 + frame = None + else: + frame = first_arg + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + h, w = frame.shape[:2] + + cx, cy = w / 2, h / 2 + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + swirl_angle = amount * (1 - r / max_r) + new_theta = theta + swirl_angle + + src_x = r * jnp.cos(new_theta) + cx + src_y = r * jnp.sin(new_theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Wave effect (frame-first signature for simple usage) + if op == 'wave-distort': + first_arg = self._eval(args[0], env) + frame = first_arg + amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0 + amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0 + freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1 + freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1 + h, w = frame.shape[:2] + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y) + src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x) + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:ripple-displace' or op == 'ripple': + # Match Python prim_ripple_displace signature: + # (w h freq amp cx cy decay phase) or (frame ...) + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + # Coordinate-only mode: (w h freq amp cx cy decay phase) + w = int(first_arg) + h = int(self._eval(args[1], env)) + freq = self._eval(args[2], env) if len(args) > 2 else 5.0 + amp = self._eval(args[3], env) if len(args) > 3 else 10.0 + cx = self._eval(args[4], env) if len(args) > 4 else w / 2 + cy = self._eval(args[5], env) if len(args) > 5 else h / 2 + decay = self._eval(args[6], env) if len(args) > 6 else 0.0 + phase = self._eval(args[7], env) if len(args) > 7 else 0.0 + frame = None + else: + # Frame mode: (frame :amplitude A :frequency F :center_x CX ...) + frame = first_arg + h, w = frame.shape[:2] + # Parse keyword args + amp = 10.0 + freq = 5.0 + cx = w / 2 + cy = h / 2 + decay = 0.0 + phase = 0.0 + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + kw = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + if kw == 'amplitude': + amp = val + elif kw == 'frequency': + freq = val + elif kw == 'center_x': + cx = val * w if val <= 1 else val # normalized or absolute + elif kw == 'center_y': + cy = val * h if val <= 1 else val + elif kw == 'decay': + decay = val + elif kw == 'speed': + # speed affects phase via time + t = env.get('t', 0) + phase = t * val * 2 * jnp.pi + elif kw == 'phase': + phase = val + i += 2 + else: + i += 1 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + dist = jnp.sqrt(dx*dx + dy*dy) + + # Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp + max_dim = jnp.maximum(w, h) + ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp + + # Apply decay (when decay=0, exp(0)=1 so no effect) + decay_factor = jnp.exp(-decay * dist / max_dim) + ripple = ripple * decay_factor + + # Radial displacement - use ADDITION to match Python prim_ripple_displace + # Python (primitives.py line 2890-2891): + # map_x = x_coords + ripple * norm_dx + # map_y = y_coords + ripple * norm_dy + # where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle) + angle = jnp.arctan2(dy, dx) + src_x = x_coords + ripple * jnp.cos(angle) + src_y = y_coords + ripple * jnp.sin(angle) + + if frame is None: + return {'x': src_x, 'y': src_y} + + # Sample using bilinear interpolation (jax_sample clamps coords, + # matching OpenCV's default remap behavior) + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:coords-x' or op == 'coords-x': + # Extract x coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('x', coords.get('map_x')) + return coords[0] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:coords-y' or op == 'coords-y': + # Extract y coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('y', coords.get('map_y')) + return coords[1] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:remap' or op == 'remap': + # Remap image using coordinate maps: (frame map_x map_y) + # OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords + frame = self._eval(args[0], env) + map_x = self._eval(args[1], env) + map_y = self._eval(args[2], env) + + h, w = frame.shape[:2] + + # Flatten coordinate maps + src_x = map_x.flatten() + src_y = map_y.flatten() + + # Sample using bilinear interpolation (jax_sample clamps coords internally, + # matching OpenCV's default behavior) + r_out, g_out, b_out = jax_sample(frame, src_x, src_y) + + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope': + # Two signatures: (frame segments) or (w h segments cx cy) + if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'): + # (w h segments cx cy) signature + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + segments = int(self._eval(args[2], env)) if len(args) > 2 else 6 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = self._eval(args[0], env) + segments = int(self._eval(args[1], env)) if len(args) > 1 else 6 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Mirror into segments + segment_angle = 2 * jnp.pi / segments + theta_mod = theta % segment_angle + theta_mirror = jnp.where( + (jnp.floor(theta / segment_angle) % 2) == 0, + theta_mod, + segment_angle - theta_mod + ) + + src_x = r * jnp.cos(theta_mirror) + cx + src_y = r * jnp.sin(theta_mirror) + cy + + if frame is None: + # Return coordinate arrays + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Geometry coordinate extraction + if op == 'geometry:coords-x': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['x'] + return coords[0] if isinstance(coords, tuple) else coords + + if op == 'geometry:coords-y': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['y'] + return coords[1] if isinstance(coords, tuple) else coords + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x_coords = self._eval(args[1], env) + y_coords = self._eval(args[2], env) + h, w = frame.shape[:2] + r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(frame1, frame2, alpha) + + if op == 'blend-add': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_add(frame1, frame2) + + if op == 'blend-multiply': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_multiply(frame1, frame2) + + if op == 'blend-screen': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_screen(frame1, frame2) + + if op == 'blend-overlay': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_overlay(frame1, frame2) + + # ===================================================================== + # Image dimension queries (namespaced aliases) + # ===================================================================== + if op == 'image:width': + if args: + frame = self._eval(args[0], env) + return frame.shape[1] # width is second dimension (h, w, c) + return env['width'] + + if op == 'image:height': + if args: + frame = self._eval(args[0], env) + return frame.shape[0] # height is first dimension (h, w, c) + return env['height'] + + # ===================================================================== + # Utility + # ===================================================================== + if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?': + x = self._eval(args[0], env) + return jax_is_nil(x) + + # ===================================================================== + # Xector channel operations (shortcuts) + # ===================================================================== + if op == 'red': + val = self._eval(args[0], env) + # Works on frames or pixel tuples + if isinstance(val, tuple): + return val[0] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 0) + else: + return val # Assume it's already a channel + + if op == 'green': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[1] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 1) + else: + return val + + if op == 'blue': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[2] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 2) + else: + return val + + if op == 'gray' or op == 'luminance': + val = self._eval(args[0], env) + # Handle tuple (r, g, b) from map-pixels + if isinstance(val, tuple) and len(val) == 3: + r, g, b = val + return r * 0.299 + g * 0.587 + b * 0.114 + # Handle frame + frame = val + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + return r * 0.299 + g * 0.587 + b * 0.114 + + if op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + # ===================================================================== + # Coordinate operations + # ===================================================================== + if op == 'x-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + + if op == 'y-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + if op == 'dist-from-center': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx + y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy + return jnp.sqrt(x*x + y*y) + + # ===================================================================== + # Alpha operations (element-wise on xectors) + # ===================================================================== + if op == 'α/' or op == 'alpha/': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a / b + + if op == 'α+' or op == 'alpha+': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result + v + return result + + if op == 'α*' or op == 'alpha*': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result * v + return result + + if op == 'α-' or op == 'alpha-': + if len(args) == 1: + return -self._eval(args[0], env) + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a - b + + if op == 'αclamp' or op == 'alpha-clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + if op == 'αmin' or op == 'alpha-min': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.minimum(a, b) + + if op == 'αmax' or op == 'alpha-max': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.maximum(a, b) + + if op == 'αsqrt' or op == 'alpha-sqrt': + return jnp.sqrt(self._eval(args[0], env)) + + if op == 'αsin' or op == 'alpha-sin': + return jnp.sin(self._eval(args[0], env)) + + if op == 'αcos' or op == 'alpha-cos': + return jnp.cos(self._eval(args[0], env)) + + if op == 'αmod' or op == 'alpha-mod': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a % b + + if op == 'α²' or op == 'αsq' or op == 'alpha-sq': + x = self._eval(args[0], env) + return x * x + + if op == 'α<' or op == 'alpha<': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a < b + + if op == 'α>' or op == 'alpha>': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a > b + + if op == 'α<=' or op == 'alpha<=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a <= b + + if op == 'α>=' or op == 'alpha>=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a >= b + + if op == 'α=' or op == 'alpha=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a == b + + if op == 'αfloor' or op == 'alpha-floor': + return jnp.floor(self._eval(args[0], env)) + + if op == 'αceil' or op == 'alpha-ceil': + return jnp.ceil(self._eval(args[0], env)) + + if op == 'αround' or op == 'alpha-round': + return jnp.round(self._eval(args[0], env)) + + if op == 'αabs' or op == 'alpha-abs': + return jnp.abs(self._eval(args[0], env)) + + if op == 'αexp' or op == 'alpha-exp': + return jnp.exp(self._eval(args[0], env)) + + if op == 'αlog' or op == 'alpha-log': + return jnp.log(self._eval(args[0], env)) + + if op == 'αor' or op == 'alpha-or': + # Element-wise logical OR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_or(a, b) + + if op == 'αand' or op == 'alpha-and': + # Element-wise logical AND + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_and(a, b) + + if op == 'αnot' or op == 'alpha-not': + # Element-wise logical NOT + return jnp.logical_not(self._eval(args[0], env)) + + if op == 'αxor' or op == 'alpha-xor': + # Element-wise logical XOR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_xor(a, b) + + # ===================================================================== + # Threading/arrow operations + # ===================================================================== + if op == '->': + # Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b) + val = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list): + # Insert val as first argument + fn_name = form[0].name if isinstance(form[0], Symbol) else form[0] + new_args = [val] + [self._eval(a, env) for a in form[1:]] + val = self._eval_call(fn_name, [val] + form[1:], env) + else: + # Simple function call + fn_name = form.name if isinstance(form, Symbol) else form + val = self._eval_call(fn_name, [args[0]], env) + return val + + # ===================================================================== + # Range and iteration + # ===================================================================== + if op == 'range': + if len(args) == 1: + end = int(self._eval(args[0], env)) + return list(range(end)) + elif len(args) == 2: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + return list(range(start, end)) + else: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + step = int(self._eval(args[2], env)) + return list(range(start, end, step)) + + if op == 'reduce' or op == 'fold': + # (reduce seq init fn) - left fold + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for item in seq: + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + acc = self._eval(body, fn_env) + return acc + + # Fallback - try evaluating fn and calling it + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for item in seq: + acc = fn_eval(acc, item) + return acc + + if op == 'fold-indexed': + # (fold-indexed seq init fn) - fold with index + # fn takes (acc item index) or (acc item index cursor) for typography + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for idx, item in enumerate(seq): + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + fn_env[param_name] = idx + acc = self._eval(body, fn_env) + return acc + + # Fallback + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for idx, item in enumerate(seq): + acc = fn_eval(acc, item, idx) + return acc + + # ===================================================================== + # Map-pixels (apply function to each pixel) + # ===================================================================== + if op == 'map-pixels': + frame = self._eval(args[0], env) + fn = args[1] # Lambda or S-expression + h, w = frame.shape[:2] + + # Extract channels + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + # Set up pixel environment + pixel_env = env.copy() + pixel_env['r'] = r + pixel_env['g'] = g + pixel_env['b'] = b + pixel_env['x'] = x_coords + pixel_env['y'] = y_coords + # Also provide c (color) as a tuple for lambda (x y c) style + pixel_env['c'] = (r, g, b) + + # If fn is a lambda, we need to handle it specially + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + # Lambda: (lambda (x y c) body) + params = fn[1] + body = fn[2] + # Bind parameters + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + pixel_env[param_name] = x_coords + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + pixel_env[param_name] = y_coords + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + pixel_env[param_name] = (r, g, b) + result = self._eval(body, pixel_env) + else: + result = self._eval(fn, pixel_env) + else: + result = self._eval(fn, pixel_env) + + if isinstance(result, tuple) and len(result) == 3: + nr, ng, nb = result + return jax_merge_channels(nr, ng, nb, (h, w)) + elif hasattr(result, 'shape') and result.ndim == 3: + return result + else: + # Single channel result + if hasattr(result, 'flatten'): + result = result.flatten() + gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # ===================================================================== + # State operations (return unchanged for stateless JIT) + # ===================================================================== + if op == 'state-get': + key = self._eval(args[0], env) + default = self._eval(args[1], env) if len(args) > 1 else None + return default # State not supported in JIT, return default + + if op == 'state-set': + return None # No-op in JIT + + # ===================================================================== + # Cell/grid operations + # ===================================================================== + if op == 'local-x-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size) / max(1, cell_size - 1) + + if op == 'local-y-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size) / max(1, cell_size - 1) + + if op == 'local-x': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size).astype(jnp.float32) + + if op == 'local-y': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size).astype(jnp.float32) + + if op == 'cell-row': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return jnp.floor(y / cell_size) + + if op == 'cell-col': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return jnp.floor(x / cell_size) + + if op == 'num-rows': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[0] // cell_size + + if op == 'num-cols': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[1] // cell_size + + # ===================================================================== + # Control flow + # ===================================================================== + if op == 'cond': + # (cond (test1 expr1) (test2 expr2) ... (else exprN)) + # For JAX compatibility, build a nested jnp.where structure + # Start from the else clause and work backwards + + # Collect clauses + clauses = [] + else_expr = None + for clause in args: + if isinstance(clause, list) and len(clause) >= 2: + test = clause[0] + if isinstance(test, Symbol) and test.name == 'else': + else_expr = clause[1] + else: + clauses.append((test, clause[1])) + + # If no else, default to None/0 + if else_expr is not None: + result = self._eval(else_expr, env) + else: + result = 0 + + # Build nested where from last to first + for test_expr, val_expr in reversed(clauses): + cond_val = self._eval(test_expr, env) + then_val = self._eval(val_expr, env) + + # Check if condition is array or scalar + if hasattr(cond_val, 'shape') and cond_val.shape != (): + # Array condition - use jnp.where + result = jnp.where(cond_val, then_val, result) + else: + # Scalar - can use Python if + if cond_val: + result = then_val + + return result + + if op == 'set!' or op == 'set': + # Mutation - not really supported in JAX, but we can update env + var = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + val = self._eval(args[1], env) + env[var] = val + return val + + if op == 'begin' or op == 'do': + # Evaluate all expressions, return last + result = None + for expr in args: + result = self._eval(expr, env) + return result + + # ===================================================================== + # Additional math + # ===================================================================== + if op == 'sq' or op == 'square': + x = self._eval(args[0], env) + return x * x + + if op == 'lerp': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + t = self._eval(args[2], env) + return a * (1 - t) + b * t + + if op == 'smoothstep': + edge0 = self._eval(args[0], env) + edge1 = self._eval(args[1], env) + x = self._eval(args[2], env) + t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1) + return t * t * (3 - 2 * t) + + if op == 'atan2': + y = self._eval(args[0], env) + x = self._eval(args[1], env) + return jnp.arctan2(y, x) + + if op == 'fract' or op == 'frac': + x = self._eval(args[0], env) + return x - jnp.floor(x) + + # ===================================================================== + # Frame copy and construction operations + # ===================================================================== + if op == 'pixel': + # Get pixel at (x, y) from frame + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + # Convert to int and clip to bounds + if isinstance(x, (int, float)): + x = max(0, min(int(x), w - 1)) + else: + x = jnp.clip(x, 0, w - 1).astype(jnp.int32) + if isinstance(y, (int, float)): + y = max(0, min(int(y), h - 1)) + else: + y = jnp.clip(y, 0, h - 1).astype(jnp.int32) + r = frame[y, x, 0] + g = frame[y, x, 1] + b = frame[y, x, 2] + return (r, g, b) + + if op == 'copy': + frame = self._eval(args[0], env) + return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame) + + if op == 'make-image': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + if len(args) > 2: + color = self._eval(args[2], env) + if isinstance(color, (list, tuple)): + r, g, b = int(color[0]), int(color[1]), int(color[2]) + else: + r = g = b = int(color) + else: + r = g = b = 0 + img = jnp.zeros((h, w, 3), dtype=jnp.uint8) + img = img.at[:, :, 0].set(r) + img = img.at[:, :, 1].set(g) + img = img.at[:, :, 2].set(b) + return img + + if op == 'paste': + dest = self._eval(args[0], env) + src = self._eval(args[1], env) + x = int(self._eval(args[2], env)) + y = int(self._eval(args[3], env)) + sh, sw = src.shape[:2] + dh, dw = dest.shape[:2] + # Clip to dest bounds + x1, y1 = max(0, x), max(0, y) + x2, y2 = min(dw, x + sw), min(dh, y + sh) + sx1, sy1 = x1 - x, y1 - y + sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1) + result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest) + result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :]) + return result + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blending:blend-images' or op == 'blend-images': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(a, b, alpha) + + if op == 'blending:blend-mode' or op == 'blend-mode': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + mode = self._eval(args[2], env) if len(args) > 2 else 'add' + if mode == 'add': + return jax_blend_add(a, b) + elif mode == 'multiply': + return jax_blend_multiply(a, b) + elif mode == 'screen': + return jax_blend_screen(a, b) + elif mode == 'overlay': + return jax_blend_overlay(a, b) + elif mode == 'lighten': + return jnp.maximum(a, b) + elif mode == 'darken': + return jnp.minimum(a, b) + elif mode == 'difference': + return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8) + else: + return jax_blend(a, b, 0.5) + + # ===================================================================== + # Geometry coordinate operations + # ===================================================================== + if op == 'geometry:wave-coords' or op == 'wave-coords': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + axis = self._eval(args[2], env) if len(args) > 2 else 'x' + freq = self._eval(args[3], env) if len(args) > 3 else 1.0 + amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0 + phase = self._eval(args[5], env) if len(args) > 5 else 0.0 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + if axis == 'x' or axis == 'horizontal': + # Wave displaces X based on Y + offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + src_x = x_coords + offset + src_y = y_coords + elif axis == 'y' or axis == 'vertical': + # Wave displaces Y based on X + offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + src_y = y_coords + offset + else: # both + offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + offset_x + src_y = y_coords + offset_y + + return {'x': src_x, 'y': src_y} + + if op == 'geometry:coords-x' or op == 'coords-x': + coords = self._eval(args[0], env) + return coords['x'] + + if op == 'geometry:coords-y' or op == 'coords-y': + coords = self._eval(args[0], env) + return coords['y'] + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + r, g, b = jax_sample(frame, x.flatten(), y.flatten()) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Glitch effects + # ===================================================================== + if op == 'pixelsort': + frame = self._eval(args[0], env) + sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness' + thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50 + thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200 + angle = int(self._eval(args[4], env)) if len(args) > 4 else 0 + reverse = self._eval(args[5], env) if len(args) > 5 else False + + h, w = frame.shape[:2] + result = frame.copy() + + # Get luminance for thresholding + lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + # Sort each row + for y in range(h): + row_lum = lum[y, :] + row = frame[y, :, :] + + # Find mask of pixels to sort + mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi) + + # Get indices where we should sort + sort_indices = jnp.where(mask, jnp.arange(w), -1) + + # Simple sort by luminance for the row + if sort_by == 'lightness': + sort_key = row_lum + elif sort_by == 'hue': + # Approximate hue from RGB + sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32), + row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32))) + else: + sort_key = row_lum + + # Sort pixels in masked region + sorted_indices = jnp.argsort(sort_key) + if reverse: + sorted_indices = sorted_indices[::-1] + + # Apply partial sort (only where mask is true) + # This is a simplified version - full pixelsort is more complex + result = result.at[y, :, :].set(row[sorted_indices]) + + return result + + if op == 'datamosh': + frame = self._eval(args[0], env) + prev = self._eval(args[1], env) + block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32 + corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3 + max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50 + color_corrupt = self._eval(args[5], env) if len(args) > 5 else True + + h, w = frame.shape[:2] + + # Use deterministic random for JIT with frame variation + seed = env.get('_seed', 42) + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + key = make_jax_key(seed, frame_num, op_counter) + + num_blocks_y = h // block_size + num_blocks_x = w // block_size + total_blocks = num_blocks_y * num_blocks_x + + # Pre-generate all random values at once (vectorized) + key, k1, k2, k3, k4, k5 = jax.random.split(key, 6) + corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption + offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1) + offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1) + channels = jax.random.randint(k4, (total_blocks,), 0, 3) + color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51) + + # Create coordinate grids for blocks + by_grid = jnp.arange(num_blocks_y) + bx_grid = jnp.arange(num_blocks_x) + + # Create block coordinate arrays + by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...] + bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...] + + # Create pixel coordinate grids + y_coords, x_coords = jnp.mgrid[:h, :w] + + # Determine which block each pixel belongs to + pixel_block_y = y_coords // block_size + pixel_block_x = x_coords // block_size + pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x + + # Clamp to valid block indices (for pixels outside the block grid) + pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1) + + # Get the corrupt mask, offsets for each pixel's block + pixel_corrupt = corrupt_mask[pixel_block_idx] + pixel_offset_y = offsets_y[pixel_block_idx] + pixel_offset_x = offsets_x[pixel_block_idx] + + # Calculate source coordinates with offset (clamped) + src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1) + src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1) + + # Sample from previous frame at offset positions + prev_sampled = prev[src_y, src_x, :] + + # Where corrupt mask is true, use prev_sampled; else use frame + result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame) + + # Apply color corruption to corrupted blocks + if color_corrupt: + pixel_channel = channels[pixel_block_idx] + pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16) + + # Create per-channel shift arrays (only shift the selected channel) + shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0) + shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0) + shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0) + + result_int = result.astype(jnp.int16) + result_int = result_int.at[:, :, 0].add(shift_r) + result_int = result_int.at[:, :, 1].add(shift_g) + result_int = result_int.at[:, :, 2].add(shift_b) + result = jnp.clip(result_int, 0, 255).astype(jnp.uint8) + + return result + + # ===================================================================== + # ASCII Art Operations (using pre-rendered font atlas) + # ===================================================================== + + if op == 'cell-sample': + # (cell-sample frame char_size) -> (colors, luminances) + # Downsample frame into cells, return average colors and luminances + frame = self._eval(args[0], env) + char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8 + + h, w = frame.shape[:2] + num_rows = h // char_size + num_cols = w // char_size + + # Crop to exact multiple of char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + + # Reshape to (num_rows, char_size, num_cols, char_size, 3) + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + + # Average over char_size dimensions -> (num_rows, num_cols, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminance per cell + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + return (colors, luminances) + + if op == 'luminance-to-chars': + # (luminance-to-chars luminances alphabet contrast) -> char_indices + # Map luminance values to character indices + luminances = self._eval(args[0], env) + alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard' + contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5 + + # Get alphabet string + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Apply contrast + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + + # Map to character indices (0 = darkest, num_chars-1 = brightest) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + return char_indices + + if op == 'render-char-grid': + # (render-char-grid frame chars colors char_size color_mode background_color invert_colors) + # Render character grid using font atlas + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8 + color_mode = self._eval(args[4], env) if len(args) > 4 else 'color' + background_color = self._eval(args[5], env) if len(args) > 5 else 'black' + invert_colors = self._eval(args[6], env) if len(args) > 6 else 0 + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get the alphabet used (stored in env or default) + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + + # Get or create font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + # Try to parse hex color + try: + if background_color.startswith('#'): + bg_hex = background_color[1:] + bg = jnp.array([int(bg_hex[0:2], 16), + int(bg_hex[2:4], 16), + int(bg_hex[4:6], 16)], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + except: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create output image starting with background + output_h = num_rows * char_size + output_w = num_cols * char_size + result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy() + + # Gather characters from atlas based on indices + # char_indices shape: (num_rows, num_cols) + # font_atlas shape: (num_chars, char_size, char_size, 3) + # Convert numpy atlas to JAX for indexing with traced indices + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3) + + # Reshape to grid + char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create coordinate grids for output pixels + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = y_out // char_size + cell_col = x_out // char_size + local_y = y_out % char_size + local_x = x_out % char_size + + # Clamp to valid ranges + cell_row = jnp.clip(cell_row, 0, num_rows - 1) + cell_col = jnp.clip(cell_col, 0, num_cols - 1) + + # Get character pixel values + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + + # Get character brightness (for masking) + char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Handle color modes + if color_mode == 'mono': + # White characters on background + fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8) + fg = jnp.broadcast_to(fg_color, char_pixels.shape) + elif color_mode == 'invert': + # Inverted cell colors + cell_colors = colors[cell_row, cell_col] + fg = 255 - cell_colors + else: + # 'color' mode - use cell colors + fg = colors[cell_row, cell_col] + + # Blend foreground onto background based on character brightness + if invert_colors: + # Swap fg and bg + fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg + result = (fg.astype(jnp.float32) * (1 - char_brightness) + + bg_broadcast.astype(jnp.float32) * char_brightness) + else: + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) + + fg.astype(jnp.float32) * char_brightness) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize back to original frame size if needed + if result.shape[0] != h or result.shape[1] != w: + # Simple nearest-neighbor resize + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = (jnp.arange(h) * y_scale).astype(jnp.int32) + x_src = (jnp.arange(w) * x_scale).astype(jnp.int32) + y_src = jnp.clip(y_src, 0, result.shape[0] - 1) + x_src = jnp.clip(x_src, 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'ascii-fx-zone': + # Complex ASCII effect with per-zone expressions + # (ascii-fx-zone frame :cols cols :alphabet alphabet ...) + frame = self._eval(args[0], env) + + # Parse keyword arguments + kwargs = {} + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + if i + 1 < len(args): + kwargs[key] = args[i + 1] + i += 2 + else: + i += 1 + + # Get parameters + cols = int(self._eval(kwargs.get('cols', 80), env)) + char_size_param = kwargs.get('char_size') + alphabet = self._eval(kwargs.get('alphabet', 'standard'), env) + color_mode = self._eval(kwargs.get('color_mode', 'color'), env) + background = self._eval(kwargs.get('background', 'black'), env) + contrast = float(self._eval(kwargs.get('contrast', 1.5), env)) + + h, w = frame.shape[:2] + + # Calculate char_size from cols if not specified + if char_size_param is not None: + char_size_val = self._eval(char_size_param, env) + if char_size_val is not None: + char_size = int(char_size_val) + else: + char_size = w // cols + else: + char_size = w // cols + char_size = max(4, min(char_size, 64)) + + # Store alphabet for render-char-grid to use + env['_ascii_alphabet'] = alphabet + + # Cell sampling + num_rows = h // char_size + num_cols = w // char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminances + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + # Get alphabet and map luminances to chars + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + # Handle optional per-zone effects (char_hue, char_saturation, etc.) + # These would modify colors based on zone position + char_hue = kwargs.get('char_hue') + char_saturation = kwargs.get('char_saturation') + char_brightness = kwargs.get('char_brightness') + + if char_hue is not None or char_saturation is not None or char_brightness is not None: + # Create zone coordinate arrays for expression evaluation + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + # Bind zone variables + zone_env = env.copy() + zone_env['zone-row'] = row_coords + zone_env['zone-col'] = col_coords + zone_env['zone-row-norm'] = row_norm + zone_env['zone-col-norm'] = col_norm + zone_env['zone-lum'] = luminances + + # Apply color modifications (simplified - full version would use HSV) + if char_brightness is not None: + brightness_mult = self._eval(char_brightness, zone_env) + if brightness_mult is not None: + colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None], + 0, 255).astype(jnp.uint8) + + # Render using font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original dimensions + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'render-char-grid-fx': + # Enhanced render with per-character effects + # (render-char-grid-fx frame chars colors luminances char_size + # color_mode bg_color invert_colors + # char_jitter char_scale char_rotation char_hue_shift + # jitter_source scale_source rotation_source hue_source) + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + luminances = self._eval(args[3], env) + char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8 + color_mode = self._eval(args[5], env) if len(args) > 5 else 'color' + background_color = self._eval(args[6], env) if len(args) > 6 else 'black' + invert_colors = self._eval(args[7], env) if len(args) > 7 else 0 + + # Per-char effect amounts + char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0 + char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0 + char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0 + char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0 + + # Modulation sources + jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none' + scale_source = self._eval(args[13], env) if len(args) > 13 else 'none' + rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none' + hue_source = self._eval(args[15], env) if len(args) > 15 else 'none' + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get alphabet + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create modulation values based on source + def get_modulation(source, lums, num_rows, num_cols): + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + if source == 'luminance': + return lums + elif source == 'inv_luminance': + return 1.0 - lums + elif source == 'position_x': + return col_norm + elif source == 'position_y': + return row_norm + elif source == 'position_diag': + return (row_norm + col_norm) / 2 + elif source == 'center_dist': + cy, cx = 0.5, 0.5 + dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2) + return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal + elif source == 'random': + # Use frame-varying key for random source + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key = make_jax_key(seed, frame_num, op_ctr) + return jax.random.uniform(key, (num_rows, num_cols)) + else: + return jnp.zeros((num_rows, num_cols)) + + # Get modulation values + jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols) + scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols) + rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols) + hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + # Apply jitter if enabled + if char_jitter > 0: + jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter + # Use frame-varying key for jitter + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2) + # Generate deterministic jitter per cell + jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1) + jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1) + offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + local_y = jnp.clip(local_y + offset_y, 0, char_size - 1) + local_x = jnp.clip(local_x + offset_x, 0, char_size - 1) + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Get foreground colors + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + # Apply hue shift if enabled + if char_hue_shift > 0 and color_mode == 'color': + hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift + # Simple hue rotation via channel cycling + fg_float = fg.astype(jnp.float32) + shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB + # Simplified: blend channels based on shift + r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2] + shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac + # Just do a simple tint for now + fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8) + + # Blend + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + if invert_colors: + result = (fg.astype(jnp.float32) * (1 - char_bright) + + bg_broadcast.astype(jnp.float32) * char_bright) + else: + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'alphabet-char': + # (alphabet-char alphabet-name index) -> char_index in that alphabet + alphabet = self._eval(args[0], env) + index = self._eval(args[1], env) + + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Handle both scalar and array indices + if hasattr(index, 'shape'): + index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1) + else: + index = max(0, min(int(index), num_chars - 1)) + + return index + + if op == 'map-char-grid': + # (map-char-grid base-chars luminances (lambda (r c ch lum) ...)) + # Map over character grid, allowing per-cell character selection + base_chars = self._eval(args[0], env) + luminances = self._eval(args[1], env) + fn = args[2] # Lambda expression + + num_rows, num_cols = base_chars.shape + + # For JAX compatibility, we can't use Python loops with traced values + # Instead, we'll evaluate the lambda for the whole grid at once + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + + # Create grid coordinates + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + + # Bind parameters for whole-grid evaluation + fn_env = env.copy() + + # Params: (r c ch lum) + if len(params) >= 1: + fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords + if len(params) >= 2: + fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords + if len(params) >= 3: + fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars + if len(params) >= 4: + # Luminances scaled to 0-255 range + fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32) + + # Evaluate body - should return new character indices + result = self._eval(body, fn_env) + if hasattr(result, 'shape'): + return result.astype(jnp.int32) + return base_chars + + return base_chars + + # ===================================================================== + # List operations + # ===================================================================== + if op == 'take': + seq = self._eval(args[0], env) + n = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[:n] + return seq[:n] # Works for arrays too + + if op == 'cons': + item = self._eval(args[0], env) + seq = self._eval(args[1], env) + if isinstance(seq, list): + return [item] + seq + elif isinstance(seq, tuple): + return (item,) + seq + return jnp.concatenate([jnp.array([item]), seq]) + + if op == 'roll': + arr = self._eval(args[0], env) + shift = self._eval(args[1], env) + axis = self._eval(args[2], env) if len(args) > 2 else 0 + # Convert to int for concrete values, keep as-is for JAX traced values + if isinstance(shift, (int, float)): + shift = int(shift) + elif hasattr(shift, 'astype'): + shift = shift.astype(jnp.int32) + if isinstance(axis, (int, float)): + axis = int(axis) + return jnp.roll(arr, shift, axis=axis) + + # ===================================================================== + # Pi constant + # ===================================================================== + if op == 'pi': + return jnp.pi + + raise ValueError(f"Unknown operation: {op}") + + +# ============================================================================= +# Public API +# ============================================================================= + +def compile_effect(code: str) -> Callable: + """ + Compile an S-expression effect to a JAX function. + + Args: + code: S-expression effect code + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + # Check cache + cache_key = hashlib.md5(code.encode()).hexdigest() + if cache_key in _COMPILED_EFFECTS: + return _COMPILED_EFFECTS[cache_key] + + # Parse and compile + sexp = parse(code) + compiler = JaxCompiler() + fn = compiler.compile_effect(sexp) + + _COMPILED_EFFECTS[cache_key] = fn + return fn + + +def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable: + """ + Compile an effect from a .sexp file. + + Args: + path: Path to the .sexp effect file + derived_paths: Optional list of paths to derived.sexp files to load + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + with open(path, 'r') as f: + code = f.read() + + # Parse all expressions in file + exprs = parse_all(code) + + # Create compiler + compiler = JaxCompiler() + + # Load derived files if specified + if derived_paths: + for dp in derived_paths: + compiler.load_derived(dp) + + # Process expressions - find require statements and the effect + effect_sexp = None + effect_dir = Path(path).parent + + for expr in exprs: + if not isinstance(expr, list) or len(expr) < 2: + continue + + head = expr[0] + if not isinstance(head, Symbol): + continue + + if head.name == 'require': + # (require "derived") or (require "path/to/file") + req_path = expr[1] + if isinstance(req_path, str): + # Resolve relative to effect file + if not req_path.endswith('.sexp'): + req_path = req_path + '.sexp' + full_path = effect_dir / req_path + if not full_path.exists(): + # Try sexp_effects directory + full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path + if full_path.exists(): + compiler.load_derived(str(full_path)) + + elif head.name == 'require-primitives': + # (require-primitives "lib") - currently ignored for JAX + # JAX has all primitives built-in + pass + + elif head.name in ('effect', 'define-effect'): + effect_sexp = expr + + if effect_sexp is None: + raise ValueError(f"No effect definition found in {path}") + + return compiler.compile_effect(effect_sexp) + + +def load_derived(derived_path: str = None) -> Dict[str, Callable]: + """ + Load derived operations from derived.sexp. + + Returns dict of compiled functions that can be used in effects. + """ + if derived_path is None: + derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp' + + with open(derived_path, 'r') as f: + code = f.read() + + exprs = parse_all(code) + compiler = JaxCompiler() + env = {} + + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + compiler._eval_define(expr[1:], env) + + return env + + +# ============================================================================= +# Test / Demo +# ============================================================================= + +if __name__ == '__main__': + import numpy as np + + # Test effect + test_effect = ''' + (effect "threshold-test" + :params ((threshold :default 128)) + :body (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (gray (+ (* r 0.299) (* g 0.587) (* b 0.114))) + (mask (where (> gray threshold) 255 0))) + (merge-channels mask mask mask))) + ''' + + print("Compiling effect...") + run_effect = compile_effect(test_effect) + + # Create test frame + print("Creating test frame...") + frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) + + # Run effect + print("Running effect (first run includes JIT compilation)...") + import time + + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms") + + # Second run should be faster + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"Second run (cached): {(t1-t0)*1000:.2f}ms") + + # Multiple runs + t0 = time.time() + for _ in range(100): + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg") + + print(f"\nResult shape: {result.shape}, dtype: {result.dtype}") + print("Done!") diff --git a/l1/streaming/sources.py b/l1/streaming/sources.py new file mode 100644 index 0000000..71e7e53 --- /dev/null +++ b/l1/streaming/sources.py @@ -0,0 +1,281 @@ +""" +Video and image sources with looping support. +""" + +import numpy as np +import subprocess +import json +from pathlib import Path +from typing import Optional, Tuple +from abc import ABC, abstractmethod + + +class Source(ABC): + """Abstract base class for frame sources.""" + + @abstractmethod + def read_frame(self, t: float) -> np.ndarray: + """Read frame at time t (with looping if needed).""" + pass + + @property + @abstractmethod + def duration(self) -> float: + """Source duration in seconds.""" + pass + + @property + @abstractmethod + def size(self) -> Tuple[int, int]: + """Frame size as (width, height).""" + pass + + @property + @abstractmethod + def fps(self) -> float: + """Frames per second.""" + pass + + +class VideoSource(Source): + """ + Video file source with automatic looping. + + Reads frames on-demand, seeking as needed. When time exceeds + duration, wraps around (loops). + """ + + def __init__(self, path: str, target_fps: float = 30): + self.path = Path(path) + self.target_fps = target_fps + + # Initialize decode state first (before _probe which could fail) + self._process: Optional[subprocess.Popen] = None + self._current_start: Optional[float] = None + self._frame_buffer: Optional[np.ndarray] = None + self._buffer_time: Optional[float] = None + + self._duration = None + self._size = None + self._fps = None + + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + self._probe() + + def _probe(self): + """Get video metadata.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_format", "-show_streams", + str(self.path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + data = json.loads(result.stdout) + + # Get duration + self._duration = float(data["format"]["duration"]) + + # Get video stream info + for stream in data["streams"]: + if stream["codec_type"] == "video": + self._size = (int(stream["width"]), int(stream["height"])) + # Parse fps from r_frame_rate (e.g., "30/1" or "30000/1001") + fps_parts = stream.get("r_frame_rate", "30/1").split("/") + self._fps = float(fps_parts[0]) / float(fps_parts[1]) + break + + @property + def duration(self) -> float: + return self._duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def _start_decode(self, start_time: float): + """Start ffmpeg decode process from given time.""" + if self._process: + try: + self._process.stdout.close() + except: + pass + self._process.terminate() + try: + self._process.wait(timeout=1) + except: + self._process.kill() + self._process.wait() + + w, h = self._size + cmd = [ + "ffmpeg", "-v", "quiet", + "-ss", str(start_time), + "-i", str(self.path), + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-r", str(self.target_fps), + "-" + ] + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + bufsize=w * h * 3 * 4, # Buffer a few frames + ) + self._current_start = start_time + self._buffer_time = start_time + + def read_frame(self, t: float) -> np.ndarray: + """ + Read frame at time t. + + If t exceeds duration, wraps around (loops). + Seeks if needed, otherwise reads sequentially. + """ + # Wrap time for looping + t_wrapped = t % self._duration + + # Check if we need to seek (loop point or large time jump) + need_seek = ( + self._process is None or + self._buffer_time is None or + abs(t_wrapped - self._buffer_time) > 1.0 / self.target_fps * 2 + ) + + if need_seek: + self._start_decode(t_wrapped) + + # Read frame + w, h = self._size + frame_size = w * h * 3 + + # Try to read with retries for seek settling + for attempt in range(3): + raw = self._process.stdout.read(frame_size) + if len(raw) == frame_size: + break + # End of stream or seek not ready - restart from beginning + self._start_decode(0) + + if len(raw) < frame_size: + # Still no data - return last frame or black + if self._frame_buffer is not None: + return self._frame_buffer.copy() + return np.zeros((h, w, 3), dtype=np.uint8) + + frame = np.frombuffer(raw, dtype=np.uint8).reshape((h, w, 3)) + self._frame_buffer = frame # Cache for fallback + self._buffer_time = t_wrapped + 1.0 / self.target_fps + + return frame + + def close(self): + """Clean up resources.""" + if self._process: + self._process.terminate() + self._process.wait() + self._process = None + + def __del__(self): + self.close() + + def __repr__(self): + return f"VideoSource({self.path.name}, {self._size[0]}x{self._size[1]}, {self._duration:.1f}s)" + + +class ImageSource(Source): + """ + Static image source (returns same frame for any time). + + Useful for backgrounds, overlays, etc. + """ + + def __init__(self, path: str): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Image not found: {path}") + + # Load image + import cv2 + self._frame = cv2.imread(str(self.path)) + self._frame = cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB) + self._size = (self._frame.shape[1], self._frame.shape[0]) + + @property + def duration(self) -> float: + return float('inf') # Images last forever + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return 30.0 # Arbitrary + + def read_frame(self, t: float) -> np.ndarray: + return self._frame.copy() + + def __repr__(self): + return f"ImageSource({self.path.name}, {self._size[0]}x{self._size[1]})" + + +class LiveSource(Source): + """ + Live video capture source (webcam, capture card, etc.). + + Time parameter is ignored - always returns latest frame. + """ + + def __init__(self, device: int = 0, size: Tuple[int, int] = (1280, 720), fps: float = 30): + import cv2 + self._cap = cv2.VideoCapture(device) + self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, size[0]) + self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, size[1]) + self._cap.set(cv2.CAP_PROP_FPS, fps) + + # Get actual settings + self._size = ( + int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + ) + self._fps = self._cap.get(cv2.CAP_PROP_FPS) + + if not self._cap.isOpened(): + raise RuntimeError(f"Could not open video device {device}") + + @property + def duration(self) -> float: + return float('inf') # Live - no duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def read_frame(self, t: float) -> np.ndarray: + """Read latest frame (t is ignored for live sources).""" + import cv2 + ret, frame = self._cap.read() + if not ret: + return np.zeros((self._size[1], self._size[0], 3), dtype=np.uint8) + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + def close(self): + self._cap.release() + + def __del__(self): + self.close() + + def __repr__(self): + return f"LiveSource({self._size[0]}x{self._size[1]}, {self._fps}fps)" diff --git a/l1/streaming/stream_sexp.py b/l1/streaming/stream_sexp.py new file mode 100644 index 0000000..07acb2a --- /dev/null +++ b/l1/streaming/stream_sexp.py @@ -0,0 +1,1098 @@ +""" +Generic Streaming S-expression Interpreter. + +Executes streaming sexp recipes frame-by-frame. +The sexp defines the pipeline logic - interpreter just provides primitives. + +Primitives: + (read source-name) - read frame from source + (rotate frame :angle N) - rotate frame + (zoom frame :amount N) - zoom frame + (invert frame :amount N) - invert colors + (hue-shift frame :degrees N) - shift hue + (blend frame1 frame2 :opacity N) - blend two frames + (blend-weighted [frames...] [weights...]) - weighted blend + (ripple frame :amplitude N :cx N :cy N ...) - ripple effect + + (bind scan-name :field) - get scan state field + (map value [lo hi]) - map 0-1 value to range + energy - current energy (0-1) + beat - 1 if beat, 0 otherwise + t - current time + beat-count - total beats so far + +Example sexp: + (stream "test" + :fps 30 + (source vid "video.mp4") + (audio aud "music.mp3") + + (scan spin beat + :init {:angle 0 :dir 1} + :step (dict :angle (+ angle (* dir 10)) :dir dir)) + + (frame + (-> (read vid) + (rotate :angle (bind spin :angle)) + (zoom :amount (map energy [1 1.5]))))) +""" + +import sys +import time +import json +import hashlib +import numpy as np +import subprocess +from pathlib import Path +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional, Tuple, Union + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) +from artdag.sexp.parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class StreamContext: + """Runtime context for streaming.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + output_size: Tuple[int, int] = (720, 720) + + +class StreamCache: + """Cache for streaming data.""" + + def __init__(self, cache_dir: Path, recipe_hash: str): + self.cache_dir = cache_dir / recipe_hash + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.analysis_buffer: Dict[str, List] = {} + self.scan_states: Dict[str, List] = {} + self.keyframe_interval = 5.0 + + def record_analysis(self, name: str, t: float, value: float): + if name not in self.analysis_buffer: + self.analysis_buffer[name] = [] + t = float(t) if hasattr(t, 'item') else t + value = float(value) if hasattr(value, 'item') else value + self.analysis_buffer[name].append((t, value)) + + def record_scan_state(self, name: str, t: float, state: dict): + if name not in self.scan_states: + self.scan_states[name] = [] + states = self.scan_states[name] + if not states or t - states[-1][0] >= self.keyframe_interval: + t = float(t) if hasattr(t, 'item') else t + clean = {k: (float(v) if hasattr(v, 'item') else v) for k, v in state.items()} + self.scan_states[name].append((t, clean)) + + def flush(self): + for name, data in self.analysis_buffer.items(): + path = self.cache_dir / f"analysis_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(data) + path.write_text(json.dumps(existing)) + self.analysis_buffer.clear() + + for name, states in self.scan_states.items(): + path = self.cache_dir / f"scan_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(states) + path.write_text(json.dumps(existing)) + self.scan_states.clear() + + +class VideoSource: + """Video source - reads frames sequentially.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + # Get info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + + for s in info.get("streams", []): + if s.get("codec_type") == "video": + self.width = s.get("width", 720) + self.height = s.get("height", 720) + break + else: + self.width, self.height = 720, 720 + + self.duration = float(info.get("format", {}).get("duration", 60)) + self.size = (self.width, self.height) + + # Start decoder + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", "-r", str(fps), "-"] + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + self._frame_size = self.width * self.height * 3 + self._current_frame = None + + def read(self) -> Optional[np.ndarray]: + """Read next frame.""" + data = self._proc.stdout.read(self._frame_size) + if len(data) < self._frame_size: + return self._current_frame # Return last frame if stream ends + self._current_frame = np.frombuffer(data, dtype=np.uint8).reshape( + self.height, self.width, 3).copy() + return self._current_frame + + def skip(self): + """Read and discard frame (keep pipe in sync).""" + self._proc.stdout.read(self._frame_size) + + def close(self): + if self._proc: + self._proc.terminate() + self._proc.wait() + + +class AudioAnalyzer: + """Real-time audio analysis.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + + # Load audio + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + self._audio = np.frombuffer( + subprocess.run(cmd, capture_output=True).stdout, dtype=np.float32) + self.sample_rate = sample_rate + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + self._flux_history = [] + self._last_beat_time = -1 + + def get_energy(self, t: float) -> float: + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size//2: + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + while self._flux_history and self._flux_history[0][0] < t - 1.5: + self._flux_history.pop(0) + + if len(self._flux_history) < 3: + return False + + vals = [f for _, f in self._flux_history] + threshold = np.mean(vals) + np.std(vals) * 0.3 + 0.001 + + is_beat = flux > threshold and t - self._last_beat_time > 0.1 + if is_beat: + self._last_beat_time = t + return is_beat + + +class StreamInterpreter: + """ + Generic streaming sexp interpreter. + + Evaluates the frame pipeline expression each frame. + """ + + def __init__(self, sexp_path: str, cache_dir: str = None): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + recipe_hash = hashlib.sha256(text.encode()).hexdigest()[:16] + cache_path = Path(cache_dir) if cache_dir else self.sexp_dir / ".stream_cache" + self.cache = StreamCache(cache_path, recipe_hash) + + self.ctx = StreamContext(fps=self.config.get('fps', 30)) + self.sources: Dict[str, VideoSource] = {} + self.frames: Dict[str, np.ndarray] = {} # Current frame per source + self._sources_read: set = set() # Track which sources read this frame + self.audios: Dict[str, AudioAnalyzer] = {} # Multiple named audio sources + self.audio_paths: Dict[str, str] = {} + self.audio_state: Dict[str, dict] = {} # Per-audio: {energy, is_beat, beat_count, last_beat} + self.scans: Dict[str, dict] = {} + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} # name -> Python function + self.effects: Dict[str, dict] = {} # name -> {params, body} + self.macros: Dict[str, dict] = {} # name -> {params, body} + self.primitive_lib_dir = self.sexp_dir.parent / "sexp_effects" / "primitive_libs" + + self.frame_pipeline = None # The (frame ...) expression + + import random + self.rng = random.Random(self.config.get('seed', 42)) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file.""" + import importlib.util + + # Try multiple paths + lib_paths = [ + self.primitive_lib_dir / f"{lib_name}.py", + self.sexp_dir / "primitive_libs" / f"{lib_name}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{lib_name}.py", + ] + + lib_path = None + for p in lib_paths: + if p.exists(): + lib_path = p + break + + if not lib_path: + print(f"Warning: primitive library '{lib_name}' not found", file=sys.stderr) + return + + spec = importlib.util.spec_from_file_location(lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Extract all prim_* functions + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] # Remove 'prim_' prefix + self.primitives[prim_name] = func + # Also register with dashes instead of underscores + dash_name = prim_name.replace('_', '-') + self.primitives[dash_name] = func + # Also register with -img suffix (sexp convention) + self.primitives[dash_name + '-img'] = func + count += 1 + + # Also check for PRIMITIVES dict (some modules use this for additional exports) + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + self.primitives[name] = func + # Also register underscore version + underscore_name = name.replace('-', '_') + self.primitives[underscore_name] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions)", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + print(f"Warning: effect file not found: {effect_path}", file=sys.stderr) + return + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + # Parse params list + params_list = form[i + 1] + for p in params_list: + if isinstance(p, list) and p: + pname = p[0].name if isinstance(p[0], Symbol) else str(p[0]) + pdef = {'default': 0} + j = 1 + while j < len(p): + if isinstance(p[j], Keyword): + pdef[p[j].name] = p[j + 1] if j + 1 < len(p) else None + j += 2 + else: + j += 1 + params[pname] = pdef + i += 2 + else: + i += 2 + else: + # Body expression + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body, 'path': str(effect_path)} + print(f"Effect: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [] + body = None + + if len(form) > 2 and isinstance(form[2], list): + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + if len(form) > 3: + body = form[3] + + self.macros[name] = {'params': params, 'body': body} + print(f"Macro: {name}", file=sys.stderr) + + def _init(self): + """Initialize sources, scans, and pipeline from sexp.""" + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + # === External loading === + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + # (effect name :path "...") + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'include': + # (include :path "...") + i = 1 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) # Reuse effect loader for includes + i += 2 + else: + i += 1 + + # === Sources === + + elif cmd == 'source': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.sources[name] = VideoSource(str(full), self.ctx.fps) + print(f"Source: {name} -> {full}", file=sys.stderr) + else: + print(f"Warning: {full} not found", file=sys.stderr) + + elif cmd == 'audio': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.audios[name] = AudioAnalyzer(str(full)) + self.audio_paths[name] = str(full) + self.audio_state[name] = {'energy': 0.0, 'is_beat': False, 'beat_count': 0, 'last_beat': False} + print(f"Audio: {name} -> {full}", file=sys.stderr) + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + # Trigger can be: + # (beat audio-name) - trigger on beat from specific audio + # beat - legacy: trigger on beat from first audio + trigger_expr = form[2] + if isinstance(trigger_expr, list) and len(trigger_expr) >= 2: + # (beat audio-name) + trigger_type = trigger_expr[0].name if isinstance(trigger_expr[0], Symbol) else str(trigger_expr[0]) + trigger_audio = trigger_expr[1].name if isinstance(trigger_expr[1], Symbol) else str(trigger_expr[1]) + trigger = (trigger_type, trigger_audio) + else: + # Legacy bare symbol + trigger = trigger_expr.name if isinstance(trigger_expr, Symbol) else str(trigger_expr) + + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], {}) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger, + } + trigger_str = f"{trigger[0]} {trigger[1]}" if isinstance(trigger, tuple) else trigger + print(f"Scan: {name} (on {trigger_str})", file=sys.stderr) + + elif cmd == 'frame': + # (frame expr) - the pipeline to evaluate each frame + self.frame_pipeline = form[1] if len(form) > 1 else None + + # Set output size from first source + if self.sources: + first = next(iter(self.sources.values())) + self.ctx.output_size = first.size + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + import cv2 + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, Symbol): + name = expr.name + # Built-in values + if name == 't' or name == '_time': + return self.ctx.t + if name == 'pi': + import math + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + return 0 + + if isinstance(expr, Keyword): + return expr.name + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check if op is a closure in environment + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + # Invoke closure + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + # Insert result as first arg + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Audio analysis (explicit) === + + if op == 'energy': + # (energy audio-name) - get current energy from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['energy'] + return 0.0 + + if op == 'beat': + # (beat audio-name) - 1 if beat this frame, 0 otherwise + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return 1.0 if self.audio_state[audio_name]['is_beat'] else 0.0 + return 0.0 + + if op == 'beat-count': + # (beat-count audio-name) - total beats from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['beat_count'] + return 0 + + # === Frame operations === + + if op == 'read': + # (read source-name) - get current frame from source (lazy read) + name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if name not in self.frames: + if name in self.sources: + self.frames[name] = self.sources[name].read() + self._sources_read.add(name) + return self.frames.get(name) + + # === Binding and mapping === + + if op == 'bind': + # (bind scan-name :field) or (bind scan-name) + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + field = None + if len(args) > 1 and isinstance(args[1], Keyword): + field = args[1].name + + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if field: + return state.get(field, 0) + return state + return 0 + + if op == 'map': + # (map value [lo hi]) + val = self._eval(args[0], env) + range_list = self._eval(args[1], env) if len(args) > 1 else [0, 1] + if isinstance(range_list, list) and len(range_list) >= 2: + lo, hi = range_list[0], range_list[1] + return lo + val * (hi - lo) + return val + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + if op == 'map-range': + # (map-range val from-lo from-hi to-lo to-hi) + val = self._eval(args[0], env) + from_lo = self._eval(args[1], env) + from_hi = self._eval(args[2], env) + to_lo = self._eval(args[3], env) + to_hi = self._eval(args[4], env) + # Normalize val to 0-1 in source range, then scale to target range + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + # Lisp-style or: returns first truthy value, or last value if none truthy + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + # (cond pred1 expr1 pred2 expr2 ... true else-expr) + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + # (lambda (params...) body) - create a closure + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + # Return a closure dict that captures the current env + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + # Support both formats: + # (let [name val name val ...] body) - flat vector + # (let ((name val) (name val) ...) body) - nested list + # Note: our let already evaluates sequentially like let* + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + # Nested format: ((name val) (name val) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + # Flat format: [name val name val ...] + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Random === + + if op == 'rand': + return self.rng.random() + if op == 'rand-int': + lo = int(self._eval(args[0], env)) + hi = int(self._eval(args[1], env)) + return self.rng.randint(lo, hi) + if op == 'rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + return lo + self.rng.random() * (hi - lo) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + result[args[i].name] = self._eval(args[i + 1], env) if i + 1 < len(args) else None + i += 2 + else: + i += 1 + return result + + if op == 'get': + d = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(d, dict): + return d.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, list) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + lst = self._eval(args[0], env) + return len(lst) if isinstance(lst, (list, dict, str)) else 0 + + # === External effects === + if op in self.effects: + effect = self.effects[op] + effect_env = dict(env) + effect_env['t'] = self.ctx.t + + # Set defaults for all params + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + # Parse args: first is frame, then positional params, then kwargs + positional_idx = 0 + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + # Keyword arg + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + # Positional arg + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + return self._eval(effect['body'], effect_env) + + # === External primitives === + if op in self.primitives: + prim_func = self.primitives[op] + # Evaluate all args + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + # Call primitive + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # === Macros === + if op in self.macros: + macro = self.macros[op] + # Bind macro params to args (unevaluated) + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + macro_env[pname] = args[i] if i < len(args) else None + # Expand and evaluate + return self._eval(macro['body'], macro_env) + + # === Primitive-style call (name-with-dashes -> prim_name_with_underscores) === + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # Unknown - return as-is + return expr + + def _step_scans(self): + """Step scans on beat from specific audio.""" + for name, scan in self.scans.items(): + trigger = scan['trigger'] + + # Check if this scan should step + should_step = False + audio_name = None + + if isinstance(trigger, tuple) and trigger[0] == 'beat': + # Explicit: (beat audio-name) + audio_name = trigger[1] + if audio_name in self.audio_state: + should_step = self.audio_state[audio_name]['is_beat'] + elif trigger == 'beat': + # Legacy: use first audio + if self.audio_state: + audio_name = next(iter(self.audio_state)) + should_step = self.audio_state[audio_name]['is_beat'] + + if should_step and audio_name: + state = self.audio_state[audio_name] + env = dict(scan['state']) + env['beat_count'] = state['beat_count'] + env['t'] = self.ctx.t + env['energy'] = state['energy'] + + if scan['step']: + new_state = self._eval(scan['step'], env) + if isinstance(new_state, dict): + scan['state'] = new_state + elif new_state is not None: + scan['state'] = {'acc': new_state} + + self.cache.record_scan_state(name, self.ctx.t, scan['state']) + + def run(self, duration: float = None, output: str = "pipe"): + """Run the streaming pipeline.""" + from .output import PipeOutput, DisplayOutput, FileOutput + + self._init() + + if not self.sources: + print("Error: no sources", file=sys.stderr) + return + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w, h = self.ctx.output_size + + # Duration from first audio or default + if duration is None: + if self.audios: + first_audio = next(iter(self.audios.values())) + duration = first_audio.duration + else: + duration = 60.0 + + n_frames = int(duration * self.ctx.fps) + frame_time = 1.0 / self.ctx.fps + + print(f"Streaming {n_frames} frames @ {self.ctx.fps}fps", file=sys.stderr) + + # Use first audio for playback sync + first_audio_path = next(iter(self.audio_paths.values())) if self.audio_paths else None + + # Output + if output == "pipe": + out = PipeOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + else: + out = FileOutput(output, size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + + try: + for frame_num in range(n_frames): + if not out.is_open: + print(f"\nOutput closed at {frame_num}", file=sys.stderr) + break + + self.ctx.t = frame_num * frame_time + self.ctx.frame_num = frame_num + + # Update all audio states + for audio_name, analyzer in self.audios.items(): + state = self.audio_state[audio_name] + energy = analyzer.get_energy(self.ctx.t) + is_beat_raw = analyzer.get_beat(self.ctx.t) + is_beat = is_beat_raw and not state['last_beat'] + state['last_beat'] = is_beat_raw + + state['energy'] = energy + state['is_beat'] = is_beat + if is_beat: + state['beat_count'] += 1 + + self.cache.record_analysis(f'{audio_name}_energy', self.ctx.t, energy) + self.cache.record_analysis(f'{audio_name}_beat', self.ctx.t, 1.0 if is_beat else 0.0) + + # Step scans + self._step_scans() + + # Clear frames - will be read lazily + self.frames.clear() + self._sources_read = set() + + # Evaluate pipeline (reads happen on-demand) + result = self._eval(self.frame_pipeline, {}) + + # Skip unread sources to keep pipes in sync + for name, src in self.sources.items(): + if name not in self._sources_read: + src.skip() + + # Ensure output size + if result is not None: + import cv2 + if result.shape[:2] != (h, w): + # Handle CuPy arrays - cv2 can't resize them directly + if hasattr(result, '__cuda_array_interface__'): + # Use GPU resize via cupyx.scipy + try: + import cupy as cp + from cupyx.scipy import ndimage as cpndimage + curr_h, curr_w = result.shape[:2] + zoom_y = h / curr_h + zoom_x = w / curr_w + if result.ndim == 3: + result = cpndimage.zoom(result, (zoom_y, zoom_x, 1), order=1) + else: + result = cpndimage.zoom(result, (zoom_y, zoom_x), order=1) + except ImportError: + # Fallback to CPU resize + result = cv2.resize(cp.asnumpy(result), (w, h)) + else: + result = cv2.resize(result, (w, h)) + out.write(result, self.ctx.t) + + # Progress + if frame_num % 30 == 0: + pct = 100 * frame_num / n_frames + # Show beats from first audio + total_beats = 0 + if self.audio_state: + first_state = next(iter(self.audio_state.values())) + total_beats = first_state['beat_count'] + print(f"\r{pct:5.1f}% | beats:{total_beats}", + end="", file=sys.stderr) + sys.stderr.flush() + + if frame_num % 300 == 0: + self.cache.flush() + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + finally: + out.close() + for src in self.sources.values(): + src.close() + self.cache.flush() + + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path) + if fps: + interp.ctx.fps = fps + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None, help="Override fps (default: from sexp)") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps) diff --git a/l1/streaming/stream_sexp_generic.py b/l1/streaming/stream_sexp_generic.py new file mode 100644 index 0000000..0619589 --- /dev/null +++ b/l1/streaming/stream_sexp_generic.py @@ -0,0 +1,1739 @@ +""" +Fully Generic Streaming S-expression Interpreter. + +The interpreter knows NOTHING about video, audio, or any domain. +All domain logic comes from primitives loaded via (require-primitives ...). + +Built-in forms: + - Control: if, cond, let, let*, lambda, -> + - Arithmetic: +, -, *, /, mod, map-range + - Comparison: <, >, =, <=, >=, and, or, not + - Data: dict, get, list, nth, len, quote + - Random: rand, rand-int, rand-range + - Scan: bind (access scan state) + +Everything else comes from primitives or effects. + +Context (ctx) is passed explicitly to frame evaluation: + - ctx.t: current time + - ctx.frame-num: current frame number + - ctx.fps: frames per second +""" + +import sys +import os +import time +import json +import hashlib +import math +import numpy as np +from pathlib import Path +from dataclasses import dataclass +from typing import Dict, List, Any, Optional, Tuple, Callable + +# Use local sexp_effects parser (supports namespaced symbols like math:sin) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from sexp_effects.parser import parse, parse_all, Symbol, Keyword + +# JAX backend (optional - loaded on demand) +_JAX_AVAILABLE = False +_jax_compiler = None + +def _init_jax(): + """Lazily initialize JAX compiler.""" + global _JAX_AVAILABLE, _jax_compiler + if _jax_compiler is not None: + return _JAX_AVAILABLE + try: + from streaming.sexp_to_jax import JaxCompiler, compile_effect_file + _jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file} + _JAX_AVAILABLE = True + print("JAX backend initialized", file=sys.stderr) + except ImportError as e: + print(f"JAX backend not available: {e}", file=sys.stderr) + _JAX_AVAILABLE = False + return _JAX_AVAILABLE + + +@dataclass +class Context: + """Runtime context passed to frame evaluation.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + + +class DeferredEffectChain: + """ + Represents a chain of JAX effects that haven't been executed yet. + + Allows effects to be accumulated through let bindings and fused + into a single JIT-compiled function when the result is needed. + """ + __slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num') + + def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int): + self.effects = effects # List of effect names, innermost first + self.params_list = params_list # List of param dicts, matching effects + self.base_frame = base_frame # The actual frame array at the start + self.t = t + self.frame_num = frame_num + + def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain': + """Add another effect to the chain (outermost).""" + return DeferredEffectChain( + self.effects + [effect_name], + self.params_list + [params], + self.base_frame, + self.t, + self.frame_num + ) + + @property + def shape(self): + """Allow shape check without forcing execution.""" + return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None + + +class StreamInterpreter: + """ + Fully generic streaming sexp interpreter. + + No domain-specific knowledge - just evaluates expressions + and calls primitives. + """ + + def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + self.actor_id = actor_id # For friendly name resolution + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + # Global environment for def bindings + self.globals: Dict[str, Any] = {} + + # Scans + self.scans: Dict[str, dict] = {} + + # Audio playback path (for syncing output) + self.audio_playback: Optional[str] = None + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} + self.effects: Dict[str, dict] = {} + self.macros: Dict[str, dict] = {} + + # JAX backend for accelerated effect evaluation + self.use_jax = use_jax + self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects + self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects + self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains + self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains + self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env + if use_jax: + if _init_jax(): + print("JAX acceleration enabled", file=sys.stderr) + else: + print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr) + self.use_jax = False + # Try multiple locations for primitive_libs + possible_paths = [ + self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir + self.sexp_dir / "sexp_effects" / "primitive_libs", # app root + Path(__file__).parent.parent / "sexp_effects" / "primitive_libs", # relative to interpreter + ] + self.primitive_lib_dir = next((p for p in possible_paths if p.exists()), possible_paths[0]) + + self.frame_pipeline = None + + # External config files (set before run()) + self.sources_config: Optional[Path] = None + self.audio_config: Optional[Path] = None + + # Error tracking + self.errors: List[str] = [] + + # Callback for live streaming (called when IPFS playlist is updated) + self.on_playlist_update: callable = None + + # Callback for progress updates (called periodically during rendering) + # Signature: on_progress(percent: float, frame_num: int, total_frames: int) + self.on_progress: callable = None + + # Callback for checkpoint saves (called at segment boundaries for resumability) + # Signature: on_checkpoint(checkpoint: dict) + # checkpoint contains: frame_num, t, scans + self.on_checkpoint: callable = None + + # Frames per segment for checkpoint timing (default 4 seconds at 30fps = 120 frames) + self._frames_per_segment: int = 120 + + def _resolve_name(self, name: str) -> Optional[Path]: + """Resolve a friendly name to a file path using the naming service.""" + try: + # Import here to avoid circular imports + from tasks.streaming import resolve_asset + path = resolve_asset(name, self.actor_id) + if path: + return path + except Exception as e: + print(f"Warning: failed to resolve name '{name}': {e}", file=sys.stderr) + return None + + def _record_error(self, msg: str): + """Record an error that occurred during evaluation.""" + self.errors.append(msg) + print(f"ERROR: {msg}", file=sys.stderr) + + def _maybe_to_numpy(self, val, for_gpu_primitive: bool = False): + """Convert GPU frames/CuPy arrays to numpy for CPU primitives. + + If for_gpu_primitive=True, preserve GPU data (CuPy arrays stay on GPU). + """ + if val is None: + return val + + # For GPU primitives, keep data on GPU + if for_gpu_primitive: + # Handle GPUFrame - return the GPU array + if hasattr(val, 'gpu') and hasattr(val, 'is_on_gpu'): + if val.is_on_gpu: + return val.gpu + return val.cpu + # CuPy arrays pass through unchanged + if hasattr(val, '__cuda_array_interface__'): + return val + return val + + # For CPU primitives, convert to numpy + # Handle GPUFrame objects (have .cpu property) + if hasattr(val, 'cpu'): + return val.cpu + # Handle CuPy arrays (have .get() method) + if hasattr(val, 'get') and callable(val.get): + return val.get() + return val + + def _load_config_file(self, config_path): + """Load a config file and process its definitions.""" + config_path = Path(config_path) # Accept str or Path + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + text = config_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'def': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Config: {name}", file=sys.stderr) + + elif cmd == 'audio-playback': + # Path relative to working directory (consistent with other paths) + path = str(form[1]).strip('"') + self.audio_playback = str(Path(path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file. + + Prefers GPU-accelerated versions (*_gpu.py) when available. + Uses cached modules from sys.modules to ensure consistent state + (e.g., same RNG instance for all interpreters). + """ + import importlib.util + + # Try GPU version first, then fall back to CPU version + lib_names_to_try = [f"{lib_name}_gpu", lib_name] + + lib_path = None + actual_lib_name = lib_name + + for try_lib in lib_names_to_try: + lib_paths = [ + self.primitive_lib_dir / f"{try_lib}.py", + self.sexp_dir / "primitive_libs" / f"{try_lib}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{try_lib}.py", + ] + for p in lib_paths: + if p.exists(): + lib_path = p + actual_lib_name = try_lib + break + if lib_path: + break + + if not lib_path: + raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}") + + # Use cached module if already imported to preserve state (e.g., RNG) + # This is critical for deterministic random number sequences + # Check multiple possible module keys (standard import paths and our cache) + possible_keys = [ + f"sexp_effects.primitive_libs.{actual_lib_name}", + f"sexp_primitives.{actual_lib_name}", + ] + + module = None + for key in possible_keys: + if key in sys.modules: + module = sys.modules[key] + break + + if module is None: + spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Cache for future use under our key + sys.modules[f"sexp_primitives.{actual_lib_name}"] = module + + # Check if this is a GPU-accelerated module + is_gpu = actual_lib_name.endswith('_gpu') + gpu_tag = " [GPU]" if is_gpu else "" + + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] + dash_name = prim_name.replace('_', '-') + # Register with original lib_name namespace (geometry:rotate, not geometry_gpu:rotate) + # Don't overwrite if already registered (allows pre-registration of overrides) + key = f"{lib_name}:{dash_name}" + if key not in self.primitives: + self.primitives[key] = func + count += 1 + + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + # Register with original lib_name namespace + # Don't overwrite if already registered + dash_name = name.replace('_', '-') + key = f"{lib_name}:{dash_name}" + if key not in self.primitives: + self.primitives[key] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions){gpu_tag}", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + raise FileNotFoundError(f"Effect/include file not found: {effect_path}") + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + for pdef in form[i + 1]: + if isinstance(pdef, list) and pdef: + pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0]) + pinfo = {'default': 0} + j = 1 + while j < len(pdef): + if isinstance(pdef[j], Keyword) and j + 1 < len(pdef): + pinfo[pdef[j].name] = pdef[j + 1] + j += 2 + else: + j += 1 + params[pname] = pinfo + i += 2 + else: + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body} + self.jax_effect_paths[name] = effect_path # Track source for JAX compilation + print(f"Effect: {name}", file=sys.stderr) + + # Try to compile with JAX if enabled + if self.use_jax and _JAX_AVAILABLE: + self._compile_jax_effect(name, effect_path) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'effect': + # Handle (effect name :path "...") or (effect name :name "...") in included files + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'include': + # Handle (include :path "...") or (include :name "...") in included files + i = 1 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'scan': + # Handle scans from included files + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + def _compile_jax_effect(self, name: str, effect_path: Path): + """Compile an effect with JAX for accelerated execution.""" + if not _JAX_AVAILABLE or name in self.jax_effects: + return + + try: + compile_effect_file = _jax_compiler['compile_effect_file'] + jax_fn = compile_effect_file(str(effect_path)) + self.jax_effects[name] = jax_fn + print(f" [JAX compiled: {name}]", file=sys.stderr) + except Exception as e: + # Silently fall back to interpreter for unsupported effects + if 'Unknown operation' not in str(e): + print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr) + + def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]: + """Apply a JAX-compiled effect to a frame.""" + if name not in self.jax_effects: + return None + + try: + jax_fn = self.jax_effects[name] + # Handle GPU frames (CuPy) - need to move to CPU for CPU JAX + # JAX handles numpy and JAX arrays natively, no conversion needed + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() # CuPy array -> numpy + + # Get seed from config for deterministic random + seed = self.config.get('seed', 42) + + # Call JAX function with parameters + # Return JAX array directly - don't block or convert per-effect + # Conversion to numpy happens once at frame write time + return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) + except Exception as e: + # Fall back to interpreter on error + print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) + return None + + def _is_jax_effect_expr(self, expr) -> bool: + """Check if an expression is a JAX-compiled effect call.""" + if not isinstance(expr, list) or not expr: + return False + head = expr[0] + if not isinstance(head, Symbol): + return False + return head.name in self.jax_effects + + def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]: + """ + Extract a chain of JAX effects from an expression. + + Returns: (effect_names, params_list, base_frame_expr) or None if not a chain. + effect_names and params_list are in execution order (innermost first). + + For (effect1 (effect2 frame :p1 v1) :p2 v2): + Returns: (['effect2', 'effect1'], [params2, params1], frame_expr) + """ + if not self._is_jax_effect_expr(expr): + return None + + chain = [] + params_list = [] + current = expr + + while self._is_jax_effect_expr(current): + head = current[0] + effect_name = head.name + args = current[1:] + + # Extract params for this effect + effect = self.effects[effect_name] + effect_params = {} + for pname, pdef in effect['params'].items(): + effect_params[pname] = pdef.get('default', 0) + + # Find the frame argument (first positional) and other params + frame_arg = None + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_params[pname] = self._eval(args[i + 1], env) + i += 2 + else: + if frame_arg is None: + frame_arg = args[i] # First positional is frame + i += 1 + + chain.append(effect_name) + params_list.append(effect_params) + + if frame_arg is None: + return None # No frame argument found + + # Check if frame_arg is another effect call + if self._is_jax_effect_expr(frame_arg): + current = frame_arg + else: + # End of chain - frame_arg is the base frame + # Reverse to get innermost-first execution order + chain.reverse() + params_list.reverse() + return (chain, params_list, frame_arg) + + return None + + def _get_chain_key(self, effect_names: list, params_list: list) -> str: + """Generate a cache key for an effect chain. + + Includes static param values in the key since they affect compilation. + """ + parts = [] + for name, params in zip(effect_names, params_list): + param_parts = [] + for pname in sorted(params.keys()): + pval = params[pname] + # Include static values in key (strings, bools) + if isinstance(pval, (str, bool)): + param_parts.append(f"{pname}={pval}") + else: + param_parts.append(pname) + parts.append(f"{name}:{','.join(param_parts)}") + return '|'.join(parts) + + def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a chain of effects into a single fused JAX function. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts for each effect (used to detect static types) + + Returns: + JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param names and identify static params from actual values + effect_param_names = [] + static_params = ['seed'] # seed is always static + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + # Check actual values to identify static types + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.append(f"_p{i}_{pname}") + + def fused_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + # Extract params for this effect from kwargs + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # JIT with static params (seed + any string/bool params) + return jax.jit(fused_fn, static_argnames=static_params) + except Exception as e: + print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int): + """Apply a chain of effects, using fused compilation if available.""" + chain_key = self._get_chain_key(effect_names, params_list) + + # Try to get or compile fused chain + if chain_key not in self.jax_fused_chains: + fused_fn = self._compile_effect_chain(effect_names, params_list) + self.jax_fused_chains[chain_key] = fused_fn + if fused_fn: + print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr) + + fused_fn = self.jax_fused_chains.get(chain_key) + + if fused_fn is not None: + # Build kwargs with prefixed param names + kwargs = {} + for i, params in enumerate(params_list): + for pname, pval in params.items(): + kwargs[f"_p{i}_{pname}"] = pval + + seed = self.config.get('seed', 42) + + # Handle GPU frames + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() + + try: + return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs) + except Exception as e: + print(f"Fused chain error: {e}", file=sys.stderr) + + # Fall back to sequential application + result = frame + for name, params in zip(effect_names, params_list): + result = self._apply_jax_effect(name, result, params, t, frame_num) + if result is None: + return None + return result + + def _force_deferred(self, deferred: DeferredEffectChain): + """Execute a deferred effect chain and return the actual array.""" + if len(deferred.effects) == 0: + return deferred.base_frame + + return self._apply_effect_chain( + deferred.effects, + deferred.params_list, + deferred.base_frame, + deferred.t, + deferred.frame_num + ) + + def _maybe_force(self, value): + """Force a deferred chain if needed, otherwise return as-is.""" + if isinstance(value, DeferredEffectChain): + return self._force_deferred(value) + return value + + def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a vmapped version of an effect chain for batch processing. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts (used to detect static types) + + Returns: + Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames + Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + import jax.numpy as jnp + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param info + effect_param_names = [] + static_params = set() + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.add(f"_p{i}_{pname}") + + # Single-frame function (will be vmapped) + def single_frame_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # Return unbatched function - we'll vmap at call time with proper in_axes + return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params)) + except Exception as e: + print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_batched_chain(self, effect_names: list, params_list_batch: list, + frames: list, ts: list, frame_nums: list) -> Optional[list]: + """ + Apply an effect chain to a batch of frames using vmap. + + Args: + effect_names: List of effect names + params_list_batch: List of params_list for each frame in batch + frames: List of input frames + ts: List of time values + frame_nums: List of frame numbers + + Returns: + List of output frames, or None on failure + """ + if not frames: + return [] + + # Use first frame's params for chain key (assume same structure) + chain_key = self._get_chain_key(effect_names, params_list_batch[0]) + batch_key = f"batch:{chain_key}" + + # Compile batched version if needed + if batch_key not in self.jax_batched_chains: + batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0]) + self.jax_batched_chains[batch_key] = batched_fn + if batched_fn: + print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr) + + batched_fn = self.jax_batched_chains.get(batch_key) + + if batched_fn is not None: + try: + import jax + import jax.numpy as jnp + + # Stack frames into batch array + frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames]) + ts_array = jnp.array(ts) + frame_nums_array = jnp.array(frame_nums) + + # Build kwargs - all numeric params as arrays for vmap + kwargs = {} + static_kwargs = {} # Non-vmapped (strings, bools) + + for i, plist in enumerate(zip(*[p for p in params_list_batch])): + for j, pname in enumerate(params_list_batch[0][i].keys()): + key = f"_p{i}_{pname}" + values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]] + + first = values[0] + if isinstance(first, (str, bool)): + # Static params - not vmapped + static_kwargs[key] = first + elif isinstance(first, (int, float)): + # Always batch numeric params for simplicity + kwargs[key] = jnp.array(values) + elif hasattr(first, 'shape'): + kwargs[key] = jnp.stack(values) + else: + kwargs[key] = jnp.array(values) + + seed = self.config.get('seed', 42) + + # Create wrapper that unpacks the params dict + def single_call(frame, t, frame_num, params_dict): + return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs) + + # vmap over frame, t, frame_num, and the params dict (as pytree) + vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0)) + + # Stack kwargs into a dict of arrays (pytree with matching structure) + results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs) + + # Unstack results + return [results[i] for i in range(len(frames))] + except Exception as e: + print(f"Batched chain error: {e}", file=sys.stderr) + + # Fall back to sequential + return None + + def _init(self): + """Initialize from sexp - load primitives, effects, defs, scans.""" + # Set random seed for deterministic output + seed = self.config.get('seed', 42) + try: + from sexp_effects.primitive_libs.core import set_random_seed + set_random_seed(seed) + except ImportError: + pass + + # Load external config files first (they can override recipe definitions) + if self.sources_config: + self._load_config_file(self.sources_config) + if self.audio_config: + self._load_config_file(self.audio_config) + + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + # Resolve friendly name to path + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'include': + i = 1 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + # Resolve friendly name to path + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'audio-playback': + # (audio-playback "path") - set audio file for playback sync + # Skip if already set by config file + if self.audio_playback is None: + path = str(form[1]).strip('"') + # Try to resolve as friendly name first + resolved = self._resolve_name(path) + if resolved: + self.audio_playback = str(resolved) + else: + # Fall back to relative path + self.audio_playback = str((self.sexp_dir / path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + elif cmd == 'def': + # (def name expr) - evaluate and store in globals + # Skip if already defined by config file + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + if name in self.globals: + print(f"Def: {name} (from config, skipped)", file=sys.stderr) + continue + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Def: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + elif cmd == 'frame': + self.frame_pipeline = form[1] if len(form) > 1 else None + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, bool): + return expr + + if isinstance(expr, Symbol): + name = expr.name + # Built-in constants + if name == 'pi': + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Global lookup + if name in self.globals: + return self.globals[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + raise NameError(f"Undefined variable: {name}") + + if isinstance(expr, Keyword): + return expr.name + + # Handle dicts from new parser - evaluate values + if isinstance(expr, dict): + return {k: self._eval(v, env) for k, v in expr.items()} + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check for closure call + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + if op in self.globals: + val = self.globals[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Binding === + + if op == 'bind': + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if len(args) > 1: + key = args[1].name if isinstance(args[1], Keyword) else str(args[1]) + return state.get(key, 0) + return state + return 0 + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + result[key] = val + i += 2 + else: + i += 1 + return result + + if op == 'get': + obj = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(obj, dict): + return obj.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'quote': + return args[0] if args else None + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + val = self._eval(args[0], env) + return len(val) if hasattr(val, '__len__') else 0 + + if op == 'map': + seq = self._eval(args[0], env) + fn = self._eval(args[1], env) + if not isinstance(seq, (list, tuple)): + return [] + # Handle closure (lambda from sexp) + if isinstance(fn, dict) and fn.get('_type') == 'closure': + results = [] + for item in seq: + closure_env = dict(fn['env']) + if fn['params']: + closure_env[fn['params'][0]] = item + results.append(self._eval(fn['body'], closure_env)) + return results + # Handle Python callable + if callable(fn): + return [fn(item) for item in seq] + return [] + + # === Effects === + + if op in self.effects: + # Try to detect and fuse effect chains for JAX acceleration + if self.use_jax and op in self.jax_effects: + chain_info = self._extract_effect_chain(expr, env) + if chain_info is not None: + effect_names, params_list, base_frame_expr = chain_info + # Only use chain if we have 2+ effects (worth fusing) + if len(effect_names) >= 2: + base_frame = self._eval(base_frame_expr, env) + if base_frame is not None and hasattr(base_frame, 'shape'): + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num) + if result is not None: + return result + # Fall through if chain application fails + + effect = self.effects[op] + effect_env = dict(env) + + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + positional_idx = 0 + frame_val = None + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + frame_val = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + # Try JAX-accelerated execution with deferred chaining + if self.use_jax and op in self.jax_effects and frame_val is not None: + # Build params dict for JAX (exclude 'frame') + jax_params = {k: self._maybe_force(v) for k, v in effect_env.items() + if k != 'frame' and k in effect['params']} + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + + # Check if input is a deferred chain - if so, extend it + if isinstance(frame_val, DeferredEffectChain): + return frame_val.extend(op, jax_params) + + # Check if input is a valid frame - create new deferred chain + if hasattr(frame_val, 'shape'): + return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num) + + # Fall through to interpreter if not a valid frame + + # Force any deferred frame before interpreter evaluation + if isinstance(frame_val, DeferredEffectChain): + frame_val = self._force_deferred(frame_val) + effect_env['frame'] = frame_val + + return self._eval(effect['body'], effect_env) + + # === Primitives === + + if op in self.primitives: + prim_func = self.primitives[op] + # Check if this is a GPU primitive (preserves GPU arrays) + is_gpu_prim = op.startswith('gpu:') or 'gpu' in op.lower() + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + # Force deferred chains before passing to primitives + v = self._maybe_force(v) + kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) + i += 2 + else: + val = self._eval(args[i], env) + # Force deferred chains before passing to primitives + val = self._maybe_force(val) + evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + self._record_error(f"Primitive {op} error: {e}") + raise RuntimeError(f"Primitive {op} failed: {e}") + + # === Macros (function-like: args evaluated before binding) === + + if op in self.macros: + macro = self.macros[op] + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + # Evaluate args in calling environment before binding + macro_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(macro['body'], macro_env) + + # Underscore variant lookup + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + # Check if this is a GPU primitive (preserves GPU arrays) + is_gpu_prim = 'gpu' in prim_name.lower() + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) + i += 2 + else: + evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) + i += 1 + + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + self._record_error(f"Primitive {op} error: {e}") + raise RuntimeError(f"Primitive {op} failed: {e}") + + # Unknown function call - raise meaningful error + raise RuntimeError(f"Unknown function or primitive: '{op}'. " + f"Available primitives: {sorted(list(self.primitives.keys())[:10])}... " + f"Available effects: {sorted(list(self.effects.keys())[:10])}... " + f"Available macros: {sorted(list(self.macros.keys())[:10])}...") + + def _step_scans(self, ctx: Context, env: dict): + """Step scans based on trigger evaluation.""" + for name, scan in self.scans.items(): + trigger_expr = scan['trigger'] + + # Evaluate trigger in context + should_step = self._eval(trigger_expr, env) + + if should_step: + state = scan['state'] + step_env = dict(state) + step_env.update(env) + + new_state = self._eval(scan['step'], step_env) + if isinstance(new_state, dict): + scan['state'] = new_state + else: + scan['state'] = {'acc': new_state} + + def _restore_checkpoint(self, checkpoint: dict): + """Restore scan states from a checkpoint. + + Called when resuming a render from a previous checkpoint. + + Args: + checkpoint: Dict with 'scans' key containing {scan_name: state_dict} + """ + scans_state = checkpoint.get('scans', {}) + for name, state in scans_state.items(): + if name in self.scans: + self.scans[name]['state'] = dict(state) + print(f"Restored scan '{name}' state from checkpoint", file=sys.stderr) + + def _get_checkpoint_state(self) -> dict: + """Get current scan states for checkpointing. + + Returns: + Dict mapping scan names to their current state dicts + """ + return {name: dict(scan['state']) for name, scan in self.scans.items()} + + def run(self, duration: float = None, output: str = "pipe", resume_from: dict = None): + """Run the streaming pipeline. + + Args: + duration: Duration in seconds (auto-detected from audio if None) + output: Output mode ("pipe", "preview", path/hls, path/ipfs-hls, or file path) + resume_from: Checkpoint dict to resume from, with keys: + - frame_num: Last completed frame + - t: Time value for checkpoint frame + - scans: Dict of scan states to restore + - segment_cids: Dict of quality -> {seg_num: cid} for output resume + """ + # Import output classes - handle both package and direct execution + try: + from .output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput + from .gpu_output import GPUHLSOutput, check_gpu_encode_available + from .multi_res_output import MultiResolutionHLSOutput + except ImportError: + from output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput + try: + from gpu_output import GPUHLSOutput, check_gpu_encode_available + except ImportError: + GPUHLSOutput = None + check_gpu_encode_available = lambda: False + try: + from multi_res_output import MultiResolutionHLSOutput + except ImportError: + MultiResolutionHLSOutput = None + + self._init() + + # Restore checkpoint state if resuming + if resume_from: + self._restore_checkpoint(resume_from) + print(f"Resuming from frame {resume_from.get('frame_num', 0)}", file=sys.stderr) + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w = self.config.get('width', 720) + h = self.config.get('height', 720) + fps = self.config.get('fps', 30) + + if duration is None: + # Try to get duration from audio if available + for name, val in self.globals.items(): + if hasattr(val, 'duration'): + duration = val.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + break + else: + duration = 60.0 + + n_frames = int(duration * fps) + frame_time = 1.0 / fps + + print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr) + + # Create context + ctx = Context(fps=fps) + + # Output (with optional audio sync) + # Resolve audio path lazily here if it wasn't resolved during parsing + audio = self.audio_playback + if audio and not Path(audio).exists(): + # Try to resolve as friendly name (may have failed during parsing) + audio_name = Path(audio).name # Get just the name part + resolved = self._resolve_name(audio_name) + if resolved and resolved.exists(): + audio = str(resolved) + print(f"Lazy resolved audio: {audio}", file=sys.stderr) + else: + raise FileNotFoundError(f"Audio file not found: {audio}") + if output == "pipe": + out = PipeOutput(size=(w, h), fps=fps, audio_source=audio) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio) + elif output.endswith("/hls"): + # HLS output - output is a directory path ending in /hls + hls_dir = output[:-4] # Remove /hls suffix + out = HLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio) + elif output.endswith("/ipfs-hls"): + # IPFS HLS output - multi-resolution adaptive streaming + hls_dir = output[:-9] # Remove /ipfs-hls suffix + import os + ipfs_gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + + # Build resume state for output if resuming + output_resume = None + if resume_from and resume_from.get('segment_cids'): + output_resume = {'segment_cids': resume_from['segment_cids']} + + # Use multi-resolution output (renders original + 720p + 360p) + if MultiResolutionHLSOutput is not None: + print(f"[StreamInterpreter] Using multi-resolution HLS output ({w}x{h} + 720p + 360p)", file=sys.stderr) + out = MultiResolutionHLSOutput( + hls_dir, + source_size=(w, h), + fps=fps, + ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update, + audio_source=audio, + resume_from=output_resume, + ) + # Fallback to GPU single-resolution if multi-res not available + elif GPUHLSOutput is not None and check_gpu_encode_available(): + print(f"[StreamInterpreter] Using GPU zero-copy encoding (single resolution)", file=sys.stderr) + out = GPUHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update) + else: + out = IPFSHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update) + else: + out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio) + + # Calculate frames per segment based on fps and segment duration (4 seconds default) + segment_duration = 4.0 + self._frames_per_segment = int(fps * segment_duration) + + # Determine start frame (resume from checkpoint + 1, or 0) + start_frame = 0 + if resume_from and resume_from.get('frame_num') is not None: + start_frame = resume_from['frame_num'] + 1 + print(f"Starting from frame {start_frame} (checkpoint was at {resume_from['frame_num']})", file=sys.stderr) + + try: + frame_times = [] + profile_interval = 30 # Profile every N frames + scan_times = [] + eval_times = [] + write_times = [] + + # Batch accumulation for JAX + batch_deferred = [] # Accumulated DeferredEffectChains + batch_times = [] # Corresponding time values + batch_start_frame = 0 + + def flush_batch(): + """Execute accumulated batch and write results.""" + nonlocal batch_deferred, batch_times + if not batch_deferred: + return + + t_flush = time.time() + + # Check if all chains have same structure (can batch) + first = batch_deferred[0] + can_batch = ( + self.use_jax and + len(batch_deferred) >= 2 and + all(d.effects == first.effects for d in batch_deferred) + ) + + if can_batch: + # Try batched execution + frames = [d.base_frame for d in batch_deferred] + ts = [d.t for d in batch_deferred] + frame_nums = [d.frame_num for d in batch_deferred] + params_batch = [d.params_list for d in batch_deferred] + + results = self._apply_batched_chain( + first.effects, params_batch, frames, ts, frame_nums + ) + + if results is not None: + # Write batched results + for result, t in zip(results, batch_times): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + batch_deferred = [] + batch_times = [] + return + + # Fall back to sequential execution + for deferred, t in zip(batch_deferred, batch_times): + result = self._force_deferred(deferred) + if result is not None and hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + + batch_deferred = [] + batch_times = [] + + for frame_num in range(start_frame, n_frames): + if not out.is_open: + break + + frame_start = time.time() + ctx.t = frame_num * frame_time + ctx.frame_num = frame_num + + # Build frame environment with context + frame_env = { + 'ctx': { + 't': ctx.t, + 'frame-num': ctx.frame_num, + 'fps': ctx.fps, + }, + 't': ctx.t, # Also expose t directly for convenience + 'frame-num': ctx.frame_num, + } + + # Step scans + t0 = time.time() + self._step_scans(ctx, frame_env) + scan_times.append(time.time() - t0) + + # Evaluate pipeline + t1 = time.time() + result = self._eval(self.frame_pipeline, frame_env) + eval_times.append(time.time() - t1) + + t2 = time.time() + if result is not None: + if isinstance(result, DeferredEffectChain): + # Accumulate for batching + batch_deferred.append(result) + batch_times.append(ctx.t) + + # Flush when batch is full + if len(batch_deferred) >= self.jax_batch_size: + flush_batch() + else: + # Not deferred - flush any pending batch first, then write + flush_batch() + if hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, ctx.t) + write_times.append(time.time() - t2) + + frame_elapsed = time.time() - frame_start + frame_times.append(frame_elapsed) + + # Checkpoint at segment boundaries (every _frames_per_segment frames) + if frame_num > 0 and frame_num % self._frames_per_segment == 0: + if self.on_checkpoint: + try: + checkpoint = { + 'frame_num': frame_num, + 't': ctx.t, + 'scans': self._get_checkpoint_state(), + } + self.on_checkpoint(checkpoint) + except Exception as e: + print(f"Warning: checkpoint callback failed: {e}", file=sys.stderr) + + # Progress with timing and profile breakdown + if frame_num % profile_interval == 0 and frame_num > 0: + pct = 100 * frame_num / n_frames + avg_ms = 1000 * sum(frame_times[-profile_interval:]) / max(1, len(frame_times[-profile_interval:])) + avg_scan = 1000 * sum(scan_times[-profile_interval:]) / max(1, len(scan_times[-profile_interval:])) + avg_eval = 1000 * sum(eval_times[-profile_interval:]) / max(1, len(eval_times[-profile_interval:])) + avg_write = 1000 * sum(write_times[-profile_interval:]) / max(1, len(write_times[-profile_interval:])) + target_ms = 1000 * frame_time + print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms] scan={avg_scan:.0f}ms eval={avg_eval:.0f}ms write={avg_write:.0f}ms", end="", file=sys.stderr, flush=True) + + # Call progress callback if set (for Celery task state updates) + if self.on_progress: + try: + self.on_progress(pct, frame_num, n_frames) + except Exception as e: + print(f"Warning: progress callback failed: {e}", file=sys.stderr) + + # Flush any remaining batch + flush_batch() + + finally: + out.close() + # Store output for access to properties like playlist_cid + self.output = out + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None, + sources_config: str = None, audio_config: str = None, use_jax: bool = False): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path, use_jax=use_jax) + if fps: + interp.config['fps'] = fps + if sources_config: + interp.sources_config = Path(sources_config) + if audio_config: + interp.audio_config = Path(audio_config) + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None) + parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file") + parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file") + parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps, + sources_config=args.sources_config, audio_config=args.audio_config, + use_jax=args.jax) diff --git a/l1/tasks/__init__.py b/l1/tasks/__init__.py new file mode 100644 index 0000000..6a07c25 --- /dev/null +++ b/l1/tasks/__init__.py @@ -0,0 +1,13 @@ +# art-celery/tasks - Celery tasks for streaming video rendering +# +# Tasks: +# 1. run_stream - Execute a streaming S-expression recipe +# 2. upload_to_ipfs - Background IPFS upload for media files + +from .streaming import run_stream +from .ipfs_upload import upload_to_ipfs + +__all__ = [ + "run_stream", + "upload_to_ipfs", +] diff --git a/l1/tasks/ipfs_upload.py b/l1/tasks/ipfs_upload.py new file mode 100644 index 0000000..541f850 --- /dev/null +++ b/l1/tasks/ipfs_upload.py @@ -0,0 +1,93 @@ +""" +Background IPFS upload task. + +Uploads files to IPFS in the background after initial local storage. +This allows fast uploads while still getting IPFS CIDs eventually. +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Optional + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from celery_app import app +import ipfs_client + +logger = logging.getLogger(__name__) + + +@app.task(bind=True, max_retries=3, default_retry_delay=60) +def upload_to_ipfs(self, local_cid: str, actor_id: str) -> Optional[str]: + """ + Upload a locally cached file to IPFS in the background. + + Args: + local_cid: The local content hash of the file + actor_id: The user who uploaded the file + + Returns: + IPFS CID if successful, None if failed + """ + from cache_manager import get_cache_manager + import asyncio + import database + + logger.info(f"Background IPFS upload starting for {local_cid[:16]}...") + + try: + cache_mgr = get_cache_manager() + + # Get the file path from local cache + file_path = cache_mgr.get_by_cid(local_cid) + if not file_path or not file_path.exists(): + logger.error(f"File not found for local CID {local_cid[:16]}...") + return None + + # Upload to IPFS + logger.info(f"Uploading {file_path} to IPFS...") + ipfs_cid = ipfs_client.add_file(file_path) + + if not ipfs_cid: + logger.error(f"IPFS upload failed for {local_cid[:16]}...") + raise self.retry(exc=Exception("IPFS upload failed")) + + logger.info(f"IPFS upload successful: {local_cid[:16]}... -> {ipfs_cid[:16]}...") + + # Update database with IPFS CID + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Initialize database pool if needed + loop.run_until_complete(database.init_pool()) + + # Update cache_items table + loop.run_until_complete( + database.update_cache_item_ipfs_cid(local_cid, ipfs_cid) + ) + + # Update friendly_names table to use IPFS CID instead of local hash + # This ensures assets can be fetched by remote workers via IPFS + try: + loop.run_until_complete( + database.update_friendly_name_cid(actor_id, local_cid, ipfs_cid) + ) + logger.info(f"Friendly name updated: {local_cid[:16]}... -> {ipfs_cid[:16]}...") + except Exception as e: + logger.warning(f"Failed to update friendly name CID: {e}") + + # Create index from IPFS CID to local cache + cache_mgr._set_content_index(ipfs_cid, local_cid) + + logger.info(f"Database updated with IPFS CID for {local_cid[:16]}...") + finally: + loop.close() + + return ipfs_cid + + except Exception as e: + logger.error(f"Background IPFS upload error: {e}") + raise self.retry(exc=e) diff --git a/l1/tasks/streaming.py b/l1/tasks/streaming.py new file mode 100644 index 0000000..7ac6057 --- /dev/null +++ b/l1/tasks/streaming.py @@ -0,0 +1,724 @@ +""" +Streaming video rendering task. + +Executes S-expression recipes for frame-by-frame video processing. +Supports CID and friendly name references for assets. +Supports pause/resume/restart for long renders. +""" + +import hashlib +import logging +import os +import signal +import sys +import tempfile +from pathlib import Path +from typing import Dict, Optional + +from celery import current_task + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from celery_app import app +from cache_manager import get_cache_manager + +logger = logging.getLogger(__name__) + + +class PauseRequested(Exception): + """Raised when user requests pause via SIGTERM.""" + pass + +# Debug: verify module is being loaded +print(f"DEBUG MODULE LOAD: tasks/streaming.py loaded at {__file__}", file=sys.stderr) + +# Module-level event loop for database operations +_resolve_loop = None +_db_initialized = False + + +def resolve_asset(ref: str, actor_id: Optional[str] = None) -> Optional[Path]: + """ + Resolve an asset reference (CID or friendly name) to a file path. + + Args: + ref: CID or friendly name (e.g., "my-video" or "QmXyz...") + actor_id: User ID for friendly name resolution + + Returns: + Path to the asset file, or None if not found + """ + global _resolve_loop, _db_initialized + import sys + print(f"RESOLVE_ASSET: ref={ref}, actor_id={actor_id}", file=sys.stderr) + cache_mgr = get_cache_manager() + + # Try as direct CID first + path = cache_mgr.get_by_cid(ref) + print(f"RESOLVE_ASSET: get_by_cid({ref}) = {path}", file=sys.stderr) + if path and path.exists(): + logger.info(f"Resolved {ref[:16]}... as CID to {path}") + return path + + # Try as friendly name if actor_id provided + print(f"RESOLVE_ASSET: trying friendly name lookup, actor_id={actor_id}", file=sys.stderr) + if actor_id: + from database import resolve_friendly_name_sync, get_ipfs_cid_sync + + try: + # Use synchronous database functions to avoid event loop issues + cid = resolve_friendly_name_sync(actor_id, ref) + print(f"RESOLVE_ASSET: resolve_friendly_name_sync({actor_id}, {ref}) = {cid}", file=sys.stderr) + + if cid: + path = cache_mgr.get_by_cid(cid) + print(f"RESOLVE_ASSET: get_by_cid({cid}) = {path}", file=sys.stderr) + if path and path.exists(): + print(f"RESOLVE_ASSET: SUCCESS - resolved to {path}", file=sys.stderr) + logger.info(f"Resolved '{ref}' via friendly name to {path}") + return path + + # File not in local cache - look up IPFS CID and fetch + # The cid from friendly_names is internal, need to get ipfs_cid from cache_items + ipfs_cid = get_ipfs_cid_sync(cid) + if not ipfs_cid or ipfs_cid == cid: + # No separate IPFS CID, try using the cid directly (might be IPFS CID) + ipfs_cid = cid + print(f"RESOLVE_ASSET: file not local, trying IPFS fetch for {ipfs_cid}", file=sys.stderr) + import ipfs_client + content = ipfs_client.get_bytes(ipfs_cid, use_gateway_fallback=True) + if content: + # Save to local cache + import tempfile + from pathlib import Path + with tempfile.NamedTemporaryFile(delete=False, suffix='.sexp') as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + # Store in cache + cached_file, _ = cache_mgr.put(tmp_path, node_type="effect", skip_ipfs=True) + # Index by IPFS CID for future lookups + cache_mgr._set_content_index(cid, cached_file.cid) + print(f"RESOLVE_ASSET: fetched from IPFS and cached at {cached_file.path}", file=sys.stderr) + logger.info(f"Fetched '{ref}' from IPFS and cached at {cached_file.path}") + return cached_file.path + else: + print(f"RESOLVE_ASSET: IPFS fetch failed for {cid}", file=sys.stderr) + except Exception as e: + print(f"RESOLVE_ASSET: ERROR - {e}", file=sys.stderr) + logger.warning(f"Failed to resolve friendly name '{ref}': {e}") + + logger.warning(f"Could not resolve asset reference: {ref}") + return None + + +class CIDVideoSource: + """ + Video source that resolves CIDs to file paths. + + Wraps the streaming VideoSource to work with cached assets. + """ + + def __init__(self, cid: str, fps: float = 30, actor_id: Optional[str] = None): + self.cid = cid + self.fps = fps + self.actor_id = actor_id + self._source = None + + def _ensure_source(self): + if self._source is None: + logger.info(f"CIDVideoSource._ensure_source: resolving cid={self.cid} with actor_id={self.actor_id}") + path = resolve_asset(self.cid, self.actor_id) + if not path: + raise ValueError(f"Could not resolve video source '{self.cid}' for actor_id={self.actor_id}") + + logger.info(f"CIDVideoSource._ensure_source: resolved to path={path}") + # Use GPU-accelerated video source if available + try: + from sexp_effects.primitive_libs.streaming_gpu import GPUVideoSource, GPU_AVAILABLE + if GPU_AVAILABLE: + logger.info(f"CIDVideoSource: using GPUVideoSource for {path}") + self._source = GPUVideoSource(str(path), self.fps, prefer_gpu=True) + else: + raise ImportError("GPU not available") + except (ImportError, Exception) as e: + logger.info(f"CIDVideoSource: falling back to CPU VideoSource ({e})") + from sexp_effects.primitive_libs.streaming import VideoSource + self._source = VideoSource(str(path), self.fps) + + def read_at(self, t: float): + self._ensure_source() + return self._source.read_at(t) + + def read(self): + self._ensure_source() + return self._source.read() + + @property + def size(self): + self._ensure_source() + return self._source.size + + @property + def duration(self): + self._ensure_source() + return self._source._duration + + @property + def path(self): + self._ensure_source() + return self._source.path + + @property + def _stream_time(self): + self._ensure_source() + return self._source._stream_time + + def skip(self): + self._ensure_source() + return self._source.skip() + + def close(self): + if self._source: + self._source.close() + + +class CIDAudioAnalyzer: + """ + Audio analyzer that resolves CIDs to file paths. + """ + + def __init__(self, cid: str, actor_id: Optional[str] = None): + self.cid = cid + self.actor_id = actor_id + self._analyzer = None + + def _ensure_analyzer(self): + if self._analyzer is None: + path = resolve_asset(self.cid, self.actor_id) + if not path: + raise ValueError(f"Could not resolve audio source: {self.cid}") + + from sexp_effects.primitive_libs.streaming import AudioAnalyzer + self._analyzer = AudioAnalyzer(str(path)) + + def get_energy(self, t: float) -> float: + self._ensure_analyzer() + return self._analyzer.get_energy(t) + + def get_beat(self, t: float) -> bool: + self._ensure_analyzer() + return self._analyzer.get_beat(t) + + def get_beat_count(self, t: float) -> int: + self._ensure_analyzer() + return self._analyzer.get_beat_count(t) + + @property + def duration(self): + self._ensure_analyzer() + return self._analyzer.duration + + +def create_cid_primitives(actor_id: Optional[str] = None): + """ + Create CID-aware primitive functions. + + Returns dict of primitives that resolve CIDs before creating sources. + """ + from celery.utils.log import get_task_logger + cid_logger = get_task_logger(__name__) + def prim_make_video_source_cid(cid: str, fps: float = 30): + cid_logger.warning(f"DEBUG: CID-aware make-video-source: cid={cid}, actor_id={actor_id}") + return CIDVideoSource(cid, fps, actor_id) + + def prim_make_audio_analyzer_cid(cid: str): + cid_logger.warning(f"DEBUG: CID-aware make-audio-analyzer: cid={cid}, actor_id={actor_id}") + return CIDAudioAnalyzer(cid, actor_id) + + return { + 'streaming:make-video-source': prim_make_video_source_cid, + 'streaming:make-audio-analyzer': prim_make_audio_analyzer_cid, + } + + +@app.task(bind=True, name='tasks.run_stream') +def run_stream( + self, + run_id: str, + recipe_sexp: str, + output_name: str = "output.mp4", + duration: Optional[float] = None, + fps: Optional[float] = None, + actor_id: Optional[str] = None, + sources_sexp: Optional[str] = None, + audio_sexp: Optional[str] = None, + resume: bool = False, +) -> dict: + """ + Execute a streaming S-expression recipe. + + Args: + run_id: The run ID for database tracking + recipe_sexp: The recipe S-expression content + output_name: Name for the output file + duration: Optional duration override (seconds) + fps: Optional FPS override + actor_id: User ID for friendly name resolution + sources_sexp: Optional sources config S-expression + audio_sexp: Optional audio config S-expression + resume: If True, load checkpoint and resume from where we left off + + Returns: + Dict with output_cid, output_path, and status + """ + global _resolve_loop, _db_initialized + task_id = self.request.id + logger.info(f"Starting stream task {task_id} for run {run_id} (resume={resume})") + + # Handle graceful pause (SIGTERM from Celery revoke) + pause_requested = False + original_sigterm = signal.getsignal(signal.SIGTERM) + + def handle_sigterm(signum, frame): + nonlocal pause_requested + pause_requested = True + logger.info(f"Pause requested for run {run_id} (SIGTERM received)") + + signal.signal(signal.SIGTERM, handle_sigterm) + + self.update_state(state='INITIALIZING', meta={'progress': 0}) + + # Get the app directory for primitive/effect paths + app_dir = Path(__file__).parent.parent # celery/ + sexp_effects_dir = app_dir / "sexp_effects" + effects_dir = app_dir / "effects" + templates_dir = app_dir / "templates" + + # Create temp directory for work + work_dir = Path(tempfile.mkdtemp(prefix="stream_")) + recipe_path = work_dir / "recipe.sexp" + + # Write output to shared cache for live streaming access + cache_dir = Path(os.environ.get("CACHE_DIR", "/data/cache")) + stream_dir = cache_dir / "streaming" / run_id + stream_dir.mkdir(parents=True, exist_ok=True) + # Use IPFS HLS output for distributed streaming - segments uploaded to IPFS + output_path = str(stream_dir) + "/ipfs-hls" # /ipfs-hls suffix triggers IPFS HLS mode + + # Create symlinks to effect directories so relative paths work + (work_dir / "sexp_effects").symlink_to(sexp_effects_dir) + (work_dir / "effects").symlink_to(effects_dir) + (work_dir / "templates").symlink_to(templates_dir) + + try: + # Write recipe to temp file + recipe_path.write_text(recipe_sexp) + + # Write optional config files + sources_path = None + if sources_sexp: + sources_path = work_dir / "sources.sexp" + sources_path.write_text(sources_sexp) + + audio_path = None + if audio_sexp: + audio_path = work_dir / "audio.sexp" + audio_path.write_text(audio_sexp) + + self.update_state(state='RENDERING', meta={'progress': 5}) + + # Import the streaming interpreter + from streaming.stream_sexp_generic import StreamInterpreter + + # Load checkpoint if resuming + checkpoint = None + if resume: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + checkpoint = _resolve_loop.run_until_complete(database.get_run_checkpoint(run_id)) + if checkpoint: + logger.info(f"Loaded checkpoint for run {run_id}: frame {checkpoint.get('frame_num')}") + else: + logger.warning(f"No checkpoint found for run {run_id}, starting from beginning") + except Exception as e: + logger.error(f"Failed to load checkpoint: {e}") + checkpoint = None + + # Create interpreter (pass actor_id for friendly name resolution) + interp = StreamInterpreter(str(recipe_path), actor_id=actor_id) + + # Set primitive library directory explicitly + interp.primitive_lib_dir = sexp_effects_dir / "primitive_libs" + + if fps: + interp.config['fps'] = fps + if sources_path: + interp.sources_config = sources_path + if audio_path: + interp.audio_config = audio_path + + # Override primitives with CID-aware versions + cid_prims = create_cid_primitives(actor_id) + from celery.utils.log import get_task_logger + task_logger = get_task_logger(__name__) + task_logger.warning(f"DEBUG: Overriding primitives: {list(cid_prims.keys())}") + task_logger.warning(f"DEBUG: Primitives before: {list(interp.primitives.keys())[:10]}...") + interp.primitives.update(cid_prims) + task_logger.warning(f"DEBUG: streaming:make-video-source is now: {type(interp.primitives.get('streaming:make-video-source'))}") + + # Set up callback to update database when IPFS playlist is created (for live HLS redirect) + def on_playlist_update(playlist_cid, quality_playlists=None): + """Update database with playlist CID and quality info. + + Args: + playlist_cid: Master playlist CID + quality_playlists: Dict of quality name -> {cid, width, height, bitrate} + """ + global _resolve_loop, _db_initialized + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_playlist(run_id, playlist_cid, quality_playlists)) + logger.info(f"Updated pending run {run_id} with IPFS playlist: {playlist_cid}, qualities: {list(quality_playlists.keys()) if quality_playlists else []}") + except Exception as e: + logger.error(f"Failed to update playlist CID in database: {e}") + + interp.on_playlist_update = on_playlist_update + + # Set up progress callback to update Celery task state + def on_progress(pct, frame_num, total_frames): + nonlocal pause_requested + # Scale progress: 5% (start) to 85% (before caching) + scaled_progress = 5 + (pct * 0.8) # 5% to 85% + self.update_state(state='RENDERING', meta={ + 'progress': scaled_progress, + 'frame': frame_num, + 'total_frames': total_frames, + 'percent': pct, + }) + + interp.on_progress = on_progress + + # Set up checkpoint callback to save state at segment boundaries + def on_checkpoint(ckpt): + """Save checkpoint state to database.""" + nonlocal pause_requested + global _resolve_loop, _db_initialized + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + + # Get total frames from interpreter config + total_frames = None + if hasattr(interp, 'output') and hasattr(interp.output, '_frame_count'): + # Estimate total frames based on duration + fps_val = interp.config.get('fps', 30) + for name, val in interp.globals.items(): + if hasattr(val, 'duration'): + total_frames = int(val.duration * fps_val) + break + + _resolve_loop.run_until_complete(database.update_pending_run_checkpoint( + run_id=run_id, + checkpoint_frame=ckpt['frame_num'], + checkpoint_t=ckpt['t'], + checkpoint_scans=ckpt.get('scans'), + total_frames=total_frames, + )) + logger.info(f"Saved checkpoint for run {run_id}: frame {ckpt['frame_num']}") + + # Check if pause was requested after checkpoint + if pause_requested: + logger.info(f"Pause requested after checkpoint, raising PauseRequested") + raise PauseRequested("Render paused by user") + + except PauseRequested: + raise # Re-raise to stop the render + except Exception as e: + logger.error(f"Failed to save checkpoint: {e}") + + interp.on_checkpoint = on_checkpoint + + # Build resume state for the interpreter (includes segment CIDs for output) + resume_from = None + if checkpoint: + resume_from = { + 'frame_num': checkpoint.get('frame_num'), + 't': checkpoint.get('t'), + 'scans': checkpoint.get('scans', {}), + } + # Add segment CIDs if available (from quality_playlists in checkpoint) + # Note: We need to extract segment_cids from the output's state, which isn't + # directly stored. For now, the output will re-check existing segments on disk. + + # Run rendering to file + logger.info(f"Rendering to {output_path}" + (f" (resuming from frame {resume_from['frame_num']})" if resume_from else "")) + render_paused = False + try: + interp.run(duration=duration, output=str(output_path), resume_from=resume_from) + except PauseRequested: + # Graceful pause - checkpoint already saved + render_paused = True + logger.info(f"Render paused for run {run_id}") + + # Restore original signal handler + signal.signal(signal.SIGTERM, original_sigterm) + + if render_paused: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status(run_id, 'paused')) + except Exception as e: + logger.error(f"Failed to update status to paused: {e}") + return {"status": "paused", "run_id": run_id, "task_id": task_id} + + # Check for interpreter errors + if interp.errors: + error_msg = f"Rendering failed with {len(interp.errors)} errors: {interp.errors[0]}" + raise RuntimeError(error_msg) + + self.update_state(state='CACHING', meta={'progress': 90}) + + # Get IPFS playlist CID if available (from IPFSHLSOutput) + ipfs_playlist_cid = None + ipfs_playlist_url = None + segment_cids = {} + if hasattr(interp, 'output') and hasattr(interp.output, 'playlist_cid'): + ipfs_playlist_cid = interp.output.playlist_cid + ipfs_playlist_url = interp.output.playlist_url + segment_cids = getattr(interp.output, 'segment_cids', {}) + logger.info(f"IPFS HLS: playlist={ipfs_playlist_cid}, segments={len(segment_cids)}") + + # Update pending run with playlist CID for live HLS redirect + if ipfs_playlist_cid: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_playlist(run_id, ipfs_playlist_cid)) + logger.info(f"Updated pending run {run_id} with IPFS playlist CID: {ipfs_playlist_cid}") + except Exception as e: + logger.error(f"Failed to update pending run with playlist CID: {e}") + raise # Fail fast - database errors should not be silently ignored + + # HLS output creates playlist and segments + # - Single-res: stream_dir/stream.m3u8 and stream_dir/segment_*.ts + # - Multi-res: stream_dir/original/playlist.m3u8 and stream_dir/original/segment_*.ts + hls_playlist = stream_dir / "stream.m3u8" + if not hls_playlist.exists(): + # Try multi-res output path + hls_playlist = stream_dir / "original" / "playlist.m3u8" + + # Validate HLS output (must have playlist and at least one segment) + if not hls_playlist.exists(): + raise RuntimeError("HLS playlist not created - rendering likely failed") + + segments = list(stream_dir.glob("segment_*.ts")) + if not segments: + # Try multi-res output path + segments = list(stream_dir.glob("original/segment_*.ts")) + if not segments: + raise RuntimeError("No HLS segments created - rendering likely failed") + + logger.info(f"HLS rendering complete: {len(segments)} segments created, IPFS playlist: {ipfs_playlist_cid}") + + # Mux HLS segments into a single MP4 for persistent cache storage + final_mp4 = stream_dir / "output.mp4" + import subprocess + mux_cmd = [ + "ffmpeg", "-y", + "-i", str(hls_playlist), + "-c", "copy", # Just copy streams, no re-encoding + "-movflags", "+faststart", # Move moov atom to start for web playback + "-fflags", "+genpts", # Generate proper timestamps + str(final_mp4) + ] + logger.info(f"Muxing HLS to MP4: {' '.join(mux_cmd)}") + result = subprocess.run(mux_cmd, capture_output=True, text=True) + if result.returncode != 0: + logger.warning(f"HLS mux failed: {result.stderr}") + # Fall back to using the first segment for caching + final_mp4 = segments[0] + + # Store output in cache + if final_mp4.exists(): + cache_mgr = get_cache_manager() + cached_file, ipfs_cid = cache_mgr.put( + source_path=final_mp4, + node_type="STREAM_OUTPUT", + node_id=f"stream_{task_id}", + ) + + logger.info(f"Stream output cached: CID={cached_file.cid}, IPFS={ipfs_cid}") + + # Save to database - reuse the module-level loop to avoid pool conflicts + import asyncio + import database + + try: + # Reuse or create event loop + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + + # Initialize database pool if needed + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + + # Get recipe CID from pending_run + pending = _resolve_loop.run_until_complete(database.get_pending_run(run_id)) + recipe_cid = pending.get("recipe", "streaming") if pending else "streaming" + + # Save to run_cache for completed runs + logger.info(f"Saving run {run_id} to run_cache with actor_id={actor_id}") + _resolve_loop.run_until_complete(database.save_run_cache( + run_id=run_id, + output_cid=cached_file.cid, + recipe=recipe_cid, + inputs=[], + ipfs_cid=ipfs_cid, + actor_id=actor_id, + )) + # Register output as video type so frontend displays it correctly + _resolve_loop.run_until_complete(database.add_item_type( + cid=cached_file.cid, + actor_id=actor_id, + item_type="video", + path=str(cached_file.path), + description=f"Stream output from run {run_id}", + )) + logger.info(f"Registered output {cached_file.cid} as video type") + # Update pending run status + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="completed", + )) + logger.info(f"Saved run {run_id} to database with actor_id={actor_id}") + except Exception as db_err: + logger.error(f"Failed to save run to database: {db_err}") + raise RuntimeError(f"Database error saving run {run_id}: {db_err}") from db_err + + return { + "status": "completed", + "run_id": run_id, + "task_id": task_id, + "output_cid": cached_file.cid, + "ipfs_cid": ipfs_cid, + "output_path": str(cached_file.path), + # IPFS HLS streaming info + "ipfs_playlist_cid": ipfs_playlist_cid, + "ipfs_playlist_url": ipfs_playlist_url, + "ipfs_segment_count": len(segment_cids), + } + else: + # Update pending run status to failed - reuse module loop + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="failed", + error="Output file not created", + )) + except Exception as db_err: + logger.warning(f"Failed to update run status: {db_err}") + + return { + "status": "failed", + "run_id": run_id, + "task_id": task_id, + "error": "Output file not created", + } + + except Exception as e: + logger.error(f"Stream task {task_id} failed: {e}") + import traceback + traceback.print_exc() + + # Update pending run status to failed - reuse module loop + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="failed", + error=str(e), + )) + except Exception as db_err: + logger.warning(f"Failed to update run status: {db_err}") + + return { + "status": "failed", + "run_id": run_id, + "task_id": task_id, + "error": str(e), + } + + finally: + # Cleanup temp directory only - NOT the streaming directory! + # The streaming directory contains HLS segments that may still be uploading + # to IPFS. Deleting it prematurely causes upload failures and missing segments. + # + # stream_dir cleanup should happen via: + # 1. A separate cleanup task that runs after confirming IPFS uploads succeeded + # 2. Or a periodic cleanup job that removes old streaming dirs + import shutil + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + # NOTE: stream_dir is intentionally NOT deleted here to allow IPFS uploads to complete + # TODO: Implement a deferred cleanup mechanism for stream_dir after IPFS confirmation diff --git a/l1/templates/crossfade-zoom.sexp b/l1/templates/crossfade-zoom.sexp new file mode 100644 index 0000000..fc6d9ad --- /dev/null +++ b/l1/templates/crossfade-zoom.sexp @@ -0,0 +1,25 @@ +;; Crossfade with Zoom Transition +;; +;; Macro for transitioning between two frames with a zoom effect. +;; Active frame zooms out while next frame zooms in. +;; +;; Required context: +;; - zoom effect must be loaded +;; - blend effect must be loaded +;; +;; Parameters: +;; active-frame: current frame +;; next-frame: frame to transition to +;; fade-amt: transition progress (0 = all active, 1 = all next) +;; +;; Usage: +;; (include :path "../templates/crossfade-zoom.sexp") +;; ... +;; (crossfade-zoom active-frame next-frame 0.5) + +(defmacro crossfade-zoom (active-frame next-frame fade-amt) + (let [active-zoom (+ 1.0 fade-amt) + active-zoomed (zoom active-frame :amount active-zoom) + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (zoom next-frame :amount next-zoom)] + (blend active-zoomed next-zoomed :opacity fade-amt))) diff --git a/l1/templates/cycle-crossfade.sexp b/l1/templates/cycle-crossfade.sexp new file mode 100644 index 0000000..40a87ca --- /dev/null +++ b/l1/templates/cycle-crossfade.sexp @@ -0,0 +1,65 @@ +;; cycle-crossfade template +;; +;; Generalized cycling zoom-crossfade for any number of video layers. +;; Cycles through videos with smooth zoom-based crossfade transitions. +;; +;; Parameters: +;; beat-data - beat analysis node (drives timing) +;; input-videos - list of video nodes to cycle through +;; init-clen - initial cycle length in beats +;; +;; NOTE: The parameter is named "input-videos" (not "videos") because +;; template substitution replaces param names everywhere in the AST. +;; The planner's _expand_slice_on injects env["videos"] at plan time, +;; so (len videos) inside the lambda references that injected value. + +(deftemplate cycle-crossfade + (beat-data input-videos init-clen) + + (slice-on beat-data + :videos input-videos + :init {:cycle 0 :beat 0 :clen init-clen} + :fn (lambda [acc i start end] + (let [beat (get acc "beat") + clen (get acc "clen") + active (get acc "cycle") + n (len videos) + phase3 (* beat 3) + wt (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (- 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.0)) + (if (= active prev) + (if (< phase3 clen) 0.0 + (if (< phase3 (* clen 2)) + (* (/ (- phase3 clen) clen) 1.0) + 1.0)) + 0.0)))) + zm (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + ;; Active video: normal -> zoom out during transition -> tiny + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (+ 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.1)) + (if (= active prev) + ;; Incoming video: tiny -> zoom in during transition -> normal + (if (< phase3 clen) 0.1 + (if (< phase3 (* clen 2)) + (+ 0.1 (* (/ (- phase3 clen) clen) 0.9)) + 1.0)) + 0.1)))) + new-acc (if (< (+ beat 1) clen) + (dict :cycle active :beat (+ beat 1) :clen clen) + (dict :cycle (mod (+ active 1) n) :beat 0 + :clen (+ 40 (mod (* i 7) 41))))] + {:layers (map (lambda [p] + {:video p :effects [{:effect zoom :amount (zm p)}]}) + (range 0 n)) + :compose {:effect blend_multi :mode "alpha" + :weights (map (lambda [p] (wt p)) (range 0 n))} + :acc new-acc})))) diff --git a/l1/templates/process-pair.sexp b/l1/templates/process-pair.sexp new file mode 100644 index 0000000..6720cd2 --- /dev/null +++ b/l1/templates/process-pair.sexp @@ -0,0 +1,112 @@ +;; process-pair template +;; +;; Reusable video-pair processor: takes a single video source, creates two +;; clips (A and B) with opposite rotations and sporadic effects, blends them, +;; and applies a per-pair slow rotation driven by a beat scan. +;; +;; All sporadic triggers (invert, hue-shift, ascii) and pair-level controls +;; (blend opacity, rotation) are defined internally using seed offsets. +;; +;; Parameters: +;; video - source video node +;; energy - energy analysis node (drives rotation/zoom amounts) +;; beat-data - beat analysis node (drives sporadic triggers) +;; rng - RNG object from (make-rng seed) for auto-derived seeds +;; rot-dir - initial rotation direction: 1 (clockwise) or -1 (anti-clockwise) +;; rot-a/b - rotation ranges for clip A/B (e.g. [0 45]) +;; zoom-a/b - zoom ranges for clip A/B (e.g. [1 1.5]) + +(deftemplate process-pair + (video energy beat-data rng rot-dir rot-a rot-b zoom-a zoom-b) + + ;; --- Sporadic triggers for clip A --- + + ;; Invert: 10% chance per beat, lasts 1-5 beats + (def inv-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + ;; Hue shift: 10% chance, random hue 30-330 deg, lasts 1-5 beats + (def hue-a (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + ;; ASCII art: 5% chance, lasts 1-3 beats + (def ascii-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Sporadic triggers for clip B (offset seeds) --- + + (def inv-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + (def hue-b (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + (def ascii-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Pair-level controls --- + + ;; Internal A/B blend: randomly show A (0), both (0.5), or B (1), every 1-11 beats + (def pair-mix (scan beat-data :rng rng + :init (dict :rem 0 :opacity 0.5) + :step (if (> rem 0) + (dict :rem (- rem 1) :opacity opacity) + (dict :rem (rand-int 1 11) :opacity (* (rand-int 0 2) 0.5))) + :emit opacity)) + + ;; Per-pair rotation: one full rotation every 20-30 beats, alternating direction + (def pair-rot (scan beat-data :rng rng + :init (dict :beat 0 :clen 25 :dir rot-dir :angle 0) + :step (if (< (+ beat 1) clen) + (dict :beat (+ beat 1) :clen clen :dir dir + :angle (+ angle (* dir (/ 360 clen)))) + (dict :beat 0 :clen (rand-int 20 30) :dir (* dir -1) + :angle angle)) + :emit angle)) + + ;; --- Clip A processing --- + (def clip-a (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-a (-> clip-a + (effect rotate :angle (bind energy values :range rot-a)) + (effect zoom :amount (bind energy values :range zoom-a)) + (effect invert :amount (bind inv-a values)) + (effect hue_shift :degrees (bind hue-a values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-a values)) + )) + + ;; --- Clip B processing --- + (def clip-b (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-b (-> clip-b + (effect rotate :angle (bind energy values :range rot-b)) + (effect zoom :amount (bind energy values :range zoom-b)) + (effect invert :amount (bind inv-b values)) + (effect hue_shift :degrees (bind hue-b values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-b values)) + )) + + ;; --- Blend A+B and apply pair rotation --- + (-> rotated-a + (effect blend rotated-b + :mode "alpha" :opacity (bind pair-mix values) :resize_mode "fit") + (effect rotate + :angle (bind pair-rot values)))) diff --git a/l1/templates/scan-oscillating-spin.sexp b/l1/templates/scan-oscillating-spin.sexp new file mode 100644 index 0000000..051f079 --- /dev/null +++ b/l1/templates/scan-oscillating-spin.sexp @@ -0,0 +1,28 @@ +;; Oscillating Spin Scan +;; +;; Accumulates rotation angle on each beat, reversing direction +;; periodically for an oscillating effect. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; +;; Provides scan: spin +;; Bind with: (bind spin :angle) ;; cumulative rotation angle +;; +;; Behavior: +;; - Rotates 14.4 degrees per beat (completes 360 in 25 beats) +;; - After 20-30 beats, reverses direction +;; - Creates a swinging/oscillating rotation effect +;; +;; Usage: +;; (include :path "../templates/scan-oscillating-spin.sexp") +;; +;; In frame: +;; (rotate frame :angle (bind spin :angle)) + +(scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :left 25} + :step (if (> left 0) + (dict :angle (+ angle (* dir 14.4)) :dir dir :left (- left 1)) + (dict :angle angle :dir (* dir -1) + :left (+ 20 (mod (streaming:audio-beat-count music t) 11))))) diff --git a/l1/templates/scan-ripple-drops.sexp b/l1/templates/scan-ripple-drops.sexp new file mode 100644 index 0000000..7caf720 --- /dev/null +++ b/l1/templates/scan-ripple-drops.sexp @@ -0,0 +1,41 @@ +;; Beat-Triggered Ripple Drops Scan +;; +;; Creates random ripple drops triggered by audio beats. +;; Each drop has a random center position and duration. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; - core primitives loaded +;; +;; Provides scan: ripple-state +;; Bind with: (bind ripple-state :gate) ;; 0 or 1 +;; (bind ripple-state :cx) ;; center x (0-1) +;; (bind ripple-state :cy) ;; center y (0-1) +;; +;; Parameters: +;; trigger-chance: probability per beat (default 0.15) +;; min-duration: minimum beats (default 1) +;; max-duration: maximum beats (default 15) +;; +;; Usage: +;; (include :path "../templates/scan-ripple-drops.sexp") +;; ;; Uses default: 15% chance, 1-15 beat duration +;; +;; In frame: +;; (let [rip-gate (bind ripple-state :gate) +;; rip-amp (* rip-gate (core:map-range e 0 1 5 50))] +;; (ripple frame +;; :amplitude rip-amp +;; :center_x (bind ripple-state :cx) +;; :center_y (bind ripple-state :cy))) + +(scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 0.5 :cy 0.5 :left 0} + :step (if (> left 0) + (dict :gate 1 :cx cx :cy cy :left (- left 1)) + (if (< (core:rand) 0.15) + (dict :gate 1 + :cx (+ 0.2 (* (core:rand) 0.6)) + :cy (+ 0.2 (* (core:rand) 0.6)) + :left (+ 1 (mod (streaming:audio-beat-count music t) 15))) + (dict :gate 0 :cx 0.5 :cy 0.5 :left 0)))) diff --git a/l1/templates/standard-effects.sexp b/l1/templates/standard-effects.sexp new file mode 100644 index 0000000..ce4a92f --- /dev/null +++ b/l1/templates/standard-effects.sexp @@ -0,0 +1,22 @@ +;; Standard Effects Bundle +;; +;; Loads commonly-used video effects. +;; Include after primitives are loaded. +;; +;; Effects provided: +;; - rotate: rotation by angle +;; - zoom: scale in/out +;; - blend: alpha blend two frames +;; - ripple: water ripple distortion +;; - invert: color inversion +;; - hue_shift: hue rotation +;; +;; Usage: +;; (include :path "../templates/standard-effects.sexp") + +(effect rotate :name "fx-rotate") +(effect zoom :name "fx-zoom") +(effect blend :name "fx-blend") +(effect ripple :name "fx-ripple") +(effect invert :name "fx-invert") +(effect hue_shift :name "fx-hue-shift") diff --git a/l1/templates/standard-primitives.sexp b/l1/templates/standard-primitives.sexp new file mode 100644 index 0000000..6e2c62d --- /dev/null +++ b/l1/templates/standard-primitives.sexp @@ -0,0 +1,14 @@ +;; Standard Primitives Bundle +;; +;; Loads all commonly-used primitive libraries. +;; Include this at the top of streaming recipes. +;; +;; Usage: +;; (include :path "../templates/standard-primitives.sexp") + +(require-primitives "geometry") +(require-primitives "core") +(require-primitives "image") +(require-primitives "blending") +(require-primitives "color_ops") +(require-primitives "streaming") diff --git a/l1/templates/stream-process-pair.sexp b/l1/templates/stream-process-pair.sexp new file mode 100644 index 0000000..55f408e --- /dev/null +++ b/l1/templates/stream-process-pair.sexp @@ -0,0 +1,72 @@ +;; stream-process-pair template (streaming-compatible) +;; +;; Macro for processing a video source pair with full effects. +;; Reads source, applies A/B effects (rotate, zoom, invert, hue), blends, +;; and applies pair-level rotation. +;; +;; Required context (must be defined in calling scope): +;; - sources: array of video sources +;; - pair-configs: array of {:dir :rot-a :rot-b :zoom-a :zoom-b} configs +;; - pair-states: array from (bind pairs :states) +;; - now: current time (t) +;; - e: audio energy (0-1) +;; +;; Required effects (must be loaded): +;; - rotate, zoom, invert, hue_shift, blend +;; +;; Usage: +;; (include :path "../templates/stream-process-pair.sexp") +;; ...in frame pipeline... +;; (let [pair-states (bind pairs :states) +;; now t +;; e (streaming:audio-energy music now)] +;; (process-pair 0)) ;; process source at index 0 + +(require-primitives "core") + +(defmacro process-pair (src-idx) + (let [src (nth sources src-idx) + frame (streaming:source-read src now) + cfg (nth pair-configs src-idx) + state (nth pair-states src-idx) + + ;; Get state values (invert uses countdown > 0) + inv-a-active (if (> (get state :inv-a) 0) 1 0) + inv-b-active (if (> (get state :inv-b) 0) 1 0) + ;; Hue is active only when countdown > 0 + hue-a-val (if (> (get state :hue-a) 0) (get state :hue-a-val) 0) + hue-b-val (if (> (get state :hue-b) 0) (get state :hue-b-val) 0) + mix-opacity (get state :mix) + pair-rot-angle (* (get state :angle) (get cfg :dir)) + + ;; Get config values for energy-mapped ranges + rot-a-max (get cfg :rot-a) + rot-b-max (get cfg :rot-b) + zoom-a-max (get cfg :zoom-a) + zoom-b-max (get cfg :zoom-b) + + ;; Energy-driven rotation and zoom + rot-a (core:map-range e 0 1 0 rot-a-max) + rot-b (core:map-range e 0 1 0 rot-b-max) + zoom-a (core:map-range e 0 1 1 zoom-a-max) + zoom-b (core:map-range e 0 1 1 zoom-b-max) + + ;; Apply effects to clip A + clip-a (-> frame + (rotate :angle rot-a) + (zoom :amount zoom-a) + (invert :amount inv-a-active) + (hue_shift :degrees hue-a-val)) + + ;; Apply effects to clip B + clip-b (-> frame + (rotate :angle rot-b) + (zoom :amount zoom-b) + (invert :amount inv-b-active) + (hue_shift :degrees hue-b-val)) + + ;; Blend A+B + blended (blend clip-a clip-b :opacity mix-opacity)] + + ;; Apply pair-level rotation + (rotate blended :angle pair-rot-angle))) diff --git a/l1/test_autonomous.sexp b/l1/test_autonomous.sexp new file mode 100644 index 0000000..9e190a2 --- /dev/null +++ b/l1/test_autonomous.sexp @@ -0,0 +1,36 @@ +;; Autonomous Pipeline Test +;; +;; Uses the autonomous-pipeline primitive which computes ALL parameters +;; on GPU - including sin/cos expressions. Zero Python in the hot path! + +(stream "autonomous_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + + ;; Effects pipeline (what effects to apply) + (def effects + [{:op "rotate" :angle 0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Dynamic expressions (computed on GPU!) + ;; These use CUDA syntax: sinf(), cosf(), t (time), frame_num + (def expressions + {:rotate_angle "t * 30.0f" + :ripple_phase "t * 2.0f" + :brightness_factor "0.8f + 0.4f * sinf(t * 2.0f)"}) + + ;; Frame pipeline - creates image ON GPU and applies autonomous pipeline + (frame + (let [;; Create base image ON GPU (no CPU involvement!) + base (streaming_gpu:gpu-make-image 1920 1080 [128 100 200])] + + ;; Apply autonomous pipeline - ALL EFFECTS + ALL MATH ON GPU! + (streaming_gpu:autonomous-pipeline base effects expressions frame-num 30.0)))) diff --git a/l1/test_autonomous_prealloc.py b/l1/test_autonomous_prealloc.py new file mode 100644 index 0000000..5fde7f7 --- /dev/null +++ b/l1/test_autonomous_prealloc.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +Test autonomous pipeline with pre-allocated buffer. +This eliminates ALL Python from the hot path. +""" + +import time +import sys +sys.path.insert(0, '/app') + +import cupy as cp +from streaming.sexp_to_cuda import compile_autonomous_pipeline + +def test_autonomous_prealloc(): + width, height = 1920, 1080 + n_frames = 300 + fps = 30.0 + + print(f"Testing {n_frames} frames at {width}x{height}") + print("=" * 60) + + # Pre-allocate frame buffer (stays on GPU) + frame = cp.zeros((height, width, 3), dtype=cp.uint8) + frame[:, :, 0] = 128 # R + frame[:, :, 1] = 100 # G + frame[:, :, 2] = 200 # B + + # Define effects + effects = [ + {'op': 'rotate', 'angle': 0}, + {'op': 'hue_shift', 'degrees': 30}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + # Dynamic expressions (computed on GPU) + dynamic_expressions = { + 'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)', + } + + # Compile autonomous pipeline + print("Compiling autonomous pipeline...") + pipeline = compile_autonomous_pipeline(effects, width, height, dynamic_expressions) + + # Warmup + output = pipeline(frame, 0, fps) + cp.cuda.Stream.null.synchronize() + + # Benchmark - ZERO Python in the hot path! + print(f"Running {n_frames} frames...") + start = time.time() + for i in range(n_frames): + output = pipeline(frame, i, fps) + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + ms_per_frame = elapsed / n_frames * 1000 + actual_fps = n_frames / elapsed + + print("=" * 60) + print(f"Time: {ms_per_frame:.2f}ms per frame") + print(f"FPS: {actual_fps:.0f}") + print(f"Real-time: {actual_fps / 30:.1f}x (at 30fps target)") + print("=" * 60) + + # Compare with original baseline + print(f"\nOriginal Python sexp: ~150ms = 6 fps") + print(f"Autonomous GPU: {ms_per_frame:.2f}ms = {actual_fps:.0f} fps") + print(f"Speedup: {150 / ms_per_frame:.0f}x faster!") + + +if __name__ == '__main__': + test_autonomous_prealloc() diff --git a/l1/test_full_optimized.py b/l1/test_full_optimized.py new file mode 100644 index 0000000..6d7ae48 --- /dev/null +++ b/l1/test_full_optimized.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Full Optimized GPU Pipeline Test + +This demonstrates the maximum performance achievable: +1. Pre-allocated GPU frame buffer +2. Autonomous CUDA kernel (all params computed on GPU) +3. GPU HLS encoder (zero-copy NVENC) + +The entire pipeline runs on GPU with zero CPU involvement per frame! +""" + +import time +import sys +import os + +sys.path.insert(0, '/app') + +import cupy as cp +import numpy as np +from streaming.sexp_to_cuda import compile_autonomous_pipeline + +# Try to import GPU encoder +try: + from streaming.gpu_output import GPUHLSOutput, check_gpu_encode_available + GPU_ENCODE = check_gpu_encode_available() +except: + GPU_ENCODE = False + +def run_optimized_stream(duration: float = 10.0, fps: float = 30.0, output_dir: str = '/tmp/optimized'): + width, height = 1920, 1080 + n_frames = int(duration * fps) + + print("=" * 60) + print("FULL OPTIMIZED GPU PIPELINE") + print("=" * 60) + print(f"Resolution: {width}x{height}") + print(f"Duration: {duration}s ({n_frames} frames @ {fps}fps)") + print(f"GPU encode: {GPU_ENCODE}") + print("=" * 60) + + # Pre-allocate frame buffer on GPU + print("\n[1/4] Pre-allocating GPU frame buffer...") + frame = cp.zeros((height, width, 3), dtype=cp.uint8) + # Create a gradient pattern + y_grad = cp.linspace(0, 255, height, dtype=cp.float32)[:, cp.newaxis] + x_grad = cp.linspace(0, 255, width, dtype=cp.float32)[cp.newaxis, :] + frame[:, :, 0] = (y_grad * 0.5).astype(cp.uint8) # R + frame[:, :, 1] = (x_grad * 0.5).astype(cp.uint8) # G + frame[:, :, 2] = 128 # B + + # Define effects + effects = [ + {'op': 'rotate', 'angle': 0}, + {'op': 'hue_shift', 'degrees': 30}, + {'op': 'ripple', 'amplitude': 20, 'frequency': 12, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + # Dynamic expressions (computed on GPU) + dynamic_expressions = { + 'rotate_angle': 't * 45.0f', # 45 degrees per second + 'ripple_phase': 't * 3.0f', # Ripple animation + 'brightness_factor': '0.7f + 0.3f * sinf(t * 2.0f)', # Pulsing brightness + } + + # Compile autonomous pipeline + print("[2/4] Compiling autonomous CUDA kernel...") + pipeline = compile_autonomous_pipeline(effects, width, height, dynamic_expressions) + + # Setup output + print("[3/4] Setting up output...") + os.makedirs(output_dir, exist_ok=True) + + if GPU_ENCODE: + print(" Using GPU HLS encoder (zero-copy)") + out = GPUHLSOutput(output_dir, size=(width, height), fps=fps) + else: + print(" Using ffmpeg encoder") + import subprocess + cmd = [ + 'ffmpeg', '-y', + '-f', 'rawvideo', '-vcodec', 'rawvideo', + '-pix_fmt', 'rgb24', '-s', f'{width}x{height}', '-r', str(fps), + '-i', '-', + '-c:v', 'h264_nvenc', '-preset', 'p4', '-cq', '18', + '-pix_fmt', 'yuv420p', + f'{output_dir}/output.mp4' + ] + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Warmup + output = pipeline(frame, 0, fps) + cp.cuda.Stream.null.synchronize() + + # Run the pipeline! + print(f"[4/4] Running {n_frames} frames...") + print("-" * 60) + + frame_times = [] + start_total = time.time() + + for i in range(n_frames): + frame_start = time.time() + + # Apply effects (autonomous kernel - all on GPU!) + output = pipeline(frame, i, fps) + + # Write output + if GPU_ENCODE: + out.write(output, i / fps) + else: + # Transfer to CPU for ffmpeg (slower path) + cpu_frame = cp.asnumpy(output) + proc.stdin.write(cpu_frame.tobytes()) + + cp.cuda.Stream.null.synchronize() + frame_times.append(time.time() - frame_start) + + # Progress + if (i + 1) % 30 == 0: + avg_ms = sum(frame_times[-30:]) / 30 * 1000 + print(f" Frame {i+1}/{n_frames}: {avg_ms:.1f}ms/frame") + + total_time = time.time() - start_total + + # Cleanup + if GPU_ENCODE: + out.close() + else: + proc.stdin.close() + proc.wait() + + # Results + print("-" * 60) + avg_ms = sum(frame_times) / len(frame_times) * 1000 + actual_fps = n_frames / total_time + + print("\nRESULTS:") + print("=" * 60) + print(f"Total time: {total_time:.2f}s") + print(f"Avg per frame: {avg_ms:.2f}ms") + print(f"Actual FPS: {actual_fps:.0f}") + print(f"Real-time: {actual_fps / fps:.1f}x") + print("=" * 60) + + if GPU_ENCODE: + print(f"\nOutput: {output_dir}/*.ts (HLS segments)") + else: + print(f"\nOutput: {output_dir}/output.mp4") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--duration', type=float, default=10.0) + parser.add_argument('-o', '--output', default='/tmp/optimized') + parser.add_argument('--fps', type=float, default=30.0) + args = parser.parse_args() + + run_optimized_stream(args.duration, args.fps, args.output) diff --git a/l1/test_funky_text.py b/l1/test_funky_text.py new file mode 100644 index 0000000..342ef1c --- /dev/null +++ b/l1/test_funky_text.py @@ -0,0 +1,542 @@ +#!/usr/bin/env python3 +""" +Funky comparison tests: PIL vs TextStrip system. +Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + +FONTS = { + 'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + 'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + 'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf', + 'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf', + 'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + 'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf', + 'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf', + 'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf', +} + + +def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly, including color/opacity.""" + frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + # For opacity, render to RGBA then composite + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + font = _load_font(font_path, font_size) + + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + # PIL fill with alpha for opacity + alpha_int = int(round(opacity * 255)) + fill_rgba = (*fill, alpha_int) + stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None + + if multiline: + draw.multiline_text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor) + + # Composite onto background + bg_img = Image.fromarray(frame).convert('RGBA') + result = Image.alpha_composite(bg_img, txt_layer) + return np.array(result.convert('RGB')) + + +def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip system.""" + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + strip = render_text_strip( + text, font_path, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array(fill, dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + return np.array(result) + + +def compare(name, tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering.""" + pil = render_pil(**kwargs) + strip = render_strip(**kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f" PASS: {name}") + return True + + if tolerance > 0: + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + if best_diff.max() == 0: + print(f" PASS: {name} (within {tolerance}px tolerance)") + return True + + print(f" FAIL: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png") + return False + + +def test_colors(): + """Test various text colors on various backgrounds.""" + print("\n--- Colors ---") + results = [] + + # White on black (baseline) + results.append(compare("color_white_on_black", + text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0))) + + # Red on black + results.append(compare("color_red", + text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0))) + + # Green on black + results.append(compare("color_green", + text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0))) + + # Blue on black + results.append(compare("color_blue", + text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0))) + + # Yellow on dark gray + results.append(compare("color_yellow_on_gray", + text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40))) + + # Magenta on white + results.append(compare("color_magenta_on_white", + text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255))) + + # Subtle: gray text on slightly lighter gray + results.append(compare("color_subtle_gray", + text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64))) + + # Orange on deep blue + results.append(compare("color_orange_on_blue", + text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80))) + + return results + + +def test_opacity(): + """Test different opacity levels.""" + print("\n--- Opacity ---") + results = [] + + results.append(compare("opacity_100", + text="Full", x=20, y=30, opacity=1.0)) + + results.append(compare("opacity_75", + text="75%", x=20, y=30, opacity=0.75)) + + results.append(compare("opacity_50", + text="Half", x=20, y=30, opacity=0.5)) + + results.append(compare("opacity_25", + text="Quarter", x=20, y=30, opacity=0.25)) + + results.append(compare("opacity_10", + text="Ghost", x=20, y=30, opacity=0.1)) + + # Opacity on colored background + results.append(compare("opacity_on_colored_bg", + text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0), + opacity=0.5)) + + # Color + opacity combo + results.append(compare("opacity_red_on_green", + text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0), + opacity=0.6)) + + return results + + +def test_fonts(): + """Test different fonts and sizes.""" + print("\n--- Fonts & Sizes ---") + results = [] + + for label, path in FONTS.items(): + results.append(compare(f"font_{label}", + text="Quick Fox", x=20, y=30, font_path=path, font_size=28, + frame_size=(300, 80))) + + # Tiny text + results.append(compare("size_tiny", + text="Tiny text at 12px", x=10, y=15, font_size=12, + frame_size=(200, 40))) + + # Big text + results.append(compare("size_big", + text="BIG", x=20, y=10, font_size=72, + frame_size=(300, 100))) + + # Huge text + results.append(compare("size_huge", + text="XL", x=10, y=10, font_size=120, + frame_size=(300, 160))) + + return results + + +def test_anchors(): + """Test all anchor combinations.""" + print("\n--- Anchors ---") + results = [] + + # All horizontal x vertical combos + for h in ['l', 'm', 'r']: + for v in ['a', 'm', 's', 'd']: + anchor = h + v + results.append(compare(f"anchor_{anchor}", + text="Anchor", x=200, y=50, anchor=anchor, + frame_size=(400, 100))) + + return results + + +def test_strokes(): + """Test various stroke widths and colors.""" + print("\n--- Strokes ---") + results = [] + + for sw in [1, 2, 3, 4, 6, 8]: + results.append(compare(f"stroke_w{sw}", + text="Stroke", x=30, y=20, font_size=40, + stroke_width=sw, stroke_fill=(0, 0, 0), + frame_size=(300, 80))) + + # Colored strokes + results.append(compare("stroke_red", + text="Red outline", x=20, y=20, font_size=36, + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(350, 80))) + + results.append(compare("stroke_white_on_black", + text="Glow", x=20, y=20, font_size=40, + fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255), + frame_size=(250, 80))) + + # Stroke with bold font + results.append(compare("stroke_bold", + text="Bold+Stroke", x=20, y=20, + font_path=FONTS['bold'], font_size=36, + stroke_width=3, stroke_fill=(0, 0, 0), + frame_size=(400, 80))) + + # Stroke + colored text on colored bg + results.append(compare("stroke_colored_on_bg", + text="Party", x=20, y=20, font_size=48, + fill=(255, 255, 0), bg=(50, 0, 80), + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(300, 80))) + + return results + + +def test_edge_clipping(): + """Test text at frame edges - clipping behavior.""" + print("\n--- Edge Clipping ---") + results = [] + + # Text at very left edge + results.append(compare("clip_left_edge", + text="LEFT", x=0, y=30, frame_size=(200, 80))) + + # Text partially off right edge + results.append(compare("clip_right_edge", + text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80))) + + # Text at top edge + results.append(compare("clip_top", + text="TOP", x=20, y=0, frame_size=(200, 80))) + + # Text at bottom edge - partially clipped + results.append(compare("clip_bottom", + text="BOTTOM", x=20, y=55, font_size=40, + frame_size=(200, 80))) + + # Large text overflowing all sides from center + results.append(compare("clip_overflow_center", + text="HUGE", x=75, y=40, font_size=100, anchor="mm", + frame_size=(150, 80))) + + # Corner placement + results.append(compare("clip_corner_br", + text="Corner", x=350, y=70, font_size=36, + frame_size=(400, 100))) + + return results + + +def test_multiline_fancy(): + """Test multiline with various styles.""" + print("\n--- Multiline Fancy ---") + results = [] + + # Right-aligned (1px tolerance: same sub-pixel issue as center alignment) + results.append(compare("multi_right", + text="Right\nAligned\nText", x=380, y=20, + frame_size=(400, 150), + multiline=True, anchor="ra", align="right", + tolerance=1)) + + # Center + stroke + results.append(compare("multi_center_stroke", + text="Title\nSubtitle", x=200, y=20, + font_size=32, frame_size=(400, 120), + multiline=True, anchor="ma", align="center", + stroke_width=2, stroke_fill=(0, 0, 0), + tolerance=1)) + + # Wide line spacing + results.append(compare("multi_wide_spacing", + text="Line A\nLine B\nLine C", x=20, y=10, + frame_size=(300, 200), + multiline=True, line_spacing=20)) + + # Zero extra spacing + results.append(compare("multi_tight_spacing", + text="Tight\nPacked\nLines", x=20, y=10, + frame_size=(300, 150), + multiline=True, line_spacing=0)) + + # Many lines + results.append(compare("multi_many_lines", + text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5, + font_size=20, frame_size=(200, 200), + multiline=True, line_spacing=4)) + + # Bold multiline with stroke + results.append(compare("multi_bold_stroke", + text="BOLD\nSTROKE", x=20, y=10, + font_path=FONTS['bold'], font_size=48, + stroke_width=3, stroke_fill=(200, 0, 0), + frame_size=(350, 150), multiline=True)) + + # Multiline on colored bg with opacity + results.append(compare("multi_opacity_on_bg", + text="Semi\nTransparent", x=20, y=10, + fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7, + frame_size=(300, 120), multiline=True)) + + return results + + +def test_special_chars(): + """Test special characters and edge cases.""" + print("\n--- Special Characters ---") + results = [] + + # Numbers and symbols + results.append(compare("chars_numbers", + text="0123456789", x=20, y=30, frame_size=(300, 80))) + + results.append(compare("chars_punctuation", + text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80))) + + results.append(compare("chars_symbols", + text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80))) + + # Single character + results.append(compare("chars_single", + text="X", x=50, y=30, font_size=48, frame_size=(100, 80))) + + # Very long text (clipped) + results.append(compare("chars_long", + text="The quick brown fox jumps over the lazy dog", x=10, y=30, + font_size=24, frame_size=(400, 80))) + + # Mixed case + results.append(compare("chars_mixed_case", + text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80))) + + return results + + +def test_combos(): + """Complex combinations of features.""" + print("\n--- Combos ---") + results = [] + + # Big bold stroke + color + opacity + results.append(compare("combo_all_features", + text="EPIC", x=20, y=10, + font_path=FONTS['bold'], font_size=64, + fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85, + stroke_width=4, stroke_fill=(180, 0, 0), + frame_size=(350, 100))) + + # Small mono on busy background + results.append(compare("combo_mono_code", + text="fn main() {}", x=10, y=15, + font_path=FONTS['mono'], font_size=16, + fill=(0, 255, 100), bg=(30, 30, 30), + frame_size=(250, 50))) + + # Serif italic multiline with stroke + results.append(compare("combo_serif_italic_multi", + text="Once upon\na time...", x=20, y=10, + font_path=FONTS['italic'], font_size=28, + stroke_width=1, stroke_fill=(80, 80, 80), + frame_size=(300, 120), multiline=True)) + + # Narrow font, big stroke, center anchored + results.append(compare("combo_narrow_stroke_center", + text="NARROW", x=150, y=40, + font_path=FONTS['narrow'], font_size=44, + stroke_width=5, stroke_fill=(0, 0, 0), + anchor="mm", frame_size=(300, 80))) + + # Multiple strips on same frame (simulated by sequential placement) + results.append(compare("combo_opacity_blend", + text="Layered", x=20, y=30, + fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5, + font_size=48, frame_size=(300, 80))) + + return results + + +def test_multi_strip_overlay(): + """Test placing multiple strips on the same frame.""" + print("\n--- Multi-Strip Overlay ---") + results = [] + + frame_size = (400, 150) + bg = (20, 20, 40) + + # PIL version - multiple draw calls + font1 = _load_font(FONTS['bold'], 48) + font2 = _load_font(None, 24) + font3 = _load_font(FONTS['mono'], 18) + + pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1, + stroke_width=2, stroke_fill=(0, 0, 0, 255)) + draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2) + draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3) + bg_img = Image.fromarray(pil_frame).convert('RGBA') + pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB')) + + # Strip version - multiple placements + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0)) + s2 = render_text_strip("Subtitle here", None, 24) + s3 = render_text_strip("code_snippet()", FONTS['mono'], 18) + + frame = place_text_strip_jax( + frame, jnp.asarray(s1.image), 20, 10, + s1.baseline_y, s1.bearing_x, + jnp.array([255, 255, 0], dtype=jnp.float32), 1.0, + anchor_x=s1.anchor_x, anchor_y=s1.anchor_y, + stroke_width=s1.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s2.image), 20, 70, + s2.baseline_y, s2.bearing_x, + jnp.array([200, 200, 200], dtype=jnp.float32), 1.0, + anchor_x=s2.anchor_x, anchor_y=s2.anchor_y, + stroke_width=s2.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s3.image), 20, 110, + s3.baseline_y, s3.bearing_x, + jnp.array([0, 255, 128], dtype=jnp.float32), 200/255, + anchor_x=s3.anchor_x, anchor_y=s3.anchor_y, + stroke_width=s3.stroke_width) + + strip_result = np.array(frame) + + diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff <= 1: + print(f" PASS: multi_overlay (max_diff={max_diff})") + results.append(True) + else: + print(f" FAIL: multi_overlay") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png") + Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png") + print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png") + results.append(False) + + return results + + +def main(): + print("=" * 60) + print("Funky TextStrip vs PIL Comparison") + print("=" * 60) + + all_results = [] + all_results.extend(test_colors()) + all_results.extend(test_opacity()) + all_results.extend(test_fonts()) + all_results.extend(test_anchors()) + all_results.extend(test_strokes()) + all_results.extend(test_edge_clipping()) + all_results.extend(test_multiline_fancy()) + all_results.extend(test_special_chars()) + all_results.extend(test_combos()) + all_results.extend(test_multi_strip_overlay()) + + print("\n" + "=" * 60) + passed = sum(all_results) + total = len(all_results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + failed = [i for i, r in enumerate(all_results) if not r] + print(f"FAILED: {total - passed} tests (indices: {failed})") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/l1/test_fused_direct.py b/l1/test_fused_direct.py new file mode 100644 index 0000000..5b87462 --- /dev/null +++ b/l1/test_fused_direct.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Direct test of fused pipeline primitive. + +Compares performance of: +1. Fused kernel (single CUDA kernel for all effects) +2. Separate kernels (one CUDA kernel per effect) +""" + +import time +import sys + +# Check for CuPy +try: + import cupy as cp + print("[test] CuPy available") +except ImportError: + print("[test] CuPy not available - can't run test") + sys.exit(1) + +# Add path for imports +sys.path.insert(0, '/app') + +from streaming.sexp_to_cuda import compile_frame_pipeline +from streaming.jit_compiler import fast_rotate, fast_hue_shift, fast_ripple + +def test_fused_vs_separate(): + """Compare fused vs separate kernel performance.""" + + width, height = 1920, 1080 + n_frames = 100 + + # Create test frame + frame = cp.random.randint(0, 255, (height, width, 3), dtype=cp.uint8) + + # Define effects pipeline + effects = [ + {'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 30.0}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + ] + + print(f"\n[test] Testing {n_frames} frames at {width}x{height}") + print(f"[test] Effects: rotate, hue_shift, ripple\n") + + # ========== Test fused kernel ========== + print("[test] Compiling fused kernel...") + pipeline = compile_frame_pipeline(effects, width, height) + + # Warmup + output = pipeline(frame, rotate_angle=45, ripple_phase=0) + cp.cuda.Stream.null.synchronize() + + print("[test] Running fused kernel benchmark...") + start = time.time() + for i in range(n_frames): + output = pipeline(frame, rotate_angle=i * 3.6, ripple_phase=i * 0.1) + cp.cuda.Stream.null.synchronize() + fused_time = time.time() - start + + fused_ms = fused_time / n_frames * 1000 + fused_fps = n_frames / fused_time + print(f"[test] Fused kernel: {fused_ms:.2f}ms/frame ({fused_fps:.0f} fps)") + + # ========== Test separate kernels ========== + print("\n[test] Running separate kernels benchmark...") + + # Warmup + temp = fast_rotate(frame, 45.0) + temp = fast_hue_shift(temp, 30.0) + temp = fast_ripple(temp, 15, frequency=10, decay=2, phase=0, center_x=960, center_y=540) + cp.cuda.Stream.null.synchronize() + + start = time.time() + for i in range(n_frames): + temp = fast_rotate(frame, i * 3.6) + temp = fast_hue_shift(temp, 30.0) + temp = fast_ripple(temp, 15, frequency=10, decay=2, phase=i * 0.1, center_x=960, center_y=540) + cp.cuda.Stream.null.synchronize() + separate_time = time.time() - start + + separate_ms = separate_time / n_frames * 1000 + separate_fps = n_frames / separate_time + print(f"[test] Separate kernels: {separate_ms:.2f}ms/frame ({separate_fps:.0f} fps)") + + # ========== Summary ========== + speedup = separate_time / fused_time + print(f"\n{'='*50}") + print(f"SPEEDUP: {speedup:.1f}x faster with fused kernel") + print(f"") + print(f"Fused: {fused_ms:.2f}ms ({fused_fps:.0f} fps)") + print(f"Separate: {separate_ms:.2f}ms ({separate_fps:.0f} fps)") + print(f"{'='*50}") + + # Compare with original Python sexp interpreter baseline (126-205ms) + python_baseline_ms = 150 # Approximate from profiling + vs_python = python_baseline_ms / fused_ms + print(f"\nVs Python sexp interpreter (~{python_baseline_ms}ms): {vs_python:.0f}x faster!") + + +if __name__ == '__main__': + test_fused_vs_separate() diff --git a/l1/test_fused_pipeline.sexp b/l1/test_fused_pipeline.sexp new file mode 100644 index 0000000..72ee033 --- /dev/null +++ b/l1/test_fused_pipeline.sexp @@ -0,0 +1,44 @@ +;; Test Fused Pipeline - Should be much faster than interpreted +;; +;; This uses the fused-pipeline primitive which compiles all effects +;; into a single CUDA kernel instead of interpreting them one by one. + +(stream "fused_pipeline_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + (require-primitives "math") + + ;; Define the effects pipeline (compiled to single CUDA kernel) + ;; Each effect is a map with :op and effect-specific parameters + (def effects-pipeline + [{:op "rotate" :angle 0} + {:op "zoom" :amount 1.0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Frame pipeline + (frame + (let [;; Create a gradient image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters (change per frame) + angle (* t 30) + zoom (+ 1.0 (* 0.2 (math:sin (* t 0.5)))) + phase (* t 2)] + + ;; Apply fused pipeline - all effects in ONE CUDA kernel! + (streaming_gpu:fused-pipeline base effects-pipeline + :rotate_angle angle + :zoom_amount zoom + :ripple_phase phase)))) diff --git a/l1/test_heavy_fused.sexp b/l1/test_heavy_fused.sexp new file mode 100644 index 0000000..d421cfb --- /dev/null +++ b/l1/test_heavy_fused.sexp @@ -0,0 +1,39 @@ +;; Heavy Fused Pipeline Test +;; +;; Tests with many effects to show the full benefit of kernel fusion + +(stream "heavy_fused_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + (require-primitives "math") + + ;; Define heavy effects pipeline (4 effects fused into one kernel) + (def heavy-pipeline + [{:op "rotate" :angle 0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Frame pipeline + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters + angle (* t 30) + phase (* t 2)] + + ;; 4 effects in ONE kernel call! + (streaming_gpu:fused-pipeline base heavy-pipeline + :rotate_angle angle + :ripple_phase phase)))) diff --git a/l1/test_heavy_interpreted.sexp b/l1/test_heavy_interpreted.sexp new file mode 100644 index 0000000..72c7965 --- /dev/null +++ b/l1/test_heavy_interpreted.sexp @@ -0,0 +1,38 @@ +;; Heavy Interpreted Pipeline Test +;; +;; Same effects as test_heavy_fused.sexp but using separate primitive calls + +(stream "heavy_interpreted_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "geometry_gpu") + (require-primitives "color_ops") + (require-primitives "image") + (require-primitives "math") + + ;; Frame pipeline - INTERPRETED (separate calls) + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters + angle (* t 30) + zoom (+ 1.0 (* 0.1 (math:sin (* t 0.5)))) + phase (* t 2) + + ;; Apply 4 effects one by one (INTERPRETED - many kernel launches) + step1 (geometry_gpu:rotate base angle) + step2 (color_ops:hue-shift step1 30) + step3 (geometry_gpu:ripple step2 15 :frequency 10 :decay 2 :time phase) + step4 (color_ops:brightness step3 1.0)] + + step4))) diff --git a/l1/test_interpreted_vs_fused.sexp b/l1/test_interpreted_vs_fused.sexp new file mode 100644 index 0000000..0a365bd --- /dev/null +++ b/l1/test_interpreted_vs_fused.sexp @@ -0,0 +1,37 @@ +;; Test: Interpreted vs Fused Pipeline Comparison +;; +;; This simulates a typical effects pipeline using the INTERPRETED approach +;; (calling primitives one by one through Python). + +(stream "interpreted_pipeline_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "geometry_gpu") + (require-primitives "color_ops") + (require-primitives "image") + (require-primitives "math") + + ;; Frame pipeline - INTERPRETED approach (one primitive call at a time) + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Apply effects one by one (INTERPRETED - slow) + angle (* t 30) + rotated (geometry_gpu:rotate base angle) + + hued (color_ops:hue-shift rotated 30) + + brightness (+ 0.8 (* 0.4 (math:sin (* t 2)))) + bright (color_ops:brightness hued brightness)] + + bright))) diff --git a/l1/test_pil_options.py b/l1/test_pil_options.py new file mode 100644 index 0000000..fb5ffb9 --- /dev/null +++ b/l1/test_pil_options.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Explore PIL text options and test if we can match them. +""" + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +def load_font(font_name=None, font_size=32): + """Load a font.""" + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', + ] + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + return ImageFont.load_default() + + +def test_pil_options(): + """Test various PIL text options.""" + + # Create a test frame + frame_size = (600, 400) + + font = load_font(None, 36) + font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36) + font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36) + + tests = [] + + # Test 1: Basic text + def basic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font) + return img, "basic" + tests.append(basic_text) + + # Test 2: Stroke/outline + def stroke_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font, + stroke_width=2, stroke_fill=(255, 0, 0, 255)) + return img, "stroke" + tests.append(stroke_text) + + # Test 3: Bold font + def bold_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold) + return img, "bold" + tests.append(bold_text) + + # Test 4: Italic font + def italic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic) + return img, "italic" + tests.append(italic_text) + + # Test 5: Different anchors + def anchor_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + # Draw crosshairs at anchor points + for x in [100, 300, 500]: + draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255)) + draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255)) + + draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm") + draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm") + draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm") + return img, "anchor" + tests.append(anchor_text) + + # Test 6: Multiline text + def multiline_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three", + fill=(255, 255, 255, 255), font=font, spacing=10) + return img, "multiline" + tests.append(multiline_text) + + # Test 7: Semi-transparent text + def alpha_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font) + draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font) + draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font) + return img, "alpha" + tests.append(alpha_text) + + # Test 8: Colored text + def colored_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font) + draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font) + draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font) + draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font) + return img, "colored" + tests.append(colored_text) + + # Test 9: Large stroke + def large_stroke(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font, + stroke_width=5, stroke_fill=(0, 0, 0, 255)) + return img, "large_stroke" + tests.append(large_stroke) + + # Test 10: Emoji (if supported) + def emoji_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + try: + # Try to find an emoji font + emoji_font = None + emoji_paths = [ + '/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf', + '/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf', + ] + for p in emoji_paths: + try: + emoji_font = ImageFont.truetype(p, 36) + break + except: + pass + if emoji_font: + draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font) + else: + draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font) + except Exception as e: + draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font) + return img, "emoji" + tests.append(emoji_text) + + # Run all tests + print("PIL Text Options Test") + print("=" * 60) + + for test_fn in tests: + img, name = test_fn() + fname = f"/tmp/pil_test_{name}.png" + img.save(fname) + print(f"Saved: {fname}") + + print("\nCheck /tmp/pil_test_*.png for results") + + # Print available parameters + print("\n" + "=" * 60) + print("PIL draw.text() parameters:") + print(" - xy: position tuple") + print(" - text: string to draw") + print(" - fill: color (R,G,B) or (R,G,B,A)") + print(" - font: ImageFont object") + print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)") + print(" - spacing: line spacing for multiline") + print(" - align: 'left', 'center', 'right' for multiline") + print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)") + print(" - features: OpenType features list") + print(" - language: language code for shaping") + print(" - stroke_width: outline width in pixels") + print(" - stroke_fill: outline color") + print(" - embedded_color: use embedded color glyphs (emoji)") + + +if __name__ == "__main__": + test_pil_options() diff --git a/l1/test_styled_text.py b/l1/test_styled_text.py new file mode 100644 index 0000000..925a7fb --- /dev/null +++ b/l1/test_styled_text.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Test styled TextStrip rendering against PIL. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + + +def render_pil(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly.""" + frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8) + img = Image.fromarray(frame) + draw = ImageDraw.Draw(img) + + font = _load_font(None, font_size) + + # Default stroke fill + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + if multiline: + draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor) + + return np.array(img) + + +def render_strip(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip.""" + frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8) + + strip = render_text_strip( + text, None, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + return np.array(result) + + +def compare(name, text, x, y, font_size=36, frame_size=(400, 100), + tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering. + + tolerance=0: exact pixel match required + tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences + in center-aligned multiline text where the strip is pre-rendered + at a different base position than the final placement) + """ + pil = render_pil(text, x, y, font_size, frame_size, **kwargs) + strip = render_strip(text, x, y, font_size, frame_size, **kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f"PASS: {name}") + print(f" Max diff: 0, Pixels different: 0") + return True + + if tolerance > 0: + # Check if the difference is just a sub-pixel position shift: + # for each shifted version, compute the minimum diff + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + max_shift_diff = best_diff.max() + pixels_shift_diff = (best_diff > 0).any(axis=2).sum() + if max_shift_diff == 0: + print(f"PASS: {name} (within {tolerance}px position tolerance)") + print(f" Raw diff: {max_diff}, After shift tolerance: 0") + return True + + status = "FAIL" + print(f"{status}: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + + # Save debug images + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png") + + return False + + +def main(): + print("=" * 60) + print("Styled TextStrip vs PIL Comparison") + print("=" * 60) + + results = [] + + # Basic text + results.append(compare("basic", "Hello World", 20, 50)) + + # Stroke/outline + results.append(compare("stroke_2", "Outlined", 20, 50, + stroke_width=2, stroke_fill=(255, 0, 0))) + + results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48, + frame_size=(500, 120), + stroke_width=5, stroke_fill=(0, 0, 0))) + + # Anchors - center + results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100), + anchor="mm")) + + # Anchors - right + results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100), + anchor="rm")) + + # Multiline + results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20, + frame_size=(400, 150), + multiline=True, line_spacing=8)) + + # Multiline centered (1px tolerance: sub-pixel rendering differs because + # the strip is pre-rendered at an integer position while PIL's center + # alignment uses fractional getlength values for the 'm' anchor shift) + results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20, + frame_size=(400, 150), + multiline=True, anchor="ma", align="center", + tolerance=1)) + + # Stroke + multiline + results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20, + frame_size=(400, 120), + stroke_width=2, stroke_fill=(0, 0, 255), + multiline=True)) + + print("=" * 60) + passed = sum(results) + total = len(results) + print(f"Results: {passed}/{total} passed") + + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/l1/test_typography_fx.py b/l1/test_typography_fx.py new file mode 100644 index 0000000..e57186e --- /dev/null +++ b/l1/test_typography_fx.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +""" +Tests for typography FX: gradients, rotation, shadow, and combined effects. +""" + +import numpy as np +import jax +import jax.numpy as jnp +from PIL import Image + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font, + make_linear_gradient, make_radial_gradient, make_multi_stop_gradient, + place_text_strip_gradient_jax, rotate_strip_jax, + place_text_strip_shadow_jax, place_text_strip_fx_jax, + bind_typography_primitives, +) + + +def make_frame(w=400, h=200): + """Create a dark gray test frame.""" + return jnp.full((h, w, 3), 40, dtype=jnp.uint8) + + +def get_strip(text="Hello", font_size=48): + """Get a pre-rendered text strip.""" + return render_text_strip(text, None, font_size) + + +def has_visible_pixels(frame, threshold=50): + """Check if frame has pixels above threshold.""" + return int(frame.max()) > threshold + + +def save_debug(name, frame): + """Save frame for visual inspection.""" + arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame + Image.fromarray(arr).save(f"/tmp/fx_{name}.png") + + +# ============================================================================= +# Gradient Tests +# ============================================================================= + +def test_linear_gradient_shape(): + grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255)) + assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}" + assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}" + # Left edge should be red-ish, right edge blue-ish + assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}" + assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}" + print("PASS: test_linear_gradient_shape") + return True + + +def test_linear_gradient_angle(): + # 90 degrees: top-to-bottom + grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0) + # Top row should be red, bottom row should be blue + assert grad[0, 50, 0] > 0.8, "Top should be red" + assert grad[-1, 50, 2] > 0.8, "Bottom should be blue" + print("PASS: test_linear_gradient_angle") + return True + + +def test_radial_gradient_shape(): + grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128)) + assert grad.shape == (100, 100, 3) + # Center should be yellow (color1) + assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)" + assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)" + # Corner should be closer to dark blue (color2) + assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue" + print("PASS: test_radial_gradient_shape") + return True + + +def test_multi_stop_gradient(): + stops = [ + (0.0, (255, 0, 0)), + (0.5, (0, 255, 0)), + (1.0, (0, 0, 255)), + ] + grad = make_multi_stop_gradient(100, 10, stops) + assert grad.shape == (10, 100, 3) + # Left: red, Middle: green, Right: blue + assert grad[5, 0, 0] > 0.8, "Left should be red" + assert grad[5, 50, 1] > 0.8, "Middle should be green" + assert grad[5, -1, 2] > 0.8, "Right should be blue" + print("PASS: test_multi_stop_gradient") + return True + + +def test_place_gradient(): + """Test gradient text rendering produces visible output.""" + frame = make_frame() + strip = get_strip() + grad = make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255)) + grad_jax = jnp.asarray(grad) + strip_img = jnp.asarray(strip.image) + + result = place_text_strip_gradient_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + grad_jax, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + assert result.shape == frame.shape + # Should have visible colored pixels + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Gradient text should be visible" + save_debug("gradient", result) + print("PASS: test_place_gradient") + return True + + +# ============================================================================= +# Rotation Tests +# ============================================================================= + +def test_rotate_strip_identity(): + """Rotation by 0 degrees should preserve content.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 0.0) + # Output is larger (diagonal size) + assert rotated.shape[2] == 4, "Should be RGBA" + assert rotated.shape[0] >= strip.height + assert rotated.shape[1] >= strip.width + + # Alpha should have non-zero pixels (text was preserved) + assert rotated[:, :, 3].max() > 200, "Should have visible alpha" + print("PASS: test_rotate_strip_identity") + return True + + +def test_rotate_strip_90(): + """Rotation by 90 degrees.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 90.0) + assert rotated.shape[2] == 4 + # Should still have visible content + assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha" + save_debug("rotated_90", np.array(rotated)) + print("PASS: test_rotate_strip_90") + return True + + +def test_rotate_360_exact(): + """360-degree rotation must be pixel-exact (regression test for trig snapping).""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + sh, sw = strip.height, strip.width + + rotated = rotate_strip_jax(strip_img, 360.0) + rh, rw = rotated.shape[:2] + off_y = (rh - sh) // 2 + off_x = (rw - sw) // 2 + + crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw]) + orig = np.array(strip_img) + d = np.abs(crop.astype(np.int16) - orig.astype(np.int16)) + max_diff = int(d.max()) + assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}" + print("PASS: test_rotate_360_exact") + return True + + +def test_place_rotated(): + """Test rotated text placement produces visible output.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 0], dtype=jnp.float32) + + result = place_text_strip_fx_jax( + frame, strip_img, 200.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + angle=30.0, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Rotated text should be visible" + save_debug("rotated_30", result) + print("PASS: test_place_rotated") + return True + + +# ============================================================================= +# Shadow Tests +# ============================================================================= + +def test_shadow_basic(): + """Test shadow produces visible offset copy.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=5.0, shadow_offset_y=5.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.8, + ) + + assert result.shape == frame.shape + # Should have both bright (text) and dark (shadow) pixels + assert result.max() > 200, "Should have bright text" + save_debug("shadow_basic", result) + print("PASS: test_shadow_basic") + return True + + +def test_shadow_blur(): + """Test blurred shadow.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.7, + shadow_blur_radius=3, + ) + + assert result.shape == frame.shape + save_debug("shadow_blur", result) + print("PASS: test_shadow_blur") + return True + + +# ============================================================================= +# Combined FX Tests +# ============================================================================= + +def test_fx_combined(): + """Test combined gradient + shadow + rotation.""" + frame = make_frame(500, 300) + strip = get_strip("FX Test", 64) + strip_img = jnp.asarray(strip.image) + + grad = make_linear_gradient(strip.width, strip.height, + (255, 100, 0), (0, 100, 255)) + grad_jax = jnp.asarray(grad) + + result = place_text_strip_fx_jax( + frame, strip_img, 250.0, 150.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + gradient_map=grad_jax, + angle=15.0, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.6, + shadow_blur_radius=2, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Combined FX should produce visible output" + save_debug("fx_combined", result) + print("PASS: test_fx_combined") + return True + + +def test_fx_no_effects(): + """FX function with no effects should match basic place_text_strip_jax.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + # Using FX function with defaults + result_fx = place_text_strip_fx_jax( + frame, strip_img, 50.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + # Using original function + result_orig = place_text_strip_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16)) + max_diff = int(diff.max()) + assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}" + print("PASS: test_fx_no_effects") + return True + + +# ============================================================================= +# S-Expression Binding Tests +# ============================================================================= + +def test_sexp_bindings(): + """Test that all new primitives are registered.""" + env = {} + bind_typography_primitives(env) + + expected = [ + 'linear-gradient', 'radial-gradient', 'multi-stop-gradient', + 'place-text-strip-gradient', 'place-text-strip-rotated', + 'place-text-strip-shadow', 'place-text-strip-fx', + ] + for name in expected: + assert name in env, f"Missing binding: {name}" + + print("PASS: test_sexp_bindings") + return True + + +def test_sexp_gradient_primitive(): + """Test gradient primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("Test", 36) + grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255)) + + assert grad.shape == (strip.height, strip.width, 3) + print("PASS: test_sexp_gradient_primitive") + return True + + +def test_sexp_fx_primitive(): + """Test combined FX primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("FX", 36) + frame = make_frame() + + result = env['place-text-strip-fx']( + frame, strip, 100.0, 80.0, + color=(255, 200, 0), opacity=0.9, + shadow_offset_x=3, shadow_offset_y=3, + shadow_opacity=0.5, + ) + assert result.shape == frame.shape + print("PASS: test_sexp_fx_primitive") + return True + + +# ============================================================================= +# JIT Compilation Test +# ============================================================================= + +def test_jit_fx(): + """Test that place_text_strip_fx_jax can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + # JIT compile with static args for angle and blur radius + @jax.jit + def render(frame, x, y, opacity): + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=shadow_color, + shadow_opacity=0.5, + shadow_blur_radius=2, + ) + + # First call traces, second uses cache + result1 = render(frame, 50.0, 100.0, 1.0) + result2 = render(frame, 60.0, 90.0, 0.8) + + assert result1.shape == frame.shape + assert result2.shape == frame.shape + print("PASS: test_jit_fx") + return True + + +def test_jit_gradient(): + """Test that gradient placement can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + grad = jnp.asarray(make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255))) + + @jax.jit + def render(frame, x, y): + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + grad, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + result = render(frame, 50.0, 100.0) + assert result.shape == frame.shape + print("PASS: test_jit_gradient") + return True + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("Typography FX Tests") + print("=" * 60) + + tests = [ + # Gradients + test_linear_gradient_shape, + test_linear_gradient_angle, + test_radial_gradient_shape, + test_multi_stop_gradient, + test_place_gradient, + # Rotation + test_rotate_strip_identity, + test_rotate_strip_90, + test_rotate_360_exact, + test_place_rotated, + # Shadow + test_shadow_basic, + test_shadow_blur, + # Combined FX + test_fx_combined, + test_fx_no_effects, + # S-expression bindings + test_sexp_bindings, + test_sexp_gradient_primitive, + test_sexp_fx_primitive, + # JIT compilation + test_jit_fx, + test_jit_gradient, + ] + + results = [] + for test in tests: + try: + results.append(test()) + except Exception as e: + print(f"FAIL: {test.__name__}: {e}") + import traceback + traceback.print_exc() + results.append(False) + + print("=" * 60) + passed = sum(r for r in results if r) + total = len(results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + return passed == total + + +if __name__ == "__main__": + import sys + sys.exit(0 if main() else 1) diff --git a/l1/tests/__init__.py b/l1/tests/__init__.py new file mode 100644 index 0000000..9849184 --- /dev/null +++ b/l1/tests/__init__.py @@ -0,0 +1 @@ +# Tests for art-celery diff --git a/l1/tests/conftest.py b/l1/tests/conftest.py new file mode 100644 index 0000000..a99ae02 --- /dev/null +++ b/l1/tests/conftest.py @@ -0,0 +1,93 @@ +""" +Pytest fixtures for art-celery tests. +""" + +import pytest +from typing import Any, Dict, List + + +@pytest.fixture +def sample_compiled_nodes() -> List[Dict[str, Any]]: + """Sample nodes as produced by the S-expression compiler.""" + return [ + { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "cat"}, + "inputs": [], + "name": None, + }, + { + "id": "source_2", + "type": "SOURCE", + "config": { + "input": True, + "name": "Second Video", + "description": "A user-provided video", + }, + "inputs": [], + "name": "second-video", + }, + { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "identity"}, + "inputs": ["source_1"], + "name": None, + }, + { + "id": "effect_2", + "type": "EFFECT", + "config": {"effect": "invert", "intensity": 1.0}, + "inputs": ["source_2"], + "name": None, + }, + { + "id": "sequence_1", + "type": "SEQUENCE", + "config": {}, + "inputs": ["effect_1", "effect_2"], + "name": None, + }, + ] + + +@pytest.fixture +def sample_registry() -> Dict[str, Dict[str, Any]]: + """Sample registry with assets and effects.""" + return { + "assets": { + "cat": { + "cid": "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ", + "url": "https://example.com/cat.jpg", + }, + }, + "effects": { + "identity": { + "cid": "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko", + }, + "invert": { + "cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + }, + }, + } + + +@pytest.fixture +def sample_recipe( + sample_compiled_nodes: List[Dict[str, Any]], + sample_registry: Dict[str, Dict[str, Any]], +) -> Dict[str, Any]: + """Sample compiled recipe.""" + return { + "name": "test-recipe", + "version": "1.0", + "description": "A test recipe", + "owner": "@test@example.com", + "registry": sample_registry, + "dag": { + "nodes": sample_compiled_nodes, + "output": "sequence_1", + }, + "recipe_id": "Qmtest123", + } diff --git a/l1/tests/test_auth.py b/l1/tests/test_auth.py new file mode 100644 index 0000000..a4eb163 --- /dev/null +++ b/l1/tests/test_auth.py @@ -0,0 +1,42 @@ +""" +Tests for authentication service. +""" + +import pytest + + +class TestUserContext: + """Tests for UserContext dataclass.""" + + def test_user_context_accepts_l2_server(self) -> None: + """ + Regression test: UserContext must accept l2_server field. + + Bug found 2026-01-12: auth_service.py passes l2_server to UserContext + but the art-common library was pinned to old version without this field. + """ + from artdag_common.middleware.auth import UserContext + + # This should not raise TypeError + ctx = UserContext( + username="testuser", + actor_id="@testuser@example.com", + token="test-token", + l2_server="https://l2.example.com", + ) + + assert ctx.username == "testuser" + assert ctx.actor_id == "@testuser@example.com" + assert ctx.token == "test-token" + assert ctx.l2_server == "https://l2.example.com" + + def test_user_context_l2_server_optional(self) -> None: + """l2_server should be optional (default None).""" + from artdag_common.middleware.auth import UserContext + + ctx = UserContext( + username="testuser", + actor_id="@testuser@example.com", + ) + + assert ctx.l2_server is None diff --git a/l1/tests/test_cache_manager.py b/l1/tests/test_cache_manager.py new file mode 100644 index 0000000..da2b5ab --- /dev/null +++ b/l1/tests/test_cache_manager.py @@ -0,0 +1,397 @@ +# tests/test_cache_manager.py +"""Tests for the L1 cache manager.""" + +import tempfile +import time +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from cache_manager import ( + L1CacheManager, + L2SharedChecker, + CachedFile, + file_hash, +) + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_l2(): + """Mock L2 server responses.""" + with patch("cache_manager.requests") as mock_requests: + mock_requests.get.return_value = Mock(status_code=404) + yield mock_requests + + +@pytest.fixture +def manager(temp_dir, mock_l2): + """Create a cache manager instance.""" + return L1CacheManager( + cache_dir=temp_dir / "cache", + l2_server="http://mock-l2:8200", + ) + + +def create_test_file(path: Path, content: str = "test content") -> Path: + """Create a test file with content.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + +class TestFileHash: + """Tests for file_hash function.""" + + def test_consistent_hash(self, temp_dir): + """Same content produces same hash.""" + file1 = create_test_file(temp_dir / "f1.txt", "hello") + file2 = create_test_file(temp_dir / "f2.txt", "hello") + + assert file_hash(file1) == file_hash(file2) + + def test_different_content_different_hash(self, temp_dir): + """Different content produces different hash.""" + file1 = create_test_file(temp_dir / "f1.txt", "hello") + file2 = create_test_file(temp_dir / "f2.txt", "world") + + assert file_hash(file1) != file_hash(file2) + + def test_sha3_256_length(self, temp_dir): + """Hash is SHA3-256 (64 hex chars).""" + f = create_test_file(temp_dir / "f.txt", "test") + assert len(file_hash(f)) == 64 + + +class TestL2SharedChecker: + """Tests for L2 shared status checking.""" + + def test_not_shared_returns_false(self, mock_l2): + """Non-existent content returns False.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=404) + + assert checker.is_shared("abc123") is False + + def test_shared_returns_true(self, mock_l2): + """Published content returns True.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=200) + + assert checker.is_shared("abc123") is True + + def test_caches_result(self, mock_l2): + """Results are cached to avoid repeated API calls.""" + checker = L2SharedChecker("http://mock:8200", cache_ttl=60) + mock_l2.get.return_value = Mock(status_code=200) + + checker.is_shared("abc123") + checker.is_shared("abc123") + + # Should only call API once + assert mock_l2.get.call_count == 1 + + def test_mark_shared(self, mock_l2): + """mark_shared updates cache without API call.""" + checker = L2SharedChecker("http://mock:8200") + + checker.mark_shared("abc123") + + assert checker.is_shared("abc123") is True + assert mock_l2.get.call_count == 0 + + def test_invalidate(self, mock_l2): + """invalidate clears cache for a hash.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=200) + + checker.is_shared("abc123") + checker.invalidate("abc123") + + mock_l2.get.return_value = Mock(status_code=404) + assert checker.is_shared("abc123") is False + + def test_error_returns_true(self, mock_l2): + """API errors return True (safe - prevents accidental deletion).""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.side_effect = Exception("Network error") + + # On error, assume IS shared to prevent accidental deletion + assert checker.is_shared("abc123") is True + + +class TestL1CacheManagerStorage: + """Tests for cache storage operations.""" + + def test_put_and_get_by_cid(self, manager, temp_dir): + """Can store and retrieve by content hash.""" + test_file = create_test_file(temp_dir / "input.txt", "hello world") + + cached, cid = manager.put(test_file, node_type="test") + + retrieved_path = manager.get_by_cid(cached.cid) + assert retrieved_path is not None + assert retrieved_path.read_text() == "hello world" + + def test_put_with_custom_node_id(self, manager, temp_dir): + """Can store with custom node_id.""" + test_file = create_test_file(temp_dir / "input.txt", "content") + + cached, cid = manager.put(test_file, node_id="custom-node-123", node_type="test") + + assert cached.node_id == "custom-node-123" + assert manager.get_by_node_id("custom-node-123") is not None + + def test_has_content(self, manager, temp_dir): + """has_content checks existence.""" + test_file = create_test_file(temp_dir / "input.txt", "data") + + cached, cid = manager.put(test_file, node_type="test") + + assert manager.has_content(cached.cid) is True + assert manager.has_content("nonexistent") is False + + def test_list_all(self, manager, temp_dir): + """list_all returns all cached files.""" + f1 = create_test_file(temp_dir / "f1.txt", "one") + f2 = create_test_file(temp_dir / "f2.txt", "two") + + manager.put(f1, node_type="test") + manager.put(f2, node_type="test") + + all_files = manager.list_all() + assert len(all_files) == 2 + + def test_deduplication(self, manager, temp_dir): + """Same content is not stored twice.""" + f1 = create_test_file(temp_dir / "f1.txt", "identical") + f2 = create_test_file(temp_dir / "f2.txt", "identical") + + cached1, cid1 = manager.put(f1, node_type="test") + cached2, cid2 = manager.put(f2, node_type="test") + + assert cached1.cid == cached2.cid + assert len(manager.list_all()) == 1 + + +class TestL1CacheManagerActivities: + """Tests for activity tracking.""" + + def test_record_simple_activity(self, manager, temp_dir): + """Can record a simple activity.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + activity = manager.record_simple_activity( + input_hashes=[input_cached.cid], + output_cid=output_cached.cid, + run_id="run-001", + ) + + assert activity.activity_id == "run-001" + assert input_cached.cid in activity.input_ids + assert activity.output_id == output_cached.cid + + def test_list_activities(self, manager, temp_dir): + """Can list all activities.""" + for i in range(3): + inp = create_test_file(temp_dir / f"in{i}.txt", f"input{i}") + out = create_test_file(temp_dir / f"out{i}.txt", f"output{i}") + inp_c, _ = manager.put(inp, node_type="source") + out_c, _ = manager.put(out, node_type="effect") + manager.record_simple_activity([inp_c.cid], out_c.cid) + + activities = manager.list_activities() + assert len(activities) == 3 + + def test_find_activities_by_inputs(self, manager, temp_dir): + """Can find activities with same inputs.""" + input_file = create_test_file(temp_dir / "shared_input.txt", "shared") + input_cached, _ = manager.put(input_file, node_type="source") + + # Two activities with same input + out1 = create_test_file(temp_dir / "out1.txt", "output1") + out2 = create_test_file(temp_dir / "out2.txt", "output2") + out1_c, _ = manager.put(out1, node_type="effect") + out2_c, _ = manager.put(out2, node_type="effect") + + manager.record_simple_activity([input_cached.cid], out1_c.cid, "run1") + manager.record_simple_activity([input_cached.cid], out2_c.cid, "run2") + + found = manager.find_activities_by_inputs([input_cached.cid]) + assert len(found) == 2 + + +class TestL1CacheManagerDeletionRules: + """Tests for deletion rules enforcement.""" + + def test_can_delete_orphaned_item(self, manager, temp_dir): + """Orphaned items can be deleted.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cached, _ = manager.put(test_file, node_type="test") + + can_delete, reason = manager.can_delete(cached.cid) + assert can_delete is True + + def test_cannot_delete_activity_input(self, manager, temp_dir): + """Activity inputs cannot be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + can_delete, reason = manager.can_delete(input_cached.cid) + assert can_delete is False + assert "input" in reason.lower() + + def test_cannot_delete_activity_output(self, manager, temp_dir): + """Activity outputs cannot be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + can_delete, reason = manager.can_delete(output_cached.cid) + assert can_delete is False + assert "output" in reason.lower() + + def test_cannot_delete_pinned_item(self, manager, temp_dir): + """Pinned items cannot be deleted.""" + test_file = create_test_file(temp_dir / "shared.txt", "shared") + cached, _ = manager.put(test_file, node_type="test") + + # Mark as pinned (published) + manager.pin(cached.cid, reason="published") + + can_delete, reason = manager.can_delete(cached.cid) + assert can_delete is False + assert "pinned" in reason + + def test_delete_orphaned_item(self, manager, temp_dir): + """Can delete orphaned items.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cached, _ = manager.put(test_file, node_type="test") + + success, msg = manager.delete_by_cid(cached.cid) + + assert success is True + assert manager.has_content(cached.cid) is False + + def test_delete_protected_item_fails(self, manager, temp_dir): + """Cannot delete protected items.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + success, msg = manager.delete_by_cid(input_cached.cid) + + assert success is False + assert manager.has_content(input_cached.cid) is True + + +class TestL1CacheManagerActivityDiscard: + """Tests for activity discard functionality.""" + + def test_can_discard_unshared_activity(self, manager, temp_dir): + """Activities with no shared items can be discarded.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + activity = manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + can_discard, reason = manager.can_discard_activity("run-001") + assert can_discard is True + + def test_cannot_discard_activity_with_pinned_output(self, manager, temp_dir): + """Activities with pinned outputs cannot be discarded.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + # Mark output as pinned (published) + manager.pin(output_cached.cid, reason="published") + + can_discard, reason = manager.can_discard_activity("run-001") + assert can_discard is False + assert "pinned" in reason + + def test_discard_activity_cleans_up(self, manager, temp_dir): + """Discarding activity cleans up orphaned items.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + success, msg = manager.discard_activity("run-001") + + assert success is True + assert manager.get_activity("run-001") is None + + +class TestL1CacheManagerStats: + """Tests for cache statistics.""" + + def test_get_stats(self, manager, temp_dir): + """get_stats returns cache statistics.""" + f1 = create_test_file(temp_dir / "f1.txt", "content1") + f2 = create_test_file(temp_dir / "f2.txt", "content2") + + manager.put(f1, node_type="test") + manager.put(f2, node_type="test") + + stats = manager.get_stats() + + assert stats["total_entries"] == 2 + assert stats["total_size_bytes"] > 0 + assert "activities" in stats diff --git a/l1/tests/test_dag_transform.py b/l1/tests/test_dag_transform.py new file mode 100644 index 0000000..32d45f8 --- /dev/null +++ b/l1/tests/test_dag_transform.py @@ -0,0 +1,492 @@ +""" +Tests for DAG transformation and input binding. + +These tests verify the critical path that was causing bugs: +- Node transformation from compiled format to artdag format +- Asset/effect CID resolution from registry +- Variable input name mapping +- Input binding +""" + +import json +import logging +import pytest +from typing import Any, Dict, List, Optional, Tuple + +# Standalone implementations of the functions for testing +# This avoids importing the full app which requires external dependencies + +logger = logging.getLogger(__name__) + + +def is_variable_input(config: Dict[str, Any]) -> bool: + """Check if a SOURCE node config represents a variable input.""" + return bool(config.get("input")) + + +def transform_node( + node: Dict[str, Any], + assets: Dict[str, Dict[str, Any]], + effects: Dict[str, Dict[str, Any]], +) -> Dict[str, Any]: + """Transform a compiled node to artdag execution format.""" + node_id = node.get("id", "") + config = dict(node.get("config", {})) + + if node.get("type") == "SOURCE" and "asset" in config: + asset_name = config["asset"] + if asset_name in assets: + config["cid"] = assets[asset_name].get("cid") + + if node.get("type") == "EFFECT" and "effect" in config: + effect_name = config["effect"] + if effect_name in effects: + config["cid"] = effects[effect_name].get("cid") + + return { + "node_id": node_id, + "node_type": node.get("type", "EFFECT"), + "config": config, + "inputs": node.get("inputs", []), + "name": node.get("name"), + } + + +def build_input_name_mapping(nodes: Dict[str, Dict[str, Any]]) -> Dict[str, str]: + """Build a mapping from input names to node IDs for variable inputs.""" + input_name_to_node: Dict[str, str] = {} + + for node_id, node in nodes.items(): + if node.get("node_type") != "SOURCE": + continue + + config = node.get("config", {}) + if not is_variable_input(config): + continue + + input_name_to_node[node_id] = node_id + + name = config.get("name") + if name: + input_name_to_node[name] = node_id + input_name_to_node[name.lower().replace(" ", "_")] = node_id + input_name_to_node[name.lower().replace(" ", "-")] = node_id + + node_name = node.get("name") + if node_name: + input_name_to_node[node_name] = node_id + input_name_to_node[node_name.replace("-", "_")] = node_id + + return input_name_to_node + + +def bind_inputs( + nodes: Dict[str, Dict[str, Any]], + input_name_to_node: Dict[str, str], + user_inputs: Dict[str, str], +) -> List[str]: + """Bind user-provided input CIDs to source nodes.""" + warnings: List[str] = [] + + for input_name, cid in user_inputs.items(): + if input_name in nodes: + node = nodes[input_name] + if node.get("node_type") == "SOURCE": + node["config"]["cid"] = cid + continue + + if input_name in input_name_to_node: + node_id = input_name_to_node[input_name] + node = nodes[node_id] + node["config"]["cid"] = cid + continue + + warnings.append(f"Input '{input_name}' not found in recipe") + + return warnings + + +def prepare_dag_for_execution( + recipe: Dict[str, Any], + user_inputs: Dict[str, str], +) -> Tuple[str, List[str]]: + """Prepare a recipe DAG for execution.""" + recipe_dag = recipe.get("dag") + if not recipe_dag or not isinstance(recipe_dag, dict): + raise ValueError("Recipe has no DAG definition") + + dag_copy = json.loads(json.dumps(recipe_dag)) + nodes = dag_copy.get("nodes", {}) + + registry = recipe.get("registry", {}) + assets = registry.get("assets", {}) if registry else {} + effects = registry.get("effects", {}) if registry else {} + + if isinstance(nodes, list): + nodes_dict: Dict[str, Dict[str, Any]] = {} + for node in nodes: + node_id = node.get("id") + if node_id: + nodes_dict[node_id] = transform_node(node, assets, effects) + nodes = nodes_dict + dag_copy["nodes"] = nodes + + input_name_to_node = build_input_name_mapping(nodes) + warnings = bind_inputs(nodes, input_name_to_node, user_inputs) + + if "output" in dag_copy: + dag_copy["output_id"] = dag_copy.pop("output") + + if "metadata" not in dag_copy: + dag_copy["metadata"] = {} + + return json.dumps(dag_copy), warnings + + +class TestTransformNode: + """Tests for transform_node function.""" + + def test_source_node_with_asset_resolves_cid( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """SOURCE nodes with asset reference should get CID from registry.""" + node = { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "cat"}, + "inputs": [], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["node_id"] == "source_1" + assert result["node_type"] == "SOURCE" + assert result["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + def test_effect_node_resolves_cid( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """EFFECT nodes should get CID from registry.""" + node = { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "invert", "intensity": 1.0}, + "inputs": ["source_1"], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["node_id"] == "effect_1" + assert result["node_type"] == "EFFECT" + assert result["config"]["effect"] == "invert" + assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + assert result["config"]["intensity"] == 1.0 + + def test_variable_input_node_preserves_config( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """Variable input SOURCE nodes should preserve their config.""" + node = { + "id": "source_2", + "type": "SOURCE", + "config": { + "input": True, + "name": "Second Video", + "description": "User-provided video", + }, + "inputs": [], + "name": "second-video", + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["config"]["input"] is True + assert result["config"]["name"] == "Second Video" + assert result["name"] == "second-video" + # No CID yet - will be bound at runtime + assert "cid" not in result["config"] + + def test_unknown_asset_not_resolved( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """Unknown assets should not crash, just not get CID.""" + node = { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "unknown_asset"}, + "inputs": [], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert "cid" not in result["config"] + + +class TestBuildInputNameMapping: + """Tests for build_input_name_mapping function.""" + + def test_maps_by_node_id(self) -> None: + """Should map by node_id.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["source_2"] == "source_2" + + def test_maps_by_config_name(self) -> None: + """Should map by config.name with various formats.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["Second Video"] == "source_2" + assert mapping["second_video"] == "source_2" + assert mapping["second-video"] == "source_2" + + def test_maps_by_def_name(self) -> None: + """Should map by node.name (def binding name).""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + "name": "inverted-video", + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["inverted-video"] == "source_2" + assert mapping["inverted_video"] == "source_2" + + def test_ignores_non_source_nodes(self) -> None: + """Should not include non-SOURCE nodes.""" + nodes = { + "effect_1": { + "node_id": "effect_1", + "node_type": "EFFECT", + "config": {"effect": "invert"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert "effect_1" not in mapping + + def test_ignores_fixed_sources(self) -> None: + """Should not include SOURCE nodes without 'input' flag.""" + nodes = { + "source_1": { + "node_id": "source_1", + "node_type": "SOURCE", + "config": {"asset": "cat", "cid": "Qm123"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert "source_1" not in mapping + + +class TestBindInputs: + """Tests for bind_inputs function.""" + + def test_binds_by_direct_node_id(self) -> None: + """Should bind when using node_id directly.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + mapping = {"source_2": "source_2"} + user_inputs = {"source_2": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert nodes["source_2"]["config"]["cid"] == "QmUserInput123" + assert len(warnings) == 0 + + def test_binds_by_name_lookup(self) -> None: + """Should bind when using input name.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + } + } + mapping = { + "source_2": "source_2", + "Second Video": "source_2", + "second-video": "source_2", + } + user_inputs = {"second-video": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert nodes["source_2"]["config"]["cid"] == "QmUserInput123" + assert len(warnings) == 0 + + def test_warns_on_unknown_input(self) -> None: + """Should warn when input name not found.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + mapping = {"source_2": "source_2", "Test": "source_2"} + user_inputs = {"unknown-input": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert len(warnings) == 1 + assert "unknown-input" in warnings[0] + + +class TestRegressions: + """Tests for specific bugs that were found in production.""" + + def test_effect_cid_key_not_effect_hash( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """ + Regression test: Effect CID must use 'cid' key, not 'effect_hash'. + + Bug found 2026-01-12: The transform_node function was setting + config["effect_hash"] but the executor looks for config["cid"]. + This caused "Unknown effect: invert" errors. + """ + node = { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "invert"}, + "inputs": ["source_1"], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + # MUST use 'cid' key - executor checks config.get("cid") + assert "cid" in result["config"], "Effect CID must be stored as 'cid'" + assert "effect_hash" not in result["config"], "Must not use 'effect_hash' key" + assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + def test_source_cid_binding_persists( + self, + sample_recipe: Dict[str, Any], + ) -> None: + """ + Regression test: Bound CIDs must appear in final DAG JSON. + + Bug found 2026-01-12: Variable input CIDs were being bound but + not appearing in the serialized DAG sent to the executor. + """ + user_inputs = {"second-video": "QmTestUserInput123"} + + dag_json, _ = prepare_dag_for_execution(sample_recipe, user_inputs) + dag = json.loads(dag_json) + + # The bound CID must be in the final JSON + source_2 = dag["nodes"]["source_2"] + assert source_2["config"]["cid"] == "QmTestUserInput123" + + +class TestPrepareDagForExecution: + """Integration tests for prepare_dag_for_execution.""" + + def test_full_pipeline(self, sample_recipe: Dict[str, Any]) -> None: + """Test the full DAG preparation pipeline.""" + user_inputs = { + "second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + } + + dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs) + + # Parse result + dag = json.loads(dag_json) + + # Check structure + assert "nodes" in dag + assert "output_id" in dag + assert dag["output_id"] == "sequence_1" + + # Check fixed source has CID + source_1 = dag["nodes"]["source_1"] + assert source_1["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + # Check variable input was bound + source_2 = dag["nodes"]["source_2"] + assert source_2["config"]["cid"] == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + + # Check effects have CIDs + effect_1 = dag["nodes"]["effect_1"] + assert effect_1["config"]["cid"] == "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko" + + effect_2 = dag["nodes"]["effect_2"] + assert effect_2["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + assert effect_2["config"]["intensity"] == 1.0 + + # No warnings + assert len(warnings) == 0 + + def test_missing_input_produces_warning( + self, + sample_recipe: Dict[str, Any], + ) -> None: + """Missing inputs should produce warnings but not fail.""" + user_inputs = {} # No inputs provided + + dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs) + + # Should still produce valid JSON + dag = json.loads(dag_json) + assert "nodes" in dag + + # Variable input should not have CID + source_2 = dag["nodes"]["source_2"] + assert "cid" not in source_2["config"] + + def test_raises_on_missing_dag(self) -> None: + """Should raise ValueError if recipe has no DAG.""" + recipe = {"name": "broken", "registry": {}} + + with pytest.raises(ValueError, match="no DAG"): + prepare_dag_for_execution(recipe, {}) diff --git a/l1/tests/test_effect_loading.py b/l1/tests/test_effect_loading.py new file mode 100644 index 0000000..3e41467 --- /dev/null +++ b/l1/tests/test_effect_loading.py @@ -0,0 +1,327 @@ +""" +Tests for effect loading from cache and IPFS. + +These tests verify that: +- Effects can be loaded from the local cache directory +- IPFS gateway configuration is correct for Docker environments +- The effect executor correctly resolves CIDs from config +""" + +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch, MagicMock + +import pytest + + +# Minimal effect loading implementation for testing +# This mirrors the logic in artdag/nodes/effect.py + + +def get_effects_cache_dir_impl(env_vars: Dict[str, str]) -> Optional[Path]: + """Get the effects cache directory from environment or default.""" + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = env_vars.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + if effects_dir.exists(): + return effects_dir + + # Try default locations + for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]: + effects_dir = base / "_effects" + if effects_dir.exists(): + return effects_dir + + return None + + +def effect_path_for_cid(effects_dir: Path, effect_cid: str) -> Path: + """Get the expected path for an effect given its CID.""" + return effects_dir / effect_cid / "effect.py" + + +class TestEffectCacheDirectory: + """Tests for effect cache directory resolution.""" + + def test_cache_dir_from_env(self, tmp_path: Path) -> None: + """CACHE_DIR env var should determine effects directory.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + + env = {"CACHE_DIR": str(tmp_path)} + result = get_effects_cache_dir_impl(env) + + assert result == effects_dir + + def test_artdag_cache_dir_fallback(self, tmp_path: Path) -> None: + """ARTDAG_CACHE_DIR should work as fallback.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + + env = {"ARTDAG_CACHE_DIR": str(tmp_path)} + result = get_effects_cache_dir_impl(env) + + assert result == effects_dir + + def test_no_env_returns_none_if_no_default_exists(self) -> None: + """Should return None if no cache directory exists.""" + env = {} + result = get_effects_cache_dir_impl(env) + + # Will return None unless default dirs exist + # This is expected behavior + if result is not None: + assert result.exists() + + +class TestEffectPathResolution: + """Tests for effect path resolution.""" + + def test_effect_path_structure(self, tmp_path: Path) -> None: + """Effect should be at _effects/{cid}/effect.py.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffect123" + + path = effect_path_for_cid(effects_dir, effect_cid) + + assert path == effects_dir / effect_cid / "effect.py" + + def test_effect_file_exists_after_upload(self, tmp_path: Path) -> None: + """After upload, effect.py should exist in the right location.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Simulate effect upload (as done by app/routers/effects.py) + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + effect_source = '''""" +@effect invert +@version 1.0.0 +""" + +def process_frame(frame, params, state): + return 255 - frame, state +''' + (effect_dir / "effect.py").write_text(effect_source) + + # Verify the path structure + expected_path = effect_path_for_cid(effects_dir, effect_cid) + assert expected_path.exists() + assert "process_frame" in expected_path.read_text() + + +class TestIPFSAPIConfiguration: + """Tests for IPFS API configuration (consistent across codebase).""" + + def test_ipfs_api_multiaddr_conversion(self) -> None: + """ + IPFS_API multiaddr should convert to correct URL. + + Both ipfs_client.py and artdag/nodes/effect.py now use IPFS_API + with multiaddr format for consistency. + """ + import re + + def multiaddr_to_url(multiaddr: str) -> str: + """Convert multiaddr to URL (same logic as ipfs_client.py).""" + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + return "http://127.0.0.1:5001" + + # Docker config + docker_api = "/dns/ipfs/tcp/5001" + url = multiaddr_to_url(docker_api) + assert url == "http://ipfs:5001" + + # Local dev config + local_api = "/ip4/127.0.0.1/tcp/5001" + url = multiaddr_to_url(local_api) + assert url == "http://127.0.0.1:5001" + + def test_all_ipfs_access_uses_api_not_gateway(self) -> None: + """ + All IPFS access should use IPFS_API (port 5001), not IPFS_GATEWAY (port 8080). + + Fixed 2026-01-12: artdag/nodes/effect.py was using a separate IPFS_GATEWAY + variable. Now it uses IPFS_API like ipfs_client.py for consistency. + """ + # The API endpoint that both modules use + api_endpoint = "/api/v0/cat" + + # This is correct - using the API + assert "api/v0" in api_endpoint + + # Gateway endpoint would be /ipfs/{cid} - we don't use this anymore + gateway_pattern = "/ipfs/" + assert gateway_pattern not in api_endpoint + + +class TestEffectExecutorConfigResolution: + """Tests for how the effect executor resolves CID from config.""" + + def test_executor_should_use_cid_key(self) -> None: + """ + Effect executor must look for 'cid' key in config. + + The transform_node function sets config["cid"] for effects. + The executor must read from the same key. + """ + config = { + "effect": "invert", + "cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + "intensity": 1.0, + } + + # Simulate executor CID extraction (from artdag/nodes/effect.py:258) + effect_cid = config.get("cid") or config.get("hash") + + assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + def test_executor_should_not_use_effect_hash(self) -> None: + """ + Regression test: 'effect_hash' is not a valid config key. + + Bug found 2026-01-12: transform_node was using config["effect_hash"] + but executor only checks config["cid"] or config["hash"]. + """ + config = { + "effect": "invert", + "effect_hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + } + + # This simulates the buggy behavior where effect_hash was set + # but executor doesn't look for it + effect_cid = config.get("cid") or config.get("hash") + + # The bug: effect_hash is ignored, effect_cid is None + assert effect_cid is None, "effect_hash should NOT be recognized" + + def test_hash_key_is_legacy_fallback(self) -> None: + """'hash' key should work as legacy fallback for 'cid'.""" + config = { + "effect": "invert", + "hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + } + + effect_cid = config.get("cid") or config.get("hash") + + assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + +class TestEffectLoadingIntegration: + """Integration tests for complete effect loading path.""" + + def test_effect_loads_from_cache_when_present(self, tmp_path: Path) -> None: + """Effect should load from cache without hitting IPFS.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Create effect file in cache + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +def process_frame(frame, params, state): + """Invert colors.""" + return 255 - frame, state +''') + + # Verify the effect can be found + effect_path = effect_path_for_cid(effects_dir, effect_cid) + assert effect_path.exists() + + # Load and verify it has the expected function + import importlib.util + spec = importlib.util.spec_from_file_location("test_effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + assert hasattr(module, "process_frame") + + def test_effect_fetch_uses_ipfs_api(self, tmp_path: Path) -> None: + """Effect fetch should use IPFS API endpoint, not gateway.""" + import re + + def multiaddr_to_url(multiaddr: str) -> str: + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + return "http://127.0.0.1:5001" + + # In Docker, IPFS_API=/dns/ipfs/tcp/5001 + docker_multiaddr = "/dns/ipfs/tcp/5001" + base_url = multiaddr_to_url(docker_multiaddr) + effect_cid = "QmTestCid123" + + # Should use API endpoint + api_url = f"{base_url}/api/v0/cat?arg={effect_cid}" + + assert "ipfs:5001" in api_url + assert "/api/v0/cat" in api_url + assert "127.0.0.1" not in api_url + + +class TestSharedVolumeScenario: + """ + Tests simulating the Docker shared volume scenario. + + In Docker: + - l1-server uploads effect to /data/cache/_effects/{cid}/effect.py + - l1-worker should find it at the same path via shared volume + """ + + def test_effect_visible_on_shared_volume(self, tmp_path: Path) -> None: + """Effect uploaded on server should be visible to worker.""" + # Simulate shared volume mounted at /data/cache on both containers + shared_volume = tmp_path / "data" / "cache" + effects_dir = shared_volume / "_effects" + + # Server uploads effect + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + effect_upload_dir = effects_dir / effect_cid + effect_upload_dir.mkdir(parents=True) + (effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s') + (effect_upload_dir / "metadata.json").write_text('{"cid": "' + effect_cid + '"}') + + # Worker should find the effect + env_vars = {"CACHE_DIR": str(shared_volume)} + worker_effects_dir = get_effects_cache_dir_impl(env_vars) + + assert worker_effects_dir is not None + assert worker_effects_dir == effects_dir + + worker_effect_path = effect_path_for_cid(worker_effects_dir, effect_cid) + assert worker_effect_path.exists() + + def test_effect_cid_matches_registry(self, tmp_path: Path) -> None: + """CID in recipe registry must match the uploaded effect directory name.""" + shared_volume = tmp_path + effects_dir = shared_volume / "_effects" + + # The CID used in the recipe registry + registry_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Upload creates directory with CID as name + effect_upload_dir = effects_dir / registry_cid + effect_upload_dir.mkdir(parents=True) + (effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s') + + # Executor receives the same CID from DAG config + dag_config_cid = registry_cid # This comes from transform_node + + # These must match for the lookup to work + assert dag_config_cid == registry_cid + + # And the path must exist + lookup_path = effects_dir / dag_config_cid / "effect.py" + assert lookup_path.exists() diff --git a/l1/tests/test_effects_web.py b/l1/tests/test_effects_web.py new file mode 100644 index 0000000..5ad461f --- /dev/null +++ b/l1/tests/test_effects_web.py @@ -0,0 +1,367 @@ +""" +Tests for Effects web UI. + +Tests effect metadata parsing, listing, and templates. +""" + +import pytest +import re +from pathlib import Path +from unittest.mock import MagicMock + + +def parse_effect_metadata_standalone(source: str) -> dict: + """ + Standalone copy of parse_effect_metadata for testing. + + This avoids import issues with the router module. + """ + metadata = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + "dependencies": [], + "requires_python": ">=3.10", + } + + # Parse PEP 723 dependencies + pep723_match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL) + if pep723_match: + block = pep723_match.group(1) + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if deps_match: + metadata["dependencies"] = re.findall(r'"([^"]+)"', deps_match.group(1)) + python_match = re.search(r'# requires-python = "([^"]+)"', block) + if python_match: + metadata["requires_python"] = python_match.group(1) + + # Parse docstring @-tags + docstring_match = re.search(r'"""(.*?)"""', source, re.DOTALL) + if not docstring_match: + docstring_match = re.search(r"'''(.*?)'''", source, re.DOTALL) + + if docstring_match: + docstring = docstring_match.group(1) + lines = docstring.split("\n") + + current_param = None + desc_lines = [] + in_description = False + + for line in lines: + stripped = line.strip() + + if stripped.startswith("@effect "): + metadata["name"] = stripped[8:].strip() + in_description = False + + elif stripped.startswith("@version "): + metadata["version"] = stripped[9:].strip() + + elif stripped.startswith("@author "): + metadata["author"] = stripped[8:].strip() + + elif stripped.startswith("@temporal "): + val = stripped[10:].strip().lower() + metadata["temporal"] = val in ("true", "yes", "1") + + elif stripped.startswith("@description"): + in_description = True + desc_lines = [] + + elif stripped.startswith("@param "): + if in_description: + metadata["description"] = " ".join(desc_lines) + in_description = False + if current_param: + metadata["params"].append(current_param) + parts = stripped[7:].split() + if len(parts) >= 2: + current_param = { + "name": parts[0], + "type": parts[1], + "description": "", + } + else: + current_param = None + + elif stripped.startswith("@range ") and current_param: + range_parts = stripped[7:].split() + if len(range_parts) >= 2: + try: + current_param["range"] = [float(range_parts[0]), float(range_parts[1])] + except ValueError: + pass + + elif stripped.startswith("@default ") and current_param: + current_param["default"] = stripped[9:].strip() + + elif stripped.startswith("@example"): + if in_description: + metadata["description"] = " ".join(desc_lines) + in_description = False + if current_param: + metadata["params"].append(current_param) + current_param = None + + elif in_description and stripped: + desc_lines.append(stripped) + + elif current_param and stripped and not stripped.startswith("@"): + current_param["description"] = stripped + + if in_description: + metadata["description"] = " ".join(desc_lines) + + if current_param: + metadata["params"].append(current_param) + + return metadata + + +class TestEffectMetadataParsing: + """Tests for parse_effect_metadata function.""" + + def test_parses_pep723_dependencies(self) -> None: + """Should extract dependencies from PEP 723 script block.""" + source = ''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["numpy", "opencv-python"] +# /// +""" +@effect test_effect +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["dependencies"] == ["numpy", "opencv-python"] + assert meta["requires_python"] == ">=3.10" + + def test_parses_effect_name(self) -> None: + """Should extract effect name from @effect tag.""" + source = ''' +""" +@effect brightness +@version 2.0.0 +@author @artist@example.com +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["name"] == "brightness" + assert meta["version"] == "2.0.0" + assert meta["author"] == "@artist@example.com" + + def test_parses_parameters(self) -> None: + """Should extract parameter definitions.""" + source = ''' +""" +@effect brightness +@param level float +@range -1.0 1.0 +@default 0.0 +Brightness adjustment level +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert len(meta["params"]) == 1 + param = meta["params"][0] + assert param["name"] == "level" + assert param["type"] == "float" + assert param["range"] == [-1.0, 1.0] + assert param["default"] == "0.0" + + def test_parses_temporal_flag(self) -> None: + """Should parse temporal flag correctly.""" + source_temporal = ''' +""" +@effect motion_blur +@temporal true +""" +def process_frame(frame, params, state): + return frame, state +''' + source_not_temporal = ''' +""" +@effect brightness +@temporal false +""" +def process_frame(frame, params, state): + return frame, state +''' + + assert parse_effect_metadata_standalone(source_temporal)["temporal"] is True + assert parse_effect_metadata_standalone(source_not_temporal)["temporal"] is False + + def test_handles_missing_metadata(self) -> None: + """Should return sensible defaults for minimal source.""" + source = ''' +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["name"] == "" + assert meta["version"] == "1.0.0" + assert meta["dependencies"] == [] + assert meta["params"] == [] + + def test_parses_description(self) -> None: + """Should extract description text.""" + source = ''' +""" +@effect test +@description +This is a multi-line +description of the effect. +@param x float +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert "multi-line" in meta["description"] + + +class TestHomePageEffectsCount: + """Test that effects count is shown on home page.""" + + def test_home_template_has_effects_card(self) -> None: + """Home page should display effects count.""" + path = Path('/home/giles/art/art-celery/app/templates/home.html') + content = path.read_text() + + assert 'stats.effects' in content, \ + "Home page should display stats.effects count" + assert 'href="/effects"' in content, \ + "Home page should link to /effects" + + def test_home_route_provides_effects_count(self) -> None: + """Home route should provide effects count in stats.""" + path = Path('/home/giles/art/art-celery/app/routers/home.py') + content = path.read_text() + + assert 'stats["effects"]' in content, \ + "Home route should populate stats['effects']" + + +class TestNavigationIncludesEffects: + """Test that Effects link is in navigation.""" + + def test_base_template_has_effects_link(self) -> None: + """Base template should have Effects navigation link.""" + base_path = Path('/home/giles/art/art-celery/app/templates/base.html') + content = base_path.read_text() + + assert 'href="/effects"' in content + assert "Effects" in content + assert "active_tab == 'effects'" in content + + def test_effects_link_between_recipes_and_media(self) -> None: + """Effects link should be positioned between Recipes and Media.""" + base_path = Path('/home/giles/art/art-celery/app/templates/base.html') + content = base_path.read_text() + + recipes_pos = content.find('href="/recipes"') + effects_pos = content.find('href="/effects"') + media_pos = content.find('href="/media"') + + assert recipes_pos < effects_pos < media_pos, \ + "Effects link should be between Recipes and Media" + + +class TestEffectsTemplatesExist: + """Tests for effects template files.""" + + def test_list_template_exists(self) -> None: + """List template should exist.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + assert path.exists(), "effects/list.html template should exist" + + def test_detail_template_exists(self) -> None: + """Detail template should exist.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + assert path.exists(), "effects/detail.html template should exist" + + def test_list_template_extends_base(self) -> None: + """List template should extend base.html.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + content = path.read_text() + assert '{% extends "base.html" %}' in content + + def test_detail_template_extends_base(self) -> None: + """Detail template should extend base.html.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert '{% extends "base.html" %}' in content + + def test_list_template_has_upload_button(self) -> None: + """List template should have upload functionality.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + content = path.read_text() + assert 'Upload Effect' in content or 'upload' in content.lower() + + def test_detail_template_shows_parameters(self) -> None: + """Detail template should display parameters.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'params' in content.lower() or 'parameter' in content.lower() + + def test_detail_template_shows_source_code(self) -> None: + """Detail template should show source code.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'source' in content.lower() + assert 'language-python' in content + + def test_detail_template_shows_dependencies(self) -> None: + """Detail template should display dependencies.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'dependencies' in content.lower() + + +class TestEffectsRouterExists: + """Tests for effects router configuration.""" + + def test_effects_router_file_exists(self) -> None: + """Effects router should exist.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + assert path.exists(), "effects.py router should exist" + + def test_effects_router_has_list_endpoint(self) -> None: + """Effects router should have list endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.get("")' in content or "@router.get('')" in content + + def test_effects_router_has_detail_endpoint(self) -> None: + """Effects router should have detail endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.get("/{cid}")' in content + + def test_effects_router_has_upload_endpoint(self) -> None: + """Effects router should have upload endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.post("/upload")' in content + + def test_effects_router_renders_templates(self) -> None: + """Effects router should render HTML templates.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert 'effects/list.html' in content + assert 'effects/detail.html' in content diff --git a/l1/tests/test_execute_recipe.py b/l1/tests/test_execute_recipe.py new file mode 100644 index 0000000..0a32326 --- /dev/null +++ b/l1/tests/test_execute_recipe.py @@ -0,0 +1,529 @@ +""" +Tests for execute_recipe SOURCE node resolution. + +These tests verify that SOURCE nodes with :input true are correctly +resolved from input_hashes at execution time. +""" + +import pytest +from unittest.mock import MagicMock, patch +from typing import Dict, Any + + +class MockStep: + """Mock ExecutionStep for testing.""" + + def __init__( + self, + step_id: str, + node_type: str, + config: Dict[str, Any], + cache_id: str, + input_steps: list = None, + level: int = 0, + ): + self.step_id = step_id + self.node_type = node_type + self.config = config + self.cache_id = cache_id + self.input_steps = input_steps or [] + self.level = level + self.name = config.get("name") + self.outputs = [] + + +class MockPlan: + """Mock ExecutionPlanSexp for testing.""" + + def __init__(self, steps: list, output_step_id: str, plan_id: str = "test-plan"): + self.steps = steps + self.output_step_id = output_step_id + self.plan_id = plan_id + + def to_string(self, pretty: bool = False) -> str: + return "(plan test)" + + +def resolve_source_cid(step_config: Dict[str, Any], input_hashes: Dict[str, str]) -> str: + """ + Resolve CID for a SOURCE node. + + This is the logic that should be in execute_recipe - extracted here for unit testing. + """ + source_cid = step_config.get("cid") + + # If source has :input true, resolve CID from input_hashes + if not source_cid and step_config.get("input"): + source_name = step_config.get("name", "") + # Try various key formats for lookup + name_variants = [ + source_name, + source_name.lower().replace(" ", "-"), + source_name.lower().replace(" ", "_"), + source_name.lower(), + ] + for variant in name_variants: + if variant in input_hashes: + source_cid = input_hashes[variant] + break + + if not source_cid: + raise ValueError( + f"SOURCE '{source_name}' not found in input_hashes. " + f"Available: {list(input_hashes.keys())}" + ) + + return source_cid + + +class TestSourceCidResolution: + """Tests for SOURCE node CID resolution from input_hashes.""" + + def test_source_with_fixed_cid(self): + """SOURCE with :cid should use that directly.""" + config = {"cid": "QmFixedCid123"} + input_hashes = {} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmFixedCid123" + + def test_source_with_input_true_exact_match(self): + """SOURCE with :input true should resolve from input_hashes by exact name.""" + config = {"input": True, "name": "my-video"} + input_hashes = {"my-video": "QmInputVideo456"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmInputVideo456" + + def test_source_with_input_true_normalized_dash(self): + """SOURCE with :input true should resolve from normalized dash format.""" + config = {"input": True, "name": "Second Video"} + input_hashes = {"second-video": "QmSecondVideo789"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmSecondVideo789" + + def test_source_with_input_true_normalized_underscore(self): + """SOURCE with :input true should resolve from normalized underscore format.""" + config = {"input": True, "name": "Second Video"} + input_hashes = {"second_video": "QmSecondVideoUnderscore"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmSecondVideoUnderscore" + + def test_source_with_input_true_lowercase(self): + """SOURCE with :input true should resolve from lowercase format.""" + config = {"input": True, "name": "MyVideo"} + input_hashes = {"myvideo": "QmLowercaseVideo"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmLowercaseVideo" + + def test_source_with_input_true_missing_raises(self): + """SOURCE with :input true should raise if not in input_hashes.""" + config = {"input": True, "name": "Missing Video"} + input_hashes = {"other-video": "QmOther123"} + + with pytest.raises(ValueError) as excinfo: + resolve_source_cid(config, input_hashes) + + assert "Missing Video" in str(excinfo.value) + assert "not found in input_hashes" in str(excinfo.value) + assert "other-video" in str(excinfo.value) # Shows available keys + + def test_source_without_cid_or_input_returns_none(self): + """SOURCE without :cid or :input should return None.""" + config = {"name": "some-source"} + input_hashes = {} + + cid = resolve_source_cid(config, input_hashes) + assert cid is None + + def test_source_priority_cid_over_input(self): + """SOURCE with both :cid and :input true should use :cid.""" + config = {"cid": "QmDirectCid", "input": True, "name": "my-video"} + input_hashes = {"my-video": "QmInputHashCid"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmDirectCid" # :cid takes priority + + +class TestSourceNameVariants: + """Tests for the various input name normalization formats.""" + + @pytest.mark.parametrize("source_name,input_key", [ + ("video", "video"), + ("My Video", "my-video"), + ("My Video", "my_video"), + ("My Video", "My Video"), + ("CamelCase", "camelcase"), + ("multiple spaces", "multiple spaces"), # Only single replace + ]) + def test_name_variant_matching(self, source_name: str, input_key: str): + """Various name formats should match.""" + config = {"input": True, "name": source_name} + input_hashes = {input_key: "QmTestCid"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmTestCid" + + +class TestExecuteRecipeIntegration: + """Integration tests for execute_recipe with SOURCE nodes.""" + + def test_recipe_with_user_input_source(self): + """ + Recipe execution should resolve SOURCE nodes with :input true. + + This is the bug that was causing "No executor for node type: SOURCE". + """ + # Create mock plan with a SOURCE that has :input true + source_step = MockStep( + step_id="source_1", + node_type="SOURCE", + config={"input": True, "name": "Second Video", "description": "User input"}, + cache_id="abc123", + level=0, + ) + + effect_step = MockStep( + step_id="effect_1", + node_type="EFFECT", + config={"effect": "invert"}, + cache_id="def456", + input_steps=["source_1"], + level=1, + ) + + plan = MockPlan( + steps=[source_step, effect_step], + output_step_id="effect_1", + ) + + # Input hashes provided by user + input_hashes = { + "second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + } + + # Verify source CID resolution works + resolved_cid = resolve_source_cid(source_step.config, input_hashes) + assert resolved_cid == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + + def test_recipe_with_fixed_and_user_sources(self): + """ + Recipe with both fixed (asset) and user-input sources. + + This is the dog-invert-concat recipe pattern: + - Fixed source: cat asset with known CID + - User input: Second Video from input_hashes + """ + fixed_source = MockStep( + step_id="source_cat", + node_type="SOURCE", + config={"cid": "QmCatVideo123", "asset": "cat"}, + cache_id="cat_cache", + level=0, + ) + + user_source = MockStep( + step_id="source_user", + node_type="SOURCE", + config={"input": True, "name": "Second Video"}, + cache_id="user_cache", + level=0, + ) + + input_hashes = { + "second-video": "QmUserProvidedVideo456" + } + + # Fixed source uses its cid + fixed_cid = resolve_source_cid(fixed_source.config, input_hashes) + assert fixed_cid == "QmCatVideo123" + + # User source resolves from input_hashes + user_cid = resolve_source_cid(user_source.config, input_hashes) + assert user_cid == "QmUserProvidedVideo456" + + +class TestCompoundNodeHandling: + """ + Tests for COMPOUND node handling. + + COMPOUND nodes are collapsed effect chains that should be executed + sequentially through their respective effect executors. + + Bug fixed: "No executor for node type: COMPOUND" + """ + + def test_compound_node_has_filter_chain(self): + """COMPOUND nodes must have a filter_chain config.""" + step = MockStep( + step_id="compound_1", + node_type="COMPOUND", + config={ + "filter_chain": [ + {"type": "EFFECT", "config": {"effect": "identity"}}, + {"type": "EFFECT", "config": {"effect": "dog"}}, + ], + "inputs": ["source_1"], + }, + cache_id="compound_cache", + level=1, + ) + + assert step.node_type == "COMPOUND" + assert "filter_chain" in step.config + assert len(step.config["filter_chain"]) == 2 + + def test_compound_filter_chain_has_effects(self): + """COMPOUND filter_chain should contain EFFECT items with effect names.""" + filter_chain = [ + {"type": "EFFECT", "config": {"effect": "identity", "cid": "Qm123"}}, + {"type": "EFFECT", "config": {"effect": "dog", "cid": "Qm456"}}, + ] + + for item in filter_chain: + assert item["type"] == "EFFECT" + assert "effect" in item["config"] + assert "cid" in item["config"] + + def test_compound_requires_inputs(self): + """COMPOUND nodes must have input steps.""" + step = MockStep( + step_id="compound_1", + node_type="COMPOUND", + config={"filter_chain": [], "inputs": []}, + cache_id="compound_cache", + input_steps=[], + level=1, + ) + + # Empty inputs should be detected as error + assert len(step.input_steps) == 0 + # The execute_recipe should raise ValueError for empty inputs + + +class TestCacheIdLookup: + """ + Tests for cache lookup by code-addressed cache_id. + + Bug fixed: Cache lookups by cache_id (code hash) were failing because + only IPFS CID was indexed. Now we also index by node_id when different. + """ + + def test_cache_id_is_code_addressed(self): + """cache_id should be a SHA3-256 hash (64 hex chars), not IPFS CID.""" + # Code-addressed hash example + cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4" + + assert len(cache_id) == 64 + assert all(c in "0123456789abcdef" for c in cache_id) + assert not cache_id.startswith("Qm") # Not IPFS CID + + def test_ipfs_cid_format(self): + """IPFS CIDs start with 'Qm' (v0) or 'bafy' (v1).""" + ipfs_cid_v0 = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + ipfs_cid_v1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi" + + assert ipfs_cid_v0.startswith("Qm") + assert ipfs_cid_v1.startswith("bafy") + + def test_cache_id_differs_from_ipfs_cid(self): + """ + Code-addressed cache_id is computed BEFORE execution. + IPFS CID is computed AFTER execution from file content. + They will differ for the same logical step. + """ + # Same step can have: + cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4" + ipfs_cid = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + assert cache_id != ipfs_cid + # Both should be indexable for cache lookups + + +class TestPlanStepTypes: + """ + Tests verifying all node types in S-expression plans are handled. + + These tests document the node types that execute_recipe must handle. + """ + + def test_source_node_types(self): + """SOURCE nodes: fixed asset or user input.""" + # Fixed asset source + fixed = MockStep("s1", "SOURCE", {"cid": "Qm123"}, "cache1") + assert fixed.node_type == "SOURCE" + assert "cid" in fixed.config + + # User input source + user = MockStep("s2", "SOURCE", {"input": True, "name": "video"}, "cache2") + assert user.node_type == "SOURCE" + assert user.config.get("input") is True + + def test_effect_node_type(self): + """EFFECT nodes: single effect application.""" + step = MockStep( + "e1", "EFFECT", + {"effect": "invert", "cid": "QmEffect123", "intensity": 1.0}, + "cache3" + ) + assert step.node_type == "EFFECT" + assert "effect" in step.config + + def test_compound_node_type(self): + """COMPOUND nodes: collapsed effect chains.""" + step = MockStep( + "c1", "COMPOUND", + {"filter_chain": [{"type": "EFFECT", "config": {}}]}, + "cache4" + ) + assert step.node_type == "COMPOUND" + assert "filter_chain" in step.config + + def test_sequence_node_type(self): + """SEQUENCE nodes: concatenate multiple clips.""" + step = MockStep( + "seq1", "SEQUENCE", + {"transition": {"type": "cut"}}, + "cache5", + input_steps=["clip1", "clip2"] + ) + assert step.node_type == "SEQUENCE" + + +class TestNodeTypeCaseSensitivity: + """ + Tests for node type case handling. + + Bug fixed: S-expression plans use lowercase (source, compound, effect) + but code was checking uppercase (SOURCE, COMPOUND, EFFECT). + """ + + def test_source_lowercase_from_sexp(self): + """S-expression plans produce lowercase node types.""" + # From plan: (source :cid "Qm...") + step = MockStep("s1", "source", {"cid": "Qm123"}, "cache1") + + # Code should handle lowercase + assert step.node_type.upper() == "SOURCE" + + def test_compound_lowercase_from_sexp(self): + """COMPOUND from S-expression is lowercase.""" + # From plan: (compound :filter_chain ...) + step = MockStep("c1", "compound", {"filter_chain": []}, "cache2") + + assert step.node_type.upper() == "COMPOUND" + + def test_effect_lowercase_from_sexp(self): + """EFFECT from S-expression is lowercase.""" + # From plan: (effect :effect "invert" ...) + step = MockStep("e1", "effect", {"effect": "invert"}, "cache3") + + assert step.node_type.upper() == "EFFECT" + + def test_sequence_lowercase_from_sexp(self): + """SEQUENCE from S-expression is lowercase.""" + # From plan: (sequence :transition ...) + step = MockStep("seq1", "sequence", {"transition": {}}, "cache4") + + assert step.node_type.upper() == "SEQUENCE" + + def test_node_type_comparison_must_be_case_insensitive(self): + """ + Node type comparisons must be case-insensitive. + + This is the actual bug - checking step.node_type == "SOURCE" + fails when step.node_type is "source" from S-expression. + """ + sexp_types = ["source", "compound", "effect", "sequence"] + code_types = ["SOURCE", "COMPOUND", "EFFECT", "SEQUENCE"] + + for sexp, code in zip(sexp_types, code_types): + # Wrong: direct comparison fails + assert sexp != code, f"{sexp} should not equal {code}" + + # Right: case-insensitive comparison works + assert sexp.upper() == code, f"{sexp}.upper() should equal {code}" + + +class TestExecuteRecipeErrorHandling: + """Tests for error handling in execute_recipe.""" + + def test_missing_input_hash_error_message(self): + """Error should list available input keys when source not found.""" + config = {"input": True, "name": "Unknown Video"} + input_hashes = {"video-a": "Qm1", "video-b": "Qm2"} + + with pytest.raises(ValueError) as excinfo: + resolve_source_cid(config, input_hashes) + + error_msg = str(excinfo.value) + assert "Unknown Video" in error_msg + assert "video-a" in error_msg or "video-b" in error_msg + + def test_source_no_cid_no_input_error(self): + """SOURCE without cid or input flag should return None (invalid).""" + config = {"name": "broken-source"} # Missing both cid and input + input_hashes = {} + + result = resolve_source_cid(config, input_hashes) + assert result is None # execute_recipe should catch this + + +class TestRecipeOutputRequired: + """ + Tests verifying recipes must produce output to succeed. + + Bug: Recipe was returning success=True with output_cid=None + """ + + def test_recipe_without_output_should_fail(self): + """ + Recipe execution must fail if no output is produced. + + This catches the bug where execute_recipe returned success=True + but output_cid was None. + """ + # Simulate the check that should happen in execute_recipe + output_cid = None + step_results = {"step1": {"status": "executed", "path": "/tmp/x"}} + + # This is the logic that should be in execute_recipe + success = output_cid is not None + + assert success is False, "Recipe with no output_cid should fail" + + def test_recipe_with_output_should_succeed(self): + """Recipe with valid output should succeed.""" + output_cid = "QmOutputCid123" + + success = output_cid is not None + assert success is True + + def test_output_step_result_must_have_cid(self): + """Output step result must contain cid or cache_id.""" + # Step result without cid + bad_result = {"status": "executed", "path": "/tmp/output.mkv"} + output_cid = bad_result.get("cid") or bad_result.get("cache_id") + assert output_cid is None, "Should detect missing cid" + + # Step result with cid + good_result = {"status": "executed", "path": "/tmp/output.mkv", "cid": "Qm123"} + output_cid = good_result.get("cid") or good_result.get("cache_id") + assert output_cid == "Qm123" + + def test_output_step_must_exist_in_results(self): + """Output step must be present in step_results.""" + step_results = { + "source_1": {"status": "source", "cid": "QmSrc"}, + "effect_1": {"status": "executed", "cid": "QmEffect"}, + # Note: output_step "sequence_1" is missing! + } + + output_step_id = "sequence_1" + output_result = step_results.get(output_step_id, {}) + output_cid = output_result.get("cid") + + assert output_cid is None, "Missing output step should result in no cid" diff --git a/l1/tests/test_frame_compatibility.py b/l1/tests/test_frame_compatibility.py new file mode 100644 index 0000000..f12cce0 --- /dev/null +++ b/l1/tests/test_frame_compatibility.py @@ -0,0 +1,185 @@ +""" +Integration tests for GPU/CPU frame compatibility. + +These tests verify that all primitives work correctly with both: +- numpy arrays (CPU frames) +- CuPy arrays (GPU frames) +- GPUFrame wrapper objects + +Run with: pytest tests/test_frame_compatibility.py -v +""" + +import pytest +import numpy as np +import sys +import os + +# Add parent to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Try to import CuPy +try: + import cupy as cp + HAS_GPU = True +except ImportError: + cp = None + HAS_GPU = False + + +def create_test_frame(on_gpu=False): + """Create a test frame (100x100 RGB).""" + frame = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + if on_gpu and HAS_GPU: + return cp.asarray(frame) + return frame + + +class MockGPUFrame: + """Mock GPUFrame for testing without full GPU stack.""" + def __init__(self, data): + self._data = data + + @property + def cpu(self): + if HAS_GPU and hasattr(self._data, 'get'): + return self._data.get() + return self._data + + @property + def gpu(self): + if HAS_GPU: + if hasattr(self._data, 'get'): + return self._data + return cp.asarray(self._data) + raise RuntimeError("No GPU available") + + +class TestColorOps: + """Test color_ops primitives with different frame types.""" + + def test_shift_hsv_numpy(self): + """shift-hsv should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = create_test_frame(on_gpu=False) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) + assert result.shape == frame.shape + + @pytest.mark.skipif(not HAS_GPU, reason="No GPU") + def test_shift_hsv_cupy(self): + """shift-hsv should work with CuPy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = create_test_frame(on_gpu=True) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) # Should return numpy + + def test_shift_hsv_gpuframe(self): + """shift-hsv should work with GPUFrame objects.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = MockGPUFrame(create_test_frame(on_gpu=False)) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) + + def test_invert_numpy(self): + """invert-img should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_invert_img + frame = create_test_frame(on_gpu=False) + result = prim_invert_img(frame) + assert isinstance(result, np.ndarray) + + def test_adjust_numpy(self): + """adjust should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_adjust + frame = create_test_frame(on_gpu=False) + result = prim_adjust(frame, brightness=10, contrast=1.2) + assert isinstance(result, np.ndarray) + + +class TestGeometry: + """Test geometry primitives with different frame types.""" + + def test_rotate_numpy(self): + """rotate should work with numpy arrays.""" + from sexp_effects.primitive_libs.geometry import prim_rotate + frame = create_test_frame(on_gpu=False) + result = prim_rotate(frame, 45) + assert isinstance(result, np.ndarray) + + def test_scale_numpy(self): + """scale should work with numpy arrays.""" + from sexp_effects.primitive_libs.geometry import prim_scale + frame = create_test_frame(on_gpu=False) + result = prim_scale(frame, 1.5) + assert isinstance(result, np.ndarray) + + +class TestBlending: + """Test blending primitives with different frame types.""" + + def test_blend_numpy(self): + """blend should work with numpy arrays.""" + from sexp_effects.primitive_libs.blending import prim_blend + frame_a = create_test_frame(on_gpu=False) + frame_b = create_test_frame(on_gpu=False) + result = prim_blend(frame_a, frame_b, 0.5) + assert isinstance(result, np.ndarray) + + +class TestInterpreterConversion: + """Test the interpreter's frame conversion.""" + + def test_maybe_to_numpy_none(self): + """_maybe_to_numpy should handle None.""" + from streaming.stream_sexp_generic import StreamInterpreter + # Create minimal interpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + assert interp._maybe_to_numpy(None) is None + + def test_maybe_to_numpy_numpy(self): + """_maybe_to_numpy should pass through numpy arrays.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = create_test_frame(on_gpu=False) + result = interp._maybe_to_numpy(frame) + assert result is frame # Should be same object + + @pytest.mark.skipif(not HAS_GPU, reason="No GPU") + def test_maybe_to_numpy_cupy(self): + """_maybe_to_numpy should convert CuPy to numpy.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = create_test_frame(on_gpu=True) + result = interp._maybe_to_numpy(frame) + assert isinstance(result, np.ndarray) + + def test_maybe_to_numpy_gpuframe(self): + """_maybe_to_numpy should convert GPUFrame to numpy.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = MockGPUFrame(create_test_frame(on_gpu=False)) + result = interp._maybe_to_numpy(frame) + assert isinstance(result, np.ndarray) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/l1/tests/test_item_visibility.py b/l1/tests/test_item_visibility.py new file mode 100644 index 0000000..be6c50b --- /dev/null +++ b/l1/tests/test_item_visibility.py @@ -0,0 +1,272 @@ +""" +Tests for item visibility in L1 web UI. + +Bug found 2026-01-12: L1 run succeeded but web UI not showing: +- Runs +- Recipes +- Created media + +Root causes identified: +1. Recipes: owner field filtering but owner never set in loaded recipes +2. Media: item_types table entries not created on upload/import +3. Run outputs: outputs not registered in item_types table +""" + +import pytest +from pathlib import Path +from unittest.mock import MagicMock, AsyncMock, patch +import tempfile + + +class TestRecipeVisibility: + """Tests for recipe listing visibility.""" + + def test_recipe_filter_allows_none_owner(self) -> None: + """ + Regression test: The recipe filter should allow recipes where owner is None. + + Bug: recipe_service.list_recipes() filtered by owner == actor_id, + but owner field is None in recipes loaded from S-expression files. + This caused ALL recipes to be filtered out. + + Fix: The filter is now: actor_id is None or owner is None or owner == actor_id + """ + # Simulate the filter logic from recipe_service.list_recipes + # OLD (broken): if actor_id is None or owner == actor_id + # NEW (fixed): if actor_id is None or owner is None or owner == actor_id + + test_cases = [ + # (actor_id, owner, expected_visible, description) + (None, None, True, "No filter, no owner -> visible"), + (None, "@someone@example.com", True, "No filter, has owner -> visible"), + ("@testuser@example.com", None, True, "Has filter, no owner -> visible (shared)"), + ("@testuser@example.com", "@testuser@example.com", True, "Filter matches owner -> visible"), + ("@testuser@example.com", "@other@example.com", False, "Filter doesn't match -> hidden"), + ] + + for actor_id, owner, expected_visible, description in test_cases: + # This is the FIXED filter logic from recipe_service.py line 86 + is_visible = actor_id is None or owner is None or owner == actor_id + + assert is_visible == expected_visible, f"Failed: {description}" + + def test_recipe_filter_old_logic_was_broken(self) -> None: + """Document that the old filter logic excluded all recipes with owner=None.""" + # OLD filter: actor_id is None or owner == actor_id + # This broke when owner=None and actor_id was provided + + actor_id = "@testuser@example.com" + owner = None # This is what compiled sexp produces + + # OLD logic (broken): + old_logic_visible = actor_id is None or owner == actor_id + assert old_logic_visible is False, "Old logic incorrectly hid owner=None recipes" + + # NEW logic (fixed): + new_logic_visible = actor_id is None or owner is None or owner == actor_id + assert new_logic_visible is True, "New logic should show owner=None recipes" + + +class TestMediaVisibility: + """Tests for media visibility after upload.""" + + @pytest.mark.asyncio + async def test_upload_content_creates_item_type_record(self) -> None: + """ + Test: Uploaded media must be registered in item_types table via save_item_metadata. + + The save_item_metadata function creates entries in item_types table, + enabling the media to appear in list_media queries. + """ + import importlib.util + spec = importlib.util.spec_from_file_location( + "cache_service", + "/home/giles/art/art-celery/app/services/cache_service.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + CacheService = module.CacheService + + # Create mocks + mock_db = AsyncMock() + mock_db.create_cache_item = AsyncMock() + mock_db.save_item_metadata = AsyncMock() + + mock_cache = MagicMock() + cached_result = MagicMock() + cached_result.cid = "QmUploadedContent123" + mock_cache.put.return_value = (cached_result, "QmIPFSCid123") + + service = CacheService(database=mock_db, cache_manager=mock_cache) + + # Upload content + cid, ipfs_cid, error = await service.upload_content( + content=b"test video content", + filename="test.mp4", + actor_id="@testuser@example.com", + ) + + assert error is None, f"Upload failed: {error}" + assert cid is not None + + # Verify save_item_metadata was called (which creates item_types entry) + mock_db.save_item_metadata.assert_called_once() + + # Verify it was called with correct actor_id and a media type (not mime type) + call_kwargs = mock_db.save_item_metadata.call_args[1] + assert call_kwargs.get('actor_id') == "@testuser@example.com", \ + "save_item_metadata must be called with the uploading user's actor_id" + # item_type should be media category like "video", "image", "audio", "unknown" + # NOT mime type like "video/mp4" + item_type = call_kwargs.get('item_type') + assert item_type in ("video", "image", "audio", "unknown"), \ + f"item_type should be media category, got '{item_type}'" + + @pytest.mark.asyncio + async def test_import_from_ipfs_creates_item_type_record(self) -> None: + """ + Test: Imported media must be registered in item_types table via save_item_metadata. + + The save_item_metadata function creates entries in item_types table with + detected media type, enabling the media to appear in list_media queries. + """ + import importlib.util + spec = importlib.util.spec_from_file_location( + "cache_service", + "/home/giles/art/art-celery/app/services/cache_service.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + CacheService = module.CacheService + + mock_db = AsyncMock() + mock_db.create_cache_item = AsyncMock() + mock_db.save_item_metadata = AsyncMock() + + mock_cache = MagicMock() + cached_result = MagicMock() + cached_result.cid = "QmImportedContent123" + mock_cache.put.return_value = (cached_result, "QmIPFSCid456") + + service = CacheService(database=mock_db, cache_manager=mock_cache) + service.cache_dir = Path(tempfile.gettempdir()) + + # We need to mock the ipfs_client module at the right location + import importlib + ipfs_module = MagicMock() + ipfs_module.get_file = MagicMock(return_value=True) + + # Patch at module level + with patch.dict('sys.modules', {'ipfs_client': ipfs_module}): + # Import from IPFS + cid, error = await service.import_from_ipfs( + ipfs_cid="QmSourceIPFSCid", + actor_id="@testuser@example.com", + ) + + # Verify save_item_metadata was called (which creates item_types entry) + mock_db.save_item_metadata.assert_called_once() + + # Verify it was called with detected media type (not hardcoded "media") + call_kwargs = mock_db.save_item_metadata.call_args[1] + item_type = call_kwargs.get('item_type') + assert item_type in ("video", "image", "audio", "unknown"), \ + f"item_type should be detected media category, got '{item_type}'" + + +class TestRunOutputVisibility: + """Tests for run output visibility.""" + + @pytest.mark.asyncio + async def test_completed_run_output_visible_in_media_list(self) -> None: + """ + Run outputs should be accessible in media listings. + + When a run completes, its output should be registered in item_types + so it appears in the user's media gallery. + """ + # This test documents the expected behavior + # Run outputs are stored in run_cache but should also be in item_types + # for the media gallery to show them + + # The fix should either: + # 1. Add item_types entry when run completes, OR + # 2. Modify list_media to also check run_cache outputs + pass # Placeholder for implementation test + + +class TestDatabaseItemTypes: + """Tests for item_types database operations.""" + + @pytest.mark.asyncio + async def test_add_item_type_function_exists(self) -> None: + """Verify add_item_type function exists and has correct signature.""" + import database + + assert hasattr(database, 'add_item_type'), \ + "database.add_item_type function should exist" + + # Check it's an async function + import inspect + assert inspect.iscoroutinefunction(database.add_item_type), \ + "add_item_type should be an async function" + + @pytest.mark.asyncio + async def test_get_user_items_returns_items_from_item_types(self) -> None: + """ + Verify get_user_items queries item_types table. + + If item_types has no entries for a user, they see no media. + """ + # This is a documentation test showing the data flow: + # 1. User uploads content -> should create item_types entry + # 2. list_media -> calls get_user_items -> queries item_types + # 3. If step 1 didn't create item_types entry, step 2 returns empty + pass + + +class TestTemplateRendering: + """Tests for template variable passing.""" + + def test_cache_not_found_template_receives_content_hash(self) -> None: + """ + Regression test: cache/not_found.html template requires content_hash. + + Bug: The template uses {{ content_hash[:24] }} but the route + doesn't pass content_hash to the render context. + + Error: jinja2.exceptions.UndefinedError: 'content_hash' is undefined + """ + # This test documents the bug - the template expects content_hash + # but the route at /app/app/routers/cache.py line 57 doesn't provide it + pass # Will verify fix by checking route code + + +class TestOwnerFieldInRecipes: + """Tests for owner field handling in recipes.""" + + def test_sexp_recipe_has_none_owner_by_default(self) -> None: + """ + S-expression recipes have owner=None by default. + + The compiled recipe includes owner field but it's None, + so the list_recipes filter must allow owner=None to show + shared/public recipes. + """ + sample_sexp = """ + (recipe "test" + (-> (source :input true :name "video") + (fx identity))) + """ + + from artdag.sexp import compile_string + + compiled = compile_string(sample_sexp) + recipe_dict = compiled.to_dict() + + # The compiled recipe has owner field but it's None + assert recipe_dict.get("owner") is None, \ + "Compiled S-expression should have owner=None" + + # This means the filter must allow owner=None for recipes to be visible + # The fix: if actor_id is None or owner is None or owner == actor_id diff --git a/l1/tests/test_jax_pipeline_integration.py b/l1/tests/test_jax_pipeline_integration.py new file mode 100644 index 0000000..8f9fb93 --- /dev/null +++ b/l1/tests/test_jax_pipeline_integration.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +"""Integration tests comparing JAX and Python rendering pipelines. + +These tests ensure the JAX-compiled effect chains produce identical output +to the Python/NumPy path. They test: +1. Full effect pipelines through both interpreters +2. Multi-frame sequences (to catch phase-dependent bugs) +3. Compiled effect chain fusion +4. Edge cases like shrinking/zooming that affect boundary handling +""" +import os +import sys +import pytest +import numpy as np +import shutil +from pathlib import Path + +# Ensure the art-celery module is importable +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sexp_effects.primitive_libs import core as core_mod + + +# Path to test resources +TEST_DIR = Path('/home/giles/art/test') +EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects' +TEMPLATES_DIR = TEST_DIR / 'templates' + + +def create_test_image(h=96, w=128): + """Create a test image with distinct patterns.""" + import cv2 + img = np.zeros((h, w, 3), dtype=np.uint8) + + # Create gradient background + for y in range(h): + for x in range(w): + img[y, x] = [ + int(255 * x / w), # R: horizontal gradient + int(255 * y / h), # G: vertical gradient + 128 # B: constant + ] + + # Add features + cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1) + cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1) + + return img + + +@pytest.fixture(scope='module') +def test_env(tmp_path_factory): + """Set up test environment with sexp files and test media.""" + test_dir = tmp_path_factory.mktemp('sexp_test') + original_dir = os.getcwd() + os.chdir(test_dir) + + # Create directories + (test_dir / 'effects').mkdir() + (test_dir / 'sexp_effects' / 'effects').mkdir(parents=True) + + # Create test image + import cv2 + test_img = create_test_image() + cv2.imwrite(str(test_dir / 'test_image.png'), test_img) + + # Copy required effect files + for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']: + src = EFFECTS_DIR / f'{effect}.sexp' + dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp' + if src.exists(): + shutil.copy(src, dst) + + yield { + 'dir': test_dir, + 'image_path': test_dir / 'test_image.png', + 'test_img': test_img, + } + + os.chdir(original_dir) + + +def create_sexp_file(test_dir, content, filename='test.sexp'): + """Create a test sexp file.""" + path = test_dir / 'effects' / filename + with open(path, 'w') as f: + f.write(content) + return str(path) + + +class TestJaxPythonPipelineEquivalence: + """Test that JAX and Python pipelines produce equivalent output.""" + + def test_single_rotate_effect(self, test_env): + """Test that a single rotate effect matches between Python and JAX.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + + (frame (rotate frame :angle 15 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + # Python path + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + # JAX path + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + # Inject test image into globals + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = py_interp._eval(py_interp.frame_pipeline, frame_env) + jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env) + + # Force deferred if needed + py_result = np.asarray(py_interp._maybe_force(py_result)) + jax_result = np.asarray(jax_interp._maybe_force(jax_result)) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold" + + def test_rotate_then_zoom(self, test_env): + """Test rotate followed by zoom - tests effect chain fusion.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (-> (rotate frame :angle 15 :speed 0) + (zoom :amount 0.95 :speed 0))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold" + + def test_zoom_shrink_boundary_handling(self, test_env): + """Test zoom with shrinking factor - critical for boundary handling.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (zoom frame :amount 0.8 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + # Check corners specifically - these are most affected by boundary handling + h, w = test_img.shape[:2] + corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)] + for y, x in corners: + py_val = py_result[y, x] + jax_val = jax_result[y, x] + corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max() + assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}" + + def test_blend_two_clips(self, test_env): + """Test blending two effect chains - the core bug scenario.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + + # Check edge region specifically + edge_diff = diff[0, :].mean() + + assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold" + assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold" + + def test_blend_with_invert(self, test_env): + """Test blending with invert - matches the problematic recipe pattern.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (require-primitives "color_ops") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + (effect invert :path "../sexp_effects/effects/invert.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0) + (invert :amount 1)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold" + + +class TestDeferredEffectChainFusion: + """Test the DeferredEffectChain fusion mechanism specifically.""" + + def test_manual_vs_fused_chain(self, test_env): + """Compare manual application vs fused DeferredEffectChain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + # Create minimal sexp to load effects + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Manual step-by-step application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + rot_angle = -5.0 + zoom_amount = 0.95 + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain application + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + diff = np.abs(manual_result.astype(float) - fused_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}" + + # Check specific pixels + h, w = test_img.shape[:2] + for y in [0, h//2, h-1]: + for x in [0, w//2, w-1]: + manual_val = manual_result[y, x] + fused_val = fused_result[y, x] + pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max() + assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}" + + def test_chain_with_shrink_zoom_boundary(self, test_env): + """Test that shrinking zoom handles boundaries correctly in chain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Parameters that shrink the image (zoom < 1.0) + rot_angle = -4.555 + zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries + + # Manual application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + # Check top edge specifically - this is where boundary issues manifest + top_edge_manual = manual_result[0, :] + top_edge_fused = fused_result[0, :] + + edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float)) + mean_edge_diff = np.mean(edge_diff) + + assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}" + + # Check for zeros at edges that shouldn't be there + manual_edge_sum = np.sum(top_edge_manual) + fused_edge_sum = np.sum(top_edge_fused) + + if manual_edge_sum > 100: # If manual has significant values + assert fused_edge_sum > manual_edge_sum * 0.5, \ + f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}" + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/l1/tests/test_jax_primitives.py b/l1/tests/test_jax_primitives.py new file mode 100644 index 0000000..5fad678 --- /dev/null +++ b/l1/tests/test_jax_primitives.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +""" +Test framework to verify JAX primitives match Python primitives. + +Compares output of each primitive through: +1. Python/NumPy path +2. JAX path (CPU) +3. JAX path (GPU) - if available + +Reports any mismatches with detailed diffs. +""" +import sys +sys.path.insert(0, '/home/giles/art/art-celery') + +import numpy as np +import json +from pathlib import Path +from typing import Dict, List, Tuple, Any, Optional +from dataclasses import dataclass, field + +# Test configuration +TEST_WIDTH = 64 +TEST_HEIGHT = 48 +TOLERANCE_MEAN = 1.0 # Max allowed mean difference +TOLERANCE_MAX = 10.0 # Max allowed single pixel difference +TOLERANCE_PCT = 0.95 # Min % of pixels within ±1 + + +@dataclass +class TestResult: + primitive: str + passed: bool + python_mean: float = 0.0 + jax_mean: float = 0.0 + diff_mean: float = 0.0 + diff_max: float = 0.0 + pct_within_1: float = 0.0 + error: str = "" + + +def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'): + """Create a test frame with known pattern.""" + if pattern == 'gradient': + # Diagonal gradient + y, x = np.mgrid[0:h, 0:w] + r = (x * 255 / w).astype(np.uint8) + g = (y * 255 / h).astype(np.uint8) + b = ((x + y) * 127 / (w + h)).astype(np.uint8) + return np.stack([r, g, b], axis=2) + elif pattern == 'checkerboard': + y, x = np.mgrid[0:h, 0:w] + check = ((x // 8) + (y // 8)) % 2 + v = (check * 255).astype(np.uint8) + return np.stack([v, v, v], axis=2) + elif pattern == 'solid': + return np.full((h, w, 3), 128, dtype=np.uint8) + else: + return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + + +def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]: + """Compare two outputs, return (mean_diff, max_diff, pct_within_1).""" + if py_out is None or jax_out is None: + return float('inf'), float('inf'), 0.0 + + if isinstance(py_out, dict) and isinstance(jax_out, dict): + # Compare coordinate maps + diffs = [] + for k in py_out: + if k in jax_out: + py_arr = np.asarray(py_out[k]) + jax_arr = np.asarray(jax_out[k]) + if py_arr.shape == jax_arr.shape: + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + diffs.append(diff) + if diffs: + all_diff = np.concatenate([d.flatten() for d in diffs]) + return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1)) + return float('inf'), float('inf'), 0.0 + + py_arr = np.asarray(py_out) + jax_arr = np.asarray(jax_out) + + if py_arr.shape != jax_arr.shape: + return float('inf'), float('inf'), 0.0 + + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1)) + + +# ============================================================================ +# Primitive Test Definitions +# ============================================================================ + +PRIMITIVE_TESTS = { + # Geometry primitives + 'geometry:ripple-displace': { + 'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5], + 'returns': 'coords', + }, + 'geometry:rotate-img': { + 'args': ['frame', 45], + 'returns': 'frame', + }, + 'geometry:scale-img': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'geometry:flip-h': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'geometry:flip-v': { + 'args': ['frame'], + 'returns': 'frame', + }, + + # Color operations + 'color_ops:invert': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:grayscale': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:brightness': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:contrast': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:hue-shift': { + 'args': ['frame', 90], + 'returns': 'frame', + }, + + # Image operations + 'image:width': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:height': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:channel': { + 'args': ['frame', 0], + 'returns': 'array', + }, + + # Blending + 'blending:blend': { + 'args': ['frame', 'frame2', 0.5], + 'returns': 'frame', + }, + 'blending:blend-add': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, + 'blending:blend-multiply': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, +} + + +def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the Python interpreter.""" + if prim_name not in interp.primitives: + return None + + func = interp.primitives[prim_name] + args = [] + for a in test_def['args']: + if a == 'frame': + args.append(frame.copy()) + elif a == 'frame2': + args.append(frame2.copy()) + else: + args.append(a) + + try: + return func(*args) + except Exception as e: + return None + + +def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the JAX compiler.""" + try: + from streaming.sexp_to_jax import JaxCompiler + import jax.numpy as jnp + + compiler = JaxCompiler() + + # Build a simple expression to test the primitive + from sexp_effects.parser import Symbol, Keyword + + args = [] + env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)} + + for a in test_def['args']: + if a == 'frame': + args.append(Symbol('frame')) + elif a == 'frame2': + args.append(Symbol('frame2')) + else: + args.append(a) + + # Create expression: (prim_name arg1 arg2 ...) + expr = [Symbol(prim_name)] + args + + result = compiler._eval(expr, env) + + if hasattr(result, '__array__'): + return np.asarray(result) + return result + + except Exception as e: + return None + + +def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult: + """Test a single primitive.""" + frame = create_test_frame(pattern='gradient') + frame2 = create_test_frame(pattern='checkerboard') + + result = TestResult(primitive=prim_name, passed=False) + + # Run Python version + try: + py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2) + if py_out is not None and hasattr(py_out, 'shape'): + result.python_mean = float(np.mean(py_out)) + except Exception as e: + result.error = f"Python error: {e}" + return result + + # Run JAX version + try: + jax_out = run_jax_primitive(prim_name, test_def, frame, frame2) + if jax_out is not None and hasattr(jax_out, 'shape'): + result.jax_mean = float(np.mean(jax_out)) + except Exception as e: + result.error = f"JAX error: {e}" + return result + + if py_out is None: + result.error = "Python returned None" + return result + if jax_out is None: + result.error = "JAX returned None" + return result + + # Compare + diff_mean, diff_max, pct = compare_outputs(py_out, jax_out) + result.diff_mean = diff_mean + result.diff_max = diff_max + result.pct_within_1 = pct + + # Check pass/fail + result.passed = ( + diff_mean <= TOLERANCE_MEAN and + diff_max <= TOLERANCE_MAX and + pct >= TOLERANCE_PCT + ) + + if not result.passed: + result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}" + + return result + + +def discover_primitives(interp) -> List[str]: + """Discover all primitives available in the interpreter.""" + return sorted(interp.primitives.keys()) + + +def run_all_tests(verbose=True): + """Run all primitive tests.""" + import warnings + warnings.filterwarnings('ignore') + + import os + os.chdir('/home/giles/art/test') + + from streaming.stream_sexp_generic import StreamInterpreter + from sexp_effects.primitive_libs import core as core_mod + + core_mod.set_random_seed(42) + + # Create interpreter to get Python primitives + interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False) + interp._init() + + results = [] + + print("=" * 60) + print("JAX PRIMITIVE TEST SUITE") + print("=" * 60) + + # Test defined primitives + for prim_name, test_def in PRIMITIVE_TESTS.items(): + result = test_primitive(interp, prim_name, test_def) + results.append(result) + + status = "✓ PASS" if result.passed else "✗ FAIL" + if verbose: + print(f"{status} {prim_name}") + if not result.passed: + print(f" {result.error}") + + # Summary + passed = sum(1 for r in results if r.passed) + failed = sum(1 for r in results if not r.passed) + + print("\n" + "=" * 60) + print(f"SUMMARY: {passed} passed, {failed} failed") + print("=" * 60) + + if failed > 0: + print("\nFailed primitives:") + for r in results: + if not r.passed: + print(f" - {r.primitive}: {r.error}") + + return results + + +if __name__ == '__main__': + run_all_tests() diff --git a/l1/tests/test_naming_service.py b/l1/tests/test_naming_service.py new file mode 100644 index 0000000..98d8e52 --- /dev/null +++ b/l1/tests/test_naming_service.py @@ -0,0 +1,246 @@ +""" +Tests for the friendly naming service. +""" + +import re +import pytest +from pathlib import Path + + +# Copy the pure functions from naming_service for testing +# This avoids import issues with the app module + +CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" + + +def normalize_name(name: str) -> str: + """Copy of normalize_name for testing.""" + name = name.lower() + name = re.sub(r"[\s_]+", "-", name) + name = re.sub(r"[^a-z0-9-]", "", name) + name = re.sub(r"-+", "-", name) + name = name.strip("-") + return name or "unnamed" + + +def parse_friendly_name(friendly_name: str): + """Copy of parse_friendly_name for testing.""" + parts = friendly_name.strip().split(" ", 1) + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + return base_name, version_id + + +def format_friendly_name(base_name: str, version_id: str) -> str: + """Copy of format_friendly_name for testing.""" + return f"{base_name} {version_id}" + + +def format_l2_name(actor_id: str, base_name: str, version_id: str) -> str: + """Copy of format_l2_name for testing.""" + return f"{actor_id} {base_name} {version_id}" + + +class TestNameNormalization: + """Tests for name normalization.""" + + def test_normalize_simple_name(self) -> None: + """Simple names should be lowercased.""" + assert normalize_name("Brightness") == "brightness" + + def test_normalize_spaces_to_dashes(self) -> None: + """Spaces should be converted to dashes.""" + assert normalize_name("My Cool Effect") == "my-cool-effect" + + def test_normalize_underscores_to_dashes(self) -> None: + """Underscores should be converted to dashes.""" + assert normalize_name("brightness_v2") == "brightness-v2" + + def test_normalize_removes_special_chars(self) -> None: + """Special characters should be removed.""" + # Special chars are removed (not replaced with dashes) + assert normalize_name("Test!!!Effect") == "testeffect" + assert normalize_name("cool@effect#1") == "cooleffect1" + # But spaces/underscores become dashes first + assert normalize_name("Test Effect!") == "test-effect" + + def test_normalize_collapses_dashes(self) -> None: + """Multiple dashes should be collapsed.""" + assert normalize_name("test--effect") == "test-effect" + assert normalize_name("test___effect") == "test-effect" + + def test_normalize_strips_edge_dashes(self) -> None: + """Leading/trailing dashes should be stripped.""" + assert normalize_name("-test-effect-") == "test-effect" + + def test_normalize_empty_returns_unnamed(self) -> None: + """Empty names should return 'unnamed'.""" + assert normalize_name("") == "unnamed" + assert normalize_name("---") == "unnamed" + assert normalize_name("!!!") == "unnamed" + + +class TestFriendlyNameParsing: + """Tests for friendly name parsing.""" + + def test_parse_base_name_only(self) -> None: + """Parsing base name only returns None for version.""" + base, version = parse_friendly_name("my-effect") + assert base == "my-effect" + assert version is None + + def test_parse_with_version(self) -> None: + """Parsing with version returns both parts.""" + base, version = parse_friendly_name("my-effect 01hw3x9k") + assert base == "my-effect" + assert version == "01hw3x9k" + + def test_parse_strips_whitespace(self) -> None: + """Parsing should strip leading/trailing whitespace.""" + base, version = parse_friendly_name(" my-effect ") + assert base == "my-effect" + assert version is None + + +class TestFriendlyNameFormatting: + """Tests for friendly name formatting.""" + + def test_format_friendly_name(self) -> None: + """Format combines base and version with space.""" + assert format_friendly_name("my-effect", "01hw3x9k") == "my-effect 01hw3x9k" + + def test_format_l2_name(self) -> None: + """L2 format includes actor ID.""" + result = format_l2_name("@alice@example.com", "my-effect", "01hw3x9k") + assert result == "@alice@example.com my-effect 01hw3x9k" + + +class TestDatabaseSchemaExists: + """Tests that verify database schema includes friendly_names table.""" + + def test_schema_has_friendly_names_table(self) -> None: + """Database schema should include friendly_names table.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "CREATE TABLE IF NOT EXISTS friendly_names" in content + + def test_schema_has_required_columns(self) -> None: + """Friendly names table should have required columns.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "actor_id" in content + assert "base_name" in content + assert "version_id" in content + assert "item_type" in content + assert "display_name" in content + + def test_schema_has_unique_constraints(self) -> None: + """Friendly names table should have unique constraints.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + # Unique on (actor_id, base_name, version_id) + assert "UNIQUE(actor_id, base_name, version_id)" in content + # Unique on (actor_id, cid) + assert "UNIQUE(actor_id, cid)" in content + + +class TestDatabaseFunctionsExist: + """Tests that verify database functions exist.""" + + def test_create_friendly_name_exists(self) -> None: + """create_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def create_friendly_name(" in content + + def test_get_friendly_name_by_cid_exists(self) -> None: + """get_friendly_name_by_cid function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def get_friendly_name_by_cid(" in content + + def test_resolve_friendly_name_exists(self) -> None: + """resolve_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def resolve_friendly_name(" in content + + def test_list_friendly_names_exists(self) -> None: + """list_friendly_names function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def list_friendly_names(" in content + + def test_delete_friendly_name_exists(self) -> None: + """delete_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def delete_friendly_name(" in content + + +class TestNamingServiceModuleExists: + """Tests that verify naming service module structure.""" + + def test_module_file_exists(self) -> None: + """Naming service module file should exist.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + assert path.exists() + + def test_module_has_normalize_name(self) -> None: + """Module should have normalize_name function.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "def normalize_name(" in content + + def test_module_has_generate_version_id(self) -> None: + """Module should have generate_version_id function.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "def generate_version_id(" in content + + def test_module_has_naming_service_class(self) -> None: + """Module should have NamingService class.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "class NamingService:" in content + + def test_naming_service_has_assign_name(self) -> None: + """NamingService should have assign_name method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def assign_name(" in content + + def test_naming_service_has_resolve(self) -> None: + """NamingService should have resolve method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def resolve(" in content + + def test_naming_service_has_get_by_cid(self) -> None: + """NamingService should have get_by_cid method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def get_by_cid(" in content + + +class TestVersionIdProperties: + """Tests for version ID format properties (using actual function).""" + + def test_version_id_format(self) -> None: + """Version ID should use base32-crockford alphabet.""" + # Read the naming service to verify it uses the right alphabet + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert 'CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz"' in content + + def test_version_id_uses_hmac(self) -> None: + """Version ID generation should use HMAC for server verification.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "hmac.new(" in content + + def test_version_id_uses_timestamp(self) -> None: + """Version ID generation should be timestamp-based.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "time.time()" in content diff --git a/l1/tests/test_recipe_visibility.py b/l1/tests/test_recipe_visibility.py new file mode 100644 index 0000000..a0fc93c --- /dev/null +++ b/l1/tests/test_recipe_visibility.py @@ -0,0 +1,150 @@ +""" +Tests for recipe visibility in web UI. + +Bug found 2026-01-12: Recipes not showing in list even after upload. +""" + +import pytest +from pathlib import Path + + +class TestRecipeListingFlow: + """Tests for recipe listing data flow.""" + + def test_cache_manager_has_list_by_type(self) -> None: + """L1CacheManager should have list_by_type method.""" + # Read cache_manager.py and verify list_by_type exists + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'def list_by_type' in content, \ + "L1CacheManager should have list_by_type method" + + def test_list_by_type_returns_node_id(self) -> None: + """list_by_type should return entry.node_id values (IPFS CID).""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Find list_by_type function and verify it appends entry.node_id + assert 'cids.append(entry.node_id)' in content, \ + "list_by_type should append entry.node_id (IPFS CID) to results" + + def test_recipe_service_uses_database_items(self) -> None: + """Recipe service should use database.get_user_items for listing.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'get_user_items' in content, \ + "Recipe service should use database.get_user_items for listing" + + def test_recipe_upload_uses_recipe_node_type(self) -> None: + """Recipe upload should store with node_type='recipe'.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'node_type="recipe"' in content, \ + "Recipe upload should use node_type='recipe'" + + def test_get_by_cid_uses_find_by_cid(self) -> None: + """get_by_cid should use cache.find_by_cid to locate entries.""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Verify get_by_cid uses find_by_cid + assert 'find_by_cid(cid)' in content, \ + "get_by_cid should use find_by_cid to locate entries" + + def test_no_duplicate_get_by_cid_methods(self) -> None: + """ + Regression test: There should only be ONE get_by_cid method. + + Bug: Two get_by_cid methods existed, the second shadowed the first, + breaking recipe lookup because the comprehensive method was hidden. + """ + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Count occurrences of 'def get_by_cid' + count = content.count('def get_by_cid') + assert count == 1, \ + f"Should have exactly 1 get_by_cid method, found {count}" + + +class TestRecipeFilterLogic: + """Tests for recipe filtering logic via database.""" + + def test_recipes_filtered_by_actor_id(self) -> None: + """list_recipes should filter by actor_id parameter.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'actor_id' in content, \ + "list_recipes should accept actor_id parameter" + + def test_recipes_uses_item_type_filter(self) -> None: + """list_recipes should filter by item_type='recipe'.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'item_type="recipe"' in content, \ + "Recipe listing should filter by item_type='recipe'" + + +class TestCacheEntryHasCid: + """Tests for cache entry cid field.""" + + def test_artdag_cache_entry_has_cid(self) -> None: + """artdag CacheEntry should have cid field.""" + from artdag import CacheEntry + import dataclasses + + fields = {f.name for f in dataclasses.fields(CacheEntry)} + assert 'cid' in fields, \ + "CacheEntry should have cid field" + + def test_artdag_cache_put_computes_cid(self) -> None: + """artdag Cache.put should compute and store cid in metadata.""" + from artdag import Cache + import inspect + + source = inspect.getsource(Cache.put) + assert '"cid":' in source or "'cid':" in source, \ + "Cache.put should store cid in metadata" + + +class TestListByTypeReturnsEntries: + """Tests for list_by_type returning cached entries.""" + + def test_list_by_type_iterates_cache_entries(self) -> None: + """list_by_type should iterate self.cache.list_entries().""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'self.cache.list_entries()' in content, \ + "list_by_type should iterate cache entries" + + def test_list_by_type_filters_by_node_type(self) -> None: + """list_by_type should filter entries by node_type.""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'entry.node_type == node_type' in content, \ + "list_by_type should filter by node_type" + + def test_list_by_type_returns_node_id(self) -> None: + """list_by_type should return entry.node_id (IPFS CID).""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'cids.append(entry.node_id)' in content, \ + "list_by_type should append entry.node_id (IPFS CID)" + + def test_artdag_cache_list_entries_returns_all(self) -> None: + """artdag Cache.list_entries should return all entries.""" + from artdag import Cache + import inspect + + source = inspect.getsource(Cache.list_entries) + # Should return self._entries.values() + assert '_entries' in source, \ + "list_entries should access _entries dict" diff --git a/l1/tests/test_run_artifacts.py b/l1/tests/test_run_artifacts.py new file mode 100644 index 0000000..e6042c4 --- /dev/null +++ b/l1/tests/test_run_artifacts.py @@ -0,0 +1,111 @@ +""" +Tests for run artifacts data structure and template variables. + +Bug found 2026-01-12: runs/detail.html template expects artifact.cid +but the route provides artifact.hash, causing UndefinedError. +""" + +import pytest +from pathlib import Path + + +class TestCacheNotFoundTemplate: + """Tests for cache/not_found.html template.""" + + def test_template_uses_cid_not_content_hash(self) -> None: + """ + Regression test: not_found.html must use 'cid' variable. + + Bug: Template used 'content_hash' but route passes 'cid'. + Fix: Changed template to use 'cid'. + """ + template_path = Path('/home/giles/art/art-celery/app/templates/cache/not_found.html') + content = template_path.read_text() + + assert 'cid' in content, "Template should use 'cid' variable" + assert 'content_hash' not in content, \ + "Template should not use 'content_hash' (route passes 'cid')" + + def test_route_passes_cid_to_template(self) -> None: + """Verify route passes 'cid' variable to not_found template.""" + router_path = Path('/home/giles/art/art-celery/app/routers/cache.py') + content = router_path.read_text() + + # Find the render call for not_found.html + assert 'cid=cid' in content, \ + "Route should pass cid=cid to not_found.html template" + + +class TestInputPreviewsDataStructure: + """Tests for input_previews dict keys matching template expectations.""" + + def test_run_card_template_expects_inp_cid(self) -> None: + """Run card template uses inp.cid for input previews.""" + path = Path('/home/giles/art/art-celery/app/templates/runs/_run_card.html') + content = path.read_text() + + assert 'inp.cid' in content, \ + "Run card template should use inp.cid for input previews" + + def test_input_previews_use_cid_key(self) -> None: + """ + Regression test: input_previews must use 'cid' key not 'hash'. + + Bug: Router created input_previews with 'hash' key but template expected 'cid'. + """ + path = Path('/home/giles/art/art-celery/app/routers/runs.py') + content = path.read_text() + + # Should have: "cid": input_hash, not "hash": input_hash + assert '"cid": input_hash' in content, \ + "input_previews should use 'cid' key" + assert '"hash": input_hash' not in content, \ + "input_previews should not use 'hash' key (template expects 'cid')" + + +class TestArtifactDataStructure: + """Tests for artifact dict keys matching template expectations.""" + + def test_template_expects_cid_key(self) -> None: + """Template uses artifact.cid - verify this expectation.""" + template_path = Path('/home/giles/art/art-celery/app/templates/runs/detail.html') + content = template_path.read_text() + + # Template uses artifact.cid in multiple places + assert 'artifact.cid' in content, "Template should reference artifact.cid" + + def test_run_service_artifacts_have_cid_key(self) -> None: + """ + Regression test: get_run_artifacts must return dicts with 'cid' key. + + Bug: Service returned artifacts with 'hash' key but template expected 'cid'. + Fix: Changed service to use 'cid' key for consistency. + """ + # Read the run_service.py file and check it uses 'cid' not 'hash' + service_path = Path('/home/giles/art/art-celery/app/services/run_service.py') + content = service_path.read_text() + + # Find the get_artifact_info function and check it returns 'cid' + # The function should have: "cid": cid, not "hash": cid + assert '"cid": cid' in content or "'cid': cid" in content, \ + "get_run_artifacts should return artifacts with 'cid' key, not 'hash'" + + def test_router_artifacts_have_cid_key(self) -> None: + """ + Regression test: inline artifact creation in router must use 'cid' key. + + Bug: Router created artifacts with 'hash' key but template expected 'cid'. + """ + router_path = Path('/home/giles/art/art-celery/app/routers/runs.py') + content = router_path.read_text() + + # Check that artifacts.append uses 'cid' key + # Should have: "cid": output_cid, not "hash": output_cid + # Count occurrences of the patterns + hash_pattern_count = content.count('"hash": output_cid') + cid_pattern_count = content.count('"cid": output_cid') + + assert cid_pattern_count > 0, \ + "Router should create artifacts with 'cid' key" + assert hash_pattern_count == 0, \ + "Router should not use 'hash' key for artifacts (template expects 'cid')" diff --git a/l1/tests/test_xector.py b/l1/tests/test_xector.py new file mode 100644 index 0000000..0d006e5 --- /dev/null +++ b/l1/tests/test_xector.py @@ -0,0 +1,305 @@ +""" +Tests for xector primitives - parallel array operations. +""" + +import pytest +import numpy as np +from sexp_effects.primitive_libs.xector import ( + Xector, + xector_red, xector_green, xector_blue, xector_rgb, + xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm, + xector_dist_from_center, + alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp, + alpha_sin, alpha_cos, alpha_sq, + alpha_lt, alpha_gt, alpha_eq, + beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count, + xector_where, xector_fill, xector_zeros, xector_ones, + is_xector, +) + + +class TestXectorBasics: + """Test Xector class basic operations.""" + + def test_create_from_list(self): + x = Xector([1, 2, 3]) + assert len(x) == 3 + assert is_xector(x) + + def test_create_from_numpy(self): + arr = np.array([1.0, 2.0, 3.0]) + x = Xector(arr) + assert len(x) == 3 + np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32)) + + def test_implicit_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = a + b + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_implicit_mul(self): + a = Xector([1, 2, 3]) + b = Xector([2, 2, 2]) + c = a * b + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_scalar_broadcast(self): + a = Xector([1, 2, 3]) + c = a + 10 + np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13]) + + def test_scalar_broadcast_rmul(self): + a = Xector([1, 2, 3]) + c = 2 * a + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + +class TestAlphaOperations: + """Test α (element-wise) operations.""" + + def test_alpha_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = alpha_add(a, b) + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_alpha_add_multi(self): + a = Xector([1, 2, 3]) + b = Xector([1, 1, 1]) + c = Xector([10, 10, 10]) + d = alpha_add(a, b, c) + np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14]) + + def test_alpha_mul_scalar(self): + a = Xector([1, 2, 3]) + c = alpha_mul(a, 2) + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_alpha_sqrt(self): + a = Xector([1, 4, 9, 16]) + c = alpha_sqrt(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4]) + + def test_alpha_clamp(self): + a = Xector([-5, 0, 5, 10, 15]) + c = alpha_clamp(a, 0, 10) + np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10]) + + def test_alpha_sin_cos(self): + a = Xector([0, np.pi/2, np.pi]) + s = alpha_sin(a) + c = alpha_cos(a) + np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5) + np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5) + + def test_alpha_sq(self): + a = Xector([1, 2, 3, 4]) + c = alpha_sq(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16]) + + def test_alpha_comparison(self): + a = Xector([1, 2, 3, 4]) + b = Xector([2, 2, 2, 2]) + lt = alpha_lt(a, b) + gt = alpha_gt(a, b) + eq = alpha_eq(a, b) + np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False]) + np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True]) + np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False]) + + +class TestBetaOperations: + """Test β (reduction) operations.""" + + def test_beta_add(self): + a = Xector([1, 2, 3, 4]) + assert beta_add(a) == 10 + + def test_beta_mul(self): + a = Xector([1, 2, 3, 4]) + assert beta_mul(a) == 24 + + def test_beta_min_max(self): + a = Xector([3, 1, 4, 1, 5, 9, 2, 6]) + assert beta_min(a) == 1 + assert beta_max(a) == 9 + + def test_beta_mean(self): + a = Xector([1, 2, 3, 4]) + assert beta_mean(a) == 2.5 + + def test_beta_count(self): + a = Xector([1, 2, 3, 4, 5]) + assert beta_count(a) == 5 + + +class TestFrameConversion: + """Test frame/xector conversion.""" + + def test_extract_channels(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[255, 0, 0], [0, 255, 0]], + [[0, 0, 255], [128, 128, 128]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + assert len(r) == 4 + np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128]) + np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128]) + np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128]) + + def test_rgb_roundtrip(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[100, 150, 200], [50, 75, 100]], + [[200, 100, 50], [25, 50, 75]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + reconstructed = xector_rgb(r, g, b) + np.testing.assert_array_equal(reconstructed, frame) + + def test_modify_and_reconstruct(self): + frame = np.array([ + [[100, 100, 100], [100, 100, 100]], + [[100, 100, 100], [100, 100, 100]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + # Double red channel + r_doubled = r * 2 + + result = xector_rgb(r_doubled, g, b) + + # Red should be 200, others unchanged + assert result[0, 0, 0] == 200 + assert result[0, 0, 1] == 100 + assert result[0, 0, 2] == 100 + + +class TestCoordinates: + """Test coordinate generation.""" + + def test_x_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + x = xector_x_coords(frame) + # Should be [0,1,2, 0,1,2] (x coords repeated for each row) + np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2]) + + def test_y_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + y = xector_y_coords(frame) + # Should be [0,0,0, 1,1,1] (y coords for each pixel) + np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1]) + + def test_normalized_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_x_norm(frame) + y = xector_y_norm(frame) + + # x should go 0 to 1 across width + assert x.to_numpy()[0] == 0 + assert x.to_numpy()[2] == 1 + + # y should go 0 to 1 down height + assert y.to_numpy()[0] == 0 + assert y.to_numpy()[3] == 1 + + +class TestConditional: + """Test conditional operations.""" + + def test_where(self): + cond = Xector([True, False, True, False]) + true_val = Xector([1, 1, 1, 1]) + false_val = Xector([0, 0, 0, 0]) + + result = xector_where(cond, true_val, false_val) + np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0]) + + def test_where_with_comparison(self): + a = Xector([1, 5, 3, 7]) + threshold = 4 + + # Elements > 4 become 255, others become 0 + result = xector_where(alpha_gt(a, threshold), 255, 0) + np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255]) + + def test_fill(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_fill(42, frame) + assert len(x) == 6 + assert all(v == 42 for v in x.to_numpy()) + + def test_zeros_ones(self): + frame = np.zeros((2, 2, 3), dtype=np.uint8) + z = xector_zeros(frame) + o = xector_ones(frame) + + assert all(v == 0 for v in z.to_numpy()) + assert all(v == 1 for v in o.to_numpy()) + + +class TestInterpreterIntegration: + """Test xector operations through the interpreter.""" + + def test_xector_vignette_effect(self): + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + + # Load the xector vignette effect + interp.load_effect('sexp_effects/effects/xector_vignette.sexp') + + # Create a test frame (white) + frame = np.full((100, 100, 3), 255, dtype=np.uint8) + + # Run effect + result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {}) + + # Center should be brighter than corners + center = result[50, 50] + corner = result[0, 0] + + assert center.mean() > corner.mean(), "Center should be brighter than corners" + # Corners should be darkened + assert corner.mean() < 255, "Corners should be darkened" + + def test_implicit_elementwise(self): + """Test that regular + works element-wise on xectors.""" + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + # Load xector primitives + from sexp_effects.primitive_libs.xector import PRIMITIVES + for name, fn in PRIMITIVES.items(): + interp.global_env.set(name, fn) + + # Parse and eval a simple xector expression + from sexp_effects.parser import parse + expr = parse('(+ (red frame) 10)') + + # Create test frame + frame = np.full((2, 2, 3), 100, dtype=np.uint8) + interp.global_env.set('frame', frame) + + result = interp.eval(expr) + + # Should be a xector with values 110 + assert is_xector(result) + np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110]) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])