From 3ca1c14432b9fb9d968f62f9f37330efcc0030f2 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:04:48 +0000 Subject: [PATCH 01/24] Initial monorepo commit From 80c94ebea7c4c96a1068b0998bb0bedd73b1625b Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:07:19 +0000 Subject: [PATCH 02/24] Squashed 'l1/' content from commit 670aa58 git-subtree-dir: l1 git-subtree-split: 670aa582df99e87fca7c247b949baf452e8c234f --- .dockerignore | 22 + .env.example | 20 + .env.gpu | 11 + .gitea/workflows/ci.yml | 63 + .gitignore | 8 + Dockerfile | 31 + Dockerfile.gpu | 98 + README.md | 329 ++ app/__init__.py | 237 + app/config.py | 116 + app/dependencies.py | 186 + app/repositories/__init__.py | 10 + app/routers/__init__.py | 23 + app/routers/api.py | 257 + app/routers/auth.py | 165 + app/routers/cache.py | 515 ++ app/routers/effects.py | 415 ++ app/routers/fragments.py | 143 + app/routers/home.py | 253 + app/routers/inbox.py | 125 + app/routers/oembed.py | 74 + app/routers/recipes.py | 686 +++ app/routers/runs.py | 1704 ++++++ app/routers/storage.py | 264 + app/services/__init__.py | 15 + app/services/auth_service.py | 138 + app/services/cache_service.py | 618 +++ app/services/naming_service.py | 234 + app/services/recipe_service.py | 337 ++ app/services/run_service.py | 1001 ++++ app/services/storage_service.py | 232 + app/templates/404.html | 14 + app/templates/base.html | 46 + app/templates/cache/detail.html | 182 + app/templates/cache/media_list.html | 325 ++ app/templates/cache/not_found.html | 21 + app/templates/effects/detail.html | 203 + app/templates/effects/list.html | 200 + app/templates/fragments/link_card.html | 22 + app/templates/fragments/nav_item.html | 7 + app/templates/home.html | 51 + app/templates/recipes/detail.html | 265 + app/templates/recipes/list.html | 136 + app/templates/runs/_run_card.html | 89 + app/templates/runs/artifacts.html | 62 + app/templates/runs/detail.html | 1073 ++++ app/templates/runs/list.html | 45 + app/templates/runs/plan.html | 99 + app/templates/runs/plan_node.html | 99 + app/templates/storage/list.html | 90 + app/templates/storage/type.html | 152 + app/types.py | 197 + app/utils/__init__.py | 0 app/utils/http_signatures.py | 84 + artdag-client.tar.gz | Bin 0 -> 7982 bytes build-client.sh | 37 + cache_manager.py | 872 ++++ celery_app.py | 51 + check_redis.py | 12 + claiming.py | 421 ++ configs/audio-dizzy.sexp | 17 + configs/audio-halleluwah.sexp | 17 + configs/sources-default.sexp | 38 + configs/sources-woods-half.sexp | 19 + configs/sources-woods.sexp | 39 + database.py | 2144 ++++++++ deploy.sh | 19 + diagnose_gpu.py | 249 + docker-compose.gpu-dev.yml | 36 + docker-compose.yml | 191 + effects/quick_test_explicit.sexp | 150 + ipfs_client.py | 345 ++ path_registry.py | 477 ++ pyproject.toml | 51 + recipes/woods-lowres.sexp | 223 + recipes/woods-recipe-optimized.sexp | 211 + recipes/woods-recipe.sexp | 134 + requirements-dev.txt | 16 + requirements.txt | 21 + scripts/cloud-init-gpu.sh | 77 + scripts/deploy-to-gpu.sh | 51 + scripts/gpu-dev-deploy.sh | 34 + scripts/setup-gpu-droplet.sh | 108 + server.py | 26 + sexp_effects/__init__.py | 32 + sexp_effects/derived.sexp | 206 + sexp_effects/effects/ascii_art.sexp | 17 + sexp_effects/effects/ascii_art_fx.sexp | 52 + sexp_effects/effects/ascii_fx_zone.sexp | 102 + sexp_effects/effects/ascii_zones.sexp | 30 + sexp_effects/effects/blend.sexp | 31 + sexp_effects/effects/blend_multi.sexp | 58 + sexp_effects/effects/bloom.sexp | 16 + sexp_effects/effects/blur.sexp | 8 + sexp_effects/effects/brightness.sexp | 9 + sexp_effects/effects/cell_pattern.sexp | 65 + sexp_effects/effects/color-adjust.sexp | 13 + sexp_effects/effects/color_cycle.sexp | 13 + sexp_effects/effects/contrast.sexp | 9 + sexp_effects/effects/crt.sexp | 30 + sexp_effects/effects/datamosh.sexp | 14 + sexp_effects/effects/echo.sexp | 19 + sexp_effects/effects/edge_detect.sexp | 9 + sexp_effects/effects/emboss.sexp | 13 + sexp_effects/effects/film_grain.sexp | 19 + sexp_effects/effects/fisheye.sexp | 16 + sexp_effects/effects/flip.sexp | 16 + sexp_effects/effects/grayscale.sexp | 7 + sexp_effects/effects/halftone.sexp | 49 + sexp_effects/effects/hue_shift.sexp | 12 + sexp_effects/effects/invert.sexp | 9 + sexp_effects/effects/kaleidoscope.sexp | 20 + sexp_effects/effects/layer.sexp | 36 + sexp_effects/effects/mirror.sexp | 33 + sexp_effects/effects/mosaic.sexp | 30 + sexp_effects/effects/neon_glow.sexp | 23 + sexp_effects/effects/noise.sexp | 8 + sexp_effects/effects/outline.sexp | 24 + sexp_effects/effects/pixelate.sexp | 13 + sexp_effects/effects/pixelsort.sexp | 11 + sexp_effects/effects/posterize.sexp | 8 + sexp_effects/effects/resize-frame.sexp | 11 + sexp_effects/effects/rgb_split.sexp | 13 + sexp_effects/effects/ripple.sexp | 19 + sexp_effects/effects/rotate.sexp | 11 + sexp_effects/effects/saturation.sexp | 9 + sexp_effects/effects/scanlines.sexp | 15 + sexp_effects/effects/sepia.sexp | 7 + sexp_effects/effects/sharpen.sexp | 8 + sexp_effects/effects/strobe.sexp | 16 + sexp_effects/effects/swirl.sexp | 17 + sexp_effects/effects/threshold.sexp | 9 + sexp_effects/effects/tile_grid.sexp | 29 + sexp_effects/effects/trails.sexp | 20 + sexp_effects/effects/vignette.sexp | 23 + sexp_effects/effects/wave.sexp | 22 + .../effects/xector_feathered_blend.sexp | 44 + sexp_effects/effects/xector_grain.sexp | 34 + sexp_effects/effects/xector_inset_blend.sexp | 57 + sexp_effects/effects/xector_threshold.sexp | 27 + sexp_effects/effects/xector_vignette.sexp | 36 + sexp_effects/effects/zoom.sexp | 8 + sexp_effects/interpreter.py | 1085 ++++ sexp_effects/parser.py | 396 ++ sexp_effects/primitive_libs/__init__.py | 102 + sexp_effects/primitive_libs/arrays.py | 196 + sexp_effects/primitive_libs/ascii.py | 388 ++ sexp_effects/primitive_libs/blending.py | 116 + sexp_effects/primitive_libs/blending_gpu.py | 220 + sexp_effects/primitive_libs/color.py | 137 + sexp_effects/primitive_libs/color_ops.py | 109 + sexp_effects/primitive_libs/color_ops_gpu.py | 280 + sexp_effects/primitive_libs/core.py | 294 ++ sexp_effects/primitive_libs/drawing.py | 690 +++ sexp_effects/primitive_libs/filters.py | 119 + sexp_effects/primitive_libs/geometry.py | 143 + sexp_effects/primitive_libs/geometry_gpu.py | 403 ++ sexp_effects/primitive_libs/image.py | 150 + sexp_effects/primitive_libs/math.py | 164 + sexp_effects/primitive_libs/streaming.py | 593 +++ sexp_effects/primitive_libs/streaming_gpu.py | 1165 +++++ sexp_effects/primitive_libs/xector.py | 1382 +++++ sexp_effects/primitives.py | 3075 +++++++++++ sexp_effects/test_interpreter.py | 236 + sexp_effects/wgsl_compiler.py | 715 +++ storage_providers.py | 1009 ++++ streaming/__init__.py | 44 + streaming/audio.py | 486 ++ streaming/backends.py | 572 ++ streaming/compositor.py | 595 +++ streaming/demo.py | 125 + streaming/gpu_output.py | 538 ++ streaming/jax_typography.py | 1642 ++++++ streaming/jit_compiler.py | 531 ++ streaming/multi_res_output.py | 509 ++ streaming/output.py | 963 ++++ streaming/pipeline.py | 846 +++ streaming/recipe_adapter.py | 470 ++ streaming/recipe_executor.py | 415 ++ streaming/sexp_executor.py | 678 +++ streaming/sexp_interp.py | 376 ++ streaming/sexp_to_cuda.py | 706 +++ streaming/sexp_to_jax.py | 4628 +++++++++++++++++ streaming/sources.py | 281 + streaming/stream_sexp.py | 1098 ++++ streaming/stream_sexp_generic.py | 1739 +++++++ tasks/__init__.py | 13 + tasks/ipfs_upload.py | 93 + tasks/streaming.py | 724 +++ templates/crossfade-zoom.sexp | 25 + templates/cycle-crossfade.sexp | 65 + templates/process-pair.sexp | 112 + templates/scan-oscillating-spin.sexp | 28 + templates/scan-ripple-drops.sexp | 41 + templates/standard-effects.sexp | 22 + templates/standard-primitives.sexp | 14 + templates/stream-process-pair.sexp | 72 + test_autonomous.sexp | 36 + test_autonomous_prealloc.py | 75 + test_full_optimized.py | 161 + test_funky_text.py | 542 ++ test_fused_direct.py | 102 + test_fused_pipeline.sexp | 44 + test_heavy_fused.sexp | 39 + test_heavy_interpreted.sexp | 38 + test_interpreted_vs_fused.sexp | 37 + test_pil_options.py | 183 + test_styled_text.py | 176 + test_typography_fx.py | 486 ++ tests/__init__.py | 1 + tests/conftest.py | 93 + tests/test_auth.py | 42 + tests/test_cache_manager.py | 397 ++ tests/test_dag_transform.py | 492 ++ tests/test_effect_loading.py | 327 ++ tests/test_effects_web.py | 367 ++ tests/test_execute_recipe.py | 529 ++ tests/test_frame_compatibility.py | 185 + tests/test_item_visibility.py | 272 + tests/test_jax_pipeline_integration.py | 517 ++ tests/test_jax_primitives.py | 334 ++ tests/test_naming_service.py | 246 + tests/test_recipe_visibility.py | 150 + tests/test_run_artifacts.py | 111 + tests/test_xector.py | 305 ++ 225 files changed, 57298 insertions(+) create mode 100644 .dockerignore create mode 100644 .env.example create mode 100644 .env.gpu create mode 100644 .gitea/workflows/ci.yml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 Dockerfile.gpu create mode 100644 README.md create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/dependencies.py create mode 100644 app/repositories/__init__.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/api.py create mode 100644 app/routers/auth.py create mode 100644 app/routers/cache.py create mode 100644 app/routers/effects.py create mode 100644 app/routers/fragments.py create mode 100644 app/routers/home.py create mode 100644 app/routers/inbox.py create mode 100644 app/routers/oembed.py create mode 100644 app/routers/recipes.py create mode 100644 app/routers/runs.py create mode 100644 app/routers/storage.py create mode 100644 app/services/__init__.py create mode 100644 app/services/auth_service.py create mode 100644 app/services/cache_service.py create mode 100644 app/services/naming_service.py create mode 100644 app/services/recipe_service.py create mode 100644 app/services/run_service.py create mode 100644 app/services/storage_service.py create mode 100644 app/templates/404.html create mode 100644 app/templates/base.html create mode 100644 app/templates/cache/detail.html create mode 100644 app/templates/cache/media_list.html create mode 100644 app/templates/cache/not_found.html create mode 100644 app/templates/effects/detail.html create mode 100644 app/templates/effects/list.html create mode 100644 app/templates/fragments/link_card.html create mode 100644 app/templates/fragments/nav_item.html create mode 100644 app/templates/home.html create mode 100644 app/templates/recipes/detail.html create mode 100644 app/templates/recipes/list.html create mode 100644 app/templates/runs/_run_card.html create mode 100644 app/templates/runs/artifacts.html create mode 100644 app/templates/runs/detail.html create mode 100644 app/templates/runs/list.html create mode 100644 app/templates/runs/plan.html create mode 100644 app/templates/runs/plan_node.html create mode 100644 app/templates/storage/list.html create mode 100644 app/templates/storage/type.html create mode 100644 app/types.py create mode 100644 app/utils/__init__.py create mode 100644 app/utils/http_signatures.py create mode 100644 artdag-client.tar.gz create mode 100755 build-client.sh create mode 100644 cache_manager.py create mode 100644 celery_app.py create mode 100644 check_redis.py create mode 100644 claiming.py create mode 100644 configs/audio-dizzy.sexp create mode 100644 configs/audio-halleluwah.sexp create mode 100644 configs/sources-default.sexp create mode 100644 configs/sources-woods-half.sexp create mode 100644 configs/sources-woods.sexp create mode 100644 database.py create mode 100755 deploy.sh create mode 100755 diagnose_gpu.py create mode 100644 docker-compose.gpu-dev.yml create mode 100644 docker-compose.yml create mode 100644 effects/quick_test_explicit.sexp create mode 100644 ipfs_client.py create mode 100644 path_registry.py create mode 100644 pyproject.toml create mode 100644 recipes/woods-lowres.sexp create mode 100644 recipes/woods-recipe-optimized.sexp create mode 100644 recipes/woods-recipe.sexp create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100644 scripts/cloud-init-gpu.sh create mode 100755 scripts/deploy-to-gpu.sh create mode 100755 scripts/gpu-dev-deploy.sh create mode 100755 scripts/setup-gpu-droplet.sh create mode 100644 server.py create mode 100644 sexp_effects/__init__.py create mode 100644 sexp_effects/derived.sexp create mode 100644 sexp_effects/effects/ascii_art.sexp create mode 100644 sexp_effects/effects/ascii_art_fx.sexp create mode 100644 sexp_effects/effects/ascii_fx_zone.sexp create mode 100644 sexp_effects/effects/ascii_zones.sexp create mode 100644 sexp_effects/effects/blend.sexp create mode 100644 sexp_effects/effects/blend_multi.sexp create mode 100644 sexp_effects/effects/bloom.sexp create mode 100644 sexp_effects/effects/blur.sexp create mode 100644 sexp_effects/effects/brightness.sexp create mode 100644 sexp_effects/effects/cell_pattern.sexp create mode 100644 sexp_effects/effects/color-adjust.sexp create mode 100644 sexp_effects/effects/color_cycle.sexp create mode 100644 sexp_effects/effects/contrast.sexp create mode 100644 sexp_effects/effects/crt.sexp create mode 100644 sexp_effects/effects/datamosh.sexp create mode 100644 sexp_effects/effects/echo.sexp create mode 100644 sexp_effects/effects/edge_detect.sexp create mode 100644 sexp_effects/effects/emboss.sexp create mode 100644 sexp_effects/effects/film_grain.sexp create mode 100644 sexp_effects/effects/fisheye.sexp create mode 100644 sexp_effects/effects/flip.sexp create mode 100644 sexp_effects/effects/grayscale.sexp create mode 100644 sexp_effects/effects/halftone.sexp create mode 100644 sexp_effects/effects/hue_shift.sexp create mode 100644 sexp_effects/effects/invert.sexp create mode 100644 sexp_effects/effects/kaleidoscope.sexp create mode 100644 sexp_effects/effects/layer.sexp create mode 100644 sexp_effects/effects/mirror.sexp create mode 100644 sexp_effects/effects/mosaic.sexp create mode 100644 sexp_effects/effects/neon_glow.sexp create mode 100644 sexp_effects/effects/noise.sexp create mode 100644 sexp_effects/effects/outline.sexp create mode 100644 sexp_effects/effects/pixelate.sexp create mode 100644 sexp_effects/effects/pixelsort.sexp create mode 100644 sexp_effects/effects/posterize.sexp create mode 100644 sexp_effects/effects/resize-frame.sexp create mode 100644 sexp_effects/effects/rgb_split.sexp create mode 100644 sexp_effects/effects/ripple.sexp create mode 100644 sexp_effects/effects/rotate.sexp create mode 100644 sexp_effects/effects/saturation.sexp create mode 100644 sexp_effects/effects/scanlines.sexp create mode 100644 sexp_effects/effects/sepia.sexp create mode 100644 sexp_effects/effects/sharpen.sexp create mode 100644 sexp_effects/effects/strobe.sexp create mode 100644 sexp_effects/effects/swirl.sexp create mode 100644 sexp_effects/effects/threshold.sexp create mode 100644 sexp_effects/effects/tile_grid.sexp create mode 100644 sexp_effects/effects/trails.sexp create mode 100644 sexp_effects/effects/vignette.sexp create mode 100644 sexp_effects/effects/wave.sexp create mode 100644 sexp_effects/effects/xector_feathered_blend.sexp create mode 100644 sexp_effects/effects/xector_grain.sexp create mode 100644 sexp_effects/effects/xector_inset_blend.sexp create mode 100644 sexp_effects/effects/xector_threshold.sexp create mode 100644 sexp_effects/effects/xector_vignette.sexp create mode 100644 sexp_effects/effects/zoom.sexp create mode 100644 sexp_effects/interpreter.py create mode 100644 sexp_effects/parser.py create mode 100644 sexp_effects/primitive_libs/__init__.py create mode 100644 sexp_effects/primitive_libs/arrays.py create mode 100644 sexp_effects/primitive_libs/ascii.py create mode 100644 sexp_effects/primitive_libs/blending.py create mode 100644 sexp_effects/primitive_libs/blending_gpu.py create mode 100644 sexp_effects/primitive_libs/color.py create mode 100644 sexp_effects/primitive_libs/color_ops.py create mode 100644 sexp_effects/primitive_libs/color_ops_gpu.py create mode 100644 sexp_effects/primitive_libs/core.py create mode 100644 sexp_effects/primitive_libs/drawing.py create mode 100644 sexp_effects/primitive_libs/filters.py create mode 100644 sexp_effects/primitive_libs/geometry.py create mode 100644 sexp_effects/primitive_libs/geometry_gpu.py create mode 100644 sexp_effects/primitive_libs/image.py create mode 100644 sexp_effects/primitive_libs/math.py create mode 100644 sexp_effects/primitive_libs/streaming.py create mode 100644 sexp_effects/primitive_libs/streaming_gpu.py create mode 100644 sexp_effects/primitive_libs/xector.py create mode 100644 sexp_effects/primitives.py create mode 100644 sexp_effects/test_interpreter.py create mode 100644 sexp_effects/wgsl_compiler.py create mode 100644 storage_providers.py create mode 100644 streaming/__init__.py create mode 100644 streaming/audio.py create mode 100644 streaming/backends.py create mode 100644 streaming/compositor.py create mode 100644 streaming/demo.py create mode 100644 streaming/gpu_output.py create mode 100644 streaming/jax_typography.py create mode 100644 streaming/jit_compiler.py create mode 100644 streaming/multi_res_output.py create mode 100644 streaming/output.py create mode 100644 streaming/pipeline.py create mode 100644 streaming/recipe_adapter.py create mode 100644 streaming/recipe_executor.py create mode 100644 streaming/sexp_executor.py create mode 100644 streaming/sexp_interp.py create mode 100644 streaming/sexp_to_cuda.py create mode 100644 streaming/sexp_to_jax.py create mode 100644 streaming/sources.py create mode 100644 streaming/stream_sexp.py create mode 100644 streaming/stream_sexp_generic.py create mode 100644 tasks/__init__.py create mode 100644 tasks/ipfs_upload.py create mode 100644 tasks/streaming.py create mode 100644 templates/crossfade-zoom.sexp create mode 100644 templates/cycle-crossfade.sexp create mode 100644 templates/process-pair.sexp create mode 100644 templates/scan-oscillating-spin.sexp create mode 100644 templates/scan-ripple-drops.sexp create mode 100644 templates/standard-effects.sexp create mode 100644 templates/standard-primitives.sexp create mode 100644 templates/stream-process-pair.sexp create mode 100644 test_autonomous.sexp create mode 100644 test_autonomous_prealloc.py create mode 100644 test_full_optimized.py create mode 100644 test_funky_text.py create mode 100644 test_fused_direct.py create mode 100644 test_fused_pipeline.sexp create mode 100644 test_heavy_fused.sexp create mode 100644 test_heavy_interpreted.sexp create mode 100644 test_interpreted_vs_fused.sexp create mode 100644 test_pil_options.py create mode 100644 test_styled_text.py create mode 100644 test_typography_fx.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_cache_manager.py create mode 100644 tests/test_dag_transform.py create mode 100644 tests/test_effect_loading.py create mode 100644 tests/test_effects_web.py create mode 100644 tests/test_execute_recipe.py create mode 100644 tests/test_frame_compatibility.py create mode 100644 tests/test_item_visibility.py create mode 100644 tests/test_jax_pipeline_integration.py create mode 100644 tests/test_jax_primitives.py create mode 100644 tests/test_naming_service.py create mode 100644 tests/test_recipe_visibility.py create mode 100644 tests/test_run_artifacts.py create mode 100644 tests/test_xector.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f48a442 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,22 @@ +# Don't copy local clones - Dockerfile will clone fresh +artdag-effects/ + +# Python cache +__pycache__/ +*.py[cod] +*.egg-info/ +.pytest_cache/ + +# Virtual environments +.venv/ +venv/ + +# Local env +.env + +# Git +.git/ + +# IDE +.vscode/ +.idea/ diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..0b0e063 --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# L1 Server Configuration + +# PostgreSQL password (REQUIRED - no default) +POSTGRES_PASSWORD=changeme-generate-with-openssl-rand-hex-16 + +# Admin token for purge operations (REQUIRED - no default) +# Generate with: openssl rand -hex 32 +ADMIN_TOKEN=changeme-generate-with-openssl-rand-hex-32 + +# L1 host IP/hostname for GPU worker cross-VPC access +L1_HOST=your-l1-server-ip + +# This L1 server's public URL (sent to L2 when publishing) +L1_PUBLIC_URL=https://l1.artdag.rose-ash.com + +# L2 server URL (for authentication and publishing) +L2_SERVER=https://artdag.rose-ash.com + +# L2 domain for ActivityPub actor IDs (e.g., @user@domain) +L2_DOMAIN=artdag.rose-ash.com diff --git a/.env.gpu b/.env.gpu new file mode 100644 index 0000000..9253dcd --- /dev/null +++ b/.env.gpu @@ -0,0 +1,11 @@ +# GPU worker env - connects to L1 host via public IP (cross-VPC) +REDIS_URL=redis://138.68.142.139:16379/5 +DATABASE_URL=postgresql://artdag:f960bcc61d8b2155a1d57f7dd72c1c58@138.68.142.139:15432/artdag +IPFS_API=/ip4/138.68.142.139/tcp/15001 +IPFS_GATEWAYS=https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link +IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs +CACHE_DIR=/data/cache +C_FORCE_ROOT=true +ARTDAG_CLUSTER_KEY= +NVIDIA_VISIBLE_DEVICES=all +STREAMING_GPU_PERSIST=0 diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..a79f66e --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,63 @@ +name: Build and Deploy + +on: + push: + +env: + REGISTRY: registry.rose-ash.com:5000 + IMAGE_CPU: celery-l1-server + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install tools + run: | + apt-get update && apt-get install -y --no-install-recommends openssh-client + + - name: Set up SSH + env: + SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + mkdir -p ~/.ssh + echo "$SSH_KEY" > ~/.ssh/id_rsa + chmod 600 ~/.ssh/id_rsa + ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true + + - name: Pull latest code on server + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + BRANCH: ${{ github.ref_name }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + git fetch origin $BRANCH + git checkout $BRANCH + git reset --hard origin/$BRANCH + " + + - name: Build and push image + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + docker build --build-arg CACHEBUST=\$(date +%s) -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} . + docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest + docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} + " + + - name: Deploy stack + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/celery + docker stack deploy -c docker-compose.yml celery + echo 'Waiting for services to update...' + sleep 10 + docker stack services celery + " diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ca2eb4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.py[cod] +.pytest_cache/ +*.egg-info/ +.venv/ +venv/ +.env +artdag-effects/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..90a770d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install git and ffmpeg (for video transcoding) +RUN apt-get update && apt-get install -y --no-install-recommends git ffmpeg && rm -rf /var/lib/apt/lists/* + +# Install dependencies +COPY requirements.txt . +ARG CACHEBUST=1 +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY . . + +# Clone effects repo +RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects + +# Build client tarball for download +RUN ./build-client.sh + +# Create cache directory +RUN mkdir -p /data/cache + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV EFFECTS_PATH=/app/artdag-effects +ENV PYTHONPATH=/app + +# Default command runs the server +CMD ["python", "server.py"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..967f788 --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,98 @@ +# GPU-enabled worker image +# Multi-stage build: use devel image for compiling, runtime for final image + +# Stage 1: Build decord with CUDA +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 AS builder + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 \ + python3.11-venv \ + python3.11-dev \ + python3-pip \ + git \ + cmake \ + build-essential \ + pkg-config \ + libavcodec-dev \ + libavformat-dev \ + libavutil-dev \ + libavdevice-dev \ + libavfilter-dev \ + libswresample-dev \ + libswscale-dev \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3.11 /usr/bin/python3 \ + && ln -sf /usr/bin/python3 /usr/bin/python + +# Download Video Codec SDK headers for NVDEC/NVCUVID +RUN git clone https://github.com/FFmpeg/nv-codec-headers.git /tmp/nv-codec-headers && \ + cd /tmp/nv-codec-headers && make install && rm -rf /tmp/nv-codec-headers + +# Create stub for libnvcuvid (real library comes from driver at runtime) +RUN echo 'void* __nvcuvid_stub__;' | gcc -shared -x c - -o /usr/local/cuda/lib64/libnvcuvid.so + +# Build decord with CUDA support +RUN git clone --recursive https://github.com/dmlc/decord /tmp/decord && \ + cd /tmp/decord && \ + mkdir build && cd build && \ + cmake .. -DUSE_CUDA=ON -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES="70;75;80;86;89;90" && \ + make -j$(nproc) && \ + cd ../python && pip install --target=/decord-install . + +# Stage 2: Runtime image +FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 + +WORKDIR /app + +# Install Python 3.11 and system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 \ + python3.11-venv \ + python3-pip \ + git \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3.11 /usr/bin/python3 \ + && ln -sf /usr/bin/python3 /usr/bin/python + +# Upgrade pip +RUN python3 -m pip install --upgrade pip + +# Install CPU dependencies first +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Install GPU-specific dependencies (CuPy for CUDA 12.x) +RUN pip install --no-cache-dir cupy-cuda12x + +# Install PyNvVideoCodec for zero-copy GPU encoding +RUN pip install --no-cache-dir PyNvVideoCodec + +# Copy decord from builder stage +COPY --from=builder /decord-install /usr/local/lib/python3.11/dist-packages/ +COPY --from=builder /tmp/decord/build/libdecord.so /usr/local/lib/ +RUN ldconfig + +# Clone effects repo (before COPY so it gets cached) +RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects + +# Copy application (this invalidates cache for any code change) +COPY . . + +# Create cache directory +RUN mkdir -p /data/cache + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV EFFECTS_PATH=/app/artdag-effects +ENV PYTHONPATH=/app +# GPU persistence enabled - frames stay on GPU throughout pipeline +ENV STREAMING_GPU_PERSIST=1 +# Preload libnvcuvid for decord NVDEC GPU decode +ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnvcuvid.so +# Use cluster's public IPFS gateway for HLS segment URLs +ENV IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs + +# Default command runs celery worker +CMD ["celery", "-A", "celery_app", "worker", "--loglevel=info", "-E", "-Q", "gpu,celery"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..d387437 --- /dev/null +++ b/README.md @@ -0,0 +1,329 @@ +# Art DAG L1 Server + +L1 rendering server for the Art DAG system. Manages distributed rendering jobs via Celery workers with content-addressable caching and optional IPFS integration. + +## Features + +- **3-Phase Execution**: Analyze → Plan → Execute pipeline for recipe-based rendering +- **Content-Addressable Caching**: IPFS CIDs with deduplication +- **IPFS Integration**: Optional IPFS-primary mode for distributed storage +- **Storage Providers**: S3, IPFS, and local storage backends +- **DAG Visualization**: Interactive graph visualization of execution plans +- **SPA-Style Navigation**: Smooth URL-based navigation without full page reloads +- **L2 Federation**: Publish outputs to ActivityPub registry + +## Dependencies + +- **artdag** (GitHub): Core DAG execution engine +- **artdag-effects** (rose-ash): Effect implementations +- **artdag-common**: Shared templates and middleware +- **Redis**: Message broker, result backend, and run persistence +- **PostgreSQL**: Metadata storage +- **IPFS** (optional): Distributed content storage + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Start Redis +redis-server + +# Start a worker +celery -A celery_app worker --loglevel=info -E + +# Start the L1 server +python server.py +``` + +## Docker Swarm Deployment + +```bash +docker stack deploy -c docker-compose.yml artdag +``` + +The stack includes: +- **redis**: Message broker (Redis 7) +- **postgres**: Metadata database (PostgreSQL 16) +- **ipfs**: IPFS node (Kubo) +- **l1-server**: FastAPI web server +- **l1-worker**: Celery workers (2 replicas) +- **flower**: Celery task monitoring + +## Configuration + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `HOST` | `0.0.0.0` | Server bind address | +| `PORT` | `8000` | Server port | +| `REDIS_URL` | `redis://localhost:6379/5` | Redis connection | +| `DATABASE_URL` | **(required)** | PostgreSQL connection | +| `CACHE_DIR` | `~/.artdag/cache` | Local cache directory | +| `IPFS_API` | `/dns/localhost/tcp/5001` | IPFS API multiaddr | +| `IPFS_GATEWAY_URL` | `https://ipfs.io/ipfs` | Public IPFS gateway | +| `IPFS_PRIMARY` | `false` | Enable IPFS-primary mode | +| `L1_PUBLIC_URL` | `http://localhost:8100` | Public URL for redirects | +| `L2_SERVER` | - | L2 ActivityPub server URL | +| `L2_DOMAIN` | - | L2 domain for federation | +| `ARTDAG_CLUSTER_KEY` | - | Cluster key for trust domains | + +### IPFS-Primary Mode + +When `IPFS_PRIMARY=true`, all content is stored on IPFS: +- Input files are added to IPFS on upload +- Analysis results stored as JSON on IPFS +- Execution plans stored on IPFS +- Step outputs pinned to IPFS +- Local cache becomes a read-through cache + +This enables distributed execution across multiple L1 nodes sharing the same IPFS network. + +## Web UI + +| Path | Description | +|------|-------------| +| `/` | Home page with server info | +| `/runs` | View and manage rendering runs | +| `/run/{id}` | Run detail with tabs: Plan, Analysis, Artifacts | +| `/run/{id}/plan` | Interactive DAG visualization | +| `/run/{id}/analysis` | Audio/video analysis data | +| `/run/{id}/artifacts` | Cached step outputs | +| `/recipes` | Browse and run available recipes | +| `/recipe/{id}` | Recipe detail page | +| `/recipe/{id}/dag` | Recipe DAG visualization | +| `/media` | Browse cached media files | +| `/storage` | Manage storage providers | +| `/auth` | Receive auth token from L2 | +| `/logout` | Log out | +| `/download/client` | Download CLI client | + +## API Reference + +Interactive docs: http://localhost:8100/docs + +### Runs + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/runs` | Start a rendering run | +| GET | `/runs` | List all runs (paginated) | +| GET | `/runs/{run_id}` | Get run status | +| DELETE | `/runs/{run_id}` | Delete a run | +| GET | `/api/run/{run_id}` | Get run as JSON | +| GET | `/api/run/{run_id}/plan` | Get execution plan JSON | +| GET | `/api/run/{run_id}/analysis` | Get analysis data JSON | + +### Recipes + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/recipes/upload` | Upload recipe YAML | +| GET | `/recipes` | List recipes (paginated) | +| GET | `/recipes/{recipe_id}` | Get recipe details | +| DELETE | `/recipes/{recipe_id}` | Delete recipe | +| POST | `/recipes/{recipe_id}/run` | Execute recipe | + +### Cache + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/cache/{cid}` | Get cached content (with preview) | +| GET | `/cache/{cid}/raw` | Download raw content | +| GET | `/cache/{cid}/mp4` | Get MP4 video | +| GET | `/cache/{cid}/meta` | Get content metadata | +| PATCH | `/cache/{cid}/meta` | Update metadata | +| POST | `/cache/{cid}/publish` | Publish to L2 | +| DELETE | `/cache/{cid}` | Delete from cache | +| POST | `/cache/import?path=` | Import local file | +| POST | `/cache/upload` | Upload file | +| GET | `/media` | Browse media gallery | + +### IPFS + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/ipfs/{cid}` | Redirect to IPFS gateway | +| GET | `/ipfs/{cid}/raw` | Fetch raw content from IPFS | + +### Storage Providers + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/storage` | List storage providers | +| POST | `/storage` | Add provider (form) | +| POST | `/storage/add` | Add provider (JSON) | +| GET | `/storage/{id}` | Get provider details | +| PATCH | `/storage/{id}` | Update provider | +| DELETE | `/storage/{id}` | Delete provider | +| POST | `/storage/{id}/test` | Test connection | +| GET | `/storage/type/{type}` | Get form for provider type | + +### 3-Phase API + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/api/plan` | Generate execution plan | +| POST | `/api/execute` | Execute a plan | +| POST | `/api/run-recipe` | Full pipeline (analyze+plan+execute) | + +### Authentication + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/auth` | Receive auth token from L2 | +| GET | `/logout` | Log out | +| POST | `/auth/revoke` | Revoke a specific token | +| POST | `/auth/revoke-user` | Revoke all user tokens | + +## 3-Phase Execution + +Recipes are executed in three phases: + +### Phase 1: Analyze +Extract features from input files: +- **Audio/Video**: Tempo, beat times, energy levels +- Results cached by CID + +### Phase 2: Plan +Generate an execution plan: +- Parse recipe YAML +- Resolve dependencies between steps +- Compute cache IDs for each step +- Skip already-cached steps + +### Phase 3: Execute +Run the plan level by level: +- Steps at each level run in parallel +- Results cached with content-addressable hashes +- Progress tracked in Redis + +## Recipe Format + +Recipes define reusable DAG pipelines: + +```yaml +name: beat-sync +version: "1.0" +description: "Synchronize video to audio beats" + +inputs: + video: + type: video + description: "Source video" + audio: + type: audio + description: "Audio track" + +steps: + - id: analyze_audio + type: ANALYZE + inputs: [audio] + config: + features: [beats, energy] + + - id: sync_video + type: BEAT_SYNC + inputs: [video, analyze_audio] + config: + mode: stretch + +output: sync_video +``` + +## Storage + +### Local Cache +- Location: `~/.artdag/cache/` (or `CACHE_DIR`) +- Content-addressed by IPFS CID +- Subdirectories: `plans/`, `analysis/` + +### Redis +- Database 5 (configurable via `REDIS_URL`) +- Keys: + - `artdag:run:*` - Run state + - `artdag:recipe:*` - Recipe definitions + - `artdag:revoked:*` - Token revocation + - `artdag:user_tokens:*` - User token tracking + +### PostgreSQL +- Content metadata +- Storage provider configurations +- Provenance records + +## Authentication + +L1 servers authenticate via L2 (ActivityPub registry). No shared secrets required. + +### Flow +1. User clicks "Attach" on L2's Renderers page +2. L2 creates a scoped token bound to this L1 +3. User redirected to L1's `/auth?auth_token=...` +4. L1 calls L2's `/auth/verify` to validate +5. L1 sets local cookie and records token + +### Token Revocation +- Tokens tracked per-user in Redis +- L2 calls `/auth/revoke-user` on logout +- Revoked hashes stored with 30-day expiry +- Every request checks revocation list + +## CLI Usage + +```bash +# Quick render (effect mode) +python render.py dog cat --sync + +# Submit async +python render.py dog cat + +# Run a recipe +curl -X POST http://localhost:8100/recipes/beat-sync/run \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{"inputs": {"video": "abc123...", "audio": "def456..."}}' +``` + +## Architecture + +``` +L1 Server (FastAPI) + │ + ├── Web UI (Jinja2 + HTMX + Tailwind) + │ + ├── POST /runs → Celery tasks + │ │ + │ └── celery_app.py + │ ├── tasks/analyze.py (Phase 1) + │ ├── tasks/execute.py (Phase 3 steps) + │ └── tasks/orchestrate.py (Full pipeline) + │ + ├── cache_manager.py + │ │ + │ ├── Local filesystem (CACHE_DIR) + │ ├── IPFS (ipfs_client.py) + │ └── S3/Storage providers + │ + └── database.py (PostgreSQL metadata) +``` + +## Provenance + +Every render produces a provenance record: + +```json +{ + "task_id": "celery-task-uuid", + "rendered_at": "2026-01-07T...", + "rendered_by": "@giles@artdag.rose-ash.com", + "output": {"name": "...", "cid": "Qm..."}, + "inputs": [...], + "effects": [...], + "infrastructure": { + "software": {"name": "infra:artdag", "cid": "Qm..."}, + "hardware": {"name": "infra:giles-hp", "cid": "Qm..."} + } +} +``` diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..408983b --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,237 @@ +""" +Art-DAG L1 Server Application Factory. + +Creates and configures the FastAPI application with all routers and middleware. +""" + +import secrets +import time +from pathlib import Path +from urllib.parse import quote + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles + +from artdag_common import create_jinja_env +from artdag_common.middleware.auth import get_user_from_cookie + +from .config import settings + +# Paths that should never trigger a silent auth check +_SKIP_PREFIXES = ("/auth/", "/static/", "/api/", "/ipfs/", "/download/", "/inbox", "/health", "/internal/", "/oembed") +_SILENT_CHECK_COOLDOWN = 300 # 5 minutes +_DEVICE_COOKIE = "artdag_did" +_DEVICE_COOKIE_MAX_AGE = 30 * 24 * 3600 # 30 days + +# Derive external base URL from oauth_redirect_uri (e.g. https://celery-artdag.rose-ash.com) +_EXTERNAL_BASE = settings.oauth_redirect_uri.rsplit("/auth/callback", 1)[0] + + +def _external_url(request: Request) -> str: + """Build external URL from request path + query, using configured base domain.""" + url = f"{_EXTERNAL_BASE}{request.url.path}" + if request.url.query: + url += f"?{request.url.query}" + return url + + +def create_app() -> FastAPI: + """ + Create and configure the L1 FastAPI application. + + Returns: + Configured FastAPI instance + """ + app = FastAPI( + title="Art-DAG L1 Server", + description="Content-addressed media processing with distributed execution", + version="1.0.0", + ) + + # Database lifecycle events + from database import init_db, close_db + + @app.on_event("startup") + async def startup(): + await init_db() + + @app.on_event("shutdown") + async def shutdown(): + await close_db() + + # Silent auth check — auto-login via prompt=none OAuth + # NOTE: registered BEFORE device_id so device_id is outermost (runs first) + @app.middleware("http") + async def silent_auth_check(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _SKIP_PREFIXES) + or request.headers.get("hx-request") # skip HTMX + ): + return await call_next(request) + + # Already logged in — but verify account hasn't logged out + if get_user_from_cookie(request): + device_id = getattr(request.state, "device_id", None) + if device_id: + try: + from .dependencies import get_redis_client + r = get_redis_client() + if not r.get(f"did_auth:{device_id}"): + # Account logged out — clear our cookie + response = await call_next(request) + response.delete_cookie("artdag_session") + response.delete_cookie("pnone_at") + return response + except Exception: + pass + return await call_next(request) + + # Check cooldown — don't re-check within 5 minutes + pnone_at = request.cookies.get("pnone_at") + if pnone_at: + try: + pnone_ts = float(pnone_at) + if (time.time() - pnone_ts) < _SILENT_CHECK_COOLDOWN: + # But first check if account signalled a login via inbox delivery + device_id = getattr(request.state, "device_id", None) + if device_id: + try: + from .dependencies import get_redis_client + r = get_redis_client() + auth_ts = r.get(f"did_auth:{device_id}") + if auth_ts and float(auth_ts) > pnone_ts: + # Login happened since our last check — retry + current_url = _external_url(request) + return RedirectResponse( + url=f"/auth/login?prompt=none&next={quote(current_url, safe='')}", + status_code=302, + ) + except Exception: + pass + return await call_next(request) + except (ValueError, TypeError): + pass + + # Redirect to silent OAuth check + current_url = _external_url(request) + return RedirectResponse( + url=f"/auth/login?prompt=none&next={quote(current_url, safe='')}", + status_code=302, + ) + + # Device ID middleware — track browser identity across domains + # Registered AFTER silent_auth_check so it's outermost (always runs) + @app.middleware("http") + async def device_id_middleware(request: Request, call_next): + did = request.cookies.get(_DEVICE_COOKIE) + if did: + request.state.device_id = did + request.state._new_device_id = False + else: + request.state.device_id = secrets.token_urlsafe(32) + request.state._new_device_id = True + + response = await call_next(request) + + if getattr(request.state, "_new_device_id", False): + response.set_cookie( + key=_DEVICE_COOKIE, + value=request.state.device_id, + max_age=_DEVICE_COOKIE_MAX_AGE, + httponly=True, + samesite="lax", + secure=True, + ) + return response + + # Coop fragment pre-fetch — inject nav-tree, auth-menu, cart-mini into + # request.state for full-page HTML renders. Skips HTMX, API, and + # internal paths. Failures are silent (fragments default to ""). + _FRAG_SKIP = ("/auth/", "/api/", "/internal/", "/health", "/oembed", + "/ipfs/", "/download/", "/inbox", "/static/") + + @app.middleware("http") + async def coop_fragments_middleware(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _FRAG_SKIP) + or request.headers.get("hx-request") + or request.headers.get(fragments.FRAGMENT_HEADER) + ): + request.state.nav_tree_html = "" + request.state.auth_menu_html = "" + request.state.cart_mini_html = "" + return await call_next(request) + + from artdag_common.fragments import fetch_fragments as _fetch_frags + + user = get_user_from_cookie(request) + auth_params = {"email": user.email} if user and user.email else {} + nav_params = {"app_name": "artdag", "path": path} + + try: + nav_tree_html, auth_menu_html, cart_mini_html = await _fetch_frags([ + ("blog", "nav-tree", nav_params), + ("account", "auth-menu", auth_params or None), + ("cart", "cart-mini", None), + ]) + except Exception: + nav_tree_html = auth_menu_html = cart_mini_html = "" + + request.state.nav_tree_html = nav_tree_html + request.state.auth_menu_html = auth_menu_html + request.state.cart_mini_html = cart_mini_html + + return await call_next(request) + + # Initialize Jinja2 templates + template_dir = Path(__file__).parent / "templates" + app.state.templates = create_jinja_env(template_dir) + + # Custom 404 handler + @app.exception_handler(404) + async def not_found_handler(request: Request, exc): + from artdag_common.middleware import wants_html + if wants_html(request): + from artdag_common import render + return render(app.state.templates, "404.html", request, + user=None, + status_code=404, + ) + return JSONResponse({"detail": "Not found"}, status_code=404) + + # Include routers + from .routers import auth, storage, api, recipes, cache, runs, home, effects, inbox, fragments, oembed + + # Home and auth routers (root level) + app.include_router(home.router, tags=["home"]) + app.include_router(auth.router, prefix="/auth", tags=["auth"]) + app.include_router(inbox.router, tags=["inbox"]) + app.include_router(fragments.router, tags=["fragments"]) + app.include_router(oembed.router, tags=["oembed"]) + + # Feature routers + app.include_router(storage.router, prefix="/storage", tags=["storage"]) + app.include_router(api.router, prefix="/api", tags=["api"]) + + # Runs and recipes routers + app.include_router(runs.router, prefix="/runs", tags=["runs"]) + app.include_router(recipes.router, prefix="/recipes", tags=["recipes"]) + + # Cache router - handles /cache and /media + app.include_router(cache.router, prefix="/cache", tags=["cache"]) + # Also mount cache router at /media for convenience + app.include_router(cache.router, prefix="/media", tags=["media"]) + + # Effects router + app.include_router(effects.router, prefix="/effects", tags=["effects"]) + + return app + + +# Create the default app instance +app = create_app() diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..8aa94d7 --- /dev/null +++ b/app/config.py @@ -0,0 +1,116 @@ +""" +L1 Server Configuration. + +Environment-based configuration with sensible defaults. +All config should go through this module - no direct os.environ calls elsewhere. +""" + +import os +import sys +from pathlib import Path +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Settings: + """Application settings loaded from environment.""" + + # Server + host: str = field(default_factory=lambda: os.environ.get("HOST", "0.0.0.0")) + port: int = field(default_factory=lambda: int(os.environ.get("PORT", "8000"))) + debug: bool = field(default_factory=lambda: os.environ.get("DEBUG", "").lower() == "true") + + # Cache (use /data/cache in Docker via env var, ~/.artdag/cache locally) + cache_dir: Path = field( + default_factory=lambda: Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache"))) + ) + + # Redis + redis_url: str = field( + default_factory=lambda: os.environ.get("REDIS_URL", "redis://localhost:6379/5") + ) + + # Database + database_url: str = field( + default_factory=lambda: os.environ.get("DATABASE_URL", "") + ) + + # IPFS + ipfs_api: str = field( + default_factory=lambda: os.environ.get("IPFS_API", "/dns/localhost/tcp/5001") + ) + ipfs_gateway_url: str = field( + default_factory=lambda: os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ) + + # OAuth SSO (replaces L2 auth) + oauth_authorize_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_AUTHORIZE_URL", "https://account.rose-ash.com/auth/oauth/authorize") + ) + oauth_token_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_TOKEN_URL", "https://account.rose-ash.com/auth/oauth/token") + ) + oauth_client_id: str = field( + default_factory=lambda: os.environ.get("OAUTH_CLIENT_ID", "artdag") + ) + oauth_redirect_uri: str = field( + default_factory=lambda: os.environ.get("OAUTH_REDIRECT_URI", "https://celery-artdag.rose-ash.com/auth/callback") + ) + oauth_logout_url: str = field( + default_factory=lambda: os.environ.get("OAUTH_LOGOUT_URL", "https://account.rose-ash.com/auth/sso-logout/") + ) + secret_key: str = field( + default_factory=lambda: os.environ.get("SECRET_KEY", "change-me-in-production") + ) + + # GPU/Streaming settings + streaming_gpu_persist: bool = field( + default_factory=lambda: os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" + ) + ipfs_gateways: str = field( + default_factory=lambda: os.environ.get( + "IPFS_GATEWAYS", "https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link" + ) + ) + + # Derived paths + @property + def plan_cache_dir(self) -> Path: + return self.cache_dir / "plans" + + @property + def analysis_cache_dir(self) -> Path: + return self.cache_dir / "analysis" + + def ensure_dirs(self) -> None: + """Create required directories.""" + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.plan_cache_dir.mkdir(parents=True, exist_ok=True) + self.analysis_cache_dir.mkdir(parents=True, exist_ok=True) + + + def log_config(self, logger=None) -> None: + """Log all configuration values for debugging.""" + output = logger.info if logger else lambda x: print(x, file=sys.stderr) + output("=" * 60) + output("CONFIGURATION") + output("=" * 60) + output(f" cache_dir: {self.cache_dir}") + output(f" redis_url: {self.redis_url}") + output(f" database_url: {self.database_url[:50]}...") + output(f" ipfs_api: {self.ipfs_api}") + output(f" ipfs_gateway_url: {self.ipfs_gateway_url}") + output(f" ipfs_gateways: {self.ipfs_gateways[:50]}...") + output(f" streaming_gpu_persist: {self.streaming_gpu_persist}") + output(f" oauth_client_id: {self.oauth_client_id}") + output(f" oauth_authorize_url: {self.oauth_authorize_url}") + output("=" * 60) + + +# Singleton settings instance +settings = Settings() + +# Log config on import if DEBUG or SHOW_CONFIG is set +if os.environ.get("DEBUG") or os.environ.get("SHOW_CONFIG"): + settings.log_config() diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000..fc59947 --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,186 @@ +""" +FastAPI dependency injection container. + +Provides shared resources and services to route handlers. +""" + +from functools import lru_cache +from typing import Optional +import asyncio + +from fastapi import Request, Depends, HTTPException +from jinja2 import Environment + +from artdag_common.middleware.auth import UserContext, get_user_from_cookie, get_user_from_header + +from .config import settings + + +# Lazy imports to avoid circular dependencies +_redis_client = None +_cache_manager = None +_database = None + + +def get_redis_client(): + """Get the Redis client singleton.""" + global _redis_client + if _redis_client is None: + import redis + _redis_client = redis.from_url(settings.redis_url, decode_responses=True) + return _redis_client + + +def get_cache_manager(): + """Get the cache manager singleton.""" + global _cache_manager + if _cache_manager is None: + from cache_manager import get_cache_manager as _get_cache_manager + _cache_manager = _get_cache_manager() + return _cache_manager + + +def get_database(): + """Get the database singleton.""" + global _database + if _database is None: + import database + _database = database + return _database + + +def get_templates(request: Request) -> Environment: + """Get the Jinja2 environment from app state.""" + return request.app.state.templates + + +async def get_current_user(request: Request) -> Optional[UserContext]: + """ + Get the current user from request (cookie or header). + + This is a permissive dependency - returns None if not authenticated. + Use require_auth for routes that require authentication. + """ + # Try header first (API clients) + ctx = get_user_from_header(request) + if ctx: + return ctx + + # Fall back to cookie (browser) + return get_user_from_cookie(request) + + +async def require_auth(request: Request) -> UserContext: + """ + Require authentication for a route. + + Raises: + HTTPException 401 if not authenticated + HTTPException 302 redirect to login for HTML requests + """ + ctx = await get_current_user(request) + if ctx is None: + # Check if HTML request for redirect + accept = request.headers.get("accept", "") + if "text/html" in accept: + raise HTTPException( + status_code=302, + headers={"Location": "/auth/login"} + ) + raise HTTPException(status_code=401, detail="Authentication required") + return ctx + + +async def get_user_context_from_cookie(request: Request) -> Optional[UserContext]: + """ + Legacy compatibility: get user from cookie. + + Validates token with L2 server if configured. + """ + ctx = get_user_from_cookie(request) + if ctx is None: + return None + + # If L2 server configured, could validate token here + # For now, trust the cookie + return ctx + + +# Service dependencies (lazy loading) + +def get_run_service(): + """Get the run service.""" + from .services.run_service import RunService + return RunService( + database=get_database(), + redis=get_redis_client(), + cache=get_cache_manager(), + ) + + +def get_recipe_service(): + """Get the recipe service.""" + from .services.recipe_service import RecipeService + return RecipeService( + redis=get_redis_client(), # Kept for API compatibility, not used + cache=get_cache_manager(), + ) + + +def get_cache_service(): + """Get the cache service.""" + from .services.cache_service import CacheService + return CacheService( + cache_manager=get_cache_manager(), + database=get_database(), + ) + + +async def get_nav_counts(actor_id: Optional[str] = None) -> dict: + """ + Get counts for navigation bar display. + + Returns dict with: runs, recipes, effects, media, storage + """ + counts = {} + + try: + import database + counts["media"] = await database.count_user_items(actor_id) if actor_id else 0 + except Exception: + pass + + try: + recipe_service = get_recipe_service() + recipes = await recipe_service.list_recipes(actor_id) + counts["recipes"] = len(recipes) + except Exception: + pass + + try: + run_service = get_run_service() + runs = await run_service.list_runs(actor_id) + counts["runs"] = len(runs) + except Exception: + pass + + try: + # Effects are stored in _effects/ directory, not in cache + from pathlib import Path + cache_mgr = get_cache_manager() + effects_dir = Path(cache_mgr.cache_dir) / "_effects" + if effects_dir.exists(): + counts["effects"] = len([d for d in effects_dir.iterdir() if d.is_dir()]) + else: + counts["effects"] = 0 + except Exception: + pass + + try: + import database + storage_providers = await database.get_user_storage_providers(actor_id) if actor_id else [] + counts["storage"] = len(storage_providers) if storage_providers else 0 + except Exception: + pass + + return counts diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py new file mode 100644 index 0000000..7985294 --- /dev/null +++ b/app/repositories/__init__.py @@ -0,0 +1,10 @@ +""" +L1 Server Repositories. + +Data access layer for persistence operations. +""" + +# TODO: Implement repositories +# - RunRepository - Redis-backed run storage +# - RecipeRepository - Redis-backed recipe storage +# - CacheRepository - Filesystem + PostgreSQL cache metadata diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..f0a9d54 --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1,23 @@ +""" +L1 Server Routers. + +Each router handles a specific domain of functionality. +""" + +from . import auth +from . import storage +from . import api +from . import recipes +from . import cache +from . import runs +from . import home + +__all__ = [ + "auth", + "storage", + "api", + "recipes", + "cache", + "runs", + "home", +] diff --git a/app/routers/api.py b/app/routers/api.py new file mode 100644 index 0000000..5288342 --- /dev/null +++ b/app/routers/api.py @@ -0,0 +1,257 @@ +""" +3-phase API routes for L1 server. + +Provides the plan/execute/run-recipe endpoints for programmatic access. +""" + +import hashlib +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from artdag_common.middleware.auth import UserContext +from ..dependencies import require_auth, get_redis_client, get_cache_manager + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Redis key prefix +RUNS_KEY_PREFIX = "artdag:run:" + + +class PlanRequest(BaseModel): + recipe_sexp: str + input_hashes: Dict[str, str] + + +class ExecutePlanRequest(BaseModel): + plan_json: str + run_id: Optional[str] = None + + +class RecipeRunRequest(BaseModel): + recipe_sexp: str + input_hashes: Dict[str, str] + + +def compute_run_id(input_hashes: List[str], recipe: str, recipe_hash: str = None) -> str: + """Compute deterministic run_id from inputs and recipe.""" + data = { + "inputs": sorted(input_hashes), + "recipe": recipe_hash or f"effect:{recipe}", + "version": "1", + } + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha3_256(json_str.encode()).hexdigest() + + +@router.post("/plan") +async def generate_plan_endpoint( + request: PlanRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Generate an execution plan without executing it. + + Phase 1 (Analyze) + Phase 2 (Plan) of the 3-phase model. + Returns the plan with cache status for each step. + """ + from tasks.orchestrate import generate_plan + + try: + task = generate_plan.delay( + recipe_sexp=request.recipe_sexp, + input_hashes=request.input_hashes, + ) + + # Wait for result (plan generation is usually fast) + result = task.get(timeout=60) + + return { + "status": result.get("status"), + "recipe": result.get("recipe"), + "plan_id": result.get("plan_id"), + "total_steps": result.get("total_steps"), + "cached_steps": result.get("cached_steps"), + "pending_steps": result.get("pending_steps"), + "steps": result.get("steps"), + } + except Exception as e: + logger.error(f"Plan generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/execute") +async def execute_plan_endpoint( + request: ExecutePlanRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Execute a pre-generated execution plan. + + Phase 3 (Execute) of the 3-phase model. + Submits the plan to Celery for parallel execution. + """ + from tasks.orchestrate import run_plan + + run_id = request.run_id or str(uuid.uuid4()) + + try: + task = run_plan.delay( + plan_json=request.plan_json, + run_id=run_id, + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + } + except Exception as e: + logger.error(f"Plan execution failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/run-recipe") +async def run_recipe_endpoint( + request: RecipeRunRequest, + ctx: UserContext = Depends(require_auth), +): + """ + Run a complete recipe through all 3 phases. + + 1. Analyze: Extract features from inputs + 2. Plan: Generate execution plan with cache IDs + 3. Execute: Run steps with parallel execution + + Returns immediately with run_id. Poll /api/run/{run_id} for status. + """ + from tasks.orchestrate import run_recipe + from artdag.sexp import compile_string + import database + + redis = get_redis_client() + cache = get_cache_manager() + + # Parse recipe name from S-expression + try: + compiled = compile_string(request.recipe_sexp) + recipe_name = compiled.name or "unknown" + except Exception: + recipe_name = "unknown" + + # Compute deterministic run_id + run_id = compute_run_id( + list(request.input_hashes.values()), + recipe_name, + hashlib.sha3_256(request.recipe_sexp.encode()).hexdigest() + ) + + # Check if already completed + cached = await database.get_run_cache(run_id) + if cached: + output_cid = cached.get("output_cid") + if cache.has_content(output_cid): + return { + "status": "completed", + "run_id": run_id, + "output_cid": output_cid, + "output_ipfs_cid": cache.get_ipfs_cid(output_cid), + "cached": True, + } + + # Submit to Celery + try: + task = run_recipe.delay( + recipe_sexp=request.recipe_sexp, + input_hashes=request.input_hashes, + run_id=run_id, + ) + + # Store run status in Redis + run_data = { + "run_id": run_id, + "status": "pending", + "recipe": recipe_name, + "inputs": list(request.input_hashes.values()), + "celery_task_id": task.id, + "created_at": datetime.now(timezone.utc).isoformat(), + "username": ctx.actor_id, + } + redis.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, + json.dumps(run_data) + ) + + return { + "status": "submitted", + "run_id": run_id, + "celery_task_id": task.id, + "recipe": recipe_name, + } + except Exception as e: + logger.error(f"Recipe run failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/run/{run_id}") +async def get_run_status( + run_id: str, + ctx: UserContext = Depends(require_auth), +): + """Get status of a recipe execution run.""" + import database + from celery.result import AsyncResult + + redis = get_redis_client() + + # Check Redis for run status + run_data = redis.get(f"{RUNS_KEY_PREFIX}{run_id}") + if run_data: + data = json.loads(run_data) + + # If pending, check Celery task status + if data.get("status") == "pending" and data.get("celery_task_id"): + result = AsyncResult(data["celery_task_id"]) + + if result.ready(): + if result.successful(): + task_result = result.get() + data["status"] = task_result.get("status", "completed") + data["output_cid"] = task_result.get("output_cache_id") + data["output_ipfs_cid"] = task_result.get("output_ipfs_cid") + data["total_steps"] = task_result.get("total_steps") + data["cached"] = task_result.get("cached") + data["executed"] = task_result.get("executed") + + # Update Redis + redis.setex( + f"{RUNS_KEY_PREFIX}{run_id}", + 86400, + json.dumps(data) + ) + else: + data["status"] = "failed" + data["error"] = str(result.result) + else: + data["celery_status"] = result.status + + return data + + # Check database cache + cached = await database.get_run_cache(run_id) + if cached: + return { + "run_id": run_id, + "status": "completed", + "output_cid": cached.get("output_cid"), + "cached": True, + } + + raise HTTPException(status_code=404, detail="Run not found") diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..c447f3d --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,165 @@ +""" +Authentication routes — OAuth2 authorization code flow via account.rose-ash.com. + +GET /auth/login — redirect to account OAuth authorize +GET /auth/callback — exchange code for user info, set session cookie +GET /auth/logout — clear cookie, redirect through account SSO logout +""" + +import secrets +import time + +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import RedirectResponse +from itsdangerous import URLSafeSerializer + +from artdag_common.middleware.auth import UserContext, set_auth_cookie, clear_auth_cookie + +from ..config import settings + +router = APIRouter() + +_signer = None + + +def _get_signer() -> URLSafeSerializer: + global _signer + if _signer is None: + _signer = URLSafeSerializer(settings.secret_key, salt="oauth-state") + return _signer + + +@router.get("/login") +async def login(request: Request): + """Store state + next in signed cookie, redirect to account OAuth authorize.""" + next_url = request.query_params.get("next", "/") + prompt = request.query_params.get("prompt", "") + state = secrets.token_urlsafe(32) + + signer = _get_signer() + state_payload = signer.dumps({"state": state, "next": next_url, "prompt": prompt}) + + device_id = getattr(request.state, "device_id", "") + authorize_url = ( + f"{settings.oauth_authorize_url}" + f"?client_id={settings.oauth_client_id}" + f"&redirect_uri={settings.oauth_redirect_uri}" + f"&device_id={device_id}" + f"&state={state}" + ) + if prompt: + authorize_url += f"&prompt={prompt}" + + response = RedirectResponse(url=authorize_url, status_code=302) + response.set_cookie( + key="oauth_state", + value=state_payload, + max_age=600, # 10 minutes + httponly=True, + samesite="lax", + secure=True, + ) + return response + + +@router.get("/callback") +async def callback(request: Request): + """Validate state, exchange code via token endpoint, set session cookie.""" + code = request.query_params.get("code", "") + state = request.query_params.get("state", "") + error = request.query_params.get("error", "") + account_did = request.query_params.get("account_did", "") + + # Adopt account's device ID as our own (one identity across all apps) + if account_did: + request.state.device_id = account_did + request.state._new_device_id = True # device_id middleware will set cookie + + # Recover state from signed cookie + state_cookie = request.cookies.get("oauth_state", "") + signer = _get_signer() + try: + payload = signer.loads(state_cookie) if state_cookie else {} + except Exception: + payload = {} + + next_url = payload.get("next", "/") + + # Handle prompt=none rejection (user not logged in on account) + if error == "login_required": + response = RedirectResponse(url=next_url, status_code=302) + response.delete_cookie("oauth_state") + # Set cooldown cookie — don't re-check for 5 minutes + response.set_cookie( + key="pnone_at", + value=str(time.time()), + max_age=300, + httponly=True, + samesite="lax", + secure=True, + ) + # Set device cookie if adopted + if account_did: + response.set_cookie( + key="artdag_did", + value=account_did, + max_age=30 * 24 * 3600, + httponly=True, + samesite="lax", + secure=True, + ) + return response + + # Normal callback — validate state + code + if not state_cookie or not code or not state: + return RedirectResponse(url="/", status_code=302) + + if payload.get("state") != state: + return RedirectResponse(url="/", status_code=302) + + # Exchange code for user info via account's token endpoint + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + settings.oauth_token_url, + json={ + "code": code, + "client_id": settings.oauth_client_id, + "redirect_uri": settings.oauth_redirect_uri, + }, + ) + except httpx.HTTPError: + return RedirectResponse(url="/", status_code=302) + + if resp.status_code != 200: + return RedirectResponse(url="/", status_code=302) + + data = resp.json() + if "error" in data: + return RedirectResponse(url="/", status_code=302) + + # Map OAuth response to artdag UserContext + # Note: account token endpoint returns user.email as "username" + display_name = data.get("display_name", "") + username = data.get("username", "") + email = username # OAuth response "username" is the user's email + actor_id = f"@{username}" + + user = UserContext(username=username, actor_id=actor_id, email=email) + + response = RedirectResponse(url=next_url, status_code=302) + set_auth_cookie(response, user) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") + return response + + +@router.get("/logout") +async def logout(): + """Clear session cookie, redirect through account SSO logout.""" + response = RedirectResponse(url=settings.oauth_logout_url, status_code=302) + clear_auth_cookie(response) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") + return response diff --git a/app/routers/cache.py b/app/routers/cache.py new file mode 100644 index 0000000..dc03d44 --- /dev/null +++ b/app/routers/cache.py @@ -0,0 +1,515 @@ +""" +Cache and media routes for L1 server. + +Handles content retrieval, metadata, media preview, and publishing. +""" + +import logging +from pathlib import Path +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File, Form +from fastapi.responses import HTMLResponse, FileResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_redis_client, + get_cache_manager, get_current_user +) +from ..services.auth_service import AuthService +from ..services.cache_service import CacheService + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class UpdateMetadataRequest(BaseModel): + title: Optional[str] = None + description: Optional[str] = None + tags: Optional[list] = None + custom: Optional[Dict[str, Any]] = None + + +def get_cache_service(): + """Get cache service instance.""" + import database + return CacheService(database, get_cache_manager()) + + +@router.get("/{cid}") +async def get_cached( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get cached content by hash. Content negotiation: HTML for browsers, JSON for APIs.""" + ctx = await get_current_user(request) + + # Pass actor_id to get friendly name and user-specific metadata + actor_id = ctx.actor_id if ctx else None + cache_item = await cache_service.get_cache_item(cid, actor_id=actor_id) + if not cache_item: + if wants_html(request): + templates = get_templates(request) + return render(templates, "cache/not_found.html", request, + cid=cid, + user=ctx, + active_tab="media", + ) + raise HTTPException(404, f"Content {cid} not in cache") + + # JSON response + if wants_json(request): + return cache_item + + # HTML response + if not ctx: + from fastapi.responses import RedirectResponse + return RedirectResponse(url="/auth", status_code=302) + + # Check access + has_access = await cache_service.check_access(cid, ctx.actor_id, ctx.username) + if not has_access: + raise HTTPException(403, "Access denied") + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "cache/detail.html", request, + cache=cache_item, + user=ctx, + nav_counts=nav_counts, + active_tab="media", + ) + + +@router.get("/{cid}/raw") +async def get_cached_raw( + cid: str, + cache_service: CacheService = Depends(get_cache_service), +): + """Get raw cached content (file download).""" + file_path, media_type, filename = await cache_service.get_raw_file(cid) + + if not file_path: + raise HTTPException(404, f"Content {cid} not in cache") + + return FileResponse(file_path, media_type=media_type, filename=filename) + + +@router.get("/{cid}/mp4") +async def get_cached_mp4( + cid: str, + cache_service: CacheService = Depends(get_cache_service), +): + """Get cached content as MP4 (transcodes MKV on first request).""" + mp4_path, error = await cache_service.get_as_mp4(cid) + + if error: + raise HTTPException(400 if "not a video" in error else 404, error) + + return FileResponse(mp4_path, media_type="video/mp4") + + +@router.get("/{cid}/meta") +async def get_metadata( + cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Get content metadata.""" + meta = await cache_service.get_metadata(cid, ctx.actor_id) + if meta is None: + raise HTTPException(404, "Content not found") + return meta + + +@router.patch("/{cid}/meta") +async def update_metadata( + cid: str, + req: UpdateMetadataRequest, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Update content metadata.""" + success, error = await cache_service.update_metadata( + cid=cid, + actor_id=ctx.actor_id, + title=req.title, + description=req.description, + tags=req.tags, + custom=req.custom, + ) + + if error: + raise HTTPException(400, error) + + return {"updated": True} + + +@router.post("/{cid}/publish") +async def publish_content( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Publish content to L2 and IPFS.""" + ipfs_cid, error = await cache_service.publish_to_l2( + cid=cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Published: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "published": True} + + +@router.delete("/{cid}") +async def delete_content( + cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Delete content from cache.""" + success, error = await cache_service.delete_content(cid, ctx.actor_id) + + if error: + raise HTTPException(400 if "Cannot" in error or "pinned" in error else 404, error) + + return {"deleted": True} + + +@router.post("/import") +async def import_from_ipfs( + ipfs_cid: str, + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Import content from IPFS.""" + cid, error = await cache_service.import_from_ipfs(ipfs_cid, ctx.actor_id) + + if error: + raise HTTPException(400, error) + + return {"cid": cid, "imported": True} + + +@router.post("/upload/chunk") +async def upload_chunk( + request: Request, + chunk: UploadFile = File(...), + upload_id: str = Form(...), + chunk_index: int = Form(...), + total_chunks: int = Form(...), + filename: str = Form(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Upload a file chunk. Assembles file when all chunks received.""" + import tempfile + import os + + # Create temp dir for this upload + chunk_dir = Path(tempfile.gettempdir()) / "uploads" / upload_id + chunk_dir.mkdir(parents=True, exist_ok=True) + + # Save this chunk + chunk_path = chunk_dir / f"chunk_{chunk_index:05d}" + chunk_data = await chunk.read() + chunk_path.write_bytes(chunk_data) + + # Check if all chunks received + received = len(list(chunk_dir.glob("chunk_*"))) + + if received < total_chunks: + return {"status": "partial", "received": received, "total": total_chunks} + + # All chunks received - assemble file + final_path = chunk_dir / filename + with open(final_path, 'wb') as f: + for i in range(total_chunks): + cp = chunk_dir / f"chunk_{i:05d}" + f.write(cp.read_bytes()) + cp.unlink() # Clean up chunk + + # Read assembled file + content = final_path.read_bytes() + final_path.unlink() + chunk_dir.rmdir() + + # Now do the normal upload flow + cid, ipfs_cid, error = await cache_service.upload_content( + content=content, + filename=filename, + actor_id=ctx.actor_id, + ) + + if error: + raise HTTPException(400, error) + + # Assign friendly name + final_cid = ipfs_cid or cid + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=final_cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, + filename=filename, + ) + + return { + "status": "complete", + "cid": final_cid, + "friendly_name": friendly_entry["friendly_name"], + "filename": filename, + "size": len(content), + "uploaded": True, + } + + +@router.post("/upload") +async def upload_content( + file: UploadFile = File(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), + cache_service: CacheService = Depends(get_cache_service), +): + """Upload content to cache and IPFS. + + Args: + file: The file to upload + display_name: Optional custom name for the media (used as friendly name) + """ + content = await file.read() + cid, ipfs_cid, error = await cache_service.upload_content( + content=content, + filename=file.filename, + actor_id=ctx.actor_id, + ) + + if error: + raise HTTPException(400, error) + + # Assign friendly name (use IPFS CID if available, otherwise local hash) + final_cid = ipfs_cid or cid + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=final_cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, # Use custom name if provided + filename=file.filename, + ) + + return { + "cid": final_cid, + "content_hash": cid, # Legacy, for backwards compatibility + "friendly_name": friendly_entry["friendly_name"], + "filename": file.filename, + "size": len(content), + "uploaded": True, + } + + +# Media listing endpoint +@router.get("") +async def list_media( + request: Request, + offset: int = 0, + limit: int = 24, + media_type: Optional[str] = None, + cache_service: CacheService = Depends(get_cache_service), + ctx: UserContext = Depends(require_auth), +): + """List all media in cache.""" + items = await cache_service.list_media( + actor_id=ctx.actor_id, + username=ctx.username, + offset=offset, + limit=limit, + media_type=media_type, + ) + has_more = len(items) >= limit + + if wants_json(request): + return {"items": items, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "cache/media_list.html", request, + items=items, + user=ctx, + nav_counts=nav_counts, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="media", + ) + + +# HTMX metadata form +@router.get("/{cid}/meta-form", response_class=HTMLResponse) +async def get_metadata_form( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get metadata editing form (HTMX).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + meta = await cache_service.get_metadata(cid, ctx.actor_id) + + return HTMLResponse(f''' +

Metadata

+
+
+ + +
+
+ + +
+ +
+ ''') + + +@router.patch("/{cid}/meta", response_class=HTMLResponse) +async def update_metadata_htmx( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Update metadata (HTMX form handler).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + form_data = await request.form() + + success, error = await cache_service.update_metadata( + cid=cid, + actor_id=ctx.actor_id, + title=form_data.get("title"), + description=form_data.get("description"), + ) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse(''' +
Metadata saved!
+ + ''') + + +# Friendly name editing +@router.get("/{cid}/name-form", response_class=HTMLResponse) +async def get_name_form( + cid: str, + request: Request, + cache_service: CacheService = Depends(get_cache_service), +): + """Get friendly name editing form (HTMX).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + # Get current friendly name + from ..services.naming_service import get_naming_service + naming = get_naming_service() + entry = await naming.get_by_cid(ctx.actor_id, cid) + current_name = entry.get("base_name", "") if entry else "" + + return HTMLResponse(f''' +
+
+ + +

A name to reference this media in recipes

+
+
+ + +
+
+ ''') + + +@router.post("/{cid}/name", response_class=HTMLResponse) +async def update_friendly_name( + cid: str, + request: Request, +): + """Update friendly name (HTMX form handler).""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
') + + form_data = await request.form() + display_name = form_data.get("display_name", "").strip() + + if not display_name: + return HTMLResponse('
Name cannot be empty
') + + from ..services.naming_service import get_naming_service + naming = get_naming_service() + + try: + entry = await naming.assign_name( + cid=cid, + actor_id=ctx.actor_id, + item_type="media", + display_name=display_name, + ) + + return HTMLResponse(f''' +
Name updated!
+ + ''') + except Exception as e: + return HTMLResponse(f'
Error: {e}
') diff --git a/app/routers/effects.py b/app/routers/effects.py new file mode 100644 index 0000000..994a925 --- /dev/null +++ b/app/routers/effects.py @@ -0,0 +1,415 @@ +""" +Effects routes for L1 server. + +Handles effect upload, listing, and metadata. +Effects are S-expression files stored in IPFS like all other content-addressed data. +""" + +import json +import logging +import re +import time +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File, Form +from fastapi.responses import HTMLResponse, PlainTextResponse + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_redis_client, + get_cache_manager, +) +from ..services.auth_service import AuthService +import ipfs_client + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def get_effects_dir() -> Path: + """Get effects storage directory.""" + cache_mgr = get_cache_manager() + effects_dir = Path(cache_mgr.cache_dir) / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + return effects_dir + + +def parse_effect_metadata(source: str) -> dict: + """ + Parse effect metadata from S-expression source code. + + Extracts metadata from comment headers (;; @key value format) + or from (defeffect name ...) form. + """ + metadata = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + } + + # Parse comment-based metadata (;; @key value) + for line in source.split("\n"): + stripped = line.strip() + if not stripped.startswith(";"): + # Stop parsing metadata at first non-comment line + if stripped and not stripped.startswith("("): + continue + if stripped.startswith("("): + break + + # Remove comment prefix + comment = stripped.lstrip(";").strip() + + if comment.startswith("@effect "): + metadata["name"] = comment[8:].strip() + elif comment.startswith("@name "): + metadata["name"] = comment[6:].strip() + elif comment.startswith("@version "): + metadata["version"] = comment[9:].strip() + elif comment.startswith("@author "): + metadata["author"] = comment[8:].strip() + elif comment.startswith("@temporal"): + val = comment[9:].strip().lower() if len(comment) > 9 else "true" + metadata["temporal"] = val in ("true", "yes", "1", "") + elif comment.startswith("@description "): + metadata["description"] = comment[13:].strip() + elif comment.startswith("@param "): + # Format: @param name type [description] + parts = comment[7:].split(None, 2) + if len(parts) >= 2: + param = {"name": parts[0], "type": parts[1]} + if len(parts) > 2: + param["description"] = parts[2] + metadata["params"].append(param) + + # Also try to extract name from (defeffect "name" ...) or (effect "name" ...) + if not metadata["name"]: + name_match = re.search(r'\((defeffect|effect)\s+"([^"]+)"', source) + if name_match: + metadata["name"] = name_match.group(2) + + # Try to extract name from first (define ...) form + if not metadata["name"]: + define_match = re.search(r'\(define\s+(\w+)', source) + if define_match: + metadata["name"] = define_match.group(1) + + return metadata + + +@router.post("/upload") +async def upload_effect( + file: UploadFile = File(...), + display_name: Optional[str] = Form(None), + ctx: UserContext = Depends(require_auth), +): + """ + Upload an S-expression effect to IPFS. + + Parses metadata from comment headers. + Returns IPFS CID for use in recipes. + + Args: + file: The .sexp effect file + display_name: Optional custom friendly name for the effect + """ + content = await file.read() + + try: + source = content.decode("utf-8") + except UnicodeDecodeError: + raise HTTPException(400, "Effect must be valid UTF-8 text") + + # Parse metadata from sexp source + try: + meta = parse_effect_metadata(source) + except Exception as e: + logger.warning(f"Failed to parse effect metadata: {e}") + meta = {"name": file.filename or "unknown"} + + if not meta.get("name"): + meta["name"] = Path(file.filename).stem if file.filename else "unknown" + + # Store effect source in IPFS + cid = ipfs_client.add_bytes(content) + if not cid: + raise HTTPException(500, "Failed to store effect in IPFS") + + # Also keep local cache for fast worker access + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + effect_dir.mkdir(parents=True, exist_ok=True) + (effect_dir / "effect.sexp").write_text(source, encoding="utf-8") + + # Store metadata (locally and in IPFS) + full_meta = { + "cid": cid, + "meta": meta, + "uploader": ctx.actor_id, + "uploaded_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "filename": file.filename, + } + (effect_dir / "metadata.json").write_text(json.dumps(full_meta, indent=2)) + + # Also store metadata in IPFS for discoverability + meta_cid = ipfs_client.add_json(full_meta) + + # Track ownership in item_types + import database + await database.save_item_metadata( + cid=cid, + actor_id=ctx.actor_id, + item_type="effect", + filename=file.filename, + ) + + # Assign friendly name (use custom display_name if provided, else from metadata) + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly_entry = await naming.assign_name( + cid=cid, + actor_id=ctx.actor_id, + item_type="effect", + display_name=display_name or meta.get("name"), + filename=file.filename, + ) + + logger.info(f"Uploaded effect '{meta.get('name')}' cid={cid} friendly_name='{friendly_entry['friendly_name']}' by {ctx.actor_id}") + + return { + "cid": cid, + "metadata_cid": meta_cid, + "name": meta.get("name"), + "friendly_name": friendly_entry["friendly_name"], + "version": meta.get("version"), + "temporal": meta.get("temporal", False), + "params": meta.get("params", []), + "uploaded": True, + } + + +@router.get("/{cid}") +async def get_effect( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Get effect metadata by CID.""" + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + metadata_path = effect_dir / "metadata.json" + + # Try local cache first + if metadata_path.exists(): + meta = json.loads(metadata_path.read_text()) + else: + # Fetch from IPFS + source_bytes = ipfs_client.get_bytes(cid) + if not source_bytes: + raise HTTPException(404, f"Effect {cid[:16]}... not found") + + # Cache locally + effect_dir.mkdir(parents=True, exist_ok=True) + source = source_bytes.decode("utf-8") + (effect_dir / "effect.sexp").write_text(source) + + # Parse metadata from source + parsed_meta = parse_effect_metadata(source) + meta = {"cid": cid, "meta": parsed_meta} + (effect_dir / "metadata.json").write_text(json.dumps(meta, indent=2)) + + # Add friendly name if available + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(ctx.actor_id, cid) + if friendly: + meta["friendly_name"] = friendly["friendly_name"] + meta["base_name"] = friendly["base_name"] + meta["version_id"] = friendly["version_id"] + + if wants_json(request): + return meta + + # HTML response + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "effects/detail.html", request, + effect=meta, + user=ctx, + nav_counts=nav_counts, + active_tab="effects", + ) + + +@router.get("/{cid}/source") +async def get_effect_source( + cid: str, + ctx: UserContext = Depends(require_auth), +): + """Get effect source code.""" + effects_dir = get_effects_dir() + source_path = effects_dir / cid / "effect.sexp" + + # Try local cache first (check both .sexp and legacy .py) + if source_path.exists(): + return PlainTextResponse(source_path.read_text()) + + legacy_path = effects_dir / cid / "effect.py" + if legacy_path.exists(): + return PlainTextResponse(legacy_path.read_text()) + + # Fetch from IPFS + source_bytes = ipfs_client.get_bytes(cid) + if not source_bytes: + raise HTTPException(404, f"Effect {cid[:16]}... not found") + + # Cache locally + source_path.parent.mkdir(parents=True, exist_ok=True) + source = source_bytes.decode("utf-8") + source_path.write_text(source) + + return PlainTextResponse(source) + + +@router.get("") +async def list_effects( + request: Request, + offset: int = 0, + limit: int = 20, + ctx: UserContext = Depends(require_auth), +): + """List user's effects with pagination.""" + import database + effects_dir = get_effects_dir() + effects = [] + + # Get user's effect CIDs from item_types + user_items = await database.get_user_items(ctx.actor_id, item_type="effect", limit=1000) + effect_cids = [item["cid"] for item in user_items] + + # Get naming service for friendly name lookup + from ..services.naming_service import get_naming_service + naming = get_naming_service() + + for cid in effect_cids: + effect_dir = effects_dir / cid + metadata_path = effect_dir / "metadata.json" + if metadata_path.exists(): + try: + meta = json.loads(metadata_path.read_text()) + # Add friendly name if available + friendly = await naming.get_by_cid(ctx.actor_id, cid) + if friendly: + meta["friendly_name"] = friendly["friendly_name"] + meta["base_name"] = friendly["base_name"] + effects.append(meta) + except json.JSONDecodeError: + pass + + # Sort by upload time (newest first) + effects.sort(key=lambda e: e.get("uploaded_at", ""), reverse=True) + + # Apply pagination + total = len(effects) + paginated_effects = effects[offset:offset + limit] + has_more = offset + limit < total + + if wants_json(request): + return {"effects": paginated_effects, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "effects/list.html", request, + effects=paginated_effects, + user=ctx, + nav_counts=nav_counts, + active_tab="effects", + offset=offset, + limit=limit, + has_more=has_more, + ) + + +@router.post("/{cid}/publish") +async def publish_effect( + cid: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Publish effect to L2 ActivityPub server.""" + from ..services.cache_service import CacheService + import database + + # Verify effect exists + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + if not effect_dir.exists(): + error = "Effect not found" + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(404, error) + + # Use cache service to publish + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + logger.info(f"Published effect {cid[:16]}... to L2 by {ctx.actor_id}") + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "cid": cid, "published": True} + + +@router.delete("/{cid}") +async def delete_effect( + cid: str, + ctx: UserContext = Depends(require_auth), +): + """Remove user's ownership link to an effect.""" + import database + + # Remove user's ownership link from item_types + await database.delete_item_type(cid, ctx.actor_id, "effect") + + # Remove friendly name + await database.delete_friendly_name(ctx.actor_id, cid) + + # Check if anyone still owns this effect + remaining_owners = await database.get_item_types(cid) + + # Only delete local files if no one owns it anymore + if not remaining_owners: + effects_dir = get_effects_dir() + effect_dir = effects_dir / cid + if effect_dir.exists(): + import shutil + shutil.rmtree(effect_dir) + + # Unpin from IPFS + ipfs_client.unpin(cid) + logger.info(f"Garbage collected effect {cid[:16]}... (no remaining owners)") + + logger.info(f"Removed effect {cid[:16]}... ownership for {ctx.actor_id}") + return {"deleted": True} diff --git a/app/routers/fragments.py b/app/routers/fragments.py new file mode 100644 index 0000000..5d6d821 --- /dev/null +++ b/app/routers/fragments.py @@ -0,0 +1,143 @@ +""" +Art-DAG fragment endpoints. + +Exposes HTML fragments at ``/internal/fragments/{type}`` for consumption +by coop apps via the fragment client. +""" + +import os + +from fastapi import APIRouter, Request, Response + +router = APIRouter() + +# Registry of fragment handlers: type -> async callable(request) returning HTML str +_handlers: dict[str, object] = {} + +FRAGMENT_HEADER = "X-Fragment-Request" + + +@router.get("/internal/fragments/{fragment_type}") +async def get_fragment(fragment_type: str, request: Request): + if not request.headers.get(FRAGMENT_HEADER): + return Response(content="", status_code=403) + + handler = _handlers.get(fragment_type) + if handler is None: + return Response(content="", media_type="text/html", status_code=200) + html = await handler(request) + return Response(content=html, media_type="text/html", status_code=200) + + +# --- nav-item fragment --- + +async def _nav_item_handler(request: Request) -> str: + from artdag_common import render_fragment + + templates = request.app.state.templates + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + return render_fragment(templates, "fragments/nav_item.html", artdag_url=artdag_url) + + +_handlers["nav-item"] = _nav_item_handler + + +# --- link-card fragment --- + +async def _link_card_handler(request: Request) -> str: + from artdag_common import render_fragment + import database + + templates = request.app.state.templates + cid = request.query_params.get("cid", "") + content_type = request.query_params.get("type", "media") + slug = request.query_params.get("slug", "") + keys_raw = request.query_params.get("keys", "") + + # Batch mode: return multiple cards separated by markers + if keys_raw: + keys = [k.strip() for k in keys_raw.split(",") if k.strip()] + parts = [] + for key in keys: + parts.append(f"") + card_html = await _render_single_link_card( + templates, key, content_type, + ) + parts.append(card_html) + return "\n".join(parts) + + # Single mode: use cid or slug + lookup_cid = cid or slug + if not lookup_cid: + return "" + return await _render_single_link_card(templates, lookup_cid, content_type) + + +async def _render_single_link_card(templates, cid: str, content_type: str) -> str: + import database + from artdag_common import render_fragment + + if not cid: + return "" + + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + + # Try item_types first (has metadata) + item = await database.get_item_types(cid) + # get_item_types returns a list; pick best match for content_type + meta = None + if item: + for it in item: + if it.get("type") == content_type: + meta = it + break + if not meta: + meta = item[0] + + # Try friendly name for display + friendly = None + if meta and meta.get("actor_id"): + friendly = await database.get_friendly_name_by_cid(meta["actor_id"], cid) + + # Try run cache if type is "run" + run = None + if content_type == "run": + run = await database.get_run_cache(cid) + + title = "" + description = "" + link = "" + + if friendly: + title = friendly.get("display_name") or friendly.get("base_name", cid[:12]) + elif meta: + title = meta.get("filename") or meta.get("description", cid[:12]) + elif run: + title = f"Run {cid[:12]}" + else: + title = cid[:16] + + if meta: + description = meta.get("description", "") + + if content_type == "run": + link = f"{artdag_url}/runs/{cid}" + elif content_type == "recipe": + link = f"{artdag_url}/recipes/{cid}" + elif content_type == "effect": + link = f"{artdag_url}/effects/{cid}" + else: + link = f"{artdag_url}/cache/{cid}" + + return render_fragment( + templates, "fragments/link_card.html", + title=title, + description=description, + link=link, + cid=cid, + content_type=content_type, + artdag_url=artdag_url, + ) + + +_handlers["link-card"] = _link_card_handler diff --git a/app/routers/home.py b/app/routers/home.py new file mode 100644 index 0000000..4b89b94 --- /dev/null +++ b/app/routers/home.py @@ -0,0 +1,253 @@ +""" +Home and root routes for L1 server. +""" + +from pathlib import Path + +import markdown +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse, RedirectResponse, FileResponse + +from artdag_common import render +from artdag_common.middleware import wants_html + +from ..dependencies import get_templates, get_current_user + +router = APIRouter() + + +@router.get("/health") +async def health(): + """Health check endpoint — always returns 200.""" + return {"status": "ok"} + + +async def get_user_stats(actor_id: str) -> dict: + """Get stats for a user.""" + import database + from ..services.run_service import RunService + from ..dependencies import get_redis_client, get_cache_manager + + stats = {} + + try: + # Count only actual media types (video, image, audio), not effects/recipes + media_count = 0 + for media_type in ["video", "image", "audio", "unknown"]: + media_count += await database.count_user_items(actor_id, item_type=media_type) + stats["media"] = media_count + except Exception: + stats["media"] = 0 + + try: + # Count user's recipes from database (ownership-based) + stats["recipes"] = await database.count_user_items(actor_id, item_type="recipe") + except Exception: + stats["recipes"] = 0 + + try: + run_service = RunService(database, get_redis_client(), get_cache_manager()) + runs = await run_service.list_runs(actor_id) + stats["runs"] = len(runs) + except Exception: + stats["runs"] = 0 + + try: + storage_providers = await database.get_user_storage_providers(actor_id) + stats["storage"] = len(storage_providers) if storage_providers else 0 + except Exception: + stats["storage"] = 0 + + try: + # Count user's effects from database (ownership-based) + stats["effects"] = await database.count_user_items(actor_id, item_type="effect") + except Exception: + stats["effects"] = 0 + + return stats + + +@router.get("/api/stats") +async def api_stats(request: Request): + """Get user stats as JSON for CLI and API clients.""" + user = await get_current_user(request) + if not user: + raise HTTPException(401, "Authentication required") + + stats = await get_user_stats(user.actor_id) + return stats + + +@router.delete("/api/clear-data") +async def clear_user_data(request: Request): + """ + Clear all user L1 data except storage configuration. + + Deletes: runs, recipes, effects, media/cache items. + Preserves: storage provider configurations. + """ + import logging + logger = logging.getLogger(__name__) + + user = await get_current_user(request) + if not user: + raise HTTPException(401, "Authentication required") + + import database + from ..services.recipe_service import RecipeService + from ..services.run_service import RunService + from ..dependencies import get_redis_client, get_cache_manager + + actor_id = user.actor_id + username = user.username + deleted = { + "runs": 0, + "recipes": 0, + "effects": 0, + "media": 0, + } + errors = [] + + # Delete all runs + try: + run_service = RunService(database, get_redis_client(), get_cache_manager()) + runs = await run_service.list_runs(actor_id, offset=0, limit=10000) + for run in runs: + try: + await run_service.discard_run(run["run_id"], actor_id, username) + deleted["runs"] += 1 + except Exception as e: + errors.append(f"Run {run['run_id']}: {e}") + except Exception as e: + errors.append(f"Failed to list runs: {e}") + + # Delete all recipes + try: + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipes = await recipe_service.list_recipes(actor_id, offset=0, limit=10000) + for recipe in recipes: + try: + success, error = await recipe_service.delete_recipe(recipe["recipe_id"], actor_id) + if success: + deleted["recipes"] += 1 + else: + errors.append(f"Recipe {recipe['recipe_id']}: {error}") + except Exception as e: + errors.append(f"Recipe {recipe['recipe_id']}: {e}") + except Exception as e: + errors.append(f"Failed to list recipes: {e}") + + # Delete all effects (uses ownership model) + cache_manager = get_cache_manager() + try: + # Get user's effects from item_types + effect_items = await database.get_user_items(actor_id, item_type="effect", limit=10000) + for item in effect_items: + cid = item.get("cid") + if cid: + try: + # Remove ownership link + await database.delete_item_type(cid, actor_id, "effect") + await database.delete_friendly_name(actor_id, cid) + + # Check if orphaned + remaining = await database.get_item_types(cid) + if not remaining: + # Garbage collect + effects_dir = Path(cache_manager.cache_dir) / "_effects" / cid + if effects_dir.exists(): + import shutil + shutil.rmtree(effects_dir) + import ipfs_client + ipfs_client.unpin(cid) + deleted["effects"] += 1 + except Exception as e: + errors.append(f"Effect {cid[:16]}...: {e}") + except Exception as e: + errors.append(f"Failed to delete effects: {e}") + + # Delete all media/cache items for user (uses ownership model) + try: + from ..services.cache_service import CacheService + cache_service = CacheService(database, cache_manager) + + # Get user's media items (video, image, audio) + for media_type in ["video", "image", "audio", "unknown"]: + items = await database.get_user_items(actor_id, item_type=media_type, limit=10000) + for item in items: + cid = item.get("cid") + if cid: + try: + success, error = await cache_service.delete_content(cid, actor_id) + if success: + deleted["media"] += 1 + elif error: + errors.append(f"Media {cid[:16]}...: {error}") + except Exception as e: + errors.append(f"Media {cid[:16]}...: {e}") + except Exception as e: + errors.append(f"Failed to delete media: {e}") + + logger.info(f"Cleared data for {actor_id}: {deleted}") + if errors: + logger.warning(f"Errors during clear: {errors[:10]}") # Log first 10 errors + + return { + "message": "User data cleared", + "deleted": deleted, + "errors": errors[:10] if errors else [], # Return first 10 errors + "storage_preserved": True, + } + + +@router.get("/") +async def home(request: Request): + """ + Home page - show README and stats. + """ + user = await get_current_user(request) + + # Load README + readme_html = "" + try: + readme_path = Path(__file__).parent.parent.parent / "README.md" + if readme_path.exists(): + readme_html = markdown.markdown(readme_path.read_text(), extensions=['tables', 'fenced_code']) + except Exception: + pass + + # Get stats for current user + stats = {} + if user: + stats = await get_user_stats(user.actor_id) + + templates = get_templates(request) + return render(templates, "home.html", request, + user=user, + readme_html=readme_html, + stats=stats, + nav_counts=stats, # Reuse stats for nav counts + active_tab="home", + ) + + +@router.get("/login") +async def login_redirect(request: Request): + """Redirect to OAuth login flow.""" + return RedirectResponse(url="/auth/login", status_code=302) + + +# Client tarball path +CLIENT_TARBALL = Path(__file__).parent.parent.parent / "artdag-client.tar.gz" + + +@router.get("/download/client") +async def download_client(): + """Download the Art DAG CLI client.""" + if not CLIENT_TARBALL.exists(): + raise HTTPException(404, "Client package not found. Run build-client.sh to create it.") + return FileResponse( + CLIENT_TARBALL, + media_type="application/gzip", + filename="artdag-client.tar.gz" + ) diff --git a/app/routers/inbox.py b/app/routers/inbox.py new file mode 100644 index 0000000..d6fa37c --- /dev/null +++ b/app/routers/inbox.py @@ -0,0 +1,125 @@ +"""AP-style inbox endpoint for receiving signed activities from the coop. + +POST /inbox — verify HTTP Signature, dispatch by activity type. +""" +from __future__ import annotations + +import logging +import time + +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +from ..dependencies import get_redis_client +from ..utils.http_signatures import verify_request_signature, parse_key_id + +log = logging.getLogger(__name__) +router = APIRouter() + +# Cache fetched public keys in Redis for 24 hours +_KEY_CACHE_TTL = 86400 + + +async def _fetch_actor_public_key(actor_url: str) -> str | None: + """Fetch an actor's public key, with Redis caching.""" + redis = get_redis_client() + cache_key = f"actor_pubkey:{actor_url}" + + # Check cache + cached = redis.get(cache_key) + if cached: + return cached + + # Fetch actor JSON + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + actor_url, + headers={"Accept": "application/activity+json, application/ld+json"}, + ) + if resp.status_code != 200: + log.warning("Failed to fetch actor %s: %d", actor_url, resp.status_code) + return None + data = resp.json() + except Exception: + log.warning("Error fetching actor %s", actor_url, exc_info=True) + return None + + pub_key_pem = (data.get("publicKey") or {}).get("publicKeyPem") + if not pub_key_pem: + log.warning("No publicKey in actor %s", actor_url) + return None + + # Cache it + redis.set(cache_key, pub_key_pem, ex=_KEY_CACHE_TTL) + return pub_key_pem + + +@router.post("/inbox") +async def inbox(request: Request): + """Receive signed AP activities from the coop platform.""" + sig_header = request.headers.get("signature", "") + if not sig_header: + return JSONResponse({"error": "missing signature"}, status_code=401) + + # Read body + body = await request.body() + + # Verify HTTP Signature + actor_url = parse_key_id(sig_header) + if not actor_url: + return JSONResponse({"error": "invalid keyId"}, status_code=401) + + pub_key = await _fetch_actor_public_key(actor_url) + if not pub_key: + return JSONResponse({"error": "could not fetch public key"}, status_code=401) + + req_headers = dict(request.headers) + path = request.url.path + valid = verify_request_signature( + public_key_pem=pub_key, + signature_header=sig_header, + method="POST", + path=path, + headers=req_headers, + ) + if not valid: + log.warning("Invalid signature from %s", actor_url) + return JSONResponse({"error": "invalid signature"}, status_code=401) + + # Parse and dispatch + try: + activity = await request.json() + except Exception: + return JSONResponse({"error": "invalid json"}, status_code=400) + + activity_type = activity.get("type", "") + log.info("Inbox received: %s from %s", activity_type, actor_url) + + if activity_type == "rose:DeviceAuth": + _handle_device_auth(activity) + + # Always 202 — AP convention + return JSONResponse({"status": "accepted"}, status_code=202) + + +def _handle_device_auth(activity: dict) -> None: + """Set or delete did_auth:{device_id} in local Redis.""" + obj = activity.get("object", {}) + device_id = obj.get("device_id", "") + action = obj.get("action", "") + + if not device_id: + log.warning("rose:DeviceAuth missing device_id") + return + + redis = get_redis_client() + if action == "login": + redis.set(f"did_auth:{device_id}", str(time.time()), ex=30 * 24 * 3600) + log.info("did_auth set for device %s...", device_id[:16]) + elif action == "logout": + redis.delete(f"did_auth:{device_id}") + log.info("did_auth cleared for device %s...", device_id[:16]) + else: + log.warning("rose:DeviceAuth unknown action: %s", action) diff --git a/app/routers/oembed.py b/app/routers/oembed.py new file mode 100644 index 0000000..615dfda --- /dev/null +++ b/app/routers/oembed.py @@ -0,0 +1,74 @@ +"""Art-DAG oEmbed endpoint. + +Returns oEmbed JSON responses for Art-DAG content (media, recipes, effects, runs). +""" + +import os + +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse + +router = APIRouter() + + +@router.get("/oembed") +async def oembed(request: Request): + url = request.query_params.get("url", "") + if not url: + return JSONResponse({"error": "url parameter required"}, status_code=400) + + # Parse URL to extract content type and CID + # URL patterns: /cache/{cid}, /recipes/{cid}, /effects/{cid}, /runs/{cid} + from urllib.parse import urlparse + + parsed = urlparse(url) + parts = [p for p in parsed.path.strip("/").split("/") if p] + + if len(parts) < 2: + return JSONResponse({"error": "could not parse content URL"}, status_code=404) + + content_type = parts[0].rstrip("s") # recipes -> recipe, runs -> run + cid = parts[1] + + import database + + title = cid[:16] + thumbnail_url = None + + # Look up metadata + items = await database.get_item_types(cid) + if items: + meta = items[0] + title = meta.get("filename") or meta.get("description") or title + + # Try friendly name + actor_id = meta.get("actor_id") + if actor_id: + friendly = await database.get_friendly_name_by_cid(actor_id, cid) + if friendly: + title = friendly.get("display_name") or friendly.get("base_name", title) + + # Media items get a thumbnail + if meta.get("type") == "media": + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + thumbnail_url = f"{artdag_url}/cache/{cid}/raw" + + elif content_type == "run": + run = await database.get_run_cache(cid) + if run: + title = f"Run {cid[:12]}" + + artdag_url = os.getenv("APP_URL_ARTDAG", "https://celery-artdag.rose-ash.com") + + resp = { + "version": "1.0", + "type": "link", + "title": title, + "provider_name": "art-dag", + "provider_url": artdag_url, + "url": url, + } + if thumbnail_url: + resp["thumbnail_url"] = thumbnail_url + + return JSONResponse(resp) diff --git a/app/routers/recipes.py b/app/routers/recipes.py new file mode 100644 index 0000000..1a55397 --- /dev/null +++ b/app/routers/recipes.py @@ -0,0 +1,686 @@ +""" +Recipe management routes for L1 server. + +Handles recipe upload, listing, viewing, and execution. +""" + +import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import require_auth, get_current_user, get_templates, get_redis_client, get_cache_manager +from ..services.auth_service import AuthService +from ..services.recipe_service import RecipeService +from ..types import ( + CompiledNode, TransformedNode, Registry, Recipe, + is_variable_input, get_effect_cid, +) + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class RecipeUploadRequest(BaseModel): + content: str # S-expression or YAML + name: Optional[str] = None + description: Optional[str] = None + + +class RecipeRunRequest(BaseModel): + """Request to run a recipe with variable inputs.""" + inputs: Dict[str, str] = {} # Map input names to CIDs + + +def get_recipe_service() -> RecipeService: + """Get recipe service instance.""" + return RecipeService(get_redis_client(), get_cache_manager()) + + +def transform_node( + node: CompiledNode, + assets: Dict[str, Dict[str, Any]], + effects: Dict[str, Dict[str, Any]], +) -> TransformedNode: + """ + Transform a compiled node to artdag execution format. + + - Resolves asset references to CIDs for SOURCE nodes + - Resolves effect references to CIDs for EFFECT nodes + - Renames 'type' to 'node_type', 'id' to 'node_id' + """ + node_id = node.get("id", "") + config = dict(node.get("config", {})) # Copy to avoid mutation + + # Resolve asset references for SOURCE nodes + if node.get("type") == "SOURCE" and "asset" in config: + asset_name = config["asset"] + if asset_name in assets: + config["cid"] = assets[asset_name].get("cid") + + # Resolve effect references for EFFECT nodes + if node.get("type") == "EFFECT" and "effect" in config: + effect_name = config["effect"] + if effect_name in effects: + config["cid"] = effects[effect_name].get("cid") + + return { + "node_id": node_id, + "node_type": node.get("type", "EFFECT"), + "config": config, + "inputs": node.get("inputs", []), + "name": node.get("name"), + } + + +def build_input_name_mapping( + nodes: Dict[str, TransformedNode], +) -> Dict[str, str]: + """ + Build a mapping from input names to node IDs for variable inputs. + + Variable inputs can be referenced by: + - node_id directly + - config.name (e.g., "Second Video") + - snake_case version (e.g., "second_video") + - kebab-case version (e.g., "second-video") + - node.name (def binding name) + """ + input_name_to_node: Dict[str, str] = {} + + for node_id, node in nodes.items(): + if node.get("node_type") != "SOURCE": + continue + + config = node.get("config", {}) + if not is_variable_input(config): + continue + + # Map by node_id + input_name_to_node[node_id] = node_id + + # Map by config.name + name = config.get("name") + if name: + input_name_to_node[name] = node_id + input_name_to_node[name.lower().replace(" ", "_")] = node_id + input_name_to_node[name.lower().replace(" ", "-")] = node_id + + # Map by node.name (def binding) + node_name = node.get("name") + if node_name: + input_name_to_node[node_name] = node_id + input_name_to_node[node_name.replace("-", "_")] = node_id + + return input_name_to_node + + +def bind_inputs( + nodes: Dict[str, TransformedNode], + input_name_to_node: Dict[str, str], + user_inputs: Dict[str, str], +) -> List[str]: + """ + Bind user-provided input CIDs to source nodes. + + Returns list of warnings for inputs that couldn't be bound. + """ + warnings: List[str] = [] + + for input_name, cid in user_inputs.items(): + # Try direct node ID match first + if input_name in nodes: + node = nodes[input_name] + if node.get("node_type") == "SOURCE": + node["config"]["cid"] = cid + logger.info(f"Bound input {input_name} directly to node, cid={cid[:16]}...") + continue + + # Try input name lookup + if input_name in input_name_to_node: + node_id = input_name_to_node[input_name] + node = nodes[node_id] + node["config"]["cid"] = cid + logger.info(f"Bound input {input_name} via lookup to node {node_id}, cid={cid[:16]}...") + continue + + # Input not found + warnings.append(f"Input '{input_name}' not found in recipe") + logger.warning(f"Input {input_name} not found in nodes or input_name_to_node") + + return warnings + + +async def resolve_friendly_names_in_registry( + registry: dict, + actor_id: str, +) -> dict: + """ + Resolve friendly names to CIDs in the registry. + + Friendly names are identified by containing a space (e.g., "brightness 01hw3x9k") + or by not being a valid CID format. + """ + from ..services.naming_service import get_naming_service + import re + + naming = get_naming_service() + resolved = {"assets": {}, "effects": {}} + + # CID patterns: IPFS CID (Qm..., bafy...) or SHA256 hash (64 hex chars) + cid_pattern = re.compile(r'^(Qm[a-zA-Z0-9]{44}|bafy[a-zA-Z0-9]+|[a-f0-9]{64})$') + + for asset_name, asset_info in registry.get("assets", {}).items(): + cid = asset_info.get("cid", "") + if cid and not cid_pattern.match(cid): + # Looks like a friendly name, resolve it + resolved_cid = await naming.resolve(actor_id, cid, item_type="media") + if resolved_cid: + asset_info = dict(asset_info) + asset_info["cid"] = resolved_cid + asset_info["_resolved_from"] = cid + resolved["assets"][asset_name] = asset_info + + for effect_name, effect_info in registry.get("effects", {}).items(): + cid = effect_info.get("cid", "") + if cid and not cid_pattern.match(cid): + # Looks like a friendly name, resolve it + resolved_cid = await naming.resolve(actor_id, cid, item_type="effect") + if resolved_cid: + effect_info = dict(effect_info) + effect_info["cid"] = resolved_cid + effect_info["_resolved_from"] = cid + resolved["effects"][effect_name] = effect_info + + return resolved + + +async def prepare_dag_for_execution( + recipe: Recipe, + user_inputs: Dict[str, str], + actor_id: str = None, +) -> Tuple[str, List[str]]: + """ + Prepare a recipe DAG for execution by transforming nodes and binding inputs. + + Resolves friendly names to CIDs if actor_id is provided. + Returns (dag_json, warnings). + """ + recipe_dag = recipe.get("dag") + if not recipe_dag or not isinstance(recipe_dag, dict): + raise ValueError("Recipe has no DAG definition") + + # Deep copy to avoid mutating original + dag_copy = json.loads(json.dumps(recipe_dag)) + nodes = dag_copy.get("nodes", {}) + + # Get registry for resolving references + registry = recipe.get("registry", {}) + + # Resolve friendly names to CIDs + if actor_id and registry: + registry = await resolve_friendly_names_in_registry(registry, actor_id) + + assets = registry.get("assets", {}) if registry else {} + effects = registry.get("effects", {}) if registry else {} + + # Transform nodes from list to dict if needed + if isinstance(nodes, list): + nodes_dict: Dict[str, TransformedNode] = {} + for node in nodes: + node_id = node.get("id") + if node_id: + nodes_dict[node_id] = transform_node(node, assets, effects) + nodes = nodes_dict + dag_copy["nodes"] = nodes + + # Build input name mapping and bind user inputs + input_name_to_node = build_input_name_mapping(nodes) + logger.info(f"Input name to node mapping: {input_name_to_node}") + logger.info(f"User-provided inputs: {user_inputs}") + + warnings = bind_inputs(nodes, input_name_to_node, user_inputs) + + # Log final SOURCE node configs for debugging + for nid, n in nodes.items(): + if n.get("node_type") == "SOURCE": + logger.info(f"Final SOURCE node {nid}: config={n.get('config')}") + + # Transform output to output_id + if "output" in dag_copy: + dag_copy["output_id"] = dag_copy.pop("output") + + # Add metadata if not present + if "metadata" not in dag_copy: + dag_copy["metadata"] = {} + + return json.dumps(dag_copy), warnings + + +@router.post("/upload") +async def upload_recipe( + file: UploadFile = File(...), + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Upload a new recipe from S-expression or YAML file.""" + import yaml + + # Read content from the uploaded file + content = (await file.read()).decode("utf-8") + + # Detect format (skip comments starting with ;) + def is_sexp_format(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(') + return False + + is_sexp = is_sexp_format(content) + + try: + from artdag.sexp import compile_string, ParseError, CompileError + SEXP_AVAILABLE = True + except ImportError: + SEXP_AVAILABLE = False + + recipe_name = None + recipe_version = "1.0" + recipe_description = None + variable_inputs = [] + fixed_inputs = [] + + if is_sexp: + if not SEXP_AVAILABLE: + raise HTTPException(500, "S-expression recipes require artdag.sexp module (not installed on server)") + # Parse S-expression + try: + compiled = compile_string(content) + recipe_name = compiled.name + recipe_version = compiled.version + recipe_description = compiled.description + + for node in compiled.nodes: + if node.get("type") == "SOURCE": + config = node.get("config", {}) + if config.get("input"): + variable_inputs.append(config.get("name", node.get("id"))) + elif config.get("asset"): + fixed_inputs.append(config.get("asset")) + except Exception as e: + raise HTTPException(400, f"Parse error: {e}") + else: + # Parse YAML + try: + recipe_data = yaml.safe_load(content) + recipe_name = recipe_data.get("name") + recipe_version = recipe_data.get("version", "1.0") + recipe_description = recipe_data.get("description") + + inputs = recipe_data.get("inputs", {}) + for input_name, input_def in inputs.items(): + if isinstance(input_def, dict) and input_def.get("fixed"): + fixed_inputs.append(input_name) + else: + variable_inputs.append(input_name) + except yaml.YAMLError as e: + raise HTTPException(400, f"Invalid YAML: {e}") + + # Use filename as recipe name if not specified + if not recipe_name and file.filename: + recipe_name = file.filename.rsplit(".", 1)[0] + + recipe_id, error = await recipe_service.upload_recipe( + content=content, + uploader=ctx.actor_id, + name=recipe_name, + description=recipe_description, + ) + + if error: + raise HTTPException(400, error) + + return { + "recipe_id": recipe_id, + "name": recipe_name or "unnamed", + "version": recipe_version, + "variable_inputs": variable_inputs, + "fixed_inputs": fixed_inputs, + "message": "Recipe uploaded successfully", + } + + +@router.get("") +async def list_recipes( + request: Request, + offset: int = 0, + limit: int = 20, + recipe_service: RecipeService = Depends(get_recipe_service), + ctx: UserContext = Depends(require_auth), +): + """List available recipes.""" + recipes = await recipe_service.list_recipes(ctx.actor_id, offset=offset, limit=limit) + has_more = len(recipes) >= limit + + if wants_json(request): + return {"recipes": recipes, "offset": offset, "limit": limit, "has_more": has_more} + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "recipes/list.html", request, + recipes=recipes, + user=ctx, + nav_counts=nav_counts, + active_tab="recipes", + offset=offset, + limit=limit, + has_more=has_more, + ) + + +@router.get("/{recipe_id}") +async def get_recipe( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), + ctx: UserContext = Depends(require_auth), +): + """Get recipe details.""" + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + # Add friendly name if available + from ..services.naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(ctx.actor_id, recipe_id) + if friendly: + recipe["friendly_name"] = friendly["friendly_name"] + recipe["base_name"] = friendly["base_name"] + recipe["version_id"] = friendly["version_id"] + + if wants_json(request): + return recipe + + # Build DAG elements for visualization and convert nodes to steps format + dag_elements = [] + steps = [] + node_colors = { + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + } + + # Debug: log recipe structure + logger.info(f"Recipe keys: {list(recipe.keys())}") + + # Get nodes from dag - can be list or dict, can be under "dag" or directly on recipe + dag = recipe.get("dag", {}) + logger.info(f"DAG type: {type(dag)}, keys: {list(dag.keys()) if isinstance(dag, dict) else 'not dict'}") + nodes = dag.get("nodes", []) if isinstance(dag, dict) else [] + logger.info(f"Nodes from dag.nodes: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + # Also check for nodes directly on recipe (alternative formats) + if not nodes: + nodes = recipe.get("nodes", []) + logger.info(f"Nodes from recipe.nodes: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + if not nodes: + nodes = recipe.get("pipeline", []) + logger.info(f"Nodes from recipe.pipeline: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + if not nodes: + nodes = recipe.get("steps", []) + logger.info(f"Nodes from recipe.steps: {type(nodes)}, len: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + logger.info(f"Final nodes count: {len(nodes) if hasattr(nodes, '__len__') else 'N/A'}") + + # Convert list of nodes to steps format + if isinstance(nodes, list): + for node in nodes: + node_id = node.get("id", "") + node_type = node.get("type", "EFFECT") + inputs = node.get("inputs", []) + config = node.get("config", {}) + + steps.append({ + "id": node_id, + "name": node_id, + "type": node_type, + "inputs": inputs, + "params": config, + }) + + dag_elements.append({ + "data": { + "id": node_id, + "label": node_id, + "color": node_colors.get(node_type, "#6b7280"), + } + }) + for inp in inputs: + if isinstance(inp, str): + dag_elements.append({ + "data": {"source": inp, "target": node_id} + }) + elif isinstance(nodes, dict): + for node_id, node in nodes.items(): + node_type = node.get("type", "EFFECT") + inputs = node.get("inputs", []) + config = node.get("config", {}) + + steps.append({ + "id": node_id, + "name": node_id, + "type": node_type, + "inputs": inputs, + "params": config, + }) + + dag_elements.append({ + "data": { + "id": node_id, + "label": node_id, + "color": node_colors.get(node_type, "#6b7280"), + } + }) + for inp in inputs: + if isinstance(inp, str): + dag_elements.append({ + "data": {"source": inp, "target": node_id} + }) + + # Add steps to recipe for template + recipe["steps"] = steps + + # Use S-expression source if available + if "sexp" not in recipe: + recipe["sexp"] = "; No S-expression source available" + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "recipes/detail.html", request, + recipe=recipe, + dag_elements=dag_elements, + user=ctx, + nav_counts=nav_counts, + active_tab="recipes", + ) + + +@router.delete("/{recipe_id}") +async def delete_recipe( + recipe_id: str, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Delete a recipe.""" + success, error = await recipe_service.delete_recipe(recipe_id, ctx.actor_id) + if error: + raise HTTPException(400 if "Cannot" in error else 404, error) + return {"deleted": True, "recipe_id": recipe_id} + + +@router.post("/{recipe_id}/run") +async def run_recipe( + recipe_id: str, + req: RecipeRunRequest, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Run a recipe with given inputs.""" + from ..services.run_service import RunService + from ..dependencies import get_cache_manager + import database + + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + try: + # Create run using run service + run_service = RunService(database, get_redis_client(), get_cache_manager()) + + # Prepare DAG for execution (transform nodes, bind inputs, resolve friendly names) + dag_json = None + if recipe.get("dag"): + dag_json, warnings = await prepare_dag_for_execution(recipe, req.inputs, actor_id=ctx.actor_id) + for warning in warnings: + logger.warning(warning) + + run, error = await run_service.create_run( + recipe=recipe_id, # Use recipe hash as primary identifier + inputs=req.inputs, + use_dag=True, + dag_json=dag_json, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + recipe_name=recipe.get("name"), # Store name for display + recipe_sexp=recipe.get("sexp"), # S-expression for code-addressed execution + ) + + if error: + raise HTTPException(400, error) + + if not run: + raise HTTPException(500, "Run creation returned no result") + + return { + "run_id": run["run_id"] if isinstance(run, dict) else run.run_id, + "status": run.get("status", "pending") if isinstance(run, dict) else run.status, + "message": "Recipe execution started", + } + except HTTPException: + raise + except Exception as e: + logger.exception(f"Error running recipe {recipe_id}") + raise HTTPException(500, f"Run failed: {e}") + + +@router.get("/{recipe_id}/dag") +async def recipe_dag( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Get recipe DAG visualization data.""" + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + dag_elements = [] + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + } + + for i, step in enumerate(recipe.get("steps", [])): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + dag_elements.append({ + "data": {"source": inp, "target": step_id} + }) + + return {"elements": dag_elements} + + +@router.delete("/{recipe_id}/ui", response_class=HTMLResponse) +async def ui_discard_recipe( + recipe_id: str, + request: Request, + recipe_service: RecipeService = Depends(get_recipe_service), +): + """HTMX handler: discard a recipe.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Login required
', status_code=401) + + success, error = await recipe_service.delete_recipe(recipe_id, ctx.actor_id) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse( + '
Recipe deleted
' + '' + ) + + +@router.post("/{recipe_id}/publish") +async def publish_recipe( + recipe_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), + recipe_service: RecipeService = Depends(get_recipe_service), +): + """Publish recipe to L2 and IPFS.""" + from ..services.cache_service import CacheService + from ..dependencies import get_cache_manager + import database + + # Verify recipe exists + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + raise HTTPException(404, "Recipe not found") + + # Use cache service to publish (recipes are stored in cache) + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=recipe_id, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "published": True} diff --git a/app/routers/runs.py b/app/routers/runs.py new file mode 100644 index 0000000..29c7d25 --- /dev/null +++ b/app/routers/runs.py @@ -0,0 +1,1704 @@ +""" +Run management routes for L1 server. + +Handles run creation, status, listing, and detail views. +""" + +import asyncio +import json +import logging +from datetime import datetime, timezone +from typing import List, Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import ( + require_auth, get_templates, get_current_user, + get_redis_client, get_cache_manager +) +from ..services.run_service import RunService + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def plan_to_sexp(plan: dict, recipe_name: str = None) -> str: + """Convert a plan to S-expression format for display.""" + if not plan or not plan.get("steps"): + return ";; No plan available" + + lines = [] + lines.append(f'(plan "{recipe_name or "unknown"}"') + + # Group nodes by type for cleaner output + steps = plan.get("steps", []) + + for step in steps: + step_id = step.get("id", "?") + step_type = step.get("type", "EFFECT") + inputs = step.get("inputs", []) + config = step.get("config", {}) + + # Build the step S-expression + if step_type == "SOURCE": + if config.get("input"): + # Variable input + input_name = config.get("name", config.get("input", "input")) + lines.append(f' (source :input "{input_name}")') + elif config.get("asset"): + # Fixed asset + lines.append(f' (source {config.get("asset", step_id)})') + else: + lines.append(f' (source {step_id})') + elif step_type == "EFFECT": + effect_name = config.get("effect", step_id) + if inputs: + inp_str = " ".join(inputs) + lines.append(f' (-> {inp_str} (effect {effect_name}))') + else: + lines.append(f' (effect {effect_name})') + elif step_type == "SEQUENCE": + if inputs: + inp_str = " ".join(inputs) + lines.append(f' (sequence {inp_str})') + else: + lines.append(f' (sequence)') + else: + # Generic node + if inputs: + inp_str = " ".join(inputs) + lines.append(f' ({step_type.lower()} {inp_str})') + else: + lines.append(f' ({step_type.lower()} {step_id})') + + lines.append(')') + return "\n".join(lines) + +RUNS_KEY_PREFIX = "artdag:run:" + + +class RunRequest(BaseModel): + recipe: str + inputs: List[str] + output_name: Optional[str] = None + use_dag: bool = True + dag_json: Optional[str] = None + + +class RunStatus(BaseModel): + run_id: str + status: str + recipe: Optional[str] = None + inputs: Optional[List[str]] = None + output_name: Optional[str] = None + created_at: Optional[str] = None + completed_at: Optional[str] = None + output_cid: Optional[str] = None + username: Optional[str] = None + provenance_cid: Optional[str] = None + celery_task_id: Optional[str] = None + error: Optional[str] = None + plan_id: Optional[str] = None + plan_name: Optional[str] = None + step_results: Optional[Dict[str, Any]] = None + all_outputs: Optional[List[str]] = None + effects_commit: Optional[str] = None + effect_url: Optional[str] = None + infrastructure: Optional[Dict[str, Any]] = None + + +class StreamRequest(BaseModel): + """Request to run a streaming recipe.""" + recipe_sexp: str # The recipe S-expression content + output_name: str = "output.mp4" + duration: Optional[float] = None # Duration in seconds + fps: Optional[float] = None # FPS override + sources_sexp: Optional[str] = None # Sources config S-expression + audio_sexp: Optional[str] = None # Audio config S-expression + + +def get_run_service(): + """Get run service instance.""" + import database + return RunService(database, get_redis_client(), get_cache_manager()) + + +@router.post("", response_model=RunStatus) +async def create_run( + request: RunRequest, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Start a new rendering run. Checks cache before executing.""" + run, error = await run_service.create_run( + recipe=request.recipe, + inputs=request.inputs, + output_name=request.output_name, + use_dag=request.use_dag, + dag_json=request.dag_json, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + ) + + if error: + raise HTTPException(400, error) + + return run + + +@router.post("/stream", response_model=RunStatus) +async def create_stream_run( + request: StreamRequest, + req: Request, + ctx: UserContext = Depends(get_current_user), +): + """Start a streaming video render. + + The recipe_sexp should be a complete streaming recipe with + (stream ...) form defining the pipeline. + + Assets can be referenced by CID or friendly name in the recipe. + Requires authentication OR admin token in X-Admin-Token header. + """ + import uuid + import tempfile + import os + from pathlib import Path + import database + from tasks.streaming import run_stream + + # Check for admin token if no user auth + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = req.headers.get("X-Admin-Token") + admin_actor_id = req.headers.get("X-Actor-Id", "admin@local") + + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Use context actor_id or admin actor_id + actor_id = ctx.actor_id if ctx else admin_actor_id + + # Generate run ID + run_id = str(uuid.uuid4()) + + # Store recipe in cache so it appears on /recipes page + recipe_id = None + try: + cache_manager = get_cache_manager() + with tempfile.NamedTemporaryFile(delete=False, suffix=".sexp", mode="w") as tmp: + tmp.write(request.recipe_sexp) + tmp_path = Path(tmp.name) + + cached, ipfs_cid = cache_manager.put(tmp_path, node_type="recipe", move=True) + recipe_id = cached.cid + + # Extract recipe name from S-expression (look for (stream "name" ...) pattern) + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', request.recipe_sexp) + recipe_name = name_match.group(1) if name_match else f"stream-{run_id[:8]}" + + # Track ownership in item_types + await database.save_item_metadata( + cid=recipe_id, + actor_id=actor_id, + item_type="recipe", + description=f"Streaming recipe: {recipe_name}", + filename=f"{recipe_name}.sexp", + ) + + # Assign friendly name + from ..services.naming_service import get_naming_service + naming = get_naming_service() + await naming.assign_name( + cid=recipe_id, + actor_id=actor_id, + item_type="recipe", + display_name=recipe_name, + ) + + logger.info(f"Stored streaming recipe {recipe_id[:16]}... as '{recipe_name}'") + except Exception as e: + logger.warning(f"Failed to store recipe in cache: {e}") + # Continue anyway - run will still work, just won't appear in /recipes + + # Submit Celery task to GPU queue for hardware-accelerated rendering + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=request.recipe_sexp, + output_name=request.output_name, + duration=request.duration, + fps=request.fps, + actor_id=actor_id, + sources_sexp=request.sources_sexp, + audio_sexp=request.audio_sexp, + ), + queue='gpu', + ) + + # Store in database for durability + pending = await database.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe_id or "streaming", # Use recipe CID if available + inputs=[], # Streaming recipes don't have traditional inputs + actor_id=actor_id, + dag_json=request.recipe_sexp, # Store recipe content for viewing + output_name=request.output_name, + ) + + logger.info(f"Started stream run {run_id} with task {task.id}") + + return RunStatus( + run_id=run_id, + status="pending", + recipe=recipe_id or "streaming", + created_at=pending.get("created_at"), + celery_task_id=task.id, + ) + + +@router.get("/{run_id}") +async def get_run( + request: Request, + run_id: str, + run_service: RunService = Depends(get_run_service), +): + """Get status of a run.""" + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, f"Run {run_id} not found") + + # Only render HTML if browser explicitly requests it + if wants_html(request): + # Extract username from actor_id (format: @user@server) + actor_id = run.get("actor_id", "") + if actor_id and actor_id.startswith("@"): + parts = actor_id[1:].split("@") + run["username"] = parts[0] if parts else "Unknown" + else: + run["username"] = actor_id or "Unknown" + + # Helper to normalize input refs to just node IDs + def normalize_inputs(inputs): + """Convert input refs (may be dicts or strings) to list of node IDs.""" + result = [] + for inp in inputs: + if isinstance(inp, dict): + node_id = inp.get("node") or inp.get("input") or inp.get("id") + else: + node_id = inp + if node_id: + result.append(node_id) + return result + + # Try to load the recipe to show the plan + plan = None + plan_sexp = None # Native S-expression if available + recipe_ipfs_cid = None + recipe_id = run.get("recipe") + # Check for valid recipe ID (64-char hash, IPFS CIDv0 "Qm...", or CIDv1 "bafy...") + is_valid_recipe_id = recipe_id and ( + len(recipe_id) == 64 or + recipe_id.startswith("Qm") or + recipe_id.startswith("bafy") + ) + if is_valid_recipe_id: + try: + from ..services.recipe_service import RecipeService + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipe = await recipe_service.get_recipe(recipe_id) + if recipe: + # Use native S-expression if available (code is data!) + if recipe.get("sexp"): + plan_sexp = recipe["sexp"] + # Get IPFS CID for the recipe + recipe_ipfs_cid = recipe.get("ipfs_cid") + + # Build steps for DAG visualization + dag = recipe.get("dag", {}) + nodes = dag.get("nodes", []) + + steps = [] + if isinstance(nodes, list): + for node in nodes: + node_id = node.get("id", "") + steps.append({ + "id": node_id, + "name": node_id, + "type": node.get("type", "EFFECT"), + "status": "completed", # Run completed + "inputs": normalize_inputs(node.get("inputs", [])), + "config": node.get("config", {}), + }) + elif isinstance(nodes, dict): + for node_id, node in nodes.items(): + steps.append({ + "id": node_id, + "name": node_id, + "type": node.get("type", "EFFECT"), + "status": "completed", + "inputs": normalize_inputs(node.get("inputs", [])), + "config": node.get("config", {}), + }) + + if steps: + plan = {"steps": steps} + run["total_steps"] = len(steps) + run["executed"] = len(steps) if run.get("status") == "completed" else 0 + + # Use recipe name instead of hash for display (if not already set) + if recipe.get("name") and not run.get("recipe_name"): + run["recipe_name"] = recipe["name"] + except Exception as e: + logger.warning(f"Failed to load recipe for plan: {e}") + + # Handle streaming runs - detect by recipe_sexp content or legacy "streaming" marker + recipe_sexp_content = run.get("recipe_sexp") + is_streaming = run.get("recipe") == "streaming" # Legacy marker + if not is_streaming and recipe_sexp_content: + # Check if content starts with (stream after skipping comments + for line in recipe_sexp_content.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + is_streaming = stripped.startswith('(stream') + break + if is_streaming and recipe_sexp_content and not plan: + plan_sexp = recipe_sexp_content + plan = { + "steps": [{ + "id": "stream", + "type": "STREAM", + "name": "Streaming Recipe", + "inputs": [], + "config": {}, + "status": "completed" if run.get("status") == "completed" else "pending", + }] + } + run["total_steps"] = 1 + run["executed"] = 1 if run.get("status") == "completed" else 0 + + # Helper to convert simple type to MIME type prefix for template + def type_to_mime(simple_type: str) -> str: + if simple_type == "video": + return "video/mp4" + elif simple_type == "image": + return "image/jpeg" + elif simple_type == "audio": + return "audio/mpeg" + return None + + # Build artifacts list from output and inputs + artifacts = [] + output_media_type = None + if run.get("output_cid"): + # Detect media type using magic bytes, fall back to database item_type + output_cid = run["output_cid"] + media_type = None + + # Streaming runs (with ipfs_cid) are always video/mp4 + if run.get("ipfs_cid"): + media_type = "video/mp4" + output_media_type = media_type + else: + try: + from ..services.run_service import detect_media_type + cache_path = get_cache_manager().get_by_cid(output_cid) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + media_type = type_to_mime(simple_type) + output_media_type = media_type + except Exception: + pass + # Fall back to database item_type if local detection failed + if not media_type: + try: + import database + item_types = await database.get_item_types(output_cid, run.get("actor_id")) + if item_types: + media_type = type_to_mime(item_types[0].get("type")) + output_media_type = media_type + except Exception: + pass + artifacts.append({ + "cid": output_cid, + "step_name": "Output", + "media_type": media_type or "application/octet-stream", + }) + + # Build inputs list with media types + run_inputs = [] + if run.get("inputs"): + from ..services.run_service import detect_media_type + cache_manager = get_cache_manager() + for i, input_hash in enumerate(run["inputs"]): + media_type = None + try: + cache_path = cache_manager.get_by_cid(input_hash) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + media_type = type_to_mime(simple_type) + except Exception: + pass + run_inputs.append({ + "cid": input_hash, + "name": f"Input {i + 1}", + "media_type": media_type, + }) + + # Build DAG elements for visualization + dag_elements = [] + if plan and plan.get("steps"): + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + for i, step in enumerate(plan["steps"]): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + # Handle both string and dict inputs + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": { + "source": source, + "target": step_id, + } + }) + + # Use native S-expression if available, otherwise generate from plan + if not plan_sexp and plan: + plan_sexp = plan_to_sexp(plan, run.get("recipe_name")) + + from ..dependencies import get_nav_counts + user = await get_current_user(request) + nav_counts = await get_nav_counts(user.actor_id if user else None) + + templates = get_templates(request) + return render(templates, "runs/detail.html", request, + run=run, + plan=plan, + artifacts=artifacts, + run_inputs=run_inputs, + dag_elements=dag_elements, + output_media_type=output_media_type, + plan_sexp=plan_sexp, + recipe_ipfs_cid=recipe_ipfs_cid, + nav_counts=nav_counts, + active_tab="runs", + ) + + # Default to JSON for API clients + return run + + +@router.delete("/{run_id}") +async def discard_run( + run_id: str, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Discard (delete) a run and its outputs.""" + success, error = await run_service.discard_run(run_id, ctx.actor_id, ctx.username) + if error: + raise HTTPException(400 if "Cannot" in error else 404, error) + return {"discarded": True, "run_id": run_id} + + +@router.get("") +async def list_runs( + request: Request, + offset: int = 0, + limit: int = 20, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(get_current_user), +): + """List all runs for the current user.""" + import os + + # Check for admin token if no user auth + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = request.headers.get("X-Admin-Token") + admin_actor_id = request.headers.get("X-Actor-Id") + + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Use context actor_id or admin actor_id + actor_id = ctx.actor_id if ctx else admin_actor_id + if not actor_id: + raise HTTPException(400, "X-Actor-Id header required with admin token") + + runs = await run_service.list_runs(actor_id, offset=offset, limit=limit) + has_more = len(runs) >= limit + + if wants_json(request): + return {"runs": runs, "offset": offset, "limit": limit, "has_more": has_more} + + # Add media info for inline previews (only for HTML) + cache_manager = get_cache_manager() + from ..services.run_service import detect_media_type + + def type_to_mime(simple_type: str) -> str: + if simple_type == "video": + return "video/mp4" + elif simple_type == "image": + return "image/jpeg" + elif simple_type == "audio": + return "audio/mpeg" + return None + + for run in runs: + # Add output media info + if run.get("output_cid"): + # Streaming runs (with ipfs_cid) are always video/mp4 + if run.get("ipfs_cid"): + run["output_media_type"] = "video/mp4" + else: + try: + cache_path = cache_manager.get_by_cid(run["output_cid"]) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + run["output_media_type"] = type_to_mime(simple_type) + except Exception: + pass + + # Add input media info (first 3 inputs) + input_previews = [] + inputs = run.get("inputs", []) + if isinstance(inputs, list): + for input_hash in inputs[:3]: + preview = {"cid": input_hash, "media_type": None} + try: + cache_path = cache_manager.get_by_cid(input_hash) + if cache_path and cache_path.exists(): + simple_type = detect_media_type(cache_path) + preview["media_type"] = type_to_mime(simple_type) + except Exception: + pass + input_previews.append(preview) + run["input_previews"] = input_previews + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(actor_id) + + templates = get_templates(request) + return render(templates, "runs/list.html", request, + runs=runs, + user=ctx or {"actor_id": actor_id}, + nav_counts=nav_counts, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="runs", + ) + + +@router.get("/{run_id}/detail") +async def run_detail( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Run detail page with tabs for plan/analysis/artifacts.""" + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, f"Run {run_id} not found") + + # Get plan, artifacts, and analysis + plan = await run_service.get_run_plan(run_id) + artifacts = await run_service.get_run_artifacts(run_id) + analysis = await run_service.get_run_analysis(run_id) + + # Build DAG elements for visualization + dag_elements = [] + if plan and plan.get("steps"): + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + for i, step in enumerate(plan["steps"]): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + # Add edges from inputs (handle both string and dict formats) + for inp in step.get("inputs", []): + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": { + "source": source, + "target": step_id, + } + }) + + if wants_json(request): + return { + "run": run, + "plan": plan, + "artifacts": artifacts, + "analysis": analysis, + } + + # Extract plan_sexp for streaming runs + plan_sexp = plan.get("sexp") if plan else None + + templates = get_templates(request) + return render(templates, "runs/detail.html", request, + run=run, + plan=plan, + plan_sexp=plan_sexp, + artifacts=artifacts, + analysis=analysis, + dag_elements=dag_elements, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/plan") +async def run_plan( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Plan visualization as interactive DAG.""" + plan = await run_service.get_run_plan(run_id) + if not plan: + raise HTTPException(404, "Plan not found for this run") + + if wants_json(request): + return plan + + # Build DAG elements + dag_elements = [] + node_colors = { + "input": "#3b82f6", + "effect": "#8b5cf6", + "analyze": "#ec4899", + "transform": "#10b981", + "output": "#f59e0b", + "SOURCE": "#3b82f6", + "EFFECT": "#8b5cf6", + "SEQUENCE": "#ec4899", + } + + for i, step in enumerate(plan.get("steps", [])): + step_id = step.get("id", f"step-{i}") + dag_elements.append({ + "data": { + "id": step_id, + "label": step.get("name", f"Step {i+1}"), + "color": node_colors.get(step.get("type", "effect"), "#6b7280"), + } + }) + for inp in step.get("inputs", []): + # Handle both string and dict formats + if isinstance(inp, dict): + source = inp.get("node") or inp.get("input") or inp.get("id") + else: + source = inp + if source: + dag_elements.append({ + "data": {"source": source, "target": step_id} + }) + + templates = get_templates(request) + return render(templates, "runs/plan.html", request, + run_id=run_id, + plan=plan, + dag_elements=dag_elements, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/artifacts") +async def run_artifacts( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), + ctx: UserContext = Depends(require_auth), +): + """Get artifacts list for a run.""" + artifacts = await run_service.get_run_artifacts(run_id) + + if wants_json(request): + return {"artifacts": artifacts} + + templates = get_templates(request) + return render(templates, "runs/artifacts.html", request, + run_id=run_id, + artifacts=artifacts, + user=ctx, + active_tab="runs", + ) + + +@router.get("/{run_id}/plan/node/{cache_id}", response_class=HTMLResponse) +async def plan_node_detail( + run_id: str, + cache_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), +): + """HTMX partial: Get plan node detail by cache_id.""" + from artdag_common import render_fragment + + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('

Login required

', status_code=401) + + run = await run_service.get_run(run_id) + if not run: + return HTMLResponse('

Run not found

', status_code=404) + + plan = await run_service.get_run_plan(run_id) + if not plan: + return HTMLResponse('

Plan not found

') + + # Build lookups + steps_by_cache_id = {} + steps_by_step_id = {} + for s in plan.get("steps", []): + if s.get("cache_id"): + steps_by_cache_id[s["cache_id"]] = s + if s.get("step_id"): + steps_by_step_id[s["step_id"]] = s + + step = steps_by_cache_id.get(cache_id) + if not step: + return HTMLResponse(f'

Step not found

') + + cache_manager = get_cache_manager() + + # Node colors + node_colors = { + "SOURCE": "#3b82f6", "EFFECT": "#22c55e", "OUTPUT": "#a855f7", + "ANALYSIS": "#f59e0b", "_LIST": "#6366f1", "default": "#6b7280" + } + node_color = node_colors.get(step.get("node_type", "EFFECT"), node_colors["default"]) + + # Check cache status + has_cached = cache_manager.has_content(cache_id) if cache_id else False + + # Determine output media type + output_media_type = None + output_preview = False + if has_cached: + cache_path = cache_manager.get_content_path(cache_id) + if cache_path: + output_media_type = run_service.detect_media_type(cache_path) + output_preview = output_media_type in ('video', 'image', 'audio') + + # Check for IPFS CID + ipfs_cid = None + if run.step_results: + res = run.step_results.get(step.get("step_id")) + if isinstance(res, dict) and res.get("cid"): + ipfs_cid = res["cid"] + + # Build input previews + inputs = [] + for inp_step_id in step.get("input_steps", []): + inp_step = steps_by_step_id.get(inp_step_id) + if inp_step: + inp_cache_id = inp_step.get("cache_id", "") + inp_has_cached = cache_manager.has_content(inp_cache_id) if inp_cache_id else False + inp_media_type = None + if inp_has_cached: + inp_path = cache_manager.get_content_path(inp_cache_id) + if inp_path: + inp_media_type = run_service.detect_media_type(inp_path) + + inputs.append({ + "name": inp_step.get("name", inp_step_id[:12]), + "cache_id": inp_cache_id, + "media_type": inp_media_type, + "has_cached": inp_has_cached, + }) + + status = "cached" if (has_cached or ipfs_cid) else ("completed" if run.status == "completed" else "pending") + + templates = get_templates(request) + return HTMLResponse(render_fragment(templates, "runs/plan_node.html", + step=step, + cache_id=cache_id, + node_color=node_color, + status=status, + has_cached=has_cached, + output_preview=output_preview, + output_media_type=output_media_type, + ipfs_cid=ipfs_cid, + ipfs_gateway="https://ipfs.io/ipfs", + inputs=inputs, + config=step.get("config", {}), + )) + + +@router.delete("/{run_id}/ui", response_class=HTMLResponse) +async def ui_discard_run( + run_id: str, + request: Request, + run_service: RunService = Depends(get_run_service), +): + """HTMX handler: discard a run.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse( + '
Login required
', + status_code=401 + ) + + success, error = await run_service.discard_run(run_id, ctx.actor_id, ctx.username) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse( + '
Run discarded
' + '' + ) + + +@router.post("/{run_id}/publish") +async def publish_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), + run_service: RunService = Depends(get_run_service), +): + """Publish run output to L2 and IPFS.""" + from ..services.cache_service import CacheService + from ..dependencies import get_cache_manager + import database + + run = await run_service.get_run(run_id) + if not run: + raise HTTPException(404, "Run not found") + + # Check if run has output + output_cid = run.get("output_cid") + if not output_cid: + error = "Run has no output to publish" + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + # Use cache service to publish the output + cache_service = CacheService(database, get_cache_manager()) + ipfs_cid, error = await cache_service.publish_to_l2( + cid=output_cid, + actor_id=ctx.actor_id, + l2_server=ctx.l2_server, + auth_token=request.cookies.get("auth_token"), + ) + + if error: + if wants_html(request): + return HTMLResponse(f'{error}') + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse(f'Shared: {ipfs_cid[:16]}...') + + return {"ipfs_cid": ipfs_cid, "output_cid": output_cid, "published": True} + + +@router.post("/rerun/{recipe_id}", response_class=HTMLResponse) +async def rerun_recipe( + recipe_id: str, + request: Request, +): + """HTMX handler: run a recipe again. + + Fetches the recipe by CID and starts a new streaming run. + """ + import uuid + import database + from tasks.streaming import run_stream + + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse( + '
Login required
', + status_code=401 + ) + + # Fetch the recipe + try: + from ..services.recipe_service import RecipeService + recipe_service = RecipeService(get_redis_client(), get_cache_manager()) + recipe = await recipe_service.get_recipe(recipe_id) + if not recipe: + return HTMLResponse(f'
Recipe not found: {recipe_id[:16]}...
') + + # Get the S-expression content + recipe_sexp = recipe.get("sexp") + if not recipe_sexp: + return HTMLResponse('
Recipe has no S-expression content
') + + # Extract recipe name for output + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', recipe_sexp) + recipe_name = name_match.group(1) if name_match else f"stream" + + # Generate new run ID + run_id = str(uuid.uuid4()) + + # Extract duration from recipe if present (look for :duration pattern) + duration = None + duration_match = re.search(r':duration\s+(\d+(?:\.\d+)?)', recipe_sexp) + if duration_match: + duration = float(duration_match.group(1)) + + # Extract fps from recipe if present + fps = None + fps_match = re.search(r':fps\s+(\d+(?:\.\d+)?)', recipe_sexp) + if fps_match: + fps = float(fps_match.group(1)) + + # Submit Celery task to GPU queue + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=recipe_sexp, + output_name=f"{recipe_name}.mp4", + duration=duration, + fps=fps, + actor_id=ctx.actor_id, + sources_sexp=None, + audio_sexp=None, + ), + queue='gpu', + ) + + # Store in database + await database.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe_id, + inputs=[], + actor_id=ctx.actor_id, + dag_json=recipe_sexp, + output_name=f"{recipe_name}.mp4", + ) + + logger.info(f"Started rerun {run_id} for recipe {recipe_id[:16]}...") + + return HTMLResponse( + f'
Started new run
' + f'' + ) + + except Exception as e: + logger.error(f"Failed to rerun recipe {recipe_id}: {e}") + return HTMLResponse(f'
Error: {str(e)}
') + + +@router.delete("/admin/purge-failed") +async def purge_failed_runs( + request: Request, + ctx: UserContext = Depends(get_current_user), +): + """Delete all failed runs from pending_runs table. + + Requires authentication OR admin token in X-Admin-Token header. + """ + import database + import os + + # Check for admin token + admin_token = os.environ.get("ADMIN_TOKEN") + request_token = request.headers.get("X-Admin-Token") + + # Require either valid auth or admin token + if not ctx and (not admin_token or request_token != admin_token): + raise HTTPException(401, "Authentication required") + + # Get all failed runs + failed_runs = await database.list_pending_runs(status="failed") + + deleted = [] + for run in failed_runs: + run_id = run.get("run_id") + try: + await database.delete_pending_run(run_id) + deleted.append(run_id) + except Exception as e: + logger.warning(f"Failed to delete run {run_id}: {e}") + + logger.info(f"Purged {len(deleted)} failed runs") + return {"purged": len(deleted), "run_ids": deleted} + + +@router.post("/{run_id}/pause") +async def pause_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Pause a running render. Waits for current segment to complete. + + The render will checkpoint at the next segment boundary and stop. + """ + import database + from celery_app import app as celery_app + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] != 'running': + raise HTTPException(400, f"Can only pause running renders (current status: {pending['status']})") + + # Revoke the Celery task (soft termination via SIGTERM - allows cleanup) + celery_task_id = pending.get('celery_task_id') + if celery_task_id: + celery_app.control.revoke(celery_task_id, terminate=True, signal='SIGTERM') + logger.info(f"Sent SIGTERM to task {celery_task_id} for run {run_id}") + + # Update status to 'paused' + await database.update_pending_run_status(run_id, 'paused') + + return { + "run_id": run_id, + "status": "paused", + "checkpoint_frame": pending.get('checkpoint_frame'), + } + + +@router.post("/{run_id}/resume") +async def resume_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Resume a paused or failed run from its last checkpoint. + + The render will continue from the checkpoint frame. + """ + import database + from tasks.streaming import run_stream + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] not in ('failed', 'paused'): + raise HTTPException(400, f"Can only resume failed/paused runs (current status: {pending['status']})") + + if not pending.get('checkpoint_frame'): + raise HTTPException(400, "No checkpoint available - use restart instead") + + if not pending.get('resumable', True): + raise HTTPException(400, "Run checkpoint is corrupted - use restart instead") + + # Submit new Celery task with resume=True + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=pending.get('dag_json', ''), # Recipe is stored in dag_json + output_name=pending.get('output_name', 'output.mp4'), + actor_id=pending.get('actor_id'), + resume=True, + ), + queue='gpu', + ) + + # Update status and celery_task_id + await database.update_pending_run_status(run_id, 'running') + + # Update the celery_task_id manually since create_pending_run isn't called + async with database.pool.acquire() as conn: + await conn.execute( + "UPDATE pending_runs SET celery_task_id = $2, updated_at = NOW() WHERE run_id = $1", + run_id, task.id + ) + + logger.info(f"Resumed run {run_id} from frame {pending.get('checkpoint_frame')} with task {task.id}") + + return { + "run_id": run_id, + "status": "running", + "celery_task_id": task.id, + "resumed_from_frame": pending.get('checkpoint_frame'), + } + + +@router.post("/{run_id}/restart") +async def restart_run( + run_id: str, + request: Request, + ctx: UserContext = Depends(require_auth), +): + """Restart a failed/paused run from the beginning (discard checkpoint). + + All progress will be lost. Use resume instead to continue from checkpoint. + """ + import database + from tasks.streaming import run_stream + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + if pending['status'] not in ('failed', 'paused'): + raise HTTPException(400, f"Can only restart failed/paused runs (current status: {pending['status']})") + + # Clear checkpoint data + await database.clear_run_checkpoint(run_id) + + # Submit new Celery task (without resume) + task = run_stream.apply_async( + kwargs=dict( + run_id=run_id, + recipe_sexp=pending.get('dag_json', ''), # Recipe is stored in dag_json + output_name=pending.get('output_name', 'output.mp4'), + actor_id=pending.get('actor_id'), + resume=False, + ), + queue='gpu', + ) + + # Update status and celery_task_id + await database.update_pending_run_status(run_id, 'running') + + async with database.pool.acquire() as conn: + await conn.execute( + "UPDATE pending_runs SET celery_task_id = $2, updated_at = NOW() WHERE run_id = $1", + run_id, task.id + ) + + logger.info(f"Restarted run {run_id} from beginning with task {task.id}") + + return { + "run_id": run_id, + "status": "running", + "celery_task_id": task.id, + } + + +@router.get("/{run_id}/stream") +async def stream_run_output( + run_id: str, + request: Request, +): + """Stream the video output of a running render. + + For IPFS HLS streams, redirects to the IPFS gateway playlist. + For local HLS streams, redirects to the m3u8 playlist. + For legacy MP4 streams, returns the file directly. + """ + from fastapi.responses import StreamingResponse, FileResponse, RedirectResponse + from pathlib import Path + import os + import database + from celery_app import app as celery_app + + await database.init_db() + + # Check for IPFS HLS streaming first (distributed P2P streaming) + pending = await database.get_pending_run(run_id) + if pending and pending.get("celery_task_id"): + task_id = pending["celery_task_id"] + result = celery_app.AsyncResult(task_id) + if result.ready() and isinstance(result.result, dict): + ipfs_playlist_url = result.result.get("ipfs_playlist_url") + if ipfs_playlist_url: + logger.info(f"Redirecting to IPFS stream: {ipfs_playlist_url}") + return RedirectResponse(url=ipfs_playlist_url, status_code=302) + + cache_dir = os.environ.get("CACHE_DIR", "/data/cache") + stream_dir = Path(cache_dir) / "streaming" / run_id + + # Check for local HLS output + hls_playlist = stream_dir / "stream.m3u8" + if hls_playlist.exists(): + # Redirect to the HLS playlist endpoint + return RedirectResponse( + url=f"/runs/{run_id}/hls/stream.m3u8", + status_code=302 + ) + + # Fall back to legacy MP4 streaming + stream_path = stream_dir / "output.mp4" + if not stream_path.exists(): + raise HTTPException(404, "Stream not available yet") + + file_size = stream_path.stat().st_size + if file_size == 0: + raise HTTPException(404, "Stream not ready") + + return FileResponse( + path=str(stream_path), + media_type="video/mp4", + headers={ + "Accept-Ranges": "bytes", + "Cache-Control": "no-cache, no-store, must-revalidate", + "X-Content-Size": str(file_size), + } + ) + + +@router.get("/{run_id}/hls/{filename:path}") +async def serve_hls_content( + run_id: str, + filename: str, + request: Request, +): + """Serve HLS playlist and segments for live streaming. + + Serves stream.m3u8 (playlist) and segment_*.ts files. + The playlist updates as new segments are rendered. + + If files aren't found locally, proxies to the GPU worker (if configured). + """ + from fastapi.responses import FileResponse, StreamingResponse + from pathlib import Path + import os + import httpx + + cache_dir = os.environ.get("CACHE_DIR", "/data/cache") + stream_dir = Path(cache_dir) / "streaming" / run_id + file_path = stream_dir / filename + + # Security: ensure we're only serving files within stream_dir + try: + file_path_resolved = file_path.resolve() + stream_dir_resolved = stream_dir.resolve() + if stream_dir.exists() and not str(file_path_resolved).startswith(str(stream_dir_resolved)): + raise HTTPException(403, "Invalid path") + except Exception: + pass # Allow proxy fallback + + # Determine content type + if filename.endswith(".m3u8"): + media_type = "application/vnd.apple.mpegurl" + headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Access-Control-Allow-Origin": "*", + } + elif filename.endswith(".ts"): + media_type = "video/mp2t" + headers = { + "Cache-Control": "public, max-age=3600", + "Access-Control-Allow-Origin": "*", + } + else: + raise HTTPException(400, "Invalid file type") + + # For playlist requests, check IPFS first (redirect to IPFS gateway) + if filename == "stream.m3u8" and not file_path.exists(): + import database + from fastapi.responses import RedirectResponse + await database.init_db() + + # Check pending run for IPFS playlist + pending = await database.get_pending_run(run_id) + if pending and pending.get("ipfs_playlist_cid"): + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ipfs_url = f"{gateway}/{pending['ipfs_playlist_cid']}" + return RedirectResponse(url=ipfs_url, status_code=302) + + # Check completed run cache + run = await database.get_run_cache(run_id) + if run and run.get("ipfs_playlist_cid"): + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + ipfs_url = f"{gateway}/{run['ipfs_playlist_cid']}" + return RedirectResponse(url=ipfs_url, status_code=302) + + # Try local file first + if file_path.exists(): + return FileResponse( + path=str(file_path), + media_type=media_type, + headers=headers, + ) + + # Fallback: proxy to GPU worker if configured + gpu_worker_url = os.environ.get("GPU_WORKER_STREAM_URL") + if gpu_worker_url: + # Proxy request to GPU worker + proxy_url = f"{gpu_worker_url}/{run_id}/{filename}" + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get(proxy_url) + if resp.status_code == 200: + return StreamingResponse( + content=iter([resp.content]), + media_type=media_type, + headers=headers, + ) + except Exception as e: + logger.warning(f"GPU worker proxy failed: {e}") + + raise HTTPException(404, f"File not found: {filename}") + + +@router.get("/{run_id}/playlist.m3u8") +async def get_playlist(run_id: str, request: Request): + """Get live HLS master playlist for a streaming run. + + For multi-resolution streams: generates a master playlist with DYNAMIC quality URLs. + For single-resolution streams: returns the playlist directly from IPFS. + """ + import database + import os + import httpx + from fastapi.responses import Response + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + quality_playlists = pending.get("quality_playlists") + + # Multi-resolution stream: generate master playlist with dynamic quality URLs + if quality_playlists: + lines = ["#EXTM3U", "#EXT-X-VERSION:3"] + + for name, info in quality_playlists.items(): + if not info.get("cid"): + continue + + lines.append( + f"#EXT-X-STREAM-INF:BANDWIDTH={info['bitrate'] * 1000}," + f"RESOLUTION={info['width']}x{info['height']}," + f"NAME=\"{name}\"" + ) + # Use dynamic URL that fetches latest CID from database + lines.append(f"/runs/{run_id}/quality/{name}/playlist.m3u8") + + if len(lines) <= 2: + raise HTTPException(404, "No quality playlists available") + + playlist_content = "\n".join(lines) + "\n" + + else: + # Single-resolution stream: fetch directly from IPFS + ipfs_playlist_cid = pending.get("ipfs_playlist_cid") + if not ipfs_playlist_cid: + raise HTTPException(404, "HLS playlist not created - rendering likely failed") + + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={ipfs_playlist_cid}") + if resp.status_code != 200: + raise HTTPException(502, "Failed to fetch playlist from IPFS") + playlist_content = resp.text + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Rewrite IPFS URLs to use our proxy endpoint + import re + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + + playlist_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + playlist_content = re.sub( + r'/ipfs(?:-ts)?/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + + return Response( + content=playlist_content, + media_type="application/vnd.apple.mpegurl", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/quality/{quality}/playlist.m3u8") +async def get_quality_playlist(run_id: str, quality: str, request: Request): + """Get quality-level HLS playlist for a streaming run. + + Fetches the LATEST CID for this quality from the database, + so HLS.js always gets updated content. + """ + import database + import os + import httpx + from fastapi.responses import Response + + await database.init_db() + + pending = await database.get_pending_run(run_id) + if not pending: + raise HTTPException(404, "Run not found") + + quality_playlists = pending.get("quality_playlists") + if not quality_playlists or quality not in quality_playlists: + raise HTTPException(404, f"Quality '{quality}' not found") + + quality_cid = quality_playlists[quality].get("cid") + if not quality_cid: + raise HTTPException(404, f"Quality '{quality}' playlist not ready") + + # Fetch playlist from local IPFS node + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={quality_cid}") + if resp.status_code != 200: + raise HTTPException(502, f"Failed to fetch quality playlist from IPFS: {quality_cid}") + playlist_content = resp.text + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Rewrite segment URLs to use our proxy (segments are still static IPFS content) + import re + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + + # Replace absolute gateway URLs with our proxy + playlist_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + # Also handle /ipfs/ paths and /ipfs-ts/ paths + playlist_content = re.sub( + r'/ipfs(?:-ts)?/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + playlist_content + ) + + return Response( + content=playlist_content, + media_type="application/vnd.apple.mpegurl", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/ipfs-proxy/{cid}") +async def proxy_ipfs_content(run_id: str, cid: str, request: Request): + """Proxy IPFS content with no-cache headers for live streaming. + + This allows HLS.js to poll for updated playlists through us rather than + hitting static IPFS URLs directly. + """ + import os + import httpx + from fastapi.responses import Response + + ipfs_api = os.environ.get("IPFS_API_URL", "http://celery_ipfs:5001") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post(f"{ipfs_api}/api/v0/cat?arg={cid}") + if resp.status_code != 200: + raise HTTPException(502, f"Failed to fetch from IPFS: {cid}") + content = resp.content + except httpx.RequestError as e: + raise HTTPException(502, f"IPFS error: {e}") + + # Determine content type + if cid.endswith('.m3u8') or b'#EXTM3U' in content[:20]: + media_type = "application/vnd.apple.mpegurl" + # Rewrite any IPFS URLs in sub-playlists too + import re + text_content = content.decode('utf-8', errors='replace') + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://celery-artdag.rose-ash.com/ipfs") + text_content = re.sub( + rf'{re.escape(gateway)}/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + text_content + ) + text_content = re.sub( + r'/ipfs/([A-Za-z0-9]+)', + rf'/runs/{run_id}/ipfs-proxy/\1', + text_content + ) + content = text_content.encode('utf-8') + elif b'\x47' in content[:1]: # MPEG-TS sync byte + media_type = "video/mp2t" + else: + media_type = "application/octet-stream" + + return Response( + content=content, + media_type=media_type, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", + "Access-Control-Allow-Origin": "*", + } + ) + + +@router.get("/{run_id}/ipfs-stream") +async def get_ipfs_stream_info(run_id: str, request: Request): + """Get IPFS streaming info for a run. + + Returns the IPFS playlist URL and segment info if available. + This allows clients to stream directly from IPFS gateways. + """ + from celery_app import app as celery_app + import database + import os + + await database.init_db() + + # Try to get pending run to find the Celery task ID + pending = await database.get_pending_run(run_id) + from fastapi.responses import JSONResponse + no_cache_headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + + if not pending: + # Try completed runs + run = await database.get_run_cache(run_id) + if not run: + raise HTTPException(404, "Run not found") + # For completed runs, check if we have IPFS info stored + ipfs_cid = run.get("ipfs_cid") + if ipfs_cid: + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + return JSONResponse( + content={ + "run_id": run_id, + "status": "completed", + "ipfs_video_url": f"{gateway}/{ipfs_cid}", + }, + headers=no_cache_headers + ) + raise HTTPException(404, "No IPFS stream info available") + + task_id = pending.get("celery_task_id") + if not task_id: + raise HTTPException(404, "No task ID for this run") + + # Get the Celery task result + result = celery_app.AsyncResult(task_id) + + from fastapi.responses import JSONResponse + gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + no_cache_headers = { + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + + if result.ready(): + # Task is complete - check the result for IPFS playlist info + task_result = result.result + if isinstance(task_result, dict): + ipfs_playlist_cid = task_result.get("ipfs_playlist_cid") + ipfs_playlist_url = task_result.get("ipfs_playlist_url") + if ipfs_playlist_url: + return JSONResponse( + content={ + "run_id": run_id, + "status": "completed", + "ipfs_playlist_cid": ipfs_playlist_cid, + "ipfs_playlist_url": ipfs_playlist_url, + "segment_count": task_result.get("ipfs_segment_count", 0), + }, + headers=no_cache_headers + ) + + # Task is still running - check database for live playlist updates + ipfs_playlist_cid = pending.get("ipfs_playlist_cid") + + # Get task state and progress metadata + task_state = result.state + task_info = result.info if isinstance(result.info, dict) else {} + + response_data = { + "run_id": run_id, + "status": task_state.lower() if task_state else pending.get("status", "pending"), + } + + # Add progress metadata if available + if task_info: + if "progress" in task_info: + response_data["progress"] = task_info["progress"] + if "frame" in task_info: + response_data["frame"] = task_info["frame"] + if "total_frames" in task_info: + response_data["total_frames"] = task_info["total_frames"] + if "percent" in task_info: + response_data["percent"] = task_info["percent"] + + if ipfs_playlist_cid: + response_data["ipfs_playlist_cid"] = ipfs_playlist_cid + response_data["ipfs_playlist_url"] = f"{gateway}/{ipfs_playlist_cid}" + else: + response_data["message"] = "IPFS streaming info not yet available" + + # No caching for live streaming data + return JSONResponse( + content=response_data, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + ) diff --git a/app/routers/storage.py b/app/routers/storage.py new file mode 100644 index 0000000..b8f2fc8 --- /dev/null +++ b/app/routers/storage.py @@ -0,0 +1,264 @@ +""" +Storage provider routes for L1 server. + +Manages user storage backends (Pinata, web3.storage, local, etc.) +""" + +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException, Form +from fastapi.responses import HTMLResponse, RedirectResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json +from artdag_common.middleware.auth import UserContext + +from ..dependencies import get_database, get_current_user, require_auth, get_templates +from ..services.storage_service import StorageService, STORAGE_PROVIDERS_INFO, VALID_PROVIDER_TYPES + +router = APIRouter() + + +# Import storage_providers module +import storage_providers as sp_module + + +def get_storage_service(): + """Get storage service instance.""" + import database + return StorageService(database, sp_module) + + +class AddStorageRequest(BaseModel): + provider_type: str + config: Dict[str, Any] + capacity_gb: int = 5 + provider_name: Optional[str] = None + + +class UpdateStorageRequest(BaseModel): + config: Optional[Dict[str, Any]] = None + capacity_gb: Optional[int] = None + is_active: Optional[bool] = None + + +@router.get("") +async def list_storage( + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """List user's storage providers. HTML for browsers, JSON for API.""" + storages = await storage_service.list_storages(ctx.actor_id) + + if wants_json(request): + return {"storages": storages} + + # Render HTML template + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "storage/list.html", request, + storages=storages, + user=ctx, + nav_counts=nav_counts, + providers_info=STORAGE_PROVIDERS_INFO, + active_tab="storage", + ) + + +@router.post("") +async def add_storage( + req: AddStorageRequest, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Add a storage provider via API.""" + ctx = await require_auth(request) + + storage_id, error = await storage_service.add_storage( + actor_id=ctx.actor_id, + provider_type=req.provider_type, + config=req.config, + capacity_gb=req.capacity_gb, + provider_name=req.provider_name, + ) + + if error: + raise HTTPException(400, error) + + return {"id": storage_id, "message": "Storage provider added"} + + +@router.post("/add") +async def add_storage_form( + request: Request, + provider_type: str = Form(...), + provider_name: Optional[str] = Form(None), + description: Optional[str] = Form(None), + capacity_gb: int = Form(5), + api_key: Optional[str] = Form(None), + secret_key: Optional[str] = Form(None), + api_token: Optional[str] = Form(None), + project_id: Optional[str] = Form(None), + project_secret: Optional[str] = Form(None), + access_key: Optional[str] = Form(None), + bucket: Optional[str] = Form(None), + path: Optional[str] = Form(None), + storage_service: StorageService = Depends(get_storage_service), +): + """Add a storage provider via HTML form.""" + ctx = await get_current_user(request) + if not ctx: + return HTMLResponse('
Not authenticated
', status_code=401) + + # Build config from form + form_data = { + "api_key": api_key, + "secret_key": secret_key, + "api_token": api_token, + "project_id": project_id, + "project_secret": project_secret, + "access_key": access_key, + "bucket": bucket, + "path": path, + } + config, error = storage_service.build_config_from_form(provider_type, form_data) + + if error: + return HTMLResponse(f'
{error}
') + + storage_id, error = await storage_service.add_storage( + actor_id=ctx.actor_id, + provider_type=provider_type, + config=config, + capacity_gb=capacity_gb, + provider_name=provider_name, + description=description, + ) + + if error: + return HTMLResponse(f'
{error}
') + + return HTMLResponse(f''' +
Storage provider added successfully!
+ + ''') + + +@router.get("/{storage_id}") +async def get_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Get a specific storage provider.""" + ctx = await require_auth(request) + + storage = await storage_service.get_storage(storage_id, ctx.actor_id) + if not storage: + raise HTTPException(404, "Storage provider not found") + + return storage + + +@router.patch("/{storage_id}") +async def update_storage( + storage_id: int, + req: UpdateStorageRequest, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Update a storage provider.""" + ctx = await require_auth(request) + + success, error = await storage_service.update_storage( + storage_id=storage_id, + actor_id=ctx.actor_id, + config=req.config, + capacity_gb=req.capacity_gb, + is_active=req.is_active, + ) + + if error: + raise HTTPException(400, error) + + return {"message": "Storage provider updated"} + + +@router.delete("/{storage_id}") +async def delete_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """Remove a storage provider.""" + success, error = await storage_service.delete_storage(storage_id, ctx.actor_id) + + if error: + raise HTTPException(400, error) + + if wants_html(request): + return HTMLResponse("") + + return {"message": "Storage provider removed"} + + +@router.post("/{storage_id}/test") +async def test_storage( + storage_id: int, + request: Request, + storage_service: StorageService = Depends(get_storage_service), +): + """Test storage provider connectivity.""" + ctx = await get_current_user(request) + if not ctx: + if wants_html(request): + return HTMLResponse('Not authenticated', status_code=401) + raise HTTPException(401, "Not authenticated") + + success, message = await storage_service.test_storage(storage_id, ctx.actor_id) + + if wants_html(request): + color = "green" if success else "red" + return HTMLResponse(f'{message}') + + return {"success": success, "message": message} + + +@router.get("/type/{provider_type}") +async def storage_type_page( + provider_type: str, + request: Request, + storage_service: StorageService = Depends(get_storage_service), + ctx: UserContext = Depends(require_auth), +): + """Page for managing storage configs of a specific type.""" + if provider_type not in STORAGE_PROVIDERS_INFO: + raise HTTPException(404, "Invalid provider type") + + storages = await storage_service.list_by_type(ctx.actor_id, provider_type) + provider_info = STORAGE_PROVIDERS_INFO[provider_type] + + if wants_json(request): + return { + "provider_type": provider_type, + "provider_info": provider_info, + "storages": storages, + } + + from ..dependencies import get_nav_counts + nav_counts = await get_nav_counts(ctx.actor_id) + + templates = get_templates(request) + return render(templates, "storage/type.html", request, + provider_type=provider_type, + provider_info=provider_info, + storages=storages, + user=ctx, + nav_counts=nav_counts, + active_tab="storage", + ) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..76eba24 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,15 @@ +""" +L1 Server Services. + +Business logic layer between routers and repositories. +""" + +from .run_service import RunService +from .recipe_service import RecipeService +from .cache_service import CacheService + +__all__ = [ + "RunService", + "RecipeService", + "CacheService", +] diff --git a/app/services/auth_service.py b/app/services/auth_service.py new file mode 100644 index 0000000..3f3ce26 --- /dev/null +++ b/app/services/auth_service.py @@ -0,0 +1,138 @@ +""" +Auth Service - token management and user verification. +""" + +import hashlib +import base64 +import json +from typing import Optional, Dict, Any, TYPE_CHECKING + +import httpx + +from artdag_common.middleware.auth import UserContext +from ..config import settings + +if TYPE_CHECKING: + import redis + from starlette.requests import Request + + +# Token expiry (30 days to match token lifetime) +TOKEN_EXPIRY_SECONDS = 60 * 60 * 24 * 30 + +# Redis key prefixes +REVOKED_KEY_PREFIX = "artdag:revoked:" +USER_TOKENS_PREFIX = "artdag:user_tokens:" + + +class AuthService: + """Service for authentication and token management.""" + + def __init__(self, redis_client: "redis.Redis[bytes]") -> None: + self.redis = redis_client + + def register_user_token(self, username: str, token: str) -> None: + """Track a token for a user (for later revocation by username).""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{USER_TOKENS_PREFIX}{username}" + self.redis.sadd(key, token_hash) + self.redis.expire(key, TOKEN_EXPIRY_SECONDS) + + def revoke_token(self, token: str) -> bool: + """Add token to revocation set. Returns True if newly revoked.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) + return result is not None + + def revoke_token_hash(self, token_hash: str) -> bool: + """Add token hash to revocation set. Returns True if newly revoked.""" + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + result = self.redis.set(key, "1", ex=TOKEN_EXPIRY_SECONDS, nx=True) + return result is not None + + def revoke_all_user_tokens(self, username: str) -> int: + """Revoke all tokens for a user. Returns count revoked.""" + key = f"{USER_TOKENS_PREFIX}{username}" + token_hashes = self.redis.smembers(key) + count = 0 + for token_hash in token_hashes: + if self.revoke_token_hash( + token_hash.decode() if isinstance(token_hash, bytes) else token_hash + ): + count += 1 + self.redis.delete(key) + return count + + def is_token_revoked(self, token: str) -> bool: + """Check if token has been revoked.""" + token_hash = hashlib.sha256(token.encode()).hexdigest() + key = f"{REVOKED_KEY_PREFIX}{token_hash}" + return self.redis.exists(key) > 0 + + def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]: + """Decode JWT claims without verification.""" + try: + parts = token.split(".") + if len(parts) != 3: + return None + payload = parts[1] + # Add padding + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + return json.loads(base64.urlsafe_b64decode(payload)) + except (json.JSONDecodeError, ValueError): + return None + + def get_user_context_from_token(self, token: str) -> Optional[UserContext]: + """Extract user context from a token.""" + if self.is_token_revoked(token): + return None + + claims = self.decode_token_claims(token) + if not claims: + return None + + username = claims.get("username") or claims.get("sub") + actor_id = claims.get("actor_id") or claims.get("actor") + + if not username: + return None + + return UserContext( + username=username, + actor_id=actor_id or f"@{username}", + token=token, + l2_server=settings.l2_server, + ) + + async def verify_token_with_l2(self, token: str) -> Optional[UserContext]: + """Verify token with L2 server.""" + ctx = self.get_user_context_from_token(token) + if not ctx: + return None + + # If L2 server configured, verify token + if settings.l2_server: + try: + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{settings.l2_server}/auth/verify", + headers={"Authorization": f"Bearer {token}"}, + timeout=5.0, + ) + if resp.status_code != 200: + return None + except httpx.RequestError: + # L2 unavailable, trust the token + pass + + return ctx + + def get_user_from_cookie(self, request: "Request") -> Optional[UserContext]: + """Extract user context from auth cookie.""" + token = request.cookies.get("auth_token") + if not token: + return None + return self.get_user_context_from_token(token) diff --git a/app/services/cache_service.py b/app/services/cache_service.py new file mode 100644 index 0000000..9b7bcd8 --- /dev/null +++ b/app/services/cache_service.py @@ -0,0 +1,618 @@ +""" +Cache Service - business logic for cache and media management. +""" + +import asyncio +import json +import logging +import os +import subprocess +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +import httpx + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from database import Database + from cache_manager import L1CacheManager + + +def detect_media_type(cache_path: Path) -> str: + """Detect if file is image, video, or audio based on magic bytes.""" + try: + with open(cache_path, "rb") as f: + header = f.read(32) + except Exception: + return "unknown" + + # Video signatures + if header[:4] == b'\x1a\x45\xdf\xa3': # WebM/MKV + return "video" + if len(header) > 8 and header[4:8] == b'ftyp': # MP4/MOV + return "video" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'AVI ': # AVI + return "video" + + # Image signatures + if header[:8] == b'\x89PNG\r\n\x1a\n': # PNG + return "image" + if header[:2] == b'\xff\xd8': # JPEG + return "image" + if header[:6] in (b'GIF87a', b'GIF89a'): # GIF + return "image" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WEBP': # WebP + return "image" + + # Audio signatures + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WAVE': # WAV + return "audio" + if header[:3] == b'ID3' or header[:2] == b'\xff\xfb': # MP3 + return "audio" + if header[:4] == b'fLaC': # FLAC + return "audio" + + return "unknown" + + +def get_mime_type(path: Path) -> str: + """Get MIME type based on file magic bytes.""" + media_type = detect_media_type(path) + if media_type == "video": + try: + with open(path, "rb") as f: + header = f.read(12) + if header[:4] == b'\x1a\x45\xdf\xa3': + return "video/x-matroska" + return "video/mp4" + except Exception: + return "video/mp4" + elif media_type == "image": + try: + with open(path, "rb") as f: + header = f.read(8) + if header[:8] == b'\x89PNG\r\n\x1a\n': + return "image/png" + if header[:2] == b'\xff\xd8': + return "image/jpeg" + if header[:6] in (b'GIF87a', b'GIF89a'): + return "image/gif" + return "image/jpeg" + except Exception: + return "image/jpeg" + elif media_type == "audio": + return "audio/mpeg" + return "application/octet-stream" + + +class CacheService: + """ + Service for managing cached content. + + Handles content retrieval, metadata, and media type detection. + """ + + def __init__(self, database: "Database", cache_manager: "L1CacheManager") -> None: + self.db = database + self.cache = cache_manager + self.cache_dir = Path(os.environ.get("CACHE_DIR", "/tmp/artdag-cache")) + + async def get_cache_item(self, cid: str, actor_id: str = None) -> Optional[Dict[str, Any]]: + """Get cached item with full metadata for display.""" + # Get metadata from database first + meta = await self.db.load_item_metadata(cid, actor_id) + cache_item = await self.db.get_cache_item(cid) + + # Check if content exists locally + path = self.cache.get_by_cid(cid) if self.cache.has_content(cid) else None + + if path and path.exists(): + # Local file exists - detect type from file + media_type = detect_media_type(path) + mime_type = get_mime_type(path) + size = path.stat().st_size + else: + # File not local - check database for type info + # Try to get type from item_types table + media_type = "unknown" + mime_type = "application/octet-stream" + size = 0 + + if actor_id: + try: + item_types = await self.db.get_item_types(cid, actor_id) + if item_types: + media_type = item_types[0].get("type", "unknown") + if media_type == "video": + mime_type = "video/mp4" + elif media_type == "image": + mime_type = "image/png" + elif media_type == "audio": + mime_type = "audio/mpeg" + except Exception: + pass + + # If no local path but we have IPFS CID, content is available remotely + if not cache_item: + return None + + result = { + "cid": cid, + "path": str(path) if path else None, + "media_type": media_type, + "mime_type": mime_type, + "size": size, + "ipfs_cid": cache_item.get("ipfs_cid") if cache_item else None, + "meta": meta, + "remote_only": path is None or not path.exists(), + } + + # Unpack meta fields to top level for template convenience + if meta: + result["title"] = meta.get("title") + result["description"] = meta.get("description") + result["tags"] = meta.get("tags", []) + result["source_type"] = meta.get("source_type") + result["source_note"] = meta.get("source_note") + result["created_at"] = meta.get("created_at") + result["filename"] = meta.get("filename") + + # Get friendly name if actor_id provided + if actor_id: + from .naming_service import get_naming_service + naming = get_naming_service() + friendly = await naming.get_by_cid(actor_id, cid) + if friendly: + result["friendly_name"] = friendly["friendly_name"] + result["base_name"] = friendly["base_name"] + result["version_id"] = friendly["version_id"] + + return result + + async def check_access(self, cid: str, actor_id: str, username: str) -> bool: + """Check if user has access to content.""" + user_hashes = await self._get_user_cache_hashes(username, actor_id) + return cid in user_hashes + + async def _get_user_cache_hashes(self, username: str, actor_id: Optional[str] = None) -> set: + """Get all cache hashes owned by or associated with a user.""" + match_values = [username] + if actor_id: + match_values.append(actor_id) + + hashes = set() + + # Query database for items owned by user + if actor_id: + try: + db_items = await self.db.get_user_items(actor_id) + for item in db_items: + hashes.add(item["cid"]) + except Exception: + pass + + # Legacy: Files uploaded by user (JSON metadata) + if self.cache_dir.exists(): + for f in self.cache_dir.iterdir(): + if f.name.endswith('.meta.json'): + try: + with open(f, 'r') as mf: + meta = json.load(mf) + if meta.get("uploader") in match_values: + hashes.add(f.name.replace('.meta.json', '')) + except Exception: + pass + + # Files from user's runs (inputs and outputs) + runs = await self._list_user_runs(username, actor_id) + for run in runs: + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + hashes.update(inputs) + if run.get("output_cid"): + hashes.add(run["output_cid"]) + + return hashes + + async def _list_user_runs(self, username: str, actor_id: Optional[str]) -> List[Dict]: + """List runs for a user (helper for access check).""" + from ..dependencies import get_redis_client + import json + + redis = get_redis_client() + runs = [] + cursor = 0 + prefix = "artdag:run:" + + while True: + cursor, keys = redis.scan(cursor=cursor, match=f"{prefix}*", count=100) + for key in keys: + data = redis.get(key) + if data: + run = json.loads(data) + if run.get("actor_id") in (username, actor_id) or run.get("username") in (username, actor_id): + runs.append(run) + if cursor == 0: + break + + return runs + + async def get_raw_file(self, cid: str) -> Tuple[Optional[Path], Optional[str], Optional[str]]: + """Get raw file path, media type, and filename for download.""" + if not self.cache.has_content(cid): + return None, None, None + + path = self.cache.get_by_cid(cid) + if not path or not path.exists(): + return None, None, None + + media_type = detect_media_type(path) + mime = get_mime_type(path) + + # Determine extension + ext = "bin" + if media_type == "video": + try: + with open(path, "rb") as f: + header = f.read(12) + if header[:4] == b'\x1a\x45\xdf\xa3': + ext = "mkv" + else: + ext = "mp4" + except Exception: + ext = "mp4" + elif media_type == "image": + try: + with open(path, "rb") as f: + header = f.read(8) + if header[:8] == b'\x89PNG\r\n\x1a\n': + ext = "png" + else: + ext = "jpg" + except Exception: + ext = "jpg" + + filename = f"{cid}.{ext}" + return path, mime, filename + + async def get_as_mp4(self, cid: str) -> Tuple[Optional[Path], Optional[str]]: + """Get content as MP4, transcoding if necessary. Returns (path, error).""" + if not self.cache.has_content(cid): + return None, f"Content {cid} not in cache" + + path = self.cache.get_by_cid(cid) + if not path or not path.exists(): + return None, f"Content {cid} not in cache" + + # Check if video + media_type = detect_media_type(path) + if media_type != "video": + return None, "Content is not a video" + + # Check for cached MP4 + mp4_path = self.cache_dir / f"{cid}.mp4" + if mp4_path.exists(): + return mp4_path, None + + # Check if already MP4 format + try: + result = subprocess.run( + ["ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "format=format_name", "-of", "csv=p=0", str(path)], + capture_output=True, text=True, timeout=10 + ) + if "mp4" in result.stdout.lower() or "mov" in result.stdout.lower(): + return path, None + except Exception: + pass + + # Transcode to MP4 + transcode_path = self.cache_dir / f"{cid}.transcoding.mp4" + try: + result = subprocess.run( + ["ffmpeg", "-y", "-i", str(path), + "-c:v", "libx264", "-preset", "fast", "-crf", "23", + "-c:a", "aac", "-b:a", "128k", + "-movflags", "+faststart", + str(transcode_path)], + capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + return None, f"Transcoding failed: {result.stderr[:200]}" + + transcode_path.rename(mp4_path) + return mp4_path, None + + except subprocess.TimeoutExpired: + if transcode_path.exists(): + transcode_path.unlink() + return None, "Transcoding timed out" + except Exception as e: + if transcode_path.exists(): + transcode_path.unlink() + return None, f"Transcoding failed: {e}" + + async def get_metadata(self, cid: str, actor_id: str) -> Optional[Dict[str, Any]]: + """Get content metadata.""" + if not self.cache.has_content(cid): + return None + return await self.db.load_item_metadata(cid, actor_id) + + async def update_metadata( + self, + cid: str, + actor_id: str, + title: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + custom: Optional[Dict[str, Any]] = None, + ) -> Tuple[bool, Optional[str]]: + """Update content metadata. Returns (success, error).""" + if not self.cache.has_content(cid): + return False, "Content not found" + + # Build update dict + updates = {} + if title is not None: + updates["title"] = title + if description is not None: + updates["description"] = description + if tags is not None: + updates["tags"] = tags + if custom is not None: + updates["custom"] = custom + + try: + await self.db.update_item_metadata(cid, actor_id, **updates) + return True, None + except Exception as e: + return False, str(e) + + async def publish_to_l2( + self, + cid: str, + actor_id: str, + l2_server: str, + auth_token: str, + ) -> Tuple[Optional[str], Optional[str]]: + """Publish content to L2 and IPFS. Returns (ipfs_cid, error).""" + if not self.cache.has_content(cid): + return None, "Content not found" + + # Get IPFS CID + cache_item = await self.db.get_cache_item(cid) + ipfs_cid = cache_item.get("ipfs_cid") if cache_item else None + + # Get metadata for origin info + meta = await self.db.load_item_metadata(cid, actor_id) + origin = meta.get("origin") if meta else None + + if not origin or "type" not in origin: + return None, "Origin must be set before publishing" + + if not auth_token: + return None, "Authentication token required" + + # Call L2 publish-cache endpoint + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{l2_server}/assets/publish-cache", + headers={"Authorization": f"Bearer {auth_token}"}, + json={ + "cid": cid, + "ipfs_cid": ipfs_cid, + "asset_name": meta.get("title") or cid[:16], + "asset_type": detect_media_type(self.cache.get_by_cid(cid)), + "origin": origin, + "description": meta.get("description"), + "tags": meta.get("tags", []), + } + ) + resp.raise_for_status() + l2_result = resp.json() + except httpx.HTTPStatusError as e: + error_detail = str(e) + try: + error_detail = e.response.json().get("detail", str(e)) + except Exception: + pass + return None, f"L2 publish failed: {error_detail}" + except Exception as e: + return None, f"L2 publish failed: {e}" + + # Update local metadata with publish status + await self.db.save_l2_share( + cid=cid, + actor_id=actor_id, + l2_server=l2_server, + asset_name=meta.get("title") or cid[:16], + content_type=detect_media_type(self.cache.get_by_cid(cid)) + ) + await self.db.update_item_metadata( + cid=cid, + actor_id=actor_id, + pinned=True, + pin_reason="published" + ) + + return l2_result.get("ipfs_cid") or ipfs_cid, None + + async def delete_content(self, cid: str, actor_id: str) -> Tuple[bool, Optional[str]]: + """ + Remove user's ownership link to cached content. + + This removes the item_types entry linking the user to the content. + The cached file is only deleted if no other users own it. + Returns (success, error). + """ + import logging + logger = logging.getLogger(__name__) + + # Check if pinned for this user + meta = await self.db.load_item_metadata(cid, actor_id) + if meta and meta.get("pinned"): + pin_reason = meta.get("pin_reason", "unknown") + return False, f"Cannot discard pinned item (reason: {pin_reason})" + + # Get the item type to delete the right ownership entry + item_types = await self.db.get_item_types(cid, actor_id) + if not item_types: + return False, "You don't own this content" + + # Remove user's ownership links (all types for this user) + for item in item_types: + item_type = item.get("type", "media") + await self.db.delete_item_type(cid, actor_id, item_type) + + # Remove friendly name + await self.db.delete_friendly_name(actor_id, cid) + + # Check if anyone else still owns this content + remaining_owners = await self.db.get_item_types(cid) + + # Only delete the actual file if no one owns it anymore + if not remaining_owners: + # Check deletion rules via cache_manager + can_delete, reason = self.cache.can_delete(cid) + if can_delete: + # Delete via cache_manager + self.cache.delete_by_cid(cid) + + # Clean up legacy metadata files + meta_path = self.cache_dir / f"{cid}.meta.json" + if meta_path.exists(): + meta_path.unlink() + mp4_path = self.cache_dir / f"{cid}.mp4" + if mp4_path.exists(): + mp4_path.unlink() + + # Delete from database + await self.db.delete_cache_item(cid) + + logger.info(f"Garbage collected content {cid[:16]}... (no remaining owners)") + else: + logger.info(f"Content {cid[:16]}... orphaned but cannot delete: {reason}") + + logger.info(f"Removed content {cid[:16]}... ownership for {actor_id}") + return True, None + + async def import_from_ipfs(self, ipfs_cid: str, actor_id: str) -> Tuple[Optional[str], Optional[str]]: + """Import content from IPFS. Returns (cid, error).""" + try: + import ipfs_client + + # Download from IPFS + legacy_dir = self.cache_dir / "legacy" + legacy_dir.mkdir(parents=True, exist_ok=True) + tmp_path = legacy_dir / f"import-{ipfs_cid[:16]}" + + if not ipfs_client.get_file(ipfs_cid, str(tmp_path)): + return None, f"Could not fetch CID {ipfs_cid} from IPFS" + + # Detect media type before storing + media_type = detect_media_type(tmp_path) + + # Store in cache + cached, new_ipfs_cid = self.cache.put(tmp_path, node_type="import", move=True) + cid = new_ipfs_cid or cached.cid # Prefer IPFS CID + + # Save to database with detected media type + await self.db.create_cache_item(cid, new_ipfs_cid) + await self.db.save_item_metadata( + cid=cid, + actor_id=actor_id, + item_type=media_type, # Use detected type for filtering + filename=f"ipfs-{ipfs_cid[:16]}" + ) + + return cid, None + except Exception as e: + return None, f"Import failed: {e}" + + async def upload_content( + self, + content: bytes, + filename: str, + actor_id: str, + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Upload content to cache. Returns (cid, ipfs_cid, error). + + Files are stored locally first for fast response, then uploaded + to IPFS in the background. + """ + import tempfile + + try: + # Write to temp file + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + + # Detect media type (video/image/audio) before moving file + media_type = detect_media_type(tmp_path) + + # Store locally AND upload to IPFS synchronously + # This ensures the IPFS CID is available immediately for distributed access + cached, ipfs_cid = self.cache.put(tmp_path, node_type="upload", move=True, skip_ipfs=False) + cid = ipfs_cid or cached.cid # Prefer IPFS CID, fall back to local hash + + # Save to database with media category type + await self.db.create_cache_item(cached.cid, ipfs_cid) + await self.db.save_item_metadata( + cid=cid, + actor_id=actor_id, + item_type=media_type, + filename=filename + ) + + if ipfs_cid: + logger.info(f"Uploaded to IPFS: {ipfs_cid[:16]}...") + else: + logger.warning(f"IPFS upload failed, using local hash: {cid[:16]}...") + + return cid, ipfs_cid, None + except Exception as e: + return None, None, f"Upload failed: {e}" + + async def list_media( + self, + actor_id: Optional[str] = None, + username: Optional[str] = None, + offset: int = 0, + limit: int = 24, + media_type: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """List media items in cache.""" + # Get items from database (uses item_types table) + items = await self.db.get_user_items( + actor_id=actor_id or username, + item_type=media_type, # "video", "image", "audio", or None for all + limit=limit, + offset=offset, + ) + + # Add friendly names to items + if actor_id: + from .naming_service import get_naming_service + naming = get_naming_service() + for item in items: + cid = item.get("cid") + if cid: + friendly = await naming.get_by_cid(actor_id, cid) + if friendly: + item["friendly_name"] = friendly["friendly_name"] + item["base_name"] = friendly["base_name"] + + return items + + # Legacy compatibility methods + def has_content(self, cid: str) -> bool: + """Check if content exists in cache.""" + return self.cache.has_content(cid) + + def get_ipfs_cid(self, cid: str) -> Optional[str]: + """Get IPFS CID for cached content.""" + return self.cache.get_ipfs_cid(cid) diff --git a/app/services/naming_service.py b/app/services/naming_service.py new file mode 100644 index 0000000..5678ab2 --- /dev/null +++ b/app/services/naming_service.py @@ -0,0 +1,234 @@ +""" +Naming service for friendly names. + +Handles: +- Name normalization (My Cool Effect -> my-cool-effect) +- Version ID generation (server-signed timestamps) +- Friendly name assignment and resolution +""" + +import hmac +import os +import re +import time +from typing import Optional, Tuple + +import database + + +# Base32 Crockford alphabet (excludes I, L, O, U to avoid confusion) +CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" + + +def _get_server_secret() -> bytes: + """Get server secret for signing version IDs.""" + secret = os.environ.get("SERVER_SECRET", "") + if not secret: + # Fall back to a derived secret from other env vars + # In production, SERVER_SECRET should be set explicitly + secret = os.environ.get("SECRET_KEY", "default-dev-secret") + return secret.encode("utf-8") + + +def _base32_crockford_encode(data: bytes) -> str: + """Encode bytes as base32-crockford (lowercase).""" + # Convert bytes to integer + num = int.from_bytes(data, "big") + if num == 0: + return CROCKFORD_ALPHABET[0] + + result = [] + while num > 0: + result.append(CROCKFORD_ALPHABET[num % 32]) + num //= 32 + + return "".join(reversed(result)) + + +def generate_version_id() -> str: + """ + Generate a version ID that is: + - Always increasing (timestamp-based prefix) + - Verifiable as originating from this server (HMAC suffix) + - Short and URL-safe (13 chars) + + Format: 6 bytes timestamp (ms) + 2 bytes HMAC = 8 bytes = 13 base32 chars + """ + timestamp_ms = int(time.time() * 1000) + timestamp_bytes = timestamp_ms.to_bytes(6, "big") + + # HMAC the timestamp with server secret + secret = _get_server_secret() + sig = hmac.new(secret, timestamp_bytes, "sha256").digest() + + # Combine: 6 bytes timestamp + 2 bytes HMAC signature + combined = timestamp_bytes + sig[:2] + + # Encode as base32-crockford + return _base32_crockford_encode(combined) + + +def normalize_name(name: str) -> str: + """ + Normalize a display name to a base name. + + - Lowercase + - Replace spaces and underscores with dashes + - Remove special characters (keep alphanumeric and dashes) + - Collapse multiple dashes + - Strip leading/trailing dashes + + Examples: + "My Cool Effect" -> "my-cool-effect" + "Brightness_V2" -> "brightness-v2" + "Test!!!Effect" -> "test-effect" + """ + # Lowercase + name = name.lower() + + # Replace spaces and underscores with dashes + name = re.sub(r"[\s_]+", "-", name) + + # Remove anything that's not alphanumeric or dash + name = re.sub(r"[^a-z0-9-]", "", name) + + # Collapse multiple dashes + name = re.sub(r"-+", "-", name) + + # Strip leading/trailing dashes + name = name.strip("-") + + return name or "unnamed" + + +def parse_friendly_name(friendly_name: str) -> Tuple[str, Optional[str]]: + """ + Parse a friendly name into base name and optional version. + + Args: + friendly_name: Name like "my-effect" or "my-effect 01hw3x9k" + + Returns: + Tuple of (base_name, version_id or None) + """ + parts = friendly_name.strip().split(" ", 1) + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + return base_name, version_id + + +def format_friendly_name(base_name: str, version_id: str) -> str: + """Format a base name and version into a full friendly name.""" + return f"{base_name} {version_id}" + + +def format_l2_name(actor_id: str, base_name: str, version_id: str) -> str: + """ + Format a friendly name for L2 sharing. + + Format: @user@domain base-name version-id + """ + return f"{actor_id} {base_name} {version_id}" + + +class NamingService: + """Service for managing friendly names.""" + + async def assign_name( + self, + cid: str, + actor_id: str, + item_type: str, + display_name: Optional[str] = None, + filename: Optional[str] = None, + ) -> dict: + """ + Assign a friendly name to content. + + Args: + cid: Content ID + actor_id: User ID + item_type: Type (recipe, effect, media) + display_name: Human-readable name (optional) + filename: Original filename (used as fallback for media) + + Returns: + Friendly name entry dict + """ + # Determine display name + if not display_name: + if filename: + # Use filename without extension + display_name = os.path.splitext(filename)[0] + else: + display_name = f"unnamed-{item_type}" + + # Normalize to base name + base_name = normalize_name(display_name) + + # Generate version ID + version_id = generate_version_id() + + # Create database entry + entry = await database.create_friendly_name( + actor_id=actor_id, + base_name=base_name, + version_id=version_id, + cid=cid, + item_type=item_type, + display_name=display_name, + ) + + return entry + + async def get_by_cid(self, actor_id: str, cid: str) -> Optional[dict]: + """Get friendly name entry by CID.""" + return await database.get_friendly_name_by_cid(actor_id, cid) + + async def resolve( + self, + actor_id: str, + name: str, + item_type: Optional[str] = None, + ) -> Optional[str]: + """ + Resolve a friendly name to a CID. + + Args: + actor_id: User ID + name: Friendly name ("base-name" or "base-name version") + item_type: Optional type filter + + Returns: + CID or None if not found + """ + return await database.resolve_friendly_name(actor_id, name, item_type) + + async def list_names( + self, + actor_id: str, + item_type: Optional[str] = None, + latest_only: bool = False, + ) -> list: + """List friendly names for a user.""" + return await database.list_friendly_names( + actor_id=actor_id, + item_type=item_type, + latest_only=latest_only, + ) + + async def delete(self, actor_id: str, cid: str) -> bool: + """Delete a friendly name entry.""" + return await database.delete_friendly_name(actor_id, cid) + + +# Module-level instance +_naming_service: Optional[NamingService] = None + + +def get_naming_service() -> NamingService: + """Get the naming service singleton.""" + global _naming_service + if _naming_service is None: + _naming_service = NamingService() + return _naming_service diff --git a/app/services/recipe_service.py b/app/services/recipe_service.py new file mode 100644 index 0000000..6b0a70d --- /dev/null +++ b/app/services/recipe_service.py @@ -0,0 +1,337 @@ +""" +Recipe Service - business logic for recipe management. + +Recipes are S-expressions stored in the content-addressed cache (and IPFS). +The recipe ID is the content hash of the file. +""" + +import tempfile +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +from artdag.sexp import compile_string, parse, serialize, CompileError, ParseError + +if TYPE_CHECKING: + import redis + from cache_manager import L1CacheManager + +from ..types import Recipe, CompiledDAG, VisualizationDAG, VisNode, VisEdge + + +class RecipeService: + """ + Service for managing recipes. + + Recipes are S-expressions stored in the content-addressed cache. + """ + + def __init__(self, redis: "redis.Redis", cache: "L1CacheManager") -> None: + # Redis kept for compatibility but not used for recipe storage + self.redis = redis + self.cache = cache + + async def get_recipe(self, recipe_id: str) -> Optional[Recipe]: + """Get a recipe by ID (content hash).""" + import yaml + import logging + logger = logging.getLogger(__name__) + + # Get from cache (content-addressed storage) + logger.info(f"get_recipe: Looking up recipe_id={recipe_id[:16]}...") + path = self.cache.get_by_cid(recipe_id) + logger.info(f"get_recipe: cache.get_by_cid returned path={path}") + if not path or not path.exists(): + logger.warning(f"get_recipe: Recipe {recipe_id[:16]}... not found in cache") + return None + + with open(path) as f: + content = f.read() + + # Detect format - check if it starts with ( after skipping comments + def is_sexp_format(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(') + return False + + import logging + logger = logging.getLogger(__name__) + + if is_sexp_format(content): + # Detect if this is a streaming recipe (starts with (stream ...)) + def is_streaming_recipe(text): + for line in text.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(stream') + return False + + if is_streaming_recipe(content): + # Streaming recipes have different format - parse manually + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', content) + recipe_name = name_match.group(1) if name_match else "streaming" + + recipe_data = { + "name": recipe_name, + "sexp": content, + "format": "sexp", + "type": "streaming", + "dag": {"nodes": []}, # Streaming recipes don't have traditional DAG + } + logger.info(f"Parsed streaming recipe {recipe_id[:16]}..., name: {recipe_name}") + else: + # Parse traditional (recipe ...) S-expression + try: + compiled = compile_string(content) + recipe_data = compiled.to_dict() + recipe_data["sexp"] = content + recipe_data["format"] = "sexp" + logger.info(f"Parsed sexp recipe {recipe_id[:16]}..., keys: {list(recipe_data.keys())}") + except (ParseError, CompileError) as e: + logger.warning(f"Failed to parse sexp recipe {recipe_id[:16]}...: {e}") + return {"error": str(e), "recipe_id": recipe_id} + else: + # Parse YAML + try: + recipe_data = yaml.safe_load(content) + if not isinstance(recipe_data, dict): + return {"error": "Invalid YAML: expected dictionary", "recipe_id": recipe_id} + recipe_data["yaml"] = content + recipe_data["format"] = "yaml" + except yaml.YAMLError as e: + return {"error": f"YAML parse error: {e}", "recipe_id": recipe_id} + + # Add the recipe_id to the data for convenience + recipe_data["recipe_id"] = recipe_id + + # Get IPFS CID if available + ipfs_cid = self.cache.get_ipfs_cid(recipe_id) + if ipfs_cid: + recipe_data["ipfs_cid"] = ipfs_cid + + # Compute step_count from nodes (handle both formats) + if recipe_data.get("format") == "sexp": + nodes = recipe_data.get("dag", {}).get("nodes", []) + else: + # YAML format: nodes might be at top level or under dag + nodes = recipe_data.get("nodes", recipe_data.get("dag", {}).get("nodes", [])) + recipe_data["step_count"] = len(nodes) if isinstance(nodes, (list, dict)) else 0 + + return recipe_data + + async def list_recipes(self, actor_id: Optional[str] = None, offset: int = 0, limit: int = 20) -> List[Recipe]: + """ + List recipes owned by a user. + + Queries item_types table for user's recipe links. + """ + import logging + import database + logger = logging.getLogger(__name__) + + recipes = [] + + if not actor_id: + logger.warning("list_recipes called without actor_id") + return [] + + # Get user's recipe CIDs from item_types + user_items = await database.get_user_items(actor_id, item_type="recipe", limit=1000) + recipe_cids = [item["cid"] for item in user_items] + logger.info(f"Found {len(recipe_cids)} recipe CIDs for user {actor_id}") + + for cid in recipe_cids: + recipe = await self.get_recipe(cid) + if recipe and not recipe.get("error"): + recipes.append(recipe) + elif recipe and recipe.get("error"): + logger.warning(f"Recipe {cid[:16]}... has error: {recipe.get('error')}") + + # Add friendly names + from .naming_service import get_naming_service + naming = get_naming_service() + for recipe in recipes: + recipe_id = recipe.get("recipe_id") + if recipe_id: + friendly = await naming.get_by_cid(actor_id, recipe_id) + if friendly: + recipe["friendly_name"] = friendly["friendly_name"] + recipe["base_name"] = friendly["base_name"] + + # Sort by name + recipes.sort(key=lambda r: r.get("name", "")) + + return recipes[offset:offset + limit] + + async def upload_recipe( + self, + content: str, + uploader: str, + name: str = None, + description: str = None, + ) -> Tuple[Optional[str], Optional[str]]: + """ + Upload a recipe from S-expression content. + + The recipe is stored in the cache and pinned to IPFS. + Returns (recipe_id, error_message). + """ + # Validate S-expression + try: + compiled = compile_string(content) + except ParseError as e: + return None, f"Parse error: {e}" + except CompileError as e: + return None, f"Compile error: {e}" + + # Write to temp file for caching + import logging + logger = logging.getLogger(__name__) + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".sexp", mode="w") as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + + # Store in cache (content-addressed, auto-pins to IPFS) + logger.info(f"upload_recipe: Storing recipe in cache from {tmp_path}") + cached, ipfs_cid = self.cache.put(tmp_path, node_type="recipe", move=True) + recipe_id = ipfs_cid or cached.cid # Prefer IPFS CID + logger.info(f"upload_recipe: Stored recipe, cached.cid={cached.cid[:16]}..., ipfs_cid={ipfs_cid[:16] if ipfs_cid else None}, recipe_id={recipe_id[:16]}...") + + # Track ownership in item_types and assign friendly name + if uploader: + import database + display_name = name or compiled.name or "unnamed-recipe" + + # Create item_types entry (ownership link) + await database.save_item_metadata( + cid=recipe_id, + actor_id=uploader, + item_type="recipe", + description=description, + filename=f"{display_name}.sexp", + ) + + # Assign friendly name + from .naming_service import get_naming_service + naming = get_naming_service() + await naming.assign_name( + cid=recipe_id, + actor_id=uploader, + item_type="recipe", + display_name=display_name, + ) + + return recipe_id, None + + except Exception as e: + return None, f"Failed to cache recipe: {e}" + + async def delete_recipe(self, recipe_id: str, actor_id: str = None) -> Tuple[bool, Optional[str]]: + """ + Remove user's ownership link to a recipe. + + This removes the item_types entry linking the user to the recipe. + The cached file is only deleted if no other users own it. + Returns (success, error_message). + """ + import database + + if not actor_id: + return False, "actor_id required" + + # Remove user's ownership link + try: + await database.delete_item_type(recipe_id, actor_id, "recipe") + + # Also remove friendly name + await database.delete_friendly_name(actor_id, recipe_id) + + # Try to garbage collect if no one owns it anymore + # (delete_cache_item only deletes if no item_types remain) + await database.delete_cache_item(recipe_id) + + return True, None + except Exception as e: + return False, f"Failed to delete: {e}" + + def parse_recipe(self, content: str) -> CompiledDAG: + """Parse recipe S-expression content.""" + compiled = compile_string(content) + return compiled.to_dict() + + def build_dag(self, recipe: Recipe) -> VisualizationDAG: + """ + Build DAG visualization data from recipe. + + Returns nodes and edges for Cytoscape.js. + """ + vis_nodes: List[VisNode] = [] + edges: List[VisEdge] = [] + + dag = recipe.get("dag", {}) + dag_nodes = dag.get("nodes", []) + output_node = dag.get("output") + + # Handle list format (compiled S-expression) + if isinstance(dag_nodes, list): + for node_def in dag_nodes: + node_id = node_def.get("id") + node_type = node_def.get("type", "EFFECT") + + vis_nodes.append({ + "data": { + "id": node_id, + "label": node_id, + "nodeType": node_type, + "isOutput": node_id == output_node, + } + }) + + for input_ref in node_def.get("inputs", []): + if isinstance(input_ref, dict): + source = input_ref.get("node") or input_ref.get("input") + else: + source = input_ref + + if source: + edges.append({ + "data": { + "source": source, + "target": node_id, + } + }) + + # Handle dict format + elif isinstance(dag_nodes, dict): + for node_id, node_def in dag_nodes.items(): + node_type = node_def.get("type", "EFFECT") + + vis_nodes.append({ + "data": { + "id": node_id, + "label": node_id, + "nodeType": node_type, + "isOutput": node_id == output_node, + } + }) + + for input_ref in node_def.get("inputs", []): + if isinstance(input_ref, dict): + source = input_ref.get("node") or input_ref.get("input") + else: + source = input_ref + + if source: + edges.append({ + "data": { + "source": source, + "target": node_id, + } + }) + + return {"nodes": vis_nodes, "edges": edges} diff --git a/app/services/run_service.py b/app/services/run_service.py new file mode 100644 index 0000000..5bfe19d --- /dev/null +++ b/app/services/run_service.py @@ -0,0 +1,1001 @@ +""" +Run Service - business logic for run management. + +Runs are content-addressed (run_id computed from inputs + recipe). +Completed runs are stored in PostgreSQL, not Redis. +In-progress runs are tracked via Celery task state. +""" + +import hashlib +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional, List, Dict, Any, Tuple, Union, TYPE_CHECKING + +if TYPE_CHECKING: + import redis + from cache_manager import L1CacheManager + from database import Database + +from ..types import RunResult + + +def compute_run_id(input_hashes: Union[List[str], Dict[str, str]], recipe: str, recipe_hash: Optional[str] = None) -> str: + """ + Compute a deterministic run_id from inputs and recipe. + + The run_id is a SHA3-256 hash of: + - Sorted input content hashes + - Recipe identifier (recipe_hash if provided, else "effect:{recipe}") + + This makes runs content-addressable: same inputs + recipe = same run_id. + """ + # Handle both list and dict inputs + if isinstance(input_hashes, dict): + sorted_inputs = sorted(input_hashes.values()) + else: + sorted_inputs = sorted(input_hashes) + + data = { + "inputs": sorted_inputs, + "recipe": recipe_hash or f"effect:{recipe}", + "version": "1", + } + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha3_256(json_str.encode()).hexdigest() + + +def detect_media_type(cache_path: Path) -> str: + """Detect if file is image, video, or audio based on magic bytes.""" + try: + with open(cache_path, "rb") as f: + header = f.read(32) + except Exception: + return "unknown" + + # Video signatures + if header[:4] == b'\x1a\x45\xdf\xa3': # WebM/MKV + return "video" + if len(header) > 8 and header[4:8] == b'ftyp': # MP4/MOV/M4A + # Check for audio-only M4A + if len(header) > 11 and header[8:12] in (b'M4A ', b'm4a '): + return "audio" + return "video" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'AVI ': # AVI + return "video" + + # Image signatures + if header[:8] == b'\x89PNG\r\n\x1a\n': # PNG + return "image" + if header[:2] == b'\xff\xd8': # JPEG + return "image" + if header[:6] in (b'GIF87a', b'GIF89a'): # GIF + return "image" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WEBP': # WebP + return "image" + + # Audio signatures + if header[:3] == b'ID3' or header[:2] == b'\xff\xfb': # MP3 + return "audio" + if header[:4] == b'fLaC': # FLAC + return "audio" + if header[:4] == b'OggS': # Ogg (could be audio or video, assume audio) + return "audio" + if header[:4] == b'RIFF' and len(header) > 12 and header[8:12] == b'WAVE': # WAV + return "audio" + + return "unknown" + + +class RunService: + """ + Service for managing recipe runs. + + Uses PostgreSQL for completed runs, Celery for task state. + Redis is only used for task_id mapping (ephemeral). + """ + + def __init__(self, database: "Database", redis: "redis.Redis[bytes]", cache: "L1CacheManager") -> None: + self.db = database + self.redis = redis # Only for task_id mapping + self.cache = cache + self.task_key_prefix = "artdag:task:" # run_id -> task_id mapping only + self.cache_dir = Path(os.environ.get("CACHE_DIR", "/tmp/artdag-cache")) + + def _ensure_inputs_list(self, inputs: Any) -> List[str]: + """Ensure inputs is a list, parsing JSON string if needed.""" + if inputs is None: + return [] + if isinstance(inputs, list): + return inputs + if isinstance(inputs, str): + try: + parsed = json.loads(inputs) + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + pass + return [] + return [] + + async def get_run(self, run_id: str) -> Optional[RunResult]: + """Get a run by ID. Checks database first, then Celery task state.""" + # Check database for completed run + cached = await self.db.get_run_cache(run_id) + if cached: + output_cid = cached.get("output_cid") + # Only return as completed if we have an output + # (runs with no output should be re-executed) + if output_cid: + # Also fetch recipe content from pending_runs for streaming runs + recipe_sexp = None + recipe_name = None + pending = await self.db.get_pending_run(run_id) + if pending: + recipe_sexp = pending.get("dag_json") + + # Extract recipe name from streaming recipe content + if recipe_sexp: + import re + name_match = re.search(r'\(stream\s+"([^"]+)"', recipe_sexp) + if name_match: + recipe_name = name_match.group(1) + + return { + "run_id": run_id, + "status": "completed", + "recipe": cached.get("recipe"), + "recipe_name": recipe_name, + "inputs": self._ensure_inputs_list(cached.get("inputs")), + "output_cid": output_cid, + "ipfs_cid": cached.get("ipfs_cid"), + "ipfs_playlist_cid": cached.get("ipfs_playlist_cid") or (pending.get("ipfs_playlist_cid") if pending else None), + "provenance_cid": cached.get("provenance_cid"), + "plan_cid": cached.get("plan_cid"), + "actor_id": cached.get("actor_id"), + "created_at": cached.get("created_at"), + "completed_at": cached.get("created_at"), + "recipe_sexp": recipe_sexp, + } + + # Check database for pending run + pending = await self.db.get_pending_run(run_id) + if pending: + task_id = pending.get("celery_task_id") + if task_id: + # Check actual Celery task state + from celery.result import AsyncResult + from celery_app import app as celery_app + + result = AsyncResult(task_id, app=celery_app) + status = result.status.lower() + + # Normalize status + status_map = { + "pending": "pending", + "started": "running", + "rendering": "running", # Custom status from streaming task + "success": "completed", + "failure": "failed", + "retry": "running", + "revoked": "failed", + } + normalized_status = status_map.get(status, status) + + run_data = { + "run_id": run_id, + "status": normalized_status, + "celery_task_id": task_id, + "actor_id": pending.get("actor_id"), + "recipe": pending.get("recipe"), + "inputs": self._ensure_inputs_list(pending.get("inputs")), + "output_name": pending.get("output_name"), + "created_at": pending.get("created_at"), + "error": pending.get("error"), + "recipe_sexp": pending.get("dag_json"), # Recipe content for streaming runs + # Checkpoint fields for resumable renders + "checkpoint_frame": pending.get("checkpoint_frame"), + "checkpoint_t": pending.get("checkpoint_t"), + "total_frames": pending.get("total_frames"), + "resumable": pending.get("resumable", True), + # IPFS streaming info + "ipfs_playlist_cid": pending.get("ipfs_playlist_cid"), + "quality_playlists": pending.get("quality_playlists"), + } + + # If task completed, get result + if result.ready(): + if result.successful(): + task_result = result.result + if isinstance(task_result, dict): + # Check task's own success flag and output_cid + task_success = task_result.get("success", True) + output_cid = task_result.get("output_cid") + if task_success and output_cid: + run_data["status"] = "completed" + run_data["output_cid"] = output_cid + else: + run_data["status"] = "failed" + run_data["error"] = task_result.get("error", "No output produced") + else: + run_data["status"] = "completed" + else: + run_data["status"] = "failed" + run_data["error"] = str(result.result) + + return run_data + + # No task_id but have pending record - return from DB + return { + "run_id": run_id, + "status": pending.get("status", "pending"), + "recipe": pending.get("recipe"), + "inputs": self._ensure_inputs_list(pending.get("inputs")), + "output_name": pending.get("output_name"), + "actor_id": pending.get("actor_id"), + "created_at": pending.get("created_at"), + "error": pending.get("error"), + "recipe_sexp": pending.get("dag_json"), # Recipe content for streaming runs + # Checkpoint fields for resumable renders + "checkpoint_frame": pending.get("checkpoint_frame"), + "checkpoint_t": pending.get("checkpoint_t"), + "total_frames": pending.get("total_frames"), + "resumable": pending.get("resumable", True), + # IPFS streaming info + "ipfs_playlist_cid": pending.get("ipfs_playlist_cid"), + "quality_playlists": pending.get("quality_playlists"), + } + + # Fallback: Check Redis for backwards compatibility + task_data = self.redis.get(f"{self.task_key_prefix}{run_id}") + if task_data: + if isinstance(task_data, bytes): + task_data = task_data.decode() + + # Parse task data (supports both old format string and new JSON format) + try: + parsed = json.loads(task_data) + task_id = parsed.get("task_id") + task_actor_id = parsed.get("actor_id") + task_recipe = parsed.get("recipe") + task_recipe_name = parsed.get("recipe_name") + task_inputs = parsed.get("inputs") + # Ensure inputs is a list (might be JSON string) + if isinstance(task_inputs, str): + try: + task_inputs = json.loads(task_inputs) + except json.JSONDecodeError: + task_inputs = None + task_output_name = parsed.get("output_name") + task_created_at = parsed.get("created_at") + except json.JSONDecodeError: + # Old format: just the task_id string + task_id = task_data + task_actor_id = None + task_recipe = None + task_recipe_name = None + task_inputs = None + task_output_name = None + task_created_at = None + + # Get task state from Celery + from celery.result import AsyncResult + from celery_app import app as celery_app + + result = AsyncResult(task_id, app=celery_app) + status = result.status.lower() + + # Normalize Celery status names + status_map = { + "pending": "pending", + "started": "running", + "rendering": "running", # Custom status from streaming task + "success": "completed", + "failure": "failed", + "retry": "running", + "revoked": "failed", + } + normalized_status = status_map.get(status, status) + + run_data = { + "run_id": run_id, + "status": normalized_status, + "celery_task_id": task_id, + "actor_id": task_actor_id, + "recipe": task_recipe, + "recipe_name": task_recipe_name, + "inputs": self._ensure_inputs_list(task_inputs), + "output_name": task_output_name, + "created_at": task_created_at, + } + + # If task completed, get result + if result.ready(): + if result.successful(): + task_result = result.result + if isinstance(task_result, dict): + # Check task's own success flag and output_cid + task_success = task_result.get("success", True) + output_cid = task_result.get("output_cid") + if task_success and output_cid: + run_data["status"] = "completed" + run_data["output_cid"] = output_cid + else: + run_data["status"] = "failed" + run_data["error"] = task_result.get("error", "No output produced") + else: + run_data["status"] = "completed" + else: + run_data["status"] = "failed" + run_data["error"] = str(result.result) + + return run_data + + return None + + async def list_runs(self, actor_id: str, offset: int = 0, limit: int = 20) -> List[RunResult]: + """List runs for a user. Returns completed and pending runs from database.""" + # Get completed runs from database + completed_runs = await self.db.list_runs_by_actor(actor_id, offset=0, limit=limit + 50) + + # Get pending runs from database + pending_db = await self.db.list_pending_runs(actor_id=actor_id) + + # Convert pending runs to run format with live status check + pending = [] + for pr in pending_db: + run_id = pr.get("run_id") + # Skip if already in completed + if any(r.get("run_id") == run_id for r in completed_runs): + continue + + # Get live status - include pending, running, rendering, and failed runs + run = await self.get_run(run_id) + if run and run.get("status") in ("pending", "running", "rendering", "failed"): + pending.append(run) + + # Combine and sort + all_runs = pending + completed_runs + all_runs.sort(key=lambda r: r.get("created_at", ""), reverse=True) + + return all_runs[offset:offset + limit] + + async def create_run( + self, + recipe: str, + inputs: Union[List[str], Dict[str, str]], + output_name: Optional[str] = None, + use_dag: bool = True, + dag_json: Optional[str] = None, + actor_id: Optional[str] = None, + l2_server: Optional[str] = None, + recipe_name: Optional[str] = None, + recipe_sexp: Optional[str] = None, + ) -> Tuple[Optional[RunResult], Optional[str]]: + """ + Create a new rendering run. Checks cache before executing. + + If recipe_sexp is provided, uses the new S-expression execution path + which generates code-addressed cache IDs before execution. + + Returns (run_dict, error_message). + """ + import httpx + try: + from legacy_tasks import render_effect, execute_dag, build_effect_dag, execute_recipe + except ImportError as e: + return None, f"Celery tasks not available: {e}" + + # Handle both list and dict inputs + if isinstance(inputs, dict): + input_list = list(inputs.values()) + else: + input_list = inputs + + # Compute content-addressable run_id + run_id = compute_run_id(input_list, recipe) + + # Generate output name if not provided + if not output_name: + output_name = f"{recipe}-{run_id[:8]}" + + # Check database cache first (completed runs) + cached_run = await self.db.get_run_cache(run_id) + if cached_run: + output_cid = cached_run.get("output_cid") + if output_cid and self.cache.has_content(output_cid): + return { + "run_id": run_id, + "status": "completed", + "recipe": recipe, + "inputs": input_list, + "output_name": output_name, + "output_cid": output_cid, + "ipfs_cid": cached_run.get("ipfs_cid"), + "provenance_cid": cached_run.get("provenance_cid"), + "created_at": cached_run.get("created_at"), + "completed_at": cached_run.get("created_at"), + "actor_id": actor_id, + }, None + + # Check L2 if not in local cache + if l2_server: + try: + async with httpx.AsyncClient(timeout=10) as client: + l2_resp = await client.get(f"{l2_server}/assets/by-run-id/{run_id}") + if l2_resp.status_code == 200: + l2_data = l2_resp.json() + output_cid = l2_data.get("output_cid") + ipfs_cid = l2_data.get("ipfs_cid") + if output_cid and ipfs_cid: + # Pull from IPFS to local cache + try: + import ipfs_client + legacy_dir = self.cache_dir / "legacy" + legacy_dir.mkdir(parents=True, exist_ok=True) + recovery_path = legacy_dir / output_cid + if ipfs_client.get_file(ipfs_cid, str(recovery_path)): + # Save to database cache + await self.db.save_run_cache( + run_id=run_id, + output_cid=output_cid, + recipe=recipe, + inputs=input_list, + ipfs_cid=ipfs_cid, + provenance_cid=l2_data.get("provenance_cid"), + actor_id=actor_id, + ) + return { + "run_id": run_id, + "status": "completed", + "recipe": recipe, + "inputs": input_list, + "output_cid": output_cid, + "ipfs_cid": ipfs_cid, + "provenance_cid": l2_data.get("provenance_cid"), + "created_at": datetime.now(timezone.utc).isoformat(), + "actor_id": actor_id, + }, None + except Exception: + pass # IPFS recovery failed, continue to run + except Exception: + pass # L2 lookup failed, continue to run + + # Not cached - submit to Celery + try: + # Prefer S-expression execution path (code-addressed cache IDs) + if recipe_sexp: + # Convert inputs to dict if needed + if isinstance(inputs, dict): + input_hashes = inputs + else: + # Legacy list format - use positional names + input_hashes = {f"input_{i}": cid for i, cid in enumerate(input_list)} + + task = execute_recipe.delay(recipe_sexp, input_hashes, run_id) + elif use_dag or recipe == "dag": + if dag_json: + dag_data = dag_json + else: + dag = build_effect_dag(input_list, recipe) + dag_data = dag.to_json() + + task = execute_dag.delay(dag_data, run_id) + else: + if len(input_list) != 1: + return None, "Legacy mode only supports single-input recipes. Use use_dag=true for multi-input." + task = render_effect.delay(input_list[0], recipe, output_name) + + # Store pending run in database for durability + try: + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=task.id, + recipe=recipe, + inputs=input_list, + actor_id=actor_id, + dag_json=dag_json, + output_name=output_name, + ) + except Exception as e: + import logging + logging.getLogger(__name__).error(f"Failed to save pending run: {e}") + + # Also store in Redis for backwards compatibility (shorter TTL) + task_data = json.dumps({ + "task_id": task.id, + "actor_id": actor_id, + "recipe": recipe, + "recipe_name": recipe_name, + "inputs": input_list, + "output_name": output_name, + "created_at": datetime.now(timezone.utc).isoformat(), + }) + self.redis.setex( + f"{self.task_key_prefix}{run_id}", + 3600 * 4, # 4 hour TTL (database is primary now) + task_data + ) + + return { + "run_id": run_id, + "status": "running", + "recipe": recipe, + "recipe_name": recipe_name, + "inputs": input_list, + "output_name": output_name, + "celery_task_id": task.id, + "created_at": datetime.now(timezone.utc).isoformat(), + "actor_id": actor_id, + }, None + + except Exception as e: + return None, f"Failed to submit task: {e}" + + async def discard_run( + self, + run_id: str, + actor_id: str, + username: str, + ) -> Tuple[bool, Optional[str]]: + """ + Discard (delete) a run record and clean up outputs/intermediates. + + Outputs and intermediates are only deleted if not used by other runs. + """ + import logging + logger = logging.getLogger(__name__) + + run = await self.get_run(run_id) + if not run: + return False, f"Run {run_id} not found" + + # Check ownership + run_owner = run.get("actor_id") + if run_owner and run_owner not in (username, actor_id): + return False, "Access denied" + + # Clean up activity outputs/intermediates (only if orphaned) + # The activity_id is the same as run_id + try: + success, msg = self.cache.discard_activity_outputs_only(run_id) + if success: + logger.info(f"Cleaned up run {run_id}: {msg}") + else: + # Activity might not exist (old runs), that's OK + logger.debug(f"No activity cleanup for {run_id}: {msg}") + except Exception as e: + logger.warning(f"Failed to cleanup activity for {run_id}: {e}") + + # Remove task_id mapping from Redis + self.redis.delete(f"{self.task_key_prefix}{run_id}") + + # Remove from run_cache database table + try: + await self.db.delete_run_cache(run_id) + except Exception as e: + logger.warning(f"Failed to delete run_cache for {run_id}: {e}") + + # Remove pending run if exists + try: + await self.db.delete_pending_run(run_id) + except Exception: + pass + + return True, None + + def _dag_to_steps(self, dag: Dict[str, Any]) -> Dict[str, Any]: + """Convert DAG nodes dict format to steps list format. + + DAG format: {"nodes": {"id": {...}}, "output_id": "..."} + Steps format: {"steps": [{"id": "...", "type": "...", ...}], "output_id": "..."} + """ + if "steps" in dag: + # Already in steps format + return dag + + if "nodes" not in dag: + return dag + + nodes = dag.get("nodes", {}) + steps = [] + + # Sort by topological order (sources first, then by input dependencies) + def get_level(node_id: str, visited: set = None) -> int: + if visited is None: + visited = set() + if node_id in visited: + return 0 + visited.add(node_id) + node = nodes.get(node_id, {}) + inputs = node.get("inputs", []) + if not inputs: + return 0 + return 1 + max(get_level(inp, visited) for inp in inputs) + + sorted_ids = sorted(nodes.keys(), key=lambda nid: (get_level(nid), nid)) + + for node_id in sorted_ids: + node = nodes[node_id] + steps.append({ + "id": node_id, + "step_id": node_id, + "type": node.get("node_type", "EFFECT"), + "config": node.get("config", {}), + "inputs": node.get("inputs", []), + "name": node.get("name"), + "cache_id": node_id, # In code-addressed system, node_id IS the cache_id + }) + + return { + "steps": steps, + "output_id": dag.get("output_id"), + "metadata": dag.get("metadata", {}), + "format": "json", + } + + def _sexp_to_steps(self, sexp_content: str) -> Dict[str, Any]: + """Convert S-expression plan to steps list format for UI. + + Parses the S-expression plan format: + (plan :id :recipe :recipe-hash + (inputs (input_name hash) ...) + (step step_id :cache-id :level (node-type :key val ...)) + ... + :output ) + + Returns steps list compatible with UI visualization. + """ + try: + from artdag.sexp import parse, Symbol, Keyword + except ImportError: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + try: + parsed = parse(sexp_content) + except Exception: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + if not isinstance(parsed, list) or not parsed: + return {"sexp": sexp_content, "steps": [], "format": "sexp"} + + steps = [] + output_step_id = None + plan_id = None + recipe_name = None + + # Parse plan structure + i = 0 + while i < len(parsed): + item = parsed[i] + + if isinstance(item, Keyword): + key = item.name + if i + 1 < len(parsed): + value = parsed[i + 1] + if key == "id": + plan_id = value + elif key == "recipe": + recipe_name = value + elif key == "output": + output_step_id = value + i += 2 + continue + + if isinstance(item, list) and item: + first = item[0] + if isinstance(first, Symbol) and first.name == "step": + # Parse step: (step step_id :cache-id :level (node-expr)) + step_id = item[1] if len(item) > 1 else None + cache_id = None + level = 0 + node_type = "EFFECT" + config = {} + inputs = [] + + j = 2 + while j < len(item): + part = item[j] + if isinstance(part, Keyword): + key = part.name + if j + 1 < len(item): + val = item[j + 1] + if key == "cache-id": + cache_id = val + elif key == "level": + level = val + j += 2 + continue + elif isinstance(part, list) and part: + # Node expression: (node-type :key val ...) + if isinstance(part[0], Symbol): + node_type = part[0].name.upper() + k = 1 + while k < len(part): + if isinstance(part[k], Keyword): + kname = part[k].name + if k + 1 < len(part): + kval = part[k + 1] + if kname == "inputs": + inputs = kval if isinstance(kval, list) else [kval] + else: + config[kname] = kval + k += 2 + continue + k += 1 + j += 1 + + steps.append({ + "id": step_id, + "step_id": step_id, + "type": node_type, + "config": config, + "inputs": inputs, + "cache_id": cache_id or step_id, + "level": level, + }) + + i += 1 + + return { + "sexp": sexp_content, + "steps": steps, + "output_id": output_step_id, + "plan_id": plan_id, + "recipe": recipe_name, + "format": "sexp", + } + + async def get_run_plan(self, run_id: str) -> Optional[Dict[str, Any]]: + """Get execution plan for a run. + + Plans are just node outputs - cached by content hash like everything else. + For streaming runs, returns the recipe content as the plan. + """ + # Get run to find plan_cache_id + run = await self.get_run(run_id) + if not run: + return None + + # For streaming runs, return the recipe as the plan + if run.get("recipe") == "streaming" and run.get("recipe_sexp"): + return { + "steps": [{"id": "stream", "type": "STREAM", "name": "Streaming Recipe"}], + "sexp": run.get("recipe_sexp"), + "format": "sexp", + } + + # Check plan_cid (stored in database) or plan_cache_id (legacy) + plan_cid = run.get("plan_cid") or run.get("plan_cache_id") + if plan_cid: + # Get plan from cache by content hash + plan_path = self.cache.get_by_cid(plan_cid) + if plan_path and plan_path.exists(): + with open(plan_path) as f: + content = f.read() + # Detect format + if content.strip().startswith("("): + # S-expression format - parse for UI + return self._sexp_to_steps(content) + else: + plan = json.loads(content) + return self._dag_to_steps(plan) + + # Fall back to legacy plans directory + sexp_path = self.cache_dir / "plans" / f"{run_id}.sexp" + if sexp_path.exists(): + with open(sexp_path) as f: + return self._sexp_to_steps(f.read()) + + json_path = self.cache_dir / "plans" / f"{run_id}.json" + if json_path.exists(): + with open(json_path) as f: + plan = json.load(f) + return self._dag_to_steps(plan) + + return None + + async def get_run_plan_sexp(self, run_id: str) -> Optional[str]: + """Get execution plan as S-expression string.""" + plan = await self.get_run_plan(run_id) + if plan and plan.get("format") == "sexp": + return plan.get("sexp") + return None + + async def get_run_artifacts(self, run_id: str) -> List[Dict[str, Any]]: + """Get all artifacts (inputs + outputs) for a run.""" + run = await self.get_run(run_id) + if not run: + return [] + + artifacts = [] + + def get_artifact_info(cid: str, role: str, name: str) -> Optional[Dict]: + if self.cache.has_content(cid): + path = self.cache.get_by_cid(cid) + if path and path.exists(): + return { + "cid": cid, + "size_bytes": path.stat().st_size, + "media_type": detect_media_type(path), + "role": role, + "step_name": name, + } + return None + + # Add inputs + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + for i, h in enumerate(inputs): + info = get_artifact_info(h, "input", f"Input {i + 1}") + if info: + artifacts.append(info) + + # Add output + if run.get("output_cid"): + info = get_artifact_info(run["output_cid"], "output", "Output") + if info: + artifacts.append(info) + + return artifacts + + async def get_run_analysis(self, run_id: str) -> List[Dict[str, Any]]: + """Get analysis data for each input in a run.""" + run = await self.get_run(run_id) + if not run: + return [] + + analysis_dir = self.cache_dir / "analysis" + results = [] + + inputs = run.get("inputs", []) + if isinstance(inputs, dict): + inputs = list(inputs.values()) + + for i, input_hash in enumerate(inputs): + analysis_path = analysis_dir / f"{input_hash}.json" + analysis_data = None + + if analysis_path.exists(): + try: + with open(analysis_path) as f: + analysis_data = json.load(f) + except (json.JSONDecodeError, IOError): + pass + + results.append({ + "input_hash": input_hash, + "input_name": f"Input {i + 1}", + "has_analysis": analysis_data is not None, + "tempo": analysis_data.get("tempo") if analysis_data else None, + "beat_times": analysis_data.get("beat_times", []) if analysis_data else [], + "raw": analysis_data, + }) + + return results + + def detect_media_type(self, path: Path) -> str: + """Detect media type for a file path.""" + return detect_media_type(path) + + async def recover_pending_runs(self) -> Dict[str, Union[int, str]]: + """ + Recover pending runs after restart. + + Checks all pending runs in the database and: + - Updates status for completed tasks + - Re-queues orphaned tasks that can be retried + - Marks as failed if unrecoverable + + Returns counts of recovered, completed, failed runs. + """ + from celery.result import AsyncResult + from celery_app import app as celery_app + + try: + from legacy_tasks import execute_dag + except ImportError: + return {"error": "Celery tasks not available"} + + stats = {"recovered": 0, "completed": 0, "failed": 0, "still_running": 0} + + # Get all pending/running runs from database + pending_runs = await self.db.list_pending_runs() + + for run in pending_runs: + run_id = run.get("run_id") + task_id = run.get("celery_task_id") + status = run.get("status") + + if not task_id: + # No task ID - try to re-queue if we have dag_json + dag_json = run.get("dag_json") + if dag_json: + try: + new_task = execute_dag.delay(dag_json, run_id) + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=new_task.id, + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + actor_id=run.get("actor_id"), + dag_json=dag_json, + output_name=run.get("output_name"), + ) + stats["recovered"] += 1 + except Exception as e: + await self.db.update_pending_run_status( + run_id, "failed", f"Recovery failed: {e}" + ) + stats["failed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", "No DAG data for recovery" + ) + stats["failed"] += 1 + continue + + # Check Celery task state + result = AsyncResult(task_id, app=celery_app) + celery_status = result.status.lower() + + if result.ready(): + if result.successful(): + # Task completed - move to run_cache + task_result = result.result + if isinstance(task_result, dict) and task_result.get("output_cid"): + await self.db.save_run_cache( + run_id=run_id, + output_cid=task_result["output_cid"], + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + ipfs_cid=task_result.get("ipfs_cid"), + provenance_cid=task_result.get("provenance_cid"), + actor_id=run.get("actor_id"), + ) + await self.db.complete_pending_run(run_id) + stats["completed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", "Task completed but no output hash" + ) + stats["failed"] += 1 + else: + # Task failed + await self.db.update_pending_run_status( + run_id, "failed", str(result.result) + ) + stats["failed"] += 1 + elif celery_status in ("pending", "started", "retry"): + # Still running + stats["still_running"] += 1 + else: + # Unknown state - try to re-queue if we have dag_json + dag_json = run.get("dag_json") + if dag_json: + try: + new_task = execute_dag.delay(dag_json, run_id) + await self.db.create_pending_run( + run_id=run_id, + celery_task_id=new_task.id, + recipe=run.get("recipe", "unknown"), + inputs=run.get("inputs", []), + actor_id=run.get("actor_id"), + dag_json=dag_json, + output_name=run.get("output_name"), + ) + stats["recovered"] += 1 + except Exception as e: + await self.db.update_pending_run_status( + run_id, "failed", f"Recovery failed: {e}" + ) + stats["failed"] += 1 + else: + await self.db.update_pending_run_status( + run_id, "failed", f"Task in unknown state: {celery_status}" + ) + stats["failed"] += 1 + + return stats diff --git a/app/services/storage_service.py b/app/services/storage_service.py new file mode 100644 index 0000000..19d4e3c --- /dev/null +++ b/app/services/storage_service.py @@ -0,0 +1,232 @@ +""" +Storage Service - business logic for storage provider management. +""" + +import json +from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from database import Database + from storage_providers import StorageProvidersModule + + +STORAGE_PROVIDERS_INFO = { + "pinata": {"name": "Pinata", "desc": "1GB free, IPFS pinning", "color": "blue"}, + "web3storage": {"name": "web3.storage", "desc": "IPFS + Filecoin", "color": "green"}, + "nftstorage": {"name": "NFT.Storage", "desc": "Free for NFTs", "color": "pink"}, + "infura": {"name": "Infura IPFS", "desc": "5GB free", "color": "orange"}, + "filebase": {"name": "Filebase", "desc": "5GB free, S3+IPFS", "color": "cyan"}, + "storj": {"name": "Storj", "desc": "25GB free", "color": "indigo"}, + "local": {"name": "Local Storage", "desc": "Your own disk", "color": "purple"}, +} + +VALID_PROVIDER_TYPES = list(STORAGE_PROVIDERS_INFO.keys()) + + +class StorageService: + """Service for managing user storage providers.""" + + def __init__(self, database: "Database", storage_providers_module: "StorageProvidersModule") -> None: + self.db = database + self.providers = storage_providers_module + + async def list_storages(self, actor_id: str) -> List[Dict[str, Any]]: + """List all storage providers for a user with usage stats.""" + storages = await self.db.get_user_storage(actor_id) + + for storage in storages: + usage = await self.db.get_storage_usage(storage["id"]) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + + # Mask sensitive config keys for display + if storage.get("config"): + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + masked = {} + for k, v in config.items(): + if "key" in k.lower() or "token" in k.lower() or "secret" in k.lower(): + masked[k] = v[:4] + "..." + v[-4:] if len(str(v)) > 8 else "****" + else: + masked[k] = v + storage["config_display"] = masked + + return storages + + async def get_storage(self, storage_id: int, actor_id: str) -> Optional[Dict[str, Any]]: + """Get a specific storage provider.""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return None + if storage["actor_id"] != actor_id: + return None + + usage = await self.db.get_storage_usage(storage_id) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + + return storage + + async def add_storage( + self, + actor_id: str, + provider_type: str, + config: Dict[str, Any], + capacity_gb: int = 5, + provider_name: Optional[str] = None, + description: Optional[str] = None, + ) -> Tuple[Optional[int], Optional[str]]: + """Add a new storage provider. Returns (storage_id, error_message).""" + if provider_type not in VALID_PROVIDER_TYPES: + return None, f"Invalid provider type: {provider_type}" + + # Test connection before saving + provider = self.providers.create_provider(provider_type, { + **config, + "capacity_gb": capacity_gb + }) + if not provider: + return None, "Failed to create provider with given config" + + success, message = await provider.test_connection() + if not success: + return None, f"Provider connection failed: {message}" + + # Generate name if not provided + if not provider_name: + existing = await self.db.get_user_storage_by_type(actor_id, provider_type) + provider_name = f"{provider_type}-{len(existing) + 1}" + + storage_id = await self.db.add_user_storage( + actor_id=actor_id, + provider_type=provider_type, + provider_name=provider_name, + config=config, + capacity_gb=capacity_gb, + description=description + ) + + if not storage_id: + return None, "Failed to save storage provider" + + return storage_id, None + + async def update_storage( + self, + storage_id: int, + actor_id: str, + config: Optional[Dict[str, Any]] = None, + capacity_gb: Optional[int] = None, + is_active: Optional[bool] = None, + ) -> Tuple[bool, Optional[str]]: + """Update a storage provider. Returns (success, error_message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage provider not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + # Test new config if provided + if config: + existing_config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + new_config = {**existing_config, **config} + provider = self.providers.create_provider(storage["provider_type"], { + **new_config, + "capacity_gb": capacity_gb or storage["capacity_gb"] + }) + if provider: + success, message = await provider.test_connection() + if not success: + return False, f"Provider connection failed: {message}" + + success = await self.db.update_user_storage( + storage_id, + config=config, + capacity_gb=capacity_gb, + is_active=is_active + ) + + return success, None if success else "Failed to update storage provider" + + async def delete_storage(self, storage_id: int, actor_id: str) -> Tuple[bool, Optional[str]]: + """Delete a storage provider. Returns (success, error_message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage provider not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + success = await self.db.remove_user_storage(storage_id) + return success, None if success else "Failed to remove storage provider" + + async def test_storage(self, storage_id: int, actor_id: str) -> Tuple[bool, str]: + """Test storage provider connectivity. Returns (success, message).""" + storage = await self.db.get_storage_by_id(storage_id) + if not storage: + return False, "Storage not found" + if storage["actor_id"] != actor_id: + return False, "Not authorized" + + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + provider = self.providers.create_provider(storage["provider_type"], { + **config, + "capacity_gb": storage["capacity_gb"] + }) + + if not provider: + return False, "Failed to create provider" + + return await provider.test_connection() + + async def list_by_type(self, actor_id: str, provider_type: str) -> List[Dict[str, Any]]: + """List storage providers of a specific type.""" + return await self.db.get_user_storage_by_type(actor_id, provider_type) + + def build_config_from_form(self, provider_type: str, form_data: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """Build provider config from form data. Returns (config, error).""" + api_key = form_data.get("api_key") + secret_key = form_data.get("secret_key") + api_token = form_data.get("api_token") + project_id = form_data.get("project_id") + project_secret = form_data.get("project_secret") + access_key = form_data.get("access_key") + bucket = form_data.get("bucket") + path = form_data.get("path") + + if provider_type == "pinata": + if not api_key or not secret_key: + return None, "Pinata requires API Key and Secret Key" + return {"api_key": api_key, "secret_key": secret_key}, None + + elif provider_type == "web3storage": + if not api_token: + return None, "web3.storage requires API Token" + return {"api_token": api_token}, None + + elif provider_type == "nftstorage": + if not api_token: + return None, "NFT.Storage requires API Token" + return {"api_token": api_token}, None + + elif provider_type == "infura": + if not project_id or not project_secret: + return None, "Infura requires Project ID and Project Secret" + return {"project_id": project_id, "project_secret": project_secret}, None + + elif provider_type == "filebase": + if not access_key or not secret_key or not bucket: + return None, "Filebase requires Access Key, Secret Key, and Bucket" + return {"access_key": access_key, "secret_key": secret_key, "bucket": bucket}, None + + elif provider_type == "storj": + if not access_key or not secret_key or not bucket: + return None, "Storj requires Access Key, Secret Key, and Bucket" + return {"access_key": access_key, "secret_key": secret_key, "bucket": bucket}, None + + elif provider_type == "local": + if not path: + return None, "Local storage requires a path" + return {"path": path}, None + + return None, f"Unknown provider type: {provider_type}" diff --git a/app/templates/404.html b/app/templates/404.html new file mode 100644 index 0000000..0cd9c70 --- /dev/null +++ b/app/templates/404.html @@ -0,0 +1,14 @@ +{% extends "base.html" %} + +{% block title %}Not Found - Art-DAG L1{% endblock %} + +{% block content %} +
+

404

+

Page Not Found

+

The page you're looking for doesn't exist or has been moved.

+ + Go Home + +
+{% endblock %} diff --git a/app/templates/base.html b/app/templates/base.html new file mode 100644 index 0000000..9be32fb --- /dev/null +++ b/app/templates/base.html @@ -0,0 +1,46 @@ +{% extends "_base.html" %} + +{% block brand %} +Rose Ash +| +Art-DAG +{% endblock %} + +{% block cart_mini %} +{% if request and request.state.cart_mini_html %} + {{ request.state.cart_mini_html | safe }} +{% endif %} +{% endblock %} + +{% block nav_tree %} +{% if request and request.state.nav_tree_html %} + {{ request.state.nav_tree_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu_mobile %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block sub_nav %} + +{% endblock %} diff --git a/app/templates/cache/detail.html b/app/templates/cache/detail.html new file mode 100644 index 0000000..da30119 --- /dev/null +++ b/app/templates/cache/detail.html @@ -0,0 +1,182 @@ +{% extends "base.html" %} + +{% block title %}{{ cache.cid[:16] }} - Cache - Art-DAG L1{% endblock %} + +{% block content %} +
+ +
+ ← Media +

{{ cache.cid[:24] }}...

+
+ + +
+ {% if cache.mime_type and cache.mime_type.startswith('image/') %} + {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} + + {% elif cache.mime_type and cache.mime_type.startswith('video/') %} + {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} + + {% elif cache.mime_type and cache.mime_type.startswith('audio/') %} +
+ {% if cache.remote_only and cache.ipfs_cid %} + + {% else %} + + {% endif %} +
+ + {% elif cache.mime_type == 'application/json' %} +
+
{{ cache.content_preview }}
+
+ + {% else %} +
+
{{ cache.mime_type or 'Unknown type' }}
+
{{ cache.size | filesizeformat if cache.size else 'Unknown size' }}
+
+ {% endif %} +
+ + +
+
+ Friendly Name + +
+ {% if cache.friendly_name %} +

{{ cache.friendly_name }}

+

Use in recipes: {{ cache.base_name }}

+ {% else %} +

No friendly name assigned. Click Edit to add one.

+ {% endif %} +
+ + +
+
+

Details

+ +
+ {% if cache.title or cache.description or cache.filename %} +
+ {% if cache.title %} +

{{ cache.title }}

+ {% elif cache.filename %} +

{{ cache.filename }}

+ {% endif %} + {% if cache.description %} +

{{ cache.description }}

+ {% endif %} +
+ {% else %} +

No title or description set. Click Edit to add metadata.

+ {% endif %} + {% if cache.tags %} +
+ {% for tag in cache.tags %} + {{ tag }} + {% endfor %} +
+ {% endif %} + {% if cache.source_type or cache.source_note %} +
+ {% if cache.source_type %}Source: {{ cache.source_type }}{% endif %} + {% if cache.source_note %} - {{ cache.source_note }}{% endif %} +
+ {% endif %} +
+ + +
+
+
CID
+
{{ cache.cid }}
+
+
+
Content Type
+
{{ cache.mime_type or 'Unknown' }}
+
+
+
Size
+
{{ cache.size | filesizeformat if cache.size else 'Unknown' }}
+
+
+
Created
+
{{ cache.created_at or 'Unknown' }}
+
+
+ + + {% if cache.ipfs_cid %} +
+
IPFS CID
+
+ {{ cache.ipfs_cid }} + + View on IPFS Gateway → + +
+
+ {% endif %} + + + {% if cache.runs %} +

Related Runs

+
+ {% for run in cache.runs %} + +
+ {{ run.run_id[:16] }}... + {{ run.created_at }} +
+
+ {% endfor %} +
+ {% endif %} + + +
+ + Download + + + +
+
+{% endblock %} diff --git a/app/templates/cache/media_list.html b/app/templates/cache/media_list.html new file mode 100644 index 0000000..0a436aa --- /dev/null +++ b/app/templates/cache/media_list.html @@ -0,0 +1,325 @@ +{% extends "base.html" %} + +{% block title %}Media - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Media

+
+ + +
+
+ + + + + {% if items %} +
+ {% for item in items %} + {# Determine media category from type or filename #} + {% set is_image = item.type in ('image', 'image/jpeg', 'image/png', 'image/gif', 'image/webp') or (item.filename and item.filename.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))) %} + {% set is_video = item.type in ('video', 'video/mp4', 'video/webm', 'video/x-matroska') or (item.filename and item.filename.lower().endswith(('.mp4', '.mkv', '.webm', '.mov'))) %} + {% set is_audio = item.type in ('audio', 'audio/mpeg', 'audio/wav', 'audio/flac') or (item.filename and item.filename.lower().endswith(('.mp3', '.wav', '.flac', '.ogg'))) %} + + + + {% if is_image %} + + + {% elif is_video %} +
+ +
+
+ + + +
+
+
+ + {% elif is_audio %} +
+ + + + Audio +
+ + {% else %} +
+ {{ item.type or 'Media' }} +
+ {% endif %} + +
+ {% if item.friendly_name %} +
{{ item.friendly_name }}
+ {% else %} +
{{ item.cid[:16] }}...
+ {% endif %} + {% if item.filename %} +
{{ item.filename }}
+ {% endif %} +
+
+ {% endfor %} +
+ + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+ + + +

No media files yet

+

Run a recipe to generate media artifacts.

+
+ {% endif %} +
+ + +{% endblock %} diff --git a/app/templates/cache/not_found.html b/app/templates/cache/not_found.html new file mode 100644 index 0000000..600a77d --- /dev/null +++ b/app/templates/cache/not_found.html @@ -0,0 +1,21 @@ +{% extends "base.html" %} + +{% block title %}Content Not Found - Art-DAG L1{% endblock %} + +{% block content %} +
+

404

+

Content Not Found

+

+ The content with hash {{ cid[:24] if cid else 'unknown' }}... was not found in the cache. +

+ +
+{% endblock %} diff --git a/app/templates/effects/detail.html b/app/templates/effects/detail.html new file mode 100644 index 0000000..572f586 --- /dev/null +++ b/app/templates/effects/detail.html @@ -0,0 +1,203 @@ +{% extends "base.html" %} + +{% set meta = effect.meta or effect %} + +{% block title %}{{ meta.name or 'Effect' }} - Effects - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + + +{% endblock %} + +{% block content %} +
+ +
+ ← Effects +

{{ meta.name or 'Unnamed Effect' }}

+ v{{ meta.version or '1.0.0' }} + {% if meta.temporal %} + temporal + {% endif %} +
+ + {% if meta.author %} +

by {{ meta.author }}

+ {% endif %} + + {% if meta.description %} +

{{ meta.description }}

+ {% endif %} + + +
+ {% if effect.friendly_name %} +
+ Friendly Name +

{{ effect.friendly_name }}

+

Use in recipes: (effect {{ effect.base_name }})

+
+ {% endif %} +
+
+ Content ID (CID) +

{{ effect.cid }}

+
+ +
+ {% if effect.uploaded_at %} +
+ Uploaded: {{ effect.uploaded_at }} + {% if effect.uploader %} + by {{ effect.uploader }} + {% endif %} +
+ {% endif %} +
+ +
+ +
+ + {% if meta.params %} +
+
+ Parameters +
+
+ {% for param in meta.params %} +
+
+ {{ param.name }} + {{ param.type }} +
+ {% if param.description %} +

{{ param.description }}

+ {% endif %} +
+ {% if param.range %} + range: {{ param.range[0] }} - {{ param.range[1] }} + {% endif %} + {% if param.default is defined %} + default: {{ param.default }} + {% endif %} +
+
+ {% endfor %} +
+
+ {% endif %} + + +
+
+ Usage in Recipe +
+
+ {% if effect.base_name %} +
({{ effect.base_name }} ...)
+

+ Use the friendly name to reference this effect. +

+ {% else %} +
(effect :cid "{{ effect.cid }}")
+

+ Reference this effect by CID in your recipe. +

+ {% endif %} +
+
+
+ + +
+
+
+ Source Code (S-expression) + +
+
+
Loading...
+
+
+
+
+ + +
+ {% if effect.cid.startswith('Qm') or effect.cid.startswith('bafy') %} + + View on IPFS + + {% endif %} + + + +
+
+ + +{% endblock %} diff --git a/app/templates/effects/list.html b/app/templates/effects/list.html new file mode 100644 index 0000000..065d2bb --- /dev/null +++ b/app/templates/effects/list.html @@ -0,0 +1,200 @@ +{% extends "base.html" %} + +{% block title %}Effects - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Effects

+ +
+ + + + +

+ Effects are S-expression files that define video processing operations. + Each effect is stored in IPFS and can be referenced by name in recipes. +

+ + {% if effects %} + + + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+ + + +

No effects uploaded yet.

+

+ Effects are S-expression files with metadata in comment headers. +

+ +
+ {% endif %} +
+ + +{% endblock %} diff --git a/app/templates/fragments/link_card.html b/app/templates/fragments/link_card.html new file mode 100644 index 0000000..ecc4450 --- /dev/null +++ b/app/templates/fragments/link_card.html @@ -0,0 +1,22 @@ + +
+
+ {% if content_type == "recipe" %} + + {% elif content_type == "effect" %} + + {% elif content_type == "run" %} + + {% else %} + + {% endif %} +
+
+
{{ title }}
+ {% if description %} +
{{ description }}
+ {% endif %} +
{{ content_type }} · {{ cid[:12] }}…
+
+
+
diff --git a/app/templates/fragments/nav_item.html b/app/templates/fragments/nav_item.html new file mode 100644 index 0000000..e987cc5 --- /dev/null +++ b/app/templates/fragments/nav_item.html @@ -0,0 +1,7 @@ + diff --git a/app/templates/home.html b/app/templates/home.html new file mode 100644 index 0000000..c6a23aa --- /dev/null +++ b/app/templates/home.html @@ -0,0 +1,51 @@ +{% extends "base.html" %} + +{% block title %}Art-DAG L1{% endblock %} + +{% block content %} +
+

Art-DAG L1

+

Content-Addressable Media Processing

+ + + + {% if not user %} +
+

Sign in through your L2 server to access all features.

+ Sign In → +
+ {% endif %} + + {% if readme_html %} +
+ {{ readme_html | safe }} +
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/recipes/detail.html b/app/templates/recipes/detail.html new file mode 100644 index 0000000..daf134a --- /dev/null +++ b/app/templates/recipes/detail.html @@ -0,0 +1,265 @@ +{% extends "base.html" %} + +{% block title %}{{ recipe.name }} - Recipe - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + +{% endblock %} + +{% block content %} +
+ +
+ ← Recipes +

{{ recipe.name or 'Unnamed Recipe' }}

+ {% if recipe.version %} + v{{ recipe.version }} + {% endif %} +
+ + {% if recipe.description %} +

{{ recipe.description }}

+ {% endif %} + + +
+
+
+ Recipe ID +

{{ recipe.recipe_id[:16] }}...

+
+ {% if recipe.ipfs_cid %} +
+ IPFS CID +

{{ recipe.ipfs_cid[:16] }}...

+
+ {% endif %} +
+ Steps +

{{ recipe.step_count or recipe.steps|length }}

+
+ {% if recipe.author %} +
+ Author +

{{ recipe.author }}

+
+ {% endif %} +
+
+ + {% if recipe.type == 'streaming' %} + +
+
+ Streaming Recipe +
+

+ This recipe uses frame-by-frame streaming rendering. The pipeline is defined as an S-expression that generates frames dynamically. +

+
+ {% else %} + +
+
+ Pipeline DAG + {{ recipe.steps | length }} steps +
+
+
+ + +

Steps

+
+ {% for step in recipe.steps %} + {% set colors = { + 'effect': 'blue', + 'analyze': 'purple', + 'transform': 'green', + 'combine': 'orange', + 'output': 'cyan' + } %} + {% set color = colors.get(step.type, 'gray') %} + +
+
+
+ + {{ loop.index }} + + {{ step.name }} + + {{ step.type }} + +
+
+ + {% if step.inputs %} +
+ Inputs: {{ step.inputs | join(', ') }} +
+ {% endif %} + + {% if step.params %} +
+ {{ step.params | tojson }} +
+ {% endif %} +
+ {% endfor %} +
+ {% endif %} + + +

Recipe (S-expression)

+
+ {% if recipe.sexp %} +
{{ recipe.sexp }}
+ {% else %} +

No source available

+ {% endif %} +
+ + + + +
+ + {% if recipe.ipfs_cid %} + + View on IPFS + + {% elif recipe.recipe_id.startswith('Qm') or recipe.recipe_id.startswith('bafy') %} + + View on IPFS + + {% endif %} + + + +
+
+ + +{% endblock %} diff --git a/app/templates/recipes/list.html b/app/templates/recipes/list.html new file mode 100644 index 0000000..0cd484f --- /dev/null +++ b/app/templates/recipes/list.html @@ -0,0 +1,136 @@ +{% extends "base.html" %} + +{% block title %}Recipes - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Recipes

+ +
+ +

+ Recipes define processing pipelines for audio and media. Each recipe is a DAG of effects. +

+ + {% if recipes %} + + + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+

No recipes available.

+

+ Recipes are S-expression files (.sexp) that define processing pipelines. +

+ +
+ {% endif %} +
+ +
+ + +{% endblock %} diff --git a/app/templates/runs/_run_card.html b/app/templates/runs/_run_card.html new file mode 100644 index 0000000..88a42a2 --- /dev/null +++ b/app/templates/runs/_run_card.html @@ -0,0 +1,89 @@ +{# Run card partial - expects 'run' variable #} +{% set status_colors = { + 'completed': 'green', + 'running': 'blue', + 'pending': 'yellow', + 'failed': 'red', + 'cached': 'purple' +} %} +{% set color = status_colors.get(run.status, 'gray') %} + + +
+
+ {{ run.run_id[:12] }}... + + {{ run.status }} + + {% if run.cached %} + cached + {% endif %} +
+ {{ run.created_at }} +
+ +
+
+ + Recipe: {{ run.recipe_name or (run.recipe[:12] ~ '...' if run.recipe and run.recipe|length > 12 else run.recipe) or 'Unknown' }} + + {% if run.total_steps %} + + Steps: {{ run.executed or 0 }}/{{ run.total_steps }} + + {% endif %} +
+
+ + {# Media previews row #} +
+ {# Input previews #} + {% if run.input_previews %} +
+ In: + {% for inp in run.input_previews %} + {% if inp.media_type and inp.media_type.startswith('image/') %} + + {% elif inp.media_type and inp.media_type.startswith('video/') %} + + {% else %} +
?
+ {% endif %} + {% endfor %} + {% if run.inputs and run.inputs|length > 3 %} + +{{ run.inputs|length - 3 }} + {% endif %} +
+ {% elif run.inputs %} +
+ {{ run.inputs|length }} input(s) +
+ {% endif %} + + {# Arrow #} + -> + + {# Output preview - prefer IPFS URLs when available #} + {% if run.output_cid %} +
+ Out: + {% if run.output_media_type and run.output_media_type.startswith('image/') %} + + {% elif run.output_media_type and run.output_media_type.startswith('video/') %} + + {% else %} +
?
+ {% endif %} +
+ {% else %} + No output yet + {% endif %} + +
+ + {% if run.output_cid %} + {{ run.output_cid[:12] }}... + {% endif %} +
+
diff --git a/app/templates/runs/artifacts.html b/app/templates/runs/artifacts.html new file mode 100644 index 0000000..874188c --- /dev/null +++ b/app/templates/runs/artifacts.html @@ -0,0 +1,62 @@ +{% extends "base.html" %} + +{% block title %}Run Artifacts{% endblock %} + +{% block content %} + + +

Run Artifacts

+ +{% if artifacts %} +
+ {% for artifact in artifacts %} +
+
+ + {{ artifact.role }} + + {{ artifact.step_name }} +
+ +
+

Content Hash

+

{{ artifact.hash }}

+
+ +
+ + {% if artifact.media_type == 'video' %}Video + {% elif artifact.media_type == 'image' %}Image + {% elif artifact.media_type == 'audio' %}Audio + {% else %}File{% endif %} + + {{ (artifact.size_bytes / 1024)|round(1) }} KB +
+ + +
+ {% endfor %} +
+{% else %} +
+

No artifacts found for this run.

+
+{% endif %} +{% endblock %} diff --git a/app/templates/runs/detail.html b/app/templates/runs/detail.html new file mode 100644 index 0000000..ae87dd3 --- /dev/null +++ b/app/templates/runs/detail.html @@ -0,0 +1,1073 @@ +{% extends "base.html" %} + +{% block title %}Run {{ run.run_id[:12] }} - Art-DAG L1{% endblock %} + +{% block head %} +{{ super() }} + + + + +{% endblock %} + +{% block content %} +{% set status_colors = {'completed': 'green', 'running': 'blue', 'pending': 'yellow', 'failed': 'red', 'paused': 'yellow'} %} +{% set color = status_colors.get(run.status, 'gray') %} + +
+ +
+ ← Runs +

{{ run.run_id[:16] }}...

+ + {{ run.status }} + + {% if run.cached %} + Cached + {% endif %} + {% if run.error %} + {{ run.error }} + {% endif %} + {% if run.checkpoint_frame %} + + Checkpoint: {{ run.checkpoint_frame }}{% if run.total_frames %} / {{ run.total_frames }}{% endif %} frames + + {% endif %} +
+ + + {% if run.status == 'running' %} + + {% endif %} + + + {% if run.status in ['failed', 'paused'] %} + {% if run.checkpoint_frame %} + + {% endif %} + + {% endif %} + + {% if run.recipe %} + + {% endif %} + + + +
+ + +
+
+
Recipe
+
+ {% if run.recipe %} + + {{ run.recipe_name or (run.recipe[:16] ~ '...') }} + + {% else %} + Unknown + {% endif %} +
+
+
+
Steps
+
+ {% if run.recipe == 'streaming' %} + {% if run.status == 'completed' %}1 / 1{% else %}0 / 1{% endif %} + {% else %} + {{ run.executed or 0 }} / {{ run.total_steps or (plan.steps|length if plan and plan.steps else '?') }} + {% endif %} + {% if run.cached_steps %} + ({{ run.cached_steps }} cached) + {% endif %} +
+
+
+
Created
+
{{ run.created_at }}
+
+
+
User
+
{{ run.username or 'Unknown' }}
+
+
+ + + {% if run.status == 'rendering' or run.ipfs_playlist_cid or (run.status in ['paused', 'failed'] and run.checkpoint_frame) %} +
+
+

+ {% if run.status == 'rendering' %} + + Live Preview + {% elif run.status == 'paused' %} + + Partial Output (Paused) + {% elif run.status == 'failed' and run.checkpoint_frame %} + + Partial Output (Failed) + {% else %} + + Video + {% endif %} +

+
+ +
+ + +
+
Connecting...
+
+
+
+ +
+
+
+
Waiting for stream...
+
+
+
+
+ Stream: /runs/{{ run.run_id }}/playlist.m3u8 + +
+
+ + + {% endif %} + + +
+ +
+ + +
+ {% if plan %} +
+ +
+
+
+ +
+
+ Click a node to view details +
+ +
+
+ + +
+ {% for step in plan.steps %} + {% set step_color = 'green' if step.status == 'completed' or step.cache_id else ('purple' if step.cached else ('blue' if step.status == 'running' else 'gray')) %} +
+
+
+ + {{ loop.index }} + + {{ step.name }} + {{ step.type }} +
+
+ {% if step.cached %} + cached + {% elif step.status == 'completed' %} + completed + {% endif %} +
+
+ {% if step.cache_id %} + + {% endif %} +
+ {% endfor %} +
+ + + {% if plan_sexp %} +
+ + Recipe (S-expression) + {% if recipe_ipfs_cid %} + + ipfs://{{ recipe_ipfs_cid[:16] }}... + + {% endif %} + +
+
{{ plan_sexp }}
+
+
+ {% endif %} + + + + {% else %} +

No plan available for this run.

+ {% endif %} +
+ + + + + + + + + + + + {% if run.output_cid %} +
+

Output

+ + {# Inline media preview - prefer IPFS URLs when available #} +
+ {% if output_media_type and output_media_type.startswith('image/') %} + + Output + + {% elif output_media_type and output_media_type.startswith('video/') %} + {# HLS streams use the unified player above; show direct video for non-HLS #} + {% if run.ipfs_playlist_cid %} +
+ HLS stream available in player above. Use "From Start" to watch from beginning or "Live Edge" to follow rendering progress. +
+ {% else %} + {# Direct video file #} + + {% endif %} + {% elif output_media_type and output_media_type.startswith('audio/') %} + + {% else %} +
+
?
+
{{ output_media_type or 'Unknown media type' }}
+
+ {% endif %} +
+ +
+ + {% if run.ipfs_cid %}{{ run.ipfs_cid }}{% else %}{{ run.output_cid }}{% endif %} + +
+ {% if run.ipfs_playlist_cid %} + + HLS Playlist + + {% endif %} + {% if run.ipfs_cid %} + + View on IPFS Gateway + + {% endif %} +
+
+
+ {% endif %} +
+ + +{% endblock %} diff --git a/app/templates/runs/list.html b/app/templates/runs/list.html new file mode 100644 index 0000000..8d72415 --- /dev/null +++ b/app/templates/runs/list.html @@ -0,0 +1,45 @@ +{% extends "base.html" %} + +{% block title %}Runs - Art-DAG L1{% endblock %} + +{% block content %} +
+
+

Execution Runs

+ Browse Recipes → +
+ + {% if runs %} +
+ {% for run in runs %} + {% include "runs/_run_card.html" %} + {% endfor %} +
+ + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+
+ + + +

No runs yet

+
+

Execute a recipe to see your runs here.

+ + Browse Recipes + +
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/runs/plan.html b/app/templates/runs/plan.html new file mode 100644 index 0000000..f50090d --- /dev/null +++ b/app/templates/runs/plan.html @@ -0,0 +1,99 @@ +{% extends "base.html" %} + +{% block title %}Run Plan - {{ run_id[:16] }}{% endblock %} + +{% block head %} + +{% endblock %} + +{% block content %} + + +

Execution Plan

+ +{% if plan %} +
+ +
+

DAG Visualization

+
+
+ + +
+

Steps ({{ plan.steps|length if plan.steps else 0 }})

+
+ {% for step in plan.get('steps', []) %} +
+
+ {{ step.name or step.id or 'Step ' ~ loop.index }} + + {{ step.status or ('cached' if step.cached else 'pending') }} + +
+ {% if step.cache_id %} +
+ {{ step.cache_id[:24] }}... +
+ {% endif %} +
+ {% else %} +

No steps defined

+ {% endfor %} +
+
+
+ + +{% else %} +
+

No execution plan available for this run.

+
+{% endif %} +{% endblock %} diff --git a/app/templates/runs/plan_node.html b/app/templates/runs/plan_node.html new file mode 100644 index 0000000..99e1658 --- /dev/null +++ b/app/templates/runs/plan_node.html @@ -0,0 +1,99 @@ +{# Plan node detail panel - loaded via HTMX #} +{% set status_color = 'green' if status in ('cached', 'completed') else 'yellow' %} + +
+
+

{{ step.name or step.step_id[:20] }}

+
+ + {{ step.node_type or 'EFFECT' }} + + {{ status }} + Level {{ step.level or 0 }} +
+
+ +
+ +{# Output preview #} +{% if output_preview %} +
+
Output
+ {% if output_media_type == 'video' %} + + {% elif output_media_type == 'image' %} + + {% elif output_media_type == 'audio' %} + + {% endif %} +
+{% elif ipfs_cid %} +
+
Output (IPFS)
+ +
+{% endif %} + +{# Output link #} +{% if ipfs_cid %} + + {{ ipfs_cid[:24] }}... + View + +{% elif has_cached and cache_id %} + + {{ cache_id[:24] }}... + View + +{% endif %} + +{# Input media previews #} +{% if inputs %} + +{% endif %} + +{# Parameters/Config #} +{% if config %} +
+
Parameters
+
+ {% for key, value in config.items() %} +
+ {{ key }}: + {{ value if value is string else value|tojson }} +
+ {% endfor %} +
+
+{% endif %} + +{# Metadata #} +
+
Step ID: {{ step.step_id[:32] }}...
+
Cache ID: {{ cache_id[:32] }}...
+
diff --git a/app/templates/storage/list.html b/app/templates/storage/list.html new file mode 100644 index 0000000..a33f98a --- /dev/null +++ b/app/templates/storage/list.html @@ -0,0 +1,90 @@ +{% extends "base.html" %} + +{% block title %}Storage Providers - Art-DAG L1{% endblock %} + +{% block content %} +
+

Storage Providers

+ +

+ Configure your IPFS pinning services. Data is pinned to your accounts, giving you full control. +

+ + + + + + {% if storages %} +

Your Storage Providers

+
+ {% for storage in storages %} + {% set info = providers_info.get(storage.provider_type, {'name': storage.provider_type, 'color': 'gray'}) %} +
+
+
+ {{ storage.provider_name or info.name }} + {% if storage.is_active %} + Active + {% else %} + Inactive + {% endif %} +
+
+ + +
+
+ +
+
+ Capacity: + {{ storage.capacity_gb }} GB +
+
+ Used: + {{ (storage.used_bytes / 1024 / 1024 / 1024) | round(2) }} GB +
+
+ Pins: + {{ storage.pin_count }} +
+
+ +
+
+ {% endfor %} +
+ {% else %} +
+

No storage providers configured yet.

+

Click on a provider above to add your first one.

+
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/storage/type.html b/app/templates/storage/type.html new file mode 100644 index 0000000..851c633 --- /dev/null +++ b/app/templates/storage/type.html @@ -0,0 +1,152 @@ +{% extends "base.html" %} + +{% block title %}{{ provider_info.name }} - Storage - Art-DAG L1{% endblock %} + +{% block content %} +
+
+ ← All Providers +

{{ provider_info.name }}

+
+ +

{{ provider_info.desc }}

+ + +
+

Add {{ provider_info.name }} Account

+ +
+ + +
+ + +
+ + {% if provider_type == 'pinata' %} +
+
+ + +
+
+ + +
+
+ + {% elif provider_type in ['web3storage', 'nftstorage'] %} +
+ + +
+ + {% elif provider_type == 'infura' %} +
+
+ + +
+
+ + +
+
+ + {% elif provider_type in ['filebase', 'storj'] %} +
+
+ + +
+
+ + +
+
+
+ + +
+ + {% elif provider_type == 'local' %} +
+ + +
+ {% endif %} + +
+ + +
+ +
+ +
+ +
+
+
+ + + {% if storages %} +

Configured Accounts

+
+ {% for storage in storages %} +
+
+
+ {{ storage.provider_name }} + {% if storage.is_active %} + Active + {% endif %} +
+
+ + +
+
+ + {% if storage.config_display %} +
+ {% for key, value in storage.config_display.items() %} + {{ key }}: {{ value }} + {% endfor %} +
+ {% endif %} + +
+
+ {% endfor %} +
+ {% endif %} +
+{% endblock %} diff --git a/app/types.py b/app/types.py new file mode 100644 index 0000000..15d81c6 --- /dev/null +++ b/app/types.py @@ -0,0 +1,197 @@ +""" +Type definitions for Art DAG L1 server. + +Uses TypedDict for configuration structures to enable mypy checking. +""" + +from typing import Any, Dict, List, Optional, TypedDict, Union +from typing_extensions import NotRequired + + +# === Node Config Types === + +class SourceConfig(TypedDict, total=False): + """Config for SOURCE nodes.""" + cid: str # Content ID (IPFS CID or SHA3-256 hash) + asset: str # Asset name from registry + input: bool # True if this is a variable input + name: str # Human-readable name for variable inputs + description: str # Description for variable inputs + + +class EffectConfig(TypedDict, total=False): + """Config for EFFECT nodes.""" + effect: str # Effect name + cid: str # Effect CID (for cached/IPFS effects) + # Effect parameters are additional keys + intensity: float + level: float + + +class SequenceConfig(TypedDict, total=False): + """Config for SEQUENCE nodes.""" + transition: Dict[str, Any] # Transition config + + +class SegmentConfig(TypedDict, total=False): + """Config for SEGMENT nodes.""" + start: float + end: float + duration: float + + +# Union of all config types +NodeConfig = Union[SourceConfig, EffectConfig, SequenceConfig, SegmentConfig, Dict[str, Any]] + + +# === Node Types === + +class CompiledNode(TypedDict): + """Node as produced by the S-expression compiler.""" + id: str + type: str # "SOURCE", "EFFECT", "SEQUENCE", etc. + config: Dict[str, Any] + inputs: List[str] + name: NotRequired[str] + + +class TransformedNode(TypedDict): + """Node after transformation for artdag execution.""" + node_id: str + node_type: str + config: Dict[str, Any] + inputs: List[str] + name: NotRequired[str] + + +# === DAG Types === + +class CompiledDAG(TypedDict): + """DAG as produced by the S-expression compiler.""" + nodes: List[CompiledNode] + output: str + + +class TransformedDAG(TypedDict): + """DAG after transformation for artdag execution.""" + nodes: Dict[str, TransformedNode] + output_id: str + metadata: NotRequired[Dict[str, Any]] + + +# === Registry Types === + +class AssetEntry(TypedDict, total=False): + """Asset in the recipe registry.""" + cid: str + url: str + + +class EffectEntry(TypedDict, total=False): + """Effect in the recipe registry.""" + cid: str + url: str + temporal: bool + + +class Registry(TypedDict): + """Recipe registry containing assets and effects.""" + assets: Dict[str, AssetEntry] + effects: Dict[str, EffectEntry] + + +# === Visualization Types === + +class VisNodeData(TypedDict, total=False): + """Data for a visualization node (Cytoscape.js format).""" + id: str + label: str + nodeType: str + isOutput: bool + + +class VisNode(TypedDict): + """Visualization node wrapper.""" + data: VisNodeData + + +class VisEdgeData(TypedDict): + """Data for a visualization edge.""" + source: str + target: str + + +class VisEdge(TypedDict): + """Visualization edge wrapper.""" + data: VisEdgeData + + +class VisualizationDAG(TypedDict): + """DAG structure for Cytoscape.js visualization.""" + nodes: List[VisNode] + edges: List[VisEdge] + + +# === Recipe Types === + +class Recipe(TypedDict, total=False): + """Compiled recipe structure.""" + name: str + version: str + description: str + owner: str + registry: Registry + dag: CompiledDAG + recipe_id: str + ipfs_cid: str + sexp: str + step_count: int + error: str + + +# === API Request/Response Types === + +class RecipeRunInputs(TypedDict): + """Mapping of input names to CIDs for recipe execution.""" + # Keys are input names, values are CIDs + pass # Actually just Dict[str, str] + + +class RunResult(TypedDict, total=False): + """Result of a recipe run.""" + run_id: str + status: str # "pending", "running", "completed", "failed" + recipe: str + recipe_name: str + inputs: List[str] + output_cid: str + ipfs_cid: str + provenance_cid: str + error: str + created_at: str + completed_at: str + actor_id: str + celery_task_id: str + output_name: str + + +# === Helper functions for type narrowing === + +def is_source_node(node: TransformedNode) -> bool: + """Check if node is a SOURCE node.""" + return node.get("node_type") == "SOURCE" + + +def is_effect_node(node: TransformedNode) -> bool: + """Check if node is an EFFECT node.""" + return node.get("node_type") == "EFFECT" + + +def is_variable_input(config: Dict[str, Any]) -> bool: + """Check if a SOURCE node config represents a variable input.""" + return bool(config.get("input")) + + +def get_effect_cid(config: Dict[str, Any]) -> Optional[str]: + """Get effect CID from config, checking both 'cid' and 'hash' keys.""" + return config.get("cid") or config.get("hash") diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/http_signatures.py b/app/utils/http_signatures.py new file mode 100644 index 0000000..da1f105 --- /dev/null +++ b/app/utils/http_signatures.py @@ -0,0 +1,84 @@ +"""HTTP Signature verification for incoming AP-style inbox requests. + +Implements the same RSA-SHA256 / PKCS1v15 scheme used by the coop's +shared/utils/http_signatures.py, but only the verification side. +""" +from __future__ import annotations + +import base64 +import re + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding + + +def verify_request_signature( + public_key_pem: str, + signature_header: str, + method: str, + path: str, + headers: dict[str, str], +) -> bool: + """Verify an incoming HTTP Signature. + + Args: + public_key_pem: PEM-encoded public key of the sender. + signature_header: Value of the ``Signature`` header. + method: HTTP method (GET, POST, etc.). + path: Request path (e.g. ``/inbox``). + headers: All request headers (case-insensitive keys). + + Returns: + True if the signature is valid. + """ + parts = _parse_signature_header(signature_header) + signed_headers = parts.get("headers", "date").split() + signature_b64 = parts.get("signature", "") + + # Reconstruct the signed string + lc_headers = {k.lower(): v for k, v in headers.items()} + lines: list[str] = [] + for h in signed_headers: + if h == "(request-target)": + lines.append(f"(request-target): {method.lower()} {path}") + else: + lines.append(f"{h}: {lc_headers.get(h, '')}") + + signed_string = "\n".join(lines) + + public_key = serialization.load_pem_public_key(public_key_pem.encode()) + try: + public_key.verify( + base64.b64decode(signature_b64), + signed_string.encode(), + padding.PKCS1v15(), + hashes.SHA256(), + ) + return True + except Exception: + return False + + +def parse_key_id(signature_header: str) -> str: + """Extract the keyId from a Signature header. + + keyId is typically ``https://domain/users/username#main-key``. + Returns the actor URL (strips ``#main-key``). + """ + parts = _parse_signature_header(signature_header) + key_id = parts.get("keyId", "") + return re.sub(r"#.*$", "", key_id) + + +def _parse_signature_header(header: str) -> dict[str, str]: + """Parse a Signature header into its component parts.""" + parts: dict[str, str] = {} + for part in header.split(","): + part = part.strip() + eq = part.find("=") + if eq < 0: + continue + key = part[:eq] + val = part[eq + 1:].strip('"') + parts[key] = val + return parts diff --git a/artdag-client.tar.gz b/artdag-client.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..a4ec7f4490aabd26f65cc0b6c59bf4fed443139d GIT binary patch literal 7982 zcma*a({>z+0szq1w$a#Ztj20=Hdf;aCbrSoR%6?0Cbn(coU_*bg1g@~c{BpV|3F^~ z+qtguCB9Y}2U>{etTkH5DP6?aW((HtPqZy16_Z%CwD_>1vZ&(G;gr)b2ySJA-yfkV zAhC)Mw2hejWSOz7xrsi-iy+&a1zgzH8qewCp8-3or|!4ltEv0AFVT<1%L{;eE+Egt zTTpPiLvUSRd1q6p>|PQ;hFNxzL03+Ue0wl}-c?+>QFy(5djG(dbx*w|^gQ^JtqZ5| z9Do(7k2>N^9wO|*sZ_Q}wAb|McN-5eEoJW{$vV$&KL0eK5R@J@W@Nus+oZn`3I+I^rA6 zdW5Fy#+=~ItF&p%i)_s4c0cInrU+&3{Jt`vjZGR)t68PY|TbZb`ty@v3? zK3B8v^tx$Jx78v<3h@C?&SY?YFZHW)E2BAON&Q2vG9VQiJh6Xk!C0yn;Ps!M^f(4M z#1?NZ-h`aM-X9u(tFhjJ@GhL_3PBPtxdEx1)HB@AiT47&ehv(mH`9v`RBLEi-%=vp zsn`_PjSSJgC6&lWslW-dFblV# zHxOhz(;#Xt`l^R-Zr&F@ODV&XRznwXO(`%vOkhU~jJ5i6x-AIqPY6Xb0p-lFKPs39 z{l7!%r@2%-yvkeyAw#w#DQ|>i0|b8}v7JZUUtPYB6r&Xw$dO&_vPR=k^^AjsadyAL z6?Nag#T{{ZQu<4CnYWfU*7p+>Xs*ic7|-mdIdtiB8mZM^XJ-0r&UndI{AxKjFC|QA zHM8M5Nzwl{uJxN{{stHIu?lt407G;psr;ECMIPjcHm`i%@X&$a^?H)E6}|%6KwC9?#vgytMM%HgDEnHO^;~ z71DJGl#?ODFbsIN6u<~OoQ2h6Gk+wT_$I`><%AUo2X$u{&<~3hQ~=J_<~LL))fMb? zwuQ)+rQ*MbEa^cM?;&j~M}+xq){95iXLT`9Yz2eST$8A&OAKW2(bvbfBOi!20l-d< zIOEhdt9T6LWtt@=918pxp|p2@?S8g7nRH?6@tyD|nX!9-3LqkiaytP&X+7fg=C4#W zauFzxdJtGkvHa;~00i}UWfofOgVr7(6e4J~or6;cO$oj6cW?rKC(~eC2fDdH3a46x zGKpU!NpWX}p9U+GiJ#^|AvtC+`d*Fj($z&}Q#AD{{H1WD=|z@|FBrK%^l!wn9=T%; zqPJJ%R?Z5O3&bzvXP%PPYfo7&V~>Rjk}rqCC@Y+-lsi!Jt)6^NI*~^?=1^KV$am<= z-=kpVe}t$(DFXCQ49cq<|4crLDds>HjE^05iDF@)E2Ri34PcCS->s5qXx<_>zkEla zS8J8#M$LVv7)B&Otz@w$e>R)`87bPn_QyZ%ci~VLT()`4!nEG_P#a7x4gfCuoSE%r z&g*@*g!qLDWvOnIta(^)$0@z{&pg!z@-lpQg?!8SoD=z}h?u5Y>TzP$Tkht^A&DlA zfGCNKR6>#~>__rwMl$>N9Cx0a#E5uO2hthtQm2+g$8@`N-=A?DW1ncI5XGZWMt9dy zbMIG9iNU0Mwf&%YCVBGdL24$F!FWhr6V~n(;mGa+(fRD84`d)&aAcO^?;7L`g9kQo z86b}1F~b6e=G0Wmd6~vO($`~BKc5p@a+$;<7!sBq3*}@ADjpCDkUug>!W?!=7Pm93 ztC#6pTPb5zyfv^Vf^S}jTE!98vV7;U>)8+f9Q1(qb!I0U1UMvR^$~^(X2C8Oo;I_s z?FdWZ&oIOFe>`N$?A%OtqSFN%KW~B8;8{ZnssFl5%1N3@V-!=|s)b^If;BFbyLM%w zD13gDwR`s%KB^h zzt)xwUvIhnJ+uy4Fa)UN?zy3UfS7WgI@hguFWuJ@UKG^)rvezRu8~I#U-6-m`tV7S zzxfVm-~?t%s8+hFu_^zu?Ry?S=UN?+eQ<6 zI3g611?$sXu*+D621fsYRP-3xBgqc-zMahueNiXa@4L4Y7>@QGx(?m-H2YaD~ow;o5-WO18CzXV|eo?jPJ8e4Y5`XErP&{@K>#;BNPMR z%P>2eBHB}-c8l^0rNwWd#_0mp#n(vCr+Ensw0oD1n6_i)z3E&p*d8ec-)+yOf}?`x z!0pTR2ziaW5xr6Pq>-Fa82KU0vFSNakw{kjZ$eY_VAuAs@>rq7b(@5LcwicmWU;$P zdntVq|L~_vEy%-o14<0>9eT=_&^eTfc!Q|-obChnR6vf7|Lw{6j8HEVb8bT+#ByH= z+Uh>ly;D6Skb>!^id7N9lj0oHECt{gINp#~d~j*|aieLBBVBBiaBgPt>}fOc7K+(~ zCrr(RCZ$T#1djQrlS^6YsLfg3^cww`>nuac)UCZ+Yh0z!q*sDV(Xk@QDc7J^+XdM$ zeuu7Htbq6(ArEWtvkO^xV$lhay?jpokFA4DpU(MODa8~>;)6Fo-F_}%R6~v_*pKio z#>{}j-xEBWN64YZkh!rJK#K|EXNXZ}AaAUS$dHSx9*XV=2O6WpzAFx_R{i9%9-IWD zb$e;3RdtLcK%CK>K9}V=Tqm9F|M*bBO_W5^f7oGJSZ2{o4gWe#2i!5|rHi zo<&3CqEp5R;qK@y$-?_+S#gP#0^6PEHIMY(1b&(v?I*@f5@-6pu&*W>QLQI#c)rW` z3nHhdb-tEp(jP)tp?j72k6DQW`Oi&8o&`;ZgqQ97JsKX$+}^UU<_M$37%$a>5-Uj~ zj8@Z=P|b3di*KPig^9;GSQx@@g%<%emVc8C`!Bh20vhIZy}C6ZEVsp^xJBXfZOS8; zP66!?h*+#f`^iaiE(p^6FPyOroX2V(EgokuX29jHbH`S%dkHCbcZL_*4mL zQ3C1M4JChL?8ym0VKC+$%stm2Nyia`;Lkkh#a9pV(Pt2F;bU!HA&DdMQLSBQ(`?!p zc4nWX1RI%G3SS7byhd@qlohSgOb6?H?g_Xg(&iKKP|;8d2Xs3LA~YPr=+W z2?)?S@4Hg{lF8cZkj-ur?ywMemnvn53lg&zqe36D{d`pr8<^v_A$-~$exegW(xt{jJQ8vP&kv%Gnf}LlO)~&~5Fauqgj*BMzI91fr`EPsZ*Rv66NVzsyEQ6o*oT53c zg8NYFl+-mRVh1Q4dt6}cy-lftFO=L{N7*dGWN_e?fnRt*{Rb!94r+sh*2R=Hc)a?$ z4!$;ssgJ!Ezfudqs|8Arh)=OB6NlfSnAf?=7!Z~qt9GIurAVpRk(()dSjL2KuE#kWN%Vh!k zXNm=D{yGI6w!hX=^kE4FVz72FHI&B5WjMAXrg!7>`j0#t3B@j#h76jPnMK;2Z)x{C{bnC`y&uQ%zyS-QuX<7iY%(M z^j%J4dCGR)XVJy0K5BC5E(>Ua#zZ z0GI~&W4%h$MrM&g8^AbsDKnf!_3@7vNwWJNc4#ybJ-|(5kaVX)l~!n9u%uLp}pWM2`r1CgUns|^Np{qbL+wLk>+Q(e%X;=;45Uen*^_kxDHyG>emu$P_ z&|w~Q2HMo02UmW@TKpwR(lrn|RcXFIhj(dw{e^f?c!ldv@C{EYR=7Vdvb`Jc;1vr)Yja^at^) ztKsoqH`8P{{-tv$IJZAZzJ@nqgtY+E{+=matr5qo6eo8 zc~M~<^XdN7w?|c6GKp%~`-?jgx!a5iKI?QXBQq6e_-ru#@d$4uZJ*XA|EPal-4z~& zUK-=qO&&I_=KMIRh^O#aqAh(U-u{dBe%!H`@hyygBe@^9wC~?D&GBH)+G}0vr&jmg zD{uOlT*qRy2VVOxhKF1hvFoeK1hC|t`?sjg z6bP&zv0I>E(Ft!3Jy7MPBC1M0zN4wIXus-Q(t|NIhyKw?RgwQ9xdglia6Q44n{Nhjd3~rUz@EL)+vE>f%YH~UA}CruqZ#pLjBFxL)GFy zIfq0A%tI!iJ`l0e3Z`2(xZkg|q!MHF(NkCiu5~Ke+koWz%O4I~qrYL=Tr$Kn!ALPl zmc{zKg&F=Ju|k7p2R+rNS53FqP5as9>A73E7c8m&6OwIkZH7_Jaf7Ch^!%_DmW5eg}smE-mSLoV1Nr zE$A*qpr9@DP*IJqV1P3gc{b%q&JxXZX#kmHZ&Mbm*9fFV={%?f^CE&(ndefqFq0c9 zj%r;E4#Z>MCmTX7dI#8NDNPPr`(%g~&W_qF8)5P_O=)w;0G=gz5W{zAuLIXxqhGa0 zzl;1ZTS@fVP%8;=8O%Wh>)IyKCSi_@yKb*CM3hmda~pTl*~9w1B>kqYGvT=jOZ`*f)a4m|=?TXdT+-dfiz;Yk4AR#6VaWIE&CnvuR z!`4({91X|g#(r)3uLmmqZ%5jr$sU_8wCE6{1veSjS|)p&?&F9_8?!b@>MdM4Q*fE; zG0J(~UF6k|!APmp!}3jJ@=aFl#jdxR$J!D`WX9es3r~w;6)$>3FvZBSgVpNY58-A{O8lU-fX7 zdi>RP6oq0_QE z9KlJpAYHYU8)dGU*elGT04J0_MFZ>gs%~?0s~{t#YxFuQg!MWBa@teLH6LAB+<4!wg)dyo>9 z>nhCCGT`GaN13LPCYS9S54uX_;cL&IP!-B&q?c^yHO;U`cU!xAH74U~1POL&4jA@y zS1ZIPtrE5*3b1u1>tA3jW707}(z*ebB#(tB@eI@2yZc8Ild8iChH{5@68g!S8bX ziDs{2rih8_^fd0okA5Cz+rsEVJr-i3zrWP{2>;%PRB}arD4%6Qk@gW7r`h^%3LoC7 zg>;g~PAyiK72VUF%Ne@5t|T4SU^bTGACT&#N-;UydG*;e4V`Lv>t@7terrvruu4u< zSz7xZKCJMom7k5wqxydL+Vfji*C3ox*Uon0vqCrsS-`62x#|w|Lb?y~tdR(3$Ba5G zn&KoZ!5rvS&bi2e#s(YF337{6!W53;@8-krm7klXN5R&!&;rxV|FFm(7J&-3EO}pL z8_Xvy%O;p4t9dar1e}22gWEd_xiN|H9$Z$TJ(~&v>)dy5N6IyB?|iY1O9VYKyM3HB zk{5|dD3??qcjUi^(+Po18;AU@(v58 z^pgixLf?yKC~}Hbx09WDkXwt7owDF(mm{1)liqqwEV~XhA&AiD8KQTYmxaq{>a3*V zzl{6Tr;JH~c?oJ`7mCot7|5V=hHTp(7^cCjjAL+Q0^nNWo-UAxRzj{LeJxl0AE{@$ z3Y&o`kQxjJMD8NUeTxQo5qP0=`w*PDZ+wFjm&spUgQXCMXd)XI84tFM@&71YjU45KZ^M5eJ0jwx)nEZ`y2+GLrfPH7eljo zPE#A3mey{EMd9nVxySDtJ{?+8t1vP1s>cZ-;HHijPCWY{GTBA>FZB1uaz_ebGR9IU z8kASCt>h8rGh}|xPL|u7hSt9rrEV|x#jHz$+!3L6ugGDzXypc7V5zcuu>RI@)X52{ znndr1v*0+Z%}n50>G9y+$Q5M;hdpkQ1C^D#KHV1>_7171*juOm^U+8ff7%gfR*xIx zex8G*x?q@!m%oW~-qpNAw<%y;4O1{k*Q}|e6 zDcb+7re5+`Z1-( zFF=szH)T?%LH(_vApSp63NY|$*jEUFf$G`yTf=+2%XB`yE%zfF5t?=F^)E3bU-opU zbrd+Y5@^Mf*`!VpPTC+OKAW< zHGVz$%P$zB*v}#fLL-H4f_R!is(g-rk?_(L1D@F}_(54A#;f(VyhO@*QO`TalSHwi zUoZP4W*S`NW&p803Jl_UdOc!hnMD!mtI#xUD!#L>Q}hv;LWZ=uchkhu#;+qdW2}cY*4#=g69Z6+scm2e@nWIDb*Tt>)C>keXJG!(G>2A?pJ@kd1s%_h^s1w>t|D z&qBg7kJz4-?yioGK_lcKiAyZBTi)epLkb&sLK6>+f>meX#erzr>i_uHYzdmxA zt~y%um-G$%2OawROVJ`L7At&vgnKvVWt^8}D-4pn{E+INJp>?oY%|L1hw?%Iu$rj?A?^lr{~f1z|FhJ_4?J_v2$H;d+pBX}*&5UM;-$sNlq zq~<{!q)0V;j`NCWR!s7*DQLev{e@7w*~bvtT$R9g=ppm9`v7Ra;5~dYT_|sAW^mHk z^`e|hU}B%$0ytc}IXIfUKx&S93m~@i#V|gvd-)P{OTgO<4ku;U1hwR&ZRL!ZZRh%R zNg`YyRC{}7k!?TYXJ~)1Ky&uSv7J|Uot;Kp0lZk)Sk|<07?3&fw}%#{Cb`#&6%RzM z@Js5Ag09|Z`tUMYzQ44L5n|u0+JyU9(PSF&YDNFaQZyV}{!uzoL^(Vm1mQu-7d*g_ zzuMlesxJ?1>Q-B*|MOmtBUx8`P>`CVxhYY zhiXpS!AwP!ojw(FuusfK$djFL^uSjA3}XU}5oZOJh|rp%?bj2#Wu&`W%Mt?<8tC zTT@4&r?Bz}@rBWBeG2spCmKt{Shf)bR;m@v#&Y=^7uY@1=#%?!gFA~URz~dNndRJq fs@%m!FRke_UExZr$NxJW$6!bxm-!b22m<0iX<@vG literal 0 HcmV?d00001 diff --git a/build-client.sh b/build-client.sh new file mode 100755 index 0000000..c9443b6 --- /dev/null +++ b/build-client.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Build the artdag-client tarball +# This script is run during deployment to create the downloadable client package + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLIENT_REPO="https://git.rose-ash.com/art-dag/client.git" +TEMP_DIR=$(mktemp -d) +OUTPUT_FILE="$SCRIPT_DIR/artdag-client.tar.gz" + +echo "Building artdag-client.tar.gz..." + +# Clone the client repo +git clone --depth 1 "$CLIENT_REPO" "$TEMP_DIR/artdag-client" 2>/dev/null || { + echo "Failed to clone client repo, trying alternative..." + # Try GitHub if internal git fails + git clone --depth 1 "https://github.com/gilesbradshaw/art-client.git" "$TEMP_DIR/artdag-client" 2>/dev/null || { + echo "Error: Could not clone client repository" + rm -rf "$TEMP_DIR" + exit 1 + } +} + +# Remove .git directory +rm -rf "$TEMP_DIR/artdag-client/.git" +rm -rf "$TEMP_DIR/artdag-client/__pycache__" + +# Create tarball +cd "$TEMP_DIR" +tar -czf "$OUTPUT_FILE" artdag-client + +# Cleanup +rm -rf "$TEMP_DIR" + +echo "Created: $OUTPUT_FILE" +ls -lh "$OUTPUT_FILE" diff --git a/cache_manager.py b/cache_manager.py new file mode 100644 index 0000000..9474ca2 --- /dev/null +++ b/cache_manager.py @@ -0,0 +1,872 @@ +# art-celery/cache_manager.py +""" +Cache management for Art DAG L1 server. + +Integrates artdag's Cache, ActivityStore, and ActivityManager to provide: +- Content-addressed caching with both node_id and cid +- Activity tracking for runs (input/output/intermediate relationships) +- Deletion rules enforcement (shared items protected) +- L2 ActivityPub integration for "shared" status checks +- IPFS as durable backing store (local cache as hot storage) +- Redis-backed indexes for multi-worker consistency +""" + +import hashlib +import json +import logging +import os +import shutil +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING + +import requests + +if TYPE_CHECKING: + import redis + +from artdag import Cache, CacheEntry, DAG, Node, NodeType +from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + +import ipfs_client + +logger = logging.getLogger(__name__) + + +def file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """Compute local content hash (fallback when IPFS unavailable).""" + hasher = hashlib.new(algorithm) + actual_path = path.resolve() if path.is_symlink() else path + with open(actual_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class CachedFile: + """ + A cached file with both identifiers. + + Provides a unified view combining: + - node_id: computation identity (for DAG caching) + - cid: file content identity (for external references) + """ + node_id: str + cid: str + path: Path + size_bytes: int + node_type: str + created_at: float + + @classmethod + def from_cache_entry(cls, entry: CacheEntry) -> "CachedFile": + return cls( + node_id=entry.node_id, + cid=entry.cid, + path=entry.output_path, + size_bytes=entry.size_bytes, + node_type=entry.node_type, + created_at=entry.created_at, + ) + + +class L2SharedChecker: + """ + Checks if content is shared (published) via L2 ActivityPub server. + + Caches results to avoid repeated API calls. + """ + + def __init__(self, l2_server: str, cache_ttl: int = 300): + self.l2_server = l2_server + self.cache_ttl = cache_ttl + self._cache: Dict[str, tuple[bool, float]] = {} + + def is_shared(self, cid: str) -> bool: + """Check if cid has been published to L2.""" + import time + now = time.time() + + # Check cache + if cid in self._cache: + is_shared, cached_at = self._cache[cid] + if now - cached_at < self.cache_ttl: + logger.debug(f"L2 check (cached): {cid[:16]}... = {is_shared}") + return is_shared + + # Query L2 + try: + url = f"{self.l2_server}/assets/by-hash/{cid}" + logger.info(f"L2 check: GET {url}") + resp = requests.get(url, timeout=5) + logger.info(f"L2 check response: {resp.status_code}") + is_shared = resp.status_code == 200 + except Exception as e: + logger.warning(f"Failed to check L2 for {cid}: {e}") + # On error, assume IS shared (safer - prevents accidental deletion) + is_shared = True + + self._cache[cid] = (is_shared, now) + return is_shared + + def invalidate(self, cid: str): + """Invalidate cache for a cid (call after publishing).""" + self._cache.pop(cid, None) + + def mark_shared(self, cid: str): + """Mark as shared without querying (call after successful publish).""" + import time + self._cache[cid] = (True, time.time()) + + +class L1CacheManager: + """ + Unified cache manager for Art DAG L1 server. + + Combines: + - artdag Cache for file storage + - ActivityStore for run tracking + - ActivityManager for deletion rules + - L2 integration for shared status + + Provides both node_id and cid based access. + """ + + def __init__( + self, + cache_dir: Path | str, + l2_server: str = "http://localhost:8200", + redis_client: Optional["redis.Redis"] = None, + ): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Redis for shared state between workers + self._redis = redis_client + self._redis_content_key = "artdag:content_index" + self._redis_ipfs_key = "artdag:ipfs_index" + + # artdag components + self.cache = Cache(self.cache_dir / "nodes") + self.activity_store = ActivityStore(self.cache_dir / "activities") + + # L2 shared checker + self.l2_checker = L2SharedChecker(l2_server) + + # Activity manager with L2-based is_shared + self.activity_manager = ActivityManager( + cache=self.cache, + activity_store=self.activity_store, + is_shared_fn=self._is_shared_by_node_id, + ) + + # Legacy files directory (for files uploaded directly by cid) + self.legacy_dir = self.cache_dir / "legacy" + self.legacy_dir.mkdir(parents=True, exist_ok=True) + + # ============ Redis Index (no JSON files) ============ + # + # Content index maps: CID (content hash or IPFS CID) -> node_id (code hash) + # IPFS index maps: node_id -> IPFS CID + # + # Database is the ONLY source of truth for cache_id -> ipfs_cid mapping. + # No fallbacks - failures raise exceptions. + + def _run_async(self, coro): + """Run async coroutine from sync context. + + Always creates a fresh event loop to avoid issues with Celery's + prefork workers where loops may be closed by previous tasks. + """ + import asyncio + + # Check if we're already in an async context + try: + asyncio.get_running_loop() + # We're in an async context - use a thread with its own loop + import threading + result = [None] + error = [None] + + def run_in_thread(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + result[0] = new_loop.run_until_complete(coro) + finally: + new_loop.close() + except Exception as e: + error[0] = e + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) + if error[0]: + raise error[0] + return result[0] + except RuntimeError: + # No running loop - create a fresh one (don't reuse potentially closed loops) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + def _set_content_index(self, cache_id: str, ipfs_cid: str): + """Set content index entry in database (cache_id -> ipfs_cid).""" + import database + + async def save_to_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + await conn.execute( + """ + INSERT INTO cache_items (cid, ipfs_cid) + VALUES ($1, $2) + ON CONFLICT (cid) DO UPDATE SET ipfs_cid = $2 + """, + cache_id, ipfs_cid + ) + finally: + await conn.close() + + self._run_async(save_to_db()) + logger.info(f"Indexed in database: {cache_id[:16]}... -> {ipfs_cid}") + + def _get_content_index(self, cache_id: str) -> Optional[str]: + """Get content index entry (cache_id -> ipfs_cid) from database.""" + import database + + async def get_from_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + row = await conn.fetchrow( + "SELECT ipfs_cid FROM cache_items WHERE cid = $1", + cache_id + ) + return {"ipfs_cid": row["ipfs_cid"]} if row else None + finally: + await conn.close() + + result = self._run_async(get_from_db()) + if result and result.get("ipfs_cid"): + return result["ipfs_cid"] + return None + + def _del_content_index(self, cache_id: str): + """Delete content index entry from database.""" + import database + + async def delete_from_db(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + await conn.execute("DELETE FROM cache_items WHERE cid = $1", cache_id) + finally: + await conn.close() + + self._run_async(delete_from_db()) + + def _set_ipfs_index(self, cid: str, ipfs_cid: str): + """Set IPFS index entry in Redis.""" + if self._redis: + try: + self._redis.hset(self._redis_ipfs_key, cid, ipfs_cid) + except Exception as e: + logger.warning(f"Failed to set IPFS index in Redis: {e}") + + def _get_ipfs_cid_from_index(self, cid: str) -> Optional[str]: + """Get IPFS CID from Redis.""" + if self._redis: + try: + val = self._redis.hget(self._redis_ipfs_key, cid) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Failed to get IPFS CID from Redis: {e}") + return None + + def get_ipfs_cid(self, cid: str) -> Optional[str]: + """Get IPFS CID for a content hash.""" + return self._get_ipfs_cid_from_index(cid) + + def _is_shared_by_node_id(self, cid: str) -> bool: + """Check if a cid is shared via L2.""" + return self.l2_checker.is_shared(cid) + + def _load_meta(self, cid: str) -> dict: + """Load metadata for a cached file.""" + meta_path = self.cache_dir / f"{cid}.meta.json" + if meta_path.exists(): + with open(meta_path) as f: + return json.load(f) + return {} + + def is_pinned(self, cid: str) -> tuple[bool, str]: + """ + Check if a cid is pinned (non-deletable). + + Returns: + (is_pinned, reason) tuple + """ + meta = self._load_meta(cid) + if meta.get("pinned"): + return True, meta.get("pin_reason", "published") + return False, "" + + def _save_meta(self, cid: str, **updates) -> dict: + """Save/update metadata for a cached file.""" + meta = self._load_meta(cid) + meta.update(updates) + meta_path = self.cache_dir / f"{cid}.meta.json" + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + return meta + + def pin(self, cid: str, reason: str = "published") -> None: + """Mark an item as pinned (non-deletable).""" + self._save_meta(cid, pinned=True, pin_reason=reason) + + # ============ File Storage ============ + + def put( + self, + source_path: Path, + node_type: str = "upload", + node_id: str = None, + cache_id: str = None, + execution_time: float = 0.0, + move: bool = False, + skip_ipfs: bool = False, + ) -> tuple[CachedFile, Optional[str]]: + """ + Store a file in the cache and optionally upload to IPFS. + + Files are stored by IPFS CID when skip_ipfs=False (default), or by + local content hash when skip_ipfs=True. The cache_id parameter creates + an index from cache_id -> CID for code-addressed lookups. + + Args: + source_path: Path to file to cache + node_type: Type of node (e.g., "upload", "source", "effect") + node_id: DEPRECATED - ignored, always uses CID + cache_id: Optional code-addressed cache ID to index + execution_time: How long the operation took + move: If True, move instead of copy + skip_ipfs: If True, skip IPFS upload and use local hash (faster for large files) + + Returns: + Tuple of (CachedFile with both node_id and cid, CID or None if skip_ipfs) + """ + if skip_ipfs: + # Use local content hash instead of IPFS CID (much faster) + cid = file_hash(source_path) + ipfs_cid = None + logger.info(f"put: Using local hash (skip_ipfs=True): {cid[:16]}...") + else: + # Upload to IPFS first to get the CID (primary identifier) + cid = ipfs_client.add_file(source_path) + if not cid: + raise RuntimeError(f"IPFS upload failed for {source_path}. IPFS is required.") + ipfs_cid = cid + + # Always store by IPFS CID (node_id parameter is deprecated) + node_id = cid + + # Check if already cached (by node_id) + existing = self.cache.get_entry(node_id) + if existing and existing.output_path.exists(): + return CachedFile.from_cache_entry(existing), ipfs_cid + + # Compute local hash BEFORE moving the file (for dual-indexing) + # Only needed if we uploaded to IPFS (to map local hash -> IPFS CID) + local_hash = None + if not skip_ipfs and self._is_ipfs_cid(cid): + local_hash = file_hash(source_path) + + # Store in local cache + logger.info(f"put: Storing in cache with node_id={node_id[:16]}...") + self.cache.put( + node_id=node_id, + source_path=source_path, + node_type=node_type, + execution_time=execution_time, + move=move, + ) + + entry = self.cache.get_entry(node_id) + logger.info(f"put: After cache.put, get_entry(node_id={node_id[:16]}...) returned entry={entry is not None}, path={entry.output_path if entry else None}") + + # Verify we can retrieve it + verify_path = self.cache.get(node_id) + logger.info(f"put: Verify cache.get(node_id={node_id[:16]}...) = {verify_path}") + + # Index by cache_id if provided (code-addressed cache lookup) + # This allows get_by_cid(cache_id) to find files stored by IPFS CID + if cache_id and cache_id != cid: + self._set_content_index(cache_id, cid) + logger.info(f"put: Indexed cache_id {cache_id[:16]}... -> IPFS {cid}") + + # Also index by local hash for content-based lookup + if local_hash and local_hash != cid: + self._set_content_index(local_hash, cid) + logger.debug(f"Indexed local hash {local_hash[:16]}... -> IPFS {cid}") + + logger.info(f"Cached: {cid[:16]}..." + (" (local only)" if skip_ipfs else " (IPFS)")) + + return CachedFile.from_cache_entry(entry), ipfs_cid if not skip_ipfs else None + + def get_by_node_id(self, node_id: str) -> Optional[Path]: + """Get cached file path by node_id.""" + return self.cache.get(node_id) + + def _is_ipfs_cid(self, identifier: str) -> bool: + """Check if identifier looks like an IPFS CID.""" + # CIDv0 starts with "Qm", CIDv1 starts with "bafy" or other multibase prefixes + return identifier.startswith("Qm") or identifier.startswith("bafy") or identifier.startswith("baf") + + def get_by_cid(self, cid: str) -> Optional[Path]: + """Get cached file path by cid or IPFS CID. Falls back to IPFS if not in local cache.""" + logger.info(f"get_by_cid: Looking for cid={cid[:16]}...") + + # Check index first (Redis then local) + node_id = self._get_content_index(cid) + logger.info(f"get_by_cid: Index lookup returned node_id={node_id[:16] if node_id else None}...") + if node_id: + path = self.cache.get(node_id) + logger.info(f"get_by_cid: cache.get(node_id={node_id[:16]}...) returned path={path}") + if path and path.exists(): + logger.info(f"get_by_cid: Found via index: {path}") + return path + + # artdag Cache doesn't know about entry - check filesystem directly + # Files are stored at {cache_dir}/nodes/{node_id}/output.* + nodes_dir = self.cache_dir / "nodes" / node_id + if nodes_dir.exists(): + for f in nodes_dir.iterdir(): + if f.name.startswith("output."): + logger.info(f"get_by_cid: Found on filesystem: {f}") + return f + + # For uploads, node_id == cid, so try direct lookup + # This works even if cache index hasn't been reloaded + path = self.cache.get(cid) + logger.info(f"get_by_cid: Direct cache.get({cid[:16]}...) returned: {path}") + if path and path.exists(): + self._set_content_index(cid, cid) + return path + + # Check filesystem directly for cid as node_id + nodes_dir = self.cache_dir / "nodes" / cid + if nodes_dir.exists(): + for f in nodes_dir.iterdir(): + if f.name.startswith("output."): + logger.info(f"get_by_cid: Found on filesystem (direct): {f}") + self._set_content_index(cid, cid) + return f + + # Scan cache entries (fallback for new structure) + entry = self.cache.find_by_cid(cid) + logger.info(f"get_by_cid: find_by_cid({cid[:16]}...) returned entry={entry}") + if entry and entry.output_path.exists(): + logger.info(f"get_by_cid: Found via scan: {entry.output_path}") + self._set_content_index(cid, entry.node_id) + return entry.output_path + + # Check legacy location (files stored directly as CACHE_DIR/{cid}) + legacy_path = self.cache_dir / cid + logger.info(f"get_by_cid: Checking legacy path: {legacy_path} exists={legacy_path.exists()}") + if legacy_path.exists() and legacy_path.is_file(): + logger.info(f"get_by_cid: Found at legacy path: {legacy_path}") + return legacy_path + + # Fetch from IPFS - this is the source of truth for all content + if self._is_ipfs_cid(cid): + logger.info(f"get_by_cid: Fetching from IPFS: {cid[:16]}...") + recovery_path = self.legacy_dir / cid + recovery_path.parent.mkdir(parents=True, exist_ok=True) + if ipfs_client.get_file(cid, str(recovery_path)): + logger.info(f"get_by_cid: Fetched from IPFS: {recovery_path}") + self._set_content_index(cid, cid) + return recovery_path + else: + logger.warning(f"get_by_cid: IPFS fetch failed for {cid[:16]}...") + + # Also try with a mapped IPFS CID if different from cid + ipfs_cid = self._get_ipfs_cid_from_index(cid) + if ipfs_cid and ipfs_cid != cid: + logger.info(f"get_by_cid: Fetching from IPFS via mapping: {ipfs_cid[:16]}...") + recovery_path = self.legacy_dir / cid + recovery_path.parent.mkdir(parents=True, exist_ok=True) + if ipfs_client.get_file(ipfs_cid, str(recovery_path)): + logger.info(f"get_by_cid: Fetched from IPFS: {recovery_path}") + return recovery_path + + return None + + def has_content(self, cid: str) -> bool: + """Check if content exists in cache.""" + return self.get_by_cid(cid) is not None + + def get_entry_by_cid(self, cid: str) -> Optional[CacheEntry]: + """Get cache entry by cid.""" + node_id = self._get_content_index(cid) + if node_id: + return self.cache.get_entry(node_id) + return self.cache.find_by_cid(cid) + + def list_all(self) -> List[CachedFile]: + """List all cached files.""" + files = [] + seen_hashes = set() + + # New cache structure entries + for entry in self.cache.list_entries(): + files.append(CachedFile.from_cache_entry(entry)) + if entry.cid: + seen_hashes.add(entry.cid) + + # Legacy files stored directly in cache_dir (old structure) + # These are files named by cid directly in CACHE_DIR + for f in self.cache_dir.iterdir(): + # Skip directories and special files + if not f.is_file(): + continue + # Skip metadata/auxiliary files + if f.suffix in ('.json', '.mp4'): + continue + # Skip if name doesn't look like a hash (64 hex chars) + if len(f.name) != 64 or not all(c in '0123456789abcdef' for c in f.name): + continue + # Skip if already seen via new cache + if f.name in seen_hashes: + continue + + files.append(CachedFile( + node_id=f.name, + cid=f.name, + path=f, + size_bytes=f.stat().st_size, + node_type="legacy", + created_at=f.stat().st_mtime, + )) + seen_hashes.add(f.name) + + return files + + def list_by_type(self, node_type: str) -> List[str]: + """ + List CIDs of all cached files of a specific type. + + Args: + node_type: Type to filter by (e.g., "recipe", "upload", "effect") + + Returns: + List of CIDs (IPFS CID if available, otherwise node_id) + """ + cids = [] + for entry in self.cache.list_entries(): + if entry.node_type == node_type: + # Return node_id which is the IPFS CID for uploaded content + cids.append(entry.node_id) + return cids + + # ============ Activity Tracking ============ + + def record_activity(self, dag: DAG, run_id: str = None) -> Activity: + """ + Record a DAG execution as an activity. + + Args: + dag: The executed DAG + run_id: Optional run ID to use as activity_id + + Returns: + The created Activity + """ + activity = Activity.from_dag(dag, activity_id=run_id) + self.activity_store.add(activity) + return activity + + def record_simple_activity( + self, + input_hashes: List[str], + output_cid: str, + run_id: str = None, + ) -> Activity: + """ + Record a simple (non-DAG) execution as an activity. + + For legacy single-effect runs that don't use full DAG execution. + Uses cid as node_id. + """ + activity = Activity( + activity_id=run_id or str(hash((tuple(input_hashes), output_cid))), + input_ids=sorted(input_hashes), + output_id=output_cid, + intermediate_ids=[], + created_at=datetime.now(timezone.utc).timestamp(), + status="completed", + ) + self.activity_store.add(activity) + return activity + + def get_activity(self, activity_id: str) -> Optional[Activity]: + """Get activity by ID.""" + return self.activity_store.get(activity_id) + + def list_activities(self) -> List[Activity]: + """List all activities.""" + return self.activity_store.list() + + def find_activities_by_inputs(self, input_hashes: List[str]) -> List[Activity]: + """Find activities with matching inputs (for UI grouping).""" + return self.activity_store.find_by_input_ids(input_hashes) + + # ============ Deletion Rules ============ + + def can_delete(self, cid: str) -> tuple[bool, str]: + """ + Check if a cached item can be deleted. + + Returns: + (can_delete, reason) tuple + """ + # Check if pinned (published or input to published) + pinned, reason = self.is_pinned(cid) + if pinned: + return False, f"Item is pinned ({reason})" + + # Find node_id for this content + node_id = self._get_content_index(cid) or cid + + # Check if it's an input or output of any activity + for activity in self.activity_store.list(): + if node_id in activity.input_ids: + return False, f"Item is input to activity {activity.activity_id}" + if node_id == activity.output_id: + return False, f"Item is output of activity {activity.activity_id}" + + return True, "OK" + + def can_discard_activity(self, activity_id: str) -> tuple[bool, str]: + """ + Check if an activity can be discarded. + + Returns: + (can_discard, reason) tuple + """ + activity = self.activity_store.get(activity_id) + if not activity: + return False, "Activity not found" + + # Check if any item is pinned + for node_id in activity.all_node_ids: + entry = self.cache.get_entry(node_id) + if entry: + pinned, reason = self.is_pinned(entry.cid) + if pinned: + return False, f"Item {node_id} is pinned ({reason})" + + return True, "OK" + + def delete_by_cid(self, cid: str) -> tuple[bool, str]: + """ + Delete a cached item by cid. + + Enforces deletion rules. + + Returns: + (success, message) tuple + """ + can_delete, reason = self.can_delete(cid) + if not can_delete: + return False, reason + + # Find and delete + node_id = self._get_content_index(cid) + if node_id: + self.cache.remove(node_id) + self._del_content_index(cid) + return True, "Deleted" + + # Try legacy + legacy_path = self.legacy_dir / cid + if legacy_path.exists(): + legacy_path.unlink() + return True, "Deleted (legacy)" + + return False, "Not found" + + def discard_activity(self, activity_id: str) -> tuple[bool, str]: + """ + Discard an activity and clean up its cache entries. + + Enforces deletion rules. + + Returns: + (success, message) tuple + """ + can_discard, reason = self.can_discard_activity(activity_id) + if not can_discard: + return False, reason + + success = self.activity_manager.discard_activity(activity_id) + if success: + return True, "Activity discarded" + return False, "Failed to discard" + + def _is_used_by_other_activities(self, node_id: str, exclude_activity_id: str) -> bool: + """Check if a node is used by any activity other than the excluded one.""" + for other_activity in self.activity_store.list(): + if other_activity.activity_id == exclude_activity_id: + continue + # Check if used as input, output, or intermediate + if node_id in other_activity.input_ids: + return True + if node_id == other_activity.output_id: + return True + if node_id in other_activity.intermediate_ids: + return True + return False + + def discard_activity_outputs_only(self, activity_id: str) -> tuple[bool, str]: + """ + Discard an activity, deleting only outputs and intermediates. + + Inputs (cache items, configs) are preserved. + Outputs/intermediates used by other activities are preserved. + + Returns: + (success, message) tuple + """ + activity = self.activity_store.get(activity_id) + if not activity: + return False, "Activity not found" + + # Check if output is pinned + if activity.output_id: + entry = self.cache.get_entry(activity.output_id) + if entry: + pinned, reason = self.is_pinned(entry.cid) + if pinned: + return False, f"Output is pinned ({reason})" + + deleted_outputs = 0 + preserved_shared = 0 + + # Delete output (only if not used by other activities) + if activity.output_id: + if self._is_used_by_other_activities(activity.output_id, activity_id): + preserved_shared += 1 + else: + entry = self.cache.get_entry(activity.output_id) + if entry: + # Remove from cache + self.cache.remove(activity.output_id) + # Remove from content index (Redis + local) + self._del_content_index(entry.cid) + # Delete from legacy dir if exists + legacy_path = self.legacy_dir / entry.cid + if legacy_path.exists(): + legacy_path.unlink() + deleted_outputs += 1 + + # Delete intermediates (only if not used by other activities) + for node_id in activity.intermediate_ids: + if self._is_used_by_other_activities(node_id, activity_id): + preserved_shared += 1 + continue + entry = self.cache.get_entry(node_id) + if entry: + self.cache.remove(node_id) + self._del_content_index(entry.cid) + legacy_path = self.legacy_dir / entry.cid + if legacy_path.exists(): + legacy_path.unlink() + deleted_outputs += 1 + + # Remove activity record (inputs remain in cache) + self.activity_store.remove(activity_id) + + msg = f"Activity discarded (deleted {deleted_outputs} outputs" + if preserved_shared > 0: + msg += f", preserved {preserved_shared} shared items" + msg += ")" + return True, msg + + def cleanup_intermediates(self) -> int: + """Delete all intermediate cache entries (reconstructible).""" + return self.activity_manager.cleanup_intermediates() + + def get_deletable_items(self) -> List[CachedFile]: + """Get all items that can be deleted.""" + deletable = [] + for entry in self.activity_manager.get_deletable_entries(): + deletable.append(CachedFile.from_cache_entry(entry)) + return deletable + + # ============ L2 Integration ============ + + def mark_published(self, cid: str): + """Mark a cid as published to L2.""" + self.l2_checker.mark_shared(cid) + + def invalidate_shared_cache(self, cid: str): + """Invalidate shared status cache (call if item might be unpublished).""" + self.l2_checker.invalidate(cid) + + # ============ Stats ============ + + def get_stats(self) -> dict: + """Get cache statistics.""" + stats = self.cache.get_stats() + return { + "total_entries": stats.total_entries, + "total_size_bytes": stats.total_size_bytes, + "hits": stats.hits, + "misses": stats.misses, + "hit_rate": stats.hit_rate, + "activities": len(self.activity_store), + } + + +# Singleton instance (initialized on first import with env vars) +_manager: Optional[L1CacheManager] = None + + +def get_cache_manager() -> L1CacheManager: + """Get the singleton cache manager instance.""" + global _manager + if _manager is None: + import redis + from urllib.parse import urlparse + + cache_dir = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache"))) + l2_server = os.environ.get("L2_SERVER", "http://localhost:8200") + + # Initialize Redis client for shared cache index + redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + parsed = urlparse(redis_url) + redis_client = redis.Redis( + host=parsed.hostname or 'localhost', + port=parsed.port or 6379, + db=int(parsed.path.lstrip('/') or 0), + socket_timeout=5, + socket_connect_timeout=5 + ) + + _manager = L1CacheManager(cache_dir=cache_dir, l2_server=l2_server, redis_client=redis_client) + return _manager + + +def reset_cache_manager(): + """Reset the singleton (for testing).""" + global _manager + _manager = None diff --git a/celery_app.py b/celery_app.py new file mode 100644 index 0000000..4a843d2 --- /dev/null +++ b/celery_app.py @@ -0,0 +1,51 @@ +""" +Art DAG Celery Application + +Streaming video rendering for the Art DAG system. +Uses S-expression recipes with frame-by-frame processing. +""" + +import os +import sys +from celery import Celery +from celery.signals import worker_ready + +# Use central config +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from app.config import settings + +app = Celery( + 'art_celery', + broker=settings.redis_url, + backend=settings.redis_url, + include=['tasks', 'tasks.streaming', 'tasks.ipfs_upload'] +) + + +@worker_ready.connect +def log_config_on_startup(sender, **kwargs): + """Log configuration when worker starts.""" + print("=" * 60, file=sys.stderr) + print("WORKER STARTED - CONFIGURATION", file=sys.stderr) + print("=" * 60, file=sys.stderr) + settings.log_config() + print(f"Worker: {sender}", file=sys.stderr) + print("=" * 60, file=sys.stderr) + +app.conf.update( + result_expires=86400 * 7, # 7 days - allow time for recovery after restarts + task_serializer='json', + accept_content=['json', 'pickle'], # pickle needed for internal Celery messages + result_serializer='json', + event_serializer='json', + timezone='UTC', + enable_utc=True, + task_track_started=True, + task_acks_late=True, # Don't ack until task completes - survives worker restart + worker_prefetch_multiplier=1, + task_reject_on_worker_lost=True, # Re-queue if worker dies + task_acks_on_failure_or_timeout=True, # Ack failed tasks so they don't retry forever +) + +if __name__ == '__main__': + app.start() diff --git a/check_redis.py b/check_redis.py new file mode 100644 index 0000000..44f70aa --- /dev/null +++ b/check_redis.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +"""Check Redis connectivity.""" + +import redis + +try: + r = redis.Redis(host='localhost', port=6379, db=0) + r.ping() + print("Redis: OK") +except redis.ConnectionError: + print("Redis: Not running") + print("Start with: sudo systemctl start redis-server") diff --git a/claiming.py b/claiming.py new file mode 100644 index 0000000..77fa1a0 --- /dev/null +++ b/claiming.py @@ -0,0 +1,421 @@ +""" +Hash-based task claiming for distributed execution. + +Prevents duplicate work when multiple workers process the same plan. +Uses Redis Lua scripts for atomic claim operations. +""" + +import json +import logging +import os +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +import redis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + +# Key prefix for task claims +CLAIM_PREFIX = "artdag:claim:" + +# Default TTL for claims (5 minutes) +DEFAULT_CLAIM_TTL = 300 + +# TTL for completed results (1 hour) +COMPLETED_TTL = 3600 + + +class ClaimStatus(Enum): + """Status of a task claim.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + CACHED = "cached" + FAILED = "failed" + + +@dataclass +class ClaimInfo: + """Information about a task claim.""" + cache_id: str + status: ClaimStatus + worker_id: Optional[str] = None + task_id: Optional[str] = None + claimed_at: Optional[str] = None + completed_at: Optional[str] = None + output_path: Optional[str] = None + error: Optional[str] = None + + def to_dict(self) -> dict: + return { + "cache_id": self.cache_id, + "status": self.status.value, + "worker_id": self.worker_id, + "task_id": self.task_id, + "claimed_at": self.claimed_at, + "completed_at": self.completed_at, + "output_path": self.output_path, + "error": self.error, + } + + @classmethod + def from_dict(cls, data: dict) -> "ClaimInfo": + return cls( + cache_id=data["cache_id"], + status=ClaimStatus(data["status"]), + worker_id=data.get("worker_id"), + task_id=data.get("task_id"), + claimed_at=data.get("claimed_at"), + completed_at=data.get("completed_at"), + output_path=data.get("output_path"), + error=data.get("error"), + ) + + +# Lua script for atomic task claiming +# Returns 1 if claim successful, 0 if already claimed/completed +CLAIM_TASK_SCRIPT = """ +local key = KEYS[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + local s = status['status'] + -- Already claimed, running, completed, or cached - don't claim + if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then + return 0 + end +end + +-- Claim the task +local claim_data = ARGV[1] +local ttl = tonumber(ARGV[2]) +redis.call('SETEX', key, ttl, claim_data) +return 1 +""" + +# Lua script for releasing a claim (e.g., on failure) +RELEASE_CLAIM_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local data = redis.call('GET', key) + +if data then + local status = cjson.decode(data) + -- Only release if we own the claim + if status['worker_id'] == worker_id then + redis.call('DEL', key) + return 1 + end +end +return 0 +""" + +# Lua script for updating claim status (claimed -> running -> completed) +UPDATE_STATUS_SCRIPT = """ +local key = KEYS[1] +local worker_id = ARGV[1] +local new_status = ARGV[2] +local new_data = ARGV[3] +local ttl = tonumber(ARGV[4]) + +local data = redis.call('GET', key) +if not data then + return 0 +end + +local status = cjson.decode(data) + +-- Only update if we own the claim +if status['worker_id'] ~= worker_id then + return 0 +end + +redis.call('SETEX', key, ttl, new_data) +return 1 +""" + + +class TaskClaimer: + """ + Manages hash-based task claiming for distributed execution. + + Uses Redis for coordination between workers. + Each task is identified by its cache_id (content-addressed). + """ + + def __init__(self, redis_url: str = None): + """ + Initialize the claimer. + + Args: + redis_url: Redis connection URL + """ + self.redis_url = redis_url or REDIS_URL + self._redis: Optional[redis.Redis] = None + self._claim_script = None + self._release_script = None + self._update_script = None + + @property + def redis(self) -> redis.Redis: + """Get Redis connection (lazy initialization).""" + if self._redis is None: + self._redis = redis.from_url(self.redis_url, decode_responses=True) + # Register Lua scripts + self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT) + self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT) + self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT) + return self._redis + + def _key(self, cache_id: str) -> str: + """Get Redis key for a cache_id.""" + return f"{CLAIM_PREFIX}{cache_id}" + + def claim( + self, + cache_id: str, + worker_id: str, + task_id: Optional[str] = None, + ttl: int = DEFAULT_CLAIM_TTL, + ) -> bool: + """ + Attempt to claim a task. + + Args: + cache_id: The cache ID of the task to claim + worker_id: Identifier for the claiming worker + task_id: Optional Celery task ID + ttl: Time-to-live for the claim in seconds + + Returns: + True if claim successful, False if already claimed + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CLAIMED, + worker_id=worker_id, + task_id=task_id, + claimed_at=datetime.now(timezone.utc).isoformat(), + ) + + result = self._claim_script( + keys=[self._key(cache_id)], + args=[json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}") + return True + else: + logger.debug(f"Task {cache_id[:16]}... already claimed") + return False + + def update_status( + self, + cache_id: str, + worker_id: str, + status: ClaimStatus, + output_path: Optional[str] = None, + error: Optional[str] = None, + ttl: Optional[int] = None, + ) -> bool: + """ + Update the status of a claimed task. + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + status: New status + output_path: Path to output (for completed) + error: Error message (for failed) + ttl: New TTL (defaults based on status) + + Returns: + True if update successful + """ + if ttl is None: + if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED): + ttl = COMPLETED_TTL + else: + ttl = DEFAULT_CLAIM_TTL + + # Get existing claim info + existing = self.get_status(cache_id) + if not existing: + logger.warning(f"No claim found for {cache_id[:16]}...") + return False + + claim_info = ClaimInfo( + cache_id=cache_id, + status=status, + worker_id=worker_id, + task_id=existing.task_id, + claimed_at=existing.claimed_at, + completed_at=datetime.now(timezone.utc).isoformat() if status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ) else None, + output_path=output_path, + error=error, + ) + + result = self._update_script( + keys=[self._key(cache_id)], + args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Updated task {cache_id[:16]}... to {status.value}") + return True + else: + logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)") + return False + + def release(self, cache_id: str, worker_id: str) -> bool: + """ + Release a claim (e.g., on task failure before completion). + + Args: + cache_id: The cache ID of the task + worker_id: Worker ID that owns the claim + + Returns: + True if release successful + """ + result = self._release_script( + keys=[self._key(cache_id)], + args=[worker_id], + client=self.redis, + ) + + if result == 1: + logger.debug(f"Released claim on {cache_id[:16]}...") + return True + return False + + def get_status(self, cache_id: str) -> Optional[ClaimInfo]: + """ + Get the current status of a task. + + Args: + cache_id: The cache ID of the task + + Returns: + ClaimInfo if task has been claimed, None otherwise + """ + data = self.redis.get(self._key(cache_id)) + if data: + return ClaimInfo.from_dict(json.loads(data)) + return None + + def is_completed(self, cache_id: str) -> bool: + """Check if a task is completed or cached.""" + info = self.get_status(cache_id) + return info is not None and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED + ) + + def wait_for_completion( + self, + cache_id: str, + timeout: float = 300, + poll_interval: float = 0.5, + ) -> Optional[ClaimInfo]: + """ + Wait for a task to complete. + + Args: + cache_id: The cache ID of the task + timeout: Maximum time to wait in seconds + poll_interval: How often to check status + + Returns: + ClaimInfo if completed, None if timeout + """ + start_time = time.time() + while time.time() - start_time < timeout: + info = self.get_status(cache_id) + if info and info.status in ( + ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED + ): + return info + time.sleep(poll_interval) + + logger.warning(f"Timeout waiting for {cache_id[:16]}...") + return None + + def mark_cached(self, cache_id: str, output_path: str) -> None: + """ + Mark a task as already cached (no processing needed). + + This is used when we discover the result already exists + before attempting to claim. + + Args: + cache_id: The cache ID of the task + output_path: Path to the cached output + """ + claim_info = ClaimInfo( + cache_id=cache_id, + status=ClaimStatus.CACHED, + output_path=output_path, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + self.redis.setex( + self._key(cache_id), + COMPLETED_TTL, + json.dumps(claim_info.to_dict()), + ) + + def clear_all(self) -> int: + """ + Clear all claims (for testing/reset). + + Returns: + Number of claims cleared + """ + pattern = f"{CLAIM_PREFIX}*" + keys = list(self.redis.scan_iter(match=pattern)) + if keys: + return self.redis.delete(*keys) + return 0 + + +# Global claimer instance +_claimer: Optional[TaskClaimer] = None + + +def get_claimer() -> TaskClaimer: + """Get the global TaskClaimer instance.""" + global _claimer + if _claimer is None: + _claimer = TaskClaimer() + return _claimer + + +def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool: + """Convenience function to claim a task.""" + return get_claimer().claim(cache_id, worker_id, task_id) + + +def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool: + """Convenience function to mark a task as completed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path + ) + + +def fail_task(cache_id: str, worker_id: str, error: str) -> bool: + """Convenience function to mark a task as failed.""" + return get_claimer().update_status( + cache_id, worker_id, ClaimStatus.FAILED, error=error + ) diff --git a/configs/audio-dizzy.sexp b/configs/audio-dizzy.sexp new file mode 100644 index 0000000..dc16087 --- /dev/null +++ b/configs/audio-dizzy.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Paths relative to working directory (project root) +(def music (streaming:make-audio-analyzer "dizzy.mp3")) + +;; Audio playback path (for sync with video output) +(audio-playback "dizzy.mp3") diff --git a/configs/audio-halleluwah.sexp b/configs/audio-halleluwah.sexp new file mode 100644 index 0000000..7d7bfae --- /dev/null +++ b/configs/audio-halleluwah.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Using friendly name for asset resolution +(def music (streaming:make-audio-analyzer "woods-audio")) + +;; Audio playback path (for sync with video output) +(audio-playback "woods-audio") diff --git a/configs/sources-default.sexp b/configs/sources-default.sexp new file mode 100644 index 0000000..754bd92 --- /dev/null +++ b/configs/sources-default.sexp @@ -0,0 +1,38 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Paths relative to working directory (project root) +(def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 +]) diff --git a/configs/sources-woods-half.sexp b/configs/sources-woods-half.sexp new file mode 100644 index 0000000..d2feff8 --- /dev/null +++ b/configs/sources-woods-half.sexp @@ -0,0 +1,19 @@ +;; Half-resolution Woods Sources (960x540) +;; +;; Pass to recipe with: --sources configs/sources-woods-half.sexp + +(require-primitives "streaming") + +(def sources [ + (streaming:make-video-source "woods_half/1.webm" 30) + (streaming:make-video-source "woods_half/2.webm" 30) + (streaming:make-video-source "woods_half/3.webm" 30) + (streaming:make-video-source "woods_half/4.webm" 30) +]) + +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} +]) diff --git a/configs/sources-woods.sexp b/configs/sources-woods.sexp new file mode 100644 index 0000000..ab8dff4 --- /dev/null +++ b/configs/sources-woods.sexp @@ -0,0 +1,39 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Using friendly names for asset resolution +(def sources [ + (streaming:make-video-source "woods-1" 10) + (streaming:make-video-source "woods-2" 10) + (streaming:make-video-source "woods-3" 10) + (streaming:make-video-source "woods-4" 10) + (streaming:make-video-source "woods-5" 10) + (streaming:make-video-source "woods-6" 10) + (streaming:make-video-source "woods-7" 10) + (streaming:make-video-source "woods-8" 10) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + +]) diff --git a/database.py b/database.py new file mode 100644 index 0000000..70187db --- /dev/null +++ b/database.py @@ -0,0 +1,2144 @@ +# art-celery/database.py +""" +PostgreSQL database module for Art DAG L1 server. + +Provides connection pooling and CRUD operations for cache metadata. +""" + +import os +from datetime import datetime, timezone +from typing import List, Optional + +import asyncpg + +DATABASE_URL = os.getenv("DATABASE_URL") +if not DATABASE_URL: + raise RuntimeError("DATABASE_URL environment variable is required") + +pool: Optional[asyncpg.Pool] = None + +SCHEMA_SQL = """ +-- Core cache: just content hash and IPFS CID +-- Physical file storage - shared by all users +CREATE TABLE IF NOT EXISTS cache_items ( + cid VARCHAR(64) PRIMARY KEY, + ipfs_cid VARCHAR(128), + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Item types: per-user metadata (same item can be recipe AND media, per user) +-- actor_id format: @username@server (ActivityPub style) +CREATE TABLE IF NOT EXISTS item_types ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) REFERENCES cache_items(cid) ON DELETE CASCADE, + actor_id VARCHAR(255) NOT NULL, + type VARCHAR(50) NOT NULL, + path VARCHAR(255), + description TEXT, + source_type VARCHAR(20), + source_url TEXT, + source_note TEXT, + pinned BOOLEAN DEFAULT FALSE, + filename VARCHAR(255), + metadata JSONB DEFAULT '{}', + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(cid, actor_id, type, path) +); + + +-- Pin reasons: one-to-many from item_types +CREATE TABLE IF NOT EXISTS pin_reasons ( + id SERIAL PRIMARY KEY, + item_type_id INTEGER REFERENCES item_types(id) ON DELETE CASCADE, + reason VARCHAR(100) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- L2 shares: per-user shares (includes content_type for role when shared) +CREATE TABLE IF NOT EXISTS l2_shares ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) REFERENCES cache_items(cid) ON DELETE CASCADE, + actor_id VARCHAR(255) NOT NULL, + l2_server VARCHAR(255) NOT NULL, + asset_name VARCHAR(255) NOT NULL, + activity_id VARCHAR(128), + content_type VARCHAR(50) NOT NULL, + published_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + last_synced_at TIMESTAMP WITH TIME ZONE, + UNIQUE(cid, actor_id, l2_server, content_type) +); + + +-- Run cache: maps content-addressable run_id to output +-- run_id is a hash of (sorted inputs + recipe), making runs deterministic +CREATE TABLE IF NOT EXISTS run_cache ( + run_id VARCHAR(64) PRIMARY KEY, + output_cid VARCHAR(64) NOT NULL, + ipfs_cid VARCHAR(128), + provenance_cid VARCHAR(128), + plan_cid VARCHAR(128), + recipe VARCHAR(255) NOT NULL, + inputs JSONB NOT NULL, + actor_id VARCHAR(255), + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Pending/running runs: tracks in-progress work for durability +-- Allows runs to survive restarts and be recovered +CREATE TABLE IF NOT EXISTS pending_runs ( + run_id VARCHAR(64) PRIMARY KEY, + celery_task_id VARCHAR(128), + status VARCHAR(20) NOT NULL DEFAULT 'pending', -- pending, running, failed + recipe VARCHAR(255) NOT NULL, + inputs JSONB NOT NULL, + dag_json TEXT, + plan_cid VARCHAR(128), + output_name VARCHAR(255), + actor_id VARCHAR(255), + error TEXT, + ipfs_playlist_cid VARCHAR(128), -- For streaming: IPFS CID of HLS playlist + quality_playlists JSONB, -- For streaming: quality-level playlist CIDs {quality_name: {cid, width, height, bitrate}} + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Add ipfs_playlist_cid if table exists but column doesn't (migration) +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'ipfs_playlist_cid') THEN + ALTER TABLE pending_runs ADD COLUMN ipfs_playlist_cid VARCHAR(128); + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'quality_playlists') THEN + ALTER TABLE pending_runs ADD COLUMN quality_playlists JSONB; + END IF; + -- Checkpoint columns for resumable renders + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_frame') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_frame INTEGER; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_t') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_t FLOAT; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'checkpoint_scans') THEN + ALTER TABLE pending_runs ADD COLUMN checkpoint_scans JSONB; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'total_frames') THEN + ALTER TABLE pending_runs ADD COLUMN total_frames INTEGER; + END IF; + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name = 'pending_runs' AND column_name = 'resumable') THEN + ALTER TABLE pending_runs ADD COLUMN resumable BOOLEAN DEFAULT TRUE; + END IF; +END $$; + +CREATE INDEX IF NOT EXISTS idx_pending_runs_status ON pending_runs(status); +CREATE INDEX IF NOT EXISTS idx_pending_runs_actor ON pending_runs(actor_id); + +-- User storage backends (synced from L2 or configured locally) +CREATE TABLE IF NOT EXISTS storage_backends ( + id SERIAL PRIMARY KEY, + actor_id VARCHAR(255) NOT NULL, + provider_type VARCHAR(50) NOT NULL, -- 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + provider_name VARCHAR(255), + description TEXT, + config JSONB NOT NULL DEFAULT '{}', + capacity_gb INTEGER NOT NULL, + used_bytes BIGINT DEFAULT 0, + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + synced_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Storage pins tracking (what's pinned where) +CREATE TABLE IF NOT EXISTS storage_pins ( + id SERIAL PRIMARY KEY, + cid VARCHAR(64) NOT NULL, + storage_id INTEGER NOT NULL REFERENCES storage_backends(id) ON DELETE CASCADE, + ipfs_cid VARCHAR(128), + pin_type VARCHAR(20) NOT NULL, -- 'user_content', 'donated', 'system' + size_bytes BIGINT, + pinned_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + UNIQUE(cid, storage_id) +); + +-- Friendly names: human-readable versioned names for content +-- Version IDs are server-signed timestamps (always increasing, verifiable origin) +-- Names are per-user within L1; when shared to L2: @user@domain name version +CREATE TABLE IF NOT EXISTS friendly_names ( + id SERIAL PRIMARY KEY, + actor_id VARCHAR(255) NOT NULL, + base_name VARCHAR(255) NOT NULL, -- normalized: my-cool-effect + version_id VARCHAR(20) NOT NULL, -- server-signed timestamp: 01hw3x9kab2cd + cid VARCHAR(64) NOT NULL, -- content address + item_type VARCHAR(20) NOT NULL, -- recipe | effect | media + display_name VARCHAR(255), -- original "My Cool Effect" + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + + UNIQUE(actor_id, base_name, version_id), + UNIQUE(actor_id, cid) -- each CID has exactly one friendly name per user +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_item_types_cid ON item_types(cid); +CREATE INDEX IF NOT EXISTS idx_item_types_actor_id ON item_types(actor_id); +CREATE INDEX IF NOT EXISTS idx_item_types_type ON item_types(type); +CREATE INDEX IF NOT EXISTS idx_item_types_path ON item_types(path); +CREATE INDEX IF NOT EXISTS idx_pin_reasons_item_type ON pin_reasons(item_type_id); +CREATE INDEX IF NOT EXISTS idx_l2_shares_cid ON l2_shares(cid); +CREATE INDEX IF NOT EXISTS idx_l2_shares_actor_id ON l2_shares(actor_id); +CREATE INDEX IF NOT EXISTS idx_run_cache_output ON run_cache(output_cid); +CREATE INDEX IF NOT EXISTS idx_storage_backends_actor ON storage_backends(actor_id); +CREATE INDEX IF NOT EXISTS idx_storage_backends_type ON storage_backends(provider_type); +CREATE INDEX IF NOT EXISTS idx_storage_pins_hash ON storage_pins(cid); +CREATE INDEX IF NOT EXISTS idx_storage_pins_storage ON storage_pins(storage_id); +CREATE INDEX IF NOT EXISTS idx_friendly_names_actor ON friendly_names(actor_id); +CREATE INDEX IF NOT EXISTS idx_friendly_names_type ON friendly_names(item_type); +CREATE INDEX IF NOT EXISTS idx_friendly_names_base ON friendly_names(actor_id, base_name); +CREATE INDEX IF NOT EXISTS idx_friendly_names_latest ON friendly_names(actor_id, item_type, base_name, created_at DESC); +""" + + +async def init_db(): + """Initialize database connection pool and create schema. + + Raises: + asyncpg.PostgresError: If database connection fails + RuntimeError: If pool creation fails + """ + global pool + if pool is not None: + return # Already initialized + try: + pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + if pool is None: + raise RuntimeError(f"Failed to create database pool for {DATABASE_URL}") + async with pool.acquire() as conn: + await conn.execute(SCHEMA_SQL) + except asyncpg.PostgresError as e: + pool = None + raise RuntimeError(f"Database connection failed: {e}") from e + except Exception as e: + pool = None + raise RuntimeError(f"Database initialization failed: {e}") from e + + +async def close_db(): + """Close database connection pool.""" + global pool + if pool: + await pool.close() + pool = None + + +# ============ Cache Items ============ + +async def create_cache_item(cid: str, ipfs_cid: Optional[str] = None) -> dict: + """Create a cache item. Returns the created item.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO cache_items (cid, ipfs_cid) + VALUES ($1, $2) + ON CONFLICT (cid) DO UPDATE SET ipfs_cid = COALESCE($2, cache_items.ipfs_cid) + RETURNING cid, ipfs_cid, created_at + """, + cid, ipfs_cid + ) + return dict(row) + + +async def get_cache_item(cid: str) -> Optional[dict]: + """Get a cache item by content hash.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT cid, ipfs_cid, created_at FROM cache_items WHERE cid = $1", + cid + ) + return dict(row) if row else None + + +async def update_cache_item_ipfs_cid(cid: str, ipfs_cid: str) -> bool: + """Update the IPFS CID for a cache item.""" + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE cache_items SET ipfs_cid = $2 WHERE cid = $1", + cid, ipfs_cid + ) + return result == "UPDATE 1" + + +async def get_ipfs_cid(cid: str) -> Optional[str]: + """Get the IPFS CID for a cache item by its internal CID.""" + async with pool.acquire() as conn: + return await conn.fetchval( + "SELECT ipfs_cid FROM cache_items WHERE cid = $1", + cid + ) + + +async def delete_cache_item(cid: str) -> bool: + """Delete a cache item and all associated data (cascades).""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM cache_items WHERE cid = $1", + cid + ) + return result == "DELETE 1" + + +async def list_cache_items(limit: int = 100, offset: int = 0) -> List[dict]: + """List cache items with pagination.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT cid, ipfs_cid, created_at + FROM cache_items + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + """, + limit, offset + ) + return [dict(row) for row in rows] + + +# ============ Item Types ============ + +async def add_item_type( + cid: str, + actor_id: str, + item_type: str, + path: Optional[str] = None, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, +) -> dict: + """Add a type to a cache item for a user. Creates cache_item if needed.""" + async with pool.acquire() as conn: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + cid + ) + # Insert or update item_type + row = await conn.fetchrow( + """ + INSERT INTO item_types (cid, actor_id, type, path, description, source_type, source_url, source_note) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE($5, item_types.description), + source_type = COALESCE($6, item_types.source_type), + source_url = COALESCE($7, item_types.source_url), + source_note = COALESCE($8, item_types.source_note) + RETURNING id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + """, + cid, actor_id, item_type, path, description, source_type, source_url, source_note + ) + return dict(row) + + +async def get_item_types(cid: str, actor_id: Optional[str] = None) -> List[dict]: + """Get types for a cache item, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 + ORDER BY created_at + """, + cid, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 + ORDER BY created_at + """, + cid + ) + return [dict(row) for row in rows] + + +async def get_item_type(cid: str, actor_id: str, item_type: str, path: Optional[str] = None) -> Optional[dict]: + """Get a specific type for a cache item and user.""" + async with pool.acquire() as conn: + if path is None: + row = await conn.fetchrow( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """, + cid, actor_id, item_type + ) + else: + row = await conn.fetchrow( + """ + SELECT id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, created_at + FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path = $4 + """, + cid, actor_id, item_type, path + ) + return dict(row) if row else None + + +async def update_item_type( + item_type_id: int, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, +) -> bool: + """Update an item type's metadata.""" + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE item_types SET + description = COALESCE($2, description), + source_type = COALESCE($3, source_type), + source_url = COALESCE($4, source_url), + source_note = COALESCE($5, source_note) + WHERE id = $1 + """, + item_type_id, description, source_type, source_url, source_note + ) + return result == "UPDATE 1" + + +async def delete_item_type(cid: str, actor_id: str, item_type: str, path: Optional[str] = None) -> bool: + """Delete a specific type from a cache item for a user.""" + async with pool.acquire() as conn: + if path is None: + result = await conn.execute( + "DELETE FROM item_types WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL", + cid, actor_id, item_type + ) + else: + result = await conn.execute( + "DELETE FROM item_types WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path = $4", + cid, actor_id, item_type, path + ) + return result == "DELETE 1" + + +async def list_items_by_type(item_type: str, actor_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[dict]: + """List items of a specific type, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.actor_id = $2 + ORDER BY it.created_at DESC + LIMIT $3 OFFSET $4 + """, + item_type, actor_id, limit, offset + ) + else: + rows = await conn.fetch( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 + ORDER BY it.created_at DESC + LIMIT $2 OFFSET $3 + """, + item_type, limit, offset + ) + return [dict(row) for row in rows] + + +async def get_item_by_path(item_type: str, path: str, actor_id: Optional[str] = None) -> Optional[dict]: + """Get an item by its type and path (e.g., recipe:/effects/dog), optionally for a specific user.""" + async with pool.acquire() as conn: + if actor_id: + row = await conn.fetchrow( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.path = $2 AND it.actor_id = $3 + """, + item_type, path, actor_id + ) + else: + row = await conn.fetchrow( + """ + SELECT it.id, it.cid, it.actor_id, it.type, it.path, it.description, + it.source_type, it.source_url, it.source_note, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.type = $1 AND it.path = $2 + """, + item_type, path + ) + return dict(row) if row else None + + +# ============ Pinning ============ + +async def pin_item_type(item_type_id: int, reason: str) -> bool: + """Pin an item type with a reason.""" + async with pool.acquire() as conn: + async with conn.transaction(): + # Set pinned flag + await conn.execute( + "UPDATE item_types SET pinned = TRUE WHERE id = $1", + item_type_id + ) + # Add pin reason + await conn.execute( + "INSERT INTO pin_reasons (item_type_id, reason) VALUES ($1, $2)", + item_type_id, reason + ) + return True + + +async def unpin_item_type(item_type_id: int, reason: Optional[str] = None) -> bool: + """Remove a pin reason from an item type. If no reasons left, unpins the item.""" + async with pool.acquire() as conn: + async with conn.transaction(): + if reason: + # Remove specific reason + await conn.execute( + "DELETE FROM pin_reasons WHERE item_type_id = $1 AND reason = $2", + item_type_id, reason + ) + else: + # Remove all reasons + await conn.execute( + "DELETE FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + + # Check if any reasons remain + count = await conn.fetchval( + "SELECT COUNT(*) FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + + if count == 0: + await conn.execute( + "UPDATE item_types SET pinned = FALSE WHERE id = $1", + item_type_id + ) + return True + + +async def get_pin_reasons(item_type_id: int) -> List[dict]: + """Get all pin reasons for an item type.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT id, reason, created_at FROM pin_reasons WHERE item_type_id = $1 ORDER BY created_at", + item_type_id + ) + return [dict(row) for row in rows] + + +async def is_item_pinned(cid: str, item_type: Optional[str] = None) -> tuple[bool, List[str]]: + """Check if any type of a cache item is pinned. Returns (is_pinned, reasons).""" + async with pool.acquire() as conn: + if item_type: + rows = await conn.fetch( + """ + SELECT pr.reason + FROM pin_reasons pr + JOIN item_types it ON pr.item_type_id = it.id + WHERE it.cid = $1 AND it.type = $2 AND it.pinned = TRUE + """, + cid, item_type + ) + else: + rows = await conn.fetch( + """ + SELECT pr.reason + FROM pin_reasons pr + JOIN item_types it ON pr.item_type_id = it.id + WHERE it.cid = $1 AND it.pinned = TRUE + """, + cid + ) + reasons = [row["reason"] for row in rows] + return len(reasons) > 0, reasons + + +# ============ L2 Shares ============ + +async def add_l2_share( + cid: str, + actor_id: str, + l2_server: str, + asset_name: str, + content_type: str, +) -> dict: + """Add or update an L2 share for a user.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO l2_shares (cid, actor_id, l2_server, asset_name, content_type, last_synced_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (cid, actor_id, l2_server, content_type) DO UPDATE SET + asset_name = $4, + last_synced_at = NOW() + RETURNING id, cid, actor_id, l2_server, asset_name, content_type, published_at, last_synced_at + """, + cid, actor_id, l2_server, asset_name, content_type + ) + return dict(row) + + +async def get_l2_shares(cid: str, actor_id: Optional[str] = None) -> List[dict]: + """Get L2 shares for a cache item, optionally filtered by user.""" + async with pool.acquire() as conn: + if actor_id: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares + WHERE cid = $1 AND actor_id = $2 + ORDER BY published_at + """, + cid, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT id, cid, actor_id, l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares + WHERE cid = $1 + ORDER BY published_at + """, + cid + ) + return [dict(row) for row in rows] + + +async def delete_l2_share(cid: str, actor_id: str, l2_server: str, content_type: str) -> bool: + """Delete an L2 share for a user.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM l2_shares WHERE cid = $1 AND actor_id = $2 AND l2_server = $3 AND content_type = $4", + cid, actor_id, l2_server, content_type + ) + return result == "DELETE 1" + + +# ============ Cache Item Cleanup ============ + +async def has_remaining_references(cid: str) -> bool: + """Check if a cache item has any remaining item_types or l2_shares.""" + async with pool.acquire() as conn: + item_types_count = await conn.fetchval( + "SELECT COUNT(*) FROM item_types WHERE cid = $1", + cid + ) + if item_types_count > 0: + return True + + l2_shares_count = await conn.fetchval( + "SELECT COUNT(*) FROM l2_shares WHERE cid = $1", + cid + ) + return l2_shares_count > 0 + + +async def cleanup_orphaned_cache_item(cid: str) -> bool: + """Delete a cache item if it has no remaining references. Returns True if deleted.""" + async with pool.acquire() as conn: + # Only delete if no item_types or l2_shares reference it + result = await conn.execute( + """ + DELETE FROM cache_items + WHERE cid = $1 + AND NOT EXISTS (SELECT 1 FROM item_types WHERE cid = $1) + AND NOT EXISTS (SELECT 1 FROM l2_shares WHERE cid = $1) + """, + cid + ) + return result == "DELETE 1" + + +# ============ High-Level Metadata Functions ============ +# These provide a compatible interface to the old JSON-based save_cache_meta/load_cache_meta + +import json as _json + + +async def save_item_metadata( + cid: str, + actor_id: str, + item_type: str = "media", + filename: Optional[str] = None, + description: Optional[str] = None, + source_type: Optional[str] = None, + source_url: Optional[str] = None, + source_note: Optional[str] = None, + pinned: bool = False, + pin_reason: Optional[str] = None, + tags: Optional[List[str]] = None, + folder: Optional[str] = None, + collections: Optional[List[str]] = None, + **extra_metadata +) -> dict: + """ + Save or update item metadata in the database. + + Returns a dict with the item metadata (compatible with old JSON format). + """ + import logging + logger = logging.getLogger(__name__) + logger.info(f"save_item_metadata: cid={cid[:16] if cid else None}..., actor_id={actor_id}, item_type={item_type}") + # Build metadata JSONB for extra fields + metadata = {} + if tags: + metadata["tags"] = tags + if folder: + metadata["folder"] = folder + if collections: + metadata["collections"] = collections + metadata.update(extra_metadata) + + async with pool.acquire() as conn: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + cid + ) + + # Upsert item_type + row = await conn.fetchrow( + """ + INSERT INTO item_types (cid, actor_id, type, description, source_type, source_url, source_note, pinned, filename, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE(EXCLUDED.description, item_types.description), + source_type = COALESCE(EXCLUDED.source_type, item_types.source_type), + source_url = COALESCE(EXCLUDED.source_url, item_types.source_url), + source_note = COALESCE(EXCLUDED.source_note, item_types.source_note), + pinned = EXCLUDED.pinned, + filename = COALESCE(EXCLUDED.filename, item_types.filename), + metadata = item_types.metadata || EXCLUDED.metadata + RETURNING id, cid, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + """, + cid, actor_id, item_type, description, source_type, source_url, source_note, pinned, filename, _json.dumps(metadata) + ) + + item_type_id = row["id"] + logger.info(f"save_item_metadata: Created/updated item_type id={item_type_id} for cid={cid[:16]}...") + + # Handle pinning + if pinned and pin_reason: + # Add pin reason if not exists + await conn.execute( + """ + INSERT INTO pin_reasons (item_type_id, reason) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, + item_type_id, pin_reason + ) + + # Build response dict (compatible with old format) + result = { + "uploader": actor_id, + "uploaded_at": row["created_at"].isoformat() if row["created_at"] else None, + "filename": row["filename"], + "type": row["type"], + "description": row["description"], + "pinned": row["pinned"], + } + + # Add origin if present + if row["source_type"] or row["source_url"] or row["source_note"]: + result["origin"] = { + "type": row["source_type"], + "url": row["source_url"], + "note": row["source_note"] + } + + # Add metadata fields + if row["metadata"]: + meta = row["metadata"] if isinstance(row["metadata"], dict) else _json.loads(row["metadata"]) + if meta.get("tags"): + result["tags"] = meta["tags"] + if meta.get("folder"): + result["folder"] = meta["folder"] + if meta.get("collections"): + result["collections"] = meta["collections"] + + # Get pin reasons + if row["pinned"]: + reasons = await conn.fetch( + "SELECT reason FROM pin_reasons WHERE item_type_id = $1", + item_type_id + ) + if reasons: + result["pin_reason"] = reasons[0]["reason"] + + return result + + +async def load_item_metadata(cid: str, actor_id: Optional[str] = None) -> dict: + """ + Load item metadata from the database. + + If actor_id is provided, returns metadata for that user's view of the item. + Otherwise, returns combined metadata from all users (for backwards compat). + + Returns a dict compatible with old JSON format. + """ + async with pool.acquire() as conn: + # Get cache item + cache_item = await conn.fetchrow( + "SELECT cid, ipfs_cid, created_at FROM cache_items WHERE cid = $1", + cid + ) + + if not cache_item: + return {} + + # Get item types + if actor_id: + item_types = await conn.fetch( + """ + SELECT id, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + FROM item_types WHERE cid = $1 AND actor_id = $2 + ORDER BY created_at + """, + cid, actor_id + ) + else: + item_types = await conn.fetch( + """ + SELECT id, actor_id, type, path, description, source_type, source_url, source_note, pinned, filename, metadata, created_at + FROM item_types WHERE cid = $1 + ORDER BY created_at + """, + cid + ) + + if not item_types: + return {"uploaded_at": cache_item["created_at"].isoformat() if cache_item["created_at"] else None} + + # Use first item type as primary (for backwards compat) + primary = item_types[0] + + result = { + "uploader": primary["actor_id"], + "uploaded_at": primary["created_at"].isoformat() if primary["created_at"] else None, + "filename": primary["filename"], + "type": primary["type"], + "description": primary["description"], + "pinned": any(it["pinned"] for it in item_types), + } + + # Add origin if present + if primary["source_type"] or primary["source_url"] or primary["source_note"]: + result["origin"] = { + "type": primary["source_type"], + "url": primary["source_url"], + "note": primary["source_note"] + } + + # Add metadata fields + if primary["metadata"]: + meta = primary["metadata"] if isinstance(primary["metadata"], dict) else _json.loads(primary["metadata"]) + if meta.get("tags"): + result["tags"] = meta["tags"] + if meta.get("folder"): + result["folder"] = meta["folder"] + if meta.get("collections"): + result["collections"] = meta["collections"] + + # Get pin reasons for pinned items + for it in item_types: + if it["pinned"]: + reasons = await conn.fetch( + "SELECT reason FROM pin_reasons WHERE item_type_id = $1", + it["id"] + ) + if reasons: + result["pin_reason"] = reasons[0]["reason"] + break + + # Get L2 shares + if actor_id: + shares = await conn.fetch( + """ + SELECT l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares WHERE cid = $1 AND actor_id = $2 + """, + cid, actor_id + ) + else: + shares = await conn.fetch( + """ + SELECT l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + FROM l2_shares WHERE cid = $1 + """, + cid + ) + + if shares: + result["l2_shares"] = [ + { + "l2_server": s["l2_server"], + "asset_name": s["asset_name"], + "activity_id": s["activity_id"], + "content_type": s["content_type"], + "published_at": s["published_at"].isoformat() if s["published_at"] else None, + "last_synced_at": s["last_synced_at"].isoformat() if s["last_synced_at"] else None, + } + for s in shares + ] + + # For backwards compat, also set "published" if shared + result["published"] = { + "to_l2": True, + "asset_name": shares[0]["asset_name"], + "activity_id": shares[0]["activity_id"], + "l2_server": shares[0]["l2_server"], + } + + return result + + +async def update_item_metadata( + cid: str, + actor_id: str, + item_type: str = "media", + **updates +) -> dict: + """ + Update specific fields of item metadata. + + Returns updated metadata dict. + """ + # Extract known fields from updates + description = updates.pop("description", None) + source_type = updates.pop("source_type", None) + source_url = updates.pop("source_url", None) + source_note = updates.pop("source_note", None) + + # Handle origin dict format + origin = updates.pop("origin", None) + if origin: + source_type = origin.get("type", source_type) + source_url = origin.get("url", source_url) + source_note = origin.get("note", source_note) + + pinned = updates.pop("pinned", None) + pin_reason = updates.pop("pin_reason", None) + filename = updates.pop("filename", None) + tags = updates.pop("tags", None) + folder = updates.pop("folder", None) + collections = updates.pop("collections", None) + + async with pool.acquire() as conn: + # Get existing item_type + existing = await conn.fetchrow( + """ + SELECT id, metadata FROM item_types + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """, + cid, actor_id, item_type + ) + + if not existing: + # Create new entry + return await save_item_metadata( + cid, actor_id, item_type, + filename=filename, description=description, + source_type=source_type, source_url=source_url, source_note=source_note, + pinned=pinned or False, pin_reason=pin_reason, + tags=tags, folder=folder, collections=collections, + **updates + ) + + # Build update query dynamically + set_parts = [] + params = [cid, actor_id, item_type] + param_idx = 4 + + if description is not None: + set_parts.append(f"description = ${param_idx}") + params.append(description) + param_idx += 1 + + if source_type is not None: + set_parts.append(f"source_type = ${param_idx}") + params.append(source_type) + param_idx += 1 + + if source_url is not None: + set_parts.append(f"source_url = ${param_idx}") + params.append(source_url) + param_idx += 1 + + if source_note is not None: + set_parts.append(f"source_note = ${param_idx}") + params.append(source_note) + param_idx += 1 + + if pinned is not None: + set_parts.append(f"pinned = ${param_idx}") + params.append(pinned) + param_idx += 1 + + if filename is not None: + set_parts.append(f"filename = ${param_idx}") + params.append(filename) + param_idx += 1 + + # Handle metadata updates + current_metadata = existing["metadata"] if isinstance(existing["metadata"], dict) else (_json.loads(existing["metadata"]) if existing["metadata"] else {}) + if tags is not None: + current_metadata["tags"] = tags + if folder is not None: + current_metadata["folder"] = folder + if collections is not None: + current_metadata["collections"] = collections + current_metadata.update(updates) + + if current_metadata: + set_parts.append(f"metadata = ${param_idx}") + params.append(_json.dumps(current_metadata)) + param_idx += 1 + + if set_parts: + query = f""" + UPDATE item_types SET {', '.join(set_parts)} + WHERE cid = $1 AND actor_id = $2 AND type = $3 AND path IS NULL + """ + await conn.execute(query, *params) + + # Handle pin reason + if pinned and pin_reason: + await conn.execute( + """ + INSERT INTO pin_reasons (item_type_id, reason) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, + existing["id"], pin_reason + ) + + return await load_item_metadata(cid, actor_id) + + +async def save_l2_share( + cid: str, + actor_id: str, + l2_server: str, + asset_name: str, + content_type: str = "media", + activity_id: Optional[str] = None +) -> dict: + """Save an L2 share and return share info.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO l2_shares (cid, actor_id, l2_server, asset_name, activity_id, content_type, last_synced_at) + VALUES ($1, $2, $3, $4, $5, $6, NOW()) + ON CONFLICT (cid, actor_id, l2_server, content_type) DO UPDATE SET + asset_name = EXCLUDED.asset_name, + activity_id = COALESCE(EXCLUDED.activity_id, l2_shares.activity_id), + last_synced_at = NOW() + RETURNING l2_server, asset_name, activity_id, content_type, published_at, last_synced_at + """, + cid, actor_id, l2_server, asset_name, activity_id, content_type + ) + return { + "l2_server": row["l2_server"], + "asset_name": row["asset_name"], + "activity_id": row["activity_id"], + "content_type": row["content_type"], + "published_at": row["published_at"].isoformat() if row["published_at"] else None, + "last_synced_at": row["last_synced_at"].isoformat() if row["last_synced_at"] else None, + } + + +async def get_user_items(actor_id: str, item_type: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[dict]: + """Get all items for a user, optionally filtered by type. Deduplicates by cid.""" + async with pool.acquire() as conn: + if item_type: + rows = await conn.fetch( + """ + SELECT * FROM ( + SELECT DISTINCT ON (it.cid) + it.cid, it.type, it.description, it.filename, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.actor_id = $1 AND it.type = $2 + ORDER BY it.cid, it.created_at DESC + ) deduped + ORDER BY created_at DESC + LIMIT $3 OFFSET $4 + """, + actor_id, item_type, limit, offset + ) + else: + rows = await conn.fetch( + """ + SELECT * FROM ( + SELECT DISTINCT ON (it.cid) + it.cid, it.type, it.description, it.filename, it.pinned, it.created_at, + ci.ipfs_cid + FROM item_types it + JOIN cache_items ci ON it.cid = ci.cid + WHERE it.actor_id = $1 + ORDER BY it.cid, it.created_at DESC + ) deduped + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + """, + actor_id, limit, offset + ) + + return [ + { + "cid": r["cid"], + "type": r["type"], + "description": r["description"], + "filename": r["filename"], + "pinned": r["pinned"], + "created_at": r["created_at"].isoformat() if r["created_at"] else None, + "ipfs_cid": r["ipfs_cid"], + } + for r in rows + ] + + +async def count_user_items(actor_id: str, item_type: Optional[str] = None) -> int: + """Count unique items (by cid) for a user.""" + async with pool.acquire() as conn: + if item_type: + return await conn.fetchval( + "SELECT COUNT(DISTINCT cid) FROM item_types WHERE actor_id = $1 AND type = $2", + actor_id, item_type + ) + else: + return await conn.fetchval( + "SELECT COUNT(DISTINCT cid) FROM item_types WHERE actor_id = $1", + actor_id + ) + + +# ============ Run Cache ============ + +async def get_run_cache(run_id: str) -> Optional[dict]: + """Get cached run result by content-addressable run_id.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + FROM run_cache WHERE run_id = $1 + """, + run_id + ) + if row: + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +async def save_run_cache( + run_id: str, + output_cid: str, + recipe: str, + inputs: List[str], + ipfs_cid: Optional[str] = None, + provenance_cid: Optional[str] = None, + plan_cid: Optional[str] = None, + actor_id: Optional[str] = None, +) -> dict: + """Save run result to cache. Updates if run_id already exists.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO run_cache (run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (run_id) DO UPDATE SET + output_cid = EXCLUDED.output_cid, + ipfs_cid = COALESCE(EXCLUDED.ipfs_cid, run_cache.ipfs_cid), + provenance_cid = COALESCE(EXCLUDED.provenance_cid, run_cache.provenance_cid), + plan_cid = COALESCE(EXCLUDED.plan_cid, run_cache.plan_cid), + actor_id = COALESCE(EXCLUDED.actor_id, run_cache.actor_id), + recipe = COALESCE(EXCLUDED.recipe, run_cache.recipe), + inputs = COALESCE(EXCLUDED.inputs, run_cache.inputs) + RETURNING run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + """, + run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, _json.dumps(inputs), actor_id + ) + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + + +async def get_run_by_output(output_cid: str) -> Optional[dict]: + """Get run cache entry by output hash.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, recipe, inputs, actor_id, created_at + FROM run_cache WHERE output_cid = $1 + """, + output_cid + ) + if row: + return { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +def _parse_inputs(inputs_value): + """Parse inputs from database - may be JSON string, list, or None.""" + if inputs_value is None: + return [] + if isinstance(inputs_value, list): + return inputs_value + if isinstance(inputs_value, str): + try: + parsed = _json.loads(inputs_value) + if isinstance(parsed, list): + return parsed + return [] + except (_json.JSONDecodeError, TypeError): + return [] + return [] + + +async def list_runs_by_actor(actor_id: str, offset: int = 0, limit: int = 20) -> List[dict]: + """List completed runs for a user, ordered by creation time (newest first).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT run_id, output_cid, ipfs_cid, provenance_cid, plan_cid, recipe, inputs, actor_id, created_at + FROM run_cache + WHERE actor_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + """, + actor_id, limit, offset + ) + return [ + { + "run_id": row["run_id"], + "output_cid": row["output_cid"], + "ipfs_cid": row["ipfs_cid"], + "provenance_cid": row["provenance_cid"], + "plan_cid": row["plan_cid"], + "recipe": row["recipe"], + "inputs": _parse_inputs(row["inputs"]), + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "status": "completed", + } + for row in rows + ] + + +async def delete_run_cache(run_id: str) -> bool: + """Delete a run from the cache.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM run_cache WHERE run_id = $1", + run_id + ) + return result == "DELETE 1" + + +async def delete_pending_run(run_id: str) -> bool: + """Delete a pending run.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM pending_runs WHERE run_id = $1", + run_id + ) + return result == "DELETE 1" + + +# ============ Storage Backends ============ + +async def get_user_storage(actor_id: str) -> List[dict]: + """Get all storage backends for a user.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE actor_id = $1 + ORDER BY provider_type, created_at""", + actor_id + ) + return [dict(row) for row in rows] + + +async def get_user_storage_by_type(actor_id: str, provider_type: str) -> List[dict]: + """Get storage backends of a specific type for a user.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE actor_id = $1 AND provider_type = $2 + ORDER BY created_at""", + actor_id, provider_type + ) + return [dict(row) for row in rows] + + +async def get_storage_by_id(storage_id: int) -> Optional[dict]: + """Get a storage backend by ID.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT id, actor_id, provider_type, provider_name, description, config, + capacity_gb, used_bytes, is_active, created_at, synced_at + FROM storage_backends WHERE id = $1""", + storage_id + ) + return dict(row) if row else None + + +async def add_user_storage( + actor_id: str, + provider_type: str, + provider_name: str, + config: dict, + capacity_gb: int, + description: Optional[str] = None +) -> Optional[int]: + """Add a storage backend for a user. Returns storage ID.""" + async with pool.acquire() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO storage_backends (actor_id, provider_type, provider_name, description, config, capacity_gb) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id""", + actor_id, provider_type, provider_name, description, _json.dumps(config), capacity_gb + ) + return row["id"] if row else None + except Exception: + return None + + +async def update_user_storage( + storage_id: int, + provider_name: Optional[str] = None, + description: Optional[str] = None, + config: Optional[dict] = None, + capacity_gb: Optional[int] = None, + is_active: Optional[bool] = None +) -> bool: + """Update a storage backend.""" + updates = [] + params = [] + param_num = 1 + + if provider_name is not None: + updates.append(f"provider_name = ${param_num}") + params.append(provider_name) + param_num += 1 + if description is not None: + updates.append(f"description = ${param_num}") + params.append(description) + param_num += 1 + if config is not None: + updates.append(f"config = ${param_num}") + params.append(_json.dumps(config)) + param_num += 1 + if capacity_gb is not None: + updates.append(f"capacity_gb = ${param_num}") + params.append(capacity_gb) + param_num += 1 + if is_active is not None: + updates.append(f"is_active = ${param_num}") + params.append(is_active) + param_num += 1 + + if not updates: + return False + + updates.append("synced_at = NOW()") + params.append(storage_id) + + async with pool.acquire() as conn: + result = await conn.execute( + f"UPDATE storage_backends SET {', '.join(updates)} WHERE id = ${param_num}", + *params + ) + return "UPDATE 1" in result + + +async def remove_user_storage(storage_id: int) -> bool: + """Remove a storage backend. Cascades to storage_pins.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM storage_backends WHERE id = $1", + storage_id + ) + return "DELETE 1" in result + + +async def get_storage_usage(storage_id: int) -> dict: + """Get storage usage stats.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT + COUNT(*) as pin_count, + COALESCE(SUM(size_bytes), 0) as used_bytes + FROM storage_pins WHERE storage_id = $1""", + storage_id + ) + return {"pin_count": row["pin_count"], "used_bytes": row["used_bytes"]} + + +async def get_all_active_storage() -> List[dict]: + """Get all active storage backends (for distributed pinning).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT sb.id, sb.actor_id, sb.provider_type, sb.provider_name, sb.description, + sb.config, sb.capacity_gb, sb.is_active, sb.created_at, sb.synced_at, + COALESCE(SUM(sp.size_bytes), 0) as used_bytes, + COUNT(sp.id) as pin_count + FROM storage_backends sb + LEFT JOIN storage_pins sp ON sb.id = sp.storage_id + WHERE sb.is_active = true + GROUP BY sb.id + ORDER BY sb.provider_type, sb.created_at""" + ) + return [dict(row) for row in rows] + + +async def add_storage_pin( + cid: str, + storage_id: int, + ipfs_cid: Optional[str], + pin_type: str, + size_bytes: int +) -> Optional[int]: + """Add a pin record. Returns pin ID.""" + async with pool.acquire() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO storage_pins (cid, storage_id, ipfs_cid, pin_type, size_bytes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (cid, storage_id) DO UPDATE SET + ipfs_cid = EXCLUDED.ipfs_cid, + pin_type = EXCLUDED.pin_type, + size_bytes = EXCLUDED.size_bytes, + pinned_at = NOW() + RETURNING id""", + cid, storage_id, ipfs_cid, pin_type, size_bytes + ) + return row["id"] if row else None + except Exception: + return None + + +async def remove_storage_pin(cid: str, storage_id: int) -> bool: + """Remove a pin record.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM storage_pins WHERE cid = $1 AND storage_id = $2", + cid, storage_id + ) + return "DELETE 1" in result + + +async def get_pins_for_content(cid: str) -> List[dict]: + """Get all storage locations where content is pinned.""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT sp.*, sb.provider_type, sb.provider_name, sb.actor_id + FROM storage_pins sp + JOIN storage_backends sb ON sp.storage_id = sb.id + WHERE sp.cid = $1""", + cid + ) + return [dict(row) for row in rows] + + +# ============ Pending Runs ============ + +async def create_pending_run( + run_id: str, + celery_task_id: str, + recipe: str, + inputs: List[str], + actor_id: str, + dag_json: Optional[str] = None, + output_name: Optional[str] = None, +) -> dict: + """Create a pending run record for durability.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO pending_runs (run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id) + VALUES ($1, $2, 'running', $3, $4, $5, $6, $7) + ON CONFLICT (run_id) DO UPDATE SET + celery_task_id = EXCLUDED.celery_task_id, + status = 'running', + updated_at = NOW() + RETURNING run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id, created_at, updated_at + """, + run_id, celery_task_id, recipe, _json.dumps(inputs), dag_json, output_name, actor_id + ) + return { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "dag_json": row["dag_json"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + + +async def get_pending_run(run_id: str) -> Optional[dict]: + """Get a pending run by ID.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT run_id, celery_task_id, status, recipe, inputs, dag_json, plan_cid, output_name, actor_id, error, + ipfs_playlist_cid, quality_playlists, checkpoint_frame, checkpoint_t, checkpoint_scans, + total_frames, resumable, created_at, updated_at + FROM pending_runs WHERE run_id = $1 + """, + run_id + ) + if row: + # Parse inputs if it's a string (JSONB should auto-parse but be safe) + inputs = row["inputs"] + if isinstance(inputs, str): + inputs = _json.loads(inputs) + # Parse quality_playlists if it's a string + quality_playlists = row.get("quality_playlists") + if isinstance(quality_playlists, str): + quality_playlists = _json.loads(quality_playlists) + # Parse checkpoint_scans if it's a string + checkpoint_scans = row.get("checkpoint_scans") + if isinstance(checkpoint_scans, str): + checkpoint_scans = _json.loads(checkpoint_scans) + return { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": inputs, + "dag_json": row["dag_json"], + "plan_cid": row["plan_cid"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "error": row["error"], + "ipfs_playlist_cid": row["ipfs_playlist_cid"], + "quality_playlists": quality_playlists, + "checkpoint_frame": row.get("checkpoint_frame"), + "checkpoint_t": row.get("checkpoint_t"), + "checkpoint_scans": checkpoint_scans, + "total_frames": row.get("total_frames"), + "resumable": row.get("resumable", True), + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + return None + + +async def list_pending_runs(actor_id: Optional[str] = None, status: Optional[str] = None) -> List[dict]: + """List pending runs, optionally filtered by actor and/or status.""" + async with pool.acquire() as conn: + conditions = [] + params = [] + param_idx = 1 + + if actor_id: + conditions.append(f"actor_id = ${param_idx}") + params.append(actor_id) + param_idx += 1 + + if status: + conditions.append(f"status = ${param_idx}") + params.append(status) + param_idx += 1 + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + rows = await conn.fetch( + f""" + SELECT run_id, celery_task_id, status, recipe, inputs, output_name, actor_id, error, created_at, updated_at + FROM pending_runs + WHERE {where_clause} + ORDER BY created_at DESC + """, + *params + ) + results = [] + for row in rows: + # Parse inputs if it's a string + inputs = row["inputs"] + if isinstance(inputs, str): + inputs = _json.loads(inputs) + results.append({ + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": inputs, + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "error": row["error"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + }) + return results + + +async def update_pending_run_status(run_id: str, status: str, error: Optional[str] = None) -> bool: + """Update the status of a pending run.""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + if error: + result = await conn.execute( + "UPDATE pending_runs SET status = $2, error = $3, updated_at = NOW() WHERE run_id = $1", + run_id, status, error + ) + else: + result = await conn.execute( + "UPDATE pending_runs SET status = $2, updated_at = NOW() WHERE run_id = $1", + run_id, status + ) + return "UPDATE 1" in result + + +async def update_pending_run_plan(run_id: str, plan_cid: str) -> bool: + """Update the plan_cid of a pending run (called when plan is generated).""" + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE pending_runs SET plan_cid = $2, updated_at = NOW() WHERE run_id = $1", + run_id, plan_cid + ) + return "UPDATE 1" in result + + +async def update_pending_run_playlist(run_id: str, ipfs_playlist_cid: str, quality_playlists: Optional[dict] = None) -> bool: + """Update the IPFS playlist CID of a streaming run. + + Args: + run_id: The run ID + ipfs_playlist_cid: Master playlist CID + quality_playlists: Dict of quality name -> {cid, width, height, bitrate} + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + if quality_playlists: + result = await conn.execute( + "UPDATE pending_runs SET ipfs_playlist_cid = $2, quality_playlists = $3, updated_at = NOW() WHERE run_id = $1", + run_id, ipfs_playlist_cid, _json.dumps(quality_playlists) + ) + else: + result = await conn.execute( + "UPDATE pending_runs SET ipfs_playlist_cid = $2, updated_at = NOW() WHERE run_id = $1", + run_id, ipfs_playlist_cid + ) + return "UPDATE 1" in result + + +async def update_pending_run_checkpoint( + run_id: str, + checkpoint_frame: int, + checkpoint_t: float, + checkpoint_scans: Optional[dict] = None, + total_frames: Optional[int] = None, +) -> bool: + """Update checkpoint state for a streaming run. + + Called at segment boundaries to enable resume after failures. + + Args: + run_id: The run ID + checkpoint_frame: Last completed frame at segment boundary + checkpoint_t: Time value for checkpoint frame + checkpoint_scans: Accumulated scan state {scan_name: state_dict} + total_frames: Total expected frames (for progress %) + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE pending_runs SET + checkpoint_frame = $2, + checkpoint_t = $3, + checkpoint_scans = $4, + total_frames = COALESCE($5, total_frames), + updated_at = NOW() + WHERE run_id = $1 + """, + run_id, + checkpoint_frame, + checkpoint_t, + _json.dumps(checkpoint_scans) if checkpoint_scans else None, + total_frames, + ) + return "UPDATE 1" in result + + +async def get_run_checkpoint(run_id: str) -> Optional[dict]: + """Get checkpoint data for resuming a run. + + Returns: + Dict with checkpoint_frame, checkpoint_t, checkpoint_scans, quality_playlists, etc. + or None if no checkpoint exists + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT checkpoint_frame, checkpoint_t, checkpoint_scans, total_frames, + quality_playlists, ipfs_playlist_cid, resumable + FROM pending_runs WHERE run_id = $1 + """, + run_id + ) + if row and row.get("checkpoint_frame") is not None: + # Parse JSONB fields + checkpoint_scans = row.get("checkpoint_scans") + if isinstance(checkpoint_scans, str): + checkpoint_scans = _json.loads(checkpoint_scans) + quality_playlists = row.get("quality_playlists") + if isinstance(quality_playlists, str): + quality_playlists = _json.loads(quality_playlists) + return { + "frame_num": row["checkpoint_frame"], + "t": row["checkpoint_t"], + "scans": checkpoint_scans or {}, + "total_frames": row.get("total_frames"), + "quality_playlists": quality_playlists, + "ipfs_playlist_cid": row.get("ipfs_playlist_cid"), + "resumable": row.get("resumable", True), + } + return None + + +async def clear_run_checkpoint(run_id: str) -> bool: + """Clear checkpoint data for a run (used on restart). + + Args: + run_id: The run ID + """ + if pool is None: + raise RuntimeError("Database pool not initialized - call init_db() first") + async with pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE pending_runs SET + checkpoint_frame = NULL, + checkpoint_t = NULL, + checkpoint_scans = NULL, + quality_playlists = NULL, + ipfs_playlist_cid = NULL, + updated_at = NOW() + WHERE run_id = $1 + """, + run_id, + ) + return "UPDATE 1" in result + + +async def complete_pending_run(run_id: str) -> bool: + """Remove a pending run after it completes (moves to run_cache).""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM pending_runs WHERE run_id = $1", + run_id + ) + return "DELETE 1" in result + + +async def get_stale_pending_runs(older_than_hours: int = 24) -> List[dict]: + """Get pending runs that haven't been updated recently (for recovery).""" + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT run_id, celery_task_id, status, recipe, inputs, dag_json, output_name, actor_id, created_at, updated_at + FROM pending_runs + WHERE status IN ('pending', 'running') + AND updated_at < NOW() - INTERVAL '%s hours' + ORDER BY created_at + """, + older_than_hours + ) + return [ + { + "run_id": row["run_id"], + "celery_task_id": row["celery_task_id"], + "status": row["status"], + "recipe": row["recipe"], + "inputs": row["inputs"], + "dag_json": row["dag_json"], + "output_name": row["output_name"], + "actor_id": row["actor_id"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None, + } + for row in rows + ] + + +# ============ Friendly Names ============ + +async def create_friendly_name( + actor_id: str, + base_name: str, + version_id: str, + cid: str, + item_type: str, + display_name: Optional[str] = None, +) -> dict: + """ + Create a friendly name entry. + + Returns the created entry. + """ + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO friendly_names (actor_id, base_name, version_id, cid, item_type, display_name) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (actor_id, cid) DO UPDATE SET + base_name = EXCLUDED.base_name, + version_id = EXCLUDED.version_id, + display_name = EXCLUDED.display_name + RETURNING id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + """, + actor_id, base_name, version_id, cid, item_type, display_name + ) + return { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + + +async def get_friendly_name_by_cid(actor_id: str, cid: str) -> Optional[dict]: + """Get friendly name entry by CID.""" + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 AND cid = $2 + """, + actor_id, cid + ) + if row: + return { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + return None + + +async def resolve_friendly_name( + actor_id: str, + name: str, + item_type: Optional[str] = None, +) -> Optional[str]: + """ + Resolve a friendly name to a CID. + + Name can be: + - "base-name" -> resolves to latest version + - "base-name version-id" -> resolves to exact version + + Returns CID or None if not found. + """ + parts = name.strip().split(' ') + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + + async with pool.acquire() as conn: + if version_id: + # Exact version lookup + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = $1 AND base_name = $2 AND version_id = $3 + """ + params = [actor_id, base_name, version_id] + if item_type: + query += " AND item_type = $4" + params.append(item_type) + + return await conn.fetchval(query, *params) + else: + # Latest version lookup + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = $1 AND base_name = $2 + """ + params = [actor_id, base_name] + if item_type: + query += " AND item_type = $3" + params.append(item_type) + query += " ORDER BY created_at DESC LIMIT 1" + + return await conn.fetchval(query, *params) + + +async def list_friendly_names( + actor_id: str, + item_type: Optional[str] = None, + base_name: Optional[str] = None, + latest_only: bool = False, +) -> List[dict]: + """ + List friendly names for a user. + + Args: + actor_id: User ID + item_type: Filter by type (recipe, effect, media) + base_name: Filter by base name + latest_only: If True, only return latest version of each base name + """ + async with pool.acquire() as conn: + if latest_only: + # Use DISTINCT ON to get latest version of each base name + query = """ + SELECT DISTINCT ON (base_name) + id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 + """ + params = [actor_id] + if item_type: + query += " AND item_type = $2" + params.append(item_type) + if base_name: + query += f" AND base_name = ${len(params) + 1}" + params.append(base_name) + query += " ORDER BY base_name, created_at DESC" + else: + query = """ + SELECT id, actor_id, base_name, version_id, cid, item_type, display_name, created_at + FROM friendly_names + WHERE actor_id = $1 + """ + params = [actor_id] + if item_type: + query += " AND item_type = $2" + params.append(item_type) + if base_name: + query += f" AND base_name = ${len(params) + 1}" + params.append(base_name) + query += " ORDER BY base_name, created_at DESC" + + rows = await conn.fetch(query, *params) + return [ + { + "id": row["id"], + "actor_id": row["actor_id"], + "base_name": row["base_name"], + "version_id": row["version_id"], + "cid": row["cid"], + "item_type": row["item_type"], + "display_name": row["display_name"], + "friendly_name": f"{row['base_name']} {row['version_id']}", + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + for row in rows + ] + + +async def delete_friendly_name(actor_id: str, cid: str) -> bool: + """Delete a friendly name entry by CID.""" + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM friendly_names WHERE actor_id = $1 AND cid = $2", + actor_id, cid + ) + return "DELETE 1" in result + + +async def update_friendly_name_cid(actor_id: str, old_cid: str, new_cid: str) -> bool: + """ + Update a friendly name's CID (used when IPFS upload completes). + + This updates the CID from a local SHA256 hash to an IPFS CID, + ensuring assets can be fetched by remote workers via IPFS. + """ + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE friendly_names SET cid = $3 WHERE actor_id = $1 AND cid = $2", + actor_id, old_cid, new_cid + ) + return "UPDATE 1" in result + + +# ============================================================================= +# SYNCHRONOUS DATABASE FUNCTIONS (for use from non-async contexts like video streaming) +# ============================================================================= + +def resolve_friendly_name_sync( + actor_id: str, + name: str, + item_type: Optional[str] = None, +) -> Optional[str]: + """ + Synchronous version of resolve_friendly_name using psycopg2. + + Useful when calling from synchronous code (e.g., video streaming callbacks) + where async/await is not possible. + + Returns CID or None if not found. + """ + import psycopg2 + + parts = name.strip().split(' ') + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + + try: + conn = psycopg2.connect(DATABASE_URL) + cursor = conn.cursor() + + if version_id: + # Exact version lookup + if item_type: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND version_id = %s AND item_type = %s + """ + cursor.execute(query, (actor_id, base_name, version_id, item_type)) + else: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND version_id = %s + """ + cursor.execute(query, (actor_id, base_name, version_id)) + else: + # Latest version lookup + if item_type: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s AND item_type = %s + ORDER BY created_at DESC LIMIT 1 + """ + cursor.execute(query, (actor_id, base_name, item_type)) + else: + query = """ + SELECT cid FROM friendly_names + WHERE actor_id = %s AND base_name = %s + ORDER BY created_at DESC LIMIT 1 + """ + cursor.execute(query, (actor_id, base_name)) + + result = cursor.fetchone() + cursor.close() + conn.close() + + return result[0] if result else None + + except Exception as e: + import sys + print(f"resolve_friendly_name_sync ERROR: {e}", file=sys.stderr) + return None + + +def get_ipfs_cid_sync(cid: str) -> Optional[str]: + """ + Synchronous version of get_ipfs_cid using psycopg2. + + Returns the IPFS CID for a given internal CID, or None if not found. + """ + import psycopg2 + + try: + conn = psycopg2.connect(DATABASE_URL) + cursor = conn.cursor() + + cursor.execute( + "SELECT ipfs_cid FROM cache_items WHERE cid = %s", + (cid,) + ) + + result = cursor.fetchone() + cursor.close() + conn.close() + + return result[0] if result else None + + except Exception as e: + import sys + print(f"get_ipfs_cid_sync ERROR: {e}", file=sys.stderr) + return None diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..a2d6e69 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +cd "$(dirname "$0")" + +echo "=== Pulling latest code ===" +git pull + +echo "=== Building Docker image ===" +docker build --build-arg CACHEBUST=$(date +%s) -t registry.rose-ash.com:5000/celery-l1-server:latest . + +echo "=== Pushing to registry ===" +docker push registry.rose-ash.com:5000/celery-l1-server:latest + +echo "=== Redeploying celery stack ===" +docker stack deploy -c docker-compose.yml celery + +echo "=== Done ===" +docker stack services celery diff --git a/diagnose_gpu.py b/diagnose_gpu.py new file mode 100755 index 0000000..5136139 --- /dev/null +++ b/diagnose_gpu.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +GPU Rendering Diagnostic Script + +Checks for common issues that cause GPU rendering slowdowns in art-dag. +Run this script to identify potential performance bottlenecks. +""" + +import sys +import subprocess +import os + +def print_section(title): + print(f"\n{'='*60}") + print(f" {title}") + print(f"{'='*60}") + +def check_pass(msg): + print(f" [PASS] {msg}") + +def check_fail(msg): + print(f" [FAIL] {msg}") + +def check_warn(msg): + print(f" [WARN] {msg}") + +def check_info(msg): + print(f" [INFO] {msg}") + +# ============================================================ +# 1. Check GPU Availability +# ============================================================ +print_section("1. GPU AVAILABILITY") + +# Check nvidia-smi +try: + result = subprocess.run(["nvidia-smi", "--query-gpu=name,memory.total,memory.free,utilization.gpu", + "--format=csv,noheader"], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + for line in result.stdout.strip().split('\n'): + check_pass(f"GPU found: {line}") + else: + check_fail("nvidia-smi failed - no GPU detected") +except FileNotFoundError: + check_fail("nvidia-smi not found - NVIDIA drivers not installed") +except Exception as e: + check_fail(f"nvidia-smi error: {e}") + +# ============================================================ +# 2. Check CuPy +# ============================================================ +print_section("2. CUPY (GPU ARRAY LIBRARY)") + +try: + import cupy as cp + check_pass(f"CuPy available, version {cp.__version__}") + + # Test basic GPU operation + try: + a = cp.zeros((100, 100), dtype=cp.uint8) + cp.cuda.Stream.null.synchronize() + check_pass("CuPy GPU operations working") + + # Check memory + mempool = cp.get_default_memory_pool() + check_info(f"GPU memory pool: {mempool.used_bytes() / 1024**2:.1f} MB used, " + f"{mempool.total_bytes() / 1024**2:.1f} MB total") + except Exception as e: + check_fail(f"CuPy GPU test failed: {e}") +except ImportError: + check_fail("CuPy not installed - GPU rendering disabled") + +# ============================================================ +# 3. Check PyNvVideoCodec (GPU Encoding) +# ============================================================ +print_section("3. PYNVVIDEOCODEC (GPU ENCODING)") + +try: + import PyNvVideoCodec as nvc + check_pass("PyNvVideoCodec available - zero-copy GPU encoding enabled") +except ImportError: + check_warn("PyNvVideoCodec not available - using FFmpeg NVENC (slower)") + +# ============================================================ +# 4. Check Decord GPU (Hardware Decode) +# ============================================================ +print_section("4. DECORD GPU (HARDWARE DECODE)") + +try: + import decord + from decord import gpu + ctx = gpu(0) + check_pass(f"Decord GPU (NVDEC) available - hardware video decode enabled") +except ImportError: + check_warn("Decord not installed - using FFmpeg decode") +except Exception as e: + check_warn(f"Decord GPU not available ({e}) - using FFmpeg decode") + +# ============================================================ +# 5. Check DLPack Support +# ============================================================ +print_section("5. DLPACK (ZERO-COPY TRANSFER)") + +try: + import decord + from decord import VideoReader, gpu + import cupy as cp + + # Need a test video file + test_video = None + for path in ["/data/cache", "/tmp"]: + if os.path.exists(path): + for f in os.listdir(path): + if f.endswith(('.mp4', '.webm', '.mkv')): + test_video = os.path.join(path, f) + break + if test_video: + break + + if test_video: + try: + vr = VideoReader(test_video, ctx=gpu(0)) + frame = vr[0] + dlpack = frame.to_dlpack() + gpu_frame = cp.from_dlpack(dlpack) + check_pass(f"DLPack zero-copy working (tested with {os.path.basename(test_video)})") + except Exception as e: + check_fail(f"DLPack FAILED: {e}") + check_info("This means every frame does GPU->CPU->GPU copy (SLOW)") + else: + check_warn("No test video found - cannot verify DLPack") +except ImportError: + check_warn("Cannot test DLPack - decord or cupy not available") + +# ============================================================ +# 6. Check Fast CUDA Kernels +# ============================================================ +print_section("6. FAST CUDA KERNELS (JIT COMPILED)") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from streaming.jit_compiler import ( + fast_rotate, fast_zoom, fast_blend, fast_hue_shift, + fast_invert, fast_ripple, get_fast_ops + ) + check_pass("Fast CUDA kernels loaded successfully") + + # Test one kernel + try: + import cupy as cp + test_img = cp.zeros((720, 1280, 3), dtype=cp.uint8) + result = fast_rotate(test_img, 45.0) + cp.cuda.Stream.null.synchronize() + check_pass("Fast rotate kernel working") + except Exception as e: + check_fail(f"Fast kernel execution failed: {e}") +except ImportError as e: + check_warn(f"Fast CUDA kernels not available: {e}") + check_info("Fallback to slower CuPy operations") + +# ============================================================ +# 7. Check Fused Pipeline Compiler +# ============================================================ +print_section("7. FUSED PIPELINE COMPILER") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from streaming.sexp_to_cuda import compile_frame_pipeline, compile_autonomous_pipeline + check_pass("Fused CUDA pipeline compiler available") +except ImportError as e: + check_warn(f"Fused pipeline compiler not available: {e}") + check_info("Using per-operation fallback (slower for multi-effect pipelines)") + +# ============================================================ +# 8. Check FFmpeg NVENC +# ============================================================ +print_section("8. FFMPEG NVENC (HARDWARE ENCODE)") + +try: + result = subprocess.run(["ffmpeg", "-encoders"], capture_output=True, text=True, timeout=5) + if "h264_nvenc" in result.stdout: + check_pass("FFmpeg h264_nvenc encoder available") + else: + check_warn("FFmpeg h264_nvenc not available - using libx264 (CPU)") + + if "hevc_nvenc" in result.stdout: + check_pass("FFmpeg hevc_nvenc encoder available") +except Exception as e: + check_fail(f"FFmpeg check failed: {e}") + +# ============================================================ +# 9. Check FFmpeg NVDEC +# ============================================================ +print_section("9. FFMPEG NVDEC (HARDWARE DECODE)") + +try: + result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5) + if "cuda" in result.stdout: + check_pass("FFmpeg CUDA hwaccel available") + else: + check_warn("FFmpeg CUDA hwaccel not available - using CPU decode") +except Exception as e: + check_fail(f"FFmpeg hwaccel check failed: {e}") + +# ============================================================ +# 10. Check Pipeline Cache Status +# ============================================================ +print_section("10. PIPELINE CACHE STATUS") + +try: + sys.path.insert(0, '/root/art-dag/celery') + from sexp_effects.primitive_libs.streaming_gpu import ( + _FUSED_PIPELINE_CACHE, _AUTONOMOUS_PIPELINE_CACHE + ) + fused_count = len(_FUSED_PIPELINE_CACHE) + auto_count = len(_AUTONOMOUS_PIPELINE_CACHE) + + if fused_count > 0 or auto_count > 0: + check_info(f"Fused pipeline cache: {fused_count} entries") + check_info(f"Autonomous pipeline cache: {auto_count} entries") + if fused_count > 100 or auto_count > 100: + check_warn("Large pipeline cache - may cause memory pressure") + else: + check_info("Pipeline caches empty (no rendering done yet)") +except Exception as e: + check_info(f"Could not check pipeline cache: {e}") + +# ============================================================ +# Summary +# ============================================================ +print_section("SUMMARY") +print(""" +Optimal GPU rendering requires: + 1. [CRITICAL] CuPy with working GPU operations + 2. [CRITICAL] DLPack zero-copy transfer (decord -> CuPy) + 3. [HIGH] Fast CUDA kernels from jit_compiler + 4. [MEDIUM] Fused pipeline compiler for multi-effect recipes + 5. [MEDIUM] PyNvVideoCodec for zero-copy encoding + 6. [LOW] FFmpeg NVENC/NVDEC as fallback + +If DLPack is failing, check: + - decord version (needs 0.6.0+ with DLPack support) + - CuPy version compatibility + - CUDA toolkit version match + +If fast kernels are not loading: + - Check if streaming/jit_compiler.py exists + - Verify CUDA compiler (nvcc) is available +""") diff --git a/docker-compose.gpu-dev.yml b/docker-compose.gpu-dev.yml new file mode 100644 index 0000000..1facb3b --- /dev/null +++ b/docker-compose.gpu-dev.yml @@ -0,0 +1,36 @@ +# GPU Worker Development Override +# +# Usage: docker stack deploy -c docker-compose.yml -c docker-compose.gpu-dev.yml celery +# Or for quick testing: docker-compose -f docker-compose.yml -f docker-compose.gpu-dev.yml up l1-gpu-worker +# +# Features: +# - Mounts source code for instant changes (no rebuild needed) +# - Uses watchmedo for auto-reload on file changes +# - Shows config on startup + +version: '3.8' + +services: + l1-gpu-worker: + # Override command to use watchmedo for auto-reload + command: > + sh -c " + pip install -q watchdog[watchmedo] 2>/dev/null || true; + echo '=== GPU WORKER DEV MODE ==='; + echo 'Source mounted - changes take effect on restart'; + echo 'Auto-reload enabled via watchmedo'; + env | grep -E 'STREAMING_GPU|IPFS_GATEWAY|REDIS|DATABASE' | sort; + echo '==========================='; + watchmedo auto-restart --directory=/app --pattern='*.py' --recursive -- \ + celery -A celery_app worker --loglevel=info -E -Q gpu,celery + " + environment: + # Development defaults (can override with .env) + - STREAMING_GPU_PERSIST=0 + - IPFS_GATEWAY_URL=https://celery-artdag.rose-ash.com/ipfs + - SHOW_CONFIG=1 + volumes: + # Mount source code for hot reload + - ./:/app:ro + # Keep cache local + - gpu_cache:/data/cache diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..301e439 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,191 @@ +version: "3.8" + +services: + redis: + image: redis:7-alpine + ports: + - target: 6379 + published: 16379 + mode: host # Bypass swarm routing mesh + volumes: + - redis_data:/data + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + postgres: + image: postgres:16-alpine + env_file: + - .env + environment: + - POSTGRES_USER=artdag + - POSTGRES_DB=artdag + ports: + - target: 5432 + published: 15432 + mode: host # Expose for GPU worker on different VPC + volumes: + - postgres_data:/var/lib/postgresql/data + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + ipfs: + image: ipfs/kubo:latest + ports: + - "4001:4001" # Swarm TCP + - "4001:4001/udp" # Swarm UDP + - target: 5001 + published: 15001 + mode: host # API port for GPU worker on different VPC + volumes: + - ipfs_data:/data/ipfs + - l1_cache:/data/cache:ro # Read-only access to cache for adding files + networks: + - celery + - externalnet # For gateway access + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + l1-server: + image: registry.rose-ash.com:5000/celery-l1-server:latest + env_file: + - .env + environment: + - REDIS_URL=redis://redis:6379/5 + # IPFS_API multiaddr - used for all IPFS operations (add, cat, pin) + - IPFS_API=/dns/ipfs/tcp/5001 + - CACHE_DIR=/data/cache + # Coop app internal URLs for fragment composition + - INTERNAL_URL_BLOG=http://blog:8000 + - INTERNAL_URL_CART=http://cart:8000 + - INTERNAL_URL_ACCOUNT=http://account:8000 + # DATABASE_URL, ADMIN_TOKEN, ARTDAG_CLUSTER_KEY, + # L2_SERVER, L2_DOMAIN, IPFS_GATEWAY_URL from .env file + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8100/health')"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 15s + volumes: + - l1_cache:/data/cache + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + - externalnet + deploy: + replicas: 1 + update_config: + order: start-first + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + l1-worker: + image: registry.rose-ash.com:5000/celery-l1-server:latest + command: sh -c "find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; celery -A celery_app worker --loglevel=info -E" + env_file: + - .env + environment: + - REDIS_URL=redis://redis:6379/5 + # IPFS_API multiaddr - used for all IPFS operations (add, cat, pin) + - IPFS_API=/dns/ipfs/tcp/5001 + - CACHE_DIR=/data/cache + - C_FORCE_ROOT=true + # DATABASE_URL, ARTDAG_CLUSTER_KEY from .env file + volumes: + - l1_cache:/data/cache + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + deploy: + replicas: 2 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + flower: + image: mher/flower:2.0 + command: celery --broker=redis://redis:6379/5 flower --port=5555 + environment: + - CELERY_BROKER_URL=redis://redis:6379/5 + - FLOWER_PORT=5555 + depends_on: + - redis + networks: + - celery + - externalnet + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + # GPU worker for streaming/rendering tasks + # Build: docker build -f Dockerfile.gpu -t registry.rose-ash.com:5000/celery-l1-gpu-server:latest . + # Requires: docker node update --label-add gpu=true + l1-gpu-worker: + image: registry.rose-ash.com:5000/celery-l1-gpu-server:latest + command: sh -c "cd /app && celery -A celery_app worker --loglevel=info -E -Q gpu,celery" + env_file: + - .env.gpu + volumes: + # Local cache - ephemeral, just for working files + - gpu_cache:/data/cache + # Note: No source mount - GPU worker uses code from image + depends_on: + - redis + - postgres + - ipfs + networks: + - celery + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu == true + +volumes: + redis_data: + postgres_data: + ipfs_data: + l1_cache: + gpu_cache: # Ephemeral cache for GPU workers + +networks: + celery: + driver: overlay + externalnet: + external: true diff --git a/effects/quick_test_explicit.sexp b/effects/quick_test_explicit.sexp new file mode 100644 index 0000000..0a3698b --- /dev/null +++ b/effects/quick_test_explicit.sexp @@ -0,0 +1,150 @@ +;; Quick Test - Fully Explicit Streaming Version +;; +;; The interpreter is completely generic - knows nothing about video/audio. +;; All domain logic is explicit via primitives. +;; +;; Run with built-in sources/audio: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp --fps 30 +;; +;; Run with external config files: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp \ +;; --sources configs/sources-default.sexp \ +;; --audio configs/audio-dizzy.sexp \ +;; --fps 30 + +(stream "quick_test_explicit" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives and effects + (include :path "../templates/standard-primitives.sexp") + (include :path "../templates/standard-effects.sexp") + + ;; Load reusable templates + (include :path "../templates/stream-process-pair.sexp") + (include :path "../templates/crossfade-zoom.sexp") + + ;; === SOURCES AS ARRAY === + (def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) + ]) + + ;; Per-pair config: [rot-dir, rot-a-max, rot-b-max, zoom-a-max, zoom-b-max] + ;; Pairs 3,6: reversed (negative rot-a, positive rot-b, shrink zoom-a, grow zoom-b) + ;; Pair 5: smaller ranges + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 + ]) + + ;; Audio analyzer + (def music (streaming:make-audio-analyzer "dizzy.mp3")) + + ;; Audio playback + (audio-playback "../dizzy.mp3") + + ;; === GLOBAL SCANS === + + ;; Cycle state: which source is active (recipe-specific) + ;; clen = beats per source (8-24 beats = ~4-12 seconds) + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Reusable scans from templates (require 'music' to be defined) + (include :path "../templates/scan-oscillating-spin.sexp") + (include :path "../templates/scan-ripple-drops.sexp") + + ;; === PER-PAIR STATE (dynamically sized based on sources) === + ;; Each pair has: inv-a, inv-b, hue-a, hue-b, mix, rot-angle + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [;; Invert toggles (10% chance, lasts 1-4 beats) + new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + ;; Hue shifts (10% chance, lasts 1-4 beats) - use countdown like invert + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + ;; Pick random hue value when triggering (stored separately) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + ;; Mix (holds for 1-10 beats, then picks 0, 0.5, or 1) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + ;; Rotation (accumulates, reverses direction when cycle completes) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + ;; Note: dir comes from pair-configs, but we store rotation state here + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic: last third of cycle crossfades to next + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Get pair states array (required by process-pair macro) + pair-states (bind pairs :states) + + ;; Process active pair using macro from template + active-frame (process-pair active) + + ;; Crossfade with zoom during transition (using macro) + result (if fading + (crossfade-zoom active-frame (process-pair next-idx) fade-amt) + active-frame) + + ;; Final: global spin + ripple + spun (rotate result :angle (bind spin :angle)) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50))] + + (ripple spun + :amplitude rip-amp + :center_x (bind ripple-state :cx) + :center_y (bind ripple-state :cy) + :frequency 8 + :decay 2 + :speed 5)))) diff --git a/ipfs_client.py b/ipfs_client.py new file mode 100644 index 0000000..3edf5b1 --- /dev/null +++ b/ipfs_client.py @@ -0,0 +1,345 @@ +# art-celery/ipfs_client.py +""" +IPFS client for Art DAG L1 server. + +Provides functions to add, retrieve, and pin files on IPFS. +Uses direct HTTP API calls for compatibility with all Kubo versions. +""" + +import logging +import os +import re +from pathlib import Path +from typing import Optional, Union + +import requests + +logger = logging.getLogger(__name__) + +# IPFS API multiaddr - default to local, docker uses /dns/ipfs/tcp/5001 +IPFS_API = os.getenv("IPFS_API", "/ip4/127.0.0.1/tcp/5001") + +# Connection timeout in seconds (increased for large files) +IPFS_TIMEOUT = int(os.getenv("IPFS_TIMEOUT", "120")) + +# IPFS gateway URLs for fallback when local node doesn't have content +# Comma-separated list of gateway URLs (without /ipfs/ suffix) +IPFS_GATEWAYS = [g.strip() for g in os.getenv( + "IPFS_GATEWAYS", + "https://ipfs.io,https://cloudflare-ipfs.com,https://dweb.link" +).split(",") if g.strip()] + +# Gateway timeout (shorter than API timeout for faster fallback) +GATEWAY_TIMEOUT = int(os.getenv("GATEWAY_TIMEOUT", "30")) + + +def _multiaddr_to_url(multiaddr: str) -> str: + """Convert IPFS multiaddr to HTTP URL.""" + # Handle /dns/hostname/tcp/port format + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +# Base URL for IPFS API +IPFS_BASE_URL = _multiaddr_to_url(IPFS_API) + + +def add_file(file_path: Union[Path, str], pin: bool = True) -> Optional[str]: + """ + Add a file to IPFS and optionally pin it. + + Args: + file_path: Path to the file to add (Path object or string) + pin: Whether to pin the file (default: True) + + Returns: + IPFS CID (content identifier) or None on failure + """ + try: + # Ensure file_path is a Path object + if isinstance(file_path, str): + file_path = Path(file_path) + + url = f"{IPFS_BASE_URL}/api/v0/add" + params = {"pin": str(pin).lower()} + + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f)} + response = requests.post(url, params=params, files=files, timeout=IPFS_TIMEOUT) + + response.raise_for_status() + result = response.json() + cid = result["Hash"] + logger.info(f"Added to IPFS: {file_path.name} -> {cid}") + return cid + except Exception as e: + logger.error(f"Failed to add to IPFS: {e}") + return None + + +def add_bytes(data: bytes, pin: bool = True) -> Optional[str]: + """ + Add bytes data to IPFS and optionally pin it. + + Args: + data: Bytes to add + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/add" + params = {"pin": str(pin).lower()} + files = {"file": ("data", data)} + + response = requests.post(url, params=params, files=files, timeout=IPFS_TIMEOUT) + response.raise_for_status() + result = response.json() + cid = result["Hash"] + + logger.info(f"Added bytes to IPFS: {len(data)} bytes -> {cid}") + return cid + except Exception as e: + logger.error(f"Failed to add bytes to IPFS: {e}") + return None + + +def add_json(data: dict, pin: bool = True) -> Optional[str]: + """ + Serialize dict to JSON and add to IPFS. + + Args: + data: Dictionary to serialize and store + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + import json + json_bytes = json.dumps(data, indent=2, sort_keys=True).encode('utf-8') + return add_bytes(json_bytes, pin=pin) + + +def add_string(content: str, pin: bool = True) -> Optional[str]: + """ + Add a string to IPFS and optionally pin it. + + Args: + content: String content to add (e.g., S-expression) + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID or None on failure + """ + return add_bytes(content.encode('utf-8'), pin=pin) + + +def get_file(cid: str, dest_path: Union[Path, str]) -> bool: + """ + Retrieve a file from IPFS and save to destination. + + Args: + cid: IPFS CID to retrieve + dest_path: Path to save the file (Path object or string) + + Returns: + True on success, False on failure + """ + try: + data = get_bytes(cid) + if data is None: + return False + + # Ensure dest_path is a Path object + if isinstance(dest_path, str): + dest_path = Path(dest_path) + + dest_path.parent.mkdir(parents=True, exist_ok=True) + dest_path.write_bytes(data) + logger.info(f"Retrieved from IPFS: {cid} -> {dest_path}") + return True + except Exception as e: + logger.error(f"Failed to get from IPFS: {e}") + return False + + +def get_bytes_from_gateway(cid: str) -> Optional[bytes]: + """ + Retrieve bytes from IPFS via public gateways (fallback). + + Tries each configured gateway in order until one succeeds. + + Args: + cid: IPFS CID to retrieve + + Returns: + File content as bytes or None if all gateways fail + """ + for gateway in IPFS_GATEWAYS: + try: + url = f"{gateway}/ipfs/{cid}" + logger.info(f"Trying gateway: {url}") + response = requests.get(url, timeout=GATEWAY_TIMEOUT) + response.raise_for_status() + data = response.content + logger.info(f"Retrieved from gateway {gateway}: {cid} ({len(data)} bytes)") + return data + except Exception as e: + logger.warning(f"Gateway {gateway} failed for {cid}: {e}") + continue + + logger.error(f"All gateways failed for {cid}") + return None + + +def get_bytes(cid: str, use_gateway_fallback: bool = True) -> Optional[bytes]: + """ + Retrieve bytes data from IPFS. + + Tries local IPFS node first, then falls back to public gateways + if configured and use_gateway_fallback is True. + + Args: + cid: IPFS CID to retrieve + use_gateway_fallback: If True, try public gateways on local failure + + Returns: + File content as bytes or None on failure + """ + # Try local IPFS node first + try: + url = f"{IPFS_BASE_URL}/api/v0/cat" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + data = response.content + + logger.info(f"Retrieved from IPFS: {cid} ({len(data)} bytes)") + return data + except Exception as e: + logger.warning(f"Local IPFS failed for {cid}: {e}") + + # Try gateway fallback + if use_gateway_fallback and IPFS_GATEWAYS: + logger.info(f"Trying gateway fallback for {cid}") + return get_bytes_from_gateway(cid) + + logger.error(f"Failed to get bytes from IPFS: {e}") + return None + + +def pin(cid: str) -> bool: + """ + Pin a CID on IPFS. + + Args: + cid: IPFS CID to pin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/add" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Pinned on IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to pin on IPFS: {e}") + return False + + +def unpin(cid: str) -> bool: + """ + Unpin a CID on IPFS. + + Args: + cid: IPFS CID to unpin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/rm" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Unpinned on IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to unpin on IPFS: {e}") + return False + + +def is_pinned(cid: str) -> bool: + """ + Check if a CID is pinned on IPFS. + + Args: + cid: IPFS CID to check + + Returns: + True if pinned, False otherwise + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/ls" + params = {"arg": cid, "type": "recursive"} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + if response.status_code == 200: + result = response.json() + return cid in result.get("Keys", {}) + return False + except Exception as e: + logger.error(f"Failed to check pin status: {e}") + return False + + +def is_available() -> bool: + """ + Check if IPFS daemon is available. + + Returns: + True if IPFS is available, False otherwise + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def get_node_id() -> Optional[str]: + """ + Get this IPFS node's peer ID. + + Returns: + Peer ID string or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=IPFS_TIMEOUT) + response.raise_for_status() + return response.json().get("ID") + except Exception as e: + logger.error(f"Failed to get node ID: {e}") + return None diff --git a/path_registry.py b/path_registry.py new file mode 100644 index 0000000..985be18 --- /dev/null +++ b/path_registry.py @@ -0,0 +1,477 @@ +""" +Path Registry - Maps human-friendly paths to content-addressed IDs. + +This module provides a bidirectional mapping between: +- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp") +- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes) + +The registry is useful for: +- Looking up effects by their friendly path name +- Resolving cids back to the original path for debugging +- Maintaining a stable naming scheme across cache updates + +Storage: +- Uses the existing item_types table in the database (path column) +- Caches in Redis for fast lookups across distributed workers + +The registry uses a system actor (@system@local) for global path mappings, +allowing effects to be resolved by path without requiring user context. +""" + +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# System actor for global path mappings (effects, recipes, analyzers) +SYSTEM_ACTOR = "@system@local" + + +@dataclass +class PathEntry: + """A registered path with its content-addressed ID.""" + path: str # Human-friendly path (relative or normalized) + cid: str # Content-addressed ID (IPFS CID or hash) + content_type: str # Type: "effect", "recipe", "analyzer", etc. + actor_id: str = SYSTEM_ACTOR # Owner (system for global) + description: Optional[str] = None + created_at: float = 0.0 + + +class PathRegistry: + """ + Registry for mapping paths to content-addressed IDs. + + Uses the existing item_types table for persistence and Redis + for fast lookups in distributed Celery workers. + """ + + def __init__(self, redis_client=None): + self._redis = redis_client + self._redis_path_to_cid_key = "artdag:path_to_cid" + self._redis_cid_to_path_key = "artdag:cid_to_path" + + def _run_async(self, coro): + """Run async coroutine from sync context.""" + import asyncio + + try: + loop = asyncio.get_running_loop() + import threading + result = [None] + error = [None] + + def run_in_thread(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + result[0] = new_loop.run_until_complete(coro) + finally: + new_loop.close() + except Exception as e: + error[0] = e + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) + if error[0]: + raise error[0] + return result[0] + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + def _normalize_path(self, path: str) -> str: + """Normalize a path for consistent storage.""" + # Remove leading ./ or / + path = path.lstrip('./') + # Normalize separators + path = path.replace('\\', '/') + # Remove duplicate slashes + while '//' in path: + path = path.replace('//', '/') + return path + + def register( + self, + path: str, + cid: str, + content_type: str = "effect", + actor_id: str = SYSTEM_ACTOR, + description: Optional[str] = None, + ) -> PathEntry: + """ + Register a path -> cid mapping. + + Args: + path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp") + cid: Content-addressed ID (IPFS CID or hash) + content_type: Type of content ("effect", "recipe", "analyzer") + actor_id: Owner (default: system for global mappings) + description: Optional description + + Returns: + The created PathEntry + """ + norm_path = self._normalize_path(path) + now = datetime.now(timezone.utc).timestamp() + + entry = PathEntry( + path=norm_path, + cid=cid, + content_type=content_type, + actor_id=actor_id, + description=description, + created_at=now, + ) + + # Store in database (item_types table) + self._save_to_db(entry) + + # Update Redis cache + self._update_redis_cache(norm_path, cid) + + logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...") + return entry + + def _save_to_db(self, entry: PathEntry): + """Save entry to database using item_types table.""" + import database + + async def save(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + entry.cid + ) + # Insert or update item_type with path + await conn.execute( + """ + INSERT INTO item_types (cid, actor_id, type, path, description) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE(EXCLUDED.description, item_types.description) + """, + entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description + ) + finally: + await conn.close() + + try: + self._run_async(save()) + except Exception as e: + logger.warning(f"Failed to save path registry to DB: {e}") + + def _update_redis_cache(self, path: str, cid: str): + """Update Redis cache with mapping.""" + if self._redis: + try: + self._redis.hset(self._redis_path_to_cid_key, path, cid) + self._redis.hset(self._redis_cid_to_path_key, cid, path) + except Exception as e: + logger.warning(f"Failed to update Redis cache: {e}") + + def get_cid(self, path: str, content_type: str = None) -> Optional[str]: + """ + Get the cid for a path. + + Args: + path: Human-friendly path + content_type: Optional type filter + + Returns: + The cid, or None if not found + """ + norm_path = self._normalize_path(path) + + # Try Redis first (fast path) + if self._redis: + try: + val = self._redis.hget(self._redis_path_to_cid_key, norm_path) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_cid_from_db(norm_path, content_type) + + def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]: + """Get cid from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1", + path + ) + return row["cid"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(path, result) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def get_path(self, cid: str) -> Optional[str]: + """ + Get the path for a cid. + + Args: + cid: Content-addressed ID + + Returns: + The path, or None if not found + """ + # Try Redis first + if self._redis: + try: + val = self._redis.hget(self._redis_cid_to_path_key, cid) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_path_from_db(cid) + + def _get_path_from_db(self, cid: str) -> Optional[str]: + """Get path from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + row = await conn.fetchrow( + "SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1", + cid + ) + return row["path"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(result, cid) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]: + """ + List all entries of a given type. + + Args: + content_type: Type to filter by ("effect", "recipe", etc.) + actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global) + + Returns: + List of PathEntry objects + """ + import database + + async def list_entries(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if actor_id: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL + ORDER BY path + """, + content_type, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND path IS NOT NULL + ORDER BY path + """, + content_type + ) + return [ + PathEntry( + path=row["path"], + cid=row["cid"], + content_type=row["type"], + actor_id=row["actor_id"], + description=row["description"], + created_at=row["created_at"] or 0, + ) + for row in rows + ] + finally: + await conn.close() + + try: + return self._run_async(list_entries()) + except Exception as e: + logger.warning(f"Failed to list from DB: {e}") + return [] + + def delete(self, path: str, content_type: str = None) -> bool: + """ + Delete a path registration. + + Args: + path: The path to delete + content_type: Optional type filter + + Returns: + True if deleted, False if not found + """ + norm_path = self._normalize_path(path) + + # Get cid for Redis cleanup + cid = self.get_cid(norm_path, content_type) + + # Delete from database + deleted = self._delete_from_db(norm_path, content_type) + + # Clean up Redis + if deleted and cid and self._redis: + try: + self._redis.hdel(self._redis_path_to_cid_key, norm_path) + self._redis.hdel(self._redis_cid_to_path_key, cid) + except Exception as e: + logger.warning(f"Failed to clean up Redis: {e}") + + return deleted + + def _delete_from_db(self, path: str, content_type: str = None) -> bool: + """Delete from database.""" + import database + + async def delete(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1", + path + ) + return "DELETE" in result + finally: + await conn.close() + + try: + return self._run_async(delete()) + except Exception as e: + logger.warning(f"Failed to delete from DB: {e}") + return False + + def register_effect( + self, + path: str, + cid: str, + description: Optional[str] = None, + ) -> PathEntry: + """ + Convenience method to register an effect. + + Args: + path: Effect path (e.g., "effects/ascii_fx_zone.sexp") + cid: IPFS CID of the effect file + description: Optional description + + Returns: + The created PathEntry + """ + return self.register( + path=path, + cid=cid, + content_type="effect", + actor_id=SYSTEM_ACTOR, + description=description, + ) + + def get_effect_cid(self, path: str) -> Optional[str]: + """ + Get CID for an effect by path. + + Args: + path: Effect path + + Returns: + IPFS CID or None + """ + return self.get_cid(path, content_type="effect") + + def list_effects(self) -> List[PathEntry]: + """List all registered effects.""" + return self.list_by_type("effect", actor_id=SYSTEM_ACTOR) + + +# Singleton instance +_registry: Optional[PathRegistry] = None + + +def get_path_registry() -> PathRegistry: + """Get the singleton path registry instance.""" + global _registry + if _registry is None: + import redis + from urllib.parse import urlparse + + redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + parsed = urlparse(redis_url) + redis_client = redis.Redis( + host=parsed.hostname or 'localhost', + port=parsed.port or 6379, + db=int(parsed.path.lstrip('/') or 0), + socket_timeout=5, + socket_connect_timeout=5 + ) + + _registry = PathRegistry(redis_client=redis_client) + return _registry + + +def reset_path_registry(): + """Reset the singleton (for testing).""" + global _registry + _registry = None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b358312 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[project] +name = "art-celery" +version = "0.1.0" +description = "Art DAG L1 Server and Celery Workers" +requires-python = ">=3.11" + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +strict_optional = true +no_implicit_optional = true + +# Start strict on new code, gradually enable for existing +files = [ + "app/types.py", + "app/routers/recipes.py", + "tests/", +] + +# Ignore missing imports for third-party packages without stubs +[[tool.mypy.overrides]] +module = [ + "celery.*", + "redis.*", + "artdag.*", + "artdag_common.*", + "ipfs_client.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +asyncio_mode = "auto" +addopts = "-v --tb=short" +filterwarnings = [ + "ignore::DeprecationWarning", +] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "UP"] +ignore = ["E501"] # Line length handled separately diff --git a/recipes/woods-lowres.sexp b/recipes/woods-lowres.sexp new file mode 100644 index 0000000..55a1a6a --- /dev/null +++ b/recipes/woods-lowres.sexp @@ -0,0 +1,223 @@ +;; Woods Recipe - OPTIMIZED VERSION +;; +;; Uses fused-pipeline for GPU acceleration when available, +;; falls back to individual primitives on CPU. +;; +;; Key optimizations: +;; 1. Uses streaming_gpu primitives with fast CUDA kernels +;; 2. Uses fused-pipeline to batch effects into single kernel passes +;; 3. GPU persistence - frames stay on GPU throughout pipeline + +(stream "woods-lowres" + :fps 30 + :width 640 + :height 360 + :seed 42 + + ;; Load standard primitives (includes proper asset resolution) + ;; Auto-selects GPU versions when available, falls back to CPU + (include :name "tpl-standard-primitives") + + ;; === SOURCES (using streaming: which has proper asset resolution) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio + (def music (streaming:make-audio-analyzer "woods-audio")) + (audio-playback "woods-audio") + + ;; === SCANS === + + ;; Cycle state + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Spin scan + (scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :speed 2} + :step (let [new-dir (if (< (core:rand) 0.05) (* dir -1) dir) + new-speed (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) speed)] + (dict :angle (+ angle (* new-dir new-speed)) + :dir new-dir + :speed new-speed))) + + ;; Ripple scan - raindrop style, all params randomized + ;; Higher freq = bigger gaps between waves (formula is dist/freq) + (scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 320 :cy 180 :freq 20 :decay 6 :amp-mult 1.0} + :step (let [new-gate (if (< (core:rand) 0.2) (+ 2 (core:rand-int 0 4)) (core:max 0 (- gate 1))) + triggered (> new-gate gate) + new-cx (if triggered (core:rand-int 50 590) cx) + new-cy (if triggered (core:rand-int 50 310) cy) + new-freq (if triggered (+ 15 (core:rand-int 0 20)) freq) + new-decay (if triggered (+ 5 (core:rand-int 0 4)) decay) + new-amp-mult (if triggered (+ 0.8 (* (core:rand) 1.2)) amp-mult)] + (dict :gate new-gate :cx new-cx :cy new-cy :freq new-freq :decay new-decay :amp-mult new-amp-mult))) + + ;; Pair states + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === OPTIMIZED PROCESS-PAIR MACRO === + ;; Uses fused-pipeline to batch rotate+hue+invert into single kernel + (defmacro process-pair-fast (idx) + (let [;; Get sources for this pair (with safe modulo indexing) + num-sources (len sources) + src-a (nth sources (mod (* idx 2) num-sources)) + src-b (nth sources (mod (+ (* idx 2) 1) num-sources)) + cfg (nth pair-configs idx) + pstate (nth (bind pairs :states) idx) + + ;; Read frames (GPU decode, stays on GPU) + frame-a (streaming:source-read src-a t) + frame-b (streaming:source-read src-b t) + + ;; Get state values + dir (get cfg :dir) + rot-max-a (get cfg :rot-a) + rot-max-b (get cfg :rot-b) + zoom-max-a (get cfg :zoom-a) + zoom-max-b (get cfg :zoom-b) + pair-angle (get pstate :angle) + inv-a-on (> (get pstate :inv-a) 0) + inv-b-on (> (get pstate :inv-b) 0) + hue-a-on (> (get pstate :hue-a) 0) + hue-b-on (> (get pstate :hue-b) 0) + hue-a-val (get pstate :hue-a-val) + hue-b-val (get pstate :hue-b-val) + mix-ratio (get pstate :mix) + + ;; Calculate rotation angles + angle-a (* dir pair-angle rot-max-a 0.01) + angle-b (* dir pair-angle rot-max-b 0.01) + + ;; Energy-driven zoom (maps audio energy 0-1 to 1-max) + zoom-a (core:map-range e 0 1 1 zoom-max-a) + zoom-b (core:map-range e 0 1 1 zoom-max-b) + + ;; Define effect pipelines for each source + ;; These get compiled to single CUDA kernels! + ;; First resize to target resolution, then apply effects + effects-a [{:op "resize" :width 640 :height 360} + {:op "zoom" :amount zoom-a} + {:op "rotate" :angle angle-a} + {:op "hue_shift" :degrees (if hue-a-on hue-a-val 0)} + {:op "invert" :amount (if inv-a-on 1 0)}] + effects-b [{:op "resize" :width 640 :height 360} + {:op "zoom" :amount zoom-b} + {:op "rotate" :angle angle-b} + {:op "hue_shift" :degrees (if hue-b-on hue-b-val 0)} + {:op "invert" :amount (if inv-b-on 1 0)}] + + ;; Apply fused pipelines (single kernel per source!) + processed-a (streaming:fused-pipeline frame-a effects-a) + processed-b (streaming:fused-pipeline frame-b effects-b)] + + ;; Blend the two processed frames + (blending:blend-images processed-a processed-b mix-ratio))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Process active pair with fused pipeline + active-frame (process-pair-fast active) + + ;; Crossfade with zoom during transition + ;; Old pair: zooms out (1.0 -> 2.0) and fades out + ;; New pair: starts small (0.1), zooms in (-> 1.0) and fades in + result (if fading + (let [next-frame (process-pair-fast next-idx) + ;; Active zooms out as it fades + active-zoom (+ 1.0 fade-amt) + active-zoomed (streaming:fused-pipeline active-frame + [{:op "zoom" :amount active-zoom}]) + ;; Next starts small and zooms in + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (streaming:fused-pipeline next-frame + [{:op "zoom" :amount next-zoom}])] + (blending:blend-images active-zoomed next-zoomed fade-amt)) + active-frame) + + ;; Final effects pipeline (fused!) + spin-angle (bind spin :angle) + ;; Ripple params - all randomized per ripple trigger + rip-gate (bind ripple-state :gate) + rip-amp-mult (bind ripple-state :amp-mult) + rip-amp (* rip-gate rip-amp-mult (core:map-range e 0 1 50 200)) + rip-cx (bind ripple-state :cx) + rip-cy (bind ripple-state :cy) + rip-freq (bind ripple-state :freq) + rip-decay (bind ripple-state :decay) + + ;; Fused final effects + final-effects [{:op "rotate" :angle spin-angle} + {:op "ripple" :amplitude rip-amp :frequency rip-freq :decay rip-decay + :phase (* now 5) :center_x rip-cx :center_y rip-cy}]] + + ;; Apply final fused pipeline + (streaming:fused-pipeline result final-effects + :rotate_angle spin-angle + :ripple_phase (* now 5) + :ripple_amplitude rip-amp)))) diff --git a/recipes/woods-recipe-optimized.sexp b/recipes/woods-recipe-optimized.sexp new file mode 100644 index 0000000..bec96b8 --- /dev/null +++ b/recipes/woods-recipe-optimized.sexp @@ -0,0 +1,211 @@ +;; Woods Recipe - OPTIMIZED VERSION +;; +;; Uses fused-pipeline for GPU acceleration when available, +;; falls back to individual primitives on CPU. +;; +;; Key optimizations: +;; 1. Uses streaming_gpu primitives with fast CUDA kernels +;; 2. Uses fused-pipeline to batch effects into single kernel passes +;; 3. GPU persistence - frames stay on GPU throughout pipeline + +(stream "woods-recipe-optimized" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives (includes proper asset resolution) + ;; Auto-selects GPU versions when available, falls back to CPU + (include :name "tpl-standard-primitives") + + ;; === SOURCES (using streaming: which has proper asset resolution) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio + (def music (streaming:make-audio-analyzer "woods-audio")) + (audio-playback "woods-audio") + + ;; === SCANS === + + ;; Cycle state + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Spin scan + (scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :speed 2} + :step (let [new-dir (if (< (core:rand) 0.05) (* dir -1) dir) + new-speed (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) speed)] + (dict :angle (+ angle (* new-dir new-speed)) + :dir new-dir + :speed new-speed))) + + ;; Ripple scan + (scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 960 :cy 540} + :step (let [new-gate (if (< (core:rand) 0.15) (+ 3 (core:rand-int 0 5)) (core:max 0 (- gate 1))) + new-cx (if (> new-gate gate) (+ 200 (core:rand-int 0 1520)) cx) + new-cy (if (> new-gate gate) (+ 200 (core:rand-int 0 680)) cy)] + (dict :gate new-gate :cx new-cx :cy new-cy))) + + ;; Pair states + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === OPTIMIZED PROCESS-PAIR MACRO === + ;; Uses fused-pipeline to batch rotate+hue+invert into single kernel + (defmacro process-pair-fast (idx) + (let [;; Get sources for this pair (with safe modulo indexing) + num-sources (len sources) + src-a (nth sources (mod (* idx 2) num-sources)) + src-b (nth sources (mod (+ (* idx 2) 1) num-sources)) + cfg (nth pair-configs idx) + pstate (nth (bind pairs :states) idx) + + ;; Read frames (GPU decode, stays on GPU) + frame-a (streaming:source-read src-a t) + frame-b (streaming:source-read src-b t) + + ;; Get state values + dir (get cfg :dir) + rot-max-a (get cfg :rot-a) + rot-max-b (get cfg :rot-b) + zoom-max-a (get cfg :zoom-a) + zoom-max-b (get cfg :zoom-b) + pair-angle (get pstate :angle) + inv-a-on (> (get pstate :inv-a) 0) + inv-b-on (> (get pstate :inv-b) 0) + hue-a-on (> (get pstate :hue-a) 0) + hue-b-on (> (get pstate :hue-b) 0) + hue-a-val (get pstate :hue-a-val) + hue-b-val (get pstate :hue-b-val) + mix-ratio (get pstate :mix) + + ;; Calculate rotation angles + angle-a (* dir pair-angle rot-max-a 0.01) + angle-b (* dir pair-angle rot-max-b 0.01) + + ;; Energy-driven zoom (maps audio energy 0-1 to 1-max) + zoom-a (core:map-range e 0 1 1 zoom-max-a) + zoom-b (core:map-range e 0 1 1 zoom-max-b) + + ;; Define effect pipelines for each source + ;; These get compiled to single CUDA kernels! + effects-a [{:op "zoom" :amount zoom-a} + {:op "rotate" :angle angle-a} + {:op "hue_shift" :degrees (if hue-a-on hue-a-val 0)} + {:op "invert" :amount (if inv-a-on 1 0)}] + effects-b [{:op "zoom" :amount zoom-b} + {:op "rotate" :angle angle-b} + {:op "hue_shift" :degrees (if hue-b-on hue-b-val 0)} + {:op "invert" :amount (if inv-b-on 1 0)}] + + ;; Apply fused pipelines (single kernel per source!) + processed-a (streaming:fused-pipeline frame-a effects-a) + processed-b (streaming:fused-pipeline frame-b effects-b)] + + ;; Blend the two processed frames + (blending:blend-images processed-a processed-b mix-ratio))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Process active pair with fused pipeline + active-frame (process-pair-fast active) + + ;; Crossfade with zoom during transition + ;; Old pair: zooms out (1.0 -> 2.0) and fades out + ;; New pair: starts small (0.1), zooms in (-> 1.0) and fades in + result (if fading + (let [next-frame (process-pair-fast next-idx) + ;; Active zooms out as it fades + active-zoom (+ 1.0 fade-amt) + active-zoomed (streaming:fused-pipeline active-frame + [{:op "zoom" :amount active-zoom}]) + ;; Next starts small and zooms in + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (streaming:fused-pipeline next-frame + [{:op "zoom" :amount next-zoom}])] + (blending:blend-images active-zoomed next-zoomed fade-amt)) + active-frame) + + ;; Final effects pipeline (fused!) + spin-angle (bind spin :angle) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50)) + rip-cx (bind ripple-state :cx) + rip-cy (bind ripple-state :cy) + + ;; Fused final effects + final-effects [{:op "rotate" :angle spin-angle} + {:op "ripple" :amplitude rip-amp :frequency 8 :decay 2 + :phase (* now 5) :center_x rip-cx :center_y rip-cy}]] + + ;; Apply final fused pipeline + (streaming:fused-pipeline result final-effects + :rotate_angle spin-angle + :ripple_phase (* now 5) + :ripple_amplitude rip-amp)))) diff --git a/recipes/woods-recipe.sexp b/recipes/woods-recipe.sexp new file mode 100644 index 0000000..4c5f4ec --- /dev/null +++ b/recipes/woods-recipe.sexp @@ -0,0 +1,134 @@ +;; Woods Recipe - Using friendly names for all assets +;; +;; Requires uploaded: +;; - Media: woods-1 through woods-8 (videos), woods-audio (audio) +;; - Effects: fx-rotate, fx-zoom, fx-blend, fx-ripple, fx-invert, fx-hue-shift +;; - Templates: tpl-standard-primitives, tpl-standard-effects, tpl-process-pair, +;; tpl-crossfade-zoom, tpl-scan-spin, tpl-scan-ripple + +(stream "woods-recipe" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives and effects via friendly names + (include :name "tpl-standard-primitives") + (include :name "tpl-standard-effects") + + ;; Load reusable templates + (include :name "tpl-process-pair") + (include :name "tpl-crossfade-zoom") + + ;; === SOURCES AS ARRAY (using friendly names) === + (def sources [ + (streaming:make-video-source "woods-1" 30) + (streaming:make-video-source "woods-2" 30) + (streaming:make-video-source "woods-3" 30) + (streaming:make-video-source "woods-4" 30) + (streaming:make-video-source "woods-5" 30) + (streaming:make-video-source "woods-6" 30) + (streaming:make-video-source "woods-7" 30) + (streaming:make-video-source "woods-8" 30) + ]) + + ;; Per-pair config: [rot-dir, rot-a-max, rot-b-max, zoom-a-max, zoom-b-max] + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + ]) + + ;; Audio analyzer (using friendly name) + (def music (streaming:make-audio-analyzer "woods-audio")) + + ;; Audio playback (friendly name resolved by streaming primitives) + (audio-playback "woods-audio") + + ;; === GLOBAL SCANS === + + ;; Cycle state: which source is active + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Reusable scans from templates + (include :name "tpl-scan-spin") + (include :name "tpl-scan-ripple") + + ;; === PER-PAIR STATE === + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Get pair states array + pair-states (bind pairs :states) + + ;; Process active pair using macro from template + active-frame (process-pair active) + + ;; Crossfade with zoom during transition + result (if fading + (crossfade-zoom active-frame (process-pair next-idx) fade-amt) + active-frame) + + ;; Final: global spin + ripple + spun (rotate result :angle (bind spin :angle)) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50))] + + (ripple spun + :amplitude rip-amp + :center_x (bind ripple-state :cx) + :center_y (bind ripple-state :cy) + :frequency 8 + :decay 2 + :speed 5)))) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..b7e7438 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,16 @@ +# Development dependencies +-r requirements.txt + +# Type checking +mypy>=1.8.0 +types-requests>=2.31.0 +types-PyYAML>=6.0.0 +typing_extensions>=4.9.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 + +# Linting +ruff>=0.2.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..deab545 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +celery[redis]>=5.3.0 +redis>=5.0.0 +requests>=2.31.0 +httpx>=0.27.0 +itsdangerous>=2.0 +cryptography>=41.0 +fastapi>=0.109.0 +uvicorn>=0.27.0 +python-multipart>=0.0.6 +PyYAML>=6.0 +asyncpg>=0.29.0 +markdown>=3.5.0 +# Common effect dependencies (used by uploaded effects) +numpy>=1.24.0 +opencv-python-headless>=4.8.0 +# Core artdag from GitHub (tracks main branch) +git+https://github.com/gilesbradshaw/art-dag.git@main +# Shared components (tracks master branch) +git+https://git.rose-ash.com/art-dag/common.git@master +psycopg2-binary +nest_asyncio diff --git a/scripts/cloud-init-gpu.sh b/scripts/cloud-init-gpu.sh new file mode 100644 index 0000000..fe8cc27 --- /dev/null +++ b/scripts/cloud-init-gpu.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Cloud-init startup script for GPU droplet (RTX 6000 Ada, etc.) +# Paste this into DigitalOcean "User data" field when creating droplet + +set -e +export DEBIAN_FRONTEND=noninteractive +exec > /var/log/artdag-setup.log 2>&1 + +echo "=== ArtDAG GPU Setup Started $(date) ===" + +# Update system (non-interactive, keep existing configs) +apt-get update +apt-get -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" upgrade + +# Install essentials +apt-get install -y \ + python3 python3-venv python3-pip \ + git curl wget \ + ffmpeg \ + vulkan-tools \ + build-essential + +# Create venv +VENV_DIR="/opt/artdag-gpu" +python3 -m venv "$VENV_DIR" +source "$VENV_DIR/bin/activate" + +# Install Python packages +pip install --upgrade pip +pip install \ + numpy \ + opencv-python-headless \ + wgpu \ + httpx \ + pyyaml \ + celery[redis] \ + fastapi \ + uvicorn \ + asyncpg + +# Create code directory +mkdir -p "$VENV_DIR/celery/sexp_effects/effects" +mkdir -p "$VENV_DIR/celery/sexp_effects/primitive_libs" +mkdir -p "$VENV_DIR/celery/streaming" + +# Add SSH key for easier access (optional - add your key here) +# echo "ssh-ed25519 AAAA... your-key" >> /root/.ssh/authorized_keys + +# Test GPU +echo "=== GPU Info ===" +nvidia-smi || echo "nvidia-smi not available yet" + +echo "=== NVENC Check ===" +ffmpeg -encoders 2>/dev/null | grep -E "nvenc|cuda" || echo "NVENC not detected" + +echo "=== wgpu Check ===" +"$VENV_DIR/bin/python3" -c " +import wgpu +try: + adapter = wgpu.gpu.request_adapter_sync(power_preference='high-performance') + print(f'GPU: {adapter.info}') +except Exception as e: + print(f'wgpu error: {e}') +" || echo "wgpu test failed" + +# Add environment setup +cat >> /etc/profile.d/artdag-gpu.sh << 'ENVEOF' +export WGPU_BACKEND_TYPE=Vulkan +export PATH="/opt/artdag-gpu/bin:$PATH" +ENVEOF + +# Mark setup complete +touch /opt/artdag-gpu/.setup-complete +echo "=== Setup Complete $(date) ===" +echo "Venv: /opt/artdag-gpu" +echo "Activate: source /opt/artdag-gpu/bin/activate" +echo "Vulkan: export WGPU_BACKEND_TYPE=Vulkan" diff --git a/scripts/deploy-to-gpu.sh b/scripts/deploy-to-gpu.sh new file mode 100755 index 0000000..e41c802 --- /dev/null +++ b/scripts/deploy-to-gpu.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Deploy art-dag GPU code to a remote droplet +# Usage: ./deploy-to-gpu.sh + +set -e + +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "Example: $0 159.223.7.100" + exit 1 +fi + +DROPLET_IP="$1" +REMOTE_DIR="/opt/artdag-gpu/celery" +LOCAL_DIR="$(dirname "$0")/.." + +echo "=== Deploying to $DROPLET_IP ===" + +# Create remote directory +echo "[1/4] Creating remote directory..." +ssh "root@$DROPLET_IP" "mkdir -p $REMOTE_DIR/sexp_effects $REMOTE_DIR/streaming $REMOTE_DIR/scripts" + +# Copy core files +echo "[2/4] Copying core files..." +scp "$LOCAL_DIR/sexp_effects/wgsl_compiler.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/parser.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/interpreter.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/sexp_effects/__init__.py" "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/" +scp "$LOCAL_DIR/streaming/backends.py" "root@$DROPLET_IP:$REMOTE_DIR/streaming/" + +# Copy effects +echo "[3/4] Copying effects..." +ssh "root@$DROPLET_IP" "mkdir -p $REMOTE_DIR/sexp_effects/effects $REMOTE_DIR/sexp_effects/primitive_libs" +scp -r "$LOCAL_DIR/sexp_effects/effects/"*.sexp "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/effects/" 2>/dev/null || true +scp -r "$LOCAL_DIR/sexp_effects/primitive_libs/"*.py "root@$DROPLET_IP:$REMOTE_DIR/sexp_effects/primitive_libs/" 2>/dev/null || true + +# Test +echo "[4/4] Testing deployment..." +ssh "root@$DROPLET_IP" "cd $REMOTE_DIR && /opt/artdag-gpu/bin/python3 -c ' +import sys +sys.path.insert(0, \".\") +from sexp_effects.wgsl_compiler import compile_effect_file +result = compile_effect_file(\"sexp_effects/effects/invert.sexp\") +print(f\"Compiled effect: {result.name}\") +print(\"Deployment OK\") +'" || echo "Test failed - may need to run setup script first" + +echo "" +echo "=== Deployment complete ===" +echo "SSH: ssh root@$DROPLET_IP" +echo "Test: ssh root@$DROPLET_IP 'cd $REMOTE_DIR && /opt/artdag-gpu/bin/python3 -c \"from streaming.backends import get_backend; b=get_backend(\\\"wgpu\\\"); print(b)\"'" diff --git a/scripts/gpu-dev-deploy.sh b/scripts/gpu-dev-deploy.sh new file mode 100755 index 0000000..f1be595 --- /dev/null +++ b/scripts/gpu-dev-deploy.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Quick deploy to GPU node with hot reload +# Usage: ./scripts/gpu-dev-deploy.sh + +set -e + +GPU_HOST="${GPU_HOST:-root@138.197.163.123}" +REMOTE_DIR="/root/art-dag/celery" + +echo "=== GPU Dev Deploy ===" +echo "Syncing code to $GPU_HOST..." + +# Sync code (excluding cache, git, __pycache__) +rsync -avz --delete \ + --exclude '.git' \ + --exclude '__pycache__' \ + --exclude '*.pyc' \ + --exclude '.pytest_cache' \ + --exclude 'node_modules' \ + --exclude '.env' \ + ./ "$GPU_HOST:$REMOTE_DIR/" + +echo "Restarting GPU worker..." +ssh "$GPU_HOST" "docker kill \$(docker ps -q -f name=l1-gpu-worker) 2>/dev/null || true" + +echo "Waiting for new container..." +sleep 10 + +# Show new container logs +ssh "$GPU_HOST" "docker logs --tail 30 \$(docker ps -q -f name=l1-gpu-worker)" + +echo "" +echo "=== Deploy Complete ===" +echo "Use 'ssh $GPU_HOST docker logs -f \$(docker ps -q -f name=l1-gpu-worker)' to follow logs" diff --git a/scripts/setup-gpu-droplet.sh b/scripts/setup-gpu-droplet.sh new file mode 100755 index 0000000..e731ef8 --- /dev/null +++ b/scripts/setup-gpu-droplet.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Setup script for GPU droplet with NVENC support +# Run as root on a fresh Ubuntu droplet with NVIDIA GPU + +set -e + +echo "=== ArtDAG GPU Droplet Setup ===" + +# 1. System updates +echo "[1/7] Updating system..." +apt-get update +apt-get upgrade -y + +# 2. Install NVIDIA drivers (if not already installed) +echo "[2/7] Checking NVIDIA drivers..." +if ! command -v nvidia-smi &> /dev/null; then + echo "Installing NVIDIA drivers..." + apt-get install -y nvidia-driver-535 nvidia-utils-535 + echo "NVIDIA drivers installed. Reboot required." + echo "After reboot, run this script again." + exit 0 +fi + +nvidia-smi +echo "NVIDIA drivers OK" + +# 3. Install FFmpeg with NVENC support +echo "[3/7] Installing FFmpeg with NVENC..." +apt-get install -y ffmpeg + +# Verify NVENC +if ffmpeg -encoders 2>/dev/null | grep -q nvenc; then + echo "NVENC available:" + ffmpeg -encoders 2>/dev/null | grep nvenc +else + echo "WARNING: NVENC not available. GPU may not support hardware encoding." +fi + +# 4. Install Python and create venv +echo "[4/7] Setting up Python environment..." +apt-get install -y python3 python3-venv python3-pip git + +VENV_DIR="/opt/artdag-gpu" +python3 -m venv "$VENV_DIR" +source "$VENV_DIR/bin/activate" + +# 5. Install Python dependencies +echo "[5/7] Installing Python packages..." +pip install --upgrade pip +pip install \ + numpy \ + opencv-python-headless \ + wgpu \ + httpx \ + pyyaml \ + celery[redis] \ + fastapi \ + uvicorn + +# 6. Clone/update art-dag code +echo "[6/7] Setting up art-dag code..." +ARTDAG_DIR="$VENV_DIR/celery" +if [ -d "$ARTDAG_DIR" ]; then + echo "Updating existing code..." + cd "$ARTDAG_DIR" + git pull || true +else + echo "Cloning art-dag..." + git clone https://git.rose-ash.com/art-dag/celery.git "$ARTDAG_DIR" || { + echo "Git clone failed. You may need to copy code manually." + } +fi + +# 7. Test GPU compute +echo "[7/7] Testing GPU compute..." +"$VENV_DIR/bin/python3" << 'PYTEST' +import sys +try: + import wgpu + adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + print(f"GPU Adapter: {adapter.info.get('device', 'unknown')}") + device = adapter.request_device_sync() + print("wgpu device created successfully") + + # Check for NVENC via FFmpeg + import subprocess + result = subprocess.run(['ffmpeg', '-encoders'], capture_output=True, text=True) + if 'h264_nvenc' in result.stdout: + print("NVENC H.264 encoder: AVAILABLE") + else: + print("NVENC H.264 encoder: NOT AVAILABLE") + if 'hevc_nvenc' in result.stdout: + print("NVENC HEVC encoder: AVAILABLE") + else: + print("NVENC HEVC encoder: NOT AVAILABLE") + +except Exception as e: + print(f"Error: {e}") + sys.exit(1) +PYTEST + +echo "" +echo "=== Setup Complete ===" +echo "Venv: $VENV_DIR" +echo "Code: $ARTDAG_DIR" +echo "" +echo "To activate: source $VENV_DIR/bin/activate" +echo "To test: cd $ARTDAG_DIR && python -c 'from streaming.backends import get_backend; print(get_backend(\"wgpu\"))'" diff --git a/server.py b/server.py new file mode 100644 index 0000000..f7c1e1e --- /dev/null +++ b/server.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +Art DAG L1 Server + +Minimal entry point that uses the modular app factory. +All routes are defined in app/routers/. +All templates are in app/templates/. +""" + +import logging +import os + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s %(name)s: %(message)s' +) + +# Import the app from the factory +from app import app + +if __name__ == "__main__": + import uvicorn + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", "8100")) + uvicorn.run("server:app", host=host, port=port, workers=4) diff --git a/sexp_effects/__init__.py b/sexp_effects/__init__.py new file mode 100644 index 0000000..b001c71 --- /dev/null +++ b/sexp_effects/__init__.py @@ -0,0 +1,32 @@ +""" +S-Expression Effects System + +Safe, shareable effects defined in S-expressions. +""" + +from .parser import parse, parse_file, Symbol, Keyword +from .interpreter import ( + Interpreter, + get_interpreter, + load_effect, + load_effects_dir, + run_effect, + list_effects, + make_process_frame, +) +from .primitives import PRIMITIVES + +__all__ = [ + 'parse', + 'parse_file', + 'Symbol', + 'Keyword', + 'Interpreter', + 'get_interpreter', + 'load_effect', + 'load_effects_dir', + 'run_effect', + 'list_effects', + 'make_process_frame', + 'PRIMITIVES', +] diff --git a/sexp_effects/derived.sexp b/sexp_effects/derived.sexp new file mode 100644 index 0000000..7e1aae3 --- /dev/null +++ b/sexp_effects/derived.sexp @@ -0,0 +1,206 @@ +;; Derived Operations +;; +;; These are built from true primitives using S-expressions. +;; Load with: (require "derived") + +;; ============================================================================= +;; Math Helpers (derivable from where + basic ops) +;; ============================================================================= + +;; Absolute value +(define (abs x) (where (< x 0) (- x) x)) + +;; Minimum of two values +(define (min2 a b) (where (< a b) a b)) + +;; Maximum of two values +(define (max2 a b) (where (> a b) a b)) + +;; Clamp x to range [lo, hi] +(define (clamp x lo hi) (max2 lo (min2 hi x))) + +;; Square of x +(define (sq x) (* x x)) + +;; Linear interpolation: a*(1-t) + b*t +(define (lerp a b t) (+ (* a (- 1 t)) (* b t))) + +;; Smooth interpolation between edges +(define (smoothstep edge0 edge1 x) + (let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1))) + (* t (* t (- 3 (* 2 t)))))) + +;; ============================================================================= +;; Channel Shortcuts (derivable from channel primitive) +;; ============================================================================= + +;; Extract red channel as xector +(define (red frame) (channel frame 0)) + +;; Extract green channel as xector +(define (green frame) (channel frame 1)) + +;; Extract blue channel as xector +(define (blue frame) (channel frame 2)) + +;; Convert to grayscale xector (ITU-R BT.601) +(define (gray frame) + (+ (* (red frame) 0.299) + (* (green frame) 0.587) + (* (blue frame) 0.114))) + +;; Alias for gray +(define (luminance frame) (gray frame)) + +;; ============================================================================= +;; Coordinate Generators (derivable from iota + repeat/tile) +;; ============================================================================= + +;; X coordinate for each pixel [0, width) +(define (x-coords frame) (tile (iota (width frame)) (height frame))) + +;; Y coordinate for each pixel [0, height) +(define (y-coords frame) (repeat (iota (height frame)) (width frame))) + +;; Normalized X coordinate [0, 1] +(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1)))) + +;; Normalized Y coordinate [0, 1] +(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1)))) + +;; Distance from frame center for each pixel +(define (dist-from-center frame) + (let* ((cx (/ (width frame) 2)) + (cy (/ (height frame) 2)) + (dx (- (x-coords frame) cx)) + (dy (- (y-coords frame) cy))) + (sqrt (+ (sq dx) (sq dy))))) + +;; Normalized distance from center [0, ~1] +(define (dist-norm frame) + (let ((d (dist-from-center frame))) + (/ d (max2 1 (βmax d))))) + +;; ============================================================================= +;; Cell/Grid Operations (derivable from floor + basic math) +;; ============================================================================= + +;; Cell row index for each pixel +(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size))) + +;; Cell column index for each pixel +(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size))) + +;; Number of cell rows +(define (num-rows frame cell-size) (floor (/ (height frame) cell-size))) + +;; Number of cell columns +(define (num-cols frame cell-size) (floor (/ (width frame) cell-size))) + +;; Flat cell index for each pixel +(define (cell-indices frame cell-size) + (+ (* (cell-row frame cell-size) (num-cols frame cell-size)) + (cell-col frame cell-size))) + +;; Total number of cells +(define (num-cells frame cell-size) + (* (num-rows frame cell-size) (num-cols frame cell-size))) + +;; X position within cell [0, cell-size) +(define (local-x frame cell-size) (mod (x-coords frame) cell-size)) + +;; Y position within cell [0, cell-size) +(define (local-y frame cell-size) (mod (y-coords frame) cell-size)) + +;; Normalized X within cell [0, 1] +(define (local-x-norm frame cell-size) + (/ (local-x frame cell-size) (max2 1 (- cell-size 1)))) + +;; Normalized Y within cell [0, 1] +(define (local-y-norm frame cell-size) + (/ (local-y frame cell-size) (max2 1 (- cell-size 1)))) + +;; ============================================================================= +;; Fill Operations (derivable from iota) +;; ============================================================================= + +;; Xector of n zeros +(define (zeros n) (* (iota n) 0)) + +;; Xector of n ones +(define (ones n) (+ (zeros n) 1)) + +;; Xector of n copies of val +(define (fill val n) (+ (zeros n) val)) + +;; Xector of zeros matching x's length +(define (zeros-like x) (* x 0)) + +;; Xector of ones matching x's length +(define (ones-like x) (+ (zeros-like x) 1)) + +;; ============================================================================= +;; Pooling (derivable from group-reduce) +;; ============================================================================= + +;; Pool a channel by cell index +(define (pool-channel chan cell-idx num-cells) + (group-reduce chan cell-idx num-cells "mean")) + +;; Pool red channel to cells +(define (pool-red frame cell-size) + (pool-channel (red frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool green channel to cells +(define (pool-green frame cell-size) + (pool-channel (green frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool blue channel to cells +(define (pool-blue frame cell-size) + (pool-channel (blue frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool grayscale to cells +(define (pool-gray frame cell-size) + (pool-channel (gray frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; ============================================================================= +;; Blending (derivable from math) +;; ============================================================================= + +;; Additive blend +(define (blend-add a b) (clamp (+ a b) 0 255)) + +;; Multiply blend (normalized) +(define (blend-multiply a b) (* (/ a 255) b)) + +;; Screen blend +(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b)))) + +;; Overlay blend +(define (blend-overlay a b) + (where (< a 128) + (* 2 (/ (* a b) 255)) + (- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255))))) + +;; ============================================================================= +;; Simple Effects (derivable from primitives) +;; ============================================================================= + +;; Invert a channel (255 - c) +(define (invert-channel c) (- 255 c)) + +;; Binary threshold +(define (threshold-channel c thresh) (where (> c thresh) 255 0)) + +;; Reduce to n levels +(define (posterize-channel c levels) + (let ((step (/ 255 (- levels 1)))) + (* (round (/ c step)) step))) diff --git a/sexp_effects/effects/ascii_art.sexp b/sexp_effects/effects/ascii_art.sexp new file mode 100644 index 0000000..0504768 --- /dev/null +++ b/sexp_effects/effects/ascii_art.sexp @@ -0,0 +1,17 @@ +;; ASCII Art effect - converts image to ASCII characters +(require-primitives "ascii") + +(define-effect ascii_art + :params ( + (char_size :type int :default 8 :range [4 32]) + (alphabet :type string :default "standard") + (color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex") + (background_color :type string :default "black" :desc "background color name/hex") + (invert_colors :type int :default 0 :desc "swap foreground and background colors") + (contrast :type float :default 1.5 :range [1 3]) + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid frame chars colors char_size color_mode background_color invert_colors))) diff --git a/sexp_effects/effects/ascii_art_fx.sexp b/sexp_effects/effects/ascii_art_fx.sexp new file mode 100644 index 0000000..2bb14be --- /dev/null +++ b/sexp_effects/effects/ascii_art_fx.sexp @@ -0,0 +1,52 @@ +;; ASCII Art FX - converts image to ASCII characters with per-character effects +(require-primitives "ascii") + +(define-effect ascii_art_fx + :params ( + ;; Basic parameters + (char_size :type int :default 8 :range [4 32] + :desc "Size of each character cell in pixels") + (alphabet :type string :default "standard" + :desc "Character set to use") + (color_mode :type string :default "color" + :choices [color mono invert] + :desc "Color mode: color, mono, invert, or any color name/hex") + (background_color :type string :default "black" + :desc "Background color name or hex value") + (invert_colors :type int :default 0 :range [0 1] + :desc "Swap foreground and background colors (0/1)") + (contrast :type float :default 1.5 :range [1 3] + :desc "Character selection contrast") + + ;; Per-character effects + (char_jitter :type float :default 0 :range [0 20] + :desc "Position jitter amount in pixels") + (char_scale :type float :default 1.0 :range [0.5 2.0] + :desc "Character scale factor") + (char_rotation :type float :default 0 :range [0 180] + :desc "Rotation amount in degrees") + (char_hue_shift :type float :default 0 :range [0 360] + :desc "Hue shift in degrees") + + ;; Modulation sources + (jitter_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives jitter modulation") + (scale_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives scale modulation") + (rotation_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives rotation modulation") + (hue_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives hue shift modulation") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid-fx frame chars colors luminances char_size + color_mode background_color invert_colors + char_jitter char_scale char_rotation char_hue_shift + jitter_source scale_source rotation_source hue_source))) diff --git a/sexp_effects/effects/ascii_fx_zone.sexp b/sexp_effects/effects/ascii_fx_zone.sexp new file mode 100644 index 0000000..69e5340 --- /dev/null +++ b/sexp_effects/effects/ascii_fx_zone.sexp @@ -0,0 +1,102 @@ +;; Composable ASCII Art with Per-Zone Expression-Driven Effects +;; Requires ascii primitive library for the ascii-fx-zone primitive + +(require-primitives "ascii") + +;; Two modes of operation: +;; +;; 1. EXPRESSION MODE: Use zone-* variables in expression parameters +;; Zone variables available: +;; zone-row, zone-col: Grid position (integers) +;; zone-row-norm, zone-col-norm: Normalized position (0-1) +;; zone-lum: Cell luminance (0-1) +;; zone-sat: Cell saturation (0-1) +;; zone-hue: Cell hue (0-360) +;; zone-r, zone-g, zone-b: RGB components (0-1) +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 80 +;; :char_hue (* zone-lum 180) +;; :char_rotation (* zone-col-norm 30)) +;; +;; 2. CELL EFFECT MODE: Pass a lambda to apply arbitrary effects per-cell +;; The lambda receives (cell-image zone-dict) and returns modified cell. +;; Zone dict contains: row, col, row-norm, col-norm, lum, sat, hue, r, g, b, +;; char, color, cell_size, plus any bound analysis values. +;; +;; Any loaded sexp effect can be called on cells - each cell is just a small frame: +;; (blur cell radius) - Gaussian blur +;; (rotate cell angle) - Rotate by angle degrees +;; (brightness cell factor) - Adjust brightness +;; (contrast cell factor) - Adjust contrast +;; (saturation cell factor) - Adjust saturation +;; (hue_shift cell degrees) - Shift hue +;; (rgb_split cell offset_x offset_y) - RGB channel split +;; (invert cell) - Invert colors +;; (pixelate cell block_size) - Pixelate +;; (wave cell amplitude freq) - Wave distortion +;; ... and any other loaded effect +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 60 +;; :cell_effect (lambda [cell zone] +;; (blur (rotate cell (* (get zone "energy") 45)) +;; (if (> (get zone "lum") 0.5) 3 0)))) + +(define-effect ascii_fx_zone + :params ( + (cols :type int :default 80 :range [20 200] + :desc "Number of character columns") + (char_size :type int :default nil :range [4 32] + :desc "Character cell size in pixels (overrides cols if set)") + (alphabet :type string :default "standard" + :desc "Character set: standard, blocks, simple, digits, or custom string") + (color_mode :type string :default "color" + :desc "Color mode: color, mono, invert, or any color name/hex") + (background :type string :default "black" + :desc "Background color name or hex value") + (contrast :type float :default 1.5 :range [0.5 3.0] + :desc "Contrast for character selection") + (char_hue :type any :default nil + :desc "Hue shift expression (evaluated per-zone with zone-* vars)") + (char_saturation :type any :default nil + :desc "Saturation multiplier expression (1.0 = unchanged)") + (char_brightness :type any :default nil + :desc "Brightness multiplier expression (1.0 = unchanged)") + (char_scale :type any :default nil + :desc "Character scale expression (1.0 = normal size)") + (char_rotation :type any :default nil + :desc "Character rotation expression (degrees)") + (char_jitter :type any :default nil + :desc "Position jitter expression (pixels)") + (cell_effect :type any :default nil + :desc "Lambda (cell zone) -> cell for arbitrary per-cell effects") + ;; Convenience params for staged recipes (avoids compile-time expression issues) + (energy :type float :default nil + :desc "Energy multiplier (0-1) from audio analysis bind") + (rotation_scale :type float :default 0 + :desc "Max rotation at top-right when energy=1 (degrees)") + ) + ;; The ascii-fx-zone special form handles expression params + ;; If energy + rotation_scale provided, it builds: energy * scale * position_factor + ;; where position_factor = 0 at bottom-left, 3 at top-right + ;; If cell_effect provided, each character is rendered to a cell image, + ;; passed to the lambda, and the result composited back + (ascii-fx-zone frame + :cols cols + :char_size char_size + :alphabet alphabet + :color_mode color_mode + :background background + :contrast contrast + :char_hue char_hue + :char_saturation char_saturation + :char_brightness char_brightness + :char_scale char_scale + :char_rotation char_rotation + :char_jitter char_jitter + :cell_effect cell_effect + :energy energy + :rotation_scale rotation_scale)) diff --git a/sexp_effects/effects/ascii_zones.sexp b/sexp_effects/effects/ascii_zones.sexp new file mode 100644 index 0000000..6bc441c --- /dev/null +++ b/sexp_effects/effects/ascii_zones.sexp @@ -0,0 +1,30 @@ +;; ASCII Zones effect - different character sets for different brightness zones +;; Dark areas use simple chars, mid uses standard, bright uses blocks +(require-primitives "ascii") + +(define-effect ascii_zones + :params ( + (char_size :type int :default 8 :range [4 32]) + (dark_threshold :type int :default 80 :range [0 128]) + (bright_threshold :type int :default 180 :range [128 255]) + (color_mode :type string :default "color") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + ;; Start with simple chars as base + (base-chars (luminance-to-chars luminances "simple" 1.2)) + ;; Map each cell to appropriate alphabet based on brightness zone + (zoned-chars (map-char-grid base-chars luminances + (lambda (r c ch lum) + (cond + ;; Bright zones: use block characters + ((> lum bright_threshold) + (alphabet-char "blocks" (floor (/ (- lum bright_threshold) 15)))) + ;; Dark zones: use simple sparse chars + ((< lum dark_threshold) + (alphabet-char " .-" (floor (/ lum 30)))) + ;; Mid zones: use standard ASCII + (else + (alphabet-char "standard" (floor (/ lum 4))))))))) + (render-char-grid frame zoned-chars colors char_size color_mode (list 0 0 0)))) diff --git a/sexp_effects/effects/blend.sexp b/sexp_effects/effects/blend.sexp new file mode 100644 index 0000000..bf7fefd --- /dev/null +++ b/sexp_effects/effects/blend.sexp @@ -0,0 +1,31 @@ +;; Blend effect - combines two video frames +;; Streaming-compatible: frame is background, overlay is second frame +;; Usage: (blend background overlay :opacity 0.5 :mode "alpha") +;; +;; Params: +;; mode - blend mode (add, multiply, screen, overlay, difference, lighten, darken, alpha) +;; opacity - blend amount (0-1) + +(require-primitives "image" "blending" "core") + +(define-effect blend + :params ( + (overlay :type frame :default nil) + (mode :type string :default "alpha") + (opacity :type float :default 0.5) + ) + (if (core:is-nil overlay) + frame + (let [a frame + b overlay + a-h (image:height a) + a-w (image:width a) + b-h (image:height b) + b-w (image:width b) + ;; Resize b to match a if needed + b-sized (if (and (= a-w b-w) (= a-h b-h)) + b + (image:resize b a-w a-h "linear"))] + (if (= mode "alpha") + (blending:blend-images a b-sized opacity) + (blending:blend-images a (blending:blend-mode a b-sized mode) opacity))))) diff --git a/sexp_effects/effects/blend_multi.sexp b/sexp_effects/effects/blend_multi.sexp new file mode 100644 index 0000000..1ee160f --- /dev/null +++ b/sexp_effects/effects/blend_multi.sexp @@ -0,0 +1,58 @@ +;; N-way weighted blend effect +;; Streaming-compatible: pass inputs as a list of frames +;; Usage: (blend_multi :inputs [(read a) (read b) (read c)] :weights [0.3 0.4 0.3]) +;; +;; Parameters: +;; inputs - list of N frames to blend +;; weights - list of N floats, one per input (resolved per-frame) +;; mode - blend mode applied when folding each frame in: +;; "alpha" — pure weighted average (default) +;; "multiply" — darken by multiplication +;; "screen" — lighten (inverse multiply) +;; "overlay" — contrast-boosting midtone blend +;; "soft-light" — gentle dodge/burn +;; "hard-light" — strong dodge/burn +;; "color-dodge" — brightens towards white +;; "color-burn" — darkens towards black +;; "difference" — absolute pixel difference +;; "exclusion" — softer difference +;; "add" — additive (clamped) +;; "subtract" — subtractive (clamped) +;; "darken" — per-pixel minimum +;; "lighten" — per-pixel maximum +;; resize_mode - how to match frame dimensions (fit, crop, stretch) +;; +;; Uses a left-fold over inputs[1..N-1]. At each step the running +;; opacity is: w[i] / (w[0] + w[1] + ... + w[i]) +;; which produces the correct normalised weighted result. + +(require-primitives "image" "blending") + +(define-effect blend_multi + :params ( + (inputs :type list :default []) + (weights :type list :default []) + (mode :type string :default "alpha") + (resize_mode :type string :default "fit") + ) + (let [n (len inputs) + ;; Target dimensions from first frame + target-w (image:width (nth inputs 0)) + target-h (image:height (nth inputs 0)) + ;; Fold over indices 1..n-1 + ;; Accumulator is (list blended-frame running-weight-sum) + seed (list (nth inputs 0) (nth weights 0)) + result (reduce (range 1 n) seed + (lambda (pair i) + (let [acc (nth pair 0) + running (nth pair 1) + w (nth weights i) + new-running (+ running w) + opacity (/ w (max new-running 0.001)) + f (image:resize (nth inputs i) target-w target-h "linear") + ;; Apply blend mode then mix with opacity + blended (if (= mode "alpha") + (blending:blend-images acc f opacity) + (blending:blend-images acc (blending:blend-mode acc f mode) opacity))] + (list blended new-running))))] + (nth result 0))) diff --git a/sexp_effects/effects/bloom.sexp b/sexp_effects/effects/bloom.sexp new file mode 100644 index 0000000..3524d01 --- /dev/null +++ b/sexp_effects/effects/bloom.sexp @@ -0,0 +1,16 @@ +;; Bloom effect - glow on bright areas +(require-primitives "image" "blending") + +(define-effect bloom + :params ( + (intensity :type float :default 0.5 :range [0 2]) + (threshold :type int :default 200 :range [0 255]) + (radius :type int :default 15 :range [1 50]) + ) + (let* ((bright (map-pixels frame + (lambda (x y c) + (if (> (luminance c) threshold) + c + (rgb 0 0 0))))) + (blurred (image:blur bright radius))) + (blending:blend-mode frame blurred "add"))) diff --git a/sexp_effects/effects/blur.sexp b/sexp_effects/effects/blur.sexp new file mode 100644 index 0000000..b71a55a --- /dev/null +++ b/sexp_effects/effects/blur.sexp @@ -0,0 +1,8 @@ +;; Blur effect - gaussian blur +(require-primitives "image") + +(define-effect blur + :params ( + (radius :type int :default 5 :range [1 50]) + ) + (image:blur frame (max 1 radius))) diff --git a/sexp_effects/effects/brightness.sexp b/sexp_effects/effects/brightness.sexp new file mode 100644 index 0000000..4af53a7 --- /dev/null +++ b/sexp_effects/effects/brightness.sexp @@ -0,0 +1,9 @@ +;; Brightness effect - adjusts overall brightness +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect brightness + :params ( + (amount :type int :default 0 :range [-255 255]) + ) + (color_ops:adjust-brightness frame amount)) diff --git a/sexp_effects/effects/cell_pattern.sexp b/sexp_effects/effects/cell_pattern.sexp new file mode 100644 index 0000000..bc503bb --- /dev/null +++ b/sexp_effects/effects/cell_pattern.sexp @@ -0,0 +1,65 @@ +;; Cell Pattern effect - custom patterns within cells +;; +;; Demonstrates building arbitrary per-cell visuals from primitives. +;; Uses local coordinates within cells to draw patterns scaled by luminance. + +(require-primitives "xector") + +(define-effect cell_pattern + :params ( + (cell-size :type int :default 16 :range [8 48] :desc "Cell size") + (pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring") + ) + (let* ( + ;; Pool to get cell colors + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + (cell-lum (α/ (nth pooled 3) 255)) + + ;; Cell indices for each pixel + (cell-idx (cell-indices frame cell-size)) + + ;; Look up cell values for each pixel + (pix-r (gather cell-r cell-idx)) + (pix-g (gather cell-g cell-idx)) + (pix-b (gather cell-b cell-idx)) + (pix-lum (gather cell-lum cell-idx)) + + ;; Local position within cell [0, 1] + (lx (local-x-norm frame cell-size)) + (ly (local-y-norm frame cell-size)) + + ;; Pattern mask based on pattern type + (mask + (cond + ;; Diagonal lines - thickness based on luminance + ((= pattern "diagonal") + (let* ((diag (αmod (α+ lx ly) 0.25)) + (thickness (α* pix-lum 0.125))) + (α< diag thickness))) + + ;; Cross pattern + ((= pattern "cross") + (let* ((cx (αabs (α- lx 0.5))) + (cy (αabs (α- ly 0.5))) + (thickness (α* pix-lum 0.25))) + (αor (α< cx thickness) (α< cy thickness)))) + + ;; Ring pattern + ((= pattern "ring") + (let* ((dx (α- lx 0.5)) + (dy (α- ly 0.5)) + (dist (αsqrt (α+ (α² dx) (α² dy)))) + (target (α* pix-lum 0.4)) + (thickness 0.05)) + (α< (αabs (α- dist target)) thickness))) + + ;; Default: solid + (else (α> pix-lum 0))))) + + ;; Apply mask: show cell color where mask is true, black elsewhere + (rgb (where mask pix-r 0) + (where mask pix-g 0) + (where mask pix-b 0)))) diff --git a/sexp_effects/effects/color-adjust.sexp b/sexp_effects/effects/color-adjust.sexp new file mode 100644 index 0000000..5318bdd --- /dev/null +++ b/sexp_effects/effects/color-adjust.sexp @@ -0,0 +1,13 @@ +;; Color adjustment effect - replaces TRANSFORM node +(require-primitives "color_ops") + +(define-effect color-adjust + :params ( + (brightness :type int :default 0 :range [-255 255] :desc "Brightness adjustment") + (contrast :type float :default 1 :range [0 3] :desc "Contrast multiplier") + (saturation :type float :default 1 :range [0 2] :desc "Saturation multiplier") + ) + (-> frame + (color_ops:adjust-brightness brightness) + (color_ops:adjust-contrast contrast) + (color_ops:adjust-saturation saturation))) diff --git a/sexp_effects/effects/color_cycle.sexp b/sexp_effects/effects/color_cycle.sexp new file mode 100644 index 0000000..e08dbb6 --- /dev/null +++ b/sexp_effects/effects/color_cycle.sexp @@ -0,0 +1,13 @@ +;; Color Cycle effect - animated hue rotation +(require-primitives "color_ops") + +(define-effect color_cycle + :params ( + (speed :type int :default 1 :range [0 10]) + ) + (let ((shift (* t speed 360))) + (map-pixels frame + (lambda (x y c) + (let* ((hsv (rgb->hsv c)) + (new-h (mod (+ (first hsv) shift) 360))) + (hsv->rgb (list new-h (nth hsv 1) (nth hsv 2)))))))) diff --git a/sexp_effects/effects/contrast.sexp b/sexp_effects/effects/contrast.sexp new file mode 100644 index 0000000..660661d --- /dev/null +++ b/sexp_effects/effects/contrast.sexp @@ -0,0 +1,9 @@ +;; Contrast effect - adjusts image contrast +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect contrast + :params ( + (amount :type int :default 1 :range [0.5 3]) + ) + (color_ops:adjust-contrast frame amount)) diff --git a/sexp_effects/effects/crt.sexp b/sexp_effects/effects/crt.sexp new file mode 100644 index 0000000..097eaf9 --- /dev/null +++ b/sexp_effects/effects/crt.sexp @@ -0,0 +1,30 @@ +;; CRT effect - old monitor simulation +(require-primitives "image") + +(define-effect crt + :params ( + (line_spacing :type int :default 2 :range [1 10]) + (line_opacity :type float :default 0.3 :range [0 1]) + (vignette_amount :type float :default 0.2) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (sqrt (+ (* cx cx) (* cy cy))))) + (map-pixels frame + (lambda (x y c) + (let* (;; Scanline darkening + (scanline-factor (if (= 0 (mod y line_spacing)) + (- 1 line_opacity) + 1)) + ;; Vignette + (dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (vignette-factor (- 1 (* (/ dist max-dist) vignette_amount))) + ;; Combined + (factor (* scanline-factor vignette-factor))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/sexp_effects/effects/datamosh.sexp b/sexp_effects/effects/datamosh.sexp new file mode 100644 index 0000000..60cec66 --- /dev/null +++ b/sexp_effects/effects/datamosh.sexp @@ -0,0 +1,14 @@ +;; Datamosh effect - glitch block corruption + +(define-effect datamosh + :params ( + (block_size :type int :default 32 :range [8 128]) + (corruption :type float :default 0.3 :range [0 1]) + (max_offset :type int :default 50 :range [0 200]) + (color_corrupt :type bool :default true) + ) + ;; Get previous frame from state, or use current frame if none + (let ((prev (state-get "prev_frame" frame))) + (begin + (state-set "prev_frame" (copy frame)) + (datamosh frame prev block_size corruption max_offset color_corrupt)))) diff --git a/sexp_effects/effects/echo.sexp b/sexp_effects/effects/echo.sexp new file mode 100644 index 0000000..599a1d6 --- /dev/null +++ b/sexp_effects/effects/echo.sexp @@ -0,0 +1,19 @@ +;; Echo effect - motion trails using frame buffer +(require-primitives "blending") + +(define-effect echo + :params ( + (num_echoes :type int :default 4 :range [1 20]) + (decay :type float :default 0.5 :range [0 1]) + ) + (let* ((buffer (state-get "buffer" (list))) + (new-buffer (take (cons frame buffer) (+ num_echoes 1)))) + (begin + (state-set "buffer" new-buffer) + ;; Blend frames with decay + (if (< (length new-buffer) 2) + frame + (let ((result (copy frame))) + ;; Simple blend of first two frames for now + ;; Full version would fold over all frames + (blending:blend-images frame (nth new-buffer 1) (* decay 0.5))))))) diff --git a/sexp_effects/effects/edge_detect.sexp b/sexp_effects/effects/edge_detect.sexp new file mode 100644 index 0000000..170befb --- /dev/null +++ b/sexp_effects/effects/edge_detect.sexp @@ -0,0 +1,9 @@ +;; Edge detection effect - highlights edges +(require-primitives "image") + +(define-effect edge_detect + :params ( + (low :type int :default 50 :range [10 100]) + (high :type int :default 150 :range [50 300]) + ) + (image:edge-detect frame low high)) diff --git a/sexp_effects/effects/emboss.sexp b/sexp_effects/effects/emboss.sexp new file mode 100644 index 0000000..1eac3ce --- /dev/null +++ b/sexp_effects/effects/emboss.sexp @@ -0,0 +1,13 @@ +;; Emboss effect - creates raised/3D appearance +(require-primitives "blending") + +(define-effect emboss + :params ( + (strength :type int :default 1 :range [0.5 3]) + (blend :type float :default 0.3 :range [0 1]) + ) + (let* ((kernel (list (list (- strength) (- strength) 0) + (list (- strength) 1 strength) + (list 0 strength strength))) + (embossed (convolve frame kernel))) + (blending:blend-images embossed frame blend))) diff --git a/sexp_effects/effects/film_grain.sexp b/sexp_effects/effects/film_grain.sexp new file mode 100644 index 0000000..29bdd75 --- /dev/null +++ b/sexp_effects/effects/film_grain.sexp @@ -0,0 +1,19 @@ +;; Film Grain effect - adds film grain texture +(require-primitives "core") + +(define-effect film_grain + :params ( + (intensity :type float :default 0.2 :range [0 1]) + (colored :type bool :default false) + ) + (let ((grain-amount (* intensity 50))) + (map-pixels frame + (lambda (x y c) + (if colored + (rgb (clamp (+ (red c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (green c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (blue c) (gaussian 0 grain-amount)) 0 255)) + (let ((n (gaussian 0 grain-amount))) + (rgb (clamp (+ (red c) n) 0 255) + (clamp (+ (green c) n) 0 255) + (clamp (+ (blue c) n) 0 255)))))))) diff --git a/sexp_effects/effects/fisheye.sexp b/sexp_effects/effects/fisheye.sexp new file mode 100644 index 0000000..37750a7 --- /dev/null +++ b/sexp_effects/effects/fisheye.sexp @@ -0,0 +1,16 @@ +;; Fisheye effect - barrel/pincushion lens distortion +(require-primitives "geometry" "image") + +(define-effect fisheye + :params ( + (strength :type float :default 0.3 :range [-1 1]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom_correct :type bool :default true) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:fisheye-coords w h strength cx cy zoom_correct))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/flip.sexp b/sexp_effects/effects/flip.sexp new file mode 100644 index 0000000..977e1e1 --- /dev/null +++ b/sexp_effects/effects/flip.sexp @@ -0,0 +1,16 @@ +;; Flip effect - flips image horizontally or vertically +(require-primitives "geometry") + +(define-effect flip + :params ( + (horizontal :type bool :default true) + (vertical :type bool :default false) + ) + (let ((result frame)) + (if horizontal + (set! result (geometry:flip-img result "horizontal")) + nil) + (if vertical + (set! result (geometry:flip-img result "vertical")) + nil) + result)) diff --git a/sexp_effects/effects/grayscale.sexp b/sexp_effects/effects/grayscale.sexp new file mode 100644 index 0000000..848f8a7 --- /dev/null +++ b/sexp_effects/effects/grayscale.sexp @@ -0,0 +1,7 @@ +;; Grayscale effect - converts to grayscale +;; Uses vectorized mix-gray primitive for fast processing +(require-primitives "image") + +(define-effect grayscale + :params () + (image:grayscale frame)) diff --git a/sexp_effects/effects/halftone.sexp b/sexp_effects/effects/halftone.sexp new file mode 100644 index 0000000..2190a4a --- /dev/null +++ b/sexp_effects/effects/halftone.sexp @@ -0,0 +1,49 @@ +;; Halftone/dot effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell luminances +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel +;; local-x/y-norm - position within cell [0,1] +;; where - conditional per-pixel + +(require-primitives "xector") + +(define-effect halftone + :params ( + (cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells") + (dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius") + (invert :type bool :default false :desc "Invert (white dots on black)") + ) + (let* ( + ;; Pool frame to get luminance per cell + (pooled (pool-frame frame cell-size)) + (cell-lum (nth pooled 3)) ; luminance is 4th element + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Get cell luminance for each pixel + (pixel-lum (α/ (gather cell-lum cell-idx) 255)) + + ;; Position within cell, normalized to [-0.5, 0.5] + (lx (α- (local-x-norm frame cell-size) 0.5)) + (ly (α- (local-y-norm frame cell-size) 0.5)) + + ;; Distance from cell center (0 at center, ~0.7 at corners) + (dist (αsqrt (α+ (α² lx) (α² ly)))) + + ;; Radius based on luminance (brighter = bigger dot) + (radius (α* (if invert (α- 1 pixel-lum) pixel-lum) + (α* dot-scale 0.5))) + + ;; Is this pixel inside the dot? + (inside (α< dist radius)) + + ;; Output color + (fg (if invert 255 0)) + (bg (if invert 0 255)) + (out (where inside fg bg))) + + ;; Grayscale output + (rgb out out out))) diff --git a/sexp_effects/effects/hue_shift.sexp b/sexp_effects/effects/hue_shift.sexp new file mode 100644 index 0000000..ab61bd6 --- /dev/null +++ b/sexp_effects/effects/hue_shift.sexp @@ -0,0 +1,12 @@ +;; Hue shift effect - rotates hue values +;; Uses vectorized shift-hsv primitive for fast processing + +(require-primitives "color_ops") + +(define-effect hue_shift + :params ( + (degrees :type int :default 0 :range [0 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((shift (+ degrees (* speed t)))) + (color_ops:shift-hsv frame shift 1 1))) diff --git a/sexp_effects/effects/invert.sexp b/sexp_effects/effects/invert.sexp new file mode 100644 index 0000000..34936da --- /dev/null +++ b/sexp_effects/effects/invert.sexp @@ -0,0 +1,9 @@ +;; Invert effect - inverts all colors +;; Uses vectorized invert-img primitive for fast processing +;; amount param: 0 = no invert, 1 = full invert (threshold at 0.5) + +(require-primitives "color_ops") + +(define-effect invert + :params ((amount :type float :default 1 :range [0 1])) + (if (> amount 0.5) (color_ops:invert-img frame) frame)) diff --git a/sexp_effects/effects/kaleidoscope.sexp b/sexp_effects/effects/kaleidoscope.sexp new file mode 100644 index 0000000..9487ae2 --- /dev/null +++ b/sexp_effects/effects/kaleidoscope.sexp @@ -0,0 +1,20 @@ +;; Kaleidoscope effect - mandala-like symmetry patterns +(require-primitives "geometry" "image") + +(define-effect kaleidoscope + :params ( + (segments :type int :default 6 :range [3 16]) + (rotation :type int :default 0 :range [0 360]) + (rotation_speed :type int :default 0 :range [-180 180]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom :type int :default 1 :range [0.5 3]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + ;; Total rotation including time-based animation + (total_rot (+ rotation (* rotation_speed (or _time 0)))) + (coords (geometry:kaleidoscope-coords w h segments total_rot cx cy zoom))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/layer.sexp b/sexp_effects/effects/layer.sexp new file mode 100644 index 0000000..e57d627 --- /dev/null +++ b/sexp_effects/effects/layer.sexp @@ -0,0 +1,36 @@ +;; Layer effect - composite overlay over background at position +;; Streaming-compatible: frame is background, overlay is foreground +;; Usage: (layer background overlay :x 10 :y 20 :opacity 0.8) +;; +;; Params: +;; overlay - frame to composite on top +;; x, y - position to place overlay +;; opacity - blend amount (0-1) +;; mode - blend mode (alpha, multiply, screen, etc.) + +(require-primitives "image" "blending" "core") + +(define-effect layer + :params ( + (overlay :type frame :default nil) + (x :type int :default 0) + (y :type int :default 0) + (opacity :type float :default 1.0) + (mode :type string :default "alpha") + ) + (if (core:is-nil overlay) + frame + (let [bg (copy frame) + fg overlay + fg-w (image:width fg) + fg-h (image:height fg)] + (if (= opacity 1.0) + ;; Simple paste + (paste bg fg x y) + ;; Blend with opacity + (let [blended (if (= mode "alpha") + (blending:blend-images (image:crop bg x y fg-w fg-h) fg opacity) + (blending:blend-images (image:crop bg x y fg-w fg-h) + (blending:blend-mode (image:crop bg x y fg-w fg-h) fg mode) + opacity))] + (paste bg blended x y)))))) diff --git a/sexp_effects/effects/mirror.sexp b/sexp_effects/effects/mirror.sexp new file mode 100644 index 0000000..a450cb6 --- /dev/null +++ b/sexp_effects/effects/mirror.sexp @@ -0,0 +1,33 @@ +;; Mirror effect - mirrors half of image +(require-primitives "geometry" "image") + +(define-effect mirror + :params ( + (mode :type string :default "left_right") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (hw (floor (/ w 2))) + (hh (floor (/ h 2)))) + (cond + ((= mode "left_right") + (let ((left (image:crop frame 0 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img left "horizontal") hw 0))) + + ((= mode "right_left") + (let ((right (image:crop frame hw 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img right "horizontal") 0 0))) + + ((= mode "top_bottom") + (let ((top (image:crop frame 0 0 w hh)) + (result (copy frame))) + (paste result (geometry:flip-img top "vertical") 0 hh))) + + ((= mode "bottom_top") + (let ((bottom (image:crop frame 0 hh w hh)) + (result (copy frame))) + (paste result (geometry:flip-img bottom "vertical") 0 0))) + + (else frame)))) diff --git a/sexp_effects/effects/mosaic.sexp b/sexp_effects/effects/mosaic.sexp new file mode 100644 index 0000000..5de07de --- /dev/null +++ b/sexp_effects/effects/mosaic.sexp @@ -0,0 +1,30 @@ +;; Mosaic effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell averages +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel + +(require-primitives "xector") + +(define-effect mosaic + :params ( + (cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells") + ) + (let* ( + ;; Pool frame to get average color per cell (returns r,g,b,lum xectors) + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Gather: look up cell color for each pixel + (out-r (gather cell-r cell-idx)) + (out-g (gather cell-g cell-idx)) + (out-b (gather cell-b cell-idx))) + + ;; Reconstruct frame + (rgb out-r out-g out-b))) diff --git a/sexp_effects/effects/neon_glow.sexp b/sexp_effects/effects/neon_glow.sexp new file mode 100644 index 0000000..39245ab --- /dev/null +++ b/sexp_effects/effects/neon_glow.sexp @@ -0,0 +1,23 @@ +;; Neon Glow effect - glowing edge effect +(require-primitives "image" "blending") + +(define-effect neon_glow + :params ( + (edge_low :type int :default 50 :range [10 200]) + (edge_high :type int :default 150 :range [50 300]) + (glow_radius :type int :default 15 :range [1 50]) + (glow_intensity :type int :default 2 :range [0.5 5]) + (background :type float :default 0.3 :range [0 1]) + ) + (let* ((edge-img (image:edge-detect frame edge_low edge_high)) + (glow (image:blur edge-img glow_radius)) + ;; Intensify the glow + (bright-glow (map-pixels glow + (lambda (x y c) + (rgb (clamp (* (red c) glow_intensity) 0 255) + (clamp (* (green c) glow_intensity) 0 255) + (clamp (* (blue c) glow_intensity) 0 255)))))) + (blending:blend-mode (blending:blend-images frame (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 background)) + bright-glow + "screen"))) diff --git a/sexp_effects/effects/noise.sexp b/sexp_effects/effects/noise.sexp new file mode 100644 index 0000000..4da8298 --- /dev/null +++ b/sexp_effects/effects/noise.sexp @@ -0,0 +1,8 @@ +;; Noise effect - adds random noise +;; Uses vectorized add-noise primitive for fast processing + +(define-effect noise + :params ( + (amount :type int :default 20 :range [0 100]) + ) + (add-noise frame amount)) diff --git a/sexp_effects/effects/outline.sexp b/sexp_effects/effects/outline.sexp new file mode 100644 index 0000000..921a0b8 --- /dev/null +++ b/sexp_effects/effects/outline.sexp @@ -0,0 +1,24 @@ +;; Outline effect - shows only edges +(require-primitives "image") + +(define-effect outline + :params ( + (thickness :type int :default 2 :range [1 10]) + (threshold :type int :default 100 :range [20 300]) + (color :type list :default (list 0 0 0)) + (fill_mode :type string :default "original") + ) + (let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold)) + (dilated (if (> thickness 1) + (dilate edge-img thickness) + edge-img)) + (base (cond + ((= fill_mode "original") (copy frame)) + ((= fill_mode "white") (make-image (image:width frame) (image:height frame) (list 255 255 255))) + (else (make-image (image:width frame) (image:height frame) (list 0 0 0)))))) + (map-pixels base + (lambda (x y c) + (let ((edge-val (luminance (pixel dilated x y)))) + (if (> edge-val 128) + color + c)))))) diff --git a/sexp_effects/effects/pixelate.sexp b/sexp_effects/effects/pixelate.sexp new file mode 100644 index 0000000..3d28ce1 --- /dev/null +++ b/sexp_effects/effects/pixelate.sexp @@ -0,0 +1,13 @@ +;; Pixelate effect - creates blocky pixels +(require-primitives "image") + +(define-effect pixelate + :params ( + (block_size :type int :default 8 :range [2 64]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (small-w (max 1 (floor (/ w block_size)))) + (small-h (max 1 (floor (/ h block_size)))) + (small (image:resize frame small-w small-h "area"))) + (image:resize small w h "nearest"))) diff --git a/sexp_effects/effects/pixelsort.sexp b/sexp_effects/effects/pixelsort.sexp new file mode 100644 index 0000000..155ac13 --- /dev/null +++ b/sexp_effects/effects/pixelsort.sexp @@ -0,0 +1,11 @@ +;; Pixelsort effect - glitch art pixel sorting + +(define-effect pixelsort + :params ( + (sort_by :type string :default "lightness") + (threshold_low :type int :default 50 :range [0 255]) + (threshold_high :type int :default 200 :range [0 255]) + (angle :type int :default 0 :range [0 180]) + (reverse :type bool :default false) + ) + (pixelsort frame sort_by threshold_low threshold_high angle reverse)) diff --git a/sexp_effects/effects/posterize.sexp b/sexp_effects/effects/posterize.sexp new file mode 100644 index 0000000..7052ed3 --- /dev/null +++ b/sexp_effects/effects/posterize.sexp @@ -0,0 +1,8 @@ +;; Posterize effect - reduces color levels +(require-primitives "color_ops") + +(define-effect posterize + :params ( + (levels :type int :default 8 :range [2 32]) + ) + (color_ops:posterize frame levels)) diff --git a/sexp_effects/effects/resize-frame.sexp b/sexp_effects/effects/resize-frame.sexp new file mode 100644 index 0000000..a1cce27 --- /dev/null +++ b/sexp_effects/effects/resize-frame.sexp @@ -0,0 +1,11 @@ +;; Resize effect - replaces RESIZE node +;; Note: uses target-w/target-h to avoid conflict with width/height primitives +(require-primitives "image") + +(define-effect resize-frame + :params ( + (target-w :type int :default 640 :desc "Target width in pixels") + (target-h :type int :default 480 :desc "Target height in pixels") + (mode :type string :default "linear" :choices [linear nearest area] :desc "Interpolation mode") + ) + (image:resize frame target-w target-h mode)) diff --git a/sexp_effects/effects/rgb_split.sexp b/sexp_effects/effects/rgb_split.sexp new file mode 100644 index 0000000..4582701 --- /dev/null +++ b/sexp_effects/effects/rgb_split.sexp @@ -0,0 +1,13 @@ +;; RGB Split effect - chromatic aberration + +(define-effect rgb_split + :params ( + (offset_x :type int :default 10 :range [-50 50]) + (offset_y :type int :default 0 :range [-50 50]) + ) + (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (r-shifted (translate (merge-channels r r r) offset_x offset_y)) + (b-shifted (translate (merge-channels b b b) (- offset_x) (- offset_y)))) + (merge-channels (channel r-shifted 0) g (channel b-shifted 0)))) diff --git a/sexp_effects/effects/ripple.sexp b/sexp_effects/effects/ripple.sexp new file mode 100644 index 0000000..0bb7a8d --- /dev/null +++ b/sexp_effects/effects/ripple.sexp @@ -0,0 +1,19 @@ +;; Ripple effect - radial wave distortion from center +(require-primitives "geometry" "image" "math") + +(define-effect ripple + :params ( + (frequency :type int :default 5 :range [1 20]) + (amplitude :type int :default 10 :range [0 50]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (decay :type int :default 1 :range [0 5]) + (speed :type int :default 1 :range [0 10]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (phase (* (or t 0) speed 2 pi)) + (coords (geometry:ripple-displace w h frequency amplitude cx cy decay phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/rotate.sexp b/sexp_effects/effects/rotate.sexp new file mode 100644 index 0000000..d06c2f7 --- /dev/null +++ b/sexp_effects/effects/rotate.sexp @@ -0,0 +1,11 @@ +;; Rotate effect - rotates image + +(require-primitives "geometry") + +(define-effect rotate + :params ( + (angle :type int :default 0 :range [-360 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((total-angle (+ angle (* speed t)))) + (geometry:rotate-img frame total-angle))) diff --git a/sexp_effects/effects/saturation.sexp b/sexp_effects/effects/saturation.sexp new file mode 100644 index 0000000..9852dc7 --- /dev/null +++ b/sexp_effects/effects/saturation.sexp @@ -0,0 +1,9 @@ +;; Saturation effect - adjusts color saturation +;; Uses vectorized shift-hsv primitive for fast processing +(require-primitives "color_ops") + +(define-effect saturation + :params ( + (amount :type int :default 1 :range [0 3]) + ) + (color_ops:adjust-saturation frame amount)) diff --git a/sexp_effects/effects/scanlines.sexp b/sexp_effects/effects/scanlines.sexp new file mode 100644 index 0000000..ddfcf44 --- /dev/null +++ b/sexp_effects/effects/scanlines.sexp @@ -0,0 +1,15 @@ +;; Scanlines effect - VHS-style horizontal line shifting +(require-primitives "core") + +(define-effect scanlines + :params ( + (amplitude :type int :default 10 :range [0 100]) + (frequency :type int :default 10 :range [1 100]) + (randomness :type float :default 0.5 :range [0 1]) + ) + (map-rows frame + (lambda (y row) + (let* ((sine-shift (* amplitude (sin (/ (* y 6.28) (max 1 frequency))))) + (rand-shift (core:rand-range (- amplitude) amplitude)) + (shift (floor (lerp sine-shift rand-shift randomness)))) + (roll row shift 0))))) diff --git a/sexp_effects/effects/sepia.sexp b/sexp_effects/effects/sepia.sexp new file mode 100644 index 0000000..e3a5875 --- /dev/null +++ b/sexp_effects/effects/sepia.sexp @@ -0,0 +1,7 @@ +;; Sepia effect - applies sepia tone +;; Classic warm vintage look +(require-primitives "color_ops") + +(define-effect sepia + :params () + (color_ops:sepia frame)) diff --git a/sexp_effects/effects/sharpen.sexp b/sexp_effects/effects/sharpen.sexp new file mode 100644 index 0000000..538bd7f --- /dev/null +++ b/sexp_effects/effects/sharpen.sexp @@ -0,0 +1,8 @@ +;; Sharpen effect - sharpens edges +(require-primitives "image") + +(define-effect sharpen + :params ( + (amount :type int :default 1 :range [0 5]) + ) + (image:sharpen frame amount)) diff --git a/sexp_effects/effects/strobe.sexp b/sexp_effects/effects/strobe.sexp new file mode 100644 index 0000000..2bf80b4 --- /dev/null +++ b/sexp_effects/effects/strobe.sexp @@ -0,0 +1,16 @@ +;; Strobe effect - holds frames for choppy look +(require-primitives "core") + +(define-effect strobe + :params ( + (frame_rate :type int :default 12 :range [1 60]) + ) + (let* ((held (state-get "held" nil)) + (held-until (state-get "held-until" 0)) + (frame-duration (/ 1 frame_rate))) + (if (or (core:is-nil held) (>= t held-until)) + (begin + (state-set "held" (copy frame)) + (state-set "held-until" (+ t frame-duration)) + frame) + held))) diff --git a/sexp_effects/effects/swirl.sexp b/sexp_effects/effects/swirl.sexp new file mode 100644 index 0000000..ba9cf57 --- /dev/null +++ b/sexp_effects/effects/swirl.sexp @@ -0,0 +1,17 @@ +;; Swirl effect - spiral vortex distortion +(require-primitives "geometry" "image") + +(define-effect swirl + :params ( + (strength :type int :default 1 :range [-10 10]) + (radius :type float :default 0.5 :range [0.1 2]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (falloff :type string :default "quadratic") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:swirl-coords w h strength radius cx cy falloff))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/threshold.sexp b/sexp_effects/effects/threshold.sexp new file mode 100644 index 0000000..50d3bc5 --- /dev/null +++ b/sexp_effects/effects/threshold.sexp @@ -0,0 +1,9 @@ +;; Threshold effect - converts to black and white +(require-primitives "color_ops") + +(define-effect threshold + :params ( + (level :type int :default 128 :range [0 255]) + (invert :type bool :default false) + ) + (color_ops:threshold frame level invert)) diff --git a/sexp_effects/effects/tile_grid.sexp b/sexp_effects/effects/tile_grid.sexp new file mode 100644 index 0000000..44487a9 --- /dev/null +++ b/sexp_effects/effects/tile_grid.sexp @@ -0,0 +1,29 @@ +;; Tile Grid effect - tiles image in grid +(require-primitives "geometry" "image") + +(define-effect tile_grid + :params ( + (rows :type int :default 2 :range [1 10]) + (cols :type int :default 2 :range [1 10]) + (gap :type int :default 0 :range [0 50]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (tile-w (floor (/ (- w (* gap (- cols 1))) cols))) + (tile-h (floor (/ (- h (* gap (- rows 1))) rows))) + (tile (image:resize frame tile-w tile-h "area")) + (result (make-image w h (list 0 0 0)))) + (begin + ;; Manually place tiles using nested iteration + ;; This is a simplified version - full version would loop + (paste result tile 0 0) + (if (> cols 1) + (paste result tile (+ tile-w gap) 0) + nil) + (if (> rows 1) + (paste result tile 0 (+ tile-h gap)) + nil) + (if (and (> cols 1) (> rows 1)) + (paste result tile (+ tile-w gap) (+ tile-h gap)) + nil) + result))) diff --git a/sexp_effects/effects/trails.sexp b/sexp_effects/effects/trails.sexp new file mode 100644 index 0000000..5c0fc7c --- /dev/null +++ b/sexp_effects/effects/trails.sexp @@ -0,0 +1,20 @@ +;; Trails effect - persistent motion trails +(require-primitives "image" "blending") + +(define-effect trails + :params ( + (persistence :type float :default 0.8 :range [0 0.99]) + ) + (let* ((buffer (state-get "buffer" nil)) + (current frame)) + (if (= buffer nil) + (begin + (state-set "buffer" (copy frame)) + frame) + (let* ((faded (blending:blend-images buffer + (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 persistence))) + (result (blending:blend-mode faded current "lighten"))) + (begin + (state-set "buffer" result) + result))))) diff --git a/sexp_effects/effects/vignette.sexp b/sexp_effects/effects/vignette.sexp new file mode 100644 index 0000000..46e63ee --- /dev/null +++ b/sexp_effects/effects/vignette.sexp @@ -0,0 +1,23 @@ +;; Vignette effect - darkens corners +(require-primitives "image") + +(define-effect vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type int :default 1 :range [0.5 2]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (* (sqrt (+ (* cx cx) (* cy cy))) radius))) + (map-pixels frame + (lambda (x y c) + (let* ((dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (factor (- 1 (* (/ dist max-dist) strength))) + (factor (clamp factor 0 1))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/sexp_effects/effects/wave.sexp b/sexp_effects/effects/wave.sexp new file mode 100644 index 0000000..98b03c2 --- /dev/null +++ b/sexp_effects/effects/wave.sexp @@ -0,0 +1,22 @@ +;; Wave effect - sine wave displacement distortion +(require-primitives "geometry" "image") + +(define-effect wave + :params ( + (amplitude :type int :default 10 :range [0 100]) + (wavelength :type int :default 50 :range [10 500]) + (speed :type int :default 1 :range [0 10]) + (direction :type string :default "horizontal") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + ;; Use _time for animation phase + (phase (* (or _time 0) speed 2 pi)) + ;; Calculate frequency: waves per dimension + (freq (/ (if (= direction "vertical") w h) wavelength)) + (axis (cond + ((= direction "horizontal") "x") + ((= direction "vertical") "y") + (else "both"))) + (coords (geometry:wave-coords w h axis freq amplitude phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/xector_feathered_blend.sexp b/sexp_effects/effects/xector_feathered_blend.sexp new file mode 100644 index 0000000..96224fb --- /dev/null +++ b/sexp_effects/effects/xector_feathered_blend.sexp @@ -0,0 +1,44 @@ +;; Feathered blend - blend two same-size frames with distance-based falloff +;; Center shows overlay, edges show background, with smooth transition + +(require-primitives "xector") + +(define-effect xector_feathered_blend + :params ( + (inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)") + (fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)") + (overlay :type frame :default nil :desc "Frame to blend in center") + ) + (let* ( + ;; Get normalized distance from center (0 at center, ~1 at corners) + (dist (dist-from-center frame)) + (max-dist (βmax dist)) + (dist-norm (α/ dist max-dist)) + + ;; Calculate blend factor: + ;; - 1.0 when dist-norm < inner-radius (fully overlay) + ;; - 0.0 when dist-norm > inner-radius + fade-width (fully background) + ;; - linear ramp between + (t (α/ (α- dist-norm inner-radius) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + (inv-blend (α- 1 blend)) + + ;; Background channels + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + (if (nil? overlay) + ;; No overlay - visualize the blend mask + (let ((vis (α* blend 255))) + (rgb vis vis vis)) + + ;; Blend overlay with background using the mask + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; lerp: bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r inv-blend) (α* ov-r blend))) + (g-out (α+ (α* bg-g inv-blend) (α* ov-g blend))) + (b-out (α+ (α* bg-b inv-blend) (α* ov-b blend)))) + (rgb r-out g-out b-out))))) diff --git a/sexp_effects/effects/xector_grain.sexp b/sexp_effects/effects/xector_grain.sexp new file mode 100644 index 0000000..64ebfa6 --- /dev/null +++ b/sexp_effects/effects/xector_grain.sexp @@ -0,0 +1,34 @@ +;; Film grain effect using xector operations +;; Demonstrates random xectors and mixing scalar/xector math + +(require-primitives "xector") + +(define-effect xector_grain + :params ( + (intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity") + (colored :type bool :default false :desc "Use colored grain") + ) + (let* ( + ;; Extract channels + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Generate noise xector(s) + ;; randn-x generates normal distribution noise + (grain-amount (* intensity 50))) + + (if colored + ;; Colored grain: different noise per channel + (let* ((nr (randn-x frame 0 grain-amount)) + (ng (randn-x frame 0 grain-amount)) + (nb (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r nr) 0 255) + (αclamp (α+ g ng) 0 255) + (αclamp (α+ b nb) 0 255))) + + ;; Monochrome grain: same noise for all channels + (let ((n (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r n) 0 255) + (αclamp (α+ g n) 0 255) + (αclamp (α+ b n) 0 255)))))) diff --git a/sexp_effects/effects/xector_inset_blend.sexp b/sexp_effects/effects/xector_inset_blend.sexp new file mode 100644 index 0000000..597e23a --- /dev/null +++ b/sexp_effects/effects/xector_inset_blend.sexp @@ -0,0 +1,57 @@ +;; Inset blend - fade a smaller frame into a larger background +;; Uses distance-based alpha for smooth transition (no hard edges) + +(require-primitives "xector") + +(define-effect xector_inset_blend + :params ( + (x :type int :default 0 :desc "X position of inset") + (y :type int :default 0 :desc "Y position of inset") + (fade-width :type int :default 50 :desc "Width of fade region in pixels") + (overlay :type frame :default nil :desc "The smaller frame to inset") + ) + (let* ( + ;; Get dimensions + (bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height + (bg-w bg-h) ;; placeholder + + ;; For now, create a simple centered circular blend + ;; Distance from center of overlay position + (cx (+ x (/ (- bg-w (* 2 x)) 2))) + (cy (+ y (/ (- bg-h (* 2 y)) 2))) + + ;; Get coordinates as xectors + (px (x-coords frame)) + (py (y-coords frame)) + + ;; Distance from center + (dx (α- px cx)) + (dy (α- py cy)) + (dist (αsqrt (α+ (α* dx dx) (α* dy dy)))) + + ;; Inner radius (fully overlay) and outer radius (fully background) + (inner-r (- (/ bg-w 2) x fade-width)) + (outer-r (- (/ bg-w 2) x)) + + ;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between + (t (α/ (α- dist inner-r) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + + ;; Extract channels from both frames + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + ;; If overlay provided, blend it + (if overlay + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; Linear blend: result = bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend))) + (g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend))) + (b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend)))) + (rgb r-out g-out b-out)) + ;; No overlay - just show the blend mask for debugging + (let ((mask-vis (α* blend 255))) + (rgb mask-vis mask-vis mask-vis))))) diff --git a/sexp_effects/effects/xector_threshold.sexp b/sexp_effects/effects/xector_threshold.sexp new file mode 100644 index 0000000..c571468 --- /dev/null +++ b/sexp_effects/effects/xector_threshold.sexp @@ -0,0 +1,27 @@ +;; Threshold effect using xector operations +;; Demonstrates where (conditional select) and β (reduction) for normalization + +(require-primitives "xector") + +(define-effect xector_threshold + :params ( + (threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)") + (invert :type bool :default false :desc "Invert the threshold") + ) + (let* ( + ;; Get grayscale luminance as xector + (luma (gray frame)) + + ;; Normalize to 0-1 range + (luma-norm (α/ luma 255)) + + ;; Create boolean mask: pixels above threshold + (mask (if invert + (α< luma-norm threshold) + (α>= luma-norm threshold))) + + ;; Use where to select: white (255) if above threshold, black (0) if below + (out (where mask 255 0))) + + ;; Output as grayscale (same value for R, G, B) + (rgb out out out))) diff --git a/sexp_effects/effects/xector_vignette.sexp b/sexp_effects/effects/xector_vignette.sexp new file mode 100644 index 0000000..d654ca7 --- /dev/null +++ b/sexp_effects/effects/xector_vignette.sexp @@ -0,0 +1,36 @@ +;; Vignette effect using xector operations +;; Demonstrates α (element-wise) and β (reduction) patterns + +(require-primitives "xector") + +(define-effect xector_vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type float :default 1.0 :range [0.5 2]) + ) + (let* ( + ;; Get normalized distance from center for each pixel + (dist (dist-from-center frame)) + + ;; Calculate max distance (corner distance) + (max-dist (* (βmax dist) radius)) + + ;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength) + ;; Using explicit α operators + (factor (α- 1 (α* (α/ dist max-dist) strength))) + + ;; Clamp factor to [0, 1] + (factor (αclamp factor 0 1)) + + ;; Extract channels as xectors + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Apply factor to each channel (implicit element-wise via Xector operators) + (r-out (* r factor)) + (g-out (* g factor)) + (b-out (* b factor))) + + ;; Combine back to frame + (rgb r-out g-out b-out))) diff --git a/sexp_effects/effects/zoom.sexp b/sexp_effects/effects/zoom.sexp new file mode 100644 index 0000000..6e4b9ff --- /dev/null +++ b/sexp_effects/effects/zoom.sexp @@ -0,0 +1,8 @@ +;; Zoom effect - zooms in/out from center +(require-primitives "geometry") + +(define-effect zoom + :params ( + (amount :type int :default 1 :range [0.1 5]) + ) + (geometry:scale-img frame amount amount)) diff --git a/sexp_effects/interpreter.py b/sexp_effects/interpreter.py new file mode 100644 index 0000000..406f6da --- /dev/null +++ b/sexp_effects/interpreter.py @@ -0,0 +1,1085 @@ +""" +S-Expression Effect Interpreter + +Interprets effect definitions written in S-expressions. +Only allows safe primitives - no arbitrary code execution. +""" + +import numpy as np +from typing import Any, Dict, List, Optional, Callable +from pathlib import Path + +from .parser import Symbol, Keyword, parse, parse_file +from .primitives import PRIMITIVES, reset_rng + + +def _is_symbol(x) -> bool: + """Check if x is a Symbol (duck typing to support multiple Symbol classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Symbol' + + +def _is_keyword(x) -> bool: + """Check if x is a Keyword (duck typing to support multiple Keyword classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Keyword' + + +def _symbol_name(x) -> str: + """Get the name from a Symbol.""" + return x.name if hasattr(x, 'name') else str(x) + + +class Environment: + """Lexical environment for variable bindings.""" + + def __init__(self, parent: 'Environment' = None): + self.bindings: Dict[str, Any] = {} + self.parent = parent + + def get(self, name: str) -> Any: + if name in self.bindings: + return self.bindings[name] + if self.parent: + return self.parent.get(name) + raise NameError(f"Undefined variable: {name}") + + def set(self, name: str, value: Any): + self.bindings[name] = value + + def has(self, name: str) -> bool: + if name in self.bindings: + return True + if self.parent: + return self.parent.has(name) + return False + + +class Lambda: + """A user-defined function (lambda).""" + + def __init__(self, params: List[str], body: Any, env: Environment): + self.params = params + self.body = body + self.env = env # Closure environment + + def __repr__(self): + return f"" + + +class EffectDefinition: + """A parsed effect definition.""" + + def __init__(self, name: str, params: Dict[str, Any], body: Any): + self.name = name + self.params = params # {name: (type, default)} + self.body = body + + def __repr__(self): + return f"" + + +class Interpreter: + """ + S-Expression interpreter for effects. + + Provides a safe execution environment where only + whitelisted primitives can be called. + + Args: + minimal_primitives: If True, only load core primitives (arithmetic, comparison, + basic data access). Additional primitives must be loaded with + (require-primitives) or (with-primitives). + If False (default), load all legacy primitives for backward compatibility. + """ + + def __init__(self, minimal_primitives: bool = False): + # Base environment with primitives + self.global_env = Environment() + self.minimal_primitives = minimal_primitives + + if minimal_primitives: + # Load only core primitives + from .primitive_libs.core import PRIMITIVES as CORE_PRIMITIVES + for name, fn in CORE_PRIMITIVES.items(): + self.global_env.set(name, fn) + else: + # Load all legacy primitives for backward compatibility + for name, fn in PRIMITIVES.items(): + self.global_env.set(name, fn) + + # Special values + self.global_env.set('true', True) + self.global_env.set('false', False) + self.global_env.set('nil', None) + + # Loaded effect definitions + self.effects: Dict[str, EffectDefinition] = {} + + def eval(self, expr: Any, env: Environment = None) -> Any: + """Evaluate an S-expression.""" + if env is None: + env = self.global_env + + # Atoms + if isinstance(expr, (int, float, str, bool)): + return expr + + if expr is None: + return None + + # Handle Symbol (duck typing to support both sexp_effects.parser.Symbol and artdag.sexp.parser.Symbol) + if _is_symbol(expr): + return env.get(expr.name) + + # Handle Keyword (duck typing) + if _is_keyword(expr): + return expr # Keywords evaluate to themselves + + if isinstance(expr, np.ndarray): + return expr # Images pass through + + # Lists (function calls / special forms) + if isinstance(expr, list): + if not expr: + return [] + + head = expr[0] + + # Special forms + if _is_symbol(head): + form = head.name + + # Quote + if form == 'quote': + return expr[1] + + # Define + if form == 'define': + name = expr[1] + if _is_symbol(name): + # Simple define: (define name value) + value = self.eval(expr[2], env) + self.global_env.set(name.name, value) + return value + elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]): + # Function define: (define (fn-name args...) body) + # Desugars to: (define fn-name (lambda (args...) body)) + fn_name = name[0].name + params = [p.name if _is_symbol(p) else p for p in name[1:]] + body = expr[2] + fn = Lambda(params, body, env) + self.global_env.set(fn_name, fn) + return fn + else: + raise SyntaxError(f"define requires symbol or (name args...), got {name}") + + # Define-effect + if form == 'define-effect': + return self._define_effect(expr, env) + + # Lambda + if form == 'lambda' or form == 'λ': + params = [p.name if _is_symbol(p) else p for p in expr[1]] + body = expr[2] + return Lambda(params, body, env) + + # Let + if form == 'let': + return self._eval_let(expr, env) + + # Let* + if form == 'let*': + return self._eval_let_star(expr, env) + + # If + if form == 'if': + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + # Cond + if form == 'cond': + return self._eval_cond(expr, env) + + # And + if form == 'and': + result = True + for e in expr[1:]: + result = self.eval(e, env) + if not result: + return False + return result + + # Or + if form == 'or': + for e in expr[1:]: + result = self.eval(e, env) + if result: + return result + return False + + # Not + if form == 'not': + return not self.eval(expr[1], env) + + # Begin (sequence) + if form == 'begin': + result = None + for e in expr[1:]: + result = self.eval(e, env) + return result + + # Thread-first macro: (-> x (f a) (g b)) => (g (f x a) b) + if form == '->': + result = self.eval(expr[1], env) + for form_expr in expr[2:]: + if isinstance(form_expr, list): + # Insert result as first arg: (f a b) => (f result a b) + result = self.eval([form_expr[0], result] + form_expr[1:], env) + else: + # Just a symbol: f => (f result) + result = self.eval([form_expr, result], env) + return result + + # Set! (mutation) + if form == 'set!': + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + value = self.eval(expr[2], env) + # Find and update in appropriate scope + scope = env + while scope: + if name in scope.bindings: + scope.bindings[name] = value + return value + scope = scope.parent + raise NameError(f"Cannot set undefined variable: {name}") + + # State-get / state-set (for effect state) + if form == 'state-get': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + default = self.eval(expr[2], env) if len(expr) > 2 else None + return state.get(key, default) + + if form == 'state-set': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + value = self.eval(expr[2], env) + state[key] = value + return value + + # ascii-fx-zone special form - delays evaluation of expression parameters + if form == 'ascii-fx-zone': + return self._eval_ascii_fx_zone(expr, env) + + # with-primitives - load primitive library and scope to body + if form == 'with-primitives': + return self._eval_with_primitives(expr, env) + + # require-primitives - load primitive library into current scope + if form == 'require-primitives': + return self._eval_require_primitives(expr, env) + + # require - load .sexp file into current scope + if form == 'require': + return self._eval_require(expr, env) + + # Function call + fn = self.eval(head, env) + args = [self.eval(arg, env) for arg in expr[1:]] + + # Handle keyword arguments + pos_args = [] + kw_args = {} + i = 0 + while i < len(args): + if _is_keyword(args[i]): + kw_args[args[i].name] = args[i + 1] if i + 1 < len(args) else None + i += 2 + else: + pos_args.append(args[i]) + i += 1 + + return self._apply(fn, pos_args, kw_args, env) + + raise TypeError(f"Cannot evaluate: {expr}") + + def _wrap_lambda(self, lam: 'Lambda') -> Callable: + """Wrap a Lambda in a Python callable for use by primitives.""" + def wrapper(*args): + new_env = Environment(lam.env) + for i, param in enumerate(lam.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(lam.body, new_env) + return wrapper + + def _apply(self, fn: Any, args: List[Any], kwargs: Dict[str, Any], env: Environment) -> Any: + """Apply a function to arguments.""" + if isinstance(fn, Lambda): + # User-defined function + new_env = Environment(fn.env) + for i, param in enumerate(fn.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(fn.body, new_env) + + elif callable(fn): + # Wrap any Lambda arguments so primitives can call them + wrapped_args = [] + for arg in args: + if isinstance(arg, Lambda): + wrapped_args.append(self._wrap_lambda(arg)) + else: + wrapped_args.append(arg) + + # Inject _interp and _env for primitives that need them + import inspect + try: + sig = inspect.signature(fn) + params = sig.parameters + if '_interp' in params and '_interp' not in kwargs: + kwargs['_interp'] = self + if '_env' in params and '_env' not in kwargs: + kwargs['_env'] = env + except (ValueError, TypeError): + # Some built-in functions don't have inspectable signatures + pass + + # Primitive function + if kwargs: + return fn(*wrapped_args, **kwargs) + return fn(*wrapped_args) + + else: + raise TypeError(f"Cannot call: {fn}") + + def _parse_bindings(self, bindings: list) -> list: + """Parse bindings in either Scheme or Clojure style. + + Scheme: ((x 1) (y 2)) -> [(x, 1), (y, 2)] + Clojure: [x 1 y 2] -> [(x, 1), (y, 2)] + """ + if not bindings: + return [] + + # Check if Clojure style (flat list with symbols and values alternating) + if _is_symbol(bindings[0]): + # Clojure style: [x 1 y 2] + pairs = [] + i = 0 + while i < len(bindings) - 1: + name = bindings[i].name if _is_symbol(bindings[i]) else bindings[i] + value = bindings[i + 1] + pairs.append((name, value)) + i += 2 + return pairs + else: + # Scheme style: ((x 1) (y 2)) + pairs = [] + for binding in bindings: + name = binding[0].name if _is_symbol(binding[0]) else binding[0] + value = binding[1] + pairs.append((name, value)) + return pairs + + def _eval_let(self, expr: Any, env: Environment) -> Any: + """Evaluate let expression: (let ((x 1) (y 2)) body) or (let [x 1 y 2] body) + + Note: Uses sequential binding (like Clojure let / Scheme let*) so each + binding can reference previous bindings. + """ + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Sequential: can see previous bindings + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_let_star(self, expr: Any, env: Environment) -> Any: + """Evaluate let* expression: sequential bindings.""" + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Evaluate in current env + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_cond(self, expr: Any, env: Environment) -> Any: + """Evaluate cond expression.""" + for clause in expr[1:]: + test = clause[0] + if _is_symbol(test) and test.name == 'else': + return self.eval(clause[1], env) + if self.eval(test, env): + return self.eval(clause[1], env) + return None + + def _eval_with_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate with-primitives: scoped primitive library loading. + + Syntax: + (with-primitives "math" + (sin (* x pi))) + + (with-primitives "math" :path "custom/math.py" + body) + + The primitives from the library are only available within the body. + """ + # Parse library name and optional path + lib_name = expr[1] + if _is_symbol(lib_name): + lib_name = lib_name.name + + path = None + body_start = 2 + + # Check for :path keyword + if len(expr) > 2 and _is_keyword(expr[2]) and expr[2].name == 'path': + path = expr[3] + body_start = 4 + + # Load the primitive library + primitives = self.load_primitive_library(lib_name, path) + + # Create new environment with primitives + new_env = Environment(env) + for name, fn in primitives.items(): + new_env.set(name, fn) + + # Evaluate body in new environment + result = None + for e in expr[body_start:]: + result = self.eval(e, new_env) + return result + + def _eval_require_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require-primitives: load primitives into current scope. + + Syntax: + (require-primitives "math" "color" "filters") + + Unlike with-primitives, this loads into the current environment + (typically used at top-level to set up an effect's dependencies). + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + primitives = self.load_primitive_library(lib_name) + for name, fn in primitives.items(): + env.set(name, fn) + + return None + + def load_primitive_library(self, name: str, path: str = None) -> dict: + """ + Load a primitive library by name or path. + + Returns dict of {name: function}. + """ + from .primitive_libs import load_primitive_library + return load_primitive_library(name, path) + + def _eval_require(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require: load a .sexp file and evaluate its definitions. + + Syntax: + (require "derived") ; loads derived.sexp from sexp_effects/ + (require "path/to/file.sexp") ; loads from explicit path + + Definitions from the file are added to the current environment. + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + # Find the .sexp file + sexp_path = self._find_sexp_file(lib_name) + if sexp_path is None: + raise ValueError(f"Cannot find sexp file: {lib_name}") + + # Parse and evaluate the file + content = parse_file(sexp_path) + + # Evaluate all top-level expressions + if isinstance(content, list) and content and isinstance(content[0], list): + for e in content: + self.eval(e, env) + else: + self.eval(content, env) + + return None + + def _find_sexp_file(self, name: str) -> Optional[str]: + """Find a .sexp file by name.""" + # Try various locations + candidates = [ + # Explicit path + name, + name + '.sexp', + # In sexp_effects directory + Path(__file__).parent / f'{name}.sexp', + Path(__file__).parent / name, + # In effects directory + Path(__file__).parent / 'effects' / f'{name}.sexp', + Path(__file__).parent / 'effects' / name, + ] + + for path in candidates: + p = Path(path) if not isinstance(path, Path) else path + if p.exists() and p.is_file(): + return str(p) + + return None + + def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any: + """ + Evaluate ascii-fx-zone special form. + + Syntax: + (ascii-fx-zone frame + :cols 80 + :alphabet "standard" + :color_mode "color" + :background "black" + :contrast 1.5 + :char_hue ;; NOT evaluated - passed to primitive + :char_saturation + :char_brightness + :char_scale + :char_rotation + :char_jitter ) + + The expression parameters (:char_hue, etc.) are NOT pre-evaluated. + They are passed as raw S-expressions to the primitive which + evaluates them per-zone with zone context variables injected. + + Requires: (require-primitives "ascii") + """ + # Look up ascii-fx-zone primitive from environment + # It must be loaded via (require-primitives "ascii") + try: + prim_ascii_fx_zone = env.get('ascii-fx-zone') + except NameError: + raise NameError( + "ascii-fx-zone primitive not found. " + "Add (require-primitives \"ascii\") to your effect file." + ) + + # Expression parameter names that should NOT be evaluated + expr_params = {'char_hue', 'char_saturation', 'char_brightness', + 'char_scale', 'char_rotation', 'char_jitter', 'cell_effect'} + + # Parse arguments + frame = self.eval(expr[1], env) # First arg is always the frame + + # Defaults + cols = 80 + char_size = None # If set, overrides cols + alphabet = "standard" + color_mode = "color" + background = "black" + contrast = 1.5 + char_hue = None + char_saturation = None + char_brightness = None + char_scale = None + char_rotation = None + char_jitter = None + cell_effect = None # Lambda for arbitrary per-cell effects + # Convenience params for staged recipes + energy = None + rotation_scale = 0 + # Extra params to pass to zone dict for lambdas + extra_params = {} + + # Parse keyword arguments + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item): + if i + 1 >= len(expr): + break + value_expr = expr[i + 1] + kw_name = item.name + + if kw_name in expr_params: + # Resolve symbol references but don't evaluate expressions + # This handles the case where effect definition passes a param like :char_hue char_hue + resolved = value_expr + if _is_symbol(value_expr): + try: + resolved = env.get(value_expr.name) + except NameError: + resolved = value_expr # Keep as symbol if not found + + if kw_name == 'char_hue': + char_hue = resolved + elif kw_name == 'char_saturation': + char_saturation = resolved + elif kw_name == 'char_brightness': + char_brightness = resolved + elif kw_name == 'char_scale': + char_scale = resolved + elif kw_name == 'char_rotation': + char_rotation = resolved + elif kw_name == 'char_jitter': + char_jitter = resolved + elif kw_name == 'cell_effect': + cell_effect = resolved + else: + # Evaluate normally + value = self.eval(value_expr, env) + if kw_name == 'cols': + cols = int(value) + elif kw_name == 'char_size': + # Handle nil/None values + if value is None or (_is_symbol(value) and value.name == 'nil'): + char_size = None + else: + char_size = int(value) + elif kw_name == 'alphabet': + alphabet = str(value) + elif kw_name == 'color_mode': + color_mode = str(value) + elif kw_name == 'background': + background = str(value) + elif kw_name == 'contrast': + contrast = float(value) + elif kw_name == 'energy': + if value is None or (_is_symbol(value) and value.name == 'nil'): + energy = None + else: + energy = float(value) + elif kw_name == 'rotation_scale': + rotation_scale = float(value) + else: + # Store any other params for lambdas to access + extra_params[kw_name] = value + i += 2 + else: + i += 1 + + # If energy and rotation_scale provided, build rotation expression + # rotation = energy * rotation_scale * position_factor + # position_factor: bottom-left=0, top-right=3 + # Formula: 1.5 * (zone-col-norm + (1 - zone-row-norm)) + if energy is not None and rotation_scale > 0: + # Build expression as S-expression list that will be evaluated per-zone + # (* (* energy rotation_scale) (* 1.5 (+ zone-col-norm (- 1 zone-row-norm)))) + energy_times_scale = energy * rotation_scale + # The position part uses zone variables, so we build it as an expression + char_rotation = [ + Symbol('*'), + energy_times_scale, + [Symbol('*'), 1.5, + [Symbol('+'), Symbol('zone-col-norm'), + [Symbol('-'), 1, Symbol('zone-row-norm')]]] + ] + + # Pull any extra params from environment that aren't standard params + # These are typically passed from recipes for use in cell_effect lambdas + standard_params = { + 'cols', 'char_size', 'alphabet', 'color_mode', 'background', 'contrast', + 'char_hue', 'char_saturation', 'char_brightness', 'char_scale', + 'char_rotation', 'char_jitter', 'cell_effect', 'energy', 'rotation_scale', + 'frame', 't', '_time', '__state__', '__interp__', 'true', 'false', 'nil' + } + # Check environment for extra bindings + current_env = env + while current_env is not None: + for k, v in current_env.bindings.items(): + if k not in standard_params and k not in extra_params and not callable(v): + # Add non-standard, non-callable bindings to extra_params + if isinstance(v, (int, float, str, bool)) or v is None: + extra_params[k] = v + current_env = current_env.parent + + # Call the primitive with interpreter and env for expression evaluation + return prim_ascii_fx_zone( + frame, + cols=cols, + char_size=char_size, + alphabet=alphabet, + color_mode=color_mode, + background=background, + contrast=contrast, + char_hue=char_hue, + char_saturation=char_saturation, + char_brightness=char_brightness, + char_scale=char_scale, + char_rotation=char_rotation, + char_jitter=char_jitter, + cell_effect=cell_effect, + energy=energy, + rotation_scale=rotation_scale, + _interp=self, + _env=env, + **extra_params + ) + + def _define_effect(self, expr: Any, env: Environment) -> EffectDefinition: + """ + Parse effect definition. + + Required syntax: + (define-effect name + :params ( + (param1 :type int :default 8 :desc "description") + ) + body) + + Effects MUST use :params syntax. Legacy ((param default) ...) is not supported. + """ + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + + params = {} + body = None + found_params = False + + # Parse :params and body + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item) and item.name == "params": + # :params syntax + if i + 1 >= len(expr): + raise SyntaxError(f"Effect '{name}': Missing params list after :params keyword") + params_list = expr[i + 1] + params = self._parse_params_block(params_list) + found_params = True + i += 2 + elif _is_keyword(item): + # Skip other keywords (like :desc) + i += 2 + elif body is None: + # First non-keyword item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy syntax and reject it + if isinstance(first_elem, list) and len(first_elem) >= 2: + raise SyntaxError( + f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. " + f"Use :params block instead." + ) + body = item + i += 1 + else: + i += 1 + + if body is None: + raise SyntaxError(f"Effect '{name}': No body found") + + if not found_params: + raise SyntaxError( + f"Effect '{name}': Missing :params block. " + f"For effects with no parameters, use empty :params ()" + ) + + effect = EffectDefinition(name, params, body) + self.effects[name] = effect + return effect + + def _parse_params_block(self, params_list: list) -> Dict[str, Any]: + """ + Parse :params block syntax: + ( + (param_name :type int :default 8 :range [4 32] :desc "description") + ) + """ + params = {} + for param_def in params_list: + if not isinstance(param_def, list) or len(param_def) < 1: + continue + + # First element is the parameter name + first = param_def[0] + if _is_symbol(first): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + continue + + # Parse keyword arguments + default = None + i = 1 + while i < len(param_def): + item = param_def[i] + if _is_keyword(item): + if i + 1 >= len(param_def): + break + kw_value = param_def[i + 1] + + if item.name == "default": + default = kw_value + i += 2 + else: + i += 1 + + params[param_name] = default + + return params + + def load_effect(self, path: str) -> EffectDefinition: + """Load an effect definition from a .sexp file.""" + expr = parse_file(path) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the last defined effect + if self.effects: + return list(self.effects.values())[-1] + return None + + def load_effect_from_string(self, sexp_content: str, effect_name: str = None) -> EffectDefinition: + """Load an effect definition from an S-expression string. + + Args: + sexp_content: The S-expression content as a string + effect_name: Optional name hint (used if effect doesn't define its own name) + + Returns: + The loaded EffectDefinition + """ + expr = parse(sexp_content) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the effect if we can find it by name + if effect_name and effect_name in self.effects: + return self.effects[effect_name] + + # Return the most recently loaded effect + if self.effects: + return list(self.effects.values())[-1] + + return None + + def run_effect(self, name: str, frame, params: Dict[str, Any], + state: Dict[str, Any]) -> tuple: + """ + Run an effect on frame(s). + + Args: + name: Effect name + frame: Input frame (H, W, 3) RGB uint8, or list of frames for multi-input + params: Effect parameters (overrides defaults) + state: Persistent state dict + + Returns: + (output_frame, new_state) + """ + if name not in self.effects: + raise ValueError(f"Unknown effect: {name}") + + effect = self.effects[name] + + # Create environment for this run + env = Environment(self.global_env) + + # Bind frame(s) - support both single frame and list of frames + if isinstance(frame, list): + # Multi-input effect + frames = frame + env.set('frame', frames[0] if frames else None) # Backwards compat + env.set('inputs', frames) + # Named frame bindings + for i, f in enumerate(frames): + env.set(f'frame-{chr(ord("a") + i)}', f) # frame-a, frame-b, etc. + else: + # Single-input effect + env.set('frame', frame) + + # Bind state + if state is None: + state = {} + env.set('__state__', state) + + # Validate that all provided params are known (except internal params) + # Extra params are allowed and will be passed through to cell_effect lambdas + known_params = set(effect.params.keys()) + internal_params = {'_time', 'seed', '_binding', 'effect', 'cid', 'hash', 'effect_path'} + extra_effect_params = {} # Unknown params passed through for cell_effect lambdas + for k in params.keys(): + if k not in known_params and k not in internal_params: + # Allow unknown params - they'll be passed to cell_effect lambdas via zone dict + extra_effect_params[k] = params[k] + + # Bind parameters (defaults + overrides) + for pname, pdefault in effect.params.items(): + value = params.get(pname) + if value is None: + # Evaluate default if it's an expression (list) or a symbol (like 'nil') + if isinstance(pdefault, list) or _is_symbol(pdefault): + value = self.eval(pdefault, env) + else: + value = pdefault + env.set(pname, value) + + # Bind extra params (unknown params passed through for cell_effect lambdas) + for k, v in extra_effect_params.items(): + env.set(k, v) + + # Reset RNG with seed if provided + seed = params.get('seed', 42) + reset_rng(int(seed)) + + # Bind time if provided + time_val = params.get('_time', 0) + env.set('t', time_val) + env.set('_time', time_val) + + # Evaluate body + result = self.eval(effect.body, env) + + # Ensure result is an image + if not isinstance(result, np.ndarray): + result = frame + + return result, state + + def eval_with_zone(self, expr, env: Environment, zone) -> Any: + """ + Evaluate expression with zone-* variables injected. + + Args: + expr: Expression to evaluate (S-expression) + env: Parent environment with bound values + zone: ZoneContext object with cell data + + Zone variables injected: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: + Evaluated result (typically a number) + """ + # Create child environment with zone variables + zone_env = Environment(env) + zone_env.set('zone-row', zone.row) + zone_env.set('zone-col', zone.col) + zone_env.set('zone-row-norm', zone.row_norm) + zone_env.set('zone-col-norm', zone.col_norm) + zone_env.set('zone-lum', zone.luminance) + zone_env.set('zone-sat', zone.saturation) + zone_env.set('zone-hue', zone.hue) + zone_env.set('zone-r', zone.r) + zone_env.set('zone-g', zone.g) + zone_env.set('zone-b', zone.b) + + return self.eval(expr, zone_env) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +_interpreter = None +_interpreter_minimal = None + + +def get_interpreter(minimal_primitives: bool = False) -> Interpreter: + """Get or create the global interpreter. + + Args: + minimal_primitives: If True, return interpreter with only core primitives. + Additional primitives must be loaded with require-primitives or with-primitives. + """ + global _interpreter, _interpreter_minimal + + if minimal_primitives: + if _interpreter_minimal is None: + _interpreter_minimal = Interpreter(minimal_primitives=True) + return _interpreter_minimal + else: + if _interpreter is None: + _interpreter = Interpreter(minimal_primitives=False) + return _interpreter + + +def load_effect(path: str) -> EffectDefinition: + """Load an effect from a .sexp file.""" + return get_interpreter().load_effect(path) + + +def load_effects_dir(directory: str): + """Load all .sexp effects from a directory.""" + interp = get_interpreter() + dir_path = Path(directory) + for path in dir_path.glob('*.sexp'): + try: + interp.load_effect(str(path)) + except Exception as e: + print(f"Warning: Failed to load {path}: {e}") + + +def run_effect(name: str, frame: np.ndarray, params: Dict[str, Any], + state: Dict[str, Any] = None) -> tuple: + """Run an effect.""" + return get_interpreter().run_effect(name, frame, params, state or {}) + + +def list_effects() -> List[str]: + """List loaded effect names.""" + return list(get_interpreter().effects.keys()) + + +# ============================================================================= +# Adapter for existing effect system +# ============================================================================= + +def make_process_frame(effect_path: str) -> Callable: + """ + Create a process_frame function from a .sexp effect. + + This allows S-expression effects to be used with the existing + effect system. + """ + interp = get_interpreter() + interp.load_effect(effect_path) + effect_name = Path(effect_path).stem + + def process_frame(frame: np.ndarray, params: dict, state: dict) -> tuple: + return interp.run_effect(effect_name, frame, params, state) + + return process_frame diff --git a/sexp_effects/parser.py b/sexp_effects/parser.py new file mode 100644 index 0000000..5e17565 --- /dev/null +++ b/sexp_effects/parser.py @@ -0,0 +1,396 @@ +""" +S-expression parser for ArtDAG recipes and plans. + +Supports: +- Lists: (a b c) +- Symbols: foo, bar-baz, -> +- Keywords: :key +- Strings: "hello world" +- Numbers: 42, 3.14, -1.5 +- Comments: ; to end of line +- Vectors: [a b c] (syntactic sugar for lists) +- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts) +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Union +import re + + +@dataclass +class Symbol: + """An unquoted symbol/identifier.""" + name: str + + def __repr__(self): + return f"Symbol({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Symbol): + return self.name == other.name + if isinstance(other, str): + return self.name == other + return False + + def __hash__(self): + return hash(self.name) + + +@dataclass +class Keyword: + """A keyword starting with colon.""" + name: str + + def __repr__(self): + return f"Keyword({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Keyword): + return self.name == other.name + return False + + def __hash__(self): + return hash((':' , self.name)) + + +class ParseError(Exception): + """Error during S-expression parsing.""" + def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1): + self.position = position + self.line = line + self.col = col + super().__init__(f"{message} at line {line}, column {col}") + + +class Tokenizer: + """Tokenize S-expression text into tokens.""" + + # Token patterns + WHITESPACE = re.compile(r'\s+') + COMMENT = re.compile(r';[^\n]*') + STRING = re.compile(r'"(?:[^"\\]|\\.)*"') + NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') + KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') + # Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations + SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*') + + def __init__(self, text: str): + self.text = text + self.pos = 0 + self.line = 1 + self.col = 1 + + def _advance(self, count: int = 1): + """Advance position, tracking line/column.""" + for _ in range(count): + if self.pos < len(self.text): + if self.text[self.pos] == '\n': + self.line += 1 + self.col = 1 + else: + self.col += 1 + self.pos += 1 + + def _skip_whitespace_and_comments(self): + """Skip whitespace and comments.""" + while self.pos < len(self.text): + # Whitespace + match = self.WHITESPACE.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + # Comments + match = self.COMMENT.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + break + + def peek(self) -> str | None: + """Peek at current character.""" + self._skip_whitespace_and_comments() + if self.pos >= len(self.text): + return None + return self.text[self.pos] + + def next_token(self) -> Any: + """Get the next token.""" + self._skip_whitespace_and_comments() + + if self.pos >= len(self.text): + return None + + char = self.text[self.pos] + start_line, start_col = self.line, self.col + + # Single-character tokens (parens, brackets, braces) + if char in '()[]{}': + self._advance() + return char + + # String + if char == '"': + match = self.STRING.match(self.text, self.pos) + if not match: + raise ParseError("Unterminated string", self.pos, self.line, self.col) + self._advance(match.end() - self.pos) + # Parse escape sequences + content = match.group()[1:-1] + content = content.replace('\\n', '\n') + content = content.replace('\\t', '\t') + content = content.replace('\\"', '"') + content = content.replace('\\\\', '\\') + return content + + # Keyword + if char == ':': + match = self.KEYWORD.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Keyword(match.group()[1:]) # Strip leading colon + raise ParseError(f"Invalid keyword", self.pos, self.line, self.col) + + # Number (must check before symbol due to - prefix) + if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and + (self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')): + match = self.NUMBER.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + num_str = match.group() + if '.' in num_str or 'e' in num_str or 'E' in num_str: + return float(num_str) + return int(num_str) + + # Symbol + match = self.SYMBOL.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Symbol(match.group()) + + raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col) + + +def parse(text: str) -> Any: + """ + Parse an S-expression string into Python data structures. + + Returns: + Parsed S-expression as nested Python structures: + - Lists become Python lists + - Symbols become Symbol objects + - Keywords become Keyword objects + - Strings become Python strings + - Numbers become int/float + + Example: + >>> parse('(recipe "test" :version "1.0")') + [Symbol('recipe'), 'test', Keyword('version'), '1.0'] + """ + tokenizer = Tokenizer(text) + result = _parse_expr(tokenizer) + + # Check for trailing content + if tokenizer.peek() is not None: + raise ParseError("Unexpected content after expression", + tokenizer.pos, tokenizer.line, tokenizer.col) + + return result + + +def parse_all(text: str) -> List[Any]: + """ + Parse multiple S-expressions from a string. + + Returns list of parsed expressions. + """ + tokenizer = Tokenizer(text) + results = [] + + while tokenizer.peek() is not None: + results.append(_parse_expr(tokenizer)) + + return results + + +def _parse_expr(tokenizer: Tokenizer) -> Any: + """Parse a single expression.""" + token = tokenizer.next_token() + + if token is None: + raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col) + + # List + if token == '(': + return _parse_list(tokenizer, ')') + + # Vector (sugar for list) + if token == '[': + return _parse_list(tokenizer, ']') + + # Map/dict: {:key1 val1 :key2 val2} + if token == '{': + return _parse_map(tokenizer) + + # Unexpected closers + if isinstance(token, str) and token in ')]}': + raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col) + + # Atom + return token + + +def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]: + """Parse a list until the closing delimiter.""" + items = [] + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError(f"Unterminated list, expected {closer!r}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == closer: + tokenizer.next_token() # Consume closer + return items + + items.append(_parse_expr(tokenizer)) + + +def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]: + """Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}.""" + result = {} + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError("Unterminated map, expected '}'", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == '}': + tokenizer.next_token() # Consume closer + return result + + # Parse key (should be a keyword like :key) + key_token = _parse_expr(tokenizer) + if isinstance(key_token, Keyword): + key = key_token.name + elif isinstance(key_token, str): + key = key_token + else: + raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + # Parse value + value = _parse_expr(tokenizer) + result[key] = value + + +def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str: + """ + Serialize a Python data structure back to S-expression format. + + Args: + expr: The expression to serialize + indent: Current indentation level (for pretty printing) + pretty: Whether to use pretty printing with newlines + + Returns: + S-expression string + """ + if isinstance(expr, list): + if not expr: + return "()" + + if pretty: + return _serialize_pretty(expr, indent) + else: + items = [serialize(item, indent, False) for item in expr] + return "(" + " ".join(items) + ")" + + if isinstance(expr, Symbol): + return expr.name + + if isinstance(expr, Keyword): + return f":{expr.name}" + + if isinstance(expr, str): + # Escape special characters + escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t') + return f'"{escaped}"' + + if isinstance(expr, bool): + return "true" if expr else "false" + + if isinstance(expr, (int, float)): + return str(expr) + + if expr is None: + return "nil" + + if isinstance(expr, dict): + # Serialize dict as property list: {:key1 val1 :key2 val2} + items = [] + for k, v in expr.items(): + items.append(f":{k}") + items.append(serialize(v, indent, pretty)) + return "{" + " ".join(items) + "}" + + raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}") + + +def _serialize_pretty(expr: List, indent: int) -> str: + """Pretty-print a list expression with smart formatting.""" + if not expr: + return "()" + + prefix = " " * indent + inner_prefix = " " * (indent + 1) + + # Check if this is a simple list that fits on one line + simple = serialize(expr, indent, False) + if len(simple) < 60 and '\n' not in simple: + return simple + + # Start building multiline output + head = serialize(expr[0], indent + 1, False) + parts = [f"({head}"] + + i = 1 + while i < len(expr): + item = expr[i] + + # Group keyword-value pairs on same line + if isinstance(item, Keyword) and i + 1 < len(expr): + key = serialize(item, 0, False) + val = serialize(expr[i + 1], indent + 1, False) + + # If value is short, put on same line + if len(val) < 50 and '\n' not in val: + parts.append(f"{inner_prefix}{key} {val}") + else: + # Value is complex, serialize it pretty + val_pretty = serialize(expr[i + 1], indent + 1, True) + parts.append(f"{inner_prefix}{key} {val_pretty}") + i += 2 + else: + # Regular item + item_str = serialize(item, indent + 1, True) + parts.append(f"{inner_prefix}{item_str}") + i += 1 + + return "\n".join(parts) + ")" + + +def parse_file(path: str) -> Any: + """Parse an S-expression file (supports multiple top-level expressions).""" + with open(path, 'r') as f: + return parse_all(f.read()) + + +def to_sexp(obj: Any) -> str: + """Convert Python object back to S-expression string (alias for serialize).""" + return serialize(obj) diff --git a/sexp_effects/primitive_libs/__init__.py b/sexp_effects/primitive_libs/__init__.py new file mode 100644 index 0000000..47ee174 --- /dev/null +++ b/sexp_effects/primitive_libs/__init__.py @@ -0,0 +1,102 @@ +""" +Primitive Libraries System + +Provides modular loading of primitives. Core primitives are always available, +additional primitive libraries can be loaded on-demand with scoped availability. + +Usage in sexp: + ;; Load at recipe level - available throughout + (primitives math :path "primitive_libs/math.py") + + ;; Or use with-primitives for scoped access + (with-primitives "image" + (blur frame 3)) ;; blur only available inside + + ;; Nested scopes work + (with-primitives "math" + (with-primitives "color" + (hue-shift frame (* (sin t) 30)))) + +Library file format (primitive_libs/math.py): + import math + + def prim_sin(x): return math.sin(x) + def prim_cos(x): return math.cos(x) + + PRIMITIVES = { + 'sin': prim_sin, + 'cos': prim_cos, + } +""" + +import importlib.util +from pathlib import Path +from typing import Dict, Callable, Any, Optional + +# Cache of loaded primitive libraries +_library_cache: Dict[str, Dict[str, Any]] = {} + +# Core primitives - always available, cannot be overridden +CORE_PRIMITIVES: Dict[str, Any] = {} + + +def register_core_primitive(name: str, fn: Callable): + """Register a core primitive that's always available.""" + CORE_PRIMITIVES[name] = fn + + +def load_primitive_library(name: str, path: Optional[str] = None) -> Dict[str, Any]: + """ + Load a primitive library by name or path. + + Args: + name: Library name (e.g., "math", "image", "color") + path: Optional explicit path to library file + + Returns: + Dict of primitive name -> function + """ + # Check cache first + cache_key = path or name + if cache_key in _library_cache: + return _library_cache[cache_key] + + # Find library file + if path: + lib_path = Path(path) + else: + # Look in standard locations + lib_dir = Path(__file__).parent + lib_path = lib_dir / f"{name}.py" + + if not lib_path.exists(): + raise ValueError(f"Primitive library '{name}' not found at {lib_path}") + + if not lib_path.exists(): + raise ValueError(f"Primitive library file not found: {lib_path}") + + # Load the module + spec = importlib.util.spec_from_file_location(f"prim_lib_{name}", lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get PRIMITIVES dict from module + if not hasattr(module, 'PRIMITIVES'): + raise ValueError(f"Primitive library '{name}' missing PRIMITIVES dict") + + primitives = module.PRIMITIVES + + # Cache and return + _library_cache[cache_key] = primitives + return primitives + + +def get_library_names() -> list: + """Get names of available primitive libraries.""" + lib_dir = Path(__file__).parent + return [p.stem for p in lib_dir.glob("*.py") if p.stem != "__init__"] + + +def clear_cache(): + """Clear the library cache (useful for testing).""" + _library_cache.clear() diff --git a/sexp_effects/primitive_libs/arrays.py b/sexp_effects/primitive_libs/arrays.py new file mode 100644 index 0000000..61da196 --- /dev/null +++ b/sexp_effects/primitive_libs/arrays.py @@ -0,0 +1,196 @@ +""" +Array Primitives Library + +Vectorized operations on numpy arrays for coordinate transformations. +""" +import numpy as np + + +# Arithmetic +def prim_arr_add(a, b): + return np.add(a, b) + + +def prim_arr_sub(a, b): + return np.subtract(a, b) + + +def prim_arr_mul(a, b): + return np.multiply(a, b) + + +def prim_arr_div(a, b): + return np.divide(a, b) + + +def prim_arr_mod(a, b): + return np.mod(a, b) + + +def prim_arr_neg(a): + return np.negative(a) + + +# Math functions +def prim_arr_sin(a): + return np.sin(a) + + +def prim_arr_cos(a): + return np.cos(a) + + +def prim_arr_tan(a): + return np.tan(a) + + +def prim_arr_sqrt(a): + return np.sqrt(np.maximum(a, 0)) + + +def prim_arr_pow(a, b): + return np.power(a, b) + + +def prim_arr_abs(a): + return np.abs(a) + + +def prim_arr_exp(a): + return np.exp(a) + + +def prim_arr_log(a): + return np.log(np.maximum(a, 1e-10)) + + +def prim_arr_atan2(y, x): + return np.arctan2(y, x) + + +# Comparison / selection +def prim_arr_min(a, b): + return np.minimum(a, b) + + +def prim_arr_max(a, b): + return np.maximum(a, b) + + +def prim_arr_clip(a, lo, hi): + return np.clip(a, lo, hi) + + +def prim_arr_where(cond, a, b): + return np.where(cond, a, b) + + +def prim_arr_floor(a): + return np.floor(a) + + +def prim_arr_ceil(a): + return np.ceil(a) + + +def prim_arr_round(a): + return np.round(a) + + +# Interpolation +def prim_arr_lerp(a, b, t): + return a + (b - a) * t + + +def prim_arr_smoothstep(edge0, edge1, x): + t = prim_arr_clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +# Creation +def prim_arr_zeros(shape): + return np.zeros(shape, dtype=np.float32) + + +def prim_arr_ones(shape): + return np.ones(shape, dtype=np.float32) + + +def prim_arr_full(shape, value): + return np.full(shape, value, dtype=np.float32) + + +def prim_arr_arange(start, stop, step=1): + return np.arange(start, stop, step, dtype=np.float32) + + +def prim_arr_linspace(start, stop, num): + return np.linspace(start, stop, num, dtype=np.float32) + + +def prim_arr_meshgrid(x, y): + return np.meshgrid(x, y) + + +# Coordinate transforms +def prim_polar_from_center(map_x, map_y, cx, cy): + """Convert Cartesian to polar coordinates centered at (cx, cy).""" + dx = map_x - cx + dy = map_y - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + return (r, theta) + + +def prim_cart_from_polar(r, theta, cx, cy): + """Convert polar to Cartesian, adding center offset.""" + x = r * np.cos(theta) + cx + y = r * np.sin(theta) + cy + return (x, y) + + +PRIMITIVES = { + # Arithmetic + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-neg': prim_arr_neg, + + # Math + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-exp': prim_arr_exp, + 'arr-log': prim_arr_log, + 'arr-atan2': prim_arr_atan2, + + # Selection + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-ceil': prim_arr_ceil, + 'arr-round': prim_arr_round, + + # Interpolation + 'arr-lerp': prim_arr_lerp, + 'arr-smoothstep': prim_arr_smoothstep, + + # Creation + 'arr-zeros': prim_arr_zeros, + 'arr-ones': prim_arr_ones, + 'arr-full': prim_arr_full, + 'arr-arange': prim_arr_arange, + 'arr-linspace': prim_arr_linspace, + 'arr-meshgrid': prim_arr_meshgrid, + + # Coordinates + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, +} diff --git a/sexp_effects/primitive_libs/ascii.py b/sexp_effects/primitive_libs/ascii.py new file mode 100644 index 0000000..858f010 --- /dev/null +++ b/sexp_effects/primitive_libs/ascii.py @@ -0,0 +1,388 @@ +""" +ASCII Art Primitives Library + +ASCII art rendering with per-zone expression evaluation and cell effects. +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont +from typing import Any, Dict, List, Optional, Callable +import colorsys + + +# Character sets +CHAR_SETS = { + "standard": " .:-=+*#%@", + "blocks": " ░▒▓█", + "simple": " .:oO@", + "digits": "0123456789", + "binary": "01", + "ascii": " `.-':_,^=;><+!rc*/z?sLTv)J7(|Fi{C}fI31tlu[neoZ5Yxjya]2ESwqkP6h9d4VpOGbUAKXHm8RD#$Bg0MNWQ%&@", +} + +# Default font +_default_font = None + + +def _get_font(size: int): + """Get monospace font at given size.""" + global _default_font + try: + return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) + except: + return ImageFont.load_default() + + +def _parse_color(color_str: str) -> tuple: + """Parse color string to RGB tuple.""" + if color_str.startswith('#'): + hex_color = color_str[1:] + if len(hex_color) == 3: + hex_color = ''.join(c*2 for c in hex_color) + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + + colors = { + 'black': (0, 0, 0), 'white': (255, 255, 255), + 'red': (255, 0, 0), 'green': (0, 255, 0), 'blue': (0, 0, 255), + 'yellow': (255, 255, 0), 'cyan': (0, 255, 255), 'magenta': (255, 0, 255), + 'gray': (128, 128, 128), 'grey': (128, 128, 128), + } + return colors.get(color_str.lower(), (0, 0, 0)) + + +def _cell_sample(frame: np.ndarray, cell_size: int): + """Sample frame into cells, returning colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + """ + h, w = frame.shape[:2] + rows = h // cell_size + cols = w // cell_size + + # Crop to exact grid then block-average via cv2 area interpolation. + cropped = frame[:rows * cell_size, :cols * cell_size] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + luminances = ((0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]) / 255.0).astype(np.float32) + + return colors, luminances + + +def _luminance_to_char(lum: float, alphabet: str, contrast: float) -> str: + """Map luminance to character.""" + chars = CHAR_SETS.get(alphabet, alphabet) + lum = ((lum - 0.5) * contrast + 0.5) + lum = max(0, min(1, lum)) + idx = int(lum * (len(chars) - 1)) + return chars[idx] + + +def _render_char_cell(char: str, cell_size: int, color: tuple, bg_color: tuple) -> np.ndarray: + """Render a single character to a cell image.""" + img = Image.new('RGB', (cell_size, cell_size), bg_color) + draw = ImageDraw.Draw(img) + font = _get_font(cell_size) + + # Center the character + bbox = draw.textbbox((0, 0), char, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + x = (cell_size - text_w) // 2 + y = (cell_size - text_h) // 2 - bbox[1] + + draw.text((x, y), char, fill=color, font=font) + return np.array(img) + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int = 80, + char_size: int = None, + alphabet: str = "standard", + color_mode: str = "color", + background: str = "black", + contrast: float = 1.5, + char_hue = None, + char_saturation = None, + char_brightness = None, + char_scale = None, + char_rotation = None, + char_jitter = None, + cell_effect = None, + energy: float = None, + rotation_scale: float = 0, + _interp = None, + _env = None, + **extra_params +) -> np.ndarray: + """ + Render frame as ASCII art with per-zone effects. + + Args: + frame: Input image + cols: Number of character columns + char_size: Cell size in pixels (overrides cols if set) + alphabet: Character set name or custom string + color_mode: "color", "mono", "invert", or color name + background: Background color name or hex + contrast: Contrast for character selection + char_hue/saturation/brightness/scale/rotation/jitter: Per-zone expressions + cell_effect: Lambda (cell, zone) -> cell for per-cell effects + energy: Energy value from audio analysis + rotation_scale: Max rotation degrees + _interp: Interpreter (auto-injected) + _env: Environment (auto-injected) + **extra_params: Additional params passed to zone dict + """ + h, w = frame.shape[:2] + + # Calculate cell size + if char_size is None or char_size == 0: + cell_size = max(4, w // cols) + else: + cell_size = max(4, int(char_size)) + + # Sample cells + colors, luminances = _cell_sample(frame, cell_size) + rows, cols_actual = luminances.shape + + # Parse background color + bg_color = _parse_color(background) + + # Create output image + out_h = rows * cell_size + out_w = cols_actual * cell_size + output = np.full((out_h, out_w, 3), bg_color, dtype=np.uint8) + + # Check if we have cell_effect + has_cell_effect = cell_effect is not None + + # Process each cell + for r in range(rows): + for c in range(cols_actual): + lum = luminances[r, c] + cell_color = tuple(colors[r, c]) + + # Build zone context + zone = { + 'row': r, + 'col': c, + 'row-norm': r / max(1, rows - 1), + 'col-norm': c / max(1, cols_actual - 1), + 'lum': float(lum), + 'r': cell_color[0] / 255, + 'g': cell_color[1] / 255, + 'b': cell_color[2] / 255, + 'cell_size': cell_size, + } + + # Add HSV + r_f, g_f, b_f = cell_color[0]/255, cell_color[1]/255, cell_color[2]/255 + hsv = colorsys.rgb_to_hsv(r_f, g_f, b_f) + zone['hue'] = hsv[0] * 360 + zone['sat'] = hsv[1] + + # Add energy and rotation_scale + if energy is not None: + zone['energy'] = energy + zone['rotation_scale'] = rotation_scale + + # Add extra params + for k, v in extra_params.items(): + if isinstance(v, (int, float, str, bool)) or v is None: + zone[k] = v + + # Get character + char = _luminance_to_char(lum, alphabet, contrast) + zone['char'] = char + + # Determine cell color based on mode + if color_mode == "mono": + render_color = (255, 255, 255) + elif color_mode == "invert": + render_color = tuple(255 - c for c in cell_color) + elif color_mode == "color": + render_color = cell_color + else: + render_color = _parse_color(color_mode) + + zone['color'] = render_color + + # Render character to cell + cell_img = _render_char_cell(char, cell_size, render_color, bg_color) + + # Apply cell_effect if provided + if has_cell_effect and _interp is not None: + cell_img = _apply_cell_effect(cell_img, zone, cell_effect, _interp, _env, extra_params) + + # Paste cell to output + y1, y2 = r * cell_size, (r + 1) * cell_size + x1, x2 = c * cell_size, (c + 1) * cell_size + output[y1:y2, x1:x2] = cell_img + + # Resize to match input dimensions + if output.shape[:2] != frame.shape[:2]: + output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LINEAR) + + return output + + +def _apply_cell_effect(cell_img, zone, cell_effect, interp, env, extra_params): + """Apply cell_effect lambda to a cell image. + + cell_effect is a Lambda object with params and body. + We create a child environment with zone variables and cell, + then evaluate the lambda body. + """ + # Get Environment class from the interpreter's module + Environment = type(env) + + # Create child environment with zone variables + cell_env = Environment(env) + + # Bind zone variables + for k, v in zone.items(): + cell_env.set(k, v) + + # Also bind with zone- prefix for consistency + cell_env.set('zone-row', zone.get('row', 0)) + cell_env.set('zone-col', zone.get('col', 0)) + cell_env.set('zone-row-norm', zone.get('row-norm', 0)) + cell_env.set('zone-col-norm', zone.get('col-norm', 0)) + cell_env.set('zone-lum', zone.get('lum', 0)) + cell_env.set('zone-sat', zone.get('sat', 0)) + cell_env.set('zone-hue', zone.get('hue', 0)) + cell_env.set('zone-r', zone.get('r', 0)) + cell_env.set('zone-g', zone.get('g', 0)) + cell_env.set('zone-b', zone.get('b', 0)) + + # Inject loaded effects as callable functions + if hasattr(interp, 'effects'): + for effect_name in interp.effects: + def make_effect_fn(name): + def effect_fn(frame, *args): + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 2: + params['offset_x'] = args[0] + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['size'] = args[0] + elif name == 'invert': + pass + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + cell_env.set(effect_name, make_effect_fn(effect_name)) + + # Bind cell image and zone dict + cell_env.set('cell', cell_img) + cell_env.set('zone', zone) + + # Evaluate the cell_effect lambda + # Lambda has params and body - we need to bind the params then evaluate + if hasattr(cell_effect, 'params') and hasattr(cell_effect, 'body'): + # Bind lambda parameters: (lambda [cell zone] body) + if len(cell_effect.params) >= 1: + cell_env.set(cell_effect.params[0], cell_img) + if len(cell_effect.params) >= 2: + cell_env.set(cell_effect.params[1], zone) + + result = interp.eval(cell_effect.body, cell_env) + elif isinstance(cell_effect, list): + # Raw S-expression lambda like (lambda [cell zone] body) or (fn [cell zone] body) + # Check if it's a lambda expression + head = cell_effect[0] if cell_effect else None + head_name = head.name if head and hasattr(head, 'name') else str(head) if head else None + is_lambda = head_name in ('lambda', 'fn') + + if is_lambda: + # (lambda [params...] body) + params = cell_effect[1] if len(cell_effect) > 1 else [] + body = cell_effect[2] if len(cell_effect) > 2 else None + + # Bind lambda parameters + if isinstance(params, list) and len(params) >= 1: + param_name = params[0].name if hasattr(params[0], 'name') else str(params[0]) + cell_env.set(param_name, cell_img) + if isinstance(params, list) and len(params) >= 2: + param_name = params[1].name if hasattr(params[1], 'name') else str(params[1]) + cell_env.set(param_name, zone) + + result = interp.eval(body, cell_env) if body else cell_img + else: + # Some other expression - just evaluate it + result = interp.eval(cell_effect, cell_env) + elif callable(cell_effect): + # It's a callable + result = cell_effect(cell_img, zone) + else: + raise ValueError(f"cell_effect must be a Lambda, list, or callable, got {type(cell_effect)}") + + if isinstance(result, np.ndarray) and result.shape == cell_img.shape: + return result + elif isinstance(result, np.ndarray): + # Shape mismatch - resize to fit + result = cv2.resize(result, (cell_img.shape[1], cell_img.shape[0])) + return result + + raise ValueError(f"cell_effect must return an image array, got {type(result)}") + + +def _get_legacy_ascii_primitives(): + """Import ASCII primitives from legacy primitives module. + + These are loaded lazily to avoid import issues during module loading. + By the time a primitive library is loaded, sexp_effects.primitives + is already in sys.modules (imported by sexp_effects.__init__). + """ + from sexp_effects.primitives import ( + prim_cell_sample, + prim_luminance_to_chars, + prim_render_char_grid, + prim_render_char_grid_fx, + prim_alphabet_char, + prim_alphabet_length, + prim_map_char_grid, + prim_map_colors, + prim_make_char_grid, + prim_set_char, + prim_get_char, + prim_char_grid_dimensions, + cell_sample_extended, + ) + return { + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + } + + +PRIMITIVES = { + 'ascii-fx-zone': prim_ascii_fx_zone, + **_get_legacy_ascii_primitives(), +} diff --git a/sexp_effects/primitive_libs/blending.py b/sexp_effects/primitive_libs/blending.py new file mode 100644 index 0000000..0bf345d --- /dev/null +++ b/sexp_effects/primitive_libs/blending.py @@ -0,0 +1,116 @@ +""" +Blending Primitives Library + +Image blending and compositing operations. +""" +import numpy as np + + +def prim_blend_images(a, b, alpha): + """Blend two images: a * (1-alpha) + b * alpha.""" + alpha = max(0.0, min(1.0, alpha)) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a, b, mode): + """Blend using Photoshop-style blend modes.""" + a = a.astype(float) / 255 + b = b.astype(float) / 255 + + if mode == "multiply": + result = a * b + elif mode == "screen": + result = 1 - (1 - a) * (1 - b) + elif mode == "overlay": + mask = a < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "soft-light": + mask = b < 0.5 + result = np.where(mask, + a - (1 - 2 * b) * a * (1 - a), + a + (2 * b - 1) * (np.sqrt(a) - a)) + elif mode == "hard-light": + mask = b < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "color-dodge": + result = np.clip(a / (1 - b + 0.001), 0, 1) + elif mode == "color-burn": + result = 1 - np.clip((1 - a) / (b + 0.001), 0, 1) + elif mode == "difference": + result = np.abs(a - b) + elif mode == "exclusion": + result = a + b - 2 * a * b + elif mode == "add": + result = np.clip(a + b, 0, 1) + elif mode == "subtract": + result = np.clip(a - b, 0, 1) + elif mode == "darken": + result = np.minimum(a, b) + elif mode == "lighten": + result = np.maximum(a, b) + else: + # Default to normal (just return b) + result = b + + return (result * 255).astype(np.uint8) + + +def prim_mask(img, mask_img): + """Apply grayscale mask to image (white=opaque, black=transparent).""" + if len(mask_img.shape) == 3: + mask = mask_img[:, :, 0].astype(float) / 255 + else: + mask = mask_img.astype(float) / 255 + + mask = mask[:, :, np.newaxis] + return (img.astype(float) * mask).astype(np.uint8) + + +def prim_alpha_composite(base, overlay, alpha_channel): + """Composite overlay onto base using alpha channel.""" + if len(alpha_channel.shape) == 3: + alpha = alpha_channel[:, :, 0].astype(float) / 255 + else: + alpha = alpha_channel.astype(float) / 255 + + alpha = alpha[:, :, np.newaxis] + result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha + return result.astype(np.uint8) + + +def prim_overlay(base, overlay, x, y, alpha=1.0): + """Overlay image at position (x, y) with optional alpha.""" + result = base.copy() + x, y = int(x), int(y) + oh, ow = overlay.shape[:2] + bh, bw = base.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8) + + return result + + +PRIMITIVES = { + # Basic blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + + # Masking + 'mask': prim_mask, + 'alpha-composite': prim_alpha_composite, + + # Overlay + 'overlay': prim_overlay, +} diff --git a/sexp_effects/primitive_libs/blending_gpu.py b/sexp_effects/primitive_libs/blending_gpu.py new file mode 100644 index 0000000..c768be3 --- /dev/null +++ b/sexp_effects/primitive_libs/blending_gpu.py @@ -0,0 +1,220 @@ +""" +GPU-Accelerated Blending Primitives Library + +Uses CuPy for CUDA-accelerated image blending and compositing. +Keeps frames on GPU when STREAMING_GPU_PERSIST=1 for maximum performance. +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + GPU_AVAILABLE = True + print("[blending_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[blending_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[blending_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def _get_xp(img): + """Get the array module (numpy or cupy) for the given image.""" + if GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp + return np + + +def prim_blend_images(a, b, alpha): + """Blend two images: a * (1-alpha) + b * alpha.""" + alpha = max(0.0, min(1.0, float(alpha))) + + if GPU_AVAILABLE: + a_gpu = _to_gpu(a) + b_gpu = _to_gpu(b) + result = (a_gpu.astype(cp.float32) * (1 - alpha) + b_gpu.astype(cp.float32) * alpha).astype(cp.uint8) + return _to_cpu(result) + + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a, b, mode): + """Blend using Photoshop-style blend modes.""" + if GPU_AVAILABLE: + a_gpu = _to_gpu(a).astype(cp.float32) / 255 + b_gpu = _to_gpu(b).astype(cp.float32) / 255 + xp = cp + else: + a_gpu = a.astype(float) / 255 + b_gpu = b.astype(float) / 255 + xp = np + + if mode == "multiply": + result = a_gpu * b_gpu + elif mode == "screen": + result = 1 - (1 - a_gpu) * (1 - b_gpu) + elif mode == "overlay": + mask = a_gpu < 0.5 + result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu)) + elif mode == "soft-light": + mask = b_gpu < 0.5 + result = xp.where(mask, + a_gpu - (1 - 2 * b_gpu) * a_gpu * (1 - a_gpu), + a_gpu + (2 * b_gpu - 1) * (xp.sqrt(a_gpu) - a_gpu)) + elif mode == "hard-light": + mask = b_gpu < 0.5 + result = xp.where(mask, 2 * a_gpu * b_gpu, 1 - 2 * (1 - a_gpu) * (1 - b_gpu)) + elif mode == "color-dodge": + result = xp.clip(a_gpu / (1 - b_gpu + 0.001), 0, 1) + elif mode == "color-burn": + result = 1 - xp.clip((1 - a_gpu) / (b_gpu + 0.001), 0, 1) + elif mode == "difference": + result = xp.abs(a_gpu - b_gpu) + elif mode == "exclusion": + result = a_gpu + b_gpu - 2 * a_gpu * b_gpu + elif mode == "add": + result = xp.clip(a_gpu + b_gpu, 0, 1) + elif mode == "subtract": + result = xp.clip(a_gpu - b_gpu, 0, 1) + elif mode == "darken": + result = xp.minimum(a_gpu, b_gpu) + elif mode == "lighten": + result = xp.maximum(a_gpu, b_gpu) + else: + # Default to normal (just return b) + result = b_gpu + + result = (result * 255).astype(xp.uint8) + return _to_cpu(result) + + +def prim_mask(img, mask_img): + """Apply grayscale mask to image (white=opaque, black=transparent).""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + mask_gpu = _to_gpu(mask_img) + + if len(mask_gpu.shape) == 3: + mask = mask_gpu[:, :, 0].astype(cp.float32) / 255 + else: + mask = mask_gpu.astype(cp.float32) / 255 + + mask = mask[:, :, cp.newaxis] + result = (img_gpu.astype(cp.float32) * mask).astype(cp.uint8) + return _to_cpu(result) + + if len(mask_img.shape) == 3: + mask = mask_img[:, :, 0].astype(float) / 255 + else: + mask = mask_img.astype(float) / 255 + + mask = mask[:, :, np.newaxis] + return (img.astype(float) * mask).astype(np.uint8) + + +def prim_alpha_composite(base, overlay, alpha_channel): + """Composite overlay onto base using alpha channel.""" + if GPU_AVAILABLE: + base_gpu = _to_gpu(base) + overlay_gpu = _to_gpu(overlay) + alpha_gpu = _to_gpu(alpha_channel) + + if len(alpha_gpu.shape) == 3: + alpha = alpha_gpu[:, :, 0].astype(cp.float32) / 255 + else: + alpha = alpha_gpu.astype(cp.float32) / 255 + + alpha = alpha[:, :, cp.newaxis] + result = base_gpu.astype(cp.float32) * (1 - alpha) + overlay_gpu.astype(cp.float32) * alpha + return _to_cpu(result.astype(cp.uint8)) + + if len(alpha_channel.shape) == 3: + alpha = alpha_channel[:, :, 0].astype(float) / 255 + else: + alpha = alpha_channel.astype(float) / 255 + + alpha = alpha[:, :, np.newaxis] + result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha + return result.astype(np.uint8) + + +def prim_overlay(base, overlay, x, y, alpha=1.0): + """Overlay image at position (x, y) with optional alpha.""" + if GPU_AVAILABLE: + base_gpu = _to_gpu(base) + overlay_gpu = _to_gpu(overlay) + result = base_gpu.copy() + + x, y = int(x), int(y) + oh, ow = overlay_gpu.shape[:2] + bh, bw = base_gpu.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay_gpu[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(cp.float32) * (1 - alpha) + src.astype(cp.float32) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(cp.uint8) + + return _to_cpu(result) + + result = base.copy() + x, y = int(x), int(y) + oh, ow = overlay.shape[:2] + bh, bw = base.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8) + + return result + + +PRIMITIVES = { + # Basic blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + + # Masking + 'mask': prim_mask, + 'alpha-composite': prim_alpha_composite, + + # Overlay + 'overlay': prim_overlay, +} diff --git a/sexp_effects/primitive_libs/color.py b/sexp_effects/primitive_libs/color.py new file mode 100644 index 0000000..0b6854b --- /dev/null +++ b/sexp_effects/primitive_libs/color.py @@ -0,0 +1,137 @@ +""" +Color Primitives Library + +Color manipulation: RGB, HSV, blending, luminance. +""" +import numpy as np +import colorsys + + +def prim_rgb(r, g, b): + """Create RGB color as [r, g, b] (0-255).""" + return [int(max(0, min(255, r))), + int(max(0, min(255, g))), + int(max(0, min(255, b)))] + + +def prim_red(c): + return c[0] + + +def prim_green(c): + return c[1] + + +def prim_blue(c): + return c[2] + + +def prim_luminance(c): + """Perceived luminance (0-1) using standard weights.""" + return (0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) / 255 + + +def prim_rgb_to_hsv(c): + """Convert RGB [0-255] to HSV [h:0-360, s:0-1, v:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, s, v = colorsys.rgb_to_hsv(r, g, b) + return [h * 360, s, v] + + +def prim_hsv_to_rgb(hsv): + """Convert HSV [h:0-360, s:0-1, v:0-1] to RGB [0-255].""" + h, s, v = hsv[0] / 360, hsv[1], hsv[2] + r, g, b = colorsys.hsv_to_rgb(h, s, v) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_rgb_to_hsl(c): + """Convert RGB [0-255] to HSL [h:0-360, s:0-1, l:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, l, s = colorsys.rgb_to_hls(r, g, b) + return [h * 360, s, l] + + +def prim_hsl_to_rgb(hsl): + """Convert HSL [h:0-360, s:0-1, l:0-1] to RGB [0-255].""" + h, s, l = hsl[0] / 360, hsl[1], hsl[2] + r, g, b = colorsys.hls_to_rgb(h, l, s) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_blend_color(c1, c2, alpha): + """Blend two colors: c1 * (1-alpha) + c2 * alpha.""" + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img): + """Get average color of an image.""" + mean = np.mean(img, axis=(0, 1)) + return [int(mean[0]), int(mean[1]), int(mean[2])] + + +def prim_dominant_color(img, k=1): + """Get dominant color using k-means (simplified: just average for now).""" + return prim_average_color(img) + + +def prim_invert_color(c): + """Invert a color.""" + return [255 - c[0], 255 - c[1], 255 - c[2]] + + +def prim_grayscale_color(c): + """Convert color to grayscale.""" + gray = int(0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) + return [gray, gray, gray] + + +def prim_saturate(c, amount): + """Adjust saturation of color. amount=0 is grayscale, 1 is unchanged, >1 is more saturated.""" + hsv = prim_rgb_to_hsv(c) + hsv[1] = max(0, min(1, hsv[1] * amount)) + return prim_hsv_to_rgb(hsv) + + +def prim_brighten(c, amount): + """Adjust brightness. amount=0 is black, 1 is unchanged, >1 is brighter.""" + return [int(max(0, min(255, c[i] * amount))) for i in range(3)] + + +def prim_shift_hue(c, degrees): + """Shift hue by degrees.""" + hsv = prim_rgb_to_hsv(c) + hsv[0] = (hsv[0] + degrees) % 360 + return prim_hsv_to_rgb(hsv) + + +PRIMITIVES = { + # Construction + 'rgb': prim_rgb, + + # Component access + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + + # Color space conversion + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'rgb->hsl': prim_rgb_to_hsl, + 'hsl->rgb': prim_hsl_to_rgb, + + # Blending + 'blend-color': prim_blend_color, + + # Analysis + 'average-color': prim_average_color, + 'dominant-color': prim_dominant_color, + + # Manipulation + 'invert-color': prim_invert_color, + 'grayscale-color': prim_grayscale_color, + 'saturate': prim_saturate, + 'brighten': prim_brighten, + 'shift-hue': prim_shift_hue, +} diff --git a/sexp_effects/primitive_libs/color_ops.py b/sexp_effects/primitive_libs/color_ops.py new file mode 100644 index 0000000..a0da497 --- /dev/null +++ b/sexp_effects/primitive_libs/color_ops.py @@ -0,0 +1,109 @@ +""" +Color Operations Primitives Library + +Vectorized color adjustments: brightness, contrast, saturation, invert, HSV. +These operate on entire images for fast processing. +""" +import numpy as np +import cv2 + + +def _to_numpy(img): + """Convert GPU frames or CuPy arrays to numpy for CPU processing.""" + # Handle GPUFrame objects + if hasattr(img, 'cpu'): + return img.cpu + # Handle CuPy arrays + if hasattr(img, 'get'): + return img.get() + return img + + +def prim_adjust(img, brightness=0, contrast=1): + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + img = _to_numpy(img) + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img_raw, amount): + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + img = _to_numpy(img_raw) + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img): + """Invert all pixel values.""" + img = _to_numpy(img) + return (255 - img).astype(np.uint8) + + +def prim_shift_hsv(img, h=0, s=1, v=1): + """Shift HSV: h=degrees offset, s/v=multipliers.""" + img = _to_numpy(img) + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +def prim_add_noise(img, amount): + """Add gaussian noise to image.""" + img = _to_numpy(img) + noise = np.random.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img, levels): + """Reduce to N color levels per channel.""" + img = _to_numpy(img) + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_sepia(img, intensity=1.0): + """Apply sepia tone effect.""" + img = _to_numpy(img) + sepia_matrix = np.array([ + [0.393, 0.769, 0.189], + [0.349, 0.686, 0.168], + [0.272, 0.534, 0.131] + ]) + sepia = np.dot(img, sepia_matrix.T) + result = img.astype(np.float32) * (1 - intensity) + sepia * intensity + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_grayscale(img): + """Convert to grayscale (still RGB output).""" + img = _to_numpy(img) + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + return np.stack([gray, gray, gray], axis=-1).astype(np.uint8) + + +PRIMITIVES = { + # Brightness/Contrast + 'adjust': prim_adjust, + + # Saturation + 'mix-gray': prim_mix_gray, + 'grayscale': prim_grayscale, + + # HSV manipulation + 'shift-hsv': prim_shift_hsv, + + # Inversion + 'invert-img': prim_invert_img, + + # Effects + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'sepia': prim_sepia, +} diff --git a/sexp_effects/primitive_libs/color_ops_gpu.py b/sexp_effects/primitive_libs/color_ops_gpu.py new file mode 100644 index 0000000..a4f5272 --- /dev/null +++ b/sexp_effects/primitive_libs/color_ops_gpu.py @@ -0,0 +1,280 @@ +""" +GPU-Accelerated Color Operations Library + +Uses CuPy for CUDA-accelerated color transforms. + +Performance Mode: +- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations +- This dramatically improves performance by avoiding CPU<->GPU transfers +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + GPU_AVAILABLE = True + print("[color_ops_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[color_ops_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[color_ops_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def prim_invert(img): + """Invert image colors.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(255 - img_gpu) + return 255 - img + + +def prim_grayscale(img): + """Convert to grayscale.""" + if img.ndim != 3: + return img + + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + # Standard luminance weights + gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2] + gray = cp.clip(gray, 0, 255).astype(cp.uint8) + # Stack to 3 channels + result = cp.stack([gray, gray, gray], axis=2) + return _to_cpu(result) + + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray = np.clip(gray, 0, 255).astype(np.uint8) + return np.stack([gray, gray, gray], axis=2) + + +def prim_brightness(img, factor=1.0): + """Adjust brightness by factor.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + result = xp.clip(img_gpu * factor, 0, 255).astype(xp.uint8) + return _to_cpu(result) + return np.clip(img.astype(np.float32) * factor, 0, 255).astype(np.uint8) + + +def prim_contrast(img, factor=1.0): + """Adjust contrast around midpoint.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + img_gpu = _to_gpu(img.astype(np.float32)) + result = xp.clip((img_gpu - 128) * factor + 128, 0, 255).astype(xp.uint8) + return _to_cpu(result) + return np.clip((img.astype(np.float32) - 128) * factor + 128, 0, 255).astype(np.uint8) + + +# CUDA kernel for HSV hue shift +if GPU_AVAILABLE: + _hue_shift_kernel = cp.RawKernel(r''' + extern "C" __global__ + void hue_shift(unsigned char* img, int width, int height, float shift) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int idx = (y * width + x) * 3; + + // Get RGB + float r = img[idx] / 255.0f; + float g = img[idx + 1] / 255.0f; + float b = img[idx + 2] / 255.0f; + + // RGB to HSV + float max_c = fmaxf(r, fmaxf(g, b)); + float min_c = fminf(r, fminf(g, b)); + float delta = max_c - min_c; + + float h = 0.0f, s = 0.0f, v = max_c; + + if (delta > 0.00001f) { + s = delta / max_c; + + if (max_c == r) { + h = 60.0f * fmodf((g - b) / delta, 6.0f); + } else if (max_c == g) { + h = 60.0f * ((b - r) / delta + 2.0f); + } else { + h = 60.0f * ((r - g) / delta + 4.0f); + } + + if (h < 0) h += 360.0f; + } + + // Shift hue + h = fmodf(h + shift, 360.0f); + if (h < 0) h += 360.0f; + + // HSV to RGB + float c = v * s; + float x_val = c * (1.0f - fabsf(fmodf(h / 60.0f, 2.0f) - 1.0f)); + float m = v - c; + + float r_out, g_out, b_out; + if (h < 60) { + r_out = c; g_out = x_val; b_out = 0; + } else if (h < 120) { + r_out = x_val; g_out = c; b_out = 0; + } else if (h < 180) { + r_out = 0; g_out = c; b_out = x_val; + } else if (h < 240) { + r_out = 0; g_out = x_val; b_out = c; + } else if (h < 300) { + r_out = x_val; g_out = 0; b_out = c; + } else { + r_out = c; g_out = 0; b_out = x_val; + } + + img[idx] = (unsigned char)fminf(255.0f, (r_out + m) * 255.0f); + img[idx + 1] = (unsigned char)fminf(255.0f, (g_out + m) * 255.0f); + img[idx + 2] = (unsigned char)fminf(255.0f, (b_out + m) * 255.0f); + } + ''', 'hue_shift') + + +def prim_hue_shift(img, shift=0.0): + """Shift hue by degrees.""" + if img.ndim != 3 or img.shape[2] != 3: + return img + + if not GPU_AVAILABLE: + import cv2 + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.float32) + shift / 2) % 180 + return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + h, w = img.shape[:2] + img_gpu = _to_gpu(img.astype(np.uint8)).copy() + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _hue_shift_kernel(grid, block, (img_gpu, np.int32(w), np.int32(h), np.float32(shift))) + + return _to_cpu(img_gpu) + + +def prim_saturate(img, factor=1.0): + """Adjust saturation by factor.""" + if img.ndim != 3: + return img + + if not GPU_AVAILABLE: + import cv2 + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + # GPU version - simple desaturation blend + img_gpu = _to_gpu(img.astype(np.float32)) + gray = 0.299 * img_gpu[:, :, 0] + 0.587 * img_gpu[:, :, 1] + 0.114 * img_gpu[:, :, 2] + gray = gray[:, :, cp.newaxis] + + if factor < 1.0: + # Desaturate: blend toward gray + result = img_gpu * factor + gray * (1 - factor) + else: + # Oversaturate: extrapolate away from gray + result = gray + (img_gpu - gray) * factor + + result = cp.clip(result, 0, 255).astype(cp.uint8) + return _to_cpu(result) + + +def prim_blend(img1, img2, alpha=0.5): + """Blend two images with alpha.""" + xp = cp if GPU_AVAILABLE else np + + if GPU_AVAILABLE: + img1_gpu = _to_gpu(img1.astype(np.float32)) + img2_gpu = _to_gpu(img2.astype(np.float32)) + result = img1_gpu * (1 - alpha) + img2_gpu * alpha + result = xp.clip(result, 0, 255).astype(xp.uint8) + return _to_cpu(result) + + result = img1.astype(np.float32) * (1 - alpha) + img2.astype(np.float32) * alpha + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_add(img1, img2): + """Add two images (clamped).""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + result = xp.clip(_to_gpu(img1).astype(np.int16) + _to_gpu(img2).astype(np.int16), 0, 255) + return _to_cpu(result.astype(xp.uint8)) + return np.clip(img1.astype(np.int16) + img2.astype(np.int16), 0, 255).astype(np.uint8) + + +def prim_multiply(img1, img2): + """Multiply two images (normalized).""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + result = (_to_gpu(img1).astype(np.float32) * _to_gpu(img2).astype(np.float32)) / 255.0 + result = xp.clip(result, 0, 255).astype(xp.uint8) + return _to_cpu(result) + result = (img1.astype(np.float32) * img2.astype(np.float32)) / 255.0 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_screen(img1, img2): + """Screen blend mode.""" + xp = cp if GPU_AVAILABLE else np + if GPU_AVAILABLE: + i1 = _to_gpu(img1).astype(np.float32) / 255.0 + i2 = _to_gpu(img2).astype(np.float32) / 255.0 + result = 1.0 - (1.0 - i1) * (1.0 - i2) + result = xp.clip(result * 255, 0, 255).astype(xp.uint8) + return _to_cpu(result) + i1 = img1.astype(np.float32) / 255.0 + i2 = img2.astype(np.float32) / 255.0 + result = 1.0 - (1.0 - i1) * (1.0 - i2) + return np.clip(result * 255, 0, 255).astype(np.uint8) + + +# Import CPU primitives as fallbacks +def _get_cpu_primitives(): + """Get all primitives from CPU color_ops module as fallbacks.""" + from sexp_effects.primitive_libs import color_ops + return color_ops.PRIMITIVES + + +# Export functions - start with CPU primitives, then override with GPU versions +PRIMITIVES = _get_cpu_primitives().copy() + +# Override specific primitives with GPU-accelerated versions +PRIMITIVES.update({ + 'invert': prim_invert, + 'grayscale': prim_grayscale, + 'brightness': prim_brightness, + 'contrast': prim_contrast, + 'hue-shift': prim_hue_shift, + 'saturate': prim_saturate, + 'blend': prim_blend, + 'add': prim_add, + 'multiply': prim_multiply, + 'screen': prim_screen, +}) diff --git a/sexp_effects/primitive_libs/core.py b/sexp_effects/primitive_libs/core.py new file mode 100644 index 0000000..34b580a --- /dev/null +++ b/sexp_effects/primitive_libs/core.py @@ -0,0 +1,294 @@ +""" +Core Primitives - Always available, minimal essential set. + +These are the primitives that form the foundation of the language. +They cannot be overridden by libraries. +""" + + +# Arithmetic +def prim_add(*args): + if len(args) == 0: + return 0 + result = args[0] + for arg in args[1:]: + result = result + arg + return result + + +def prim_sub(a, b=None): + if b is None: + return -a + return a - b + + +def prim_mul(*args): + if len(args) == 0: + return 1 + result = args[0] + for arg in args[1:]: + result = result * arg + return result + + +def prim_div(a, b): + return a / b + + +def prim_mod(a, b): + return a % b + + +def prim_abs(x): + return abs(x) + + +def prim_min(*args): + return min(args) + + +def prim_max(*args): + return max(args) + + +def prim_round(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) + return round(x) + + +def prim_floor(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) + import math + return math.floor(x) + + +def prim_ceil(x): + import numpy as np + if hasattr(x, '_data'): # Xector + from .xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) + import math + return math.ceil(x) + + +# Comparison +def prim_lt(a, b): + return a < b + + +def prim_gt(a, b): + return a > b + + +def prim_le(a, b): + return a <= b + + +def prim_ge(a, b): + return a >= b + + +def prim_eq(a, b): + if isinstance(a, float) or isinstance(b, float): + return abs(a - b) < 1e-9 + return a == b + + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# Logic +def prim_not(x): + return not x + + +def prim_and(*args): + for a in args: + if not a: + return False + return True + + +def prim_or(*args): + for a in args: + if a: + return True + return False + + +# Basic data access +def prim_get(obj, key, default=None): + """Get value from dict or list.""" + if isinstance(obj, dict): + return obj.get(key, default) + elif isinstance(obj, (list, tuple)): + try: + return obj[int(key)] + except (IndexError, ValueError): + return default + return default + + +def prim_nth(seq, i): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_length(seq): + return len(seq) + + +def prim_list(*args): + return list(args) + + +# Type checking +def prim_is_number(x): + return isinstance(x, (int, float)) + + +def prim_is_string(x): + return isinstance(x, str) + + +def prim_is_list(x): + return isinstance(x, (list, tuple)) + + +def prim_is_dict(x): + return isinstance(x, dict) + + +def prim_is_nil(x): + return x is None + + +# Higher-order / iteration +def prim_reduce(seq, init, fn): + """(reduce seq init fn) — fold left: fn(fn(fn(init, s0), s1), s2) ...""" + acc = init + for item in seq: + acc = fn(acc, item) + return acc + + +def prim_map(seq, fn): + """(map seq fn) — apply fn to each element, return new list.""" + return [fn(item) for item in seq] + + +def prim_range(*args): + """(range end), (range start end), or (range start end step) — integer range.""" + if len(args) == 1: + return list(range(int(args[0]))) + elif len(args) == 2: + return list(range(int(args[0]), int(args[1]))) + elif len(args) >= 3: + return list(range(int(args[0]), int(args[1]), int(args[2]))) + return [] + + +# Random +import random +_rng = random.Random() + +def set_random_seed(seed): + """Set the random seed for deterministic output.""" + global _rng + _rng = random.Random(seed) + +def prim_rand(): + """Return random float in [0, 1).""" + return _rng.random() + +def prim_rand_int(lo, hi): + """Return random integer in [lo, hi].""" + return _rng.randint(int(lo), int(hi)) + +def prim_rand_range(lo, hi): + """Return random float in [lo, hi).""" + return lo + _rng.random() * (hi - lo) + +def prim_map_range(val, from_lo, from_hi, to_lo, to_hi): + """Map value from one range to another.""" + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + +# Core primitives dict +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + 'mod': prim_mod, + 'abs': prim_abs, + 'min': prim_min, + 'max': prim_max, + 'round': prim_round, + 'floor': prim_floor, + 'ceil': prim_ceil, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Logic + 'not': prim_not, + 'and': prim_and, + 'or': prim_or, + + # Data access + 'get': prim_get, + 'nth': prim_nth, + 'first': prim_first, + 'length': prim_length, + 'len': prim_length, + 'list': prim_list, + + # Type predicates + 'number?': prim_is_number, + 'string?': prim_is_string, + 'list?': prim_is_list, + 'dict?': prim_is_dict, + 'nil?': prim_is_nil, + 'is-nil': prim_is_nil, + + # Higher-order / iteration + 'reduce': prim_reduce, + 'fold': prim_reduce, + 'map': prim_map, + 'range': prim_range, + + # Random + 'rand': prim_rand, + 'rand-int': prim_rand_int, + 'rand-range': prim_rand_range, + 'map-range': prim_map_range, +} diff --git a/sexp_effects/primitive_libs/drawing.py b/sexp_effects/primitive_libs/drawing.py new file mode 100644 index 0000000..50e0c45 --- /dev/null +++ b/sexp_effects/primitive_libs/drawing.py @@ -0,0 +1,690 @@ +""" +Drawing Primitives Library + +Draw shapes, text, and characters on images with sophisticated text handling. + +Text Features: +- Font loading from files or system fonts +- Text measurement and fitting +- Alignment (left/center/right, top/middle/bottom) +- Opacity for fade effects +- Multi-line text support +- Shadow and outline effects +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont +import os +import glob as glob_module +from typing import Optional, Tuple, List, Union + + +# ============================================================================= +# Font Management +# ============================================================================= + +# Font cache: (path, size) -> font object +_font_cache = {} + +# Common system font directories +FONT_DIRS = [ + "/usr/share/fonts", + "/usr/local/share/fonts", + "~/.fonts", + "~/.local/share/fonts", + "/System/Library/Fonts", # macOS + "/Library/Fonts", # macOS + "C:/Windows/Fonts", # Windows +] + +# Default fonts to try (in order of preference) +DEFAULT_FONTS = [ + "DejaVuSans.ttf", + "DejaVuSansMono.ttf", + "Arial.ttf", + "Helvetica.ttf", + "FreeSans.ttf", + "LiberationSans-Regular.ttf", +] + + +def _find_font_file(name: str) -> Optional[str]: + """Find a font file by name in system directories.""" + # If it's already a full path + if os.path.isfile(name): + return name + + # Expand user paths + expanded = os.path.expanduser(name) + if os.path.isfile(expanded): + return expanded + + # Search in font directories + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + # Direct match + direct = os.path.join(font_dir, name) + if os.path.isfile(direct): + return direct + + # Recursive search + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower() == name.lower(): + return os.path.join(root, f) + # Also match without extension + base = os.path.splitext(f)[0] + if base.lower() == name.lower(): + return os.path.join(root, f) + + return None + + +def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont: + """Get a default font at the given size.""" + for font_name in DEFAULT_FONTS: + path = _find_font_file(font_name) + if path: + try: + return ImageFont.truetype(path, size) + except: + continue + + # Last resort: PIL default + return ImageFont.load_default() + + +def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont: + """ + Load a font by name or path. + + (make-font "Arial" 32) ; system font by name + (make-font "/path/to/font.ttf" 24) ; font file path + (make-font "DejaVuSans" 48) ; searches common locations + + Returns a font object for use with text primitives. + """ + size = int(size) + + # Check cache + cache_key = (name_or_path, size) + if cache_key in _font_cache: + return _font_cache[cache_key] + + # Find the font file + path = _find_font_file(name_or_path) + if not path: + raise FileNotFoundError(f"Font not found: {name_or_path}") + + # Load and cache + font = ImageFont.truetype(path, size) + _font_cache[cache_key] = font + return font + + +def prim_list_fonts() -> List[str]: + """ + List available system fonts. + + (list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...) + + Returns list of font filenames found in system directories. + """ + fonts = set() + + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower().endswith(('.ttf', '.otf', '.ttc')): + fonts.add(f) + + return sorted(fonts) + + +def prim_font_size(font: ImageFont.FreeTypeFont) -> int: + """ + Get the size of a font. + + (font-size my-font) ; -> 24 + """ + return font.size + + +# ============================================================================= +# Text Measurement +# ============================================================================= + +def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]: + """ + Measure text dimensions. + + (text-size "Hello" my-font) ; -> (width height) + (text-size "Hello" :font-size 32) ; -> (width height) with default font + + For multi-line text, returns total bounding box. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + # Create temporary image for measurement + img = Image.new('RGB', (1, 1)) + draw = ImageDraw.Draw(img) + + bbox = draw.textbbox((0, 0), str(text), font=font) + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + + return (width, height) + + +def prim_text_metrics(font=None, font_size: int = 24) -> dict: + """ + Get font metrics. + + (text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25} + + Useful for precise text layout. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + ascent, descent = font.getmetrics() + return { + 'ascent': ascent, + 'descent': descent, + 'height': ascent + descent, + 'size': font.size, + } + + +def prim_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> int: + """ + Calculate font size to fit text within bounds. + + (fit-text-size "Hello World" 400 100) ; -> 48 + (fit-text-size "Title" 800 200 :font-name "Arial") + + Returns the largest font size that fits within max_width x max_height. + """ + max_width = int(max_width) + max_height = int(max_height) + min_size = int(min_size) + max_size = int(max_size) + text = str(text) + + # Binary search for optimal size + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + + if font_name: + try: + font = prim_make_font(font_name, mid) + except: + font = _get_default_font(mid) + else: + font = _get_default_font(mid) + + w, h = prim_text_size(text, font) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + +def prim_fit_font(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> ImageFont.FreeTypeFont: + """ + Create a font sized to fit text within bounds. + + (fit-font "Hello World" 400 100) ; -> font object + (fit-font "Title" 800 200 :font-name "Arial") + + Returns a font object at the optimal size. + """ + size = prim_fit_text_size(text, max_width, max_height, + font_name, min_size, max_size) + + if font_name: + try: + return prim_make_font(font_name, size) + except: + pass + + return _get_default_font(size) + + +# ============================================================================= +# Text Drawing +# ============================================================================= + +def prim_text(img: np.ndarray, text: str, + x: int = None, y: int = None, + width: int = None, height: int = None, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "left", valign: str = "top", + fit: bool = False, + shadow: bool = False, shadow_color=None, shadow_offset: int = 2, + outline: bool = False, outline_color=None, outline_width: int = 1, + line_spacing: float = 1.2) -> np.ndarray: + """ + Draw text with alignment, opacity, and effects. + + Basic usage: + (text frame "Hello" :x 100 :y 50) + + Centered in frame: + (text frame "Title" :align "center" :valign "middle") + + Fit to box: + (text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true) + + With fade (for animations): + (text frame "Fading" :x 100 :y 100 :opacity 0.5) + + With effects: + (text frame "Shadow" :x 100 :y 100 :shadow true) + (text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0)) + + Args: + img: Input frame + text: Text to draw + x, y: Position (if not specified, uses alignment in full frame) + width, height: Bounding box (for fit and alignment within box) + font: Font object from make-font + font_size: Size if no font specified + font_name: Font name to load + color: RGB tuple (default white) + opacity: 0.0 (invisible) to 1.0 (opaque) for fading + align: "left", "center", "right" + valign: "top", "middle", "bottom" + fit: If true, auto-size font to fit in box + shadow: Draw drop shadow + shadow_color: Shadow color (default black) + shadow_offset: Shadow offset in pixels + outline: Draw text outline + outline_color: Outline color (default black) + outline_width: Outline thickness + line_spacing: Multiplier for line height (for multi-line) + + Returns: + Frame with text drawn + """ + h, w = img.shape[:2] + text = str(text) + + # Default colors + if color is None: + color = (255, 255, 255) + else: + color = tuple(int(c) for c in color) + + if shadow_color is None: + shadow_color = (0, 0, 0) + else: + shadow_color = tuple(int(c) for c in shadow_color) + + if outline_color is None: + outline_color = (0, 0, 0) + else: + outline_color = tuple(int(c) for c in outline_color) + + # Determine bounding box + if x is None: + x = 0 + if width is None: + width = w + if y is None: + y = 0 + if height is None: + height = h + + x, y = int(x), int(y) + box_width = int(width) if width else w - x + box_height = int(height) if height else h - y + + # Get or create font + if font is None: + if fit: + font = prim_fit_font(text, box_width, box_height, font_name) + elif font_name: + try: + font = prim_make_font(font_name, int(font_size)) + except: + font = _get_default_font(int(font_size)) + else: + font = _get_default_font(int(font_size)) + + # Measure text + text_w, text_h = prim_text_size(text, font) + + # Calculate position based on alignment + if align == "center": + draw_x = x + (box_width - text_w) // 2 + elif align == "right": + draw_x = x + box_width - text_w + else: # left + draw_x = x + + if valign == "middle": + draw_y = y + (box_height - text_h) // 2 + elif valign == "bottom": + draw_y = y + box_height - text_h + else: # top + draw_y = y + + # Create RGBA image for compositing with opacity + pil_img = Image.fromarray(img).convert('RGBA') + + # Create text layer with transparency + text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(text_layer) + + # Draw shadow first (if enabled) + if shadow: + shadow_x = draw_x + shadow_offset + shadow_y = draw_y + shadow_offset + shadow_rgba = shadow_color + (int(255 * opacity * 0.5),) + draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font) + + # Draw outline (if enabled) + if outline: + outline_rgba = outline_color + (int(255 * opacity),) + ow = int(outline_width) + for dx in range(-ow, ow + 1): + for dy in range(-ow, ow + 1): + if dx != 0 or dy != 0: + draw.text((draw_x + dx, draw_y + dy), text, + fill=outline_rgba, font=font) + + # Draw main text + text_rgba = color + (int(255 * opacity),) + draw.text((draw_x, draw_y), text, fill=text_rgba, font=font) + + # Composite + result = Image.alpha_composite(pil_img, text_layer) + return np.array(result.convert('RGB')) + + +def prim_text_box(img: np.ndarray, text: str, + x: int, y: int, width: int, height: int, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "center", valign: str = "middle", + fit: bool = True, + padding: int = 0, + background=None, background_opacity: float = 0.5, + **kwargs) -> np.ndarray: + """ + Draw text fitted within a box, optionally with background. + + (text-box frame "Title" 50 50 400 100) + (text-box frame "Subtitle" 50 160 400 50 + :background (0 0 0) :background-opacity 0.7) + + Convenience wrapper around text() for common box-with-text pattern. + """ + x, y = int(x), int(y) + width, height = int(width), int(height) + padding = int(padding) + + result = img.copy() + + # Draw background if specified + if background is not None: + bg_color = tuple(int(c) for c in background) + + # Create background with opacity + pil_img = Image.fromarray(result).convert('RGBA') + bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + bg_draw = ImageDraw.Draw(bg_layer) + bg_rgba = bg_color + (int(255 * background_opacity),) + bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba) + result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB')) + + # Draw text within padded box + return prim_text(result, text, + x=x + padding, y=y + padding, + width=width - 2 * padding, height=height - 2 * padding, + font=font, font_size=font_size, font_name=font_name, + color=color, opacity=opacity, + align=align, valign=valign, fit=fit, + **kwargs) + + +# ============================================================================= +# Legacy text functions (keep for compatibility) +# ============================================================================= + +def prim_draw_char(img, char, x, y, font_size=16, color=None): + """Draw a single character at (x, y). Legacy function.""" + return prim_text(img, str(char), x=int(x), y=int(y), + font_size=int(font_size), color=color) + + +def prim_draw_text(img, text, x, y, font_size=16, color=None): + """Draw text string at (x, y). Legacy function.""" + return prim_text(img, str(text), x=int(x), y=int(y), + font_size=int(font_size), color=color) + + +# ============================================================================= +# Shape Drawing +# ============================================================================= + +def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0): + """ + Fill a rectangle with color. + + (fill-rect frame 10 10 100 50 (255 0 0)) + (fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5) + """ + if color is None: + color = [255, 255, 255] + + x, y, w, h = int(x), int(y), int(w), int(h) + + if opacity >= 1.0: + result = img.copy() + result[y:y+h, x:x+w] = color + return result + + # With opacity, use alpha compositing + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([x, y, x + w, y + h], fill=fill_rgba) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0): + """Draw rectangle outline.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([int(x), int(y), int(x+w), int(y+h)], + outline=outline_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0): + """Draw a line from (x1, y1) to (x2, y2).""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.line([(int(x1), int(y1)), (int(x2), int(y2))], + fill=line_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, + fill=False, opacity: float = 1.0): + """Draw a circle.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.circle(result, (int(cx), int(cy)), int(radius), + tuple(int(c) for c in color), t) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + cx, cy, r = int(cx), int(cy), int(radius) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + + if fill: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba) + else: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], + outline=rgba, width=int(thickness)) + + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, + thickness=1, fill=False, opacity: float = 1.0): + """Draw an ellipse.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, tuple(int(c) for c in color), t) + return result + + # With opacity (note: PIL doesn't support rotated ellipses easily) + # Fall back to cv2 on a separate layer + layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) + t = -1 if fill else int(thickness) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, rgba, t) + + pil_img = Image.fromarray(img).convert('RGBA') + pil_layer = Image.fromarray(layer) + result = Image.alpha_composite(pil_img, pil_layer) + return np.array(result.convert('RGB')) + + +def prim_draw_polygon(img, points, color=None, thickness=1, + fill=False, opacity: float = 1.0): + """Draw a polygon from list of [x, y] points.""" + if color is None: + color = [255, 255, 255] + + if opacity >= 1.0: + result = img.copy() + pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + if fill: + cv2.fillPoly(result, [pts], tuple(int(c) for c in color)) + else: + cv2.polylines(result, [pts], True, + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + + pts_flat = [(int(p[0]), int(p[1])) for p in points] + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + + if fill: + draw.polygon(pts_flat, fill=rgba) + else: + draw.polygon(pts_flat, outline=rgba, width=int(thickness)) + + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) + + +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + +PRIMITIVES = { + # Font management + 'make-font': prim_make_font, + 'list-fonts': prim_list_fonts, + 'font-size': prim_font_size, + + # Text measurement + 'text-size': prim_text_size, + 'text-metrics': prim_text_metrics, + 'fit-text-size': prim_fit_text_size, + 'fit-font': prim_fit_font, + + # Text drawing + 'text': prim_text, + 'text-box': prim_text_box, + + # Legacy text (compatibility) + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + + # Rectangles + 'fill-rect': prim_fill_rect, + 'draw-rect': prim_draw_rect, + + # Lines and shapes + 'draw-line': prim_draw_line, + 'draw-circle': prim_draw_circle, + 'draw-ellipse': prim_draw_ellipse, + 'draw-polygon': prim_draw_polygon, +} diff --git a/sexp_effects/primitive_libs/filters.py b/sexp_effects/primitive_libs/filters.py new file mode 100644 index 0000000..a66f107 --- /dev/null +++ b/sexp_effects/primitive_libs/filters.py @@ -0,0 +1,119 @@ +""" +Filters Primitives Library + +Image filters: blur, sharpen, edges, convolution. +""" +import numpy as np +import cv2 + + +def prim_blur(img, radius): + """Gaussian blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img, radius): + """Box blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_median_blur(img, radius): + """Median blur (good for noise removal).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.medianBlur(img, ksize) + + +def prim_bilateral(img, d=9, sigma_color=75, sigma_space=75): + """Bilateral filter (edge-preserving blur).""" + return cv2.bilateralFilter(img, d, sigma_color, sigma_space) + + +def prim_sharpen(img, amount=1.0): + """Sharpen image using unsharp mask.""" + blurred = cv2.GaussianBlur(img, (0, 0), 3) + return cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0) + + +def prim_edges(img, low=50, high=150): + """Canny edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, low, high) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img, ksize=3): + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=ksize) + sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=ksize) + mag = np.sqrt(sobelx**2 + sobely**2) + mag = np.clip(mag, 0, 255).astype(np.uint8) + return cv2.cvtColor(mag, cv2.COLOR_GRAY2RGB) + + +def prim_laplacian(img, ksize=3): + """Laplacian edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + lap = cv2.Laplacian(gray, cv2.CV_64F, ksize=ksize) + lap = np.abs(lap) + lap = np.clip(lap, 0, 255).astype(np.uint8) + return cv2.cvtColor(lap, cv2.COLOR_GRAY2RGB) + + +def prim_emboss(img): + """Emboss effect.""" + kernel = np.array([[-2, -1, 0], + [-1, 1, 1], + [ 0, 1, 2]]) + result = cv2.filter2D(img, -1, kernel) + return np.clip(result + 128, 0, 255).astype(np.uint8) + + +def prim_dilate(img, size=1): + """Morphological dilation.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.dilate(img, kernel) + + +def prim_erode(img, size=1): + """Morphological erosion.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.erode(img, kernel) + + +def prim_convolve(img, kernel): + """Apply custom convolution kernel.""" + kernel = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, kernel) + + +PRIMITIVES = { + # Blur + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'median-blur': prim_median_blur, + 'bilateral': prim_bilateral, + + # Sharpen + 'sharpen': prim_sharpen, + + # Edges + 'edges': prim_edges, + 'sobel': prim_sobel, + 'laplacian': prim_laplacian, + + # Effects + 'emboss': prim_emboss, + + # Morphology + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Custom + 'convolve': prim_convolve, +} diff --git a/sexp_effects/primitive_libs/geometry.py b/sexp_effects/primitive_libs/geometry.py new file mode 100644 index 0000000..5b385a4 --- /dev/null +++ b/sexp_effects/primitive_libs/geometry.py @@ -0,0 +1,143 @@ +""" +Geometry Primitives Library + +Geometric transforms: rotate, scale, flip, translate, remap. +""" +import numpy as np +import cv2 + + +def prim_translate(img, dx, dy): + """Translate image by (dx, dy) pixels.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_rotate(img, angle, cx=None, cy=None): + """Rotate image by angle degrees around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_scale(img, sx, sy, cx=None, cy=None): + """Scale image by (sx, sy) around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Build transform matrix + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_flip_h(img): + """Flip image horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img): + """Flip image vertically.""" + return cv2.flip(img, 0) + + +def prim_flip(img, direction="horizontal"): + """Flip image in given direction.""" + if direction in ("horizontal", "h"): + return prim_flip_h(img) + elif direction in ("vertical", "v"): + return prim_flip_v(img) + elif direction in ("both", "hv", "vh"): + return cv2.flip(img, -1) + return img + + +def prim_transpose(img): + """Transpose image (swap x and y).""" + return np.transpose(img, (1, 0, 2)) + + +def prim_remap(img, map_x, map_y): + """Remap image using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), + map_y.astype(np.float32), + cv2.INTER_LINEAR) + + +def prim_make_coords(w, h): + """Create coordinate grids for remapping.""" + x = np.arange(w, dtype=np.float32) + y = np.arange(h, dtype=np.float32) + map_x, map_y = np.meshgrid(x, y) + return (map_x, map_y) + + +def prim_perspective(img, src_pts, dst_pts): + """Apply perspective transform.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getPerspectiveTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpPerspective(img, M, (w, h)) + + +def prim_affine(img, src_pts, dst_pts): + """Apply affine transform using 3 point pairs.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getAffineTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpAffine(img, M, (w, h)) + + +def _get_legacy_geometry_primitives(): + """Import geometry primitives from legacy primitives module.""" + from sexp_effects.primitives import ( + prim_coords_x, + prim_coords_y, + prim_ripple_displace, + prim_fisheye_displace, + prim_kaleidoscope_displace, + ) + return { + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'ripple-displace': prim_ripple_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + } + + +PRIMITIVES = { + # Basic transforms + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + + # Flips + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'flip': prim_flip, + 'transpose': prim_transpose, + + # Remapping + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Advanced transforms + 'perspective': prim_perspective, + 'affine': prim_affine, + + # Displace / coordinate ops (from legacy primitives) + **_get_legacy_geometry_primitives(), +} diff --git a/sexp_effects/primitive_libs/geometry_gpu.py b/sexp_effects/primitive_libs/geometry_gpu.py new file mode 100644 index 0000000..d4e3193 --- /dev/null +++ b/sexp_effects/primitive_libs/geometry_gpu.py @@ -0,0 +1,403 @@ +""" +GPU-Accelerated Geometry Primitives Library + +Uses CuPy for CUDA-accelerated image transforms. +Falls back to CPU if GPU unavailable. + +Performance Mode: +- Set STREAMING_GPU_PERSIST=1 to keep frames on GPU between operations +- This dramatically improves performance by avoiding CPU<->GPU transfers +- Frames only transfer to CPU at final output +""" +import os +import numpy as np + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + from cupyx.scipy import ndimage as cpndimage + GPU_AVAILABLE = True + print("[geometry_gpu] CuPy GPU acceleration enabled") +except ImportError: + cp = np + GPU_AVAILABLE = False + print("[geometry_gpu] CuPy not available, using CPU fallback") + +# GPU persistence mode - keep frames on GPU between operations +# Set STREAMING_GPU_PERSIST=1 for maximum performance +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" +if GPU_AVAILABLE and GPU_PERSIST: + print("[geometry_gpu] GPU persistence enabled - frames stay on GPU") + + +def _to_gpu(img): + """Move image to GPU if available.""" + if GPU_AVAILABLE and not isinstance(img, cp.ndarray): + return cp.asarray(img) + return img + + +def _to_cpu(img): + """Move image back to CPU (only if GPU_PERSIST is disabled).""" + if not GPU_PERSIST and GPU_AVAILABLE and isinstance(img, cp.ndarray): + return cp.asnumpy(img) + return img + + +def _ensure_output_format(img): + """Ensure output is in correct format based on GPU_PERSIST setting.""" + return _to_cpu(img) + + +def prim_rotate(img, angle, cx=None, cy=None): + """Rotate image by angle degrees around center (cx, cy). + + Uses fast CUDA kernel when available (< 1ms vs 20ms for scipy). + """ + if not GPU_AVAILABLE: + # Fallback to OpenCV + import cv2 + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + # Use fast CUDA kernel (prim_rotate_gpu defined below) + return prim_rotate_gpu(img, angle, cx, cy) + + +def prim_scale(img, sx, sy, cx=None, cy=None): + """Scale image by (sx, sy) around center (cx, cy).""" + if not GPU_AVAILABLE: + import cv2 + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h)) + + img_gpu = _to_gpu(img) + h, w = img_gpu.shape[:2] + + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Use cupyx.scipy.ndimage.zoom + if img_gpu.ndim == 3: + zoom_factors = (sy, sx, 1) # Don't zoom color channels + else: + zoom_factors = (sy, sx) + + zoomed = cpndimage.zoom(img_gpu, zoom_factors, order=1) + + # Crop/pad to original size + zh, zw = zoomed.shape[:2] + result = cp.zeros_like(img_gpu) + + # Calculate offsets + src_y = max(0, (zh - h) // 2) + src_x = max(0, (zw - w) // 2) + dst_y = max(0, (h - zh) // 2) + dst_x = max(0, (w - zw) // 2) + + copy_h = min(h - dst_y, zh - src_y) + copy_w = min(w - dst_x, zw - src_x) + + result[dst_y:dst_y+copy_h, dst_x:dst_x+copy_w] = zoomed[src_y:src_y+copy_h, src_x:src_x+copy_w] + + return _to_cpu(result) + + +def prim_translate(img, dx, dy): + """Translate image by (dx, dy) pixels.""" + if not GPU_AVAILABLE: + import cv2 + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h)) + + img_gpu = _to_gpu(img) + # Use cupyx.scipy.ndimage.shift + if img_gpu.ndim == 3: + shift = (dy, dx, 0) # Don't shift color channels + else: + shift = (dy, dx) + + shifted = cpndimage.shift(img_gpu, shift, order=1) + return _to_cpu(shifted) + + +def prim_flip_h(img): + """Flip image horizontally.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(img_gpu, axis=1)) + return np.flip(img, axis=1) + + +def prim_flip_v(img): + """Flip image vertically.""" + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(img_gpu, axis=0)) + return np.flip(img, axis=0) + + +def prim_flip(img, direction="horizontal"): + """Flip image in given direction.""" + if direction in ("horizontal", "h"): + return prim_flip_h(img) + elif direction in ("vertical", "v"): + return prim_flip_v(img) + elif direction in ("both", "hv", "vh"): + if GPU_AVAILABLE: + img_gpu = _to_gpu(img) + return _to_cpu(cp.flip(cp.flip(img_gpu, axis=0), axis=1)) + return np.flip(np.flip(img, axis=0), axis=1) + return img + + +# CUDA kernel for ripple effect +if GPU_AVAILABLE: + _ripple_kernel = cp.RawKernel(r''' + extern "C" __global__ + void ripple(const unsigned char* src, unsigned char* dst, + int width, int height, int channels, + float amplitude, float frequency, float decay, + float speed, float time, float cx, float cy) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Distance from center + float dx = x - cx; + float dy = y - cy; + float dist = sqrtf(dx * dx + dy * dy); + + // Ripple displacement + float wave = sinf(dist * frequency * 0.1f - time * speed) * amplitude; + float falloff = expf(-dist * decay * 0.01f); + float displacement = wave * falloff; + + // Direction from center + float len = dist + 0.0001f; // Avoid division by zero + float dir_x = dx / len; + float dir_y = dy / len; + + // Source coordinates + float src_x = x - dir_x * displacement; + float src_y = y - dir_y * displacement; + + // Clamp to bounds + src_x = fmaxf(0.0f, fminf(width - 1.0f, src_x)); + src_y = fmaxf(0.0f, fminf(height - 1.0f, src_y)); + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = min(x0 + 1, width - 1); + int y1 = min(y0 + 1, height - 1); + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float val = v0 * (1 - fy) + v1 * fy; + + dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val)); + } + } + ''', 'ripple') + + +def prim_ripple(img, amplitude=10.0, frequency=8.0, decay=2.0, speed=5.0, + time=0.0, center_x=None, center_y=None): + """Apply ripple distortion effect.""" + h, w = img.shape[:2] + channels = img.shape[2] if img.ndim == 3 else 1 + + if center_x is None: + center_x = w / 2 + if center_y is None: + center_y = h / 2 + + if not GPU_AVAILABLE: + # CPU fallback using coordinate mapping + import cv2 + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx**2 + dy**2) + + wave = np.sin(dist * frequency * 0.1 - time * speed) * amplitude + falloff = np.exp(-dist * decay * 0.01) + displacement = wave * falloff + + length = dist + 0.0001 + dir_x = dx / length + dir_y = dy / length + + map_x = (x_coords - dir_x * displacement).astype(np.float32) + map_y = (y_coords - dir_y * displacement).astype(np.float32) + + return cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR) + + # GPU implementation + img_gpu = _to_gpu(img.astype(np.uint8)) + if img_gpu.ndim == 2: + img_gpu = img_gpu[:, :, cp.newaxis] + channels = 1 + + dst = cp.zeros_like(img_gpu) + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _ripple_kernel(grid, block, ( + img_gpu, dst, + np.int32(w), np.int32(h), np.int32(channels), + np.float32(amplitude), np.float32(frequency), np.float32(decay), + np.float32(speed), np.float32(time), + np.float32(center_x), np.float32(center_y) + )) + + result = _to_cpu(dst) + if channels == 1: + result = result[:, :, 0] + return result + + +# CUDA kernel for fast rotation with bilinear interpolation +if GPU_AVAILABLE: + _rotate_kernel = cp.RawKernel(r''' + extern "C" __global__ + void rotate_img(const unsigned char* src, unsigned char* dst, + int width, int height, int channels, + float cos_a, float sin_a, float cx, float cy) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = blockDim.y * blockIdx.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Translate to center, rotate, translate back + float dx = x - cx; + float dy = y - cy; + + float src_x = cos_a * dx + sin_a * dy + cx; + float src_y = -sin_a * dx + cos_a * dy + cy; + + // Check bounds + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[(y * width + x) * channels + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = x0 + 1; + int y1 = y0 + 1; + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float val = v0 * (1 - fy) + v1 * fy; + + dst[(y * width + x) * channels + c] = (unsigned char)fminf(255.0f, fmaxf(0.0f, val)); + } + } + ''', 'rotate_img') + + +def prim_rotate_gpu(img, angle, cx=None, cy=None): + """Fast GPU rotation using custom CUDA kernel.""" + if not GPU_AVAILABLE: + return prim_rotate(img, angle, cx, cy) + + h, w = img.shape[:2] + channels = img.shape[2] if img.ndim == 3 else 1 + + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + img_gpu = _to_gpu(img.astype(np.uint8)) + if img_gpu.ndim == 2: + img_gpu = img_gpu[:, :, cp.newaxis] + channels = 1 + + dst = cp.zeros_like(img_gpu) + + # Convert angle to radians + rad = np.radians(angle) + cos_a = np.cos(rad) + sin_a = np.sin(rad) + + block = (16, 16) + grid = ((w + block[0] - 1) // block[0], (h + block[1] - 1) // block[1]) + + _rotate_kernel(grid, block, ( + img_gpu, dst, + np.int32(w), np.int32(h), np.int32(channels), + np.float32(cos_a), np.float32(sin_a), + np.float32(cx), np.float32(cy) + )) + + result = _to_cpu(dst) + if channels == 1: + result = result[:, :, 0] + return result + + +# Import CPU primitives as fallbacks for functions we don't GPU-accelerate +def _get_cpu_primitives(): + """Get all primitives from CPU geometry module as fallbacks.""" + from sexp_effects.primitive_libs import geometry + return geometry.PRIMITIVES + + +# Export functions - start with CPU primitives, then override with GPU versions +PRIMITIVES = _get_cpu_primitives().copy() + +# Override specific primitives with GPU-accelerated versions +PRIMITIVES.update({ + 'translate': prim_translate, + 'rotate': prim_rotate_gpu if GPU_AVAILABLE else prim_rotate, # Fast CUDA kernel + 'rotate-img': prim_rotate_gpu if GPU_AVAILABLE else prim_rotate, # Alias + 'scale-img': prim_scale, + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'flip': prim_flip, + 'ripple': prim_ripple, # Fast CUDA kernel + # Note: ripple-displace uses CPU version (different API - returns coords, not image) +}) diff --git a/sexp_effects/primitive_libs/image.py b/sexp_effects/primitive_libs/image.py new file mode 100644 index 0000000..2ab922c --- /dev/null +++ b/sexp_effects/primitive_libs/image.py @@ -0,0 +1,150 @@ +""" +Image Primitives Library + +Basic image operations: dimensions, pixels, resize, crop, paste. +""" +import numpy as np +import cv2 + + +def prim_width(img): + if isinstance(img, (list, tuple)): + raise TypeError(f"image:width expects an image array, got {type(img).__name__} with {len(img)} elements") + return img.shape[1] + + +def prim_height(img): + if isinstance(img, (list, tuple)): + import sys + print(f"DEBUG image:height got list: {img[:3]}... (types: {[type(x).__name__ for x in img[:3]]})", file=sys.stderr) + raise TypeError(f"image:height expects an image array, got {type(img).__name__} with {len(img)} elements: {img}") + return img.shape[0] + + +def prim_make_image(w, h, color=None): + """Create a new image filled with color (default black).""" + if color is None: + color = [0, 0, 0] + img = np.zeros((h, w, 3), dtype=np.uint8) + img[:] = color + return img + + +def prim_copy(img): + return img.copy() + + +def prim_pixel(img, x, y): + """Get pixel color at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + if 0 <= x < w and 0 <= y < h: + return list(img[int(y), int(x)]) + return [0, 0, 0] + + +def prim_set_pixel(img, x, y, color): + """Set pixel at (x, y) to color, returns modified image.""" + result = img.copy() + h, w = result.shape[:2] + if 0 <= x < w and 0 <= y < h: + result[int(y), int(x)] = color + return result + + +def prim_sample(img, x, y): + """Bilinear sample at float coordinates, returns [r, g, b] as floats.""" + h, w = img.shape[:2] + x = max(0, min(w - 1.001, x)) + y = max(0, min(h - 1.001, y)) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + top = c00 * (1 - fx) + c10 * fx + bottom = c01 * (1 - fx) + c11 * fx + return list(top * (1 - fy) + bottom * fy) + + +def prim_channel(img, c): + """Extract single channel (0=R, 1=G, 2=B).""" + return img[:, :, c] + + +def prim_merge_channels(r, g, b): + """Merge three single-channel arrays into RGB image.""" + return np.stack([r, g, b], axis=2).astype(np.uint8) + + +def prim_resize(img, w, h, mode="linear"): + """Resize image to w x h.""" + interp = cv2.INTER_LINEAR + if mode == "nearest": + interp = cv2.INTER_NEAREST + elif mode == "cubic": + interp = cv2.INTER_CUBIC + elif mode == "area": + interp = cv2.INTER_AREA + return cv2.resize(img, (int(w), int(h)), interpolation=interp) + + +def prim_crop(img, x, y, w, h): + """Crop rectangle from image.""" + x, y, w, h = int(x), int(y), int(w), int(h) + ih, iw = img.shape[:2] + x = max(0, min(x, iw - 1)) + y = max(0, min(y, ih - 1)) + w = min(w, iw - x) + h = min(h, ih - y) + return img[y:y+h, x:x+w].copy() + + +def prim_paste(dst, src, x, y): + """Paste src onto dst at position (x, y).""" + result = dst.copy() + x, y = int(x), int(y) + sh, sw = src.shape[:2] + dh, dw = dst.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + if sx2 > sx1 and sy2 > sy1: + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = src[sy1:sy2, sx1:sx2] + + return result + + +PRIMITIVES = { + # Dimensions + 'width': prim_width, + 'height': prim_height, + + # Creation + 'make-image': prim_make_image, + 'copy': prim_copy, + + # Pixel access + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + + # Channels + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + + # Geometry + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, +} diff --git a/sexp_effects/primitive_libs/math.py b/sexp_effects/primitive_libs/math.py new file mode 100644 index 0000000..140ad3e --- /dev/null +++ b/sexp_effects/primitive_libs/math.py @@ -0,0 +1,164 @@ +""" +Math Primitives Library + +Trigonometry, rounding, clamping, random numbers, etc. +""" +import math +import random as rand_module + + +def prim_sin(x): + return math.sin(x) + + +def prim_cos(x): + return math.cos(x) + + +def prim_tan(x): + return math.tan(x) + + +def prim_asin(x): + return math.asin(x) + + +def prim_acos(x): + return math.acos(x) + + +def prim_atan(x): + return math.atan(x) + + +def prim_atan2(y, x): + return math.atan2(y, x) + + +def prim_sqrt(x): + return math.sqrt(x) + + +def prim_pow(x, y): + return math.pow(x, y) + + +def prim_exp(x): + return math.exp(x) + + +def prim_log(x, base=None): + if base is None: + return math.log(x) + return math.log(x, base) + + +def prim_abs(x): + return abs(x) + + +def prim_floor(x): + return math.floor(x) + + +def prim_ceil(x): + return math.ceil(x) + + +def prim_round(x): + return round(x) + + +def prim_min(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return min(args[0]) + return min(args) + + +def prim_max(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return max(args[0]) + return max(args) + + +def prim_clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +def prim_lerp(a, b, t): + """Linear interpolation: a + (b - a) * t""" + return a + (b - a) * t + + +def prim_smoothstep(edge0, edge1, x): + """Smooth interpolation between 0 and 1.""" + t = prim_clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +def prim_random(lo=0.0, hi=1.0): + return rand_module.uniform(lo, hi) + + +def prim_randint(lo, hi): + return rand_module.randint(lo, hi) + + +def prim_gaussian(mean=0.0, std=1.0): + return rand_module.gauss(mean, std) + + +def prim_sign(x): + if x > 0: + return 1 + elif x < 0: + return -1 + return 0 + + +def prim_fract(x): + """Fractional part of x.""" + return x - math.floor(x) + + +PRIMITIVES = { + # Trigonometry + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'asin': prim_asin, + 'acos': prim_acos, + 'atan': prim_atan, + 'atan2': prim_atan2, + + # Powers and roots + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'exp': prim_exp, + 'log': prim_log, + + # Rounding + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'sign': prim_sign, + 'fract': prim_fract, + + # Min/max/clamp + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'smoothstep': prim_smoothstep, + + # Random + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + + # Constants + 'pi': math.pi, + 'tau': math.tau, + 'e': math.e, +} diff --git a/sexp_effects/primitive_libs/streaming.py b/sexp_effects/primitive_libs/streaming.py new file mode 100644 index 0000000..ccb6056 --- /dev/null +++ b/sexp_effects/primitive_libs/streaming.py @@ -0,0 +1,593 @@ +""" +Streaming primitives for video/audio processing. + +These primitives handle video source reading and audio analysis, +keeping the interpreter completely generic. + +GPU Acceleration: +- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU) +- Hardware video decoding (NVDEC) is used when available +- Dramatically improves performance on GPU nodes + +Async Prefetching: +- Set STREAMING_PREFETCH=1 to enable background frame prefetching +- Decodes upcoming frames while current frame is being processed +""" + +import os +import numpy as np +import subprocess +import json +import threading +from collections import deque +from pathlib import Path + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + +# GPU persistence mode - output CuPy arrays instead of numpy +# Disabled by default until all primitives support GPU frames +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE + +# Async prefetch mode - decode frames in background thread +PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1" +PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10")) + +# Check for hardware decode support (cached) +_HWDEC_AVAILABLE = None + + +def _check_hwdec(): + """Check if NVIDIA hardware decode is available.""" + global _HWDEC_AVAILABLE + if _HWDEC_AVAILABLE is not None: + return _HWDEC_AVAILABLE + + try: + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2) + if result.returncode != 0: + _HWDEC_AVAILABLE = False + return False + result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5) + _HWDEC_AVAILABLE = "cuda" in result.stdout + except Exception: + _HWDEC_AVAILABLE = False + + return _HWDEC_AVAILABLE + + +class VideoSource: + """Video source with persistent streaming pipe for fast sequential reads.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + self.fps = fps # Output fps for the stream + self._frame_size = None + self._duration = None + self._proc = None # Persistent ffmpeg process + self._stream_time = 0.0 # Current position in stream + self._frame_time = 1.0 / fps # Time per frame at output fps + self._last_read_time = -1 + self._cached_frame = None + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + # Get video info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}") + try: + info = json.loads(result.stdout) + except json.JSONDecodeError: + raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}") + + for stream in info.get("streams", []): + if stream.get("codec_type") == "video": + self._frame_size = (stream.get("width", 720), stream.get("height", 720)) + # Try direct duration field first + if "duration" in stream: + self._duration = float(stream["duration"]) + # Fall back to tags.DURATION (webm format: "00:01:00.124000000") + elif "tags" in stream and "DURATION" in stream["tags"]: + dur_str = stream["tags"]["DURATION"] + parts = dur_str.split(":") + if len(parts) == 3: + h, m, s = parts + self._duration = int(h) * 3600 + int(m) * 60 + float(s) + break + + # Fallback: check format duration if stream duration not found + if self._duration is None and "format" in info and "duration" in info["format"]: + self._duration = float(info["format"]["duration"]) + + if not self._frame_size: + self._frame_size = (720, 720) + + import sys + print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr) + + def _start_stream(self, seek_time: float = 0): + """Start or restart the ffmpeg streaming process. + + Uses NVIDIA hardware decoding (NVDEC) when available for better performance. + """ + if self._proc: + self._proc.kill() + self._proc = None + + # Check file exists before trying to open + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + w, h = self._frame_size + + # Build ffmpeg command with optional hardware decode + cmd = ["ffmpeg", "-v", "error"] + + # Use hardware decode if available (significantly faster) + if _check_hwdec(): + cmd.extend(["-hwaccel", "cuda"]) + + cmd.extend([ + "-ss", f"{seek_time:.3f}", + "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(self.fps), # Output at specified fps + "-" + ]) + + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._stream_time = seek_time + + # Check if process started successfully by reading first bit of stderr + import select + import sys + readable, _, _ = select.select([self._proc.stderr], [], [], 0.5) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr) + + def _read_frame_from_stream(self): + """Read one frame from the stream. + + Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise. + """ + w, h = self._frame_size + frame_size = w * h * 3 + + if not self._proc or self._proc.poll() is not None: + return None + + data = self._proc.stdout.read(frame_size) + if len(data) < frame_size: + return None + + frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() + + # Transfer to GPU if persistence mode enabled + if GPU_PERSIST: + return cp.asarray(frame) + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + if self._cached_frame is not None: + return self._cached_frame + return self.read_at(0) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time using streaming with smart seeking.""" + # Cache check - return same frame for same time + if t == self._last_read_time and self._cached_frame is not None: + return self._cached_frame + + w, h = self._frame_size + + # Loop time if video is shorter + seek_time = t + if self._duration and self._duration > 0: + seek_time = t % self._duration + # If we're within 0.1s of the end, wrap to beginning to avoid EOF issues + if seek_time > self._duration - 0.1: + seek_time = 0.0 + + # Decide whether to seek or continue streaming + # Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead + # Allow small backward tolerance to handle floating point and timing jitter + need_seek = ( + self._proc is None or + self._proc.poll() is not None or + seek_time < self._stream_time - self._frame_time or # More than 1 frame backward + seek_time > self._stream_time + 2.0 + ) + + if need_seek: + import sys + reason = "no proc" if self._proc is None else "proc dead" if self._proc.poll() is not None else "backward" if seek_time < self._stream_time else "jump" + print(f"SEEK {self.path.name}: t={t:.4f} seek={seek_time:.4f} stream={self._stream_time:.4f} ({reason})", file=sys.stderr) + self._start_stream(seek_time) + + # Skip frames to reach target time + skip_retries = 0 + while self._stream_time + self._frame_time <= seek_time: + frame = self._read_frame_from_stream() + if frame is None: + # Stream ended or failed - restart from seek point + import time + skip_retries += 1 + if skip_retries > 3: + # Give up skipping, just start fresh at seek_time + self._start_stream(seek_time) + time.sleep(0.1) + break + self._start_stream(seek_time) + time.sleep(0.05) + continue + self._stream_time += self._frame_time + skip_retries = 0 # Reset on successful read + + # Read the target frame with retry logic + frame = None + max_retries = 3 + for attempt in range(max_retries): + frame = self._read_frame_from_stream() + if frame is not None: + break + + # Stream failed - try restarting + import sys + import time + print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr) + + # Check for ffmpeg errors + if self._proc and self._proc.stderr: + try: + import select + readable, _, _ = select.select([self._proc.stderr], [], [], 0.1) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error: {err}", file=sys.stderr) + except: + pass + + # Wait a bit and restart + time.sleep(0.1) + self._start_stream(seek_time) + + # Give ffmpeg time to start + time.sleep(0.1) + + if frame is None: + import sys + raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries") + else: + self._stream_time += self._frame_time + + self._last_read_time = t + self._cached_frame = frame + return frame + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._frame_size + + def close(self): + if self._proc: + self._proc.kill() + self._proc = None + + +class PrefetchingVideoSource: + """ + Video source with background prefetching for improved performance. + + Wraps VideoSource and adds a background thread that pre-decodes + upcoming frames while the main thread processes the current frame. + """ + + def __init__(self, path: str, fps: float = 30, buffer_size: int = None): + self._source = VideoSource(path, fps) + self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE + self._buffer = {} # time -> frame + self._buffer_lock = threading.Lock() + self._prefetch_time = 0.0 + self._frame_time = 1.0 / fps + self._stop_event = threading.Event() + self._request_event = threading.Event() + self._target_time = 0.0 + + # Start prefetch thread + self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) + self._thread.start() + + import sys + print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr) + + def _prefetch_loop(self): + """Background thread that pre-reads frames.""" + while not self._stop_event.is_set(): + # Wait for work or timeout + self._request_event.wait(timeout=0.01) + self._request_event.clear() + + if self._stop_event.is_set(): + break + + # Prefetch frames ahead of target time + target = self._target_time + with self._buffer_lock: + # Clean old frames (more than 1 second behind) + old_times = [t for t in self._buffer.keys() if t < target - 1.0] + for t in old_times: + del self._buffer[t] + + # Count how many frames we have buffered ahead + buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target) + + # Prefetch if buffer not full + if buffered_ahead < self._buffer_size: + # Find next time to prefetch + prefetch_t = target + with self._buffer_lock: + existing_times = set(self._buffer.keys()) + for _ in range(self._buffer_size): + if prefetch_t not in existing_times: + break + prefetch_t += self._frame_time + + # Read the frame (this is the slow part) + try: + frame = self._source.read_at(prefetch_t) + with self._buffer_lock: + self._buffer[prefetch_t] = frame + except Exception as e: + import sys + print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time, using prefetch buffer if available.""" + self._target_time = t + self._request_event.set() # Wake up prefetch thread + + # Round to frame time for buffer lookup + t_key = round(t / self._frame_time) * self._frame_time + + # Check buffer first + with self._buffer_lock: + if t_key in self._buffer: + return self._buffer[t_key] + # Also check for close matches (within half frame time) + for buf_t, frame in self._buffer.items(): + if abs(buf_t - t) < self._frame_time * 0.5: + return frame + + # Not in buffer - read directly (blocking) + frame = self._source.read_at(t) + + # Store in buffer + with self._buffer_lock: + self._buffer[t_key] = frame + + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + return self.read_at(0) + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._source.size + + @property + def path(self): + return self._source.path + + def close(self): + self._stop_event.set() + self._request_event.set() # Wake up thread to exit + self._thread.join(timeout=1.0) + self._source.close() + + +class AudioAnalyzer: + """Audio analyzer for energy and beat detection.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + self.sample_rate = sample_rate + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Audio file not found: {self.path}") + + # Load audio via ffmpeg + cmd = ["ffmpeg", "-v", "error", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}") + self._audio = np.frombuffer(result.stdout, dtype=np.float32) + if len(self._audio) == 0: + raise RuntimeError(f"Audio file is empty or invalid: {self.path}") + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}") + info = json.loads(result.stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + # Beat detection state + self._flux_history = [] + self._last_beat_time = -1 + self._beat_count = 0 + self._last_beat_check_time = -1 + # Cache beat result for current time (so multiple scans see same result) + self._beat_cache_time = -1 + self._beat_cache_result = False + + def get_energy(self, t: float) -> float: + """Get energy level at time t (0-1).""" + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + """Check if there's a beat at time t.""" + # Return cached result if same time (multiple scans query same frame) + if t == self._beat_cache_time: + return self._beat_cache_result + + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size/2: + self._beat_cache_time = t + self._beat_cache_result = False + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + self._beat_cache_time = t + self._beat_cache_result = False + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + if len(self._flux_history) > 50: + self._flux_history = self._flux_history[-50:] + + if len(self._flux_history) < 5: + self._beat_cache_time = t + self._beat_cache_result = False + return False + + recent = [f for _, f in self._flux_history[-20:]] + threshold = np.mean(recent) + 1.5 * np.std(recent) + + is_beat = flux > threshold and (t - self._last_beat_time) > 0.1 + if is_beat: + self._last_beat_time = t + if t > self._last_beat_check_time: + self._beat_count += 1 + self._last_beat_check_time = t + + # Cache result for this time + self._beat_cache_time = t + self._beat_cache_result = is_beat + return is_beat + + def get_beat_count(self, t: float) -> int: + """Get cumulative beat count up to time t.""" + # Ensure beat detection has run up to this time + self.get_beat(t) + return self._beat_count + + +# === Primitives === + +def prim_make_video_source(path: str, fps: float = 30): + """Create a video source from a file path. + + Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default). + """ + if PREFETCH_ENABLED: + return PrefetchingVideoSource(path, fps) + return VideoSource(path, fps) + + +def prim_source_read(source: VideoSource, t: float = None): + """Read a frame from a video source.""" + import sys + if t is not None: + frame = source.read_at(t) + # Debug: show source and time + if int(t * 10) % 10 == 0: # Every second + print(f"READ {source.path.name}: t={t:.2f} stream={source._stream_time:.2f}", file=sys.stderr) + return frame + return source.read() + + +def prim_source_skip(source: VideoSource): + """Skip a frame (keep pipe in sync).""" + source.skip() + + +def prim_source_size(source: VideoSource): + """Get (width, height) of source.""" + return source.size + + +def prim_make_audio_analyzer(path: str): + """Create an audio analyzer from a file path.""" + return AudioAnalyzer(path) + + +def prim_audio_energy(analyzer: AudioAnalyzer, t: float) -> float: + """Get energy level (0-1) at time t.""" + return analyzer.get_energy(t) + + +def prim_audio_beat(analyzer: AudioAnalyzer, t: float) -> bool: + """Check if there's a beat at time t.""" + return analyzer.get_beat(t) + + +def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int: + """Get cumulative beat count up to time t.""" + return analyzer.get_beat_count(t) + + +def prim_audio_duration(analyzer: AudioAnalyzer) -> float: + """Get audio duration in seconds.""" + return analyzer.duration + + +# Export primitives +PRIMITIVES = { + # Video source + 'make-video-source': prim_make_video_source, + 'source-read': prim_source_read, + 'source-skip': prim_source_skip, + 'source-size': prim_source_size, + + # Audio analyzer + 'make-audio-analyzer': prim_make_audio_analyzer, + 'audio-energy': prim_audio_energy, + 'audio-beat': prim_audio_beat, + 'audio-beat-count': prim_audio_beat_count, + 'audio-duration': prim_audio_duration, +} diff --git a/sexp_effects/primitive_libs/streaming_gpu.py b/sexp_effects/primitive_libs/streaming_gpu.py new file mode 100644 index 0000000..f2aa7ea --- /dev/null +++ b/sexp_effects/primitive_libs/streaming_gpu.py @@ -0,0 +1,1165 @@ +""" +GPU-Accelerated Streaming Primitives + +Provides GPU-native video source and frame processing. +Frames stay on GPU memory throughout the pipeline for maximum performance. + +Architecture: +- GPUFrame: Wrapper that tracks whether data is on CPU or GPU +- GPUVideoSource: Hardware-accelerated decode to GPU memory +- GPU primitives operate directly on GPU frames using fast CUDA kernels +- Transfer to CPU only at final output + +Requirements: +- CuPy for CUDA support +- FFmpeg with NVDEC support (for hardware decode) +- NVIDIA GPU with CUDA capability +""" + +import os +import sys +import json +import subprocess +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, Union + +# Try to import CuPy +try: + import cupy as cp + GPU_AVAILABLE = True +except ImportError: + cp = None + GPU_AVAILABLE = False + +# Try to import fast CUDA kernels from JIT compiler +_FAST_KERNELS_AVAILABLE = False +try: + if GPU_AVAILABLE: + from streaming.jit_compiler import ( + fast_rotate, fast_zoom, fast_blend, fast_hue_shift, + fast_invert, fast_ripple, get_fast_ops + ) + _FAST_KERNELS_AVAILABLE = True + print("[streaming_gpu] Fast CUDA kernels loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] Fast kernels not available: {e}", file=sys.stderr) + +# Check for hardware decode support +_HWDEC_AVAILABLE: Optional[bool] = None +_DECORD_GPU_AVAILABLE: Optional[bool] = None + + +def check_hwdec_available() -> bool: + """Check if NVIDIA hardware decode is available.""" + global _HWDEC_AVAILABLE + if _HWDEC_AVAILABLE is not None: + return _HWDEC_AVAILABLE + + try: + # Check for nvidia-smi (GPU present) + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2) + if result.returncode != 0: + _HWDEC_AVAILABLE = False + return False + + # Check for nvdec in ffmpeg + result = subprocess.run( + ["ffmpeg", "-hwaccels"], + capture_output=True, + text=True, + timeout=5 + ) + _HWDEC_AVAILABLE = "cuda" in result.stdout + except Exception: + _HWDEC_AVAILABLE = False + + return _HWDEC_AVAILABLE + + +def check_decord_gpu_available() -> bool: + """Check if decord with CUDA GPU decode is available.""" + global _DECORD_GPU_AVAILABLE + if _DECORD_GPU_AVAILABLE is not None: + return _DECORD_GPU_AVAILABLE + + try: + import decord + from decord import gpu + # Try to create a GPU context to verify CUDA support + ctx = gpu(0) + _DECORD_GPU_AVAILABLE = True + print("[streaming_gpu] decord GPU (CUDA) decode available", file=sys.stderr) + except Exception as e: + _DECORD_GPU_AVAILABLE = False + print(f"[streaming_gpu] decord GPU not available: {e}", file=sys.stderr) + + return _DECORD_GPU_AVAILABLE + + +class GPUFrame: + """ + Frame container that tracks data location (CPU/GPU). + + Enables zero-copy operations when data is already on the right device. + Lazy transfer - only moves data when actually needed. + """ + + def __init__(self, data: Union[np.ndarray, 'cp.ndarray'], on_gpu: bool = None): + self._cpu_data: Optional[np.ndarray] = None + self._gpu_data = None # Optional[cp.ndarray] + + if on_gpu is None: + # Auto-detect based on type + if GPU_AVAILABLE and isinstance(data, cp.ndarray): + self._gpu_data = data + else: + self._cpu_data = np.asarray(data) + elif on_gpu and GPU_AVAILABLE: + self._gpu_data = cp.asarray(data) if not isinstance(data, cp.ndarray) else data + else: + self._cpu_data = np.asarray(data) if isinstance(data, np.ndarray) else cp.asnumpy(data) + + @property + def cpu(self) -> np.ndarray: + """Get frame as numpy array (transfers from GPU if needed).""" + if self._cpu_data is None: + if self._gpu_data is not None and GPU_AVAILABLE: + self._cpu_data = cp.asnumpy(self._gpu_data) + else: + raise ValueError("No frame data available") + return self._cpu_data + + @property + def gpu(self): + """Get frame as CuPy array (transfers to GPU if needed).""" + if not GPU_AVAILABLE: + raise RuntimeError("GPU not available") + if self._gpu_data is None: + if self._cpu_data is not None: + self._gpu_data = cp.asarray(self._cpu_data) + else: + raise ValueError("No frame data available") + return self._gpu_data + + @property + def is_on_gpu(self) -> bool: + """Check if data is currently on GPU.""" + return self._gpu_data is not None + + @property + def shape(self) -> Tuple[int, ...]: + """Get frame shape.""" + if self._gpu_data is not None: + return self._gpu_data.shape + return self._cpu_data.shape + + @property + def dtype(self): + """Get frame dtype.""" + if self._gpu_data is not None: + return self._gpu_data.dtype + return self._cpu_data.dtype + + def numpy(self) -> np.ndarray: + """Alias for cpu property.""" + return self.cpu + + def cupy(self): + """Alias for gpu property.""" + return self.gpu + + def free_cpu(self): + """Free CPU memory (keep GPU only).""" + if self._gpu_data is not None: + self._cpu_data = None + + def free_gpu(self): + """Free GPU memory (keep CPU only).""" + if self._cpu_data is not None: + self._gpu_data = None + + +class GPUVideoSource: + """ + GPU-accelerated video source using hardware decode. + + Uses decord with CUDA GPU context for true NVDEC decode - frames + decode directly to GPU memory via CUDA. + + Falls back to FFmpeg pipe if decord GPU unavailable (slower due to CPU copy). + """ + + def __init__(self, path: str, fps: float = 30, prefer_gpu: bool = True): + self.path = Path(path) + self.fps = fps + self.prefer_gpu = prefer_gpu and GPU_AVAILABLE + self._use_decord_gpu = self.prefer_gpu and check_decord_gpu_available() + + self._frame_size: Optional[Tuple[int, int]] = None + self._duration: Optional[float] = None + self._video_fps: float = 30.0 + self._total_frames: int = 0 + self._frame_time = 1.0 / fps + self._last_read_time = -1 + self._cached_frame: Optional[GPUFrame] = None + + # Decord VideoReader with GPU context + self._vr = None + self._decord_ctx = None + + # FFmpeg fallback state + self._proc = None + self._stream_time = 0.0 + + # Initialize video source + self._init_video() + + mode = "decord-GPU" if self._use_decord_gpu else ("ffmpeg-hwaccel" if check_hwdec_available() else "ffmpeg-CPU") + print(f"[GPUVideoSource] {self.path.name}: {self._frame_size}, " + f"duration={self._duration:.1f}s, mode={mode}", file=sys.stderr) + + def _init_video(self): + """Initialize video reader (decord GPU or probe for ffmpeg).""" + if self._use_decord_gpu: + try: + from decord import VideoReader, gpu + + # Use GPU context for NVDEC hardware decode + self._decord_ctx = gpu(0) + self._vr = VideoReader(str(self.path), ctx=self._decord_ctx, num_threads=1) + + self._total_frames = len(self._vr) + self._video_fps = self._vr.get_avg_fps() + self._duration = self._total_frames / self._video_fps + + # Get frame size from first frame + first_frame = self._vr[0] + self._frame_size = (first_frame.shape[1], first_frame.shape[0]) + + print(f"[GPUVideoSource] decord GPU initialized: {self._frame_size}, " + f"{self._total_frames} frames @ {self._video_fps:.1f}fps", file=sys.stderr) + return + except Exception as e: + print(f"[GPUVideoSource] decord GPU init failed, falling back to ffmpeg: {e}", file=sys.stderr) + self._use_decord_gpu = False + self._vr = None + self._decord_ctx = None + + # FFmpeg fallback - probe video for metadata + self._probe_video() + + def _probe_video(self): + """Probe video file for metadata (FFmpeg fallback).""" + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + info = json.loads(result.stdout) + + for stream in info.get("streams", []): + if stream.get("codec_type") == "video": + self._frame_size = (stream.get("width", 720), stream.get("height", 720)) + if "duration" in stream: + self._duration = float(stream["duration"]) + elif "tags" in stream and "DURATION" in stream["tags"]: + dur_str = stream["tags"]["DURATION"] + parts = dur_str.split(":") + if len(parts) == 3: + h, m, s = parts + self._duration = int(h) * 3600 + int(m) * 60 + float(s) + # Get fps + if "r_frame_rate" in stream: + fps_str = stream["r_frame_rate"] + if "/" in fps_str: + num, den = fps_str.split("/") + self._video_fps = float(num) / float(den) + break + + if self._duration is None and "format" in info: + if "duration" in info["format"]: + self._duration = float(info["format"]["duration"]) + + if not self._frame_size: + self._frame_size = (720, 720) + if not self._duration: + self._duration = 60.0 + + self._total_frames = int(self._duration * self._video_fps) + + def _start_stream(self, seek_time: float = 0): + """Start ffmpeg decode process (fallback mode).""" + if self._proc: + self._proc.kill() + self._proc = None + + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + w, h = self._frame_size + + # Build ffmpeg command + cmd = ["ffmpeg", "-v", "error"] + + # Hardware decode if available + if check_hwdec_available(): + cmd.extend(["-hwaccel", "cuda"]) + + cmd.extend([ + "-ss", f"{seek_time:.3f}", + "-i", str(self.path), + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(self.fps), + "-" + ]) + + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._stream_time = seek_time + + def _read_frame_raw(self) -> Optional[np.ndarray]: + """Read one frame from ffmpeg pipe (fallback mode).""" + w, h = self._frame_size + frame_size = w * h * 3 + + if not self._proc or self._proc.poll() is not None: + return None + + data = self._proc.stdout.read(frame_size) + if len(data) < frame_size: + return None + + return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() + + def _read_frame_decord_gpu(self, frame_idx: int) -> Optional[GPUFrame]: + """Read frame using decord with GPU context (NVDEC, zero-copy to CuPy).""" + if self._vr is None: + return None + + try: + # Handle looping + frame_idx = frame_idx % max(1, self._total_frames) + + # Decode frame - with GPU context, this uses NVDEC + frame_tensor = self._vr[frame_idx] + + # Convert to CuPy via DLPack (zero-copy GPU transfer) + if GPU_AVAILABLE: + # decord tensors have .to_dlpack() which returns a PyCapsule + # that CuPy can consume for zero-copy GPU transfer + try: + dlpack_capsule = frame_tensor.to_dlpack() + gpu_frame = cp.from_dlpack(dlpack_capsule) + # Log success once per source + if not getattr(self, '_dlpack_logged', False): + print(f"[GPUVideoSource] DLPack zero-copy SUCCESS - frames stay on GPU", file=sys.stderr) + self._dlpack_logged = True + return GPUFrame(gpu_frame, on_gpu=True) + except Exception as dlpack_err: + # Fallback: convert via numpy (involves CPU copy) + if not getattr(self, '_dlpack_fail_logged', False): + print(f"[GPUVideoSource] DLPack FAILED ({dlpack_err}), using CPU copy fallback", file=sys.stderr) + self._dlpack_fail_logged = True + frame_np = frame_tensor.asnumpy() + return GPUFrame(frame_np, on_gpu=True) + else: + return GPUFrame(frame_tensor.asnumpy(), on_gpu=False) + + except Exception as e: + print(f"[GPUVideoSource] decord GPU read error at frame {frame_idx}: {e}", file=sys.stderr) + return None + + def read_at(self, t: float) -> Optional[GPUFrame]: + """ + Read frame at specific time. + + Returns GPUFrame with data on GPU if GPU mode enabled. + """ + # Cache check + if t == self._last_read_time and self._cached_frame is not None: + return self._cached_frame + + # Loop time for shorter videos + seek_time = t + if self._duration and self._duration > 0: + seek_time = t % self._duration + if seek_time > self._duration - 0.1: + seek_time = 0.0 + + self._last_read_time = t + + # Use decord GPU if available (NVDEC decode, zero-copy via DLPack) + if self._use_decord_gpu: + frame_idx = int(seek_time * self._video_fps) + self._cached_frame = self._read_frame_decord_gpu(frame_idx) + if self._cached_frame is not None: + # Free CPU copy if on GPU (saves memory) + if self.prefer_gpu and self._cached_frame.is_on_gpu: + self._cached_frame.free_cpu() + return self._cached_frame + + # FFmpeg fallback + need_seek = ( + self._proc is None or + self._proc.poll() is not None or + seek_time < self._stream_time - self._frame_time or + seek_time > self._stream_time + 2.0 + ) + + if need_seek: + self._start_stream(seek_time) + + # Skip frames to reach target + while self._stream_time + self._frame_time <= seek_time: + frame = self._read_frame_raw() + if frame is None: + self._start_stream(seek_time) + break + self._stream_time += self._frame_time + + # Read target frame + frame_np = self._read_frame_raw() + if frame_np is None: + return self._cached_frame + + self._stream_time += self._frame_time + + # Create GPUFrame - transfer to GPU if in GPU mode + self._cached_frame = GPUFrame(frame_np, on_gpu=self.prefer_gpu) + + # Free CPU copy if on GPU (saves memory) + if self.prefer_gpu and self._cached_frame.is_on_gpu: + self._cached_frame.free_cpu() + + return self._cached_frame + + def read(self) -> Optional[GPUFrame]: + """Read current frame.""" + if self._cached_frame is not None: + return self._cached_frame + return self.read_at(0) + + @property + def size(self) -> Tuple[int, int]: + return self._frame_size + + @property + def duration(self) -> float: + return self._duration + + def close(self): + """Close the video source.""" + if self._proc: + self._proc.kill() + self._proc = None + # Release decord resources + self._vr = None + self._decord_ctx = None + + +# GPU-aware primitive functions + +def gpu_blend(frame_a: GPUFrame, frame_b: GPUFrame, alpha: float = 0.5) -> GPUFrame: + """ + Blend two frames on GPU using fast CUDA kernel. + + Both frames stay on GPU throughout - no CPU transfer. + """ + if not GPU_AVAILABLE: + a = frame_a.cpu.astype(np.float32) + b = frame_b.cpu.astype(np.float32) + result = (a * alpha + b * (1 - alpha)).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + # Use fast CUDA kernel + if _FAST_KERNELS_AVAILABLE: + a_gpu = frame_a.gpu + b_gpu = frame_b.gpu + if a_gpu.dtype != cp.uint8: + a_gpu = cp.clip(a_gpu, 0, 255).astype(cp.uint8) + if b_gpu.dtype != cp.uint8: + b_gpu = cp.clip(b_gpu, 0, 255).astype(cp.uint8) + result = fast_blend(a_gpu, b_gpu, alpha) + return GPUFrame(result, on_gpu=True) + + # Fallback + a = frame_a.gpu.astype(cp.float32) + b = frame_b.gpu.astype(cp.float32) + result = (a * alpha + b * (1 - alpha)).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_resize(frame: GPUFrame, size: Tuple[int, int]) -> GPUFrame: + """Resize frame on GPU using fast CUDA zoom kernel.""" + import cv2 + + if not GPU_AVAILABLE or not frame.is_on_gpu: + resized = cv2.resize(frame.cpu, size) + return GPUFrame(resized, on_gpu=False) + + gpu_data = frame.gpu + h, w = gpu_data.shape[:2] + target_w, target_h = size + + # Use fast zoom kernel if same aspect ratio (pure zoom) + if _FAST_KERNELS_AVAILABLE and target_w == target_h == w == h: + # For uniform zoom we can use the zoom kernel + pass # Fall through to scipy for now - full resize needs different approach + + # CuPy doesn't have built-in resize, use scipy zoom + from cupyx.scipy import ndimage as cpndimage + + zoom_y = target_h / h + zoom_x = target_w / w + + if gpu_data.ndim == 3: + resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x, 1), order=1) + else: + resized = cpndimage.zoom(gpu_data, (zoom_y, zoom_x), order=1) + + return GPUFrame(resized, on_gpu=True) + + +def gpu_zoom(frame: GPUFrame, factor: float, cx: float = None, cy: float = None) -> GPUFrame: + """Zoom frame on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + h, w = frame.cpu.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), 0, factor) + zoomed = cv2.warpAffine(frame.cpu, M, (w, h)) + return GPUFrame(zoomed, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + zoomed = fast_zoom(frame.gpu, factor, cx=cx, cy=cy) + return GPUFrame(zoomed, on_gpu=True) + + # Fallback - basic zoom via slice and resize + return frame + + +def gpu_hue_shift(frame: GPUFrame, degrees: float) -> GPUFrame: + """Shift hue on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + hsv = cv2.cvtColor(frame.cpu, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.float32) + degrees / 2) % 180 + result = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + shifted = fast_hue_shift(gpu_data, degrees) + return GPUFrame(shifted, on_gpu=True) + + # Fallback - no GPU hue shift without fast kernels + return frame + + +def gpu_invert(frame: GPUFrame) -> GPUFrame: + """Invert colors on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = 255 - frame.cpu + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + inverted = fast_invert(gpu_data) + return GPUFrame(inverted, on_gpu=True) + + # Fallback - basic CuPy invert + result = 255 - frame.gpu + return GPUFrame(result, on_gpu=True) + + +def gpu_ripple(frame: GPUFrame, amplitude: float, frequency: float = 8, + decay: float = 2, phase: float = 0, + cx: float = None, cy: float = None) -> GPUFrame: + """Apply ripple effect on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + return frame # No CPU fallback for ripple + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + rippled = fast_ripple( + gpu_data, amplitude, + center_x=cx if cx else w/2, + center_y=cy if cy else h/2, + frequency=frequency, + decay=decay, + speed=1.0, + t=phase + ) + return GPUFrame(rippled, on_gpu=True) + + return frame + + +def gpu_contrast(frame: GPUFrame, factor: float) -> GPUFrame: + """Adjust contrast on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = np.clip((frame.cpu.astype(np.float32) - 128) * factor + 128, 0, 255).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(gpu_data) + ops.contrast(factor) + return GPUFrame(ops.get_output().copy(), on_gpu=True) + + # Fallback + result = cp.clip((frame.gpu.astype(cp.float32) - 128) * factor + 128, 0, 255).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_rotate(frame: GPUFrame, angle: float) -> GPUFrame: + """Rotate frame on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + import cv2 + h, w = frame.cpu.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(frame.cpu, M, (w, h)) + return GPUFrame(rotated, on_gpu=False) + + # Use fast CUDA kernel (< 1ms vs 20ms for scipy) + if _FAST_KERNELS_AVAILABLE: + rotated = fast_rotate(frame.gpu, angle) + return GPUFrame(rotated, on_gpu=True) + + # Fallback to scipy (slow) + from cupyx.scipy import ndimage as cpndimage + rotated = cpndimage.rotate(frame.gpu, angle, reshape=False, order=1) + return GPUFrame(rotated, on_gpu=True) + + +def gpu_brightness(frame: GPUFrame, factor: float) -> GPUFrame: + """Adjust brightness on GPU using fast CUDA kernel.""" + if not GPU_AVAILABLE or not frame.is_on_gpu: + result = np.clip(frame.cpu.astype(np.float32) * factor, 0, 255).astype(np.uint8) + return GPUFrame(result, on_gpu=False) + + # Use fast CUDA kernel + if _FAST_KERNELS_AVAILABLE: + gpu_data = frame.gpu + if gpu_data.dtype != cp.uint8: + gpu_data = cp.clip(gpu_data, 0, 255).astype(cp.uint8) + h, w = gpu_data.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(gpu_data) + ops.brightness(factor) + return GPUFrame(ops.get_output().copy(), on_gpu=True) + + # Fallback + result = cp.clip(frame.gpu.astype(cp.float32) * factor, 0, 255).astype(cp.uint8) + return GPUFrame(result, on_gpu=True) + + +def gpu_composite(frames: list, weights: list = None) -> GPUFrame: + """ + Composite multiple frames with weights. + + All frames processed on GPU for efficiency. + """ + if not frames: + raise ValueError("No frames to composite") + + if len(frames) == 1: + return frames[0] + + if weights is None: + weights = [1.0 / len(frames)] * len(frames) + + # Normalize weights + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + + use_gpu = GPU_AVAILABLE and any(f.is_on_gpu for f in frames) + + if use_gpu: + # All on GPU + target_shape = frames[0].gpu.shape + result = cp.zeros(target_shape, dtype=cp.float32) + + for frame, weight in zip(frames, weights): + gpu_data = frame.gpu.astype(cp.float32) + if gpu_data.shape != target_shape: + # Resize to match + from cupyx.scipy import ndimage as cpndimage + h, w = target_shape[:2] + fh, fw = gpu_data.shape[:2] + zoom_factors = (h/fh, w/fw, 1) if gpu_data.ndim == 3 else (h/fh, w/fw) + gpu_data = cpndimage.zoom(gpu_data, zoom_factors, order=1) + result += gpu_data * weight + + return GPUFrame(cp.clip(result, 0, 255).astype(cp.uint8), on_gpu=True) + else: + # All on CPU + import cv2 + target_shape = frames[0].cpu.shape + result = np.zeros(target_shape, dtype=np.float32) + + for frame, weight in zip(frames, weights): + cpu_data = frame.cpu.astype(np.float32) + if cpu_data.shape != target_shape: + cpu_data = cv2.resize(cpu_data, (target_shape[1], target_shape[0])) + result += cpu_data * weight + + return GPUFrame(np.clip(result, 0, 255).astype(np.uint8), on_gpu=False) + + +# Primitive registration for streaming interpreter + +def _to_gpu_frame(img): + """Convert any image type to GPUFrame, keeping data on GPU if possible.""" + if isinstance(img, GPUFrame): + return img + # Check for CuPy array (stays on GPU) + if GPU_AVAILABLE and hasattr(img, '__cuda_array_interface__'): + # Already a CuPy array - wrap directly + return GPUFrame(img, on_gpu=True) + # Numpy or other - will be uploaded to GPU + return GPUFrame(img, on_gpu=True) + + +def get_primitives(): + """ + Get GPU-aware primitives for registration with interpreter. + + These wrap the GPU functions to work with the sexp interpreter. + All use fast CUDA kernels when available for maximum performance. + + Primitives detect CuPy arrays and keep them on GPU (no CPU round-trips). + """ + def prim_make_video_source_gpu(path: str, fps: float = 30): + """Create GPU-accelerated video source.""" + return GPUVideoSource(path, fps, prefer_gpu=True) + + def prim_gpu_blend(a, b, alpha=0.5): + """Blend two frames using fast CUDA kernel.""" + fa = _to_gpu_frame(a) + fb = _to_gpu_frame(b) + result = gpu_blend(fa, fb, alpha) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_rotate(img, angle): + """Rotate image using fast CUDA kernel (< 1ms).""" + f = _to_gpu_frame(img) + result = gpu_rotate(f, angle) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_brightness(img, factor): + """Adjust brightness using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_brightness(f, factor) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_contrast(img, factor): + """Adjust contrast using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_contrast(f, factor) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_zoom(img, factor, cx=None, cy=None): + """Zoom image using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_zoom(f, factor, cx, cy) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_hue_shift(img, degrees): + """Shift hue using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_hue_shift(f, degrees) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_invert(img): + """Invert colors using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_invert(f) + return result.gpu if result.is_on_gpu else result.cpu + + def prim_gpu_ripple(img, amplitude, frequency=8, decay=2, phase=0, cx=None, cy=None): + """Apply ripple effect using fast CUDA kernel.""" + f = _to_gpu_frame(img) + result = gpu_ripple(f, amplitude, frequency, decay, phase, cx, cy) + return result.gpu if result.is_on_gpu else result.cpu + + return { + 'streaming-gpu:make-video-source': prim_make_video_source_gpu, + 'gpu:blend': prim_gpu_blend, + 'gpu:rotate': prim_gpu_rotate, + 'gpu:brightness': prim_gpu_brightness, + 'gpu:contrast': prim_gpu_contrast, + 'gpu:zoom': prim_gpu_zoom, + 'gpu:hue-shift': prim_gpu_hue_shift, + 'gpu:invert': prim_gpu_invert, + 'gpu:ripple': prim_gpu_ripple, + } + + +# Export +__all__ = [ + 'GPU_AVAILABLE', + 'GPUFrame', + 'GPUVideoSource', + 'gpu_blend', + 'gpu_resize', + 'gpu_rotate', + 'gpu_brightness', + 'gpu_contrast', + 'gpu_zoom', + 'gpu_hue_shift', + 'gpu_invert', + 'gpu_ripple', + 'gpu_composite', + 'get_primitives', + 'check_hwdec_available', + 'PRIMITIVES', +] + + +# Import CPU primitives from streaming.py and include them in PRIMITIVES +# This ensures audio analysis primitives are available when streaming_gpu is loaded +def _get_cpu_primitives(): + from sexp_effects.primitive_libs import streaming + return streaming.PRIMITIVES + + +PRIMITIVES = _get_cpu_primitives().copy() + +# Try to import fused kernel compiler +_FUSED_KERNELS_AVAILABLE = False +_compile_frame_pipeline = None +_compile_autonomous_pipeline = None +try: + if GPU_AVAILABLE: + from streaming.sexp_to_cuda import compile_frame_pipeline as _compile_frame_pipeline + from streaming.sexp_to_cuda import compile_autonomous_pipeline as _compile_autonomous_pipeline + _FUSED_KERNELS_AVAILABLE = True + print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] Fused kernels not available: {e}", file=sys.stderr) + + +# Fused pipeline cache +_FUSED_PIPELINE_CACHE = {} + + +def _normalize_effect_dict(effect): + """Convert effect dict with Keyword keys to string keys.""" + result = {} + for k, v in effect.items(): + # Handle Keyword objects from sexp parser + if hasattr(k, 'name'): # Keyword object + key = k.name + else: + key = str(k) + result[key] = v + return result + + +_FUSED_CALL_COUNT = 0 + +def prim_fused_pipeline(img, effects_list, **dynamic_params): + """ + Apply a fused CUDA kernel pipeline to an image. + + This compiles multiple effects into a single CUDA kernel that processes + the entire pipeline in one GPU pass, eliminating Python interpreter overhead. + + Args: + img: Input image (GPU array or numpy array) + effects_list: List of effect dicts like: + [{'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 90.0}, + {'op': 'ripple', 'amplitude': 10, ...}] + **dynamic_params: Parameters that change per-frame like: + rotate_angle=45, ripple_phase=0.5 + + Returns: + Processed image as GPU array + + Supported ops: rotate, zoom, ripple, invert, hue_shift, brightness, resize + """ + global _FUSED_CALL_COUNT + _FUSED_CALL_COUNT += 1 + if _FUSED_CALL_COUNT <= 5 or _FUSED_CALL_COUNT % 100 == 0: + print(f"[FUSED] call #{_FUSED_CALL_COUNT}, effects={len(effects_list)}, params={list(dynamic_params.keys())}", file=sys.stderr) + + # Normalize effects list - convert Keyword keys to strings + effects_list = [_normalize_effect_dict(e) for e in effects_list] + + # Handle resize separately - it changes dimensions so must happen before fused kernel + resize_ops = [e for e in effects_list if e.get('op') == 'resize'] + other_effects = [e for e in effects_list if e.get('op') != 'resize'] + + # Apply resize first if needed + if resize_ops: + for resize_op in resize_ops: + target_w = int(resize_op.get('width', 640)) + target_h = int(resize_op.get('height', 360)) + # Wrap in GPUFrame if needed + if isinstance(img, GPUFrame): + img = gpu_resize(img, (target_w, target_h)) + img = img.gpu if img.is_on_gpu else img.cpu + else: + frame = GPUFrame(img, on_gpu=hasattr(img, '__cuda_array_interface__')) + img = gpu_resize(frame, (target_w, target_h)) + img = img.gpu if img.is_on_gpu else img.cpu + + # If no other effects, just return the resized image + if not other_effects: + return img + + # Update effects list to exclude resize ops + effects_list = other_effects + + if not _FUSED_KERNELS_AVAILABLE: + # Fallback: apply effects one by one + print(f"[FUSED FALLBACK] Using fallback path for {len(effects_list)} effects", file=sys.stderr) + # Wrap in GPUFrame if needed (GPU functions expect GPUFrame objects) + if isinstance(img, GPUFrame): + result = img + else: + on_gpu = hasattr(img, '__cuda_array_interface__') + result = GPUFrame(img, on_gpu=on_gpu) + for effect in effects_list: + op = effect['op'] + if op == 'rotate': + angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) + result = gpu_rotate(result, angle) + elif op == 'zoom': + amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) + result = gpu_zoom(result, amount) + elif op == 'hue_shift': + degrees = effect.get('degrees', 0) + if abs(degrees) > 0.1: # Only apply if significant shift + result = gpu_hue_shift(result, degrees) + elif op == 'ripple': + amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) + if amplitude > 0.1: # Only apply if amplitude is significant + result = gpu_ripple(result, + amplitude=amplitude, + frequency=effect.get('frequency', 8), + decay=effect.get('decay', 2), + phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), + cx=effect.get('center_x'), + cy=effect.get('center_y')) + elif op == 'brightness': + factor = effect.get('factor', 1.0) + result = gpu_contrast(result, factor, 0) + elif op == 'invert': + amount = effect.get('amount', 0) + if amount > 0.5: # Only invert if amount > 0.5 + result = gpu_invert(result) + else: + raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") + # Return raw array, not GPUFrame (downstream expects arrays with .flags attribute) + if isinstance(result, GPUFrame): + return result.gpu if result.is_on_gpu else result.cpu + return result + + # Get image dimensions + if hasattr(img, 'shape'): + h, w = img.shape[:2] + else: + raise ValueError("Image must have shape attribute") + + # Create cache key from effects + import hashlib + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects_list]) + cache_key = f"{w}x{h}_{hashlib.md5(ops_key.encode()).hexdigest()}" + + # Compile or get cached pipeline + if cache_key not in _FUSED_PIPELINE_CACHE: + _FUSED_PIPELINE_CACHE[cache_key] = _compile_frame_pipeline(effects_list, w, h) + + pipeline = _FUSED_PIPELINE_CACHE[cache_key] + + # Ensure image is on GPU and uint8 + if hasattr(img, '__cuda_array_interface__'): + gpu_img = img + elif GPU_AVAILABLE: + gpu_img = cp.asarray(img) + else: + gpu_img = img + + # Run the fused pipeline + # Debug: log dynamic params occasionally + import random + if random.random() < 0.01: # 1% of frames + print(f"[fused] dynamic_params: {dynamic_params}", file=sys.stderr) + print(f"[fused] effects: {[(e['op'], e.get('amount'), e.get('amplitude')) for e in effects_list]}", file=sys.stderr) + return pipeline(gpu_img, **dynamic_params) + + +# Autonomous pipeline cache (separate from fused) +_AUTONOMOUS_PIPELINE_CACHE = {} + + +def prim_autonomous_pipeline(img, effects_list, dynamic_expressions, frame_num, fps=30.0): + """ + Apply a fully autonomous CUDA kernel pipeline. + + This computes ALL parameters on GPU - including time-based expressions + like sin(t), t*30, etc. Zero Python in the hot path! + + Args: + img: Input image (GPU array or numpy array) + effects_list: List of effect dicts + dynamic_expressions: Dict mapping param names to CUDA expressions: + {'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)'} + frame_num: Current frame number + fps: Frames per second (default 30) + + Returns: + Processed image as GPU array + + Note: Expressions use CUDA syntax - use sinf() not sin(), etc. + """ + # Normalize effects and expressions + effects_list = [_normalize_effect_dict(e) for e in effects_list] + dynamic_expressions = { + (k.name if hasattr(k, 'name') else str(k)): v + for k, v in dynamic_expressions.items() + } + + if not _FUSED_KERNELS_AVAILABLE or _compile_autonomous_pipeline is None: + # Fallback to regular fused pipeline with Python-computed params + import math + t = float(frame_num) / float(fps) + # Evaluate expressions in Python as fallback + dynamic_params = {} + for key, expr in dynamic_expressions.items(): + try: + # Simple eval with t and math functions + result = eval(expr.replace('f', '').replace('sin', 'math.sin').replace('cos', 'math.cos'), + {'t': t, 'math': math, 'frame_num': frame_num}) + dynamic_params[key] = result + except: + dynamic_params[key] = 0 + return prim_fused_pipeline(img, effects_list, **dynamic_params) + + # Get image dimensions + if hasattr(img, 'shape'): + h, w = img.shape[:2] + else: + raise ValueError("Image must have shape attribute") + + # Create cache key + import hashlib + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects_list]) + expr_key = str(sorted(dynamic_expressions.items())) + cache_key = f"auto_{w}x{h}_{hashlib.md5((ops_key + expr_key).encode()).hexdigest()}" + + # Compile or get cached pipeline + if cache_key not in _AUTONOMOUS_PIPELINE_CACHE: + _AUTONOMOUS_PIPELINE_CACHE[cache_key] = _compile_autonomous_pipeline( + effects_list, w, h, dynamic_expressions) + + pipeline = _AUTONOMOUS_PIPELINE_CACHE[cache_key] + + # Ensure image is on GPU + if hasattr(img, '__cuda_array_interface__'): + gpu_img = img + elif GPU_AVAILABLE: + gpu_img = cp.asarray(img) + else: + gpu_img = img + + # Run - just pass frame_num and fps, kernel does the rest! + return pipeline(gpu_img, int(frame_num), float(fps)) + + +# ============================================================ +# GPU Image Primitives (keep images on GPU) +# ============================================================ + +def gpu_make_image(w, h, color=None): + """Create a new image on GPU filled with color (default black). + + Unlike image:make-image, this keeps the image on GPU memory, + avoiding CPU<->GPU transfers in the pipeline. + """ + if not GPU_AVAILABLE: + # Fallback to CPU + import numpy as np + if color is None: + color = [0, 0, 0] + img = np.zeros((int(h), int(w), 3), dtype=np.uint8) + img[:] = color + return img + + w, h = int(w), int(h) + if color is None: + color = [0, 0, 0] + + # Create on GPU directly + img = cp.zeros((h, w, 3), dtype=cp.uint8) + img[:, :, 0] = int(color[0]) if len(color) > 0 else 0 + img[:, :, 1] = int(color[1]) if len(color) > 1 else 0 + img[:, :, 2] = int(color[2]) if len(color) > 2 else 0 + + return img + + +def gpu_gradient_image(w, h, color1=None, color2=None, direction='horizontal'): + """Create a gradient image on GPU. + + Args: + w, h: Dimensions + color1, color2: Start and end colors [r, g, b] + direction: 'horizontal', 'vertical', 'diagonal' + """ + if not GPU_AVAILABLE: + return gpu_make_image(w, h, color1) + + w, h = int(w), int(h) + if color1 is None: + color1 = [0, 0, 0] + if color2 is None: + color2 = [255, 255, 255] + + img = cp.zeros((h, w, 3), dtype=cp.uint8) + + if direction == 'horizontal': + for c in range(3): + grad = cp.linspace(color1[c], color2[c], w, dtype=cp.float32) + img[:, :, c] = grad[cp.newaxis, :].astype(cp.uint8) + elif direction == 'vertical': + for c in range(3): + grad = cp.linspace(color1[c], color2[c], h, dtype=cp.float32) + img[:, :, c] = grad[:, cp.newaxis].astype(cp.uint8) + elif direction == 'diagonal': + for c in range(3): + x_grad = cp.linspace(0, 1, w, dtype=cp.float32)[cp.newaxis, :] + y_grad = cp.linspace(0, 1, h, dtype=cp.float32)[:, cp.newaxis] + combined = (x_grad + y_grad) / 2 + img[:, :, c] = (color1[c] + (color2[c] - color1[c]) * combined).astype(cp.uint8) + + return img + + +# Add GPU-specific primitives +PRIMITIVES['fused-pipeline'] = prim_fused_pipeline +PRIMITIVES['autonomous-pipeline'] = prim_autonomous_pipeline +PRIMITIVES['gpu-make-image'] = gpu_make_image +PRIMITIVES['gpu-gradient'] = gpu_gradient_image +# (The GPU video source will be added by create_cid_primitives in the task) diff --git a/sexp_effects/primitive_libs/xector.py b/sexp_effects/primitive_libs/xector.py new file mode 100644 index 0000000..fb95dfd --- /dev/null +++ b/sexp_effects/primitive_libs/xector.py @@ -0,0 +1,1382 @@ +""" +Xector Primitives - Parallel array operations for GPU-style data parallelism. + +Inspired by Connection Machine Lisp and hillisp. Xectors are parallel arrays +where operations automatically apply element-wise. + +Usage in sexp: + (require-primitives "xector") + + ;; Extract channels as xectors + (let* ((r (red frame)) + (g (green frame)) + (b (blue frame)) + ;; Operations are element-wise on xectors + (brightness (α+ (α* r 0.299) (α* g 0.587) (α* b 0.114)))) + ;; Reduce to scalar + (βmax brightness)) + + ;; Explicit α for element-wise, implicit also works + (α+ r 10) ;; explicit: add 10 to every element + (+ r 10) ;; implicit: same thing when r is a xector + + ;; β for reductions + (β+ r) ;; sum all elements + (βmax r) ;; maximum element + (βmean r) ;; average + +Operators: + α (alpha) - element-wise: (α+ x y) adds corresponding elements + β (beta) - reduce: (β+ x) sums all elements +""" + +import numpy as np +from typing import Union, Callable, Any + +# Try to use CuPy for GPU acceleration if available +try: + import cupy as cp + HAS_CUPY = True +except ImportError: + cp = None + HAS_CUPY = False + + +class Xector: + """ + Parallel array type for element-wise operations. + + Wraps a numpy/cupy array and provides automatic broadcasting + and element-wise operation semantics. + """ + + def __init__(self, data, shape=None): + """ + Create a Xector from data. + + Args: + data: numpy array, cupy array, scalar, or list + shape: optional shape tuple (for coordinate xectors) + """ + if isinstance(data, Xector): + self._data = data._data + self._shape = data._shape + elif isinstance(data, np.ndarray): + self._data = data.astype(np.float32) + self._shape = shape or data.shape + elif HAS_CUPY and isinstance(data, cp.ndarray): + self._data = data.astype(cp.float32) + self._shape = shape or data.shape + elif isinstance(data, (list, tuple)): + self._data = np.array(data, dtype=np.float32) + self._shape = shape or self._data.shape + else: + # Scalar - will broadcast + self._data = np.float32(data) + self._shape = shape or () + + @property + def data(self): + return self._data + + @property + def shape(self): + return self._shape + + def __len__(self): + return self._data.size + + def __repr__(self): + if self._data.size <= 10: + return f"Xector({self._data})" + return f"Xector(shape={self._shape}, size={self._data.size})" + + def to_numpy(self): + """Convert to numpy array.""" + if HAS_CUPY and isinstance(self._data, cp.ndarray): + return cp.asnumpy(self._data) + return self._data + + def to_gpu(self): + """Move to GPU if available.""" + if HAS_CUPY and not isinstance(self._data, cp.ndarray): + self._data = cp.asarray(self._data) + return self + + # Arithmetic operators - enable implicit element-wise ops + def __add__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data + other_data, self._shape) + + def __radd__(self, other): + return Xector(other + self._data, self._shape) + + def __sub__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data - other_data, self._shape) + + def __rsub__(self, other): + return Xector(other - self._data, self._shape) + + def __mul__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data * other_data, self._shape) + + def __rmul__(self, other): + return Xector(other * self._data, self._shape) + + def __truediv__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data / other_data, self._shape) + + def __rtruediv__(self, other): + return Xector(other / self._data, self._shape) + + def __pow__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data ** other_data, self._shape) + + def __neg__(self): + return Xector(-self._data, self._shape) + + def __abs__(self): + return Xector(np.abs(self._data), self._shape) + + # Comparison operators - return boolean xectors + def __lt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data < other_data, self._shape) + + def __le__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data <= other_data, self._shape) + + def __gt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data > other_data, self._shape) + + def __ge__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data >= other_data, self._shape) + + def __eq__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data == other_data, self._shape) + + def __ne__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data != other_data, self._shape) + + +def _unwrap(x): + """Unwrap Xector to underlying data, or return as-is.""" + if isinstance(x, Xector): + return x._data + return x + + +def _wrap(data, shape=None): + """Wrap result in Xector if it's an array.""" + if isinstance(data, (np.ndarray,)) or (HAS_CUPY and isinstance(data, cp.ndarray)): + return Xector(data, shape) + return data + + +# ============================================================================= +# Frame/Xector Conversion +# ============================================================================= +# NOTE: red, green, blue, gray are derived in derived.sexp using (channel frame n) + +def xector_from_frame(frame): + """Convert entire frame to xector (flattened RGB). (xector frame) -> Xector""" + if isinstance(frame, np.ndarray): + return Xector(frame.flatten().astype(np.float32), frame.shape) + raise TypeError(f"Expected frame array, got {type(frame)}") + + +def xector_to_frame(x, shape=None): + """Convert xector back to frame. (to-frame x) or (to-frame x shape) -> frame""" + data = _unwrap(x) + if shape is None and isinstance(x, Xector): + shape = x._shape + if shape is None: + raise ValueError("Shape required to convert xector to frame") + return np.clip(data, 0, 255).reshape(shape).astype(np.uint8) + + +# ============================================================================= +# Coordinate Generators +# ============================================================================= +# NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived +# in derived.sexp using iota, tile, repeat primitives + + +# ============================================================================= +# Alpha (α) - Element-wise Operations +# ============================================================================= + +def alpha_lift(fn): + """Lift a scalar function to work element-wise on xectors.""" + def lifted(*args): + # Check if any arg is a Xector + has_xector = any(isinstance(a, Xector) for a in args) + if not has_xector: + return fn(*args) + + # Get shape from first xector + shape = None + for a in args: + if isinstance(a, Xector): + shape = a._shape + break + + # Unwrap all args + unwrapped = [_unwrap(a) for a in args] + + # Apply function + result = fn(*unwrapped) + + return _wrap(result, shape) + + return lifted + + +# Element-wise math operations +def alpha_add(*args): + """Element-wise addition. (α+ a b ...) -> Xector""" + if len(args) == 0: + return 0 + result = _unwrap(args[0]) + for a in args[1:]: + result = result + _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_sub(a, b=None): + """Element-wise subtraction. (α- a b) -> Xector""" + if b is None: + return Xector(-_unwrap(a)) if isinstance(a, Xector) else -a + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) - _unwrap(b), shape) + + +def alpha_mul(*args): + """Element-wise multiplication. (α* a b ...) -> Xector""" + if len(args) == 0: + return 1 + result = _unwrap(args[0]) + for a in args[1:]: + result = result * _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_div(a, b): + """Element-wise division. (α/ a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) / _unwrap(b), shape) + + +def alpha_pow(a, b): + """Element-wise power. (α** a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) ** _unwrap(b), shape) + + +def alpha_sqrt(x): + """Element-wise square root. (αsqrt x) -> Xector""" + return _wrap(np.sqrt(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_abs(x): + """Element-wise absolute value. (αabs x) -> Xector""" + return _wrap(np.abs(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_sin(x): + """Element-wise sine. (αsin x) -> Xector""" + return _wrap(np.sin(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_cos(x): + """Element-wise cosine. (αcos x) -> Xector""" + return _wrap(np.cos(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_exp(x): + """Element-wise exponential. (αexp x) -> Xector""" + return _wrap(np.exp(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_log(x): + """Element-wise natural log. (αlog x) -> Xector""" + return _wrap(np.log(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_clamp is derived in derived.sexp as (max2 lo (min2 hi x)) + +def alpha_min(a, b): + """Element-wise minimum. (αmin a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.minimum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_max(a, b): + """Element-wise maximum. (αmax a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.maximum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_mod(a, b): + """Element-wise modulo. (αmod a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) % _unwrap(b), shape) + + +def alpha_floor(x): + """Element-wise floor. (αfloor x) -> Xector""" + return _wrap(np.floor(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_ceil(x): + """Element-wise ceiling. (αceil x) -> Xector""" + return _wrap(np.ceil(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_round(x): + """Element-wise round. (αround x) -> Xector""" + return _wrap(np.round(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_sq is derived in derived.sexp as (* x x) + +# Comparison operators (return boolean xectors) +def alpha_lt(a, b): + """Element-wise less than. (α< a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) < _unwrap(b), shape) + + +def alpha_le(a, b): + """Element-wise less-or-equal. (α<= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) <= _unwrap(b), shape) + + +def alpha_gt(a, b): + """Element-wise greater than. (α> a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) > _unwrap(b), shape) + + +def alpha_ge(a, b): + """Element-wise greater-or-equal. (α>= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) >= _unwrap(b), shape) + + +def alpha_eq(a, b): + """Element-wise equality. (α= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) == _unwrap(b), shape) + + +# Logical operators +def alpha_and(a, b): + """Element-wise logical and. (αand a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_and(_unwrap(a), _unwrap(b)), shape) + + +def alpha_or(a, b): + """Element-wise logical or. (αor a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_or(_unwrap(a), _unwrap(b)), shape) + + +def alpha_not(x): + """Element-wise logical not. (αnot x) -> Xector[bool]""" + return _wrap(np.logical_not(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# ============================================================================= +# Beta (β) - Reduction Operations +# ============================================================================= + +def beta_add(x): + """Sum all elements. (β+ x) -> scalar""" + return float(np.sum(_unwrap(x))) + + +def beta_mul(x): + """Product of all elements. (β* x) -> scalar""" + return float(np.prod(_unwrap(x))) + + +def beta_min(x): + """Minimum element. (βmin x) -> scalar""" + return float(np.min(_unwrap(x))) + + +def beta_max(x): + """Maximum element. (βmax x) -> scalar""" + return float(np.max(_unwrap(x))) + + +def beta_mean(x): + """Mean of all elements. (βmean x) -> scalar""" + return float(np.mean(_unwrap(x))) + + +def beta_std(x): + """Standard deviation. (βstd x) -> scalar""" + return float(np.std(_unwrap(x))) + + +def beta_count(x): + """Count of elements. (βcount x) -> scalar""" + return int(np.size(_unwrap(x))) + + +def beta_any(x): + """True if any element is truthy. (βany x) -> bool""" + return bool(np.any(_unwrap(x))) + + +def beta_all(x): + """True if all elements are truthy. (βall x) -> bool""" + return bool(np.all(_unwrap(x))) + + +# ============================================================================= +# Conditional / Selection +# ============================================================================= + +def xector_where(cond, true_val, false_val): + """ + Conditional select. (where cond true-val false-val) -> Xector + + Like numpy.where - selects elements based on condition. + """ + cond_data = _unwrap(cond) + true_data = _unwrap(true_val) + false_data = _unwrap(false_val) + + # Get shape from condition or values + shape = None + for x in [cond, true_val, false_val]: + if isinstance(x, Xector): + shape = x._shape + break + + result = np.where(cond_data, true_data, false_data) + return _wrap(result, shape) + + +# NOTE: fill, zeros, ones are derived in derived.sexp using iota + +def xector_rand(size_or_frame): + """Create xector of random values [0,1). (rand-x frame) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector(np.random.random(size).astype(np.float32), shape) + + +def xector_randn(size_or_frame, mean=0, std=1): + """Create xector of normal random values. (randn-x frame) or (randn-x frame mean std) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector((np.random.randn(size) * std + mean).astype(np.float32), shape) + + +# ============================================================================= +# Type checking +# ============================================================================= + +def is_xector(x): + """Check if x is a Xector. (xector? x) -> bool""" + return isinstance(x, Xector) + + +# ============================================================================= +# CORE PRIMITIVES: gather, scatter, group-reduce, reshape +# These are the fundamental operations everything else builds on. +# ============================================================================= + +def xector_gather(data, indices): + """ + Parallel index lookup. (gather data indices) -> Xector + + For each index in indices, look up the corresponding value in data. + This is the fundamental operation for remapping/resampling. + + Example: + (gather [10 20 30 40] [2 0 1 2]) ; -> [30 10 20 30] + """ + data_arr = _unwrap(data) + idx_arr = _unwrap(indices).astype(np.int32) + + # Flatten data for 1D indexing + flat_data = data_arr.flatten() + + # Clip indices to valid range + idx_clipped = np.clip(idx_arr, 0, len(flat_data) - 1) + + result = flat_data[idx_clipped] + shape = indices._shape if isinstance(indices, Xector) else None + return Xector(result, shape) + + +def xector_gather_2d(data, row_indices, col_indices): + """ + 2D parallel index lookup. (gather-2d data rows cols) -> Xector + + For each (row, col) pair, look up the value in 2D data. + Essential for grid/cell operations. + + Example: + (gather-2d image-lum cell-rows cell-cols) + """ + data_arr = _unwrap(data) + row_arr = _unwrap(row_indices).astype(np.int32) + col_arr = _unwrap(col_indices).astype(np.int32) + + # Get data shape + if isinstance(data, Xector) and data._shape and len(data._shape) >= 2: + h, w = data._shape[:2] + data_2d = data_arr.reshape(h, w) + elif len(data_arr.shape) >= 2: + h, w = data_arr.shape[:2] + data_2d = data_arr.reshape(h, w) if data_arr.ndim == 1 else data_arr + else: + # Assume square + size = int(np.sqrt(len(data_arr))) + h, w = size, size + data_2d = data_arr.reshape(h, w) + + # Clip indices + row_clipped = np.clip(row_arr, 0, h - 1) + col_clipped = np.clip(col_arr, 0, w - 1) + + result = data_2d[row_clipped.flatten(), col_clipped.flatten()] + shape = row_indices._shape if isinstance(row_indices, Xector) else None + return Xector(result, shape) + + +def xector_scatter(indices, values, size): + """ + Parallel index write. (scatter indices values size) -> Xector + + Create a new xector of given size, writing values at indices. + Later writes overwrite earlier ones at same index. + + Example: + (scatter [0 2 4] [10 20 30] 5) ; -> [10 0 20 0 30] + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + idx_clipped = np.clip(idx_arr, 0, int(size) - 1) + result[idx_clipped] = val_arr + + return Xector(result, (int(size),)) + + +def xector_scatter_add(indices, values, size): + """ + Parallel index accumulate. (scatter-add indices values size) -> Xector + + Like scatter, but adds to existing values instead of overwriting. + Useful for histograms, pooling reductions. + + Example: + (scatter-add [0 0 1] [1 2 3] 3) ; -> [3 3 0] (1+2 at index 0) + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + np.add.at(result, np.clip(idx_arr, 0, int(size) - 1), val_arr) + + return Xector(result, (int(size),)) + + +def xector_group_reduce(values, group_indices, num_groups, op='mean'): + """ + Reduce values by group. (group-reduce values groups num-groups op) -> Xector + + Groups values by group_indices and reduces each group. + This is the primitive for pooling operations. + + Args: + values: Xector of values to reduce + group_indices: Xector of group assignments (integers) + num_groups: Number of groups (output size) + op: 'mean', 'sum', 'max', 'min' + + Example: + ; Pool 4 values into 2 groups + (group-reduce [1 2 3 4] [0 0 1 1] 2 "mean") ; -> [1.5 3.5] + """ + val_arr = _unwrap(values).flatten() + grp_arr = _unwrap(group_indices).astype(np.int32).flatten() + n = int(num_groups) + + if op == 'sum': + result = np.zeros(n, dtype=np.float32) + np.add.at(result, grp_arr, val_arr) + elif op == 'mean': + sums = np.zeros(n, dtype=np.float32) + counts = np.zeros(n, dtype=np.float32) + np.add.at(sums, grp_arr, val_arr) + np.add.at(counts, grp_arr, 1) + result = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + elif op == 'max': + result = np.full(n, -np.inf, dtype=np.float32) + np.maximum.at(result, grp_arr, val_arr) + result[result == -np.inf] = 0 + elif op == 'min': + result = np.full(n, np.inf, dtype=np.float32) + np.minimum.at(result, grp_arr, val_arr) + result[result == np.inf] = 0 + else: + raise ValueError(f"Unknown reduce op: {op}") + + return Xector(result, (n,)) + + +def xector_reshape(x, *dims): + """ + Reshape xector. (reshape x h w) or (reshape x n) -> Xector + + Changes the logical shape of the xector without changing data. + """ + data = _unwrap(x) + if len(dims) == 1: + new_shape = (int(dims[0]),) + else: + new_shape = tuple(int(d) for d in dims) + + return Xector(data.reshape(-1), new_shape) + + +def xector_shape(x): + """Get shape of xector. (shape x) -> list""" + if isinstance(x, Xector): + return list(x._shape) if x._shape else [len(x)] + if isinstance(x, np.ndarray): + return list(x.shape) + return [] + + +def xector_len(x): + """Get length of xector. (xlen x) -> int""" + return len(_unwrap(x).flatten()) + + +def xector_iota(n): + """ + Generate indices 0 to n-1. (iota n) -> Xector + + Fundamental for generating coordinate xectors. + + Example: + (iota 5) ; -> [0 1 2 3 4] + """ + return Xector(np.arange(int(n), dtype=np.float32), (int(n),)) + + +def xector_repeat(x, n): + """ + Repeat each element n times. (repeat x n) -> Xector + + Example: + (repeat [1 2 3] 2) ; -> [1 1 2 2 3 3] + """ + data = _unwrap(x) + result = np.repeat(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +def xector_tile(x, n): + """ + Tile entire xector n times. (tile x n) -> Xector + + Example: + (tile [1 2 3] 2) ; -> [1 2 3 1 2 3] + """ + data = _unwrap(x) + result = np.tile(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +# ============================================================================= +# 2D Grid Helpers (built on primitives above) +# ============================================================================= + +def xector_cell_indices(frame, cell_size): + """ + Compute cell index for each pixel. (cell-indices frame cell-size) -> Xector + + Returns flat index of which cell each pixel belongs to. + This is the bridge between pixel-space and cell-space. + """ + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y = np.repeat(np.arange(h), w) # [0,0,0..., 1,1,1..., ...] + x = np.tile(np.arange(w), h) # [0,1,2..., 0,1,2..., ...] + + cell_row = y // cell_size + cell_col = x // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + cell_idx = np.clip(cell_idx, 0, rows * cols - 1) + + return Xector(cell_idx.astype(np.float32), (h, w)) + + +def xector_local_x(frame, cell_size): + """ + X position within each cell [0, cell_size). (local-x frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + x = np.tile(np.arange(w), h) + local = (x % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y(frame, cell_size): + """ + Y position within each cell [0, cell_size). (local-y frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + y = np.repeat(np.arange(h), w) + local = (y % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_x_norm(frame, cell_size): + """ + Normalized X within cell [0, 1]. (local-x-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + x = np.tile(np.arange(w), h) + local = ((x % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y_norm(frame, cell_size): + """ + Normalized Y within cell [0, 1]. (local-y-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + y = np.repeat(np.arange(h), w) + local = ((y % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_pool_frame(frame, cell_size, op='mean'): + """ + Pool frame to cell values. (pool-frame frame cell-size) -> (r, g, b, lum) Xectors + + Returns tuple of xectors: (red, green, blue, luminance) for cells. + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y = np.repeat(np.arange(h), w) + x = np.tile(np.arange(w), h) + cell_row = np.clip(y // cs, 0, rows - 1) + cell_col = np.clip(x // cs, 0, cols - 1) + cell_idx = cell_row * cols + cell_col + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(np.float32) + g_flat = frame[:, :, 1].flatten().astype(np.float32) + b_flat = frame[:, :, 2].flatten().astype(np.float32) + + # Pool each channel + def pool_channel(data): + sums = np.zeros(num_cells, dtype=np.float32) + counts = np.zeros(num_cells, dtype=np.float32) + np.add.at(sums, cell_idx, data) + np.add.at(counts, cell_idx, 1) + return np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + shape = (rows, cols) + return (Xector(r_pooled, shape), + Xector(g_pooled, shape), + Xector(b_pooled, shape), + Xector(lum, shape)) + + +def xector_cell_row(frame, cell_size): + """ + Cell row index for each pixel. (cell-row frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + + y = np.repeat(np.arange(h), w) + cell_row = np.clip(y // cs, 0, rows - 1).astype(np.float32) + return Xector(cell_row, (h, w)) + + +def xector_cell_col(frame, cell_size): + """ + Cell column index for each pixel. (cell-col frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + cols = w // cs + + x = np.tile(np.arange(w), h) + cell_col = np.clip(x // cs, 0, cols - 1).astype(np.float32) + return Xector(cell_col, (h, w)) + + +def xector_num_cells(frame, cell_size): + """Number of cells. (num-cells frame cell-size) -> (rows, cols, total)""" + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + return (rows, cols, rows * cols) + + +# ============================================================================= +# Scan (Prefix Operations) - cumulative reductions +# ============================================================================= + +def xector_scan_add(x, axis=None): + """ + Cumulative sum (prefix sum). (scan+ x) or (scan+ x :axis 0) + + Returns array where each element is sum of all previous elements. + Useful for integral images, cumulative effects. + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None: + # Reshape to 2D for axis operation + if shape and len(shape) == 2: + result = np.cumsum(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumsum(data, axis=int(axis)) + else: + result = np.cumsum(data) + + return _wrap(result, shape) + + +def xector_scan_mul(x, axis=None): + """Cumulative product. (scan* x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.cumprod(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumprod(data) + + return _wrap(result, shape) + + +def xector_scan_max(x, axis=None): + """Cumulative maximum. (scan-max x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.maximum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.maximum.accumulate(data) + + return _wrap(result, shape) + + +def xector_scan_min(x, axis=None): + """Cumulative minimum. (scan-min x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.minimum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.minimum.accumulate(data) + + return _wrap(result, shape) + + +# ============================================================================= +# Outer Product - Cartesian operations +# ============================================================================= + +def xector_outer(x, y, op='*'): + """ + Outer product. (outer x y) or (outer x y :op '+') + + Creates 2D result where result[i,j] = op(x[i], y[j]). + Default is multiplication (*). + + Useful for generating 2D patterns from 1D vectors. + """ + x_data = _unwrap(x) + y_data = _unwrap(y) + + ops = { + '*': np.multiply, + '+': np.add, + '-': np.subtract, + '/': np.divide, + 'max': np.maximum, + 'min': np.minimum, + 'and': np.logical_and, + 'or': np.logical_or, + 'xor': np.logical_xor, + } + + op_fn = ops.get(op, np.multiply) + result = op_fn.outer(x_data.flatten(), y_data.flatten()) + + # Return as xector with 2D shape + h, w = len(x_data.flatten()), len(y_data.flatten()) + return _wrap(result.flatten(), (h, w)) + + +def xector_outer_add(x, y): + """Outer sum. (outer+ x y) -> result[i,j] = x[i] + y[j]""" + return xector_outer(x, y, '+') + + +def xector_outer_mul(x, y): + """Outer product. (outer* x y) -> result[i,j] = x[i] * y[j]""" + return xector_outer(x, y, '*') + + +def xector_outer_max(x, y): + """Outer max. (outer-max x y) -> result[i,j] = max(x[i], y[j])""" + return xector_outer(x, y, 'max') + + +def xector_outer_min(x, y): + """Outer min. (outer-min x y) -> result[i,j] = min(x[i], y[j])""" + return xector_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - dimensional reductions +# ============================================================================= + +def xector_reduce_axis(x, op='sum', axis=0): + """ + Reduce along an axis. (reduce-axis x :op 'sum' :axis 0) + + ops: 'sum', 'mean', 'max', 'min', 'prod', 'std' + axis: 0 (rows), 1 (columns) + + For a frame-sized xector (H*W): + axis=0: reduce across rows -> W values (one per column) + axis=1: reduce across columns -> H values (one per row) + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if shape is None or len(shape) != 2: + # Can't do axis reduction without 2D shape + raise ValueError("reduce-axis requires 2D xector (with shape)") + + h, w = shape + data_2d = data.reshape(h, w) + axis = int(axis) + + ops = { + 'sum': lambda d, a: np.sum(d, axis=a), + '+': lambda d, a: np.sum(d, axis=a), + 'mean': lambda d, a: np.mean(d, axis=a), + 'max': lambda d, a: np.max(d, axis=a), + 'min': lambda d, a: np.min(d, axis=a), + 'prod': lambda d, a: np.prod(d, axis=a), + '*': lambda d, a: np.prod(d, axis=a), + 'std': lambda d, a: np.std(d, axis=a), + } + + op_fn = ops.get(op, ops['sum']) + result = op_fn(data_2d, axis) + + # Result shape: if axis=0, shape is (w,); if axis=1, shape is (h,) + new_shape = (w,) if axis == 0 else (h,) + return _wrap(result.flatten(), new_shape) + + +def xector_sum_axis(x, axis=0): + """Sum along axis. (sum-axis x :axis 0)""" + return xector_reduce_axis(x, 'sum', axis) + + +def xector_mean_axis(x, axis=0): + """Mean along axis. (mean-axis x :axis 0)""" + return xector_reduce_axis(x, 'mean', axis) + + +def xector_max_axis(x, axis=0): + """Max along axis. (max-axis x :axis 0)""" + return xector_reduce_axis(x, 'max', axis) + + +def xector_min_axis(x, axis=0): + """Min along axis. (min-axis x :axis 0)""" + return xector_reduce_axis(x, 'min', axis) + + +# ============================================================================= +# Windowed Operations - sliding window computations +# ============================================================================= + +def xector_window(x, size, op='mean', stride=1): + """ + Sliding window operation. (window x size :op 'mean' :stride 1) + + Applies reduction over sliding windows of given size. + ops: 'sum', 'mean', 'max', 'min' + + For 1D: windows slide along the array + For 2D (with shape): windows are size x size squares + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + size = int(size) + stride = int(stride) + + ops = { + 'sum': np.sum, + 'mean': np.mean, + 'max': np.max, + 'min': np.min, + 'std': np.std, + } + op_fn = ops.get(op, np.mean) + + if shape and len(shape) == 2: + # 2D sliding window + h, w = shape + data_2d = data.reshape(h, w) + + # Use stride tricks for efficient windowing + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + result = np.zeros((out_h, out_w)) + for i in range(out_h): + for j in range(out_w): + window = data_2d[i*stride:i*stride+size, j*stride:j*stride+size] + result[i, j] = op_fn(window) + + return _wrap(result.flatten(), (out_h, out_w)) + else: + # 1D sliding window + n = len(data) + out_n = (n - size) // stride + 1 + result = np.array([op_fn(data[i*stride:i*stride+size]) for i in range(out_n)]) + return _wrap(result, (out_n,)) + + +def xector_window_sum(x, size, stride=1): + """Sliding window sum. (window-sum x size)""" + return xector_window(x, size, 'sum', stride) + + +def xector_window_mean(x, size, stride=1): + """Sliding window mean. (window-mean x size)""" + return xector_window(x, size, 'mean', stride) + + +def xector_window_max(x, size, stride=1): + """Sliding window max. (window-max x size)""" + return xector_window(x, size, 'max', stride) + + +def xector_window_min(x, size, stride=1): + """Sliding window min. (window-min x size)""" + return xector_window(x, size, 'min', stride) + + +def xector_integral_image(frame): + """ + Compute integral image (summed area table). (integral-image frame) + + Each pixel contains sum of all pixels above and to the left. + Enables O(1) box blur at any radius. + + Returns xector with same shape as frame's luminance. + """ + if hasattr(frame, 'shape') and len(frame.shape) == 3: + # Convert frame to grayscale + gray = np.mean(frame, axis=2) + else: + data = _unwrap(frame) + shape = frame._shape if isinstance(frame, Xector) else None + if shape and len(shape) == 2: + gray = data.reshape(shape) + else: + gray = data + + integral = np.cumsum(np.cumsum(gray, axis=0), axis=1) + h, w = integral.shape + return _wrap(integral.flatten(), (h, w)) + + +def xector_box_blur_fast(integral, x, y, radius, width, height): + """ + Fast box blur using integral image. (box-blur-fast integral x y radius w h) + + Given pre-computed integral image, compute average in box centered at (x,y). + O(1) regardless of radius. + """ + integral_data = _unwrap(integral) + shape = integral._shape if isinstance(integral, Xector) else None + + if shape is None or len(shape) != 2: + raise ValueError("box-blur-fast requires 2D integral image") + + h, w = shape + integral_2d = integral_data.reshape(h, w) + + radius = int(radius) + x, y = int(x), int(y) + + # Clamp coordinates + x1 = max(0, x - radius) + y1 = max(0, y - radius) + x2 = min(w - 1, x + radius) + y2 = min(h - 1, y + radius) + + # Sum in rectangle using integral image + total = integral_2d[y2, x2] + if x1 > 0: + total -= integral_2d[y2, x1 - 1] + if y1 > 0: + total -= integral_2d[y1 - 1, x2] + if x1 > 0 and y1 > 0: + total += integral_2d[y1 - 1, x1 - 1] + + count = (x2 - x1 + 1) * (y2 - y1 + 1) + return total / max(count, 1) + + +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + +PRIMITIVES = { + # Frame/Xector conversion + # NOTE: red, green, blue, gray, rgb are derived in derived.sexp using (channel frame n) + 'xector': xector_from_frame, + 'to-frame': xector_to_frame, + + # Coordinate generators + # NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived + # in derived.sexp using iota, tile, repeat primitives + + # Alpha (α) - element-wise operations + 'α+': alpha_add, + 'α-': alpha_sub, + 'α*': alpha_mul, + 'α/': alpha_div, + 'α**': alpha_pow, + 'αsqrt': alpha_sqrt, + 'αabs': alpha_abs, + 'αsin': alpha_sin, + 'αcos': alpha_cos, + 'αexp': alpha_exp, + 'αlog': alpha_log, + # NOTE: αclamp is derived in derived.sexp as (max2 lo (min2 hi x)) + 'αmin': alpha_min, + 'αmax': alpha_max, + 'αmod': alpha_mod, + 'αfloor': alpha_floor, + 'αceil': alpha_ceil, + 'αround': alpha_round, + # NOTE: α² / αsq is derived in derived.sexp as (* x x) + + # Alpha comparison + 'α<': alpha_lt, + 'α<=': alpha_le, + 'α>': alpha_gt, + 'α>=': alpha_ge, + 'α=': alpha_eq, + + # Alpha logical + 'αand': alpha_and, + 'αor': alpha_or, + 'αnot': alpha_not, + + # ASCII fallbacks for α + 'alpha+': alpha_add, + 'alpha-': alpha_sub, + 'alpha*': alpha_mul, + 'alpha/': alpha_div, + 'alpha**': alpha_pow, + 'alpha-sqrt': alpha_sqrt, + 'alpha-abs': alpha_abs, + 'alpha-sin': alpha_sin, + 'alpha-cos': alpha_cos, + 'alpha-exp': alpha_exp, + 'alpha-log': alpha_log, + 'alpha-min': alpha_min, + 'alpha-max': alpha_max, + 'alpha-mod': alpha_mod, + 'alpha-floor': alpha_floor, + 'alpha-ceil': alpha_ceil, + 'alpha-round': alpha_round, + 'alpha<': alpha_lt, + 'alpha<=': alpha_le, + 'alpha>': alpha_gt, + 'alpha>=': alpha_ge, + 'alpha=': alpha_eq, + 'alpha-and': alpha_and, + 'alpha-or': alpha_or, + 'alpha-not': alpha_not, + + # Beta (β) - reduction operations + 'β+': beta_add, + 'β*': beta_mul, + 'βmin': beta_min, + 'βmax': beta_max, + 'βmean': beta_mean, + 'βstd': beta_std, + 'βcount': beta_count, + 'βany': beta_any, + 'βall': beta_all, + + # ASCII fallbacks for β + 'beta+': beta_add, + 'beta*': beta_mul, + 'beta-min': beta_min, + 'beta-max': beta_max, + 'beta-mean': beta_mean, + 'beta-std': beta_std, + 'beta-count': beta_count, + 'beta-any': beta_any, + 'beta-all': beta_all, + + # Convenience aliases + 'sum': beta_add, + 'product': beta_mul, + 'mean': beta_mean, + + # Conditional / Selection + 'where': xector_where, + # NOTE: fill, zeros, ones are derived in derived.sexp using iota + 'rand-x': xector_rand, + 'randn-x': xector_randn, + + # Type checking + 'xector?': is_xector, + + # =========================================== + # CORE PRIMITIVES - fundamental operations + # =========================================== + + # Gather/Scatter - parallel indexing + 'gather': xector_gather, + 'gather-2d': xector_gather_2d, + 'scatter': xector_scatter, + 'scatter-add': xector_scatter_add, + + # Group reduce - pooling primitive + 'group-reduce': xector_group_reduce, + + # Shape operations + 'reshape': xector_reshape, + 'shape': xector_shape, + 'xlen': xector_len, + + # Index generation + 'iota': xector_iota, + 'repeat': xector_repeat, + 'tile': xector_tile, + + # Cell/Grid helpers (built on primitives) + 'cell-indices': xector_cell_indices, + 'cell-row': xector_cell_row, + 'cell-col': xector_cell_col, + 'local-x': xector_local_x, + 'local-y': xector_local_y, + 'local-x-norm': xector_local_x_norm, + 'local-y-norm': xector_local_y_norm, + 'pool-frame': xector_pool_frame, + 'num-cells': xector_num_cells, + + # Scan (prefix) operations - cumulative reductions + 'scan+': xector_scan_add, + 'scan*': xector_scan_mul, + 'scan-max': xector_scan_max, + 'scan-min': xector_scan_min, + 'scan-add': xector_scan_add, + 'scan-mul': xector_scan_mul, + + # Outer product - Cartesian operations + 'outer': xector_outer, + 'outer+': xector_outer_add, + 'outer*': xector_outer_mul, + 'outer-add': xector_outer_add, + 'outer-mul': xector_outer_mul, + 'outer-max': xector_outer_max, + 'outer-min': xector_outer_min, + + # Reduce with axis - dimensional reductions + 'reduce-axis': xector_reduce_axis, + 'sum-axis': xector_sum_axis, + 'mean-axis': xector_mean_axis, + 'max-axis': xector_max_axis, + 'min-axis': xector_min_axis, + + # Windowed operations - sliding window computations + 'window': xector_window, + 'window-sum': xector_window_sum, + 'window-mean': xector_window_mean, + 'window-max': xector_window_max, + 'window-min': xector_window_min, + + # Integral image - for fast box blur + 'integral-image': xector_integral_image, + 'box-blur-fast': xector_box_blur_fast, +} diff --git a/sexp_effects/primitives.py b/sexp_effects/primitives.py new file mode 100644 index 0000000..9a50356 --- /dev/null +++ b/sexp_effects/primitives.py @@ -0,0 +1,3075 @@ +""" +Safe Primitives for S-Expression Effects + +These are the building blocks that user-defined effects can use. +All primitives operate only on image data - no filesystem, network, etc. +""" + +import numpy as np +import cv2 +from typing import Any, Callable, Dict, List, Tuple, Optional +from dataclasses import dataclass +import math + + +@dataclass +class ZoneContext: + """Context for a single cell/zone in ASCII art grid.""" + row: int + col: int + row_norm: float # Normalized row position 0-1 + col_norm: float # Normalized col position 0-1 + luminance: float # Cell luminance 0-1 + saturation: float # Cell saturation 0-1 + hue: float # Cell hue 0-360 + r: float # Red component 0-1 + g: float # Green component 0-1 + b: float # Blue component 0-1 + + +class DeterministicRNG: + """Seeded RNG for reproducible effects.""" + + def __init__(self, seed: int = 42): + self._rng = np.random.RandomState(seed) + + def random(self, low: float = 0, high: float = 1) -> float: + return self._rng.uniform(low, high) + + def randint(self, low: int, high: int) -> int: + return self._rng.randint(low, high + 1) + + def gaussian(self, mean: float = 0, std: float = 1) -> float: + return self._rng.normal(mean, std) + + +# Global RNG instance (reset per frame with seed param) +_rng = DeterministicRNG() + + +def reset_rng(seed: int): + """Reset the global RNG with a new seed.""" + global _rng + _rng = DeterministicRNG(seed) + + +# ============================================================================= +# Color Names (FFmpeg/X11 compatible) +# ============================================================================= + +NAMED_COLORS = { + # Basic colors + "black": (0, 0, 0), + "white": (255, 255, 255), + "red": (255, 0, 0), + "green": (0, 128, 0), + "blue": (0, 0, 255), + "yellow": (255, 255, 0), + "cyan": (0, 255, 255), + "magenta": (255, 0, 255), + + # Grays + "gray": (128, 128, 128), + "grey": (128, 128, 128), + "darkgray": (169, 169, 169), + "darkgrey": (169, 169, 169), + "lightgray": (211, 211, 211), + "lightgrey": (211, 211, 211), + "dimgray": (105, 105, 105), + "dimgrey": (105, 105, 105), + "silver": (192, 192, 192), + + # Reds + "darkred": (139, 0, 0), + "firebrick": (178, 34, 34), + "crimson": (220, 20, 60), + "indianred": (205, 92, 92), + "lightcoral": (240, 128, 128), + "salmon": (250, 128, 114), + "darksalmon": (233, 150, 122), + "lightsalmon": (255, 160, 122), + "tomato": (255, 99, 71), + "orangered": (255, 69, 0), + "coral": (255, 127, 80), + + # Oranges + "orange": (255, 165, 0), + "darkorange": (255, 140, 0), + + # Yellows + "gold": (255, 215, 0), + "lightyellow": (255, 255, 224), + "lemonchiffon": (255, 250, 205), + "papayawhip": (255, 239, 213), + "moccasin": (255, 228, 181), + "peachpuff": (255, 218, 185), + "palegoldenrod": (238, 232, 170), + "khaki": (240, 230, 140), + "darkkhaki": (189, 183, 107), + + # Greens + "lime": (0, 255, 0), + "limegreen": (50, 205, 50), + "forestgreen": (34, 139, 34), + "darkgreen": (0, 100, 0), + "seagreen": (46, 139, 87), + "mediumseagreen": (60, 179, 113), + "springgreen": (0, 255, 127), + "mediumspringgreen": (0, 250, 154), + "lightgreen": (144, 238, 144), + "palegreen": (152, 251, 152), + "darkseagreen": (143, 188, 143), + "greenyellow": (173, 255, 47), + "chartreuse": (127, 255, 0), + "lawngreen": (124, 252, 0), + "olivedrab": (107, 142, 35), + "olive": (128, 128, 0), + "darkolivegreen": (85, 107, 47), + "yellowgreen": (154, 205, 50), + + # Cyans/Teals + "aqua": (0, 255, 255), + "teal": (0, 128, 128), + "darkcyan": (0, 139, 139), + "lightcyan": (224, 255, 255), + "aquamarine": (127, 255, 212), + "mediumaquamarine": (102, 205, 170), + "paleturquoise": (175, 238, 238), + "turquoise": (64, 224, 208), + "mediumturquoise": (72, 209, 204), + "darkturquoise": (0, 206, 209), + "cadetblue": (95, 158, 160), + + # Blues + "navy": (0, 0, 128), + "darkblue": (0, 0, 139), + "mediumblue": (0, 0, 205), + "royalblue": (65, 105, 225), + "cornflowerblue": (100, 149, 237), + "steelblue": (70, 130, 180), + "dodgerblue": (30, 144, 255), + "deepskyblue": (0, 191, 255), + "lightskyblue": (135, 206, 250), + "skyblue": (135, 206, 235), + "lightsteelblue": (176, 196, 222), + "lightblue": (173, 216, 230), + "powderblue": (176, 224, 230), + "slateblue": (106, 90, 205), + "mediumslateblue": (123, 104, 238), + "darkslateblue": (72, 61, 139), + "midnightblue": (25, 25, 112), + + # Purples/Violets + "purple": (128, 0, 128), + "darkmagenta": (139, 0, 139), + "darkviolet": (148, 0, 211), + "blueviolet": (138, 43, 226), + "darkorchid": (153, 50, 204), + "mediumorchid": (186, 85, 211), + "orchid": (218, 112, 214), + "violet": (238, 130, 238), + "plum": (221, 160, 221), + "thistle": (216, 191, 216), + "lavender": (230, 230, 250), + "indigo": (75, 0, 130), + "mediumpurple": (147, 112, 219), + "fuchsia": (255, 0, 255), + "hotpink": (255, 105, 180), + "deeppink": (255, 20, 147), + "mediumvioletred": (199, 21, 133), + "palevioletred": (219, 112, 147), + + # Pinks + "pink": (255, 192, 203), + "lightpink": (255, 182, 193), + "mistyrose": (255, 228, 225), + + # Browns + "brown": (165, 42, 42), + "maroon": (128, 0, 0), + "saddlebrown": (139, 69, 19), + "sienna": (160, 82, 45), + "chocolate": (210, 105, 30), + "peru": (205, 133, 63), + "sandybrown": (244, 164, 96), + "burlywood": (222, 184, 135), + "tan": (210, 180, 140), + "rosybrown": (188, 143, 143), + "goldenrod": (218, 165, 32), + "darkgoldenrod": (184, 134, 11), + + # Whites + "snow": (255, 250, 250), + "honeydew": (240, 255, 240), + "mintcream": (245, 255, 250), + "azure": (240, 255, 255), + "aliceblue": (240, 248, 255), + "ghostwhite": (248, 248, 255), + "whitesmoke": (245, 245, 245), + "seashell": (255, 245, 238), + "beige": (245, 245, 220), + "oldlace": (253, 245, 230), + "floralwhite": (255, 250, 240), + "ivory": (255, 255, 240), + "antiquewhite": (250, 235, 215), + "linen": (250, 240, 230), + "lavenderblush": (255, 240, 245), + "wheat": (245, 222, 179), + "cornsilk": (255, 248, 220), + "blanchedalmond": (255, 235, 205), + "bisque": (255, 228, 196), + "navajowhite": (255, 222, 173), + + # Special + "transparent": (0, 0, 0), # Note: no alpha support, just black +} + + +def parse_color(color_spec: str) -> Optional[Tuple[int, int, int]]: + """ + Parse a color specification into RGB tuple. + + Supports: + - Named colors: "red", "green", "lime", "navy", etc. + - Hex colors: "#FF0000", "#f00", "0xFF0000" + - Special modes: "color", "mono", "invert" return None (handled separately) + + Returns: + RGB tuple (r, g, b) or None for special modes + """ + if color_spec is None: + return None + + color_spec = str(color_spec).strip().lower() + + # Special modes handled elsewhere + if color_spec in ("color", "mono", "invert"): + return None + + # Check named colors + if color_spec in NAMED_COLORS: + return NAMED_COLORS[color_spec] + + # Handle hex colors + hex_str = None + if color_spec.startswith("#"): + hex_str = color_spec[1:] + elif color_spec.startswith("0x"): + hex_str = color_spec[2:] + elif all(c in "0123456789abcdef" for c in color_spec) and len(color_spec) in (3, 6): + hex_str = color_spec + + if hex_str: + try: + if len(hex_str) == 3: + # Short form: #RGB -> #RRGGBB + r = int(hex_str[0] * 2, 16) + g = int(hex_str[1] * 2, 16) + b = int(hex_str[2] * 2, 16) + return (r, g, b) + elif len(hex_str) == 6: + r = int(hex_str[0:2], 16) + g = int(hex_str[2:4], 16) + b = int(hex_str[4:6], 16) + return (r, g, b) + except ValueError: + pass + + # Unknown color - default to None (will use original colors) + return None + + +# ============================================================================= +# Image Primitives +# ============================================================================= + +def prim_width(img: np.ndarray) -> int: + """Get image width.""" + return img.shape[1] + + +def prim_height(img: np.ndarray) -> int: + """Get image height.""" + return img.shape[0] + + +def prim_make_image(w: int, h: int, color: List[int]) -> np.ndarray: + """Create a new image filled with color.""" + img = np.zeros((int(h), int(w), 3), dtype=np.uint8) + if color: + img[:, :] = color[:3] + return img + + +def prim_copy(img: np.ndarray) -> np.ndarray: + """Copy an image.""" + return img.copy() + + +def prim_pixel(img: np.ndarray, x: int, y: int) -> List[int]: + """Get pixel at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + return list(img[y, x]) + return [0, 0, 0] + + +def prim_set_pixel(img: np.ndarray, x: int, y: int, color: List[int]) -> np.ndarray: + """Set pixel at (x, y). Returns modified image.""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + img[y, x] = color[:3] + return img + + +def prim_sample(img: np.ndarray, x: float, y: float) -> List[float]: + """Bilinear sample at float coordinates.""" + h, w = img.shape[:2] + x = np.clip(x, 0, w - 1) + y = np.clip(y, 0, h - 1) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + c = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + return list(c) + + +def prim_channel(img: np.ndarray, c: int) -> np.ndarray: + """Extract a single channel as 2D array.""" + return img[:, :, int(c)].copy() + + +def prim_merge_channels(r: np.ndarray, g: np.ndarray, b: np.ndarray) -> np.ndarray: + """Merge three channels into RGB image.""" + return np.stack([r, g, b], axis=-1).astype(np.uint8) + + +def prim_resize(img: np.ndarray, w: int, h: int, mode: str = "linear") -> np.ndarray: + """Resize image. Mode: linear, nearest, area.""" + w, h = int(w), int(h) + if w < 1 or h < 1: + return img + interp = { + "linear": cv2.INTER_LINEAR, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + }.get(mode, cv2.INTER_LINEAR) + return cv2.resize(img, (w, h), interpolation=interp) + + +def prim_crop(img: np.ndarray, x: int, y: int, w: int, h: int) -> np.ndarray: + """Crop a region from image.""" + ih, iw = img.shape[:2] + x, y, w, h = int(x), int(y), int(w), int(h) + x = max(0, min(x, iw)) + y = max(0, min(y, ih)) + w = max(0, min(w, iw - x)) + h = max(0, min(h, ih - y)) + return img[y:y + h, x:x + w].copy() + + +def prim_paste(dst: np.ndarray, src: np.ndarray, x: int, y: int) -> np.ndarray: + """Paste src onto dst at position (x, y).""" + dh, dw = dst.shape[:2] + sh, sw = src.shape[:2] + x, y = int(x), int(y) + + # Calculate valid regions + sx1 = max(0, -x) + sy1 = max(0, -y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + dx1 = max(0, x) + dy1 = max(0, y) + dx2 = dx1 + (sx2 - sx1) + dy2 = dy1 + (sy2 - sy1) + + if dx2 > dx1 and dy2 > dy1: + dst[dy1:dy2, dx1:dx2] = src[sy1:sy2, sx1:sx2] + + return dst + + +# ============================================================================= +# Color Primitives +# ============================================================================= + +def prim_rgb(r: float, g: float, b: float) -> List[int]: + """Create RGB color.""" + return [int(np.clip(r, 0, 255)), + int(np.clip(g, 0, 255)), + int(np.clip(b, 0, 255))] + + +def prim_red(c: List[int]) -> int: + return c[0] if c else 0 + + +def prim_green(c: List[int]) -> int: + return c[1] if len(c) > 1 else 0 + + +def prim_blue(c: List[int]) -> int: + return c[2] if len(c) > 2 else 0 + + +def prim_luminance(c: List[int]) -> float: + """Calculate luminance (grayscale value).""" + if not c: + return 0 + return 0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2] + + +def prim_rgb_to_hsv(c: List[int]) -> List[float]: + """Convert RGB to HSV.""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + mx, mn = max(r, g, b), min(r, g, b) + diff = mx - mn + + if diff == 0: + h = 0 + elif mx == r: + h = (60 * ((g - b) / diff) + 360) % 360 + elif mx == g: + h = (60 * ((b - r) / diff) + 120) % 360 + else: + h = (60 * ((r - g) / diff) + 240) % 360 + + s = 0 if mx == 0 else diff / mx + v = mx + + return [h, s * 100, v * 100] + + +def prim_hsv_to_rgb(hsv: List[float]) -> List[int]: + """Convert HSV to RGB.""" + h, s, v = hsv[0], hsv[1] / 100, hsv[2] / 100 + c = v * s + x = c * (1 - abs((h / 60) % 2 - 1)) + m = v - c + + if h < 60: + r, g, b = c, x, 0 + elif h < 120: + r, g, b = x, c, 0 + elif h < 180: + r, g, b = 0, c, x + elif h < 240: + r, g, b = 0, x, c + elif h < 300: + r, g, b = x, 0, c + else: + r, g, b = c, 0, x + + return [int((r + m) * 255), int((g + m) * 255), int((b + m) * 255)] + + +def prim_blend_color(c1: List[int], c2: List[int], alpha: float) -> List[int]: + """Blend two colors.""" + alpha = np.clip(alpha, 0, 1) + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img: np.ndarray) -> List[int]: + """Get average color of image/region.""" + return [int(x) for x in img.mean(axis=(0, 1))] + + +# ============================================================================= +# Image Operations (Bulk) +# ============================================================================= + +def prim_map_pixels(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each pixel: fn(x, y, [r,g,b]) -> [r,g,b].""" + result = img.copy() + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + new_color = fn(x, y, color) + if new_color is not None: + result[y, x] = new_color[:3] + return result + + +def prim_map_rows(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each row: fn(y, row) -> row.""" + result = img.copy() + h = img.shape[0] + for y in range(h): + row = img[y].copy() + new_row = fn(y, row) + if new_row is not None: + result[y] = new_row + return result + + +def prim_for_grid(img: np.ndarray, cell_size: int, fn: Callable) -> np.ndarray: + """Iterate over grid cells: fn(gx, gy, cell_img) for side effects.""" + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + for gy in range(rows): + for gx in range(cols): + y, x = gy * cell_size, gx * cell_size + cell = img[y:y + cell_size, x:x + cell_size] + fn(gx, gy, cell) + + return img + + +def prim_fold_pixels(img: np.ndarray, init: Any, fn: Callable) -> Any: + """Fold over pixels: fn(acc, x, y, color) -> acc.""" + acc = init + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + acc = fn(acc, x, y, color) + return acc + + +# ============================================================================= +# Convolution / Filters +# ============================================================================= + +def prim_convolve(img: np.ndarray, kernel: List[List[float]]) -> np.ndarray: + """Apply convolution kernel.""" + k = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, k) + + +def prim_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Gaussian blur.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Box blur (faster than Gaussian).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_edges(img: np.ndarray, low: int = 50, high: int = 150) -> np.ndarray: + """Canny edge detection, returns grayscale edges.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, int(low), int(high)) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img: np.ndarray) -> np.ndarray: + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.float32) + sx = cv2.Sobel(gray, cv2.CV_32F, 1, 0) + sy = cv2.Sobel(gray, cv2.CV_32F, 0, 1) + magnitude = np.sqrt(sx ** 2 + sy ** 2) + magnitude = np.clip(magnitude, 0, 255).astype(np.uint8) + return cv2.cvtColor(magnitude, cv2.COLOR_GRAY2RGB) + + +def prim_dilate(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological dilation.""" + kernel = np.ones((size, size), np.uint8) + return cv2.dilate(img, kernel, iterations=1) + + +def prim_erode(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological erosion.""" + kernel = np.ones((size, size), np.uint8) + return cv2.erode(img, kernel, iterations=1) + + +# ============================================================================= +# Geometric Transforms +# ============================================================================= + +def prim_translate(img: np.ndarray, dx: float, dy: float) -> np.ndarray: + """Translate image.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_rotate(img: np.ndarray, angle: float, cx: float = None, cy: float = None) -> np.ndarray: + """Rotate image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_scale(img: np.ndarray, sx: float, sy: float, cx: float = None, cy: float = None) -> np.ndarray: + """Scale image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_flip_h(img: np.ndarray) -> np.ndarray: + """Flip horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img: np.ndarray) -> np.ndarray: + """Flip vertically.""" + return cv2.flip(img, 0) + + +def prim_remap(img: np.ndarray, map_x: np.ndarray, map_y: np.ndarray) -> np.ndarray: + """Remap using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), map_y.astype(np.float32), + cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + +def prim_make_coords(w: int, h: int) -> Tuple[np.ndarray, np.ndarray]: + """Create coordinate grid (map_x, map_y).""" + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + return map_x, map_y + + +# ============================================================================= +# Blending +# ============================================================================= + +def prim_blend_images(a: np.ndarray, b: np.ndarray, alpha: float) -> np.ndarray: + """Blend two images. Auto-resizes b to match a if sizes differ.""" + alpha = np.clip(alpha, 0, 1) + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a: np.ndarray, b: np.ndarray, mode: str) -> np.ndarray: + """Blend with various modes: add, multiply, screen, overlay, difference. + Auto-resizes b to match a if sizes differ.""" + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + af = a.astype(float) / 255 + bf = b.astype(float) / 255 + + if mode == "add": + result = af + bf + elif mode == "multiply": + result = af * bf + elif mode == "screen": + result = 1 - (1 - af) * (1 - bf) + elif mode == "overlay": + mask = af < 0.5 + result = np.where(mask, 2 * af * bf, 1 - 2 * (1 - af) * (1 - bf)) + elif mode == "difference": + result = np.abs(af - bf) + elif mode == "lighten": + result = np.maximum(af, bf) + elif mode == "darken": + result = np.minimum(af, bf) + else: + result = af + + return (np.clip(result, 0, 1) * 255).astype(np.uint8) + + +def prim_mask(img: np.ndarray, mask_img: np.ndarray) -> np.ndarray: + """Apply grayscale mask to image.""" + if len(mask_img.shape) == 3: + mask = cv2.cvtColor(mask_img, cv2.COLOR_RGB2GRAY) + else: + mask = mask_img + mask_f = mask.astype(float) / 255 + result = img.astype(float) * mask_f[:, :, np.newaxis] + return result.astype(np.uint8) + + +# ============================================================================= +# Drawing +# ============================================================================= + +# Simple font (5x7 bitmap characters) +FONT_5X7 = { + ' ': [0, 0, 0, 0, 0, 0, 0], + '.': [0, 0, 0, 0, 0, 0, 4], + ':': [0, 0, 4, 0, 4, 0, 0], + '-': [0, 0, 0, 14, 0, 0, 0], + '=': [0, 0, 14, 0, 14, 0, 0], + '+': [0, 4, 4, 31, 4, 4, 0], + '*': [0, 4, 21, 14, 21, 4, 0], + '#': [10, 31, 10, 10, 31, 10, 0], + '%': [19, 19, 4, 8, 25, 25, 0], + '@': [14, 17, 23, 21, 23, 16, 14], + '0': [14, 17, 19, 21, 25, 17, 14], + '1': [4, 12, 4, 4, 4, 4, 14], + '2': [14, 17, 1, 2, 4, 8, 31], + '3': [31, 2, 4, 2, 1, 17, 14], + '4': [2, 6, 10, 18, 31, 2, 2], + '5': [31, 16, 30, 1, 1, 17, 14], + '6': [6, 8, 16, 30, 17, 17, 14], + '7': [31, 1, 2, 4, 8, 8, 8], + '8': [14, 17, 17, 14, 17, 17, 14], + '9': [14, 17, 17, 15, 1, 2, 12], +} + +# Add uppercase letters +for i, c in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ'): + FONT_5X7[c] = [0] * 7 # Placeholder + + +def prim_draw_char(img: np.ndarray, char: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw a character at position.""" + # Use OpenCV's built-in font for simplicity + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, char, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_draw_text(img: np.ndarray, text: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw text at position.""" + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, text, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_fill_rect(img: np.ndarray, x: int, y: int, w: int, h: int, + color: List[int]) -> np.ndarray: + """Fill rectangle.""" + x, y, w, h = int(x), int(y), int(w), int(h) + img[y:y + h, x:x + w] = color[:3] + return img + + +def prim_draw_line(img: np.ndarray, x1: int, y1: int, x2: int, y2: int, + color: List[int], thickness: int = 1) -> np.ndarray: + """Draw line.""" + cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), tuple(color[:3]), int(thickness)) + return img + + +# ============================================================================= +# Math Primitives +# ============================================================================= + +def prim_sin(x: float) -> float: + return math.sin(x) + + +def prim_cos(x: float) -> float: + return math.cos(x) + + +def prim_tan(x: float) -> float: + return math.tan(x) + + +def prim_atan2(y, x): + if hasattr(y, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape) + return math.atan2(y, x) + + +def prim_sqrt(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.sqrt(np.maximum(0, x._data)), x._shape) + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(0, x)) + return math.sqrt(max(0, x)) + + +def prim_pow(x, y): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + y_data = y._data if hasattr(y, '_data') else y + return Xector(np.power(x._data, y_data), x._shape) + return math.pow(x, y) + + +def prim_abs(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.abs(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.abs(x) + return abs(x) + + +def prim_floor(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) + return int(math.floor(x)) + + +def prim_ceil(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) + return int(math.ceil(x)) + + +def prim_round(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) + return int(round(x)) + + +def prim_min(*args) -> float: + return min(args) + + +def prim_max(*args) -> float: + return max(args) + + +def prim_clamp(x: float, lo: float, hi: float) -> float: + return max(lo, min(hi, x)) + + +def prim_lerp(a: float, b: float, t: float) -> float: + """Linear interpolation.""" + return a + (b - a) * t + + +def prim_mod(a: float, b: float) -> float: + return a % b + + +def prim_random(lo: float = 0, hi: float = 1) -> float: + """Random number from global RNG.""" + return _rng.random(lo, hi) + + +def prim_randint(lo: int, hi: int) -> int: + """Random integer from global RNG.""" + return _rng.randint(lo, hi) + + +def prim_gaussian(mean: float = 0, std: float = 1) -> float: + """Gaussian random from global RNG.""" + return _rng.gaussian(mean, std) + + +def prim_assert(condition, message: str = "Assertion failed"): + """Assert that condition is true, raise error with message if false.""" + if not condition: + raise RuntimeError(f"Assertion error: {message}") + return True + + +# ============================================================================= +# Array/List Primitives +# ============================================================================= + +def prim_length(seq) -> int: + return len(seq) + + +def prim_nth(seq, i: int): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_rest(seq): + return seq[1:] if seq else [] + + +def prim_take(seq, n: int): + return seq[:int(n)] + + +def prim_drop(seq, n: int): + return seq[int(n):] + + +def prim_cons(x, seq): + return [x] + list(seq) + + +def prim_append(*seqs): + result = [] + for s in seqs: + result.extend(s) + return result + + +def prim_reverse(seq): + return list(reversed(seq)) + + +def prim_range(start: int, end: int, step: int = 1) -> List[int]: + return list(range(int(start), int(end), int(step))) + + +def prim_roll(arr: np.ndarray, shift: int, axis: int = 0) -> np.ndarray: + """Circular roll of array.""" + return np.roll(arr, int(shift), axis=int(axis)) + + +def prim_list(*args) -> list: + """Create a list.""" + return list(args) + + +# ============================================================================= +# Primitive Registry +# ============================================================================= + +def prim_add(*args): + return sum(args) + +def prim_sub(a, b=None): + if b is None: + return -a # Unary negation + return a - b + +def prim_mul(*args): + result = 1 + for x in args: + result *= x + return result + +def prim_div(a, b): + return a / b if b != 0 else 0 + +def prim_lt(a, b): + return a < b + +def prim_gt(a, b): + return a > b + +def prim_le(a, b): + return a <= b + +def prim_ge(a, b): + return a >= b + +def prim_eq(a, b): + # Handle None/nil comparisons with numpy arrays + if a is None: + return b is None + if b is None: + return a is None + if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.array_equal(a, b) + return False # array vs non-array + return a == b + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# ============================================================================= +# Vectorized Bulk Operations (true primitives for composing effects) +# ============================================================================= + +def prim_color_matrix(img: np.ndarray, matrix: List[List[float]]) -> np.ndarray: + """Apply a 3x3 color transformation matrix to all pixels.""" + m = np.array(matrix, dtype=np.float32) + result = img.astype(np.float32) @ m.T + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_adjust(img: np.ndarray, brightness: float = 0, contrast: float = 1) -> np.ndarray: + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img: np.ndarray, amount: float) -> np.ndarray: + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img: np.ndarray) -> np.ndarray: + """Invert all pixel values.""" + return (255 - img).astype(np.uint8) + + +def prim_add_noise(img: np.ndarray, amount: float) -> np.ndarray: + """Add gaussian noise to image.""" + noise = _rng._rng.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img: np.ndarray, levels: int) -> np.ndarray: + """Reduce to N color levels per channel.""" + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_shift_hsv(img: np.ndarray, h: float = 0, s: float = 1, v: float = 1) -> np.ndarray: + """Shift HSV: h=degrees offset, s/v=multipliers.""" + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +# ============================================================================= +# Array Math Primitives (vectorized operations on coordinate arrays) +# ============================================================================= + +def prim_arr_add(a: np.ndarray, b) -> np.ndarray: + """Element-wise addition. b can be array or scalar.""" + return (np.asarray(a) + np.asarray(b)).astype(np.float32) + + +def prim_arr_sub(a: np.ndarray, b) -> np.ndarray: + """Element-wise subtraction. b can be array or scalar.""" + return (np.asarray(a) - np.asarray(b)).astype(np.float32) + + +def prim_arr_mul(a: np.ndarray, b) -> np.ndarray: + """Element-wise multiplication. b can be array or scalar.""" + return (np.asarray(a) * np.asarray(b)).astype(np.float32) + + +def prim_arr_div(a: np.ndarray, b) -> np.ndarray: + """Element-wise division. b can be array or scalar.""" + b = np.asarray(b) + # Avoid division by zero + with np.errstate(divide='ignore', invalid='ignore'): + result = np.asarray(a) / np.where(b == 0, 1e-10, b) + return result.astype(np.float32) + + +def prim_arr_mod(a: np.ndarray, b) -> np.ndarray: + """Element-wise modulo.""" + return (np.asarray(a) % np.asarray(b)).astype(np.float32) + + +def prim_arr_sin(a: np.ndarray) -> np.ndarray: + """Element-wise sine.""" + return np.sin(np.asarray(a)).astype(np.float32) + + +def prim_arr_cos(a: np.ndarray) -> np.ndarray: + """Element-wise cosine.""" + return np.cos(np.asarray(a)).astype(np.float32) + + +def prim_arr_tan(a: np.ndarray) -> np.ndarray: + """Element-wise tangent.""" + return np.tan(np.asarray(a)).astype(np.float32) + + +def prim_arr_sqrt(a: np.ndarray) -> np.ndarray: + """Element-wise square root.""" + return np.sqrt(np.maximum(0, np.asarray(a))).astype(np.float32) + + +def prim_arr_pow(a: np.ndarray, b) -> np.ndarray: + """Element-wise power.""" + return np.power(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_abs(a: np.ndarray) -> np.ndarray: + """Element-wise absolute value.""" + return np.abs(np.asarray(a)).astype(np.float32) + + +def prim_arr_neg(a: np.ndarray) -> np.ndarray: + """Element-wise negation.""" + return (-np.asarray(a)).astype(np.float32) + + +def prim_arr_exp(a: np.ndarray) -> np.ndarray: + """Element-wise exponential.""" + return np.exp(np.asarray(a)).astype(np.float32) + + +def prim_arr_atan2(y: np.ndarray, x: np.ndarray) -> np.ndarray: + """Element-wise atan2(y, x).""" + return np.arctan2(np.asarray(y), np.asarray(x)).astype(np.float32) + + +def prim_arr_min(a: np.ndarray, b) -> np.ndarray: + """Element-wise minimum.""" + return np.minimum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_max(a: np.ndarray, b) -> np.ndarray: + """Element-wise maximum.""" + return np.maximum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_clip(a: np.ndarray, lo, hi) -> np.ndarray: + """Element-wise clip to range.""" + return np.clip(np.asarray(a), lo, hi).astype(np.float32) + + +def prim_arr_where(cond: np.ndarray, a, b) -> np.ndarray: + """Element-wise conditional: where cond is true, use a, else b.""" + return np.where(np.asarray(cond), np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_floor(a: np.ndarray) -> np.ndarray: + """Element-wise floor.""" + return np.floor(np.asarray(a)).astype(np.float32) + + +def prim_arr_lerp(a: np.ndarray, b: np.ndarray, t) -> np.ndarray: + """Element-wise linear interpolation.""" + a, b = np.asarray(a), np.asarray(b) + return (a + (b - a) * t).astype(np.float32) + + +# ============================================================================= +# Coordinate Transformation Primitives +# ============================================================================= + +def prim_polar_from_center(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create polar coordinates (r, theta) from image center. + + Usage: + (polar-from-center img) ; center of image + (polar-from-center img cx cy) ; custom center + (polar-from-center w h cx cy) ; explicit dimensions + + Returns: (r, theta) tuple of arrays + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + + return (r, theta) + + +def prim_cart_from_polar(r: np.ndarray, theta: np.ndarray, cx: float, cy: float) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert polar coordinates back to Cartesian. + + Args: + r: radius array + theta: angle array + cx, cy: center point + + Returns: (x, y) tuple of coordinate arrays + """ + x = (cx + r * np.cos(theta)).astype(np.float32) + y = (cy + r * np.sin(theta)).astype(np.float32) + return (x, y) + + +def prim_normalize_coords(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create normalized coordinates (-1 to 1) from center. + + Returns: (x_norm, y_norm) tuple of arrays where center is (0,0) + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + + return (x_norm, y_norm) + + +def prim_coords_x(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get x/first component from coordinate tuple.""" + return coords[0] + + +def prim_coords_y(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get y/second component from coordinate tuple.""" + return coords[1] + + +def prim_make_coords_centered(w: int, h: int, cx: float = None, cy: float = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create coordinate grids centered at (cx, cy). + Like make-coords but returns coordinates relative to center. + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + return (x_coords - cx, y_coords - cy) + + +# ============================================================================= +# Specialized Distortion Primitives +# ============================================================================= + +def prim_wave_displace(w: int, h: int, axis: str, freq: float, amp: float, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create wave displacement maps. + + Args: + w, h: dimensions + axis: "x" (horizontal waves) or "y" (vertical waves) + freq: wave frequency (waves per image width/height) + amp: wave amplitude in pixels + phase: phase offset in radians + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + + if axis == "x" or axis == "horizontal": + # Horizontal waves: displace x based on y + wave = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + map_x = map_x + wave + elif axis == "y" or axis == "vertical": + # Vertical waves: displace y based on x + wave = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_y = map_y + wave + elif axis == "both": + wave_x = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + wave_y = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_x = map_x + wave_x + map_y = map_y + wave_y + + return (map_x, map_y) + + +def prim_swirl_displace(w: int, h: int, strength: float, radius: float = 0.5, + cx: float = None, cy: float = None, falloff: str = "quadratic") -> Tuple[np.ndarray, np.ndarray]: + """ + Create swirl displacement maps. + + Args: + w, h: dimensions + strength: swirl strength in radians + radius: effect radius as fraction of max dimension + cx, cy: center (defaults to image center) + falloff: "linear", "quadratic", or "gaussian" + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + radius_px = max(w, h) * radius + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + angle = np.arctan2(dy, dx) + + # Normalized distance for falloff + norm_dist = dist / radius_px + + # Calculate falloff factor + if falloff == "linear": + factor = np.maximum(0, 1 - norm_dist) + elif falloff == "gaussian": + factor = np.exp(-norm_dist**2 * 2) + else: # quadratic + factor = np.maximum(0, 1 - norm_dist**2) + + # Apply swirl rotation + new_angle = angle + strength * factor + + # Calculate new coordinates + map_x = (cx + dist * np.cos(new_angle)).astype(np.float32) + map_y = (cy + dist * np.sin(new_angle)).astype(np.float32) + + return (map_x, map_y) + + +def prim_fisheye_displace(w: int, h: int, strength: float, cx: float = None, cy: float = None, + zoom_correct: bool = True) -> Tuple[np.ndarray, np.ndarray]: + """ + Create fisheye/barrel distortion displacement maps. + + Args: + w, h: dimensions + strength: distortion strength (-1 to 1, positive=bulge, negative=pinch) + cx, cy: center (defaults to image center) + zoom_correct: auto-zoom to hide black edges + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize coordinates + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + r = np.sqrt(x_norm**2 + y_norm**2) + + # Apply barrel/pincushion distortion + if strength > 0: + r_distorted = r * (1 + strength * r**2) + else: + r_distorted = r / (1 - strength * r**2 + 0.001) + + # Calculate scale factor + with np.errstate(divide='ignore', invalid='ignore'): + scale = np.where(r > 0, r_distorted / r, 1) + + # Apply zoom correction + if zoom_correct and strength > 0: + zoom = 1 + strength * 0.5 + scale = scale / zoom + + # Calculate new coordinates + map_x = (x_norm * scale * (w / 2) + cx).astype(np.float32) + map_y = (y_norm * scale * (h / 2) + cy).astype(np.float32) + + return (map_x, map_y) + + +def prim_kaleidoscope_displace(w: int, h: int, segments: int, rotation: float = 0, + cx: float = None, cy: float = None, zoom: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create kaleidoscope displacement maps. + + Args: + w, h: dimensions + segments: number of symmetry segments (3-16) + rotation: rotation angle in degrees + cx, cy: center (defaults to image center) + zoom: zoom factor + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + segments = max(3, min(int(segments), 16)) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + segment_angle = 2 * np.pi / segments + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Translate to center + x_centered = x_coords - cx + y_centered = y_coords - cy + + # Convert to polar + r = np.sqrt(x_centered**2 + y_centered**2) + theta = np.arctan2(y_centered, x_centered) + + # Apply rotation + theta = theta - np.deg2rad(rotation) + + # Fold angle into first segment and mirror + theta_normalized = theta % (2 * np.pi) + segment_idx = (theta_normalized / segment_angle).astype(int) + theta_in_segment = theta_normalized - segment_idx * segment_angle + + # Mirror alternating segments + mirror_mask = (segment_idx % 2) == 1 + theta_in_segment = np.where(mirror_mask, segment_angle - theta_in_segment, theta_in_segment) + + # Apply zoom + r = r / zoom + + # Convert back to Cartesian + map_x = (r * np.cos(theta_in_segment) + cx).astype(np.float32) + map_y = (r * np.sin(theta_in_segment) + cy).astype(np.float32) + + return (map_x, map_y) + + +# ============================================================================= +# Character/ASCII Art Primitives +# ============================================================================= + +# Character sets ordered by visual density (light to dark) +CHAR_ALPHABETS = { + "standard": " .`'^\",:;Il!i><~+_-?][}{1)(|/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$", + "blocks": " ░▒▓█", + "simple": " .-:=+*#%@", + "digits": " 0123456789", +} + +# Global atlas cache: keyed on (frozenset(chars), cell_size) -> +# (atlas_array, char_to_idx) where atlas_array is (N, cell_size, cell_size) uint8. +_char_atlas_cache = {} +_CHAR_ATLAS_CACHE_MAX = 32 + + +def _get_char_atlas(alphabet: str, cell_size: int) -> dict: + """Get or create character atlas for alphabet (legacy dict version).""" + atlas_arr, char_to_idx = _get_render_atlas(alphabet, cell_size) + # Build legacy dict from array + idx_to_char = {v: k for k, v in char_to_idx.items()} + return {idx_to_char[i]: atlas_arr[i] for i in range(len(atlas_arr))} + + +def _get_render_atlas(unique_chars_or_alphabet, cell_size: int): + """Get or build a stacked numpy atlas for vectorised rendering. + + Args: + unique_chars_or_alphabet: Either an alphabet name (str looked up in + CHAR_ALPHABETS), a literal character string, or a set/frozenset + of characters. + cell_size: Pixel size of each cell. + + Returns: + (atlas_array, char_to_idx) where + atlas_array: (num_chars, cell_size, cell_size) uint8 masks + char_to_idx: dict mapping character -> index in atlas_array + """ + if isinstance(unique_chars_or_alphabet, (set, frozenset)): + chars_tuple = tuple(sorted(unique_chars_or_alphabet)) + else: + resolved = CHAR_ALPHABETS.get(unique_chars_or_alphabet, unique_chars_or_alphabet) + chars_tuple = tuple(resolved) + + cache_key = (chars_tuple, cell_size) + cached = _char_atlas_cache.get(cache_key) + if cached is not None: + return cached + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + n = len(chars_tuple) + atlas = np.zeros((n, cell_size, cell_size), dtype=np.uint8) + char_to_idx = {} + + for i, char in enumerate(chars_tuple): + char_to_idx[char] = i + if char and char != ' ': + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + cv2.putText(atlas[i], char, (text_x, text_y), + font, font_scale, 255, thickness, cv2.LINE_AA) + except Exception: + pass + + # Evict oldest entry if cache is full + if len(_char_atlas_cache) >= _CHAR_ATLAS_CACHE_MAX: + _char_atlas_cache.pop(next(iter(_char_atlas_cache))) + + _char_atlas_cache[cache_key] = (atlas, char_to_idx) + return atlas, char_to_idx + + +def prim_cell_sample(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample image into cell grid, returning average colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + + Args: + img: source image + cell_size: size of each cell in pixels + + Returns: (colors, luminances) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32)) + + # Crop to exact grid then block-average via cv2 area interpolation. + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + return (colors, luminances) + + +def cell_sample_extended(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray, List[List[ZoneContext]]]: + """ + Sample image into cell grid, returning colors, luminances, and full zone contexts. + + Args: + img: source image (RGB) + cell_size: size of each cell in pixels + + Returns: (colors, luminances, zone_contexts) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + - zone_contexts: 2D list of ZoneContext objects with full cell data + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32), + [[ZoneContext(0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]]) + + # Crop to grid + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + + # Reshape and average + reshaped = cropped.reshape(rows, cell_size, cols, cell_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(np.uint8) + + # Compute luminance (0-255) + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + # Normalize colors to 0-1 for HSV/saturation calculations + colors_float = colors.astype(np.float32) / 255.0 + + # Compute HSV values for each cell + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + diff = max_c - min_c + + # Saturation + saturation = np.where(max_c > 0, diff / max_c, 0) + + # Hue (0-360) + hue = np.zeros((rows, cols), dtype=np.float32) + # Avoid division by zero + mask = diff > 0 + r, g, b = colors_float[:, :, 0], colors_float[:, :, 1], colors_float[:, :, 2] + + # Red is max + red_max = mask & (max_c == r) + hue[red_max] = 60 * (((g[red_max] - b[red_max]) / diff[red_max]) % 6) + + # Green is max + green_max = mask & (max_c == g) + hue[green_max] = 60 * ((b[green_max] - r[green_max]) / diff[green_max] + 2) + + # Blue is max + blue_max = mask & (max_c == b) + hue[blue_max] = 60 * ((r[blue_max] - g[blue_max]) / diff[blue_max] + 4) + + # Ensure hue is in 0-360 range + hue = hue % 360 + + # Build zone contexts + zone_contexts = [] + for row in range(rows): + row_contexts = [] + for col in range(cols): + ctx = ZoneContext( + row=row, + col=col, + row_norm=row / max(1, rows - 1) if rows > 1 else 0.5, + col_norm=col / max(1, cols - 1) if cols > 1 else 0.5, + luminance=luminances[row, col] / 255.0, # Normalize to 0-1 + saturation=float(saturation[row, col]), + hue=float(hue[row, col]), + r=float(colors_float[row, col, 0]), + g=float(colors_float[row, col, 1]), + b=float(colors_float[row, col, 2]), + ) + row_contexts.append(ctx) + zone_contexts.append(row_contexts) + + return (colors, luminances, zone_contexts) + + +def prim_luminance_to_chars(luminances: np.ndarray, alphabet: str, contrast: float = 1.0) -> List[List[str]]: + """ + Map luminance values to characters from alphabet. + + Args: + luminances: (rows, cols) array of brightness values 0-255 + alphabet: character set name or literal string (light to dark) + contrast: contrast boost factor + + Returns: 2D list of single-character strings + """ + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + num_chars = len(chars) + + # Apply contrast + lum = luminances.astype(np.float32) + if contrast != 1.0: + lum = (lum - 128) * contrast + 128 + lum = np.clip(lum, 0, 255) + + # Map to indices + indices = ((lum / 255) * (num_chars - 1)).astype(np.int32) + indices = np.clip(indices, 0, num_chars - 1) + + # Vectorised conversion via numpy char array lookup + chars_arr = np.array(list(chars)) + char_grid = chars_arr[indices.ravel()].reshape(indices.shape) + + return char_grid.tolist() + + +def prim_render_char_grid(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + cell_size: int, color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False) -> np.ndarray: + """ + Render a grid of characters onto an image. + + Uses vectorised numpy operations instead of per-cell Python loops: + the character atlas is looked up via fancy indexing and the full + mask + colour image are assembled in bulk. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + cell_size: size of each cell + color_mode: "color" (original colors), "mono" (white), "invert", + or any color name/hex value ("green", "lime", "#00ff00") + background_color: background color name/hex ("black", "navy", "#001100") + invert_colors: if True, swap foreground and background colors + + Returns: rendered image + """ + # Parse color_mode - may be a named color or hex value + fg_color = parse_color(color_mode) + + # Parse background_color + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + # Handle invert_colors - swap fg and bg + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + + # --- Build atlas & index grid --- + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + atlas, char_to_idx = _get_render_atlas(unique_chars, cell_size) + + # Convert 2D char list to index array using ordinal lookup table + # (avoids per-cell Python dict lookup). + space_idx = char_to_idx.get(' ', 0) + max_ord = max(ord(ch) for ch in char_to_idx) + 1 + ord_lookup = np.full(max_ord, space_idx, dtype=np.int32) + for ch, idx in char_to_idx.items(): + if ch: + ord_lookup[ord(ch)] = idx + + flat = [ch for row in chars for ch in row] + ords = np.frombuffer(np.array(flat, dtype='U1'), dtype=np.uint32) + char_indices = ord_lookup[ords].reshape(rows, cols) + + # --- Vectorised mask assembly --- + # atlas[char_indices] -> (rows, cols, cell_size, cell_size) + # Transpose to (rows, cell_size, cols, cell_size) then reshape to full image. + all_masks = atlas[char_indices] + full_mask = all_masks.transpose(0, 2, 1, 3).reshape(h, w) + + # Expand per-cell colours to per-pixel (only when needed). + need_color_full = (color_mode in ("color", "invert") + or (fg_color is None and color_mode != "mono")) + + if need_color_full: + color_full = np.repeat( + np.repeat(colors[:rows, :cols], cell_size, axis=0), + cell_size, axis=1) + + # --- Vectorised colour composite --- + # Use element-wise multiply/np.where instead of boolean-indexed scatter + # for much better memory access patterns. + mask_u8 = (full_mask > 0).astype(np.uint8)[:, :, np.newaxis] + + if color_mode == "invert": + # Background is source colour; characters are black. + # result = color_full * (1 - mask) + result = color_full * (1 - mask_u8) + elif fg_color is not None: + # Fixed foreground colour on background. + fg = np.array(fg_color, dtype=np.uint8) + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, fg, bg_arr).astype(np.uint8) + elif color_mode == "mono": + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, np.uint8(255), bg_arr).astype(np.uint8) + else: + # "color" mode – each cell uses its source colour on bg. + if bg == [0, 0, 0]: + result = color_full * mask_u8 + else: + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, color_full, bg_arr).astype(np.uint8) + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_render_char_grid_fx(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + luminances: np.ndarray, cell_size: int, + color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False, + char_jitter: float = 0.0, + char_scale: float = 1.0, + char_rotation: float = 0.0, + char_hue_shift: float = 0.0, + jitter_source: str = "none", + scale_source: str = "none", + rotation_source: str = "none", + hue_source: str = "none") -> np.ndarray: + """ + Render a grid of characters with per-character effects. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + luminances: (rows, cols) array of luminance values (0-255) + cell_size: size of each cell + color_mode: "color", "mono", "invert", or any color name/hex + background_color: background color name/hex + invert_colors: if True, swap foreground and background colors + char_jitter: base jitter amount in pixels + char_scale: base scale factor (1.0 = normal) + char_rotation: base rotation in degrees + char_hue_shift: base hue shift in degrees (0-360) + jitter_source: source for jitter modulation ("none", "luminance", "position", "random") + scale_source: source for scale modulation + rotation_source: source for rotation modulation + hue_source: source for hue modulation + + Per-character effect sources: + "none" - use base value only + "luminance" - modulate by cell luminance (0-1) + "inv_luminance" - modulate by inverse luminance (dark = high) + "saturation" - modulate by cell color saturation + "position_x" - modulate by horizontal position (0-1) + "position_y" - modulate by vertical position (0-1) + "position_diag" - modulate by diagonal position + "random" - random per-cell value (deterministic from position) + "center_dist" - distance from center (0=center, 1=corner) + + Returns: rendered image + """ + # Parse colors + fg_color = parse_color(color_mode) + + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + result = np.full((h, w, 3), bg, dtype=np.uint8) + + # Normalize luminances to 0-1 + lum_normalized = luminances.astype(np.float32) / 255.0 + + # Compute saturation from colors + colors_float = colors.astype(np.float32) / 255.0 + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0) + + # Helper to get modulation value for a cell + def get_mod_value(source: str, r: int, c: int) -> float: + if source == "none": + return 1.0 + elif source == "luminance": + return lum_normalized[r, c] + elif source == "inv_luminance": + return 1.0 - lum_normalized[r, c] + elif source == "saturation": + return saturation[r, c] + elif source == "position_x": + return c / max(1, cols - 1) if cols > 1 else 0.5 + elif source == "position_y": + return r / max(1, rows - 1) if rows > 1 else 0.5 + elif source == "position_diag": + px = c / max(1, cols - 1) if cols > 1 else 0.5 + py = r / max(1, rows - 1) if rows > 1 else 0.5 + return (px + py) / 2.0 + elif source == "random": + # Deterministic random based on position + seed = (r * 1000 + c) % 10000 + return ((seed * 9301 + 49297) % 233280) / 233280.0 + elif source == "center_dist": + cx, cy = (cols - 1) / 2.0, (rows - 1) / 2.0 + dx = (c - cx) / max(1, cx) if cx > 0 else 0 + dy = (r - cy) / max(1, cy) if cy > 0 else 0 + return min(1.0, math.sqrt(dx*dx + dy*dy)) + else: + return 1.0 + + # Build character atlas at base size + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, we need to render characters larger then transform + max_scale = max(1.0, char_scale * 1.5) # Allow headroom for scaling + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with effects + for r in range(rows): + for c in range(cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell modulation values + jitter_mod = get_mod_value(jitter_source, r, c) + scale_mod = get_mod_value(scale_source, r, c) + rot_mod = get_mod_value(rotation_source, r, c) + hue_mod = get_mod_value(hue_source, r, c) + + # Compute effective values + eff_jitter = char_jitter * jitter_mod + eff_scale = char_scale * (0.5 + 0.5 * scale_mod) if scale_source != "none" else char_scale + eff_rotation = char_rotation * (rot_mod * 2 - 1) # -1 to 1 range + eff_hue_shift = char_hue_shift * hue_mod + + # Apply transformations + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + # Fill cell with source color first + cy1 = max(0, r * cell_size) + cy2 = min(h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + # Convert to HSV, shift hue, convert back + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + # Cast to int to avoid uint8 overflow, then back to uint8 + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < h and 0 <= px < w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def _render_with_cell_effect( + frame: np.ndarray, + chars: List[List[str]], + colors: np.ndarray, + luminances: np.ndarray, + zone_contexts: List[List['ZoneContext']], + cell_size: int, + bg_color: tuple, + fg_color: tuple, + color_mode: str, + cell_effect, # Lambda or callable: (cell_image, zone_dict) -> cell_image + extra_params: dict, + interp, + env, + result: np.ndarray, +) -> np.ndarray: + """ + Render ASCII art using a cell_effect lambda for arbitrary per-cell transforms. + + Each character is rendered to a cell image, the cell_effect is called with + (cell_image, zone_dict), and the returned cell is composited into result. + + This allows arbitrary effects (rotate, blur, etc.) to be applied per-character. + """ + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + out_h, out_w = result.shape[:2] + + # Build character atlas (cell-sized colored characters on transparent bg) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + # Helper to render a single character cell + def render_char_cell(char: str, color: np.ndarray) -> np.ndarray: + """Render a character onto a cell-sized RGB image.""" + cell = np.full((cell_size, cell_size, 3), bg_color, dtype=np.uint8) + if not char or char == ' ': + return cell + + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + + # Render character in white on mask, then apply color + mask = np.zeros((cell_size, cell_size), dtype=np.uint8) + cv2.putText(mask, char, (text_x, text_y), font, font_scale, 255, thickness, cv2.LINE_AA) + + # Apply color where mask is set + for ch in range(3): + cell[:, :, ch] = np.where(mask > 0, color[ch], bg_color[ch]) + except: + pass + + return cell + + # Helper to evaluate cell_effect (handles artdag Lambda objects) + def eval_cell_effect(cell_img: np.ndarray, zone_dict: dict) -> np.ndarray: + """Call cell_effect with (cell_image, zone_dict), handle Lambda objects.""" + if callable(cell_effect): + return cell_effect(cell_img, zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(cell_effect, ArtdagLambda): + # Build env with closure values + eval_env = dict(cell_effect.closure) if cell_effect.closure else {} + # Bind lambda parameters + if len(cell_effect.params) >= 2: + eval_env[cell_effect.params[0]] = cell_img + eval_env[cell_effect.params[1]] = zone_dict + elif len(cell_effect.params) == 1: + # Single param gets zone_dict with cell as 'cell' key + zone_dict['cell'] = cell_img + eval_env[cell_effect.params[0]] = zone_dict + + # Add primitives to eval env + eval_env.update(PRIMITIVES) + + # Add effect runner - allows calling any loaded sexp effect on a cell + # Usage: (apply-effect "effect_name" cell {"param" value ...}) + # Or: (apply-effect "effect_name" cell) for defaults + def apply_effect_fn(effect_name, frame, params=None): + """Run a loaded sexp effect on a frame (cell).""" + if interp and hasattr(interp, 'run_effect'): + if params is None: + params = {} + result, _ = interp.run_effect(effect_name, frame, params, {}) + return result + return frame + eval_env['apply-effect'] = apply_effect_fn + + # Also inject loaded effects directly as callable functions + # These wrappers take positional args in common order for each effect + # Usage: (blur cell 5) or (rotate cell 45) etc. + if interp and hasattr(interp, 'effects'): + for effect_name in interp.effects: + # Create a wrapper that calls run_effect with positional-to-named mapping + def make_effect_fn(name): + def effect_fn(frame, *args): + # Map common positional args to named params + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 1: + params['offset_x'] = args[0] + if len(args) >= 2: + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['block_size'] = args[0] + elif name == 'wave' and len(args) >= 1: + params['amplitude'] = args[0] + if len(args) >= 2: + params['frequency'] = args[1] + elif name == 'noise' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'posterize' and len(args) >= 1: + params['levels'] = args[0] + elif name == 'threshold' and len(args) >= 1: + params['level'] = args[0] + elif name == 'sharpen' and len(args) >= 1: + params['amount'] = args[0] + elif len(args) == 1 and isinstance(args[0], dict): + # Accept dict as single arg + params = args[0] + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + eval_env[effect_name] = make_effect_fn(effect_name) + + result = artdag_evaluate(cell_effect.body, eval_env) + if isinstance(result, np.ndarray): + return result + return cell_img + except ImportError: + pass + + # Fallback: return cell unchanged + return cell_img + + # Render each cell + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + zone = zone_contexts[r][c] + + # Determine character color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + color = np.array([0, 0, 0], dtype=np.uint8) + else: + color = colors[r, c].copy() + + # Render character to cell image + cell_img = render_char_cell(char, color) + + # Build zone dict + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + 'color': color.tolist(), + 'cell_size': cell_size, + } + # Add extra params (energy, rotation_scale, etc.) + if extra_params: + zone_dict.update(extra_params) + + # Call cell_effect + modified_cell = eval_cell_effect(cell_img, zone_dict) + + # Ensure result is valid + if modified_cell is None or not isinstance(modified_cell, np.ndarray): + modified_cell = cell_img + if modified_cell.shape[:2] != (cell_size, cell_size): + # Resize if cell size changed + modified_cell = cv2.resize(modified_cell, (cell_size, cell_size)) + if len(modified_cell.shape) == 2: + # Convert grayscale to RGB + modified_cell = cv2.cvtColor(modified_cell, cv2.COLOR_GRAY2RGB) + + # Composite into result + y1 = r * cell_size + x1 = c * cell_size + y2 = min(y1 + cell_size, out_h) + x2 = min(x1 + cell_size, out_w) + ch = y2 - y1 + cw = x2 - x1 + result[y1:y2, x1:x2] = modified_cell[:ch, :cw] + + # Resize to match original frame if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + bg = list(bg_color) + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int, + char_size_override: int, # If set, overrides cols-based calculation + alphabet: str, + color_mode: str, + background: str, + contrast: float, + char_hue_expr, # Expression, literal, or None + char_sat_expr, # Expression, literal, or None + char_bright_expr, # Expression, literal, or None + char_scale_expr, # Expression, literal, or None + char_rotation_expr, # Expression, literal, or None + char_jitter_expr, # Expression, literal, or None + interp, # Interpreter for expression evaluation + env, # Environment with bound values + extra_params=None, # Extra params to include in zone dict for lambdas + cell_effect=None, # Lambda (cell_image, zone_dict) -> cell_image for arbitrary cell effects +) -> np.ndarray: + """ + Render ASCII art with per-zone expression-driven transforms. + + Args: + frame: Source image (H, W, 3) RGB uint8 + cols: Number of character columns + char_size_override: If set, use this cell size instead of cols-based + alphabet: Character set name or literal string + color_mode: "color", "mono", "invert", or color name/hex + background: Background color name or hex + contrast: Contrast boost for character selection + char_hue_expr: Expression for hue shift (evaluated per zone) + char_sat_expr: Expression for saturation adjustment (evaluated per zone) + char_bright_expr: Expression for brightness adjustment (evaluated per zone) + char_scale_expr: Expression for scale factor (evaluated per zone) + char_rotation_expr: Expression for rotation degrees (evaluated per zone) + char_jitter_expr: Expression for position jitter (evaluated per zone) + interp: Interpreter instance for expression evaluation + env: Environment with bound variables + cell_effect: Optional lambda that receives (cell_image, zone_dict) and returns + a modified cell_image. When provided, each character is rendered + to a cell image, passed to this lambda, and the result composited. + This allows arbitrary effects to be applied per-character. + + Zone variables available in expressions: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: Rendered image + """ + h, w = frame.shape[:2] + # Use char_size if provided, otherwise calculate from cols + if char_size_override is not None: + cell_size = max(4, int(char_size_override)) + else: + cell_size = max(4, w // cols) + + # Get zone data using extended sampling + colors, luminances, zone_contexts = cell_sample_extended(frame, cell_size) + + # Convert luminances to characters + chars = prim_luminance_to_chars(luminances, alphabet, contrast) + + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + + # Parse colors + fg_color = parse_color(color_mode) + if isinstance(background, (list, tuple)): + bg_color = tuple(int(c) for c in background[:3]) + else: + bg_color = parse_color(background) + if bg_color is None: + bg_color = (0, 0, 0) + + # Arrays for per-zone transform values + hue_shifts = np.zeros((grid_rows, grid_cols), dtype=np.float32) + saturations = np.ones((grid_rows, grid_cols), dtype=np.float32) + brightness = np.ones((grid_rows, grid_cols), dtype=np.float32) + scales = np.ones((grid_rows, grid_cols), dtype=np.float32) + rotations = np.zeros((grid_rows, grid_cols), dtype=np.float32) + jitters = np.zeros((grid_rows, grid_cols), dtype=np.float32) + + # Helper to evaluate expression or return literal value + def eval_expr(expr, zone, char): + if expr is None: + return None + if isinstance(expr, (int, float)): + return expr + + # Build zone dict for lambda calls + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + } + # Add extra params (energy, rotation_scale, etc.) for lambdas to access + if extra_params: + zone_dict.update(extra_params) + + # Check if it's a Python callable + if callable(expr): + return expr(zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(expr, ArtdagLambda): + # Build env with zone dict and any closure values + eval_env = dict(expr.closure) if expr.closure else {} + # Bind the lambda parameter to zone_dict + if expr.params: + eval_env[expr.params[0]] = zone_dict + return artdag_evaluate(expr.body, eval_env) + except ImportError: + pass + + # It's an expression - evaluate with zone context (sexp_effects style) + return interp.eval_with_zone(expr, env, zone) + + # Evaluate expressions for each zone + for r in range(grid_rows): + for c in range(grid_cols): + zone = zone_contexts[r][c] + char = chars[r][c] + + val = eval_expr(char_hue_expr, zone, char) + if val is not None: + hue_shifts[r, c] = float(val) + + val = eval_expr(char_sat_expr, zone, char) + if val is not None: + saturations[r, c] = float(val) + + val = eval_expr(char_bright_expr, zone, char) + if val is not None: + brightness[r, c] = float(val) + + val = eval_expr(char_scale_expr, zone, char) + if val is not None: + scales[r, c] = float(val) + + val = eval_expr(char_rotation_expr, zone, char) + if val is not None: + rotations[r, c] = float(val) + + val = eval_expr(char_jitter_expr, zone, char) + if val is not None: + jitters[r, c] = float(val) + + # Now render with computed transform arrays + out_h, out_w = grid_rows * cell_size, grid_cols * cell_size + bg = list(bg_color) + result = np.full((out_h, out_w, 3), bg, dtype=np.uint8) + + # If cell_effect is provided, use the cell-mapper rendering path + if cell_effect is not None: + return _render_with_cell_effect( + frame, chars, colors, luminances, zone_contexts, + cell_size, bg_color, fg_color, color_mode, + cell_effect, extra_params, interp, env, result + ) + + # Build character atlas + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, render characters larger then transform + max_scale = max(1.0, np.max(scales) * 1.5) + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with per-zone effects + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell values + eff_scale = scales[r, c] + eff_rotation = rotations[r, c] + eff_jitter = jitters[r, c] + eff_hue_shift = hue_shifts[r, c] + eff_brightness = brightness[r, c] + eff_saturation = saturations[r, c] + + # Apply transformations to character + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + cy1 = max(0, r * cell_size) + cy2 = min(out_h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(out_w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode - use source colors + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply saturation adjustment + if abs(eff_saturation - 1.0) > 0.01 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_sat = np.clip(int(color_hsv[0, 0, 1] * eff_saturation), 0, 255) + color_hsv[0, 0, 1] = np.uint8(new_sat) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply brightness adjustment + if abs(eff_brightness - 1.0) > 0.01: + color = np.clip(color.astype(np.float32) * eff_brightness, 0, 255).astype(np.uint8) + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < out_h and 0 <= px < out_w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_make_char_grid(rows: int, cols: int, fill_char: str = " ") -> List[List[str]]: + """Create a character grid filled with a character.""" + return [[fill_char for _ in range(cols)] for _ in range(rows)] + + +def prim_set_char(chars: List[List[str]], row: int, col: int, char: str) -> List[List[str]]: + """Set a character at position (returns modified copy).""" + result = [r[:] for r in chars] # shallow copy rows + if 0 <= row < len(result) and 0 <= col < len(result[0]): + result[row][col] = char + return result + + +def prim_get_char(chars: List[List[str]], row: int, col: int) -> str: + """Get character at position.""" + if 0 <= row < len(chars) and 0 <= col < len(chars[0]): + return chars[row][col] + return " " + + +def prim_char_grid_dimensions(chars: List[List[str]]) -> Tuple[int, int]: + """Get (rows, cols) of character grid.""" + if not chars: + return (0, 0) + return (len(chars), len(chars[0]) if chars[0] else 0) + + +def prim_alphabet_char(alphabet: str, index: int) -> str: + """Get character at index from alphabet (wraps around).""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + if not chars: + return " " + return chars[int(index) % len(chars)] + + +def prim_alphabet_length(alphabet: str) -> int: + """Get length of alphabet.""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + return len(chars) + + +def prim_map_char_grid(chars: List[List[str]], luminances: np.ndarray, fn: Callable) -> List[List[str]]: + """ + Map a function over character grid. + + fn receives (row, col, char, luminance) and returns new character. + This allows per-cell character selection based on position, brightness, etc. + + Example: + (map-char-grid chars luminances + (lambda (r c ch lum) + (if (> lum 128) + (alphabet-char "blocks" (floor (/ lum 50))) + ch))) + """ + if not chars or not chars[0]: + return chars + + rows = len(chars) + cols = len(chars[0]) + result = [] + + for r in range(rows): + row = [] + for c in range(cols): + ch = chars[r][c] + lum = float(luminances[r, c]) if r < luminances.shape[0] and c < luminances.shape[1] else 0 + new_ch = fn(r, c, ch, lum) + row.append(str(new_ch) if new_ch else " ") + result.append(row) + + return result + + +def prim_map_colors(colors: np.ndarray, fn: Callable) -> np.ndarray: + """ + Map a function over color grid. + + fn receives (row, col, color) and returns new [r, g, b]. + Color is a list [r, g, b]. + """ + if colors.size == 0: + return colors + + rows, cols = colors.shape[:2] + result = colors.copy() + + for r in range(rows): + for c in range(cols): + color = list(colors[r, c]) + new_color = fn(r, c, color) + if new_color is not None: + result[r, c] = new_color[:3] + + return result + + +# ============================================================================= +# Glitch Art Primitives +# ============================================================================= + +def prim_pixelsort(img: np.ndarray, sort_by: str = "lightness", + threshold_low: float = 50, threshold_high: float = 200, + angle: float = 0, reverse: bool = False) -> np.ndarray: + """ + Pixel sorting glitch effect. + + Args: + img: source image + sort_by: "lightness", "hue", "saturation", "red", "green", "blue" + threshold_low: pixels below this aren't sorted + threshold_high: pixels above this aren't sorted + angle: 0 = horizontal, 90 = vertical + reverse: reverse sort order + """ + h, w = img.shape[:2] + + # Rotate for vertical sorting + if 45 <= (angle % 180) <= 135: + frame = np.transpose(img, (1, 0, 2)) + h, w = frame.shape[:2] + rotated = True + else: + frame = img + rotated = False + + result = frame.copy() + + # Get sort values + if sort_by == "lightness": + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + elif sort_by == "hue": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 0].astype(np.float32) + elif sort_by == "saturation": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 1].astype(np.float32) + elif sort_by == "red": + sort_values = frame[:, :, 0].astype(np.float32) + elif sort_by == "green": + sort_values = frame[:, :, 1].astype(np.float32) + elif sort_by == "blue": + sort_values = frame[:, :, 2].astype(np.float32) + else: + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + + # Create mask + mask = (sort_values >= threshold_low) & (sort_values <= threshold_high) + + # Sort each row + for y in range(h): + row = result[y].copy() + row_mask = mask[y] + row_values = sort_values[y] + + # Find contiguous segments + segments = [] + start = None + for i, val in enumerate(row_mask): + if val and start is None: + start = i + elif not val and start is not None: + segments.append((start, i)) + start = None + if start is not None: + segments.append((start, len(row_mask))) + + # Sort each segment + for seg_start, seg_end in segments: + if seg_end - seg_start > 1: + segment_values = row_values[seg_start:seg_end] + sort_indices = np.argsort(segment_values) + if reverse: + sort_indices = sort_indices[::-1] + row[seg_start:seg_end] = row[seg_start:seg_end][sort_indices] + + result[y] = row + + # Rotate back + if rotated: + result = np.transpose(result, (1, 0, 2)) + + return np.ascontiguousarray(result) + + +def prim_datamosh(img: np.ndarray, prev_frame: np.ndarray, + block_size: int = 32, corruption: float = 0.3, + max_offset: int = 50, color_corrupt: bool = True) -> np.ndarray: + """ + Datamosh/glitch block corruption effect. + + Args: + img: current frame + prev_frame: previous frame (or None) + block_size: size of corruption blocks + corruption: probability 0-1 of corrupting each block + max_offset: maximum pixel shift + color_corrupt: also apply color channel shifts + """ + if corruption <= 0: + return img.copy() + + block_size = max(8, min(int(block_size), 128)) + h, w = img.shape[:2] + result = img.copy() + + for by in range(0, h, block_size): + for bx in range(0, w, block_size): + bh = min(block_size, h - by) + bw = min(block_size, w - bx) + + if _rng.random() < corruption: + corruption_type = _rng.randint(0, 3) + + if corruption_type == 0 and max_offset > 0: + # Shift + ox = _rng.randint(-max_offset, max_offset) + oy = _rng.randint(-max_offset, max_offset) + src_x = max(0, min(bx + ox, w - bw)) + src_y = max(0, min(by + oy, h - bh)) + result[by:by+bh, bx:bx+bw] = img[src_y:src_y+bh, src_x:src_x+bw] + + elif corruption_type == 1 and prev_frame is not None: + # Duplicate from previous frame + if prev_frame.shape == img.shape: + result[by:by+bh, bx:bx+bw] = prev_frame[by:by+bh, bx:bx+bw] + + elif corruption_type == 2 and color_corrupt: + # Color channel shift + block = result[by:by+bh, bx:bx+bw].copy() + shift = _rng.randint(1, 3) + channel = _rng.randint(0, 2) + block[:, :, channel] = np.roll(block[:, :, channel], shift, axis=0) + result[by:by+bh, bx:bx+bw] = block + + else: + # Swap with another block + other_bx = _rng.randint(0, max(0, w - bw)) + other_by = _rng.randint(0, max(0, h - bh)) + temp = result[by:by+bh, bx:bx+bw].copy() + result[by:by+bh, bx:bx+bw] = img[other_by:other_by+bh, other_bx:other_bx+bw] + result[other_by:other_by+bh, other_bx:other_bx+bw] = temp + + return result + + +def prim_ripple_displace(w: int, h: int, freq: float, amp: float, cx: float = None, cy: float = None, + decay: float = 0, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create radial ripple displacement maps. + + Args: + w, h: dimensions + freq: ripple frequency + amp: ripple amplitude in pixels + cx, cy: center + decay: how fast ripples decay with distance (0 = no decay) + phase: phase offset + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + + # Calculate ripple displacement (radial) + ripple = np.sin(2 * np.pi * freq * dist / max(w, h) + phase) * amp + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max(w, h)) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return (map_x, map_y) + + +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Image + 'width': prim_width, + 'height': prim_height, + 'make-image': prim_make_image, + 'copy': prim_copy, + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, + + # Color + 'rgb': prim_rgb, + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'blend-color': prim_blend_color, + 'average-color': prim_average_color, + + # Vectorized bulk operations + 'color-matrix': prim_color_matrix, + 'adjust': prim_adjust, + 'mix-gray': prim_mix_gray, + 'invert-img': prim_invert_img, + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'shift-hsv': prim_shift_hsv, + + # Bulk operations + 'map-pixels': prim_map_pixels, + 'map-rows': prim_map_rows, + 'for-grid': prim_for_grid, + 'fold-pixels': prim_fold_pixels, + + # Filters + 'convolve': prim_convolve, + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'edges': prim_edges, + 'sobel': prim_sobel, + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Geometry + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + 'mask': prim_mask, + + # Drawing + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + 'fill-rect': prim_fill_rect, + 'draw-line': prim_draw_line, + + # Math + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'atan2': prim_atan2, + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'mod': prim_mod, + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + 'assert': prim_assert, + 'pi': math.pi, + 'tau': math.tau, + + # Array + 'length': prim_length, + 'len': prim_length, # alias + 'nth': prim_nth, + 'first': prim_first, + 'rest': prim_rest, + 'take': prim_take, + 'drop': prim_drop, + 'cons': prim_cons, + 'append': prim_append, + 'reverse': prim_reverse, + 'range': prim_range, + 'roll': prim_roll, + 'list': prim_list, + + # Array math (vectorized operations on coordinate arrays) + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-neg': prim_arr_neg, + 'arr-exp': prim_arr_exp, + 'arr-atan2': prim_arr_atan2, + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-lerp': prim_arr_lerp, + + # Coordinate transformations + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, + 'normalize-coords': prim_normalize_coords, + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'make-coords-centered': prim_make_coords_centered, + + # Specialized distortion maps + 'wave-displace': prim_wave_displace, + 'swirl-displace': prim_swirl_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + 'ripple-displace': prim_ripple_displace, + + # Character/ASCII art + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'ascii-fx-zone': prim_ascii_fx_zone, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + + # Glitch art + 'pixelsort': prim_pixelsort, + 'datamosh': prim_datamosh, + +} diff --git a/sexp_effects/test_interpreter.py b/sexp_effects/test_interpreter.py new file mode 100644 index 0000000..550b21a --- /dev/null +++ b/sexp_effects/test_interpreter.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Test the S-expression effect interpreter. +""" + +import numpy as np +import sys +from pathlib import Path + +# Add parent to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sexp_effects import ( + get_interpreter, + load_effects_dir, + run_effect, + list_effects, + parse, +) + + +def test_parser(): + """Test S-expression parser.""" + print("Testing parser...") + + # Simple expressions + assert parse("42") == 42 + assert parse("3.14") == 3.14 + assert parse('"hello"') == "hello" + assert parse("true") == True + + # Lists + assert parse("(+ 1 2)")[0].name == "+" + assert parse("(+ 1 2)")[1] == 1 + + # Nested + expr = parse("(define x (+ 1 2))") + assert expr[0].name == "define" + + print(" Parser OK") + + +def test_interpreter_basics(): + """Test basic interpreter operations.""" + print("Testing interpreter basics...") + + interp = get_interpreter() + + # Math + assert interp.eval(parse("(+ 1 2)")) == 3 + assert interp.eval(parse("(* 3 4)")) == 12 + assert interp.eval(parse("(- 10 3)")) == 7 + + # Comparison + assert interp.eval(parse("(< 1 2)")) == True + assert interp.eval(parse("(> 1 2)")) == False + + # Let binding + assert interp.eval(parse("(let ((x 5)) x)")) == 5 + assert interp.eval(parse("(let ((x 5) (y 3)) (+ x y))")) == 8 + + # Lambda + result = interp.eval(parse("((lambda (x) (* x 2)) 5)")) + assert result == 10 + + # If + assert interp.eval(parse("(if true 1 2)")) == 1 + assert interp.eval(parse("(if false 1 2)")) == 2 + + print(" Interpreter basics OK") + + +def test_primitives(): + """Test image primitives.""" + print("Testing primitives...") + + interp = get_interpreter() + + # Create test image + img = np.zeros((100, 100, 3), dtype=np.uint8) + img[50, 50] = [255, 128, 64] + + interp.global_env.set('test_img', img) + + # Width/height + assert interp.eval(parse("(width test_img)")) == 100 + assert interp.eval(parse("(height test_img)")) == 100 + + # Pixel + pixel = interp.eval(parse("(pixel test_img 50 50)")) + assert pixel == [255, 128, 64] + + # RGB + color = interp.eval(parse("(rgb 100 150 200)")) + assert color == [100, 150, 200] + + # Luminance + lum = interp.eval(parse("(luminance (rgb 100 100 100))")) + assert abs(lum - 100) < 1 + + print(" Primitives OK") + + +def test_effect_loading(): + """Test loading effects from .sexp files.""" + print("Testing effect loading...") + + # Load all effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + effects = list_effects() + print(f" Loaded {len(effects)} effects: {', '.join(sorted(effects))}") + + assert len(effects) > 0 + print(" Effect loading OK") + + +def test_effect_execution(): + """Test running effects on images.""" + print("Testing effect execution...") + + # Create test image + img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + + # Load effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Test each effect + effects = list_effects() + passed = 0 + failed = [] + + for name in sorted(effects): + try: + result, state = run_effect(name, img.copy(), {'_time': 0.5}, {}) + assert isinstance(result, np.ndarray) + assert result.shape == img.shape + passed += 1 + print(f" {name}: OK") + except Exception as e: + failed.append((name, str(e))) + print(f" {name}: FAILED - {e}") + + print(f" Passed: {passed}/{len(effects)}") + if failed: + print(f" Failed: {[f[0] for f in failed]}") + + return passed, failed + + +def test_ascii_fx_zone(): + """Test ascii_fx_zone effect with zone expressions.""" + print("Testing ascii_fx_zone...") + + interp = get_interpreter() + + # Load the effect + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Create gradient test frame + frame = np.zeros((120, 160, 3), dtype=np.uint8) + for x in range(160): + frame[:, x] = int(x / 160 * 255) + frame = np.stack([frame[:,:,0]]*3, axis=2) + + # Test 1: Basic without expressions + result, _ = run_effect('ascii_fx_zone', frame, {'cols': 20}, {}) + assert result.shape == frame.shape + print(" Basic run: OK") + + # Test 2: With zone-lum expression + expr = parse('(* zone-lum 180)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': expr + }, {}) + assert result.shape == frame.shape + print(" Zone-lum expression: OK") + + # Test 3: With multiple expressions + scale_expr = parse('(+ 0.5 (* zone-lum 0.5))') + rot_expr = parse('(* zone-row-norm 30)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_scale': scale_expr, + 'char_rotation': rot_expr + }, {}) + assert result.shape == frame.shape + print(" Multiple expressions: OK") + + # Test 4: With numeric literals + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': 90, + 'char_scale': 1.2 + }, {}) + assert result.shape == frame.shape + print(" Numeric literals: OK") + + # Test 5: Zone position expressions + col_expr = parse('(* zone-col-norm 360)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': col_expr + }, {}) + assert result.shape == frame.shape + print(" Zone position expression: OK") + + print(" ascii_fx_zone OK") + + +def main(): + print("=" * 60) + print("S-Expression Effect Interpreter Tests") + print("=" * 60) + + test_parser() + test_interpreter_basics() + test_primitives() + test_effect_loading() + test_ascii_fx_zone() + passed, failed = test_effect_execution() + + print("=" * 60) + if not failed: + print("All tests passed!") + else: + print(f"Tests completed with {len(failed)} failures") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/sexp_effects/wgsl_compiler.py b/sexp_effects/wgsl_compiler.py new file mode 100644 index 0000000..0c8b024 --- /dev/null +++ b/sexp_effects/wgsl_compiler.py @@ -0,0 +1,715 @@ +""" +S-Expression to WGSL Compiler + +Compiles sexp effect definitions to WGSL compute shaders for GPU execution. +The compilation happens at effect upload time (AOT), not at runtime. + +Architecture: +- Parse sexp AST +- Analyze primitives used +- Generate WGSL compute shader + +Shader Categories: +1. Per-pixel ops: brightness, invert, grayscale, sepia (1 thread per pixel) +2. Geometric transforms: rotate, scale, wave, ripple (coordinate remap + sample) +3. Neighborhood ops: blur, sharpen, edge detect (sample neighbors) +""" + +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from pathlib import Path +import math + +from .parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class WGSLParam: + """A shader parameter (uniform).""" + name: str + wgsl_type: str # f32, i32, u32, vec2f, etc. + default: Any + + +@dataclass +class CompiledEffect: + """Result of compiling an sexp effect to WGSL.""" + name: str + wgsl_code: str + params: List[WGSLParam] + workgroup_size: Tuple[int, int, int] = (16, 16, 1) + # Metadata for runtime + uses_time: bool = False + uses_sampling: bool = False # Needs texture sampler + category: str = "per_pixel" # per_pixel, geometric, neighborhood + + +@dataclass +class CompilerContext: + """Context during compilation.""" + effect_name: str = "" + params: Dict[str, WGSLParam] = field(default_factory=dict) + locals: Dict[str, str] = field(default_factory=dict) # local var -> wgsl expr + required_libs: Set[str] = field(default_factory=set) + uses_time: bool = False + uses_sampling: bool = False + temp_counter: int = 0 + + def fresh_temp(self) -> str: + """Generate a fresh temporary variable name.""" + self.temp_counter += 1 + return f"_t{self.temp_counter}" + + +class SexpToWGSLCompiler: + """ + Compiles S-expression effect definitions to WGSL compute shaders. + """ + + # Map sexp types to WGSL types + TYPE_MAP = { + 'int': 'i32', + 'float': 'f32', + 'bool': 'u32', # WGSL doesn't have bool in storage + 'string': None, # Strings handled specially + } + + # Per-pixel primitives that can be compiled directly + PER_PIXEL_PRIMITIVES = { + 'color_ops:invert-img', + 'color_ops:grayscale', + 'color_ops:sepia', + 'color_ops:adjust', + 'color_ops:adjust-brightness', + 'color_ops:shift-hsv', + 'color_ops:quantize', + } + + # Geometric primitives (coordinate remapping) + GEOMETRIC_PRIMITIVES = { + 'geometry:scale-img', + 'geometry:rotate-img', + 'geometry:translate', + 'geometry:flip-h', + 'geometry:flip-v', + 'geometry:remap', + } + + def __init__(self): + self.ctx: Optional[CompilerContext] = None + + def compile_file(self, path: str) -> CompiledEffect: + """Compile an effect from a .sexp file.""" + with open(path, 'r') as f: + content = f.read() + exprs = parse_all(content) + return self.compile(exprs) + + def compile_string(self, sexp_code: str) -> CompiledEffect: + """Compile an effect from an sexp string.""" + exprs = parse_all(sexp_code) + return self.compile(exprs) + + def compile(self, expr: Any) -> CompiledEffect: + """Compile a parsed sexp expression.""" + self.ctx = CompilerContext() + + # Handle multiple top-level expressions (require-primitives, define-effect) + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self._process_toplevel(e) + else: + self._process_toplevel(expr) + + # Generate the WGSL shader + wgsl = self._generate_wgsl() + + # Determine category based on primitives used + category = self._determine_category() + + return CompiledEffect( + name=self.ctx.effect_name, + wgsl_code=wgsl, + params=list(self.ctx.params.values()), + uses_time=self.ctx.uses_time, + uses_sampling=self.ctx.uses_sampling, + category=category, + ) + + def _process_toplevel(self, expr: Any): + """Process a top-level expression.""" + if not isinstance(expr, list) or not expr: + return + + head = expr[0] + if isinstance(head, Symbol): + if head.name == 'require-primitives': + # Track required primitive libraries + for lib in expr[1:]: + lib_name = lib.name if isinstance(lib, Symbol) else str(lib) + self.ctx.required_libs.add(lib_name) + + elif head.name == 'define-effect': + self._compile_effect_def(expr) + + def _compile_effect_def(self, expr: list): + """Compile a define-effect form.""" + # (define-effect name :params (...) body) + self.ctx.effect_name = expr[1].name if isinstance(expr[1], Symbol) else str(expr[1]) + + # Parse :params and body + i = 2 + body = None + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword) and item.name == 'params': + self._parse_params(expr[i + 1]) + i += 2 + elif isinstance(item, Keyword): + i += 2 # Skip other keywords + else: + body = item + i += 1 + + if body: + self.ctx.body_expr = body + + def _parse_params(self, params_list: list): + """Parse the :params block.""" + for param_def in params_list: + if not isinstance(param_def, list): + continue + + name = param_def[0].name if isinstance(param_def[0], Symbol) else str(param_def[0]) + + # Parse keyword args + param_type = 'float' + default = 0 + + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 < len(param_def): + val = param_def[i + 1] + if item.name == 'type': + param_type = val.name if isinstance(val, Symbol) else str(val) + elif item.name == 'default': + default = val + i += 2 + else: + i += 1 + + wgsl_type = self.TYPE_MAP.get(param_type, 'f32') + if wgsl_type: + self.ctx.params[name] = WGSLParam(name, wgsl_type, default) + + def _determine_category(self) -> str: + """Determine shader category based on primitives used.""" + for lib in self.ctx.required_libs: + if lib == 'geometry': + return 'geometric' + if lib == 'filters': + return 'neighborhood' + return 'per_pixel' + + def _generate_wgsl(self) -> str: + """Generate the complete WGSL shader code.""" + lines = [] + + # Header comment + lines.append(f"// WGSL Shader: {self.ctx.effect_name}") + lines.append(f"// Auto-generated from sexp effect definition") + lines.append("") + + # Bindings + lines.append("@group(0) @binding(0) var input: array;") + lines.append("@group(0) @binding(1) var output: array;") + lines.append("") + + # Params struct + if self.ctx.params: + lines.append("struct Params {") + lines.append(" width: u32,") + lines.append(" height: u32,") + lines.append(" time: f32,") + for param in self.ctx.params.values(): + lines.append(f" {param.name}: {param.wgsl_type},") + lines.append("}") + lines.append("@group(0) @binding(2) var params: Params;") + else: + lines.append("struct Params {") + lines.append(" width: u32,") + lines.append(" height: u32,") + lines.append(" time: f32,") + lines.append("}") + lines.append("@group(0) @binding(2) var params: Params;") + lines.append("") + + # Helper functions + lines.extend(self._generate_helpers()) + lines.append("") + + # Main compute shader + lines.append("@compute @workgroup_size(16, 16, 1)") + lines.append("fn main(@builtin(global_invocation_id) gid: vec3) {") + lines.append(" let x = gid.x;") + lines.append(" let y = gid.y;") + lines.append(" if (x >= params.width || y >= params.height) { return; }") + lines.append(" let idx = y * params.width + x;") + lines.append("") + + # Compile the effect body + body_code = self._compile_expr(self.ctx.body_expr) + lines.append(f" // Effect: {self.ctx.effect_name}") + lines.append(body_code) + lines.append("}") + + return "\n".join(lines) + + def _generate_helpers(self) -> List[str]: + """Generate WGSL helper functions.""" + helpers = [] + + # Pack/unpack RGB from u32 + helpers.append("fn unpack_rgb(packed: u32) -> vec3 {") + helpers.append(" let r = f32((packed >> 16u) & 0xFFu) / 255.0;") + helpers.append(" let g = f32((packed >> 8u) & 0xFFu) / 255.0;") + helpers.append(" let b = f32(packed & 0xFFu) / 255.0;") + helpers.append(" return vec3(r, g, b);") + helpers.append("}") + helpers.append("") + + helpers.append("fn pack_rgb(rgb: vec3) -> u32 {") + helpers.append(" let r = u32(clamp(rgb.r, 0.0, 1.0) * 255.0);") + helpers.append(" let g = u32(clamp(rgb.g, 0.0, 1.0) * 255.0);") + helpers.append(" let b = u32(clamp(rgb.b, 0.0, 1.0) * 255.0);") + helpers.append(" return (r << 16u) | (g << 8u) | b;") + helpers.append("}") + helpers.append("") + + # Bilinear sampling for geometric transforms + if self.ctx.uses_sampling or 'geometry' in self.ctx.required_libs: + helpers.append("fn sample_bilinear(sx: f32, sy: f32) -> vec3 {") + helpers.append(" let w = f32(params.width);") + helpers.append(" let h = f32(params.height);") + helpers.append(" let cx = clamp(sx, 0.0, w - 1.001);") + helpers.append(" let cy = clamp(sy, 0.0, h - 1.001);") + helpers.append(" let x0 = u32(cx);") + helpers.append(" let y0 = u32(cy);") + helpers.append(" let x1 = min(x0 + 1u, params.width - 1u);") + helpers.append(" let y1 = min(y0 + 1u, params.height - 1u);") + helpers.append(" let fx = cx - f32(x0);") + helpers.append(" let fy = cy - f32(y0);") + helpers.append(" let c00 = unpack_rgb(input[y0 * params.width + x0]);") + helpers.append(" let c10 = unpack_rgb(input[y0 * params.width + x1]);") + helpers.append(" let c01 = unpack_rgb(input[y1 * params.width + x0]);") + helpers.append(" let c11 = unpack_rgb(input[y1 * params.width + x1]);") + helpers.append(" let top = mix(c00, c10, fx);") + helpers.append(" let bot = mix(c01, c11, fx);") + helpers.append(" return mix(top, bot, fy);") + helpers.append("}") + helpers.append("") + + # HSV conversion for color effects + if 'color_ops' in self.ctx.required_libs or 'color' in self.ctx.required_libs: + helpers.append("fn rgb_to_hsv(rgb: vec3) -> vec3 {") + helpers.append(" let mx = max(max(rgb.r, rgb.g), rgb.b);") + helpers.append(" let mn = min(min(rgb.r, rgb.g), rgb.b);") + helpers.append(" let d = mx - mn;") + helpers.append(" var h = 0.0;") + helpers.append(" if (d > 0.0) {") + helpers.append(" if (mx == rgb.r) { h = (rgb.g - rgb.b) / d; }") + helpers.append(" else if (mx == rgb.g) { h = 2.0 + (rgb.b - rgb.r) / d; }") + helpers.append(" else { h = 4.0 + (rgb.r - rgb.g) / d; }") + helpers.append(" h = h / 6.0;") + helpers.append(" if (h < 0.0) { h = h + 1.0; }") + helpers.append(" }") + helpers.append(" let s = select(0.0, d / mx, mx > 0.0);") + helpers.append(" return vec3(h, s, mx);") + helpers.append("}") + helpers.append("") + + helpers.append("fn hsv_to_rgb(hsv: vec3) -> vec3 {") + helpers.append(" let h = hsv.x * 6.0;") + helpers.append(" let s = hsv.y;") + helpers.append(" let v = hsv.z;") + helpers.append(" let c = v * s;") + helpers.append(" let x = c * (1.0 - abs(h % 2.0 - 1.0));") + helpers.append(" let m = v - c;") + helpers.append(" var rgb: vec3;") + helpers.append(" if (h < 1.0) { rgb = vec3(c, x, 0.0); }") + helpers.append(" else if (h < 2.0) { rgb = vec3(x, c, 0.0); }") + helpers.append(" else if (h < 3.0) { rgb = vec3(0.0, c, x); }") + helpers.append(" else if (h < 4.0) { rgb = vec3(0.0, x, c); }") + helpers.append(" else if (h < 5.0) { rgb = vec3(x, 0.0, c); }") + helpers.append(" else { rgb = vec3(c, 0.0, x); }") + helpers.append(" return rgb + vec3(m, m, m);") + helpers.append("}") + helpers.append("") + + return helpers + + def _compile_expr(self, expr: Any, indent: int = 4) -> str: + """Compile an sexp expression to WGSL code.""" + ind = " " * indent + + # Literals + if isinstance(expr, (int, float)): + return f"{ind}// literal: {expr}" + + if isinstance(expr, str): + return f'{ind}// string: "{expr}"' + + # Symbol reference + if isinstance(expr, Symbol): + name = expr.name + if name == 'frame': + return f"{ind}let rgb = unpack_rgb(input[idx]);" + if name == 't' or name == '_time': + self.ctx.uses_time = True + return f"{ind}let t = params.time;" + if name in self.ctx.params: + return f"{ind}let {name} = params.{name};" + if name in self.ctx.locals: + return f"{ind}// local: {name}" + return f"{ind}// unknown symbol: {name}" + + # List (function call or special form) + if isinstance(expr, list) and expr: + head = expr[0] + + if isinstance(head, Symbol): + form = head.name + + # Special forms + if form == 'let' or form == 'let*': + return self._compile_let(expr, indent) + + if form == 'if': + return self._compile_if(expr, indent) + + if form == 'or': + # (or a b) - return a if truthy, else b + return self._compile_or(expr, indent) + + # Primitive calls + if ':' in form: + return self._compile_primitive_call(expr, indent) + + # Arithmetic + if form in ('+', '-', '*', '/'): + return self._compile_arithmetic(expr, indent) + + if form in ('>', '<', '>=', '<=', '='): + return self._compile_comparison(expr, indent) + + if form == 'max': + return self._compile_builtin('max', expr[1:], indent) + + if form == 'min': + return self._compile_builtin('min', expr[1:], indent) + + return f"{ind}// unhandled: {expr}" + + def _compile_let(self, expr: list, indent: int) -> str: + """Compile let/let* binding form.""" + ind = " " * indent + lines = [] + + bindings = expr[1] + body = expr[2] + + # Parse bindings (Clojure style: [x 1 y 2] or Scheme style: ((x 1) (y 2))) + pairs = [] + if bindings and isinstance(bindings[0], Symbol): + # Clojure style + i = 0 + while i < len(bindings) - 1: + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + value = bindings[i + 1] + pairs.append((name, value)) + i += 2 + else: + # Scheme style + for binding in bindings: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + value = binding[1] + pairs.append((name, value)) + + # Compile bindings + for name, value in pairs: + val_code = self._expr_to_wgsl(value) + lines.append(f"{ind}let {name} = {val_code};") + self.ctx.locals[name] = val_code + + # Compile body + body_lines = self._compile_body(body, indent) + lines.append(body_lines) + + return "\n".join(lines) + + def _compile_body(self, body: Any, indent: int) -> str: + """Compile the body of an effect (the final image expression).""" + ind = " " * indent + + # Most effects end with a primitive call that produces the output + if isinstance(body, list) and body: + head = body[0] + if isinstance(head, Symbol) and ':' in head.name: + return self._compile_primitive_call(body, indent) + + # If body is just 'frame', pass through + if isinstance(body, Symbol) and body.name == 'frame': + return f"{ind}output[idx] = input[idx];" + + return f"{ind}// body: {body}" + + def _compile_primitive_call(self, expr: list, indent: int) -> str: + """Compile a primitive function call.""" + ind = " " * indent + head = expr[0] + prim_name = head.name if isinstance(head, Symbol) else str(head) + args = expr[1:] + + # Per-pixel color operations + if prim_name == 'color_ops:invert-img': + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let result = vec3(1.0, 1.0, 1.0) - rgb; +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:grayscale': + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let gray = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b; +{ind}let result = vec3(gray, gray, gray); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:adjust-brightness': + amount = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let adj = f32({amount}) / 255.0; +{ind}let result = clamp(rgb + vec3(adj, adj, adj), vec3(0.0, 0.0, 0.0), vec3(1.0, 1.0, 1.0)); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:adjust': + # (adjust img brightness contrast) + brightness = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + contrast = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let centered = rgb - vec3(0.5, 0.5, 0.5); +{ind}let contrasted = centered * {contrast}; +{ind}let brightened = contrasted + vec3(0.5, 0.5, 0.5) + vec3({brightness}/255.0); +{ind}let result = clamp(brightened, vec3(0.0), vec3(1.0)); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:sepia': + intensity = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let sepia_r = 0.393 * rgb.r + 0.769 * rgb.g + 0.189 * rgb.b; +{ind}let sepia_g = 0.349 * rgb.r + 0.686 * rgb.g + 0.168 * rgb.b; +{ind}let sepia_b = 0.272 * rgb.r + 0.534 * rgb.g + 0.131 * rgb.b; +{ind}let sepia = vec3(sepia_r, sepia_g, sepia_b); +{ind}let result = mix(rgb, sepia, {intensity}); +{ind}output[idx] = pack_rgb(clamp(result, vec3(0.0), vec3(1.0)));""" + + if prim_name == 'color_ops:shift-hsv': + h_shift = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + s_mult = self._expr_to_wgsl(args[2]) if len(args) > 2 else "1.0" + v_mult = self._expr_to_wgsl(args[3]) if len(args) > 3 else "1.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}var hsv = rgb_to_hsv(rgb); +{ind}hsv.x = fract(hsv.x + {h_shift} / 360.0); +{ind}hsv.y = clamp(hsv.y * {s_mult}, 0.0, 1.0); +{ind}hsv.z = clamp(hsv.z * {v_mult}, 0.0, 1.0); +{ind}let result = hsv_to_rgb(hsv); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'color_ops:quantize': + levels = self._expr_to_wgsl(args[1]) if len(args) > 1 else "8.0" + return f"""{ind}let rgb = unpack_rgb(input[idx]); +{ind}let lvl = max(2.0, {levels}); +{ind}let result = floor(rgb * lvl) / lvl; +{ind}output[idx] = pack_rgb(result);""" + + # Geometric transforms + if prim_name == 'geometry:scale-img': + sx = self._expr_to_wgsl(args[1]) if len(args) > 1 else "1.0" + sy = self._expr_to_wgsl(args[2]) if len(args) > 2 else sx + self.ctx.uses_sampling = True + return f"""{ind}let w = f32(params.width); +{ind}let h = f32(params.height); +{ind}let cx = w / 2.0; +{ind}let cy = h / 2.0; +{ind}let sx = f32(x) - cx; +{ind}let sy = f32(y) - cy; +{ind}let src_x = sx / {sx} + cx; +{ind}let src_y = sy / {sy} + cy; +{ind}let result = sample_bilinear(src_x, src_y); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'geometry:rotate-img': + angle = self._expr_to_wgsl(args[1]) if len(args) > 1 else "0.0" + self.ctx.uses_sampling = True + return f"""{ind}let w = f32(params.width); +{ind}let h = f32(params.height); +{ind}let cx = w / 2.0; +{ind}let cy = h / 2.0; +{ind}let angle_rad = {angle} * 3.14159265 / 180.0; +{ind}let cos_a = cos(-angle_rad); +{ind}let sin_a = sin(-angle_rad); +{ind}let dx = f32(x) - cx; +{ind}let dy = f32(y) - cy; +{ind}let src_x = dx * cos_a - dy * sin_a + cx; +{ind}let src_y = dx * sin_a + dy * cos_a + cy; +{ind}let result = sample_bilinear(src_x, src_y); +{ind}output[idx] = pack_rgb(result);""" + + if prim_name == 'geometry:flip-h': + return f"""{ind}let src_idx = y * params.width + (params.width - 1u - x); +{ind}output[idx] = input[src_idx];""" + + if prim_name == 'geometry:flip-v': + return f"""{ind}let src_idx = (params.height - 1u - y) * params.width + x; +{ind}output[idx] = input[src_idx];""" + + # Image library + if prim_name == 'image:blur': + radius = self._expr_to_wgsl(args[1]) if len(args) > 1 else "5" + # Box blur approximation (separable would be better) + return f"""{ind}let radius = i32({radius}); +{ind}var sum = vec3(0.0, 0.0, 0.0); +{ind}var count = 0.0; +{ind}for (var dy = -radius; dy <= radius; dy = dy + 1) {{ +{ind} for (var dx = -radius; dx <= radius; dx = dx + 1) {{ +{ind} let sx = i32(x) + dx; +{ind} let sy = i32(y) + dy; +{ind} if (sx >= 0 && sx < i32(params.width) && sy >= 0 && sy < i32(params.height)) {{ +{ind} let sidx = u32(sy) * params.width + u32(sx); +{ind} sum = sum + unpack_rgb(input[sidx]); +{ind} count = count + 1.0; +{ind} }} +{ind} }} +{ind}}} +{ind}let result = sum / count; +{ind}output[idx] = pack_rgb(result);""" + + # Fallback - passthrough + return f"""{ind}// Unimplemented primitive: {prim_name} +{ind}output[idx] = input[idx];""" + + def _compile_if(self, expr: list, indent: int) -> str: + """Compile if expression.""" + ind = " " * indent + cond = self._expr_to_wgsl(expr[1]) + then_expr = expr[2] + else_expr = expr[3] if len(expr) > 3 else None + + lines = [] + lines.append(f"{ind}if ({cond}) {{") + lines.append(self._compile_body(then_expr, indent + 4)) + if else_expr: + lines.append(f"{ind}}} else {{") + lines.append(self._compile_body(else_expr, indent + 4)) + lines.append(f"{ind}}}") + + return "\n".join(lines) + + def _compile_or(self, expr: list, indent: int) -> str: + """Compile or expression - returns first truthy value.""" + # For numeric context, (or a b) means "a if a != 0 else b" + a = self._expr_to_wgsl(expr[1]) + b = self._expr_to_wgsl(expr[2]) if len(expr) > 2 else "0.0" + return f"select({b}, {a}, {a} != 0.0)" + + def _compile_arithmetic(self, expr: list, indent: int) -> str: + """Compile arithmetic expression to inline WGSL.""" + op = expr[0].name + operands = [self._expr_to_wgsl(arg) for arg in expr[1:]] + + if len(operands) == 1: + if op == '-': + return f"(-{operands[0]})" + return operands[0] + + return f"({f' {op} '.join(operands)})" + + def _compile_comparison(self, expr: list, indent: int) -> str: + """Compile comparison expression.""" + op = expr[0].name + if op == '=': + op = '==' + a = self._expr_to_wgsl(expr[1]) + b = self._expr_to_wgsl(expr[2]) + return f"({a} {op} {b})" + + def _compile_builtin(self, fn: str, args: list, indent: int) -> str: + """Compile builtin function call.""" + compiled_args = [self._expr_to_wgsl(arg) for arg in args] + return f"{fn}({', '.join(compiled_args)})" + + def _expr_to_wgsl(self, expr: Any) -> str: + """Convert an expression to inline WGSL code.""" + if isinstance(expr, (int, float)): + # Ensure floats have decimal point + if isinstance(expr, float) or '.' not in str(expr): + return f"{float(expr)}" + return str(expr) + + if isinstance(expr, str): + return f'"{expr}"' + + if isinstance(expr, Symbol): + name = expr.name + if name == 'frame': + return "rgb" # Assume rgb is already loaded + if name == 't' or name == '_time': + self.ctx.uses_time = True + return "params.time" + if name == 'pi': + return "3.14159265" + if name in self.ctx.params: + return f"params.{name}" + if name in self.ctx.locals: + return name + return name + + if isinstance(expr, list) and expr: + head = expr[0] + if isinstance(head, Symbol): + form = head.name + + # Arithmetic + if form in ('+', '-', '*', '/'): + return self._compile_arithmetic(expr, 0) + + # Comparison + if form in ('>', '<', '>=', '<=', '='): + return self._compile_comparison(expr, 0) + + # Builtins + if form in ('max', 'min', 'abs', 'floor', 'ceil', 'sin', 'cos', 'sqrt'): + args = [self._expr_to_wgsl(a) for a in expr[1:]] + return f"{form}({', '.join(args)})" + + if form == 'or': + return self._compile_or(expr, 0) + + # Image dimension queries + if form == 'image:width': + return "f32(params.width)" + if form == 'image:height': + return "f32(params.height)" + + return f"/* unknown: {expr} */" + + +def compile_effect(sexp_code: str) -> CompiledEffect: + """Convenience function to compile an sexp effect string.""" + compiler = SexpToWGSLCompiler() + return compiler.compile_string(sexp_code) + + +def compile_effect_file(path: str) -> CompiledEffect: + """Convenience function to compile an sexp effect file.""" + compiler = SexpToWGSLCompiler() + return compiler.compile_file(path) diff --git a/storage_providers.py b/storage_providers.py new file mode 100644 index 0000000..1cee65d --- /dev/null +++ b/storage_providers.py @@ -0,0 +1,1009 @@ +""" +Storage provider abstraction for user-attachable storage. + +Supports: +- Pinata (IPFS pinning service) +- web3.storage (IPFS pinning service) +- Local filesystem storage +""" + +import hashlib +import json +import logging +import os +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + + +class StorageProvider(ABC): + """Abstract base class for storage backends.""" + + provider_type: str = "unknown" + + @abstractmethod + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """ + Pin content to storage. + + Args: + cid: SHA3-256 hash of the content + data: Raw bytes to store + filename: Optional filename hint + + Returns: + IPFS CID or provider-specific ID, or None on failure + """ + pass + + @abstractmethod + async def unpin(self, cid: str) -> bool: + """ + Unpin content from storage. + + Args: + cid: SHA3-256 hash of the content + + Returns: + True if unpinned successfully + """ + pass + + @abstractmethod + async def get(self, cid: str) -> Optional[bytes]: + """ + Retrieve content from storage. + + Args: + cid: SHA3-256 hash of the content + + Returns: + Raw bytes or None if not found + """ + pass + + @abstractmethod + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned in this storage.""" + pass + + @abstractmethod + async def test_connection(self) -> tuple[bool, str]: + """ + Test connectivity to the storage provider. + + Returns: + (success, message) tuple + """ + pass + + @abstractmethod + def get_usage(self) -> dict: + """ + Get storage usage statistics. + + Returns: + {used_bytes, capacity_bytes, pin_count} + """ + pass + + +class PinataProvider(StorageProvider): + """Pinata IPFS pinning service provider.""" + + provider_type = "pinata" + + def __init__(self, api_key: str, secret_key: str, capacity_gb: int = 1): + self.api_key = api_key + self.secret_key = secret_key + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.pinata.cloud" + self._usage_cache = None + + def _headers(self) -> dict: + return { + "pinata_api_key": self.api_key, + "pinata_secret_api_key": self.secret_key, + } + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Pinata.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{cid[:16]}.bin", data)} + metadata = { + "name": filename or cid[:16], + "keyvalues": {"cid": cid} + } + response = requests.post( + f"{self.base_url}/pinning/pinFileToIPFS", + files=files, + data={"pinataMetadata": json.dumps(metadata)}, + headers=self._headers(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("IpfsHash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Pinata: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Pinata pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Unpin content from Pinata by finding its CID first.""" + try: + import asyncio + + def do_unpin(): + # First find the pin by cid metadata + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return False + + # Unpin each matching CID + for pin in pins: + cid = pin.get("ipfs_pin_hash") + if cid: + resp = requests.delete( + f"{self.base_url}/pinning/unpin/{cid}", + headers=self._headers(), + timeout=30 + ) + resp.raise_for_status() + return True + + result = await asyncio.to_thread(do_unpin) + logger.info(f"Pinata: Unpinned {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Pinata unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Pinata via IPFS gateway.""" + try: + import asyncio + + def do_get(): + # First find the CID + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return None + + cid = pins[0].get("ipfs_pin_hash") + if not cid: + return None + + # Fetch from gateway + gateway_response = requests.get( + f"https://gateway.pinata.cloud/ipfs/{cid}", + timeout=120 + ) + gateway_response.raise_for_status() + return gateway_response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Pinata get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned on Pinata.""" + try: + import asyncio + + def do_check(): + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][cid]": cid, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + return len(response.json().get("rows", [])) > 0 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Pinata API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/data/testAuthentication", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Pinata successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Pinata usage stats.""" + try: + response = requests.get( + f"{self.base_url}/data/userPinnedDataTotal", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + data = response.json() + return { + "used_bytes": data.get("pin_size_total", 0), + "capacity_bytes": self.capacity_bytes, + "pin_count": data.get("pin_count", 0) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class Web3StorageProvider(StorageProvider): + """web3.storage pinning service provider.""" + + provider_type = "web3storage" + + def __init__(self, api_token: str, capacity_gb: int = 1): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.web3.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to web3.storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={ + **self._headers(), + "X-Name": filename or cid[:16] + }, + timeout=120 + ) + response.raise_for_status() + return response.json().get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"web3.storage: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"web3.storage pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """web3.storage doesn't support unpinning - data is stored permanently.""" + logger.warning("web3.storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from web3.storage - would need CID mapping.""" + # web3.storage requires knowing the CID to fetch + # For now, return None - we'd need to maintain a mapping + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test web3.storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1}, + timeout=10 + ) + response.raise_for_status() + return True, "Connected to web3.storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get web3.storage usage stats.""" + try: + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1000}, + timeout=30 + ) + response.raise_for_status() + uploads = response.json() + total_size = sum(u.get("dagSize", 0) for u in uploads) + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": len(uploads) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class NFTStorageProvider(StorageProvider): + """NFT.Storage pinning service provider (free for NFT data).""" + + provider_type = "nftstorage" + + def __init__(self, api_token: str, capacity_gb: int = 5): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.nft.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to NFT.Storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={**self._headers(), "Content-Type": "application/octet-stream"}, + timeout=120 + ) + response.raise_for_status() + return response.json().get("value", {}).get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"NFT.Storage: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"NFT.Storage pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """NFT.Storage doesn't support unpinning - data is stored permanently.""" + logger.warning("NFT.Storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from NFT.Storage - would need CID mapping.""" + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test NFT.Storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to NFT.Storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get NFT.Storage usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class InfuraIPFSProvider(StorageProvider): + """Infura IPFS pinning service provider.""" + + provider_type = "infura" + + def __init__(self, project_id: str, project_secret: str, capacity_gb: int = 5): + self.project_id = project_id + self.project_secret = project_secret + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://ipfs.infura.io:5001/api/v0" + + def _auth(self) -> tuple: + return (self.project_id, self.project_secret) + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Infura IPFS.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{cid[:16]}.bin", data)} + response = requests.post( + f"{self.base_url}/add", + files=files, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("Hash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Infura IPFS: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Infura IPFS pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Unpin content from Infura IPFS.""" + try: + import asyncio + + def do_unpin(): + response = requests.post( + f"{self.base_url}/pin/rm", + params={"arg": cid}, + auth=self._auth(), + timeout=30 + ) + response.raise_for_status() + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Infura IPFS unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Infura IPFS gateway.""" + try: + import asyncio + + def do_get(): + response = requests.post( + f"{self.base_url}/cat", + params={"arg": cid}, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Infura IPFS get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content is pinned on Infura IPFS.""" + try: + import asyncio + + def do_check(): + response = requests.post( + f"{self.base_url}/pin/ls", + params={"arg": cid}, + auth=self._auth(), + timeout=30 + ) + return response.status_code == 200 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Infura IPFS API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.post( + f"{self.base_url}/id", + auth=self._auth(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Infura IPFS successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid project credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Infura usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class FilebaseProvider(StorageProvider): + """Filebase S3-compatible IPFS pinning service.""" + + provider_type = "filebase" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 5): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://s3.filebase.com" + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or f"{cid[:16]}.bin" + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + # Get CID from response headers + head = s3.head_object(Bucket=self.bucket, Key=key) + return head.get('Metadata', {}).get('cid', cid) + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Filebase: Pinned {cid[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Filebase pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Filebase unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=cid) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Filebase get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists in Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Filebase connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Filebase bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Filebase usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class StorjProvider(StorageProvider): + """Storj decentralized cloud storage (S3-compatible).""" + + provider_type = "storj" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 25): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://gateway.storjshare.io" + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or cid + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + return cid + + result = await asyncio.to_thread(do_pin) + logger.info(f"Storj: Stored {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Storj pin failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Storj unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=cid) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Storj get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=cid) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Storj connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Storj bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Storj usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class LocalStorageProvider(StorageProvider): + """Local filesystem storage provider.""" + + provider_type = "local" + + def __init__(self, base_path: str, capacity_gb: int = 10): + self.base_path = Path(base_path) + self.capacity_bytes = capacity_gb * 1024**3 + # Create directory if it doesn't exist + self.base_path.mkdir(parents=True, exist_ok=True) + + def _get_file_path(self, cid: str) -> Path: + """Get file path for a content hash (using subdirectories).""" + # Use first 2 chars as subdirectory for better filesystem performance + subdir = cid[:2] + return self.base_path / subdir / cid + + async def pin(self, cid: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content locally.""" + try: + import asyncio + + def do_store(): + file_path = self._get_file_path(cid) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(data) + return cid # Use cid as ID for local storage + + result = await asyncio.to_thread(do_store) + logger.info(f"Local: Stored {cid[:16]}...") + return result + except Exception as e: + logger.error(f"Local storage failed: {e}") + return None + + async def unpin(self, cid: str) -> bool: + """Remove content from local storage.""" + try: + import asyncio + + def do_remove(): + file_path = self._get_file_path(cid) + if file_path.exists(): + file_path.unlink() + return True + return False + + return await asyncio.to_thread(do_remove) + except Exception as e: + logger.error(f"Local unpin failed: {e}") + return False + + async def get(self, cid: str) -> Optional[bytes]: + """Get content from local storage.""" + try: + import asyncio + + def do_get(): + file_path = self._get_file_path(cid) + if file_path.exists(): + return file_path.read_bytes() + return None + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Local get failed: {e}") + return None + + async def is_pinned(self, cid: str) -> bool: + """Check if content exists in local storage.""" + return self._get_file_path(cid).exists() + + async def test_connection(self) -> tuple[bool, str]: + """Test local storage is writable.""" + try: + test_file = self.base_path / ".write_test" + test_file.write_text("test") + test_file.unlink() + return True, f"Local storage ready at {self.base_path}" + except Exception as e: + return False, f"Cannot write to {self.base_path}: {e}" + + def get_usage(self) -> dict: + """Get local storage usage stats.""" + try: + total_size = 0 + file_count = 0 + for subdir in self.base_path.iterdir(): + if subdir.is_dir() and len(subdir.name) == 2: + for f in subdir.iterdir(): + if f.is_file(): + total_size += f.stat().st_size + file_count += 1 + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": file_count + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +def create_provider(provider_type: str, config: dict) -> Optional[StorageProvider]: + """ + Factory function to create a storage provider from config. + + Args: + provider_type: One of 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + config: Provider-specific configuration dict + + Returns: + StorageProvider instance or None if invalid + """ + try: + if provider_type == "pinata": + return PinataProvider( + api_key=config["api_key"], + secret_key=config["secret_key"], + capacity_gb=config.get("capacity_gb", 1) + ) + elif provider_type == "web3storage": + return Web3StorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "nftstorage": + return NFTStorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "infura": + return InfuraIPFSProvider( + project_id=config["project_id"], + project_secret=config["project_secret"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "filebase": + return FilebaseProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "storj": + return StorjProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 25) + ) + elif provider_type == "local": + return LocalStorageProvider( + base_path=config["path"], + capacity_gb=config.get("capacity_gb", 10) + ) + else: + logger.error(f"Unknown provider type: {provider_type}") + return None + except KeyError as e: + logger.error(f"Missing config key for {provider_type}: {e}") + return None + except Exception as e: + logger.error(f"Failed to create provider {provider_type}: {e}") + return None diff --git a/streaming/__init__.py b/streaming/__init__.py new file mode 100644 index 0000000..2c007cc --- /dev/null +++ b/streaming/__init__.py @@ -0,0 +1,44 @@ +""" +Streaming video compositor for real-time effect processing. + +This module provides a frame-by-frame streaming architecture that: +- Reads from multiple video sources with automatic looping +- Applies effects inline (no intermediate files) +- Composites layers with time-varying weights +- Outputs to display, file, or stream + +Usage: + from streaming import StreamingCompositor, VideoSource, AudioAnalyzer + + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[...], + compositor_config={...}, + ) + + # With live audio + audio = AudioAnalyzer(device=0) + compositor.run(output="output.mp4", duration=60, audio=audio) + + # With preview window + compositor.run(output="preview", duration=60) + +Backends: + - numpy: Works everywhere, ~3-5 fps (default) + - glsl: Requires GPU, 30+ fps real-time (future) +""" + +from .sources import VideoSource, ImageSource +from .compositor import StreamingCompositor +from .backends import NumpyBackend, get_backend +from .output import DisplayOutput, FileOutput + +__all__ = [ + "StreamingCompositor", + "VideoSource", + "ImageSource", + "NumpyBackend", + "get_backend", + "DisplayOutput", + "FileOutput", +] diff --git a/streaming/audio.py b/streaming/audio.py new file mode 100644 index 0000000..9d20937 --- /dev/null +++ b/streaming/audio.py @@ -0,0 +1,486 @@ +""" +Live audio analysis for reactive effects. + +Provides real-time audio features: +- Energy (RMS amplitude) +- Beat detection +- Frequency bands (bass, mid, high) +""" + +import numpy as np +from typing import Optional +import threading +import time + + +class AudioAnalyzer: + """ + Real-time audio analyzer using sounddevice. + + Captures audio from microphone/line-in and computes + features in real-time for effect parameter bindings. + + Example: + analyzer = AudioAnalyzer(device=0) + analyzer.start() + + # In compositor loop: + energy = analyzer.get_energy() + beat = analyzer.get_beat() + + analyzer.stop() + """ + + def __init__( + self, + device: int = None, + sample_rate: int = 44100, + block_size: int = 1024, + buffer_seconds: float = 0.5, + ): + """ + Initialize audio analyzer. + + Args: + device: Audio input device index (None = default) + sample_rate: Audio sample rate + block_size: Samples per block + buffer_seconds: Ring buffer duration + """ + self.sample_rate = sample_rate + self.block_size = block_size + self.device = device + + # Ring buffer for recent audio + buffer_size = int(sample_rate * buffer_seconds) + self._buffer = np.zeros(buffer_size, dtype=np.float32) + self._buffer_pos = 0 + self._lock = threading.Lock() + + # Beat detection state + self._last_energy = 0 + self._energy_history = [] + self._last_beat_time = 0 + self._beat_threshold = 1.5 # Energy ratio for beat detection + self._min_beat_interval = 0.1 # Min seconds between beats + + # Stream state + self._stream = None + self._running = False + + def _audio_callback(self, indata, frames, time_info, status): + """Called by sounddevice for each audio block.""" + with self._lock: + # Add to ring buffer + data = indata[:, 0] if len(indata.shape) > 1 else indata + n = len(data) + if self._buffer_pos + n <= len(self._buffer): + self._buffer[self._buffer_pos:self._buffer_pos + n] = data + else: + # Wrap around + first = len(self._buffer) - self._buffer_pos + self._buffer[self._buffer_pos:] = data[:first] + self._buffer[:n - first] = data[first:] + self._buffer_pos = (self._buffer_pos + n) % len(self._buffer) + + def start(self): + """Start audio capture.""" + try: + import sounddevice as sd + except ImportError: + print("Warning: sounddevice not installed. Audio analysis disabled.") + print("Install with: pip install sounddevice") + return + + self._stream = sd.InputStream( + device=self.device, + channels=1, + samplerate=self.sample_rate, + blocksize=self.block_size, + callback=self._audio_callback, + ) + self._stream.start() + self._running = True + + def stop(self): + """Stop audio capture.""" + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + self._running = False + + def get_energy(self) -> float: + """ + Get current audio energy (RMS amplitude). + + Returns: + Energy value normalized to 0-1 range (approximately) + """ + with self._lock: + # Use recent samples + recent = 2048 + if self._buffer_pos >= recent: + data = self._buffer[self._buffer_pos - recent:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(recent - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # RMS energy + rms = np.sqrt(np.mean(data ** 2)) + + # Normalize (typical mic input is quite low) + normalized = min(1.0, rms * 10) + + return normalized + + def get_beat(self) -> bool: + """ + Detect if current moment is a beat. + + Simple onset detection based on energy spikes. + + Returns: + True if beat detected, False otherwise + """ + current_energy = self.get_energy() + now = time.time() + + # Update energy history + self._energy_history.append(current_energy) + if len(self._energy_history) > 20: + self._energy_history.pop(0) + + # Need enough history + if len(self._energy_history) < 5: + self._last_energy = current_energy + return False + + # Average recent energy + avg_energy = np.mean(self._energy_history[:-1]) + + # Beat if current energy is significantly above average + is_beat = ( + current_energy > avg_energy * self._beat_threshold and + now - self._last_beat_time > self._min_beat_interval and + current_energy > self._last_energy # Rising edge + ) + + if is_beat: + self._last_beat_time = now + + self._last_energy = current_energy + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """ + Get frequency spectrum divided into bands. + + Args: + bands: Number of frequency bands (default 3: bass, mid, high) + + Returns: + Array of band energies, normalized to 0-1 + """ + with self._lock: + # Use recent samples for FFT + n = 2048 + if self._buffer_pos >= n: + data = self._buffer[self._buffer_pos - n:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(n - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # FFT + fft = np.abs(np.fft.rfft(data * np.hanning(len(data)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + start = i * band_size + end = start + band_size + result[i] = np.mean(fft[start:end]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return self._running + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() + + +class FileAudioAnalyzer: + """ + Audio analyzer that reads from a file (for testing/development). + + Pre-computes analysis and plays back in sync with video. + """ + + def __init__(self, path: str, analysis_data: dict = None): + """ + Initialize from audio file. + + Args: + path: Path to audio file + analysis_data: Pre-computed analysis (times, values, etc.) + """ + self.path = path + self.analysis_data = analysis_data or {} + self._current_time = 0 + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Get energy at current time from pre-computed data.""" + track = self.analysis_data.get("energy", {}) + return self._interpolate(track, self._current_time) + + def get_beat(self) -> bool: + """Check if current time is near a beat.""" + track = self.analysis_data.get("beats", {}) + times = track.get("times", []) + + # Check if we're within 50ms of a beat + for beat_time in times: + if abs(beat_time - self._current_time) < 0.05: + return True + return False + + def _interpolate(self, track: dict, t: float) -> float: + """Interpolate value at time t.""" + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + if t <= times[0]: + return values[0] + if t >= times[-1]: + return values[-1] + + # Find bracket and interpolate + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + return values[i] * (1 - alpha) + values[i + 1] * alpha + + return values[-1] + + @property + def is_running(self) -> bool: + return True + + +class StreamingAudioAnalyzer: + """ + Real-time audio analyzer that streams from a file. + + Reads audio in sync with video time and computes features on-the-fly. + No pre-computation needed - analysis happens as frames are processed. + """ + + def __init__(self, path: str, sample_rate: int = 22050, hop_length: int = 512): + """ + Initialize streaming audio analyzer. + + Args: + path: Path to audio file + sample_rate: Sample rate for analysis + hop_length: Hop length for feature extraction + """ + import subprocess + import json + + self.path = path + self.sample_rate = sample_rate + self.hop_length = hop_length + self._current_time = 0.0 + + # Get audio duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(path)] + result = subprocess.run(cmd, capture_output=True, text=True) + info = json.loads(result.stdout) + self.duration = float(info["format"]["duration"]) + + # Audio buffer and state + self._audio_data = None + self._energy_history = [] + self._last_energy = 0 + self._last_beat_time = -1 + self._beat_threshold = 1.5 + self._min_beat_interval = 0.15 + + # Load audio lazily + self._loaded = False + + def _load_audio(self): + """Load audio data on first use.""" + if self._loaded: + return + + import subprocess + + # Use ffmpeg to decode audio to raw PCM + cmd = [ + "ffmpeg", "-v", "quiet", + "-i", str(self.path), + "-f", "f32le", # 32-bit float, little-endian + "-ac", "1", # mono + "-ar", str(self.sample_rate), + "-" + ] + result = subprocess.run(cmd, capture_output=True) + self._audio_data = np.frombuffer(result.stdout, dtype=np.float32) + self._loaded = True + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Compute energy at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return 0.0 + + # Get sample index for current time + sample_idx = int(self._current_time * self.sample_rate) + window_size = self.hop_length * 2 + + start = max(0, sample_idx - window_size // 2) + end = min(len(self._audio_data), sample_idx + window_size // 2) + + if start >= end: + return 0.0 + + # RMS energy + chunk = self._audio_data[start:end] + rms = np.sqrt(np.mean(chunk ** 2)) + + # Normalize to 0-1 range (approximate) + energy = min(1.0, rms * 3.0) + + self._last_energy = energy + return energy + + def get_beat(self) -> bool: + """Detect beat using spectral flux (change in frequency content).""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return False + + # Get audio chunks for current and previous frame + sample_idx = int(self._current_time * self.sample_rate) + chunk_size = self.hop_length * 2 + + # Current chunk + start = max(0, sample_idx - chunk_size // 2) + end = min(len(self._audio_data), sample_idx + chunk_size // 2) + if end - start < chunk_size // 2: + return False + current_chunk = self._audio_data[start:end] + + # Previous chunk (one hop back) + prev_start = max(0, start - self.hop_length) + prev_end = max(0, end - self.hop_length) + if prev_end <= prev_start: + return False + prev_chunk = self._audio_data[prev_start:prev_end] + + # Compute spectra + current_spec = np.abs(np.fft.rfft(current_chunk * np.hanning(len(current_chunk)))) + prev_spec = np.abs(np.fft.rfft(prev_chunk * np.hanning(len(prev_chunk)))) + + # Spectral flux: sum of positive differences (onset = new frequencies appearing) + min_len = min(len(current_spec), len(prev_spec)) + diff = current_spec[:min_len] - prev_spec[:min_len] + flux = np.sum(np.maximum(0, diff)) # Only count increases + + # Normalize by spectrum size + flux = flux / (min_len + 1) + + # Update flux history + self._energy_history.append((self._current_time, flux)) + while self._energy_history and self._energy_history[0][0] < self._current_time - 1.5: + self._energy_history.pop(0) + + if len(self._energy_history) < 3: + return False + + # Adaptive threshold based on recent flux values + flux_values = [f for t, f in self._energy_history] + mean_flux = np.mean(flux_values) + std_flux = np.std(flux_values) + 0.001 # Avoid division by zero + + # Beat if flux is above mean (more sensitive threshold) + threshold = mean_flux + std_flux * 0.3 # Lower = more sensitive + min_interval = 0.1 # Allow up to 600 BPM + time_ok = self._current_time - self._last_beat_time > min_interval + + is_beat = flux > threshold and time_ok + + if is_beat: + self._last_beat_time = self._current_time + + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """Get frequency spectrum at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return np.zeros(bands) + + sample_idx = int(self._current_time * self.sample_rate) + n = 2048 + + start = max(0, sample_idx - n // 2) + end = min(len(self._audio_data), sample_idx + n // 2) + + if end - start < n // 2: + return np.zeros(bands) + + chunk = self._audio_data[start:end] + + # FFT + fft = np.abs(np.fft.rfft(chunk * np.hanning(len(chunk)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + s, e = i * band_size, (i + 1) * band_size + result[i] = np.mean(fft[s:e]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return True diff --git a/streaming/backends.py b/streaming/backends.py new file mode 100644 index 0000000..80c558a --- /dev/null +++ b/streaming/backends.py @@ -0,0 +1,572 @@ +""" +Effect processing backends. + +Provides abstraction over different rendering backends: +- numpy: CPU-based, works everywhere, ~3-5 fps +- glsl: GPU-based, requires OpenGL, 30+ fps (future) +""" + +import numpy as np +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from pathlib import Path + + +class Backend(ABC): + """Abstract base class for effect processing backends.""" + + @abstractmethod + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process multiple input frames through effects and composite. + + Args: + frames: List of input frames (one per source) + effects_per_frame: List of effect chains (one per source) + compositor_config: How to blend the layers + t: Current time in seconds + analysis_data: Analysis data for binding resolution + + Returns: + Composited output frame + """ + pass + + @abstractmethod + def load_effect(self, effect_path: Path) -> Any: + """Load an effect definition.""" + pass + + +class NumpyBackend(Backend): + """ + CPU-based effect processing using NumPy. + + Uses existing sexp_effects interpreter for effect execution. + Works on any system, but limited to ~3-5 fps for complex effects. + """ + + def __init__(self, recipe_dir: Path = None, minimal_primitives: bool = True): + self.recipe_dir = recipe_dir or Path(".") + self.minimal_primitives = minimal_primitives + self._interpreter = None + self._loaded_effects = {} + + def _get_interpreter(self): + """Lazy-load the sexp interpreter.""" + if self._interpreter is None: + from sexp_effects import get_interpreter + self._interpreter = get_interpreter(minimal_primitives=self.minimal_primitives) + return self._interpreter + + def load_effect(self, effect_path: Path) -> Any: + """Load an effect from sexp file.""" + if isinstance(effect_path, str): + effect_path = Path(effect_path) + effect_key = str(effect_path) + if effect_key not in self._loaded_effects: + interp = self._get_interpreter() + interp.load_effect(str(effect_path)) + self._loaded_effects[effect_key] = effect_path.stem + return self._loaded_effects[effect_key] + + def _resolve_binding(self, value: Any, t: float, analysis_data: Dict) -> Any: + """Resolve a parameter binding to its value at time t.""" + if not isinstance(value, dict): + return value + + if "_binding" in value or "_bind" in value: + source = value.get("source") or value.get("_bind") + feature = value.get("feature", "values") + range_map = value.get("range") + + track = analysis_data.get(source, {}) + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + # Find value at time t (linear interpolation) + if t <= times[0]: + val = values[0] + elif t >= times[-1]: + val = values[-1] + else: + # Binary search for bracket + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + val = values[i] * (1 - alpha) + values[i + 1] * alpha + break + else: + val = values[-1] + + # Apply range mapping + if range_map and len(range_map) == 2: + val = range_map[0] + val * (range_map[1] - range_map[0]) + + return val + + return value + + def _apply_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Apply a single effect to a frame.""" + # Resolve bindings in params + resolved_params = {"_time": t} + for key, value in params.items(): + if key in ("effect", "effect_path", "cid", "analysis_refs"): + continue + resolved_params[key] = self._resolve_binding(value, t, analysis_data) + + # Try fast native effects first + result = self._apply_native_effect(frame, effect_name, resolved_params) + if result is not None: + return result + + # Fall back to sexp interpreter for complex effects + interp = self._get_interpreter() + if effect_name in interp.effects: + result, _ = interp.run_effect(effect_name, frame, resolved_params, {}) + return result + + # Unknown effect - pass through + return frame + + def _apply_native_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + ) -> Optional[np.ndarray]: + """Fast native numpy effects for real-time streaming.""" + import cv2 + + if effect_name == "zoom": + amount = float(params.get("amount", 1.0)) + if abs(amount - 1.0) < 0.01: + return frame + h, w = frame.shape[:2] + # Crop center and resize + new_w, new_h = int(w / amount), int(h / amount) + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + + elif effect_name == "rotate": + angle = float(params.get("angle", 0)) + if abs(angle) < 0.5: + return frame + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + return cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == "brightness": + amount = float(params.get("amount", 1.0)) + return np.clip(frame * amount, 0, 255).astype(np.uint8) + + elif effect_name == "invert": + amount = float(params.get("amount", 1.0)) + if amount < 0.5: + return frame + return 255 - frame + + # Not a native effect + return None + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process frames through effects and composite. + """ + if not frames: + return np.zeros((720, 1280, 3), dtype=np.uint8) + + processed = [] + + # Apply effects to each input frame + for i, (frame, effects) in enumerate(zip(frames, effects_per_frame)): + result = frame.copy() + for effect_config in effects: + effect_name = effect_config.get("effect", "") + if effect_name: + result = self._apply_effect( + result, effect_name, effect_config, t, analysis_data + ) + processed.append(result) + + # Composite layers + if len(processed) == 1: + return processed[0] + + return self._composite(processed, compositor_config, t, analysis_data) + + def _composite( + self, + frames: List[np.ndarray], + config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Composite multiple frames into one.""" + mode = config.get("mode", "alpha") + weights = config.get("weights", [1.0 / len(frames)] * len(frames)) + + # Resolve weight bindings + resolved_weights = [] + for w in weights: + resolved_weights.append(self._resolve_binding(w, t, analysis_data)) + + # Normalize weights + total = sum(resolved_weights) + if total > 0: + resolved_weights = [w / total for w in resolved_weights] + else: + resolved_weights = [1.0 / len(frames)] * len(frames) + + # Resize frames to match first frame + target_h, target_w = frames[0].shape[:2] + resized = [] + for frame in frames: + if frame.shape[:2] != (target_h, target_w): + import cv2 + frame = cv2.resize(frame, (target_w, target_h)) + resized.append(frame.astype(np.float32)) + + # Weighted blend + result = np.zeros_like(resized[0]) + for frame, weight in zip(resized, resolved_weights): + result += frame * weight + + return np.clip(result, 0, 255).astype(np.uint8) + + +class WGPUBackend(Backend): + """ + GPU-based effect processing using wgpu/WebGPU compute shaders. + + Compiles sexp effects to WGSL at load time, executes on GPU. + Achieves 30+ fps real-time processing on supported hardware. + + Requirements: + - wgpu-py library + - Vulkan-capable GPU (or software renderer) + """ + + def __init__(self, recipe_dir: Path = None): + self.recipe_dir = recipe_dir or Path(".") + self._device = None + self._loaded_effects: Dict[str, Any] = {} # name -> compiled shader info + self._numpy_fallback = NumpyBackend(recipe_dir) + # Buffer pool for reuse - keyed by (width, height) + self._buffer_pool: Dict[tuple, Dict] = {} + + def _ensure_device(self): + """Lazy-initialize wgpu device.""" + if self._device is not None: + return + + try: + import wgpu + adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + self._device = adapter.request_device_sync() + print(f"[WGPUBackend] Using GPU: {adapter.info.get('device', 'unknown')}") + except Exception as e: + print(f"[WGPUBackend] GPU init failed: {e}, falling back to CPU") + self._device = None + + def load_effect(self, effect_path: Path) -> Any: + """Load and compile an effect from sexp file to WGSL.""" + effect_key = str(effect_path) + if effect_key in self._loaded_effects: + return self._loaded_effects[effect_key] + + try: + from sexp_effects.wgsl_compiler import compile_effect_file + compiled = compile_effect_file(str(effect_path)) + + self._ensure_device() + if self._device is None: + # Fall back to numpy + return self._numpy_fallback.load_effect(effect_path) + + # Create shader module + import wgpu + shader_module = self._device.create_shader_module(code=compiled.wgsl_code) + + # Create compute pipeline + pipeline = self._device.create_compute_pipeline( + layout="auto", + compute={"module": shader_module, "entry_point": "main"} + ) + + self._loaded_effects[effect_key] = { + 'compiled': compiled, + 'pipeline': pipeline, + 'name': compiled.name, + } + return compiled.name + + except Exception as e: + print(f"[WGPUBackend] Failed to compile {effect_path}: {e}") + # Fall back to numpy for this effect + return self._numpy_fallback.load_effect(effect_path) + + def _resolve_binding(self, value: Any, t: float, analysis_data: Dict) -> Any: + """Resolve a parameter binding to its value at time t.""" + # Delegate to numpy backend's implementation + return self._numpy_fallback._resolve_binding(value, t, analysis_data) + + def _get_or_create_buffers(self, w: int, h: int): + """Get or create reusable buffers for given dimensions.""" + import wgpu + + key = (w, h) + if key in self._buffer_pool: + return self._buffer_pool[key] + + size = w * h * 4 # u32 per pixel + + # Create staging buffer for uploads (MAP_WRITE) + staging_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.MAP_WRITE | wgpu.BufferUsage.COPY_SRC, + mapped_at_creation=False, + ) + + # Create input buffer (STORAGE, receives data from staging) + input_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST, + ) + + # Create output buffer (STORAGE + COPY_SRC for readback) + output_buffer = self._device.create_buffer( + size=size, + usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC, + ) + + # Params buffer (uniform, 256 bytes should be enough) + params_buffer = self._device.create_buffer( + size=256, + usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST, + ) + + self._buffer_pool[key] = { + 'staging': staging_buffer, + 'input': input_buffer, + 'output': output_buffer, + 'params': params_buffer, + 'size': size, + } + return self._buffer_pool[key] + + def _apply_effect_gpu( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + ) -> Optional[np.ndarray]: + """Apply effect using GPU. Returns None if GPU not available.""" + import wgpu + + # Find the loaded effect + effect_info = None + for key, info in self._loaded_effects.items(): + if info.get('name') == effect_name: + effect_info = info + break + + if effect_info is None or self._device is None: + return None + + compiled = effect_info['compiled'] + pipeline = effect_info['pipeline'] + + h, w = frame.shape[:2] + + # Get reusable buffers + buffers = self._get_or_create_buffers(w, h) + + # Pack frame as u32 array (RGB -> packed u32) + r = frame[:, :, 0].astype(np.uint32) + g = frame[:, :, 1].astype(np.uint32) + b = frame[:, :, 2].astype(np.uint32) + packed = (r << 16) | (g << 8) | b + input_data = packed.flatten().astype(np.uint32) + + # Upload input data via queue.write_buffer (more efficient than recreation) + self._device.queue.write_buffer(buffers['input'], 0, input_data.tobytes()) + + # Build params struct + import struct + param_values = [w, h] # width, height as u32 + param_format = "II" # two u32 + + # Add time as f32 + param_values.append(t) + param_format += "f" + + # Add effect-specific params + for param in compiled.params: + val = params.get(param.name, param.default) + if val is None: + val = 0 + if param.wgsl_type == 'f32': + param_values.append(float(val)) + param_format += "f" + elif param.wgsl_type == 'i32': + param_values.append(int(val)) + param_format += "i" + elif param.wgsl_type == 'u32': + param_values.append(int(val)) + param_format += "I" + + # Pad to 16-byte alignment + param_bytes = struct.pack(param_format, *param_values) + while len(param_bytes) % 16 != 0: + param_bytes += b'\x00' + + self._device.queue.write_buffer(buffers['params'], 0, param_bytes) + + # Create bind group (unfortunately this can't be easily reused with different effects) + bind_group = self._device.create_bind_group( + layout=pipeline.get_bind_group_layout(0), + entries=[ + {"binding": 0, "resource": {"buffer": buffers['input']}}, + {"binding": 1, "resource": {"buffer": buffers['output']}}, + {"binding": 2, "resource": {"buffer": buffers['params']}}, + ] + ) + + # Dispatch compute + encoder = self._device.create_command_encoder() + compute_pass = encoder.begin_compute_pass() + compute_pass.set_pipeline(pipeline) + compute_pass.set_bind_group(0, bind_group) + + # Workgroups: ceil(w/16) x ceil(h/16) + wg_x = (w + 15) // 16 + wg_y = (h + 15) // 16 + compute_pass.dispatch_workgroups(wg_x, wg_y, 1) + compute_pass.end() + + self._device.queue.submit([encoder.finish()]) + + # Read back result + result_data = self._device.queue.read_buffer(buffers['output']) + result_packed = np.frombuffer(result_data, dtype=np.uint32).reshape(h, w) + + # Unpack u32 -> RGB + result = np.zeros((h, w, 3), dtype=np.uint8) + result[:, :, 0] = ((result_packed >> 16) & 0xFF).astype(np.uint8) + result[:, :, 1] = ((result_packed >> 8) & 0xFF).astype(np.uint8) + result[:, :, 2] = (result_packed & 0xFF).astype(np.uint8) + + return result + + def _apply_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Apply a single effect to a frame.""" + # Resolve bindings in params + resolved_params = {"_time": t} + for key, value in params.items(): + if key in ("effect", "effect_path", "cid", "analysis_refs"): + continue + resolved_params[key] = self._resolve_binding(value, t, analysis_data) + + # Try GPU first + self._ensure_device() + if self._device is not None: + result = self._apply_effect_gpu(frame, effect_name, resolved_params, t) + if result is not None: + return result + + # Fall back to numpy + return self._numpy_fallback._apply_effect( + frame, effect_name, params, t, analysis_data + ) + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Process frames through effects and composite.""" + if not frames: + return np.zeros((720, 1280, 3), dtype=np.uint8) + + processed = [] + + # Apply effects to each input frame + for i, (frame, effects) in enumerate(zip(frames, effects_per_frame)): + result = frame.copy() + for effect_config in effects: + effect_name = effect_config.get("effect", "") + if effect_name: + result = self._apply_effect( + result, effect_name, effect_config, t, analysis_data + ) + processed.append(result) + + # Composite layers (use numpy backend for now) + if len(processed) == 1: + return processed[0] + + return self._numpy_fallback._composite( + processed, compositor_config, t, analysis_data + ) + + +# Keep GLSLBackend as alias for backwards compatibility +GLSLBackend = WGPUBackend + + +def get_backend(name: str = "numpy", **kwargs) -> Backend: + """ + Get a backend by name. + + Args: + name: "numpy", "wgpu", or "glsl" (alias for wgpu) + **kwargs: Backend-specific options + + Returns: + Backend instance + """ + if name == "numpy": + return NumpyBackend(**kwargs) + elif name in ("wgpu", "glsl", "gpu"): + return WGPUBackend(**kwargs) + else: + raise ValueError(f"Unknown backend: {name}") diff --git a/streaming/compositor.py b/streaming/compositor.py new file mode 100644 index 0000000..477128f --- /dev/null +++ b/streaming/compositor.py @@ -0,0 +1,595 @@ +""" +Streaming video compositor. + +Main entry point for the streaming pipeline. Combines: +- Multiple video sources (with looping) +- Per-source effect chains +- Layer compositing +- Optional live audio analysis +- Output to display/file/stream +""" + +import time +import sys +import numpy as np +from typing import List, Dict, Any, Optional, Union +from pathlib import Path + +from .sources import Source, VideoSource +from .backends import Backend, NumpyBackend, get_backend +from .output import Output, DisplayOutput, FileOutput, MultiOutput + + +class StreamingCompositor: + """ + Real-time streaming video compositor. + + Reads frames from multiple sources, applies effects, composites layers, + and outputs the result - all frame-by-frame without intermediate files. + + Example: + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[ + [{"effect": "rotate", "angle": 45}], + [{"effect": "zoom", "amount": 1.5}], + ], + compositor_config={"mode": "alpha", "weights": [0.5, 0.5]}, + ) + compositor.run(output="preview", duration=60) + """ + + def __init__( + self, + sources: List[Union[str, Source]], + effects_per_source: List[List[Dict]] = None, + compositor_config: Dict = None, + analysis_data: Dict = None, + backend: str = "numpy", + recipe_dir: Path = None, + fps: float = 30, + audio_source: str = None, + ): + """ + Initialize the streaming compositor. + + Args: + sources: List of video paths or Source objects + effects_per_source: List of effect chains, one per source + compositor_config: How to blend layers (mode, weights) + analysis_data: Pre-computed analysis data for bindings + backend: "numpy" or "glsl" + recipe_dir: Directory for resolving relative effect paths + fps: Output frame rate + audio_source: Path to audio file for streaming analysis + """ + self.fps = fps + self.recipe_dir = recipe_dir or Path(".") + self.analysis_data = analysis_data or {} + + # Initialize streaming audio analyzer if audio source provided + self._audio_analyzer = None + self._audio_source = audio_source + if audio_source: + from .audio import StreamingAudioAnalyzer + self._audio_analyzer = StreamingAudioAnalyzer(audio_source) + print(f"Streaming audio: {audio_source}", file=sys.stderr) + + # Initialize sources + self.sources: List[Source] = [] + for src in sources: + if isinstance(src, Source): + self.sources.append(src) + elif isinstance(src, (str, Path)): + self.sources.append(VideoSource(str(src), target_fps=fps)) + else: + raise ValueError(f"Unknown source type: {type(src)}") + + # Effect chains (default: no effects) + self.effects_per_source = effects_per_source or [[] for _ in self.sources] + if len(self.effects_per_source) != len(self.sources): + raise ValueError( + f"effects_per_source length ({len(self.effects_per_source)}) " + f"must match sources length ({len(self.sources)})" + ) + + # Compositor config (default: equal blend) + self.compositor_config = compositor_config or { + "mode": "alpha", + "weights": [1.0 / len(self.sources)] * len(self.sources), + } + + # Initialize backend + self.backend: Backend = get_backend( + backend, + recipe_dir=self.recipe_dir, + ) + + # Load effects + self._load_effects() + + def _load_effects(self): + """Pre-load all effect definitions.""" + for effects in self.effects_per_source: + for effect_config in effects: + effect_path = effect_config.get("effect_path") + if effect_path: + full_path = self.recipe_dir / effect_path + if full_path.exists(): + self.backend.load_effect(full_path) + + def _create_output( + self, + output: Union[str, Output], + size: tuple, + ) -> Output: + """Create output target from string or Output object.""" + if isinstance(output, Output): + return output + + if output == "preview": + return DisplayOutput("Streaming Preview", size, + audio_source=self._audio_source, fps=self.fps) + elif output == "null": + from .output import NullOutput + return NullOutput() + elif isinstance(output, str): + return FileOutput(output, size, fps=self.fps, audio_source=self._audio_source) + else: + raise ValueError(f"Unknown output type: {output}") + + def run( + self, + output: Union[str, Output] = "preview", + duration: float = None, + audio_analyzer=None, + show_fps: bool = True, + recipe_executor=None, + ): + """ + Run the streaming compositor. + + Args: + output: Output target - "preview", filename, or Output object + duration: Duration in seconds (None = run until quit) + audio_analyzer: Optional AudioAnalyzer for live audio reactivity + show_fps: Show FPS counter in console + recipe_executor: Optional StreamingRecipeExecutor for full recipe logic + """ + # Determine output size from first source + output_size = self.sources[0].size + + # Create output + out = self._create_output(output, output_size) + + # Determine duration + if duration is None: + # Run until stopped (or min source duration if not looping) + duration = min(s.duration for s in self.sources) + if duration == float('inf'): + duration = 3600 # 1 hour max for live sources + + total_frames = int(duration * self.fps) + frame_time = 1.0 / self.fps + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {total_frames} frames @ {self.fps}fps", file=sys.stderr) + print(f"Output size: {output_size[0]}x{output_size[1]}", file=sys.stderr) + print(f"Press 'q' to quit (if preview)", file=sys.stderr) + + # Frame loop + start_time = time.time() + frame_count = 0 + fps_update_interval = 30 # Update FPS display every N frames + last_fps_time = start_time + last_fps_count = 0 + + try: + for frame_num in range(total_frames): + if not out.is_open: + print(f"\nOutput closed at frame {frame_num}", file=sys.stderr) + break + + t = frame_num * frame_time + + try: + # Update analysis data from streaming audio (file-based) + energy = 0.0 + is_beat = False + if self._audio_analyzer: + self._update_from_audio(self._audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + elif audio_analyzer: + self._update_from_audio(audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + + # Read frames from all sources + frames = [src.read_frame(t) for src in self.sources] + + # Process through recipe executor if provided + if recipe_executor: + result = self._process_with_executor( + frames, recipe_executor, energy, is_beat, t + ) + else: + # Simple backend processing + result = self.backend.process_frame( + frames, + self.effects_per_source, + self.compositor_config, + t, + self.analysis_data, + ) + + # Output + out.write(result, t) + frame_count += 1 + + # FPS display + if show_fps and frame_count % fps_update_interval == 0: + now = time.time() + elapsed = now - last_fps_time + if elapsed > 0: + current_fps = (frame_count - last_fps_count) / elapsed + progress = frame_num / total_frames * 100 + print( + f"\r {progress:5.1f}% | {current_fps:5.1f} fps | " + f"frame {frame_num}/{total_frames}", + end="", file=sys.stderr + ) + last_fps_time = now + last_fps_count = frame_count + + except Exception as e: + print(f"\nError at frame {frame_num}, t={t:.1f}s: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + break + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources: + if hasattr(src, 'close'): + src.close() + + # Final stats + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", file=sys.stderr) + + def _process_with_executor( + self, + frames: List[np.ndarray], + executor, + energy: float, + is_beat: bool, + t: float, + ) -> np.ndarray: + """ + Process frames using the recipe executor for full pipeline. + + Implements: + 1. process-pair: two clips per source with effects, blended + 2. cycle-crossfade: dynamic composition with zoom and weights + 3. Final effects: whole-spin, ripple + """ + import cv2 + + # Target size from first source + target_h, target_w = frames[0].shape[:2] + + # Resize all frames to target size (letterbox to preserve aspect ratio) + resized_frames = [] + for frame in frames: + fh, fw = frame.shape[:2] + if (fh, fw) != (target_h, target_w): + # Calculate scale to fit while preserving aspect ratio + scale = min(target_w / fw, target_h / fh) + new_w, new_h = int(fw * scale), int(fh * scale) + resized = cv2.resize(frame, (new_w, new_h)) + # Center on black canvas + canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8) + x_off = (target_w - new_w) // 2 + y_off = (target_h - new_h) // 2 + canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized + resized_frames.append(canvas) + else: + resized_frames.append(frame) + frames = resized_frames + + # Update executor state + executor.on_frame(energy, is_beat, t) + + # Get weights to know which sources are active + weights = executor.get_cycle_weights() + + # Process each source as a "pair" (clip A and B with different effects) + processed_pairs = [] + + for i, frame in enumerate(frames): + # Skip sources with zero weight (but still need placeholder) + if i < len(weights) and weights[i] < 0.001: + processed_pairs.append(None) + continue + # Get effect params for clip A and B + params_a = executor.get_effect_params(i, "a", energy) + params_b = executor.get_effect_params(i, "b", energy) + pair_params = executor.get_pair_params(i) + + # Process clip A + clip_a = self._apply_clip_effects(frame.copy(), params_a, t) + + # Process clip B + clip_b = self._apply_clip_effects(frame.copy(), params_b, t) + + # Blend A and B using pair_mix opacity + opacity = pair_params["blend_opacity"] + blended = cv2.addWeighted( + clip_a, 1 - opacity, + clip_b, opacity, + 0 + ) + + # Apply pair rotation + h, w = blended.shape[:2] + center = (w // 2, h // 2) + angle = pair_params["pair_rotation"] + if abs(angle) > 0.5: + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + blended = cv2.warpAffine(blended, matrix, (w, h)) + + processed_pairs.append(blended) + + # Cycle-crossfade composition + weights = executor.get_cycle_weights() + zooms = executor.get_cycle_zooms() + + # Apply zoom per pair and composite + h, w = target_h, target_w + result = np.zeros((h, w, 3), dtype=np.float32) + + for idx, (pair, weight, zoom) in enumerate(zip(processed_pairs, weights, zooms)): + # Skip zero-weight sources + if pair is None or weight < 0.001: + continue + + orig_shape = pair.shape + + # Apply zoom + if zoom > 1.01: + # Zoom in: crop center and resize up + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = pair[y1:y1+new_h, x1:x1+new_w] + pair = cv2.resize(cropped, (w, h)) + elif zoom < 0.99: + # Zoom out: shrink video and center on black + scaled_w, scaled_h = int(w * zoom), int(h * zoom) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(pair, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + pair = canvas.copy() + + # Draw colored border - size indicates zoom level + border_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] + color = border_colors[idx % 4] + thickness = max(3, int(10 * weight)) # Thicker border = higher weight + pair = np.ascontiguousarray(pair) + pair[:thickness, :] = color + pair[-thickness:, :] = color + pair[:, :thickness] = color + pair[:, -thickness:] = color + + result += pair.astype(np.float32) * weight + + result = np.clip(result, 0, 255).astype(np.uint8) + + # Apply final effects (whole-spin, ripple) + final_params = executor.get_final_effects(energy) + + # Whole spin + spin_angle = final_params["whole_spin_angle"] + if abs(spin_angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, spin_angle, 1.0) + result = cv2.warpAffine(result, matrix, (w, h)) + + # Ripple effect + amp = final_params["ripple_amplitude"] + if amp > 1: + result = self._apply_ripple(result, amp, + final_params["ripple_cx"], + final_params["ripple_cy"], + t) + + return result + + def _apply_clip_effects(self, frame: np.ndarray, params: dict, t: float) -> np.ndarray: + """Apply per-clip effects: rotate, zoom, invert, hue_shift, ascii.""" + import cv2 + + h, w = frame.shape[:2] + + # Rotate + angle = params["rotate_angle"] + if abs(angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + # Zoom + zoom = params["zoom_amount"] + if abs(zoom - 1.0) > 0.01: + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x1 + new_w), min(h, y1 + new_h) + if x2 > x1 and y2 > y1: + cropped = frame[y1:y2, x1:x2] + frame = cv2.resize(cropped, (w, h)) + + # Invert + if params["invert_amount"] > 0.5: + frame = 255 - frame + + # Hue shift + hue_deg = params["hue_degrees"] + if abs(hue_deg) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.int32) + int(hue_deg / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + # ASCII art + if params["ascii_mix"] > 0.5: + char_size = max(4, int(params["ascii_char_size"])) + frame = self._apply_ascii(frame, char_size) + + return frame + + def _apply_ascii(self, frame: np.ndarray, char_size: int) -> np.ndarray: + """Apply ASCII art effect.""" + import cv2 + from PIL import Image, ImageDraw, ImageFont + + h, w = frame.shape[:2] + chars = " .:-=+*#%@" + + # Get font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size) + except: + font = ImageFont.load_default() + + # Sample cells using area interpolation (fast block average) + rows = h // char_size + cols = w // char_size + if rows < 1 or cols < 1: + return frame + + # Crop to exact grid and downsample + cropped = frame[:rows * char_size, :cols * char_size] + cell_colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * cell_colors[:, :, 0] + + 0.587 * cell_colors[:, :, 1] + + 0.114 * cell_colors[:, :, 2]) / 255.0 + + # Create output image + out_h = rows * char_size + out_w = cols * char_size + output = Image.new('RGB', (out_w, out_h), (0, 0, 0)) + draw = ImageDraw.Draw(output) + + # Draw characters + for r in range(rows): + for c in range(cols): + lum = luminances[r, c] + color = tuple(cell_colors[r, c]) + + # Map luminance to character + idx = int(lum * (len(chars) - 1)) + char = chars[idx] + + # Draw character + x = c * char_size + y = r * char_size + draw.text((x, y), char, fill=color, font=font) + + # Convert back to numpy and resize to original + result = np.array(output) + if result.shape[:2] != (h, w): + result = cv2.resize(result, (w, h), interpolation=cv2.INTER_LINEAR) + + return result + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, t: float = 0) -> np.ndarray: + """Apply ripple distortion effect.""" + import cv2 + + h, w = frame.shape[:2] + center_x, center_y = cx * w, cy * h + max_dim = max(w, h) + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Distance from center + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx*dx + dy*dy) + + # Ripple parameters (matching recipe: frequency=8, decay=2, speed=5) + freq = 8 + decay = 2 + speed = 5 + phase = t * speed * 2 * np.pi + + # Ripple displacement (matching original formula) + ripple = np.sin(2 * np.pi * freq * dist / max_dim + phase) * amplitude + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max_dim) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT) + + def _update_from_audio(self, analyzer, t: float): + """Update analysis data from audio analyzer (streaming or live).""" + # Set time for file-based streaming analyzers + if hasattr(analyzer, 'set_time'): + analyzer.set_time(t) + + # Get current audio features + energy = analyzer.get_energy() if hasattr(analyzer, 'get_energy') else 0 + beat = analyzer.get_beat() if hasattr(analyzer, 'get_beat') else False + + # Update analysis tracks - these can be referenced by effect bindings + self.analysis_data["live_energy"] = { + "times": [t], + "values": [energy], + "duration": float('inf'), + } + self.analysis_data["live_beat"] = { + "times": [t], + "values": [1.0 if beat else 0.0], + "duration": float('inf'), + } + + +def quick_preview( + sources: List[str], + effects: List[List[Dict]] = None, + duration: float = 10, + fps: float = 30, +): + """ + Quick preview helper - show sources with optional effects. + + Example: + quick_preview(["video1.mp4", "video2.mp4"], duration=30) + """ + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects, + fps=fps, + ) + compositor.run(output="preview", duration=duration) diff --git a/streaming/demo.py b/streaming/demo.py new file mode 100644 index 0000000..0b1899f --- /dev/null +++ b/streaming/demo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Demo script for streaming compositor. + +Usage: + # Preview two videos blended + python -m streaming.demo preview video1.mp4 video2.mp4 + + # Record output to file + python -m streaming.demo record video1.mp4 video2.mp4 -o output.mp4 + + # Benchmark (no output) + python -m streaming.demo benchmark video1.mp4 --duration 10 +""" + +import argparse +import sys +from pathlib import Path + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from streaming import StreamingCompositor, VideoSource +from streaming.output import NullOutput + + +def demo_preview(sources: list, duration: float, effects: bool = False): + """Preview sources with optional simple effects.""" + effects_config = None + if effects: + effects_config = [ + [{"effect": "rotate", "angle": 15}], + [{"effect": "zoom", "amount": 1.2}], + ][:len(sources)] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="preview", duration=duration) + + +def demo_record(sources: list, output_path: str, duration: float): + """Record blended output to file.""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output=output_path, duration=duration) + + +def demo_benchmark(sources: list, duration: float): + """Benchmark processing speed (no output).""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="null", duration=duration) + + +def demo_audio_reactive(sources: list, duration: float): + """Preview with live audio reactivity.""" + from streaming.audio import AudioAnalyzer + + # Create compositor with energy-reactive effects + effects_config = [ + [{ + "effect": "zoom", + "amount": {"_binding": True, "source": "live_energy", "feature": "values", "range": [1.0, 1.5]}, + }] + for _ in sources + ] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + + # Start audio analyzer + try: + with AudioAnalyzer() as audio: + print("Audio analyzer started. Make some noise!", file=sys.stderr) + compositor.run(output="preview", duration=duration, audio_analyzer=audio) + except Exception as e: + print(f"Audio not available: {e}", file=sys.stderr) + print("Running without audio...", file=sys.stderr) + compositor.run(output="preview", duration=duration) + + +def main(): + parser = argparse.ArgumentParser(description="Streaming compositor demo") + parser.add_argument("mode", choices=["preview", "record", "benchmark", "audio"], + help="Demo mode") + parser.add_argument("sources", nargs="+", help="Video source files") + parser.add_argument("-o", "--output", help="Output file (for record mode)") + parser.add_argument("-d", "--duration", type=float, default=30, + help="Duration in seconds") + parser.add_argument("--effects", action="store_true", + help="Apply simple effects (for preview)") + + args = parser.parse_args() + + # Verify sources exist + for src in args.sources: + if not Path(src).exists(): + print(f"Error: Source not found: {src}", file=sys.stderr) + sys.exit(1) + + if args.mode == "preview": + demo_preview(args.sources, args.duration, args.effects) + elif args.mode == "record": + if not args.output: + print("Error: --output required for record mode", file=sys.stderr) + sys.exit(1) + demo_record(args.sources, args.output, args.duration) + elif args.mode == "benchmark": + demo_benchmark(args.sources, args.duration) + elif args.mode == "audio": + demo_audio_reactive(args.sources, args.duration) + + +if __name__ == "__main__": + main() diff --git a/streaming/gpu_output.py b/streaming/gpu_output.py new file mode 100644 index 0000000..3034310 --- /dev/null +++ b/streaming/gpu_output.py @@ -0,0 +1,538 @@ +""" +Zero-copy GPU video encoding output. + +Uses PyNvVideoCodec for direct GPU-to-GPU encoding without CPU transfers. +Frames stay on GPU throughout: CuPy → NV12 conversion → NVENC encoding. +""" + +import numpy as np +import subprocess +import sys +import threading +import queue +from pathlib import Path +from typing import Tuple, Optional, Union +import time + +# Try to import GPU libraries +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + +try: + import PyNvVideoCodec as nvc + PYNVCODEC_AVAILABLE = True +except ImportError: + nvc = None + PYNVCODEC_AVAILABLE = False + + +def check_gpu_encode_available() -> bool: + """Check if zero-copy GPU encoding is available.""" + return CUPY_AVAILABLE and PYNVCODEC_AVAILABLE + + +# RGB to NV12 CUDA kernel +_RGB_TO_NV12_KERNEL = None + +def _get_rgb_to_nv12_kernel(): + """Get or create the RGB to NV12 conversion kernel.""" + global _RGB_TO_NV12_KERNEL + if _RGB_TO_NV12_KERNEL is None and CUPY_AVAILABLE: + _RGB_TO_NV12_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void rgb_to_nv12( + const unsigned char* rgb, + unsigned char* y_plane, + unsigned char* uv_plane, + int width, int height +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int rgb_idx = (y * width + x) * 3; + unsigned char r = rgb[rgb_idx]; + unsigned char g = rgb[rgb_idx + 1]; + unsigned char b = rgb[rgb_idx + 2]; + + // RGB to Y (BT.601) + int y_val = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16; + y_plane[y * width + x] = (unsigned char)(y_val > 255 ? 255 : (y_val < 0 ? 0 : y_val)); + + // UV (subsample 2x2) - only process even pixels + if ((x & 1) == 0 && (y & 1) == 0) { + int u_val = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + int v_val = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + + int uv_idx = (y / 2) * width + x; + uv_plane[uv_idx] = (unsigned char)(u_val > 255 ? 255 : (u_val < 0 ? 0 : u_val)); + uv_plane[uv_idx + 1] = (unsigned char)(v_val > 255 ? 255 : (v_val < 0 ? 0 : v_val)); + } +} +''', 'rgb_to_nv12') + return _RGB_TO_NV12_KERNEL + + +class GPUEncoder: + """ + Zero-copy GPU video encoder using PyNvVideoCodec. + + Frames are converted from RGB to NV12 on GPU and encoded directly + without any CPU memory transfers. + """ + + def __init__(self, width: int, height: int, fps: float = 30, crf: int = 23): + if not check_gpu_encode_available(): + raise RuntimeError("GPU encoding not available (need CuPy and PyNvVideoCodec)") + + self.width = width + self.height = height + self.fps = fps + self.crf = crf + + # Create dummy video to get frame buffer template + self._init_frame_buffer() + + # Create encoder with low-latency settings (no B-frames for immediate output) + # Use H264 codec explicitly, with SPS/PPS headers for browser compatibility + self.encoder = nvc.CreateEncoder( + width, height, "NV12", usecpuinputbuffer=False, + codec="h264", # Explicit H.264 (not HEVC) + bf=0, # No B-frames - immediate output + repeatSPSPPS=1, # Include SPS/PPS with each IDR frame + idrPeriod=30, # IDR frame every 30 frames (1 sec at 30fps) + ) + + # CUDA kernel grid/block config + self._block = (16, 16) + self._grid = ((width + 15) // 16, (height + 15) // 16) + + self._frame_count = 0 + self._encoded_data = [] + + print(f"[GPUEncoder] Initialized {width}x{height} @ {fps}fps, zero-copy GPU encoding", file=sys.stderr) + + def _init_frame_buffer(self): + """Initialize frame buffer from dummy decode.""" + # Create minimal dummy video + dummy_path = Path("/tmp/gpu_encoder_dummy.mp4") + subprocess.run([ + "ffmpeg", "-y", "-f", "lavfi", + "-i", f"color=black:size={self.width}x{self.height}:duration=0.1:rate=30", + "-c:v", "h264", "-pix_fmt", "yuv420p", + str(dummy_path) + ], capture_output=True) + + # Decode to get frame buffer + demuxer = nvc.CreateDemuxer(str(dummy_path)) + decoder = nvc.CreateDecoder(gpuid=0, usedevicememory=True) + + self._template_frame = None + for _ in range(30): + packet = demuxer.Demux() + if not packet: + break + frames = decoder.Decode(packet) + if frames: + self._template_frame = frames[0] + break + + if not self._template_frame: + raise RuntimeError("Failed to initialize GPU frame buffer") + + # Wrap frame planes with CuPy for zero-copy access + y_ptr = self._template_frame.GetPtrToPlane(0) + uv_ptr = self._template_frame.GetPtrToPlane(1) + + y_mem = cp.cuda.UnownedMemory(y_ptr, self.height * self.width, None) + self._y_plane = cp.ndarray( + (self.height, self.width), dtype=cp.uint8, + memptr=cp.cuda.MemoryPointer(y_mem, 0) + ) + + uv_mem = cp.cuda.UnownedMemory(uv_ptr, (self.height // 2) * self.width, None) + self._uv_plane = cp.ndarray( + (self.height // 2, self.width), dtype=cp.uint8, + memptr=cp.cuda.MemoryPointer(uv_mem, 0) + ) + + # Keep references to prevent GC + self._decoder = decoder + self._demuxer = demuxer + + # Cleanup dummy file + dummy_path.unlink(missing_ok=True) + + def encode_frame(self, frame: Union[np.ndarray, 'cp.ndarray']) -> bytes: + """ + Encode a frame (RGB format) to H.264. + + Args: + frame: RGB frame as numpy or CuPy array, shape (H, W, 3) + + Returns: + Encoded bytes (may be empty if frame is buffered) + """ + # Ensure frame is on GPU + if isinstance(frame, np.ndarray): + frame_gpu = cp.asarray(frame) + else: + frame_gpu = frame + + # Ensure uint8 + if frame_gpu.dtype != cp.uint8: + frame_gpu = cp.clip(frame_gpu, 0, 255).astype(cp.uint8) + + # Ensure contiguous + if not frame_gpu.flags['C_CONTIGUOUS']: + frame_gpu = cp.ascontiguousarray(frame_gpu) + + # Debug: check input frame has actual data (first few frames only) + if self._frame_count < 3: + frame_sum = float(cp.sum(frame_gpu)) + print(f"[GPUEncoder] Frame {self._frame_count}: shape={frame_gpu.shape}, dtype={frame_gpu.dtype}, sum={frame_sum:.0f}", file=sys.stderr) + if frame_sum < 1000: + print(f"[GPUEncoder] WARNING: Frame appears to be mostly black!", file=sys.stderr) + + # Convert RGB to NV12 on GPU + kernel = _get_rgb_to_nv12_kernel() + kernel(self._grid, self._block, (frame_gpu, self._y_plane, self._uv_plane, self.width, self.height)) + + # CRITICAL: Synchronize CUDA to ensure kernel completes before encoding + cp.cuda.Stream.null.synchronize() + + # Debug: check Y plane has data after conversion (first few frames only) + if self._frame_count < 3: + y_sum = float(cp.sum(self._y_plane)) + print(f"[GPUEncoder] Frame {self._frame_count}: Y plane sum={y_sum:.0f}", file=sys.stderr) + + # Encode (GPU to GPU) + result = self.encoder.Encode(self._template_frame) + self._frame_count += 1 + + return result if result else b'' + + def flush(self) -> bytes: + """Flush encoder and return remaining data.""" + return self.encoder.EndEncode() + + def close(self): + """Close encoder and cleanup.""" + pass + + +class GPUHLSOutput: + """ + GPU-accelerated HLS output with IPFS upload. + + Uses zero-copy GPU encoding and writes HLS segments. + Uploads happen asynchronously in a background thread to avoid stuttering. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + crf: int = 23, + audio_source: str = None, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._on_playlist_update = on_playlist_update + self._is_open = True + self.audio_source = audio_source + + # GPU encoder + self._gpu_encoder = GPUEncoder(size[0], size[1], fps, crf) + + # Segment management + self._current_segment = 0 + self._frames_in_segment = 0 + self._frames_per_segment = int(fps * segment_duration) + self._segment_data = [] + + # Track segment CIDs for IPFS + self.segment_cids = {} + self._playlist_cid = None + self._upload_lock = threading.Lock() + + # Import IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Background upload thread + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Setup ffmpeg for muxing (takes raw H.264, outputs .ts segments) + self._setup_muxer() + + print(f"[GPUHLSOutput] Initialized {size[0]}x{size[1]} @ {fps}fps, GPU encoding", file=sys.stderr) + + def _setup_muxer(self): + """Setup ffmpeg for muxing H.264 to MPEG-TS segments with optional audio.""" + self.local_playlist_path = self.output_dir / "stream.m3u8" + + cmd = [ + "ffmpeg", "-y", + "-f", "h264", # Input is raw H.264 + "-i", "-", + ] + + # Add audio input if provided + if self.audio_source: + cmd.extend(["-i", str(self.audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + cmd.extend([ + "-c:v", "copy", # Just copy video, no re-encoding + ]) + + # Add audio codec if we have audio + if self.audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + cmd.extend([ + "-f", "hls", + "-hls_time", str(self.segment_duration), + "-hls_list_size", "0", + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.local_playlist_path), + ]) + + print(f"[GPUHLSOutput] FFmpeg cmd: {' '.join(cmd)}", file=sys.stderr) + + self._muxer = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, # Capture stderr for debugging + ) + + # Start thread to drain stderr (prevents pipe buffer from filling and blocking FFmpeg) + self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) + self._stderr_thread.start() + + def _drain_stderr(self): + """Drain FFmpeg stderr to prevent blocking.""" + try: + for line in self._muxer.stderr: + line_str = line.decode('utf-8', errors='replace').strip() + if line_str: + print(f"[FFmpeg] {line_str}", file=sys.stderr) + except Exception as e: + print(f"[FFmpeg stderr] Error reading: {e}", file=sys.stderr) + + def write(self, frame: Union[np.ndarray, 'cp.ndarray'], t: float = 0): + """Write a frame using GPU encoding.""" + if not self._is_open: + return + + # Handle GPUFrame objects (from streaming_gpu primitives) + if hasattr(frame, 'gpu') and hasattr(frame, 'is_on_gpu'): + # It's a GPUFrame - extract the underlying array + frame = frame.gpu if frame.is_on_gpu else frame.cpu + + # GPU encode + encoded = self._gpu_encoder.encode_frame(frame) + + # Send to muxer + if encoded: + try: + self._muxer.stdin.write(encoded) + except BrokenPipeError as e: + print(f"[GPUHLSOutput] FFmpeg pipe broken after {self._frames_in_segment} frames in segment, total segments: {self._current_segment}", file=sys.stderr) + # Check if muxer is still running + if self._muxer.poll() is not None: + print(f"[GPUHLSOutput] FFmpeg exited with code {self._muxer.returncode}", file=sys.stderr) + self._is_open = False + return + except Exception as e: + print(f"[GPUHLSOutput] Error writing to FFmpeg: {e}", file=sys.stderr) + self._is_open = False + return + + self._frames_in_segment += 1 + + # Check for segment completion + if self._frames_in_segment >= self._frames_per_segment: + self._frames_in_segment = 0 + self._check_upload_segments() + + def _upload_worker(self): + """Background worker thread for async IPFS uploads.""" + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + seg_path, seg_num = item + self._do_upload(seg_path, seg_num) + except queue.Empty: + continue + except Exception as e: + print(f"Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, seg_path: Path, seg_num: int): + """Actually perform the upload (runs in background thread).""" + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + with self._upload_lock: + self.segment_cids[seg_num] = cid + print(f"Added to IPFS: {seg_path.name} -> {cid}", file=sys.stderr) + self._update_playlist() + except Exception as e: + print(f"Failed to add to IPFS: {e}", file=sys.stderr) + + def _check_upload_segments(self): + """Check for and queue new segments for async IPFS upload.""" + segments = sorted(self.output_dir.glob("segment_*.ts")) + + for seg_path in segments: + seg_num = int(seg_path.stem.split("_")[1]) + + with self._upload_lock: + if seg_num in self.segment_cids: + continue + + # Check if segment is complete (quick check, no blocking) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue + # Quick non-blocking check + time.sleep(0.01) + size2 = seg_path.stat().st_size + if size1 != size2: + continue + except FileNotFoundError: + continue + + # Queue for async upload (non-blocking!) + self._upload_queue.put((seg_path, seg_num)) + + def _update_playlist(self): + """Generate and upload IPFS-aware playlist.""" + with self._upload_lock: + if not self.segment_cids: + return + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + # Use /ipfs-ts/ path for segments to get correct MIME type (video/mp2t) + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + lines.append(f"{segment_gateway}/{cid}") + + playlist_content = "\n".join(lines) + "\n" + + # Upload playlist + self._playlist_cid = self._ipfs_add_bytes(playlist_content.encode(), pin=True) + if self._playlist_cid and self._on_playlist_update: + self._on_playlist_update(self._playlist_cid) + + def close(self): + """Close output and flush remaining data.""" + if not self._is_open: + return + + self._is_open = False + + # Flush GPU encoder + final_data = self._gpu_encoder.flush() + if final_data: + try: + self._muxer.stdin.write(final_data) + except: + pass + + # Close muxer + try: + self._muxer.stdin.close() + self._muxer.wait(timeout=10) + except: + self._muxer.kill() + + # Final segment upload + self._check_upload_segments() + + # Wait for pending uploads to complete + self._upload_queue.put(None) # Signal shutdown + self._upload_thread.join(timeout=30) + + # Generate final playlist with #EXT-X-ENDLIST for VOD playback + self._generate_final_playlist() + + self._gpu_encoder.close() + + def _generate_final_playlist(self): + """Generate final IPFS playlist with #EXT-X-ENDLIST for completed streams.""" + with self._upload_lock: + if not self.segment_cids: + return + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + "#EXT-X-PLAYLIST-TYPE:VOD", # Mark as VOD for completed streams + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + # Use /ipfs-ts/ path for segments to get correct MIME type (video/mp2t) + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + lines.append(f"{segment_gateway}/{cid}") + + # Mark stream as complete - critical for VOD playback + lines.append("#EXT-X-ENDLIST") + + playlist_content = "\n".join(lines) + "\n" + + # Upload final playlist + self._playlist_cid = self._ipfs_add_bytes(playlist_content.encode(), pin=True) + if self._playlist_cid: + print(f"[GPUHLSOutput] Final VOD playlist: {self._playlist_cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + if self._on_playlist_update: + self._on_playlist_update(self._playlist_cid) + + @property + def is_open(self) -> bool: + return self._is_open + + @property + def playlist_cid(self) -> Optional[str]: + return self._playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + """Get the full IPFS URL for the playlist.""" + if self._playlist_cid: + return f"{self.ipfs_gateway}/{self._playlist_cid}" + return None diff --git a/streaming/jax_typography.py b/streaming/jax_typography.py new file mode 100644 index 0000000..74c0b31 --- /dev/null +++ b/streaming/jax_typography.py @@ -0,0 +1,1642 @@ +""" +JAX Typography Primitives + +Two approaches for text rendering, both compile to JAX/GPU: + +## 1. TextStrip - Pixel-perfect static text + Pre-render entire strings at compile time using PIL. + Perfect sub-pixel anti-aliasing, exact match with PIL. + Use for: static titles, labels, any text without per-character effects. + + S-expression: + (let ((strip (render-text-strip "Hello World" 48))) + (place-text-strip frame strip x y :color white)) + +## 2. Glyph-by-glyph - Dynamic text effects + Individual glyph placement for wave, arc, audio-reactive effects. + Each character can have independent position, color, opacity. + Note: slight anti-aliasing differences vs PIL due to integer positioning. + + S-expression: + ; Wave text - y oscillates with character index + (let ((glyphs (text-glyphs "Wavy" 48))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc)) + (i (length acc))) ; approximate index + (list + (place-glyph frm (glyph-image g) + (+ x cursor) + (+ y (* amplitude (sin (* i frequency)))) + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g)))))))) + + ; Audio-reactive spacing + (let ((glyphs (text-glyphs "Bass" 48)) + (bass (audio-band 0 200))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc))) + (list + (place-glyph frm (glyph-image g) + (+ x cursor) y + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g) (* bass 20)))))))) + +Kerning support: + ; With kerning adjustment + (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) +""" + +import math +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from typing import Tuple, Dict, Any, List, Optional +from dataclasses import dataclass + + +# ============================================================================= +# Glyph Data (computed at compile time) +# ============================================================================= + +@dataclass +class GlyphData: + """Glyph data computed at compile time. + + Attributes: + char: The character + image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime + advance: Horizontal advance (distance to next glyph origin) + bearing_x: Left side bearing (x offset from origin to first pixel) + bearing_y: Top bearing (y offset from baseline to top of glyph) + width: Image width + height: Image height + """ + char: str + image: np.ndarray # (H, W, 4) RGBA uint8 + advance: float + bearing_x: float + bearing_y: float + width: int + height: int + + +# Font cache: (font_name, font_size) -> {char: GlyphData} +_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {} + +# Font metrics cache: (font_name, font_size) -> (ascent, descent) +_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {} + +# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment} +# Kerning adjustment is added to advance: new_advance = advance + kerning +# Typically negative (characters move closer together) +_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {} + + +def _load_font(font_name: str = None, font_size: int = 32): + """Load a font. Called at compile time.""" + from PIL import ImageFont + + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + ] + + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + + return ImageFont.load_default() + + +def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]: + """Get or create glyph cache for a font. Called at compile time.""" + cache_key = (font_name, font_size) + + if cache_key in _GLYPH_CACHE: + return _GLYPH_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + _METRICS_CACHE[cache_key] = (ascent, descent) + + glyphs = {} + charset = ''.join(chr(i) for i in range(32, 127)) + + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + for char in charset: + # Get metrics + bbox = temp_draw.textbbox((0, 0), char, font=font) + advance = font.getlength(char) + + x_min, y_min, x_max, y_max = bbox + + # Create glyph image with padding + padding = 2 + img_w = max(int(x_max - x_min) + padding * 2, 1) + img_h = max(int(y_max - y_min) + padding * 2, 1) + + glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + glyph_draw = ImageDraw.Draw(glyph_img) + + # Draw at position accounting for bbox offset + draw_x = padding - x_min + draw_y = padding - y_min + glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + glyphs[char] = GlyphData( + char=char, + image=np.array(glyph_img, dtype=np.uint8), + advance=float(advance), + bearing_x=float(x_min), + bearing_y=float(-y_min), # Distance from baseline to top + width=img_w, + height=img_h, + ) + + _GLYPH_CACHE[cache_key] = glyphs + return glyphs + + +def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]: + """Get or create kerning cache for a font. Called at compile time. + + Kerning is computed as: + kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b) + + This gives the adjustment needed when placing 'b' after 'a'. + Typically negative (characters move closer together). + """ + cache_key = (font_name, font_size) + + if cache_key in _KERNING_CACHE: + return _KERNING_CACHE[cache_key] + + font = _load_font(font_name, font_size) + kerning = {} + + # Compute kerning for all printable ASCII pairs + charset = ''.join(chr(i) for i in range(32, 127)) + + # Pre-compute individual character lengths + char_lengths = {c: font.getlength(c) for c in charset} + + # Compute kerning for each pair + for c1 in charset: + for c2 in charset: + pair_length = font.getlength(c1 + c2) + individual_sum = char_lengths[c1] + char_lengths[c2] + kern = pair_length - individual_sum + + # Only store non-zero kerning to save memory + if abs(kern) > 0.01: + kerning[(c1, c2)] = kern + + _KERNING_CACHE[cache_key] = kerning + return kerning + + +def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float: + """Get kerning adjustment between two characters. Compile-time. + + Returns the adjustment to add to char1's advance when char2 follows. + Typically negative (characters move closer). + + Usage in S-expression: + (+ (glyph-advance g1) (kerning g1 g2)) + """ + kerning_cache = _get_kerning_cache(font_name, font_size) + return kerning_cache.get((char1, char2), 0.0) + + +@dataclass +class TextStrip: + """Pre-rendered text strip with proper sub-pixel anti-aliasing. + + Rendered at compile time using PIL for exact matching. + At runtime, just composite onto frame at integer positions. + + Attributes: + text: The original text + image: RGBA image as numpy array (H, W, 4) + width: Strip width + height: Strip height + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing of first character + anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right) + anchor_y: Y offset for anchor point (depends on anchor type) + stroke_width: Stroke width used when rendering + """ + text: str + image: np.ndarray + width: int + height: int + baseline_y: int + bearing_x: float + anchor_x: float = 0.0 + anchor_y: float = 0.0 + stroke_width: int = 0 + + +# Text strip cache: cache_key -> TextStrip +_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {} + + +def render_text_strip( + text: str, + font_name: str = None, + font_size: int = 32, + stroke_width: int = 0, + stroke_fill: tuple = None, + anchor: str = "la", # left-ascender (PIL default is "la") + multiline: bool = False, + line_spacing: int = 4, + align: str = "left", +) -> TextStrip: + """Render text to a strip at compile time. Perfect sub-pixel anti-aliasing. + + Args: + text: Text to render + font_name: Path to font file (None for default) + font_size: Font size in pixels + stroke_width: Outline width in pixels (0 for no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black + anchor: PIL anchor code - first char: h=left, m=middle, r=right + second char: a=ascender, t=top, m=middle, s=baseline, d=descender + multiline: If True, handle newlines in text + line_spacing: Extra pixels between lines (for multiline) + align: 'left', 'center', 'right' (for multiline) + + Returns: + TextStrip with pre-rendered text + """ + # Build cache key from all parameters + cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align) + if cache_key in _TEXT_STRIP_CACHE: + return _TEXT_STRIP_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + + # Default stroke fill to black + if stroke_fill is None: + stroke_fill = (0, 0, 0, 255) + elif len(stroke_fill) == 3: + stroke_fill = (*stroke_fill, 255) + + # Get text bbox (accounting for stroke) + temp = Image.new('RGBA', (1, 1)) + temp_draw = ImageDraw.Draw(temp) + + if multiline: + bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing, + stroke_width=stroke_width) + else: + bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width) + + # bbox is (left, top, right, bottom) relative to origin + x_min, y_min, x_max, y_max = bbox + + # Create image with padding (extra for stroke) + padding = 2 + stroke_width + img_width = max(int(x_max - x_min) + padding * 2, 1) + img_height = max(int(y_max - y_min) + padding * 2, 1) + + # Create RGBA image + img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw text at position that puts it in the image + draw_x = padding - x_min + draw_y = padding - y_min + + if multiline: + draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + spacing=line_spacing, align=align, + stroke_width=stroke_width, stroke_fill=stroke_fill) + else: + draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill) + + # Baseline is at y=0 in text coordinates, which is at draw_y in image + baseline_y = draw_y + + # Convert to numpy for pixel analysis + img_array = np.array(img, dtype=np.uint8) + + # Calculate anchor offsets + # For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching + h_anchor = anchor[0] if len(anchor) > 0 else 'l' + v_anchor = anchor[1] if len(anchor) > 1 else 'a' + + # Find actual pixel bounds (for middle anchor calculations) + alpha = img_array[:, :, 3] + nonzero_cols = np.where(alpha.max(axis=0) > 0)[0] + nonzero_rows = np.where(alpha.max(axis=1) > 0)[0] + + if len(nonzero_cols) > 0: + pixel_x_min = nonzero_cols.min() + pixel_x_max = nonzero_cols.max() + pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0 + else: + pixel_x_center = img_width / 2.0 + + if len(nonzero_rows) > 0: + pixel_y_min = nonzero_rows.min() + pixel_y_max = nonzero_rows.max() + pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0 + else: + pixel_y_center = img_height / 2.0 + + # Horizontal offset + text_width = x_max - x_min + if h_anchor == 'l': # left edge of text + anchor_x = float(draw_x) + elif h_anchor == 'm': # middle - use actual pixel center for perfect matching + anchor_x = pixel_x_center + elif h_anchor == 'r': # right edge of text + anchor_x = float(draw_x + text_width) + else: + anchor_x = float(draw_x) + + # Vertical offset + # PIL anchor positions are based on font metrics (ascent/descent): + # - 'a' (ascender): at the ascender line → draw_y in strip + # - 't' (top): at top of text bounding box → padding in strip + # - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender + # - 's' (baseline): at baseline = ascent below ascender + # - 'd' (descender): at descender line = ascent + descent below ascender + + if v_anchor == 'a': # ascender + anchor_y = float(draw_y) + elif v_anchor == 't': # top of bbox + anchor_y = float(padding) + elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation) + anchor_y = float(draw_y + (ascent + descent) / 2.0) + elif v_anchor == 's': # baseline + anchor_y = float(draw_y + ascent) + elif v_anchor == 'd': # descender + anchor_y = float(draw_y + ascent + descent) + else: + anchor_y = float(draw_y) # default to ascender + + strip = TextStrip( + text=text, + image=img_array, + width=img_width, + height=img_height, + baseline_y=baseline_y, + bearing_x=float(x_min), + anchor_x=anchor_x, + anchor_y=anchor_y, + stroke_width=stroke_width, + ) + + _TEXT_STRIP_CACHE[cache_key] = strip + return strip + + +# ============================================================================= +# Compile-time functions (called during S-expression compilation) +# ============================================================================= + +def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData: + """Get glyph data for a single character. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + return cache.get(char, cache.get(' ')) + + +def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list: + """Get glyph data for a string. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + space = cache.get(' ') + return [cache.get(c, space) for c in text] + + +def get_font_ascent(font_name: str = None, font_size: int = 32) -> float: + """Get font ascent. Compile-time.""" + _get_glyph_cache(font_name, font_size) # Ensure cache exists + return _METRICS_CACHE[(font_name, font_size)][0] + + +def get_font_descent(font_name: str = None, font_size: int = 32) -> float: + """Get font descent. Compile-time.""" + _get_glyph_cache(font_name, font_size) + return _METRICS_CACHE[(font_name, font_size)][1] + + +# ============================================================================= +# JAX Runtime Primitives +# ============================================================================= + +def place_glyph_jax( + frame: jnp.ndarray, + glyph_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + bearing_x: float, + bearing_y: float, + color: jnp.ndarray, # (3,) RGB 0-255 + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Place a glyph onto a frame. This is the core JAX primitive. + + All positioning math can use traced values (x, y from audio, time, etc.) + The glyph_image is static (determined at compile time). + + Args: + frame: (H, W, 3) RGB frame + glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array) + x: X position of glyph origin (baseline point) + y: Y position of baseline + bearing_x: Left side bearing + bearing_y: Top bearing (from baseline to top) + color: RGB color array + opacity: Opacity 0-1 + + Returns: + Frame with glyph composited + """ + h, w = frame.shape[:2] + gh, gw = glyph_image.shape[:2] + + # Calculate destination position + # bearing_x: how far right of origin the glyph starts (can be negative) + # bearing_y: how far up from baseline the glyph extends + padding = 2 # Must match padding used in glyph creation + dst_x = x + bearing_x - padding + dst_y = y - bearing_y - padding + + # Extract glyph RGB and alpha + glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + opacity_int = jnp.round(opacity * 255) + glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32) + glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint (glyph is white, multiply by color) + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = glyph_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use padded buffer to avoid XLA's dynamic_update_slice clamping + buf_h = h + 2 * gh + buf_w = w + 2 * gw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + dst_x_int = dst_x.astype(jnp.int32) + dst_y_int = dst_y.astype(jnp.int32) + place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :] + alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """ + Place a pre-rendered text strip onto a frame. + + The strip was rendered at compile time with proper sub-pixel anti-aliasing. + This just composites it at the specified position. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x: X position for anchor point + y: Y position for anchor point + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing + color: RGB color + opacity: Opacity 0-1 + anchor_x: X offset of anchor point within strip + anchor_y: Y offset of anchor point within strip + stroke_width: Stroke width used when rendering (affects padding) + + Returns: + Frame with text composited + """ + h, w = frame.shape[:2] + sh, sw = strip_image.shape[:2] + + # Calculate destination position + # Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame + # anchor_x/anchor_y already account for the anchor position within the strip + # Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding) + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract strip RGB and alpha + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + # Use jnp.round (banker's rounding) to match Python's round() used by PIL + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior. + # XLA clamps indices so the update fits, which silently shifts the strip. + # By placing into a buffer padded by strip dimensions, then extracting the + # frame-sized region, we get correct clipping for both overflow and underflow. + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + # Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + # Extract frame-sized region (sh, sw are compile-time constants from strip shape) + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic: + # result = (src * alpha + dst * (255 - alpha) + 127) // 255 + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_glyph_simple( + frame: jnp.ndarray, + glyph: GlyphData, + x: float, + y: float, + color: tuple = (255, 255, 255), + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Convenience wrapper that takes GlyphData directly. + Converts glyph image to JAX array. + + For S-expression use, prefer place_glyph_jax with pre-converted arrays. + """ + glyph_jax = jnp.asarray(glyph.image) + color_jax = jnp.array(color, dtype=jnp.float32) + + return place_glyph_jax( + frame, glyph_jax, x, y, + glyph.bearing_x, glyph.bearing_y, + color_jax, opacity + ) + + +# ============================================================================= +# Gradient Functions (compile-time: generate color maps from strip dimensions) +# ============================================================================= + +def make_linear_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + angle: float = 0.0, +) -> np.ndarray: + """Create a linear gradient color map. + + Args: + width, height: Dimensions of the gradient (match strip dimensions) + color1: Start color (R, G, B) 0-255 + color2: End color (R, G, B) 0-255 + angle: Gradient angle in degrees (0 = left-to-right, 90 = top-to-bottom) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + # Create coordinate grid + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Project onto gradient axis + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + + # Project (nx - 0.5, ny - 0.5) onto direction vector, then remap to [0, 1] + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + # Normalize: max projection is 0.5*|cos|+0.5*|sin| = 0.5*(|cos|+|sin|) + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + t = np.clip(t, 0.0, 1.0) + + # Interpolate + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_radial_gradient( + width: int, + height: int, + color1: tuple, + color2: tuple, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a radial gradient color map. + + Args: + width, height: Dimensions + color1: Inner color (R, G, B) + color2: Outer color (R, G, B) + center_x, center_y: Center position in [0, 1] (0.5 = center) + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + c1 = np.array(color1[:3], dtype=np.float32) / 255.0 + c2 = np.array(color2[:3], dtype=np.float32) / 255.0 + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + # Normalize to [0, 1] + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + # Distance from center, normalized so corners are ~1.0 + dx = nx - center_x + dy = ny - center_y + # Max possible distance from center to a corner + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + if max_dist > 0: + t = np.sqrt(dx**2 + dy**2) / max_dist + else: + t = np.zeros_like(dx) + t = np.clip(t, 0.0, 1.0) + + gradient = c1[None, None, :] * (1 - t[:, :, None]) + c2[None, None, :] * t[:, :, None] + return gradient + + +def make_multi_stop_gradient( + width: int, + height: int, + stops: list, + angle: float = 0.0, + radial: bool = False, + center_x: float = 0.5, + center_y: float = 0.5, +) -> np.ndarray: + """Create a multi-stop gradient color map. + + Args: + width, height: Dimensions + stops: List of (position, (R, G, B)) tuples, position in [0, 1] + angle: Gradient angle in degrees (for linear mode) + radial: If True, use radial gradient + center_x, center_y: Center for radial gradient + + Returns: + (height, width, 3) float32 array with values in [0, 1] + """ + if len(stops) < 2: + if len(stops) == 1: + c = np.array(stops[0][1][:3], dtype=np.float32) / 255.0 + return np.broadcast_to(c, (height, width, 3)).copy() + return np.zeros((height, width, 3), dtype=np.float32) + + # Sort stops by position + stops = sorted(stops, key=lambda s: s[0]) + + ys = np.arange(height, dtype=np.float32) + xs = np.arange(width, dtype=np.float32) + yy, xx = np.meshgrid(ys, xs, indexing='ij') + + nx = xx / max(width - 1, 1) + ny = yy / max(height - 1, 1) + + if radial: + dx = nx - center_x + dy = ny - center_y + max_dist = np.sqrt(max(center_x, 1 - center_x)**2 + max(center_y, 1 - center_y)**2) + t = np.sqrt(dx**2 + dy**2) / max(max_dist, 1e-6) + else: + theta = angle * np.pi / 180.0 + cos_t = np.cos(theta) + sin_t = np.sin(theta) + proj = (nx - 0.5) * cos_t + (ny - 0.5) * sin_t + max_proj = 0.5 * (abs(cos_t) + abs(sin_t)) + if max_proj > 0: + t = (proj / max_proj + 1.0) / 2.0 + else: + t = np.full_like(proj, 0.5) + + t = np.clip(t, 0.0, 1.0) + + # Build gradient from stops using piecewise linear interpolation + colors = np.array([np.array(s[1][:3], dtype=np.float32) / 255.0 for s in stops]) + positions = np.array([s[0] for s in stops], dtype=np.float32) + + # Start with first color + gradient = np.broadcast_to(colors[0], (height, width, 3)).copy() + + for i in range(len(stops) - 1): + p0, p1 = positions[i], positions[i + 1] + c0, c1 = colors[i], colors[i + 1] + + if p1 <= p0: + continue + + # Segment interpolation factor + seg_t = np.clip((t - p0) / (p1 - p0), 0.0, 1.0) + # Only apply where t >= p0 + mask = (t >= p0)[:, :, None] + seg_color = c0[None, None, :] * (1 - seg_t[:, :, None]) + c1[None, None, :] * seg_t[:, :, None] + gradient = np.where(mask, seg_color, gradient) + + return gradient + + +def _composite_strip_onto_frame( + frame: jnp.ndarray, + strip_rgb: jnp.ndarray, + strip_alpha: jnp.ndarray, + dst_x: jnp.ndarray, + dst_y: jnp.ndarray, + sh: int, + sw: int, +) -> jnp.ndarray: + """Core compositing: place tinted+alpha strip onto frame using padded buffer. + + Args: + frame: (H, W, 3) RGB uint8 + strip_rgb: (sh, sw, 3) float32 in [0, 1] - pre-tinted strip RGB + strip_alpha: (sh, sw, 1) float32 in [0, 1] - effective alpha + dst_x, dst_y: int32 destination position + sh, sw: strip dimensions (compile-time constants) + + Returns: + Composited frame (H, W, 3) uint8 + """ + h, w = frame.shape[:2] + + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = lax.dynamic_update_slice(rgb_buf, strip_rgb, (place_y, place_x, 0)) + alpha_buf = lax.dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # PIL-compatible integer alpha blending + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_gradient_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + gradient_map: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """Place text strip with gradient coloring instead of solid color. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + gradient_map: (sh, sw, 3) float32 color map in [0, 1] + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + sh, sw = strip_image.shape[:2] + + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract alpha with opacity + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply gradient instead of solid color + tinted = strip_rgb * gradient_map + + return _composite_strip_onto_frame(frame, tinted, strip_alpha, dst_x, dst_y, sh, sw) + + +# ============================================================================= +# Strip Rotation (RGBA bilinear interpolation) +# ============================================================================= + +def _sample_rgba(strip, x, y): + """Bilinear sample all 4 RGBA channels from a strip. + + Args: + strip: (H, W, 4) RGBA float32 + x, y: coordinate arrays (flattened) + + Returns: + (r, g, b, a) each same shape as x + """ + h, w = strip.shape[:2] + + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = x0 + 1 + y1 = y0 + 1 + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + + channels = [] + for c in range(4): + c00 = jnp.where(valid00, strip[y0_safe, x0_safe, c], 0.0) + c10 = jnp.where(valid10, strip[y0_safe, x1_safe, c], 0.0) + c01 = jnp.where(valid01, strip[y1_safe, x0_safe, c], 0.0) + c11 = jnp.where(valid11, strip[y1_safe, x1_safe, c], 0.0) + + val = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + channels.append(val) + + return channels[0], channels[1], channels[2], channels[3] + + +def rotate_strip_jax( + strip_image: jnp.ndarray, + angle: float, +) -> jnp.ndarray: + """Rotate an RGBA strip by angle (degrees), counter-clockwise. + + Output buffer is sized to contain the full rotated strip. + The output size is ceil(sqrt(w^2 + h^2)), computed at trace time + from the strip's static shape. + + Args: + strip_image: (H, W, 4) RGBA uint8 + angle: Rotation angle in degrees + + Returns: + (out_h, out_w, 4) RGBA uint8 - rotated strip + """ + sh, sw = strip_image.shape[:2] + + # Output size: diagonal of original strip (compile-time constant). + # Ensure output dimensions have same parity as source so that the + # center offset (out - src) / 2 is always an integer. Otherwise + # identity rotations would place content at half-pixel offsets. + diag = int(math.ceil(math.sqrt(sw * sw + sh * sh))) + out_w = diag + ((diag % 2) != (sw % 2)) + out_h = diag + ((diag % 2) != (sh % 2)) + + # Center of input strip and output buffer (pixel-center convention). + # Using (dim-1)/2 ensures integer coords map to integer coords for + # identity rotation regardless of even/odd dimension parity. + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (out_w - 1) / 2.0 + dst_cy = (out_h - 1) / 2.0 + + # Convert to radians and snap trig values near 0/±1 to exact values. + # Without snapping, e.g. sin(360°) ≈ 1.7e-7 instead of 0, causing + # bilinear blending at pixel edges and 1-value differences. + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Create output coordinate grid + y_coords = jnp.repeat(jnp.arange(out_h), out_w).reshape(out_h, out_w) + x_coords = jnp.tile(jnp.arange(out_w), out_h).reshape(out_h, out_w) + + # Inverse rotation: map output coords to source coords + x_centered = x_coords.astype(jnp.float32) - dst_cx + y_centered = y_coords.astype(jnp.float32) - dst_cy + + src_x = cos_t * x_centered - sin_t * y_centered + src_cx + src_y = sin_t * x_centered + cos_t * y_centered + src_cy + + # Sample all 4 channels + strip_f = strip_image.astype(jnp.float32) + r, g, b, a = _sample_rgba(strip_f, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + jnp.clip(a, 0, 255).reshape(out_h, out_w).astype(jnp.uint8), + ], axis=2) + + +# ============================================================================= +# Shadow Compositing +# ============================================================================= + +def _blur_alpha_channel(alpha: jnp.ndarray, radius: int) -> jnp.ndarray: + """Blur a single-channel alpha array using Gaussian convolution. + + Args: + alpha: (H, W) float32 alpha channel + radius: Blur radius (compile-time constant) + + Returns: + (H, W) float32 blurred alpha + """ + size = radius * 2 + 1 + x = jnp.arange(size, dtype=jnp.float32) - radius + sigma = max(radius / 2.0, 0.5) + gaussian_1d = jnp.exp(-x**2 / (2 * sigma**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + # Use JAX conv with SAME padding + h, w = alpha.shape + data_4d = alpha.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(size, size, 1, 1) + + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + ) + return result.reshape(h, w) + + +def place_text_strip_shadow_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + shadow_offset_x: float = 3.0, + shadow_offset_y: float = 3.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.5, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Place text strip with a drop shadow. + + Composites the strip twice: first as shadow (offset, colored, optionally blurred), + then the text itself on top. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + shadow_offset_x/y: Shadow offset in pixels + shadow_color: (3,) RGB color for shadow (default black) + shadow_opacity: Shadow opacity 0-1 + shadow_blur_radius: Gaussian blur radius for shadow (0 = sharp, compile-time) + Other args same as place_text_strip_jax + + Returns: + Composited frame + """ + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Shadow pass --- + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + # Shadow alpha from strip alpha + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + # Blur the alpha channel for soft shadow + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] # (sh, sw, 1) + + # Shadow RGB: solid shadow color + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Text pass --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + text_alpha = jnp.floor(strip_a_raw[:, :, None] * opacity_int / 255.0 + 0.5) / 255.0 + + color_norm = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_norm + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + +# ============================================================================= +# Combined FX Pipeline +# ============================================================================= + +def place_text_strip_fx_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, + x: float, + y: float, + baseline_y: int = 0, + bearing_x: float = 0.0, + color: jnp.ndarray = None, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, + gradient_map: jnp.ndarray = None, + angle: float = 0.0, + shadow_offset_x: float = 0.0, + shadow_offset_y: float = 0.0, + shadow_color: jnp.ndarray = None, + shadow_opacity: float = 0.0, + shadow_blur_radius: int = 0, +) -> jnp.ndarray: + """Combined text placement with gradient, rotation, and shadow. + + Pipeline order: + 1. Build color layer (solid color or gradient map) + 2. Rotate strip + color layer if angle != 0 + 3. Composite shadow if shadow_opacity > 0 + 4. Composite text + + Note: angle and shadow_blur_radius should be compile-time constants + for optimal JIT performance (they affect buffer shapes/kernel sizes). + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x, y: Anchor point position + color: (3,) RGB color (ignored if gradient_map provided) + opacity: Text opacity + gradient_map: (sh, sw, 3) float32 color map in [0,1], or None for solid color + angle: Rotation angle in degrees (0 = no rotation) + shadow_offset_x/y: Shadow offset + shadow_color: (3,) RGB shadow color + shadow_opacity: Shadow opacity (0 = no shadow) + shadow_blur_radius: Shadow blur radius + + Returns: + Composited frame + """ + if color is None: + color = jnp.array([255, 255, 255], dtype=jnp.float32) + if shadow_color is None: + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + sh, sw = strip_image.shape[:2] + + # --- Step 1: Build color layer --- + if gradient_map is not None: + color_layer = gradient_map # (sh, sw, 3) float32 [0, 1] + else: + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (sh, sw, 3)) + + # --- Step 2: Rotate if needed --- + # angle is expected to be a compile-time constant or static value + # We check at Python level to avoid tracing issues with dynamic shapes + use_rotation = not isinstance(angle, (int, float)) or angle != 0.0 + + if use_rotation: + # Rotate the strip + rotated_strip = rotate_strip_jax(strip_image, angle) + rh, rw = rotated_strip.shape[:2] + + # Rotate the color layer by building a 4-channel color+dummy image + # Actually, just re-create color layer at rotated size + if gradient_map is not None: + # Rotate gradient map: pack into 3-channel "image", rotate via sampling + grad_uint8 = jnp.clip(gradient_map * 255, 0, 255).astype(jnp.uint8) + # Create RGBA from gradient (alpha=255 everywhere) + grad_rgba = jnp.concatenate([grad_uint8, jnp.full((sh, sw, 1), 255, dtype=jnp.uint8)], axis=2) + rotated_grad_rgba = rotate_strip_jax(grad_rgba, angle) + color_layer = rotated_grad_rgba[:, :, :3].astype(jnp.float32) / 255.0 + else: + # Solid color: just broadcast to rotated size + color_norm = color.astype(jnp.float32) / 255.0 + color_layer = jnp.broadcast_to(color_norm[None, None, :], (rh, rw, 3)) + + # Update anchor point for rotation (pixel-center convention) + # Rotate the anchor offset around the strip center + theta = angle * jnp.pi / 180.0 + cos_t = jnp.cos(theta) + sin_t = jnp.sin(theta) + cos_t = jnp.where(jnp.abs(cos_t) < 1e-6, 0.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t) < 1e-6, 0.0, sin_t) + cos_t = jnp.where(jnp.abs(cos_t - 1.0) < 1e-6, 1.0, cos_t) + cos_t = jnp.where(jnp.abs(cos_t + 1.0) < 1e-6, -1.0, cos_t) + sin_t = jnp.where(jnp.abs(sin_t - 1.0) < 1e-6, 1.0, sin_t) + sin_t = jnp.where(jnp.abs(sin_t + 1.0) < 1e-6, -1.0, sin_t) + + # Original anchor relative to strip pixel center + src_cx = (sw - 1) / 2.0 + src_cy = (sh - 1) / 2.0 + dst_cx = (rw - 1) / 2.0 + dst_cy = (rh - 1) / 2.0 + + ax_rel = anchor_x - src_cx + ay_rel = anchor_y - src_cy + + # Rotate anchor point (forward rotation, not inverse) + new_ax = -sin_t * ay_rel + cos_t * ax_rel + dst_cx + new_ay = cos_t * ay_rel + sin_t * ax_rel + dst_cy + + strip_image = rotated_strip + anchor_x = new_ax + anchor_y = new_ay + sh, sw = rh, rw + + # --- Step 3: Shadow --- + has_shadow = not isinstance(shadow_opacity, (int, float)) or shadow_opacity > 0 + if has_shadow: + shadow_dst_x = jnp.floor(x - anchor_x + shadow_offset_x + 0.5).astype(jnp.int32) + shadow_dst_y = jnp.floor(y - anchor_y + shadow_offset_y + 0.5).astype(jnp.int32) + + shadow_opacity_int = jnp.round(shadow_opacity * 255) + strip_a_raw = strip_image[:, :, 3].astype(jnp.float32) + + if shadow_blur_radius > 0: + blurred_alpha = _blur_alpha_channel(strip_a_raw / 255.0, shadow_blur_radius) + shadow_alpha = jnp.floor(blurred_alpha * shadow_opacity_int + 0.5) / 255.0 + else: + shadow_alpha = jnp.floor(strip_a_raw * shadow_opacity_int / 255.0 + 0.5) / 255.0 + shadow_alpha = shadow_alpha[:, :, None] + + shadow_color_norm = shadow_color.astype(jnp.float32) / 255.0 + shadow_rgb = jnp.broadcast_to(shadow_color_norm[None, None, :], (sh, sw, 3)) + + frame = _composite_strip_onto_frame(frame, shadow_rgb, shadow_alpha, shadow_dst_x, shadow_dst_y, sh, sw) + + # --- Step 4: Composite text --- + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + text_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + tinted = strip_rgb * color_layer + + frame = _composite_strip_onto_frame(frame, tinted, text_alpha, dst_x, dst_y, sh, sw) + + return frame + + +# ============================================================================= +# S-Expression Primitive Bindings +# ============================================================================= + +def bind_typography_primitives(env: dict) -> dict: + """ + Add typography primitives to an S-expression environment. + + Primitives added: + (text-glyphs text font-size) -> list of glyph data + (glyph-image g) -> JAX array (H, W, 4) + (glyph-advance g) -> float + (glyph-bearing-x g) -> float + (glyph-bearing-y g) -> float + (glyph-width g) -> int + (glyph-height g) -> int + (font-ascent font-size) -> float + (font-descent font-size) -> float + (place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame + """ + + def prim_text_glyphs(text, font_size=32, font_name=None): + """Get list of glyph data for text. Compile-time.""" + return get_glyphs(str(text), font_name, int(font_size)) + + def prim_glyph_image(glyph): + """Get glyph image as JAX array.""" + return jnp.asarray(glyph.image) + + def prim_glyph_advance(glyph): + """Get glyph advance width.""" + return glyph.advance + + def prim_glyph_bearing_x(glyph): + """Get glyph left side bearing.""" + return glyph.bearing_x + + def prim_glyph_bearing_y(glyph): + """Get glyph top bearing.""" + return glyph.bearing_y + + def prim_glyph_width(glyph): + """Get glyph image width.""" + return glyph.width + + def prim_glyph_height(glyph): + """Get glyph image height.""" + return glyph.height + + def prim_font_ascent(font_size=32, font_name=None): + """Get font ascent.""" + return get_font_ascent(font_name, int(font_size)) + + def prim_font_descent(font_size=32, font_name=None): + """Get font descent.""" + return get_font_descent(font_name, int(font_size)) + + def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y, + color=(255, 255, 255), opacity=1.0): + """Place glyph on frame. Runtime JAX operation.""" + color_arr = jnp.array(color, dtype=jnp.float32) + return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y, + color_arr, opacity) + + def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None): + """Get kerning adjustment between two glyphs. Compile-time. + + Returns adjustment to add to glyph1's advance when glyph2 follows. + Typically negative (characters move closer). + + Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size)) + """ + return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size)) + + def prim_char_kerning(char1, char2, font_size=32, font_name=None): + """Get kerning adjustment between two characters. Compile-time.""" + return get_kerning(str(char1), str(char2), font_name, int(font_size)) + + # TextStrip primitives for pre-rendered text with proper anti-aliasing + def prim_render_text_strip(text, font_size=32, font_name=None): + """Render text to a strip at compile time. Perfect anti-aliasing.""" + return render_text_strip(str(text), font_name, int(font_size)) + + def prim_render_text_strip_styled( + text, font_size=32, font_name=None, + stroke_width=0, stroke_fill=None, + anchor="la", multiline=False, line_spacing=4, align="left" + ): + """Render styled text to a strip. Supports stroke, anchors, multiline. + + Args: + text: Text to render + font_size: Size in pixels + font_name: Path to font file + stroke_width: Outline width (0 = no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A) + anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender) + multiline: If True, handle newlines + line_spacing: Extra pixels between lines + align: "left", "center", "right" for multiline + """ + return render_text_strip( + str(text), font_name, int(font_size), + stroke_width=int(stroke_width), + stroke_fill=stroke_fill, + anchor=str(anchor), + multiline=bool(multiline), + line_spacing=int(line_spacing), + align=str(align), + ) + + def prim_text_strip_image(strip): + """Get text strip image as JAX array.""" + return jnp.asarray(strip.image) + + def prim_text_strip_width(strip): + """Get text strip width.""" + return strip.width + + def prim_text_strip_height(strip): + """Get text strip height.""" + return strip.height + + def prim_text_strip_baseline_y(strip): + """Get text strip baseline Y position.""" + return strip.baseline_y + + def prim_text_strip_bearing_x(strip): + """Get text strip left bearing.""" + return strip.bearing_x + + def prim_text_strip_anchor_x(strip): + """Get text strip anchor X offset.""" + return strip.anchor_x + + def prim_text_strip_anchor_y(strip): + """Get text strip anchor Y offset.""" + return strip.anchor_y + + def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0): + """Place pre-rendered text strip on frame. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # --- Gradient primitives --- + + def prim_linear_gradient(strip, color1, color2, angle=0.0): + """Create linear gradient color map for a text strip. Compile-time.""" + grad = make_linear_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(angle)) + return jnp.asarray(grad) + + def prim_radial_gradient(strip, color1, color2, center_x=0.5, center_y=0.5): + """Create radial gradient color map for a text strip. Compile-time.""" + grad = make_radial_gradient(strip.width, strip.height, + tuple(int(c) for c in color1), + tuple(int(c) for c in color2), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_multi_stop_gradient(strip, stops, angle=0.0, radial=False, + center_x=0.5, center_y=0.5): + """Create multi-stop gradient for a text strip. Compile-time. + + stops: list of (position, (R, G, B)) tuples + """ + parsed_stops = [] + for s in stops: + pos = float(s[0]) + color_tuple = tuple(int(c) for c in s[1]) + parsed_stops.append((pos, color_tuple)) + grad = make_multi_stop_gradient(strip.width, strip.height, + parsed_stops, float(angle), + bool(radial), + float(center_x), float(center_y)) + return jnp.asarray(grad) + + def prim_place_text_strip_gradient(frame, strip, x, y, gradient_map, opacity=1.0): + """Place text strip with gradient coloring. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + gradient_map, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # --- Rotation primitive --- + + def prim_place_text_strip_rotated(frame, strip, x, y, color=(255, 255, 255), + opacity=1.0, angle=0.0): + """Place text strip with rotation. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + angle=float(angle), + ) + + # --- Shadow primitive --- + + def prim_place_text_strip_shadow(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=(0, 0, 0), shadow_opacity=0.5, + shadow_blur_radius=0): + """Place text strip with shadow. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_shadow_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur_radius), + ) + + # --- Combined FX primitive --- + + def prim_place_text_strip_fx(frame, strip, x, y, + color=(255, 255, 255), opacity=1.0, + gradient=None, angle=0.0, + shadow_offset_x=0.0, shadow_offset_y=0.0, + shadow_color=(0, 0, 0), shadow_opacity=0.0, + shadow_blur=0): + """Place text strip with all effects. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + shadow_color_arr = jnp.array(shadow_color, dtype=jnp.float32) + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color_arr, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width, + gradient_map=gradient, + angle=float(angle), + shadow_offset_x=float(shadow_offset_x), + shadow_offset_y=float(shadow_offset_y), + shadow_color=shadow_color_arr, + shadow_opacity=float(shadow_opacity), + shadow_blur_radius=int(shadow_blur), + ) + + # Add to environment + env.update({ + # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) + 'text-glyphs': prim_text_glyphs, + 'glyph-image': prim_glyph_image, + 'glyph-advance': prim_glyph_advance, + 'glyph-bearing-x': prim_glyph_bearing_x, + 'glyph-bearing-y': prim_glyph_bearing_y, + 'glyph-width': prim_glyph_width, + 'glyph-height': prim_glyph_height, + 'glyph-kerning': prim_glyph_kerning, + 'char-kerning': prim_char_kerning, + 'font-ascent': prim_font_ascent, + 'font-descent': prim_font_descent, + 'place-glyph': prim_place_glyph, + # TextStrip primitives (for pixel-perfect static text) + 'render-text-strip': prim_render_text_strip, + 'render-text-strip-styled': prim_render_text_strip_styled, + 'text-strip-image': prim_text_strip_image, + 'text-strip-width': prim_text_strip_width, + 'text-strip-height': prim_text_strip_height, + 'text-strip-baseline-y': prim_text_strip_baseline_y, + 'text-strip-bearing-x': prim_text_strip_bearing_x, + 'text-strip-anchor-x': prim_text_strip_anchor_x, + 'text-strip-anchor-y': prim_text_strip_anchor_y, + 'place-text-strip': prim_place_text_strip, + # Gradient primitives + 'linear-gradient': prim_linear_gradient, + 'radial-gradient': prim_radial_gradient, + 'multi-stop-gradient': prim_multi_stop_gradient, + 'place-text-strip-gradient': prim_place_text_strip_gradient, + # Rotation + 'place-text-strip-rotated': prim_place_text_strip_rotated, + # Shadow + 'place-text-strip-shadow': prim_place_text_strip_shadow, + # Combined FX + 'place-text-strip-fx': prim_place_text_strip_fx, + }) + + return env + + +# ============================================================================= +# Example: Render text using primitives (for testing) +# ============================================================================= + +def render_text_primitives( + frame: jnp.ndarray, + text: str, + x: float, + y: float, + font_size: int = 32, + color: tuple = (255, 255, 255), + opacity: float = 1.0, + use_kerning: bool = True, +) -> jnp.ndarray: + """ + Render text using the primitives. + This is what an S-expression would compile to. + + Args: + use_kerning: If True, apply kerning adjustments between characters + """ + glyphs = get_glyphs(text, None, font_size) + color_arr = jnp.array(color, dtype=jnp.float32) + + cursor = x + for i, g in enumerate(glyphs): + glyph_jax = jnp.asarray(g.image) + frame = place_glyph_jax( + frame, glyph_jax, cursor, y, + g.bearing_x, g.bearing_y, + color_arr, opacity + ) + # Advance cursor with optional kerning + advance = g.advance + if use_kerning and i + 1 < len(glyphs): + advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size) + cursor = cursor + advance + + return frame diff --git a/streaming/jit_compiler.py b/streaming/jit_compiler.py new file mode 100644 index 0000000..bb8c97c --- /dev/null +++ b/streaming/jit_compiler.py @@ -0,0 +1,531 @@ +""" +JIT Compiler for sexp frame pipelines. + +Compiles sexp expressions to fused CUDA kernels for maximum performance. +""" + +import cupy as cp +import numpy as np +from typing import Dict, List, Any, Optional, Tuple, Callable +import hashlib +import sys + +# Cache for compiled kernels +_KERNEL_CACHE: Dict[str, Callable] = {} + + +def _generate_kernel_key(ops: List[Tuple]) -> str: + """Generate cache key for operation sequence.""" + return hashlib.md5(str(ops).encode()).hexdigest() + + +# ============================================================================= +# CUDA Kernel Templates +# ============================================================================= + +AFFINE_WARP_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void affine_warp( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float m00, float m01, float m02, + float m10, float m11, float m12 +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Apply inverse affine transform + float src_x = m00 * x + m01 * y + m02; + float src_y = m10 * x + m11 * y + m12; + + int dst_idx = (y * width + x) * channels; + + // Bounds check + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + int x1 = x0 + 1; + int y1 = y0 + 1; + + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + x1) * channels + c]; + float v01 = src[(y1 * width + x0) * channels + c]; + float v11 = src[(y1 * width + x1) * channels + c]; + + float v0 = v00 * (1 - fx) + v10 * fx; + float v1 = v01 * (1 - fx) + v11 * fx; + float v = v0 * (1 - fy) + v1 * fy; + + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'affine_warp') + + +BLEND_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void blend( + const unsigned char* src1, + const unsigned char* src2, + unsigned char* dst, + int size, + float alpha +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + float v = src1[idx] * (1.0f - alpha) + src2[idx] * alpha; + dst[idx] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); +} +''', 'blend') + + +BRIGHTNESS_CONTRAST_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void brightness_contrast( + const unsigned char* src, + unsigned char* dst, + int size, + float brightness, + float contrast +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + float v = src[idx]; + v = (v - 128.0f) * contrast + 128.0f + brightness; + dst[idx] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); +} +''', 'brightness_contrast') + + +HUE_SHIFT_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void hue_shift( + const unsigned char* src, + unsigned char* dst, + int width, int height, + float hue_shift +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + int idx = (y * width + x) * 3; + + float r = src[idx] / 255.0f; + float g = src[idx + 1] / 255.0f; + float b = src[idx + 2] / 255.0f; + + // RGB to HSV + float max_c = fmaxf(r, fmaxf(g, b)); + float min_c = fminf(r, fminf(g, b)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) { + s = delta / max_c; + if (r >= max_c) h = (g - b) / delta; + else if (g >= max_c) h = 2.0f + (b - r) / delta; + else h = 4.0f + (r - g) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + } + + // Apply hue shift + h = fmodf(h + hue_shift + 360.0f, 360.0f); + + // HSV to RGB + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) { r2 = c; g2 = x_val; b2 = 0; } + else if (h < 120) { r2 = x_val; g2 = c; b2 = 0; } + else if (h < 180) { r2 = 0; g2 = c; b2 = x_val; } + else if (h < 240) { r2 = 0; g2 = x_val; b2 = c; } + else if (h < 300) { r2 = x_val; g2 = 0; b2 = c; } + else { r2 = c; g2 = 0; b2 = x_val; } + + dst[idx] = (unsigned char)((r2 + m) * 255); + dst[idx + 1] = (unsigned char)((g2 + m) * 255); + dst[idx + 2] = (unsigned char)((b2 + m) * 255); +} +''', 'hue_shift') + + +INVERT_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void invert( + const unsigned char* src, + unsigned char* dst, + int size +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + dst[idx] = 255 - src[idx]; +} +''', 'invert') + + +ZOOM_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void zoom( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float zoom_factor, + float cx, float cy +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Map to source coordinates (zoom from center) + float src_x = (x - cx) / zoom_factor + cx; + float src_y = (y - cy) / zoom_factor + cy; + + int dst_idx = (y * width + x) * channels; + + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = 0; + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + (x0+1)) * channels + c]; + float v01 = src[((y0+1) * width + x0) * channels + c]; + float v11 = src[((y0+1) * width + (x0+1)) * channels + c]; + + float v = v00*(1-fx)*(1-fy) + v10*fx*(1-fy) + v01*(1-fx)*fy + v11*fx*fy; + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'zoom') + + +RIPPLE_KERNEL = cp.RawKernel(r''' +extern "C" __global__ +void ripple( + const unsigned char* src, + unsigned char* dst, + int width, int height, int channels, + float cx, float cy, + float amplitude, float frequency, float decay, float phase +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + float dx = x - cx; + float dy = y - cy; + float dist = sqrtf(dx * dx + dy * dy); + + // Ripple displacement + float wave = sinf(dist * frequency * 0.1f + phase); + float amp = amplitude * expf(-dist * decay * 0.01f); + + float src_x = x + dx / (dist + 0.001f) * wave * amp; + float src_y = y + dy / (dist + 0.001f) * wave * amp; + + int dst_idx = (y * width + x) * channels; + + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + for (int c = 0; c < channels; c++) { + dst[dst_idx + c] = src[dst_idx + c]; // Keep original on boundary + } + return; + } + + // Bilinear interpolation + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + for (int c = 0; c < channels; c++) { + float v00 = src[(y0 * width + x0) * channels + c]; + float v10 = src[(y0 * width + (x0+1)) * channels + c]; + float v01 = src[((y0+1) * width + x0) * channels + c]; + float v11 = src[((y0+1) * width + (x0+1)) * channels + c]; + + float v = v00*(1-fx)*(1-fy) + v10*fx*(1-fy) + v01*(1-fx)*fy + v11*fx*fy; + dst[dst_idx + c] = (unsigned char)(v < 0 ? 0 : (v > 255 ? 255 : v)); + } +} +''', 'ripple') + + +# ============================================================================= +# Fast GPU Operations +# ============================================================================= + +class FastGPUOps: + """Optimized GPU operations using CUDA kernels.""" + + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.channels = 3 + + # Pre-allocate work buffers + self._buf1 = cp.zeros((height, width, 3), dtype=cp.uint8) + self._buf2 = cp.zeros((height, width, 3), dtype=cp.uint8) + self._current_buf = 0 + + # Grid/block config + self._block_2d = (16, 16) + self._grid_2d = ((width + 15) // 16, (height + 15) // 16) + self._block_1d = 256 + self._grid_1d = (width * height * 3 + 255) // 256 + + def _get_buffers(self): + """Get source and destination buffers (ping-pong).""" + if self._current_buf == 0: + return self._buf1, self._buf2 + return self._buf2, self._buf1 + + def _swap_buffers(self): + """Swap ping-pong buffers.""" + self._current_buf = 1 - self._current_buf + + def set_input(self, frame: cp.ndarray): + """Set input frame.""" + if self._current_buf == 0: + cp.copyto(self._buf1, frame) + else: + cp.copyto(self._buf2, frame) + + def get_output(self) -> cp.ndarray: + """Get current output buffer.""" + if self._current_buf == 0: + return self._buf1 + return self._buf2 + + def rotate(self, angle: float, cx: float = None, cy: float = None): + """Fast GPU rotation.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + # Compute inverse rotation matrix + import math + rad = math.radians(-angle) # Negative for inverse + cos_a = math.cos(rad) + sin_a = math.sin(rad) + + # Inverse affine matrix (rotate around center) + m00 = cos_a + m01 = -sin_a + m02 = cx - cos_a * cx + sin_a * cy + m10 = sin_a + m11 = cos_a + m12 = cy - sin_a * cx - cos_a * cy + + AFFINE_WARP_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(m00), np.float32(m01), np.float32(m02), + np.float32(m10), np.float32(m11), np.float32(m12)) + ) + self._swap_buffers() + + def zoom(self, factor: float, cx: float = None, cy: float = None): + """Fast GPU zoom.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + ZOOM_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(factor), np.float32(cx), np.float32(cy)) + ) + self._swap_buffers() + + def blend(self, other: cp.ndarray, alpha: float): + """Fast GPU blend.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BLEND_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), other.ravel(), dst.ravel(), size, np.float32(alpha)) + ) + self._swap_buffers() + + def brightness(self, factor: float): + """Fast GPU brightness adjustment.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BRIGHTNESS_CONTRAST_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size, np.float32((factor - 1) * 128), np.float32(1.0)) + ) + self._swap_buffers() + + def contrast(self, factor: float): + """Fast GPU contrast adjustment.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + BRIGHTNESS_CONTRAST_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size, np.float32(0), np.float32(factor)) + ) + self._swap_buffers() + + def hue_shift(self, degrees: float): + """Fast GPU hue shift.""" + src, dst = self._get_buffers() + + HUE_SHIFT_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, np.float32(degrees)) + ) + self._swap_buffers() + + def invert(self): + """Fast GPU invert.""" + src, dst = self._get_buffers() + size = self.width * self.height * self.channels + + INVERT_KERNEL( + (self._grid_1d,), (self._block_1d,), + (src.ravel(), dst.ravel(), size) + ) + self._swap_buffers() + + def ripple(self, amplitude: float, cx: float = None, cy: float = None, + frequency: float = 8, decay: float = 2, phase: float = 0): + """Fast GPU ripple effect.""" + if cx is None: + cx = self.width / 2 + if cy is None: + cy = self.height / 2 + + src, dst = self._get_buffers() + + RIPPLE_KERNEL( + self._grid_2d, self._block_2d, + (src, dst, self.width, self.height, self.channels, + np.float32(cx), np.float32(cy), + np.float32(amplitude), np.float32(frequency), + np.float32(decay), np.float32(phase)) + ) + self._swap_buffers() + + +# Global fast ops instance (created per resolution) +_FAST_OPS: Dict[Tuple[int, int], FastGPUOps] = {} + + +def get_fast_ops(width: int, height: int) -> FastGPUOps: + """Get or create FastGPUOps for given resolution.""" + key = (width, height) + if key not in _FAST_OPS: + _FAST_OPS[key] = FastGPUOps(width, height) + return _FAST_OPS[key] + + +# ============================================================================= +# Fast effect functions (drop-in replacements) +# ============================================================================= + +def fast_rotate(frame: cp.ndarray, angle: float, **kwargs) -> cp.ndarray: + """Fast GPU rotation.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.rotate(angle, kwargs.get('cx'), kwargs.get('cy')) + return ops.get_output().copy() + + +def fast_zoom(frame: cp.ndarray, factor: float, **kwargs) -> cp.ndarray: + """Fast GPU zoom.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.zoom(factor, kwargs.get('cx'), kwargs.get('cy')) + return ops.get_output().copy() + + +def fast_blend(frame1: cp.ndarray, frame2: cp.ndarray, alpha: float) -> cp.ndarray: + """Fast GPU blend.""" + h, w = frame1.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame1) + ops.blend(frame2, alpha) + return ops.get_output().copy() + + +def fast_hue_shift(frame: cp.ndarray, degrees: float) -> cp.ndarray: + """Fast GPU hue shift.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.hue_shift(degrees) + return ops.get_output().copy() + + +def fast_invert(frame: cp.ndarray) -> cp.ndarray: + """Fast GPU invert.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.invert() + return ops.get_output().copy() + + +def fast_ripple(frame: cp.ndarray, amplitude: float, **kwargs) -> cp.ndarray: + """Fast GPU ripple.""" + h, w = frame.shape[:2] + ops = get_fast_ops(w, h) + ops.set_input(frame) + ops.ripple( + amplitude, + kwargs.get('center_x', w/2), + kwargs.get('center_y', h/2), + kwargs.get('frequency', 8), + kwargs.get('decay', 2), + kwargs.get('speed', 0) * kwargs.get('t', 0) # phase from speed*time + ) + return ops.get_output().copy() + + +print("[jit_compiler] CUDA kernels loaded", file=sys.stderr) diff --git a/streaming/multi_res_output.py b/streaming/multi_res_output.py new file mode 100644 index 0000000..40c661a --- /dev/null +++ b/streaming/multi_res_output.py @@ -0,0 +1,509 @@ +""" +Multi-Resolution HLS Output with IPFS Storage. + +Renders video at multiple quality levels simultaneously: +- Original resolution (from recipe) +- 720p (streaming quality) +- 360p (mobile/low bandwidth) + +All segments stored on IPFS. Master playlist enables adaptive bitrate streaming. +""" + +import os +import sys +import subprocess +import threading +import queue +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union +from dataclasses import dataclass, field + +import numpy as np + +# Try GPU imports +try: + import cupy as cp + GPU_AVAILABLE = True +except ImportError: + cp = None + GPU_AVAILABLE = False + + +@dataclass +class QualityLevel: + """Configuration for a quality level.""" + name: str + width: int + height: int + bitrate: int # kbps + segment_cids: Dict[int, str] = field(default_factory=dict) + playlist_cid: Optional[str] = None + + +class MultiResolutionHLSOutput: + """ + GPU-accelerated multi-resolution HLS output with IPFS storage. + + Encodes video at multiple quality levels simultaneously using NVENC. + Segments are uploaded to IPFS as they're created. + Generates adaptive bitrate master playlist. + """ + + def __init__( + self, + output_dir: str, + source_size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + audio_source: str = None, + resume_from: Optional[Dict] = None, + ): + """Initialize multi-resolution HLS output. + + Args: + output_dir: Directory for HLS output files + source_size: (width, height) of source frames + fps: Frames per second + segment_duration: Duration of each HLS segment in seconds + ipfs_gateway: IPFS gateway URL for playlist URLs + on_playlist_update: Callback when playlists are updated + audio_source: Optional audio file to mux with video + resume_from: Optional dict to resume from checkpoint with keys: + - segment_cids: Dict of quality -> {seg_num: cid} + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.source_width, self.source_height = source_size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._on_playlist_update = on_playlist_update + self.audio_source = audio_source + self._is_open = True + self._frame_count = 0 + + # Define quality levels + self.qualities: Dict[str, QualityLevel] = {} + self._setup_quality_levels() + + # Restore segment CIDs if resuming (don't re-upload existing segments) + if resume_from and resume_from.get('segment_cids'): + for name, cids in resume_from['segment_cids'].items(): + if name in self.qualities: + self.qualities[name].segment_cids = dict(cids) + print(f"[MultiResHLS] Restored {len(cids)} segment CIDs for {name}", file=sys.stderr) + + # IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Upload queue and thread + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Track master playlist + self._master_playlist_cid = None + + # Setup encoders + self._setup_encoders() + + print(f"[MultiResHLS] Initialized {self.source_width}x{self.source_height} @ {fps}fps", file=sys.stderr) + print(f"[MultiResHLS] Quality levels: {list(self.qualities.keys())}", file=sys.stderr) + + def _setup_quality_levels(self): + """Configure quality levels based on source resolution.""" + # Always include original resolution + self.qualities['original'] = QualityLevel( + name='original', + width=self.source_width, + height=self.source_height, + bitrate=self._estimate_bitrate(self.source_width, self.source_height), + ) + + # Add 720p if source is larger + if self.source_height > 720: + aspect = self.source_width / self.source_height + w720 = int(720 * aspect) + w720 = w720 - (w720 % 2) # Ensure even width + self.qualities['720p'] = QualityLevel( + name='720p', + width=w720, + height=720, + bitrate=2500, + ) + + # Add 360p if source is larger + if self.source_height > 360: + aspect = self.source_width / self.source_height + w360 = int(360 * aspect) + w360 = w360 - (w360 % 2) # Ensure even width + self.qualities['360p'] = QualityLevel( + name='360p', + width=w360, + height=360, + bitrate=800, + ) + + def _estimate_bitrate(self, width: int, height: int) -> int: + """Estimate appropriate bitrate for resolution (in kbps).""" + pixels = width * height + if pixels >= 3840 * 2160: # 4K + return 15000 + elif pixels >= 1920 * 1080: # 1080p + return 5000 + elif pixels >= 1280 * 720: # 720p + return 2500 + elif pixels >= 854 * 480: # 480p + return 1500 + else: + return 800 + + def _setup_encoders(self): + """Setup FFmpeg encoder processes for each quality level.""" + self._encoders: Dict[str, subprocess.Popen] = {} + self._encoder_threads: Dict[str, threading.Thread] = {} + + for name, quality in self.qualities.items(): + # Create output directory for this quality + quality_dir = self.output_dir / name + quality_dir.mkdir(parents=True, exist_ok=True) + + # Build FFmpeg command + cmd = self._build_encoder_cmd(quality, quality_dir) + + print(f"[MultiResHLS] Starting encoder for {name}: {quality.width}x{quality.height}", file=sys.stderr) + + # Start encoder process + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=10**7, # Large buffer to prevent blocking + ) + self._encoders[name] = proc + + # Start stderr drain thread + stderr_thread = threading.Thread( + target=self._drain_stderr, + args=(name, proc), + daemon=True + ) + stderr_thread.start() + self._encoder_threads[name] = stderr_thread + + def _build_encoder_cmd(self, quality: QualityLevel, output_dir: Path) -> List[str]: + """Build FFmpeg command for a quality level.""" + playlist_path = output_dir / "playlist.m3u8" + segment_pattern = output_dir / "segment_%05d.ts" + + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-pixel_format", "rgb24", + "-video_size", f"{self.source_width}x{self.source_height}", + "-framerate", str(self.fps), + "-i", "-", + ] + + # Add audio input if provided + if self.audio_source: + cmd.extend(["-i", str(self.audio_source)]) + # Map video from input 0, audio from input 1 + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Scale if not original resolution + if quality.width != self.source_width or quality.height != self.source_height: + cmd.extend([ + "-vf", f"scale={quality.width}:{quality.height}:flags=lanczos", + ]) + + # NVENC encoding with quality settings + cmd.extend([ + "-c:v", "h264_nvenc", + "-preset", "p4", # Balanced speed/quality + "-tune", "hq", + "-b:v", f"{quality.bitrate}k", + "-maxrate", f"{int(quality.bitrate * 1.5)}k", + "-bufsize", f"{quality.bitrate * 2}k", + "-g", str(int(self.fps * self.segment_duration)), # Keyframe interval = segment duration + "-keyint_min", str(int(self.fps * self.segment_duration)), + "-sc_threshold", "0", # Disable scene change detection for consistent segments + ]) + + # Add audio encoding if audio source provided + if self.audio_source: + cmd.extend([ + "-c:a", "aac", + "-b:a", "128k", + "-shortest", # Stop when shortest stream ends + ]) + + # HLS output + cmd.extend([ + "-f", "hls", + "-hls_time", str(self.segment_duration), + "-hls_list_size", "0", # Keep all segments in playlist + "-hls_flags", "independent_segments+append_list", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(segment_pattern), + str(playlist_path), + ]) + + return cmd + + def _drain_stderr(self, name: str, proc: subprocess.Popen): + """Drain FFmpeg stderr to prevent blocking.""" + try: + for line in proc.stderr: + line_str = line.decode('utf-8', errors='replace').strip() + if line_str and ('error' in line_str.lower() or 'warning' in line_str.lower()): + print(f"[FFmpeg/{name}] {line_str}", file=sys.stderr) + except Exception as e: + print(f"[FFmpeg/{name}] stderr drain error: {e}", file=sys.stderr) + + def write(self, frame: Union[np.ndarray, 'cp.ndarray'], t: float = 0): + """Write a frame to all quality encoders.""" + if not self._is_open: + return + + # Convert GPU frame to CPU if needed + if GPU_AVAILABLE and hasattr(frame, 'get'): + frame = frame.get() # CuPy to NumPy + elif hasattr(frame, 'cpu'): + frame = frame.cpu # GPUFrame to NumPy + elif hasattr(frame, 'gpu') and hasattr(frame, 'is_on_gpu'): + frame = frame.gpu.get() if frame.is_on_gpu else frame.cpu + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + + # Ensure contiguous + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + frame_bytes = frame.tobytes() + + # Write to all encoders + for name, proc in self._encoders.items(): + if proc.poll() is not None: + print(f"[MultiResHLS] Encoder {name} died with code {proc.returncode}", file=sys.stderr) + self._is_open = False + return + + try: + proc.stdin.write(frame_bytes) + except BrokenPipeError: + print(f"[MultiResHLS] Encoder {name} pipe broken", file=sys.stderr) + self._is_open = False + return + + self._frame_count += 1 + + # Check for new segments periodically + if self._frame_count % int(self.fps * self.segment_duration) == 0: + self._check_and_upload_segments() + + def _check_and_upload_segments(self): + """Check for new segments and queue them for upload.""" + for name, quality in self.qualities.items(): + quality_dir = self.output_dir / name + segments = sorted(quality_dir.glob("segment_*.ts")) + + for seg_path in segments: + seg_num = int(seg_path.stem.split("_")[1]) + + if seg_num in quality.segment_cids: + continue # Already uploaded + + # Check if segment is complete (not still being written) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue + time.sleep(0.05) + size2 = seg_path.stat().st_size + if size1 != size2: + continue # Still being written + except FileNotFoundError: + continue + + # Queue for upload + self._upload_queue.put((name, seg_path, seg_num)) + + def _upload_worker(self): + """Background worker for IPFS uploads.""" + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + + quality_name, seg_path, seg_num = item + self._do_upload(quality_name, seg_path, seg_num) + + except queue.Empty: + continue + except Exception as e: + print(f"[MultiResHLS] Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, quality_name: str, seg_path: Path, seg_num: int): + """Upload a segment to IPFS.""" + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + self.qualities[quality_name].segment_cids[seg_num] = cid + print(f"[MultiResHLS] Uploaded {quality_name}/segment_{seg_num:05d}.ts -> {cid[:16]}...", file=sys.stderr) + + # Update playlists after each upload + self._update_playlists() + except Exception as e: + print(f"[MultiResHLS] Failed to upload {seg_path}: {e}", file=sys.stderr) + + def _update_playlists(self): + """Generate and upload IPFS playlists.""" + # Generate quality-specific playlists + for name, quality in self.qualities.items(): + if not quality.segment_cids: + continue + + playlist = self._generate_quality_playlist(quality) + cid = self._ipfs_add_bytes(playlist.encode(), pin=True) + if cid: + quality.playlist_cid = cid + + # Generate master playlist + self._generate_master_playlist() + + def _generate_quality_playlist(self, quality: QualityLevel, finalize: bool = False) -> str: + """Generate HLS playlist for a quality level.""" + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + if finalize: + lines.append("#EXT-X-PLAYLIST-TYPE:VOD") + + # Use /ipfs-ts/ for correct MIME type + segment_gateway = self.ipfs_gateway.replace("/ipfs", "/ipfs-ts") + + for seg_num in sorted(quality.segment_cids.keys()): + cid = quality.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{segment_gateway}/{cid}") + + if finalize: + lines.append("#EXT-X-ENDLIST") + + return "\n".join(lines) + "\n" + + def _generate_master_playlist(self, finalize: bool = False): + """Generate and upload master playlist.""" + lines = ["#EXTM3U", "#EXT-X-VERSION:3"] + + for name, quality in self.qualities.items(): + if not quality.playlist_cid: + continue + + lines.append( + f"#EXT-X-STREAM-INF:BANDWIDTH={quality.bitrate * 1000}," + f"RESOLUTION={quality.width}x{quality.height}," + f"NAME=\"{name}\"" + ) + lines.append(f"{self.ipfs_gateway}/{quality.playlist_cid}") + + if len(lines) <= 2: + return # No quality playlists yet + + master_content = "\n".join(lines) + "\n" + cid = self._ipfs_add_bytes(master_content.encode(), pin=True) + + if cid: + self._master_playlist_cid = cid + print(f"[MultiResHLS] Master playlist: {cid}", file=sys.stderr) + + if self._on_playlist_update: + # Pass both master CID and quality info for dynamic playlist generation + quality_info = { + name: { + "cid": q.playlist_cid, + "width": q.width, + "height": q.height, + "bitrate": q.bitrate, + } + for name, q in self.qualities.items() + if q.playlist_cid + } + self._on_playlist_update(cid, quality_info) + + def close(self): + """Close all encoders and finalize output.""" + if not self._is_open: + return + + self._is_open = False + print(f"[MultiResHLS] Closing after {self._frame_count} frames", file=sys.stderr) + + # Close encoder stdin pipes + for name, proc in self._encoders.items(): + try: + proc.stdin.close() + except: + pass + + # Wait for encoders to finish + for name, proc in self._encoders.items(): + try: + proc.wait(timeout=30) + print(f"[MultiResHLS] Encoder {name} finished with code {proc.returncode}", file=sys.stderr) + except subprocess.TimeoutExpired: + proc.kill() + print(f"[MultiResHLS] Encoder {name} killed (timeout)", file=sys.stderr) + + # Final segment check and upload + self._check_and_upload_segments() + + # Wait for uploads to complete + self._upload_queue.put(None) # Shutdown signal + self._upload_thread.join(timeout=60) + + # Generate final playlists with EXT-X-ENDLIST + for name, quality in self.qualities.items(): + if quality.segment_cids: + playlist = self._generate_quality_playlist(quality, finalize=True) + cid = self._ipfs_add_bytes(playlist.encode(), pin=True) + if cid: + quality.playlist_cid = cid + print(f"[MultiResHLS] Final {name} playlist: {cid} ({len(quality.segment_cids)} segments)", file=sys.stderr) + + # Final master playlist + self._generate_master_playlist(finalize=True) + + print(f"[MultiResHLS] Complete. Master playlist: {self._master_playlist_cid}", file=sys.stderr) + + @property + def is_open(self) -> bool: + return self._is_open + + @property + def playlist_cid(self) -> Optional[str]: + return self._master_playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + if self._master_playlist_cid: + return f"{self.ipfs_gateway}/{self._master_playlist_cid}" + return None + + @property + def segment_cids(self) -> Dict[str, Dict[int, str]]: + """Get all segment CIDs organized by quality.""" + return {name: dict(q.segment_cids) for name, q in self.qualities.items()} diff --git a/streaming/output.py b/streaming/output.py new file mode 100644 index 0000000..b2a4e85 --- /dev/null +++ b/streaming/output.py @@ -0,0 +1,963 @@ +""" +Output targets for streaming compositor. + +Supports: +- Display window (preview) +- File output (recording) +- Stream output (RTMP, etc.) - future +- NVENC hardware encoding (auto-detected) +- CuPy GPU arrays (auto-converted to numpy for output) +""" + +import numpy as np +import subprocess +import threading +import queue +from abc import ABC, abstractmethod +from typing import Tuple, Optional, List, Union +from pathlib import Path + +# Try to import CuPy for GPU array support +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + + +def ensure_numpy(frame: Union[np.ndarray, 'cp.ndarray']) -> np.ndarray: + """Convert frame to numpy array if it's a CuPy array.""" + if CUPY_AVAILABLE and isinstance(frame, cp.ndarray): + return cp.asnumpy(frame) + return frame + +# Cache NVENC availability check +_nvenc_available: Optional[bool] = None + + +def check_nvenc_available() -> bool: + """Check if NVENC hardware encoding is available and working. + + Does a real encode test to catch cases where nvenc is listed + but CUDA libraries aren't loaded. + """ + global _nvenc_available + if _nvenc_available is not None: + return _nvenc_available + + try: + # First check if encoder is listed + result = subprocess.run( + ["ffmpeg", "-encoders"], + capture_output=True, + text=True, + timeout=5 + ) + if "h264_nvenc" not in result.stdout: + _nvenc_available = False + return _nvenc_available + + # Actually try to encode a small test frame + result = subprocess.run( + ["ffmpeg", "-y", "-f", "lavfi", "-i", "testsrc=duration=0.1:size=64x64:rate=1", + "-c:v", "h264_nvenc", "-f", "null", "-"], + capture_output=True, + text=True, + timeout=10 + ) + _nvenc_available = result.returncode == 0 + if not _nvenc_available: + import sys + print("NVENC listed but not working, falling back to libx264", file=sys.stderr) + except Exception: + _nvenc_available = False + + return _nvenc_available + + +def get_encoder_params(codec: str, preset: str, crf: int) -> List[str]: + """ + Get encoder-specific FFmpeg parameters. + + For NVENC (h264_nvenc, hevc_nvenc): + - Uses -cq for constant quality (similar to CRF) + - Presets: p1 (fastest) to p7 (slowest/best quality) + - Mapping: fast->p4, medium->p5, slow->p6 + + For libx264: + - Uses -crf for constant rate factor + - Presets: ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow + """ + if codec in ("h264_nvenc", "hevc_nvenc"): + # Map libx264 presets to NVENC presets + nvenc_preset_map = { + "ultrafast": "p1", + "superfast": "p2", + "veryfast": "p3", + "faster": "p3", + "fast": "p4", + "medium": "p5", + "slow": "p6", + "slower": "p6", + "veryslow": "p7", + } + nvenc_preset = nvenc_preset_map.get(preset, "p4") + + # NVENC quality: 0 (best) to 51 (worst), similar to CRF + # CRF 18 = high quality, CRF 23 = good quality + return [ + "-c:v", codec, + "-preset", nvenc_preset, + "-cq", str(crf), # Constant quality mode + "-rc", "vbr", # Variable bitrate with quality target + ] + else: + # Standard libx264 params + return [ + "-c:v", codec, + "-preset", preset, + "-crf", str(crf), + ] + + +class Output(ABC): + """Abstract base class for output targets.""" + + @abstractmethod + def write(self, frame: np.ndarray, t: float): + """Write a frame to the output.""" + pass + + @abstractmethod + def close(self): + """Close the output and clean up resources.""" + pass + + @property + @abstractmethod + def is_open(self) -> bool: + """Check if output is still open/valid.""" + pass + + +class DisplayOutput(Output): + """ + Display frames using mpv (handles Wayland properly). + + Useful for live preview. Press 'q' to quit. + """ + + def __init__(self, title: str = "Streaming Preview", size: Tuple[int, int] = None, + audio_source: str = None, fps: float = 30): + self.title = title + self.size = size + self.audio_source = audio_source + self.fps = fps + self._is_open = True + self._process = None + self._audio_process = None + + def _start_mpv(self, frame_size: Tuple[int, int]): + """Start mpv process for display.""" + import sys + w, h = frame_size + cmd = [ + "mpv", + "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + f"--title={self.title}", + "-", + ] + print(f"Starting mpv: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Start audio playback if we have an audio source + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Display frame.""" + if not self._is_open: + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Ensure frame is correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + # Start mpv on first frame + if self._process is None: + self._start_mpv((frame.shape[1], frame.shape[0])) + + # Check if mpv is still running + if self._process.poll() is not None: + self._is_open = False + return + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() # Prevent buffering + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the display and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open + + +class FileOutput(Output): + """ + Write frames to a video file using ffmpeg. + + Automatically uses NVENC hardware encoding when available, + falling back to libx264 CPU encoding otherwise. + """ + + def __init__( + self, + path: str, + size: Tuple[int, int], + fps: float = 30, + codec: str = "auto", # "auto", "h264_nvenc", "libx264" + crf: int = 18, + preset: str = "fast", + audio_source: str = None, + ): + self.path = Path(path) + self.size = size + self.fps = fps + self._is_open = True + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # Build ffmpeg command + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + # Explicitly map: video from input 0 (rawvideo), audio from input 1 + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend(["-pix_fmt", "yuv420p"]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "192k", "-shortest"]) + + # Use fragmented mp4 for streamable output while writing + if str(self.path).endswith('.mp4'): + cmd.extend(["-movflags", "frag_keyframe+empty_moov+default_base_moof"]) + + cmd.append(str(self.path)) + + import sys + print(f"FileOutput cmd: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, # Show errors for debugging + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to video file.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the video file.""" + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None + + +class MultiOutput(Output): + """ + Write to multiple outputs simultaneously. + + Useful for recording while showing preview. + """ + + def __init__(self, outputs: list): + self.outputs = outputs + + def write(self, frame: np.ndarray, t: float): + for output in self.outputs: + if output.is_open: + output.write(frame, t) + + def close(self): + for output in self.outputs: + output.close() + + @property + def is_open(self) -> bool: + return any(o.is_open for o in self.outputs) + + +class NullOutput(Output): + """ + Discard frames (for benchmarking). + """ + + def __init__(self): + self._is_open = True + self.frame_count = 0 + + def write(self, frame: np.ndarray, t: float): + self.frame_count += 1 + + def close(self): + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open + + +class PipeOutput(Output): + """ + Pipe frames directly to mpv. + + Launches mpv with rawvideo demuxer and writes frames to stdin. + """ + + def __init__(self, size: Tuple[int, int], fps: float = 30, audio_source: str = None): + self.size = size + self.fps = fps + self.audio_source = audio_source + self._is_open = True + self._process = None + self._audio_process = None + self._started = False + + def _start(self): + """Start mpv and audio on first frame.""" + if self._started: + return + self._started = True + + import sys + w, h = self.size + + # Start mpv + cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + "--title=Streaming", + "-" + ] + print(f"Starting mpv: {w}x{h} @ {self.fps}fps", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + # Start audio + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to mpv.""" + if not self._is_open: + return + + self._start() + + # Check mpv still running + if self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close mpv and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open + + +class HLSOutput(Output): + """ + Write frames as HLS stream (m3u8 playlist + .ts segments). + + This enables true live streaming where the browser can poll + for new segments as they become available. + + Automatically uses NVENC hardware encoding when available. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, # 4s segments for stability + codec: str = "auto", # "auto", "h264_nvenc", "libx264" + crf: int = 23, + preset: str = "fast", # Better quality than ultrafast + audio_source: str = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self._is_open = True + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # HLS playlist path + self.playlist_path = self.output_dir / "stream.m3u8" + + # Build ffmpeg command for HLS output + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Keyframe interval - must be exactly segment_duration for clean cuts + gop_size = int(fps * segment_duration) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend([ + "-pix_fmt", "yuv420p", + # Force keyframes at exact intervals for clean segment boundaries + "-g", str(gop_size), + "-keyint_min", str(gop_size), + "-sc_threshold", "0", # Disable scene change detection + "-force_key_frames", f"expr:gte(t,n_forced*{segment_duration})", + # Reduce buffering for faster segment availability + "-flush_packets", "1", + ]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + # HLS specific options for smooth live streaming + cmd.extend([ + "-f", "hls", + "-hls_time", str(segment_duration), + "-hls_list_size", "0", # Keep all segments in playlist + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.playlist_path), + ]) + + import sys + print(f"HLSOutput cmd: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, # Show errors for debugging + ) + + # Track segments for status reporting + self.segments_written = 0 + self._last_segment_check = 0 + + def write(self, frame: np.ndarray, t: float): + """Write frame to HLS stream.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + + # Periodically count segments + if t - self._last_segment_check > 1.0: + self._last_segment_check = t + self.segments_written = len(list(self.output_dir.glob("segment_*.ts"))) + + def close(self): + """Close the HLS stream.""" + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + # Final segment count + self.segments_written = len(list(self.output_dir.glob("segment_*.ts"))) + + # Mark playlist as ended (VOD mode) + if self.playlist_path.exists(): + with open(self.playlist_path, "a") as f: + f.write("#EXT-X-ENDLIST\n") + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None + + +class IPFSHLSOutput(Output): + """ + Write frames as HLS stream with segments uploaded to IPFS. + + Each segment is uploaded to IPFS as it's created, enabling distributed + streaming where clients can fetch segments from any IPFS gateway. + + The m3u8 playlist is continuously updated with IPFS URLs and can be + fetched via get_playlist() or the playlist_cid property. + """ + + def __init__( + self, + output_dir: str, + size: Tuple[int, int], + fps: float = 30, + segment_duration: float = 4.0, + codec: str = "auto", + crf: int = 23, + preset: str = "fast", + audio_source: str = None, + ipfs_gateway: str = "https://ipfs.io/ipfs", + on_playlist_update: callable = None, + ): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.size = size + self.fps = fps + self.segment_duration = segment_duration + self.ipfs_gateway = ipfs_gateway.rstrip("/") + self._is_open = True + self._on_playlist_update = on_playlist_update # Callback when playlist CID changes + + # Auto-detect NVENC + if codec == "auto": + codec = "h264_nvenc" if check_nvenc_available() else "libx264" + self.codec = codec + + # Track segment CIDs + self.segment_cids: dict = {} # segment_number -> cid + self._last_segment_checked = -1 + self._playlist_cid: Optional[str] = None + self._upload_lock = threading.Lock() + + # Import IPFS client + from ipfs_client import add_file, add_bytes + self._ipfs_add_file = add_file + self._ipfs_add_bytes = add_bytes + + # Background upload thread for async IPFS uploads + self._upload_queue = queue.Queue() + self._upload_thread = threading.Thread(target=self._upload_worker, daemon=True) + self._upload_thread.start() + + # Local HLS paths + self.local_playlist_path = self.output_dir / "stream.m3u8" + + # Build ffmpeg command for HLS output + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + # Keyframe interval + gop_size = int(fps * segment_duration) + + # Get encoder-specific params + cmd.extend(get_encoder_params(codec, preset, crf)) + cmd.extend([ + "-pix_fmt", "yuv420p", + "-g", str(gop_size), + "-keyint_min", str(gop_size), + "-sc_threshold", "0", + "-force_key_frames", f"expr:gte(t,n_forced*{segment_duration})", + "-flush_packets", "1", + ]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "128k", "-shortest"]) + + # HLS options + cmd.extend([ + "-f", "hls", + "-hls_time", str(segment_duration), + "-hls_list_size", "0", + "-hls_flags", "independent_segments+append_list+split_by_time", + "-hls_segment_type", "mpegts", + "-hls_segment_filename", str(self.output_dir / "segment_%05d.ts"), + str(self.local_playlist_path), + ]) + + import sys + print(f"IPFSHLSOutput: starting ffmpeg", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, + ) + + def _upload_worker(self): + """Background worker thread for async IPFS uploads.""" + import sys + while True: + try: + item = self._upload_queue.get(timeout=1.0) + if item is None: # Shutdown signal + break + seg_path, seg_num = item + self._do_upload(seg_path, seg_num) + except queue.Empty: + continue + except Exception as e: + print(f"Upload worker error: {e}", file=sys.stderr) + + def _do_upload(self, seg_path: Path, seg_num: int): + """Actually perform the upload (runs in background thread).""" + import sys + try: + cid = self._ipfs_add_file(seg_path, pin=True) + if cid: + with self._upload_lock: + self.segment_cids[seg_num] = cid + print(f"IPFS: segment_{seg_num:05d}.ts -> {cid}", file=sys.stderr) + self._update_ipfs_playlist() + except Exception as e: + print(f"Failed to upload segment {seg_num}: {e}", file=sys.stderr) + + def _upload_new_segments(self): + """Check for new segments and queue them for async IPFS upload.""" + import sys + import time + + # Find all segments + segments = sorted(self.output_dir.glob("segment_*.ts")) + + for seg_path in segments: + # Extract segment number from filename + seg_name = seg_path.stem # segment_00000 + seg_num = int(seg_name.split("_")[1]) + + # Skip if already uploaded or queued + with self._upload_lock: + if seg_num in self.segment_cids: + continue + + # Skip if segment is still being written (quick non-blocking check) + try: + size1 = seg_path.stat().st_size + if size1 == 0: + continue # Empty file, still being created + + time.sleep(0.01) # Very short check + size2 = seg_path.stat().st_size + if size1 != size2: + continue # File still being written + except FileNotFoundError: + continue + + # Queue for async upload (non-blocking!) + self._upload_queue.put((seg_path, seg_num)) + + def _update_ipfs_playlist(self): + """Generate and upload IPFS-aware m3u8 playlist.""" + import sys + + with self._upload_lock: + if not self.segment_cids: + return + + # Build m3u8 content with IPFS URLs + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + # Add segments in order + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + playlist_content = "\n".join(lines) + "\n" + + # Upload playlist to IPFS + cid = self._ipfs_add_bytes(playlist_content.encode("utf-8"), pin=True) + if cid: + self._playlist_cid = cid + print(f"IPFS: playlist updated -> {cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + # Notify callback (e.g., to update database for live HLS redirect) + if self._on_playlist_update: + try: + self._on_playlist_update(cid) + except Exception as e: + print(f"IPFS: playlist callback error: {e}", file=sys.stderr) + + def write(self, frame: np.ndarray, t: float): + """Write frame to HLS stream and upload segments to IPFS.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Convert GPU array to numpy if needed + frame = ensure_numpy(frame) + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + return + + # Check for new segments periodically (every second) + current_segment = int(t / self.segment_duration) + if current_segment > self._last_segment_checked: + self._last_segment_checked = current_segment + self._upload_new_segments() + + def close(self): + """Close the HLS stream and finalize IPFS uploads.""" + import sys + + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + # Queue any remaining segments + self._upload_new_segments() + + # Wait for pending uploads to complete + self._upload_queue.put(None) # Signal shutdown + self._upload_thread.join(timeout=30) + + # Generate final playlist with #EXT-X-ENDLIST + if self.segment_cids: + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + "#EXT-X-PLAYLIST-TYPE:VOD", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + lines.append("#EXT-X-ENDLIST") + playlist_content = "\n".join(lines) + "\n" + + cid = self._ipfs_add_bytes(playlist_content.encode("utf-8"), pin=True) + if cid: + self._playlist_cid = cid + print(f"IPFS: final playlist -> {cid} ({len(self.segment_cids)} segments)", file=sys.stderr) + + @property + def playlist_cid(self) -> Optional[str]: + """Get the current playlist CID.""" + return self._playlist_cid + + @property + def playlist_url(self) -> Optional[str]: + """Get the full IPFS URL for the playlist.""" + if self._playlist_cid: + return f"{self.ipfs_gateway}/{self._playlist_cid}" + return None + + def get_playlist(self) -> str: + """Get the current m3u8 playlist content with IPFS URLs.""" + if not self.segment_cids: + return "#EXTM3U\n" + + lines = [ + "#EXTM3U", + "#EXT-X-VERSION:3", + f"#EXT-X-TARGETDURATION:{int(self.segment_duration) + 1}", + "#EXT-X-MEDIA-SEQUENCE:0", + ] + + for seg_num in sorted(self.segment_cids.keys()): + cid = self.segment_cids[seg_num] + lines.append(f"#EXTINF:{self.segment_duration:.3f},") + lines.append(f"{self.ipfs_gateway}/{cid}") + + if not self._is_open: + lines.append("#EXT-X-ENDLIST") + + return "\n".join(lines) + "\n" + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None \ No newline at end of file diff --git a/streaming/pipeline.py b/streaming/pipeline.py new file mode 100644 index 0000000..29dd7e1 --- /dev/null +++ b/streaming/pipeline.py @@ -0,0 +1,846 @@ +""" +Streaming pipeline executor. + +Directly executes compiled sexp recipes frame-by-frame. +No adapter layer - frames and analysis flow through the DAG. +""" + +import sys +import time +import numpy as np +from pathlib import Path +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sources import VideoSource +from .audio import StreamingAudioAnalyzer +from .output import DisplayOutput, FileOutput +from .sexp_interp import SexpInterpreter + + +@dataclass +class FrameContext: + """Context passed through the pipeline for each frame.""" + t: float # Current time + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + analysis: Dict[str, Any] = field(default_factory=dict) + + +class StreamingPipeline: + """ + Executes a compiled sexp recipe as a streaming pipeline. + + Frames flow through the DAG directly - no adapter needed. + Each node is evaluated lazily when its output is requested. + """ + + def __init__(self, compiled_recipe, recipe_dir: Path = None, fps: float = 30, seed: int = 42, + output_size: tuple = None): + self.recipe = compiled_recipe + self.recipe_dir = recipe_dir or Path(".") + self.fps = fps + self.seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # Runtime state + self.sources: Dict[str, VideoSource] = {} + self.audio_analyzer: Optional[StreamingAudioAnalyzer] = None + self.audio_source_path: Optional[str] = None + + # Sexp interpreter for expressions + self.interp = SexpInterpreter() + + # Scan state (node_id -> current value) + self.scan_state: Dict[str, Any] = {} + self.scan_emit: Dict[str, Any] = {} + + # SLICE_ON state + self.slice_on_acc: Dict[str, Any] = {} + self.slice_on_result: Dict[str, Any] = {} + + # Frame cache for current timestep (cleared each frame) + self._frame_cache: Dict[str, np.ndarray] = {} + + # Context for current frame + self.ctx = FrameContext(t=0.0) + + # Output size (w, h) - set after sources are initialized + self._output_size = output_size + + # Initialize + self._init_sources() + self._init_scans() + self._init_slice_on() + + # Set output size from first source if not specified + if self._output_size is None and self.sources: + first_source = next(iter(self.sources.values())) + self._output_size = first_source._size + + def _init_sources(self): + """Initialize video and audio sources.""" + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + config = node.get('config', {}) + path = config.get('path') + if path: + full_path = (self.recipe_dir / path).resolve() + suffix = full_path.suffix.lower() + + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + if not full_path.exists(): + print(f"Warning: video not found: {full_path}", file=sys.stderr) + continue + self.sources[node['id']] = VideoSource( + str(full_path), + target_fps=self.fps + ) + elif suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + if not full_path.exists(): + print(f"Warning: audio not found: {full_path}", file=sys.stderr) + continue + self.audio_source_path = str(full_path) + self.audio_analyzer = StreamingAudioAnalyzer(str(full_path)) + + def _init_scans(self): + """Initialize scan nodes with their initial state.""" + import random + seed_offset = 0 + + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + config = node.get('config', {}) + + # Create RNG for this scan + scan_seed = config.get('seed', self.seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + init_value = self.interp.eval(init_expr, {}) + + self.scan_state[node['id']] = { + 'value': init_value, + 'rng': rng, + 'config': config, + } + + # Compute initial emit + self._update_scan_emit(node['id']) + + def _update_scan_emit(self, node_id: str): + """Update the emit value for a scan.""" + state = self.scan_state[node_id] + config = state['config'] + emit_expr = config.get('emit_expr', config.get('emit', None)) + + if emit_expr is None: + # No emit expression - emit the value directly + self.scan_emit[node_id] = state['value'] + return + + # Build environment from state + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG for interpreter + self.interp.rng = state['rng'] + + self.scan_emit[node_id] = self.interp.eval(emit_expr, env) + + def _step_scan(self, node_id: str): + """Step a scan forward on beat.""" + state = self.scan_state[node_id] + config = state['config'] + step_expr = config.get('step_expr', config.get('step', None)) + + if step_expr is None: + return + + # Build environment + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG + self.interp.rng = state['rng'] + + # Evaluate step + new_value = self.interp.eval(step_expr, env) + state['value'] = new_value + + # Update emit + self._update_scan_emit(node_id) + + def _init_slice_on(self): + """Initialize SLICE_ON nodes.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + init = config.get('init', {}) + self.slice_on_acc[node['id']] = dict(init) + + # Evaluate initial state + self._eval_slice_on(node['id']) + + def _eval_slice_on(self, node_id: str): + """Evaluate a SLICE_ON node's Lambda.""" + node = self.nodes[node_id] + config = node.get('config', {}) + fn = config.get('fn') + videos = config.get('videos', []) + + if not fn: + return + + acc = self.slice_on_acc[node_id] + n_videos = len(videos) + + # Set up environment + self.interp.globals['videos'] = list(range(n_videos)) + + try: + from .sexp_interp import eval_slice_on_lambda + result = eval_slice_on_lambda( + fn, acc, self.ctx.beat_count, 0, 1, + list(range(n_videos)), self.interp + ) + self.slice_on_result[node_id] = result + + # Update accumulator + if 'acc' in result: + self.slice_on_acc[node_id] = result['acc'] + except Exception as e: + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _on_beat(self): + """Called when a beat is detected.""" + self.ctx.beat_count += 1 + + # Step all scans + for node_id in self.scan_state: + self._step_scan(node_id) + + # Step all SLICE_ON nodes + for node_id in self.slice_on_acc: + self._eval_slice_on(node_id) + + def _get_frame(self, node_id: str) -> Optional[np.ndarray]: + """ + Get the output frame for a node at current time. + + Recursively evaluates inputs as needed. + Results are cached for the current timestep. + """ + if node_id in self._frame_cache: + return self._frame_cache[node_id] + + node = self.nodes.get(node_id) + if not node: + return None + + node_type = node.get('type') + + if node_type == 'SOURCE': + frame = self._eval_source(node) + elif node_type == 'SEGMENT': + frame = self._eval_segment(node) + elif node_type == 'EFFECT': + frame = self._eval_effect(node) + elif node_type == 'SLICE_ON': + frame = self._eval_slice_on_frame(node) + else: + # Unknown node type - try to pass through input + inputs = node.get('inputs', []) + frame = self._get_frame(inputs[0]) if inputs else None + + self._frame_cache[node_id] = frame + return frame + + def _eval_source(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SOURCE node.""" + source = self.sources.get(node['id']) + if source: + return source.read_frame(self.ctx.t) + return None + + def _eval_segment(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SEGMENT node (time segment of source).""" + inputs = node.get('inputs', []) + if not inputs: + return None + + config = node.get('config', {}) + start = config.get('start', 0) + duration = config.get('duration') + + # Resolve any bindings + if isinstance(start, dict): + start = self._resolve_binding(start) if start.get('_binding') else 0 + if isinstance(duration, dict): + duration = self._resolve_binding(duration) if duration.get('_binding') else None + + # Adjust time for segment + t_local = self.ctx.t + (start if isinstance(start, (int, float)) else 0) + if duration and isinstance(duration, (int, float)): + t_local = t_local % duration # Loop within segment + + # Get source frame at adjusted time + source_id = inputs[0] + source = self.sources.get(source_id) + if source: + return source.read_frame(t_local) + + return self._get_frame(source_id) + + def _eval_effect(self, node: dict) -> Optional[np.ndarray]: + """Evaluate an EFFECT node.""" + import cv2 + + inputs = node.get('inputs', []) + config = node.get('config', {}) + effect_name = config.get('effect') + + # Get input frame(s) + input_frames = [self._get_frame(inp) for inp in inputs] + input_frames = [f for f in input_frames if f is not None] + + if not input_frames: + return None + + frame = input_frames[0] + + # Resolve bindings in config + params = self._resolve_config(config) + + # Apply effect based on name + if effect_name == 'rotate': + angle = params.get('angle', 0) + if abs(angle) > 0.5: + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == 'zoom': + amount = params.get('amount', 1.0) + if abs(amount - 1.0) > 0.01: + frame = self._apply_zoom(frame, amount) + + elif effect_name == 'invert': + amount = params.get('amount', 0) + if amount > 0.01: + inverted = 255 - frame + frame = cv2.addWeighted(frame, 1 - amount, inverted, amount, 0) + + elif effect_name == 'hue_shift': + degrees = params.get('degrees', 0) + if abs(degrees) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(int) + int(degrees / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + elif effect_name == 'blend': + if len(input_frames) >= 2: + opacity = params.get('opacity', 0.5) + frame = cv2.addWeighted(input_frames[0], 1 - opacity, + input_frames[1], opacity, 0) + + elif effect_name == 'blend_multi': + weights = params.get('weights', []) + if len(input_frames) > 1 and weights: + h, w = input_frames[0].shape[:2] + result = np.zeros((h, w, 3), dtype=np.float32) + for f, wt in zip(input_frames, weights): + if f is not None and wt > 0.001: + if f.shape[:2] != (h, w): + f = cv2.resize(f, (w, h)) + result += f.astype(np.float32) * wt + frame = np.clip(result, 0, 255).astype(np.uint8) + + elif effect_name == 'ripple': + amp = params.get('amplitude', 0) + if amp > 1: + frame = self._apply_ripple(frame, amp, + params.get('center_x', 0.5), + params.get('center_y', 0.5), + params.get('frequency', 8), + params.get('decay', 2), + params.get('speed', 5)) + + return frame + + def _eval_slice_on_frame(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SLICE_ON node - returns composited frame.""" + import cv2 + + config = node.get('config', {}) + video_ids = config.get('videos', []) + result = self.slice_on_result.get(node['id'], {}) + + if not result: + # No result yet - return first video + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get layers and compose info + layers = result.get('layers', []) + compose = result.get('compose', {}) + weights = compose.get('weights', []) + + if not layers or not weights: + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get frames for each layer + frames = [] + for i, layer in enumerate(layers): + video_idx = layer.get('video', i) + if video_idx < len(video_ids): + frame = self._get_frame(video_ids[video_idx]) + + # Apply layer effects (zoom) + effects = layer.get('effects', []) + for eff in effects: + eff_name = eff.get('effect') + if hasattr(eff_name, 'name'): + eff_name = eff_name.name + if eff_name == 'zoom': + zoom_amt = eff.get('amount', 1.0) + if frame is not None: + frame = self._apply_zoom(frame, zoom_amt) + + frames.append(frame) + else: + frames.append(None) + + # Composite with weights - use consistent output size + if self._output_size: + w, h = self._output_size + else: + # Fallback to first non-None frame size + for f in frames: + if f is not None: + h, w = f.shape[:2] + break + else: + return None + + output = np.zeros((h, w, 3), dtype=np.float32) + + for frame, weight in zip(frames, weights): + if frame is None or weight < 0.001: + continue + + # Resize to output size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + + output += frame.astype(np.float32) * weight + + # Normalize weights + total_weight = sum(wt for wt in weights if wt > 0.001) + if total_weight > 0 and abs(total_weight - 1.0) > 0.01: + output /= total_weight + + return np.clip(output, 0, 255).astype(np.uint8) + + def _resolve_config(self, config: dict) -> dict: + """Resolve bindings in effect config to actual values.""" + resolved = {} + + for key, value in config.items(): + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', + 'analysis_refs', 'inputs', 'cid'): + continue + + if isinstance(value, dict) and value.get('_binding'): + resolved[key] = self._resolve_binding(value) + elif isinstance(value, dict) and value.get('_expr'): + resolved[key] = self._resolve_expr(value) + else: + resolved[key] = value + + return resolved + + def _resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to its current value.""" + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get raw value from scan or analysis + if source_id in self.scan_emit: + value = self.scan_emit[source_id] + elif source_id in self.ctx.analysis: + data = self.ctx.analysis[source_id] + value = data.get(feature, data.get('values', [0]))[0] if isinstance(data, dict) else data + else: + # Fallback to energy + value = self.ctx.energy + + # Extract feature from dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def _resolve_expr(self, expr: dict) -> Any: + """Resolve a compiled expression.""" + env = { + 'energy': self.ctx.energy, + 'beat_count': self.ctx.beat_count, + 't': self.ctx.t, + } + + # Add scan values + for scan_id, value in self.scan_emit.items(): + # Use short form if available + env[scan_id] = value + + # Extract the actual expression from _expr wrapper + actual_expr = expr.get('_expr', expr) + return self.interp.eval(actual_expr, env) + + def _apply_zoom(self, frame: np.ndarray, amount: float) -> np.ndarray: + """Apply zoom to frame.""" + import cv2 + h, w = frame.shape[:2] + + if amount > 1.01: + # Zoom in: crop center + new_w, new_h = int(w / amount), int(h / amount) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + elif amount < 0.99: + # Zoom out: shrink and center + scaled_w, scaled_h = int(w * amount), int(h * amount) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(frame, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + return canvas + + return frame + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, frequency: float, + decay: float, speed: float) -> np.ndarray: + """Apply ripple effect.""" + import cv2 + h, w = frame.shape[:2] + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize to center + center_x, center_y = w * cx, h * cy + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx**2 + dy**2) + + # Ripple displacement + phase = self.ctx.t * speed + ripple = amplitude * np.sin(dist / frequency - phase) * np.exp(-dist * decay / max(w, h)) + + # Displace coordinates + angle = np.arctan2(dy, dx) + map_x = (x_coords + ripple * np.cos(angle)).astype(np.float32) + map_y = (y_coords + ripple * np.sin(angle)).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + def _find_output_node(self) -> Optional[str]: + """Find the final output node (MUX or last EFFECT).""" + # Look for MUX node + for node in self.recipe.nodes: + if node.get('type') == 'MUX': + return node['id'] + + # Otherwise find last EFFECT after SLICE_ON + last_effect = None + found_slice_on = False + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + found_slice_on = True + elif node.get('type') == 'EFFECT' and found_slice_on: + last_effect = node['id'] + + return last_effect + + def render_frame(self, t: float) -> Optional[np.ndarray]: + """Render a single frame at time t.""" + # Clear frame cache + self._frame_cache.clear() + + # Update context + self.ctx.t = t + + # Update audio analysis + if self.audio_analyzer: + self.audio_analyzer.set_time(t) + energy = self.audio_analyzer.get_energy() + is_beat = self.audio_analyzer.get_beat() + + # Beat edge detection + was_beat = self.ctx.is_beat + self.ctx.energy = energy + self.ctx.is_beat = is_beat + + if is_beat and not was_beat: + self._on_beat() + + # Store in analysis dict + self.ctx.analysis['live_energy'] = {'values': [energy]} + self.ctx.analysis['live_beat'] = {'values': [1.0 if is_beat else 0.0]} + + # Find output node and render + output_node = self._find_output_node() + if output_node: + frame = self._get_frame(output_node) + # Normalize to output size + if frame is not None and self._output_size: + w, h = self._output_size + if frame.shape[1] != w or frame.shape[0] != h: + import cv2 + frame = cv2.resize(frame, (w, h)) + return frame + + return None + + def run(self, output: str = "preview", duration: float = None): + """ + Run the pipeline. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration or 60s) + """ + # Determine duration + if duration is None: + if self.audio_analyzer: + duration = self.audio_analyzer.duration + else: + duration = 60.0 + + # Create output + if output == "preview": + # Get frame size from first source + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = DisplayOutput(size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + elif isinstance(output, str): + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = FileOutput(output, size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + else: + out = output + + frame_time = 1.0 / self.fps + n_frames = int(duration * self.fps) + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {n_frames} frames @ {self.fps}fps", file=sys.stderr) + + start_time = time.time() + frame_count = 0 + + try: + for frame_num in range(n_frames): + t = frame_num * frame_time + + frame = self.render_frame(t) + + if frame is not None: + out.write(frame, t) + frame_count += 1 + + # Progress + if frame_num % 50 == 0: + elapsed = time.time() - start_time + fps = frame_count / elapsed if elapsed > 0 else 0 + pct = 100 * frame_num / n_frames + print(f"\r{pct:5.1f}% | {fps:5.1f} fps | frame {frame_num}/{n_frames}", + end="", file=sys.stderr) + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources.values(): + src.close() + + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", + file=sys.stderr) + + +def run_pipeline(recipe_path: str, output: str = "preview", + duration: float = None, fps: float = None): + """ + Run a recipe through the streaming pipeline. + + No adapter layer - directly executes the compiled recipe. + """ + from pathlib import Path + + # Add artdag to path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + pipeline.run(output=output, duration=duration) + + +def run_pipeline_piped(recipe_path: str, duration: float = None, fps: float = None): + """ + Run pipeline and pipe directly to mpv with audio. + """ + import subprocess + from pathlib import Path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + # Get frame info + first_source = next(iter(pipeline.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + + # Determine duration + if duration is None: + if pipeline.audio_analyzer: + duration = pipeline.audio_analyzer.duration + else: + duration = 60.0 + + actual_fps = fps or compiled.encoding.get('fps', 30) + n_frames = int(duration * actual_fps) + frame_time = 1.0 / actual_fps + + print(f"Streaming {n_frames} frames @ {actual_fps}fps to mpv", file=sys.stderr) + + # Start mpv + mpv_cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={actual_fps}", + "--title=Streaming Pipeline", + "-" + ] + mpv = subprocess.Popen(mpv_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Start audio if available + audio_proc = None + if pipeline.audio_source_path: + audio_cmd = ["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + pipeline.audio_source_path] + audio_proc = subprocess.Popen(audio_cmd, stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + + try: + import cv2 + for frame_num in range(n_frames): + if mpv.poll() is not None: + break # mpv closed + + t = frame_num * frame_time + frame = pipeline.render_frame(t) + if frame is not None: + # Ensure consistent frame size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + try: + mpv.stdin.write(frame.tobytes()) + mpv.stdin.flush() + except BrokenPipeError: + break + except KeyboardInterrupt: + pass + finally: + if mpv.stdin: + mpv.stdin.close() + mpv.terminate() + if audio_proc: + audio_proc.terminate() + for src in pipeline.sources.values(): + src.close() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe through streaming pipeline") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_pipeline_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_pipeline(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/streaming/recipe_adapter.py b/streaming/recipe_adapter.py new file mode 100644 index 0000000..2133919 --- /dev/null +++ b/streaming/recipe_adapter.py @@ -0,0 +1,470 @@ +""" +Adapter to run sexp recipes through the streaming compositor. + +Bridges the gap between: +- Existing recipe format (sexp files with stages, effects, analysis) +- Streaming compositor (sources, effect chains, compositor config) +""" + +import sys +from pathlib import Path +from typing import Dict, List, Any, Optional + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + +from .compositor import StreamingCompositor +from .sources import VideoSource +from .audio import FileAudioAnalyzer + + +class RecipeAdapter: + """ + Adapts a compiled sexp recipe to run through the streaming compositor. + + Example: + adapter = RecipeAdapter("effects/quick_test.sexp") + adapter.run(output="preview", duration=60) + """ + + def __init__( + self, + recipe_path: str, + params: Dict[str, Any] = None, + backend: str = "numpy", + ): + """ + Load and prepare a recipe for streaming. + + Args: + recipe_path: Path to .sexp recipe file + params: Parameter overrides + backend: "numpy" or "glsl" + """ + self.recipe_path = Path(recipe_path) + self.recipe_dir = self.recipe_path.parent + self.params = params or {} + self.backend = backend + + # Compile recipe + self._compile() + + def _compile(self): + """Compile the recipe and extract structure.""" + from artdag.sexp.compiler import compile_string + + recipe_text = self.recipe_path.read_text() + self.compiled = compile_string(recipe_text, self.params, recipe_dir=self.recipe_dir) + + # Extract key info + self.sources = {} # name -> path + self.effects_registry = {} # effect_name -> path + self.analyzers = {} # name -> analyzer info + + # Walk nodes to find sources and structure + # nodes is a list in CompiledRecipe + for node in self.compiled.nodes: + node_type = node.get("type", "") + + if node_type == "SOURCE": + config = node.get("config", {}) + path = config.get("path") + if path: + self.sources[node["id"]] = self.recipe_dir / path + + elif node_type == "ANALYZE": + config = node.get("config", {}) + self.analyzers[node["id"]] = { + "analyzer": config.get("analyzer"), + "path": config.get("analyzer_path"), + } + + # Get effects registry from compiled recipe + # registry has 'effects' sub-dict + effects_dict = self.compiled.registry.get("effects", {}) + for name, info in effects_dict.items(): + if info.get("path"): + self.effects_registry[name] = Path(info["path"]) + + def run_analysis(self) -> Dict[str, Any]: + """ + Run analysis phase (energy, beats, etc.). + + Returns: + Dict of analysis track name -> {times, values, duration} + """ + print(f"Running analysis...", file=sys.stderr) + + # Use existing planner's analysis execution + from artdag.sexp.planner import create_plan + + analysis_data = {} + + def on_analysis(node_id: str, results: dict): + analysis_data[node_id] = results + print(f" {node_id[:16]}...: {len(results.get('times', []))} samples", file=sys.stderr) + + # Create plan (runs analysis as side effect) + plan = create_plan( + self.compiled, + inputs={}, + recipe_dir=self.recipe_dir, + on_analysis=on_analysis, + ) + + # Also store named analysis tracks + for name, data in plan.analysis.items(): + analysis_data[name] = data + + return analysis_data + + def build_compositor( + self, + analysis_data: Dict[str, Any] = None, + fps: float = None, + ) -> StreamingCompositor: + """ + Build a streaming compositor from the recipe. + + This is a simplified version that handles common patterns. + Complex recipes may need manual configuration. + + Args: + analysis_data: Pre-computed analysis data + + Returns: + Configured StreamingCompositor + """ + # Extract video and audio sources in SLICE_ON input order + video_sources = [] + audio_source = None + + # Find audio source first + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + audio_source = str(path) + break + + # Find SLICE_ON node to get correct video order + slice_on_inputs = None + for node in self.compiled.nodes: + if node.get('type') == 'SLICE_ON': + # Use 'videos' config key which has the correct order + config = node.get('config', {}) + slice_on_inputs = config.get('videos', []) + break + + if slice_on_inputs: + # Trace each SLICE_ON input back to its SOURCE + node_lookup = {n['id']: n for n in self.compiled.nodes} + + def trace_to_source(node_id, visited=None): + """Trace a node back to its SOURCE, return source path.""" + if visited is None: + visited = set() + if node_id in visited: + return None + visited.add(node_id) + + node = node_lookup.get(node_id) + if not node: + return None + if node.get('type') == 'SOURCE': + return self.sources.get(node_id) + # Recurse through inputs + for inp in node.get('inputs', []): + result = trace_to_source(inp, visited) + if result: + return result + return None + + # Build video_sources in SLICE_ON input order + for inp_id in slice_on_inputs: + source_path = trace_to_source(inp_id) + if source_path: + suffix = source_path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(source_path)) + + # Fallback to definition order if no SLICE_ON + if not video_sources: + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(path)) + + if not video_sources: + raise ValueError("No video sources found in recipe") + + # Build effect chains - use live audio bindings (matching video_sources count) + effects_per_source = self._build_streaming_effects(n_sources=len(video_sources)) + + # Build compositor config from recipe + compositor_config = self._extract_compositor_config(analysis_data) + + return StreamingCompositor( + sources=video_sources, + effects_per_source=effects_per_source, + compositor_config=compositor_config, + analysis_data=analysis_data or {}, + backend=self.backend, + recipe_dir=self.recipe_dir, + fps=fps or self.compiled.encoding.get("fps", 30), + audio_source=audio_source, + ) + + def _build_streaming_effects(self, n_sources: int = None) -> List[List[Dict]]: + """ + Build effect chains for streaming with live audio bindings. + + Replicates the recipe's effect pipeline: + - Per source: rotate, zoom, invert, hue_shift, ascii_art + - All driven by live_energy and live_beat + """ + if n_sources is None: + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + + effects_per_source = [] + + for i in range(n_sources): + # Alternate rotation direction per source + rot_dir = 1 if i % 2 == 0 else -1 + + effects = [ + # Rotate - energy drives angle + { + "effect": "rotate", + "effect_path": str(self.effects_registry.get("rotate", "")), + "angle": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 45 * rot_dir], + }, + }, + # Zoom - energy drives amount + { + "effect": "zoom", + "effect_path": str(self.effects_registry.get("zoom", "")), + "amount": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [1.0, 1.5] if i % 2 == 0 else [1.0, 0.7], + }, + }, + # Invert - beat triggers + { + "effect": "invert", + "effect_path": str(self.effects_registry.get("invert", "")), + "amount": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + # Hue shift - energy drives hue + { + "effect": "hue_shift", + "effect_path": str(self.effects_registry.get("hue_shift", "")), + "degrees": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 180], + }, + }, + # ASCII art - energy drives char size, beat triggers mix + { + "effect": "ascii_art", + "effect_path": str(self.effects_registry.get("ascii_art", "")), + "char_size": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [4, 32], + }, + "mix": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + ] + effects_per_source.append(effects) + + return effects_per_source + + def _extract_effects(self) -> List[List[Dict]]: + """Extract effect chains for each source (legacy, pre-computed analysis).""" + # Simplified: find EFFECT nodes and their configs + effects_per_source = [] + + for node_id, path in self.sources.items(): + if path.suffix.lower() not in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + continue + + # Find effects that depend on this source + # This is simplified - real implementation would trace the DAG + effects = [] + + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + effect_name = config.get("effect") + if effect_name and effect_name in self.effects_registry: + effect_config = { + "effect": effect_name, + "effect_path": str(self.effects_registry[effect_name]), + } + # Copy only effect params (filter out internal fields) + internal_fields = ( + "effect", "effect_path", "cid", "effect_cid", + "effects_registry", "analysis_refs", "inputs", + ) + for k, v in config.items(): + if k not in internal_fields: + effect_config[k] = v + effects.append(effect_config) + break # One effect per source for now + + effects_per_source.append(effects) + + return effects_per_source + + def _extract_compositor_config(self, analysis_data: Dict) -> Dict: + """Extract compositor configuration.""" + # Look for blend_multi or similar composition nodes + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + if config.get("effect") == "blend_multi": + return { + "mode": config.get("mode", "alpha"), + "weights": config.get("weights", []), + } + + # Default: equal blend + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + return { + "mode": "alpha", + "weights": [1.0 / n_sources] * n_sources if n_sources > 0 else [1.0], + } + + def run( + self, + output: str = "preview", + duration: float = None, + fps: float = None, + ): + """ + Run the recipe through streaming compositor. + + Everything streams: video frames read on-demand, audio analyzed in real-time. + No pre-computation. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration) + fps: Frame rate (default from recipe, or 30) + """ + # Build compositor with recipe executor for full pipeline + from .recipe_executor import StreamingRecipeExecutor + + compositor = self.build_compositor(analysis_data={}, fps=fps) + + # Use audio duration if not specified + if duration is None: + if compositor._audio_analyzer: + duration = compositor._audio_analyzer.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + else: + # Live mode - run until quit + print("Live mode - press 'q' to quit", file=sys.stderr) + + # Create sexp executor that interprets the recipe + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(self.compiled, seed=42) + + compositor.run(output=output, duration=duration, recipe_executor=executor) + + +def run_recipe( + recipe_path: str, + output: str = "preview", + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run a recipe through streaming compositor. + + Everything streams in real-time: video frames, audio analysis. + No pre-computation - starts immediately. + + Example: + run_recipe("effects/quick_test.sexp", output="preview", duration=30) + run_recipe("effects/quick_test.sexp", output="preview", fps=5) # Lower fps for slow systems + """ + adapter = RecipeAdapter(recipe_path, params=params) + adapter.run(output=output, duration=duration, fps=fps) + + +def run_recipe_piped( + recipe_path: str, + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run recipe and pipe directly to mpv. + """ + from .output import PipeOutput + + adapter = RecipeAdapter(recipe_path, params=params) + compositor = adapter.build_compositor(analysis_data={}, fps=fps) + + # Get frame size + if compositor.sources: + first_source = compositor.sources[0] + w, h = first_source._size + else: + w, h = 720, 720 + + actual_fps = fps or adapter.compiled.encoding.get('fps', 30) + + # Create pipe output + pipe_out = PipeOutput( + size=(w, h), + fps=actual_fps, + audio_source=compositor._audio_source + ) + + # Create executor + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(adapter.compiled, seed=42) + + # Run with pipe output + compositor.run(output=pipe_out, duration=duration, recipe_executor=executor) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe with streaming compositor") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_recipe_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_recipe(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/streaming/recipe_executor.py b/streaming/recipe_executor.py new file mode 100644 index 0000000..678d9f6 --- /dev/null +++ b/streaming/recipe_executor.py @@ -0,0 +1,415 @@ +""" +Streaming recipe executor. + +Implements the full recipe logic for real-time streaming: +- Scans (state machines that evolve on beats) +- Process-pair template (two clips with sporadic effects, blended) +- Cycle-crossfade (dynamic composition cycling through video pairs) +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + + +@dataclass +class ScanState: + """State for a scan (beat-driven state machine).""" + value: Any = 0 + rng: random.Random = field(default_factory=random.Random) + + +class StreamingScans: + """ + Real-time scan executor. + + Scans are state machines that evolve on each beat. + They drive effect parameters like invert triggers, hue shifts, etc. + """ + + def __init__(self, seed: int = 42, n_sources: int = 4): + self.master_seed = seed + self.n_sources = n_sources + self.scans: Dict[str, ScanState] = {} + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self._init_scans() + + def _init_scans(self): + """Initialize all scans with their own RNG seeds.""" + scan_names = [] + + # Per-pair scans (dynamic based on n_sources) + for i in range(self.n_sources): + scan_names.extend([ + f"inv_a_{i}", f"inv_b_{i}", f"hue_a_{i}", f"hue_b_{i}", + f"ascii_a_{i}", f"ascii_b_{i}", f"pair_mix_{i}", f"pair_rot_{i}", + ]) + + # Global scans + scan_names.extend(["whole_spin", "ripple_gate", "cycle"]) + + for i, name in enumerate(scan_names): + rng = random.Random(self.master_seed + i) + self.scans[name] = ScanState(value=self._init_value(name), rng=rng) + + def _init_value(self, name: str) -> Any: + """Get initial value for a scan.""" + if name.startswith("inv_") or name.startswith("ascii_"): + return 0 # Counter for remaining beats + elif name.startswith("hue_"): + return {"rem": 0, "hue": 0} + elif name.startswith("pair_mix"): + return {"rem": 0, "opacity": 0.5} + elif name.startswith("pair_rot"): + pair_idx = int(name.split("_")[-1]) + rot_dir = 1 if pair_idx % 2 == 0 else -1 + return {"beat": 0, "clen": 25, "dir": rot_dir, "angle": 0} + elif name == "whole_spin": + return { + "phase": 0, # 0 = waiting, 1 = spinning + "beat": 0, # beats into current phase + "plen": 20, # beats in this phase + "dir": 1, # spin direction + "total_angle": 0.0, # cumulative angle after all spins + "spin_start_angle": 0.0, # angle when current spin started + "spin_start_time": 0.0, # time when current spin started + "spin_end_time": 0.0, # estimated time when spin ends + } + elif name == "ripple_gate": + return {"rem": 0, "cx": 0.5, "cy": 0.5} + elif name == "cycle": + return {"cycle": 0, "beat": 0, "clen": 60} + return 0 + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + # Estimate beat interval from last two beats + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + for name, state in self.scans.items(): + state.value = self._step_scan(name, state.value, state.rng, beat_interval) + + def _step_scan(self, name: str, value: Any, rng: random.Random, beat_interval: float = 0.5) -> Any: + """Step a scan forward by one beat.""" + + # Invert scan: 10% chance, lasts 1-5 beats + if name.startswith("inv_"): + if value > 0: + return value - 1 + elif rng.random() < 0.1: + return rng.randint(1, 5) + return 0 + + # Hue scan: 10% chance, random hue 30-330, lasts 1-5 beats + elif name.startswith("hue_"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "hue": value["hue"]} + elif rng.random() < 0.1: + return {"rem": rng.randint(1, 5), "hue": rng.uniform(30, 330)} + return {"rem": 0, "hue": 0} + + # ASCII scan: 5% chance, lasts 1-3 beats + elif name.startswith("ascii_"): + if value > 0: + return value - 1 + elif rng.random() < 0.05: + return rng.randint(1, 3) + return 0 + + # Pair mix: changes every 1-11 beats + elif name.startswith("pair_mix"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "opacity": value["opacity"]} + return {"rem": rng.randint(1, 11), "opacity": rng.choice([0, 0.5, 1.0])} + + # Pair rotation: full rotation every 20-30 beats + elif name.startswith("pair_rot"): + beat = value["beat"] + clen = value["clen"] + dir_ = value["dir"] + angle = value["angle"] + + if beat + 1 < clen: + new_angle = angle + dir_ * (360 / clen) + return {"beat": beat + 1, "clen": clen, "dir": dir_, "angle": new_angle} + else: + return {"beat": 0, "clen": rng.randint(20, 30), "dir": -dir_, "angle": angle} + + # Whole spin: sporadic 720 degree spins (cumulative - stays rotated) + elif name == "whole_spin": + phase = value["phase"] + beat = value["beat"] + plen = value["plen"] + dir_ = value["dir"] + total_angle = value.get("total_angle", 0.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", 0.0) + + if phase == 1: + # Currently spinning + if beat + 1 < plen: + return { + "phase": 1, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Spin complete - update total_angle with final spin + new_total = spin_start_angle + dir_ * 720.0 + return { + "phase": 0, "beat": 0, "plen": rng.randint(20, 40), "dir": dir_, + "total_angle": new_total, + "spin_start_angle": new_total, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time, + } + else: + # Waiting phase + if beat + 1 < plen: + return { + "phase": 0, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Start new spin + new_dir = 1 if rng.random() < 0.5 else -1 + new_plen = rng.randint(10, 25) + spin_duration = new_plen * beat_interval + return { + "phase": 1, "beat": 0, "plen": new_plen, "dir": new_dir, + "total_angle": total_angle, + "spin_start_angle": total_angle, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time + spin_duration, + } + + # Ripple gate: 5% chance, lasts 1-20 beats + elif name == "ripple_gate": + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "cx": value["cx"], "cy": value["cy"]} + elif rng.random() < 0.05: + return {"rem": rng.randint(1, 20), + "cx": rng.uniform(0.1, 0.9), + "cy": rng.uniform(0.1, 0.9)} + return {"rem": 0, "cx": 0.5, "cy": 0.5} + + # Cycle: track which video pair is active + elif name == "cycle": + beat = value["beat"] + clen = value["clen"] + cycle = value["cycle"] + + if beat + 1 < clen: + return {"cycle": cycle, "beat": beat + 1, "clen": clen} + else: + # Move to next pair, vary cycle length + return {"cycle": (cycle + 1) % 4, "beat": 0, + "clen": 40 + (self.beat_count * 7) % 41} + + return value + + def get_emit(self, name: str) -> float: + """Get emitted value for a scan.""" + value = self.scans[name].value + + if name.startswith("inv_") or name.startswith("ascii_"): + return 1.0 if value > 0 else 0.0 + + elif name.startswith("hue_"): + return value["hue"] if value["rem"] > 0 else 0.0 + + elif name.startswith("pair_mix"): + return value["opacity"] + + elif name.startswith("pair_rot"): + return value["angle"] + + elif name == "whole_spin": + # Smooth time-based interpolation during spin + phase = value.get("phase", 0) + if phase == 1: + # Currently spinning - interpolate based on time + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", spin_start_time + 1.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + dir_ = value.get("dir", 1) + + duration = spin_end_time - spin_start_time + if duration > 0: + progress = (self.current_time - spin_start_time) / duration + progress = max(0.0, min(1.0, progress)) # clamp to 0-1 + else: + progress = 1.0 + + return spin_start_angle + progress * 720.0 * dir_ + else: + # Not spinning - return cumulative angle + return value.get("total_angle", 0.0) + + elif name == "ripple_gate": + return 1.0 if value["rem"] > 0 else 0.0 + + elif name == "cycle": + return value + + return 0.0 + + +class StreamingRecipeExecutor: + """ + Executes a recipe in streaming mode. + + Implements: + - process-pair: two video clips with opposite effects, blended + - cycle-crossfade: dynamic cycling through video pairs + - Final effects: whole-spin rotation, ripple + """ + + def __init__(self, n_sources: int = 4, seed: int = 42): + self.n_sources = n_sources + self.scans = StreamingScans(seed, n_sources=n_sources) + self.last_beat_detected = False + self.current_time = 0.0 + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with current audio analysis.""" + self.current_time = t + self.scans.current_time = t + # Update scans on beat + if is_beat and not self.last_beat_detected: + self.scans.on_beat() + self.last_beat_detected = is_beat + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """ + Get effect parameters for a source clip. + + Args: + source_idx: Which video source (0-3) + clip: "a" or "b" (each source has two clips) + energy: Current audio energy (0-1) + """ + suffix = f"_{source_idx}" + + # Rotation ranges alternate + if source_idx % 2 == 0: + rot_range = [0, 45] if clip == "a" else [0, -45] + zoom_range = [1, 1.5] if clip == "a" else [1, 0.5] + else: + rot_range = [0, -45] if clip == "a" else [0, 45] + zoom_range = [1, 0.5] if clip == "a" else [1, 1.5] + + return { + "rotate_angle": rot_range[0] + energy * (rot_range[1] - rot_range[0]), + "zoom_amount": zoom_range[0] + energy * (zoom_range[1] - zoom_range[0]), + "invert_amount": self.scans.get_emit(f"inv_{clip}{suffix}"), + "hue_degrees": self.scans.get_emit(f"hue_{clip}{suffix}"), + "ascii_mix": 0, # Disabled - too slow without GPU + "ascii_char_size": 4 + energy * 28, # 4-32 + } + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair.""" + suffix = f"_{source_idx}" + return { + "blend_opacity": self.scans.get_emit(f"pair_mix{suffix}"), + "pair_rotation": self.scans.get_emit(f"pair_rot{suffix}"), + } + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade composition.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + weights = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + w = 0.9 + elif phase3 < clen * 2: + w = 0.9 - ((phase3 - clen) / clen) * 0.85 + else: + w = 0.05 + elif active == prev: + if phase3 < clen: + w = 0.05 + elif phase3 < clen * 2: + w = 0.05 + ((phase3 - clen) / clen) * 0.85 + else: + w = 0.9 + else: + w = 0.05 + + weights.append(w) + + # Normalize + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + + return weights + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + zooms = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + z = 1.0 + elif phase3 < clen * 2: + z = 1.0 + ((phase3 - clen) / clen) * 1.0 + else: + z = 0.1 + elif active == prev: + if phase3 < clen: + z = 3.0 # Start big + elif phase3 < clen * 2: + z = 3.0 - ((phase3 - clen) / clen) * 2.0 # Shrink to 1.0 + else: + z = 1.0 + else: + z = 0.1 + + zooms.append(z) + + return zooms + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (whole-spin, ripple).""" + ripple_gate = self.scans.get_emit("ripple_gate") + ripple_state = self.scans.scans["ripple_gate"].value + + return { + "whole_spin_angle": self.scans.get_emit("whole_spin"), + "ripple_amplitude": ripple_gate * (5 + energy * 45), # 5-50 + "ripple_cx": ripple_state["cx"], + "ripple_cy": ripple_state["cy"], + } diff --git a/streaming/sexp_executor.py b/streaming/sexp_executor.py new file mode 100644 index 0000000..0151853 --- /dev/null +++ b/streaming/sexp_executor.py @@ -0,0 +1,678 @@ +""" +Streaming S-expression executor. + +Executes compiled sexp recipes in real-time by: +- Evaluating scan expressions on each beat +- Resolving bindings to get effect parameter values +- Applying effects frame-by-frame +- Evaluating SLICE_ON Lambda for cycle crossfade +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sexp_interp import SexpInterpreter, eval_slice_on_lambda + + +@dataclass +class ScanState: + """Runtime state for a scan.""" + node_id: str + name: Optional[str] + value: Any + rng: random.Random + init_expr: dict + step_expr: dict + emit_expr: dict + + +class ExprEvaluator: + """ + Evaluates compiled expression ASTs. + + Expressions are dicts with: + - _expr: True (marks as expression) + - op: operation name + - args: list of arguments + - name: for 'var' ops + - keys: for 'dict' ops + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + + def eval(self, expr: Any, env: Dict[str, Any]) -> Any: + """Evaluate an expression in the given environment.""" + # Literal values + if not isinstance(expr, dict): + return expr + + # Check if it's an expression + if not expr.get('_expr'): + # It's a plain dict - return as-is + return expr + + op = expr.get('op') + args = expr.get('args', []) + + # Evaluate based on operation + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + raise KeyError(f"Unknown variable: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison ops + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic ops + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random ops + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.randint(lo, hi) + elif op == 'rand-range': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.uniform(lo, hi) + + # Logic ops + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown operation: {op}") + + +class SexpStreamingExecutor: + """ + Executes a compiled sexp recipe in streaming mode. + + Reads scan definitions, effect chains, and bindings from the + compiled recipe and executes them frame-by-frame. + """ + + def __init__(self, compiled_recipe, seed: int = 42): + self.recipe = compiled_recipe + self.master_seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # State (must be initialized before _init_scans) + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self.last_beat_detected = False + self.energy = 0.0 + + # Initialize scans + self.scans: Dict[str, ScanState] = {} + self.scan_outputs: Dict[str, Any] = {} # Current emit values by node_id + self._init_scans() + + # Initialize SLICE_ON interpreter + self.sexp_interp = SexpInterpreter(random.Random(seed)) + self._slice_on_lambda = None + self._slice_on_acc = None + self._slice_on_result = None # Last evaluation result {layers, compose, acc} + self._init_slice_on() + + def _init_slice_on(self): + """Initialize SLICE_ON Lambda for cycle crossfade.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + self._slice_on_lambda = config.get('fn') + init = config.get('init', {}) + self._slice_on_acc = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + # Evaluate initial state + self._eval_slice_on() + break + + def _eval_slice_on(self): + """Evaluate the SLICE_ON Lambda with current state.""" + if not self._slice_on_lambda: + return + + n = len(self._get_video_sources()) + videos = list(range(n)) # Placeholder video indices + + try: + result = eval_slice_on_lambda( + self._slice_on_lambda, + self._slice_on_acc, + self.beat_count, + 0.0, # start time (not used for weights) + 1.0, # end time (not used for weights) + videos, + self.sexp_interp, + ) + self._slice_on_result = result + # Update accumulator for next beat + if 'acc' in result: + self._slice_on_acc = result['acc'] + except Exception as e: + import sys + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _init_scans(self): + """Initialize all scan nodes from the recipe.""" + seed_offset = 0 + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + node_id = node['id'] + config = node.get('config', {}) + + # Create RNG with unique seed + scan_seed = config.get('seed', self.master_seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + evaluator = ExprEvaluator(rng) + init_value = evaluator.eval(init_expr, {}) + + self.scans[node_id] = ScanState( + node_id=node_id, + name=node.get('name'), + value=init_value, + rng=rng, + init_expr=init_expr, + step_expr=config.get('step_expr', {}), + emit_expr=config.get('emit_expr', {}), + ) + + # Compute initial emit + self._update_emit(node_id) + + def _update_emit(self, node_id: str): + """Update the emit value for a scan.""" + scan = self.scans[node_id] + evaluator = ExprEvaluator(scan.rng) + + # Build environment from current state + env = self._build_scan_env(scan) + + # Evaluate emit expression + emit_value = evaluator.eval(scan.emit_expr, env) + self.scan_outputs[node_id] = emit_value + + def _build_scan_env(self, scan: ScanState) -> Dict[str, Any]: + """Build environment for scan expression evaluation.""" + env = {} + + # Add state variables + if isinstance(scan.value, dict): + env.update(scan.value) + else: + env['acc'] = scan.value + + # Add beat count + env['beat_count'] = self.beat_count + env['time'] = self.current_time + + return env + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + + # Estimate beat interval + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + # Step each scan + for node_id, scan in self.scans.items(): + evaluator = ExprEvaluator(scan.rng) + env = self._build_scan_env(scan) + + # Evaluate step expression + new_value = evaluator.eval(scan.step_expr, env) + scan.value = new_value + + # Update emit + self._update_emit(node_id) + + # Step the cycle state + self._step_cycle() + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with audio analysis.""" + self.current_time = t + self.energy = energy + + # Update scans on beat (edge detection) + if is_beat and not self.last_beat_detected: + self.on_beat() + self.last_beat_detected = is_beat + + def resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to get the current value.""" + if not isinstance(binding, dict) or not binding.get('_binding'): + return binding + + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get the raw value + if source_id in self.scan_outputs: + value = self.scan_outputs[source_id] + else: + # Might be an analyzer reference - use energy as fallback + value = self.energy + + # Extract feature if value is a dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def get_effect_params(self, effect_node: dict) -> Dict[str, Any]: + """Get resolved parameters for an effect node.""" + config = effect_node.get('config', {}) + params = {} + + for key, value in config.items(): + # Skip internal fields + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', 'analysis_refs'): + continue + + # Resolve bindings + params[key] = self.resolve_binding(value) + + return params + + def get_scan_value(self, name: str) -> Any: + """Get scan output by name.""" + for node_id, scan in self.scans.items(): + if scan.name == name: + return self.scan_outputs.get(node_id) + return None + + def get_all_scan_values(self) -> Dict[str, Any]: + """Get all named scan outputs.""" + result = {} + for node_id, scan in self.scans.items(): + if scan.name: + result[scan.name] = self.scan_outputs.get(node_id) + return result + + # === Compositor interface methods === + + def _get_video_sources(self) -> List[str]: + """Get list of video source node IDs.""" + sources = [] + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + sources.append(node['id']) + # Filter to video only (exclude audio - last one is usually audio) + # Look at file extensions in the paths + return sources[:-1] if len(sources) > 1 else sources + + def _trace_effect_chain(self, start_id: str, stop_at_blend: bool = True) -> List[dict]: + """Trace effect chain from a node, returning effects in order.""" + chain = [] + current_id = start_id + + for _ in range(20): # Max depth + # Find node that uses current as input + next_node = None + for node in self.recipe.nodes: + if current_id in node.get('inputs', []): + if node.get('type') == 'EFFECT': + effect_type = node.get('config', {}).get('effect') + chain.append(node) + if stop_at_blend and effect_type == 'blend': + return chain + next_node = node + break + elif node.get('type') == 'SEGMENT': + next_node = node + break + + if next_node is None: + break + current_id = next_node['id'] + + return chain + + def _find_clip_chains(self, source_idx: int) -> tuple: + """Find effect chains for clip A and B from a source.""" + sources = self._get_video_sources() + if source_idx >= len(sources): + return [], [] + + source_id = sources[source_idx] + + # Find SEGMENT node + segment_id = None + for node in self.recipe.nodes: + if node.get('type') == 'SEGMENT' and source_id in node.get('inputs', []): + segment_id = node['id'] + break + + if not segment_id: + return [], [] + + # Find the two effect chains from segment (clip A and clip B) + chains = [] + for node in self.recipe.nodes: + if segment_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + chain = self._trace_effect_chain(segment_id) + # Get chain starting from this specific branch + branch_chain = [node] + current = node['id'] + for _ in range(10): + found = False + for n in self.recipe.nodes: + if current in n.get('inputs', []) and n.get('type') == 'EFFECT': + branch_chain.append(n) + if n.get('config', {}).get('effect') == 'blend': + break + current = n['id'] + found = True + break + if not found: + break + chains.append(branch_chain) + + # Return first two chains as A and B + chain_a = chains[0] if len(chains) > 0 else [] + chain_b = chains[1] if len(chains) > 1 else [] + return chain_a, chain_b + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """Get effect parameters for a source clip (compositor interface).""" + # Get the correct chain for this clip + chain_a, chain_b = self._find_clip_chains(source_idx) + chain = chain_a if clip == 'a' else chain_b + + # Default params + params = { + "rotate_angle": 0, + "zoom_amount": 1.0, + "invert_amount": 0, + "hue_degrees": 0, + "ascii_mix": 0, + "ascii_char_size": 8, + } + + # Resolve from effects in chain + for eff in chain: + config = eff.get('config', {}) + effect_type = config.get('effect') + + if effect_type == 'rotate': + angle_binding = config.get('angle') + if angle_binding: + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + # Bound to analyzer - use energy with range + range_map = angle_binding.get('range') + if range_map: + lo, hi = range_map + params["rotate_angle"] = lo + energy * (hi - lo) + else: + params["rotate_angle"] = self.resolve_binding(angle_binding) + else: + params["rotate_angle"] = angle_binding if isinstance(angle_binding, (int, float)) else 0 + + elif effect_type == 'zoom': + amount_binding = config.get('amount') + if amount_binding: + if isinstance(amount_binding, dict) and amount_binding.get('_binding'): + range_map = amount_binding.get('range') + if range_map: + lo, hi = range_map + params["zoom_amount"] = lo + energy * (hi - lo) + else: + params["zoom_amount"] = self.resolve_binding(amount_binding) + else: + params["zoom_amount"] = amount_binding if isinstance(amount_binding, (int, float)) else 1.0 + + elif effect_type == 'invert': + amount_binding = config.get('amount') + if amount_binding: + val = self.resolve_binding(amount_binding) + params["invert_amount"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'hue_shift': + deg_binding = config.get('degrees') + if deg_binding: + val = self.resolve_binding(deg_binding) + params["hue_degrees"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'ascii_art': + mix_binding = config.get('mix') + if mix_binding: + val = self.resolve_binding(mix_binding) + params["ascii_mix"] = val if isinstance(val, (int, float)) else 0 + size_binding = config.get('char_size') + if size_binding: + if isinstance(size_binding, dict) and size_binding.get('_binding'): + range_map = size_binding.get('range') + if range_map: + lo, hi = range_map + params["ascii_char_size"] = lo + energy * (hi - lo) + + return params + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair (compositor interface).""" + params = { + "blend_opacity": 0.5, + "pair_rotation": 0, + } + + # Find the blend node for this source + chain_a, _ = self._find_clip_chains(source_idx) + + # The last effect in chain_a should be the blend + blend_node = None + for eff in reversed(chain_a): + if eff.get('config', {}).get('effect') == 'blend': + blend_node = eff + break + + if blend_node: + config = blend_node.get('config', {}) + opacity_binding = config.get('opacity') + if opacity_binding: + val = self.resolve_binding(opacity_binding) + if isinstance(val, (int, float)): + params["blend_opacity"] = val + + # Find rotate after blend (pair rotation) + blend_id = blend_node['id'] + for node in self.recipe.nodes: + if blend_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + if node.get('config', {}).get('effect') == 'rotate': + angle_binding = node.get('config', {}).get('angle') + if angle_binding: + val = self.resolve_binding(angle_binding) + if isinstance(val, (int, float)): + params["pair_rotation"] = val + break + + return params + + def _get_cycle_state(self) -> dict: + """Get current cycle state from SLICE_ON or internal tracking.""" + if not hasattr(self, '_cycle_state'): + # Initialize from SLICE_ON node + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + init = node.get('config', {}).get('init', {}) + self._cycle_state = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + break + else: + self._cycle_state = {'cycle': 0, 'beat': 0, 'clen': 60} + + return self._cycle_state + + def _step_cycle(self): + """Step the cycle state forward on beat by evaluating SLICE_ON Lambda.""" + # Use interpreter to evaluate the Lambda + self._eval_slice_on() + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get weights from interpreted result + if self._slice_on_result: + compose = self._slice_on_result.get('compose', {}) + weights = compose.get('weights', []) + if weights and len(weights) == n: + # Normalize + total = sum(weights) + if total > 0: + return [w / total for w in weights] + + # Fallback: equal weights + return [1.0 / n] * n + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get zooms from interpreted result (layers -> effects -> zoom amount) + if self._slice_on_result: + layers = self._slice_on_result.get('layers', []) + if layers and len(layers) == n: + zooms = [] + for layer in layers: + effects = layer.get('effects', []) + zoom_amt = 1.0 + for eff in effects: + if eff.get('effect') == 'zoom' or (hasattr(eff.get('effect'), 'name') and eff.get('effect').name == 'zoom'): + zoom_amt = eff.get('amount', 1.0) + break + zooms.append(zoom_amt) + return zooms + + # Fallback + return [1.0] * n + + def _get_final_rotate_scan_id(self) -> str: + """Find the scan ID that drives the final rotation (after SLICE_ON).""" + if hasattr(self, '_final_rotate_scan_id'): + return self._final_rotate_scan_id + + # Find SLICE_ON node index + slice_on_idx = None + for i, node in enumerate(self.recipe.nodes): + if node.get('type') == 'SLICE_ON': + slice_on_idx = i + break + + # Find rotate effect after SLICE_ON + if slice_on_idx is not None: + for node in self.recipe.nodes[slice_on_idx + 1:]: + if node.get('type') == 'EFFECT': + config = node.get('config', {}) + if config.get('effect') == 'rotate': + angle_binding = config.get('angle', {}) + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + self._final_rotate_scan_id = angle_binding.get('source') + return self._final_rotate_scan_id + + self._final_rotate_scan_id = None + return None + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (compositor interface).""" + # Get named scans + scan_values = self.get_all_scan_values() + + # Whole spin - get from the specific scan bound to final rotate effect + whole_spin = 0 + final_rotate_scan_id = self._get_final_rotate_scan_id() + if final_rotate_scan_id and final_rotate_scan_id in self.scan_outputs: + val = self.scan_outputs[final_rotate_scan_id] + if isinstance(val, dict) and 'angle' in val: + whole_spin = val['angle'] + elif isinstance(val, (int, float)): + whole_spin = val + + # Ripple + ripple_gate = scan_values.get('ripple-gate', 0) + ripple_cx = scan_values.get('ripple-cx', 0.5) + ripple_cy = scan_values.get('ripple-cy', 0.5) + + if isinstance(ripple_gate, dict): + ripple_gate = ripple_gate.get('gate', 0) if 'gate' in ripple_gate else 1 + + return { + "whole_spin_angle": whole_spin, + "ripple_amplitude": ripple_gate * (5 + energy * 45), + "ripple_cx": ripple_cx if isinstance(ripple_cx, (int, float)) else 0.5, + "ripple_cy": ripple_cy if isinstance(ripple_cy, (int, float)) else 0.5, + } diff --git a/streaming/sexp_interp.py b/streaming/sexp_interp.py new file mode 100644 index 0000000..e3433b2 --- /dev/null +++ b/streaming/sexp_interp.py @@ -0,0 +1,376 @@ +""" +S-expression interpreter for streaming execution. + +Evaluates sexp expressions including: +- let bindings +- lambda definitions and calls +- Arithmetic, comparison, logic operators +- dict/list operations +- Random number generation +""" + +import random +from typing import Any, Dict, List, Callable +from dataclasses import dataclass + + +@dataclass +class Lambda: + """Runtime lambda value.""" + params: List[str] + body: Any + closure: Dict[str, Any] + + +class Symbol: + """Symbol reference.""" + def __init__(self, name: str): + self.name = name + + def __repr__(self): + return f"Symbol({self.name})" + + +class SexpInterpreter: + """ + Interprets S-expressions in real-time. + + Handles the full sexp language used in recipes. + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + self.globals: Dict[str, Any] = {} + + def eval(self, expr: Any, env: Dict[str, Any] = None) -> Any: + """Evaluate an expression in the given environment.""" + if env is None: + env = {} + + # Literals + if isinstance(expr, (int, float, str, bool)) or expr is None: + return expr + + # Symbol lookup + if isinstance(expr, Symbol) or (hasattr(expr, 'name') and hasattr(expr, '__class__') and expr.__class__.__name__ == 'Symbol'): + name = expr.name if hasattr(expr, 'name') else str(expr) + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined symbol: {name}") + + # Compiled expression dict (from compiler) + if isinstance(expr, dict): + if expr.get('_expr'): + return self._eval_compiled_expr(expr, env) + # Plain dict - evaluate values that might be expressions + result = {} + for k, v in expr.items(): + # Some keys should keep Symbol values as strings (effect names, modes) + if k in ('effect', 'mode') and hasattr(v, 'name'): + result[k] = v.name + else: + result[k] = self.eval(v, env) + return result + + # List expression (sexp) + if isinstance(expr, (list, tuple)) and len(expr) > 0: + return self._eval_list(expr, env) + + # Empty list + if isinstance(expr, (list, tuple)): + return [] + + return expr + + def _eval_compiled_expr(self, expr: dict, env: Dict[str, Any]) -> Any: + """Evaluate a compiled expression dict.""" + op = expr.get('op') + args = expr.get('args', []) + + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + return self.rng.randint(self.eval(args[0], env), self.eval(args[1], env)) + elif op == 'rand-range': + return self.rng.uniform(self.eval(args[0], env), self.eval(args[1], env)) + + # Logic + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown op: {op}") + + def _eval_list(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate a list expression (sexp form).""" + if len(expr) == 0: + return [] + + head = expr[0] + + # Get head name + if isinstance(head, Symbol) or (hasattr(head, 'name') and hasattr(head, '__class__')): + head_name = head.name if hasattr(head, 'name') else str(head) + elif isinstance(head, str): + head_name = head + else: + # Not a symbol - check if it's a data list or function call + if isinstance(head, dict): + # List of dicts - evaluate each element as data + return [self.eval(item, env) for item in expr] + # Otherwise evaluate as function call + fn = self.eval(head, env) + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + # Special forms + if head_name == 'let': + return self._eval_let(expr, env) + elif head_name in ('lambda', 'fn'): + return self._eval_lambda(expr, env) + elif head_name == 'if': + return self._eval_if(expr, env) + elif head_name == 'dict': + return self._eval_dict(expr, env) + elif head_name == 'get': + obj = self.eval(expr[1], env) + key = self.eval(expr[2], env) if len(expr) > 2 else expr[2] + if isinstance(key, str): + return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None) + return obj[key] + elif head_name == 'len': + return len(self.eval(expr[1], env)) + elif head_name == 'range': + start = self.eval(expr[1], env) + end = self.eval(expr[2], env) if len(expr) > 2 else start + if len(expr) == 2: + return list(range(end)) + return list(range(start, end)) + elif head_name == 'map': + fn = self.eval(expr[1], env) + lst = self.eval(expr[2], env) + return [self._call(fn, [x], env) for x in lst] + elif head_name == 'mod': + return self.eval(expr[1], env) % self.eval(expr[2], env) + + # Arithmetic + elif head_name == '+': + return self.eval(expr[1], env) + self.eval(expr[2], env) + elif head_name == '-': + if len(expr) == 2: + return -self.eval(expr[1], env) + return self.eval(expr[1], env) - self.eval(expr[2], env) + elif head_name == '*': + return self.eval(expr[1], env) * self.eval(expr[2], env) + elif head_name == '/': + return self.eval(expr[1], env) / self.eval(expr[2], env) + + # Comparison + elif head_name == '<': + return self.eval(expr[1], env) < self.eval(expr[2], env) + elif head_name == '>': + return self.eval(expr[1], env) > self.eval(expr[2], env) + elif head_name == '<=': + return self.eval(expr[1], env) <= self.eval(expr[2], env) + elif head_name == '>=': + return self.eval(expr[1], env) >= self.eval(expr[2], env) + elif head_name == '=': + return self.eval(expr[1], env) == self.eval(expr[2], env) + + # Logic + elif head_name == 'and': + return all(self.eval(a, env) for a in expr[1:]) + elif head_name == 'or': + return any(self.eval(a, env) for a in expr[1:]) + elif head_name == 'not': + return not self.eval(expr[1], env) + + # Function call + else: + fn = env.get(head_name) or self.globals.get(head_name) + if fn is None: + raise NameError(f"Undefined function: {head_name}") + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + def _eval_let(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (let [bindings...] body).""" + bindings = expr[1] + body = expr[2] + + # Create new environment with bindings + new_env = dict(env) + + # Process bindings in pairs + i = 0 + while i < len(bindings): + name = bindings[i] + if isinstance(name, Symbol) or hasattr(name, 'name'): + name = name.name if hasattr(name, 'name') else str(name) + value = self.eval(bindings[i + 1], new_env) + new_env[name] = value + i += 2 + + return self.eval(body, new_env) + + def _eval_lambda(self, expr: list, env: Dict[str, Any]) -> Lambda: + """Evaluate (lambda [params] body).""" + params_expr = expr[1] + body = expr[2] + + # Extract parameter names + params = [] + for p in params_expr: + if isinstance(p, Symbol) or hasattr(p, 'name'): + params.append(p.name if hasattr(p, 'name') else str(p)) + else: + params.append(str(p)) + + return Lambda(params=params, body=body, closure=dict(env)) + + def _eval_if(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + def _eval_dict(self, expr: list, env: Dict[str, Any]) -> dict: + """Evaluate (dict :key val ...).""" + result = {} + i = 1 + while i < len(expr): + key = expr[i] + # Handle keyword syntax (:key) and Keyword objects + if hasattr(key, 'name'): + key = key.name + elif hasattr(key, '__class__') and key.__class__.__name__ == 'Keyword': + key = str(key).lstrip(':') + elif isinstance(key, str) and key.startswith(':'): + key = key[1:] + value = self.eval(expr[i + 1], env) + result[key] = value + i += 2 + return result + + def _call(self, fn: Any, args: List[Any], env: Dict[str, Any]) -> Any: + """Call a function with arguments.""" + if isinstance(fn, Lambda): + # Our own Lambda type + call_env = dict(fn.closure) + for param, arg in zip(fn.params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif hasattr(fn, 'params') and hasattr(fn, 'body'): + # Lambda from parser (artdag.sexp.parser.Lambda) + call_env = dict(env) + if hasattr(fn, 'closure') and fn.closure: + call_env.update(fn.closure) + # Get param names + params = [] + for p in fn.params: + if hasattr(p, 'name'): + params.append(p.name) + else: + params.append(str(p)) + for param, arg in zip(params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif callable(fn): + return fn(*args) + else: + raise TypeError(f"Not callable: {type(fn).__name__}") + + +def eval_slice_on_lambda(lambda_obj, acc: dict, i: int, start: float, end: float, + videos: list, interp: SexpInterpreter = None) -> dict: + """ + Evaluate a SLICE_ON lambda function. + + Args: + lambda_obj: The Lambda object from the compiled recipe + acc: Current accumulator state + i: Beat index + start: Slice start time + end: Slice end time + videos: List of video inputs + interp: Interpreter to use + + Returns: + Dict with 'layers', 'compose', 'acc' keys + """ + if interp is None: + interp = SexpInterpreter() + + # Set up global 'videos' for (len videos) to work + interp.globals['videos'] = videos + + # Build initial environment with lambda parameters + env = dict(lambda_obj.closure) if hasattr(lambda_obj, 'closure') and lambda_obj.closure else {} + env['videos'] = videos + + # Call the lambda + result = interp._call(lambda_obj, [acc, i, start, end], env) + + return result diff --git a/streaming/sexp_to_cuda.py b/streaming/sexp_to_cuda.py new file mode 100644 index 0000000..e4051bd --- /dev/null +++ b/streaming/sexp_to_cuda.py @@ -0,0 +1,706 @@ +""" +Sexp to CUDA Kernel Compiler. + +Compiles sexp frame pipelines to fused CUDA kernels for maximum performance. +Instead of interpreting sexp and launching 10+ kernels per frame, +generates a single kernel that does everything in one pass. +""" + +import cupy as cp +import numpy as np +from typing import Dict, List, Any, Optional, Tuple +import hashlib +import sys +import logging + +logger = logging.getLogger(__name__) + +# Kernel cache +_COMPILED_KERNELS: Dict[str, Any] = {} + + +def compile_frame_pipeline(effects: List[dict], width: int, height: int) -> callable: + """ + Compile a list of effects to a fused CUDA kernel. + + Args: + effects: List of effect dicts like: + [{'op': 'rotate', 'angle': 45.0}, + {'op': 'blend', 'alpha': 0.5, 'src2': }, + {'op': 'hue_shift', 'degrees': 90.0}, + {'op': 'ripple', 'amplitude': 10.0, 'frequency': 8.0, ...}] + width, height: Frame dimensions + + Returns: + Callable that takes input frame and returns output frame + """ + + # Generate cache key + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects]) + cache_key = f"{width}x{height}_{hashlib.md5(ops_key.encode()).hexdigest()}" + + if cache_key in _COMPILED_KERNELS: + return _COMPILED_KERNELS[cache_key] + + # Generate fused kernel code + kernel_code = _generate_fused_kernel(effects, width, height) + + # Compile kernel + kernel = cp.RawKernel(kernel_code, 'fused_pipeline') + + # Create wrapper function + def run_pipeline(frame: cp.ndarray, **dynamic_params) -> cp.ndarray: + """Run the compiled pipeline on a frame.""" + if frame.dtype != cp.uint8: + frame = cp.clip(frame, 0, 255).astype(cp.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = cp.ascontiguousarray(frame) + + output = cp.zeros_like(frame) + + block = (16, 16) + grid = ((width + 15) // 16, (height + 15) // 16) + + # Build parameter array + params = _build_params(effects, dynamic_params) + + kernel(grid, block, (frame, output, width, height, params)) + + return output + + _COMPILED_KERNELS[cache_key] = run_pipeline + return run_pipeline + + +def _generate_fused_kernel(effects: List[dict], width: int, height: int) -> str: + """Generate CUDA kernel code for fused effects pipeline.""" + + # Validate all ops are supported + SUPPORTED_OPS = {'rotate', 'zoom', 'ripple', 'invert', 'hue_shift', 'brightness'} + for effect in effects: + op = effect.get('op') + if op not in SUPPORTED_OPS: + raise ValueError(f"Unsupported CUDA kernel operation: '{op}'. Supported ops: {', '.join(sorted(SUPPORTED_OPS))}. Note: 'resize' must be handled separately before the fused kernel.") + + # Build the kernel + code = r''' +extern "C" __global__ +void fused_pipeline( + const unsigned char* src, + unsigned char* dst, + int width, int height, + const float* params +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Start with source coordinates + float src_x = (float)x; + float src_y = (float)y; + float cx = width / 2.0f; + float cy = height / 2.0f; + + // Track accumulated transforms + float total_cos = 1.0f, total_sin = 0.0f; // rotation + float total_zoom = 1.0f; // zoom + float ripple_dx = 0.0f, ripple_dy = 0.0f; // ripple displacement + + int param_idx = 0; + +''' + + # Add effect-specific code + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'rotate': + code += f''' + // Rotate {i} + {{ + float angle = params[param_idx++] * 3.14159265f / 180.0f; + float c = cosf(angle); + float s = sinf(angle); + // Compose with existing rotation + float nc = total_cos * c - total_sin * s; + float ns = total_cos * s + total_sin * c; + total_cos = nc; + total_sin = ns; + }} +''' + elif op == 'zoom': + code += f''' + // Zoom {i} + {{ + float zoom = params[param_idx++]; + total_zoom *= zoom; + }} +''' + elif op == 'ripple': + code += f''' + // Ripple {i} - matching original formula: sin(dist/freq - phase) * exp(-dist*decay/maxdim) + {{ + float amplitude = params[param_idx++]; + float frequency = params[param_idx++]; + float decay = params[param_idx++]; + float phase = params[param_idx++]; + float rcx = params[param_idx++]; + float rcy = params[param_idx++]; + + float rdx = src_x - rcx; + float rdy = src_y - rcy; + float dist = sqrtf(rdx * rdx + rdy * rdy); + float max_dim = (float)(width > height ? width : height); + + // Original formula: sin(dist / frequency - phase) * exp(-dist * decay / max_dim) + float wave = sinf(dist / frequency - phase); + float amp = amplitude * expf(-dist * decay / max_dim); + + if (dist > 0.001f) {{ + ripple_dx += rdx / dist * wave * amp; + ripple_dy += rdy / dist * wave * amp; + }} + }} +''' + + # Apply all geometric transforms at once + code += ''' + // Apply accumulated geometric transforms + { + // Translate to center + float dx = src_x - cx; + float dy = src_y - cy; + + // Apply rotation + float rx = total_cos * dx + total_sin * dy; + float ry = -total_sin * dx + total_cos * dy; + + // Apply zoom (inverse for sampling) + rx /= total_zoom; + ry /= total_zoom; + + // Translate back and apply ripple + src_x = rx + cx - ripple_dx; + src_y = ry + cy - ripple_dy; + } + + // Sample source with bilinear interpolation + float r, g, b; + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + r = g = b = 0; + } else { + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + int idx00 = (y0 * width + x0) * 3; + int idx10 = (y0 * width + x0 + 1) * 3; + int idx01 = ((y0 + 1) * width + x0) * 3; + int idx11 = ((y0 + 1) * width + x0 + 1) * 3; + + #define BILERP(c) \\ + (src[idx00 + c] * (1-fx) * (1-fy) + \\ + src[idx10 + c] * fx * (1-fy) + \\ + src[idx01 + c] * (1-fx) * fy + \\ + src[idx11 + c] * fx * fy) + + r = BILERP(0); + g = BILERP(1); + b = BILERP(2); + } + +''' + + # Add color transforms + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'invert': + code += f''' + // Invert {i} + {{ + float amount = params[param_idx++]; + if (amount > 0.5f) {{ + r = 255.0f - r; + g = 255.0f - g; + b = 255.0f - b; + }} + }} +''' + elif op == 'hue_shift': + code += f''' + // Hue shift {i} + {{ + float shift = params[param_idx++]; + if (fabsf(shift) > 0.01f) {{ + // RGB to HSV + float rf = r / 255.0f; + float gf = g / 255.0f; + float bf = b / 255.0f; + + float max_c = fmaxf(rf, fmaxf(gf, bf)); + float min_c = fminf(rf, fminf(gf, bf)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) {{ + s = delta / max_c; + if (rf >= max_c) h = (gf - bf) / delta; + else if (gf >= max_c) h = 2.0f + (bf - rf) / delta; + else h = 4.0f + (rf - gf) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + }} + + h = fmodf(h + shift + 360.0f, 360.0f); + + // HSV to RGB + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) {{ r2 = c; g2 = x_val; b2 = 0; }} + else if (h < 120) {{ r2 = x_val; g2 = c; b2 = 0; }} + else if (h < 180) {{ r2 = 0; g2 = c; b2 = x_val; }} + else if (h < 240) {{ r2 = 0; g2 = x_val; b2 = c; }} + else if (h < 300) {{ r2 = x_val; g2 = 0; b2 = c; }} + else {{ r2 = c; g2 = 0; b2 = x_val; }} + + r = (r2 + m) * 255.0f; + g = (g2 + m) * 255.0f; + b = (b2 + m) * 255.0f; + }} + }} +''' + elif op == 'brightness': + code += f''' + // Brightness {i} + {{ + float factor = params[param_idx++]; + r *= factor; + g *= factor; + b *= factor; + }} +''' + + # Write output + code += ''' + // Write output + int dst_idx = (y * width + x) * 3; + dst[dst_idx] = (unsigned char)fminf(255.0f, fmaxf(0.0f, r)); + dst[dst_idx + 1] = (unsigned char)fminf(255.0f, fmaxf(0.0f, g)); + dst[dst_idx + 2] = (unsigned char)fminf(255.0f, fmaxf(0.0f, b)); +} +''' + + return code + + +_BUILD_PARAMS_COUNT = 0 + +def _build_params(effects: List[dict], dynamic_params: dict) -> cp.ndarray: + """Build parameter array for kernel. + + IMPORTANT: Parameters must be built in the same order the kernel consumes them: + 1. First all geometric transforms (rotate, zoom, ripple) in list order + 2. Then all color transforms (invert, hue_shift, brightness) in list order + """ + global _BUILD_PARAMS_COUNT + _BUILD_PARAMS_COUNT += 1 + + # ALWAYS log first few calls - use WARNING to ensure visibility in Celery logs + if _BUILD_PARAMS_COUNT <= 3: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] effects={[e['op'] for e in effects]}") + + params = [] + + # First pass: geometric transforms (matches kernel's first loop) + for effect in effects: + op = effect['op'] + + if op == 'rotate': + params.append(float(dynamic_params.get('rotate_angle', effect.get('angle', 0)))) + elif op == 'zoom': + params.append(float(dynamic_params.get('zoom_amount', effect.get('amount', 1.0)))) + elif op == 'ripple': + amp = float(dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10))) + freq = float(effect.get('frequency', 8)) + decay = float(effect.get('decay', 2)) + phase = float(dynamic_params.get('ripple_phase', effect.get('phase', 0))) + cx = float(effect.get('center_x', 960)) + cy = float(effect.get('center_y', 540)) + params.extend([amp, freq, decay, phase, cx, cy]) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] ripple amp={amp} freq={freq} decay={decay} phase={phase:.2f} cx={cx} cy={cy}") + + # Second pass: color transforms (matches kernel's second loop) + for effect in effects: + op = effect['op'] + + if op == 'invert': + amt = float(effect.get('amount', 0)) + params.append(amt) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] invert amount={amt}") + elif op == 'hue_shift': + deg = float(effect.get('degrees', 0)) + params.append(deg) + if _BUILD_PARAMS_COUNT <= 10 or _BUILD_PARAMS_COUNT % 500 == 0: + logger.warning(f"[BUILD_PARAMS #{_BUILD_PARAMS_COUNT}] hue_shift degrees={deg}") + elif op == 'brightness': + params.append(float(effect.get('factor', 1.0))) + + return cp.array(params, dtype=cp.float32) + + +def compile_autonomous_pipeline(effects: List[dict], width: int, height: int, + dynamic_expressions: dict = None) -> callable: + """ + Compile a fully autonomous pipeline that computes ALL parameters on GPU. + + This eliminates Python from the hot path - the kernel computes time-based + parameters (sin, cos, etc.) directly on GPU. + + Args: + effects: List of effect dicts + width, height: Frame dimensions + dynamic_expressions: Dict mapping param names to expressions, e.g.: + {'rotate_angle': 't * 30', + 'ripple_phase': 't * 2', + 'brightness_factor': '0.8 + 0.4 * sin(t * 2)'} + + Returns: + Callable that takes (frame, frame_num, fps) and returns output frame + """ + if dynamic_expressions is None: + dynamic_expressions = {} + + # Generate cache key + ops_key = str([(e['op'], {k:v for k,v in e.items() if k != 'src2'}) for e in effects]) + expr_key = str(sorted(dynamic_expressions.items())) + cache_key = f"auto_{width}x{height}_{hashlib.md5((ops_key + expr_key).encode()).hexdigest()}" + + if cache_key in _COMPILED_KERNELS: + return _COMPILED_KERNELS[cache_key] + + # Generate autonomous kernel code + kernel_code = _generate_autonomous_kernel(effects, width, height, dynamic_expressions) + + # Compile kernel + kernel = cp.RawKernel(kernel_code, 'autonomous_pipeline') + + # Create wrapper function + def run_autonomous(frame: cp.ndarray, frame_num: int, fps: float = 30.0) -> cp.ndarray: + """Run the autonomous pipeline - no Python in the hot path!""" + if frame.dtype != cp.uint8: + frame = cp.clip(frame, 0, 255).astype(cp.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = cp.ascontiguousarray(frame) + + output = cp.zeros_like(frame) + + block = (16, 16) + grid = ((width + 15) // 16, (height + 15) // 16) + + # Only pass frame_num and fps - kernel computes everything else! + t = float(frame_num) / float(fps) + kernel(grid, block, (frame, output, np.int32(width), np.int32(height), + np.float32(t), np.int32(frame_num))) + + return output + + _COMPILED_KERNELS[cache_key] = run_autonomous + return run_autonomous + + +def _generate_autonomous_kernel(effects: List[dict], width: int, height: int, + dynamic_expressions: dict) -> str: + """Generate CUDA kernel that computes everything autonomously.""" + + # Map simple expressions to CUDA code + def expr_to_cuda(expr: str) -> str: + """Convert simple expression to CUDA.""" + expr = expr.replace('sin(', 'sinf(') + expr = expr.replace('cos(', 'cosf(') + expr = expr.replace('abs(', 'fabsf(') + return expr + + code = r''' +extern "C" __global__ +void autonomous_pipeline( + const unsigned char* src, + unsigned char* dst, + int width, int height, + float t, int frame_num +) { + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // Compute dynamic parameters from time (ALL ON GPU!) +''' + + # Add dynamic parameter calculations + rotate_expr = dynamic_expressions.get('rotate_angle', '0.0f') + ripple_phase_expr = dynamic_expressions.get('ripple_phase', '0.0f') + brightness_expr = dynamic_expressions.get('brightness_factor', '1.0f') + zoom_expr = dynamic_expressions.get('zoom_amount', '1.0f') + + code += f''' + float rotate_angle = {expr_to_cuda(rotate_expr)}; + float ripple_phase = {expr_to_cuda(ripple_phase_expr)}; + float brightness_factor = {expr_to_cuda(brightness_expr)}; + float zoom_amount = {expr_to_cuda(zoom_expr)}; + + // Start with source coordinates + float src_x = (float)x; + float src_y = (float)y; + float cx = width / 2.0f; + float cy = height / 2.0f; + + // Accumulated transforms + float total_cos = 1.0f, total_sin = 0.0f; + float total_zoom = 1.0f; + float ripple_dx = 0.0f, ripple_dy = 0.0f; + +''' + + # Add effect-specific code + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'rotate': + code += f''' + // Rotate {i} + {{ + float angle = rotate_angle * 3.14159265f / 180.0f; + float c = cosf(angle); + float s = sinf(angle); + float nc = total_cos * c - total_sin * s; + float ns = total_cos * s + total_sin * c; + total_cos = nc; + total_sin = ns; + }} +''' + elif op == 'zoom': + code += f''' + // Zoom {i} + {{ + total_zoom *= zoom_amount; + }} +''' + elif op == 'ripple': + amp = float(effect.get('amplitude', 10)) + freq = float(effect.get('frequency', 8)) + decay = float(effect.get('decay', 2)) + rcx = float(effect.get('center_x', width/2)) + rcy = float(effect.get('center_y', height/2)) + code += f''' + // Ripple {i} + {{ + float amplitude = {amp:.1f}f; + float frequency = {freq:.1f}f; + float decay_val = {decay:.1f}f; + float rcx = {rcx:.1f}f; + float rcy = {rcy:.1f}f; + + float rdx = src_x - rcx; + float rdy = src_y - rcy; + float dist = sqrtf(rdx * rdx + rdy * rdy); + + float wave = sinf(dist * frequency * 0.1f + ripple_phase); + float amp = amplitude * expf(-dist * decay_val * 0.01f); + + if (dist > 0.001f) {{ + ripple_dx += rdx / dist * wave * amp; + ripple_dy += rdy / dist * wave * amp; + }} + }} +''' + + # Apply geometric transforms + code += ''' + // Apply accumulated transforms + { + float dx = src_x - cx; + float dy = src_y - cy; + float rx = total_cos * dx + total_sin * dy; + float ry = -total_sin * dx + total_cos * dy; + rx /= total_zoom; + ry /= total_zoom; + src_x = rx + cx - ripple_dx; + src_y = ry + cy - ripple_dy; + } + + // Bilinear sample + float r, g, b; + if (src_x < 0 || src_x >= width - 1 || src_y < 0 || src_y >= height - 1) { + r = g = b = 0; + } else { + int x0 = (int)src_x; + int y0 = (int)src_y; + float fx = src_x - x0; + float fy = src_y - y0; + + int idx00 = (y0 * width + x0) * 3; + int idx10 = (y0 * width + x0 + 1) * 3; + int idx01 = ((y0 + 1) * width + x0) * 3; + int idx11 = ((y0 + 1) * width + x0 + 1) * 3; + + #define BILERP(c) \\ + (src[idx00 + c] * (1-fx) * (1-fy) + \\ + src[idx10 + c] * fx * (1-fy) + \\ + src[idx01 + c] * (1-fx) * fy + \\ + src[idx11 + c] * fx * fy) + + r = BILERP(0); + g = BILERP(1); + b = BILERP(2); + } + +''' + + # Add color transforms + for i, effect in enumerate(effects): + op = effect['op'] + + if op == 'hue_shift': + degrees = float(effect.get('degrees', 0)) + code += f''' + // Hue shift {i} + {{ + float shift = {degrees:.1f}f; + float rf = r / 255.0f; + float gf = g / 255.0f; + float bf = b / 255.0f; + + float max_c = fmaxf(rf, fmaxf(gf, bf)); + float min_c = fminf(rf, fminf(gf, bf)); + float delta = max_c - min_c; + + float h = 0, s = 0, v = max_c; + + if (delta > 0.00001f) {{ + s = delta / max_c; + if (rf >= max_c) h = (gf - bf) / delta; + else if (gf >= max_c) h = 2.0f + (bf - rf) / delta; + else h = 4.0f + (rf - gf) / delta; + h *= 60.0f; + if (h < 0) h += 360.0f; + }} + + h = fmodf(h + shift + 360.0f, 360.0f); + + float c = v * s; + float x_val = c * (1 - fabsf(fmodf(h / 60.0f, 2.0f) - 1)); + float m = v - c; + + float r2, g2, b2; + if (h < 60) {{ r2 = c; g2 = x_val; b2 = 0; }} + else if (h < 120) {{ r2 = x_val; g2 = c; b2 = 0; }} + else if (h < 180) {{ r2 = 0; g2 = c; b2 = x_val; }} + else if (h < 240) {{ r2 = 0; g2 = x_val; b2 = c; }} + else if (h < 300) {{ r2 = x_val; g2 = 0; b2 = c; }} + else {{ r2 = c; g2 = 0; b2 = x_val; }} + + r = (r2 + m) * 255.0f; + g = (g2 + m) * 255.0f; + b = (b2 + m) * 255.0f; + }} +''' + elif op == 'brightness': + code += ''' + // Brightness + { + r *= brightness_factor; + g *= brightness_factor; + b *= brightness_factor; + } +''' + + # Write output + code += ''' + // Write output + int dst_idx = (y * width + x) * 3; + dst[dst_idx] = (unsigned char)fminf(255.0f, fmaxf(0.0f, r)); + dst[dst_idx + 1] = (unsigned char)fminf(255.0f, fmaxf(0.0f, g)); + dst[dst_idx + 2] = (unsigned char)fminf(255.0f, fmaxf(0.0f, b)); +} +''' + + return code + + +# Test the compiler +if __name__ == '__main__': + import time + + print("[sexp_to_cuda] Testing fused kernel compiler...") + print("=" * 60) + + # Define a test pipeline + effects = [ + {'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 30.0}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + frame = cp.random.randint(0, 255, (1080, 1920, 3), dtype=cp.uint8) + + # ===== Test 1: Standard fused kernel (params passed from Python) ===== + print("\n[Test 1] Standard fused kernel (Python computes params)") + pipeline = compile_frame_pipeline(effects, 1920, 1080) + + # Warmup + output = pipeline(frame) + cp.cuda.Stream.null.synchronize() + + # Benchmark with Python param computation + start = time.time() + for i in range(100): + # Simulate Python computing params (like sexp interpreter does) + import math + t = i / 30.0 + angle = t * 30 + phase = t * 2 + brightness = 0.8 + 0.4 * math.sin(t * 2) + output = pipeline(frame, rotate_angle=angle, ripple_phase=phase) + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + print(f" Time: {elapsed/100*1000:.2f}ms per frame") + print(f" FPS: {100/elapsed:.0f}") + + # ===== Test 2: Autonomous kernel (GPU computes everything) ===== + print("\n[Test 2] Autonomous kernel (GPU computes ALL params)") + + dynamic_expressions = { + 'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)', + } + + auto_pipeline = compile_autonomous_pipeline(effects, 1920, 1080, dynamic_expressions) + + # Warmup + output = auto_pipeline(frame, 0, 30.0) + cp.cuda.Stream.null.synchronize() + + # Benchmark - NO Python computation in loop! + start = time.time() + for i in range(100): + output = auto_pipeline(frame, i, 30.0) # Just pass frame_num! + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + print(f" Time: {elapsed/100*1000:.2f}ms per frame") + print(f" FPS: {100/elapsed:.0f}") + + print("\n" + "=" * 60) + print("Autonomous kernel eliminates Python from hot path!") diff --git a/streaming/sexp_to_jax.py b/streaming/sexp_to_jax.py new file mode 100644 index 0000000..db781f2 --- /dev/null +++ b/streaming/sexp_to_jax.py @@ -0,0 +1,4628 @@ +""" +Sexp to JAX Compiler. + +Compiles S-expression effects to JAX functions that run on CPU, GPU, or TPU. +Uses XLA compilation via @jax.jit for automatic kernel fusion. + +Unlike sexp_to_cuda.py which generates CUDA C strings, this compiles +S-expressions directly to JAX operations which XLA then optimizes. + +Usage: + from streaming.sexp_to_jax import compile_effect + + effect_code = ''' + (effect "threshold" + :params ((threshold :default 128)) + :body (let ((g (gray frame))) + (rgb (where (> g threshold) 255 0) + (where (> g threshold) 255 0) + (where (> g threshold) 255 0)))) + ''' + + run_effect = compile_effect(effect_code) + output = run_effect(frame, threshold=128) +""" + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial +from typing import Any, Dict, List, Callable, Optional, Tuple +import hashlib +import numpy as np + +# Import parser +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) +from sexp_effects.parser import parse, parse_all, Symbol, Keyword + +# Import typography primitives +from streaming.jax_typography import bind_typography_primitives + + +# ============================================================================= +# Compilation Cache +# ============================================================================= + +_COMPILED_EFFECTS: Dict[str, Callable] = {} + + +# ============================================================================= +# Font Atlas for ASCII Effects +# ============================================================================= + +# Character sets for ASCII rendering +ASCII_ALPHABETS = { + 'standard': ' .:-=+*#%@', + 'blocks': ' ░▒▓█', + 'simple': ' .:oO@', + 'digits': ' 0123456789', + 'binary': ' 01', + 'detailed': ' .\'`^",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$', +} + +# Cache for font atlases: (alphabet, char_size, font_name) -> atlas array +_FONT_ATLAS_CACHE: Dict[tuple, np.ndarray] = {} + + +def _create_font_atlas(alphabet: str, char_size: int, font_name: str = None) -> np.ndarray: + """ + Create a font atlas with all characters pre-rendered. + + Uses numpy arrays (not JAX) to avoid tracer issues when called at compile time. + + Args: + alphabet: String of characters to render (ordered by brightness, dark to light) + char_size: Size of each character cell in pixels + font_name: Optional font name/path (uses default monospace if None) + + Returns: + NumPy array of shape (num_chars, char_size, char_size, 3) with rendered characters + Each character is white on black background. + """ + cache_key = (alphabet, char_size, font_name) + if cache_key in _FONT_ATLAS_CACHE: + return _FONT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + # Fallback: create simple block-based atlas without PIL + return _create_block_atlas(alphabet, char_size) + + num_chars = len(alphabet) + atlas = [] + + # Try to load a monospace font + font = None + font_size = int(char_size * 0.9) # Slightly smaller than cell + + # Try various monospace fonts + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf', + '/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf', + '/System/Library/Fonts/Menlo.ttc', # macOS + '/System/Library/Fonts/Monaco.dfont', # macOS + 'C:\\Windows\\Fonts\\consola.ttf', # Windows + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + # Use default font + try: + font = ImageFont.load_default() + except: + # Ultimate fallback to blocks + return _create_block_atlas(alphabet, char_size) + + for char in alphabet: + # Create image for this character + img = Image.new('RGB', (char_size, char_size), color=(0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Get text bounding box for centering + try: + bbox = draw.textbbox((0, 0), char, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + except AttributeError: + # Older PIL versions + text_width, text_height = draw.textsize(char, font=font) + + # Center the character + x = (char_size - text_width) // 2 + y = (char_size - text_height) // 2 + + # Draw white character on black background + draw.text((x, y), char, fill=(255, 255, 255), font=font) + + # Convert to numpy array (NOT jax array - avoids tracer issues) + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) + _FONT_ATLAS_CACHE[cache_key] = atlas + return atlas + + +def _create_block_atlas(alphabet: str, char_size: int) -> np.ndarray: + """ + Create a simple block-based atlas without fonts. + Uses numpy to avoid tracer issues. + """ + num_chars = len(alphabet) + atlas = [] + + for i, char in enumerate(alphabet): + # Brightness proportional to position in alphabet + brightness = int(255 * i / max(num_chars - 1, 1)) + + # Create a simple pattern based on character + img = np.full((char_size, char_size, 3), brightness, dtype=np.uint8) + + # Add some texture/pattern for visual interest + # Checkerboard pattern for mid-range characters + if 0.2 < i / num_chars < 0.8: + y_coords, x_coords = np.mgrid[:char_size, :char_size] + checker = ((x_coords + y_coords) % 2 == 0) + variation = int(brightness * 0.2) + img = np.where(checker[:, :, None], + np.clip(img.astype(np.int16) + variation, 0, 255).astype(np.uint8), + np.clip(img.astype(np.int16) - variation, 0, 255).astype(np.uint8)) + + atlas.append(img) + + return np.stack(atlas, axis=0) + + +def _get_alphabet_string(alphabet_name: str) -> str: + """Get the character string for a named alphabet or return as-is if custom.""" + if alphabet_name in ASCII_ALPHABETS: + return ASCII_ALPHABETS[alphabet_name] + return alphabet_name # Assume it's a custom character string + + +# ============================================================================= +# Text Rendering with Font Atlas (JAX-compatible) +# ============================================================================= + +# Default character set for text rendering (printable ASCII) +TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~ + +# Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height) +_TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {} + + +def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple: + """ + Create a font atlas for general text rendering with proper baseline alignment. + + Font Metrics (from typography): + - Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l) + - Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y) + - Baseline: the line text "sits" on - all characters align to this + - Em-square: the design space, typically = ascender + descender + + Returns: + (atlas, char_to_idx, char_widths, char_height, baseline_offset) + - atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA + - char_to_idx: dict mapping character to atlas index + - char_widths: numpy array of actual width for each character + - char_height: height of character cells (ascent + descent) + - baseline_offset: pixels from top of cell to baseline (= ascent) + """ + cache_key = (font_name, font_size) + if cache_key in _TEXT_ATLAS_CACHE: + return _TEXT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + raise ImportError("PIL/Pillow required for text rendering") + + # Load font - match drawing.py's font order for consistency + font = None + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + '/System/Library/Fonts/Helvetica.ttc', + '/System/Library/Fonts/Arial.ttf', + 'C:\\Windows\\Fonts\\arial.ttf', + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + font = ImageFont.load_default() + + # Get font metrics - this is the key to proper text layout + # getmetrics() returns (ascent, descent) where: + # ascent = pixels from baseline to top of tallest character + # descent = pixels from baseline to bottom of lowest character + ascent, descent = font.getmetrics() + + # Cell dimensions based on font metrics (not per-character bounding boxes) + cell_height = ascent + descent + 2 # +2 for padding + baseline_y = ascent + 1 # Baseline position within cell (1px padding from top) + + # Find max character width + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + max_width = 0 + char_widths_dict = {} + + for char in TEXT_CHARSET: + try: + # Use getlength for horizontal advance (proper character spacing) + advance = font.getlength(char) + char_widths_dict[char] = int(advance) + max_width = max(max_width, int(advance)) + except: + char_widths_dict[char] = font_size // 2 + max_width = max(max_width, font_size // 2) + + cell_width = max_width + 2 # +2 for padding + + # Create atlas with all characters - draw same way as prim_text for pixel-perfect match + char_to_idx = {} + char_widths = [] # Advance widths + char_left_bearings = [] # Left bearing (x offset from origin to first pixel) + atlas = [] + + # Position to draw at within each tile (with margin for negative bearings) + draw_x = 5 # Margin for chars with negative left bearing + draw_y = 0 # Top of cell (PIL default without anchor) + + for i, char in enumerate(TEXT_CHARSET): + char_to_idx[char] = i + char_widths.append(char_widths_dict.get(char, cell_width // 2)) + + # Create RGBA image for this character + img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw same way as prim_text - at (draw_x, draw_y), no anchor + # This positions the text origin, and glyphs may extend left/right from there + draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + # Get bbox to find left bearing + bbox = draw.textbbox((draw_x, draw_y), char, font=font) + left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends + char_left_bearings.append(left_bearing) + + # Convert to numpy + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4) + char_widths = np.array(char_widths, dtype=np.int32) + char_left_bearings = np.array(char_left_bearings, dtype=np.int32) + + # Return draw_x (origin offset within tile) so rendering knows where origin is + result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings) + _TEXT_ATLAS_CACHE[cache_key] = result + return result + + +def jax_text_render(frame, text: str, x: int, y: int, + font_name: str = None, font_size: int = 32, + color=(255, 255, 255), opacity: float = 1.0, + align: str = "left", valign: str = "baseline", + shadow: bool = False, shadow_color=(0, 0, 0), + shadow_offset: int = 2): + """ + Render text onto frame using font atlas (JAX-compatible). + + This is designed to be called from within a JIT-compiled function. + The font atlas is created at compile time (using numpy/PIL), + then converted to JAX array for the actual rendering. + + Typography notes: + - Baseline: The line text "sits" on. Most characters rest on this line. + - Ascender: Top of tall letters (b, d, h, k, l) - above baseline + - Descender: Bottom of letters like g, j, p, q, y - below baseline + - For normal text, use valign="baseline" and y = the baseline position + + Args: + frame: Input frame (H, W, 3) + text: Text string to render + x, y: Position reference point (affected by align/valign) + font_name: Font to use (None = default) + font_size: Font size in pixels + color: RGB tuple (0-255) + opacity: 0.0 to 1.0 + align: Horizontal alignment relative to x: + "left" - text starts at x + "center" - text centered on x + "right" - text ends at x + valign: Vertical alignment relative to y: + "baseline" - text baseline at y (default, like normal text) + "top" - top of ascenders at y + "middle" - text vertically centered on y + "bottom" - bottom of descenders at y + shadow: Whether to draw drop shadow + shadow_color: Shadow RGB color + shadow_offset: Shadow offset in pixels + + Returns: + Frame with text rendered + """ + if not text: + return frame + + h, w = frame.shape[:2] + + # Get or create font atlas (this happens at trace time, uses numpy) + atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size) + + # Convert atlas to JAX array + atlas = jnp.asarray(atlas_np) + + # Atlas dimensions + cell_width = atlas.shape[2] + + # Convert text to character indices and compute character widths + # (at trace time, text is static so we can pre-compute) + indices_list = [] + char_x_offsets = [0] # Starting x position for each character + total_width = 0 + + for char in text: + if char in char_to_idx: + idx = char_to_idx[char] + indices_list.append(idx) + char_w = int(char_widths_np[idx]) + else: + indices_list.append(char_to_idx.get(' ', 0)) + char_w = int(char_widths_np[char_to_idx.get(' ', 0)]) + total_width += char_w + char_x_offsets.append(total_width) + + indices = jnp.array(indices_list, dtype=jnp.int32) + num_chars = len(indices_list) + + # Actual text dimensions using proportional widths + text_width = total_width + text_height = char_height + + # Adjust position for horizontal alignment + if align == "center": + x = x - text_width // 2 + elif align == "right": + x = x - text_width + + # Adjust position for vertical alignment + # baseline_offset = pixels from top of cell to baseline + if valign == "baseline": + # y specifies baseline position, so top of text cell is above it + y = y - baseline_offset + elif valign == "middle": + y = y - text_height // 2 + elif valign == "bottom": + y = y - text_height + # valign == "top" needs no adjustment (default) + + # Ensure position is integer + x, y = int(x), int(y) + + # Create text strip with proper character spacing at trace time (using numpy) + # This ensures proportional fonts render correctly + # + # The atlas stores each character drawn at (origin_x, 0) in its tile. + # To place a character at cursor position 'cx': + # - The tile's origin_x should align with cx in the strip + # - So we blit tile to strip starting at (cx - origin_x) + # + # Add padding for characters with negative left bearings + strip_padding = origin_x # Extra space at start for negative bearings + text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8) + + for i, char in enumerate(text): + if char in char_to_idx: + idx = char_to_idx[char] + char_tile = atlas_np[idx] # (char_height, cell_width, 4) + cx = char_x_offsets[i] + # Position tile so its origin aligns with cursor position + strip_x = strip_padding + cx - origin_x + if strip_x >= 0: + end_x = min(strip_x + cell_width, text_strip_np.shape[1]) + tile_end = end_x - strip_x + text_strip_np[:, strip_x:end_x] = np.maximum( + text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end]) + + # Trim the strip: + # - Left side: trim to first visible pixel (handles negative left bearing) + # - Right side: use computed text_width (preserve advance width spacing) + alpha = text_strip_np[:, :, 3] + cols_with_content = np.any(alpha > 0, axis=0) + if cols_with_content.any(): + first_col = np.argmax(cols_with_content) + # Right edge: use the computed text width from the strip's logical end + right_col = strip_padding + text_width + # Adjust x to account for the left trim offset + x = x + first_col - strip_padding + text_strip_np = text_strip_np[:, first_col:right_col] + else: + # No visible content, return original frame + return frame + + # Convert to JAX + text_strip = jnp.asarray(text_strip_np) + + # Convert color to array + color = jnp.array(color, dtype=jnp.float32) + shadow_color = jnp.array(shadow_color, dtype=jnp.float32) + + # Apply color tint to text strip (white text * color) + text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color + text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity + + # Start with frame as float + result = frame.astype(jnp.float32) + + # Draw shadow first if enabled + if shadow: + sx, sy = x + shadow_offset, y + shadow_offset + shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color + shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5 + result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy) + + # Draw main text + result = _composite_text_strip(result, text_rgb, text_alpha, x, y) + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def _composite_text_strip(frame, text_rgb, text_alpha, x, y): + """ + Composite text strip onto frame at position (x, y). + + Uses alpha blending: result = text * alpha + frame * (1 - alpha) + + This is designed to work within JAX tracing. + """ + h, w = frame.shape[:2] + th, tw = text_rgb.shape[:2] + + # Clamp to frame bounds + # Source region (in text strip) + src_x1 = jnp.maximum(0, -x) + src_y1 = jnp.maximum(0, -y) + src_x2 = jnp.minimum(tw, w - x) + src_y2 = jnp.minimum(th, h - y) + + # Destination region (in frame) + dst_x1 = jnp.maximum(0, x) + dst_y1 = jnp.maximum(0, y) + dst_x2 = jnp.minimum(w, x + tw) + dst_y2 = jnp.minimum(h, y + th) + + # Check if there's anything to draw + # (We need to handle this carefully for JAX - can't use Python if with traced values) + # Instead, we'll do the full operation but the slicing will handle bounds + + # Create coordinate grids for the destination region + # We'll use dynamic_slice for JAX-compatible slicing + + # For simplicity and JAX compatibility, we'll create a full-frame text layer + # and composite it - this is less efficient but works with JIT + + # Create full-frame RGBA layer + text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Place text strip in the layer using dynamic_update_slice + # First pad the text strip to handle out-of-bounds + padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + padded_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Calculate valid region + y_start = int(max(0, y)) + y_end = int(min(h, y + th)) + x_start = int(max(0, x)) + x_end = int(min(w, x + tw)) + + src_y_start = int(max(0, -y)) + src_y_end = src_y_start + (y_end - y_start) + src_x_start = int(max(0, -x)) + src_x_end = src_x_start + (x_end - x_start) + + # Only proceed if there's a valid region + if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start: + # Extract the valid portion of text + valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end] + valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end] + + # Use lax.dynamic_update_slice for JAX compatibility + padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0)) + padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start)) + + # Alpha composite: result = text * alpha + frame * (1 - alpha) + alpha_3d = padded_alpha[:, :, jnp.newaxis] + result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d) + + return result + + +def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple: + """ + Measure text dimensions (width, height). + + This can be called at compile time to get text dimensions for layout. + + Returns: + (width, height) tuple in pixels + """ + _, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size) + + # Sum actual character widths + total_width = 0 + for c in text: + if c in char_to_idx: + total_width += int(char_widths[char_to_idx[c]]) + else: + total_width += int(char_widths[char_to_idx.get(' ', 0)]) + + return (total_width, char_height) + + +def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict: + """ + Get font metrics for layout calculations. + + Typography terms: + - ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.) + - descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.) + - height: total height = ascent + descent (plus padding) + - baseline: position of baseline from top of text cell + + Returns: + dict with keys: ascent, descent, height, baseline + """ + _, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size) + + # baseline_offset is pixels from top to baseline (= ascent + padding) + # descent = height - baseline (approximately) + ascent = baseline_offset - 1 # remove padding + descent = char_height - baseline_offset - 1 # remove padding + + return { + 'ascent': ascent, + 'descent': descent, + 'height': char_height, + 'baseline': baseline_offset, + } + + +def jax_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, max_size: int = 200) -> int: + """ + Calculate font size to fit text within bounds. + + Binary search for largest size that fits. + """ + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + w, h = jax_text_size(text, font_name, mid) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + +# ============================================================================= +# JAX Primitives - True primitives that can't be derived +# ============================================================================= + +def jax_width(frame): + """Frame width.""" + return frame.shape[1] + + +def jax_height(frame): + """Frame height.""" + return frame.shape[0] + + +def jax_channel(frame, idx): + """Extract channel by index as flat array.""" + # idx must be a static int for indexing + return frame[:, :, int(idx)].flatten().astype(jnp.float32) + + +def jax_merge_channels(r, g, b, shape): + """Merge RGB channels back to frame.""" + h, w = shape + r_img = jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8) + g_img = jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8) + b_img = jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([r_img, g_img, b_img], axis=2) + + +def jax_iota(n): + """Generate [0, 1, 2, ..., n-1].""" + return jnp.arange(n, dtype=jnp.float32) + + +def jax_repeat(x, n): + """Repeat each element n times: [a,b] -> [a,a,b,b].""" + return jnp.repeat(x, n) + + +def jax_tile(x, n): + """Tile array n times: [a,b] -> [a,b,a,b].""" + return jnp.tile(x, n) + + +def jax_gather(data, indices): + """Parallel index lookup.""" + flat_data = data.flatten() + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, len(flat_data) - 1) + return flat_data[idx_clipped] + + +def jax_scatter(indices, values, size): + """Parallel index write (last write wins).""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].set(values) + + +def jax_scatter_add(indices, values, size): + """Parallel index accumulate.""" + result = jnp.zeros(size, dtype=jnp.float32) + idx_clipped = jnp.clip(indices.astype(jnp.int32), 0, size - 1) + return result.at[idx_clipped].add(values) + + +def jax_group_reduce(values, group_indices, num_groups, op='mean'): + """Reduce values by group.""" + grp = group_indices.astype(jnp.int32) + + if op == 'sum': + result = jnp.zeros(num_groups, dtype=jnp.float32) + return result.at[grp].add(values) + elif op == 'mean': + sums = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(values) + counts = jnp.zeros(num_groups, dtype=jnp.float32).at[grp].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + elif op == 'max': + result = jnp.full(num_groups, -jnp.inf, dtype=jnp.float32) + result = result.at[grp].max(values) + return jnp.where(result == -jnp.inf, 0.0, result) + elif op == 'min': + result = jnp.full(num_groups, jnp.inf, dtype=jnp.float32) + result = result.at[grp].min(values) + return jnp.where(result == jnp.inf, 0.0, result) + else: + raise ValueError(f"Unknown reduce op: {op}") + + +def jax_where(cond, true_val, false_val): + """Conditional select.""" + return jnp.where(cond, true_val, false_val) + + +def jax_cell_indices(frame, cell_size): + """Compute cell index for each pixel.""" + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + + cell_row = y_coords // cell_size + cell_col = x_coords // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + return jnp.clip(cell_idx, 0, rows * cols - 1).astype(jnp.float32) + + +def jax_pool_frame(frame, cell_size): + """ + Pool frame to cell values. + Returns tuple: (cell_r, cell_g, cell_b, cell_lum) + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y_coords = jnp.repeat(jnp.arange(h), w) + x_coords = jnp.tile(jnp.arange(w), h) + cell_row = jnp.clip(y_coords // cs, 0, rows - 1) + cell_col = jnp.clip(x_coords // cs, 0, cols - 1) + cell_idx = (cell_row * cols + cell_col).astype(jnp.int32) + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(jnp.float32) + g_flat = frame[:, :, 1].flatten().astype(jnp.float32) + b_flat = frame[:, :, 2].flatten().astype(jnp.float32) + + # Pool each channel (mean) + def pool_channel(data): + sums = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(data) + counts = jnp.zeros(num_cells, dtype=jnp.float32).at[cell_idx].add(1.0) + return jnp.where(counts > 0, sums / counts, 0.0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + return (r_pooled, g_pooled, b_pooled, lum) + + +# ============================================================================= +# Scan (Prefix Operations) - JAX implementations +# ============================================================================= + +def jax_scan_add(x, axis=None): + """Cumulative sum (prefix sum).""" + if axis is not None: + return jnp.cumsum(x, axis=int(axis)) + return jnp.cumsum(x.flatten()) + + +def jax_scan_mul(x, axis=None): + """Cumulative product.""" + if axis is not None: + return jnp.cumprod(x, axis=int(axis)) + return jnp.cumprod(x.flatten()) + + +def jax_scan_max(x, axis=None): + """Cumulative maximum.""" + if axis is not None: + return lax.cummax(x, axis=int(axis)) + return lax.cummax(x.flatten(), axis=0) + + +def jax_scan_min(x, axis=None): + """Cumulative minimum.""" + if axis is not None: + return lax.cummin(x, axis=int(axis)) + return lax.cummin(x.flatten(), axis=0) + + +# ============================================================================= +# Outer Product - JAX implementations +# ============================================================================= + +def jax_outer(x, y, op='*'): + """Outer product with configurable operation.""" + x_flat = x.flatten() + y_flat = y.flatten() + + ops = { + '*': lambda a, b: jnp.outer(a, b), + '+': lambda a, b: a[:, None] + b[None, :], + '-': lambda a, b: a[:, None] - b[None, :], + '/': lambda a, b: a[:, None] / b[None, :], + 'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]), + 'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]), + } + + op_fn = ops.get(op, ops['*']) + return op_fn(x_flat, y_flat) + + +def jax_outer_add(x, y): + """Outer sum.""" + return jax_outer(x, y, '+') + + +def jax_outer_mul(x, y): + """Outer product.""" + return jax_outer(x, y, '*') + + +def jax_outer_max(x, y): + """Outer max.""" + return jax_outer(x, y, 'max') + + +def jax_outer_min(x, y): + """Outer min.""" + return jax_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - JAX implementations +# ============================================================================= + +def jax_reduce_axis(x, op='sum', axis=0): + """Reduce along an axis.""" + axis = int(axis) + ops = { + 'sum': lambda d: jnp.sum(d, axis=axis), + '+': lambda d: jnp.sum(d, axis=axis), + 'mean': lambda d: jnp.mean(d, axis=axis), + 'max': lambda d: jnp.max(d, axis=axis), + 'min': lambda d: jnp.min(d, axis=axis), + 'prod': lambda d: jnp.prod(d, axis=axis), + '*': lambda d: jnp.prod(d, axis=axis), + 'std': lambda d: jnp.std(d, axis=axis), + } + op_fn = ops.get(op, ops['sum']) + return op_fn(x) + + +def jax_sum_axis(x, axis=0): + """Sum along axis.""" + return jnp.sum(x, axis=int(axis)) + + +def jax_mean_axis(x, axis=0): + """Mean along axis.""" + return jnp.mean(x, axis=int(axis)) + + +def jax_max_axis(x, axis=0): + """Max along axis.""" + return jnp.max(x, axis=int(axis)) + + +def jax_min_axis(x, axis=0): + """Min along axis.""" + return jnp.min(x, axis=int(axis)) + + +# ============================================================================= +# Windowed Operations - JAX implementations +# ============================================================================= + +def jax_window(x, size, op='mean', stride=1): + """ + Sliding window operation. + + For 1D arrays: standard sliding window + For 2D arrays: 2D sliding window (size x size) + """ + size = int(size) + stride = int(stride) + + if x.ndim == 1: + # 1D sliding window using convolution trick + n = len(x) + if op == 'sum': + kernel = jnp.ones(size) + return jnp.convolve(x, kernel, mode='valid')[::stride] + elif op == 'mean': + kernel = jnp.ones(size) / size + return jnp.convolve(x, kernel, mode='valid')[::stride] + else: + # For max/min, use manual approach + out_n = (n - size) // stride + 1 + indices = jnp.arange(out_n) * stride + windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices) + if op == 'max': + return jnp.max(windows, axis=1) + elif op == 'min': + return jnp.min(windows, axis=1) + else: + return jnp.mean(windows, axis=1) + else: + # 2D sliding window + h, w = x.shape[:2] + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + # Extract all windows using vmap + def extract_window(ij): + i, j = ij // out_w, ij % out_w + return lax.dynamic_slice(x, (i * stride, j * stride), (size, size)) + + indices = jnp.arange(out_h * out_w) + windows = jax.vmap(extract_window)(indices) + + if op == 'sum': + result = jnp.sum(windows, axis=(1, 2)) + elif op == 'mean': + result = jnp.mean(windows, axis=(1, 2)) + elif op == 'max': + result = jnp.max(windows, axis=(1, 2)) + elif op == 'min': + result = jnp.min(windows, axis=(1, 2)) + else: + result = jnp.mean(windows, axis=(1, 2)) + + return result.reshape(out_h, out_w) + + +def jax_window_sum(x, size, stride=1): + """Sliding window sum.""" + return jax_window(x, size, 'sum', stride) + + +def jax_window_mean(x, size, stride=1): + """Sliding window mean.""" + return jax_window(x, size, 'mean', stride) + + +def jax_window_max(x, size, stride=1): + """Sliding window max.""" + return jax_window(x, size, 'max', stride) + + +def jax_window_min(x, size, stride=1): + """Sliding window min.""" + return jax_window(x, size, 'min', stride) + + +def jax_integral_image(frame): + """ + Compute integral image (summed area table). + Enables O(1) box blur at any radius. + """ + if frame.ndim == 3: + # Convert to grayscale + gray = jnp.mean(frame.astype(jnp.float32), axis=2) + else: + gray = frame.astype(jnp.float32) + + # Cumsum along both axes + return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1) + + +def jax_sample(frame, x, y): + """Bilinear sample at (x, y) coordinates. + + Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default): + out-of-bounds samples return 0, then bilinear blend includes those zeros. + """ + h, w = frame.shape[:2] + + # Get integer coords for the 4 sample points + x0 = jnp.floor(x).astype(jnp.int32) + y0 = jnp.floor(y).astype(jnp.int32) + x1 = x0 + 1 + y1 = y0 + 1 + + fx = x - x0.astype(jnp.float32) + fy = y - y0.astype(jnp.float32) + + # Check which sample points are in bounds + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + # Clamp indices for safe array access (values will be masked anyway) + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + + # Bilinear interpolation for each channel + def interp_channel(c): + # Sample with 0 for out-of-bounds (BORDER_CONSTANT) + c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0) + c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0) + c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0) + c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0) + + return (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + r = interp_channel(0) + g = interp_channel(1) + b = interp_channel(2) + + return r, g, b + + +# ============================================================================= +# Convolution Operations +# ============================================================================= + +def jax_convolve2d(data, kernel): + """2D convolution on a single channel.""" + # data shape: (H, W), kernel shape: (kH, kW) + # Use JAX's conv with appropriate padding + h, w = data.shape + kh, kw = kernel.shape + + # Reshape for conv: (batch, H, W, channels) and (kH, kW, in_c, out_c) + data_4d = data.reshape(1, h, w, 1) + kernel_4d = kernel.reshape(kh, kw, 1, 1) + + # Convolve with 'SAME' padding + result = lax.conv_general_dilated( + data_4d, kernel_4d, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + + return result.reshape(h, w) + + +def jax_blur(frame, radius=1): + """Gaussian blur.""" + # Create gaussian kernel + size = int(radius) * 2 + 1 + x = jnp.arange(size) - radius + gaussian_1d = jnp.exp(-x**2 / (2 * (radius/2)**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + kernel = jnp.outer(gaussian_1d, gaussian_1d) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_sharpen(frame, amount=1.0): + """Sharpen using unsharp mask.""" + kernel = jnp.array([ + [0, -1, 0], + [-1, 5, -1], + [0, -1, 0] + ], dtype=jnp.float32) + + # Adjust kernel based on amount + center = 4 * amount + 1 + kernel = kernel.at[1, 1].set(center) + kernel = kernel * amount + jnp.array([[0,0,0],[0,1,0],[0,0,0]]) * (1 - amount) + + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_edge_detect(frame): + """Sobel edge detection.""" + # Sobel kernels + sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) + sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) + + # Convert to grayscale first + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + gx = jax_convolve2d(gray, sobel_x) + gy = jax_convolve2d(gray, sobel_y) + + edges = jnp.sqrt(gx**2 + gy**2) + edges = jnp.clip(edges, 0, 255).astype(jnp.uint8) + + return jnp.stack([edges, edges, edges], axis=2) + + +def jax_emboss(frame): + """Emboss effect.""" + kernel = jnp.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]], dtype=jnp.float32) + + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + embossed = jax_convolve2d(gray, kernel) + 128 + embossed = jnp.clip(embossed, 0, 255).astype(jnp.uint8) + + return jnp.stack([embossed, embossed, embossed], axis=2) + + +# ============================================================================= +# Color Space Conversion +# ============================================================================= + +def jax_rgb_to_hsv(r, g, b): + """Convert RGB to HSV. All inputs/outputs are 0-255 range.""" + r, g, b = r / 255.0, g / 255.0, b / 255.0 + + max_c = jnp.maximum(jnp.maximum(r, g), b) + min_c = jnp.minimum(jnp.minimum(r, g), b) + diff = max_c - min_c + + # Value + v = max_c + + # Saturation + s = jnp.where(max_c > 0, diff / max_c, 0.0) + + # Hue + h = jnp.where(diff == 0, 0.0, + jnp.where(max_c == r, (60 * ((g - b) / diff) + 360) % 360, + jnp.where(max_c == g, 60 * ((b - r) / diff) + 120, + 60 * ((r - g) / diff) + 240))) + + return h, s * 255, v * 255 + + +def jax_hsv_to_rgb(h, s, v): + """Convert HSV to RGB. H is 0-360, S and V are 0-255.""" + h = h % 360 + s, v = s / 255.0, v / 255.0 + + c = v * s + x = c * (1 - jnp.abs((h / 60) % 2 - 1)) + m = v - c + + h_sector = (h / 60).astype(jnp.int32) % 6 + + r = jnp.where(h_sector == 0, c, + jnp.where(h_sector == 1, x, + jnp.where(h_sector == 2, 0, + jnp.where(h_sector == 3, 0, + jnp.where(h_sector == 4, x, c))))) + + g = jnp.where(h_sector == 0, x, + jnp.where(h_sector == 1, c, + jnp.where(h_sector == 2, c, + jnp.where(h_sector == 3, x, + jnp.where(h_sector == 4, 0, 0))))) + + b = jnp.where(h_sector == 0, 0, + jnp.where(h_sector == 1, 0, + jnp.where(h_sector == 2, x, + jnp.where(h_sector == 3, c, + jnp.where(h_sector == 4, c, x))))) + + return (r + m) * 255, (g + m) * 255, (b + m) * 255 + + +def jax_adjust_saturation(frame, factor): + """Adjust saturation by factor (1.0 = unchanged).""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + s = jnp.clip(s * factor, 0, 255) + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +def jax_shift_hue(frame, degrees): + """Shift hue by degrees.""" + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + + h, s, v = jax_rgb_to_hsv(r, g, b) + h = (h + degrees) % 360 + r2, g2, b2 = jax_hsv_to_rgb(h, s, v) + + h_dim, w_dim = frame.shape[:2] + return jnp.stack([ + jnp.clip(r2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(g2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8), + jnp.clip(b2, 0, 255).reshape(h_dim, w_dim).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Color Adjustment Operations +# ============================================================================= + +def jax_adjust_brightness(frame, amount): + """Adjust brightness by amount (-255 to 255).""" + result = frame.astype(jnp.float32) + amount + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_adjust_contrast(frame, factor): + """Adjust contrast by factor (1.0 = unchanged).""" + result = (frame.astype(jnp.float32) - 128) * factor + 128 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_invert(frame): + """Invert colors.""" + return 255 - frame + + +def jax_posterize(frame, levels): + """Reduce to N color levels per channel.""" + levels = int(levels) + if levels < 2: + levels = 2 + step = 255.0 / (levels - 1) + result = jnp.round(frame.astype(jnp.float32) / step) * step + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_threshold(frame, level, invert=False): + """Binary threshold.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + if invert: + binary = jnp.where(gray < level, 255, 0).astype(jnp.uint8) + else: + binary = jnp.where(gray >= level, 255, 0).astype(jnp.uint8) + + return jnp.stack([binary, binary, binary], axis=2) + + +def jax_sepia(frame): + """Apply sepia tone.""" + r = frame[:, :, 0].astype(jnp.float32) + g = frame[:, :, 1].astype(jnp.float32) + b = frame[:, :, 2].astype(jnp.float32) + + new_r = r * 0.393 + g * 0.769 + b * 0.189 + new_g = r * 0.349 + g * 0.686 + b * 0.168 + new_b = r * 0.272 + g * 0.534 + b * 0.131 + + return jnp.stack([ + jnp.clip(new_r, 0, 255).astype(jnp.uint8), + jnp.clip(new_g, 0, 255).astype(jnp.uint8), + jnp.clip(new_b, 0, 255).astype(jnp.uint8) + ], axis=2) + + +def jax_grayscale(frame): + """Convert to grayscale.""" + gray = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + gray = gray.astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + +# ============================================================================= +# Geometry Operations +# ============================================================================= + +def jax_flip_horizontal(frame): + """Flip horizontally.""" + return frame[:, ::-1, :] + + +def jax_flip_vertical(frame): + """Flip vertically.""" + return frame[::-1, :, :] + + +def jax_rotate(frame, angle, center_x=None, center_y=None): + """Rotate frame by angle (degrees), matching OpenCV convention. + + Positive angle = counter-clockwise rotation. + """ + h, w = frame.shape[:2] + if center_x is None: + center_x = w / 2 + if center_y is None: + center_y = h / 2 + + # Convert to radians + theta = angle * jnp.pi / 180 + cos_t, sin_t = jnp.cos(theta), jnp.sin(theta) + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # OpenCV getRotationMatrix2D gives FORWARD transform M = [[cos,sin],[-sin,cos]] + # For sampling we need INVERSE: M^-1 = [[cos,-sin],[sin,cos]] + # So: src_x = cos(θ)*(x-cx) - sin(θ)*(y-cy) + cx + # src_y = sin(θ)*(x-cx) + cos(θ)*(y-cy) + cy + x_centered = x_coords - center_x + y_centered = y_coords - center_y + + src_x = cos_t * x_centered - sin_t * y_centered + center_x + src_y = sin_t * x_centered + cos_t * y_centered + center_y + + # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_scale(frame, scale_x, scale_y=None): + """Scale frame (zoom). Matches OpenCV behavior with black out-of-bounds.""" + if scale_y is None: + scale_y = scale_x + + h, w = frame.shape[:2] + center_x, center_y = w / 2, h / 2 + + # Create coordinate grids + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + + # Scale from center (inverse mapping: dst -> src) + src_x = (x_coords - center_x) / scale_x + center_x + src_y = (y_coords - center_y) / scale_y + center_y + + # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) + r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + +def jax_resize(frame, new_width, new_height): + """Resize frame to new dimensions.""" + h, w = frame.shape[:2] + new_h, new_w = int(new_height), int(new_width) + + # Create coordinate grids for new size + y_coords = jnp.repeat(jnp.arange(new_h), new_w) + x_coords = jnp.tile(jnp.arange(new_w), new_h) + + # Map to source coordinates + src_x = x_coords * (w - 1) / (new_w - 1) + src_y = y_coords * (h - 1) / (new_h - 1) + + r, g, b = jax_sample(frame, src_x, src_y) + + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(new_h, new_w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(new_h, new_w).astype(jnp.uint8) + ], axis=2) + + +# ============================================================================= +# Blending Operations +# ============================================================================= + +def _resize_to_match(frame1, frame2): + """Resize frame2 to match frame1's dimensions if they differ. + + Uses jax.image.resize for bilinear interpolation. + Returns frame2 resized to frame1's shape. + """ + h1, w1 = frame1.shape[:2] + h2, w2 = frame2.shape[:2] + + # If same size, return as-is + if h1 == h2 and w1 == w2: + return frame2 + + # Resize frame2 to match frame1 + # jax.image.resize expects (height, width, channels) and target shape + return jax.image.resize( + frame2.astype(jnp.float32), + (h1, w1, frame2.shape[2]), + method='bilinear' + ).astype(jnp.uint8) + + +def jax_blend(frame1, frame2, alpha): + """Blend two frames. alpha=0 -> frame1, alpha=1 -> frame2. + + Auto-resizes frame2 to match frame1 if dimensions differ. + """ + frame2 = _resize_to_match(frame1, frame2) + return (frame1.astype(jnp.float32) * (1 - alpha) + + frame2.astype(jnp.float32) * alpha).astype(jnp.uint8) + + +def jax_blend_add(frame1, frame2): + """Additive blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) + frame2.astype(jnp.float32) + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_multiply(frame1, frame2): + """Multiply blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + result = frame1.astype(jnp.float32) * frame2.astype(jnp.float32) / 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def jax_blend_screen(frame1, frame2): + """Screen blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = 1 - (1 - f1) * (1 - f2) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +def jax_blend_overlay(frame1, frame2): + """Overlay blend. Auto-resizes frame2 to match frame1.""" + frame2 = _resize_to_match(frame1, frame2) + f1 = frame1.astype(jnp.float32) / 255 + f2 = frame2.astype(jnp.float32) / 255 + result = jnp.where(f1 < 0.5, + 2 * f1 * f2, + 1 - 2 * (1 - f1) * (1 - f2)) + return jnp.clip(result * 255, 0, 255).astype(jnp.uint8) + + +# ============================================================================= +# Utility +# ============================================================================= + +def make_jax_key(seed: int = 42, frame_num = 0, op_id: int = 0): + """Create a JAX random key that varies with frame and operation. + + Uses jax.random.fold_in to mix frame_num (which may be traced) into the key. + This allows JIT compilation without recompiling for each frame. + + Args: + seed: Base seed for determinism (must be concrete) + frame_num: Frame number for variation (can be traced) + op_id: Operation ID for variation (must be concrete) + + Returns: + JAX PRNGKey + """ + # Create base key from seed and op_id (both concrete) + base_key = jax.random.PRNGKey(seed + op_id * 1000003) + # Fold in frame_num (can be traced value) + return jax.random.fold_in(base_key, frame_num) + + +def jax_rand_range(lo, hi, frame_num=0, op_id=0, seed=42): + """Random float in [lo, hi), varies with frame.""" + key = make_jax_key(seed, frame_num, op_id) + return lo + jax.random.uniform(key) * (hi - lo) + + +def jax_is_nil(x): + """Check if value is None/nil.""" + return x is None + + +# ============================================================================= +# S-expression to JAX Compiler +# ============================================================================= + +class JaxCompiler: + """Compiles S-expressions to JAX functions.""" + + def __init__(self): + self.env = {} # Variable bindings during compilation + self.params = {} # Effect parameters + self.primitives = {} # Loaded primitive libraries + self.derived = {} # Loaded derived functions + + def load_derived(self, path: str): + """Load derived operations from a .sexp file.""" + with open(path, 'r') as f: + code = f.read() + exprs = parse_all(code) + + # Evaluate all define expressions to populate derived functions + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + self._eval_define(expr[1:], self.derived) + + def compile_effect(self, sexp) -> Callable: + """ + Compile an effect S-expression to a JAX function. + + Supports both formats: + (effect "name" :params (...) :body ...) + (define-effect name :params (...) body) + + Args: + sexp: Parsed S-expression + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + if not isinstance(sexp, list) or len(sexp) < 2: + raise ValueError("Effect must be a list") + + head = sexp[0] + if not isinstance(head, Symbol): + raise ValueError("Effect must start with a symbol") + + form = head.name + + # Handle both 'effect' and 'define-effect' formats + if form == 'effect': + # (effect "name" :params (...) :body ...) + name = sexp[1] if len(sexp) > 1 else "unnamed" + if isinstance(name, Symbol): + name = name.name + start_idx = 2 + elif form == 'define-effect': + # (define-effect name :params (...) body) + name = sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1]) + start_idx = 2 + else: + raise ValueError(f"Expected 'effect' or 'define-effect', got '{form}'") + + params_spec = [] + body = None + + i = start_idx + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword): + if item.name == 'params' and i + 1 < len(sexp): + params_spec = sexp[i + 1] + i += 2 + elif item.name == 'body' and i + 1 < len(sexp): + body = sexp[i + 1] + i += 2 + elif item.name in ('desc', 'type', 'range'): + # Skip metadata keywords + i += 2 + else: + i += 2 # Skip unknown keywords with their values + else: + # Assume it's the body if we haven't seen one + if body is None: + body = item + i += 1 + + if body is None: + raise ValueError(f"Effect '{name}' must have a body") + + # Extract parameter names, defaults, and static params (strings, bools) + param_info, static_params = self._parse_params(params_spec) + + # Capture derived functions for the closure + derived_fns = self.derived.copy() + + # Create the JAX function + def effect_fn(frame, **kwargs): + # Set up environment + h, w = frame.shape[:2] + # Get frame_num for deterministic random variation + frame_num = kwargs.get('frame_num', 0) + # Get seed from recipe config (passed via kwargs) + seed = kwargs.get('seed', 42) + env = { + 'frame': frame, + 'width': w, + 'height': h, + '_shape': (h, w), + # Time variables (default to 0, can be overridden via kwargs) + 't': kwargs.get('t', kwargs.get('_time', 0.0)), + '_time': kwargs.get('_time', kwargs.get('t', 0.0)), + 'time': kwargs.get('time', kwargs.get('t', 0.0)), + # Frame number for random key generation + 'frame_num': frame_num, + 'frame-num': frame_num, + '_frame_num': frame_num, + # Seed from recipe for deterministic random + '_seed': seed, + # Counter for unique random keys within same frame + '_rand_op_counter': 0, + # Common constants + 'pi': jnp.pi, + 'PI': jnp.pi, + } + + # Add derived functions + env.update(derived_fns) + + # Add typography primitives + bind_typography_primitives(env) + + # Add parameters with defaults + for pname, pdefault in param_info.items(): + if pname in kwargs: + env[pname] = kwargs[pname] + elif isinstance(pdefault, list): + # Unevaluated S-expression default - evaluate it + env[pname] = self._eval(pdefault, env) + else: + env[pname] = pdefault + + # Evaluate body + result = self._eval(body, env) + + # Ensure result is a frame + if isinstance(result, tuple) and len(result) == 3: + # RGB tuple - merge to frame + r, g, b = result + return jax_merge_channels(r, g, b, (h, w)) + elif result.ndim == 3: + return result + else: + # Single channel - replicate to RGB + h, w = env['_shape'] + gray = jnp.clip(result.reshape(h, w), 0, 255).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # JIT compile with static args for string/bool parameters and seed + # seed must be static for PRNGKey, but frame_num can be traced via fold_in + all_static = set(static_params) | {'seed'} + return jax.jit(effect_fn, static_argnames=list(all_static)) + + def _parse_params(self, params_spec) -> Tuple[Dict[str, Any], set]: + """Parse parameter specifications. + + Returns: + Tuple of (param_defaults, static_params) + - param_defaults: Dict mapping param names to default values + - static_params: Set of param names that should be static (strings, bools) + """ + result = {} + static_params = set() + if not isinstance(params_spec, list): + return result, static_params + + for param in params_spec: + if isinstance(param, Symbol): + result[param.name] = 0.0 + elif isinstance(param, list) and len(param) >= 1: + pname = param[0].name if isinstance(param[0], Symbol) else str(param[0]) + pdefault = 0.0 + ptype = None + + # Look for :default and :type keywords + i = 1 + while i < len(param): + if isinstance(param[i], Keyword): + kw = param[i].name + if kw == 'default' and i + 1 < len(param): + pdefault = param[i + 1] + if isinstance(pdefault, Symbol): + if pdefault.name == 'nil': + pdefault = None + elif pdefault.name == 'true': + pdefault = True + elif pdefault.name == 'false': + pdefault = False + i += 2 + elif kw == 'type' and i + 1 < len(param): + ptype = param[i + 1] + if isinstance(ptype, Symbol): + ptype = ptype.name + i += 2 + else: + i += 1 + else: + i += 1 + + result[pname] = pdefault + + # Mark string and bool parameters as static (can't be traced by JAX) + if ptype in ('string', 'bool') or isinstance(pdefault, (str, bool)): + static_params.add(pname) + + return result, static_params + + def _eval(self, expr, env: Dict[str, Any]) -> Any: + """Evaluate an S-expression in the given environment.""" + + # Already-evaluated values (e.g., from threading macros) + # JAX arrays, NumPy arrays, tuples, etc. + if hasattr(expr, 'shape'): # JAX/NumPy array + return expr + if isinstance(expr, tuple): # e.g., (r, g, b) from rgb + return expr + + # Literals - keep as Python numbers for static operations + if isinstance(expr, (int, float)): + return expr + + if isinstance(expr, str): + return expr + + # Symbols - variable lookup + if isinstance(expr, Symbol): + name = expr.name + if name in env: + return env[name] + if name == 'nil': + return None + if name == 'true': + return True + if name == 'false': + return False + raise NameError(f"Unknown symbol: {name}") + + # Lists - function calls + if isinstance(expr, list) and len(expr) > 0: + head = expr[0] + + if isinstance(head, Symbol): + op = head.name + args = expr[1:] + + # Special forms + if op == 'let' or op == 'let*': + return self._eval_let(args, env) + if op == 'if': + return self._eval_if(args, env) + if op == 'lambda' or op == 'λ': + return self._eval_lambda(args, env) + if op == 'define': + return self._eval_define(args, env) + + # Built-in operations + return self._eval_call(op, args, env) + + # Empty list + if isinstance(expr, list) and len(expr) == 0: + return [] + + raise ValueError(f"Cannot evaluate: {expr}") + + def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]): + """Extract a keyword argument from args list. + + Looks for :key value pattern in args and evaluates the value. + Returns default if not found. + """ + i = 0 + while i < len(args): + if isinstance(args[i], Keyword) and args[i].name == key: + if i + 1 < len(args): + val = self._eval(args[i + 1], env) + # Handle Symbol values (e.g., :op 'sum -> 'sum') + if isinstance(val, Symbol): + return val.name + return val + return default + i += 1 + return default + + def _eval_let(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body).""" + if len(args) < 2: + raise ValueError("let requires bindings and body") + + bindings = args[0] + body = args[1] + + new_env = env.copy() + + # Handle both ((var val) ...) and [var val var2 val2 ...] syntax + if isinstance(bindings, list): + # Check if it's a flat list [var val var2 val2 ...] or nested ((var val) ...) + if bindings and isinstance(bindings[0], Symbol): + # Flat list: [var val var2 val2 ...] + i = 0 + while i < len(bindings) - 1: + var = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[var] = val + i += 2 + else: + # Nested list: ((var val) (var2 val2) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + var = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[var] = val + + return self._eval(body, new_env) + + def _eval_if(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + if len(args) < 2: + raise ValueError("if requires condition and then-branch") + + cond = self._eval(args[0], env) + + # Handle None as falsy (important for optional params like overlay) + if cond is None: + return self._eval(args[2], env) if len(args) > 2 else None + + # For Python scalar bools, use normal Python if + # This allows side effects and None values + if isinstance(cond, bool): + if cond: + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + + # For NumPy/JAX scalar bools with concrete values + if hasattr(cond, 'item') and cond.shape == (): + try: + if bool(cond.item()): + return self._eval(args[1], env) + else: + return self._eval(args[2], env) if len(args) > 2 else None + except: + pass # Fall through to jnp.where for traced values + + # For traced values, evaluate both branches and use jnp.where + then_val = self._eval(args[1], env) + else_val = self._eval(args[2], env) if len(args) > 2 else 0.0 + + # Handle None by converting to zeros + if then_val is None: + then_val = 0.0 + if else_val is None: + else_val = 0.0 + + # Convert lists to tuples + if isinstance(then_val, list): + then_val = tuple(then_val) + if isinstance(else_val, list): + else_val = tuple(else_val) + + # Handle tuple results (e.g., from rgb in map-pixels) + if isinstance(then_val, tuple) and isinstance(else_val, tuple): + return tuple(jnp.where(cond, t, e) for t, e in zip(then_val, else_val)) + + return jnp.where(cond, then_val, else_val) + + def _eval_lambda(self, args, env: Dict[str, Any]) -> Callable: + """Evaluate (lambda (params) body).""" + if len(args) < 2: + raise ValueError("lambda requires parameters and body") + + params = [p.name if isinstance(p, Symbol) else str(p) for p in args[0]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + return fn + + def _eval_define(self, args, env: Dict[str, Any]) -> Any: + """Evaluate (define name value) or (define (name params) body).""" + if len(args) < 2: + raise ValueError("define requires name and value") + + name_part = args[0] + + if isinstance(name_part, list): + # Function definition: (define (name params) body) + fn_name = name_part[0].name if isinstance(name_part[0], Symbol) else str(name_part[0]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in name_part[1:]] + body = args[1] + captured_env = env.copy() + + def fn(*fn_args): + local_env = captured_env.copy() + for pname, pval in zip(params, fn_args): + local_env[pname] = pval + return self._eval(body, local_env) + + env[fn_name] = fn + return fn + else: + # Variable definition + var_name = name_part.name if isinstance(name_part, Symbol) else str(name_part) + val = self._eval(args[1], env) + env[var_name] = val + return val + + def _eval_call(self, op: str, args: List, env: Dict[str, Any]) -> Any: + """Evaluate a function call.""" + + # Check if it's a user-defined function + if op in env and callable(env[op]): + fn = env[op] + eval_args = [self._eval(a, env) for a in args] + return fn(*eval_args) + + # Arithmetic + if op == '+': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 0.0 + for v in vals[1:]: + result = result + v + return result + + if op == '-': + if len(args) == 1: + return -self._eval(args[0], env) + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result - v + return result + + if op == '*': + vals = [self._eval(a, env) for a in args] + result = vals[0] if vals else 1.0 + for v in vals[1:]: + result = result * v + return result + + if op == '/': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result / v + return result + + if op == 'mod': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a % b + + if op == 'pow' or op == '**': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return jnp.power(a, b) + + # Comparison + if op == '<': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a < b + if op == '>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a > b + if op == '<=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a <= b + if op == '>=': + a, b = self._eval(args[0], env), self._eval(args[1], env) + return a >= b + if op == '=' or op == '==': + a, b = self._eval(args[0], env), self._eval(args[1], env) + # For scalar Python types, return Python bool to enable trace-time if + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a == b) + return a == b + if op == '!=' or op == '<>': + a, b = self._eval(args[0], env), self._eval(args[1], env) + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return bool(a != b) + return a != b + + # Logic + if op == 'and': + vals = [self._eval(a, env) for a in args] + # Use Python and for concrete Python bools (e.g., shape comparisons) + if all(isinstance(v, (bool, np.bool_)) for v in vals): + result = True + for v in vals: + result = result and bool(v) + return result + # Otherwise use JAX logical_and + result = vals[0] + for v in vals[1:]: + result = jnp.logical_and(result, v) + return result + + if op == 'or': + # Lisp-style or: returns first truthy value, not boolean + # (or a b c) returns a if a is truthy, else b if b is truthy, else c + for arg in args: + val = self._eval(arg, env) + # Check if value is truthy + if val is None: + continue + if isinstance(val, (bool, np.bool_)): + if val: + return val + continue + if isinstance(val, (int, float)): + if val: + return val + continue + if hasattr(val, 'shape'): + # JAX/numpy array - return it (considered truthy) + return val + # For other types, check truthiness + if val: + return val + # All values were falsy, return the last one + return self._eval(args[-1], env) if args else None + + if op == 'not': + val = self._eval(args[0], env) + if isinstance(val, (bool, np.bool_)): + return not bool(val) + return jnp.logical_not(val) + + # Math functions + if op == 'sqrt': + return jnp.sqrt(self._eval(args[0], env)) + if op == 'sin': + return jnp.sin(self._eval(args[0], env)) + if op == 'cos': + return jnp.cos(self._eval(args[0], env)) + if op == 'tan': + return jnp.tan(self._eval(args[0], env)) + if op == 'exp': + return jnp.exp(self._eval(args[0], env)) + if op == 'log': + return jnp.log(self._eval(args[0], env)) + if op == 'abs': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return abs(x) + return jnp.abs(x) + if op == 'floor': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.floor(x) + return jnp.floor(x) + if op == 'ceil': + import math + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return math.ceil(x) + return jnp.ceil(x) + if op == 'round': + x = self._eval(args[0], env) + if isinstance(x, (int, float)): + return round(x) + return jnp.round(x) + + # Frame primitives + if op == 'width': + return env['width'] + if op == 'height': + return env['height'] + + if op == 'channel': + frame = self._eval(args[0], env) + idx = self._eval(args[1], env) + # idx should be a Python int (literal from S-expression) + return jax_channel(frame, idx) + + if op == 'merge-channels' or op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + if op == 'sample': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + return jax_sample(frame, x, y) + + if op == 'cell-indices': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_cell_indices(frame, cell_size) + + if op == 'pool-frame': + frame = self._eval(args[0], env) + cell_size = self._eval(args[1], env) + return jax_pool_frame(frame, cell_size) + + # Xector primitives + if op == 'iota': + n = self._eval(args[0], env) + return jax_iota(int(n)) + + if op == 'repeat': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_repeat(x, int(n)) + + if op == 'tile': + x = self._eval(args[0], env) + n = self._eval(args[1], env) + return jax_tile(x, int(n)) + + if op == 'gather': + data = self._eval(args[0], env) + indices = self._eval(args[1], env) + return jax_gather(data, indices) + + if op == 'scatter': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter(indices, values, size) + + if op == 'scatter-add': + indices = self._eval(args[0], env) + values = self._eval(args[1], env) + size = int(self._eval(args[2], env)) + return jax_scatter_add(indices, values, size) + + if op == 'group-reduce': + values = self._eval(args[0], env) + groups = self._eval(args[1], env) + num_groups = int(self._eval(args[2], env)) + reduce_op = args[3] if len(args) > 3 else 'mean' + if isinstance(reduce_op, Symbol): + reduce_op = reduce_op.name + return jax_group_reduce(values, groups, num_groups, reduce_op) + + if op == 'where': + cond = self._eval(args[0], env) + true_val = self._eval(args[1], env) + false_val = self._eval(args[2], env) + # Handle None values + if true_val is None: + true_val = 0.0 + if false_val is None: + false_val = 0.0 + return jax_where(cond, true_val, false_val) + + if op == 'len' or op == 'length': + x = self._eval(args[0], env) + if isinstance(x, (list, tuple)): + return len(x) + return x.size + + # Beta reductions + if op in ('β+', 'beta+', 'sum'): + return jnp.sum(self._eval(args[0], env)) + if op in ('β*', 'beta*', 'product'): + return jnp.prod(self._eval(args[0], env)) + if op in ('βmin', 'beta-min'): + return jnp.min(self._eval(args[0], env)) + if op in ('βmax', 'beta-max'): + return jnp.max(self._eval(args[0], env)) + if op in ('βmean', 'beta-mean', 'mean'): + return jnp.mean(self._eval(args[0], env)) + if op in ('βstd', 'beta-std'): + return jnp.std(self._eval(args[0], env)) + if op in ('βany', 'beta-any'): + return jnp.any(self._eval(args[0], env)) + if op in ('βall', 'beta-all'): + return jnp.all(self._eval(args[0], env)) + + # Scan (prefix) operations + if op in ('scan+', 'scan-add'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_add(x, axis) + if op in ('scan*', 'scan-mul'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_mul(x, axis) + if op == 'scan-max': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_max(x, axis) + if op == 'scan-min': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_min(x, axis) + + # Outer product operations + if op == 'outer': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + op_type = self._eval_kwarg(args, 'op', '*', env) + return jax_outer(x, y, op_type) + if op in ('outer+', 'outer-add'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_add(x, y) + if op in ('outer*', 'outer-mul'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_mul(x, y) + if op == 'outer-max': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_max(x, y) + if op == 'outer-min': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_min(x, y) + + # Reduce with axis operations + if op == 'reduce-axis': + x = self._eval(args[0], env) + reduce_op = self._eval_kwarg(args, 'op', 'sum', env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_reduce_axis(x, reduce_op, axis) + if op == 'sum-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_sum_axis(x, axis) + if op == 'mean-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_mean_axis(x, axis) + if op == 'max-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_max_axis(x, axis) + if op == 'min-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_min_axis(x, axis) + + # Windowed operations + if op == 'window': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + win_op = self._eval_kwarg(args, 'op', 'mean', env) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window(x, size, win_op, stride) + if op == 'window-sum': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_sum(x, size, stride) + if op == 'window-mean': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_mean(x, size, stride) + if op == 'window-max': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_max(x, size, stride) + if op == 'window-min': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_min(x, size, stride) + + # Integral image + if op == 'integral-image': + frame = self._eval(args[0], env) + return jax_integral_image(frame) + + # Convenience - min/max of two values (handle both scalars and arrays) + if op == 'min' or op == 'min2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return min(a, b) + return jnp.minimum(jnp.asarray(a), jnp.asarray(b)) + if op == 'max' or op == 'max2': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + # Use Python min/max for scalar Python values to preserve type + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return max(a, b) + return jnp.maximum(jnp.asarray(a), jnp.asarray(b)) + if op == 'clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + # List operations + if op == 'list': + return tuple(self._eval(a, env) for a in args) + + if op == 'nth': + seq = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[idx] if 0 <= idx < len(seq) else None + return seq[idx] # For arrays + + if op == 'first': + seq = self._eval(args[0], env) + return seq[0] if len(seq) > 0 else None + + if op == 'second': + seq = self._eval(args[0], env) + return seq[1] if len(seq) > 1 else None + + # Random (JAX-compatible) + # Get frame_num for deterministic variation - can be traced, fold_in handles it + frame_num = env.get('_frame_num', env.get('frame_num', 0)) + # Convert to int32 for fold_in if needed (but keep as JAX array if traced) + if frame_num is None: + frame_num = 0 + elif isinstance(frame_num, (int, float)): + frame_num = int(frame_num) + # If it's a JAX array, leave it as-is for tracing + + # Increment operation counter for unique keys within same frame + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + + if op == 'rand' or op == 'rand-x': + # For size-based random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + shape = (n,) + else: + n = size.size + shape = size.shape + elif hasattr(size, 'size'): + n = size.size + shape = (n,) + else: + n = int(size) + shape = (n,) + # Use deterministic key that varies with frame + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, shape).flatten() + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.uniform(key, ()) + + if op == 'randn' or op == 'randn-x': + # Normal random + if args: + size = self._eval(args[0], env) + if hasattr(size, 'shape'): + # For frames (3D), use h*w (channel size), not h*w*c + if size.ndim == 3: + n = size.shape[0] * size.shape[1] # h * w + else: + n = size.size + elif hasattr(size, 'size'): + n = size.size + else: + n = int(size) + mean = self._eval(args[1], env) if len(args) > 1 else 0.0 + std = self._eval(args[2], env) if len(args) > 2 else 1.0 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, (n,)) * std + mean + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) + + if op == 'rand-range' or op == 'core:rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + seed = env.get('_seed', 42) + return jax_rand_range(lo, hi, frame_num, op_counter, seed) + + # ===================================================================== + # Convolution operations + # ===================================================================== + if op == 'blur' or op == 'image:blur': + frame = self._eval(args[0], env) + radius = self._eval(args[1], env) if len(args) > 1 else 1 + # Convert traced value to concrete for kernel size + if hasattr(radius, 'item'): + radius = int(radius.item()) + elif hasattr(radius, '__float__'): + radius = int(float(radius)) + else: + radius = int(radius) + return jax_blur(frame, max(1, radius)) + + if op == 'gaussian': + first_arg = self._eval(args[0], env) + # Check if first arg is a frame (blur) or scalar (random) + if hasattr(first_arg, 'shape') and first_arg.ndim == 3: + # Frame - apply gaussian blur + sigma = self._eval(args[1], env) if len(args) > 1 else 1.0 + radius = max(1, int(sigma * 3)) + return jax_blur(first_arg, radius) + else: + # Scalar args - generate gaussian random value + mean = float(first_arg) if not isinstance(first_arg, (int, float)) else first_arg + std = self._eval(args[1], env) if len(args) > 1 else 1.0 + # Return a single random value + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + return jax.random.normal(key, ()) * std + mean + + if op == 'sharpen' or op == 'image:sharpen': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + return jax_sharpen(frame, amount) + + if op == 'edge-detect' or op == 'image:edge-detect': + frame = self._eval(args[0], env) + return jax_edge_detect(frame) + + if op == 'emboss': + frame = self._eval(args[0], env) + return jax_emboss(frame) + + if op == 'convolve': + frame = self._eval(args[0], env) + kernel = self._eval(args[1], env) + # Convert kernel to array if it's a list + if isinstance(kernel, (list, tuple)): + kernel = jnp.array(kernel, dtype=jnp.float32) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'add-noise': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) if len(args) > 1 else 0.1 + h, w = frame.shape[:2] + # Use frame-varying key for noise + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + seed = env.get('_seed', 42) + key = make_jax_key(seed, frame_num, op_counter) + noise = jax.random.uniform(key, frame.shape) * 2 - 1 # [-1, 1] + result = frame.astype(jnp.float32) + noise * amount * 255 + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + if op == 'translate': + frame = self._eval(args[0], env) + dx = self._eval(args[1], env) + dy = self._eval(args[2], env) if len(args) > 2 else 0 + h, w = frame.shape[:2] + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w) + src_x = (x_coords - dx).flatten() + src_y = (y_coords - dy).flatten() + r, g, b = jax_sample(frame, src_x, src_y) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'image:crop' or op == 'crop': + frame = self._eval(args[0], env) + x = int(self._eval(args[1], env)) + y = int(self._eval(args[2], env)) + w = int(self._eval(args[3], env)) + h = int(self._eval(args[4], env)) + return frame[y:y+h, x:x+w, :] + + if op == 'dilate': + frame = self._eval(args[0], env) + size = int(self._eval(args[1], env)) if len(args) > 1 else 3 + # Simple dilation using max pooling approximation + kernel = jnp.ones((size, size), dtype=jnp.float32) / (size * size) + h, w = frame.shape[:2] + r = jax_convolve2d(frame[:, :, 0].astype(jnp.float32), kernel) * (size * size) + g = jax_convolve2d(frame[:, :, 1].astype(jnp.float32), kernel) * (size * size) + b = jax_convolve2d(frame[:, :, 2].astype(jnp.float32), kernel) * (size * size) + return jnp.stack([ + jnp.clip(r, 0, 255).astype(jnp.uint8), + jnp.clip(g, 0, 255).astype(jnp.uint8), + jnp.clip(b, 0, 255).astype(jnp.uint8) + ], axis=2) + + if op == 'map-rows': + frame = self._eval(args[0], env) + fn = args[1] # S-expression function + h, w = frame.shape[:2] + # For each row, apply the function + results = [] + for row_idx in range(h): + row_env = env.copy() + row_env['row'] = frame[row_idx, :, :] + row_env['row-idx'] = row_idx + + # Check if fn is a lambda + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + # Bind lambda params to y and row + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + row_env[param_name] = row_idx # y + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + row_env[param_name] = frame[row_idx, :, :] # row + result_row = self._eval(body, row_env) + results.append(result_row) + continue + + result_row = self._eval(fn, row_env) + # If result is a function, call it + if callable(result_row): + result_row = result_row(row_idx, frame[row_idx, :, :]) + results.append(result_row) + return jnp.stack(results, axis=0) + + # ===================================================================== + # Text rendering operations + # ===================================================================== + if op == 'text': + frame = self._eval(args[0], env) + text_str = self._eval(args[1], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + + # Extract keyword arguments + x = self._eval_kwarg(args, 'x', None, env) + y = self._eval_kwarg(args, 'y', None, env) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + color = self._eval_kwarg(args, 'color', (255, 255, 255), env) + opacity = self._eval_kwarg(args, 'opacity', 1.0, env) + align = self._eval_kwarg(args, 'align', 'left', env) + valign = self._eval_kwarg(args, 'valign', 'top', env) + shadow = self._eval_kwarg(args, 'shadow', False, env) + shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env) + shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env) + fit = self._eval_kwarg(args, 'fit', False, env) + width = self._eval_kwarg(args, 'width', None, env) + height = self._eval_kwarg(args, 'height', None, env) + + # Handle color as list/tuple + if isinstance(color, (list, tuple)): + color = tuple(int(c) for c in color[:3]) + if isinstance(shadow_color, (list, tuple)): + shadow_color = tuple(int(c) for c in shadow_color[:3]) + + h, w_frame = frame.shape[:2] + + # Default position to 0,0 or center based on alignment + if x is None: + if align == 'center': + x = w_frame // 2 + elif align == 'right': + x = w_frame + else: + x = 0 + if y is None: + if valign == 'middle': + y = h // 2 + elif valign == 'bottom': + y = h + else: + y = 0 + + # Auto-fit text to bounds + if fit and width is not None and height is not None: + font_size = jax_fit_text_size(text_str, int(width), int(height), + font_name, min_size=8, max_size=200) + + return jax_text_render(frame, text_str, int(x), int(y), + font_name=font_name, font_size=int(font_size), + color=color, opacity=float(opacity), + align=str(align), valign=str(valign), + shadow=bool(shadow), shadow_color=shadow_color, + shadow_offset=int(shadow_offset)) + + if op == 'text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_text_size(text_str, font_name, int(font_size)) + + if op == 'fit-text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + max_width = int(self._eval(args[1], env)) + max_height = int(self._eval(args[2], env)) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_fit_text_size(text_str, max_width, max_height, font_name) + + # ===================================================================== + # Color operations + # ===================================================================== + if op == 'rgb->hsv' or op == 'rgb-to-hsv': + # Handle both (rgb->hsv r g b) and (rgb->hsv c) where c is tuple + if len(args) == 1: + c = self._eval(args[0], env) + if isinstance(c, tuple) and len(c) == 3: + r, g, b = c + else: + # Assume it's a list-like + r, g, b = c[0], c[1], c[2] + else: + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + return jax_rgb_to_hsv(r, g, b) + + if op == 'hsv->rgb' or op == 'hsv-to-rgb': + # Handle both (hsv->rgb h s v) and (hsv->rgb hsv-list) + if len(args) == 1: + hsv = self._eval(args[0], env) + if isinstance(hsv, (tuple, list)) and len(hsv) >= 3: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h, s, v = hsv[0], hsv[1], hsv[2] + else: + h = self._eval(args[0], env) + s = self._eval(args[1], env) + v = self._eval(args[2], env) + return jax_hsv_to_rgb(h, s, v) + + if op == 'adjust-brightness' or op == 'color_ops:adjust-brightness': + frame = self._eval(args[0], env) + amount = self._eval(args[1], env) + return jax_adjust_brightness(frame, amount) + + if op == 'adjust-contrast' or op == 'color_ops:adjust-contrast': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_contrast(frame, factor) + + if op == 'adjust-saturation' or op == 'color_ops:adjust-saturation': + frame = self._eval(args[0], env) + factor = self._eval(args[1], env) + return jax_adjust_saturation(frame, factor) + + if op == 'shift-hsv' or op == 'color_ops:shift-hsv' or op == 'hue-shift': + frame = self._eval(args[0], env) + degrees = self._eval(args[1], env) + return jax_shift_hue(frame, degrees) + + if op == 'invert' or op == 'invert-img' or op == 'color_ops:invert-img': + frame = self._eval(args[0], env) + return jax_invert(frame) + + if op == 'posterize' or op == 'color_ops:posterize': + frame = self._eval(args[0], env) + levels = self._eval(args[1], env) + return jax_posterize(frame, levels) + + if op == 'threshold' or op == 'color_ops:threshold': + frame = self._eval(args[0], env) + level = self._eval(args[1], env) + invert = self._eval(args[2], env) if len(args) > 2 else False + return jax_threshold(frame, level, invert) + + if op == 'sepia' or op == 'color_ops:sepia': + frame = self._eval(args[0], env) + return jax_sepia(frame) + + if op == 'grayscale' or op == 'image:grayscale': + frame = self._eval(args[0], env) + return jax_grayscale(frame) + + # ===================================================================== + # Geometry operations + # ===================================================================== + if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img': + frame = self._eval(args[0], env) + direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal' + if direction == 'vertical' or direction == 'v': + return jax_flip_vertical(frame) + return jax_flip_horizontal(frame) + + if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v': + frame = self._eval(args[0], env) + return jax_flip_vertical(frame) + + if op == 'rotate' or op == 'rotate-img' or op == 'geometry:rotate-img': + frame = self._eval(args[0], env) + angle = self._eval(args[1], env) + return jax_rotate(frame, angle) + + if op == 'scale' or op == 'scale-img' or op == 'geometry:scale-img': + frame = self._eval(args[0], env) + scale_x = self._eval(args[1], env) + scale_y = self._eval(args[2], env) if len(args) > 2 else None + return jax_scale(frame, scale_x, scale_y) + + if op == 'resize' or op == 'image:resize': + frame = self._eval(args[0], env) + new_w = self._eval(args[1], env) + new_h = self._eval(args[2], env) + return jax_resize(frame, new_w, new_h) + + # ===================================================================== + # Geometry distortion effects + # ===================================================================== + if op == 'geometry:fisheye-coords' or op == 'fisheye': + # Signature: (w h strength cx cy zoom_correct) or (frame strength) + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + # (w h strength cx cy zoom_correct) signature + w = int(first_arg) + h = int(self._eval(args[1], env)) + strength = self._eval(args[2], env) if len(args) > 2 else 0.5 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = first_arg + strength = self._eval(args[1], env) if len(args) > 1 else 0.5 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Fisheye distortion + r_new = r + strength * r * (1 - r / max_r) + + src_x = r_new * jnp.cos(theta) + cx + src_y = r_new * jnp.sin(theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:swirl-coords' or op == 'swirl': + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + w = int(first_arg) + h = int(self._eval(args[1], env)) + amount = self._eval(args[2], env) if len(args) > 2 else 1.0 + frame = None + else: + frame = first_arg + amount = self._eval(args[1], env) if len(args) > 1 else 1.0 + h, w = frame.shape[:2] + + cx, cy = w / 2, h / 2 + max_r = jnp.sqrt(float(cx*cx + cy*cy)) + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + swirl_angle = amount * (1 - r / max_r) + new_theta = theta + swirl_angle + + src_x = r * jnp.cos(new_theta) + cx + src_y = r * jnp.sin(new_theta) + cy + + if frame is None: + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Wave effect (frame-first signature for simple usage) + if op == 'wave-distort': + first_arg = self._eval(args[0], env) + frame = first_arg + amp_x = float(self._eval(args[1], env)) if len(args) > 1 else 10.0 + amp_y = float(self._eval(args[2], env)) if len(args) > 2 else 10.0 + freq_x = float(self._eval(args[3], env)) if len(args) > 3 else 0.1 + freq_y = float(self._eval(args[4], env)) if len(args) > 4 else 0.1 + h, w = frame.shape[:2] + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + src_x = x_coords + amp_x * jnp.sin(y_coords * freq_y) + src_y = y_coords + amp_y * jnp.sin(x_coords * freq_x) + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:ripple-displace' or op == 'ripple': + # Match Python prim_ripple_displace signature: + # (w h freq amp cx cy decay phase) or (frame ...) + first_arg = self._eval(args[0], env) + if not hasattr(first_arg, 'shape'): + # Coordinate-only mode: (w h freq amp cx cy decay phase) + w = int(first_arg) + h = int(self._eval(args[1], env)) + freq = self._eval(args[2], env) if len(args) > 2 else 5.0 + amp = self._eval(args[3], env) if len(args) > 3 else 10.0 + cx = self._eval(args[4], env) if len(args) > 4 else w / 2 + cy = self._eval(args[5], env) if len(args) > 5 else h / 2 + decay = self._eval(args[6], env) if len(args) > 6 else 0.0 + phase = self._eval(args[7], env) if len(args) > 7 else 0.0 + frame = None + else: + # Frame mode: (frame :amplitude A :frequency F :center_x CX ...) + frame = first_arg + h, w = frame.shape[:2] + # Parse keyword args + amp = 10.0 + freq = 5.0 + cx = w / 2 + cy = h / 2 + decay = 0.0 + phase = 0.0 + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + kw = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + if kw == 'amplitude': + amp = val + elif kw == 'frequency': + freq = val + elif kw == 'center_x': + cx = val * w if val <= 1 else val # normalized or absolute + elif kw == 'center_y': + cy = val * h if val <= 1 else val + elif kw == 'decay': + decay = val + elif kw == 'speed': + # speed affects phase via time + t = env.get('t', 0) + phase = t * val * 2 * jnp.pi + elif kw == 'phase': + phase = val + i += 2 + else: + i += 1 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + dist = jnp.sqrt(dx*dx + dy*dy) + + # Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp + max_dim = jnp.maximum(w, h) + ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp + + # Apply decay (when decay=0, exp(0)=1 so no effect) + decay_factor = jnp.exp(-decay * dist / max_dim) + ripple = ripple * decay_factor + + # Radial displacement - use ADDITION to match Python prim_ripple_displace + # Python (primitives.py line 2890-2891): + # map_x = x_coords + ripple * norm_dx + # map_y = y_coords + ripple * norm_dy + # where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle) + angle = jnp.arctan2(dy, dx) + src_x = x_coords + ripple * jnp.cos(angle) + src_y = y_coords + ripple * jnp.sin(angle) + + if frame is None: + return {'x': src_x, 'y': src_y} + + # Sample using bilinear interpolation (jax_sample clamps coords, + # matching OpenCV's default remap behavior) + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:coords-x' or op == 'coords-x': + # Extract x coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('x', coords.get('map_x')) + return coords[0] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:coords-y' or op == 'coords-y': + # Extract y coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('y', coords.get('map_y')) + return coords[1] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:remap' or op == 'remap': + # Remap image using coordinate maps: (frame map_x map_y) + # OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords + frame = self._eval(args[0], env) + map_x = self._eval(args[1], env) + map_y = self._eval(args[2], env) + + h, w = frame.shape[:2] + + # Flatten coordinate maps + src_x = map_x.flatten() + src_y = map_y.flatten() + + # Sample using bilinear interpolation (jax_sample clamps coords internally, + # matching OpenCV's default behavior) + r_out, g_out, b_out = jax_sample(frame, src_x, src_y) + + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:kaleidoscope-coords' or op == 'kaleidoscope': + # Two signatures: (frame segments) or (w h segments cx cy) + if len(args) >= 3 and not hasattr(self._eval(args[0], env), 'shape'): + # (w h segments cx cy) signature + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + segments = int(self._eval(args[2], env)) if len(args) > 2 else 6 + cx = self._eval(args[3], env) if len(args) > 3 else w / 2 + cy = self._eval(args[4], env) if len(args) > 4 else h / 2 + frame = None + else: + frame = self._eval(args[0], env) + segments = int(self._eval(args[1], env)) if len(args) > 1 else 6 + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + dx = x_coords - cx + dy = y_coords - cy + r = jnp.sqrt(dx*dx + dy*dy) + theta = jnp.arctan2(dy, dx) + + # Mirror into segments + segment_angle = 2 * jnp.pi / segments + theta_mod = theta % segment_angle + theta_mirror = jnp.where( + (jnp.floor(theta / segment_angle) % 2) == 0, + theta_mod, + segment_angle - theta_mod + ) + + src_x = r * jnp.cos(theta_mirror) + cx + src_y = r * jnp.sin(theta_mirror) + cy + + if frame is None: + # Return coordinate arrays + return {'x': src_x, 'y': src_y} + + r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # Geometry coordinate extraction + if op == 'geometry:coords-x': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['x'] + return coords[0] if isinstance(coords, tuple) else coords + + if op == 'geometry:coords-y': + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords['y'] + return coords[1] if isinstance(coords, tuple) else coords + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x_coords = self._eval(args[1], env) + y_coords = self._eval(args[2], env) + h, w = frame.shape[:2] + r_out, g_out, b_out = jax_sample(frame, x_coords.flatten(), y_coords.flatten()) + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blend' or op == 'blend-images' or op == 'blending:blend-images': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(frame1, frame2, alpha) + + if op == 'blend-add': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_add(frame1, frame2) + + if op == 'blend-multiply': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_multiply(frame1, frame2) + + if op == 'blend-screen': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_screen(frame1, frame2) + + if op == 'blend-overlay': + frame1 = self._eval(args[0], env) + frame2 = self._eval(args[1], env) + return jax_blend_overlay(frame1, frame2) + + # ===================================================================== + # Image dimension queries (namespaced aliases) + # ===================================================================== + if op == 'image:width': + if args: + frame = self._eval(args[0], env) + return frame.shape[1] # width is second dimension (h, w, c) + return env['width'] + + if op == 'image:height': + if args: + frame = self._eval(args[0], env) + return frame.shape[0] # height is first dimension (h, w, c) + return env['height'] + + # ===================================================================== + # Utility + # ===================================================================== + if op == 'is-nil' or op == 'core:is-nil' or op == 'nil?': + x = self._eval(args[0], env) + return jax_is_nil(x) + + # ===================================================================== + # Xector channel operations (shortcuts) + # ===================================================================== + if op == 'red': + val = self._eval(args[0], env) + # Works on frames or pixel tuples + if isinstance(val, tuple): + return val[0] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 0) + else: + return val # Assume it's already a channel + + if op == 'green': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[1] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 1) + else: + return val + + if op == 'blue': + val = self._eval(args[0], env) + if isinstance(val, tuple): + return val[2] + elif hasattr(val, 'shape') and val.ndim == 3: + return jax_channel(val, 2) + else: + return val + + if op == 'gray' or op == 'luminance': + val = self._eval(args[0], env) + # Handle tuple (r, g, b) from map-pixels + if isinstance(val, tuple) and len(val) == 3: + r, g, b = val + return r * 0.299 + g * 0.587 + b * 0.114 + # Handle frame + frame = val + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + return r * 0.299 + g * 0.587 + b * 0.114 + + if op == 'rgb': + r = self._eval(args[0], env) + g = self._eval(args[1], env) + b = self._eval(args[2], env) + # For scalars (e.g., in map-pixels), return tuple + r_is_scalar = isinstance(r, (int, float)) or (hasattr(r, 'shape') and r.shape == ()) + g_is_scalar = isinstance(g, (int, float)) or (hasattr(g, 'shape') and g.shape == ()) + b_is_scalar = isinstance(b, (int, float)) or (hasattr(b, 'shape') and b.shape == ()) + if r_is_scalar and g_is_scalar and b_is_scalar: + return (r, g, b) + return jax_merge_channels(r, g, b, env['_shape']) + + # ===================================================================== + # Coordinate operations + # ===================================================================== + if op == 'x-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + + if op == 'y-coords': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + return jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + if op == 'dist-from-center': + frame = self._eval(args[0], env) + h, w = frame.shape[:2] + cx, cy = w / 2, h / 2 + x = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) - cx + y = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) - cy + return jnp.sqrt(x*x + y*y) + + # ===================================================================== + # Alpha operations (element-wise on xectors) + # ===================================================================== + if op == 'α/' or op == 'alpha/': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a / b + + if op == 'α+' or op == 'alpha+': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result + v + return result + + if op == 'α*' or op == 'alpha*': + vals = [self._eval(a, env) for a in args] + result = vals[0] + for v in vals[1:]: + result = result * v + return result + + if op == 'α-' or op == 'alpha-': + if len(args) == 1: + return -self._eval(args[0], env) + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a - b + + if op == 'αclamp' or op == 'alpha-clamp': + x = self._eval(args[0], env) + lo = self._eval(args[1], env) + hi = self._eval(args[2], env) + return jnp.clip(x, lo, hi) + + if op == 'αmin' or op == 'alpha-min': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.minimum(a, b) + + if op == 'αmax' or op == 'alpha-max': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.maximum(a, b) + + if op == 'αsqrt' or op == 'alpha-sqrt': + return jnp.sqrt(self._eval(args[0], env)) + + if op == 'αsin' or op == 'alpha-sin': + return jnp.sin(self._eval(args[0], env)) + + if op == 'αcos' or op == 'alpha-cos': + return jnp.cos(self._eval(args[0], env)) + + if op == 'αmod' or op == 'alpha-mod': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a % b + + if op == 'α²' or op == 'αsq' or op == 'alpha-sq': + x = self._eval(args[0], env) + return x * x + + if op == 'α<' or op == 'alpha<': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a < b + + if op == 'α>' or op == 'alpha>': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a > b + + if op == 'α<=' or op == 'alpha<=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a <= b + + if op == 'α>=' or op == 'alpha>=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a >= b + + if op == 'α=' or op == 'alpha=': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return a == b + + if op == 'αfloor' or op == 'alpha-floor': + return jnp.floor(self._eval(args[0], env)) + + if op == 'αceil' or op == 'alpha-ceil': + return jnp.ceil(self._eval(args[0], env)) + + if op == 'αround' or op == 'alpha-round': + return jnp.round(self._eval(args[0], env)) + + if op == 'αabs' or op == 'alpha-abs': + return jnp.abs(self._eval(args[0], env)) + + if op == 'αexp' or op == 'alpha-exp': + return jnp.exp(self._eval(args[0], env)) + + if op == 'αlog' or op == 'alpha-log': + return jnp.log(self._eval(args[0], env)) + + if op == 'αor' or op == 'alpha-or': + # Element-wise logical OR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_or(a, b) + + if op == 'αand' or op == 'alpha-and': + # Element-wise logical AND + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_and(a, b) + + if op == 'αnot' or op == 'alpha-not': + # Element-wise logical NOT + return jnp.logical_not(self._eval(args[0], env)) + + if op == 'αxor' or op == 'alpha-xor': + # Element-wise logical XOR + a = self._eval(args[0], env) + b = self._eval(args[1], env) + return jnp.logical_xor(a, b) + + # ===================================================================== + # Threading/arrow operations + # ===================================================================== + if op == '->': + # Thread-first macro: (-> x (f a) (g b)) = (g (f x a) b) + val = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list): + # Insert val as first argument + fn_name = form[0].name if isinstance(form[0], Symbol) else form[0] + new_args = [val] + [self._eval(a, env) for a in form[1:]] + val = self._eval_call(fn_name, [val] + form[1:], env) + else: + # Simple function call + fn_name = form.name if isinstance(form, Symbol) else form + val = self._eval_call(fn_name, [args[0]], env) + return val + + # ===================================================================== + # Range and iteration + # ===================================================================== + if op == 'range': + if len(args) == 1: + end = int(self._eval(args[0], env)) + return list(range(end)) + elif len(args) == 2: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + return list(range(start, end)) + else: + start = int(self._eval(args[0], env)) + end = int(self._eval(args[1], env)) + step = int(self._eval(args[2], env)) + return list(range(start, end, step)) + + if op == 'reduce' or op == 'fold': + # (reduce seq init fn) - left fold + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for item in seq: + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + acc = self._eval(body, fn_env) + return acc + + # Fallback - try evaluating fn and calling it + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for item in seq: + acc = fn_eval(acc, item) + return acc + + if op == 'fold-indexed': + # (fold-indexed seq init fn) - fold with index + # fn takes (acc item index) or (acc item index cursor) for typography + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for idx, item in enumerate(seq): + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + fn_env[param_name] = idx + acc = self._eval(body, fn_env) + return acc + + # Fallback + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for idx, item in enumerate(seq): + acc = fn_eval(acc, item, idx) + return acc + + # ===================================================================== + # Map-pixels (apply function to each pixel) + # ===================================================================== + if op == 'map-pixels': + frame = self._eval(args[0], env) + fn = args[1] # Lambda or S-expression + h, w = frame.shape[:2] + + # Extract channels + r = frame[:, :, 0].flatten().astype(jnp.float32) + g = frame[:, :, 1].flatten().astype(jnp.float32) + b = frame[:, :, 2].flatten().astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w, dtype=jnp.float32), h) + y_coords = jnp.repeat(jnp.arange(h, dtype=jnp.float32), w) + + # Set up pixel environment + pixel_env = env.copy() + pixel_env['r'] = r + pixel_env['g'] = g + pixel_env['b'] = b + pixel_env['x'] = x_coords + pixel_env['y'] = y_coords + # Also provide c (color) as a tuple for lambda (x y c) style + pixel_env['c'] = (r, g, b) + + # If fn is a lambda, we need to handle it specially + if isinstance(fn, list) and len(fn) >= 2: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + # Lambda: (lambda (x y c) body) + params = fn[1] + body = fn[2] + # Bind parameters + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + pixel_env[param_name] = x_coords + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + pixel_env[param_name] = y_coords + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + pixel_env[param_name] = (r, g, b) + result = self._eval(body, pixel_env) + else: + result = self._eval(fn, pixel_env) + else: + result = self._eval(fn, pixel_env) + + if isinstance(result, tuple) and len(result) == 3: + nr, ng, nb = result + return jax_merge_channels(nr, ng, nb, (h, w)) + elif hasattr(result, 'shape') and result.ndim == 3: + return result + else: + # Single channel result + if hasattr(result, 'flatten'): + result = result.flatten() + gray = jnp.clip(result, 0, 255).reshape(h, w).astype(jnp.uint8) + return jnp.stack([gray, gray, gray], axis=2) + + # ===================================================================== + # State operations (return unchanged for stateless JIT) + # ===================================================================== + if op == 'state-get': + key = self._eval(args[0], env) + default = self._eval(args[1], env) if len(args) > 1 else None + return default # State not supported in JIT, return default + + if op == 'state-set': + return None # No-op in JIT + + # ===================================================================== + # Cell/grid operations + # ===================================================================== + if op == 'local-x-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size) / max(1, cell_size - 1) + + if op == 'local-y-norm': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size) / max(1, cell_size - 1) + + if op == 'local-x': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return (x % cell_size).astype(jnp.float32) + + if op == 'local-y': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return (y % cell_size).astype(jnp.float32) + + if op == 'cell-row': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + y = jnp.repeat(jnp.arange(h), w) + return jnp.floor(y / cell_size) + + if op == 'cell-col': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + h, w = frame.shape[:2] + x = jnp.tile(jnp.arange(w), h) + return jnp.floor(x / cell_size) + + if op == 'num-rows': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[0] // cell_size + + if op == 'num-cols': + frame = self._eval(args[0], env) + cell_size = int(self._eval(args[1], env)) + return frame.shape[1] // cell_size + + # ===================================================================== + # Control flow + # ===================================================================== + if op == 'cond': + # (cond (test1 expr1) (test2 expr2) ... (else exprN)) + # For JAX compatibility, build a nested jnp.where structure + # Start from the else clause and work backwards + + # Collect clauses + clauses = [] + else_expr = None + for clause in args: + if isinstance(clause, list) and len(clause) >= 2: + test = clause[0] + if isinstance(test, Symbol) and test.name == 'else': + else_expr = clause[1] + else: + clauses.append((test, clause[1])) + + # If no else, default to None/0 + if else_expr is not None: + result = self._eval(else_expr, env) + else: + result = 0 + + # Build nested where from last to first + for test_expr, val_expr in reversed(clauses): + cond_val = self._eval(test_expr, env) + then_val = self._eval(val_expr, env) + + # Check if condition is array or scalar + if hasattr(cond_val, 'shape') and cond_val.shape != (): + # Array condition - use jnp.where + result = jnp.where(cond_val, then_val, result) + else: + # Scalar - can use Python if + if cond_val: + result = then_val + + return result + + if op == 'set!' or op == 'set': + # Mutation - not really supported in JAX, but we can update env + var = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + val = self._eval(args[1], env) + env[var] = val + return val + + if op == 'begin' or op == 'do': + # Evaluate all expressions, return last + result = None + for expr in args: + result = self._eval(expr, env) + return result + + # ===================================================================== + # Additional math + # ===================================================================== + if op == 'sq' or op == 'square': + x = self._eval(args[0], env) + return x * x + + if op == 'lerp': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + t = self._eval(args[2], env) + return a * (1 - t) + b * t + + if op == 'smoothstep': + edge0 = self._eval(args[0], env) + edge1 = self._eval(args[1], env) + x = self._eval(args[2], env) + t = jnp.clip((x - edge0) / (edge1 - edge0), 0, 1) + return t * t * (3 - 2 * t) + + if op == 'atan2': + y = self._eval(args[0], env) + x = self._eval(args[1], env) + return jnp.arctan2(y, x) + + if op == 'fract' or op == 'frac': + x = self._eval(args[0], env) + return x - jnp.floor(x) + + # ===================================================================== + # Frame copy and construction operations + # ===================================================================== + if op == 'pixel': + # Get pixel at (x, y) from frame + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + # Convert to int and clip to bounds + if isinstance(x, (int, float)): + x = max(0, min(int(x), w - 1)) + else: + x = jnp.clip(x, 0, w - 1).astype(jnp.int32) + if isinstance(y, (int, float)): + y = max(0, min(int(y), h - 1)) + else: + y = jnp.clip(y, 0, h - 1).astype(jnp.int32) + r = frame[y, x, 0] + g = frame[y, x, 1] + b = frame[y, x, 2] + return (r, g, b) + + if op == 'copy': + frame = self._eval(args[0], env) + return frame.copy() if hasattr(frame, 'copy') else jnp.array(frame) + + if op == 'make-image': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + if len(args) > 2: + color = self._eval(args[2], env) + if isinstance(color, (list, tuple)): + r, g, b = int(color[0]), int(color[1]), int(color[2]) + else: + r = g = b = int(color) + else: + r = g = b = 0 + img = jnp.zeros((h, w, 3), dtype=jnp.uint8) + img = img.at[:, :, 0].set(r) + img = img.at[:, :, 1].set(g) + img = img.at[:, :, 2].set(b) + return img + + if op == 'paste': + dest = self._eval(args[0], env) + src = self._eval(args[1], env) + x = int(self._eval(args[2], env)) + y = int(self._eval(args[3], env)) + sh, sw = src.shape[:2] + dh, dw = dest.shape[:2] + # Clip to dest bounds + x1, y1 = max(0, x), max(0, y) + x2, y2 = min(dw, x + sw), min(dh, y + sh) + sx1, sy1 = x1 - x, y1 - y + sx2, sy2 = sx1 + (x2 - x1), sy1 + (y2 - y1) + result = dest.copy() if hasattr(dest, 'copy') else jnp.array(dest) + result = result.at[y1:y2, x1:x2, :].set(src[sy1:sy2, sx1:sx2, :]) + return result + + # ===================================================================== + # Blending operations + # ===================================================================== + if op == 'blending:blend-images' or op == 'blend-images': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + alpha = self._eval(args[2], env) if len(args) > 2 else 0.5 + return jax_blend(a, b, alpha) + + if op == 'blending:blend-mode' or op == 'blend-mode': + a = self._eval(args[0], env) + b = self._eval(args[1], env) + mode = self._eval(args[2], env) if len(args) > 2 else 'add' + if mode == 'add': + return jax_blend_add(a, b) + elif mode == 'multiply': + return jax_blend_multiply(a, b) + elif mode == 'screen': + return jax_blend_screen(a, b) + elif mode == 'overlay': + return jax_blend_overlay(a, b) + elif mode == 'lighten': + return jnp.maximum(a, b) + elif mode == 'darken': + return jnp.minimum(a, b) + elif mode == 'difference': + return jnp.abs(a.astype(jnp.int16) - b.astype(jnp.int16)).astype(jnp.uint8) + else: + return jax_blend(a, b, 0.5) + + # ===================================================================== + # Geometry coordinate operations + # ===================================================================== + if op == 'geometry:wave-coords' or op == 'wave-coords': + w = int(self._eval(args[0], env)) + h = int(self._eval(args[1], env)) + axis = self._eval(args[2], env) if len(args) > 2 else 'x' + freq = self._eval(args[3], env) if len(args) > 3 else 1.0 + amplitude = self._eval(args[4], env) if len(args) > 4 else 10.0 + phase = self._eval(args[5], env) if len(args) > 5 else 0.0 + + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + + if axis == 'x' or axis == 'horizontal': + # Wave displaces X based on Y + offset = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + src_x = x_coords + offset + src_y = y_coords + elif axis == 'y' or axis == 'vertical': + # Wave displaces Y based on X + offset = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + src_y = y_coords + offset + else: # both + offset_x = amplitude * jnp.sin(2 * jnp.pi * freq * y_coords / h + phase) + offset_y = amplitude * jnp.sin(2 * jnp.pi * freq * x_coords / w + phase) + src_x = x_coords + offset_x + src_y = y_coords + offset_y + + return {'x': src_x, 'y': src_y} + + if op == 'geometry:coords-x' or op == 'coords-x': + coords = self._eval(args[0], env) + return coords['x'] + + if op == 'geometry:coords-y' or op == 'coords-y': + coords = self._eval(args[0], env) + return coords['y'] + + if op == 'geometry:remap' or op == 'remap': + frame = self._eval(args[0], env) + x = self._eval(args[1], env) + y = self._eval(args[2], env) + h, w = frame.shape[:2] + r, g, b = jax_sample(frame, x.flatten(), y.flatten()) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + # ===================================================================== + # Glitch effects + # ===================================================================== + if op == 'pixelsort': + frame = self._eval(args[0], env) + sort_by = self._eval(args[1], env) if len(args) > 1 else 'lightness' + thresh_lo = int(self._eval(args[2], env)) if len(args) > 2 else 50 + thresh_hi = int(self._eval(args[3], env)) if len(args) > 3 else 200 + angle = int(self._eval(args[4], env)) if len(args) > 4 else 0 + reverse = self._eval(args[5], env) if len(args) > 5 else False + + h, w = frame.shape[:2] + result = frame.copy() + + # Get luminance for thresholding + lum = (frame[:, :, 0].astype(jnp.float32) * 0.299 + + frame[:, :, 1].astype(jnp.float32) * 0.587 + + frame[:, :, 2].astype(jnp.float32) * 0.114) + + # Sort each row + for y in range(h): + row_lum = lum[y, :] + row = frame[y, :, :] + + # Find mask of pixels to sort + mask = (row_lum >= thresh_lo) & (row_lum <= thresh_hi) + + # Get indices where we should sort + sort_indices = jnp.where(mask, jnp.arange(w), -1) + + # Simple sort by luminance for the row + if sort_by == 'lightness': + sort_key = row_lum + elif sort_by == 'hue': + # Approximate hue from RGB + sort_key = jnp.arctan2(row[:, 1].astype(jnp.float32) - row[:, 2].astype(jnp.float32), + row[:, 0].astype(jnp.float32) - 0.5 * (row[:, 1].astype(jnp.float32) + row[:, 2].astype(jnp.float32))) + else: + sort_key = row_lum + + # Sort pixels in masked region + sorted_indices = jnp.argsort(sort_key) + if reverse: + sorted_indices = sorted_indices[::-1] + + # Apply partial sort (only where mask is true) + # This is a simplified version - full pixelsort is more complex + result = result.at[y, :, :].set(row[sorted_indices]) + + return result + + if op == 'datamosh': + frame = self._eval(args[0], env) + prev = self._eval(args[1], env) + block_size = int(self._eval(args[2], env)) if len(args) > 2 else 32 + corruption = float(self._eval(args[3], env)) if len(args) > 3 else 0.3 + max_offset = int(self._eval(args[4], env)) if len(args) > 4 else 50 + color_corrupt = self._eval(args[5], env) if len(args) > 5 else True + + h, w = frame.shape[:2] + + # Use deterministic random for JIT with frame variation + seed = env.get('_seed', 42) + op_counter = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_counter + 1 + key = make_jax_key(seed, frame_num, op_counter) + + num_blocks_y = h // block_size + num_blocks_x = w // block_size + total_blocks = num_blocks_y * num_blocks_x + + # Pre-generate all random values at once (vectorized) + key, k1, k2, k3, k4, k5 = jax.random.split(key, 6) + corrupt_mask = jax.random.uniform(k1, (total_blocks,)) < corruption + offsets_y = jax.random.randint(k2, (total_blocks,), -max_offset, max_offset + 1) + offsets_x = jax.random.randint(k3, (total_blocks,), -max_offset, max_offset + 1) + channels = jax.random.randint(k4, (total_blocks,), 0, 3) + color_shifts = jax.random.randint(k5, (total_blocks,), -50, 51) + + # Create coordinate grids for blocks + by_grid = jnp.arange(num_blocks_y) + bx_grid = jnp.arange(num_blocks_x) + + # Create block coordinate arrays + by_coords = jnp.repeat(by_grid, num_blocks_x) # [0,0,0..., 1,1,1..., ...] + bx_coords = jnp.tile(bx_grid, num_blocks_y) # [0,1,2..., 0,1,2..., ...] + + # Create pixel coordinate grids + y_coords, x_coords = jnp.mgrid[:h, :w] + + # Determine which block each pixel belongs to + pixel_block_y = y_coords // block_size + pixel_block_x = x_coords // block_size + pixel_block_idx = pixel_block_y * num_blocks_x + pixel_block_x + + # Clamp to valid block indices (for pixels outside the block grid) + pixel_block_idx = jnp.clip(pixel_block_idx, 0, total_blocks - 1) + + # Get the corrupt mask, offsets for each pixel's block + pixel_corrupt = corrupt_mask[pixel_block_idx] + pixel_offset_y = offsets_y[pixel_block_idx] + pixel_offset_x = offsets_x[pixel_block_idx] + + # Calculate source coordinates with offset (clamped) + src_y = jnp.clip(y_coords + pixel_offset_y, 0, h - 1) + src_x = jnp.clip(x_coords + pixel_offset_x, 0, w - 1) + + # Sample from previous frame at offset positions + prev_sampled = prev[src_y, src_x, :] + + # Where corrupt mask is true, use prev_sampled; else use frame + result = jnp.where(pixel_corrupt[:, :, None], prev_sampled, frame) + + # Apply color corruption to corrupted blocks + if color_corrupt: + pixel_channel = channels[pixel_block_idx] + pixel_shift = color_shifts[pixel_block_idx].astype(jnp.int16) + + # Create per-channel shift arrays (only shift the selected channel) + shift_r = jnp.where((pixel_channel == 0) & pixel_corrupt, pixel_shift, 0) + shift_g = jnp.where((pixel_channel == 1) & pixel_corrupt, pixel_shift, 0) + shift_b = jnp.where((pixel_channel == 2) & pixel_corrupt, pixel_shift, 0) + + result_int = result.astype(jnp.int16) + result_int = result_int.at[:, :, 0].add(shift_r) + result_int = result_int.at[:, :, 1].add(shift_g) + result_int = result_int.at[:, :, 2].add(shift_b) + result = jnp.clip(result_int, 0, 255).astype(jnp.uint8) + + return result + + # ===================================================================== + # ASCII Art Operations (using pre-rendered font atlas) + # ===================================================================== + + if op == 'cell-sample': + # (cell-sample frame char_size) -> (colors, luminances) + # Downsample frame into cells, return average colors and luminances + frame = self._eval(args[0], env) + char_size = int(self._eval(args[1], env)) if len(args) > 1 else 8 + + h, w = frame.shape[:2] + num_rows = h // char_size + num_cols = w // char_size + + # Crop to exact multiple of char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + + # Reshape to (num_rows, char_size, num_cols, char_size, 3) + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + + # Average over char_size dimensions -> (num_rows, num_cols, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminance per cell + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + return (colors, luminances) + + if op == 'luminance-to-chars': + # (luminance-to-chars luminances alphabet contrast) -> char_indices + # Map luminance values to character indices + luminances = self._eval(args[0], env) + alphabet = self._eval(args[1], env) if len(args) > 1 else 'standard' + contrast = float(self._eval(args[2], env)) if len(args) > 2 else 1.5 + + # Get alphabet string + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Apply contrast + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + + # Map to character indices (0 = darkest, num_chars-1 = brightest) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + return char_indices + + if op == 'render-char-grid': + # (render-char-grid frame chars colors char_size color_mode background_color invert_colors) + # Render character grid using font atlas + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + char_size = int(self._eval(args[3], env)) if len(args) > 3 else 8 + color_mode = self._eval(args[4], env) if len(args) > 4 else 'color' + background_color = self._eval(args[5], env) if len(args) > 5 else 'black' + invert_colors = self._eval(args[6], env) if len(args) > 6 else 0 + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get the alphabet used (stored in env or default) + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + + # Get or create font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + # Try to parse hex color + try: + if background_color.startswith('#'): + bg_hex = background_color[1:] + bg = jnp.array([int(bg_hex[0:2], 16), + int(bg_hex[2:4], 16), + int(bg_hex[4:6], 16)], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + except: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create output image starting with background + output_h = num_rows * char_size + output_w = num_cols * char_size + result = jnp.broadcast_to(bg, (output_h, output_w, 3)).copy() + + # Gather characters from atlas based on indices + # char_indices shape: (num_rows, num_cols) + # font_atlas shape: (num_chars, char_size, char_size, 3) + # Convert numpy atlas to JAX for indexing with traced indices + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices] # (num_rows*num_cols, char_size, char_size, 3) + + # Reshape to grid + char_tiles = char_tiles.reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create coordinate grids for output pixels + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = y_out // char_size + cell_col = x_out // char_size + local_y = y_out % char_size + local_x = x_out % char_size + + # Clamp to valid ranges + cell_row = jnp.clip(cell_row, 0, num_rows - 1) + cell_col = jnp.clip(cell_col, 0, num_cols - 1) + + # Get character pixel values + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + + # Get character brightness (for masking) + char_brightness = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Handle color modes + if color_mode == 'mono': + # White characters on background + fg_color = jnp.array([255, 255, 255], dtype=jnp.uint8) + fg = jnp.broadcast_to(fg_color, char_pixels.shape) + elif color_mode == 'invert': + # Inverted cell colors + cell_colors = colors[cell_row, cell_col] + fg = 255 - cell_colors + else: + # 'color' mode - use cell colors + fg = colors[cell_row, cell_col] + + # Blend foreground onto background based on character brightness + if invert_colors: + # Swap fg and bg + fg, bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape), fg + result = (fg.astype(jnp.float32) * (1 - char_brightness) + + bg_broadcast.astype(jnp.float32) * char_brightness) + else: + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_brightness) + + fg.astype(jnp.float32) * char_brightness) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize back to original frame size if needed + if result.shape[0] != h or result.shape[1] != w: + # Simple nearest-neighbor resize + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = (jnp.arange(h) * y_scale).astype(jnp.int32) + x_src = (jnp.arange(w) * x_scale).astype(jnp.int32) + y_src = jnp.clip(y_src, 0, result.shape[0] - 1) + x_src = jnp.clip(x_src, 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'ascii-fx-zone': + # Complex ASCII effect with per-zone expressions + # (ascii-fx-zone frame :cols cols :alphabet alphabet ...) + frame = self._eval(args[0], env) + + # Parse keyword arguments + kwargs = {} + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + if i + 1 < len(args): + kwargs[key] = args[i + 1] + i += 2 + else: + i += 1 + + # Get parameters + cols = int(self._eval(kwargs.get('cols', 80), env)) + char_size_param = kwargs.get('char_size') + alphabet = self._eval(kwargs.get('alphabet', 'standard'), env) + color_mode = self._eval(kwargs.get('color_mode', 'color'), env) + background = self._eval(kwargs.get('background', 'black'), env) + contrast = float(self._eval(kwargs.get('contrast', 1.5), env)) + + h, w = frame.shape[:2] + + # Calculate char_size from cols if not specified + if char_size_param is not None: + char_size_val = self._eval(char_size_param, env) + if char_size_val is not None: + char_size = int(char_size_val) + else: + char_size = w // cols + else: + char_size = w // cols + char_size = max(4, min(char_size, 64)) + + # Store alphabet for render-char-grid to use + env['_ascii_alphabet'] = alphabet + + # Cell sampling + num_rows = h // char_size + num_cols = w // char_size + cropped = frame[:num_rows * char_size, :num_cols * char_size, :] + reshaped = cropped.reshape(num_rows, char_size, num_cols, char_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(jnp.uint8) + + # Compute luminances + colors_float = colors.astype(jnp.float32) + luminances = (0.299 * colors_float[:, :, 0] + + 0.587 * colors_float[:, :, 1] + + 0.114 * colors_float[:, :, 2]) / 255.0 + + # Get alphabet and map luminances to chars + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + lum_adjusted = jnp.clip((luminances - 0.5) * contrast + 0.5, 0, 1) + char_indices = (lum_adjusted * (num_chars - 1)).astype(jnp.int32) + char_indices = jnp.clip(char_indices, 0, num_chars - 1) + + # Handle optional per-zone effects (char_hue, char_saturation, etc.) + # These would modify colors based on zone position + char_hue = kwargs.get('char_hue') + char_saturation = kwargs.get('char_saturation') + char_brightness = kwargs.get('char_brightness') + + if char_hue is not None or char_saturation is not None or char_brightness is not None: + # Create zone coordinate arrays for expression evaluation + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + # Bind zone variables + zone_env = env.copy() + zone_env['zone-row'] = row_coords + zone_env['zone-col'] = col_coords + zone_env['zone-row-norm'] = row_norm + zone_env['zone-col-norm'] = col_norm + zone_env['zone-lum'] = luminances + + # Apply color modifications (simplified - full version would use HSV) + if char_brightness is not None: + brightness_mult = self._eval(char_brightness, zone_env) + if brightness_mult is not None: + colors = jnp.clip(colors.astype(jnp.float32) * brightness_mult[:, :, None], + 0, 255).astype(jnp.uint8) + + # Render using font atlas + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background color + if background == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original dimensions + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'render-char-grid-fx': + # Enhanced render with per-character effects + # (render-char-grid-fx frame chars colors luminances char_size + # color_mode bg_color invert_colors + # char_jitter char_scale char_rotation char_hue_shift + # jitter_source scale_source rotation_source hue_source) + frame = self._eval(args[0], env) + char_indices = self._eval(args[1], env) + colors = self._eval(args[2], env) + luminances = self._eval(args[3], env) + char_size = int(self._eval(args[4], env)) if len(args) > 4 else 8 + color_mode = self._eval(args[5], env) if len(args) > 5 else 'color' + background_color = self._eval(args[6], env) if len(args) > 6 else 'black' + invert_colors = self._eval(args[7], env) if len(args) > 7 else 0 + + # Per-char effect amounts + char_jitter = float(self._eval(args[8], env)) if len(args) > 8 else 0 + char_scale = float(self._eval(args[9], env)) if len(args) > 9 else 1.0 + char_rotation = float(self._eval(args[10], env)) if len(args) > 10 else 0 + char_hue_shift = float(self._eval(args[11], env)) if len(args) > 11 else 0 + + # Modulation sources + jitter_source = self._eval(args[12], env) if len(args) > 12 else 'none' + scale_source = self._eval(args[13], env) if len(args) > 13 else 'none' + rotation_source = self._eval(args[14], env) if len(args) > 14 else 'none' + hue_source = self._eval(args[15], env) if len(args) > 15 else 'none' + + h, w = frame.shape[:2] + num_rows, num_cols = char_indices.shape + + # Get alphabet + alphabet = env.get('_ascii_alphabet', 'standard') + alpha_str = _get_alphabet_string(alphabet) + font_atlas = _create_font_atlas(alpha_str, char_size) + + # Parse background + if background_color == 'black': + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + elif background_color == 'white': + bg = jnp.array([255, 255, 255], dtype=jnp.uint8) + else: + bg = jnp.array([0, 0, 0], dtype=jnp.uint8) + + # Create modulation values based on source + def get_modulation(source, lums, num_rows, num_cols): + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + row_norm = row_coords / max(num_rows - 1, 1) + col_norm = col_coords / max(num_cols - 1, 1) + + if source == 'luminance': + return lums + elif source == 'inv_luminance': + return 1.0 - lums + elif source == 'position_x': + return col_norm + elif source == 'position_y': + return row_norm + elif source == 'position_diag': + return (row_norm + col_norm) / 2 + elif source == 'center_dist': + cy, cx = 0.5, 0.5 + dist = jnp.sqrt((row_norm - cy)**2 + (col_norm - cx)**2) + return jnp.clip(dist / 0.707, 0, 1) # Normalize by max diagonal + elif source == 'random': + # Use frame-varying key for random source + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key = make_jax_key(seed, frame_num, op_ctr) + return jax.random.uniform(key, (num_rows, num_cols)) + else: + return jnp.zeros((num_rows, num_cols)) + + # Get modulation values + jitter_mod = get_modulation(jitter_source, luminances, num_rows, num_cols) + scale_mod = get_modulation(scale_source, luminances, num_rows, num_cols) + rotation_mod = get_modulation(rotation_source, luminances, num_rows, num_cols) + hue_mod = get_modulation(hue_source, luminances, num_rows, num_cols) + + # Gather characters - convert numpy atlas to JAX for traced indexing + font_atlas_jax = jnp.asarray(font_atlas) + flat_indices = char_indices.flatten() + char_tiles = font_atlas_jax[flat_indices].reshape(num_rows, num_cols, char_size, char_size, 3) + + # Create output + output_h = num_rows * char_size + output_w = num_cols * char_size + + y_out, x_out = jnp.mgrid[:output_h, :output_w] + cell_row = jnp.clip(y_out // char_size, 0, num_rows - 1) + cell_col = jnp.clip(x_out // char_size, 0, num_cols - 1) + local_y = y_out % char_size + local_x = x_out % char_size + + # Apply jitter if enabled + if char_jitter > 0: + jitter_amount = jitter_mod[cell_row, cell_col] * char_jitter + # Use frame-varying key for jitter + seed = env.get('_seed', 42) + op_ctr = env.get('_rand_op_counter', 0) + env['_rand_op_counter'] = op_ctr + 1 + key1, key2 = jax.random.split(make_jax_key(seed, frame_num, op_ctr), 2) + # Generate deterministic jitter per cell + jitter_y = jax.random.uniform(key1, (num_rows, num_cols), minval=-1, maxval=1) + jitter_x = jax.random.uniform(key2, (num_rows, num_cols), minval=-1, maxval=1) + offset_y = (jitter_y[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + offset_x = (jitter_x[cell_row, cell_col] * jitter_amount).astype(jnp.int32) + local_y = jnp.clip(local_y + offset_y, 0, char_size - 1) + local_x = jnp.clip(local_x + offset_x, 0, char_size - 1) + + char_pixels = char_tiles[cell_row, cell_col, local_y, local_x] + char_bright = char_pixels.mean(axis=-1, keepdims=True) / 255.0 + + # Get foreground colors + if color_mode == 'mono': + fg = jnp.full_like(char_pixels, 255) + else: + fg = colors[cell_row, cell_col] + + # Apply hue shift if enabled + if char_hue_shift > 0 and color_mode == 'color': + hue_shift_amount = hue_mod[cell_row, cell_col] * char_hue_shift + # Simple hue rotation via channel cycling + fg_float = fg.astype(jnp.float32) + shift_frac = (hue_shift_amount / 120.0) % 3 # Cycle through RGB + # Simplified: blend channels based on shift + r, g, b = fg_float[:,:,0], fg_float[:,:,1], fg_float[:,:,2] + shift_frac_2d = shift_frac[:, :, None] if shift_frac.ndim == 2 else shift_frac + # Just do a simple tint for now + fg = jnp.clip(fg_float + hue_shift_amount[:, :, None] * 0.5, 0, 255).astype(jnp.uint8) + + # Blend + bg_broadcast = jnp.broadcast_to(bg, char_pixels.shape) + if invert_colors: + result = (fg.astype(jnp.float32) * (1 - char_bright) + + bg_broadcast.astype(jnp.float32) * char_bright) + else: + result = (bg_broadcast.astype(jnp.float32) * (1 - char_bright) + + fg.astype(jnp.float32) * char_bright) + + result = jnp.clip(result, 0, 255).astype(jnp.uint8) + + # Resize to original + if result.shape[0] != h or result.shape[1] != w: + y_scale = result.shape[0] / h + x_scale = result.shape[1] / w + y_src = jnp.clip((jnp.arange(h) * y_scale).astype(jnp.int32), 0, result.shape[0] - 1) + x_src = jnp.clip((jnp.arange(w) * x_scale).astype(jnp.int32), 0, result.shape[1] - 1) + result = result[y_src[:, None], x_src[None, :], :] + + return result + + if op == 'alphabet-char': + # (alphabet-char alphabet-name index) -> char_index in that alphabet + alphabet = self._eval(args[0], env) + index = self._eval(args[1], env) + + alpha_str = _get_alphabet_string(alphabet) + num_chars = len(alpha_str) + + # Handle both scalar and array indices + if hasattr(index, 'shape'): + index = jnp.clip(index.astype(jnp.int32), 0, num_chars - 1) + else: + index = max(0, min(int(index), num_chars - 1)) + + return index + + if op == 'map-char-grid': + # (map-char-grid base-chars luminances (lambda (r c ch lum) ...)) + # Map over character grid, allowing per-cell character selection + base_chars = self._eval(args[0], env) + luminances = self._eval(args[1], env) + fn = args[2] # Lambda expression + + num_rows, num_cols = base_chars.shape + + # For JAX compatibility, we can't use Python loops with traced values + # Instead, we'll evaluate the lambda for the whole grid at once + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + + # Create grid coordinates + row_coords, col_coords = jnp.mgrid[:num_rows, :num_cols] + + # Bind parameters for whole-grid evaluation + fn_env = env.copy() + + # Params: (r c ch lum) + if len(params) >= 1: + fn_env[params[0].name if isinstance(params[0], Symbol) else params[0]] = row_coords + if len(params) >= 2: + fn_env[params[1].name if isinstance(params[1], Symbol) else params[1]] = col_coords + if len(params) >= 3: + fn_env[params[2].name if isinstance(params[2], Symbol) else params[2]] = base_chars + if len(params) >= 4: + # Luminances scaled to 0-255 range + fn_env[params[3].name if isinstance(params[3], Symbol) else params[3]] = (luminances * 255).astype(jnp.int32) + + # Evaluate body - should return new character indices + result = self._eval(body, fn_env) + if hasattr(result, 'shape'): + return result.astype(jnp.int32) + return base_chars + + return base_chars + + # ===================================================================== + # List operations + # ===================================================================== + if op == 'take': + seq = self._eval(args[0], env) + n = int(self._eval(args[1], env)) + if isinstance(seq, (list, tuple)): + return seq[:n] + return seq[:n] # Works for arrays too + + if op == 'cons': + item = self._eval(args[0], env) + seq = self._eval(args[1], env) + if isinstance(seq, list): + return [item] + seq + elif isinstance(seq, tuple): + return (item,) + seq + return jnp.concatenate([jnp.array([item]), seq]) + + if op == 'roll': + arr = self._eval(args[0], env) + shift = self._eval(args[1], env) + axis = self._eval(args[2], env) if len(args) > 2 else 0 + # Convert to int for concrete values, keep as-is for JAX traced values + if isinstance(shift, (int, float)): + shift = int(shift) + elif hasattr(shift, 'astype'): + shift = shift.astype(jnp.int32) + if isinstance(axis, (int, float)): + axis = int(axis) + return jnp.roll(arr, shift, axis=axis) + + # ===================================================================== + # Pi constant + # ===================================================================== + if op == 'pi': + return jnp.pi + + raise ValueError(f"Unknown operation: {op}") + + +# ============================================================================= +# Public API +# ============================================================================= + +def compile_effect(code: str) -> Callable: + """ + Compile an S-expression effect to a JAX function. + + Args: + code: S-expression effect code + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + # Check cache + cache_key = hashlib.md5(code.encode()).hexdigest() + if cache_key in _COMPILED_EFFECTS: + return _COMPILED_EFFECTS[cache_key] + + # Parse and compile + sexp = parse(code) + compiler = JaxCompiler() + fn = compiler.compile_effect(sexp) + + _COMPILED_EFFECTS[cache_key] = fn + return fn + + +def compile_effect_file(path: str, derived_paths: List[str] = None) -> Callable: + """ + Compile an effect from a .sexp file. + + Args: + path: Path to the .sexp effect file + derived_paths: Optional list of paths to derived.sexp files to load + + Returns: + JIT-compiled function: (frame, **params) -> frame + """ + with open(path, 'r') as f: + code = f.read() + + # Parse all expressions in file + exprs = parse_all(code) + + # Create compiler + compiler = JaxCompiler() + + # Load derived files if specified + if derived_paths: + for dp in derived_paths: + compiler.load_derived(dp) + + # Process expressions - find require statements and the effect + effect_sexp = None + effect_dir = Path(path).parent + + for expr in exprs: + if not isinstance(expr, list) or len(expr) < 2: + continue + + head = expr[0] + if not isinstance(head, Symbol): + continue + + if head.name == 'require': + # (require "derived") or (require "path/to/file") + req_path = expr[1] + if isinstance(req_path, str): + # Resolve relative to effect file + if not req_path.endswith('.sexp'): + req_path = req_path + '.sexp' + full_path = effect_dir / req_path + if not full_path.exists(): + # Try sexp_effects directory + full_path = Path(__file__).parent.parent / 'sexp_effects' / req_path + if full_path.exists(): + compiler.load_derived(str(full_path)) + + elif head.name == 'require-primitives': + # (require-primitives "lib") - currently ignored for JAX + # JAX has all primitives built-in + pass + + elif head.name in ('effect', 'define-effect'): + effect_sexp = expr + + if effect_sexp is None: + raise ValueError(f"No effect definition found in {path}") + + return compiler.compile_effect(effect_sexp) + + +def load_derived(derived_path: str = None) -> Dict[str, Callable]: + """ + Load derived operations from derived.sexp. + + Returns dict of compiled functions that can be used in effects. + """ + if derived_path is None: + derived_path = Path(__file__).parent.parent / 'sexp_effects' / 'derived.sexp' + + with open(derived_path, 'r') as f: + code = f.read() + + exprs = parse_all(code) + compiler = JaxCompiler() + env = {} + + for expr in exprs: + if isinstance(expr, list) and len(expr) >= 3: + head = expr[0] + if isinstance(head, Symbol) and head.name == 'define': + compiler._eval_define(expr[1:], env) + + return env + + +# ============================================================================= +# Test / Demo +# ============================================================================= + +if __name__ == '__main__': + import numpy as np + + # Test effect + test_effect = ''' + (effect "threshold-test" + :params ((threshold :default 128)) + :body (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (gray (+ (* r 0.299) (* g 0.587) (* b 0.114))) + (mask (where (> gray threshold) 255 0))) + (merge-channels mask mask mask))) + ''' + + print("Compiling effect...") + run_effect = compile_effect(test_effect) + + # Create test frame + print("Creating test frame...") + frame = np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) + + # Run effect + print("Running effect (first run includes JIT compilation)...") + import time + + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"First run (with JIT): {(t1-t0)*1000:.2f}ms") + + # Second run should be faster + t0 = time.time() + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"Second run (cached): {(t1-t0)*1000:.2f}ms") + + # Multiple runs + t0 = time.time() + for _ in range(100): + result = run_effect(frame, threshold=128) + t1 = time.time() + print(f"100 runs: {(t1-t0)*1000:.2f}ms total, {(t1-t0)*10:.2f}ms avg") + + print(f"\nResult shape: {result.shape}, dtype: {result.dtype}") + print("Done!") diff --git a/streaming/sources.py b/streaming/sources.py new file mode 100644 index 0000000..71e7e53 --- /dev/null +++ b/streaming/sources.py @@ -0,0 +1,281 @@ +""" +Video and image sources with looping support. +""" + +import numpy as np +import subprocess +import json +from pathlib import Path +from typing import Optional, Tuple +from abc import ABC, abstractmethod + + +class Source(ABC): + """Abstract base class for frame sources.""" + + @abstractmethod + def read_frame(self, t: float) -> np.ndarray: + """Read frame at time t (with looping if needed).""" + pass + + @property + @abstractmethod + def duration(self) -> float: + """Source duration in seconds.""" + pass + + @property + @abstractmethod + def size(self) -> Tuple[int, int]: + """Frame size as (width, height).""" + pass + + @property + @abstractmethod + def fps(self) -> float: + """Frames per second.""" + pass + + +class VideoSource(Source): + """ + Video file source with automatic looping. + + Reads frames on-demand, seeking as needed. When time exceeds + duration, wraps around (loops). + """ + + def __init__(self, path: str, target_fps: float = 30): + self.path = Path(path) + self.target_fps = target_fps + + # Initialize decode state first (before _probe which could fail) + self._process: Optional[subprocess.Popen] = None + self._current_start: Optional[float] = None + self._frame_buffer: Optional[np.ndarray] = None + self._buffer_time: Optional[float] = None + + self._duration = None + self._size = None + self._fps = None + + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + self._probe() + + def _probe(self): + """Get video metadata.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_format", "-show_streams", + str(self.path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + data = json.loads(result.stdout) + + # Get duration + self._duration = float(data["format"]["duration"]) + + # Get video stream info + for stream in data["streams"]: + if stream["codec_type"] == "video": + self._size = (int(stream["width"]), int(stream["height"])) + # Parse fps from r_frame_rate (e.g., "30/1" or "30000/1001") + fps_parts = stream.get("r_frame_rate", "30/1").split("/") + self._fps = float(fps_parts[0]) / float(fps_parts[1]) + break + + @property + def duration(self) -> float: + return self._duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def _start_decode(self, start_time: float): + """Start ffmpeg decode process from given time.""" + if self._process: + try: + self._process.stdout.close() + except: + pass + self._process.terminate() + try: + self._process.wait(timeout=1) + except: + self._process.kill() + self._process.wait() + + w, h = self._size + cmd = [ + "ffmpeg", "-v", "quiet", + "-ss", str(start_time), + "-i", str(self.path), + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-r", str(self.target_fps), + "-" + ] + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + bufsize=w * h * 3 * 4, # Buffer a few frames + ) + self._current_start = start_time + self._buffer_time = start_time + + def read_frame(self, t: float) -> np.ndarray: + """ + Read frame at time t. + + If t exceeds duration, wraps around (loops). + Seeks if needed, otherwise reads sequentially. + """ + # Wrap time for looping + t_wrapped = t % self._duration + + # Check if we need to seek (loop point or large time jump) + need_seek = ( + self._process is None or + self._buffer_time is None or + abs(t_wrapped - self._buffer_time) > 1.0 / self.target_fps * 2 + ) + + if need_seek: + self._start_decode(t_wrapped) + + # Read frame + w, h = self._size + frame_size = w * h * 3 + + # Try to read with retries for seek settling + for attempt in range(3): + raw = self._process.stdout.read(frame_size) + if len(raw) == frame_size: + break + # End of stream or seek not ready - restart from beginning + self._start_decode(0) + + if len(raw) < frame_size: + # Still no data - return last frame or black + if self._frame_buffer is not None: + return self._frame_buffer.copy() + return np.zeros((h, w, 3), dtype=np.uint8) + + frame = np.frombuffer(raw, dtype=np.uint8).reshape((h, w, 3)) + self._frame_buffer = frame # Cache for fallback + self._buffer_time = t_wrapped + 1.0 / self.target_fps + + return frame + + def close(self): + """Clean up resources.""" + if self._process: + self._process.terminate() + self._process.wait() + self._process = None + + def __del__(self): + self.close() + + def __repr__(self): + return f"VideoSource({self.path.name}, {self._size[0]}x{self._size[1]}, {self._duration:.1f}s)" + + +class ImageSource(Source): + """ + Static image source (returns same frame for any time). + + Useful for backgrounds, overlays, etc. + """ + + def __init__(self, path: str): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Image not found: {path}") + + # Load image + import cv2 + self._frame = cv2.imread(str(self.path)) + self._frame = cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB) + self._size = (self._frame.shape[1], self._frame.shape[0]) + + @property + def duration(self) -> float: + return float('inf') # Images last forever + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return 30.0 # Arbitrary + + def read_frame(self, t: float) -> np.ndarray: + return self._frame.copy() + + def __repr__(self): + return f"ImageSource({self.path.name}, {self._size[0]}x{self._size[1]})" + + +class LiveSource(Source): + """ + Live video capture source (webcam, capture card, etc.). + + Time parameter is ignored - always returns latest frame. + """ + + def __init__(self, device: int = 0, size: Tuple[int, int] = (1280, 720), fps: float = 30): + import cv2 + self._cap = cv2.VideoCapture(device) + self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, size[0]) + self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, size[1]) + self._cap.set(cv2.CAP_PROP_FPS, fps) + + # Get actual settings + self._size = ( + int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + ) + self._fps = self._cap.get(cv2.CAP_PROP_FPS) + + if not self._cap.isOpened(): + raise RuntimeError(f"Could not open video device {device}") + + @property + def duration(self) -> float: + return float('inf') # Live - no duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def read_frame(self, t: float) -> np.ndarray: + """Read latest frame (t is ignored for live sources).""" + import cv2 + ret, frame = self._cap.read() + if not ret: + return np.zeros((self._size[1], self._size[0], 3), dtype=np.uint8) + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + def close(self): + self._cap.release() + + def __del__(self): + self.close() + + def __repr__(self): + return f"LiveSource({self._size[0]}x{self._size[1]}, {self._fps}fps)" diff --git a/streaming/stream_sexp.py b/streaming/stream_sexp.py new file mode 100644 index 0000000..07acb2a --- /dev/null +++ b/streaming/stream_sexp.py @@ -0,0 +1,1098 @@ +""" +Generic Streaming S-expression Interpreter. + +Executes streaming sexp recipes frame-by-frame. +The sexp defines the pipeline logic - interpreter just provides primitives. + +Primitives: + (read source-name) - read frame from source + (rotate frame :angle N) - rotate frame + (zoom frame :amount N) - zoom frame + (invert frame :amount N) - invert colors + (hue-shift frame :degrees N) - shift hue + (blend frame1 frame2 :opacity N) - blend two frames + (blend-weighted [frames...] [weights...]) - weighted blend + (ripple frame :amplitude N :cx N :cy N ...) - ripple effect + + (bind scan-name :field) - get scan state field + (map value [lo hi]) - map 0-1 value to range + energy - current energy (0-1) + beat - 1 if beat, 0 otherwise + t - current time + beat-count - total beats so far + +Example sexp: + (stream "test" + :fps 30 + (source vid "video.mp4") + (audio aud "music.mp3") + + (scan spin beat + :init {:angle 0 :dir 1} + :step (dict :angle (+ angle (* dir 10)) :dir dir)) + + (frame + (-> (read vid) + (rotate :angle (bind spin :angle)) + (zoom :amount (map energy [1 1.5]))))) +""" + +import sys +import time +import json +import hashlib +import numpy as np +import subprocess +from pathlib import Path +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional, Tuple, Union + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) +from artdag.sexp.parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class StreamContext: + """Runtime context for streaming.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + output_size: Tuple[int, int] = (720, 720) + + +class StreamCache: + """Cache for streaming data.""" + + def __init__(self, cache_dir: Path, recipe_hash: str): + self.cache_dir = cache_dir / recipe_hash + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.analysis_buffer: Dict[str, List] = {} + self.scan_states: Dict[str, List] = {} + self.keyframe_interval = 5.0 + + def record_analysis(self, name: str, t: float, value: float): + if name not in self.analysis_buffer: + self.analysis_buffer[name] = [] + t = float(t) if hasattr(t, 'item') else t + value = float(value) if hasattr(value, 'item') else value + self.analysis_buffer[name].append((t, value)) + + def record_scan_state(self, name: str, t: float, state: dict): + if name not in self.scan_states: + self.scan_states[name] = [] + states = self.scan_states[name] + if not states or t - states[-1][0] >= self.keyframe_interval: + t = float(t) if hasattr(t, 'item') else t + clean = {k: (float(v) if hasattr(v, 'item') else v) for k, v in state.items()} + self.scan_states[name].append((t, clean)) + + def flush(self): + for name, data in self.analysis_buffer.items(): + path = self.cache_dir / f"analysis_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(data) + path.write_text(json.dumps(existing)) + self.analysis_buffer.clear() + + for name, states in self.scan_states.items(): + path = self.cache_dir / f"scan_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(states) + path.write_text(json.dumps(existing)) + self.scan_states.clear() + + +class VideoSource: + """Video source - reads frames sequentially.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + # Get info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + + for s in info.get("streams", []): + if s.get("codec_type") == "video": + self.width = s.get("width", 720) + self.height = s.get("height", 720) + break + else: + self.width, self.height = 720, 720 + + self.duration = float(info.get("format", {}).get("duration", 60)) + self.size = (self.width, self.height) + + # Start decoder + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", "-r", str(fps), "-"] + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + self._frame_size = self.width * self.height * 3 + self._current_frame = None + + def read(self) -> Optional[np.ndarray]: + """Read next frame.""" + data = self._proc.stdout.read(self._frame_size) + if len(data) < self._frame_size: + return self._current_frame # Return last frame if stream ends + self._current_frame = np.frombuffer(data, dtype=np.uint8).reshape( + self.height, self.width, 3).copy() + return self._current_frame + + def skip(self): + """Read and discard frame (keep pipe in sync).""" + self._proc.stdout.read(self._frame_size) + + def close(self): + if self._proc: + self._proc.terminate() + self._proc.wait() + + +class AudioAnalyzer: + """Real-time audio analysis.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + + # Load audio + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + self._audio = np.frombuffer( + subprocess.run(cmd, capture_output=True).stdout, dtype=np.float32) + self.sample_rate = sample_rate + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + self._flux_history = [] + self._last_beat_time = -1 + + def get_energy(self, t: float) -> float: + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size//2: + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + while self._flux_history and self._flux_history[0][0] < t - 1.5: + self._flux_history.pop(0) + + if len(self._flux_history) < 3: + return False + + vals = [f for _, f in self._flux_history] + threshold = np.mean(vals) + np.std(vals) * 0.3 + 0.001 + + is_beat = flux > threshold and t - self._last_beat_time > 0.1 + if is_beat: + self._last_beat_time = t + return is_beat + + +class StreamInterpreter: + """ + Generic streaming sexp interpreter. + + Evaluates the frame pipeline expression each frame. + """ + + def __init__(self, sexp_path: str, cache_dir: str = None): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + recipe_hash = hashlib.sha256(text.encode()).hexdigest()[:16] + cache_path = Path(cache_dir) if cache_dir else self.sexp_dir / ".stream_cache" + self.cache = StreamCache(cache_path, recipe_hash) + + self.ctx = StreamContext(fps=self.config.get('fps', 30)) + self.sources: Dict[str, VideoSource] = {} + self.frames: Dict[str, np.ndarray] = {} # Current frame per source + self._sources_read: set = set() # Track which sources read this frame + self.audios: Dict[str, AudioAnalyzer] = {} # Multiple named audio sources + self.audio_paths: Dict[str, str] = {} + self.audio_state: Dict[str, dict] = {} # Per-audio: {energy, is_beat, beat_count, last_beat} + self.scans: Dict[str, dict] = {} + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} # name -> Python function + self.effects: Dict[str, dict] = {} # name -> {params, body} + self.macros: Dict[str, dict] = {} # name -> {params, body} + self.primitive_lib_dir = self.sexp_dir.parent / "sexp_effects" / "primitive_libs" + + self.frame_pipeline = None # The (frame ...) expression + + import random + self.rng = random.Random(self.config.get('seed', 42)) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file.""" + import importlib.util + + # Try multiple paths + lib_paths = [ + self.primitive_lib_dir / f"{lib_name}.py", + self.sexp_dir / "primitive_libs" / f"{lib_name}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{lib_name}.py", + ] + + lib_path = None + for p in lib_paths: + if p.exists(): + lib_path = p + break + + if not lib_path: + print(f"Warning: primitive library '{lib_name}' not found", file=sys.stderr) + return + + spec = importlib.util.spec_from_file_location(lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Extract all prim_* functions + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] # Remove 'prim_' prefix + self.primitives[prim_name] = func + # Also register with dashes instead of underscores + dash_name = prim_name.replace('_', '-') + self.primitives[dash_name] = func + # Also register with -img suffix (sexp convention) + self.primitives[dash_name + '-img'] = func + count += 1 + + # Also check for PRIMITIVES dict (some modules use this for additional exports) + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + self.primitives[name] = func + # Also register underscore version + underscore_name = name.replace('-', '_') + self.primitives[underscore_name] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions)", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + print(f"Warning: effect file not found: {effect_path}", file=sys.stderr) + return + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + # Parse params list + params_list = form[i + 1] + for p in params_list: + if isinstance(p, list) and p: + pname = p[0].name if isinstance(p[0], Symbol) else str(p[0]) + pdef = {'default': 0} + j = 1 + while j < len(p): + if isinstance(p[j], Keyword): + pdef[p[j].name] = p[j + 1] if j + 1 < len(p) else None + j += 2 + else: + j += 1 + params[pname] = pdef + i += 2 + else: + i += 2 + else: + # Body expression + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body, 'path': str(effect_path)} + print(f"Effect: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [] + body = None + + if len(form) > 2 and isinstance(form[2], list): + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + if len(form) > 3: + body = form[3] + + self.macros[name] = {'params': params, 'body': body} + print(f"Macro: {name}", file=sys.stderr) + + def _init(self): + """Initialize sources, scans, and pipeline from sexp.""" + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + # === External loading === + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + # (effect name :path "...") + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'include': + # (include :path "...") + i = 1 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) # Reuse effect loader for includes + i += 2 + else: + i += 1 + + # === Sources === + + elif cmd == 'source': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.sources[name] = VideoSource(str(full), self.ctx.fps) + print(f"Source: {name} -> {full}", file=sys.stderr) + else: + print(f"Warning: {full} not found", file=sys.stderr) + + elif cmd == 'audio': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.audios[name] = AudioAnalyzer(str(full)) + self.audio_paths[name] = str(full) + self.audio_state[name] = {'energy': 0.0, 'is_beat': False, 'beat_count': 0, 'last_beat': False} + print(f"Audio: {name} -> {full}", file=sys.stderr) + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + # Trigger can be: + # (beat audio-name) - trigger on beat from specific audio + # beat - legacy: trigger on beat from first audio + trigger_expr = form[2] + if isinstance(trigger_expr, list) and len(trigger_expr) >= 2: + # (beat audio-name) + trigger_type = trigger_expr[0].name if isinstance(trigger_expr[0], Symbol) else str(trigger_expr[0]) + trigger_audio = trigger_expr[1].name if isinstance(trigger_expr[1], Symbol) else str(trigger_expr[1]) + trigger = (trigger_type, trigger_audio) + else: + # Legacy bare symbol + trigger = trigger_expr.name if isinstance(trigger_expr, Symbol) else str(trigger_expr) + + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], {}) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger, + } + trigger_str = f"{trigger[0]} {trigger[1]}" if isinstance(trigger, tuple) else trigger + print(f"Scan: {name} (on {trigger_str})", file=sys.stderr) + + elif cmd == 'frame': + # (frame expr) - the pipeline to evaluate each frame + self.frame_pipeline = form[1] if len(form) > 1 else None + + # Set output size from first source + if self.sources: + first = next(iter(self.sources.values())) + self.ctx.output_size = first.size + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + import cv2 + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, Symbol): + name = expr.name + # Built-in values + if name == 't' or name == '_time': + return self.ctx.t + if name == 'pi': + import math + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + return 0 + + if isinstance(expr, Keyword): + return expr.name + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check if op is a closure in environment + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + # Invoke closure + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + # Insert result as first arg + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Audio analysis (explicit) === + + if op == 'energy': + # (energy audio-name) - get current energy from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['energy'] + return 0.0 + + if op == 'beat': + # (beat audio-name) - 1 if beat this frame, 0 otherwise + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return 1.0 if self.audio_state[audio_name]['is_beat'] else 0.0 + return 0.0 + + if op == 'beat-count': + # (beat-count audio-name) - total beats from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['beat_count'] + return 0 + + # === Frame operations === + + if op == 'read': + # (read source-name) - get current frame from source (lazy read) + name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if name not in self.frames: + if name in self.sources: + self.frames[name] = self.sources[name].read() + self._sources_read.add(name) + return self.frames.get(name) + + # === Binding and mapping === + + if op == 'bind': + # (bind scan-name :field) or (bind scan-name) + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + field = None + if len(args) > 1 and isinstance(args[1], Keyword): + field = args[1].name + + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if field: + return state.get(field, 0) + return state + return 0 + + if op == 'map': + # (map value [lo hi]) + val = self._eval(args[0], env) + range_list = self._eval(args[1], env) if len(args) > 1 else [0, 1] + if isinstance(range_list, list) and len(range_list) >= 2: + lo, hi = range_list[0], range_list[1] + return lo + val * (hi - lo) + return val + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + if op == 'map-range': + # (map-range val from-lo from-hi to-lo to-hi) + val = self._eval(args[0], env) + from_lo = self._eval(args[1], env) + from_hi = self._eval(args[2], env) + to_lo = self._eval(args[3], env) + to_hi = self._eval(args[4], env) + # Normalize val to 0-1 in source range, then scale to target range + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + # Lisp-style or: returns first truthy value, or last value if none truthy + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + # (cond pred1 expr1 pred2 expr2 ... true else-expr) + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + # (lambda (params...) body) - create a closure + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + # Return a closure dict that captures the current env + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + # Support both formats: + # (let [name val name val ...] body) - flat vector + # (let ((name val) (name val) ...) body) - nested list + # Note: our let already evaluates sequentially like let* + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + # Nested format: ((name val) (name val) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + # Flat format: [name val name val ...] + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Random === + + if op == 'rand': + return self.rng.random() + if op == 'rand-int': + lo = int(self._eval(args[0], env)) + hi = int(self._eval(args[1], env)) + return self.rng.randint(lo, hi) + if op == 'rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + return lo + self.rng.random() * (hi - lo) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + result[args[i].name] = self._eval(args[i + 1], env) if i + 1 < len(args) else None + i += 2 + else: + i += 1 + return result + + if op == 'get': + d = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(d, dict): + return d.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, list) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + lst = self._eval(args[0], env) + return len(lst) if isinstance(lst, (list, dict, str)) else 0 + + # === External effects === + if op in self.effects: + effect = self.effects[op] + effect_env = dict(env) + effect_env['t'] = self.ctx.t + + # Set defaults for all params + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + # Parse args: first is frame, then positional params, then kwargs + positional_idx = 0 + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + # Keyword arg + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + # Positional arg + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + return self._eval(effect['body'], effect_env) + + # === External primitives === + if op in self.primitives: + prim_func = self.primitives[op] + # Evaluate all args + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + # Call primitive + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # === Macros === + if op in self.macros: + macro = self.macros[op] + # Bind macro params to args (unevaluated) + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + macro_env[pname] = args[i] if i < len(args) else None + # Expand and evaluate + return self._eval(macro['body'], macro_env) + + # === Primitive-style call (name-with-dashes -> prim_name_with_underscores) === + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # Unknown - return as-is + return expr + + def _step_scans(self): + """Step scans on beat from specific audio.""" + for name, scan in self.scans.items(): + trigger = scan['trigger'] + + # Check if this scan should step + should_step = False + audio_name = None + + if isinstance(trigger, tuple) and trigger[0] == 'beat': + # Explicit: (beat audio-name) + audio_name = trigger[1] + if audio_name in self.audio_state: + should_step = self.audio_state[audio_name]['is_beat'] + elif trigger == 'beat': + # Legacy: use first audio + if self.audio_state: + audio_name = next(iter(self.audio_state)) + should_step = self.audio_state[audio_name]['is_beat'] + + if should_step and audio_name: + state = self.audio_state[audio_name] + env = dict(scan['state']) + env['beat_count'] = state['beat_count'] + env['t'] = self.ctx.t + env['energy'] = state['energy'] + + if scan['step']: + new_state = self._eval(scan['step'], env) + if isinstance(new_state, dict): + scan['state'] = new_state + elif new_state is not None: + scan['state'] = {'acc': new_state} + + self.cache.record_scan_state(name, self.ctx.t, scan['state']) + + def run(self, duration: float = None, output: str = "pipe"): + """Run the streaming pipeline.""" + from .output import PipeOutput, DisplayOutput, FileOutput + + self._init() + + if not self.sources: + print("Error: no sources", file=sys.stderr) + return + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w, h = self.ctx.output_size + + # Duration from first audio or default + if duration is None: + if self.audios: + first_audio = next(iter(self.audios.values())) + duration = first_audio.duration + else: + duration = 60.0 + + n_frames = int(duration * self.ctx.fps) + frame_time = 1.0 / self.ctx.fps + + print(f"Streaming {n_frames} frames @ {self.ctx.fps}fps", file=sys.stderr) + + # Use first audio for playback sync + first_audio_path = next(iter(self.audio_paths.values())) if self.audio_paths else None + + # Output + if output == "pipe": + out = PipeOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + else: + out = FileOutput(output, size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + + try: + for frame_num in range(n_frames): + if not out.is_open: + print(f"\nOutput closed at {frame_num}", file=sys.stderr) + break + + self.ctx.t = frame_num * frame_time + self.ctx.frame_num = frame_num + + # Update all audio states + for audio_name, analyzer in self.audios.items(): + state = self.audio_state[audio_name] + energy = analyzer.get_energy(self.ctx.t) + is_beat_raw = analyzer.get_beat(self.ctx.t) + is_beat = is_beat_raw and not state['last_beat'] + state['last_beat'] = is_beat_raw + + state['energy'] = energy + state['is_beat'] = is_beat + if is_beat: + state['beat_count'] += 1 + + self.cache.record_analysis(f'{audio_name}_energy', self.ctx.t, energy) + self.cache.record_analysis(f'{audio_name}_beat', self.ctx.t, 1.0 if is_beat else 0.0) + + # Step scans + self._step_scans() + + # Clear frames - will be read lazily + self.frames.clear() + self._sources_read = set() + + # Evaluate pipeline (reads happen on-demand) + result = self._eval(self.frame_pipeline, {}) + + # Skip unread sources to keep pipes in sync + for name, src in self.sources.items(): + if name not in self._sources_read: + src.skip() + + # Ensure output size + if result is not None: + import cv2 + if result.shape[:2] != (h, w): + # Handle CuPy arrays - cv2 can't resize them directly + if hasattr(result, '__cuda_array_interface__'): + # Use GPU resize via cupyx.scipy + try: + import cupy as cp + from cupyx.scipy import ndimage as cpndimage + curr_h, curr_w = result.shape[:2] + zoom_y = h / curr_h + zoom_x = w / curr_w + if result.ndim == 3: + result = cpndimage.zoom(result, (zoom_y, zoom_x, 1), order=1) + else: + result = cpndimage.zoom(result, (zoom_y, zoom_x), order=1) + except ImportError: + # Fallback to CPU resize + result = cv2.resize(cp.asnumpy(result), (w, h)) + else: + result = cv2.resize(result, (w, h)) + out.write(result, self.ctx.t) + + # Progress + if frame_num % 30 == 0: + pct = 100 * frame_num / n_frames + # Show beats from first audio + total_beats = 0 + if self.audio_state: + first_state = next(iter(self.audio_state.values())) + total_beats = first_state['beat_count'] + print(f"\r{pct:5.1f}% | beats:{total_beats}", + end="", file=sys.stderr) + sys.stderr.flush() + + if frame_num % 300 == 0: + self.cache.flush() + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + finally: + out.close() + for src in self.sources.values(): + src.close() + self.cache.flush() + + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path) + if fps: + interp.ctx.fps = fps + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None, help="Override fps (default: from sexp)") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps) diff --git a/streaming/stream_sexp_generic.py b/streaming/stream_sexp_generic.py new file mode 100644 index 0000000..0619589 --- /dev/null +++ b/streaming/stream_sexp_generic.py @@ -0,0 +1,1739 @@ +""" +Fully Generic Streaming S-expression Interpreter. + +The interpreter knows NOTHING about video, audio, or any domain. +All domain logic comes from primitives loaded via (require-primitives ...). + +Built-in forms: + - Control: if, cond, let, let*, lambda, -> + - Arithmetic: +, -, *, /, mod, map-range + - Comparison: <, >, =, <=, >=, and, or, not + - Data: dict, get, list, nth, len, quote + - Random: rand, rand-int, rand-range + - Scan: bind (access scan state) + +Everything else comes from primitives or effects. + +Context (ctx) is passed explicitly to frame evaluation: + - ctx.t: current time + - ctx.frame-num: current frame number + - ctx.fps: frames per second +""" + +import sys +import os +import time +import json +import hashlib +import math +import numpy as np +from pathlib import Path +from dataclasses import dataclass +from typing import Dict, List, Any, Optional, Tuple, Callable + +# Use local sexp_effects parser (supports namespaced symbols like math:sin) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from sexp_effects.parser import parse, parse_all, Symbol, Keyword + +# JAX backend (optional - loaded on demand) +_JAX_AVAILABLE = False +_jax_compiler = None + +def _init_jax(): + """Lazily initialize JAX compiler.""" + global _JAX_AVAILABLE, _jax_compiler + if _jax_compiler is not None: + return _JAX_AVAILABLE + try: + from streaming.sexp_to_jax import JaxCompiler, compile_effect_file + _jax_compiler = {'JaxCompiler': JaxCompiler, 'compile_effect_file': compile_effect_file} + _JAX_AVAILABLE = True + print("JAX backend initialized", file=sys.stderr) + except ImportError as e: + print(f"JAX backend not available: {e}", file=sys.stderr) + _JAX_AVAILABLE = False + return _JAX_AVAILABLE + + +@dataclass +class Context: + """Runtime context passed to frame evaluation.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + + +class DeferredEffectChain: + """ + Represents a chain of JAX effects that haven't been executed yet. + + Allows effects to be accumulated through let bindings and fused + into a single JIT-compiled function when the result is needed. + """ + __slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num') + + def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int): + self.effects = effects # List of effect names, innermost first + self.params_list = params_list # List of param dicts, matching effects + self.base_frame = base_frame # The actual frame array at the start + self.t = t + self.frame_num = frame_num + + def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain': + """Add another effect to the chain (outermost).""" + return DeferredEffectChain( + self.effects + [effect_name], + self.params_list + [params], + self.base_frame, + self.t, + self.frame_num + ) + + @property + def shape(self): + """Allow shape check without forcing execution.""" + return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None + + +class StreamInterpreter: + """ + Fully generic streaming sexp interpreter. + + No domain-specific knowledge - just evaluates expressions + and calls primitives. + """ + + def __init__(self, sexp_path: str, actor_id: Optional[str] = None, use_jax: bool = False): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + self.actor_id = actor_id # For friendly name resolution + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + # Global environment for def bindings + self.globals: Dict[str, Any] = {} + + # Scans + self.scans: Dict[str, dict] = {} + + # Audio playback path (for syncing output) + self.audio_playback: Optional[str] = None + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} + self.effects: Dict[str, dict] = {} + self.macros: Dict[str, dict] = {} + + # JAX backend for accelerated effect evaluation + self.use_jax = use_jax + self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects + self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects + self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains + self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains + self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env + if use_jax: + if _init_jax(): + print("JAX acceleration enabled", file=sys.stderr) + else: + print("Warning: JAX requested but not available, falling back to interpreter", file=sys.stderr) + self.use_jax = False + # Try multiple locations for primitive_libs + possible_paths = [ + self.sexp_dir.parent / "sexp_effects" / "primitive_libs", # recipes/ subdir + self.sexp_dir / "sexp_effects" / "primitive_libs", # app root + Path(__file__).parent.parent / "sexp_effects" / "primitive_libs", # relative to interpreter + ] + self.primitive_lib_dir = next((p for p in possible_paths if p.exists()), possible_paths[0]) + + self.frame_pipeline = None + + # External config files (set before run()) + self.sources_config: Optional[Path] = None + self.audio_config: Optional[Path] = None + + # Error tracking + self.errors: List[str] = [] + + # Callback for live streaming (called when IPFS playlist is updated) + self.on_playlist_update: callable = None + + # Callback for progress updates (called periodically during rendering) + # Signature: on_progress(percent: float, frame_num: int, total_frames: int) + self.on_progress: callable = None + + # Callback for checkpoint saves (called at segment boundaries for resumability) + # Signature: on_checkpoint(checkpoint: dict) + # checkpoint contains: frame_num, t, scans + self.on_checkpoint: callable = None + + # Frames per segment for checkpoint timing (default 4 seconds at 30fps = 120 frames) + self._frames_per_segment: int = 120 + + def _resolve_name(self, name: str) -> Optional[Path]: + """Resolve a friendly name to a file path using the naming service.""" + try: + # Import here to avoid circular imports + from tasks.streaming import resolve_asset + path = resolve_asset(name, self.actor_id) + if path: + return path + except Exception as e: + print(f"Warning: failed to resolve name '{name}': {e}", file=sys.stderr) + return None + + def _record_error(self, msg: str): + """Record an error that occurred during evaluation.""" + self.errors.append(msg) + print(f"ERROR: {msg}", file=sys.stderr) + + def _maybe_to_numpy(self, val, for_gpu_primitive: bool = False): + """Convert GPU frames/CuPy arrays to numpy for CPU primitives. + + If for_gpu_primitive=True, preserve GPU data (CuPy arrays stay on GPU). + """ + if val is None: + return val + + # For GPU primitives, keep data on GPU + if for_gpu_primitive: + # Handle GPUFrame - return the GPU array + if hasattr(val, 'gpu') and hasattr(val, 'is_on_gpu'): + if val.is_on_gpu: + return val.gpu + return val.cpu + # CuPy arrays pass through unchanged + if hasattr(val, '__cuda_array_interface__'): + return val + return val + + # For CPU primitives, convert to numpy + # Handle GPUFrame objects (have .cpu property) + if hasattr(val, 'cpu'): + return val.cpu + # Handle CuPy arrays (have .get() method) + if hasattr(val, 'get') and callable(val.get): + return val.get() + return val + + def _load_config_file(self, config_path): + """Load a config file and process its definitions.""" + config_path = Path(config_path) # Accept str or Path + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + text = config_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'def': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Config: {name}", file=sys.stderr) + + elif cmd == 'audio-playback': + # Path relative to working directory (consistent with other paths) + path = str(form[1]).strip('"') + self.audio_playback = str(Path(path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file. + + Prefers GPU-accelerated versions (*_gpu.py) when available. + Uses cached modules from sys.modules to ensure consistent state + (e.g., same RNG instance for all interpreters). + """ + import importlib.util + + # Try GPU version first, then fall back to CPU version + lib_names_to_try = [f"{lib_name}_gpu", lib_name] + + lib_path = None + actual_lib_name = lib_name + + for try_lib in lib_names_to_try: + lib_paths = [ + self.primitive_lib_dir / f"{try_lib}.py", + self.sexp_dir / "primitive_libs" / f"{try_lib}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{try_lib}.py", + ] + for p in lib_paths: + if p.exists(): + lib_path = p + actual_lib_name = try_lib + break + if lib_path: + break + + if not lib_path: + raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}") + + # Use cached module if already imported to preserve state (e.g., RNG) + # This is critical for deterministic random number sequences + # Check multiple possible module keys (standard import paths and our cache) + possible_keys = [ + f"sexp_effects.primitive_libs.{actual_lib_name}", + f"sexp_primitives.{actual_lib_name}", + ] + + module = None + for key in possible_keys: + if key in sys.modules: + module = sys.modules[key] + break + + if module is None: + spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Cache for future use under our key + sys.modules[f"sexp_primitives.{actual_lib_name}"] = module + + # Check if this is a GPU-accelerated module + is_gpu = actual_lib_name.endswith('_gpu') + gpu_tag = " [GPU]" if is_gpu else "" + + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] + dash_name = prim_name.replace('_', '-') + # Register with original lib_name namespace (geometry:rotate, not geometry_gpu:rotate) + # Don't overwrite if already registered (allows pre-registration of overrides) + key = f"{lib_name}:{dash_name}" + if key not in self.primitives: + self.primitives[key] = func + count += 1 + + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + # Register with original lib_name namespace + # Don't overwrite if already registered + dash_name = name.replace('_', '-') + key = f"{lib_name}:{dash_name}" + if key not in self.primitives: + self.primitives[key] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions){gpu_tag}", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + raise FileNotFoundError(f"Effect/include file not found: {effect_path}") + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + for pdef in form[i + 1]: + if isinstance(pdef, list) and pdef: + pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0]) + pinfo = {'default': 0} + j = 1 + while j < len(pdef): + if isinstance(pdef[j], Keyword) and j + 1 < len(pdef): + pinfo[pdef[j].name] = pdef[j + 1] + j += 2 + else: + j += 1 + params[pname] = pinfo + i += 2 + else: + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body} + self.jax_effect_paths[name] = effect_path # Track source for JAX compilation + print(f"Effect: {name}", file=sys.stderr) + + # Try to compile with JAX if enabled + if self.use_jax and _JAX_AVAILABLE: + self._compile_jax_effect(name, effect_path) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'effect': + # Handle (effect name :path "...") or (effect name :name "...") in included files + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'include': + # Handle (include :path "...") or (include :name "...") in included files + i = 1 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'scan': + # Handle scans from included files + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + def _compile_jax_effect(self, name: str, effect_path: Path): + """Compile an effect with JAX for accelerated execution.""" + if not _JAX_AVAILABLE or name in self.jax_effects: + return + + try: + compile_effect_file = _jax_compiler['compile_effect_file'] + jax_fn = compile_effect_file(str(effect_path)) + self.jax_effects[name] = jax_fn + print(f" [JAX compiled: {name}]", file=sys.stderr) + except Exception as e: + # Silently fall back to interpreter for unsupported effects + if 'Unknown operation' not in str(e): + print(f" [JAX skip {name}: {str(e)[:50]}]", file=sys.stderr) + + def _apply_jax_effect(self, name: str, frame: np.ndarray, params: Dict[str, Any], t: float, frame_num: int) -> Optional[np.ndarray]: + """Apply a JAX-compiled effect to a frame.""" + if name not in self.jax_effects: + return None + + try: + jax_fn = self.jax_effects[name] + # Handle GPU frames (CuPy) - need to move to CPU for CPU JAX + # JAX handles numpy and JAX arrays natively, no conversion needed + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() # CuPy array -> numpy + + # Get seed from config for deterministic random + seed = self.config.get('seed', 42) + + # Call JAX function with parameters + # Return JAX array directly - don't block or convert per-effect + # Conversion to numpy happens once at frame write time + return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) + except Exception as e: + # Fall back to interpreter on error + print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) + return None + + def _is_jax_effect_expr(self, expr) -> bool: + """Check if an expression is a JAX-compiled effect call.""" + if not isinstance(expr, list) or not expr: + return False + head = expr[0] + if not isinstance(head, Symbol): + return False + return head.name in self.jax_effects + + def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]: + """ + Extract a chain of JAX effects from an expression. + + Returns: (effect_names, params_list, base_frame_expr) or None if not a chain. + effect_names and params_list are in execution order (innermost first). + + For (effect1 (effect2 frame :p1 v1) :p2 v2): + Returns: (['effect2', 'effect1'], [params2, params1], frame_expr) + """ + if not self._is_jax_effect_expr(expr): + return None + + chain = [] + params_list = [] + current = expr + + while self._is_jax_effect_expr(current): + head = current[0] + effect_name = head.name + args = current[1:] + + # Extract params for this effect + effect = self.effects[effect_name] + effect_params = {} + for pname, pdef in effect['params'].items(): + effect_params[pname] = pdef.get('default', 0) + + # Find the frame argument (first positional) and other params + frame_arg = None + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_params[pname] = self._eval(args[i + 1], env) + i += 2 + else: + if frame_arg is None: + frame_arg = args[i] # First positional is frame + i += 1 + + chain.append(effect_name) + params_list.append(effect_params) + + if frame_arg is None: + return None # No frame argument found + + # Check if frame_arg is another effect call + if self._is_jax_effect_expr(frame_arg): + current = frame_arg + else: + # End of chain - frame_arg is the base frame + # Reverse to get innermost-first execution order + chain.reverse() + params_list.reverse() + return (chain, params_list, frame_arg) + + return None + + def _get_chain_key(self, effect_names: list, params_list: list) -> str: + """Generate a cache key for an effect chain. + + Includes static param values in the key since they affect compilation. + """ + parts = [] + for name, params in zip(effect_names, params_list): + param_parts = [] + for pname in sorted(params.keys()): + pval = params[pname] + # Include static values in key (strings, bools) + if isinstance(pval, (str, bool)): + param_parts.append(f"{pname}={pval}") + else: + param_parts.append(pname) + parts.append(f"{name}:{','.join(param_parts)}") + return '|'.join(parts) + + def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a chain of effects into a single fused JAX function. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts for each effect (used to detect static types) + + Returns: + JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param names and identify static params from actual values + effect_param_names = [] + static_params = ['seed'] # seed is always static + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + # Check actual values to identify static types + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.append(f"_p{i}_{pname}") + + def fused_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + # Extract params for this effect from kwargs + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # JIT with static params (seed + any string/bool params) + return jax.jit(fused_fn, static_argnames=static_params) + except Exception as e: + print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int): + """Apply a chain of effects, using fused compilation if available.""" + chain_key = self._get_chain_key(effect_names, params_list) + + # Try to get or compile fused chain + if chain_key not in self.jax_fused_chains: + fused_fn = self._compile_effect_chain(effect_names, params_list) + self.jax_fused_chains[chain_key] = fused_fn + if fused_fn: + print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr) + + fused_fn = self.jax_fused_chains.get(chain_key) + + if fused_fn is not None: + # Build kwargs with prefixed param names + kwargs = {} + for i, params in enumerate(params_list): + for pname, pval in params.items(): + kwargs[f"_p{i}_{pname}"] = pval + + seed = self.config.get('seed', 42) + + # Handle GPU frames + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() + + try: + return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs) + except Exception as e: + print(f"Fused chain error: {e}", file=sys.stderr) + + # Fall back to sequential application + result = frame + for name, params in zip(effect_names, params_list): + result = self._apply_jax_effect(name, result, params, t, frame_num) + if result is None: + return None + return result + + def _force_deferred(self, deferred: DeferredEffectChain): + """Execute a deferred effect chain and return the actual array.""" + if len(deferred.effects) == 0: + return deferred.base_frame + + return self._apply_effect_chain( + deferred.effects, + deferred.params_list, + deferred.base_frame, + deferred.t, + deferred.frame_num + ) + + def _maybe_force(self, value): + """Force a deferred chain if needed, otherwise return as-is.""" + if isinstance(value, DeferredEffectChain): + return self._force_deferred(value) + return value + + def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a vmapped version of an effect chain for batch processing. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts (used to detect static types) + + Returns: + Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames + Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + import jax.numpy as jnp + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param info + effect_param_names = [] + static_params = set() + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.add(f"_p{i}_{pname}") + + # Single-frame function (will be vmapped) + def single_frame_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # Return unbatched function - we'll vmap at call time with proper in_axes + return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params)) + except Exception as e: + print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_batched_chain(self, effect_names: list, params_list_batch: list, + frames: list, ts: list, frame_nums: list) -> Optional[list]: + """ + Apply an effect chain to a batch of frames using vmap. + + Args: + effect_names: List of effect names + params_list_batch: List of params_list for each frame in batch + frames: List of input frames + ts: List of time values + frame_nums: List of frame numbers + + Returns: + List of output frames, or None on failure + """ + if not frames: + return [] + + # Use first frame's params for chain key (assume same structure) + chain_key = self._get_chain_key(effect_names, params_list_batch[0]) + batch_key = f"batch:{chain_key}" + + # Compile batched version if needed + if batch_key not in self.jax_batched_chains: + batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0]) + self.jax_batched_chains[batch_key] = batched_fn + if batched_fn: + print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr) + + batched_fn = self.jax_batched_chains.get(batch_key) + + if batched_fn is not None: + try: + import jax + import jax.numpy as jnp + + # Stack frames into batch array + frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames]) + ts_array = jnp.array(ts) + frame_nums_array = jnp.array(frame_nums) + + # Build kwargs - all numeric params as arrays for vmap + kwargs = {} + static_kwargs = {} # Non-vmapped (strings, bools) + + for i, plist in enumerate(zip(*[p for p in params_list_batch])): + for j, pname in enumerate(params_list_batch[0][i].keys()): + key = f"_p{i}_{pname}" + values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]] + + first = values[0] + if isinstance(first, (str, bool)): + # Static params - not vmapped + static_kwargs[key] = first + elif isinstance(first, (int, float)): + # Always batch numeric params for simplicity + kwargs[key] = jnp.array(values) + elif hasattr(first, 'shape'): + kwargs[key] = jnp.stack(values) + else: + kwargs[key] = jnp.array(values) + + seed = self.config.get('seed', 42) + + # Create wrapper that unpacks the params dict + def single_call(frame, t, frame_num, params_dict): + return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs) + + # vmap over frame, t, frame_num, and the params dict (as pytree) + vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0)) + + # Stack kwargs into a dict of arrays (pytree with matching structure) + results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs) + + # Unstack results + return [results[i] for i in range(len(frames))] + except Exception as e: + print(f"Batched chain error: {e}", file=sys.stderr) + + # Fall back to sequential + return None + + def _init(self): + """Initialize from sexp - load primitives, effects, defs, scans.""" + # Set random seed for deterministic output + seed = self.config.get('seed', 42) + try: + from sexp_effects.primitive_libs.core import set_random_seed + set_random_seed(seed) + except ImportError: + pass + + # Load external config files first (they can override recipe definitions) + if self.sources_config: + self._load_config_file(self.sources_config) + if self.audio_config: + self._load_config_file(self.audio_config) + + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + # Resolve friendly name to path + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve effect name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'include': + i = 1 + while i < len(form): + if isinstance(form[i], Keyword): + kw = form[i].name + if kw == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + elif kw == 'name': + # Resolve friendly name to path + fname = str(form[i + 1]).strip('"') + resolved = self._resolve_name(fname) + if resolved: + self._load_effect(resolved) + else: + raise RuntimeError(f"Could not resolve include name '{fname}' - make sure it's uploaded and you're logged in") + i += 2 + else: + i += 1 + else: + i += 1 + + elif cmd == 'audio-playback': + # (audio-playback "path") - set audio file for playback sync + # Skip if already set by config file + if self.audio_playback is None: + path = str(form[1]).strip('"') + # Try to resolve as friendly name first + resolved = self._resolve_name(path) + if resolved: + self.audio_playback = str(resolved) + else: + # Fall back to relative path + self.audio_playback = str((self.sexp_dir / path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + elif cmd == 'def': + # (def name expr) - evaluate and store in globals + # Skip if already defined by config file + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + if name in self.globals: + print(f"Def: {name} (from config, skipped)", file=sys.stderr) + continue + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Def: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + elif cmd == 'frame': + self.frame_pipeline = form[1] if len(form) > 1 else None + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, bool): + return expr + + if isinstance(expr, Symbol): + name = expr.name + # Built-in constants + if name == 'pi': + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Global lookup + if name in self.globals: + return self.globals[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + raise NameError(f"Undefined variable: {name}") + + if isinstance(expr, Keyword): + return expr.name + + # Handle dicts from new parser - evaluate values + if isinstance(expr, dict): + return {k: self._eval(v, env) for k, v in expr.items()} + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check for closure call + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + if op in self.globals: + val = self.globals[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Binding === + + if op == 'bind': + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if len(args) > 1: + key = args[1].name if isinstance(args[1], Keyword) else str(args[1]) + return state.get(key, 0) + return state + return 0 + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + result[key] = val + i += 2 + else: + i += 1 + return result + + if op == 'get': + obj = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(obj, dict): + return obj.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'quote': + return args[0] if args else None + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + val = self._eval(args[0], env) + return len(val) if hasattr(val, '__len__') else 0 + + if op == 'map': + seq = self._eval(args[0], env) + fn = self._eval(args[1], env) + if not isinstance(seq, (list, tuple)): + return [] + # Handle closure (lambda from sexp) + if isinstance(fn, dict) and fn.get('_type') == 'closure': + results = [] + for item in seq: + closure_env = dict(fn['env']) + if fn['params']: + closure_env[fn['params'][0]] = item + results.append(self._eval(fn['body'], closure_env)) + return results + # Handle Python callable + if callable(fn): + return [fn(item) for item in seq] + return [] + + # === Effects === + + if op in self.effects: + # Try to detect and fuse effect chains for JAX acceleration + if self.use_jax and op in self.jax_effects: + chain_info = self._extract_effect_chain(expr, env) + if chain_info is not None: + effect_names, params_list, base_frame_expr = chain_info + # Only use chain if we have 2+ effects (worth fusing) + if len(effect_names) >= 2: + base_frame = self._eval(base_frame_expr, env) + if base_frame is not None and hasattr(base_frame, 'shape'): + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num) + if result is not None: + return result + # Fall through if chain application fails + + effect = self.effects[op] + effect_env = dict(env) + + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + positional_idx = 0 + frame_val = None + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + frame_val = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + # Try JAX-accelerated execution with deferred chaining + if self.use_jax and op in self.jax_effects and frame_val is not None: + # Build params dict for JAX (exclude 'frame') + jax_params = {k: self._maybe_force(v) for k, v in effect_env.items() + if k != 'frame' and k in effect['params']} + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + + # Check if input is a deferred chain - if so, extend it + if isinstance(frame_val, DeferredEffectChain): + return frame_val.extend(op, jax_params) + + # Check if input is a valid frame - create new deferred chain + if hasattr(frame_val, 'shape'): + return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num) + + # Fall through to interpreter if not a valid frame + + # Force any deferred frame before interpreter evaluation + if isinstance(frame_val, DeferredEffectChain): + frame_val = self._force_deferred(frame_val) + effect_env['frame'] = frame_val + + return self._eval(effect['body'], effect_env) + + # === Primitives === + + if op in self.primitives: + prim_func = self.primitives[op] + # Check if this is a GPU primitive (preserves GPU arrays) + is_gpu_prim = op.startswith('gpu:') or 'gpu' in op.lower() + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + # Force deferred chains before passing to primitives + v = self._maybe_force(v) + kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) + i += 2 + else: + val = self._eval(args[i], env) + # Force deferred chains before passing to primitives + val = self._maybe_force(val) + evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + self._record_error(f"Primitive {op} error: {e}") + raise RuntimeError(f"Primitive {op} failed: {e}") + + # === Macros (function-like: args evaluated before binding) === + + if op in self.macros: + macro = self.macros[op] + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + # Evaluate args in calling environment before binding + macro_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(macro['body'], macro_env) + + # Underscore variant lookup + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + # Check if this is a GPU primitive (preserves GPU arrays) + is_gpu_prim = 'gpu' in prim_name.lower() + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) + i += 2 + else: + evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) + i += 1 + + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + self._record_error(f"Primitive {op} error: {e}") + raise RuntimeError(f"Primitive {op} failed: {e}") + + # Unknown function call - raise meaningful error + raise RuntimeError(f"Unknown function or primitive: '{op}'. " + f"Available primitives: {sorted(list(self.primitives.keys())[:10])}... " + f"Available effects: {sorted(list(self.effects.keys())[:10])}... " + f"Available macros: {sorted(list(self.macros.keys())[:10])}...") + + def _step_scans(self, ctx: Context, env: dict): + """Step scans based on trigger evaluation.""" + for name, scan in self.scans.items(): + trigger_expr = scan['trigger'] + + # Evaluate trigger in context + should_step = self._eval(trigger_expr, env) + + if should_step: + state = scan['state'] + step_env = dict(state) + step_env.update(env) + + new_state = self._eval(scan['step'], step_env) + if isinstance(new_state, dict): + scan['state'] = new_state + else: + scan['state'] = {'acc': new_state} + + def _restore_checkpoint(self, checkpoint: dict): + """Restore scan states from a checkpoint. + + Called when resuming a render from a previous checkpoint. + + Args: + checkpoint: Dict with 'scans' key containing {scan_name: state_dict} + """ + scans_state = checkpoint.get('scans', {}) + for name, state in scans_state.items(): + if name in self.scans: + self.scans[name]['state'] = dict(state) + print(f"Restored scan '{name}' state from checkpoint", file=sys.stderr) + + def _get_checkpoint_state(self) -> dict: + """Get current scan states for checkpointing. + + Returns: + Dict mapping scan names to their current state dicts + """ + return {name: dict(scan['state']) for name, scan in self.scans.items()} + + def run(self, duration: float = None, output: str = "pipe", resume_from: dict = None): + """Run the streaming pipeline. + + Args: + duration: Duration in seconds (auto-detected from audio if None) + output: Output mode ("pipe", "preview", path/hls, path/ipfs-hls, or file path) + resume_from: Checkpoint dict to resume from, with keys: + - frame_num: Last completed frame + - t: Time value for checkpoint frame + - scans: Dict of scan states to restore + - segment_cids: Dict of quality -> {seg_num: cid} for output resume + """ + # Import output classes - handle both package and direct execution + try: + from .output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput + from .gpu_output import GPUHLSOutput, check_gpu_encode_available + from .multi_res_output import MultiResolutionHLSOutput + except ImportError: + from output import PipeOutput, DisplayOutput, FileOutput, HLSOutput, IPFSHLSOutput + try: + from gpu_output import GPUHLSOutput, check_gpu_encode_available + except ImportError: + GPUHLSOutput = None + check_gpu_encode_available = lambda: False + try: + from multi_res_output import MultiResolutionHLSOutput + except ImportError: + MultiResolutionHLSOutput = None + + self._init() + + # Restore checkpoint state if resuming + if resume_from: + self._restore_checkpoint(resume_from) + print(f"Resuming from frame {resume_from.get('frame_num', 0)}", file=sys.stderr) + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w = self.config.get('width', 720) + h = self.config.get('height', 720) + fps = self.config.get('fps', 30) + + if duration is None: + # Try to get duration from audio if available + for name, val in self.globals.items(): + if hasattr(val, 'duration'): + duration = val.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + break + else: + duration = 60.0 + + n_frames = int(duration * fps) + frame_time = 1.0 / fps + + print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr) + + # Create context + ctx = Context(fps=fps) + + # Output (with optional audio sync) + # Resolve audio path lazily here if it wasn't resolved during parsing + audio = self.audio_playback + if audio and not Path(audio).exists(): + # Try to resolve as friendly name (may have failed during parsing) + audio_name = Path(audio).name # Get just the name part + resolved = self._resolve_name(audio_name) + if resolved and resolved.exists(): + audio = str(resolved) + print(f"Lazy resolved audio: {audio}", file=sys.stderr) + else: + raise FileNotFoundError(f"Audio file not found: {audio}") + if output == "pipe": + out = PipeOutput(size=(w, h), fps=fps, audio_source=audio) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio) + elif output.endswith("/hls"): + # HLS output - output is a directory path ending in /hls + hls_dir = output[:-4] # Remove /hls suffix + out = HLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio) + elif output.endswith("/ipfs-hls"): + # IPFS HLS output - multi-resolution adaptive streaming + hls_dir = output[:-9] # Remove /ipfs-hls suffix + import os + ipfs_gateway = os.environ.get("IPFS_GATEWAY_URL", "https://ipfs.io/ipfs") + + # Build resume state for output if resuming + output_resume = None + if resume_from and resume_from.get('segment_cids'): + output_resume = {'segment_cids': resume_from['segment_cids']} + + # Use multi-resolution output (renders original + 720p + 360p) + if MultiResolutionHLSOutput is not None: + print(f"[StreamInterpreter] Using multi-resolution HLS output ({w}x{h} + 720p + 360p)", file=sys.stderr) + out = MultiResolutionHLSOutput( + hls_dir, + source_size=(w, h), + fps=fps, + ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update, + audio_source=audio, + resume_from=output_resume, + ) + # Fallback to GPU single-resolution if multi-res not available + elif GPUHLSOutput is not None and check_gpu_encode_available(): + print(f"[StreamInterpreter] Using GPU zero-copy encoding (single resolution)", file=sys.stderr) + out = GPUHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update) + else: + out = IPFSHLSOutput(hls_dir, size=(w, h), fps=fps, audio_source=audio, ipfs_gateway=ipfs_gateway, + on_playlist_update=self.on_playlist_update) + else: + out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio) + + # Calculate frames per segment based on fps and segment duration (4 seconds default) + segment_duration = 4.0 + self._frames_per_segment = int(fps * segment_duration) + + # Determine start frame (resume from checkpoint + 1, or 0) + start_frame = 0 + if resume_from and resume_from.get('frame_num') is not None: + start_frame = resume_from['frame_num'] + 1 + print(f"Starting from frame {start_frame} (checkpoint was at {resume_from['frame_num']})", file=sys.stderr) + + try: + frame_times = [] + profile_interval = 30 # Profile every N frames + scan_times = [] + eval_times = [] + write_times = [] + + # Batch accumulation for JAX + batch_deferred = [] # Accumulated DeferredEffectChains + batch_times = [] # Corresponding time values + batch_start_frame = 0 + + def flush_batch(): + """Execute accumulated batch and write results.""" + nonlocal batch_deferred, batch_times + if not batch_deferred: + return + + t_flush = time.time() + + # Check if all chains have same structure (can batch) + first = batch_deferred[0] + can_batch = ( + self.use_jax and + len(batch_deferred) >= 2 and + all(d.effects == first.effects for d in batch_deferred) + ) + + if can_batch: + # Try batched execution + frames = [d.base_frame for d in batch_deferred] + ts = [d.t for d in batch_deferred] + frame_nums = [d.frame_num for d in batch_deferred] + params_batch = [d.params_list for d in batch_deferred] + + results = self._apply_batched_chain( + first.effects, params_batch, frames, ts, frame_nums + ) + + if results is not None: + # Write batched results + for result, t in zip(results, batch_times): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + batch_deferred = [] + batch_times = [] + return + + # Fall back to sequential execution + for deferred, t in zip(batch_deferred, batch_times): + result = self._force_deferred(deferred) + if result is not None and hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + + batch_deferred = [] + batch_times = [] + + for frame_num in range(start_frame, n_frames): + if not out.is_open: + break + + frame_start = time.time() + ctx.t = frame_num * frame_time + ctx.frame_num = frame_num + + # Build frame environment with context + frame_env = { + 'ctx': { + 't': ctx.t, + 'frame-num': ctx.frame_num, + 'fps': ctx.fps, + }, + 't': ctx.t, # Also expose t directly for convenience + 'frame-num': ctx.frame_num, + } + + # Step scans + t0 = time.time() + self._step_scans(ctx, frame_env) + scan_times.append(time.time() - t0) + + # Evaluate pipeline + t1 = time.time() + result = self._eval(self.frame_pipeline, frame_env) + eval_times.append(time.time() - t1) + + t2 = time.time() + if result is not None: + if isinstance(result, DeferredEffectChain): + # Accumulate for batching + batch_deferred.append(result) + batch_times.append(ctx.t) + + # Flush when batch is full + if len(batch_deferred) >= self.jax_batch_size: + flush_batch() + else: + # Not deferred - flush any pending batch first, then write + flush_batch() + if hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, ctx.t) + write_times.append(time.time() - t2) + + frame_elapsed = time.time() - frame_start + frame_times.append(frame_elapsed) + + # Checkpoint at segment boundaries (every _frames_per_segment frames) + if frame_num > 0 and frame_num % self._frames_per_segment == 0: + if self.on_checkpoint: + try: + checkpoint = { + 'frame_num': frame_num, + 't': ctx.t, + 'scans': self._get_checkpoint_state(), + } + self.on_checkpoint(checkpoint) + except Exception as e: + print(f"Warning: checkpoint callback failed: {e}", file=sys.stderr) + + # Progress with timing and profile breakdown + if frame_num % profile_interval == 0 and frame_num > 0: + pct = 100 * frame_num / n_frames + avg_ms = 1000 * sum(frame_times[-profile_interval:]) / max(1, len(frame_times[-profile_interval:])) + avg_scan = 1000 * sum(scan_times[-profile_interval:]) / max(1, len(scan_times[-profile_interval:])) + avg_eval = 1000 * sum(eval_times[-profile_interval:]) / max(1, len(eval_times[-profile_interval:])) + avg_write = 1000 * sum(write_times[-profile_interval:]) / max(1, len(write_times[-profile_interval:])) + target_ms = 1000 * frame_time + print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms] scan={avg_scan:.0f}ms eval={avg_eval:.0f}ms write={avg_write:.0f}ms", end="", file=sys.stderr, flush=True) + + # Call progress callback if set (for Celery task state updates) + if self.on_progress: + try: + self.on_progress(pct, frame_num, n_frames) + except Exception as e: + print(f"Warning: progress callback failed: {e}", file=sys.stderr) + + # Flush any remaining batch + flush_batch() + + finally: + out.close() + # Store output for access to properties like playlist_cid + self.output = out + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None, + sources_config: str = None, audio_config: str = None, use_jax: bool = False): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path, use_jax=use_jax) + if fps: + interp.config['fps'] = fps + if sources_config: + interp.sources_config = Path(sources_config) + if audio_config: + interp.audio_config = Path(audio_config) + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None) + parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file") + parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file") + parser.add_argument("--jax", action="store_true", help="Enable JAX acceleration for effects") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps, + sources_config=args.sources_config, audio_config=args.audio_config, + use_jax=args.jax) diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000..6a07c25 --- /dev/null +++ b/tasks/__init__.py @@ -0,0 +1,13 @@ +# art-celery/tasks - Celery tasks for streaming video rendering +# +# Tasks: +# 1. run_stream - Execute a streaming S-expression recipe +# 2. upload_to_ipfs - Background IPFS upload for media files + +from .streaming import run_stream +from .ipfs_upload import upload_to_ipfs + +__all__ = [ + "run_stream", + "upload_to_ipfs", +] diff --git a/tasks/ipfs_upload.py b/tasks/ipfs_upload.py new file mode 100644 index 0000000..541f850 --- /dev/null +++ b/tasks/ipfs_upload.py @@ -0,0 +1,93 @@ +""" +Background IPFS upload task. + +Uploads files to IPFS in the background after initial local storage. +This allows fast uploads while still getting IPFS CIDs eventually. +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Optional + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from celery_app import app +import ipfs_client + +logger = logging.getLogger(__name__) + + +@app.task(bind=True, max_retries=3, default_retry_delay=60) +def upload_to_ipfs(self, local_cid: str, actor_id: str) -> Optional[str]: + """ + Upload a locally cached file to IPFS in the background. + + Args: + local_cid: The local content hash of the file + actor_id: The user who uploaded the file + + Returns: + IPFS CID if successful, None if failed + """ + from cache_manager import get_cache_manager + import asyncio + import database + + logger.info(f"Background IPFS upload starting for {local_cid[:16]}...") + + try: + cache_mgr = get_cache_manager() + + # Get the file path from local cache + file_path = cache_mgr.get_by_cid(local_cid) + if not file_path or not file_path.exists(): + logger.error(f"File not found for local CID {local_cid[:16]}...") + return None + + # Upload to IPFS + logger.info(f"Uploading {file_path} to IPFS...") + ipfs_cid = ipfs_client.add_file(file_path) + + if not ipfs_cid: + logger.error(f"IPFS upload failed for {local_cid[:16]}...") + raise self.retry(exc=Exception("IPFS upload failed")) + + logger.info(f"IPFS upload successful: {local_cid[:16]}... -> {ipfs_cid[:16]}...") + + # Update database with IPFS CID + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Initialize database pool if needed + loop.run_until_complete(database.init_pool()) + + # Update cache_items table + loop.run_until_complete( + database.update_cache_item_ipfs_cid(local_cid, ipfs_cid) + ) + + # Update friendly_names table to use IPFS CID instead of local hash + # This ensures assets can be fetched by remote workers via IPFS + try: + loop.run_until_complete( + database.update_friendly_name_cid(actor_id, local_cid, ipfs_cid) + ) + logger.info(f"Friendly name updated: {local_cid[:16]}... -> {ipfs_cid[:16]}...") + except Exception as e: + logger.warning(f"Failed to update friendly name CID: {e}") + + # Create index from IPFS CID to local cache + cache_mgr._set_content_index(ipfs_cid, local_cid) + + logger.info(f"Database updated with IPFS CID for {local_cid[:16]}...") + finally: + loop.close() + + return ipfs_cid + + except Exception as e: + logger.error(f"Background IPFS upload error: {e}") + raise self.retry(exc=e) diff --git a/tasks/streaming.py b/tasks/streaming.py new file mode 100644 index 0000000..7ac6057 --- /dev/null +++ b/tasks/streaming.py @@ -0,0 +1,724 @@ +""" +Streaming video rendering task. + +Executes S-expression recipes for frame-by-frame video processing. +Supports CID and friendly name references for assets. +Supports pause/resume/restart for long renders. +""" + +import hashlib +import logging +import os +import signal +import sys +import tempfile +from pathlib import Path +from typing import Dict, Optional + +from celery import current_task + +# Add parent directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from celery_app import app +from cache_manager import get_cache_manager + +logger = logging.getLogger(__name__) + + +class PauseRequested(Exception): + """Raised when user requests pause via SIGTERM.""" + pass + +# Debug: verify module is being loaded +print(f"DEBUG MODULE LOAD: tasks/streaming.py loaded at {__file__}", file=sys.stderr) + +# Module-level event loop for database operations +_resolve_loop = None +_db_initialized = False + + +def resolve_asset(ref: str, actor_id: Optional[str] = None) -> Optional[Path]: + """ + Resolve an asset reference (CID or friendly name) to a file path. + + Args: + ref: CID or friendly name (e.g., "my-video" or "QmXyz...") + actor_id: User ID for friendly name resolution + + Returns: + Path to the asset file, or None if not found + """ + global _resolve_loop, _db_initialized + import sys + print(f"RESOLVE_ASSET: ref={ref}, actor_id={actor_id}", file=sys.stderr) + cache_mgr = get_cache_manager() + + # Try as direct CID first + path = cache_mgr.get_by_cid(ref) + print(f"RESOLVE_ASSET: get_by_cid({ref}) = {path}", file=sys.stderr) + if path and path.exists(): + logger.info(f"Resolved {ref[:16]}... as CID to {path}") + return path + + # Try as friendly name if actor_id provided + print(f"RESOLVE_ASSET: trying friendly name lookup, actor_id={actor_id}", file=sys.stderr) + if actor_id: + from database import resolve_friendly_name_sync, get_ipfs_cid_sync + + try: + # Use synchronous database functions to avoid event loop issues + cid = resolve_friendly_name_sync(actor_id, ref) + print(f"RESOLVE_ASSET: resolve_friendly_name_sync({actor_id}, {ref}) = {cid}", file=sys.stderr) + + if cid: + path = cache_mgr.get_by_cid(cid) + print(f"RESOLVE_ASSET: get_by_cid({cid}) = {path}", file=sys.stderr) + if path and path.exists(): + print(f"RESOLVE_ASSET: SUCCESS - resolved to {path}", file=sys.stderr) + logger.info(f"Resolved '{ref}' via friendly name to {path}") + return path + + # File not in local cache - look up IPFS CID and fetch + # The cid from friendly_names is internal, need to get ipfs_cid from cache_items + ipfs_cid = get_ipfs_cid_sync(cid) + if not ipfs_cid or ipfs_cid == cid: + # No separate IPFS CID, try using the cid directly (might be IPFS CID) + ipfs_cid = cid + print(f"RESOLVE_ASSET: file not local, trying IPFS fetch for {ipfs_cid}", file=sys.stderr) + import ipfs_client + content = ipfs_client.get_bytes(ipfs_cid, use_gateway_fallback=True) + if content: + # Save to local cache + import tempfile + from pathlib import Path + with tempfile.NamedTemporaryFile(delete=False, suffix='.sexp') as tmp: + tmp.write(content) + tmp_path = Path(tmp.name) + # Store in cache + cached_file, _ = cache_mgr.put(tmp_path, node_type="effect", skip_ipfs=True) + # Index by IPFS CID for future lookups + cache_mgr._set_content_index(cid, cached_file.cid) + print(f"RESOLVE_ASSET: fetched from IPFS and cached at {cached_file.path}", file=sys.stderr) + logger.info(f"Fetched '{ref}' from IPFS and cached at {cached_file.path}") + return cached_file.path + else: + print(f"RESOLVE_ASSET: IPFS fetch failed for {cid}", file=sys.stderr) + except Exception as e: + print(f"RESOLVE_ASSET: ERROR - {e}", file=sys.stderr) + logger.warning(f"Failed to resolve friendly name '{ref}': {e}") + + logger.warning(f"Could not resolve asset reference: {ref}") + return None + + +class CIDVideoSource: + """ + Video source that resolves CIDs to file paths. + + Wraps the streaming VideoSource to work with cached assets. + """ + + def __init__(self, cid: str, fps: float = 30, actor_id: Optional[str] = None): + self.cid = cid + self.fps = fps + self.actor_id = actor_id + self._source = None + + def _ensure_source(self): + if self._source is None: + logger.info(f"CIDVideoSource._ensure_source: resolving cid={self.cid} with actor_id={self.actor_id}") + path = resolve_asset(self.cid, self.actor_id) + if not path: + raise ValueError(f"Could not resolve video source '{self.cid}' for actor_id={self.actor_id}") + + logger.info(f"CIDVideoSource._ensure_source: resolved to path={path}") + # Use GPU-accelerated video source if available + try: + from sexp_effects.primitive_libs.streaming_gpu import GPUVideoSource, GPU_AVAILABLE + if GPU_AVAILABLE: + logger.info(f"CIDVideoSource: using GPUVideoSource for {path}") + self._source = GPUVideoSource(str(path), self.fps, prefer_gpu=True) + else: + raise ImportError("GPU not available") + except (ImportError, Exception) as e: + logger.info(f"CIDVideoSource: falling back to CPU VideoSource ({e})") + from sexp_effects.primitive_libs.streaming import VideoSource + self._source = VideoSource(str(path), self.fps) + + def read_at(self, t: float): + self._ensure_source() + return self._source.read_at(t) + + def read(self): + self._ensure_source() + return self._source.read() + + @property + def size(self): + self._ensure_source() + return self._source.size + + @property + def duration(self): + self._ensure_source() + return self._source._duration + + @property + def path(self): + self._ensure_source() + return self._source.path + + @property + def _stream_time(self): + self._ensure_source() + return self._source._stream_time + + def skip(self): + self._ensure_source() + return self._source.skip() + + def close(self): + if self._source: + self._source.close() + + +class CIDAudioAnalyzer: + """ + Audio analyzer that resolves CIDs to file paths. + """ + + def __init__(self, cid: str, actor_id: Optional[str] = None): + self.cid = cid + self.actor_id = actor_id + self._analyzer = None + + def _ensure_analyzer(self): + if self._analyzer is None: + path = resolve_asset(self.cid, self.actor_id) + if not path: + raise ValueError(f"Could not resolve audio source: {self.cid}") + + from sexp_effects.primitive_libs.streaming import AudioAnalyzer + self._analyzer = AudioAnalyzer(str(path)) + + def get_energy(self, t: float) -> float: + self._ensure_analyzer() + return self._analyzer.get_energy(t) + + def get_beat(self, t: float) -> bool: + self._ensure_analyzer() + return self._analyzer.get_beat(t) + + def get_beat_count(self, t: float) -> int: + self._ensure_analyzer() + return self._analyzer.get_beat_count(t) + + @property + def duration(self): + self._ensure_analyzer() + return self._analyzer.duration + + +def create_cid_primitives(actor_id: Optional[str] = None): + """ + Create CID-aware primitive functions. + + Returns dict of primitives that resolve CIDs before creating sources. + """ + from celery.utils.log import get_task_logger + cid_logger = get_task_logger(__name__) + def prim_make_video_source_cid(cid: str, fps: float = 30): + cid_logger.warning(f"DEBUG: CID-aware make-video-source: cid={cid}, actor_id={actor_id}") + return CIDVideoSource(cid, fps, actor_id) + + def prim_make_audio_analyzer_cid(cid: str): + cid_logger.warning(f"DEBUG: CID-aware make-audio-analyzer: cid={cid}, actor_id={actor_id}") + return CIDAudioAnalyzer(cid, actor_id) + + return { + 'streaming:make-video-source': prim_make_video_source_cid, + 'streaming:make-audio-analyzer': prim_make_audio_analyzer_cid, + } + + +@app.task(bind=True, name='tasks.run_stream') +def run_stream( + self, + run_id: str, + recipe_sexp: str, + output_name: str = "output.mp4", + duration: Optional[float] = None, + fps: Optional[float] = None, + actor_id: Optional[str] = None, + sources_sexp: Optional[str] = None, + audio_sexp: Optional[str] = None, + resume: bool = False, +) -> dict: + """ + Execute a streaming S-expression recipe. + + Args: + run_id: The run ID for database tracking + recipe_sexp: The recipe S-expression content + output_name: Name for the output file + duration: Optional duration override (seconds) + fps: Optional FPS override + actor_id: User ID for friendly name resolution + sources_sexp: Optional sources config S-expression + audio_sexp: Optional audio config S-expression + resume: If True, load checkpoint and resume from where we left off + + Returns: + Dict with output_cid, output_path, and status + """ + global _resolve_loop, _db_initialized + task_id = self.request.id + logger.info(f"Starting stream task {task_id} for run {run_id} (resume={resume})") + + # Handle graceful pause (SIGTERM from Celery revoke) + pause_requested = False + original_sigterm = signal.getsignal(signal.SIGTERM) + + def handle_sigterm(signum, frame): + nonlocal pause_requested + pause_requested = True + logger.info(f"Pause requested for run {run_id} (SIGTERM received)") + + signal.signal(signal.SIGTERM, handle_sigterm) + + self.update_state(state='INITIALIZING', meta={'progress': 0}) + + # Get the app directory for primitive/effect paths + app_dir = Path(__file__).parent.parent # celery/ + sexp_effects_dir = app_dir / "sexp_effects" + effects_dir = app_dir / "effects" + templates_dir = app_dir / "templates" + + # Create temp directory for work + work_dir = Path(tempfile.mkdtemp(prefix="stream_")) + recipe_path = work_dir / "recipe.sexp" + + # Write output to shared cache for live streaming access + cache_dir = Path(os.environ.get("CACHE_DIR", "/data/cache")) + stream_dir = cache_dir / "streaming" / run_id + stream_dir.mkdir(parents=True, exist_ok=True) + # Use IPFS HLS output for distributed streaming - segments uploaded to IPFS + output_path = str(stream_dir) + "/ipfs-hls" # /ipfs-hls suffix triggers IPFS HLS mode + + # Create symlinks to effect directories so relative paths work + (work_dir / "sexp_effects").symlink_to(sexp_effects_dir) + (work_dir / "effects").symlink_to(effects_dir) + (work_dir / "templates").symlink_to(templates_dir) + + try: + # Write recipe to temp file + recipe_path.write_text(recipe_sexp) + + # Write optional config files + sources_path = None + if sources_sexp: + sources_path = work_dir / "sources.sexp" + sources_path.write_text(sources_sexp) + + audio_path = None + if audio_sexp: + audio_path = work_dir / "audio.sexp" + audio_path.write_text(audio_sexp) + + self.update_state(state='RENDERING', meta={'progress': 5}) + + # Import the streaming interpreter + from streaming.stream_sexp_generic import StreamInterpreter + + # Load checkpoint if resuming + checkpoint = None + if resume: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + checkpoint = _resolve_loop.run_until_complete(database.get_run_checkpoint(run_id)) + if checkpoint: + logger.info(f"Loaded checkpoint for run {run_id}: frame {checkpoint.get('frame_num')}") + else: + logger.warning(f"No checkpoint found for run {run_id}, starting from beginning") + except Exception as e: + logger.error(f"Failed to load checkpoint: {e}") + checkpoint = None + + # Create interpreter (pass actor_id for friendly name resolution) + interp = StreamInterpreter(str(recipe_path), actor_id=actor_id) + + # Set primitive library directory explicitly + interp.primitive_lib_dir = sexp_effects_dir / "primitive_libs" + + if fps: + interp.config['fps'] = fps + if sources_path: + interp.sources_config = sources_path + if audio_path: + interp.audio_config = audio_path + + # Override primitives with CID-aware versions + cid_prims = create_cid_primitives(actor_id) + from celery.utils.log import get_task_logger + task_logger = get_task_logger(__name__) + task_logger.warning(f"DEBUG: Overriding primitives: {list(cid_prims.keys())}") + task_logger.warning(f"DEBUG: Primitives before: {list(interp.primitives.keys())[:10]}...") + interp.primitives.update(cid_prims) + task_logger.warning(f"DEBUG: streaming:make-video-source is now: {type(interp.primitives.get('streaming:make-video-source'))}") + + # Set up callback to update database when IPFS playlist is created (for live HLS redirect) + def on_playlist_update(playlist_cid, quality_playlists=None): + """Update database with playlist CID and quality info. + + Args: + playlist_cid: Master playlist CID + quality_playlists: Dict of quality name -> {cid, width, height, bitrate} + """ + global _resolve_loop, _db_initialized + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_playlist(run_id, playlist_cid, quality_playlists)) + logger.info(f"Updated pending run {run_id} with IPFS playlist: {playlist_cid}, qualities: {list(quality_playlists.keys()) if quality_playlists else []}") + except Exception as e: + logger.error(f"Failed to update playlist CID in database: {e}") + + interp.on_playlist_update = on_playlist_update + + # Set up progress callback to update Celery task state + def on_progress(pct, frame_num, total_frames): + nonlocal pause_requested + # Scale progress: 5% (start) to 85% (before caching) + scaled_progress = 5 + (pct * 0.8) # 5% to 85% + self.update_state(state='RENDERING', meta={ + 'progress': scaled_progress, + 'frame': frame_num, + 'total_frames': total_frames, + 'percent': pct, + }) + + interp.on_progress = on_progress + + # Set up checkpoint callback to save state at segment boundaries + def on_checkpoint(ckpt): + """Save checkpoint state to database.""" + nonlocal pause_requested + global _resolve_loop, _db_initialized + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + + # Get total frames from interpreter config + total_frames = None + if hasattr(interp, 'output') and hasattr(interp.output, '_frame_count'): + # Estimate total frames based on duration + fps_val = interp.config.get('fps', 30) + for name, val in interp.globals.items(): + if hasattr(val, 'duration'): + total_frames = int(val.duration * fps_val) + break + + _resolve_loop.run_until_complete(database.update_pending_run_checkpoint( + run_id=run_id, + checkpoint_frame=ckpt['frame_num'], + checkpoint_t=ckpt['t'], + checkpoint_scans=ckpt.get('scans'), + total_frames=total_frames, + )) + logger.info(f"Saved checkpoint for run {run_id}: frame {ckpt['frame_num']}") + + # Check if pause was requested after checkpoint + if pause_requested: + logger.info(f"Pause requested after checkpoint, raising PauseRequested") + raise PauseRequested("Render paused by user") + + except PauseRequested: + raise # Re-raise to stop the render + except Exception as e: + logger.error(f"Failed to save checkpoint: {e}") + + interp.on_checkpoint = on_checkpoint + + # Build resume state for the interpreter (includes segment CIDs for output) + resume_from = None + if checkpoint: + resume_from = { + 'frame_num': checkpoint.get('frame_num'), + 't': checkpoint.get('t'), + 'scans': checkpoint.get('scans', {}), + } + # Add segment CIDs if available (from quality_playlists in checkpoint) + # Note: We need to extract segment_cids from the output's state, which isn't + # directly stored. For now, the output will re-check existing segments on disk. + + # Run rendering to file + logger.info(f"Rendering to {output_path}" + (f" (resuming from frame {resume_from['frame_num']})" if resume_from else "")) + render_paused = False + try: + interp.run(duration=duration, output=str(output_path), resume_from=resume_from) + except PauseRequested: + # Graceful pause - checkpoint already saved + render_paused = True + logger.info(f"Render paused for run {run_id}") + + # Restore original signal handler + signal.signal(signal.SIGTERM, original_sigterm) + + if render_paused: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status(run_id, 'paused')) + except Exception as e: + logger.error(f"Failed to update status to paused: {e}") + return {"status": "paused", "run_id": run_id, "task_id": task_id} + + # Check for interpreter errors + if interp.errors: + error_msg = f"Rendering failed with {len(interp.errors)} errors: {interp.errors[0]}" + raise RuntimeError(error_msg) + + self.update_state(state='CACHING', meta={'progress': 90}) + + # Get IPFS playlist CID if available (from IPFSHLSOutput) + ipfs_playlist_cid = None + ipfs_playlist_url = None + segment_cids = {} + if hasattr(interp, 'output') and hasattr(interp.output, 'playlist_cid'): + ipfs_playlist_cid = interp.output.playlist_cid + ipfs_playlist_url = interp.output.playlist_url + segment_cids = getattr(interp.output, 'segment_cids', {}) + logger.info(f"IPFS HLS: playlist={ipfs_playlist_cid}, segments={len(segment_cids)}") + + # Update pending run with playlist CID for live HLS redirect + if ipfs_playlist_cid: + import asyncio + import database + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_playlist(run_id, ipfs_playlist_cid)) + logger.info(f"Updated pending run {run_id} with IPFS playlist CID: {ipfs_playlist_cid}") + except Exception as e: + logger.error(f"Failed to update pending run with playlist CID: {e}") + raise # Fail fast - database errors should not be silently ignored + + # HLS output creates playlist and segments + # - Single-res: stream_dir/stream.m3u8 and stream_dir/segment_*.ts + # - Multi-res: stream_dir/original/playlist.m3u8 and stream_dir/original/segment_*.ts + hls_playlist = stream_dir / "stream.m3u8" + if not hls_playlist.exists(): + # Try multi-res output path + hls_playlist = stream_dir / "original" / "playlist.m3u8" + + # Validate HLS output (must have playlist and at least one segment) + if not hls_playlist.exists(): + raise RuntimeError("HLS playlist not created - rendering likely failed") + + segments = list(stream_dir.glob("segment_*.ts")) + if not segments: + # Try multi-res output path + segments = list(stream_dir.glob("original/segment_*.ts")) + if not segments: + raise RuntimeError("No HLS segments created - rendering likely failed") + + logger.info(f"HLS rendering complete: {len(segments)} segments created, IPFS playlist: {ipfs_playlist_cid}") + + # Mux HLS segments into a single MP4 for persistent cache storage + final_mp4 = stream_dir / "output.mp4" + import subprocess + mux_cmd = [ + "ffmpeg", "-y", + "-i", str(hls_playlist), + "-c", "copy", # Just copy streams, no re-encoding + "-movflags", "+faststart", # Move moov atom to start for web playback + "-fflags", "+genpts", # Generate proper timestamps + str(final_mp4) + ] + logger.info(f"Muxing HLS to MP4: {' '.join(mux_cmd)}") + result = subprocess.run(mux_cmd, capture_output=True, text=True) + if result.returncode != 0: + logger.warning(f"HLS mux failed: {result.stderr}") + # Fall back to using the first segment for caching + final_mp4 = segments[0] + + # Store output in cache + if final_mp4.exists(): + cache_mgr = get_cache_manager() + cached_file, ipfs_cid = cache_mgr.put( + source_path=final_mp4, + node_type="STREAM_OUTPUT", + node_id=f"stream_{task_id}", + ) + + logger.info(f"Stream output cached: CID={cached_file.cid}, IPFS={ipfs_cid}") + + # Save to database - reuse the module-level loop to avoid pool conflicts + import asyncio + import database + + try: + # Reuse or create event loop + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + + # Initialize database pool if needed + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + + # Get recipe CID from pending_run + pending = _resolve_loop.run_until_complete(database.get_pending_run(run_id)) + recipe_cid = pending.get("recipe", "streaming") if pending else "streaming" + + # Save to run_cache for completed runs + logger.info(f"Saving run {run_id} to run_cache with actor_id={actor_id}") + _resolve_loop.run_until_complete(database.save_run_cache( + run_id=run_id, + output_cid=cached_file.cid, + recipe=recipe_cid, + inputs=[], + ipfs_cid=ipfs_cid, + actor_id=actor_id, + )) + # Register output as video type so frontend displays it correctly + _resolve_loop.run_until_complete(database.add_item_type( + cid=cached_file.cid, + actor_id=actor_id, + item_type="video", + path=str(cached_file.path), + description=f"Stream output from run {run_id}", + )) + logger.info(f"Registered output {cached_file.cid} as video type") + # Update pending run status + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="completed", + )) + logger.info(f"Saved run {run_id} to database with actor_id={actor_id}") + except Exception as db_err: + logger.error(f"Failed to save run to database: {db_err}") + raise RuntimeError(f"Database error saving run {run_id}: {db_err}") from db_err + + return { + "status": "completed", + "run_id": run_id, + "task_id": task_id, + "output_cid": cached_file.cid, + "ipfs_cid": ipfs_cid, + "output_path": str(cached_file.path), + # IPFS HLS streaming info + "ipfs_playlist_cid": ipfs_playlist_cid, + "ipfs_playlist_url": ipfs_playlist_url, + "ipfs_segment_count": len(segment_cids), + } + else: + # Update pending run status to failed - reuse module loop + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="failed", + error="Output file not created", + )) + except Exception as db_err: + logger.warning(f"Failed to update run status: {db_err}") + + return { + "status": "failed", + "run_id": run_id, + "task_id": task_id, + "error": "Output file not created", + } + + except Exception as e: + logger.error(f"Stream task {task_id} failed: {e}") + import traceback + traceback.print_exc() + + # Update pending run status to failed - reuse module loop + import asyncio + import database + + try: + if _resolve_loop is None or _resolve_loop.is_closed(): + _resolve_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_resolve_loop) + _db_initialized = False + if not _db_initialized: + _resolve_loop.run_until_complete(database.init_db()) + _db_initialized = True + _resolve_loop.run_until_complete(database.update_pending_run_status( + run_id=run_id, + status="failed", + error=str(e), + )) + except Exception as db_err: + logger.warning(f"Failed to update run status: {db_err}") + + return { + "status": "failed", + "run_id": run_id, + "task_id": task_id, + "error": str(e), + } + + finally: + # Cleanup temp directory only - NOT the streaming directory! + # The streaming directory contains HLS segments that may still be uploading + # to IPFS. Deleting it prematurely causes upload failures and missing segments. + # + # stream_dir cleanup should happen via: + # 1. A separate cleanup task that runs after confirming IPFS uploads succeeded + # 2. Or a periodic cleanup job that removes old streaming dirs + import shutil + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + # NOTE: stream_dir is intentionally NOT deleted here to allow IPFS uploads to complete + # TODO: Implement a deferred cleanup mechanism for stream_dir after IPFS confirmation diff --git a/templates/crossfade-zoom.sexp b/templates/crossfade-zoom.sexp new file mode 100644 index 0000000..fc6d9ad --- /dev/null +++ b/templates/crossfade-zoom.sexp @@ -0,0 +1,25 @@ +;; Crossfade with Zoom Transition +;; +;; Macro for transitioning between two frames with a zoom effect. +;; Active frame zooms out while next frame zooms in. +;; +;; Required context: +;; - zoom effect must be loaded +;; - blend effect must be loaded +;; +;; Parameters: +;; active-frame: current frame +;; next-frame: frame to transition to +;; fade-amt: transition progress (0 = all active, 1 = all next) +;; +;; Usage: +;; (include :path "../templates/crossfade-zoom.sexp") +;; ... +;; (crossfade-zoom active-frame next-frame 0.5) + +(defmacro crossfade-zoom (active-frame next-frame fade-amt) + (let [active-zoom (+ 1.0 fade-amt) + active-zoomed (zoom active-frame :amount active-zoom) + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (zoom next-frame :amount next-zoom)] + (blend active-zoomed next-zoomed :opacity fade-amt))) diff --git a/templates/cycle-crossfade.sexp b/templates/cycle-crossfade.sexp new file mode 100644 index 0000000..40a87ca --- /dev/null +++ b/templates/cycle-crossfade.sexp @@ -0,0 +1,65 @@ +;; cycle-crossfade template +;; +;; Generalized cycling zoom-crossfade for any number of video layers. +;; Cycles through videos with smooth zoom-based crossfade transitions. +;; +;; Parameters: +;; beat-data - beat analysis node (drives timing) +;; input-videos - list of video nodes to cycle through +;; init-clen - initial cycle length in beats +;; +;; NOTE: The parameter is named "input-videos" (not "videos") because +;; template substitution replaces param names everywhere in the AST. +;; The planner's _expand_slice_on injects env["videos"] at plan time, +;; so (len videos) inside the lambda references that injected value. + +(deftemplate cycle-crossfade + (beat-data input-videos init-clen) + + (slice-on beat-data + :videos input-videos + :init {:cycle 0 :beat 0 :clen init-clen} + :fn (lambda [acc i start end] + (let [beat (get acc "beat") + clen (get acc "clen") + active (get acc "cycle") + n (len videos) + phase3 (* beat 3) + wt (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (- 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.0)) + (if (= active prev) + (if (< phase3 clen) 0.0 + (if (< phase3 (* clen 2)) + (* (/ (- phase3 clen) clen) 1.0) + 1.0)) + 0.0)))) + zm (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + ;; Active video: normal -> zoom out during transition -> tiny + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (+ 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.1)) + (if (= active prev) + ;; Incoming video: tiny -> zoom in during transition -> normal + (if (< phase3 clen) 0.1 + (if (< phase3 (* clen 2)) + (+ 0.1 (* (/ (- phase3 clen) clen) 0.9)) + 1.0)) + 0.1)))) + new-acc (if (< (+ beat 1) clen) + (dict :cycle active :beat (+ beat 1) :clen clen) + (dict :cycle (mod (+ active 1) n) :beat 0 + :clen (+ 40 (mod (* i 7) 41))))] + {:layers (map (lambda [p] + {:video p :effects [{:effect zoom :amount (zm p)}]}) + (range 0 n)) + :compose {:effect blend_multi :mode "alpha" + :weights (map (lambda [p] (wt p)) (range 0 n))} + :acc new-acc})))) diff --git a/templates/process-pair.sexp b/templates/process-pair.sexp new file mode 100644 index 0000000..6720cd2 --- /dev/null +++ b/templates/process-pair.sexp @@ -0,0 +1,112 @@ +;; process-pair template +;; +;; Reusable video-pair processor: takes a single video source, creates two +;; clips (A and B) with opposite rotations and sporadic effects, blends them, +;; and applies a per-pair slow rotation driven by a beat scan. +;; +;; All sporadic triggers (invert, hue-shift, ascii) and pair-level controls +;; (blend opacity, rotation) are defined internally using seed offsets. +;; +;; Parameters: +;; video - source video node +;; energy - energy analysis node (drives rotation/zoom amounts) +;; beat-data - beat analysis node (drives sporadic triggers) +;; rng - RNG object from (make-rng seed) for auto-derived seeds +;; rot-dir - initial rotation direction: 1 (clockwise) or -1 (anti-clockwise) +;; rot-a/b - rotation ranges for clip A/B (e.g. [0 45]) +;; zoom-a/b - zoom ranges for clip A/B (e.g. [1 1.5]) + +(deftemplate process-pair + (video energy beat-data rng rot-dir rot-a rot-b zoom-a zoom-b) + + ;; --- Sporadic triggers for clip A --- + + ;; Invert: 10% chance per beat, lasts 1-5 beats + (def inv-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + ;; Hue shift: 10% chance, random hue 30-330 deg, lasts 1-5 beats + (def hue-a (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + ;; ASCII art: 5% chance, lasts 1-3 beats + (def ascii-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Sporadic triggers for clip B (offset seeds) --- + + (def inv-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + (def hue-b (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + (def ascii-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Pair-level controls --- + + ;; Internal A/B blend: randomly show A (0), both (0.5), or B (1), every 1-11 beats + (def pair-mix (scan beat-data :rng rng + :init (dict :rem 0 :opacity 0.5) + :step (if (> rem 0) + (dict :rem (- rem 1) :opacity opacity) + (dict :rem (rand-int 1 11) :opacity (* (rand-int 0 2) 0.5))) + :emit opacity)) + + ;; Per-pair rotation: one full rotation every 20-30 beats, alternating direction + (def pair-rot (scan beat-data :rng rng + :init (dict :beat 0 :clen 25 :dir rot-dir :angle 0) + :step (if (< (+ beat 1) clen) + (dict :beat (+ beat 1) :clen clen :dir dir + :angle (+ angle (* dir (/ 360 clen)))) + (dict :beat 0 :clen (rand-int 20 30) :dir (* dir -1) + :angle angle)) + :emit angle)) + + ;; --- Clip A processing --- + (def clip-a (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-a (-> clip-a + (effect rotate :angle (bind energy values :range rot-a)) + (effect zoom :amount (bind energy values :range zoom-a)) + (effect invert :amount (bind inv-a values)) + (effect hue_shift :degrees (bind hue-a values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-a values)) + )) + + ;; --- Clip B processing --- + (def clip-b (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-b (-> clip-b + (effect rotate :angle (bind energy values :range rot-b)) + (effect zoom :amount (bind energy values :range zoom-b)) + (effect invert :amount (bind inv-b values)) + (effect hue_shift :degrees (bind hue-b values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-b values)) + )) + + ;; --- Blend A+B and apply pair rotation --- + (-> rotated-a + (effect blend rotated-b + :mode "alpha" :opacity (bind pair-mix values) :resize_mode "fit") + (effect rotate + :angle (bind pair-rot values)))) diff --git a/templates/scan-oscillating-spin.sexp b/templates/scan-oscillating-spin.sexp new file mode 100644 index 0000000..051f079 --- /dev/null +++ b/templates/scan-oscillating-spin.sexp @@ -0,0 +1,28 @@ +;; Oscillating Spin Scan +;; +;; Accumulates rotation angle on each beat, reversing direction +;; periodically for an oscillating effect. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; +;; Provides scan: spin +;; Bind with: (bind spin :angle) ;; cumulative rotation angle +;; +;; Behavior: +;; - Rotates 14.4 degrees per beat (completes 360 in 25 beats) +;; - After 20-30 beats, reverses direction +;; - Creates a swinging/oscillating rotation effect +;; +;; Usage: +;; (include :path "../templates/scan-oscillating-spin.sexp") +;; +;; In frame: +;; (rotate frame :angle (bind spin :angle)) + +(scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :left 25} + :step (if (> left 0) + (dict :angle (+ angle (* dir 14.4)) :dir dir :left (- left 1)) + (dict :angle angle :dir (* dir -1) + :left (+ 20 (mod (streaming:audio-beat-count music t) 11))))) diff --git a/templates/scan-ripple-drops.sexp b/templates/scan-ripple-drops.sexp new file mode 100644 index 0000000..7caf720 --- /dev/null +++ b/templates/scan-ripple-drops.sexp @@ -0,0 +1,41 @@ +;; Beat-Triggered Ripple Drops Scan +;; +;; Creates random ripple drops triggered by audio beats. +;; Each drop has a random center position and duration. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; - core primitives loaded +;; +;; Provides scan: ripple-state +;; Bind with: (bind ripple-state :gate) ;; 0 or 1 +;; (bind ripple-state :cx) ;; center x (0-1) +;; (bind ripple-state :cy) ;; center y (0-1) +;; +;; Parameters: +;; trigger-chance: probability per beat (default 0.15) +;; min-duration: minimum beats (default 1) +;; max-duration: maximum beats (default 15) +;; +;; Usage: +;; (include :path "../templates/scan-ripple-drops.sexp") +;; ;; Uses default: 15% chance, 1-15 beat duration +;; +;; In frame: +;; (let [rip-gate (bind ripple-state :gate) +;; rip-amp (* rip-gate (core:map-range e 0 1 5 50))] +;; (ripple frame +;; :amplitude rip-amp +;; :center_x (bind ripple-state :cx) +;; :center_y (bind ripple-state :cy))) + +(scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 0.5 :cy 0.5 :left 0} + :step (if (> left 0) + (dict :gate 1 :cx cx :cy cy :left (- left 1)) + (if (< (core:rand) 0.15) + (dict :gate 1 + :cx (+ 0.2 (* (core:rand) 0.6)) + :cy (+ 0.2 (* (core:rand) 0.6)) + :left (+ 1 (mod (streaming:audio-beat-count music t) 15))) + (dict :gate 0 :cx 0.5 :cy 0.5 :left 0)))) diff --git a/templates/standard-effects.sexp b/templates/standard-effects.sexp new file mode 100644 index 0000000..ce4a92f --- /dev/null +++ b/templates/standard-effects.sexp @@ -0,0 +1,22 @@ +;; Standard Effects Bundle +;; +;; Loads commonly-used video effects. +;; Include after primitives are loaded. +;; +;; Effects provided: +;; - rotate: rotation by angle +;; - zoom: scale in/out +;; - blend: alpha blend two frames +;; - ripple: water ripple distortion +;; - invert: color inversion +;; - hue_shift: hue rotation +;; +;; Usage: +;; (include :path "../templates/standard-effects.sexp") + +(effect rotate :name "fx-rotate") +(effect zoom :name "fx-zoom") +(effect blend :name "fx-blend") +(effect ripple :name "fx-ripple") +(effect invert :name "fx-invert") +(effect hue_shift :name "fx-hue-shift") diff --git a/templates/standard-primitives.sexp b/templates/standard-primitives.sexp new file mode 100644 index 0000000..6e2c62d --- /dev/null +++ b/templates/standard-primitives.sexp @@ -0,0 +1,14 @@ +;; Standard Primitives Bundle +;; +;; Loads all commonly-used primitive libraries. +;; Include this at the top of streaming recipes. +;; +;; Usage: +;; (include :path "../templates/standard-primitives.sexp") + +(require-primitives "geometry") +(require-primitives "core") +(require-primitives "image") +(require-primitives "blending") +(require-primitives "color_ops") +(require-primitives "streaming") diff --git a/templates/stream-process-pair.sexp b/templates/stream-process-pair.sexp new file mode 100644 index 0000000..55f408e --- /dev/null +++ b/templates/stream-process-pair.sexp @@ -0,0 +1,72 @@ +;; stream-process-pair template (streaming-compatible) +;; +;; Macro for processing a video source pair with full effects. +;; Reads source, applies A/B effects (rotate, zoom, invert, hue), blends, +;; and applies pair-level rotation. +;; +;; Required context (must be defined in calling scope): +;; - sources: array of video sources +;; - pair-configs: array of {:dir :rot-a :rot-b :zoom-a :zoom-b} configs +;; - pair-states: array from (bind pairs :states) +;; - now: current time (t) +;; - e: audio energy (0-1) +;; +;; Required effects (must be loaded): +;; - rotate, zoom, invert, hue_shift, blend +;; +;; Usage: +;; (include :path "../templates/stream-process-pair.sexp") +;; ...in frame pipeline... +;; (let [pair-states (bind pairs :states) +;; now t +;; e (streaming:audio-energy music now)] +;; (process-pair 0)) ;; process source at index 0 + +(require-primitives "core") + +(defmacro process-pair (src-idx) + (let [src (nth sources src-idx) + frame (streaming:source-read src now) + cfg (nth pair-configs src-idx) + state (nth pair-states src-idx) + + ;; Get state values (invert uses countdown > 0) + inv-a-active (if (> (get state :inv-a) 0) 1 0) + inv-b-active (if (> (get state :inv-b) 0) 1 0) + ;; Hue is active only when countdown > 0 + hue-a-val (if (> (get state :hue-a) 0) (get state :hue-a-val) 0) + hue-b-val (if (> (get state :hue-b) 0) (get state :hue-b-val) 0) + mix-opacity (get state :mix) + pair-rot-angle (* (get state :angle) (get cfg :dir)) + + ;; Get config values for energy-mapped ranges + rot-a-max (get cfg :rot-a) + rot-b-max (get cfg :rot-b) + zoom-a-max (get cfg :zoom-a) + zoom-b-max (get cfg :zoom-b) + + ;; Energy-driven rotation and zoom + rot-a (core:map-range e 0 1 0 rot-a-max) + rot-b (core:map-range e 0 1 0 rot-b-max) + zoom-a (core:map-range e 0 1 1 zoom-a-max) + zoom-b (core:map-range e 0 1 1 zoom-b-max) + + ;; Apply effects to clip A + clip-a (-> frame + (rotate :angle rot-a) + (zoom :amount zoom-a) + (invert :amount inv-a-active) + (hue_shift :degrees hue-a-val)) + + ;; Apply effects to clip B + clip-b (-> frame + (rotate :angle rot-b) + (zoom :amount zoom-b) + (invert :amount inv-b-active) + (hue_shift :degrees hue-b-val)) + + ;; Blend A+B + blended (blend clip-a clip-b :opacity mix-opacity)] + + ;; Apply pair-level rotation + (rotate blended :angle pair-rot-angle))) diff --git a/test_autonomous.sexp b/test_autonomous.sexp new file mode 100644 index 0000000..9e190a2 --- /dev/null +++ b/test_autonomous.sexp @@ -0,0 +1,36 @@ +;; Autonomous Pipeline Test +;; +;; Uses the autonomous-pipeline primitive which computes ALL parameters +;; on GPU - including sin/cos expressions. Zero Python in the hot path! + +(stream "autonomous_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + + ;; Effects pipeline (what effects to apply) + (def effects + [{:op "rotate" :angle 0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Dynamic expressions (computed on GPU!) + ;; These use CUDA syntax: sinf(), cosf(), t (time), frame_num + (def expressions + {:rotate_angle "t * 30.0f" + :ripple_phase "t * 2.0f" + :brightness_factor "0.8f + 0.4f * sinf(t * 2.0f)"}) + + ;; Frame pipeline - creates image ON GPU and applies autonomous pipeline + (frame + (let [;; Create base image ON GPU (no CPU involvement!) + base (streaming_gpu:gpu-make-image 1920 1080 [128 100 200])] + + ;; Apply autonomous pipeline - ALL EFFECTS + ALL MATH ON GPU! + (streaming_gpu:autonomous-pipeline base effects expressions frame-num 30.0)))) diff --git a/test_autonomous_prealloc.py b/test_autonomous_prealloc.py new file mode 100644 index 0000000..5fde7f7 --- /dev/null +++ b/test_autonomous_prealloc.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +Test autonomous pipeline with pre-allocated buffer. +This eliminates ALL Python from the hot path. +""" + +import time +import sys +sys.path.insert(0, '/app') + +import cupy as cp +from streaming.sexp_to_cuda import compile_autonomous_pipeline + +def test_autonomous_prealloc(): + width, height = 1920, 1080 + n_frames = 300 + fps = 30.0 + + print(f"Testing {n_frames} frames at {width}x{height}") + print("=" * 60) + + # Pre-allocate frame buffer (stays on GPU) + frame = cp.zeros((height, width, 3), dtype=cp.uint8) + frame[:, :, 0] = 128 # R + frame[:, :, 1] = 100 # G + frame[:, :, 2] = 200 # B + + # Define effects + effects = [ + {'op': 'rotate', 'angle': 0}, + {'op': 'hue_shift', 'degrees': 30}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + # Dynamic expressions (computed on GPU) + dynamic_expressions = { + 'rotate_angle': 't * 30.0f', + 'ripple_phase': 't * 2.0f', + 'brightness_factor': '0.8f + 0.4f * sinf(t * 2.0f)', + } + + # Compile autonomous pipeline + print("Compiling autonomous pipeline...") + pipeline = compile_autonomous_pipeline(effects, width, height, dynamic_expressions) + + # Warmup + output = pipeline(frame, 0, fps) + cp.cuda.Stream.null.synchronize() + + # Benchmark - ZERO Python in the hot path! + print(f"Running {n_frames} frames...") + start = time.time() + for i in range(n_frames): + output = pipeline(frame, i, fps) + cp.cuda.Stream.null.synchronize() + elapsed = time.time() - start + + ms_per_frame = elapsed / n_frames * 1000 + actual_fps = n_frames / elapsed + + print("=" * 60) + print(f"Time: {ms_per_frame:.2f}ms per frame") + print(f"FPS: {actual_fps:.0f}") + print(f"Real-time: {actual_fps / 30:.1f}x (at 30fps target)") + print("=" * 60) + + # Compare with original baseline + print(f"\nOriginal Python sexp: ~150ms = 6 fps") + print(f"Autonomous GPU: {ms_per_frame:.2f}ms = {actual_fps:.0f} fps") + print(f"Speedup: {150 / ms_per_frame:.0f}x faster!") + + +if __name__ == '__main__': + test_autonomous_prealloc() diff --git a/test_full_optimized.py b/test_full_optimized.py new file mode 100644 index 0000000..6d7ae48 --- /dev/null +++ b/test_full_optimized.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Full Optimized GPU Pipeline Test + +This demonstrates the maximum performance achievable: +1. Pre-allocated GPU frame buffer +2. Autonomous CUDA kernel (all params computed on GPU) +3. GPU HLS encoder (zero-copy NVENC) + +The entire pipeline runs on GPU with zero CPU involvement per frame! +""" + +import time +import sys +import os + +sys.path.insert(0, '/app') + +import cupy as cp +import numpy as np +from streaming.sexp_to_cuda import compile_autonomous_pipeline + +# Try to import GPU encoder +try: + from streaming.gpu_output import GPUHLSOutput, check_gpu_encode_available + GPU_ENCODE = check_gpu_encode_available() +except: + GPU_ENCODE = False + +def run_optimized_stream(duration: float = 10.0, fps: float = 30.0, output_dir: str = '/tmp/optimized'): + width, height = 1920, 1080 + n_frames = int(duration * fps) + + print("=" * 60) + print("FULL OPTIMIZED GPU PIPELINE") + print("=" * 60) + print(f"Resolution: {width}x{height}") + print(f"Duration: {duration}s ({n_frames} frames @ {fps}fps)") + print(f"GPU encode: {GPU_ENCODE}") + print("=" * 60) + + # Pre-allocate frame buffer on GPU + print("\n[1/4] Pre-allocating GPU frame buffer...") + frame = cp.zeros((height, width, 3), dtype=cp.uint8) + # Create a gradient pattern + y_grad = cp.linspace(0, 255, height, dtype=cp.float32)[:, cp.newaxis] + x_grad = cp.linspace(0, 255, width, dtype=cp.float32)[cp.newaxis, :] + frame[:, :, 0] = (y_grad * 0.5).astype(cp.uint8) # R + frame[:, :, 1] = (x_grad * 0.5).astype(cp.uint8) # G + frame[:, :, 2] = 128 # B + + # Define effects + effects = [ + {'op': 'rotate', 'angle': 0}, + {'op': 'hue_shift', 'degrees': 30}, + {'op': 'ripple', 'amplitude': 20, 'frequency': 12, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + {'op': 'brightness', 'factor': 1.0}, + ] + + # Dynamic expressions (computed on GPU) + dynamic_expressions = { + 'rotate_angle': 't * 45.0f', # 45 degrees per second + 'ripple_phase': 't * 3.0f', # Ripple animation + 'brightness_factor': '0.7f + 0.3f * sinf(t * 2.0f)', # Pulsing brightness + } + + # Compile autonomous pipeline + print("[2/4] Compiling autonomous CUDA kernel...") + pipeline = compile_autonomous_pipeline(effects, width, height, dynamic_expressions) + + # Setup output + print("[3/4] Setting up output...") + os.makedirs(output_dir, exist_ok=True) + + if GPU_ENCODE: + print(" Using GPU HLS encoder (zero-copy)") + out = GPUHLSOutput(output_dir, size=(width, height), fps=fps) + else: + print(" Using ffmpeg encoder") + import subprocess + cmd = [ + 'ffmpeg', '-y', + '-f', 'rawvideo', '-vcodec', 'rawvideo', + '-pix_fmt', 'rgb24', '-s', f'{width}x{height}', '-r', str(fps), + '-i', '-', + '-c:v', 'h264_nvenc', '-preset', 'p4', '-cq', '18', + '-pix_fmt', 'yuv420p', + f'{output_dir}/output.mp4' + ] + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Warmup + output = pipeline(frame, 0, fps) + cp.cuda.Stream.null.synchronize() + + # Run the pipeline! + print(f"[4/4] Running {n_frames} frames...") + print("-" * 60) + + frame_times = [] + start_total = time.time() + + for i in range(n_frames): + frame_start = time.time() + + # Apply effects (autonomous kernel - all on GPU!) + output = pipeline(frame, i, fps) + + # Write output + if GPU_ENCODE: + out.write(output, i / fps) + else: + # Transfer to CPU for ffmpeg (slower path) + cpu_frame = cp.asnumpy(output) + proc.stdin.write(cpu_frame.tobytes()) + + cp.cuda.Stream.null.synchronize() + frame_times.append(time.time() - frame_start) + + # Progress + if (i + 1) % 30 == 0: + avg_ms = sum(frame_times[-30:]) / 30 * 1000 + print(f" Frame {i+1}/{n_frames}: {avg_ms:.1f}ms/frame") + + total_time = time.time() - start_total + + # Cleanup + if GPU_ENCODE: + out.close() + else: + proc.stdin.close() + proc.wait() + + # Results + print("-" * 60) + avg_ms = sum(frame_times) / len(frame_times) * 1000 + actual_fps = n_frames / total_time + + print("\nRESULTS:") + print("=" * 60) + print(f"Total time: {total_time:.2f}s") + print(f"Avg per frame: {avg_ms:.2f}ms") + print(f"Actual FPS: {actual_fps:.0f}") + print(f"Real-time: {actual_fps / fps:.1f}x") + print("=" * 60) + + if GPU_ENCODE: + print(f"\nOutput: {output_dir}/*.ts (HLS segments)") + else: + print(f"\nOutput: {output_dir}/output.mp4") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--duration', type=float, default=10.0) + parser.add_argument('-o', '--output', default='/tmp/optimized') + parser.add_argument('--fps', type=float, default=30.0) + args = parser.parse_args() + + run_optimized_stream(args.duration, args.fps, args.output) diff --git a/test_funky_text.py b/test_funky_text.py new file mode 100644 index 0000000..342ef1c --- /dev/null +++ b/test_funky_text.py @@ -0,0 +1,542 @@ +#!/usr/bin/env python3 +""" +Funky comparison tests: PIL vs TextStrip system. +Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + +FONTS = { + 'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + 'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + 'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf', + 'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf', + 'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + 'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf', + 'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf', + 'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf', +} + + +def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly, including color/opacity.""" + frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + # For opacity, render to RGBA then composite + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + font = _load_font(font_path, font_size) + + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + # PIL fill with alpha for opacity + alpha_int = int(round(opacity * 255)) + fill_rgba = (*fill, alpha_int) + stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None + + if multiline: + draw.multiline_text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor) + + # Composite onto background + bg_img = Image.fromarray(frame).convert('RGBA') + result = Image.alpha_composite(bg_img, txt_layer) + return np.array(result.convert('RGB')) + + +def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip system.""" + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + strip = render_text_strip( + text, font_path, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array(fill, dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + return np.array(result) + + +def compare(name, tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering.""" + pil = render_pil(**kwargs) + strip = render_strip(**kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f" PASS: {name}") + return True + + if tolerance > 0: + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + if best_diff.max() == 0: + print(f" PASS: {name} (within {tolerance}px tolerance)") + return True + + print(f" FAIL: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png") + return False + + +def test_colors(): + """Test various text colors on various backgrounds.""" + print("\n--- Colors ---") + results = [] + + # White on black (baseline) + results.append(compare("color_white_on_black", + text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0))) + + # Red on black + results.append(compare("color_red", + text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0))) + + # Green on black + results.append(compare("color_green", + text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0))) + + # Blue on black + results.append(compare("color_blue", + text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0))) + + # Yellow on dark gray + results.append(compare("color_yellow_on_gray", + text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40))) + + # Magenta on white + results.append(compare("color_magenta_on_white", + text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255))) + + # Subtle: gray text on slightly lighter gray + results.append(compare("color_subtle_gray", + text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64))) + + # Orange on deep blue + results.append(compare("color_orange_on_blue", + text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80))) + + return results + + +def test_opacity(): + """Test different opacity levels.""" + print("\n--- Opacity ---") + results = [] + + results.append(compare("opacity_100", + text="Full", x=20, y=30, opacity=1.0)) + + results.append(compare("opacity_75", + text="75%", x=20, y=30, opacity=0.75)) + + results.append(compare("opacity_50", + text="Half", x=20, y=30, opacity=0.5)) + + results.append(compare("opacity_25", + text="Quarter", x=20, y=30, opacity=0.25)) + + results.append(compare("opacity_10", + text="Ghost", x=20, y=30, opacity=0.1)) + + # Opacity on colored background + results.append(compare("opacity_on_colored_bg", + text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0), + opacity=0.5)) + + # Color + opacity combo + results.append(compare("opacity_red_on_green", + text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0), + opacity=0.6)) + + return results + + +def test_fonts(): + """Test different fonts and sizes.""" + print("\n--- Fonts & Sizes ---") + results = [] + + for label, path in FONTS.items(): + results.append(compare(f"font_{label}", + text="Quick Fox", x=20, y=30, font_path=path, font_size=28, + frame_size=(300, 80))) + + # Tiny text + results.append(compare("size_tiny", + text="Tiny text at 12px", x=10, y=15, font_size=12, + frame_size=(200, 40))) + + # Big text + results.append(compare("size_big", + text="BIG", x=20, y=10, font_size=72, + frame_size=(300, 100))) + + # Huge text + results.append(compare("size_huge", + text="XL", x=10, y=10, font_size=120, + frame_size=(300, 160))) + + return results + + +def test_anchors(): + """Test all anchor combinations.""" + print("\n--- Anchors ---") + results = [] + + # All horizontal x vertical combos + for h in ['l', 'm', 'r']: + for v in ['a', 'm', 's', 'd']: + anchor = h + v + results.append(compare(f"anchor_{anchor}", + text="Anchor", x=200, y=50, anchor=anchor, + frame_size=(400, 100))) + + return results + + +def test_strokes(): + """Test various stroke widths and colors.""" + print("\n--- Strokes ---") + results = [] + + for sw in [1, 2, 3, 4, 6, 8]: + results.append(compare(f"stroke_w{sw}", + text="Stroke", x=30, y=20, font_size=40, + stroke_width=sw, stroke_fill=(0, 0, 0), + frame_size=(300, 80))) + + # Colored strokes + results.append(compare("stroke_red", + text="Red outline", x=20, y=20, font_size=36, + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(350, 80))) + + results.append(compare("stroke_white_on_black", + text="Glow", x=20, y=20, font_size=40, + fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255), + frame_size=(250, 80))) + + # Stroke with bold font + results.append(compare("stroke_bold", + text="Bold+Stroke", x=20, y=20, + font_path=FONTS['bold'], font_size=36, + stroke_width=3, stroke_fill=(0, 0, 0), + frame_size=(400, 80))) + + # Stroke + colored text on colored bg + results.append(compare("stroke_colored_on_bg", + text="Party", x=20, y=20, font_size=48, + fill=(255, 255, 0), bg=(50, 0, 80), + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(300, 80))) + + return results + + +def test_edge_clipping(): + """Test text at frame edges - clipping behavior.""" + print("\n--- Edge Clipping ---") + results = [] + + # Text at very left edge + results.append(compare("clip_left_edge", + text="LEFT", x=0, y=30, frame_size=(200, 80))) + + # Text partially off right edge + results.append(compare("clip_right_edge", + text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80))) + + # Text at top edge + results.append(compare("clip_top", + text="TOP", x=20, y=0, frame_size=(200, 80))) + + # Text at bottom edge - partially clipped + results.append(compare("clip_bottom", + text="BOTTOM", x=20, y=55, font_size=40, + frame_size=(200, 80))) + + # Large text overflowing all sides from center + results.append(compare("clip_overflow_center", + text="HUGE", x=75, y=40, font_size=100, anchor="mm", + frame_size=(150, 80))) + + # Corner placement + results.append(compare("clip_corner_br", + text="Corner", x=350, y=70, font_size=36, + frame_size=(400, 100))) + + return results + + +def test_multiline_fancy(): + """Test multiline with various styles.""" + print("\n--- Multiline Fancy ---") + results = [] + + # Right-aligned (1px tolerance: same sub-pixel issue as center alignment) + results.append(compare("multi_right", + text="Right\nAligned\nText", x=380, y=20, + frame_size=(400, 150), + multiline=True, anchor="ra", align="right", + tolerance=1)) + + # Center + stroke + results.append(compare("multi_center_stroke", + text="Title\nSubtitle", x=200, y=20, + font_size=32, frame_size=(400, 120), + multiline=True, anchor="ma", align="center", + stroke_width=2, stroke_fill=(0, 0, 0), + tolerance=1)) + + # Wide line spacing + results.append(compare("multi_wide_spacing", + text="Line A\nLine B\nLine C", x=20, y=10, + frame_size=(300, 200), + multiline=True, line_spacing=20)) + + # Zero extra spacing + results.append(compare("multi_tight_spacing", + text="Tight\nPacked\nLines", x=20, y=10, + frame_size=(300, 150), + multiline=True, line_spacing=0)) + + # Many lines + results.append(compare("multi_many_lines", + text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5, + font_size=20, frame_size=(200, 200), + multiline=True, line_spacing=4)) + + # Bold multiline with stroke + results.append(compare("multi_bold_stroke", + text="BOLD\nSTROKE", x=20, y=10, + font_path=FONTS['bold'], font_size=48, + stroke_width=3, stroke_fill=(200, 0, 0), + frame_size=(350, 150), multiline=True)) + + # Multiline on colored bg with opacity + results.append(compare("multi_opacity_on_bg", + text="Semi\nTransparent", x=20, y=10, + fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7, + frame_size=(300, 120), multiline=True)) + + return results + + +def test_special_chars(): + """Test special characters and edge cases.""" + print("\n--- Special Characters ---") + results = [] + + # Numbers and symbols + results.append(compare("chars_numbers", + text="0123456789", x=20, y=30, frame_size=(300, 80))) + + results.append(compare("chars_punctuation", + text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80))) + + results.append(compare("chars_symbols", + text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80))) + + # Single character + results.append(compare("chars_single", + text="X", x=50, y=30, font_size=48, frame_size=(100, 80))) + + # Very long text (clipped) + results.append(compare("chars_long", + text="The quick brown fox jumps over the lazy dog", x=10, y=30, + font_size=24, frame_size=(400, 80))) + + # Mixed case + results.append(compare("chars_mixed_case", + text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80))) + + return results + + +def test_combos(): + """Complex combinations of features.""" + print("\n--- Combos ---") + results = [] + + # Big bold stroke + color + opacity + results.append(compare("combo_all_features", + text="EPIC", x=20, y=10, + font_path=FONTS['bold'], font_size=64, + fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85, + stroke_width=4, stroke_fill=(180, 0, 0), + frame_size=(350, 100))) + + # Small mono on busy background + results.append(compare("combo_mono_code", + text="fn main() {}", x=10, y=15, + font_path=FONTS['mono'], font_size=16, + fill=(0, 255, 100), bg=(30, 30, 30), + frame_size=(250, 50))) + + # Serif italic multiline with stroke + results.append(compare("combo_serif_italic_multi", + text="Once upon\na time...", x=20, y=10, + font_path=FONTS['italic'], font_size=28, + stroke_width=1, stroke_fill=(80, 80, 80), + frame_size=(300, 120), multiline=True)) + + # Narrow font, big stroke, center anchored + results.append(compare("combo_narrow_stroke_center", + text="NARROW", x=150, y=40, + font_path=FONTS['narrow'], font_size=44, + stroke_width=5, stroke_fill=(0, 0, 0), + anchor="mm", frame_size=(300, 80))) + + # Multiple strips on same frame (simulated by sequential placement) + results.append(compare("combo_opacity_blend", + text="Layered", x=20, y=30, + fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5, + font_size=48, frame_size=(300, 80))) + + return results + + +def test_multi_strip_overlay(): + """Test placing multiple strips on the same frame.""" + print("\n--- Multi-Strip Overlay ---") + results = [] + + frame_size = (400, 150) + bg = (20, 20, 40) + + # PIL version - multiple draw calls + font1 = _load_font(FONTS['bold'], 48) + font2 = _load_font(None, 24) + font3 = _load_font(FONTS['mono'], 18) + + pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1, + stroke_width=2, stroke_fill=(0, 0, 0, 255)) + draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2) + draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3) + bg_img = Image.fromarray(pil_frame).convert('RGBA') + pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB')) + + # Strip version - multiple placements + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0)) + s2 = render_text_strip("Subtitle here", None, 24) + s3 = render_text_strip("code_snippet()", FONTS['mono'], 18) + + frame = place_text_strip_jax( + frame, jnp.asarray(s1.image), 20, 10, + s1.baseline_y, s1.bearing_x, + jnp.array([255, 255, 0], dtype=jnp.float32), 1.0, + anchor_x=s1.anchor_x, anchor_y=s1.anchor_y, + stroke_width=s1.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s2.image), 20, 70, + s2.baseline_y, s2.bearing_x, + jnp.array([200, 200, 200], dtype=jnp.float32), 1.0, + anchor_x=s2.anchor_x, anchor_y=s2.anchor_y, + stroke_width=s2.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s3.image), 20, 110, + s3.baseline_y, s3.bearing_x, + jnp.array([0, 255, 128], dtype=jnp.float32), 200/255, + anchor_x=s3.anchor_x, anchor_y=s3.anchor_y, + stroke_width=s3.stroke_width) + + strip_result = np.array(frame) + + diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff <= 1: + print(f" PASS: multi_overlay (max_diff={max_diff})") + results.append(True) + else: + print(f" FAIL: multi_overlay") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png") + Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png") + print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png") + results.append(False) + + return results + + +def main(): + print("=" * 60) + print("Funky TextStrip vs PIL Comparison") + print("=" * 60) + + all_results = [] + all_results.extend(test_colors()) + all_results.extend(test_opacity()) + all_results.extend(test_fonts()) + all_results.extend(test_anchors()) + all_results.extend(test_strokes()) + all_results.extend(test_edge_clipping()) + all_results.extend(test_multiline_fancy()) + all_results.extend(test_special_chars()) + all_results.extend(test_combos()) + all_results.extend(test_multi_strip_overlay()) + + print("\n" + "=" * 60) + passed = sum(all_results) + total = len(all_results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + failed = [i for i, r in enumerate(all_results) if not r] + print(f"FAILED: {total - passed} tests (indices: {failed})") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/test_fused_direct.py b/test_fused_direct.py new file mode 100644 index 0000000..5b87462 --- /dev/null +++ b/test_fused_direct.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Direct test of fused pipeline primitive. + +Compares performance of: +1. Fused kernel (single CUDA kernel for all effects) +2. Separate kernels (one CUDA kernel per effect) +""" + +import time +import sys + +# Check for CuPy +try: + import cupy as cp + print("[test] CuPy available") +except ImportError: + print("[test] CuPy not available - can't run test") + sys.exit(1) + +# Add path for imports +sys.path.insert(0, '/app') + +from streaming.sexp_to_cuda import compile_frame_pipeline +from streaming.jit_compiler import fast_rotate, fast_hue_shift, fast_ripple + +def test_fused_vs_separate(): + """Compare fused vs separate kernel performance.""" + + width, height = 1920, 1080 + n_frames = 100 + + # Create test frame + frame = cp.random.randint(0, 255, (height, width, 3), dtype=cp.uint8) + + # Define effects pipeline + effects = [ + {'op': 'rotate', 'angle': 45.0}, + {'op': 'hue_shift', 'degrees': 30.0}, + {'op': 'ripple', 'amplitude': 15, 'frequency': 10, 'decay': 2, 'phase': 0, 'center_x': 960, 'center_y': 540}, + ] + + print(f"\n[test] Testing {n_frames} frames at {width}x{height}") + print(f"[test] Effects: rotate, hue_shift, ripple\n") + + # ========== Test fused kernel ========== + print("[test] Compiling fused kernel...") + pipeline = compile_frame_pipeline(effects, width, height) + + # Warmup + output = pipeline(frame, rotate_angle=45, ripple_phase=0) + cp.cuda.Stream.null.synchronize() + + print("[test] Running fused kernel benchmark...") + start = time.time() + for i in range(n_frames): + output = pipeline(frame, rotate_angle=i * 3.6, ripple_phase=i * 0.1) + cp.cuda.Stream.null.synchronize() + fused_time = time.time() - start + + fused_ms = fused_time / n_frames * 1000 + fused_fps = n_frames / fused_time + print(f"[test] Fused kernel: {fused_ms:.2f}ms/frame ({fused_fps:.0f} fps)") + + # ========== Test separate kernels ========== + print("\n[test] Running separate kernels benchmark...") + + # Warmup + temp = fast_rotate(frame, 45.0) + temp = fast_hue_shift(temp, 30.0) + temp = fast_ripple(temp, 15, frequency=10, decay=2, phase=0, center_x=960, center_y=540) + cp.cuda.Stream.null.synchronize() + + start = time.time() + for i in range(n_frames): + temp = fast_rotate(frame, i * 3.6) + temp = fast_hue_shift(temp, 30.0) + temp = fast_ripple(temp, 15, frequency=10, decay=2, phase=i * 0.1, center_x=960, center_y=540) + cp.cuda.Stream.null.synchronize() + separate_time = time.time() - start + + separate_ms = separate_time / n_frames * 1000 + separate_fps = n_frames / separate_time + print(f"[test] Separate kernels: {separate_ms:.2f}ms/frame ({separate_fps:.0f} fps)") + + # ========== Summary ========== + speedup = separate_time / fused_time + print(f"\n{'='*50}") + print(f"SPEEDUP: {speedup:.1f}x faster with fused kernel") + print(f"") + print(f"Fused: {fused_ms:.2f}ms ({fused_fps:.0f} fps)") + print(f"Separate: {separate_ms:.2f}ms ({separate_fps:.0f} fps)") + print(f"{'='*50}") + + # Compare with original Python sexp interpreter baseline (126-205ms) + python_baseline_ms = 150 # Approximate from profiling + vs_python = python_baseline_ms / fused_ms + print(f"\nVs Python sexp interpreter (~{python_baseline_ms}ms): {vs_python:.0f}x faster!") + + +if __name__ == '__main__': + test_fused_vs_separate() diff --git a/test_fused_pipeline.sexp b/test_fused_pipeline.sexp new file mode 100644 index 0000000..72ee033 --- /dev/null +++ b/test_fused_pipeline.sexp @@ -0,0 +1,44 @@ +;; Test Fused Pipeline - Should be much faster than interpreted +;; +;; This uses the fused-pipeline primitive which compiles all effects +;; into a single CUDA kernel instead of interpreting them one by one. + +(stream "fused_pipeline_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + (require-primitives "math") + + ;; Define the effects pipeline (compiled to single CUDA kernel) + ;; Each effect is a map with :op and effect-specific parameters + (def effects-pipeline + [{:op "rotate" :angle 0} + {:op "zoom" :amount 1.0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Frame pipeline + (frame + (let [;; Create a gradient image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters (change per frame) + angle (* t 30) + zoom (+ 1.0 (* 0.2 (math:sin (* t 0.5)))) + phase (* t 2)] + + ;; Apply fused pipeline - all effects in ONE CUDA kernel! + (streaming_gpu:fused-pipeline base effects-pipeline + :rotate_angle angle + :zoom_amount zoom + :ripple_phase phase)))) diff --git a/test_heavy_fused.sexp b/test_heavy_fused.sexp new file mode 100644 index 0000000..d421cfb --- /dev/null +++ b/test_heavy_fused.sexp @@ -0,0 +1,39 @@ +;; Heavy Fused Pipeline Test +;; +;; Tests with many effects to show the full benefit of kernel fusion + +(stream "heavy_fused_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "image") + (require-primitives "math") + + ;; Define heavy effects pipeline (4 effects fused into one kernel) + (def heavy-pipeline + [{:op "rotate" :angle 0} + {:op "hue_shift" :degrees 30} + {:op "ripple" :amplitude 15 :frequency 10 :decay 2 :phase 0 :center_x 960 :center_y 540} + {:op "brightness" :factor 1.0}]) + + ;; Frame pipeline + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters + angle (* t 30) + phase (* t 2)] + + ;; 4 effects in ONE kernel call! + (streaming_gpu:fused-pipeline base heavy-pipeline + :rotate_angle angle + :ripple_phase phase)))) diff --git a/test_heavy_interpreted.sexp b/test_heavy_interpreted.sexp new file mode 100644 index 0000000..72c7965 --- /dev/null +++ b/test_heavy_interpreted.sexp @@ -0,0 +1,38 @@ +;; Heavy Interpreted Pipeline Test +;; +;; Same effects as test_heavy_fused.sexp but using separate primitive calls + +(stream "heavy_interpreted_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "geometry_gpu") + (require-primitives "color_ops") + (require-primitives "image") + (require-primitives "math") + + ;; Frame pipeline - INTERPRETED (separate calls) + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Dynamic parameters + angle (* t 30) + zoom (+ 1.0 (* 0.1 (math:sin (* t 0.5)))) + phase (* t 2) + + ;; Apply 4 effects one by one (INTERPRETED - many kernel launches) + step1 (geometry_gpu:rotate base angle) + step2 (color_ops:hue-shift step1 30) + step3 (geometry_gpu:ripple step2 15 :frequency 10 :decay 2 :time phase) + step4 (color_ops:brightness step3 1.0)] + + step4))) diff --git a/test_interpreted_vs_fused.sexp b/test_interpreted_vs_fused.sexp new file mode 100644 index 0000000..0a365bd --- /dev/null +++ b/test_interpreted_vs_fused.sexp @@ -0,0 +1,37 @@ +;; Test: Interpreted vs Fused Pipeline Comparison +;; +;; This simulates a typical effects pipeline using the INTERPRETED approach +;; (calling primitives one by one through Python). + +(stream "interpreted_pipeline_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "streaming_gpu") + (require-primitives "geometry_gpu") + (require-primitives "color_ops") + (require-primitives "image") + (require-primitives "math") + + ;; Frame pipeline - INTERPRETED approach (one primitive call at a time) + (frame + (let [;; Create base image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Apply effects one by one (INTERPRETED - slow) + angle (* t 30) + rotated (geometry_gpu:rotate base angle) + + hued (color_ops:hue-shift rotated 30) + + brightness (+ 0.8 (* 0.4 (math:sin (* t 2)))) + bright (color_ops:brightness hued brightness)] + + bright))) diff --git a/test_pil_options.py b/test_pil_options.py new file mode 100644 index 0000000..fb5ffb9 --- /dev/null +++ b/test_pil_options.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Explore PIL text options and test if we can match them. +""" + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +def load_font(font_name=None, font_size=32): + """Load a font.""" + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', + ] + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + return ImageFont.load_default() + + +def test_pil_options(): + """Test various PIL text options.""" + + # Create a test frame + frame_size = (600, 400) + + font = load_font(None, 36) + font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36) + font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36) + + tests = [] + + # Test 1: Basic text + def basic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font) + return img, "basic" + tests.append(basic_text) + + # Test 2: Stroke/outline + def stroke_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font, + stroke_width=2, stroke_fill=(255, 0, 0, 255)) + return img, "stroke" + tests.append(stroke_text) + + # Test 3: Bold font + def bold_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold) + return img, "bold" + tests.append(bold_text) + + # Test 4: Italic font + def italic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic) + return img, "italic" + tests.append(italic_text) + + # Test 5: Different anchors + def anchor_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + # Draw crosshairs at anchor points + for x in [100, 300, 500]: + draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255)) + draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255)) + + draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm") + draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm") + draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm") + return img, "anchor" + tests.append(anchor_text) + + # Test 6: Multiline text + def multiline_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three", + fill=(255, 255, 255, 255), font=font, spacing=10) + return img, "multiline" + tests.append(multiline_text) + + # Test 7: Semi-transparent text + def alpha_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font) + draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font) + draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font) + return img, "alpha" + tests.append(alpha_text) + + # Test 8: Colored text + def colored_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font) + draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font) + draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font) + draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font) + return img, "colored" + tests.append(colored_text) + + # Test 9: Large stroke + def large_stroke(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font, + stroke_width=5, stroke_fill=(0, 0, 0, 255)) + return img, "large_stroke" + tests.append(large_stroke) + + # Test 10: Emoji (if supported) + def emoji_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + try: + # Try to find an emoji font + emoji_font = None + emoji_paths = [ + '/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf', + '/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf', + ] + for p in emoji_paths: + try: + emoji_font = ImageFont.truetype(p, 36) + break + except: + pass + if emoji_font: + draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font) + else: + draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font) + except Exception as e: + draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font) + return img, "emoji" + tests.append(emoji_text) + + # Run all tests + print("PIL Text Options Test") + print("=" * 60) + + for test_fn in tests: + img, name = test_fn() + fname = f"/tmp/pil_test_{name}.png" + img.save(fname) + print(f"Saved: {fname}") + + print("\nCheck /tmp/pil_test_*.png for results") + + # Print available parameters + print("\n" + "=" * 60) + print("PIL draw.text() parameters:") + print(" - xy: position tuple") + print(" - text: string to draw") + print(" - fill: color (R,G,B) or (R,G,B,A)") + print(" - font: ImageFont object") + print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)") + print(" - spacing: line spacing for multiline") + print(" - align: 'left', 'center', 'right' for multiline") + print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)") + print(" - features: OpenType features list") + print(" - language: language code for shaping") + print(" - stroke_width: outline width in pixels") + print(" - stroke_fill: outline color") + print(" - embedded_color: use embedded color glyphs (emoji)") + + +if __name__ == "__main__": + test_pil_options() diff --git a/test_styled_text.py b/test_styled_text.py new file mode 100644 index 0000000..925a7fb --- /dev/null +++ b/test_styled_text.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Test styled TextStrip rendering against PIL. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + + +def render_pil(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly.""" + frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8) + img = Image.fromarray(frame) + draw = ImageDraw.Draw(img) + + font = _load_font(None, font_size) + + # Default stroke fill + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + if multiline: + draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor) + + return np.array(img) + + +def render_strip(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip.""" + frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8) + + strip = render_text_strip( + text, None, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + return np.array(result) + + +def compare(name, text, x, y, font_size=36, frame_size=(400, 100), + tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering. + + tolerance=0: exact pixel match required + tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences + in center-aligned multiline text where the strip is pre-rendered + at a different base position than the final placement) + """ + pil = render_pil(text, x, y, font_size, frame_size, **kwargs) + strip = render_strip(text, x, y, font_size, frame_size, **kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f"PASS: {name}") + print(f" Max diff: 0, Pixels different: 0") + return True + + if tolerance > 0: + # Check if the difference is just a sub-pixel position shift: + # for each shifted version, compute the minimum diff + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + max_shift_diff = best_diff.max() + pixels_shift_diff = (best_diff > 0).any(axis=2).sum() + if max_shift_diff == 0: + print(f"PASS: {name} (within {tolerance}px position tolerance)") + print(f" Raw diff: {max_diff}, After shift tolerance: 0") + return True + + status = "FAIL" + print(f"{status}: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + + # Save debug images + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png") + + return False + + +def main(): + print("=" * 60) + print("Styled TextStrip vs PIL Comparison") + print("=" * 60) + + results = [] + + # Basic text + results.append(compare("basic", "Hello World", 20, 50)) + + # Stroke/outline + results.append(compare("stroke_2", "Outlined", 20, 50, + stroke_width=2, stroke_fill=(255, 0, 0))) + + results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48, + frame_size=(500, 120), + stroke_width=5, stroke_fill=(0, 0, 0))) + + # Anchors - center + results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100), + anchor="mm")) + + # Anchors - right + results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100), + anchor="rm")) + + # Multiline + results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20, + frame_size=(400, 150), + multiline=True, line_spacing=8)) + + # Multiline centered (1px tolerance: sub-pixel rendering differs because + # the strip is pre-rendered at an integer position while PIL's center + # alignment uses fractional getlength values for the 'm' anchor shift) + results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20, + frame_size=(400, 150), + multiline=True, anchor="ma", align="center", + tolerance=1)) + + # Stroke + multiline + results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20, + frame_size=(400, 120), + stroke_width=2, stroke_fill=(0, 0, 255), + multiline=True)) + + print("=" * 60) + passed = sum(results) + total = len(results) + print(f"Results: {passed}/{total} passed") + + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/test_typography_fx.py b/test_typography_fx.py new file mode 100644 index 0000000..e57186e --- /dev/null +++ b/test_typography_fx.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +""" +Tests for typography FX: gradients, rotation, shadow, and combined effects. +""" + +import numpy as np +import jax +import jax.numpy as jnp +from PIL import Image + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font, + make_linear_gradient, make_radial_gradient, make_multi_stop_gradient, + place_text_strip_gradient_jax, rotate_strip_jax, + place_text_strip_shadow_jax, place_text_strip_fx_jax, + bind_typography_primitives, +) + + +def make_frame(w=400, h=200): + """Create a dark gray test frame.""" + return jnp.full((h, w, 3), 40, dtype=jnp.uint8) + + +def get_strip(text="Hello", font_size=48): + """Get a pre-rendered text strip.""" + return render_text_strip(text, None, font_size) + + +def has_visible_pixels(frame, threshold=50): + """Check if frame has pixels above threshold.""" + return int(frame.max()) > threshold + + +def save_debug(name, frame): + """Save frame for visual inspection.""" + arr = np.array(frame) if not isinstance(frame, np.ndarray) else frame + Image.fromarray(arr).save(f"/tmp/fx_{name}.png") + + +# ============================================================================= +# Gradient Tests +# ============================================================================= + +def test_linear_gradient_shape(): + grad = make_linear_gradient(100, 50, (255, 0, 0), (0, 0, 255)) + assert grad.shape == (50, 100, 3), f"Expected (50, 100, 3), got {grad.shape}" + assert grad.dtype in (np.float32, np.float64), f"Expected float, got {grad.dtype}" + # Left edge should be red-ish, right edge blue-ish + assert grad[25, 0, 0] > 0.8, f"Left edge should be red, got R={grad[25, 0, 0]}" + assert grad[25, -1, 2] > 0.8, f"Right edge should be blue, got B={grad[25, -1, 2]}" + print("PASS: test_linear_gradient_shape") + return True + + +def test_linear_gradient_angle(): + # 90 degrees: top-to-bottom + grad = make_linear_gradient(100, 100, (255, 0, 0), (0, 0, 255), angle=90.0) + # Top row should be red, bottom row should be blue + assert grad[0, 50, 0] > 0.8, "Top should be red" + assert grad[-1, 50, 2] > 0.8, "Bottom should be blue" + print("PASS: test_linear_gradient_angle") + return True + + +def test_radial_gradient_shape(): + grad = make_radial_gradient(100, 100, (255, 255, 0), (0, 0, 128)) + assert grad.shape == (100, 100, 3) + # Center should be yellow (color1) + assert grad[50, 50, 0] > 0.9, "Center should be yellow (R)" + assert grad[50, 50, 1] > 0.9, "Center should be yellow (G)" + # Corner should be closer to dark blue (color2) + assert grad[0, 0, 2] > grad[50, 50, 2], "Corner should have more blue" + print("PASS: test_radial_gradient_shape") + return True + + +def test_multi_stop_gradient(): + stops = [ + (0.0, (255, 0, 0)), + (0.5, (0, 255, 0)), + (1.0, (0, 0, 255)), + ] + grad = make_multi_stop_gradient(100, 10, stops) + assert grad.shape == (10, 100, 3) + # Left: red, Middle: green, Right: blue + assert grad[5, 0, 0] > 0.8, "Left should be red" + assert grad[5, 50, 1] > 0.8, "Middle should be green" + assert grad[5, -1, 2] > 0.8, "Right should be blue" + print("PASS: test_multi_stop_gradient") + return True + + +def test_place_gradient(): + """Test gradient text rendering produces visible output.""" + frame = make_frame() + strip = get_strip() + grad = make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255)) + grad_jax = jnp.asarray(grad) + strip_img = jnp.asarray(strip.image) + + result = place_text_strip_gradient_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + grad_jax, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + assert result.shape == frame.shape + # Should have visible colored pixels + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Gradient text should be visible" + save_debug("gradient", result) + print("PASS: test_place_gradient") + return True + + +# ============================================================================= +# Rotation Tests +# ============================================================================= + +def test_rotate_strip_identity(): + """Rotation by 0 degrees should preserve content.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 0.0) + # Output is larger (diagonal size) + assert rotated.shape[2] == 4, "Should be RGBA" + assert rotated.shape[0] >= strip.height + assert rotated.shape[1] >= strip.width + + # Alpha should have non-zero pixels (text was preserved) + assert rotated[:, :, 3].max() > 200, "Should have visible alpha" + print("PASS: test_rotate_strip_identity") + return True + + +def test_rotate_strip_90(): + """Rotation by 90 degrees.""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + + rotated = rotate_strip_jax(strip_img, 90.0) + assert rotated.shape[2] == 4 + # Should still have visible content + assert rotated[:, :, 3].max() > 200, "Rotated strip should have visible alpha" + save_debug("rotated_90", np.array(rotated)) + print("PASS: test_rotate_strip_90") + return True + + +def test_rotate_360_exact(): + """360-degree rotation must be pixel-exact (regression test for trig snapping).""" + strip = get_strip() + strip_img = jnp.asarray(strip.image) + sh, sw = strip.height, strip.width + + rotated = rotate_strip_jax(strip_img, 360.0) + rh, rw = rotated.shape[:2] + off_y = (rh - sh) // 2 + off_x = (rw - sw) // 2 + + crop = np.array(rotated[off_y:off_y+sh, off_x:off_x+sw]) + orig = np.array(strip_img) + d = np.abs(crop.astype(np.int16) - orig.astype(np.int16)) + max_diff = int(d.max()) + assert max_diff == 0, f"360° rotation should be exact, max_diff={max_diff}" + print("PASS: test_rotate_360_exact") + return True + + +def test_place_rotated(): + """Test rotated text placement produces visible output.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 0], dtype=jnp.float32) + + result = place_text_strip_fx_jax( + frame, strip_img, 200.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + angle=30.0, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Rotated text should be visible" + save_debug("rotated_30", result) + print("PASS: test_place_rotated") + return True + + +# ============================================================================= +# Shadow Tests +# ============================================================================= + +def test_shadow_basic(): + """Test shadow produces visible offset copy.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=5.0, shadow_offset_y=5.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.8, + ) + + assert result.shape == frame.shape + # Should have both bright (text) and dark (shadow) pixels + assert result.max() > 200, "Should have bright text" + save_debug("shadow_basic", result) + print("PASS: test_shadow_basic") + return True + + +def test_shadow_blur(): + """Test blurred shadow.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_shadow_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.7, + shadow_blur_radius=3, + ) + + assert result.shape == frame.shape + save_debug("shadow_blur", result) + print("PASS: test_shadow_blur") + return True + + +# ============================================================================= +# Combined FX Tests +# ============================================================================= + +def test_fx_combined(): + """Test combined gradient + shadow + rotation.""" + frame = make_frame(500, 300) + strip = get_strip("FX Test", 64) + strip_img = jnp.asarray(strip.image) + + grad = make_linear_gradient(strip.width, strip.height, + (255, 100, 0), (0, 100, 255)) + grad_jax = jnp.asarray(grad) + + result = place_text_strip_fx_jax( + frame, strip_img, 250.0, 150.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + gradient_map=grad_jax, + angle=15.0, + shadow_offset_x=4.0, shadow_offset_y=4.0, + shadow_color=jnp.array([0, 0, 0], dtype=jnp.float32), + shadow_opacity=0.6, + shadow_blur_radius=2, + ) + + assert result.shape == frame.shape + diff = jnp.abs(result.astype(jnp.int16) - frame.astype(jnp.int16)) + assert diff.max() > 50, "Combined FX should produce visible output" + save_debug("fx_combined", result) + print("PASS: test_fx_combined") + return True + + +def test_fx_no_effects(): + """FX function with no effects should match basic place_text_strip_jax.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + # Using FX function with defaults + result_fx = place_text_strip_fx_jax( + frame, strip_img, 50.0, 100.0, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + # Using original function + result_orig = place_text_strip_jax( + frame, strip_img, 50.0, 100.0, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + diff = jnp.abs(result_fx.astype(jnp.int16) - result_orig.astype(jnp.int16)) + max_diff = int(diff.max()) + assert max_diff == 0, f"FX with no effects should match original, max diff={max_diff}" + print("PASS: test_fx_no_effects") + return True + + +# ============================================================================= +# S-Expression Binding Tests +# ============================================================================= + +def test_sexp_bindings(): + """Test that all new primitives are registered.""" + env = {} + bind_typography_primitives(env) + + expected = [ + 'linear-gradient', 'radial-gradient', 'multi-stop-gradient', + 'place-text-strip-gradient', 'place-text-strip-rotated', + 'place-text-strip-shadow', 'place-text-strip-fx', + ] + for name in expected: + assert name in env, f"Missing binding: {name}" + + print("PASS: test_sexp_bindings") + return True + + +def test_sexp_gradient_primitive(): + """Test gradient primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("Test", 36) + grad = env['linear-gradient'](strip, (255, 0, 0), (0, 0, 255)) + + assert grad.shape == (strip.height, strip.width, 3) + print("PASS: test_sexp_gradient_primitive") + return True + + +def test_sexp_fx_primitive(): + """Test combined FX primitive via binding.""" + env = {} + bind_typography_primitives(env) + + strip = env['render-text-strip']("FX", 36) + frame = make_frame() + + result = env['place-text-strip-fx']( + frame, strip, 100.0, 80.0, + color=(255, 200, 0), opacity=0.9, + shadow_offset_x=3, shadow_offset_y=3, + shadow_opacity=0.5, + ) + assert result.shape == frame.shape + print("PASS: test_sexp_fx_primitive") + return True + + +# ============================================================================= +# JIT Compilation Test +# ============================================================================= + +def test_jit_fx(): + """Test that place_text_strip_fx_jax can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + shadow_color = jnp.array([0, 0, 0], dtype=jnp.float32) + + # JIT compile with static args for angle and blur radius + @jax.jit + def render(frame, x, y, opacity): + return place_text_strip_fx_jax( + frame, strip_img, x, y, + baseline_y=strip.baseline_y, bearing_x=strip.bearing_x, + color=color, opacity=opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + shadow_offset_x=3.0, shadow_offset_y=3.0, + shadow_color=shadow_color, + shadow_opacity=0.5, + shadow_blur_radius=2, + ) + + # First call traces, second uses cache + result1 = render(frame, 50.0, 100.0, 1.0) + result2 = render(frame, 60.0, 90.0, 0.8) + + assert result1.shape == frame.shape + assert result2.shape == frame.shape + print("PASS: test_jit_fx") + return True + + +def test_jit_gradient(): + """Test that gradient placement can be JIT compiled.""" + frame = make_frame() + strip = get_strip() + strip_img = jnp.asarray(strip.image) + grad = jnp.asarray(make_linear_gradient(strip.width, strip.height, + (255, 0, 0), (0, 0, 255))) + + @jax.jit + def render(frame, x, y): + return place_text_strip_gradient_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + grad, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + ) + + result = render(frame, 50.0, 100.0) + assert result.shape == frame.shape + print("PASS: test_jit_gradient") + return True + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("Typography FX Tests") + print("=" * 60) + + tests = [ + # Gradients + test_linear_gradient_shape, + test_linear_gradient_angle, + test_radial_gradient_shape, + test_multi_stop_gradient, + test_place_gradient, + # Rotation + test_rotate_strip_identity, + test_rotate_strip_90, + test_rotate_360_exact, + test_place_rotated, + # Shadow + test_shadow_basic, + test_shadow_blur, + # Combined FX + test_fx_combined, + test_fx_no_effects, + # S-expression bindings + test_sexp_bindings, + test_sexp_gradient_primitive, + test_sexp_fx_primitive, + # JIT compilation + test_jit_fx, + test_jit_gradient, + ] + + results = [] + for test in tests: + try: + results.append(test()) + except Exception as e: + print(f"FAIL: {test.__name__}: {e}") + import traceback + traceback.print_exc() + results.append(False) + + print("=" * 60) + passed = sum(r for r in results if r) + total = len(results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + return passed == total + + +if __name__ == "__main__": + import sys + sys.exit(0 if main() else 1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9849184 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests for art-celery diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a99ae02 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,93 @@ +""" +Pytest fixtures for art-celery tests. +""" + +import pytest +from typing import Any, Dict, List + + +@pytest.fixture +def sample_compiled_nodes() -> List[Dict[str, Any]]: + """Sample nodes as produced by the S-expression compiler.""" + return [ + { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "cat"}, + "inputs": [], + "name": None, + }, + { + "id": "source_2", + "type": "SOURCE", + "config": { + "input": True, + "name": "Second Video", + "description": "A user-provided video", + }, + "inputs": [], + "name": "second-video", + }, + { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "identity"}, + "inputs": ["source_1"], + "name": None, + }, + { + "id": "effect_2", + "type": "EFFECT", + "config": {"effect": "invert", "intensity": 1.0}, + "inputs": ["source_2"], + "name": None, + }, + { + "id": "sequence_1", + "type": "SEQUENCE", + "config": {}, + "inputs": ["effect_1", "effect_2"], + "name": None, + }, + ] + + +@pytest.fixture +def sample_registry() -> Dict[str, Dict[str, Any]]: + """Sample registry with assets and effects.""" + return { + "assets": { + "cat": { + "cid": "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ", + "url": "https://example.com/cat.jpg", + }, + }, + "effects": { + "identity": { + "cid": "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko", + }, + "invert": { + "cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + }, + }, + } + + +@pytest.fixture +def sample_recipe( + sample_compiled_nodes: List[Dict[str, Any]], + sample_registry: Dict[str, Dict[str, Any]], +) -> Dict[str, Any]: + """Sample compiled recipe.""" + return { + "name": "test-recipe", + "version": "1.0", + "description": "A test recipe", + "owner": "@test@example.com", + "registry": sample_registry, + "dag": { + "nodes": sample_compiled_nodes, + "output": "sequence_1", + }, + "recipe_id": "Qmtest123", + } diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a4eb163 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,42 @@ +""" +Tests for authentication service. +""" + +import pytest + + +class TestUserContext: + """Tests for UserContext dataclass.""" + + def test_user_context_accepts_l2_server(self) -> None: + """ + Regression test: UserContext must accept l2_server field. + + Bug found 2026-01-12: auth_service.py passes l2_server to UserContext + but the art-common library was pinned to old version without this field. + """ + from artdag_common.middleware.auth import UserContext + + # This should not raise TypeError + ctx = UserContext( + username="testuser", + actor_id="@testuser@example.com", + token="test-token", + l2_server="https://l2.example.com", + ) + + assert ctx.username == "testuser" + assert ctx.actor_id == "@testuser@example.com" + assert ctx.token == "test-token" + assert ctx.l2_server == "https://l2.example.com" + + def test_user_context_l2_server_optional(self) -> None: + """l2_server should be optional (default None).""" + from artdag_common.middleware.auth import UserContext + + ctx = UserContext( + username="testuser", + actor_id="@testuser@example.com", + ) + + assert ctx.l2_server is None diff --git a/tests/test_cache_manager.py b/tests/test_cache_manager.py new file mode 100644 index 0000000..da2b5ab --- /dev/null +++ b/tests/test_cache_manager.py @@ -0,0 +1,397 @@ +# tests/test_cache_manager.py +"""Tests for the L1 cache manager.""" + +import tempfile +import time +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from cache_manager import ( + L1CacheManager, + L2SharedChecker, + CachedFile, + file_hash, +) + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_l2(): + """Mock L2 server responses.""" + with patch("cache_manager.requests") as mock_requests: + mock_requests.get.return_value = Mock(status_code=404) + yield mock_requests + + +@pytest.fixture +def manager(temp_dir, mock_l2): + """Create a cache manager instance.""" + return L1CacheManager( + cache_dir=temp_dir / "cache", + l2_server="http://mock-l2:8200", + ) + + +def create_test_file(path: Path, content: str = "test content") -> Path: + """Create a test file with content.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + +class TestFileHash: + """Tests for file_hash function.""" + + def test_consistent_hash(self, temp_dir): + """Same content produces same hash.""" + file1 = create_test_file(temp_dir / "f1.txt", "hello") + file2 = create_test_file(temp_dir / "f2.txt", "hello") + + assert file_hash(file1) == file_hash(file2) + + def test_different_content_different_hash(self, temp_dir): + """Different content produces different hash.""" + file1 = create_test_file(temp_dir / "f1.txt", "hello") + file2 = create_test_file(temp_dir / "f2.txt", "world") + + assert file_hash(file1) != file_hash(file2) + + def test_sha3_256_length(self, temp_dir): + """Hash is SHA3-256 (64 hex chars).""" + f = create_test_file(temp_dir / "f.txt", "test") + assert len(file_hash(f)) == 64 + + +class TestL2SharedChecker: + """Tests for L2 shared status checking.""" + + def test_not_shared_returns_false(self, mock_l2): + """Non-existent content returns False.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=404) + + assert checker.is_shared("abc123") is False + + def test_shared_returns_true(self, mock_l2): + """Published content returns True.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=200) + + assert checker.is_shared("abc123") is True + + def test_caches_result(self, mock_l2): + """Results are cached to avoid repeated API calls.""" + checker = L2SharedChecker("http://mock:8200", cache_ttl=60) + mock_l2.get.return_value = Mock(status_code=200) + + checker.is_shared("abc123") + checker.is_shared("abc123") + + # Should only call API once + assert mock_l2.get.call_count == 1 + + def test_mark_shared(self, mock_l2): + """mark_shared updates cache without API call.""" + checker = L2SharedChecker("http://mock:8200") + + checker.mark_shared("abc123") + + assert checker.is_shared("abc123") is True + assert mock_l2.get.call_count == 0 + + def test_invalidate(self, mock_l2): + """invalidate clears cache for a hash.""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.return_value = Mock(status_code=200) + + checker.is_shared("abc123") + checker.invalidate("abc123") + + mock_l2.get.return_value = Mock(status_code=404) + assert checker.is_shared("abc123") is False + + def test_error_returns_true(self, mock_l2): + """API errors return True (safe - prevents accidental deletion).""" + checker = L2SharedChecker("http://mock:8200") + mock_l2.get.side_effect = Exception("Network error") + + # On error, assume IS shared to prevent accidental deletion + assert checker.is_shared("abc123") is True + + +class TestL1CacheManagerStorage: + """Tests for cache storage operations.""" + + def test_put_and_get_by_cid(self, manager, temp_dir): + """Can store and retrieve by content hash.""" + test_file = create_test_file(temp_dir / "input.txt", "hello world") + + cached, cid = manager.put(test_file, node_type="test") + + retrieved_path = manager.get_by_cid(cached.cid) + assert retrieved_path is not None + assert retrieved_path.read_text() == "hello world" + + def test_put_with_custom_node_id(self, manager, temp_dir): + """Can store with custom node_id.""" + test_file = create_test_file(temp_dir / "input.txt", "content") + + cached, cid = manager.put(test_file, node_id="custom-node-123", node_type="test") + + assert cached.node_id == "custom-node-123" + assert manager.get_by_node_id("custom-node-123") is not None + + def test_has_content(self, manager, temp_dir): + """has_content checks existence.""" + test_file = create_test_file(temp_dir / "input.txt", "data") + + cached, cid = manager.put(test_file, node_type="test") + + assert manager.has_content(cached.cid) is True + assert manager.has_content("nonexistent") is False + + def test_list_all(self, manager, temp_dir): + """list_all returns all cached files.""" + f1 = create_test_file(temp_dir / "f1.txt", "one") + f2 = create_test_file(temp_dir / "f2.txt", "two") + + manager.put(f1, node_type="test") + manager.put(f2, node_type="test") + + all_files = manager.list_all() + assert len(all_files) == 2 + + def test_deduplication(self, manager, temp_dir): + """Same content is not stored twice.""" + f1 = create_test_file(temp_dir / "f1.txt", "identical") + f2 = create_test_file(temp_dir / "f2.txt", "identical") + + cached1, cid1 = manager.put(f1, node_type="test") + cached2, cid2 = manager.put(f2, node_type="test") + + assert cached1.cid == cached2.cid + assert len(manager.list_all()) == 1 + + +class TestL1CacheManagerActivities: + """Tests for activity tracking.""" + + def test_record_simple_activity(self, manager, temp_dir): + """Can record a simple activity.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + activity = manager.record_simple_activity( + input_hashes=[input_cached.cid], + output_cid=output_cached.cid, + run_id="run-001", + ) + + assert activity.activity_id == "run-001" + assert input_cached.cid in activity.input_ids + assert activity.output_id == output_cached.cid + + def test_list_activities(self, manager, temp_dir): + """Can list all activities.""" + for i in range(3): + inp = create_test_file(temp_dir / f"in{i}.txt", f"input{i}") + out = create_test_file(temp_dir / f"out{i}.txt", f"output{i}") + inp_c, _ = manager.put(inp, node_type="source") + out_c, _ = manager.put(out, node_type="effect") + manager.record_simple_activity([inp_c.cid], out_c.cid) + + activities = manager.list_activities() + assert len(activities) == 3 + + def test_find_activities_by_inputs(self, manager, temp_dir): + """Can find activities with same inputs.""" + input_file = create_test_file(temp_dir / "shared_input.txt", "shared") + input_cached, _ = manager.put(input_file, node_type="source") + + # Two activities with same input + out1 = create_test_file(temp_dir / "out1.txt", "output1") + out2 = create_test_file(temp_dir / "out2.txt", "output2") + out1_c, _ = manager.put(out1, node_type="effect") + out2_c, _ = manager.put(out2, node_type="effect") + + manager.record_simple_activity([input_cached.cid], out1_c.cid, "run1") + manager.record_simple_activity([input_cached.cid], out2_c.cid, "run2") + + found = manager.find_activities_by_inputs([input_cached.cid]) + assert len(found) == 2 + + +class TestL1CacheManagerDeletionRules: + """Tests for deletion rules enforcement.""" + + def test_can_delete_orphaned_item(self, manager, temp_dir): + """Orphaned items can be deleted.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cached, _ = manager.put(test_file, node_type="test") + + can_delete, reason = manager.can_delete(cached.cid) + assert can_delete is True + + def test_cannot_delete_activity_input(self, manager, temp_dir): + """Activity inputs cannot be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + can_delete, reason = manager.can_delete(input_cached.cid) + assert can_delete is False + assert "input" in reason.lower() + + def test_cannot_delete_activity_output(self, manager, temp_dir): + """Activity outputs cannot be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + can_delete, reason = manager.can_delete(output_cached.cid) + assert can_delete is False + assert "output" in reason.lower() + + def test_cannot_delete_pinned_item(self, manager, temp_dir): + """Pinned items cannot be deleted.""" + test_file = create_test_file(temp_dir / "shared.txt", "shared") + cached, _ = manager.put(test_file, node_type="test") + + # Mark as pinned (published) + manager.pin(cached.cid, reason="published") + + can_delete, reason = manager.can_delete(cached.cid) + assert can_delete is False + assert "pinned" in reason + + def test_delete_orphaned_item(self, manager, temp_dir): + """Can delete orphaned items.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cached, _ = manager.put(test_file, node_type="test") + + success, msg = manager.delete_by_cid(cached.cid) + + assert success is True + assert manager.has_content(cached.cid) is False + + def test_delete_protected_item_fails(self, manager, temp_dir): + """Cannot delete protected items.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + ) + + success, msg = manager.delete_by_cid(input_cached.cid) + + assert success is False + assert manager.has_content(input_cached.cid) is True + + +class TestL1CacheManagerActivityDiscard: + """Tests for activity discard functionality.""" + + def test_can_discard_unshared_activity(self, manager, temp_dir): + """Activities with no shared items can be discarded.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + activity = manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + can_discard, reason = manager.can_discard_activity("run-001") + assert can_discard is True + + def test_cannot_discard_activity_with_pinned_output(self, manager, temp_dir): + """Activities with pinned outputs cannot be discarded.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + # Mark output as pinned (published) + manager.pin(output_cached.cid, reason="published") + + can_discard, reason = manager.can_discard_activity("run-001") + assert can_discard is False + assert "pinned" in reason + + def test_discard_activity_cleans_up(self, manager, temp_dir): + """Discarding activity cleans up orphaned items.""" + input_file = create_test_file(temp_dir / "input.txt", "input") + output_file = create_test_file(temp_dir / "output.txt", "output") + + input_cached, _ = manager.put(input_file, node_type="source") + output_cached, _ = manager.put(output_file, node_type="effect") + + manager.record_simple_activity( + [input_cached.cid], + output_cached.cid, + "run-001", + ) + + success, msg = manager.discard_activity("run-001") + + assert success is True + assert manager.get_activity("run-001") is None + + +class TestL1CacheManagerStats: + """Tests for cache statistics.""" + + def test_get_stats(self, manager, temp_dir): + """get_stats returns cache statistics.""" + f1 = create_test_file(temp_dir / "f1.txt", "content1") + f2 = create_test_file(temp_dir / "f2.txt", "content2") + + manager.put(f1, node_type="test") + manager.put(f2, node_type="test") + + stats = manager.get_stats() + + assert stats["total_entries"] == 2 + assert stats["total_size_bytes"] > 0 + assert "activities" in stats diff --git a/tests/test_dag_transform.py b/tests/test_dag_transform.py new file mode 100644 index 0000000..32d45f8 --- /dev/null +++ b/tests/test_dag_transform.py @@ -0,0 +1,492 @@ +""" +Tests for DAG transformation and input binding. + +These tests verify the critical path that was causing bugs: +- Node transformation from compiled format to artdag format +- Asset/effect CID resolution from registry +- Variable input name mapping +- Input binding +""" + +import json +import logging +import pytest +from typing import Any, Dict, List, Optional, Tuple + +# Standalone implementations of the functions for testing +# This avoids importing the full app which requires external dependencies + +logger = logging.getLogger(__name__) + + +def is_variable_input(config: Dict[str, Any]) -> bool: + """Check if a SOURCE node config represents a variable input.""" + return bool(config.get("input")) + + +def transform_node( + node: Dict[str, Any], + assets: Dict[str, Dict[str, Any]], + effects: Dict[str, Dict[str, Any]], +) -> Dict[str, Any]: + """Transform a compiled node to artdag execution format.""" + node_id = node.get("id", "") + config = dict(node.get("config", {})) + + if node.get("type") == "SOURCE" and "asset" in config: + asset_name = config["asset"] + if asset_name in assets: + config["cid"] = assets[asset_name].get("cid") + + if node.get("type") == "EFFECT" and "effect" in config: + effect_name = config["effect"] + if effect_name in effects: + config["cid"] = effects[effect_name].get("cid") + + return { + "node_id": node_id, + "node_type": node.get("type", "EFFECT"), + "config": config, + "inputs": node.get("inputs", []), + "name": node.get("name"), + } + + +def build_input_name_mapping(nodes: Dict[str, Dict[str, Any]]) -> Dict[str, str]: + """Build a mapping from input names to node IDs for variable inputs.""" + input_name_to_node: Dict[str, str] = {} + + for node_id, node in nodes.items(): + if node.get("node_type") != "SOURCE": + continue + + config = node.get("config", {}) + if not is_variable_input(config): + continue + + input_name_to_node[node_id] = node_id + + name = config.get("name") + if name: + input_name_to_node[name] = node_id + input_name_to_node[name.lower().replace(" ", "_")] = node_id + input_name_to_node[name.lower().replace(" ", "-")] = node_id + + node_name = node.get("name") + if node_name: + input_name_to_node[node_name] = node_id + input_name_to_node[node_name.replace("-", "_")] = node_id + + return input_name_to_node + + +def bind_inputs( + nodes: Dict[str, Dict[str, Any]], + input_name_to_node: Dict[str, str], + user_inputs: Dict[str, str], +) -> List[str]: + """Bind user-provided input CIDs to source nodes.""" + warnings: List[str] = [] + + for input_name, cid in user_inputs.items(): + if input_name in nodes: + node = nodes[input_name] + if node.get("node_type") == "SOURCE": + node["config"]["cid"] = cid + continue + + if input_name in input_name_to_node: + node_id = input_name_to_node[input_name] + node = nodes[node_id] + node["config"]["cid"] = cid + continue + + warnings.append(f"Input '{input_name}' not found in recipe") + + return warnings + + +def prepare_dag_for_execution( + recipe: Dict[str, Any], + user_inputs: Dict[str, str], +) -> Tuple[str, List[str]]: + """Prepare a recipe DAG for execution.""" + recipe_dag = recipe.get("dag") + if not recipe_dag or not isinstance(recipe_dag, dict): + raise ValueError("Recipe has no DAG definition") + + dag_copy = json.loads(json.dumps(recipe_dag)) + nodes = dag_copy.get("nodes", {}) + + registry = recipe.get("registry", {}) + assets = registry.get("assets", {}) if registry else {} + effects = registry.get("effects", {}) if registry else {} + + if isinstance(nodes, list): + nodes_dict: Dict[str, Dict[str, Any]] = {} + for node in nodes: + node_id = node.get("id") + if node_id: + nodes_dict[node_id] = transform_node(node, assets, effects) + nodes = nodes_dict + dag_copy["nodes"] = nodes + + input_name_to_node = build_input_name_mapping(nodes) + warnings = bind_inputs(nodes, input_name_to_node, user_inputs) + + if "output" in dag_copy: + dag_copy["output_id"] = dag_copy.pop("output") + + if "metadata" not in dag_copy: + dag_copy["metadata"] = {} + + return json.dumps(dag_copy), warnings + + +class TestTransformNode: + """Tests for transform_node function.""" + + def test_source_node_with_asset_resolves_cid( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """SOURCE nodes with asset reference should get CID from registry.""" + node = { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "cat"}, + "inputs": [], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["node_id"] == "source_1" + assert result["node_type"] == "SOURCE" + assert result["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + def test_effect_node_resolves_cid( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """EFFECT nodes should get CID from registry.""" + node = { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "invert", "intensity": 1.0}, + "inputs": ["source_1"], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["node_id"] == "effect_1" + assert result["node_type"] == "EFFECT" + assert result["config"]["effect"] == "invert" + assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + assert result["config"]["intensity"] == 1.0 + + def test_variable_input_node_preserves_config( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """Variable input SOURCE nodes should preserve their config.""" + node = { + "id": "source_2", + "type": "SOURCE", + "config": { + "input": True, + "name": "Second Video", + "description": "User-provided video", + }, + "inputs": [], + "name": "second-video", + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert result["config"]["input"] is True + assert result["config"]["name"] == "Second Video" + assert result["name"] == "second-video" + # No CID yet - will be bound at runtime + assert "cid" not in result["config"] + + def test_unknown_asset_not_resolved( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """Unknown assets should not crash, just not get CID.""" + node = { + "id": "source_1", + "type": "SOURCE", + "config": {"asset": "unknown_asset"}, + "inputs": [], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + assert "cid" not in result["config"] + + +class TestBuildInputNameMapping: + """Tests for build_input_name_mapping function.""" + + def test_maps_by_node_id(self) -> None: + """Should map by node_id.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["source_2"] == "source_2" + + def test_maps_by_config_name(self) -> None: + """Should map by config.name with various formats.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["Second Video"] == "source_2" + assert mapping["second_video"] == "source_2" + assert mapping["second-video"] == "source_2" + + def test_maps_by_def_name(self) -> None: + """Should map by node.name (def binding name).""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + "name": "inverted-video", + } + } + + mapping = build_input_name_mapping(nodes) + + assert mapping["inverted-video"] == "source_2" + assert mapping["inverted_video"] == "source_2" + + def test_ignores_non_source_nodes(self) -> None: + """Should not include non-SOURCE nodes.""" + nodes = { + "effect_1": { + "node_id": "effect_1", + "node_type": "EFFECT", + "config": {"effect": "invert"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert "effect_1" not in mapping + + def test_ignores_fixed_sources(self) -> None: + """Should not include SOURCE nodes without 'input' flag.""" + nodes = { + "source_1": { + "node_id": "source_1", + "node_type": "SOURCE", + "config": {"asset": "cat", "cid": "Qm123"}, + "inputs": [], + } + } + + mapping = build_input_name_mapping(nodes) + + assert "source_1" not in mapping + + +class TestBindInputs: + """Tests for bind_inputs function.""" + + def test_binds_by_direct_node_id(self) -> None: + """Should bind when using node_id directly.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + mapping = {"source_2": "source_2"} + user_inputs = {"source_2": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert nodes["source_2"]["config"]["cid"] == "QmUserInput123" + assert len(warnings) == 0 + + def test_binds_by_name_lookup(self) -> None: + """Should bind when using input name.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Second Video"}, + "inputs": [], + } + } + mapping = { + "source_2": "source_2", + "Second Video": "source_2", + "second-video": "source_2", + } + user_inputs = {"second-video": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert nodes["source_2"]["config"]["cid"] == "QmUserInput123" + assert len(warnings) == 0 + + def test_warns_on_unknown_input(self) -> None: + """Should warn when input name not found.""" + nodes = { + "source_2": { + "node_id": "source_2", + "node_type": "SOURCE", + "config": {"input": True, "name": "Test"}, + "inputs": [], + } + } + mapping = {"source_2": "source_2", "Test": "source_2"} + user_inputs = {"unknown-input": "QmUserInput123"} + + warnings = bind_inputs(nodes, mapping, user_inputs) + + assert len(warnings) == 1 + assert "unknown-input" in warnings[0] + + +class TestRegressions: + """Tests for specific bugs that were found in production.""" + + def test_effect_cid_key_not_effect_hash( + self, + sample_registry: Dict[str, Dict[str, Any]], + ) -> None: + """ + Regression test: Effect CID must use 'cid' key, not 'effect_hash'. + + Bug found 2026-01-12: The transform_node function was setting + config["effect_hash"] but the executor looks for config["cid"]. + This caused "Unknown effect: invert" errors. + """ + node = { + "id": "effect_1", + "type": "EFFECT", + "config": {"effect": "invert"}, + "inputs": ["source_1"], + } + assets = sample_registry["assets"] + effects = sample_registry["effects"] + + result = transform_node(node, assets, effects) + + # MUST use 'cid' key - executor checks config.get("cid") + assert "cid" in result["config"], "Effect CID must be stored as 'cid'" + assert "effect_hash" not in result["config"], "Must not use 'effect_hash' key" + assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + def test_source_cid_binding_persists( + self, + sample_recipe: Dict[str, Any], + ) -> None: + """ + Regression test: Bound CIDs must appear in final DAG JSON. + + Bug found 2026-01-12: Variable input CIDs were being bound but + not appearing in the serialized DAG sent to the executor. + """ + user_inputs = {"second-video": "QmTestUserInput123"} + + dag_json, _ = prepare_dag_for_execution(sample_recipe, user_inputs) + dag = json.loads(dag_json) + + # The bound CID must be in the final JSON + source_2 = dag["nodes"]["source_2"] + assert source_2["config"]["cid"] == "QmTestUserInput123" + + +class TestPrepareDagForExecution: + """Integration tests for prepare_dag_for_execution.""" + + def test_full_pipeline(self, sample_recipe: Dict[str, Any]) -> None: + """Test the full DAG preparation pipeline.""" + user_inputs = { + "second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + } + + dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs) + + # Parse result + dag = json.loads(dag_json) + + # Check structure + assert "nodes" in dag + assert "output_id" in dag + assert dag["output_id"] == "sequence_1" + + # Check fixed source has CID + source_1 = dag["nodes"]["source_1"] + assert source_1["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + # Check variable input was bound + source_2 = dag["nodes"]["source_2"] + assert source_2["config"]["cid"] == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + + # Check effects have CIDs + effect_1 = dag["nodes"]["effect_1"] + assert effect_1["config"]["cid"] == "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko" + + effect_2 = dag["nodes"]["effect_2"] + assert effect_2["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + assert effect_2["config"]["intensity"] == 1.0 + + # No warnings + assert len(warnings) == 0 + + def test_missing_input_produces_warning( + self, + sample_recipe: Dict[str, Any], + ) -> None: + """Missing inputs should produce warnings but not fail.""" + user_inputs = {} # No inputs provided + + dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs) + + # Should still produce valid JSON + dag = json.loads(dag_json) + assert "nodes" in dag + + # Variable input should not have CID + source_2 = dag["nodes"]["source_2"] + assert "cid" not in source_2["config"] + + def test_raises_on_missing_dag(self) -> None: + """Should raise ValueError if recipe has no DAG.""" + recipe = {"name": "broken", "registry": {}} + + with pytest.raises(ValueError, match="no DAG"): + prepare_dag_for_execution(recipe, {}) diff --git a/tests/test_effect_loading.py b/tests/test_effect_loading.py new file mode 100644 index 0000000..3e41467 --- /dev/null +++ b/tests/test_effect_loading.py @@ -0,0 +1,327 @@ +""" +Tests for effect loading from cache and IPFS. + +These tests verify that: +- Effects can be loaded from the local cache directory +- IPFS gateway configuration is correct for Docker environments +- The effect executor correctly resolves CIDs from config +""" + +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch, MagicMock + +import pytest + + +# Minimal effect loading implementation for testing +# This mirrors the logic in artdag/nodes/effect.py + + +def get_effects_cache_dir_impl(env_vars: Dict[str, str]) -> Optional[Path]: + """Get the effects cache directory from environment or default.""" + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = env_vars.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + if effects_dir.exists(): + return effects_dir + + # Try default locations + for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]: + effects_dir = base / "_effects" + if effects_dir.exists(): + return effects_dir + + return None + + +def effect_path_for_cid(effects_dir: Path, effect_cid: str) -> Path: + """Get the expected path for an effect given its CID.""" + return effects_dir / effect_cid / "effect.py" + + +class TestEffectCacheDirectory: + """Tests for effect cache directory resolution.""" + + def test_cache_dir_from_env(self, tmp_path: Path) -> None: + """CACHE_DIR env var should determine effects directory.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + + env = {"CACHE_DIR": str(tmp_path)} + result = get_effects_cache_dir_impl(env) + + assert result == effects_dir + + def test_artdag_cache_dir_fallback(self, tmp_path: Path) -> None: + """ARTDAG_CACHE_DIR should work as fallback.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + + env = {"ARTDAG_CACHE_DIR": str(tmp_path)} + result = get_effects_cache_dir_impl(env) + + assert result == effects_dir + + def test_no_env_returns_none_if_no_default_exists(self) -> None: + """Should return None if no cache directory exists.""" + env = {} + result = get_effects_cache_dir_impl(env) + + # Will return None unless default dirs exist + # This is expected behavior + if result is not None: + assert result.exists() + + +class TestEffectPathResolution: + """Tests for effect path resolution.""" + + def test_effect_path_structure(self, tmp_path: Path) -> None: + """Effect should be at _effects/{cid}/effect.py.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffect123" + + path = effect_path_for_cid(effects_dir, effect_cid) + + assert path == effects_dir / effect_cid / "effect.py" + + def test_effect_file_exists_after_upload(self, tmp_path: Path) -> None: + """After upload, effect.py should exist in the right location.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Simulate effect upload (as done by app/routers/effects.py) + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + effect_source = '''""" +@effect invert +@version 1.0.0 +""" + +def process_frame(frame, params, state): + return 255 - frame, state +''' + (effect_dir / "effect.py").write_text(effect_source) + + # Verify the path structure + expected_path = effect_path_for_cid(effects_dir, effect_cid) + assert expected_path.exists() + assert "process_frame" in expected_path.read_text() + + +class TestIPFSAPIConfiguration: + """Tests for IPFS API configuration (consistent across codebase).""" + + def test_ipfs_api_multiaddr_conversion(self) -> None: + """ + IPFS_API multiaddr should convert to correct URL. + + Both ipfs_client.py and artdag/nodes/effect.py now use IPFS_API + with multiaddr format for consistency. + """ + import re + + def multiaddr_to_url(multiaddr: str) -> str: + """Convert multiaddr to URL (same logic as ipfs_client.py).""" + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + return "http://127.0.0.1:5001" + + # Docker config + docker_api = "/dns/ipfs/tcp/5001" + url = multiaddr_to_url(docker_api) + assert url == "http://ipfs:5001" + + # Local dev config + local_api = "/ip4/127.0.0.1/tcp/5001" + url = multiaddr_to_url(local_api) + assert url == "http://127.0.0.1:5001" + + def test_all_ipfs_access_uses_api_not_gateway(self) -> None: + """ + All IPFS access should use IPFS_API (port 5001), not IPFS_GATEWAY (port 8080). + + Fixed 2026-01-12: artdag/nodes/effect.py was using a separate IPFS_GATEWAY + variable. Now it uses IPFS_API like ipfs_client.py for consistency. + """ + # The API endpoint that both modules use + api_endpoint = "/api/v0/cat" + + # This is correct - using the API + assert "api/v0" in api_endpoint + + # Gateway endpoint would be /ipfs/{cid} - we don't use this anymore + gateway_pattern = "/ipfs/" + assert gateway_pattern not in api_endpoint + + +class TestEffectExecutorConfigResolution: + """Tests for how the effect executor resolves CID from config.""" + + def test_executor_should_use_cid_key(self) -> None: + """ + Effect executor must look for 'cid' key in config. + + The transform_node function sets config["cid"] for effects. + The executor must read from the same key. + """ + config = { + "effect": "invert", + "cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + "intensity": 1.0, + } + + # Simulate executor CID extraction (from artdag/nodes/effect.py:258) + effect_cid = config.get("cid") or config.get("hash") + + assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + def test_executor_should_not_use_effect_hash(self) -> None: + """ + Regression test: 'effect_hash' is not a valid config key. + + Bug found 2026-01-12: transform_node was using config["effect_hash"] + but executor only checks config["cid"] or config["hash"]. + """ + config = { + "effect": "invert", + "effect_hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + } + + # This simulates the buggy behavior where effect_hash was set + # but executor doesn't look for it + effect_cid = config.get("cid") or config.get("hash") + + # The bug: effect_hash is ignored, effect_cid is None + assert effect_cid is None, "effect_hash should NOT be recognized" + + def test_hash_key_is_legacy_fallback(self) -> None: + """'hash' key should work as legacy fallback for 'cid'.""" + config = { + "effect": "invert", + "hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J", + } + + effect_cid = config.get("cid") or config.get("hash") + + assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + +class TestEffectLoadingIntegration: + """Integration tests for complete effect loading path.""" + + def test_effect_loads_from_cache_when_present(self, tmp_path: Path) -> None: + """Effect should load from cache without hitting IPFS.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Create effect file in cache + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +def process_frame(frame, params, state): + """Invert colors.""" + return 255 - frame, state +''') + + # Verify the effect can be found + effect_path = effect_path_for_cid(effects_dir, effect_cid) + assert effect_path.exists() + + # Load and verify it has the expected function + import importlib.util + spec = importlib.util.spec_from_file_location("test_effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + assert hasattr(module, "process_frame") + + def test_effect_fetch_uses_ipfs_api(self, tmp_path: Path) -> None: + """Effect fetch should use IPFS API endpoint, not gateway.""" + import re + + def multiaddr_to_url(multiaddr: str) -> str: + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + return "http://127.0.0.1:5001" + + # In Docker, IPFS_API=/dns/ipfs/tcp/5001 + docker_multiaddr = "/dns/ipfs/tcp/5001" + base_url = multiaddr_to_url(docker_multiaddr) + effect_cid = "QmTestCid123" + + # Should use API endpoint + api_url = f"{base_url}/api/v0/cat?arg={effect_cid}" + + assert "ipfs:5001" in api_url + assert "/api/v0/cat" in api_url + assert "127.0.0.1" not in api_url + + +class TestSharedVolumeScenario: + """ + Tests simulating the Docker shared volume scenario. + + In Docker: + - l1-server uploads effect to /data/cache/_effects/{cid}/effect.py + - l1-worker should find it at the same path via shared volume + """ + + def test_effect_visible_on_shared_volume(self, tmp_path: Path) -> None: + """Effect uploaded on server should be visible to worker.""" + # Simulate shared volume mounted at /data/cache on both containers + shared_volume = tmp_path / "data" / "cache" + effects_dir = shared_volume / "_effects" + + # Server uploads effect + effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + effect_upload_dir = effects_dir / effect_cid + effect_upload_dir.mkdir(parents=True) + (effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s') + (effect_upload_dir / "metadata.json").write_text('{"cid": "' + effect_cid + '"}') + + # Worker should find the effect + env_vars = {"CACHE_DIR": str(shared_volume)} + worker_effects_dir = get_effects_cache_dir_impl(env_vars) + + assert worker_effects_dir is not None + assert worker_effects_dir == effects_dir + + worker_effect_path = effect_path_for_cid(worker_effects_dir, effect_cid) + assert worker_effect_path.exists() + + def test_effect_cid_matches_registry(self, tmp_path: Path) -> None: + """CID in recipe registry must match the uploaded effect directory name.""" + shared_volume = tmp_path + effects_dir = shared_volume / "_effects" + + # The CID used in the recipe registry + registry_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J" + + # Upload creates directory with CID as name + effect_upload_dir = effects_dir / registry_cid + effect_upload_dir.mkdir(parents=True) + (effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s') + + # Executor receives the same CID from DAG config + dag_config_cid = registry_cid # This comes from transform_node + + # These must match for the lookup to work + assert dag_config_cid == registry_cid + + # And the path must exist + lookup_path = effects_dir / dag_config_cid / "effect.py" + assert lookup_path.exists() diff --git a/tests/test_effects_web.py b/tests/test_effects_web.py new file mode 100644 index 0000000..5ad461f --- /dev/null +++ b/tests/test_effects_web.py @@ -0,0 +1,367 @@ +""" +Tests for Effects web UI. + +Tests effect metadata parsing, listing, and templates. +""" + +import pytest +import re +from pathlib import Path +from unittest.mock import MagicMock + + +def parse_effect_metadata_standalone(source: str) -> dict: + """ + Standalone copy of parse_effect_metadata for testing. + + This avoids import issues with the router module. + """ + metadata = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + "dependencies": [], + "requires_python": ">=3.10", + } + + # Parse PEP 723 dependencies + pep723_match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL) + if pep723_match: + block = pep723_match.group(1) + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if deps_match: + metadata["dependencies"] = re.findall(r'"([^"]+)"', deps_match.group(1)) + python_match = re.search(r'# requires-python = "([^"]+)"', block) + if python_match: + metadata["requires_python"] = python_match.group(1) + + # Parse docstring @-tags + docstring_match = re.search(r'"""(.*?)"""', source, re.DOTALL) + if not docstring_match: + docstring_match = re.search(r"'''(.*?)'''", source, re.DOTALL) + + if docstring_match: + docstring = docstring_match.group(1) + lines = docstring.split("\n") + + current_param = None + desc_lines = [] + in_description = False + + for line in lines: + stripped = line.strip() + + if stripped.startswith("@effect "): + metadata["name"] = stripped[8:].strip() + in_description = False + + elif stripped.startswith("@version "): + metadata["version"] = stripped[9:].strip() + + elif stripped.startswith("@author "): + metadata["author"] = stripped[8:].strip() + + elif stripped.startswith("@temporal "): + val = stripped[10:].strip().lower() + metadata["temporal"] = val in ("true", "yes", "1") + + elif stripped.startswith("@description"): + in_description = True + desc_lines = [] + + elif stripped.startswith("@param "): + if in_description: + metadata["description"] = " ".join(desc_lines) + in_description = False + if current_param: + metadata["params"].append(current_param) + parts = stripped[7:].split() + if len(parts) >= 2: + current_param = { + "name": parts[0], + "type": parts[1], + "description": "", + } + else: + current_param = None + + elif stripped.startswith("@range ") and current_param: + range_parts = stripped[7:].split() + if len(range_parts) >= 2: + try: + current_param["range"] = [float(range_parts[0]), float(range_parts[1])] + except ValueError: + pass + + elif stripped.startswith("@default ") and current_param: + current_param["default"] = stripped[9:].strip() + + elif stripped.startswith("@example"): + if in_description: + metadata["description"] = " ".join(desc_lines) + in_description = False + if current_param: + metadata["params"].append(current_param) + current_param = None + + elif in_description and stripped: + desc_lines.append(stripped) + + elif current_param and stripped and not stripped.startswith("@"): + current_param["description"] = stripped + + if in_description: + metadata["description"] = " ".join(desc_lines) + + if current_param: + metadata["params"].append(current_param) + + return metadata + + +class TestEffectMetadataParsing: + """Tests for parse_effect_metadata function.""" + + def test_parses_pep723_dependencies(self) -> None: + """Should extract dependencies from PEP 723 script block.""" + source = ''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["numpy", "opencv-python"] +# /// +""" +@effect test_effect +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["dependencies"] == ["numpy", "opencv-python"] + assert meta["requires_python"] == ">=3.10" + + def test_parses_effect_name(self) -> None: + """Should extract effect name from @effect tag.""" + source = ''' +""" +@effect brightness +@version 2.0.0 +@author @artist@example.com +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["name"] == "brightness" + assert meta["version"] == "2.0.0" + assert meta["author"] == "@artist@example.com" + + def test_parses_parameters(self) -> None: + """Should extract parameter definitions.""" + source = ''' +""" +@effect brightness +@param level float +@range -1.0 1.0 +@default 0.0 +Brightness adjustment level +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert len(meta["params"]) == 1 + param = meta["params"][0] + assert param["name"] == "level" + assert param["type"] == "float" + assert param["range"] == [-1.0, 1.0] + assert param["default"] == "0.0" + + def test_parses_temporal_flag(self) -> None: + """Should parse temporal flag correctly.""" + source_temporal = ''' +""" +@effect motion_blur +@temporal true +""" +def process_frame(frame, params, state): + return frame, state +''' + source_not_temporal = ''' +""" +@effect brightness +@temporal false +""" +def process_frame(frame, params, state): + return frame, state +''' + + assert parse_effect_metadata_standalone(source_temporal)["temporal"] is True + assert parse_effect_metadata_standalone(source_not_temporal)["temporal"] is False + + def test_handles_missing_metadata(self) -> None: + """Should return sensible defaults for minimal source.""" + source = ''' +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert meta["name"] == "" + assert meta["version"] == "1.0.0" + assert meta["dependencies"] == [] + assert meta["params"] == [] + + def test_parses_description(self) -> None: + """Should extract description text.""" + source = ''' +""" +@effect test +@description +This is a multi-line +description of the effect. +@param x float +""" +def process_frame(frame, params, state): + return frame, state +''' + meta = parse_effect_metadata_standalone(source) + + assert "multi-line" in meta["description"] + + +class TestHomePageEffectsCount: + """Test that effects count is shown on home page.""" + + def test_home_template_has_effects_card(self) -> None: + """Home page should display effects count.""" + path = Path('/home/giles/art/art-celery/app/templates/home.html') + content = path.read_text() + + assert 'stats.effects' in content, \ + "Home page should display stats.effects count" + assert 'href="/effects"' in content, \ + "Home page should link to /effects" + + def test_home_route_provides_effects_count(self) -> None: + """Home route should provide effects count in stats.""" + path = Path('/home/giles/art/art-celery/app/routers/home.py') + content = path.read_text() + + assert 'stats["effects"]' in content, \ + "Home route should populate stats['effects']" + + +class TestNavigationIncludesEffects: + """Test that Effects link is in navigation.""" + + def test_base_template_has_effects_link(self) -> None: + """Base template should have Effects navigation link.""" + base_path = Path('/home/giles/art/art-celery/app/templates/base.html') + content = base_path.read_text() + + assert 'href="/effects"' in content + assert "Effects" in content + assert "active_tab == 'effects'" in content + + def test_effects_link_between_recipes_and_media(self) -> None: + """Effects link should be positioned between Recipes and Media.""" + base_path = Path('/home/giles/art/art-celery/app/templates/base.html') + content = base_path.read_text() + + recipes_pos = content.find('href="/recipes"') + effects_pos = content.find('href="/effects"') + media_pos = content.find('href="/media"') + + assert recipes_pos < effects_pos < media_pos, \ + "Effects link should be between Recipes and Media" + + +class TestEffectsTemplatesExist: + """Tests for effects template files.""" + + def test_list_template_exists(self) -> None: + """List template should exist.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + assert path.exists(), "effects/list.html template should exist" + + def test_detail_template_exists(self) -> None: + """Detail template should exist.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + assert path.exists(), "effects/detail.html template should exist" + + def test_list_template_extends_base(self) -> None: + """List template should extend base.html.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + content = path.read_text() + assert '{% extends "base.html" %}' in content + + def test_detail_template_extends_base(self) -> None: + """Detail template should extend base.html.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert '{% extends "base.html" %}' in content + + def test_list_template_has_upload_button(self) -> None: + """List template should have upload functionality.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/list.html') + content = path.read_text() + assert 'Upload Effect' in content or 'upload' in content.lower() + + def test_detail_template_shows_parameters(self) -> None: + """Detail template should display parameters.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'params' in content.lower() or 'parameter' in content.lower() + + def test_detail_template_shows_source_code(self) -> None: + """Detail template should show source code.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'source' in content.lower() + assert 'language-python' in content + + def test_detail_template_shows_dependencies(self) -> None: + """Detail template should display dependencies.""" + path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html') + content = path.read_text() + assert 'dependencies' in content.lower() + + +class TestEffectsRouterExists: + """Tests for effects router configuration.""" + + def test_effects_router_file_exists(self) -> None: + """Effects router should exist.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + assert path.exists(), "effects.py router should exist" + + def test_effects_router_has_list_endpoint(self) -> None: + """Effects router should have list endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.get("")' in content or "@router.get('')" in content + + def test_effects_router_has_detail_endpoint(self) -> None: + """Effects router should have detail endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.get("/{cid}")' in content + + def test_effects_router_has_upload_endpoint(self) -> None: + """Effects router should have upload endpoint.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert '@router.post("/upload")' in content + + def test_effects_router_renders_templates(self) -> None: + """Effects router should render HTML templates.""" + path = Path('/home/giles/art/art-celery/app/routers/effects.py') + content = path.read_text() + assert 'effects/list.html' in content + assert 'effects/detail.html' in content diff --git a/tests/test_execute_recipe.py b/tests/test_execute_recipe.py new file mode 100644 index 0000000..0a32326 --- /dev/null +++ b/tests/test_execute_recipe.py @@ -0,0 +1,529 @@ +""" +Tests for execute_recipe SOURCE node resolution. + +These tests verify that SOURCE nodes with :input true are correctly +resolved from input_hashes at execution time. +""" + +import pytest +from unittest.mock import MagicMock, patch +from typing import Dict, Any + + +class MockStep: + """Mock ExecutionStep for testing.""" + + def __init__( + self, + step_id: str, + node_type: str, + config: Dict[str, Any], + cache_id: str, + input_steps: list = None, + level: int = 0, + ): + self.step_id = step_id + self.node_type = node_type + self.config = config + self.cache_id = cache_id + self.input_steps = input_steps or [] + self.level = level + self.name = config.get("name") + self.outputs = [] + + +class MockPlan: + """Mock ExecutionPlanSexp for testing.""" + + def __init__(self, steps: list, output_step_id: str, plan_id: str = "test-plan"): + self.steps = steps + self.output_step_id = output_step_id + self.plan_id = plan_id + + def to_string(self, pretty: bool = False) -> str: + return "(plan test)" + + +def resolve_source_cid(step_config: Dict[str, Any], input_hashes: Dict[str, str]) -> str: + """ + Resolve CID for a SOURCE node. + + This is the logic that should be in execute_recipe - extracted here for unit testing. + """ + source_cid = step_config.get("cid") + + # If source has :input true, resolve CID from input_hashes + if not source_cid and step_config.get("input"): + source_name = step_config.get("name", "") + # Try various key formats for lookup + name_variants = [ + source_name, + source_name.lower().replace(" ", "-"), + source_name.lower().replace(" ", "_"), + source_name.lower(), + ] + for variant in name_variants: + if variant in input_hashes: + source_cid = input_hashes[variant] + break + + if not source_cid: + raise ValueError( + f"SOURCE '{source_name}' not found in input_hashes. " + f"Available: {list(input_hashes.keys())}" + ) + + return source_cid + + +class TestSourceCidResolution: + """Tests for SOURCE node CID resolution from input_hashes.""" + + def test_source_with_fixed_cid(self): + """SOURCE with :cid should use that directly.""" + config = {"cid": "QmFixedCid123"} + input_hashes = {} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmFixedCid123" + + def test_source_with_input_true_exact_match(self): + """SOURCE with :input true should resolve from input_hashes by exact name.""" + config = {"input": True, "name": "my-video"} + input_hashes = {"my-video": "QmInputVideo456"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmInputVideo456" + + def test_source_with_input_true_normalized_dash(self): + """SOURCE with :input true should resolve from normalized dash format.""" + config = {"input": True, "name": "Second Video"} + input_hashes = {"second-video": "QmSecondVideo789"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmSecondVideo789" + + def test_source_with_input_true_normalized_underscore(self): + """SOURCE with :input true should resolve from normalized underscore format.""" + config = {"input": True, "name": "Second Video"} + input_hashes = {"second_video": "QmSecondVideoUnderscore"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmSecondVideoUnderscore" + + def test_source_with_input_true_lowercase(self): + """SOURCE with :input true should resolve from lowercase format.""" + config = {"input": True, "name": "MyVideo"} + input_hashes = {"myvideo": "QmLowercaseVideo"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmLowercaseVideo" + + def test_source_with_input_true_missing_raises(self): + """SOURCE with :input true should raise if not in input_hashes.""" + config = {"input": True, "name": "Missing Video"} + input_hashes = {"other-video": "QmOther123"} + + with pytest.raises(ValueError) as excinfo: + resolve_source_cid(config, input_hashes) + + assert "Missing Video" in str(excinfo.value) + assert "not found in input_hashes" in str(excinfo.value) + assert "other-video" in str(excinfo.value) # Shows available keys + + def test_source_without_cid_or_input_returns_none(self): + """SOURCE without :cid or :input should return None.""" + config = {"name": "some-source"} + input_hashes = {} + + cid = resolve_source_cid(config, input_hashes) + assert cid is None + + def test_source_priority_cid_over_input(self): + """SOURCE with both :cid and :input true should use :cid.""" + config = {"cid": "QmDirectCid", "input": True, "name": "my-video"} + input_hashes = {"my-video": "QmInputHashCid"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmDirectCid" # :cid takes priority + + +class TestSourceNameVariants: + """Tests for the various input name normalization formats.""" + + @pytest.mark.parametrize("source_name,input_key", [ + ("video", "video"), + ("My Video", "my-video"), + ("My Video", "my_video"), + ("My Video", "My Video"), + ("CamelCase", "camelcase"), + ("multiple spaces", "multiple spaces"), # Only single replace + ]) + def test_name_variant_matching(self, source_name: str, input_key: str): + """Various name formats should match.""" + config = {"input": True, "name": source_name} + input_hashes = {input_key: "QmTestCid"} + + cid = resolve_source_cid(config, input_hashes) + assert cid == "QmTestCid" + + +class TestExecuteRecipeIntegration: + """Integration tests for execute_recipe with SOURCE nodes.""" + + def test_recipe_with_user_input_source(self): + """ + Recipe execution should resolve SOURCE nodes with :input true. + + This is the bug that was causing "No executor for node type: SOURCE". + """ + # Create mock plan with a SOURCE that has :input true + source_step = MockStep( + step_id="source_1", + node_type="SOURCE", + config={"input": True, "name": "Second Video", "description": "User input"}, + cache_id="abc123", + level=0, + ) + + effect_step = MockStep( + step_id="effect_1", + node_type="EFFECT", + config={"effect": "invert"}, + cache_id="def456", + input_steps=["source_1"], + level=1, + ) + + plan = MockPlan( + steps=[source_step, effect_step], + output_step_id="effect_1", + ) + + # Input hashes provided by user + input_hashes = { + "second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + } + + # Verify source CID resolution works + resolved_cid = resolve_source_cid(source_step.config, input_hashes) + assert resolved_cid == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS" + + def test_recipe_with_fixed_and_user_sources(self): + """ + Recipe with both fixed (asset) and user-input sources. + + This is the dog-invert-concat recipe pattern: + - Fixed source: cat asset with known CID + - User input: Second Video from input_hashes + """ + fixed_source = MockStep( + step_id="source_cat", + node_type="SOURCE", + config={"cid": "QmCatVideo123", "asset": "cat"}, + cache_id="cat_cache", + level=0, + ) + + user_source = MockStep( + step_id="source_user", + node_type="SOURCE", + config={"input": True, "name": "Second Video"}, + cache_id="user_cache", + level=0, + ) + + input_hashes = { + "second-video": "QmUserProvidedVideo456" + } + + # Fixed source uses its cid + fixed_cid = resolve_source_cid(fixed_source.config, input_hashes) + assert fixed_cid == "QmCatVideo123" + + # User source resolves from input_hashes + user_cid = resolve_source_cid(user_source.config, input_hashes) + assert user_cid == "QmUserProvidedVideo456" + + +class TestCompoundNodeHandling: + """ + Tests for COMPOUND node handling. + + COMPOUND nodes are collapsed effect chains that should be executed + sequentially through their respective effect executors. + + Bug fixed: "No executor for node type: COMPOUND" + """ + + def test_compound_node_has_filter_chain(self): + """COMPOUND nodes must have a filter_chain config.""" + step = MockStep( + step_id="compound_1", + node_type="COMPOUND", + config={ + "filter_chain": [ + {"type": "EFFECT", "config": {"effect": "identity"}}, + {"type": "EFFECT", "config": {"effect": "dog"}}, + ], + "inputs": ["source_1"], + }, + cache_id="compound_cache", + level=1, + ) + + assert step.node_type == "COMPOUND" + assert "filter_chain" in step.config + assert len(step.config["filter_chain"]) == 2 + + def test_compound_filter_chain_has_effects(self): + """COMPOUND filter_chain should contain EFFECT items with effect names.""" + filter_chain = [ + {"type": "EFFECT", "config": {"effect": "identity", "cid": "Qm123"}}, + {"type": "EFFECT", "config": {"effect": "dog", "cid": "Qm456"}}, + ] + + for item in filter_chain: + assert item["type"] == "EFFECT" + assert "effect" in item["config"] + assert "cid" in item["config"] + + def test_compound_requires_inputs(self): + """COMPOUND nodes must have input steps.""" + step = MockStep( + step_id="compound_1", + node_type="COMPOUND", + config={"filter_chain": [], "inputs": []}, + cache_id="compound_cache", + input_steps=[], + level=1, + ) + + # Empty inputs should be detected as error + assert len(step.input_steps) == 0 + # The execute_recipe should raise ValueError for empty inputs + + +class TestCacheIdLookup: + """ + Tests for cache lookup by code-addressed cache_id. + + Bug fixed: Cache lookups by cache_id (code hash) were failing because + only IPFS CID was indexed. Now we also index by node_id when different. + """ + + def test_cache_id_is_code_addressed(self): + """cache_id should be a SHA3-256 hash (64 hex chars), not IPFS CID.""" + # Code-addressed hash example + cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4" + + assert len(cache_id) == 64 + assert all(c in "0123456789abcdef" for c in cache_id) + assert not cache_id.startswith("Qm") # Not IPFS CID + + def test_ipfs_cid_format(self): + """IPFS CIDs start with 'Qm' (v0) or 'bafy' (v1).""" + ipfs_cid_v0 = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + ipfs_cid_v1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi" + + assert ipfs_cid_v0.startswith("Qm") + assert ipfs_cid_v1.startswith("bafy") + + def test_cache_id_differs_from_ipfs_cid(self): + """ + Code-addressed cache_id is computed BEFORE execution. + IPFS CID is computed AFTER execution from file content. + They will differ for the same logical step. + """ + # Same step can have: + cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4" + ipfs_cid = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ" + + assert cache_id != ipfs_cid + # Both should be indexable for cache lookups + + +class TestPlanStepTypes: + """ + Tests verifying all node types in S-expression plans are handled. + + These tests document the node types that execute_recipe must handle. + """ + + def test_source_node_types(self): + """SOURCE nodes: fixed asset or user input.""" + # Fixed asset source + fixed = MockStep("s1", "SOURCE", {"cid": "Qm123"}, "cache1") + assert fixed.node_type == "SOURCE" + assert "cid" in fixed.config + + # User input source + user = MockStep("s2", "SOURCE", {"input": True, "name": "video"}, "cache2") + assert user.node_type == "SOURCE" + assert user.config.get("input") is True + + def test_effect_node_type(self): + """EFFECT nodes: single effect application.""" + step = MockStep( + "e1", "EFFECT", + {"effect": "invert", "cid": "QmEffect123", "intensity": 1.0}, + "cache3" + ) + assert step.node_type == "EFFECT" + assert "effect" in step.config + + def test_compound_node_type(self): + """COMPOUND nodes: collapsed effect chains.""" + step = MockStep( + "c1", "COMPOUND", + {"filter_chain": [{"type": "EFFECT", "config": {}}]}, + "cache4" + ) + assert step.node_type == "COMPOUND" + assert "filter_chain" in step.config + + def test_sequence_node_type(self): + """SEQUENCE nodes: concatenate multiple clips.""" + step = MockStep( + "seq1", "SEQUENCE", + {"transition": {"type": "cut"}}, + "cache5", + input_steps=["clip1", "clip2"] + ) + assert step.node_type == "SEQUENCE" + + +class TestNodeTypeCaseSensitivity: + """ + Tests for node type case handling. + + Bug fixed: S-expression plans use lowercase (source, compound, effect) + but code was checking uppercase (SOURCE, COMPOUND, EFFECT). + """ + + def test_source_lowercase_from_sexp(self): + """S-expression plans produce lowercase node types.""" + # From plan: (source :cid "Qm...") + step = MockStep("s1", "source", {"cid": "Qm123"}, "cache1") + + # Code should handle lowercase + assert step.node_type.upper() == "SOURCE" + + def test_compound_lowercase_from_sexp(self): + """COMPOUND from S-expression is lowercase.""" + # From plan: (compound :filter_chain ...) + step = MockStep("c1", "compound", {"filter_chain": []}, "cache2") + + assert step.node_type.upper() == "COMPOUND" + + def test_effect_lowercase_from_sexp(self): + """EFFECT from S-expression is lowercase.""" + # From plan: (effect :effect "invert" ...) + step = MockStep("e1", "effect", {"effect": "invert"}, "cache3") + + assert step.node_type.upper() == "EFFECT" + + def test_sequence_lowercase_from_sexp(self): + """SEQUENCE from S-expression is lowercase.""" + # From plan: (sequence :transition ...) + step = MockStep("seq1", "sequence", {"transition": {}}, "cache4") + + assert step.node_type.upper() == "SEQUENCE" + + def test_node_type_comparison_must_be_case_insensitive(self): + """ + Node type comparisons must be case-insensitive. + + This is the actual bug - checking step.node_type == "SOURCE" + fails when step.node_type is "source" from S-expression. + """ + sexp_types = ["source", "compound", "effect", "sequence"] + code_types = ["SOURCE", "COMPOUND", "EFFECT", "SEQUENCE"] + + for sexp, code in zip(sexp_types, code_types): + # Wrong: direct comparison fails + assert sexp != code, f"{sexp} should not equal {code}" + + # Right: case-insensitive comparison works + assert sexp.upper() == code, f"{sexp}.upper() should equal {code}" + + +class TestExecuteRecipeErrorHandling: + """Tests for error handling in execute_recipe.""" + + def test_missing_input_hash_error_message(self): + """Error should list available input keys when source not found.""" + config = {"input": True, "name": "Unknown Video"} + input_hashes = {"video-a": "Qm1", "video-b": "Qm2"} + + with pytest.raises(ValueError) as excinfo: + resolve_source_cid(config, input_hashes) + + error_msg = str(excinfo.value) + assert "Unknown Video" in error_msg + assert "video-a" in error_msg or "video-b" in error_msg + + def test_source_no_cid_no_input_error(self): + """SOURCE without cid or input flag should return None (invalid).""" + config = {"name": "broken-source"} # Missing both cid and input + input_hashes = {} + + result = resolve_source_cid(config, input_hashes) + assert result is None # execute_recipe should catch this + + +class TestRecipeOutputRequired: + """ + Tests verifying recipes must produce output to succeed. + + Bug: Recipe was returning success=True with output_cid=None + """ + + def test_recipe_without_output_should_fail(self): + """ + Recipe execution must fail if no output is produced. + + This catches the bug where execute_recipe returned success=True + but output_cid was None. + """ + # Simulate the check that should happen in execute_recipe + output_cid = None + step_results = {"step1": {"status": "executed", "path": "/tmp/x"}} + + # This is the logic that should be in execute_recipe + success = output_cid is not None + + assert success is False, "Recipe with no output_cid should fail" + + def test_recipe_with_output_should_succeed(self): + """Recipe with valid output should succeed.""" + output_cid = "QmOutputCid123" + + success = output_cid is not None + assert success is True + + def test_output_step_result_must_have_cid(self): + """Output step result must contain cid or cache_id.""" + # Step result without cid + bad_result = {"status": "executed", "path": "/tmp/output.mkv"} + output_cid = bad_result.get("cid") or bad_result.get("cache_id") + assert output_cid is None, "Should detect missing cid" + + # Step result with cid + good_result = {"status": "executed", "path": "/tmp/output.mkv", "cid": "Qm123"} + output_cid = good_result.get("cid") or good_result.get("cache_id") + assert output_cid == "Qm123" + + def test_output_step_must_exist_in_results(self): + """Output step must be present in step_results.""" + step_results = { + "source_1": {"status": "source", "cid": "QmSrc"}, + "effect_1": {"status": "executed", "cid": "QmEffect"}, + # Note: output_step "sequence_1" is missing! + } + + output_step_id = "sequence_1" + output_result = step_results.get(output_step_id, {}) + output_cid = output_result.get("cid") + + assert output_cid is None, "Missing output step should result in no cid" diff --git a/tests/test_frame_compatibility.py b/tests/test_frame_compatibility.py new file mode 100644 index 0000000..f12cce0 --- /dev/null +++ b/tests/test_frame_compatibility.py @@ -0,0 +1,185 @@ +""" +Integration tests for GPU/CPU frame compatibility. + +These tests verify that all primitives work correctly with both: +- numpy arrays (CPU frames) +- CuPy arrays (GPU frames) +- GPUFrame wrapper objects + +Run with: pytest tests/test_frame_compatibility.py -v +""" + +import pytest +import numpy as np +import sys +import os + +# Add parent to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Try to import CuPy +try: + import cupy as cp + HAS_GPU = True +except ImportError: + cp = None + HAS_GPU = False + + +def create_test_frame(on_gpu=False): + """Create a test frame (100x100 RGB).""" + frame = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + if on_gpu and HAS_GPU: + return cp.asarray(frame) + return frame + + +class MockGPUFrame: + """Mock GPUFrame for testing without full GPU stack.""" + def __init__(self, data): + self._data = data + + @property + def cpu(self): + if HAS_GPU and hasattr(self._data, 'get'): + return self._data.get() + return self._data + + @property + def gpu(self): + if HAS_GPU: + if hasattr(self._data, 'get'): + return self._data + return cp.asarray(self._data) + raise RuntimeError("No GPU available") + + +class TestColorOps: + """Test color_ops primitives with different frame types.""" + + def test_shift_hsv_numpy(self): + """shift-hsv should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = create_test_frame(on_gpu=False) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) + assert result.shape == frame.shape + + @pytest.mark.skipif(not HAS_GPU, reason="No GPU") + def test_shift_hsv_cupy(self): + """shift-hsv should work with CuPy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = create_test_frame(on_gpu=True) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) # Should return numpy + + def test_shift_hsv_gpuframe(self): + """shift-hsv should work with GPUFrame objects.""" + from sexp_effects.primitive_libs.color_ops import prim_shift_hsv + frame = MockGPUFrame(create_test_frame(on_gpu=False)) + result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9) + assert isinstance(result, np.ndarray) + + def test_invert_numpy(self): + """invert-img should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_invert_img + frame = create_test_frame(on_gpu=False) + result = prim_invert_img(frame) + assert isinstance(result, np.ndarray) + + def test_adjust_numpy(self): + """adjust should work with numpy arrays.""" + from sexp_effects.primitive_libs.color_ops import prim_adjust + frame = create_test_frame(on_gpu=False) + result = prim_adjust(frame, brightness=10, contrast=1.2) + assert isinstance(result, np.ndarray) + + +class TestGeometry: + """Test geometry primitives with different frame types.""" + + def test_rotate_numpy(self): + """rotate should work with numpy arrays.""" + from sexp_effects.primitive_libs.geometry import prim_rotate + frame = create_test_frame(on_gpu=False) + result = prim_rotate(frame, 45) + assert isinstance(result, np.ndarray) + + def test_scale_numpy(self): + """scale should work with numpy arrays.""" + from sexp_effects.primitive_libs.geometry import prim_scale + frame = create_test_frame(on_gpu=False) + result = prim_scale(frame, 1.5) + assert isinstance(result, np.ndarray) + + +class TestBlending: + """Test blending primitives with different frame types.""" + + def test_blend_numpy(self): + """blend should work with numpy arrays.""" + from sexp_effects.primitive_libs.blending import prim_blend + frame_a = create_test_frame(on_gpu=False) + frame_b = create_test_frame(on_gpu=False) + result = prim_blend(frame_a, frame_b, 0.5) + assert isinstance(result, np.ndarray) + + +class TestInterpreterConversion: + """Test the interpreter's frame conversion.""" + + def test_maybe_to_numpy_none(self): + """_maybe_to_numpy should handle None.""" + from streaming.stream_sexp_generic import StreamInterpreter + # Create minimal interpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + assert interp._maybe_to_numpy(None) is None + + def test_maybe_to_numpy_numpy(self): + """_maybe_to_numpy should pass through numpy arrays.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = create_test_frame(on_gpu=False) + result = interp._maybe_to_numpy(frame) + assert result is frame # Should be same object + + @pytest.mark.skipif(not HAS_GPU, reason="No GPU") + def test_maybe_to_numpy_cupy(self): + """_maybe_to_numpy should convert CuPy to numpy.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = create_test_frame(on_gpu=True) + result = interp._maybe_to_numpy(frame) + assert isinstance(result, np.ndarray) + + def test_maybe_to_numpy_gpuframe(self): + """_maybe_to_numpy should convert GPUFrame to numpy.""" + from streaming.stream_sexp_generic import StreamInterpreter + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))') + f.flush() + interp = StreamInterpreter(f.name) + + frame = MockGPUFrame(create_test_frame(on_gpu=False)) + result = interp._maybe_to_numpy(frame) + assert isinstance(result, np.ndarray) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_item_visibility.py b/tests/test_item_visibility.py new file mode 100644 index 0000000..be6c50b --- /dev/null +++ b/tests/test_item_visibility.py @@ -0,0 +1,272 @@ +""" +Tests for item visibility in L1 web UI. + +Bug found 2026-01-12: L1 run succeeded but web UI not showing: +- Runs +- Recipes +- Created media + +Root causes identified: +1. Recipes: owner field filtering but owner never set in loaded recipes +2. Media: item_types table entries not created on upload/import +3. Run outputs: outputs not registered in item_types table +""" + +import pytest +from pathlib import Path +from unittest.mock import MagicMock, AsyncMock, patch +import tempfile + + +class TestRecipeVisibility: + """Tests for recipe listing visibility.""" + + def test_recipe_filter_allows_none_owner(self) -> None: + """ + Regression test: The recipe filter should allow recipes where owner is None. + + Bug: recipe_service.list_recipes() filtered by owner == actor_id, + but owner field is None in recipes loaded from S-expression files. + This caused ALL recipes to be filtered out. + + Fix: The filter is now: actor_id is None or owner is None or owner == actor_id + """ + # Simulate the filter logic from recipe_service.list_recipes + # OLD (broken): if actor_id is None or owner == actor_id + # NEW (fixed): if actor_id is None or owner is None or owner == actor_id + + test_cases = [ + # (actor_id, owner, expected_visible, description) + (None, None, True, "No filter, no owner -> visible"), + (None, "@someone@example.com", True, "No filter, has owner -> visible"), + ("@testuser@example.com", None, True, "Has filter, no owner -> visible (shared)"), + ("@testuser@example.com", "@testuser@example.com", True, "Filter matches owner -> visible"), + ("@testuser@example.com", "@other@example.com", False, "Filter doesn't match -> hidden"), + ] + + for actor_id, owner, expected_visible, description in test_cases: + # This is the FIXED filter logic from recipe_service.py line 86 + is_visible = actor_id is None or owner is None or owner == actor_id + + assert is_visible == expected_visible, f"Failed: {description}" + + def test_recipe_filter_old_logic_was_broken(self) -> None: + """Document that the old filter logic excluded all recipes with owner=None.""" + # OLD filter: actor_id is None or owner == actor_id + # This broke when owner=None and actor_id was provided + + actor_id = "@testuser@example.com" + owner = None # This is what compiled sexp produces + + # OLD logic (broken): + old_logic_visible = actor_id is None or owner == actor_id + assert old_logic_visible is False, "Old logic incorrectly hid owner=None recipes" + + # NEW logic (fixed): + new_logic_visible = actor_id is None or owner is None or owner == actor_id + assert new_logic_visible is True, "New logic should show owner=None recipes" + + +class TestMediaVisibility: + """Tests for media visibility after upload.""" + + @pytest.mark.asyncio + async def test_upload_content_creates_item_type_record(self) -> None: + """ + Test: Uploaded media must be registered in item_types table via save_item_metadata. + + The save_item_metadata function creates entries in item_types table, + enabling the media to appear in list_media queries. + """ + import importlib.util + spec = importlib.util.spec_from_file_location( + "cache_service", + "/home/giles/art/art-celery/app/services/cache_service.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + CacheService = module.CacheService + + # Create mocks + mock_db = AsyncMock() + mock_db.create_cache_item = AsyncMock() + mock_db.save_item_metadata = AsyncMock() + + mock_cache = MagicMock() + cached_result = MagicMock() + cached_result.cid = "QmUploadedContent123" + mock_cache.put.return_value = (cached_result, "QmIPFSCid123") + + service = CacheService(database=mock_db, cache_manager=mock_cache) + + # Upload content + cid, ipfs_cid, error = await service.upload_content( + content=b"test video content", + filename="test.mp4", + actor_id="@testuser@example.com", + ) + + assert error is None, f"Upload failed: {error}" + assert cid is not None + + # Verify save_item_metadata was called (which creates item_types entry) + mock_db.save_item_metadata.assert_called_once() + + # Verify it was called with correct actor_id and a media type (not mime type) + call_kwargs = mock_db.save_item_metadata.call_args[1] + assert call_kwargs.get('actor_id') == "@testuser@example.com", \ + "save_item_metadata must be called with the uploading user's actor_id" + # item_type should be media category like "video", "image", "audio", "unknown" + # NOT mime type like "video/mp4" + item_type = call_kwargs.get('item_type') + assert item_type in ("video", "image", "audio", "unknown"), \ + f"item_type should be media category, got '{item_type}'" + + @pytest.mark.asyncio + async def test_import_from_ipfs_creates_item_type_record(self) -> None: + """ + Test: Imported media must be registered in item_types table via save_item_metadata. + + The save_item_metadata function creates entries in item_types table with + detected media type, enabling the media to appear in list_media queries. + """ + import importlib.util + spec = importlib.util.spec_from_file_location( + "cache_service", + "/home/giles/art/art-celery/app/services/cache_service.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + CacheService = module.CacheService + + mock_db = AsyncMock() + mock_db.create_cache_item = AsyncMock() + mock_db.save_item_metadata = AsyncMock() + + mock_cache = MagicMock() + cached_result = MagicMock() + cached_result.cid = "QmImportedContent123" + mock_cache.put.return_value = (cached_result, "QmIPFSCid456") + + service = CacheService(database=mock_db, cache_manager=mock_cache) + service.cache_dir = Path(tempfile.gettempdir()) + + # We need to mock the ipfs_client module at the right location + import importlib + ipfs_module = MagicMock() + ipfs_module.get_file = MagicMock(return_value=True) + + # Patch at module level + with patch.dict('sys.modules', {'ipfs_client': ipfs_module}): + # Import from IPFS + cid, error = await service.import_from_ipfs( + ipfs_cid="QmSourceIPFSCid", + actor_id="@testuser@example.com", + ) + + # Verify save_item_metadata was called (which creates item_types entry) + mock_db.save_item_metadata.assert_called_once() + + # Verify it was called with detected media type (not hardcoded "media") + call_kwargs = mock_db.save_item_metadata.call_args[1] + item_type = call_kwargs.get('item_type') + assert item_type in ("video", "image", "audio", "unknown"), \ + f"item_type should be detected media category, got '{item_type}'" + + +class TestRunOutputVisibility: + """Tests for run output visibility.""" + + @pytest.mark.asyncio + async def test_completed_run_output_visible_in_media_list(self) -> None: + """ + Run outputs should be accessible in media listings. + + When a run completes, its output should be registered in item_types + so it appears in the user's media gallery. + """ + # This test documents the expected behavior + # Run outputs are stored in run_cache but should also be in item_types + # for the media gallery to show them + + # The fix should either: + # 1. Add item_types entry when run completes, OR + # 2. Modify list_media to also check run_cache outputs + pass # Placeholder for implementation test + + +class TestDatabaseItemTypes: + """Tests for item_types database operations.""" + + @pytest.mark.asyncio + async def test_add_item_type_function_exists(self) -> None: + """Verify add_item_type function exists and has correct signature.""" + import database + + assert hasattr(database, 'add_item_type'), \ + "database.add_item_type function should exist" + + # Check it's an async function + import inspect + assert inspect.iscoroutinefunction(database.add_item_type), \ + "add_item_type should be an async function" + + @pytest.mark.asyncio + async def test_get_user_items_returns_items_from_item_types(self) -> None: + """ + Verify get_user_items queries item_types table. + + If item_types has no entries for a user, they see no media. + """ + # This is a documentation test showing the data flow: + # 1. User uploads content -> should create item_types entry + # 2. list_media -> calls get_user_items -> queries item_types + # 3. If step 1 didn't create item_types entry, step 2 returns empty + pass + + +class TestTemplateRendering: + """Tests for template variable passing.""" + + def test_cache_not_found_template_receives_content_hash(self) -> None: + """ + Regression test: cache/not_found.html template requires content_hash. + + Bug: The template uses {{ content_hash[:24] }} but the route + doesn't pass content_hash to the render context. + + Error: jinja2.exceptions.UndefinedError: 'content_hash' is undefined + """ + # This test documents the bug - the template expects content_hash + # but the route at /app/app/routers/cache.py line 57 doesn't provide it + pass # Will verify fix by checking route code + + +class TestOwnerFieldInRecipes: + """Tests for owner field handling in recipes.""" + + def test_sexp_recipe_has_none_owner_by_default(self) -> None: + """ + S-expression recipes have owner=None by default. + + The compiled recipe includes owner field but it's None, + so the list_recipes filter must allow owner=None to show + shared/public recipes. + """ + sample_sexp = """ + (recipe "test" + (-> (source :input true :name "video") + (fx identity))) + """ + + from artdag.sexp import compile_string + + compiled = compile_string(sample_sexp) + recipe_dict = compiled.to_dict() + + # The compiled recipe has owner field but it's None + assert recipe_dict.get("owner") is None, \ + "Compiled S-expression should have owner=None" + + # This means the filter must allow owner=None for recipes to be visible + # The fix: if actor_id is None or owner is None or owner == actor_id diff --git a/tests/test_jax_pipeline_integration.py b/tests/test_jax_pipeline_integration.py new file mode 100644 index 0000000..8f9fb93 --- /dev/null +++ b/tests/test_jax_pipeline_integration.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +"""Integration tests comparing JAX and Python rendering pipelines. + +These tests ensure the JAX-compiled effect chains produce identical output +to the Python/NumPy path. They test: +1. Full effect pipelines through both interpreters +2. Multi-frame sequences (to catch phase-dependent bugs) +3. Compiled effect chain fusion +4. Edge cases like shrinking/zooming that affect boundary handling +""" +import os +import sys +import pytest +import numpy as np +import shutil +from pathlib import Path + +# Ensure the art-celery module is importable +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sexp_effects.primitive_libs import core as core_mod + + +# Path to test resources +TEST_DIR = Path('/home/giles/art/test') +EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects' +TEMPLATES_DIR = TEST_DIR / 'templates' + + +def create_test_image(h=96, w=128): + """Create a test image with distinct patterns.""" + import cv2 + img = np.zeros((h, w, 3), dtype=np.uint8) + + # Create gradient background + for y in range(h): + for x in range(w): + img[y, x] = [ + int(255 * x / w), # R: horizontal gradient + int(255 * y / h), # G: vertical gradient + 128 # B: constant + ] + + # Add features + cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1) + cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1) + + return img + + +@pytest.fixture(scope='module') +def test_env(tmp_path_factory): + """Set up test environment with sexp files and test media.""" + test_dir = tmp_path_factory.mktemp('sexp_test') + original_dir = os.getcwd() + os.chdir(test_dir) + + # Create directories + (test_dir / 'effects').mkdir() + (test_dir / 'sexp_effects' / 'effects').mkdir(parents=True) + + # Create test image + import cv2 + test_img = create_test_image() + cv2.imwrite(str(test_dir / 'test_image.png'), test_img) + + # Copy required effect files + for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']: + src = EFFECTS_DIR / f'{effect}.sexp' + dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp' + if src.exists(): + shutil.copy(src, dst) + + yield { + 'dir': test_dir, + 'image_path': test_dir / 'test_image.png', + 'test_img': test_img, + } + + os.chdir(original_dir) + + +def create_sexp_file(test_dir, content, filename='test.sexp'): + """Create a test sexp file.""" + path = test_dir / 'effects' / filename + with open(path, 'w') as f: + f.write(content) + return str(path) + + +class TestJaxPythonPipelineEquivalence: + """Test that JAX and Python pipelines produce equivalent output.""" + + def test_single_rotate_effect(self, test_env): + """Test that a single rotate effect matches between Python and JAX.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + + (frame (rotate frame :angle 15 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + # Python path + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + # JAX path + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + # Inject test image into globals + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = py_interp._eval(py_interp.frame_pipeline, frame_env) + jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env) + + # Force deferred if needed + py_result = np.asarray(py_interp._maybe_force(py_result)) + jax_result = np.asarray(jax_interp._maybe_force(jax_result)) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold" + + def test_rotate_then_zoom(self, test_env): + """Test rotate followed by zoom - tests effect chain fusion.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (-> (rotate frame :angle 15 :speed 0) + (zoom :amount 0.95 :speed 0))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold" + + def test_zoom_shrink_boundary_handling(self, test_env): + """Test zoom with shrinking factor - critical for boundary handling.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (zoom frame :amount 0.8 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + # Check corners specifically - these are most affected by boundary handling + h, w = test_img.shape[:2] + corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)] + for y, x in corners: + py_val = py_result[y, x] + jax_val = jax_result[y, x] + corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max() + assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}" + + def test_blend_two_clips(self, test_env): + """Test blending two effect chains - the core bug scenario.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + + # Check edge region specifically + edge_diff = diff[0, :].mean() + + assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold" + assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold" + + def test_blend_with_invert(self, test_env): + """Test blending with invert - matches the problematic recipe pattern.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (require-primitives "color_ops") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + (effect invert :path "../sexp_effects/effects/invert.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0) + (invert :amount 1)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold" + + +class TestDeferredEffectChainFusion: + """Test the DeferredEffectChain fusion mechanism specifically.""" + + def test_manual_vs_fused_chain(self, test_env): + """Compare manual application vs fused DeferredEffectChain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + # Create minimal sexp to load effects + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Manual step-by-step application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + rot_angle = -5.0 + zoom_amount = 0.95 + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain application + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + diff = np.abs(manual_result.astype(float) - fused_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}" + + # Check specific pixels + h, w = test_img.shape[:2] + for y in [0, h//2, h-1]: + for x in [0, w//2, w-1]: + manual_val = manual_result[y, x] + fused_val = fused_result[y, x] + pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max() + assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}" + + def test_chain_with_shrink_zoom_boundary(self, test_env): + """Test that shrinking zoom handles boundaries correctly in chain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Parameters that shrink the image (zoom < 1.0) + rot_angle = -4.555 + zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries + + # Manual application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + # Check top edge specifically - this is where boundary issues manifest + top_edge_manual = manual_result[0, :] + top_edge_fused = fused_result[0, :] + + edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float)) + mean_edge_diff = np.mean(edge_diff) + + assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}" + + # Check for zeros at edges that shouldn't be there + manual_edge_sum = np.sum(top_edge_manual) + fused_edge_sum = np.sum(top_edge_fused) + + if manual_edge_sum > 100: # If manual has significant values + assert fused_edge_sum > manual_edge_sum * 0.5, \ + f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}" + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_jax_primitives.py b/tests/test_jax_primitives.py new file mode 100644 index 0000000..5fad678 --- /dev/null +++ b/tests/test_jax_primitives.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +""" +Test framework to verify JAX primitives match Python primitives. + +Compares output of each primitive through: +1. Python/NumPy path +2. JAX path (CPU) +3. JAX path (GPU) - if available + +Reports any mismatches with detailed diffs. +""" +import sys +sys.path.insert(0, '/home/giles/art/art-celery') + +import numpy as np +import json +from pathlib import Path +from typing import Dict, List, Tuple, Any, Optional +from dataclasses import dataclass, field + +# Test configuration +TEST_WIDTH = 64 +TEST_HEIGHT = 48 +TOLERANCE_MEAN = 1.0 # Max allowed mean difference +TOLERANCE_MAX = 10.0 # Max allowed single pixel difference +TOLERANCE_PCT = 0.95 # Min % of pixels within ±1 + + +@dataclass +class TestResult: + primitive: str + passed: bool + python_mean: float = 0.0 + jax_mean: float = 0.0 + diff_mean: float = 0.0 + diff_max: float = 0.0 + pct_within_1: float = 0.0 + error: str = "" + + +def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'): + """Create a test frame with known pattern.""" + if pattern == 'gradient': + # Diagonal gradient + y, x = np.mgrid[0:h, 0:w] + r = (x * 255 / w).astype(np.uint8) + g = (y * 255 / h).astype(np.uint8) + b = ((x + y) * 127 / (w + h)).astype(np.uint8) + return np.stack([r, g, b], axis=2) + elif pattern == 'checkerboard': + y, x = np.mgrid[0:h, 0:w] + check = ((x // 8) + (y // 8)) % 2 + v = (check * 255).astype(np.uint8) + return np.stack([v, v, v], axis=2) + elif pattern == 'solid': + return np.full((h, w, 3), 128, dtype=np.uint8) + else: + return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + + +def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]: + """Compare two outputs, return (mean_diff, max_diff, pct_within_1).""" + if py_out is None or jax_out is None: + return float('inf'), float('inf'), 0.0 + + if isinstance(py_out, dict) and isinstance(jax_out, dict): + # Compare coordinate maps + diffs = [] + for k in py_out: + if k in jax_out: + py_arr = np.asarray(py_out[k]) + jax_arr = np.asarray(jax_out[k]) + if py_arr.shape == jax_arr.shape: + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + diffs.append(diff) + if diffs: + all_diff = np.concatenate([d.flatten() for d in diffs]) + return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1)) + return float('inf'), float('inf'), 0.0 + + py_arr = np.asarray(py_out) + jax_arr = np.asarray(jax_out) + + if py_arr.shape != jax_arr.shape: + return float('inf'), float('inf'), 0.0 + + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1)) + + +# ============================================================================ +# Primitive Test Definitions +# ============================================================================ + +PRIMITIVE_TESTS = { + # Geometry primitives + 'geometry:ripple-displace': { + 'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5], + 'returns': 'coords', + }, + 'geometry:rotate-img': { + 'args': ['frame', 45], + 'returns': 'frame', + }, + 'geometry:scale-img': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'geometry:flip-h': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'geometry:flip-v': { + 'args': ['frame'], + 'returns': 'frame', + }, + + # Color operations + 'color_ops:invert': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:grayscale': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:brightness': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:contrast': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:hue-shift': { + 'args': ['frame', 90], + 'returns': 'frame', + }, + + # Image operations + 'image:width': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:height': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:channel': { + 'args': ['frame', 0], + 'returns': 'array', + }, + + # Blending + 'blending:blend': { + 'args': ['frame', 'frame2', 0.5], + 'returns': 'frame', + }, + 'blending:blend-add': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, + 'blending:blend-multiply': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, +} + + +def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the Python interpreter.""" + if prim_name not in interp.primitives: + return None + + func = interp.primitives[prim_name] + args = [] + for a in test_def['args']: + if a == 'frame': + args.append(frame.copy()) + elif a == 'frame2': + args.append(frame2.copy()) + else: + args.append(a) + + try: + return func(*args) + except Exception as e: + return None + + +def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the JAX compiler.""" + try: + from streaming.sexp_to_jax import JaxCompiler + import jax.numpy as jnp + + compiler = JaxCompiler() + + # Build a simple expression to test the primitive + from sexp_effects.parser import Symbol, Keyword + + args = [] + env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)} + + for a in test_def['args']: + if a == 'frame': + args.append(Symbol('frame')) + elif a == 'frame2': + args.append(Symbol('frame2')) + else: + args.append(a) + + # Create expression: (prim_name arg1 arg2 ...) + expr = [Symbol(prim_name)] + args + + result = compiler._eval(expr, env) + + if hasattr(result, '__array__'): + return np.asarray(result) + return result + + except Exception as e: + return None + + +def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult: + """Test a single primitive.""" + frame = create_test_frame(pattern='gradient') + frame2 = create_test_frame(pattern='checkerboard') + + result = TestResult(primitive=prim_name, passed=False) + + # Run Python version + try: + py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2) + if py_out is not None and hasattr(py_out, 'shape'): + result.python_mean = float(np.mean(py_out)) + except Exception as e: + result.error = f"Python error: {e}" + return result + + # Run JAX version + try: + jax_out = run_jax_primitive(prim_name, test_def, frame, frame2) + if jax_out is not None and hasattr(jax_out, 'shape'): + result.jax_mean = float(np.mean(jax_out)) + except Exception as e: + result.error = f"JAX error: {e}" + return result + + if py_out is None: + result.error = "Python returned None" + return result + if jax_out is None: + result.error = "JAX returned None" + return result + + # Compare + diff_mean, diff_max, pct = compare_outputs(py_out, jax_out) + result.diff_mean = diff_mean + result.diff_max = diff_max + result.pct_within_1 = pct + + # Check pass/fail + result.passed = ( + diff_mean <= TOLERANCE_MEAN and + diff_max <= TOLERANCE_MAX and + pct >= TOLERANCE_PCT + ) + + if not result.passed: + result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}" + + return result + + +def discover_primitives(interp) -> List[str]: + """Discover all primitives available in the interpreter.""" + return sorted(interp.primitives.keys()) + + +def run_all_tests(verbose=True): + """Run all primitive tests.""" + import warnings + warnings.filterwarnings('ignore') + + import os + os.chdir('/home/giles/art/test') + + from streaming.stream_sexp_generic import StreamInterpreter + from sexp_effects.primitive_libs import core as core_mod + + core_mod.set_random_seed(42) + + # Create interpreter to get Python primitives + interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False) + interp._init() + + results = [] + + print("=" * 60) + print("JAX PRIMITIVE TEST SUITE") + print("=" * 60) + + # Test defined primitives + for prim_name, test_def in PRIMITIVE_TESTS.items(): + result = test_primitive(interp, prim_name, test_def) + results.append(result) + + status = "✓ PASS" if result.passed else "✗ FAIL" + if verbose: + print(f"{status} {prim_name}") + if not result.passed: + print(f" {result.error}") + + # Summary + passed = sum(1 for r in results if r.passed) + failed = sum(1 for r in results if not r.passed) + + print("\n" + "=" * 60) + print(f"SUMMARY: {passed} passed, {failed} failed") + print("=" * 60) + + if failed > 0: + print("\nFailed primitives:") + for r in results: + if not r.passed: + print(f" - {r.primitive}: {r.error}") + + return results + + +if __name__ == '__main__': + run_all_tests() diff --git a/tests/test_naming_service.py b/tests/test_naming_service.py new file mode 100644 index 0000000..98d8e52 --- /dev/null +++ b/tests/test_naming_service.py @@ -0,0 +1,246 @@ +""" +Tests for the friendly naming service. +""" + +import re +import pytest +from pathlib import Path + + +# Copy the pure functions from naming_service for testing +# This avoids import issues with the app module + +CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz" + + +def normalize_name(name: str) -> str: + """Copy of normalize_name for testing.""" + name = name.lower() + name = re.sub(r"[\s_]+", "-", name) + name = re.sub(r"[^a-z0-9-]", "", name) + name = re.sub(r"-+", "-", name) + name = name.strip("-") + return name or "unnamed" + + +def parse_friendly_name(friendly_name: str): + """Copy of parse_friendly_name for testing.""" + parts = friendly_name.strip().split(" ", 1) + base_name = parts[0] + version_id = parts[1] if len(parts) > 1 else None + return base_name, version_id + + +def format_friendly_name(base_name: str, version_id: str) -> str: + """Copy of format_friendly_name for testing.""" + return f"{base_name} {version_id}" + + +def format_l2_name(actor_id: str, base_name: str, version_id: str) -> str: + """Copy of format_l2_name for testing.""" + return f"{actor_id} {base_name} {version_id}" + + +class TestNameNormalization: + """Tests for name normalization.""" + + def test_normalize_simple_name(self) -> None: + """Simple names should be lowercased.""" + assert normalize_name("Brightness") == "brightness" + + def test_normalize_spaces_to_dashes(self) -> None: + """Spaces should be converted to dashes.""" + assert normalize_name("My Cool Effect") == "my-cool-effect" + + def test_normalize_underscores_to_dashes(self) -> None: + """Underscores should be converted to dashes.""" + assert normalize_name("brightness_v2") == "brightness-v2" + + def test_normalize_removes_special_chars(self) -> None: + """Special characters should be removed.""" + # Special chars are removed (not replaced with dashes) + assert normalize_name("Test!!!Effect") == "testeffect" + assert normalize_name("cool@effect#1") == "cooleffect1" + # But spaces/underscores become dashes first + assert normalize_name("Test Effect!") == "test-effect" + + def test_normalize_collapses_dashes(self) -> None: + """Multiple dashes should be collapsed.""" + assert normalize_name("test--effect") == "test-effect" + assert normalize_name("test___effect") == "test-effect" + + def test_normalize_strips_edge_dashes(self) -> None: + """Leading/trailing dashes should be stripped.""" + assert normalize_name("-test-effect-") == "test-effect" + + def test_normalize_empty_returns_unnamed(self) -> None: + """Empty names should return 'unnamed'.""" + assert normalize_name("") == "unnamed" + assert normalize_name("---") == "unnamed" + assert normalize_name("!!!") == "unnamed" + + +class TestFriendlyNameParsing: + """Tests for friendly name parsing.""" + + def test_parse_base_name_only(self) -> None: + """Parsing base name only returns None for version.""" + base, version = parse_friendly_name("my-effect") + assert base == "my-effect" + assert version is None + + def test_parse_with_version(self) -> None: + """Parsing with version returns both parts.""" + base, version = parse_friendly_name("my-effect 01hw3x9k") + assert base == "my-effect" + assert version == "01hw3x9k" + + def test_parse_strips_whitespace(self) -> None: + """Parsing should strip leading/trailing whitespace.""" + base, version = parse_friendly_name(" my-effect ") + assert base == "my-effect" + assert version is None + + +class TestFriendlyNameFormatting: + """Tests for friendly name formatting.""" + + def test_format_friendly_name(self) -> None: + """Format combines base and version with space.""" + assert format_friendly_name("my-effect", "01hw3x9k") == "my-effect 01hw3x9k" + + def test_format_l2_name(self) -> None: + """L2 format includes actor ID.""" + result = format_l2_name("@alice@example.com", "my-effect", "01hw3x9k") + assert result == "@alice@example.com my-effect 01hw3x9k" + + +class TestDatabaseSchemaExists: + """Tests that verify database schema includes friendly_names table.""" + + def test_schema_has_friendly_names_table(self) -> None: + """Database schema should include friendly_names table.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "CREATE TABLE IF NOT EXISTS friendly_names" in content + + def test_schema_has_required_columns(self) -> None: + """Friendly names table should have required columns.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "actor_id" in content + assert "base_name" in content + assert "version_id" in content + assert "item_type" in content + assert "display_name" in content + + def test_schema_has_unique_constraints(self) -> None: + """Friendly names table should have unique constraints.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + # Unique on (actor_id, base_name, version_id) + assert "UNIQUE(actor_id, base_name, version_id)" in content + # Unique on (actor_id, cid) + assert "UNIQUE(actor_id, cid)" in content + + +class TestDatabaseFunctionsExist: + """Tests that verify database functions exist.""" + + def test_create_friendly_name_exists(self) -> None: + """create_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def create_friendly_name(" in content + + def test_get_friendly_name_by_cid_exists(self) -> None: + """get_friendly_name_by_cid function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def get_friendly_name_by_cid(" in content + + def test_resolve_friendly_name_exists(self) -> None: + """resolve_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def resolve_friendly_name(" in content + + def test_list_friendly_names_exists(self) -> None: + """list_friendly_names function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def list_friendly_names(" in content + + def test_delete_friendly_name_exists(self) -> None: + """delete_friendly_name function should exist.""" + path = Path(__file__).parent.parent / "database.py" + content = path.read_text() + assert "async def delete_friendly_name(" in content + + +class TestNamingServiceModuleExists: + """Tests that verify naming service module structure.""" + + def test_module_file_exists(self) -> None: + """Naming service module file should exist.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + assert path.exists() + + def test_module_has_normalize_name(self) -> None: + """Module should have normalize_name function.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "def normalize_name(" in content + + def test_module_has_generate_version_id(self) -> None: + """Module should have generate_version_id function.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "def generate_version_id(" in content + + def test_module_has_naming_service_class(self) -> None: + """Module should have NamingService class.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "class NamingService:" in content + + def test_naming_service_has_assign_name(self) -> None: + """NamingService should have assign_name method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def assign_name(" in content + + def test_naming_service_has_resolve(self) -> None: + """NamingService should have resolve method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def resolve(" in content + + def test_naming_service_has_get_by_cid(self) -> None: + """NamingService should have get_by_cid method.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "async def get_by_cid(" in content + + +class TestVersionIdProperties: + """Tests for version ID format properties (using actual function).""" + + def test_version_id_format(self) -> None: + """Version ID should use base32-crockford alphabet.""" + # Read the naming service to verify it uses the right alphabet + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert 'CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz"' in content + + def test_version_id_uses_hmac(self) -> None: + """Version ID generation should use HMAC for server verification.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "hmac.new(" in content + + def test_version_id_uses_timestamp(self) -> None: + """Version ID generation should be timestamp-based.""" + path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py" + content = path.read_text() + assert "time.time()" in content diff --git a/tests/test_recipe_visibility.py b/tests/test_recipe_visibility.py new file mode 100644 index 0000000..a0fc93c --- /dev/null +++ b/tests/test_recipe_visibility.py @@ -0,0 +1,150 @@ +""" +Tests for recipe visibility in web UI. + +Bug found 2026-01-12: Recipes not showing in list even after upload. +""" + +import pytest +from pathlib import Path + + +class TestRecipeListingFlow: + """Tests for recipe listing data flow.""" + + def test_cache_manager_has_list_by_type(self) -> None: + """L1CacheManager should have list_by_type method.""" + # Read cache_manager.py and verify list_by_type exists + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'def list_by_type' in content, \ + "L1CacheManager should have list_by_type method" + + def test_list_by_type_returns_node_id(self) -> None: + """list_by_type should return entry.node_id values (IPFS CID).""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Find list_by_type function and verify it appends entry.node_id + assert 'cids.append(entry.node_id)' in content, \ + "list_by_type should append entry.node_id (IPFS CID) to results" + + def test_recipe_service_uses_database_items(self) -> None: + """Recipe service should use database.get_user_items for listing.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'get_user_items' in content, \ + "Recipe service should use database.get_user_items for listing" + + def test_recipe_upload_uses_recipe_node_type(self) -> None: + """Recipe upload should store with node_type='recipe'.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'node_type="recipe"' in content, \ + "Recipe upload should use node_type='recipe'" + + def test_get_by_cid_uses_find_by_cid(self) -> None: + """get_by_cid should use cache.find_by_cid to locate entries.""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Verify get_by_cid uses find_by_cid + assert 'find_by_cid(cid)' in content, \ + "get_by_cid should use find_by_cid to locate entries" + + def test_no_duplicate_get_by_cid_methods(self) -> None: + """ + Regression test: There should only be ONE get_by_cid method. + + Bug: Two get_by_cid methods existed, the second shadowed the first, + breaking recipe lookup because the comprehensive method was hidden. + """ + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + # Count occurrences of 'def get_by_cid' + count = content.count('def get_by_cid') + assert count == 1, \ + f"Should have exactly 1 get_by_cid method, found {count}" + + +class TestRecipeFilterLogic: + """Tests for recipe filtering logic via database.""" + + def test_recipes_filtered_by_actor_id(self) -> None: + """list_recipes should filter by actor_id parameter.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'actor_id' in content, \ + "list_recipes should accept actor_id parameter" + + def test_recipes_uses_item_type_filter(self) -> None: + """list_recipes should filter by item_type='recipe'.""" + path = Path('/home/giles/art/art-celery/app/services/recipe_service.py') + content = path.read_text() + + assert 'item_type="recipe"' in content, \ + "Recipe listing should filter by item_type='recipe'" + + +class TestCacheEntryHasCid: + """Tests for cache entry cid field.""" + + def test_artdag_cache_entry_has_cid(self) -> None: + """artdag CacheEntry should have cid field.""" + from artdag import CacheEntry + import dataclasses + + fields = {f.name for f in dataclasses.fields(CacheEntry)} + assert 'cid' in fields, \ + "CacheEntry should have cid field" + + def test_artdag_cache_put_computes_cid(self) -> None: + """artdag Cache.put should compute and store cid in metadata.""" + from artdag import Cache + import inspect + + source = inspect.getsource(Cache.put) + assert '"cid":' in source or "'cid':" in source, \ + "Cache.put should store cid in metadata" + + +class TestListByTypeReturnsEntries: + """Tests for list_by_type returning cached entries.""" + + def test_list_by_type_iterates_cache_entries(self) -> None: + """list_by_type should iterate self.cache.list_entries().""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'self.cache.list_entries()' in content, \ + "list_by_type should iterate cache entries" + + def test_list_by_type_filters_by_node_type(self) -> None: + """list_by_type should filter entries by node_type.""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'entry.node_type == node_type' in content, \ + "list_by_type should filter by node_type" + + def test_list_by_type_returns_node_id(self) -> None: + """list_by_type should return entry.node_id (IPFS CID).""" + path = Path('/home/giles/art/art-celery/cache_manager.py') + content = path.read_text() + + assert 'cids.append(entry.node_id)' in content, \ + "list_by_type should append entry.node_id (IPFS CID)" + + def test_artdag_cache_list_entries_returns_all(self) -> None: + """artdag Cache.list_entries should return all entries.""" + from artdag import Cache + import inspect + + source = inspect.getsource(Cache.list_entries) + # Should return self._entries.values() + assert '_entries' in source, \ + "list_entries should access _entries dict" diff --git a/tests/test_run_artifacts.py b/tests/test_run_artifacts.py new file mode 100644 index 0000000..e6042c4 --- /dev/null +++ b/tests/test_run_artifacts.py @@ -0,0 +1,111 @@ +""" +Tests for run artifacts data structure and template variables. + +Bug found 2026-01-12: runs/detail.html template expects artifact.cid +but the route provides artifact.hash, causing UndefinedError. +""" + +import pytest +from pathlib import Path + + +class TestCacheNotFoundTemplate: + """Tests for cache/not_found.html template.""" + + def test_template_uses_cid_not_content_hash(self) -> None: + """ + Regression test: not_found.html must use 'cid' variable. + + Bug: Template used 'content_hash' but route passes 'cid'. + Fix: Changed template to use 'cid'. + """ + template_path = Path('/home/giles/art/art-celery/app/templates/cache/not_found.html') + content = template_path.read_text() + + assert 'cid' in content, "Template should use 'cid' variable" + assert 'content_hash' not in content, \ + "Template should not use 'content_hash' (route passes 'cid')" + + def test_route_passes_cid_to_template(self) -> None: + """Verify route passes 'cid' variable to not_found template.""" + router_path = Path('/home/giles/art/art-celery/app/routers/cache.py') + content = router_path.read_text() + + # Find the render call for not_found.html + assert 'cid=cid' in content, \ + "Route should pass cid=cid to not_found.html template" + + +class TestInputPreviewsDataStructure: + """Tests for input_previews dict keys matching template expectations.""" + + def test_run_card_template_expects_inp_cid(self) -> None: + """Run card template uses inp.cid for input previews.""" + path = Path('/home/giles/art/art-celery/app/templates/runs/_run_card.html') + content = path.read_text() + + assert 'inp.cid' in content, \ + "Run card template should use inp.cid for input previews" + + def test_input_previews_use_cid_key(self) -> None: + """ + Regression test: input_previews must use 'cid' key not 'hash'. + + Bug: Router created input_previews with 'hash' key but template expected 'cid'. + """ + path = Path('/home/giles/art/art-celery/app/routers/runs.py') + content = path.read_text() + + # Should have: "cid": input_hash, not "hash": input_hash + assert '"cid": input_hash' in content, \ + "input_previews should use 'cid' key" + assert '"hash": input_hash' not in content, \ + "input_previews should not use 'hash' key (template expects 'cid')" + + +class TestArtifactDataStructure: + """Tests for artifact dict keys matching template expectations.""" + + def test_template_expects_cid_key(self) -> None: + """Template uses artifact.cid - verify this expectation.""" + template_path = Path('/home/giles/art/art-celery/app/templates/runs/detail.html') + content = template_path.read_text() + + # Template uses artifact.cid in multiple places + assert 'artifact.cid' in content, "Template should reference artifact.cid" + + def test_run_service_artifacts_have_cid_key(self) -> None: + """ + Regression test: get_run_artifacts must return dicts with 'cid' key. + + Bug: Service returned artifacts with 'hash' key but template expected 'cid'. + Fix: Changed service to use 'cid' key for consistency. + """ + # Read the run_service.py file and check it uses 'cid' not 'hash' + service_path = Path('/home/giles/art/art-celery/app/services/run_service.py') + content = service_path.read_text() + + # Find the get_artifact_info function and check it returns 'cid' + # The function should have: "cid": cid, not "hash": cid + assert '"cid": cid' in content or "'cid': cid" in content, \ + "get_run_artifacts should return artifacts with 'cid' key, not 'hash'" + + def test_router_artifacts_have_cid_key(self) -> None: + """ + Regression test: inline artifact creation in router must use 'cid' key. + + Bug: Router created artifacts with 'hash' key but template expected 'cid'. + """ + router_path = Path('/home/giles/art/art-celery/app/routers/runs.py') + content = router_path.read_text() + + # Check that artifacts.append uses 'cid' key + # Should have: "cid": output_cid, not "hash": output_cid + # Count occurrences of the patterns + hash_pattern_count = content.count('"hash": output_cid') + cid_pattern_count = content.count('"cid": output_cid') + + assert cid_pattern_count > 0, \ + "Router should create artifacts with 'cid' key" + assert hash_pattern_count == 0, \ + "Router should not use 'hash' key for artifacts (template expects 'cid')" diff --git a/tests/test_xector.py b/tests/test_xector.py new file mode 100644 index 0000000..0d006e5 --- /dev/null +++ b/tests/test_xector.py @@ -0,0 +1,305 @@ +""" +Tests for xector primitives - parallel array operations. +""" + +import pytest +import numpy as np +from sexp_effects.primitive_libs.xector import ( + Xector, + xector_red, xector_green, xector_blue, xector_rgb, + xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm, + xector_dist_from_center, + alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp, + alpha_sin, alpha_cos, alpha_sq, + alpha_lt, alpha_gt, alpha_eq, + beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count, + xector_where, xector_fill, xector_zeros, xector_ones, + is_xector, +) + + +class TestXectorBasics: + """Test Xector class basic operations.""" + + def test_create_from_list(self): + x = Xector([1, 2, 3]) + assert len(x) == 3 + assert is_xector(x) + + def test_create_from_numpy(self): + arr = np.array([1.0, 2.0, 3.0]) + x = Xector(arr) + assert len(x) == 3 + np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32)) + + def test_implicit_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = a + b + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_implicit_mul(self): + a = Xector([1, 2, 3]) + b = Xector([2, 2, 2]) + c = a * b + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_scalar_broadcast(self): + a = Xector([1, 2, 3]) + c = a + 10 + np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13]) + + def test_scalar_broadcast_rmul(self): + a = Xector([1, 2, 3]) + c = 2 * a + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + +class TestAlphaOperations: + """Test α (element-wise) operations.""" + + def test_alpha_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = alpha_add(a, b) + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_alpha_add_multi(self): + a = Xector([1, 2, 3]) + b = Xector([1, 1, 1]) + c = Xector([10, 10, 10]) + d = alpha_add(a, b, c) + np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14]) + + def test_alpha_mul_scalar(self): + a = Xector([1, 2, 3]) + c = alpha_mul(a, 2) + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_alpha_sqrt(self): + a = Xector([1, 4, 9, 16]) + c = alpha_sqrt(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4]) + + def test_alpha_clamp(self): + a = Xector([-5, 0, 5, 10, 15]) + c = alpha_clamp(a, 0, 10) + np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10]) + + def test_alpha_sin_cos(self): + a = Xector([0, np.pi/2, np.pi]) + s = alpha_sin(a) + c = alpha_cos(a) + np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5) + np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5) + + def test_alpha_sq(self): + a = Xector([1, 2, 3, 4]) + c = alpha_sq(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16]) + + def test_alpha_comparison(self): + a = Xector([1, 2, 3, 4]) + b = Xector([2, 2, 2, 2]) + lt = alpha_lt(a, b) + gt = alpha_gt(a, b) + eq = alpha_eq(a, b) + np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False]) + np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True]) + np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False]) + + +class TestBetaOperations: + """Test β (reduction) operations.""" + + def test_beta_add(self): + a = Xector([1, 2, 3, 4]) + assert beta_add(a) == 10 + + def test_beta_mul(self): + a = Xector([1, 2, 3, 4]) + assert beta_mul(a) == 24 + + def test_beta_min_max(self): + a = Xector([3, 1, 4, 1, 5, 9, 2, 6]) + assert beta_min(a) == 1 + assert beta_max(a) == 9 + + def test_beta_mean(self): + a = Xector([1, 2, 3, 4]) + assert beta_mean(a) == 2.5 + + def test_beta_count(self): + a = Xector([1, 2, 3, 4, 5]) + assert beta_count(a) == 5 + + +class TestFrameConversion: + """Test frame/xector conversion.""" + + def test_extract_channels(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[255, 0, 0], [0, 255, 0]], + [[0, 0, 255], [128, 128, 128]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + assert len(r) == 4 + np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128]) + np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128]) + np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128]) + + def test_rgb_roundtrip(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[100, 150, 200], [50, 75, 100]], + [[200, 100, 50], [25, 50, 75]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + reconstructed = xector_rgb(r, g, b) + np.testing.assert_array_equal(reconstructed, frame) + + def test_modify_and_reconstruct(self): + frame = np.array([ + [[100, 100, 100], [100, 100, 100]], + [[100, 100, 100], [100, 100, 100]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + # Double red channel + r_doubled = r * 2 + + result = xector_rgb(r_doubled, g, b) + + # Red should be 200, others unchanged + assert result[0, 0, 0] == 200 + assert result[0, 0, 1] == 100 + assert result[0, 0, 2] == 100 + + +class TestCoordinates: + """Test coordinate generation.""" + + def test_x_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + x = xector_x_coords(frame) + # Should be [0,1,2, 0,1,2] (x coords repeated for each row) + np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2]) + + def test_y_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + y = xector_y_coords(frame) + # Should be [0,0,0, 1,1,1] (y coords for each pixel) + np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1]) + + def test_normalized_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_x_norm(frame) + y = xector_y_norm(frame) + + # x should go 0 to 1 across width + assert x.to_numpy()[0] == 0 + assert x.to_numpy()[2] == 1 + + # y should go 0 to 1 down height + assert y.to_numpy()[0] == 0 + assert y.to_numpy()[3] == 1 + + +class TestConditional: + """Test conditional operations.""" + + def test_where(self): + cond = Xector([True, False, True, False]) + true_val = Xector([1, 1, 1, 1]) + false_val = Xector([0, 0, 0, 0]) + + result = xector_where(cond, true_val, false_val) + np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0]) + + def test_where_with_comparison(self): + a = Xector([1, 5, 3, 7]) + threshold = 4 + + # Elements > 4 become 255, others become 0 + result = xector_where(alpha_gt(a, threshold), 255, 0) + np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255]) + + def test_fill(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_fill(42, frame) + assert len(x) == 6 + assert all(v == 42 for v in x.to_numpy()) + + def test_zeros_ones(self): + frame = np.zeros((2, 2, 3), dtype=np.uint8) + z = xector_zeros(frame) + o = xector_ones(frame) + + assert all(v == 0 for v in z.to_numpy()) + assert all(v == 1 for v in o.to_numpy()) + + +class TestInterpreterIntegration: + """Test xector operations through the interpreter.""" + + def test_xector_vignette_effect(self): + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + + # Load the xector vignette effect + interp.load_effect('sexp_effects/effects/xector_vignette.sexp') + + # Create a test frame (white) + frame = np.full((100, 100, 3), 255, dtype=np.uint8) + + # Run effect + result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {}) + + # Center should be brighter than corners + center = result[50, 50] + corner = result[0, 0] + + assert center.mean() > corner.mean(), "Center should be brighter than corners" + # Corners should be darkened + assert corner.mean() < 255, "Corners should be darkened" + + def test_implicit_elementwise(self): + """Test that regular + works element-wise on xectors.""" + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + # Load xector primitives + from sexp_effects.primitive_libs.xector import PRIMITIVES + for name, fn in PRIMITIVES.items(): + interp.global_env.set(name, fn) + + # Parse and eval a simple xector expression + from sexp_effects.parser import parse + expr = parse('(+ (red frame) 10)') + + # Create test frame + frame = np.full((2, 2, 3), 100, dtype=np.uint8) + interp.global_env.set('frame', frame) + + result = interp.eval(expr) + + # Should be a xector with values 110 + assert is_xector(result) + np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110]) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) From f54b0fb5dae7670168abddc577443fca22fa30d1 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:07:31 +0000 Subject: [PATCH 03/24] Squashed 'l2/' content from commit 79caa24 git-subtree-dir: l2 git-subtree-split: 79caa24e2129bf6e2cee819327d5622425306b67 --- .env.example | 20 + .gitea/workflows/ci.yml | 62 + .gitignore | 11 + Dockerfile | 23 + README.md | 389 +++ anchoring.py | 334 ++ app/__init__.py | 116 + app/config.py | 56 + app/dependencies.py | 80 + app/routers/__init__.py | 25 + app/routers/activities.py | 99 + app/routers/anchors.py | 203 ++ app/routers/assets.py | 244 ++ app/routers/auth.py | 223 ++ app/routers/federation.py | 115 + app/routers/renderers.py | 93 + app/routers/storage.py | 254 ++ app/routers/users.py | 161 + app/templates/404.html | 11 + app/templates/activities/list.html | 39 + app/templates/anchors/list.html | 47 + app/templates/assets/list.html | 58 + app/templates/auth/already_logged_in.html | 12 + app/templates/auth/login.html | 37 + app/templates/auth/register.html | 45 + app/templates/base.html | 47 + app/templates/home.html | 42 + app/templates/renderers/list.html | 52 + app/templates/storage/list.html | 41 + artdag-client.tar.gz | Bin 0 -> 7982 bytes auth.py | 213 ++ db.py | 1215 +++++++ deploy.sh | 19 + docker-compose.yml | 90 + docker-stack.yml | 91 + ipfs_client.py | 226 ++ keys.py | 119 + migrate.py | 245 ++ requirements.txt | 13 + server.py | 26 + server_legacy.py | 3765 +++++++++++++++++++++ setup_keys.py | 51 + storage_providers.py | 1009 ++++++ 43 files changed, 10021 insertions(+) create mode 100644 .env.example create mode 100644 .gitea/workflows/ci.yml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 anchoring.py create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/dependencies.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/activities.py create mode 100644 app/routers/anchors.py create mode 100644 app/routers/assets.py create mode 100644 app/routers/auth.py create mode 100644 app/routers/federation.py create mode 100644 app/routers/renderers.py create mode 100644 app/routers/storage.py create mode 100644 app/routers/users.py create mode 100644 app/templates/404.html create mode 100644 app/templates/activities/list.html create mode 100644 app/templates/anchors/list.html create mode 100644 app/templates/assets/list.html create mode 100644 app/templates/auth/already_logged_in.html create mode 100644 app/templates/auth/login.html create mode 100644 app/templates/auth/register.html create mode 100644 app/templates/base.html create mode 100644 app/templates/home.html create mode 100644 app/templates/renderers/list.html create mode 100644 app/templates/storage/list.html create mode 100644 artdag-client.tar.gz create mode 100644 auth.py create mode 100644 db.py create mode 100755 deploy.sh create mode 100644 docker-compose.yml create mode 100644 docker-stack.yml create mode 100644 ipfs_client.py create mode 100644 keys.py create mode 100755 migrate.py create mode 100644 requirements.txt create mode 100644 server.py create mode 100644 server_legacy.py create mode 100755 setup_keys.py create mode 100644 storage_providers.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..d0bb2cf --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# L2 Server Configuration + +# PostgreSQL password (REQUIRED - no default) +POSTGRES_PASSWORD=changeme-generate-with-openssl-rand-hex-16 + +# Domain for this ActivityPub server +ARTDAG_DOMAIN=artdag.rose-ash.com + +# JWT secret for token signing (generate with: openssl rand -hex 32) +JWT_SECRET=your-secret-here-generate-with-openssl-rand-hex-32 + +# L1 server URL for fetching content (images/videos) +L1_PUBLIC_URL=https://celery-artdag.rose-ash.com + +# Effects repository URL for linking to effect source code +EFFECTS_REPO_URL=https://git.rose-ash.com/art-dag/effects + +# Notes: +# - ARTDAG_USER removed - now multi-actor, each registered user is their own actor +# - L1 URL can also come from provenance data per-asset diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..30d34ea --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,62 @@ +name: Build and Deploy + +on: + push: + branches: [main] + +env: + REGISTRY: registry.rose-ash.com:5000 + IMAGE: l2-server + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install tools + run: | + apt-get update && apt-get install -y --no-install-recommends openssh-client + + - name: Set up SSH + env: + SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + mkdir -p ~/.ssh + echo "$SSH_KEY" > ~/.ssh/id_rsa + chmod 600 ~/.ssh/id_rsa + ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true + + - name: Pull latest code on server + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/activity-pub + git fetch origin main + git reset --hard origin/main + " + + - name: Build and push image + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/activity-pub + docker build --build-arg CACHEBUST=\$(date +%s) -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:latest -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:${{ github.sha }} . + docker push ${{ env.REGISTRY }}/${{ env.IMAGE }}:latest + docker push ${{ env.REGISTRY }}/${{ env.IMAGE }}:${{ github.sha }} + " + + - name: Deploy stack + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd /root/art-dag/activity-pub + docker stack deploy -c docker-compose.yml activitypub + echo 'Waiting for services to update...' + sleep 10 + docker stack services activitypub + " diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..705d35d --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +__pycache__/ +*.py[cod] +.venv/ +venv/ + +# Private keys - NEVER commit these +*.pem +keys/ + +# Secrets +.env diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..409aadf --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install git for pip to clone dependencies +RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* + +# Install dependencies +COPY requirements.txt . +ARG CACHEBUST=1 +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application +COPY . . + +# Create data directory +RUN mkdir -p /data/l2 + +ENV PYTHONUNBUFFERED=1 +ENV ARTDAG_DATA=/data/l2 + +# Default command runs the server +CMD ["python", "server.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..31f8c36 --- /dev/null +++ b/README.md @@ -0,0 +1,389 @@ +# Art DAG L2 Server - ActivityPub + +Ownership registry and ActivityPub federation for Art DAG. Manages asset provenance, cryptographic anchoring, and distributed identity. + +## Features + +- **Asset Registry**: Content-addressed assets with provenance tracking +- **ActivityPub Federation**: Standard protocol for distributed social networking +- **OpenTimestamps Anchoring**: Cryptographic proof of existence on Bitcoin blockchain +- **L1 Integration**: Record and verify L1 rendering runs +- **Storage Providers**: S3, IPFS, and local storage backends +- **Scoped Authentication**: Secure token-based auth for federated L1 servers + +## Dependencies + +- **PostgreSQL**: Primary data storage +- **artdag-common**: Shared templates and middleware +- **cryptography**: RSA key generation and signing +- **httpx**: Async HTTP client for federation + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Configure +export ARTDAG_DOMAIN=artdag.example.com +export ARTDAG_USER=giles +export DATABASE_URL=postgresql://artdag:$POSTGRES_PASSWORD@localhost:5432/artdag +export L1_SERVERS=https://celery-artdag.example.com + +# Generate signing keys (required for federation) +python setup_keys.py + +# Start server +python server.py +``` + +## Docker Deployment + +```bash +docker stack deploy -c docker-compose.yml artdag-l2 +``` + +## Configuration + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `ARTDAG_DOMAIN` | `artdag.rose-ash.com` | Domain for ActivityPub actors | +| `ARTDAG_USER` | `giles` | Default username | +| `ARTDAG_DATA` | `~/.artdag/l2` | Data directory | +| `DATABASE_URL` | **(required)** | PostgreSQL connection | +| `L1_SERVERS` | - | Comma-separated list of L1 server URLs | +| `JWT_SECRET` | (generated) | JWT signing secret | +| `HOST` | `0.0.0.0` | Server bind address | +| `PORT` | `8200` | Server port | + +### JWT Secret + +The JWT secret signs authentication tokens. Without a persistent secret, tokens are invalidated on restart. + +```bash +# Generate a secret +openssl rand -hex 32 + +# Set in environment +export JWT_SECRET="your-generated-secret" + +# Or use Docker secrets (recommended for production) +echo "your-secret" | docker secret create jwt_secret - +``` + +### RSA Keys + +ActivityPub requires RSA keys for signing activities: + +```bash +# Generate keys +python setup_keys.py + +# Or with custom paths +python setup_keys.py --data-dir /data/l2 --user giles +``` + +Keys stored in `$ARTDAG_DATA/keys/`: +- `{username}.pem` - Private key (chmod 600) +- `{username}.pub` - Public key (in actor profile) + +## Web UI + +| Path | Description | +|------|-------------| +| `/` | Home page with stats | +| `/login` | Login form | +| `/register` | Registration form | +| `/logout` | Log out | +| `/assets` | Browse registered assets | +| `/asset/{name}` | Asset detail page | +| `/activities` | Published activities | +| `/activity/{id}` | Activity detail | +| `/users` | Registered users | +| `/renderers` | L1 renderer connections | +| `/anchors/ui` | OpenTimestamps management | +| `/storage` | Storage provider config | +| `/download/client` | Download CLI client | + +## API Reference + +Interactive docs: http://localhost:8200/docs + +### Authentication + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/auth/register` | Register new user | +| POST | `/auth/login` | Login, get JWT token | +| GET | `/auth/me` | Get current user info | +| POST | `/auth/verify` | Verify token (for L1 servers) | + +### Assets + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/assets` | List all assets | +| GET | `/assets/{name}` | Get asset by name | +| POST | `/assets` | Register new asset | +| PATCH | `/assets/{name}` | Update asset metadata | +| POST | `/assets/record-run` | Record L1 run as asset | +| POST | `/assets/publish-cache` | Publish L1 cache item | +| GET | `/assets/by-run-id/{run_id}` | Find asset by L1 run ID | + +### ActivityPub + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/.well-known/webfinger` | Actor discovery | +| GET | `/users/{username}` | Actor profile | +| GET | `/users/{username}/outbox` | Published activities | +| POST | `/users/{username}/inbox` | Receive activities | +| GET | `/users/{username}/followers` | Followers list | +| GET | `/objects/{hash}` | Get object by content hash | +| GET | `/activities` | List activities (paginated) | +| GET | `/activities/{ref}` | Get activity by reference | +| GET | `/activity/{index}` | Get activity by index | + +### OpenTimestamps Anchoring + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/anchors/create` | Create timestamp anchor | +| GET | `/anchors` | List all anchors | +| GET | `/anchors/{merkle_root}` | Get anchor details | +| GET | `/anchors/{merkle_root}/tree` | Get merkle tree | +| GET | `/anchors/verify/{activity_id}` | Verify activity timestamp | +| POST | `/anchors/{merkle_root}/upgrade` | Upgrade pending timestamp | +| GET | `/anchors/ui` | Anchor management UI | +| POST | `/anchors/test-ots` | Test OTS functionality | + +### Renderers (L1 Connections) + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/renderers` | List attached L1 servers | +| GET | `/renderers/attach` | Initiate L1 attachment | +| POST | `/renderers/detach` | Detach from L1 server | + +### Storage Providers + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/storage` | List storage providers | +| POST | `/storage` | Add provider (form) | +| POST | `/storage/add` | Add provider (JSON) | +| GET | `/storage/{id}` | Get provider details | +| PATCH | `/storage/{id}` | Update provider | +| DELETE | `/storage/{id}` | Delete provider | +| POST | `/storage/{id}/test` | Test connection | +| GET | `/storage/type/{type}` | Get form for provider type | + +## L1 Renderer Integration + +L2 coordinates with L1 rendering servers for distributed processing. + +### Configuration + +```bash +# Single L1 server +export L1_SERVERS=https://celery-artdag.rose-ash.com + +# Multiple L1 servers +export L1_SERVERS=https://server1.example.com,https://server2.example.com +``` + +### Attachment Flow + +1. User visits `/renderers` and clicks "Attach" +2. L2 creates a **scoped token** bound to the specific L1 +3. User redirected to L1's `/auth?auth_token=...` +4. L1 calls L2's `/auth/verify` to validate +5. L2 checks token scope matches requesting L1 +6. L1 sets local cookie, attachment recorded in `user_renderers` + +### Security + +- **Scoped tokens**: Tokens bound to specific L1; can't be used elsewhere +- **No shared secrets**: L1 verifies via L2's `/auth/verify` endpoint +- **Federated logout**: L2 revokes tokens on all attached L1s + +## OpenTimestamps Anchoring + +Cryptographic proof of existence using Bitcoin blockchain. + +### How It Works + +1. Activities are collected into merkle trees +2. Merkle root submitted to Bitcoin via OpenTimestamps +3. Pending proofs upgraded when Bitcoin confirms +4. Final proof verifiable without trusted third parties + +### Verification + +```bash +# Verify an activity's timestamp +curl https://artdag.example.com/anchors/verify/123 + +# Returns: +{ + "activity_id": 123, + "merkle_root": "abc123...", + "status": "confirmed", + "bitcoin_block": 800000, + "verified_at": "2026-01-01T..." +} +``` + +## Data Model + +### PostgreSQL Tables + +| Table | Description | +|-------|-------------| +| `users` | Registered users with hashed passwords | +| `assets` | Asset registry with content hashes | +| `activities` | Signed ActivityPub activities | +| `followers` | Follower relationships | +| `anchors` | OpenTimestamps anchor records | +| `anchor_activities` | Activity-to-anchor mappings | +| `user_renderers` | L1 attachment records | +| `revoked_tokens` | Token revocation list | +| `storage_providers` | Storage configurations | + +### Asset Structure + +```json +{ + "name": "my-video", + "content_hash": "sha3-256:abc123...", + "asset_type": "video", + "owner": "@giles@artdag.rose-ash.com", + "created_at": "2026-01-01T...", + "provenance": { + "inputs": [...], + "recipe": "beat-sync", + "l1_server": "https://celery-artdag.rose-ash.com", + "run_id": "..." + }, + "tags": ["art", "generated"] +} +``` + +### Activity Structure + +```json +{ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "Create", + "actor": "https://artdag.rose-ash.com/users/giles", + "object": { + "type": "Document", + "name": "my-video", + "content": "sha3-256:abc123...", + "attributedTo": "https://artdag.rose-ash.com/users/giles" + }, + "published": "2026-01-01T..." +} +``` + +## CLI Commands + +### Register Asset + +```bash +curl -X POST https://artdag.example.com/assets \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "my-video", + "content_hash": "abc123...", + "asset_type": "video", + "tags": ["art", "generated"] + }' +``` + +### Record L1 Run + +```bash +curl -X POST https://artdag.example.com/assets/record-run \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "run_id": "uuid-from-l1", + "l1_server": "https://celery-artdag.rose-ash.com", + "output_name": "my-rendered-video" + }' +``` + +### Publish L1 Cache Item + +```bash +curl -X POST https://artdag.example.com/assets/publish-cache \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "content_hash": "abc123...", + "l1_server": "https://celery-artdag.rose-ash.com", + "name": "my-asset", + "asset_type": "video" + }' +``` + +## Architecture + +``` +L2 Server (FastAPI) + │ + ├── Web UI (Jinja2 + HTMX + Tailwind) + │ + ├── /assets → Asset Registry + │ │ + │ └── PostgreSQL (assets table) + │ + ├── /users/{user}/outbox → ActivityPub + │ │ + │ ├── Sign activities (RSA) + │ └── PostgreSQL (activities table) + │ + ├── /anchors → OpenTimestamps + │ │ + │ ├── Merkle tree construction + │ └── Bitcoin anchoring + │ + ├── /auth/verify → L1 Token Verification + │ │ + │ └── Scoped token validation + │ + └── /storage → Storage Providers + │ + ├── S3 (boto3) + ├── IPFS (ipfs_client) + └── Local filesystem +``` + +## Federation + +L2 implements ActivityPub for federated asset sharing. + +### Discovery + +```bash +# Webfinger lookup +curl "https://artdag.example.com/.well-known/webfinger?resource=acct:giles@artdag.example.com" +``` + +### Actor Profile + +```bash +curl -H "Accept: application/activity+json" \ + https://artdag.example.com/users/giles +``` + +### Outbox + +```bash +curl -H "Accept: application/activity+json" \ + https://artdag.example.com/users/giles/outbox +``` diff --git a/anchoring.py b/anchoring.py new file mode 100644 index 0000000..49f6f1a --- /dev/null +++ b/anchoring.py @@ -0,0 +1,334 @@ +# art-activity-pub/anchoring.py +""" +Merkle tree anchoring to Bitcoin via OpenTimestamps. + +Provides provable timestamps for ActivityPub activities without running +our own blockchain. Activities are hashed into a merkle tree, the root +is submitted to OpenTimestamps (free), and the proof is stored on IPFS. + +The merkle tree + OTS proof provides cryptographic evidence that +activities existed at a specific time, anchored to Bitcoin. +""" + +import hashlib +import json +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import List, Optional + +import requests + +logger = logging.getLogger(__name__) + +# Backup file location (should be on persistent volume) +ANCHOR_BACKUP_DIR = Path(os.getenv("ANCHOR_BACKUP_DIR", "/data/anchors")) +ANCHOR_BACKUP_FILE = ANCHOR_BACKUP_DIR / "anchors.jsonl" + +# OpenTimestamps calendar servers +OTS_SERVERS = [ + "https://a.pool.opentimestamps.org", + "https://b.pool.opentimestamps.org", + "https://a.pool.eternitywall.com", +] + + +def _ensure_backup_dir(): + """Ensure backup directory exists.""" + ANCHOR_BACKUP_DIR.mkdir(parents=True, exist_ok=True) + + +def build_merkle_tree(items: List[str]) -> Optional[dict]: + """ + Build a merkle tree from a list of strings (activity IDs). + + Args: + items: List of activity IDs to include + + Returns: + Dict with root, tree structure, and metadata, or None if empty + """ + if not items: + return None + + # Sort for deterministic ordering + items = sorted(items) + + # Hash each item to create leaves + leaves = [hashlib.sha256(item.encode()).hexdigest() for item in items] + + # Build tree bottom-up + tree_levels = [leaves] + current_level = leaves + + while len(current_level) > 1: + next_level = [] + for i in range(0, len(current_level), 2): + left = current_level[i] + # If odd number, duplicate last node + right = current_level[i + 1] if i + 1 < len(current_level) else left + # Hash pair together + combined = hashlib.sha256((left + right).encode()).hexdigest() + next_level.append(combined) + tree_levels.append(next_level) + current_level = next_level + + root = current_level[0] + + return { + "root": root, + "tree": tree_levels, + "items": items, + "item_count": len(items), + "created_at": datetime.now(timezone.utc).isoformat() + } + + +def get_merkle_proof(tree: dict, item: str) -> Optional[List[dict]]: + """ + Get merkle proof for a specific item. + + Args: + tree: Merkle tree dict from build_merkle_tree + item: The item to prove membership for + + Returns: + List of proof steps, or None if item not in tree + """ + items = tree["items"] + if item not in items: + return None + + # Find leaf index + sorted_items = sorted(items) + leaf_index = sorted_items.index(item) + leaf_hash = hashlib.sha256(item.encode()).hexdigest() + + proof = [] + tree_levels = tree["tree"] + current_index = leaf_index + + for level in tree_levels[:-1]: # Skip root level + sibling_index = current_index ^ 1 # XOR to get sibling + if sibling_index < len(level): + sibling_hash = level[sibling_index] + proof.append({ + "hash": sibling_hash, + "position": "right" if current_index % 2 == 0 else "left" + }) + current_index //= 2 + + return proof + + +def verify_merkle_proof(item: str, proof: List[dict], root: str) -> bool: + """ + Verify a merkle proof. + + Args: + item: The item to verify + proof: Proof steps from get_merkle_proof + root: Expected merkle root + + Returns: + True if proof is valid + """ + current_hash = hashlib.sha256(item.encode()).hexdigest() + + for step in proof: + sibling = step["hash"] + if step["position"] == "right": + combined = current_hash + sibling + else: + combined = sibling + current_hash + current_hash = hashlib.sha256(combined.encode()).hexdigest() + + return current_hash == root + + +def submit_to_opentimestamps(hash_hex: str) -> Optional[bytes]: + """ + Submit a hash to OpenTimestamps for Bitcoin anchoring. + + Args: + hash_hex: Hex-encoded SHA256 hash to timestamp + + Returns: + Incomplete .ots proof bytes, or None on failure + + Note: + The returned proof is "incomplete" - it becomes complete + after Bitcoin confirms (usually 1-2 hours). Use upgrade_ots_proof + to get the complete proof later. + """ + hash_bytes = bytes.fromhex(hash_hex) + + for server in OTS_SERVERS: + try: + resp = requests.post( + f"{server}/digest", + data=hash_bytes, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=10 + ) + if resp.status_code == 200: + logger.info(f"Submitted to OpenTimestamps via {server}") + return resp.content + except Exception as e: + logger.warning(f"OTS server {server} failed: {e}") + continue + + logger.error("All OpenTimestamps servers failed") + return None + + +def upgrade_ots_proof(ots_proof: bytes) -> Optional[bytes]: + """ + Upgrade an incomplete OTS proof to a complete Bitcoin-anchored proof. + + Args: + ots_proof: Incomplete .ots proof bytes + + Returns: + Complete .ots proof bytes, or None if not yet confirmed + + Note: + This should be called periodically (e.g., hourly) until + the proof is complete. Bitcoin confirmation takes ~1-2 hours. + """ + for server in OTS_SERVERS: + try: + resp = requests.post( + f"{server}/upgrade", + data=ots_proof, + headers={"Content-Type": "application/octet-stream"}, + timeout=10 + ) + if resp.status_code == 200 and len(resp.content) > len(ots_proof): + logger.info(f"OTS proof upgraded via {server}") + return resp.content + except Exception as e: + logger.warning(f"OTS upgrade via {server} failed: {e}") + continue + + return None + + +def append_to_backup(anchor_record: dict): + """ + Append anchor record to persistent JSONL backup file. + + Args: + anchor_record: Dict with anchor metadata + """ + _ensure_backup_dir() + + with open(ANCHOR_BACKUP_FILE, "a") as f: + f.write(json.dumps(anchor_record, sort_keys=True) + "\n") + + logger.info(f"Anchor backed up to {ANCHOR_BACKUP_FILE}") + + +def load_backup_anchors() -> List[dict]: + """ + Load all anchors from backup file. + + Returns: + List of anchor records + """ + if not ANCHOR_BACKUP_FILE.exists(): + return [] + + anchors = [] + with open(ANCHOR_BACKUP_FILE, "r") as f: + for line in f: + line = line.strip() + if line: + try: + anchors.append(json.loads(line)) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON in backup: {line[:50]}...") + + return anchors + + +def get_latest_anchor_from_backup() -> Optional[dict]: + """Get the most recent anchor from backup.""" + anchors = load_backup_anchors() + return anchors[-1] if anchors else None + + +async def create_anchor( + activity_ids: List[str], + db_module, + ipfs_module +) -> Optional[dict]: + """ + Create a new anchor for a batch of activities. + + Args: + activity_ids: List of activity UUIDs to anchor + db_module: Database module with anchor functions + ipfs_module: IPFS client module + + Returns: + Anchor record dict, or None on failure + """ + if not activity_ids: + logger.info("No activities to anchor") + return None + + # Build merkle tree + tree = build_merkle_tree(activity_ids) + if not tree: + return None + + root = tree["root"] + logger.info(f"Built merkle tree: {len(activity_ids)} activities, root={root[:16]}...") + + # Store tree on IPFS + try: + tree_cid = ipfs_module.add_json(tree) + logger.info(f"Merkle tree stored on IPFS: {tree_cid}") + except Exception as e: + logger.error(f"Failed to store tree on IPFS: {e}") + tree_cid = None + + # Submit to OpenTimestamps + ots_proof = submit_to_opentimestamps(root) + + # Store OTS proof on IPFS too + ots_cid = None + if ots_proof and ipfs_module: + try: + ots_cid = ipfs_module.add_bytes(ots_proof) + logger.info(f"OTS proof stored on IPFS: {ots_cid}") + except Exception as e: + logger.warning(f"Failed to store OTS proof on IPFS: {e}") + + # Create anchor record + anchor_record = { + "merkle_root": root, + "tree_ipfs_cid": tree_cid, + "ots_proof_cid": ots_cid, + "activity_count": len(activity_ids), + "first_activity_id": activity_ids[0], + "last_activity_id": activity_ids[-1], + "created_at": datetime.now(timezone.utc).isoformat(), + "confirmed_at": None, + "bitcoin_txid": None + } + + # Save to database + if db_module: + try: + await db_module.create_anchor(anchor_record) + await db_module.mark_activities_anchored(activity_ids, root) + except Exception as e: + logger.error(f"Failed to save anchor to database: {e}") + + # Append to backup file (persistent) + append_to_backup(anchor_record) + + return anchor_record diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..1062a13 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,116 @@ +""" +Art-DAG L2 Server Application Factory. + +Creates and configures the FastAPI application with all routers and middleware. +""" + +from pathlib import Path +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, HTMLResponse + +from artdag_common import create_jinja_env +from artdag_common.middleware.auth import get_user_from_cookie + +from .config import settings + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage database connection pool lifecycle.""" + import db + await db.init_pool() + yield + await db.close_pool() + + +def create_app() -> FastAPI: + """ + Create and configure the L2 FastAPI application. + + Returns: + Configured FastAPI instance + """ + app = FastAPI( + title="Art-DAG L2 Server", + description="ActivityPub server for Art-DAG ownership and federation", + version="1.0.0", + lifespan=lifespan, + ) + + # Coop fragment pre-fetch — inject nav-tree, auth-menu, cart-mini + _FRAG_SKIP = ("/auth/", "/.well-known/", "/health", + "/internal/", "/static/", "/inbox") + + @app.middleware("http") + async def coop_fragments_middleware(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _FRAG_SKIP) + or request.headers.get("hx-request") + ): + request.state.nav_tree_html = "" + request.state.auth_menu_html = "" + request.state.cart_mini_html = "" + return await call_next(request) + + from artdag_common.fragments import fetch_fragments as _fetch_frags + + user = get_user_from_cookie(request) + auth_params = {"email": user.email} if user and user.email else {} + nav_params = {"app_name": "artdag", "path": path} + + try: + nav_tree_html, auth_menu_html, cart_mini_html = await _fetch_frags([ + ("blog", "nav-tree", nav_params), + ("account", "auth-menu", auth_params or None), + ("cart", "cart-mini", None), + ]) + except Exception: + nav_tree_html = auth_menu_html = cart_mini_html = "" + + request.state.nav_tree_html = nav_tree_html + request.state.auth_menu_html = auth_menu_html + request.state.cart_mini_html = cart_mini_html + + return await call_next(request) + + # Initialize Jinja2 templates + template_dir = Path(__file__).parent / "templates" + app.state.templates = create_jinja_env(template_dir) + + # Custom 404 handler + @app.exception_handler(404) + async def not_found_handler(request: Request, exc): + from artdag_common.middleware import wants_html + if wants_html(request): + from artdag_common import render + return render(app.state.templates, "404.html", request, + user=None, + ) + return JSONResponse({"detail": "Not found"}, status_code=404) + + # Include routers + from .routers import auth, assets, activities, anchors, storage, users, renderers + + # Root routes + app.include_router(auth.router, prefix="/auth", tags=["auth"]) + app.include_router(users.router, tags=["users"]) + + # Feature routers + app.include_router(assets.router, prefix="/assets", tags=["assets"]) + app.include_router(activities.router, prefix="/activities", tags=["activities"]) + app.include_router(anchors.router, prefix="/anchors", tags=["anchors"]) + app.include_router(storage.router, prefix="/storage", tags=["storage"]) + app.include_router(renderers.router, prefix="/renderers", tags=["renderers"]) + + # WebFinger and ActivityPub discovery + from .routers import federation + app.include_router(federation.router, tags=["federation"]) + + return app + + +# Create the default app instance +app = create_app() diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..d88d435 --- /dev/null +++ b/app/config.py @@ -0,0 +1,56 @@ +""" +L2 Server Configuration. + +Environment-based settings for the ActivityPub server. +""" + +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class Settings: + """L2 Server configuration.""" + + # Domain and URLs + domain: str = os.environ.get("ARTDAG_DOMAIN", "artdag.rose-ash.com") + l1_public_url: str = os.environ.get("L1_PUBLIC_URL", "https://celery-artdag.rose-ash.com") + effects_repo_url: str = os.environ.get("EFFECTS_REPO_URL", "https://git.rose-ash.com/art-dag/effects") + ipfs_gateway_url: str = os.environ.get("IPFS_GATEWAY_URL", "") + + # L1 servers + l1_servers: list = None + + # Cookie domain for cross-subdomain auth + cookie_domain: str = None + + # Data directory + data_dir: Path = None + + # JWT settings + jwt_secret: str = os.environ.get("JWT_SECRET", "") + jwt_algorithm: str = "HS256" + access_token_expire_minutes: int = 60 * 24 * 30 # 30 days + + def __post_init__(self): + # Parse L1 servers + l1_str = os.environ.get("L1_SERVERS", "https://celery-artdag.rose-ash.com") + self.l1_servers = [s.strip() for s in l1_str.split(",") if s.strip()] + + # Cookie domain + env_cookie = os.environ.get("COOKIE_DOMAIN") + if env_cookie: + self.cookie_domain = env_cookie + else: + parts = self.domain.split(".") + if len(parts) >= 2: + self.cookie_domain = "." + ".".join(parts[-2:]) + + # Data directory + self.data_dir = Path(os.environ.get("ARTDAG_DATA", str(Path.home() / ".artdag" / "l2"))) + self.data_dir.mkdir(parents=True, exist_ok=True) + (self.data_dir / "assets").mkdir(exist_ok=True) + + +settings = Settings() diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000..d10d063 --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,80 @@ +""" +L2 Server Dependency Injection. + +Provides common dependencies for routes. +""" + +from typing import Optional + +from fastapi import Request, HTTPException, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from .config import settings + +security = HTTPBearer(auto_error=False) + + +def get_templates(request: Request): + """Get Jinja2 templates from app state.""" + return request.app.state.templates + + +async def get_current_user(request: Request) -> Optional[dict]: + """ + Get current user from cookie or header. + + Returns user dict or None if not authenticated. + """ + from auth import verify_token, get_token_claims + + # Try cookie first + token = request.cookies.get("auth_token") + + # Try Authorization header + if not token: + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + + if not token: + return None + + # Verify token + username = verify_token(token) + if not username: + return None + + # Get full claims + claims = get_token_claims(token) + if not claims: + return None + + return { + "username": username, + "actor_id": f"https://{settings.domain}/users/{username}", + "token": token, + **claims, + } + + +async def require_auth(request: Request) -> dict: + """ + Require authentication. + + Raises HTTPException 401 if not authenticated. + """ + user = await get_current_user(request) + if not user: + raise HTTPException(401, "Authentication required") + return user + + +def get_user_from_cookie(request: Request) -> Optional[str]: + """Get username from cookie (for HTML pages).""" + from auth import verify_token + + token = request.cookies.get("auth_token") + if not token: + return None + + return verify_token(token) diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..8365296 --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1,25 @@ +""" +L2 Server Routers. + +Each router handles a specific domain of functionality. +""" + +from . import auth +from . import assets +from . import activities +from . import anchors +from . import storage +from . import users +from . import renderers +from . import federation + +__all__ = [ + "auth", + "assets", + "activities", + "anchors", + "storage", + "users", + "renderers", + "federation", +] diff --git a/app/routers/activities.py b/app/routers/activities.py new file mode 100644 index 0000000..10740c8 --- /dev/null +++ b/app/routers/activities.py @@ -0,0 +1,99 @@ +""" +Activity routes for L2 server. + +Handles ActivityPub activities and outbox. +""" + +import logging +from typing import Optional + +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import JSONResponse + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json + +from ..config import settings +from ..dependencies import get_templates, require_auth, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("") +async def list_activities( + request: Request, + offset: int = 0, + limit: int = 20, +): + """List recent activities.""" + import db + + username = get_user_from_cookie(request) + + activities, total = await db.get_activities_paginated(limit=limit, offset=offset) + has_more = offset + len(activities) < total + + if wants_json(request): + return {"activities": activities, "offset": offset, "limit": limit} + + templates = get_templates(request) + return render(templates, "activities/list.html", request, + activities=activities, + user={"username": username} if username else None, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="activities", + ) + + +@router.get("/{activity_id}") +async def get_activity( + activity_id: str, + request: Request, +): + """Get activity details.""" + import db + + activity = await db.get_activity(activity_id) + if not activity: + raise HTTPException(404, "Activity not found") + + # ActivityPub response + if "application/activity+json" in request.headers.get("accept", ""): + return JSONResponse( + content=activity.get("activity_json", activity), + media_type="application/activity+json", + ) + + if wants_json(request): + return activity + + username = get_user_from_cookie(request) + templates = get_templates(request) + return render(templates, "activities/detail.html", request, + activity=activity, + user={"username": username} if username else None, + active_tab="activities", + ) + + +@router.post("") +async def create_activity( + request: Request, + user: dict = Depends(require_auth), +): + """Create a new activity (internal use).""" + import db + import json + + body = await request.json() + + activity_id = await db.create_activity( + actor=user["actor_id"], + activity_type=body.get("type", "Create"), + object_data=body.get("object"), + ) + + return {"activity_id": activity_id, "created": True} diff --git a/app/routers/anchors.py b/app/routers/anchors.py new file mode 100644 index 0000000..6bfb6a5 --- /dev/null +++ b/app/routers/anchors.py @@ -0,0 +1,203 @@ +""" +Anchor routes for L2 server. + +Handles OpenTimestamps anchoring and verification. +""" + +import logging +from typing import Optional + +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse, FileResponse + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json + +from ..config import settings +from ..dependencies import get_templates, require_auth, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("") +async def list_anchors( + request: Request, + offset: int = 0, + limit: int = 20, +): + """List user's anchors.""" + import db + + username = get_user_from_cookie(request) + if not username: + if wants_json(request): + raise HTTPException(401, "Authentication required") + from fastapi.responses import RedirectResponse + return RedirectResponse(url="/login", status_code=302) + + anchors = await db.get_anchors_paginated(offset=offset, limit=limit) + has_more = len(anchors) >= limit + + if wants_json(request): + return {"anchors": anchors, "offset": offset, "limit": limit} + + templates = get_templates(request) + return render(templates, "anchors/list.html", request, + anchors=anchors, + user={"username": username}, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="anchors", + ) + + +@router.post("") +async def create_anchor( + request: Request, + user: dict = Depends(require_auth), +): + """Create a new timestamp anchor.""" + import db + import anchoring + + body = await request.json() + content_hash = body.get("content_hash") + + if not content_hash: + raise HTTPException(400, "content_hash required") + + # Create OTS timestamp + try: + ots_data = await anchoring.create_timestamp(content_hash) + except Exception as e: + logger.error(f"Failed to create timestamp: {e}") + raise HTTPException(500, f"Timestamping failed: {e}") + + # Save anchor + anchor_id = await db.create_anchor( + username=user["username"], + content_hash=content_hash, + ots_data=ots_data, + ) + + return { + "anchor_id": anchor_id, + "content_hash": content_hash, + "status": "pending", + "message": "Anchor created, pending Bitcoin confirmation", + } + + +@router.get("/{anchor_id}") +async def get_anchor( + anchor_id: str, + request: Request, +): + """Get anchor details.""" + import db + + anchor = await db.get_anchor(anchor_id) + if not anchor: + raise HTTPException(404, "Anchor not found") + + if wants_json(request): + return anchor + + username = get_user_from_cookie(request) + templates = get_templates(request) + return render(templates, "anchors/detail.html", request, + anchor=anchor, + user={"username": username} if username else None, + active_tab="anchors", + ) + + +@router.get("/{anchor_id}/ots") +async def download_ots(anchor_id: str): + """Download OTS proof file.""" + import db + + anchor = await db.get_anchor(anchor_id) + if not anchor: + raise HTTPException(404, "Anchor not found") + + ots_data = anchor.get("ots_data") + if not ots_data: + raise HTTPException(404, "OTS data not available") + + # Return as file download + from fastapi.responses import Response + return Response( + content=ots_data, + media_type="application/octet-stream", + headers={ + "Content-Disposition": f"attachment; filename={anchor['content_hash']}.ots" + }, + ) + + +@router.post("/{anchor_id}/verify") +async def verify_anchor( + anchor_id: str, + request: Request, + user: dict = Depends(require_auth), +): + """Verify anchor status (check Bitcoin confirmation).""" + import db + import anchoring + + anchor = await db.get_anchor(anchor_id) + if not anchor: + raise HTTPException(404, "Anchor not found") + + try: + result = await anchoring.verify_timestamp( + anchor["content_hash"], + anchor["ots_data"], + ) + + # Update anchor status + if result.get("confirmed"): + await db.update_anchor( + anchor_id, + status="confirmed", + bitcoin_block=result.get("block_height"), + confirmed_at=result.get("confirmed_at"), + ) + + if wants_html(request): + if result.get("confirmed"): + return HTMLResponse( + f'Confirmed in block {result["block_height"]}' + ) + return HTMLResponse('Pending confirmation') + + return result + + except Exception as e: + logger.error(f"Verification failed: {e}") + raise HTTPException(500, f"Verification failed: {e}") + + +@router.delete("/{anchor_id}") +async def delete_anchor( + anchor_id: str, + user: dict = Depends(require_auth), +): + """Delete an anchor.""" + import db + + anchor = await db.get_anchor(anchor_id) + if not anchor: + raise HTTPException(404, "Anchor not found") + + if anchor.get("username") != user["username"]: + raise HTTPException(403, "Not authorized") + + success = await db.delete_anchor(anchor_id) + if not success: + raise HTTPException(400, "Failed to delete anchor") + + return {"deleted": True} diff --git a/app/routers/assets.py b/app/routers/assets.py new file mode 100644 index 0000000..cd8f5fd --- /dev/null +++ b/app/routers/assets.py @@ -0,0 +1,244 @@ +""" +Asset management routes for L2 server. + +Handles asset registration, listing, and publishing. +""" + +import logging +from typing import Optional, List + +from fastapi import APIRouter, Request, Depends, HTTPException, Form +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json + +from ..config import settings +from ..dependencies import get_templates, require_auth, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class AssetCreate(BaseModel): + name: str + content_hash: str + ipfs_cid: Optional[str] = None + asset_type: str # image, video, effect, recipe + tags: List[str] = [] + metadata: dict = {} + provenance: Optional[dict] = None + + +class RecordRunRequest(BaseModel): + run_id: str + recipe: str + inputs: List[str] + output_hash: str + ipfs_cid: Optional[str] = None + provenance: Optional[dict] = None + + +@router.get("") +async def list_assets( + request: Request, + offset: int = 0, + limit: int = 20, + asset_type: Optional[str] = None, +): + """List user's assets.""" + import db + + username = get_user_from_cookie(request) + if not username: + if wants_json(request): + raise HTTPException(401, "Authentication required") + from fastapi.responses import RedirectResponse + return RedirectResponse(url="/login", status_code=302) + + assets = await db.get_user_assets(username, offset=offset, limit=limit, asset_type=asset_type) + has_more = len(assets) >= limit + + if wants_json(request): + return {"assets": assets, "offset": offset, "limit": limit, "has_more": has_more} + + templates = get_templates(request) + return render(templates, "assets/list.html", request, + assets=assets, + user={"username": username}, + offset=offset, + limit=limit, + has_more=has_more, + active_tab="assets", + ) + + +@router.post("") +async def create_asset( + req: AssetCreate, + user: dict = Depends(require_auth), +): + """Register a new asset.""" + import db + + asset = await db.create_asset({ + "owner": user["username"], + "name": req.name, + "content_hash": req.content_hash, + "ipfs_cid": req.ipfs_cid, + "asset_type": req.asset_type, + "tags": req.tags or [], + "metadata": req.metadata or {}, + "provenance": req.provenance, + }) + + if not asset: + raise HTTPException(400, "Failed to create asset") + + return {"asset_id": asset.get("name"), "message": "Asset registered"} + + +@router.get("/{asset_id}") +async def get_asset( + asset_id: str, + request: Request, +): + """Get asset details.""" + import db + + username = get_user_from_cookie(request) + + asset = await db.get_asset(asset_id) + if not asset: + raise HTTPException(404, "Asset not found") + + if wants_json(request): + return asset + + templates = get_templates(request) + return render(templates, "assets/detail.html", request, + asset=asset, + user={"username": username} if username else None, + active_tab="assets", + ) + + +@router.delete("/{asset_id}") +async def delete_asset( + asset_id: str, + user: dict = Depends(require_auth), +): + """Delete an asset.""" + import db + + asset = await db.get_asset(asset_id) + if not asset: + raise HTTPException(404, "Asset not found") + + if asset.get("owner") != user["username"]: + raise HTTPException(403, "Not authorized") + + success = await db.delete_asset(asset_id) + if not success: + raise HTTPException(400, "Failed to delete asset") + + return {"deleted": True} + + +@router.post("/record-run") +async def record_run( + req: RecordRunRequest, + user: dict = Depends(require_auth), +): + """Record a run completion and register output as asset.""" + import db + + # Create asset for output + asset = await db.create_asset({ + "owner": user["username"], + "name": f"{req.recipe}-{req.run_id[:8]}", + "content_hash": req.output_hash, + "ipfs_cid": req.ipfs_cid, + "asset_type": "render", + "metadata": { + "run_id": req.run_id, + "recipe": req.recipe, + "inputs": req.inputs, + }, + "provenance": req.provenance, + }) + asset_id = asset.get("name") if asset else None + + # Record run + await db.record_run( + run_id=req.run_id, + username=user["username"], + recipe=req.recipe, + inputs=req.inputs or [], + output_hash=req.output_hash, + ipfs_cid=req.ipfs_cid, + asset_id=asset_id, + ) + + return { + "run_id": req.run_id, + "asset_id": asset_id, + "recorded": True, + } + + +@router.get("/by-run-id/{run_id}") +async def get_asset_by_run_id(run_id: str): + """Get asset by run ID (for L1 cache lookup).""" + import db + + run = await db.get_run(run_id) + if not run: + raise HTTPException(404, "Run not found") + + return { + "run_id": run_id, + "output_hash": run.get("output_hash"), + "ipfs_cid": run.get("ipfs_cid"), + "provenance_cid": run.get("provenance_cid"), + } + + +@router.post("/{asset_id}/publish") +async def publish_asset( + asset_id: str, + request: Request, + user: dict = Depends(require_auth), +): + """Publish asset to IPFS.""" + import db + import ipfs_client + + asset = await db.get_asset(asset_id) + if not asset: + raise HTTPException(404, "Asset not found") + + if asset.get("owner") != user["username"]: + raise HTTPException(403, "Not authorized") + + # Already published? + if asset.get("ipfs_cid"): + return {"ipfs_cid": asset["ipfs_cid"], "already_published": True} + + # Get content from L1 + content_hash = asset.get("content_hash") + for l1_url in settings.l1_servers: + try: + import requests + resp = requests.get(f"{l1_url}/cache/{content_hash}/raw", timeout=30) + if resp.status_code == 200: + # Pin to IPFS + cid = await ipfs_client.add_bytes(resp.content) + if cid: + await db.update_asset(asset_id, {"ipfs_cid": cid}) + return {"ipfs_cid": cid, "published": True} + except Exception as e: + logger.warning(f"Failed to fetch from {l1_url}: {e}") + + raise HTTPException(400, "Failed to publish - content not found on any L1") diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..4691caf --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,223 @@ +""" +Authentication routes for L2 server. + +Handles login, registration, logout, and token verification. +""" + +import hashlib +from datetime import datetime, timezone + +from fastapi import APIRouter, Request, Form, HTTPException, Depends +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from artdag_common import render +from artdag_common.middleware import wants_html + +from ..config import settings +from ..dependencies import get_templates, get_user_from_cookie + +router = APIRouter() +security = HTTPBearer(auto_error=False) + + +@router.get("/login", response_class=HTMLResponse) +async def login_page(request: Request, return_to: str = None): + """Login page.""" + username = get_user_from_cookie(request) + + if username: + templates = get_templates(request) + return render(templates, "auth/already_logged_in.html", request, + user={"username": username}, + ) + + templates = get_templates(request) + return render(templates, "auth/login.html", request, + return_to=return_to, + ) + + +@router.post("/login", response_class=HTMLResponse) +async def login_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + return_to: str = Form(None), +): + """Handle login form submission.""" + from auth import authenticate_user, create_access_token + + if not username or not password: + return HTMLResponse( + '
Username and password are required
' + ) + + user = await authenticate_user(settings.data_dir, username.strip(), password) + if not user: + return HTMLResponse( + '
Invalid username or password
' + ) + + token = create_access_token(user.username, l2_server=f"https://{settings.domain}") + + # Handle return_to redirect + if return_to and return_to.startswith("http"): + separator = "&" if "?" in return_to else "?" + redirect_url = f"{return_to}{separator}auth_token={token.access_token}" + response = HTMLResponse(f''' +
Login successful! Redirecting...
+ + ''') + else: + response = HTMLResponse(''' +
Login successful! Redirecting...
+ + ''') + + response.set_cookie( + key="auth_token", + value=token.access_token, + httponly=True, + max_age=60 * 60 * 24 * 30, + samesite="lax", + secure=True, + ) + return response + + +@router.get("/register", response_class=HTMLResponse) +async def register_page(request: Request): + """Registration page.""" + username = get_user_from_cookie(request) + + if username: + templates = get_templates(request) + return render(templates, "auth/already_logged_in.html", request, + user={"username": username}, + ) + + templates = get_templates(request) + return render(templates, "auth/register.html", request) + + +@router.post("/register", response_class=HTMLResponse) +async def register_submit( + request: Request, + username: str = Form(...), + password: str = Form(...), + password2: str = Form(...), + email: str = Form(None), +): + """Handle registration form submission.""" + from auth import create_user, create_access_token + + if not username or not password: + return HTMLResponse('
Username and password are required
') + + if password != password2: + return HTMLResponse('
Passwords do not match
') + + if len(password) < 6: + return HTMLResponse('
Password must be at least 6 characters
') + + try: + user = await create_user(settings.data_dir, username.strip(), password, email) + except ValueError as e: + return HTMLResponse(f'
{str(e)}
') + + token = create_access_token(user.username, l2_server=f"https://{settings.domain}") + + response = HTMLResponse(''' +
Registration successful! Redirecting...
+ + ''') + response.set_cookie( + key="auth_token", + value=token.access_token, + httponly=True, + max_age=60 * 60 * 24 * 30, + samesite="lax", + secure=True, + ) + return response + + +@router.get("/logout") +async def logout(request: Request): + """Handle logout.""" + import db + import requests + from auth import get_token_claims + + token = request.cookies.get("auth_token") + claims = get_token_claims(token) if token else None + username = claims.get("sub") if claims else None + + if username and token and claims: + # Revoke token in database + token_hash = hashlib.sha256(token.encode()).hexdigest() + expires_at = datetime.fromtimestamp(claims.get("exp", 0), tz=timezone.utc) + await db.revoke_token(token_hash, username, expires_at) + + # Revoke on attached L1 servers + attached = await db.get_user_renderers(username) + for l1_url in attached: + try: + requests.post( + f"{l1_url}/auth/revoke-user", + json={"username": username, "l2_server": f"https://{settings.domain}"}, + timeout=5, + ) + except Exception: + pass + + response = RedirectResponse(url="/", status_code=302) + response.delete_cookie("auth_token") + return response + + +@router.get("/verify") +async def verify_token( + request: Request, + credentials: HTTPAuthorizationCredentials = Depends(security), +): + """ + Verify a token is valid. + + Called by L1 servers to verify tokens during auth callback. + Returns user info if valid, 401 if not. + """ + import db + from auth import verify_token as verify_jwt, get_token_claims + + # Get token from Authorization header or query param + token = None + if credentials: + token = credentials.credentials + else: + # Try Authorization header manually (for clients that don't use Bearer format) + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + + if not token: + raise HTTPException(401, "No token provided") + + # Verify JWT signature and expiry + username = verify_jwt(token) + if not username: + raise HTTPException(401, "Invalid or expired token") + + # Check if token is revoked + claims = get_token_claims(token) + if claims: + token_hash = hashlib.sha256(token.encode()).hexdigest() + if await db.is_token_revoked(token_hash): + raise HTTPException(401, "Token has been revoked") + + return { + "valid": True, + "username": username, + "claims": claims, + } diff --git a/app/routers/federation.py b/app/routers/federation.py new file mode 100644 index 0000000..ab3fb4f --- /dev/null +++ b/app/routers/federation.py @@ -0,0 +1,115 @@ +""" +Federation routes for L2 server. + +Handles WebFinger, nodeinfo, and ActivityPub discovery. +""" + +import logging + +from fastapi import APIRouter, Request, HTTPException +from fastapi.responses import JSONResponse + +from ..config import settings + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/.well-known/webfinger") +async def webfinger(resource: str): + """WebFinger endpoint for actor discovery.""" + import db + + # Parse resource (acct:username@domain) + if not resource.startswith("acct:"): + raise HTTPException(400, "Invalid resource format") + + parts = resource[5:].split("@") + if len(parts) != 2: + raise HTTPException(400, "Invalid resource format") + + username, domain = parts + + if domain != settings.domain: + raise HTTPException(404, "User not on this server") + + user = await db.get_user(username) + if not user: + raise HTTPException(404, "User not found") + + return JSONResponse( + content={ + "subject": resource, + "aliases": [f"https://{settings.domain}/users/{username}"], + "links": [ + { + "rel": "self", + "type": "application/activity+json", + "href": f"https://{settings.domain}/users/{username}", + }, + { + "rel": "http://webfinger.net/rel/profile-page", + "type": "text/html", + "href": f"https://{settings.domain}/users/{username}", + }, + ], + }, + media_type="application/jrd+json", + ) + + +@router.get("/.well-known/nodeinfo") +async def nodeinfo_index(): + """NodeInfo index.""" + return JSONResponse( + content={ + "links": [ + { + "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0", + "href": f"https://{settings.domain}/nodeinfo/2.0", + } + ] + }, + media_type="application/json", + ) + + +@router.get("/nodeinfo/2.0") +async def nodeinfo(): + """NodeInfo 2.0 endpoint.""" + import db + + user_count = await db.count_users() + activity_count = await db.count_activities() + + return JSONResponse( + content={ + "version": "2.0", + "software": { + "name": "artdag", + "version": "1.0.0", + }, + "protocols": ["activitypub"], + "usage": { + "users": {"total": user_count, "activeMonth": user_count}, + "localPosts": activity_count, + }, + "openRegistrations": True, + "metadata": { + "nodeName": "Art-DAG", + "nodeDescription": "Content-addressable media processing with ActivityPub federation", + }, + }, + media_type="application/json", + ) + + +@router.get("/.well-known/host-meta") +async def host_meta(): + """Host-meta endpoint.""" + xml = f''' + + +''' + from fastapi.responses import Response + return Response(content=xml, media_type="application/xrd+xml") diff --git a/app/routers/renderers.py b/app/routers/renderers.py new file mode 100644 index 0000000..4b9edf6 --- /dev/null +++ b/app/routers/renderers.py @@ -0,0 +1,93 @@ +""" +Renderer (L1) management routes for L2 server. + +L1 servers are configured via environment variable L1_SERVERS. +Users connect to renderers to create and run recipes. +""" + +import logging +from typing import Optional + +import requests +from fastapi import APIRouter, Request, Depends, HTTPException +from fastapi.responses import HTMLResponse, RedirectResponse + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json + +from ..config import settings +from ..dependencies import get_templates, require_auth, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def check_renderer_health(url: str, timeout: float = 5.0) -> bool: + """Check if a renderer is healthy.""" + try: + resp = requests.get(f"{url}/", timeout=timeout) + return resp.status_code == 200 + except Exception: + return False + + +@router.get("") +async def list_renderers(request: Request): + """List configured L1 renderers.""" + # Get user if logged in + username = get_user_from_cookie(request) + user = None + if username: + # Get token for connection links + token = request.cookies.get("auth_token", "") + user = {"username": username, "token": token} + + # Build server list with health status + servers = [] + for url in settings.l1_servers: + servers.append({ + "url": url, + "healthy": check_renderer_health(url), + }) + + if wants_json(request): + return {"servers": servers} + + templates = get_templates(request) + return render(templates, "renderers/list.html", request, + servers=servers, + user=user, + active_tab="renderers", + ) + + +@router.get("/{path:path}") +async def renderer_catchall(path: str, request: Request): + """Catch-all for invalid renderer URLs - redirect to list.""" + if wants_json(request): + raise HTTPException(404, "Not found") + return RedirectResponse(url="/renderers", status_code=302) + + +@router.post("") +@router.post("/{path:path}") +async def renderer_post_catchall(request: Request, path: str = ""): + """ + Catch-all for POST requests. + + The old API expected JSON POST to attach renderers. + Now renderers are env-configured, so redirect to the list. + """ + if wants_json(request): + return { + "error": "Renderers are now configured via environment. See /renderers for available servers.", + "servers": settings.l1_servers, + } + + templates = get_templates(request) + return render(templates, "renderers/list.html", request, + servers=[{"url": url, "healthy": check_renderer_health(url)} for url in settings.l1_servers], + user=get_user_from_cookie(request), + error="Renderers are configured by the system administrator. Use the Connect button to access a renderer.", + active_tab="renderers", + ) diff --git a/app/routers/storage.py b/app/routers/storage.py new file mode 100644 index 0000000..f9cdcaf --- /dev/null +++ b/app/routers/storage.py @@ -0,0 +1,254 @@ +""" +Storage provider routes for L2 server. + +Manages user storage backends. +""" + +import logging +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Request, Depends, HTTPException, Form +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from artdag_common import render +from artdag_common.middleware import wants_html, wants_json + +from ..config import settings +from ..dependencies import get_templates, require_auth, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +STORAGE_PROVIDERS_INFO = { + "pinata": {"name": "Pinata", "desc": "1GB free, IPFS pinning", "color": "blue"}, + "web3storage": {"name": "web3.storage", "desc": "IPFS + Filecoin", "color": "green"}, + "nftstorage": {"name": "NFT.Storage", "desc": "Free for NFTs", "color": "pink"}, + "infura": {"name": "Infura IPFS", "desc": "5GB free", "color": "orange"}, + "filebase": {"name": "Filebase", "desc": "5GB free, S3+IPFS", "color": "cyan"}, + "storj": {"name": "Storj", "desc": "25GB free", "color": "indigo"}, + "local": {"name": "Local Storage", "desc": "Your own disk", "color": "purple"}, +} + + +class AddStorageRequest(BaseModel): + provider_type: str + config: Dict[str, Any] + capacity_gb: int = 5 + provider_name: Optional[str] = None + + +@router.get("") +async def list_storage(request: Request): + """List user's storage providers.""" + import db + + username = get_user_from_cookie(request) + if not username: + if wants_json(request): + raise HTTPException(401, "Authentication required") + from fastapi.responses import RedirectResponse + return RedirectResponse(url="/login", status_code=302) + + storages = await db.get_user_storage(username) + + if wants_json(request): + return {"storages": storages} + + templates = get_templates(request) + return render(templates, "storage/list.html", request, + storages=storages, + user={"username": username}, + providers_info=STORAGE_PROVIDERS_INFO, + active_tab="storage", + ) + + +@router.post("") +async def add_storage( + req: AddStorageRequest, + user: dict = Depends(require_auth), +): + """Add a storage provider.""" + import db + import storage_providers + + if req.provider_type not in STORAGE_PROVIDERS_INFO: + raise HTTPException(400, f"Invalid provider type: {req.provider_type}") + + # Test connection + provider = storage_providers.create_provider(req.provider_type, { + **req.config, + "capacity_gb": req.capacity_gb, + }) + if not provider: + raise HTTPException(400, "Failed to create provider") + + success, message = await provider.test_connection() + if not success: + raise HTTPException(400, f"Connection failed: {message}") + + # Save + storage_id = await db.add_user_storage( + username=user["username"], + provider_type=req.provider_type, + provider_name=req.provider_name, + config=req.config, + capacity_gb=req.capacity_gb, + ) + + return {"id": storage_id, "message": "Storage provider added"} + + +@router.post("/add", response_class=HTMLResponse) +async def add_storage_form( + request: Request, + provider_type: str = Form(...), + provider_name: Optional[str] = Form(None), + capacity_gb: int = Form(5), + api_key: Optional[str] = Form(None), + secret_key: Optional[str] = Form(None), + api_token: Optional[str] = Form(None), + project_id: Optional[str] = Form(None), + project_secret: Optional[str] = Form(None), + access_key: Optional[str] = Form(None), + bucket: Optional[str] = Form(None), + path: Optional[str] = Form(None), +): + """Add storage via HTML form.""" + import db + import storage_providers + + username = get_user_from_cookie(request) + if not username: + return HTMLResponse('
Not authenticated
', status_code=401) + + # Build config + config = {} + if provider_type == "pinata": + if not api_key or not secret_key: + return HTMLResponse('
Pinata requires API Key and Secret Key
') + config = {"api_key": api_key, "secret_key": secret_key} + elif provider_type in ["web3storage", "nftstorage"]: + if not api_token: + return HTMLResponse(f'
{provider_type} requires API Token
') + config = {"api_token": api_token} + elif provider_type == "infura": + if not project_id or not project_secret: + return HTMLResponse('
Infura requires Project ID and Secret
') + config = {"project_id": project_id, "project_secret": project_secret} + elif provider_type in ["filebase", "storj"]: + if not access_key or not secret_key or not bucket: + return HTMLResponse('
Requires Access Key, Secret Key, and Bucket
') + config = {"access_key": access_key, "secret_key": secret_key, "bucket": bucket} + elif provider_type == "local": + if not path: + return HTMLResponse('
Local storage requires a path
') + config = {"path": path} + else: + return HTMLResponse(f'
Unknown provider: {provider_type}
') + + # Test + provider = storage_providers.create_provider(provider_type, {**config, "capacity_gb": capacity_gb}) + if provider: + success, message = await provider.test_connection() + if not success: + return HTMLResponse(f'
Connection failed: {message}
') + + # Save + storage_id = await db.add_user_storage( + username=username, + provider_type=provider_type, + provider_name=provider_name, + config=config, + capacity_gb=capacity_gb, + ) + + return HTMLResponse(f''' +
Storage provider added!
+ + ''') + + +@router.get("/{storage_id}") +async def get_storage( + storage_id: int, + user: dict = Depends(require_auth), +): + """Get storage details.""" + import db + + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage not found") + + if storage.get("username") != user["username"]: + raise HTTPException(403, "Not authorized") + + return storage + + +@router.delete("/{storage_id}") +async def delete_storage( + storage_id: int, + request: Request, + user: dict = Depends(require_auth), +): + """Delete a storage provider.""" + import db + + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage not found") + + if storage.get("username") != user["username"]: + raise HTTPException(403, "Not authorized") + + success = await db.remove_user_storage(storage_id) + + if wants_html(request): + return HTMLResponse("") + + return {"deleted": True} + + +@router.post("/{storage_id}/test") +async def test_storage( + storage_id: int, + request: Request, + user: dict = Depends(require_auth), +): + """Test storage connectivity.""" + import db + import storage_providers + import json + + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage not found") + + if storage.get("username") != user["username"]: + raise HTTPException(403, "Not authorized") + + config = storage["config"] + if isinstance(config, str): + config = json.loads(config) + + provider = storage_providers.create_provider(storage["provider_type"], { + **config, + "capacity_gb": storage.get("capacity_gb", 5), + }) + + if not provider: + if wants_html(request): + return HTMLResponse('Failed to create provider') + return {"success": False, "message": "Failed to create provider"} + + success, message = await provider.test_connection() + + if wants_html(request): + color = "green" if success else "red" + return HTMLResponse(f'{message}') + + return {"success": success, "message": message} diff --git a/app/routers/users.py b/app/routers/users.py new file mode 100644 index 0000000..1715418 --- /dev/null +++ b/app/routers/users.py @@ -0,0 +1,161 @@ +""" +User profile routes for L2 server. + +Handles ActivityPub actor profiles. +""" + +import logging + +from fastapi import APIRouter, Request, HTTPException +from fastapi.responses import JSONResponse + +from artdag_common import render +from artdag_common.middleware import wants_html + +from ..config import settings +from ..dependencies import get_templates, get_user_from_cookie + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/users/{username}") +async def get_user_profile( + username: str, + request: Request, +): + """Get user profile (ActivityPub actor).""" + import db + + user = await db.get_user(username) + if not user: + raise HTTPException(404, "User not found") + + # ActivityPub response + accept = request.headers.get("accept", "") + if "application/activity+json" in accept or "application/ld+json" in accept: + actor = { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1", + ], + "type": "Person", + "id": f"https://{settings.domain}/users/{username}", + "name": user.get("display_name", username), + "preferredUsername": username, + "inbox": f"https://{settings.domain}/users/{username}/inbox", + "outbox": f"https://{settings.domain}/users/{username}/outbox", + "publicKey": { + "id": f"https://{settings.domain}/users/{username}#main-key", + "owner": f"https://{settings.domain}/users/{username}", + "publicKeyPem": user.get("public_key", ""), + }, + } + return JSONResponse(content=actor, media_type="application/activity+json") + + # HTML profile page + current_user = get_user_from_cookie(request) + assets = await db.get_user_assets(username, limit=12) + + templates = get_templates(request) + return render(templates, "users/profile.html", request, + profile=user, + assets=assets, + user={"username": current_user} if current_user else None, + ) + + +@router.get("/users/{username}/outbox") +async def get_outbox( + username: str, + request: Request, + page: bool = False, +): + """Get user's outbox (ActivityPub).""" + import db + + user = await db.get_user(username) + if not user: + raise HTTPException(404, "User not found") + + actor_id = f"https://{settings.domain}/users/{username}" + + if not page: + # Return collection summary + total = await db.count_user_activities(username) + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "OrderedCollection", + "id": f"{actor_id}/outbox", + "totalItems": total, + "first": f"{actor_id}/outbox?page=true", + }, + media_type="application/activity+json", + ) + + # Return paginated activities + activities = await db.get_user_activities(username, limit=20) + items = [a.get("activity_json", a) for a in activities] + + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "OrderedCollectionPage", + "id": f"{actor_id}/outbox?page=true", + "partOf": f"{actor_id}/outbox", + "orderedItems": items, + }, + media_type="application/activity+json", + ) + + +@router.post("/users/{username}/inbox") +async def receive_inbox( + username: str, + request: Request, +): + """Receive ActivityPub inbox message.""" + import db + + user = await db.get_user(username) + if not user: + raise HTTPException(404, "User not found") + + # TODO: Verify HTTP signature + # TODO: Process activity (Follow, Like, Announce, etc.) + + body = await request.json() + logger.info(f"Received inbox activity for {username}: {body.get('type')}") + + # For now, just acknowledge + return {"status": "accepted"} + + +@router.get("/") +async def home(request: Request): + """Home page.""" + import db + import markdown + + username = get_user_from_cookie(request) + + # Get recent activities + activities, _ = await db.get_activities_paginated(limit=10) + + # Get README if exists + readme_html = "" + try: + from pathlib import Path + readme_path = Path(__file__).parent.parent.parent / "README.md" + if readme_path.exists(): + readme_html = markdown.markdown(readme_path.read_text(), extensions=['tables', 'fenced_code']) + except Exception: + pass + + templates = get_templates(request) + return render(templates, "home.html", request, + user={"username": username} if username else None, + activities=activities, + readme_html=readme_html, + ) diff --git a/app/templates/404.html b/app/templates/404.html new file mode 100644 index 0000000..f6dbdcb --- /dev/null +++ b/app/templates/404.html @@ -0,0 +1,11 @@ +{% extends "base.html" %} + +{% block title %}Not Found - Art-DAG{% endblock %} + +{% block content %} +
+{% endblock %} diff --git a/app/templates/activities/list.html b/app/templates/activities/list.html new file mode 100644 index 0000000..d8a63cf --- /dev/null +++ b/app/templates/activities/list.html @@ -0,0 +1,39 @@ +{% extends "base.html" %} + +{% block title %}Activities - Art-DAG{% endblock %} + +{% block content %} +
+
+

Activities

+
+ + {% if activities %} + + + {% if has_more %} +
+ Load More +
+ {% endif %} + {% else %} +
+

No activities yet.

+
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/anchors/list.html b/app/templates/anchors/list.html new file mode 100644 index 0000000..3626852 --- /dev/null +++ b/app/templates/anchors/list.html @@ -0,0 +1,47 @@ +{% extends "base.html" %} + +{% block title %}Anchors - Art-DAG{% endblock %} + +{% block content %} +
+
+

Bitcoin Anchors

+
+ + {% if anchors %} +
+ {% for anchor in anchors %} +
+
+ {{ anchor.merkle_root[:16] }}... + {% if anchor.confirmed_at %} + Confirmed + {% else %} + Pending + {% endif %} +
+
+ {{ anchor.activity_count or 0 }} activities | Created: {{ anchor.created_at }} +
+ {% if anchor.bitcoin_txid %} +
+ TX: {{ anchor.bitcoin_txid }} +
+ {% endif %} +
+ {% endfor %} +
+ + {% if has_more %} +
+ Load More +
+ {% endif %} + {% else %} +
+

No anchors yet.

+
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/assets/list.html b/app/templates/assets/list.html new file mode 100644 index 0000000..b82f3b1 --- /dev/null +++ b/app/templates/assets/list.html @@ -0,0 +1,58 @@ +{% extends "base.html" %} + +{% block title %}Assets - Art-DAG{% endblock %} + +{% block content %} +
+

Your Assets

+ + {% if assets %} + + + {% if has_more %} +
+ Loading more... +
+ {% endif %} + + {% else %} +
+

No assets yet

+

Create content on an L1 renderer and publish it here.

+
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/auth/already_logged_in.html b/app/templates/auth/already_logged_in.html new file mode 100644 index 0000000..aa94799 --- /dev/null +++ b/app/templates/auth/already_logged_in.html @@ -0,0 +1,12 @@ +{% extends "base.html" %} + +{% block title %}Already Logged In - Art-DAG{% endblock %} + +{% block content %} +
+
+ You are already logged in as {{ user.username }} +
+

Go to home page

+
+{% endblock %} diff --git a/app/templates/auth/login.html b/app/templates/auth/login.html new file mode 100644 index 0000000..0ba4e66 --- /dev/null +++ b/app/templates/auth/login.html @@ -0,0 +1,37 @@ +{% extends "base.html" %} + +{% block title %}Login - Art-DAG{% endblock %} + +{% block content %} +
+

Login

+ +
+ +
+ {% if return_to %} + + {% endif %} + +
+ + +
+ +
+ + +
+ + +
+ +

+ Don't have an account? Register +

+
+{% endblock %} diff --git a/app/templates/auth/register.html b/app/templates/auth/register.html new file mode 100644 index 0000000..8a1837e --- /dev/null +++ b/app/templates/auth/register.html @@ -0,0 +1,45 @@ +{% extends "base.html" %} + +{% block title %}Register - Art-DAG{% endblock %} + +{% block content %} +
+

Register

+ +
+ +
+
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ + +
+ +

+ Already have an account? Login +

+
+{% endblock %} diff --git a/app/templates/base.html b/app/templates/base.html new file mode 100644 index 0000000..380ef13 --- /dev/null +++ b/app/templates/base.html @@ -0,0 +1,47 @@ +{% extends "_base.html" %} + +{% block brand %} +Rose Ash +| +Art-DAG +/ +L2 +{% endblock %} + +{% block cart_mini %} +{% if request and request.state.cart_mini_html %} + {{ request.state.cart_mini_html | safe }} +{% endif %} +{% endblock %} + +{% block nav_tree %} +{% if request and request.state.nav_tree_html %} + {{ request.state.nav_tree_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block auth_menu_mobile %} +{% if request and request.state.auth_menu_html %} + {{ request.state.auth_menu_html | safe }} +{% endif %} +{% endblock %} + +{% block sub_nav %} + +{% endblock %} diff --git a/app/templates/home.html b/app/templates/home.html new file mode 100644 index 0000000..1898981 --- /dev/null +++ b/app/templates/home.html @@ -0,0 +1,42 @@ +{% extends "base.html" %} + +{% block title %}Art-DAG{% endblock %} + +{% block content %} +
+ {% if readme_html %} +
+ {{ readme_html | safe }} +
+ {% else %} +
+

Art-DAG

+

Content-Addressable Media with ActivityPub Federation

+ + {% if not user %} +
+ Login + Register +
+ {% endif %} +
+ {% endif %} + + {% if activities %} +

Recent Activity

+
+ {% for activity in activities %} +
+
+ {{ activity.actor }} + {{ activity.created_at }} +
+
+ {{ activity.type }}: {{ activity.summary or activity.object_type }} +
+
+ {% endfor %} +
+ {% endif %} +
+{% endblock %} diff --git a/app/templates/renderers/list.html b/app/templates/renderers/list.html new file mode 100644 index 0000000..66f93b8 --- /dev/null +++ b/app/templates/renderers/list.html @@ -0,0 +1,52 @@ +{% extends "base.html" %} + +{% block content %} +
+

Renderers

+ +

+ Renderers are L1 servers that process your media. Connect to a renderer to create and run recipes. +

+ + {% if error %} +
+ {{ error }} +
+ {% endif %} + + {% if success %} +
+ {{ success }} +
+ {% endif %} + +
+ {% for server in servers %} +
+
+ + {{ server.url }} + + {% if server.healthy %} + Online + {% else %} + Offline + {% endif %} +
+ +
+ {% else %} +

No renderers configured.

+ {% endfor %} +
+ +
+

Renderers are configured by the system administrator.

+
+
+{% endblock %} diff --git a/app/templates/storage/list.html b/app/templates/storage/list.html new file mode 100644 index 0000000..a3aebf5 --- /dev/null +++ b/app/templates/storage/list.html @@ -0,0 +1,41 @@ +{% extends "base.html" %} + +{% block title %}Storage - Art-DAG{% endblock %} + +{% block content %} +
+
+

Storage Providers

+ + Add Storage + +
+ + {% if storages %} +
+ {% for storage in storages %} +
+
+ {{ storage.name or storage.provider_type }} + + {{ storage.provider_type }} + +
+
+ {% if storage.endpoint %} + {{ storage.endpoint }} + {% elif storage.bucket %} + Bucket: {{ storage.bucket }} + {% endif %} +
+
+ {% endfor %} +
+ {% else %} +
+

No storage providers configured.

+ Add one now +
+ {% endif %} +
+{% endblock %} diff --git a/artdag-client.tar.gz b/artdag-client.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..a4ec7f4490aabd26f65cc0b6c59bf4fed443139d GIT binary patch literal 7982 zcma*a({>z+0szq1w$a#Ztj20=Hdf;aCbrSoR%6?0Cbn(coU_*bg1g@~c{BpV|3F^~ z+qtguCB9Y}2U>{etTkH5DP6?aW((HtPqZy16_Z%CwD_>1vZ&(G;gr)b2ySJA-yfkV zAhC)Mw2hejWSOz7xrsi-iy+&a1zgzH8qewCp8-3or|!4ltEv0AFVT<1%L{;eE+Egt zTTpPiLvUSRd1q6p>|PQ;hFNxzL03+Ue0wl}-c?+>QFy(5djG(dbx*w|^gQ^JtqZ5| z9Do(7k2>N^9wO|*sZ_Q}wAb|McN-5eEoJW{$vV$&KL0eK5R@J@W@Nus+oZn`3I+I^rA6 zdW5Fy#+=~ItF&p%i)_s4c0cInrU+&3{Jt`vjZGR)t68PY|TbZb`ty@v3? zK3B8v^tx$Jx78v<3h@C?&SY?YFZHW)E2BAON&Q2vG9VQiJh6Xk!C0yn;Ps!M^f(4M z#1?NZ-h`aM-X9u(tFhjJ@GhL_3PBPtxdEx1)HB@AiT47&ehv(mH`9v`RBLEi-%=vp zsn`_PjSSJgC6&lWslW-dFblV# zHxOhz(;#Xt`l^R-Zr&F@ODV&XRznwXO(`%vOkhU~jJ5i6x-AIqPY6Xb0p-lFKPs39 z{l7!%r@2%-yvkeyAw#w#DQ|>i0|b8}v7JZUUtPYB6r&Xw$dO&_vPR=k^^AjsadyAL z6?Nag#T{{ZQu<4CnYWfU*7p+>Xs*ic7|-mdIdtiB8mZM^XJ-0r&UndI{AxKjFC|QA zHM8M5Nzwl{uJxN{{stHIu?lt407G;psr;ECMIPjcHm`i%@X&$a^?H)E6}|%6KwC9?#vgytMM%HgDEnHO^;~ z71DJGl#?ODFbsIN6u<~OoQ2h6Gk+wT_$I`><%AUo2X$u{&<~3hQ~=J_<~LL))fMb? zwuQ)+rQ*MbEa^cM?;&j~M}+xq){95iXLT`9Yz2eST$8A&OAKW2(bvbfBOi!20l-d< zIOEhdt9T6LWtt@=918pxp|p2@?S8g7nRH?6@tyD|nX!9-3LqkiaytP&X+7fg=C4#W zauFzxdJtGkvHa;~00i}UWfofOgVr7(6e4J~or6;cO$oj6cW?rKC(~eC2fDdH3a46x zGKpU!NpWX}p9U+GiJ#^|AvtC+`d*Fj($z&}Q#AD{{H1WD=|z@|FBrK%^l!wn9=T%; zqPJJ%R?Z5O3&bzvXP%PPYfo7&V~>Rjk}rqCC@Y+-lsi!Jt)6^NI*~^?=1^KV$am<= z-=kpVe}t$(DFXCQ49cq<|4crLDds>HjE^05iDF@)E2Ri34PcCS->s5qXx<_>zkEla zS8J8#M$LVv7)B&Otz@w$e>R)`87bPn_QyZ%ci~VLT()`4!nEG_P#a7x4gfCuoSE%r z&g*@*g!qLDWvOnIta(^)$0@z{&pg!z@-lpQg?!8SoD=z}h?u5Y>TzP$Tkht^A&DlA zfGCNKR6>#~>__rwMl$>N9Cx0a#E5uO2hthtQm2+g$8@`N-=A?DW1ncI5XGZWMt9dy zbMIG9iNU0Mwf&%YCVBGdL24$F!FWhr6V~n(;mGa+(fRD84`d)&aAcO^?;7L`g9kQo z86b}1F~b6e=G0Wmd6~vO($`~BKc5p@a+$;<7!sBq3*}@ADjpCDkUug>!W?!=7Pm93 ztC#6pTPb5zyfv^Vf^S}jTE!98vV7;U>)8+f9Q1(qb!I0U1UMvR^$~^(X2C8Oo;I_s z?FdWZ&oIOFe>`N$?A%OtqSFN%KW~B8;8{ZnssFl5%1N3@V-!=|s)b^If;BFbyLM%w zD13gDwR`s%KB^h zzt)xwUvIhnJ+uy4Fa)UN?zy3UfS7WgI@hguFWuJ@UKG^)rvezRu8~I#U-6-m`tV7S zzxfVm-~?t%s8+hFu_^zu?Ry?S=UN?+eQ<6 zI3g611?$sXu*+D621fsYRP-3xBgqc-zMahueNiXa@4L4Y7>@QGx(?m-H2YaD~ow;o5-WO18CzXV|eo?jPJ8e4Y5`XErP&{@K>#;BNPMR z%P>2eBHB}-c8l^0rNwWd#_0mp#n(vCr+Ensw0oD1n6_i)z3E&p*d8ec-)+yOf}?`x z!0pTR2ziaW5xr6Pq>-Fa82KU0vFSNakw{kjZ$eY_VAuAs@>rq7b(@5LcwicmWU;$P zdntVq|L~_vEy%-o14<0>9eT=_&^eTfc!Q|-obChnR6vf7|Lw{6j8HEVb8bT+#ByH= z+Uh>ly;D6Skb>!^id7N9lj0oHECt{gINp#~d~j*|aieLBBVBBiaBgPt>}fOc7K+(~ zCrr(RCZ$T#1djQrlS^6YsLfg3^cww`>nuac)UCZ+Yh0z!q*sDV(Xk@QDc7J^+XdM$ zeuu7Htbq6(ArEWtvkO^xV$lhay?jpokFA4DpU(MODa8~>;)6Fo-F_}%R6~v_*pKio z#>{}j-xEBWN64YZkh!rJK#K|EXNXZ}AaAUS$dHSx9*XV=2O6WpzAFx_R{i9%9-IWD zb$e;3RdtLcK%CK>K9}V=Tqm9F|M*bBO_W5^f7oGJSZ2{o4gWe#2i!5|rHi zo<&3CqEp5R;qK@y$-?_+S#gP#0^6PEHIMY(1b&(v?I*@f5@-6pu&*W>QLQI#c)rW` z3nHhdb-tEp(jP)tp?j72k6DQW`Oi&8o&`;ZgqQ97JsKX$+}^UU<_M$37%$a>5-Uj~ zj8@Z=P|b3di*KPig^9;GSQx@@g%<%emVc8C`!Bh20vhIZy}C6ZEVsp^xJBXfZOS8; zP66!?h*+#f`^iaiE(p^6FPyOroX2V(EgokuX29jHbH`S%dkHCbcZL_*4mL zQ3C1M4JChL?8ym0VKC+$%stm2Nyia`;Lkkh#a9pV(Pt2F;bU!HA&DdMQLSBQ(`?!p zc4nWX1RI%G3SS7byhd@qlohSgOb6?H?g_Xg(&iKKP|;8d2Xs3LA~YPr=+W z2?)?S@4Hg{lF8cZkj-ur?ywMemnvn53lg&zqe36D{d`pr8<^v_A$-~$exegW(xt{jJQ8vP&kv%Gnf}LlO)~&~5Fauqgj*BMzI91fr`EPsZ*Rv66NVzsyEQ6o*oT53c zg8NYFl+-mRVh1Q4dt6}cy-lftFO=L{N7*dGWN_e?fnRt*{Rb!94r+sh*2R=Hc)a?$ z4!$;ssgJ!Ezfudqs|8Arh)=OB6NlfSnAf?=7!Z~qt9GIurAVpRk(()dSjL2KuE#kWN%Vh!k zXNm=D{yGI6w!hX=^kE4FVz72FHI&B5WjMAXrg!7>`j0#t3B@j#h76jPnMK;2Z)x{C{bnC`y&uQ%zyS-QuX<7iY%(M z^j%J4dCGR)XVJy0K5BC5E(>Ua#zZ z0GI~&W4%h$MrM&g8^AbsDKnf!_3@7vNwWJNc4#ybJ-|(5kaVX)l~!n9u%uLp}pWM2`r1CgUns|^Np{qbL+wLk>+Q(e%X;=;45Uen*^_kxDHyG>emu$P_ z&|w~Q2HMo02UmW@TKpwR(lrn|RcXFIhj(dw{e^f?c!ldv@C{EYR=7Vdvb`Jc;1vr)Yja^at^) ztKsoqH`8P{{-tv$IJZAZzJ@nqgtY+E{+=matr5qo6eo8 zc~M~<^XdN7w?|c6GKp%~`-?jgx!a5iKI?QXBQq6e_-ru#@d$4uZJ*XA|EPal-4z~& zUK-=qO&&I_=KMIRh^O#aqAh(U-u{dBe%!H`@hyygBe@^9wC~?D&GBH)+G}0vr&jmg zD{uOlT*qRy2VVOxhKF1hvFoeK1hC|t`?sjg z6bP&zv0I>E(Ft!3Jy7MPBC1M0zN4wIXus-Q(t|NIhyKw?RgwQ9xdglia6Q44n{Nhjd3~rUz@EL)+vE>f%YH~UA}CruqZ#pLjBFxL)GFy zIfq0A%tI!iJ`l0e3Z`2(xZkg|q!MHF(NkCiu5~Ke+koWz%O4I~qrYL=Tr$Kn!ALPl zmc{zKg&F=Ju|k7p2R+rNS53FqP5as9>A73E7c8m&6OwIkZH7_Jaf7Ch^!%_DmW5eg}smE-mSLoV1Nr zE$A*qpr9@DP*IJqV1P3gc{b%q&JxXZX#kmHZ&Mbm*9fFV={%?f^CE&(ndefqFq0c9 zj%r;E4#Z>MCmTX7dI#8NDNPPr`(%g~&W_qF8)5P_O=)w;0G=gz5W{zAuLIXxqhGa0 zzl;1ZTS@fVP%8;=8O%Wh>)IyKCSi_@yKb*CM3hmda~pTl*~9w1B>kqYGvT=jOZ`*f)a4m|=?TXdT+-dfiz;Yk4AR#6VaWIE&CnvuR z!`4({91X|g#(r)3uLmmqZ%5jr$sU_8wCE6{1veSjS|)p&?&F9_8?!b@>MdM4Q*fE; zG0J(~UF6k|!APmp!}3jJ@=aFl#jdxR$J!D`WX9es3r~w;6)$>3FvZBSgVpNY58-A{O8lU-fX7 zdi>RP6oq0_QE z9KlJpAYHYU8)dGU*elGT04J0_MFZ>gs%~?0s~{t#YxFuQg!MWBa@teLH6LAB+<4!wg)dyo>9 z>nhCCGT`GaN13LPCYS9S54uX_;cL&IP!-B&q?c^yHO;U`cU!xAH74U~1POL&4jA@y zS1ZIPtrE5*3b1u1>tA3jW707}(z*ebB#(tB@eI@2yZc8Ild8iChH{5@68g!S8bX ziDs{2rih8_^fd0okA5Cz+rsEVJr-i3zrWP{2>;%PRB}arD4%6Qk@gW7r`h^%3LoC7 zg>;g~PAyiK72VUF%Ne@5t|T4SU^bTGACT&#N-;UydG*;e4V`Lv>t@7terrvruu4u< zSz7xZKCJMom7k5wqxydL+Vfji*C3ox*Uon0vqCrsS-`62x#|w|Lb?y~tdR(3$Ba5G zn&KoZ!5rvS&bi2e#s(YF337{6!W53;@8-krm7klXN5R&!&;rxV|FFm(7J&-3EO}pL z8_Xvy%O;p4t9dar1e}22gWEd_xiN|H9$Z$TJ(~&v>)dy5N6IyB?|iY1O9VYKyM3HB zk{5|dD3??qcjUi^(+Po18;AU@(v58 z^pgixLf?yKC~}Hbx09WDkXwt7owDF(mm{1)liqqwEV~XhA&AiD8KQTYmxaq{>a3*V zzl{6Tr;JH~c?oJ`7mCot7|5V=hHTp(7^cCjjAL+Q0^nNWo-UAxRzj{LeJxl0AE{@$ z3Y&o`kQxjJMD8NUeTxQo5qP0=`w*PDZ+wFjm&spUgQXCMXd)XI84tFM@&71YjU45KZ^M5eJ0jwx)nEZ`y2+GLrfPH7eljo zPE#A3mey{EMd9nVxySDtJ{?+8t1vP1s>cZ-;HHijPCWY{GTBA>FZB1uaz_ebGR9IU z8kASCt>h8rGh}|xPL|u7hSt9rrEV|x#jHz$+!3L6ugGDzXypc7V5zcuu>RI@)X52{ znndr1v*0+Z%}n50>G9y+$Q5M;hdpkQ1C^D#KHV1>_7171*juOm^U+8ff7%gfR*xIx zex8G*x?q@!m%oW~-qpNAw<%y;4O1{k*Q}|e6 zDcb+7re5+`Z1-( zFF=szH)T?%LH(_vApSp63NY|$*jEUFf$G`yTf=+2%XB`yE%zfF5t?=F^)E3bU-opU zbrd+Y5@^Mf*`!VpPTC+OKAW< zHGVz$%P$zB*v}#fLL-H4f_R!is(g-rk?_(L1D@F}_(54A#;f(VyhO@*QO`TalSHwi zUoZP4W*S`NW&p803Jl_UdOc!hnMD!mtI#xUD!#L>Q}hv;LWZ=uchkhu#;+qdW2}cY*4#=g69Z6+scm2e@nWIDb*Tt>)C>keXJG!(G>2A?pJ@kd1s%_h^s1w>t|D z&qBg7kJz4-?yioGK_lcKiAyZBTi)epLkb&sLK6>+f>meX#erzr>i_uHYzdmxA zt~y%um-G$%2OawROVJ`L7At&vgnKvVWt^8}D-4pn{E+INJp>?oY%|L1hw?%Iu$rj?A?^lr{~f1z|FhJ_4?J_v2$H;d+pBX}*&5UM;-$sNlq zq~<{!q)0V;j`NCWR!s7*DQLev{e@7w*~bvtT$R9g=ppm9`v7Ra;5~dYT_|sAW^mHk z^`e|hU}B%$0ytc}IXIfUKx&S93m~@i#V|gvd-)P{OTgO<4ku;U1hwR&ZRL!ZZRh%R zNg`YyRC{}7k!?TYXJ~)1Ky&uSv7J|Uot;Kp0lZk)Sk|<07?3&fw}%#{Cb`#&6%RzM z@Js5Ag09|Z`tUMYzQ44L5n|u0+JyU9(PSF&YDNFaQZyV}{!uzoL^(Vm1mQu-7d*g_ zzuMlesxJ?1>Q-B*|MOmtBUx8`P>`CVxhYY zhiXpS!AwP!ojw(FuusfK$djFL^uSjA3}XU}5oZOJh|rp%?bj2#Wu&`W%Mt?<8tC zTT@4&r?Bz}@rBWBeG2spCmKt{Shf)bR;m@v#&Y=^7uY@1=#%?!gFA~URz~dNndRJq fs@%m!FRke_UExZr$NxJW$6!bxm-!b22m<0iX<@vG literal 0 HcmV?d00001 diff --git a/auth.py b/auth.py new file mode 100644 index 0000000..a56e13c --- /dev/null +++ b/auth.py @@ -0,0 +1,213 @@ +""" +Authentication for Art DAG L2 Server. + +User registration, login, and JWT tokens. +""" + +import os +import secrets +from datetime import datetime, timezone, timedelta +from pathlib import Path +from typing import Optional + +import bcrypt +from jose import JWTError, jwt +from pydantic import BaseModel + +import db + +# JWT settings +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_DAYS = 30 + + +def load_jwt_secret() -> str: + """Load JWT secret from Docker secret, env var, or generate.""" + # Try Docker secret first + secret_path = Path("/run/secrets/jwt_secret") + if secret_path.exists(): + return secret_path.read_text().strip() + + # Try environment variable + if os.environ.get("JWT_SECRET"): + return os.environ["JWT_SECRET"] + + # Generate one (tokens won't persist across restarts!) + print("WARNING: No JWT_SECRET configured. Tokens will be invalidated on restart.") + return secrets.token_hex(32) + + +SECRET_KEY = load_jwt_secret() + + +class User(BaseModel): + """A registered user.""" + username: str + password_hash: str + created_at: str + email: Optional[str] = None + + +class UserCreate(BaseModel): + """Request to register a user.""" + username: str + password: str + email: Optional[str] = None + + +class UserLogin(BaseModel): + """Request to login.""" + username: str + password: str + + +class Token(BaseModel): + """JWT token response.""" + access_token: str + token_type: str = "bearer" + username: str + expires_at: str + + +# Keep DATA_DIR for keys (RSA keys still stored as files) +DATA_DIR = Path(os.environ.get("ARTDAG_DATA", str(Path.home() / ".artdag" / "l2"))) + + +def hash_password(password: str) -> str: + """Hash a password (truncate to 72 bytes for bcrypt).""" + # Truncate to 72 bytes (bcrypt limit) + pw_bytes = password.encode('utf-8')[:72] + return bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode('utf-8') + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash.""" + pw_bytes = plain_password.encode('utf-8')[:72] + return bcrypt.checkpw(pw_bytes, hashed_password.encode('utf-8')) + + +async def create_user(data_dir: Path, username: str, password: str, email: Optional[str] = None) -> User: + """Create a new user with ActivityPub keys.""" + from keys import generate_keypair + + if await db.user_exists(username): + raise ValueError(f"Username already exists: {username}") + + password_hash = hash_password(password) + user_data = await db.create_user(username, password_hash, email) + + # Generate ActivityPub keys for this user + generate_keypair(data_dir, username) + + # Convert datetime to ISO string if needed + created_at = user_data.get("created_at") + if hasattr(created_at, 'isoformat'): + created_at = created_at.isoformat() + + return User( + username=username, + password_hash=password_hash, + created_at=created_at, + email=email + ) + + +async def authenticate_user(data_dir: Path, username: str, password: str) -> Optional[User]: + """Authenticate a user by username and password.""" + user_data = await db.get_user(username) + + if not user_data: + return None + + if not verify_password(password, user_data["password_hash"]): + return None + + # Convert datetime to ISO string if needed + created_at = user_data.get("created_at") + if hasattr(created_at, 'isoformat'): + created_at = created_at.isoformat() + + return User( + username=user_data["username"], + password_hash=user_data["password_hash"], + created_at=created_at, + email=user_data.get("email") + ) + + +def create_access_token(username: str, l2_server: str = None, l1_server: str = None) -> Token: + """Create a JWT access token. + + Args: + username: The username + l2_server: The L2 server URL (e.g., https://artdag.rose-ash.com) + Required for L1 to verify tokens with the correct L2. + l1_server: Optional L1 server URL to scope the token to. + If set, token only works for this specific L1. + """ + expires = datetime.now(timezone.utc) + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS) + + payload = { + "sub": username, + "username": username, # Also include as username for compatibility + "exp": expires, + "iat": datetime.now(timezone.utc) + } + + # Include l2_server so L1 knows which L2 to verify with + if l2_server: + payload["l2_server"] = l2_server + + # Include l1_server to scope token to specific L1 + if l1_server: + payload["l1_server"] = l1_server + + token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) + + return Token( + access_token=token, + username=username, + expires_at=expires.isoformat() + ) + + +def verify_token(token: str) -> Optional[str]: + """Verify a JWT token, return username if valid.""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username = payload.get("sub") + return username + except JWTError: + return None + + +def get_token_claims(token: str) -> Optional[dict]: + """Decode token and return all claims. Returns None if invalid.""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + return payload + except JWTError: + return None + + +async def get_current_user(data_dir: Path, token: str) -> Optional[User]: + """Get current user from token.""" + username = verify_token(token) + if not username: + return None + + user_data = await db.get_user(username) + if not user_data: + return None + + # Convert datetime to ISO string if needed + created_at = user_data.get("created_at") + if hasattr(created_at, 'isoformat'): + created_at = created_at.isoformat() + + return User( + username=user_data["username"], + password_hash=user_data["password_hash"], + created_at=created_at, + email=user_data.get("email") + ) diff --git a/db.py b/db.py new file mode 100644 index 0000000..205271d --- /dev/null +++ b/db.py @@ -0,0 +1,1215 @@ +""" +Database module for Art DAG L2 Server. + +Uses asyncpg for async PostgreSQL access with connection pooling. +""" + +import json +import os +from datetime import datetime, timezone +from typing import Optional +from contextlib import asynccontextmanager + +import asyncpg + +# Connection pool (initialized on startup) + + +def _parse_timestamp(ts) -> datetime: + """Parse a timestamp string or datetime to datetime object.""" + if ts is None: + return datetime.now(timezone.utc) + if isinstance(ts, datetime): + return ts + # Parse ISO format string + if isinstance(ts, str): + if ts.endswith('Z'): + ts = ts[:-1] + '+00:00' + return datetime.fromisoformat(ts) + return datetime.now(timezone.utc) + + +_pool: Optional[asyncpg.Pool] = None + +# Configuration from environment +DATABASE_URL = os.environ.get("DATABASE_URL") +if not DATABASE_URL: + raise RuntimeError("DATABASE_URL environment variable is required") + +# Schema for database initialization +SCHEMA = """ +-- Users table +CREATE TABLE IF NOT EXISTS users ( + username VARCHAR(255) PRIMARY KEY, + password_hash VARCHAR(255) NOT NULL, + email VARCHAR(255), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Assets table +CREATE TABLE IF NOT EXISTS assets ( + name VARCHAR(255) PRIMARY KEY, + content_hash VARCHAR(128) NOT NULL, + ipfs_cid VARCHAR(128), + asset_type VARCHAR(50) NOT NULL, + tags JSONB DEFAULT '[]'::jsonb, + metadata JSONB DEFAULT '{}'::jsonb, + url TEXT, + provenance JSONB, + description TEXT, + origin JSONB, + owner VARCHAR(255) NOT NULL REFERENCES users(username), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ +); + +-- Activities table (activity_id is content-addressable run_id hash) +CREATE TABLE IF NOT EXISTS activities ( + activity_id VARCHAR(64) PRIMARY KEY, + activity_type VARCHAR(50) NOT NULL, + actor_id TEXT NOT NULL, + object_data JSONB NOT NULL, + published TIMESTAMPTZ NOT NULL, + signature JSONB, + anchor_root VARCHAR(64) -- Merkle root this activity is anchored to +); + +-- Anchors table (Bitcoin timestamps via OpenTimestamps) +CREATE TABLE IF NOT EXISTS anchors ( + id SERIAL PRIMARY KEY, + merkle_root VARCHAR(64) NOT NULL UNIQUE, + tree_ipfs_cid VARCHAR(128), + ots_proof_cid VARCHAR(128), + activity_count INTEGER NOT NULL, + first_activity_id VARCHAR(64), + last_activity_id VARCHAR(64), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + confirmed_at TIMESTAMPTZ, + bitcoin_txid VARCHAR(64) +); + +-- Followers table +CREATE TABLE IF NOT EXISTS followers ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL REFERENCES users(username), + acct VARCHAR(255) NOT NULL, + url TEXT NOT NULL, + public_key TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(username, acct) +); + +-- User's attached L1 renderers +CREATE TABLE IF NOT EXISTS user_renderers ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL REFERENCES users(username), + l1_url TEXT NOT NULL, + attached_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(username, l1_url) +); + +-- Revoked tokens (for federated logout) +CREATE TABLE IF NOT EXISTS revoked_tokens ( + token_hash VARCHAR(64) PRIMARY KEY, + username VARCHAR(255) NOT NULL, + revoked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL +); + +-- User storage providers (IPFS pinning services, local storage, etc.) +-- Users can have multiple configs of the same provider type +CREATE TABLE IF NOT EXISTS user_storage ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL REFERENCES users(username), + provider_type VARCHAR(50) NOT NULL, -- 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + provider_name VARCHAR(255), -- User-friendly name + description TEXT, -- User description to distinguish configs + config JSONB NOT NULL DEFAULT '{}', -- API keys, endpoints, paths + capacity_gb INTEGER NOT NULL, -- Total capacity user is contributing + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Track what's stored where +CREATE TABLE IF NOT EXISTS storage_pins ( + id SERIAL PRIMARY KEY, + content_hash VARCHAR(64) NOT NULL, + storage_id INTEGER NOT NULL REFERENCES user_storage(id) ON DELETE CASCADE, + ipfs_cid VARCHAR(128), + pin_type VARCHAR(20) NOT NULL, -- 'user_content', 'donated', 'system' + size_bytes BIGINT, + pinned_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(content_hash, storage_id) +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at); +CREATE INDEX IF NOT EXISTS idx_assets_content_hash ON assets(content_hash); +CREATE INDEX IF NOT EXISTS idx_assets_owner ON assets(owner); +CREATE INDEX IF NOT EXISTS idx_assets_created_at ON assets(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_assets_tags ON assets USING GIN(tags); +CREATE INDEX IF NOT EXISTS idx_activities_actor_id ON activities(actor_id); +CREATE INDEX IF NOT EXISTS idx_activities_published ON activities(published DESC); +CREATE INDEX IF NOT EXISTS idx_activities_anchor ON activities(anchor_root); +CREATE INDEX IF NOT EXISTS idx_anchors_created ON anchors(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_followers_username ON followers(username); +CREATE INDEX IF NOT EXISTS idx_revoked_tokens_expires ON revoked_tokens(expires_at); +CREATE INDEX IF NOT EXISTS idx_user_storage_username ON user_storage(username); +CREATE INDEX IF NOT EXISTS idx_storage_pins_hash ON storage_pins(content_hash); +CREATE INDEX IF NOT EXISTS idx_storage_pins_storage ON storage_pins(storage_id); + +-- Add source URL columns to assets if they don't exist +DO $$ BEGIN + ALTER TABLE assets ADD COLUMN source_url TEXT; +EXCEPTION WHEN duplicate_column THEN NULL; +END $$; + +DO $$ BEGIN + ALTER TABLE assets ADD COLUMN source_type VARCHAR(50); +EXCEPTION WHEN duplicate_column THEN NULL; +END $$; + +-- Add description column to user_storage if it doesn't exist +DO $$ BEGIN + ALTER TABLE user_storage ADD COLUMN description TEXT; +EXCEPTION WHEN duplicate_column THEN NULL; +END $$; +""" + + +async def init_pool(): + """Initialize the connection pool and create tables. Call on app startup.""" + global _pool + _pool = await asyncpg.create_pool( + DATABASE_URL, + min_size=2, + max_size=10, + command_timeout=60 + ) + # Create tables if they don't exist + async with _pool.acquire() as conn: + await conn.execute(SCHEMA) + + +async def close_pool(): + """Close the connection pool. Call on app shutdown.""" + global _pool + if _pool: + await _pool.close() + _pool = None + + +def get_pool() -> asyncpg.Pool: + """Get the connection pool.""" + if _pool is None: + raise RuntimeError("Database pool not initialized") + return _pool + + +@asynccontextmanager +async def get_connection(): + """Get a connection from the pool.""" + async with get_pool().acquire() as conn: + yield conn + + +@asynccontextmanager +async def transaction(): + """ + Get a connection with an active transaction. + + Usage: + async with db.transaction() as conn: + await create_asset_tx(conn, asset1) + await create_asset_tx(conn, asset2) + await create_activity_tx(conn, activity) + # Commits on exit, rolls back on exception + """ + async with get_pool().acquire() as conn: + async with conn.transaction(): + yield conn + + +# ============ Users ============ + +async def get_user(username: str) -> Optional[dict]: + """Get user by username.""" + async with get_connection() as conn: + row = await conn.fetchrow( + "SELECT username, password_hash, email, created_at FROM users WHERE username = $1", + username + ) + if row: + return dict(row) + return None + + +async def get_all_users() -> dict[str, dict]: + """Get all users as a dict indexed by username.""" + async with get_connection() as conn: + rows = await conn.fetch( + "SELECT username, password_hash, email, created_at FROM users ORDER BY username" + ) + return {row["username"]: dict(row) for row in rows} + + +async def create_user(username: str, password_hash: str, email: Optional[str] = None) -> dict: + """Create a new user.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """INSERT INTO users (username, password_hash, email) + VALUES ($1, $2, $3) + RETURNING username, password_hash, email, created_at""", + username, password_hash, email + ) + return dict(row) + + +async def user_exists(username: str) -> bool: + """Check if user exists.""" + async with get_connection() as conn: + result = await conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)", + username + ) + return result + + +# ============ Assets ============ + +async def get_asset(name: str) -> Optional[dict]: + """Get asset by name.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE name = $1""", + name + ) + if row: + return _parse_asset_row(row) + return None + + +async def get_asset_by_hash(content_hash: str) -> Optional[dict]: + """Get asset by content hash.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE content_hash = $1""", + content_hash + ) + if row: + return _parse_asset_row(row) + return None + + +async def get_asset_by_run_id(run_id: str) -> Optional[dict]: + """Get asset by run_id stored in provenance.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE provenance->>'run_id' = $1""", + run_id + ) + if row: + return _parse_asset_row(row) + return None + + +async def get_all_assets() -> dict[str, dict]: + """Get all assets as a dict indexed by name.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets ORDER BY created_at DESC""" + ) + return {row["name"]: _parse_asset_row(row) for row in rows} + + +async def get_assets_paginated(limit: int = 100, offset: int = 0) -> tuple[list[tuple[str, dict]], int]: + """Get paginated assets, returns (list of (name, asset) tuples, total_count).""" + async with get_connection() as conn: + total = await conn.fetchval("SELECT COUNT(*) FROM assets") + rows = await conn.fetch( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets ORDER BY created_at DESC LIMIT $1 OFFSET $2""", + limit, offset + ) + return [(row["name"], _parse_asset_row(row)) for row in rows], total + + +async def get_assets_by_owner(owner: str) -> dict[str, dict]: + """Get all assets owned by a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE owner = $1 ORDER BY created_at DESC""", + owner + ) + return {row["name"]: _parse_asset_row(row) for row in rows} + + +async def create_asset(asset: dict) -> dict: + """Create a new asset.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """INSERT INTO assets (name, content_hash, ipfs_cid, asset_type, tags, metadata, + url, provenance, description, origin, owner, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING *""", + asset["name"], + asset["content_hash"], + asset.get("ipfs_cid"), + asset["asset_type"], + json.dumps(asset.get("tags", [])), + json.dumps(asset.get("metadata", {})), + asset.get("url"), + json.dumps(asset.get("provenance")) if asset.get("provenance") else None, + asset.get("description"), + json.dumps(asset.get("origin")) if asset.get("origin") else None, + asset["owner"], + _parse_timestamp(asset.get("created_at")) + ) + return _parse_asset_row(row) + + +async def update_asset(name: str, updates: dict) -> Optional[dict]: + """Update an existing asset.""" + # Build dynamic UPDATE query + set_clauses = [] + values = [] + idx = 1 + + for key, value in updates.items(): + if key in ("tags", "metadata", "provenance", "origin"): + set_clauses.append(f"{key} = ${idx}") + values.append(json.dumps(value) if value is not None else None) + else: + set_clauses.append(f"{key} = ${idx}") + values.append(value) + idx += 1 + + set_clauses.append(f"updated_at = ${idx}") + values.append(datetime.now(timezone.utc)) + idx += 1 + + values.append(name) # WHERE clause + + async with get_connection() as conn: + row = await conn.fetchrow( + f"""UPDATE assets SET {', '.join(set_clauses)} + WHERE name = ${idx} RETURNING *""", + *values + ) + if row: + return _parse_asset_row(row) + return None + + +async def asset_exists(name: str) -> bool: + """Check if asset exists.""" + async with get_connection() as conn: + return await conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM assets WHERE name = $1)", + name + ) + + +def _parse_asset_row(row) -> dict: + """Parse a database row into an asset dict, handling JSONB fields.""" + asset = dict(row) + # Convert datetime to ISO string + if asset.get("created_at"): + asset["created_at"] = asset["created_at"].isoformat() + if asset.get("updated_at"): + asset["updated_at"] = asset["updated_at"].isoformat() + # Ensure JSONB fields are dicts (handle string case) + for field in ("tags", "metadata", "provenance", "origin"): + if isinstance(asset.get(field), str): + try: + asset[field] = json.loads(asset[field]) + except (json.JSONDecodeError, TypeError): + pass + return asset + + +# ============ Assets (Transaction variants) ============ + +async def get_asset_by_hash_tx(conn, content_hash: str) -> Optional[dict]: + """Get asset by content hash within a transaction.""" + row = await conn.fetchrow( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE content_hash = $1""", + content_hash + ) + if row: + return _parse_asset_row(row) + return None + + +async def asset_exists_by_name_tx(conn, name: str) -> bool: + """Check if asset name exists within a transaction.""" + return await conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM assets WHERE name = $1)", + name + ) + + +async def get_asset_by_name_tx(conn, name: str) -> Optional[dict]: + """Get asset by name within a transaction.""" + row = await conn.fetchrow( + """SELECT name, content_hash, ipfs_cid, asset_type, tags, metadata, url, + provenance, description, origin, owner, created_at, updated_at + FROM assets WHERE name = $1""", + name + ) + if row: + return _parse_asset_row(row) + return None + + +async def create_asset_tx(conn, asset: dict) -> dict: + """Create a new asset within a transaction.""" + row = await conn.fetchrow( + """INSERT INTO assets (name, content_hash, ipfs_cid, asset_type, tags, metadata, + url, provenance, description, origin, owner, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING *""", + asset["name"], + asset["content_hash"], + asset.get("ipfs_cid"), + asset["asset_type"], + json.dumps(asset.get("tags", [])), + json.dumps(asset.get("metadata", {})), + asset.get("url"), + json.dumps(asset.get("provenance")) if asset.get("provenance") else None, + asset.get("description"), + json.dumps(asset.get("origin")) if asset.get("origin") else None, + asset["owner"], + _parse_timestamp(asset.get("created_at")) + ) + return _parse_asset_row(row) + + +# ============ Activities ============ + +async def get_activity(activity_id: str) -> Optional[dict]: + """Get activity by ID.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities WHERE activity_id = $1""", + activity_id + ) + if row: + return _parse_activity_row(row) + return None + + +async def get_activity_by_index(index: int) -> Optional[dict]: + """Get activity by index (for backward compatibility with URL scheme).""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities ORDER BY published ASC LIMIT 1 OFFSET $1""", + index + ) + if row: + return _parse_activity_row(row) + return None + + +async def get_all_activities() -> list[dict]: + """Get all activities ordered by published date (oldest first for index compatibility).""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities ORDER BY published ASC""" + ) + return [_parse_activity_row(row) for row in rows] + + +async def get_activities_paginated(limit: int = 100, offset: int = 0) -> tuple[list[dict], int]: + """Get paginated activities (newest first), returns (activities, total_count).""" + async with get_connection() as conn: + total = await conn.fetchval("SELECT COUNT(*) FROM activities") + rows = await conn.fetch( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities ORDER BY published DESC LIMIT $1 OFFSET $2""", + limit, offset + ) + return [_parse_activity_row(row) for row in rows], total + + +async def get_activities_by_actor(actor_id: str) -> list[dict]: + """Get all activities by an actor.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities WHERE actor_id = $1 ORDER BY published DESC""", + actor_id + ) + return [_parse_activity_row(row) for row in rows] + + +async def create_activity(activity: dict) -> dict: + """Create a new activity.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """INSERT INTO activities (activity_id, activity_type, actor_id, object_data, published, signature) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING *""", + activity["activity_id"], + activity["activity_type"], + activity["actor_id"], + json.dumps(activity["object_data"]), + _parse_timestamp(activity["published"]), + json.dumps(activity.get("signature")) if activity.get("signature") else None + ) + return _parse_activity_row(row) + + +async def count_activities() -> int: + """Get total activity count.""" + async with get_connection() as conn: + return await conn.fetchval("SELECT COUNT(*) FROM activities") + + +def _parse_activity_row(row) -> dict: + """Parse a database row into an activity dict, handling JSONB fields.""" + activity = dict(row) + # Convert datetime to ISO string + if activity.get("published"): + activity["published"] = activity["published"].isoformat() + # Ensure JSONB fields are dicts (handle string case) + for field in ("object_data", "signature"): + if isinstance(activity.get(field), str): + try: + activity[field] = json.loads(activity[field]) + except (json.JSONDecodeError, TypeError): + pass + return activity + + +# ============ Activities (Transaction variants) ============ + +async def create_activity_tx(conn, activity: dict) -> dict: + """Create a new activity within a transaction.""" + row = await conn.fetchrow( + """INSERT INTO activities (activity_id, activity_type, actor_id, object_data, published, signature) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING *""", + activity["activity_id"], + activity["activity_type"], + activity["actor_id"], + json.dumps(activity["object_data"]), + _parse_timestamp(activity["published"]), + json.dumps(activity.get("signature")) if activity.get("signature") else None + ) + return _parse_activity_row(row) + + +# ============ Followers ============ + +async def get_followers(username: str) -> list[dict]: + """Get followers for a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT id, username, acct, url, public_key, created_at + FROM followers WHERE username = $1""", + username + ) + return [dict(row) for row in rows] + + +async def get_all_followers() -> list: + """Get all followers (for backward compatibility with old global list).""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT DISTINCT url FROM followers""" + ) + return [row["url"] for row in rows] + + +async def add_follower(username: str, acct: str, url: str, public_key: Optional[str] = None) -> dict: + """Add a follower.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """INSERT INTO followers (username, acct, url, public_key) + VALUES ($1, $2, $3, $4) + ON CONFLICT (username, acct) DO UPDATE SET url = $3, public_key = $4 + RETURNING *""", + username, acct, url, public_key + ) + return dict(row) + + +async def remove_follower(username: str, acct: str) -> bool: + """Remove a follower.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM followers WHERE username = $1 AND acct = $2", + username, acct + ) + return result == "DELETE 1" + + +# ============ Stats ============ + +async def get_stats() -> dict: + """Get counts for dashboard.""" + async with get_connection() as conn: + assets = await conn.fetchval("SELECT COUNT(*) FROM assets") + activities = await conn.fetchval("SELECT COUNT(*) FROM activities") + users = await conn.fetchval("SELECT COUNT(*) FROM users") + return {"assets": assets, "activities": activities, "users": users} + + +# ============ Anchors (Bitcoin timestamps) ============ + +async def get_unanchored_activities() -> list[dict]: + """Get all activities not yet anchored to Bitcoin.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities WHERE anchor_root IS NULL ORDER BY published ASC""" + ) + return [_parse_activity_row(row) for row in rows] + + +async def create_anchor(anchor: dict) -> dict: + """Create an anchor record.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """INSERT INTO anchors (merkle_root, tree_ipfs_cid, ots_proof_cid, + activity_count, first_activity_id, last_activity_id) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING *""", + anchor["merkle_root"], + anchor.get("tree_ipfs_cid"), + anchor.get("ots_proof_cid"), + anchor["activity_count"], + anchor.get("first_activity_id"), + anchor.get("last_activity_id") + ) + return dict(row) + + +async def mark_activities_anchored(activity_ids: list[str], merkle_root: str) -> int: + """Mark activities as anchored with the given merkle root.""" + async with get_connection() as conn: + result = await conn.execute( + """UPDATE activities SET anchor_root = $1 + WHERE activity_id = ANY($2::text[])""", + merkle_root, + activity_ids + ) + # Returns "UPDATE N" + return int(result.split()[1]) if result else 0 + + +async def get_anchor(merkle_root: str) -> Optional[dict]: + """Get anchor by merkle root.""" + async with get_connection() as conn: + row = await conn.fetchrow( + "SELECT * FROM anchors WHERE merkle_root = $1", + merkle_root + ) + if row: + result = dict(row) + if result.get("first_activity_id"): + result["first_activity_id"] = str(result["first_activity_id"]) + if result.get("last_activity_id"): + result["last_activity_id"] = str(result["last_activity_id"]) + if result.get("created_at"): + result["created_at"] = result["created_at"].isoformat() + if result.get("confirmed_at"): + result["confirmed_at"] = result["confirmed_at"].isoformat() + return result + return None + + +async def get_all_anchors() -> list[dict]: + """Get all anchors, newest first.""" + async with get_connection() as conn: + rows = await conn.fetch( + "SELECT * FROM anchors ORDER BY created_at DESC" + ) + results = [] + for row in rows: + result = dict(row) + if result.get("first_activity_id"): + result["first_activity_id"] = str(result["first_activity_id"]) + if result.get("last_activity_id"): + result["last_activity_id"] = str(result["last_activity_id"]) + if result.get("created_at"): + result["created_at"] = result["created_at"].isoformat() + if result.get("confirmed_at"): + result["confirmed_at"] = result["confirmed_at"].isoformat() + results.append(result) + return results + + +async def get_anchors_paginated(offset: int = 0, limit: int = 20) -> list[dict]: + """Get anchors with pagination, newest first.""" + async with get_connection() as conn: + rows = await conn.fetch( + "SELECT * FROM anchors ORDER BY created_at DESC LIMIT $1 OFFSET $2", + limit, offset + ) + results = [] + for row in rows: + result = dict(row) + if result.get("first_activity_id"): + result["first_activity_id"] = str(result["first_activity_id"]) + if result.get("last_activity_id"): + result["last_activity_id"] = str(result["last_activity_id"]) + if result.get("created_at"): + result["created_at"] = result["created_at"].isoformat() + if result.get("confirmed_at"): + result["confirmed_at"] = result["confirmed_at"].isoformat() + results.append(result) + return results + + +async def update_anchor_confirmed(merkle_root: str, bitcoin_txid: str) -> bool: + """Mark anchor as confirmed with Bitcoin txid.""" + async with get_connection() as conn: + result = await conn.execute( + """UPDATE anchors SET confirmed_at = NOW(), bitcoin_txid = $1 + WHERE merkle_root = $2""", + bitcoin_txid, merkle_root + ) + return result == "UPDATE 1" + + +async def get_anchor_stats() -> dict: + """Get anchoring statistics.""" + async with get_connection() as conn: + total_anchors = await conn.fetchval("SELECT COUNT(*) FROM anchors") + confirmed_anchors = await conn.fetchval( + "SELECT COUNT(*) FROM anchors WHERE confirmed_at IS NOT NULL" + ) + pending_anchors = await conn.fetchval( + "SELECT COUNT(*) FROM anchors WHERE confirmed_at IS NULL" + ) + anchored_activities = await conn.fetchval( + "SELECT COUNT(*) FROM activities WHERE anchor_root IS NOT NULL" + ) + unanchored_activities = await conn.fetchval( + "SELECT COUNT(*) FROM activities WHERE anchor_root IS NULL" + ) + return { + "total_anchors": total_anchors, + "confirmed_anchors": confirmed_anchors, + "pending_anchors": pending_anchors, + "anchored_activities": anchored_activities, + "unanchored_activities": unanchored_activities + } + + +# ============ User Renderers (L1 attachments) ============ + +async def get_user_renderers(username: str) -> list[str]: + """Get L1 renderer URLs attached by a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + "SELECT l1_url FROM user_renderers WHERE username = $1 ORDER BY attached_at", + username + ) + return [row["l1_url"] for row in rows] + + +async def attach_renderer(username: str, l1_url: str) -> bool: + """Attach a user to an L1 renderer. Returns True if newly attached.""" + async with get_connection() as conn: + try: + await conn.execute( + """INSERT INTO user_renderers (username, l1_url) + VALUES ($1, $2) + ON CONFLICT (username, l1_url) DO NOTHING""", + username, l1_url + ) + return True + except Exception: + return False + + +async def detach_renderer(username: str, l1_url: str) -> bool: + """Detach a user from an L1 renderer. Returns True if was attached.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM user_renderers WHERE username = $1 AND l1_url = $2", + username, l1_url + ) + return "DELETE 1" in result + + +# ============ User Storage ============ + +async def get_user_storage(username: str) -> list[dict]: + """Get all storage providers for a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT id, username, provider_type, provider_name, description, config, + capacity_gb, is_active, created_at, updated_at + FROM user_storage WHERE username = $1 + ORDER BY provider_type, created_at""", + username + ) + return [dict(row) for row in rows] + + +async def get_user_storage_by_type(username: str, provider_type: str) -> list[dict]: + """Get storage providers of a specific type for a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT id, username, provider_type, provider_name, description, config, + capacity_gb, is_active, created_at, updated_at + FROM user_storage WHERE username = $1 AND provider_type = $2 + ORDER BY created_at""", + username, provider_type + ) + return [dict(row) for row in rows] + + +async def get_storage_by_id(storage_id: int) -> Optional[dict]: + """Get a storage provider by ID.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT id, username, provider_type, provider_name, description, config, + capacity_gb, is_active, created_at, updated_at + FROM user_storage WHERE id = $1""", + storage_id + ) + return dict(row) if row else None + + +async def add_user_storage( + username: str, + provider_type: str, + provider_name: str, + config: dict, + capacity_gb: int, + description: Optional[str] = None +) -> Optional[int]: + """Add a storage provider for a user. Returns storage ID.""" + async with get_connection() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO user_storage (username, provider_type, provider_name, description, config, capacity_gb) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id""", + username, provider_type, provider_name, description, json.dumps(config), capacity_gb + ) + return row["id"] if row else None + except Exception: + return None + + +async def update_user_storage( + storage_id: int, + provider_name: Optional[str] = None, + description: Optional[str] = None, + config: Optional[dict] = None, + capacity_gb: Optional[int] = None, + is_active: Optional[bool] = None +) -> bool: + """Update a storage provider.""" + updates = [] + params = [] + param_num = 1 + + if provider_name is not None: + updates.append(f"provider_name = ${param_num}") + params.append(provider_name) + param_num += 1 + if description is not None: + updates.append(f"description = ${param_num}") + params.append(description) + param_num += 1 + if config is not None: + updates.append(f"config = ${param_num}") + params.append(json.dumps(config)) + param_num += 1 + if capacity_gb is not None: + updates.append(f"capacity_gb = ${param_num}") + params.append(capacity_gb) + param_num += 1 + if is_active is not None: + updates.append(f"is_active = ${param_num}") + params.append(is_active) + param_num += 1 + + if not updates: + return False + + updates.append("updated_at = NOW()") + params.append(storage_id) + + async with get_connection() as conn: + result = await conn.execute( + f"UPDATE user_storage SET {', '.join(updates)} WHERE id = ${param_num}", + *params + ) + return "UPDATE 1" in result + + +async def remove_user_storage(storage_id: int) -> bool: + """Remove a storage provider. Cascades to storage_pins.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM user_storage WHERE id = $1", + storage_id + ) + return "DELETE 1" in result + + +async def get_storage_usage(storage_id: int) -> dict: + """Get storage usage stats for a provider.""" + async with get_connection() as conn: + row = await conn.fetchrow( + """SELECT + COUNT(*) as pin_count, + COALESCE(SUM(size_bytes), 0) as used_bytes + FROM storage_pins WHERE storage_id = $1""", + storage_id + ) + return {"pin_count": row["pin_count"], "used_bytes": row["used_bytes"]} + + +async def add_storage_pin( + content_hash: str, + storage_id: int, + ipfs_cid: Optional[str], + pin_type: str, + size_bytes: int +) -> Optional[int]: + """Add a pin record. Returns pin ID.""" + async with get_connection() as conn: + try: + row = await conn.fetchrow( + """INSERT INTO storage_pins (content_hash, storage_id, ipfs_cid, pin_type, size_bytes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (content_hash, storage_id) DO UPDATE SET + ipfs_cid = EXCLUDED.ipfs_cid, + pin_type = EXCLUDED.pin_type, + size_bytes = EXCLUDED.size_bytes, + pinned_at = NOW() + RETURNING id""", + content_hash, storage_id, ipfs_cid, pin_type, size_bytes + ) + return row["id"] if row else None + except Exception: + return None + + +async def remove_storage_pin(content_hash: str, storage_id: int) -> bool: + """Remove a pin record.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM storage_pins WHERE content_hash = $1 AND storage_id = $2", + content_hash, storage_id + ) + return "DELETE 1" in result + + +async def get_pins_for_content(content_hash: str) -> list[dict]: + """Get all storage locations where content is pinned.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT sp.*, us.provider_type, us.provider_name, us.username + FROM storage_pins sp + JOIN user_storage us ON sp.storage_id = us.id + WHERE sp.content_hash = $1""", + content_hash + ) + return [dict(row) for row in rows] + + +async def get_all_active_storage() -> list[dict]: + """Get all active storage providers (for distributed pinning).""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT us.id, us.username, us.provider_type, us.provider_name, us.description, + us.config, us.capacity_gb, us.is_active, us.created_at, us.updated_at, + COALESCE(SUM(sp.size_bytes), 0) as used_bytes, + COUNT(sp.id) as pin_count + FROM user_storage us + LEFT JOIN storage_pins sp ON us.id = sp.storage_id + WHERE us.is_active = true + GROUP BY us.id + ORDER BY us.provider_type, us.created_at""" + ) + return [dict(row) for row in rows] + + +# ============ Token Revocation ============ + +async def revoke_token(token_hash: str, username: str, expires_at) -> bool: + """Revoke a token. Returns True if newly revoked.""" + async with get_connection() as conn: + try: + await conn.execute( + """INSERT INTO revoked_tokens (token_hash, username, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (token_hash) DO NOTHING""", + token_hash, username, expires_at + ) + return True + except Exception: + return False + + +async def is_token_revoked(token_hash: str) -> bool: + """Check if a token has been revoked.""" + async with get_connection() as conn: + row = await conn.fetchrow( + "SELECT 1 FROM revoked_tokens WHERE token_hash = $1 AND expires_at > NOW()", + token_hash + ) + return row is not None + + +async def cleanup_expired_revocations() -> int: + """Remove expired revocation entries. Returns count removed.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM revoked_tokens WHERE expires_at < NOW()" + ) + # Extract count from "DELETE N" + try: + return int(result.split()[-1]) + except (ValueError, IndexError): + return 0 + + +# ============ Additional helper functions ============ + +async def get_user_assets(username: str, offset: int = 0, limit: int = 20, asset_type: str = None) -> list[dict]: + """Get assets owned by a user with pagination.""" + async with get_connection() as conn: + if asset_type: + rows = await conn.fetch( + """SELECT * FROM assets WHERE owner = $1 AND asset_type = $2 + ORDER BY created_at DESC LIMIT $3 OFFSET $4""", + username, asset_type, limit, offset + ) + else: + rows = await conn.fetch( + """SELECT * FROM assets WHERE owner = $1 + ORDER BY created_at DESC LIMIT $2 OFFSET $3""", + username, limit, offset + ) + return [dict(row) for row in rows] + + +async def delete_asset(asset_id: str) -> bool: + """Delete an asset by name/id.""" + async with get_connection() as conn: + result = await conn.execute("DELETE FROM assets WHERE name = $1", asset_id) + return "DELETE 1" in result + + +async def count_users() -> int: + """Count total users.""" + async with get_connection() as conn: + return await conn.fetchval("SELECT COUNT(*) FROM users") + + +async def count_user_activities(username: str) -> int: + """Count activities by a user.""" + async with get_connection() as conn: + return await conn.fetchval( + "SELECT COUNT(*) FROM activities WHERE actor_id LIKE $1", + f"%{username}%" + ) + + +async def get_user_activities(username: str, limit: int = 20, offset: int = 0) -> list[dict]: + """Get activities by a user.""" + async with get_connection() as conn: + rows = await conn.fetch( + """SELECT activity_id, activity_type, actor_id, object_data, published, signature + FROM activities WHERE actor_id LIKE $1 + ORDER BY published DESC LIMIT $2 OFFSET $3""", + f"%{username}%", limit, offset + ) + return [_parse_activity_row(row) for row in rows] + + +async def get_renderer(renderer_id: str) -> Optional[dict]: + """Get a renderer by ID/URL.""" + async with get_connection() as conn: + row = await conn.fetchrow( + "SELECT * FROM user_renderers WHERE l1_url = $1", + renderer_id + ) + return dict(row) if row else None + + +async def update_anchor(anchor_id: str, **updates) -> bool: + """Update an anchor.""" + async with get_connection() as conn: + if "bitcoin_txid" in updates: + result = await conn.execute( + """UPDATE anchors SET bitcoin_txid = $1, confirmed_at = NOW() + WHERE merkle_root = $2""", + updates["bitcoin_txid"], anchor_id + ) + return "UPDATE 1" in result + return False + + +async def delete_anchor(anchor_id: str) -> bool: + """Delete an anchor.""" + async with get_connection() as conn: + result = await conn.execute( + "DELETE FROM anchors WHERE merkle_root = $1", anchor_id + ) + return "DELETE 1" in result + + +async def record_run(run_id: str, username: str, recipe: str, inputs: list, + output_hash: str, ipfs_cid: str = None, asset_id: str = None) -> dict: + """Record a completed run.""" + async with get_connection() as conn: + # Check if runs table exists, if not just return the data + try: + row = await conn.fetchrow( + """INSERT INTO runs (run_id, username, recipe, inputs, output_hash, ipfs_cid, asset_id, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (run_id) DO UPDATE SET + output_hash = EXCLUDED.output_hash, + ipfs_cid = EXCLUDED.ipfs_cid, + asset_id = EXCLUDED.asset_id + RETURNING *""", + run_id, username, recipe, json.dumps(inputs), output_hash, ipfs_cid, asset_id + ) + return dict(row) if row else None + except Exception: + # Table might not exist + return {"run_id": run_id, "username": username, "recipe": recipe} + + +async def get_run(run_id: str) -> Optional[dict]: + """Get a run by ID.""" + async with get_connection() as conn: + try: + row = await conn.fetchrow("SELECT * FROM runs WHERE run_id = $1", run_id) + if row: + result = dict(row) + if result.get("inputs") and isinstance(result["inputs"], str): + result["inputs"] = json.loads(result["inputs"]) + return result + except Exception: + pass + return None diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..aac2460 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +cd "$(dirname "$0")" + +echo "=== Pulling latest code ===" +git pull + +echo "=== Building Docker image ===" +docker build --build-arg CACHEBUST=$(date +%s) -t git.rose-ash.com/art-dag/l2-server:latest . + +echo "=== Redeploying activitypub stack ===" +docker stack deploy -c docker-compose.yml activitypub + +echo "=== Restarting proxy nginx ===" +docker service update --force proxy_nginx + +echo "=== Done ===" +docker stack services activitypub diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..0f67e81 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,90 @@ +version: "3.8" + +services: + postgres: + image: postgres:16-alpine + env_file: + - .env + environment: + POSTGRES_USER: artdag + POSTGRES_DB: artdag + volumes: + - postgres_data:/var/lib/postgresql/data + networks: + - internal + healthcheck: + test: ["CMD-SHELL", "pg_isready -U artdag"] + interval: 5s + timeout: 5s + retries: 5 + deploy: + placement: + constraints: + - node.labels.gpu != true + + ipfs: + image: ipfs/kubo:latest + ports: + - "4002:4001" # Swarm TCP (4002 external, L1 uses 4001) + - "4002:4001/udp" # Swarm UDP + volumes: + - ipfs_data:/data/ipfs + networks: + - internal + - externalnet # For gateway access + deploy: + replicas: 1 + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + + l2-server: + image: registry.rose-ash.com:5000/l2-server:latest + env_file: + - .env + environment: + - ARTDAG_DATA=/data/l2 + - IPFS_API=/dns/ipfs/tcp/5001 + - ANCHOR_BACKUP_DIR=/data/anchors + # Coop app internal URLs for fragment composition + - INTERNAL_URL_BLOG=http://blog:8000 + - INTERNAL_URL_CART=http://cart:8000 + - INTERNAL_URL_ACCOUNT=http://account:8000 + # DATABASE_URL, ARTDAG_DOMAIN, ARTDAG_USER, JWT_SECRET from .env file + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8200/')"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 15s + volumes: + - l2_data:/data/l2 # Still needed for RSA keys + - anchor_backup:/data/anchors # Persistent anchor proofs (survives DB wipes) + networks: + - internal + - externalnet + depends_on: + - postgres + - ipfs + deploy: + replicas: 1 + update_config: + order: start-first + restart_policy: + condition: on-failure + placement: + constraints: + - node.labels.gpu != true + +volumes: + l2_data: + postgres_data: + ipfs_data: + anchor_backup: # Persistent - don't delete when resetting DB + +networks: + internal: + externalnet: + external: true diff --git a/docker-stack.yml b/docker-stack.yml new file mode 100644 index 0000000..3411aeb --- /dev/null +++ b/docker-stack.yml @@ -0,0 +1,91 @@ +version: "3.8" + +# Full Art DAG stack for Docker Swarm deployment +# Deploy with: docker stack deploy -c docker-stack.yml artdag + +services: + # Redis for L1 + redis: + image: redis:7-alpine + volumes: + - redis_data:/data + networks: + - artdag + deploy: + replicas: 1 + placement: + constraints: + - node.role == manager + restart_policy: + condition: on-failure + + # L1 Server (API) + l1-server: + image: git.rose-ash.com/art-dag/l1-server:latest + ports: + - "8100:8100" + env_file: + - .env + environment: + - REDIS_URL=redis://redis:6379/5 + - CACHE_DIR=/data/cache + # L1_PUBLIC_URL, L2_SERVER, L2_DOMAIN from .env file + volumes: + - l1_cache:/data/cache + depends_on: + - redis + networks: + - artdag + deploy: + replicas: 1 + restart_policy: + condition: on-failure + + # L1 Worker (Celery) + l1-worker: + image: git.rose-ash.com/art-dag/l1-server:latest + command: celery -A celery_app worker --loglevel=info + environment: + - REDIS_URL=redis://redis:6379/5 + - CACHE_DIR=/data/cache + - C_FORCE_ROOT=true + volumes: + - l1_cache:/data/cache + depends_on: + - redis + networks: + - artdag + deploy: + replicas: 2 + restart_policy: + condition: on-failure + + # L2 Server (ActivityPub) + l2-server: + image: git.rose-ash.com/art-dag/l2-server:latest + ports: + - "8200:8200" + env_file: + - .env + environment: + - ARTDAG_DATA=/data/l2 + # ARTDAG_DOMAIN, JWT_SECRET from .env file (multi-actor, no ARTDAG_USER) + volumes: + - l2_data:/data/l2 + depends_on: + - l1-server + networks: + - artdag + deploy: + replicas: 1 + restart_policy: + condition: on-failure + +volumes: + redis_data: + l1_cache: + l2_data: + +networks: + artdag: + driver: overlay diff --git a/ipfs_client.py b/ipfs_client.py new file mode 100644 index 0000000..108327b --- /dev/null +++ b/ipfs_client.py @@ -0,0 +1,226 @@ +# art-activity-pub/ipfs_client.py +""" +IPFS client for Art DAG L2 server. + +Provides functions to fetch, pin, and add content to IPFS. +Uses direct HTTP API calls for compatibility with all Kubo versions. +""" + +import json +import logging +import os +import re +from typing import Optional + +import requests + + +class IPFSError(Exception): + """Raised when an IPFS operation fails.""" + pass + +logger = logging.getLogger(__name__) + +# IPFS API multiaddr - default to local, docker uses /dns/ipfs/tcp/5001 +IPFS_API = os.getenv("IPFS_API", "/ip4/127.0.0.1/tcp/5001") + +# Connection timeout in seconds +IPFS_TIMEOUT = int(os.getenv("IPFS_TIMEOUT", "60")) + + +def _multiaddr_to_url(multiaddr: str) -> str: + """Convert IPFS multiaddr to HTTP URL.""" + # Handle /dns/hostname/tcp/port format + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +# Base URL for IPFS API +IPFS_BASE_URL = _multiaddr_to_url(IPFS_API) + + +def get_bytes(cid: str) -> Optional[bytes]: + """ + Retrieve content from IPFS by CID. + + Args: + cid: IPFS CID to retrieve + + Returns: + Content as bytes or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/cat" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + data = response.content + + logger.info(f"Retrieved from IPFS: {cid} ({len(data)} bytes)") + return data + except Exception as e: + logger.error(f"Failed to get from IPFS: {e}") + return None + + +def pin(cid: str) -> bool: + """ + Pin a CID on this node. + + Args: + cid: IPFS CID to pin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/add" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Pinned on IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to pin on IPFS: {e}") + return False + + +def unpin(cid: str) -> bool: + """ + Unpin a CID from this node. + + Args: + cid: IPFS CID to unpin + + Returns: + True on success, False on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/rm" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Unpinned from IPFS: {cid}") + return True + except Exception as e: + logger.error(f"Failed to unpin from IPFS: {e}") + return False + + +def is_available() -> bool: + """ + Check if IPFS daemon is available. + + Returns: + True if IPFS is available, False otherwise + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def get_node_id() -> Optional[str]: + """ + Get this IPFS node's peer ID. + + Returns: + Peer ID string or None on failure + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/id" + response = requests.post(url, timeout=IPFS_TIMEOUT) + response.raise_for_status() + return response.json().get("ID") + except Exception as e: + logger.error(f"Failed to get node ID: {e}") + return None + + +def add_bytes(data: bytes, pin: bool = True) -> str: + """ + Add bytes data to IPFS and optionally pin it. + + Args: + data: Bytes to add + pin: Whether to pin the data (default: True) + + Returns: + IPFS CID + + Raises: + IPFSError: If adding fails + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/add" + params = {"pin": str(pin).lower()} + files = {"file": ("data", data)} + + response = requests.post(url, params=params, files=files, timeout=IPFS_TIMEOUT) + response.raise_for_status() + result = response.json() + cid = result["Hash"] + + logger.info(f"Added to IPFS: {len(data)} bytes -> {cid}") + return cid + except Exception as e: + logger.error(f"Failed to add bytes to IPFS: {e}") + raise IPFSError(f"Failed to add bytes to IPFS: {e}") from e + + +def add_json(data: dict) -> str: + """ + Serialize dict to JSON and add to IPFS. + + Args: + data: Dictionary to serialize and store + + Returns: + IPFS CID + + Raises: + IPFSError: If adding fails + """ + json_bytes = json.dumps(data, indent=2, sort_keys=True).encode('utf-8') + return add_bytes(json_bytes, pin=True) + + +def pin_or_raise(cid: str) -> None: + """ + Pin a CID on IPFS. Raises exception on failure. + + Args: + cid: IPFS CID to pin + + Raises: + IPFSError: If pinning fails + """ + try: + url = f"{IPFS_BASE_URL}/api/v0/pin/add" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + logger.info(f"Pinned on IPFS: {cid}") + except Exception as e: + logger.error(f"Failed to pin on IPFS: {e}") + raise IPFSError(f"Failed to pin {cid}: {e}") from e diff --git a/keys.py b/keys.py new file mode 100644 index 0000000..247a558 --- /dev/null +++ b/keys.py @@ -0,0 +1,119 @@ +""" +Key management for ActivityPub signing. + +Keys are stored in DATA_DIR/keys/: +- {username}.pem - Private key (chmod 600) +- {username}.pub - Public key +""" + +import base64 +import hashlib +import json +from datetime import datetime, timezone +from pathlib import Path + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa, padding + + +def get_keys_dir(data_dir: Path) -> Path: + """Get keys directory, create if needed.""" + keys_dir = data_dir / "keys" + keys_dir.mkdir(parents=True, exist_ok=True) + return keys_dir + + +def generate_keypair(data_dir: Path, username: str) -> tuple[str, str]: + """Generate RSA keypair for a user. + + Returns (private_pem, public_pem) + """ + keys_dir = get_keys_dir(data_dir) + private_path = keys_dir / f"{username}.pem" + public_path = keys_dir / f"{username}.pub" + + # Generate key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Serialize private key + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode() + + # Serialize public key + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode() + + # Save keys + private_path.write_text(private_pem) + private_path.chmod(0o600) + public_path.write_text(public_pem) + + return private_pem, public_pem + + +def load_private_key(data_dir: Path, username: str): + """Load private key for signing.""" + keys_dir = get_keys_dir(data_dir) + private_path = keys_dir / f"{username}.pem" + + if not private_path.exists(): + raise FileNotFoundError(f"Private key not found: {private_path}") + + private_pem = private_path.read_text() + return serialization.load_pem_private_key( + private_pem.encode(), + password=None + ) + + +def load_public_key_pem(data_dir: Path, username: str) -> str: + """Load public key PEM for actor profile.""" + keys_dir = get_keys_dir(data_dir) + public_path = keys_dir / f"{username}.pub" + + if not public_path.exists(): + raise FileNotFoundError(f"Public key not found: {public_path}") + + return public_path.read_text() + + +def has_keys(data_dir: Path, username: str) -> bool: + """Check if keys exist for user.""" + keys_dir = get_keys_dir(data_dir) + return (keys_dir / f"{username}.pem").exists() + + +def sign_data(private_key, data: str) -> str: + """Sign data with private key, return base64 signature.""" + signature = private_key.sign( + data.encode(), + padding.PKCS1v15(), + hashes.SHA256() + ) + return base64.b64encode(signature).decode() + + +def create_signature(data_dir: Path, username: str, domain: str, activity: dict) -> dict: + """Create RsaSignature2017 for an activity.""" + private_key = load_private_key(data_dir, username) + + # Create canonical JSON for signing + canonical = json.dumps(activity, sort_keys=True, separators=(',', ':')) + + # Sign + signature_value = sign_data(private_key, canonical) + + return { + "type": "RsaSignature2017", + "creator": f"https://{domain}/users/{username}#main-key", + "created": datetime.now(timezone.utc).isoformat(), + "signatureValue": signature_value + } diff --git a/migrate.py b/migrate.py new file mode 100755 index 0000000..146c487 --- /dev/null +++ b/migrate.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +""" +Migration script: JSON files to PostgreSQL. + +Usage: + python migrate.py [--dry-run] + +Migrates: +- users.json -> users table +- registry.json -> assets table +- activities.json -> activities table +- followers.json -> followers table + +Does NOT migrate: +- keys/ directory (stays as files) +""" + +import asyncio +import json +import os +import sys +from pathlib import Path +from datetime import datetime, timezone +from uuid import UUID + +import asyncpg + +# Configuration +DATA_DIR = Path(os.environ.get("ARTDAG_DATA", str(Path.home() / ".artdag" / "l2"))) +DATABASE_URL = os.environ.get("DATABASE_URL") +if not DATABASE_URL: + raise RuntimeError("DATABASE_URL environment variable is required") + +SCHEMA = """ +-- Drop existing tables (careful in production!) +DROP TABLE IF EXISTS followers CASCADE; +DROP TABLE IF EXISTS activities CASCADE; +DROP TABLE IF EXISTS assets CASCADE; +DROP TABLE IF EXISTS users CASCADE; + +-- Users table +CREATE TABLE users ( + username VARCHAR(255) PRIMARY KEY, + password_hash VARCHAR(255) NOT NULL, + email VARCHAR(255), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Assets table +CREATE TABLE assets ( + name VARCHAR(255) PRIMARY KEY, + content_hash VARCHAR(128) NOT NULL, + asset_type VARCHAR(50) NOT NULL, + tags JSONB DEFAULT '[]'::jsonb, + metadata JSONB DEFAULT '{}'::jsonb, + url TEXT, + provenance JSONB, + description TEXT, + origin JSONB, + owner VARCHAR(255) NOT NULL REFERENCES users(username), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ +); + +-- Activities table +CREATE TABLE activities ( + activity_id UUID PRIMARY KEY, + activity_type VARCHAR(50) NOT NULL, + actor_id TEXT NOT NULL, + object_data JSONB NOT NULL, + published TIMESTAMPTZ NOT NULL, + signature JSONB +); + +-- Followers table +CREATE TABLE followers ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL REFERENCES users(username), + acct VARCHAR(255) NOT NULL, + url TEXT NOT NULL, + public_key TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(username, acct) +); + +-- Indexes +CREATE INDEX idx_users_created_at ON users(created_at); +CREATE INDEX idx_assets_content_hash ON assets(content_hash); +CREATE INDEX idx_assets_owner ON assets(owner); +CREATE INDEX idx_assets_created_at ON assets(created_at DESC); +CREATE INDEX idx_assets_tags ON assets USING GIN(tags); +CREATE INDEX idx_activities_actor_id ON activities(actor_id); +CREATE INDEX idx_activities_published ON activities(published DESC); +CREATE INDEX idx_followers_username ON followers(username); +""" + + +async def migrate(dry_run: bool = False): + """Run the migration.""" + print(f"Migrating from {DATA_DIR} to PostgreSQL") + print(f"Database: {DATABASE_URL}") + print(f"Dry run: {dry_run}") + print() + + # Load JSON files + users = load_json(DATA_DIR / "users.json") or {} + registry = load_json(DATA_DIR / "registry.json") or {"assets": {}} + activities_data = load_json(DATA_DIR / "activities.json") or {"activities": []} + followers = load_json(DATA_DIR / "followers.json") or [] + + assets = registry.get("assets", {}) + activities = activities_data.get("activities", []) + + print(f"Found {len(users)} users") + print(f"Found {len(assets)} assets") + print(f"Found {len(activities)} activities") + print(f"Found {len(followers)} followers") + print() + + if dry_run: + print("DRY RUN - no changes made") + return + + # Connect and migrate + conn = await asyncpg.connect(DATABASE_URL) + try: + # Create schema + print("Creating schema...") + await conn.execute(SCHEMA) + + # Migrate users + print("Migrating users...") + for username, user_data in users.items(): + await conn.execute( + """INSERT INTO users (username, password_hash, email, created_at) + VALUES ($1, $2, $3, $4)""", + username, + user_data["password_hash"], + user_data.get("email"), + parse_timestamp(user_data.get("created_at")) + ) + print(f" Migrated {len(users)} users") + + # Migrate assets + print("Migrating assets...") + for name, asset in assets.items(): + await conn.execute( + """INSERT INTO assets (name, content_hash, asset_type, tags, metadata, + url, provenance, description, origin, owner, + created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)""", + name, + asset["content_hash"], + asset["asset_type"], + json.dumps(asset.get("tags", [])), + json.dumps(asset.get("metadata", {})), + asset.get("url"), + json.dumps(asset.get("provenance")) if asset.get("provenance") else None, + asset.get("description"), + json.dumps(asset.get("origin")) if asset.get("origin") else None, + asset["owner"], + parse_timestamp(asset.get("created_at")), + parse_timestamp(asset.get("updated_at")) + ) + print(f" Migrated {len(assets)} assets") + + # Migrate activities + print("Migrating activities...") + for activity in activities: + await conn.execute( + """INSERT INTO activities (activity_id, activity_type, actor_id, + object_data, published, signature) + VALUES ($1, $2, $3, $4, $5, $6)""", + UUID(activity["activity_id"]), + activity["activity_type"], + activity["actor_id"], + json.dumps(activity["object_data"]), + parse_timestamp(activity["published"]), + json.dumps(activity.get("signature")) if activity.get("signature") else None + ) + print(f" Migrated {len(activities)} activities") + + # Migrate followers + print("Migrating followers...") + if followers and users: + first_user = list(users.keys())[0] + migrated = 0 + for follower in followers: + if isinstance(follower, str): + # Old format: just URL string + await conn.execute( + """INSERT INTO followers (username, acct, url) + VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING""", + first_user, + follower, + follower + ) + migrated += 1 + elif isinstance(follower, dict): + await conn.execute( + """INSERT INTO followers (username, acct, url, public_key) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING""", + follower.get("username", first_user), + follower.get("acct", follower.get("url", "")), + follower["url"], + follower.get("public_key") + ) + migrated += 1 + print(f" Migrated {migrated} followers") + else: + print(" No followers to migrate") + + print() + print("Migration complete!") + + finally: + await conn.close() + + +def load_json(path: Path) -> dict | list | None: + """Load JSON file if it exists.""" + if path.exists(): + with open(path) as f: + return json.load(f) + return None + + +def parse_timestamp(ts: str | None) -> datetime | None: + """Parse ISO timestamp string to datetime.""" + if not ts: + return datetime.now(timezone.utc) + try: + # Handle various ISO formats + if ts.endswith('Z'): + ts = ts[:-1] + '+00:00' + return datetime.fromisoformat(ts) + except Exception: + return datetime.now(timezone.utc) + + +if __name__ == "__main__": + dry_run = "--dry-run" in sys.argv + asyncio.run(migrate(dry_run)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..94d1e5a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +fastapi>=0.109.0 +uvicorn>=0.27.0 +requests>=2.31.0 +httpx>=0.27.0 +cryptography>=42.0.0 +bcrypt>=4.0.0 +python-jose[cryptography]>=3.3.0 +markdown>=3.5.0 +python-multipart>=0.0.6 +asyncpg>=0.29.0 +boto3>=1.34.0 +# Shared components +git+https://git.rose-ash.com/art-dag/common.git@889ea98 diff --git a/server.py b/server.py new file mode 100644 index 0000000..9c00b57 --- /dev/null +++ b/server.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +Art DAG L2 Server - ActivityPub + +Minimal entry point that uses the modular app factory. +All routes are defined in app/routers/. +All templates are in app/templates/. +""" + +import logging +import os + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s %(name)s: %(message)s' +) + +# Import the app from the factory +from app import app + +if __name__ == "__main__": + import uvicorn + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", "8200")) + uvicorn.run("server:app", host=host, port=port, workers=4) diff --git a/server_legacy.py b/server_legacy.py new file mode 100644 index 0000000..7ab9a56 --- /dev/null +++ b/server_legacy.py @@ -0,0 +1,3765 @@ +#!/usr/bin/env python3 +""" +Art DAG L2 Server - ActivityPub + +Manages ownership registry, activities, and federation. +- Registry of owned assets +- ActivityPub actor endpoints +- Sign and publish Create activities +- Federation with other servers +""" + +import hashlib +import json +import logging +import os +import uuid +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s %(name)s: %(message)s' +) +logger = logging.getLogger(__name__) + +from fastapi import FastAPI, HTTPException, Request, Response, Depends, Cookie, Form +from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse, FileResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel +import requests +import markdown + +import db +from auth import ( + UserCreate, UserLogin, Token, User, + create_user, authenticate_user, create_access_token, + verify_token, get_token_claims, get_current_user +) + +# Configuration +DOMAIN = os.environ.get("ARTDAG_DOMAIN", "artdag.rose-ash.com") +DATA_DIR = Path(os.environ.get("ARTDAG_DATA", str(Path.home() / ".artdag" / "l2"))) +L1_PUBLIC_URL = os.environ.get("L1_PUBLIC_URL", "https://celery-artdag.rose-ash.com") +EFFECTS_REPO_URL = os.environ.get("EFFECTS_REPO_URL", "https://git.rose-ash.com/art-dag/effects") +IPFS_GATEWAY_URL = os.environ.get("IPFS_GATEWAY_URL", "") + +# Known L1 renderers (comma-separated URLs) +L1_SERVERS_STR = os.environ.get("L1_SERVERS", "https://celery-artdag.rose-ash.com") +L1_SERVERS = [s.strip() for s in L1_SERVERS_STR.split(",") if s.strip()] + +# Cookie domain for sharing auth across subdomains (e.g., ".rose-ash.com") +# If not set, derives from DOMAIN (strips first subdomain, adds leading dot) +def _get_cookie_domain(): + env_val = os.environ.get("COOKIE_DOMAIN") + if env_val: + return env_val + # Derive from DOMAIN: artdag.rose-ash.com -> .rose-ash.com + parts = DOMAIN.split(".") + if len(parts) >= 2: + return "." + ".".join(parts[-2:]) + return None + +COOKIE_DOMAIN = _get_cookie_domain() + +# Ensure data directory exists +DATA_DIR.mkdir(parents=True, exist_ok=True) +(DATA_DIR / "assets").mkdir(exist_ok=True) + + +def compute_run_id(input_hashes: list[str], recipe: str, recipe_hash: str = None) -> str: + """ + Compute a deterministic run_id from inputs and recipe. + + The run_id is a SHA3-256 hash of: + - Sorted input content hashes + - Recipe identifier (recipe_hash if provided, else "effect:{recipe}") + + This makes runs content-addressable: same inputs + recipe = same run_id. + Must match the L1 implementation exactly. + """ + data = { + "inputs": sorted(input_hashes), + "recipe": recipe_hash or f"effect:{recipe}", + "version": "1", # For future schema changes + } + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha3_256(json_str.encode()).hexdigest() + +# Load README +README_PATH = Path(__file__).parent / "README.md" +README_CONTENT = "" +if README_PATH.exists(): + README_CONTENT = README_PATH.read_text() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage database connection pool lifecycle.""" + await db.init_pool() + yield + await db.close_pool() + + +app = FastAPI( + title="Art DAG L2 Server", + description="ActivityPub server for Art DAG ownership and federation", + version="0.1.0", + lifespan=lifespan +) + + +@app.exception_handler(404) +async def not_found_handler(request: Request, exc): + """Custom 404 page.""" + accept = request.headers.get("accept", "") + if "text/html" in accept and "application/json" not in accept: + content = ''' + + ''' + username = get_user_from_cookie(request) + return HTMLResponse(base_html("Not Found", content, username), status_code=404) + return JSONResponse({"detail": "Not found"}, status_code=404) + + +# ============ Data Models ============ + +class Asset(BaseModel): + """An owned asset.""" + name: str + content_hash: str + ipfs_cid: Optional[str] = None # IPFS content identifier + asset_type: str # image, video, effect, recipe, infrastructure + tags: list[str] = [] + metadata: dict = {} + url: Optional[str] = None + provenance: Optional[dict] = None + created_at: str = "" + + +class Activity(BaseModel): + """An ActivityPub activity.""" + activity_id: str + activity_type: str # Create, Update, Delete, Announce + actor_id: str + object_data: dict + published: str + signature: Optional[dict] = None + + +class RegisterRequest(BaseModel): + """Request to register an asset.""" + name: str + content_hash: str + ipfs_cid: Optional[str] = None # IPFS content identifier + asset_type: str + tags: list[str] = [] + metadata: dict = {} + url: Optional[str] = None + provenance: Optional[dict] = None + + +class RecordRunRequest(BaseModel): + """Request to record an L1 run.""" + run_id: str + l1_server: str # URL of the L1 server that has this run + output_name: Optional[str] = None # Deprecated - assets now named by content_hash + + +class PublishCacheRequest(BaseModel): + """Request to publish a cache item from L1.""" + content_hash: str + ipfs_cid: Optional[str] = None # IPFS content identifier + asset_name: str + asset_type: str = "image" + origin: dict # {type: "self"|"external", url?: str, note?: str} + description: Optional[str] = None + tags: list[str] = [] + metadata: dict = {} + + +class UpdateAssetRequest(BaseModel): + """Request to update an existing asset.""" + description: Optional[str] = None + tags: Optional[list[str]] = None + metadata: Optional[dict] = None + origin: Optional[dict] = None + ipfs_cid: Optional[str] = None # IPFS content identifier + + +class AddStorageRequest(BaseModel): + """Request to add a storage provider.""" + provider_type: str # 'pinata', 'web3storage', 'local' + provider_name: Optional[str] = None # User-friendly name + config: dict # Provider-specific config (api_key, path, etc.) + capacity_gb: int # Storage capacity in GB + + +class UpdateStorageRequest(BaseModel): + """Request to update a storage provider.""" + config: Optional[dict] = None + capacity_gb: Optional[int] = None + is_active: Optional[bool] = None + + +class SetAssetSourceRequest(BaseModel): + """Request to set source URL for an asset.""" + source_url: str + source_type: str # 'youtube', 'local', 'url' + + +# ============ Storage (Database) ============ + +async def load_registry() -> dict: + """Load registry from database.""" + assets = await db.get_all_assets() + return {"version": "1.0", "assets": assets} + + +async def load_activities() -> list: + """Load activities from database.""" + return await db.get_all_activities() + + +def load_actor(username: str) -> dict: + """Load actor data for a specific user with public key if available.""" + actor = { + "id": f"https://{DOMAIN}/users/{username}", + "type": "Person", + "preferredUsername": username, + "name": username, + "inbox": f"https://{DOMAIN}/users/{username}/inbox", + "outbox": f"https://{DOMAIN}/users/{username}/outbox", + "followers": f"https://{DOMAIN}/users/{username}/followers", + "following": f"https://{DOMAIN}/users/{username}/following", + } + + # Add public key if available + from keys import has_keys, load_public_key_pem + if has_keys(DATA_DIR, username): + actor["publicKey"] = { + "id": f"https://{DOMAIN}/users/{username}#main-key", + "owner": f"https://{DOMAIN}/users/{username}", + "publicKeyPem": load_public_key_pem(DATA_DIR, username) + } + + return actor + + +async def user_exists(username: str) -> bool: + """Check if a user exists.""" + return await db.user_exists(username) + + +async def load_followers() -> list: + """Load followers list from database.""" + return await db.get_all_followers() + + +# ============ Signing ============ + +from keys import has_keys, load_public_key_pem, create_signature + + +def sign_activity(activity: dict, username: str) -> dict: + """Sign an activity with the user's RSA private key.""" + if not has_keys(DATA_DIR, username): + # No keys - use placeholder (for testing) + activity["signature"] = { + "type": "RsaSignature2017", + "creator": f"https://{DOMAIN}/users/{username}#main-key", + "created": datetime.now(timezone.utc).isoformat(), + "signatureValue": "NO_KEYS_CONFIGURED" + } + else: + activity["signature"] = create_signature(DATA_DIR, username, DOMAIN, activity) + return activity + + +# ============ HTML Templates ============ + +# Tailwind CSS config for L2 - dark theme to match L1 +TAILWIND_CONFIG = ''' + + + +''' + + +def base_html(title: str, content: str, username: str = None) -> str: + """Base HTML template with Tailwind CSS dark theme.""" + user_section = f''' +
+ Logged in as {username} + + Logout + +
+ ''' if username else ''' +
+ Login + | + Register +
+ ''' + + return f''' + + + + + {title} - Art DAG L2 + {TAILWIND_CONFIG} + + +
+
+

+ Art DAG L2 +

+ {user_section} +
+ + + +
+ {content} +
+
+ +''' + + +def get_user_from_cookie(request: Request) -> Optional[str]: + """Get username from auth cookie.""" + token = request.cookies.get("auth_token") + if token: + return verify_token(token) + return None + + +def wants_html(request: Request) -> bool: + """Check if request wants HTML (browser) vs JSON (API).""" + accept = request.headers.get("accept", "") + return "text/html" in accept and "application/json" not in accept and "application/activity+json" not in accept + + +def format_date(value, length: int = 10) -> str: + """Format a date value (datetime or string) to a string, sliced to length.""" + if value is None: + return "" + if hasattr(value, 'isoformat'): + return value.isoformat()[:length] + if isinstance(value, str): + return value[:length] + return "" + + +# ============ Auth UI Endpoints ============ + +@app.get("/login", response_class=HTMLResponse) +async def ui_login_page(request: Request, return_to: str = None): + """Login page. Accepts optional return_to URL for redirect after login.""" + username = get_user_from_cookie(request) + if username: + return HTMLResponse(base_html("Already Logged In", f''' +
+ You are already logged in as {username} +
+

Go to home page

+ ''', username)) + + # Hidden field for return_to URL + return_to_field = f'' if return_to else '' + + content = f''' +

Login

+
+
+ {return_to_field} +
+ + +
+
+ + +
+ +
+

Don't have an account? Register

+ ''' + return HTMLResponse(base_html("Login", content)) + + +@app.post("/login", response_class=HTMLResponse) +async def ui_login_submit(request: Request): + """Handle login form submission.""" + form = await request.form() + username = form.get("username", "").strip() + password = form.get("password", "") + return_to = form.get("return_to", "").strip() + + if not username or not password: + return HTMLResponse('
Username and password are required
') + + user = await authenticate_user(DATA_DIR, username, password) + if not user: + return HTMLResponse('
Invalid username or password
') + + token = create_access_token(user.username, l2_server=f"https://{DOMAIN}") + + # If return_to is specified, redirect there with token for the other site to set its own cookie + if return_to and return_to.startswith("http"): + # Append token to return_to URL for the target site to set its own cookie + separator = "&" if "?" in return_to else "?" + redirect_url = f"{return_to}{separator}auth_token={token.access_token}" + response = HTMLResponse(f''' +
Login successful! Redirecting...
+ + ''') + else: + response = HTMLResponse(f''' +
Login successful! Redirecting...
+ + ''') + + # Set cookie for L2 only (L1 servers set their own cookies via /auth endpoint) + response.set_cookie( + key="auth_token", + value=token.access_token, + httponly=True, + max_age=60 * 60 * 24 * 30, # 30 days + samesite="lax", + secure=True + ) + return response + + +@app.get("/register", response_class=HTMLResponse) +async def ui_register_page(request: Request): + """Register page.""" + username = get_user_from_cookie(request) + if username: + return HTMLResponse(base_html("Already Logged In", f''' +
+ You are already logged in as {username} +
+

Go to home page

+ ''', username)) + + content = ''' +

Register

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+

Already have an account? Login

+ ''' + return HTMLResponse(base_html("Register", content)) + + +@app.post("/register", response_class=HTMLResponse) +async def ui_register_submit(request: Request): + """Handle register form submission.""" + form = await request.form() + username = form.get("username", "").strip() + email = form.get("email", "").strip() or None + password = form.get("password", "") + password2 = form.get("password2", "") + + if not username or not password: + return HTMLResponse('
Username and password are required
') + + if password != password2: + return HTMLResponse('
Passwords do not match
') + + if len(password) < 6: + return HTMLResponse('
Password must be at least 6 characters
') + + try: + user = await create_user(DATA_DIR, username, password, email) + except ValueError as e: + return HTMLResponse(f'
{str(e)}
') + + token = create_access_token(user.username, l2_server=f"https://{DOMAIN}") + + response = HTMLResponse(f''' +
Registration successful! Redirecting...
+ + ''') + response.set_cookie( + key="auth_token", + value=token.access_token, + httponly=True, + max_age=60 * 60 * 24 * 30, # 30 days + samesite="lax", + secure=True + ) + return response + + +@app.get("/logout") +async def logout(request: Request): + """Handle logout - clear cookie, revoke token on L2 and attached L1s, and redirect to home.""" + token = request.cookies.get("auth_token") + claims = get_token_claims(token) if token else None + username = claims.get("sub") if claims else None + + if username and token and claims: + # Revoke token in L2 database (so even if L1 ignores revoke, token won't verify) + token_hash = hashlib.sha256(token.encode()).hexdigest() + expires_at = datetime.fromtimestamp(claims.get("exp", 0), tz=timezone.utc) + await db.revoke_token(token_hash, username, expires_at) + + # Revoke ALL tokens for this user on attached L1 renderers + # (L1 may have scoped tokens different from L2's token) + attached = await db.get_user_renderers(username) + for l1_url in attached: + try: + requests.post( + f"{l1_url}/auth/revoke-user", + json={"username": username, "l2_server": f"https://{DOMAIN}"}, + timeout=5 + ) + except Exception as e: + logger.warning(f"Failed to revoke user tokens on {l1_url}: {e}") + + # Remove all attachments for this user + for l1_url in attached: + await db.detach_renderer(username, l1_url) + + response = RedirectResponse(url="/", status_code=302) + # Delete both legacy (no domain) and new (shared domain) cookies + response.delete_cookie("auth_token") + if COOKIE_DOMAIN: + response.delete_cookie("auth_token", domain=COOKIE_DOMAIN) + return response + + +# ============ HTML Rendering Helpers ============ + +async def ui_activity_detail(activity_index: int, request: Request): + """Activity detail page with full content display. Helper function for HTML rendering.""" + username = get_user_from_cookie(request) + activities = await load_activities() + + if activity_index < 0 or activity_index >= len(activities): + content = ''' +

Activity Not Found

+

This activity does not exist.

+

← Back to Activities

+ ''' + return HTMLResponse(base_html("Activity Not Found", content, username)) + + activity = activities[activity_index] + return await _render_activity_detail(activity, request) + + +async def ui_activity_detail_by_data(activity: dict, request: Request): + """Activity detail page taking activity data directly.""" + return await _render_activity_detail(activity, request) + + +async def _render_activity_detail(activity: dict, request: Request): + """Core activity detail rendering logic.""" + username = get_user_from_cookie(request) + activity_type = activity.get("activity_type", "") + activity_id = activity.get("activity_id", "") + actor_id = activity.get("actor_id", "") + actor_name = actor_id.split("/")[-1] if actor_id else "unknown" + published = format_date(activity.get("published")) + obj = activity.get("object_data", {}) + + # Object details + obj_name = obj.get("name", "Untitled") + obj_type = obj.get("type", "") + content_hash_obj = obj.get("contentHash", {}) + content_hash = content_hash_obj.get("value", "") if isinstance(content_hash_obj, dict) else "" + media_type = obj.get("mediaType", "") + description = obj.get("summary", "") or obj.get("content", "") + + # Provenance from object - or fallback to registry asset + provenance = obj.get("provenance", {}) + origin = obj.get("origin", {}) + + # Fallback: if activity doesn't have provenance, look up the asset from registry + if not provenance or not origin: + registry = await load_registry() + assets = registry.get("assets", {}) + # Find asset by content_hash or name + for asset_name, asset_data in assets.items(): + if asset_data.get("content_hash") == content_hash or asset_data.get("name") == obj_name: + if not provenance: + provenance = asset_data.get("provenance", {}) + if not origin: + origin = asset_data.get("origin", {}) + break + + # Type colors + type_color = "bg-green-600" if activity_type == "Create" else "bg-yellow-600" if activity_type == "Update" else "bg-gray-600" + obj_type_color = "bg-blue-600" if "Image" in obj_type else "bg-purple-600" if "Video" in obj_type else "bg-gray-600" + + # Determine L1 server and asset type + l1_server = provenance.get("l1_server", L1_PUBLIC_URL).rstrip("/") if provenance else L1_PUBLIC_URL.rstrip("/") + is_video = "Video" in obj_type or "video" in media_type + + # Content display + if is_video: + content_html = f''' + + ''' + elif "Image" in obj_type or "image" in media_type: + content_html = f''' +
+ {obj_name} + +
+ ''' + else: + content_html = f''' +
+

Content type: {media_type or obj_type}

+ + Download + +
+ ''' + + # Origin display + origin_html = 'Not specified' + if origin: + origin_type = origin.get("type", "") + if origin_type == "self": + origin_html = 'Original content by author' + elif origin_type == "external": + origin_url = origin.get("url", "") + origin_note = origin.get("note", "") + origin_html = f'{origin_url}' + if origin_note: + origin_html += f'

{origin_note}

' + + # Provenance section + provenance_html = "" + if provenance and provenance.get("recipe"): + recipe = provenance.get("recipe", "") + inputs = provenance.get("inputs", []) + l1_run_id = provenance.get("l1_run_id", "") + rendered_at = format_date(provenance.get("rendered_at")) + effects_commit = provenance.get("effects_commit", "") + effect_url = provenance.get("effect_url") + infrastructure = provenance.get("infrastructure", {}) + + if not effect_url: + if effects_commit and effects_commit != "unknown": + effect_url = f"{EFFECTS_REPO_URL}/src/commit/{effects_commit}/{recipe}" + else: + effect_url = f"{EFFECTS_REPO_URL}/src/branch/main/{recipe}" + + # Build inputs display - show actual content as thumbnails + inputs_html = "" + for inp in inputs: + inp_hash = inp.get("content_hash", "") if isinstance(inp, dict) else inp + if inp_hash: + inputs_html += f''' +
+
+ + Input +
+
+ {inp_hash[:16]}... + view +
+
+ ''' + + # Infrastructure display + infra_html = "" + if infrastructure: + software = infrastructure.get("software", {}) + hardware = infrastructure.get("hardware", {}) + if software or hardware: + infra_parts = [] + if software: + infra_parts.append(f"Software: {software.get('name', 'unknown')}") + if hardware: + infra_parts.append(f"Hardware: {hardware.get('name', 'unknown')}") + infra_html = f'

{" | ".join(infra_parts)}

' + + provenance_html = f''' +
+

Provenance

+

This content was created by applying an effect to input content.

+
+
+

Effect

+ + + + + {recipe} + + {f'
Commit: {effects_commit[:12]}...
' if effects_commit else ''} +
+
+

Input(s)

+ {inputs_html if inputs_html else 'No inputs recorded'} +
+
+

L1 Run

+ {l1_run_id[:20]}... +
+
+

Rendered

+ {rendered_at if rendered_at else 'Unknown'} + {infra_html} +
+
+
+ ''' + + content = f''' +

← Back to Activities

+ +
+ {activity_type} +

{obj_name}

+ {obj_type} +
+ + {content_html} + +
+
+
+

Actor

+ {actor_name} +
+ +
+

Description

+

{description if description else 'No description'}

+
+ +
+

Origin

+ {origin_html} +
+
+ +
+
+

Content Hash

+ {content_hash} +
+ +
+

Published

+ {published} +
+ +
+

Activity ID

+ {activity_id} +
+
+
+ + {provenance_html} + +
+

ActivityPub

+
+

+ Object URL: + https://{DOMAIN}/objects/{content_hash} +

+

+ Actor: + {actor_id} +

+
+
+ ''' + return HTMLResponse(base_html(f"Activity: {obj_name}", content, username)) + + +async def ui_asset_detail(name: str, request: Request): + """Asset detail page with content preview and provenance. Helper function for HTML rendering.""" + username = get_user_from_cookie(request) + registry = await load_registry() + assets = registry.get("assets", {}) + + if name not in assets: + content = f''' +

Asset Not Found

+

No asset named "{name}" exists.

+

← Back to Assets

+ ''' + return HTMLResponse(base_html("Asset Not Found", content, username)) + + asset = assets[name] + owner = asset.get("owner", "unknown") + content_hash = asset.get("content_hash", "") + ipfs_cid = asset.get("ipfs_cid", "") + asset_type = asset.get("asset_type", "") + tags = asset.get("tags", []) + description = asset.get("description", "") + origin = asset.get("origin") or {} + provenance = asset.get("provenance") or {} + metadata = asset.get("metadata") or {} + created_at = format_date(asset.get("created_at")) + + type_color = "bg-blue-600" if asset_type == "image" else "bg-purple-600" if asset_type == "video" else "bg-gray-600" + + # Determine L1 server URL for content + l1_server = provenance.get("l1_server", L1_PUBLIC_URL).rstrip("/") + + # Content display - image or video from L1 + if asset_type == "video": + # Use iOS-compatible MP4 endpoint + content_html = f''' + + ''' + elif asset_type == "image": + content_html = f''' +
+ {name} + +
+ ''' + elif asset_type == "recipe": + # Fetch recipe source from L1 or IPFS + recipe_source = "" + try: + resp = requests.get(f"{l1_server}/cache/{content_hash}", timeout=10, headers={"Accept": "text/plain"}) + if resp.status_code == 200: + recipe_source = resp.text + except Exception: + pass + + if not recipe_source and ipfs_cid: + # Try IPFS + try: + import ipfs_client + recipe_bytes = ipfs_client.get_bytes(ipfs_cid) + if recipe_bytes: + recipe_source = recipe_bytes.decode('utf-8') + except Exception: + pass + + import html as html_module + recipe_source_escaped = html_module.escape(recipe_source) if recipe_source else "(Could not load recipe source)" + + content_html = f''' +
+

Recipe Source

+
{recipe_source_escaped}
+ +
+ ''' + else: + content_html = f''' +
+

Content type: {asset_type}

+ + Download + +
+ ''' + + # Origin display + origin_html = 'Not specified' + if origin: + origin_type = origin.get("type", "unknown") + if origin_type == "self": + origin_html = 'Original content by author' + elif origin_type == "external": + origin_url = origin.get("url", "") + origin_note = origin.get("note", "") + origin_html = f'{origin_url}' + if origin_note: + origin_html += f'

{origin_note}

' + + # Tags display + tags_html = 'No tags' + if tags: + tags_html = " ".join([f'{t}' for t in tags]) + + # IPFS display + if ipfs_cid: + local_gateway = f'Local' if IPFS_GATEWAY_URL else '' + ipfs_html = f'''{ipfs_cid} +
+ {local_gateway} + ipfs.io + dweb.link +
''' + else: + ipfs_html = 'Not on IPFS' + + # Provenance section - for rendered outputs + provenance_html = "" + if provenance: + recipe = provenance.get("recipe", "") + inputs = provenance.get("inputs", []) + l1_run_id = provenance.get("l1_run_id", "") + rendered_at = format_date(provenance.get("rendered_at")) + effects_commit = provenance.get("effects_commit", "") + infrastructure = provenance.get("infrastructure", {}) + + # Use stored effect_url or build fallback + effect_url = provenance.get("effect_url") + if not effect_url: + # Fallback for older records + if effects_commit and effects_commit != "unknown": + effect_url = f"{EFFECTS_REPO_URL}/src/commit/{effects_commit}/{recipe}" + else: + effect_url = f"{EFFECTS_REPO_URL}/src/branch/main/{recipe}" + + # Build inputs display - show actual content as thumbnails + inputs_html = "" + for inp in inputs: + inp_hash = inp.get("content_hash", "") if isinstance(inp, dict) else inp + if inp_hash: + inputs_html += f''' +
+
+ + Input +
+
+ {inp_hash[:16]}... + view +
+
+ ''' + + # Infrastructure display + infra_html = "" + if infrastructure: + software = infrastructure.get("software", {}) + hardware = infrastructure.get("hardware", {}) + if software or hardware: + infra_html = f''' +
+

Infrastructure

+
+ {f"Software: {software.get('name', 'unknown')}" if software else ""} + {f" ({software.get('content_hash', '')[:16]}...)" if software.get('content_hash') else ""} + {" | " if software and hardware else ""} + {f"Hardware: {hardware.get('name', 'unknown')}" if hardware else ""} +
+
+ ''' + + provenance_html = f''' +
+

Provenance

+

This asset was created by applying an effect to input content.

+
+
+

Effect

+ + + + + {recipe} + + {f'
Commit: {effects_commit[:12]}...
' if effects_commit else ''} +
+
+

Input(s)

+ {inputs_html if inputs_html else 'No inputs recorded'} +
+
+

L1 Run

+ {l1_run_id[:16]}... +
+
+

Rendered

+ {rendered_at if rendered_at else 'Unknown'} +
+ {infra_html} +
+
+ ''' + + content = f''' +

← Back to Assets

+ +
+

{name}

+ {asset_type} +
+ + {content_html} + +
+
+
+

Owner

+ {owner} +
+ +
+

Description

+

{description if description else 'No description'}

+
+ +
+

Origin

+ {origin_html} +
+
+ +
+
+

Content Hash

+ {content_hash} +
+ +
+

IPFS

+ {ipfs_html} +
+ +
+

Created

+ {created_at} +
+ +
+

Tags

+
{tags_html}
+
+
+
+ + {provenance_html} + +
+

ActivityPub

+
+

+ Object URL: + https://{DOMAIN}/objects/{content_hash} +

+

+ Owner Actor: + https://{DOMAIN}/users/{owner} +

+
+
+ ''' + return HTMLResponse(base_html(f"Asset: {name}", content, username)) + + +async def ui_user_detail(username: str, request: Request): + """User detail page showing their published assets. Helper function for HTML rendering.""" + current_user = get_user_from_cookie(request) + + if not await user_exists(username): + content = f''' +

User Not Found

+

No user named "{username}" exists.

+

← Back to Users

+ ''' + return HTMLResponse(base_html("User Not Found", content, current_user)) + + # Get user's assets + registry = await load_registry() + all_assets = registry.get("assets", {}) + user_assets = {name: asset for name, asset in all_assets.items() if asset.get("owner") == username} + + # Get user's activities + all_activities = await load_activities() + actor_id = f"https://{DOMAIN}/users/{username}" + user_activities = [a for a in all_activities if a.get("actor_id") == actor_id] + + webfinger = f"@{username}@{DOMAIN}" + + # Assets table + if user_assets: + rows = "" + for name, asset in sorted(user_assets.items(), key=lambda x: x[1].get("created_at", ""), reverse=True): + hash_short = asset.get("content_hash", "")[:16] + "..." + asset_type = asset.get("asset_type", "") + type_color = "bg-blue-600" if asset_type == "image" else "bg-purple-600" if asset_type == "video" else "bg-gray-600" + rows += f''' + + + {name} + + {asset_type} + {hash_short} + {", ".join(asset.get("tags", []))} + + ''' + assets_html = f''' +
+ + + + + + + + + + + {rows} + +
NameTypeContent HashTags
+
+ ''' + else: + assets_html = '

No published assets yet.

' + + content = f''' +

← Back to Users

+ +
+

{username}

+ {webfinger} +
+ +
+
+
{len(user_assets)}
+
Published Assets
+
+
+
{len(user_activities)}
+
Activities
+
+
+ +
+

ActivityPub

+

+ Actor URL: https://{DOMAIN}/users/{username} +

+

+ Outbox: https://{DOMAIN}/users/{username}/outbox +

+
+ +

Published Assets ({len(user_assets)})

+ {assets_html} + ''' + return HTMLResponse(base_html(f"User: {username}", content, current_user)) + + +# ============ API Endpoints ============ + +@app.get("/") +async def root(request: Request): + """Server info. HTML shows home page with counts, JSON returns stats.""" + registry = await load_registry() + activities = await load_activities() + users = await db.get_all_users() + + assets_count = len(registry.get("assets", {})) + activities_count = len(activities) + users_count = len(users) + + if wants_html(request): + username = get_user_from_cookie(request) + readme_html = markdown.markdown(README_CONTENT, extensions=['tables', 'fenced_code']) + content = f''' + +
+ {readme_html} +
+ ''' + return HTMLResponse(base_html("Home", content, username)) + + return { + "name": "Art DAG L2 Server", + "version": "0.1.0", + "domain": DOMAIN, + "assets_count": assets_count, + "activities_count": activities_count, + "users_count": users_count + } + + +# ============ Auth Endpoints ============ + +security = HTTPBearer(auto_error=False) + + +async def get_optional_user( + credentials: HTTPAuthorizationCredentials = Depends(security) +) -> Optional[User]: + """Get current user if authenticated, None otherwise.""" + if not credentials: + return None + return await get_current_user(DATA_DIR, credentials.credentials) + + +async def get_required_user( + credentials: HTTPAuthorizationCredentials = Depends(security) +) -> User: + """Get current user, raise 401 if not authenticated.""" + if not credentials: + raise HTTPException(401, "Not authenticated") + user = await get_current_user(DATA_DIR, credentials.credentials) + if not user: + raise HTTPException(401, "Invalid token") + return user + + +@app.post("/auth/register", response_model=Token) +async def register(req: UserCreate): + """Register a new user.""" + try: + user = await create_user(DATA_DIR, req.username, req.password, req.email) + except ValueError as e: + raise HTTPException(400, str(e)) + + return create_access_token(user.username, l2_server=f"https://{DOMAIN}") + + +@app.post("/auth/login", response_model=Token) +async def login(req: UserLogin): + """Login and get access token.""" + user = await authenticate_user(DATA_DIR, req.username, req.password) + if not user: + raise HTTPException(401, "Invalid username or password") + + return create_access_token(user.username, l2_server=f"https://{DOMAIN}") + + +@app.get("/auth/me") +async def get_me(user: User = Depends(get_required_user)): + """Get current user info.""" + return { + "username": user.username, + "email": user.email, + "created_at": user.created_at + } + + +class VerifyRequest(BaseModel): + l1_server: str # URL of the L1 server requesting verification + + +@app.post("/auth/verify") +async def verify_auth( + request: VerifyRequest, + credentials: HTTPAuthorizationCredentials = Depends(security) +): + """Verify a token and return username. Only authorized L1 servers can call this.""" + if not credentials: + raise HTTPException(401, "No token provided") + + token = credentials.credentials + + # Check L1 is authorized + l1_normalized = request.l1_server.rstrip("/") + authorized = any(l1_normalized == s.rstrip("/") for s in L1_SERVERS) + if not authorized: + raise HTTPException(403, f"L1 server not authorized: {request.l1_server}") + + # Check if token is revoked (L2-side revocation) + token_hash = hashlib.sha256(token.encode()).hexdigest() + if await db.is_token_revoked(token_hash): + raise HTTPException(401, "Token has been revoked") + + # Verify token and get claims + claims = get_token_claims(token) + if not claims: + raise HTTPException(401, "Invalid token") + + username = claims.get("sub") + if not username: + raise HTTPException(401, "Invalid token") + + # Check token scope - if token is scoped to an L1, it must match + token_l1_server = claims.get("l1_server") + if token_l1_server: + token_l1_normalized = token_l1_server.rstrip("/") + if token_l1_normalized != l1_normalized: + raise HTTPException(403, f"Token is scoped to {token_l1_server}, not {request.l1_server}") + + # Record the attachment (L1 successfully verified user's token) + await db.attach_renderer(username, l1_normalized) + + return {"username": username, "valid": True, "l1_server": request.l1_server} + + +@app.get("/.well-known/webfinger") +async def webfinger(resource: str): + """WebFinger endpoint for actor discovery.""" + # Parse acct:username@domain + if not resource.startswith("acct:"): + raise HTTPException(400, "Resource must be acct: URI") + + acct = resource[5:] # Remove "acct:" + if "@" not in acct: + raise HTTPException(400, "Invalid acct format") + + username, domain = acct.split("@", 1) + + if domain != DOMAIN: + raise HTTPException(404, f"Unknown domain: {domain}") + + if not await user_exists(username): + raise HTTPException(404, f"Unknown user: {username}") + + return JSONResponse( + content={ + "subject": resource, + "links": [ + { + "rel": "self", + "type": "application/activity+json", + "href": f"https://{DOMAIN}/users/{username}" + } + ] + }, + media_type="application/jrd+json" + ) + + +@app.get("/users") +async def get_users_list(request: Request, page: int = 1, limit: int = 20): + """Get all users. HTML for browsers (with infinite scroll), JSON for APIs (with pagination).""" + all_users = list((await db.get_all_users()).items()) + total = len(all_users) + + # Sort by username + all_users.sort(key=lambda x: x[0]) + + # Pagination + start = (page - 1) * limit + end = start + limit + users_page = all_users[start:end] + has_more = end < total + + if wants_html(request): + username = get_user_from_cookie(request) + + if not users_page: + if page == 1: + content = ''' +

Users

+

No users registered yet.

+ ''' + else: + return HTMLResponse("") # Empty for infinite scroll + else: + rows = "" + for uname, user_data in users_page: + webfinger = f"@{uname}@{DOMAIN}" + created_at = format_date(user_data.get("created_at")) + rows += f''' + + + {uname} + + {webfinger} + {created_at} + + ''' + + # For infinite scroll, just return rows if not first page + if page > 1: + if has_more: + rows += f''' + + Loading more... + + ''' + return HTMLResponse(rows) + + # First page - full content + infinite_scroll_trigger = "" + if has_more: + infinite_scroll_trigger = f''' + + Loading more... + + ''' + + content = f''' +

Users ({total} total)

+
+ + + + + + + + + + {rows} + {infinite_scroll_trigger} + +
UsernameWebFingerCreated
+
+ ''' + + return HTMLResponse(base_html("Users", content, username)) + + # JSON response for APIs + return { + "users": [{"username": uname, **data} for uname, data in users_page], + "pagination": { + "page": page, + "limit": limit, + "total": total, + "has_more": has_more + } + } + + +@app.get("/users/{username}") +async def get_actor(username: str, request: Request): + """Get actor profile for any registered user. Content negotiation: HTML for browsers, JSON for APIs.""" + if not await user_exists(username): + if wants_html(request): + content = f''' +

User Not Found

+

No user named "{username}" exists.

+

← Back to Users

+ ''' + return HTMLResponse(base_html("User Not Found", content, get_user_from_cookie(request))) + raise HTTPException(404, f"Unknown user: {username}") + + if wants_html(request): + # Render user detail page + return await ui_user_detail(username, request) + + actor = load_actor(username) + + # Add ActivityPub context + actor["@context"] = [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1" + ] + + return JSONResponse( + content=actor, + media_type="application/activity+json" + ) + + +@app.get("/users/{username}/outbox") +async def get_outbox(username: str, page: bool = False): + """Get actor's outbox (activities they created).""" + if not await user_exists(username): + raise HTTPException(404, f"Unknown user: {username}") + + # Filter activities by this user's actor_id + all_activities = await load_activities() + actor_id = f"https://{DOMAIN}/users/{username}" + user_activities = [a for a in all_activities if a.get("actor_id") == actor_id] + + if not page: + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "id": f"https://{DOMAIN}/users/{username}/outbox", + "type": "OrderedCollection", + "totalItems": len(user_activities), + "first": f"https://{DOMAIN}/users/{username}/outbox?page=true" + }, + media_type="application/activity+json" + ) + + # Return activities page + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "id": f"https://{DOMAIN}/users/{username}/outbox?page=true", + "type": "OrderedCollectionPage", + "partOf": f"https://{DOMAIN}/users/{username}/outbox", + "orderedItems": user_activities + }, + media_type="application/activity+json" + ) + + +@app.post("/users/{username}/inbox") +async def post_inbox(username: str, request: Request): + """Receive activities from other servers.""" + if not await user_exists(username): + raise HTTPException(404, f"Unknown user: {username}") + + body = await request.json() + activity_type = body.get("type") + + # Handle Follow requests + if activity_type == "Follow": + follower_url = body.get("actor") + # Add follower to database + await db.add_follower(username, follower_url, follower_url) + + # Send Accept (in production, do this async) + # For now just acknowledge + return {"status": "accepted"} + + # Handle other activity types + return {"status": "received"} + + +@app.get("/users/{username}/followers") +async def get_followers(username: str): + """Get actor's followers.""" + if not await user_exists(username): + raise HTTPException(404, f"Unknown user: {username}") + + # TODO: Per-user followers - for now use global followers + followers = await load_followers() + + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "id": f"https://{DOMAIN}/users/{username}/followers", + "type": "OrderedCollection", + "totalItems": len(followers), + "orderedItems": followers + }, + media_type="application/activity+json" + ) + + +# ============ Assets Endpoints ============ + +@app.get("/assets") +async def get_registry(request: Request, page: int = 1, limit: int = 20): + """Get registry. HTML for browsers (with infinite scroll), JSON for APIs (with pagination).""" + registry = await load_registry() + all_assets = list(registry.get("assets", {}).items()) + total = len(all_assets) + + # Sort by created_at descending + all_assets.sort(key=lambda x: x[1].get("created_at", ""), reverse=True) + + # Pagination + start = (page - 1) * limit + end = start + limit + assets_page = all_assets[start:end] + has_more = end < total + + if wants_html(request): + username = get_user_from_cookie(request) + + if not assets_page: + if page == 1: + content = ''' +

Registry

+

No assets registered yet.

+ ''' + else: + return HTMLResponse("") # Empty for infinite scroll + else: + rows = "" + for name, asset in assets_page: + asset_type = asset.get("asset_type", "") + type_color = "bg-blue-600" if asset_type == "image" else "bg-purple-600" if asset_type == "video" else "bg-gray-600" + owner = asset.get("owner", "unknown") + content_hash = asset.get("content_hash", "")[:16] + "..." + rows += f''' + + {name} + {asset_type} + + {owner} + + {content_hash} + + View + + + ''' + + # For infinite scroll, just return rows if not first page + if page > 1: + if has_more: + rows += f''' + + Loading more... + + ''' + return HTMLResponse(rows) + + # First page - full content + infinite_scroll_trigger = "" + if has_more: + infinite_scroll_trigger = f''' + + Loading more... + + ''' + + content = f''' +

Registry ({total} assets)

+
+ + + + + + + + + + + + {rows} + {infinite_scroll_trigger} + +
NameTypeOwnerHash
+
+ ''' + + return HTMLResponse(base_html("Registry", content, username)) + + # JSON response for APIs + return { + "assets": {name: asset for name, asset in assets_page}, + "pagination": { + "page": page, + "limit": limit, + "total": total, + "has_more": has_more + } + } + + +@app.get("/asset/{name}") +async def get_asset_by_name_legacy(name: str): + """Legacy route - redirect to /assets/{name}.""" + return RedirectResponse(url=f"/assets/{name}", status_code=301) + + +@app.get("/assets/{name}") +async def get_asset(name: str, request: Request): + """Get asset by name. HTML for browsers (default), JSON only if explicitly requested.""" + registry = await load_registry() + + # Check if JSON explicitly requested + accept = request.headers.get("accept", "") + wants_json = "application/json" in accept and "text/html" not in accept + + if name not in registry.get("assets", {}): + if wants_json: + raise HTTPException(404, f"Asset not found: {name}") + content = f''' +

Asset Not Found

+

No asset named "{name}" exists.

+

← Back to Assets

+ ''' + return HTMLResponse(base_html("Asset Not Found", content, get_user_from_cookie(request))) + + if wants_json: + return registry["assets"][name] + + # Default to HTML for browsers + return await ui_asset_detail(name, request) + + +@app.get("/assets/by-run-id/{run_id}") +async def get_asset_by_run_id(run_id: str): + """ + Get asset by content-addressable run_id. + + Returns the asset info including output_hash and ipfs_cid for L1 recovery. + The run_id is stored in the asset's provenance when the run is recorded. + """ + asset = await db.get_asset_by_run_id(run_id) + if not asset: + raise HTTPException(404, f"No asset found for run_id: {run_id}") + + return { + "run_id": run_id, + "asset_name": asset.get("name"), + "output_hash": asset.get("content_hash"), + "ipfs_cid": asset.get("ipfs_cid"), + "provenance_cid": asset.get("provenance", {}).get("provenance_cid") if asset.get("provenance") else None, + } + + +@app.patch("/assets/{name}") +async def update_asset(name: str, req: UpdateAssetRequest, user: User = Depends(get_required_user)): + """Update an existing asset's metadata. Creates an Update activity.""" + asset = await db.get_asset(name) + if not asset: + raise HTTPException(404, f"Asset not found: {name}") + + # Check ownership + if asset.get("owner") != user.username: + raise HTTPException(403, f"Not authorized to update asset owned by {asset.get('owner')}") + + # Build updates dict + updates = {} + if req.description is not None: + updates["description"] = req.description + if req.tags is not None: + updates["tags"] = req.tags + if req.metadata is not None: + updates["metadata"] = {**asset.get("metadata", {}), **req.metadata} + if req.origin is not None: + updates["origin"] = req.origin + if req.ipfs_cid is not None: + updates["ipfs_cid"] = req.ipfs_cid + # Pin on IPFS (fire-and-forget, don't block) + import threading + threading.Thread(target=_pin_ipfs_async, args=(req.ipfs_cid,), daemon=True).start() + + # Update asset in database + updated_asset = await db.update_asset(name, updates) + + # Create Update activity + activity = { + "activity_id": str(uuid.uuid4()), + "activity_type": "Update", + "actor_id": f"https://{DOMAIN}/users/{user.username}", + "object_data": { + "type": updated_asset.get("asset_type", "Object").capitalize(), + "name": name, + "id": f"https://{DOMAIN}/objects/{updated_asset['content_hash']}", + "contentHash": { + "algorithm": "sha3-256", + "value": updated_asset["content_hash"] + }, + "attributedTo": f"https://{DOMAIN}/users/{user.username}", + "summary": req.description, + "tag": req.tags or updated_asset.get("tags", []) + }, + "published": updated_asset.get("updated_at", datetime.now(timezone.utc).isoformat()) + } + + # Sign activity with the user's keys + activity = sign_activity(activity, user.username) + + # Save activity to database + await db.create_activity(activity) + + return {"asset": updated_asset, "activity": activity} + + +def _pin_ipfs_async(cid: str): + """Pin IPFS content in background thread.""" + try: + import ipfs_client + if ipfs_client.is_available(): + ipfs_client.pin(cid) + logger.info(f"Pinned IPFS content: {cid}") + except Exception as e: + logger.warning(f"Failed to pin IPFS content {cid}: {e}") + + +async def _register_asset_impl(req: RegisterRequest, owner: str): + """ + Internal implementation for registering an asset atomically. + + Requires IPFS CID - content must be on IPFS before registering. + Uses a transaction for all DB operations. + """ + import ipfs_client + from ipfs_client import IPFSError + + logger.info(f"register_asset: Starting for {req.name} (hash={req.content_hash[:16]}...)") + + # ===== PHASE 1: VALIDATION ===== + # IPFS CID is required + if not req.ipfs_cid: + raise HTTPException(400, "IPFS CID is required for registration") + + # Check if name exists - return existing asset if so + existing = await db.get_asset(req.name) + if existing: + logger.info(f"register_asset: Asset {req.name} already exists, returning existing") + return {"asset": existing, "activity": None, "existing": True} + + # ===== PHASE 2: IPFS OPERATIONS (non-blocking) ===== + import asyncio + logger.info(f"register_asset: Pinning CID {req.ipfs_cid[:16]}... on IPFS") + try: + await asyncio.to_thread(ipfs_client.pin_or_raise, req.ipfs_cid) + logger.info("register_asset: CID pinned successfully") + except IPFSError as e: + logger.error(f"register_asset: IPFS pin failed: {e}") + raise HTTPException(500, f"IPFS operation failed: {e}") + + # ===== PHASE 3: DB TRANSACTION ===== + now = datetime.now(timezone.utc).isoformat() + + try: + async with db.transaction() as conn: + # Check name again inside transaction (race condition protection) + if await db.asset_exists_by_name_tx(conn, req.name): + # Race condition - another request created it first, return existing + existing = await db.get_asset(req.name) + logger.info(f"register_asset: Asset {req.name} created by concurrent request") + return {"asset": existing, "activity": None, "existing": True} + + # Create asset + asset = { + "name": req.name, + "content_hash": req.content_hash, + "ipfs_cid": req.ipfs_cid, + "asset_type": req.asset_type, + "tags": req.tags, + "metadata": req.metadata, + "url": req.url, + "provenance": req.provenance, + "owner": owner, + "created_at": now + } + created_asset = await db.create_asset_tx(conn, asset) + + # Create ownership activity + object_data = { + "type": req.asset_type.capitalize(), + "name": req.name, + "id": f"https://{DOMAIN}/objects/{req.content_hash}", + "contentHash": { + "algorithm": "sha3-256", + "value": req.content_hash + }, + "attributedTo": f"https://{DOMAIN}/users/{owner}" + } + + # Include provenance in activity object_data if present + if req.provenance: + object_data["provenance"] = req.provenance + + activity = { + "activity_id": req.content_hash, # Content-addressable by content hash + "activity_type": "Create", + "actor_id": f"https://{DOMAIN}/users/{owner}", + "object_data": object_data, + "published": now + } + activity = sign_activity(activity, owner) + created_activity = await db.create_activity_tx(conn, activity) + + # Transaction commits here on successful exit + + except HTTPException: + raise + except Exception as e: + logger.error(f"register_asset: Database transaction failed: {e}") + raise HTTPException(500, f"Failed to register asset: {e}") + + logger.info(f"register_asset: Successfully registered {req.name}") + return {"asset": created_asset, "activity": created_activity} + + +@app.post("/assets") +async def register_asset(req: RegisterRequest, user: User = Depends(get_required_user)): + """Register a new asset and create ownership activity. Requires authentication.""" + return await _register_asset_impl(req, user.username) + + +@app.post("/assets/record-run") +@app.post("/registry/record-run") # Legacy route +async def record_run(req: RecordRunRequest, user: User = Depends(get_required_user)): + """ + Record an L1 run and register the output atomically. + + Ensures all operations succeed or none do: + 1. All input assets registered (if not already on L2) + pinned on IPFS + 2. Output asset registered + pinned on IPFS + 3. Recipe serialized to JSON, stored on IPFS, CID saved in provenance + """ + import ipfs_client + from ipfs_client import IPFSError + + # ===== PHASE 1: PREPARATION (read-only, non-blocking) ===== + import asyncio + l1_url = req.l1_server.rstrip('/') + + logger.info(f"record_run: Starting for run_id={req.run_id} from {l1_url}") + + # Helper to fetch from L1 without blocking event loop + def fetch_l1_run(): + import time as _time + url = f"{l1_url}/runs/{req.run_id}" + logger.info(f"record_run: Fetching run from L1: {url}") + t0 = _time.time() + resp = requests.get(url, timeout=30) + logger.info(f"record_run: L1 request took {_time.time()-t0:.3f}s, status={resp.status_code}") + if resp.status_code == 404: + raise ValueError(f"Run not found on L1: {req.run_id}") + resp.raise_for_status() + try: + return resp.json() + except Exception: + body_preview = resp.text[:200] if resp.text else "(empty)" + logger.error(f"L1 returned non-JSON for {url}: status={resp.status_code}, body={body_preview}") + raise ValueError(f"L1 returned invalid response: {body_preview[:100]}") + + def fetch_l1_cache(content_hash): + logger.debug(f"record_run: Fetching cache {content_hash[:16]}... from L1") + url = f"{l1_url}/cache/{content_hash}" + resp = requests.get(url, headers={"Accept": "application/json"}, timeout=10) + if resp.status_code == 404: + raise ValueError(f"Cache item not found on L1: {content_hash[:16]}...") + resp.raise_for_status() + try: + return resp.json() + except Exception as e: + # Log what we actually got back + body_preview = resp.text[:200] if resp.text else "(empty)" + logger.error(f"L1 returned non-JSON for {url}: status={resp.status_code}, body={body_preview}") + raise ValueError(f"L1 returned invalid response (status={resp.status_code}): {body_preview[:100]}") + + # Fetch run from L1 + try: + run = await asyncio.to_thread(fetch_l1_run) + logger.info(f"record_run: Fetched run, status={run.get('status')}, inputs={len(run.get('inputs', []))}") + except Exception as e: + logger.error(f"record_run: Failed to fetch run from L1: {e}") + raise HTTPException(400, f"Failed to fetch run from L1 ({l1_url}): {e}") + + if run.get("status") != "completed": + raise HTTPException(400, f"Run not completed: {run.get('status')}") + + output_hash = run.get("output_hash") + if not output_hash: + raise HTTPException(400, "Run has no output hash") + + # Fetch output cache info from L1 (must exist - it's new) + logger.info(f"record_run: Fetching output cache {output_hash[:16]}... from L1") + try: + cache_info = await asyncio.to_thread(fetch_l1_cache, output_hash) + output_media_type = cache_info.get("media_type", "image") + output_ipfs_cid = cache_info.get("ipfs_cid") + logger.info(f"record_run: Output has IPFS CID: {output_ipfs_cid[:16] if output_ipfs_cid else 'None'}...") + except Exception as e: + logger.error(f"record_run: Failed to fetch output cache info: {e}") + raise HTTPException(400, f"Failed to fetch output cache info: {e}") + + if not output_ipfs_cid: + logger.error("record_run: Output has no IPFS CID") + raise HTTPException(400, "Output has no IPFS CID - cannot publish") + + # Gather input info: check L2 first, then fall back to L1 + input_hashes = run.get("inputs", []) + input_infos = [] # List of {content_hash, ipfs_cid, media_type, existing_asset} + logger.info(f"record_run: Gathering info for {len(input_hashes)} inputs") + + for input_hash in input_hashes: + # Check if already on L2 + existing = await db.get_asset_by_hash(input_hash) + if existing and existing.get("ipfs_cid"): + logger.info(f"record_run: Input {input_hash[:16]}... found on L2") + input_infos.append({ + "content_hash": input_hash, + "ipfs_cid": existing["ipfs_cid"], + "media_type": existing.get("asset_type", "image"), + "existing_asset": existing + }) + else: + # Not on L2, try L1 + logger.info(f"record_run: Input {input_hash[:16]}... not on L2, fetching from L1") + try: + inp_info = await asyncio.to_thread(fetch_l1_cache, input_hash) + ipfs_cid = inp_info.get("ipfs_cid") + if not ipfs_cid: + logger.error(f"record_run: Input {input_hash[:16]}... has no IPFS CID") + raise HTTPException(400, f"Input {input_hash[:16]}... has no IPFS CID (not on L2 or L1)") + input_infos.append({ + "content_hash": input_hash, + "ipfs_cid": ipfs_cid, + "media_type": inp_info.get("media_type", "image"), + "existing_asset": None + }) + except HTTPException: + raise + except Exception as e: + logger.error(f"record_run: Failed to fetch input {input_hash[:16]}... from L1: {e}") + raise HTTPException(400, f"Input {input_hash[:16]}... not on L2 and failed to fetch from L1: {e}") + + # Prepare recipe data + recipe_data = run.get("recipe") + if not recipe_data: + recipe_data = { + "name": run.get("recipe_name", "unknown"), + "effect_url": run.get("effect_url"), + "effects_commit": run.get("effects_commit"), + } + + # Build registered_inputs list - all referenced by content_hash + registered_inputs = [] + for inp in input_infos: + registered_inputs.append({ + "content_hash": inp["content_hash"], + "ipfs_cid": inp["ipfs_cid"] + }) + + # ===== PHASE 2: IPFS OPERATIONS (non-blocking for event loop) ===== + def do_ipfs_operations(): + """Run IPFS operations in thread pool to not block event loop.""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Collect all CIDs to pin (inputs + output) + cids_to_pin = [inp["ipfs_cid"] for inp in input_infos] + [output_ipfs_cid] + logger.info(f"record_run: Pinning {len(cids_to_pin)} CIDs on IPFS") + + # Pin all in parallel + with ThreadPoolExecutor(max_workers=5) as executor: + futures = {executor.submit(ipfs_client.pin_or_raise, cid): cid for cid in cids_to_pin} + for future in as_completed(futures): + future.result() # Raises IPFSError if failed + logger.info("record_run: All CIDs pinned successfully") + + # Store recipe on IPFS + logger.info("record_run: Storing recipe on IPFS") + recipe_cid = ipfs_client.add_json(recipe_data) + + # Build and store full provenance on IPFS + # Compute content-addressable run_id from inputs + recipe + recipe_name = recipe_data.get("name", "unknown") if isinstance(recipe_data, dict) else str(recipe_data) + run_id = compute_run_id(input_hashes, recipe_name) + provenance = { + "run_id": run_id, # Content-addressable run identifier + "inputs": registered_inputs, + "output": { + "content_hash": output_hash, + "ipfs_cid": output_ipfs_cid + }, + "recipe": recipe_data, + "recipe_cid": recipe_cid, + "effect_url": run.get("effect_url"), + "effects_commit": run.get("effects_commit"), + "l1_server": l1_url, + "l1_run_id": req.run_id, + "rendered_at": run.get("completed_at"), + "infrastructure": run.get("infrastructure") + } + logger.info("record_run: Storing provenance on IPFS") + provenance_cid = ipfs_client.add_json(provenance) + + return recipe_cid, provenance_cid, provenance + + try: + import asyncio + recipe_cid, provenance_cid, provenance = await asyncio.to_thread(do_ipfs_operations) + logger.info(f"record_run: Recipe CID: {recipe_cid[:16]}..., Provenance CID: {provenance_cid[:16]}...") + except IPFSError as e: + logger.error(f"record_run: IPFS operation failed: {e}") + raise HTTPException(500, f"IPFS operation failed: {e}") + + # ===== PHASE 3: DB TRANSACTION (all-or-nothing) ===== + logger.info("record_run: Starting DB transaction") + now = datetime.now(timezone.utc).isoformat() + + # Add provenance_cid to provenance for storage in DB + provenance["provenance_cid"] = provenance_cid + + try: + async with db.transaction() as conn: + # Register input assets (if not already on L2) - named by content_hash + for inp in input_infos: + if not inp["existing_asset"]: + media_type = inp["media_type"] + tags = ["auto-registered", "input"] + if media_type == "recipe": + tags.append("recipe") + input_asset = { + "name": inp["content_hash"], # Use content_hash as name + "content_hash": inp["content_hash"], + "ipfs_cid": inp["ipfs_cid"], + "asset_type": media_type, + "tags": tags, + "metadata": {"auto_registered_from_run": req.run_id}, + "owner": user.username, + "created_at": now + } + await db.create_asset_tx(conn, input_asset) + + # Check if output already exists (by content_hash) - return existing if so + existing = await db.get_asset_by_name_tx(conn, output_hash) + if existing: + logger.info(f"record_run: Output {output_hash[:16]}... already exists") + # Check if activity already exists for this run + existing_activity = await db.get_activity(provenance["run_id"]) + if existing_activity: + logger.info(f"record_run: Activity {provenance['run_id'][:16]}... also exists") + return {"asset": existing, "activity": existing_activity, "existing": True} + # Asset exists but no activity - create one + logger.info(f"record_run: Creating activity for existing asset") + object_data = { + "type": existing.get("asset_type", "image").capitalize(), + "name": output_hash, + "id": f"https://{DOMAIN}/objects/{output_hash}", + "contentHash": { + "algorithm": "sha3-256", + "value": output_hash + }, + "attributedTo": f"https://{DOMAIN}/users/{user.username}", + "provenance": provenance + } + activity = { + "activity_id": provenance["run_id"], + "activity_type": "Create", + "actor_id": f"https://{DOMAIN}/users/{user.username}", + "object_data": object_data, + "published": now + } + activity = sign_activity(activity, user.username) + created_activity = await db.create_activity_tx(conn, activity) + return {"asset": existing, "activity": created_activity, "existing": True} + + # Create output asset with provenance - named by content_hash + output_asset = { + "name": output_hash, # Use content_hash as name + "content_hash": output_hash, + "ipfs_cid": output_ipfs_cid, + "asset_type": output_media_type, + "tags": ["rendered", "l1"], + "metadata": {"l1_server": l1_url, "l1_run_id": req.run_id}, + "provenance": provenance, + "owner": user.username, + "created_at": now + } + created_asset = await db.create_asset_tx(conn, output_asset) + + # Create activity - all referenced by content_hash + object_data = { + "type": output_media_type.capitalize(), + "name": output_hash, # Use content_hash as name + "id": f"https://{DOMAIN}/objects/{output_hash}", + "contentHash": { + "algorithm": "sha3-256", + "value": output_hash + }, + "attributedTo": f"https://{DOMAIN}/users/{user.username}", + "provenance": provenance + } + + activity = { + "activity_id": provenance["run_id"], # Content-addressable run_id + "activity_type": "Create", + "actor_id": f"https://{DOMAIN}/users/{user.username}", + "object_data": object_data, + "published": now + } + activity = sign_activity(activity, user.username) + created_activity = await db.create_activity_tx(conn, activity) + + # Transaction commits here on successful exit + + except HTTPException: + raise + except Exception as e: + logger.error(f"record_run: Database transaction failed: {e}") + raise HTTPException(500, f"Failed to record run: {e}") + + logger.info(f"record_run: Successfully published {output_hash[:16]}... with {len(registered_inputs)} inputs") + return {"asset": created_asset, "activity": created_activity} + + +@app.post("/assets/publish-cache") +async def publish_cache(req: PublishCacheRequest, user: User = Depends(get_required_user)): + """ + Publish a cache item from L1 with metadata atomically. + + Requires origin to be set (self or external URL). + Requires IPFS CID - content must be on IPFS before publishing. + Creates a new asset and Create activity in a single transaction. + """ + import ipfs_client + from ipfs_client import IPFSError + + logger.info(f"publish_cache: Starting for {req.asset_name} (hash={req.content_hash[:16]}...)") + + # ===== PHASE 1: VALIDATION ===== + # Validate origin + if not req.origin or "type" not in req.origin: + raise HTTPException(400, "Origin is required for publishing (type: 'self' or 'external')") + + origin_type = req.origin.get("type") + if origin_type not in ("self", "external"): + raise HTTPException(400, "Origin type must be 'self' or 'external'") + + if origin_type == "external" and not req.origin.get("url"): + raise HTTPException(400, "External origin requires a URL") + + # IPFS CID is now required + if not req.ipfs_cid: + raise HTTPException(400, "IPFS CID is required for publishing") + + # Check if asset name already exists + if await db.asset_exists(req.asset_name): + raise HTTPException(400, f"Asset name already exists: {req.asset_name}") + + # ===== PHASE 2: IPFS OPERATIONS (non-blocking) ===== + import asyncio + logger.info(f"publish_cache: Pinning CID {req.ipfs_cid[:16]}... on IPFS") + try: + await asyncio.to_thread(ipfs_client.pin_or_raise, req.ipfs_cid) + logger.info("publish_cache: CID pinned successfully") + except IPFSError as e: + logger.error(f"publish_cache: IPFS pin failed: {e}") + raise HTTPException(500, f"IPFS operation failed: {e}") + + # ===== PHASE 3: DB TRANSACTION ===== + logger.info("publish_cache: Starting DB transaction") + now = datetime.now(timezone.utc).isoformat() + + try: + async with db.transaction() as conn: + # Check name again inside transaction (race condition protection) + if await db.asset_exists_by_name_tx(conn, req.asset_name): + raise HTTPException(400, f"Asset name already exists: {req.asset_name}") + + # Create asset + asset = { + "name": req.asset_name, + "content_hash": req.content_hash, + "ipfs_cid": req.ipfs_cid, + "asset_type": req.asset_type, + "tags": req.tags, + "description": req.description, + "origin": req.origin, + "metadata": req.metadata, + "owner": user.username, + "created_at": now + } + created_asset = await db.create_asset_tx(conn, asset) + + # Create ownership activity with origin info + object_data = { + "type": req.asset_type.capitalize(), + "name": req.asset_name, + "id": f"https://{DOMAIN}/objects/{req.content_hash}", + "contentHash": { + "algorithm": "sha3-256", + "value": req.content_hash + }, + "attributedTo": f"https://{DOMAIN}/users/{user.username}", + "tag": req.tags + } + + if req.description: + object_data["summary"] = req.description + + # Include origin in ActivityPub object + if origin_type == "self": + object_data["generator"] = { + "type": "Application", + "name": "Art DAG", + "note": "Original content created by the author" + } + else: + object_data["source"] = { + "type": "Link", + "href": req.origin.get("url"), + "name": req.origin.get("note", "External source") + } + + activity = { + "activity_id": req.content_hash, # Content-addressable by content hash + "activity_type": "Create", + "actor_id": f"https://{DOMAIN}/users/{user.username}", + "object_data": object_data, + "published": now + } + activity = sign_activity(activity, user.username) + created_activity = await db.create_activity_tx(conn, activity) + + # Transaction commits here on successful exit + + except HTTPException: + raise + except Exception as e: + logger.error(f"publish_cache: Database transaction failed: {e}") + raise HTTPException(500, f"Failed to publish cache item: {e}") + + logger.info(f"publish_cache: Successfully published {req.asset_name}") + return {"asset": created_asset, "activity": created_activity} + + +# ============ Activities Endpoints ============ + +@app.get("/activities") +async def get_activities(request: Request, page: int = 1, limit: int = 20): + """Get activities. HTML for browsers (with infinite scroll), JSON for APIs (with pagination).""" + all_activities = await load_activities() + total = len(all_activities) + + # Reverse for newest first + all_activities = list(reversed(all_activities)) + + # Pagination + start = (page - 1) * limit + end = start + limit + activities_page = all_activities[start:end] + has_more = end < total + + if wants_html(request): + username = get_user_from_cookie(request) + + if not activities_page: + if page == 1: + content = ''' +

Activities

+

No activities yet.

+ ''' + else: + return HTMLResponse("") # Empty for infinite scroll + else: + rows = "" + for i, activity in enumerate(activities_page): + activity_index = total - 1 - (start + i) # Original index + obj = activity.get("object_data", {}) + activity_type = activity.get("activity_type", "") + type_color = "bg-green-600" if activity_type == "Create" else "bg-yellow-600" if activity_type == "Update" else "bg-gray-600" + actor_id = activity.get("actor_id", "") + actor_name = actor_id.split("/")[-1] if actor_id else "unknown" + rows += f''' + + {activity_type} + {obj.get("name", "Untitled")} + + {actor_name} + + {format_date(activity.get("published"))} + + View + + + ''' + + # For infinite scroll, just return rows if not first page + if page > 1: + if has_more: + rows += f''' + + Loading more... + + ''' + return HTMLResponse(rows) + + # First page - full content with header + infinite_scroll_trigger = "" + if has_more: + infinite_scroll_trigger = f''' + + Loading more... + + ''' + + content = f''' +

Activities ({total} total)

+
+ + + + + + + + + + + + {rows} + {infinite_scroll_trigger} + +
TypeObjectActorPublished
+
+ ''' + + return HTMLResponse(base_html("Activities", content, username)) + + # JSON response for APIs + return { + "activities": activities_page, + "pagination": { + "page": page, + "limit": limit, + "total": total, + "has_more": has_more + } + } + + +@app.get("/activities/{activity_ref}") +async def get_activity_detail(activity_ref: str, request: Request): + """Get single activity by index or activity_id. HTML for browsers (default), JSON only if explicitly requested.""" + + # Check if JSON explicitly requested + accept = request.headers.get("accept", "") + wants_json = ("application/json" in accept or "application/activity+json" in accept) and "text/html" not in accept + + activity = None + activity_index = None + + # Check if it's a numeric index or an activity_id (hash) + if activity_ref.isdigit(): + # Numeric index (legacy) + activity_index = int(activity_ref) + activities = await load_activities() + if 0 <= activity_index < len(activities): + activity = activities[activity_index] + else: + # Activity ID (hash) - look up directly + activity = await db.get_activity(activity_ref) + if activity: + # Find index for UI rendering + activities = await load_activities() + for i, a in enumerate(activities): + if a.get("activity_id") == activity_ref: + activity_index = i + break + + if not activity: + if wants_json: + raise HTTPException(404, "Activity not found") + content = ''' +

Activity Not Found

+

This activity does not exist.

+

← Back to Activities

+ ''' + return HTMLResponse(base_html("Activity Not Found", content, get_user_from_cookie(request))) + + if wants_json: + return activity + + # Default to HTML for browsers + if activity_index is not None: + return await ui_activity_detail(activity_index, request) + else: + # Render activity directly if no index found + return await ui_activity_detail_by_data(activity, request) + + +@app.get("/activity/{activity_index}") +async def get_activity_legacy(activity_index: int): + """Legacy route - redirect to /activities/{activity_index}.""" + return RedirectResponse(url=f"/activities/{activity_index}", status_code=301) + + +@app.get("/objects/{content_hash}") +async def get_object(content_hash: str, request: Request): + """Get object by content hash. Content negotiation: HTML for browsers, JSON for APIs.""" + registry = await load_registry() + + # Find asset by hash + for name, asset in registry.get("assets", {}).items(): + if asset.get("content_hash") == content_hash: + # Check Accept header - only return JSON if explicitly requested + accept = request.headers.get("accept", "") + wants_json = ("application/json" in accept or "application/activity+json" in accept) and "text/html" not in accept + + if not wants_json: + # Default: redirect to detail page for browsers + return RedirectResponse(url=f"/assets/{name}", status_code=303) + + owner = asset.get("owner", "unknown") + return JSONResponse( + content={ + "@context": "https://www.w3.org/ns/activitystreams", + "id": f"https://{DOMAIN}/objects/{content_hash}", + "type": asset.get("asset_type", "Object").capitalize(), + "name": name, + "contentHash": { + "algorithm": "sha3-256", + "value": content_hash + }, + "attributedTo": f"https://{DOMAIN}/users/{owner}", + "published": asset.get("created_at") + }, + media_type="application/activity+json" + ) + + raise HTTPException(404, f"Object not found: {content_hash}") + + +# ============ Anchoring (Bitcoin timestamps) ============ + +@app.post("/anchors/create") +async def create_anchor_endpoint(request: Request): + """ + Create a new anchor for all unanchored activities. + + Builds a merkle tree, stores it on IPFS, and submits to OpenTimestamps + for Bitcoin anchoring. The anchor proof is backed up to persistent storage. + """ + import anchoring + import ipfs_client + + # Check auth (cookie or header) + username = get_user_from_cookie(request) + if not username: + if wants_html(request): + return HTMLResponse(''' +
+ Error: Login required +
+ ''') + raise HTTPException(401, "Authentication required") + + # Get unanchored activities + unanchored = await db.get_unanchored_activities() + if not unanchored: + if wants_html(request): + return HTMLResponse(''' +
+ No unanchored activities to anchor. +
+ ''') + return {"message": "No unanchored activities", "anchored": 0} + + activity_ids = [a["activity_id"] for a in unanchored] + + # Create anchor + anchor = await anchoring.create_anchor(activity_ids, db, ipfs_client) + + if anchor: + if wants_html(request): + return HTMLResponse(f''' +
+ Success! Anchored {len(activity_ids)} activities.
+ Merkle root: {anchor["merkle_root"][:32]}...
+ Refresh page to see the new anchor. +
+ ''') + return { + "message": f"Anchored {len(activity_ids)} activities", + "merkle_root": anchor["merkle_root"], + "tree_ipfs_cid": anchor.get("tree_ipfs_cid"), + "activity_count": anchor["activity_count"] + } + else: + if wants_html(request): + return HTMLResponse(''' +
+ Failed! Could not create anchor. +
+ ''') + raise HTTPException(500, "Failed to create anchor") + + +@app.get("/anchors") +async def list_anchors(): + """List all anchors.""" + anchors = await db.get_all_anchors() + stats = await db.get_anchor_stats() + return { + "anchors": anchors, + "stats": stats + } + + +@app.get("/anchors/{merkle_root}") +async def get_anchor_endpoint(merkle_root: str): + """Get anchor details by merkle root.""" + anchor = await db.get_anchor(merkle_root) + if not anchor: + raise HTTPException(404, f"Anchor not found: {merkle_root}") + return anchor + + +@app.get("/anchors/{merkle_root}/tree") +async def get_anchor_tree(merkle_root: str): + """Get the full merkle tree from IPFS.""" + import asyncio + import ipfs_client + + anchor = await db.get_anchor(merkle_root) + if not anchor: + raise HTTPException(404, f"Anchor not found: {merkle_root}") + + tree_cid = anchor.get("tree_ipfs_cid") + if not tree_cid: + raise HTTPException(404, "Anchor has no tree on IPFS") + + try: + tree_bytes = await asyncio.to_thread(ipfs_client.get_bytes, tree_cid) + if tree_bytes: + return json.loads(tree_bytes) + except Exception as e: + raise HTTPException(500, f"Failed to fetch tree from IPFS: {e}") + + +@app.get("/anchors/verify/{activity_id}") +async def verify_activity_anchor(activity_id: str): + """ + Verify an activity's anchor proof. + + Returns the merkle proof showing this activity is included in an anchored batch. + """ + import anchoring + import ipfs_client + + # Get activity + activity = await db.get_activity(activity_id) + if not activity: + raise HTTPException(404, f"Activity not found: {activity_id}") + + anchor_root = activity.get("anchor_root") + if not anchor_root: + return {"verified": False, "reason": "Activity not yet anchored"} + + # Get anchor + anchor = await db.get_anchor(anchor_root) + if not anchor: + return {"verified": False, "reason": "Anchor record not found"} + + # Get tree from IPFS (non-blocking) + import asyncio + tree_cid = anchor.get("tree_ipfs_cid") + if not tree_cid: + return {"verified": False, "reason": "Merkle tree not on IPFS"} + + try: + tree_bytes = await asyncio.to_thread(ipfs_client.get_bytes, tree_cid) + tree = json.loads(tree_bytes) if tree_bytes else None + except Exception: + return {"verified": False, "reason": "Failed to fetch tree from IPFS"} + + if not tree: + return {"verified": False, "reason": "Could not load merkle tree"} + + # Get proof + proof = anchoring.get_merkle_proof(tree, activity_id) + if not proof: + return {"verified": False, "reason": "Activity not in merkle tree"} + + # Verify proof + valid = anchoring.verify_merkle_proof(activity_id, proof, anchor_root) + + return { + "verified": valid, + "activity_id": activity_id, + "merkle_root": anchor_root, + "tree_ipfs_cid": tree_cid, + "proof": proof, + "bitcoin_txid": anchor.get("bitcoin_txid"), + "confirmed_at": anchor.get("confirmed_at") + } + + +@app.post("/anchors/{merkle_root}/upgrade") +async def upgrade_anchor_proof(merkle_root: str): + """ + Try to upgrade an OTS proof from pending to confirmed. + + Bitcoin confirmation typically takes 1-2 hours. Call this periodically + to check if the proof has been included in a Bitcoin block. + """ + import anchoring + import ipfs_client + import asyncio + + anchor = await db.get_anchor(merkle_root) + if not anchor: + raise HTTPException(404, f"Anchor not found: {merkle_root}") + + if anchor.get("confirmed_at"): + return {"status": "already_confirmed", "bitcoin_txid": anchor.get("bitcoin_txid")} + + # Get current OTS proof from IPFS + ots_cid = anchor.get("ots_proof_cid") + if not ots_cid: + return {"status": "no_proof", "message": "No OTS proof stored"} + + try: + ots_proof = await asyncio.to_thread(ipfs_client.get_bytes, ots_cid) + if not ots_proof: + return {"status": "error", "message": "Could not fetch OTS proof from IPFS"} + except Exception as e: + return {"status": "error", "message": f"IPFS error: {e}"} + + # Try to upgrade + upgraded = await asyncio.to_thread(anchoring.upgrade_ots_proof, ots_proof) + + if upgraded and len(upgraded) > len(ots_proof): + # Store upgraded proof on IPFS + try: + new_cid = await asyncio.to_thread(ipfs_client.add_bytes, upgraded) + # TODO: Update anchor record with new CID and confirmed status + return { + "status": "upgraded", + "message": "Proof upgraded - Bitcoin confirmation received", + "new_ots_cid": new_cid, + "proof_size": len(upgraded) + } + except Exception as e: + return {"status": "error", "message": f"Failed to store upgraded proof: {e}"} + else: + return { + "status": "pending", + "message": "Not yet confirmed on Bitcoin. Try again in ~1 hour.", + "proof_size": len(ots_proof) if ots_proof else 0 + } + + +@app.get("/anchors/ui", response_class=HTMLResponse) +async def anchors_ui(request: Request): + """Anchors UI page - view and test OpenTimestamps anchoring.""" + username = get_user_from_cookie(request) + + anchors = await db.get_all_anchors() + stats = await db.get_anchor_stats() + + # Build anchors table rows + rows = "" + for anchor in anchors: + status = "confirmed" if anchor.get("confirmed_at") else "pending" + status_class = "text-green-400" if status == "confirmed" else "text-yellow-400" + merkle_root = anchor.get("merkle_root", "")[:16] + "..." + + rows += f''' + + {merkle_root} + {anchor.get("activity_count", 0)} + {status} + {format_date(anchor.get("created_at"), 16)} + + + + + + ''' + + if not rows: + rows = 'No anchors yet' + + content = f''' + + +

Bitcoin Anchoring via OpenTimestamps

+ +
+
+
{stats.get("total_anchors", 0)}
+
Total Anchors
+
+
+
{stats.get("confirmed_anchors", 0)}
+
Confirmed
+
+
+
{stats.get("pending_anchors", 0)}
+
Pending
+
+
+ +
+

Test Anchoring

+

Create a test anchor for unanchored activities, or test the OTS connection.

+
+ + +
+
+
+ +

Anchors

+
+ + + + + + + + + + + + {rows} + +
Merkle RootActivitiesStatusCreatedActions
+
+ +
+

How it works:

+
    +
  1. Activities are batched and hashed into a merkle tree
  2. +
  3. The merkle root is submitted to OpenTimestamps
  4. +
  5. OTS aggregates hashes and anchors to Bitcoin (~1-2 hours)
  6. +
  7. Once confirmed, anyone can verify the timestamp
  8. +
+
+ ''' + + return HTMLResponse(base_html("Anchors", content, username)) + + +@app.post("/anchors/test-ots", response_class=HTMLResponse) +async def test_ots_connection(): + """Test OpenTimestamps connection by submitting a test hash.""" + import anchoring + import hashlib + import asyncio + + # Create a test hash + test_data = f"test-{datetime.now(timezone.utc).isoformat()}" + test_hash = hashlib.sha256(test_data.encode()).hexdigest() + + # Try to submit + try: + ots_proof = await asyncio.to_thread(anchoring.submit_to_opentimestamps, test_hash) + if ots_proof: + return HTMLResponse(f''' +
+ Success! OpenTimestamps is working.
+ Test hash: {test_hash[:32]}...
+ Proof size: {len(ots_proof)} bytes +
+ ''') + else: + return HTMLResponse(''' +
+ Failed! Could not reach OpenTimestamps servers. +
+ ''') + except Exception as e: + return HTMLResponse(f''' +
+ Error: {str(e)} +
+ ''') + + +# ============ Renderers (L1 servers) ============ + +@app.get("/renderers", response_class=HTMLResponse) +async def renderers_page(request: Request): + """Page to manage L1 renderer attachments.""" + username = get_user_from_cookie(request) + + if not username: + content = ''' +

Renderers

+

Log in to manage your renderer connections.

+ ''' + return HTMLResponse(base_html("Renderers", content)) + + # Get user's attached renderers + attached = await db.get_user_renderers(username) + from urllib.parse import quote + + # Build renderer list + rows = [] + for l1_url in L1_SERVERS: + is_attached = l1_url in attached + # Extract display name from URL + display_name = l1_url.replace("https://", "").replace("http://", "") + + if is_attached: + status = 'Attached' + action = f''' + + Open + + + ''' + else: + status = 'Not attached' + # Attach via endpoint that creates scoped token (not raw token in URL) + attach_url = f"/renderers/attach?l1_url={quote(l1_url, safe='')}" + action = f''' + + Attach + + ''' + + row_id = l1_url.replace("://", "-").replace("/", "-").replace(".", "-") + rows.append(f''' +
+
+
{display_name}
+
{l1_url}
+
+
+ {status} + {action} +
+
+ ''') + + content = f''' +

Renderers

+

Connect to L1 rendering servers. After attaching, you can run effects and manage media on that renderer.

+
+ {"".join(rows) if rows else '

No renderers configured.

'} +
+ ''' + return HTMLResponse(base_html("Renderers", content, username)) + + +@app.get("/renderers/attach") +async def attach_renderer_redirect(request: Request, l1_url: str): + """Create a scoped token and redirect to L1 for attachment.""" + username = get_user_from_cookie(request) + if not username: + return RedirectResponse(url="/login", status_code=302) + + # Verify L1 is in our allowed list + l1_normalized = l1_url.rstrip("/") + if not any(l1_normalized == s.rstrip("/") for s in L1_SERVERS): + raise HTTPException(403, f"L1 server not authorized: {l1_url}") + + # Create a scoped token that only works for this specific L1 + scoped_token = create_access_token( + username, + l2_server=f"https://{DOMAIN}", + l1_server=l1_normalized + ) + + # Redirect to L1 with scoped token + redirect_url = f"{l1_normalized}/auth?auth_token={scoped_token.access_token}" + return RedirectResponse(url=redirect_url, status_code=302) + + +@app.post("/renderers/detach", response_class=HTMLResponse) +async def detach_renderer(request: Request): + """Detach from an L1 renderer.""" + username = get_user_from_cookie(request) + if not username: + return HTMLResponse('
Not logged in
') + + form = await request.form() + l1_url = form.get("l1_url", "") + + await db.detach_renderer(username, l1_url) + + # Return updated row with link to attach endpoint (not raw token) + display_name = l1_url.replace("https://", "").replace("http://", "") + from urllib.parse import quote + attach_url = f"/renderers/attach?l1_url={quote(l1_url, safe='')}" + row_id = l1_url.replace("://", "-").replace("/", "-").replace(".", "-") + + return HTMLResponse(f''' +
+
+
{display_name}
+
{l1_url}
+
+
+ Not attached + + Attach + +
+
+ ''') + + +# ============ User Storage ============ + +import storage_providers + + +@app.get("/storage") +async def list_storage(request: Request, user: User = Depends(get_optional_user)): + """List user's storage providers. HTML for browsers (default), JSON only if explicitly requested.""" + # Check if JSON explicitly requested + accept = request.headers.get("accept", "") + wants_json = "application/json" in accept and "text/html" not in accept + + # For browser sessions, also check cookie authentication + username = user.username if user else get_user_from_cookie(request) + + if not username: + if wants_json: + raise HTTPException(401, "Authentication required") + return RedirectResponse(url="/login", status_code=302) + + storages = await db.get_user_storage(username) + + # Add usage stats to each storage + for storage in storages: + usage = await db.get_storage_usage(storage["id"]) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + # Mask sensitive config keys for display + if storage.get("config"): + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + masked = {} + for k, v in config.items(): + if "key" in k.lower() or "token" in k.lower() or "secret" in k.lower(): + masked[k] = v[:4] + "..." + v[-4:] if len(str(v)) > 8 else "****" + else: + masked[k] = v + storage["config_display"] = masked + + if wants_json: + return {"storages": storages} + + # Default to HTML for browsers + return await ui_storage_page(username, storages, request) + + +@app.post("/storage") +async def add_storage(req: AddStorageRequest, user: User = Depends(get_required_user)): + """Add a storage provider.""" + # Validate provider type + valid_types = ["pinata", "web3storage", "nftstorage", "infura", "filebase", "storj", "local"] + if req.provider_type not in valid_types: + raise HTTPException(400, f"Invalid provider type: {req.provider_type}") + + # Test the provider connection before saving + provider = storage_providers.create_provider(req.provider_type, { + **req.config, + "capacity_gb": req.capacity_gb + }) + if not provider: + raise HTTPException(400, "Failed to create provider with given config") + + success, message = await provider.test_connection() + if not success: + raise HTTPException(400, f"Provider connection failed: {message}") + + # Save to database + provider_name = req.provider_name or f"{req.provider_type}-{user.username}" + storage_id = await db.add_user_storage( + username=user.username, + provider_type=req.provider_type, + provider_name=provider_name, + config=req.config, + capacity_gb=req.capacity_gb + ) + + if not storage_id: + raise HTTPException(500, "Failed to save storage provider") + + return {"id": storage_id, "message": f"Storage provider added: {provider_name}"} + + +@app.post("/storage/add") +async def add_storage_form( + request: Request, + provider_type: str = Form(...), + provider_name: Optional[str] = Form(None), + description: Optional[str] = Form(None), + capacity_gb: int = Form(5), + api_key: Optional[str] = Form(None), + secret_key: Optional[str] = Form(None), + api_token: Optional[str] = Form(None), + project_id: Optional[str] = Form(None), + project_secret: Optional[str] = Form(None), + access_key: Optional[str] = Form(None), + bucket: Optional[str] = Form(None), + path: Optional[str] = Form(None), +): + """Add a storage provider via HTML form (cookie auth).""" + username = get_user_from_cookie(request) + if not username: + return HTMLResponse('
Not authenticated
', status_code=401) + + # Validate provider type + valid_types = ["pinata", "web3storage", "nftstorage", "infura", "filebase", "storj", "local"] + if provider_type not in valid_types: + return HTMLResponse(f'
Invalid provider type: {provider_type}
') + + # Build config based on provider type + config = {} + if provider_type == "pinata": + if not api_key or not secret_key: + return HTMLResponse('
Pinata requires API Key and Secret Key
') + config = {"api_key": api_key, "secret_key": secret_key} + elif provider_type == "web3storage": + if not api_token: + return HTMLResponse('
web3.storage requires API Token
') + config = {"api_token": api_token} + elif provider_type == "nftstorage": + if not api_token: + return HTMLResponse('
NFT.Storage requires API Token
') + config = {"api_token": api_token} + elif provider_type == "infura": + if not project_id or not project_secret: + return HTMLResponse('
Infura requires Project ID and Project Secret
') + config = {"project_id": project_id, "project_secret": project_secret} + elif provider_type == "filebase": + if not access_key or not secret_key or not bucket: + return HTMLResponse('
Filebase requires Access Key, Secret Key, and Bucket
') + config = {"access_key": access_key, "secret_key": secret_key, "bucket": bucket} + elif provider_type == "storj": + if not access_key or not secret_key or not bucket: + return HTMLResponse('
Storj requires Access Key, Secret Key, and Bucket
') + config = {"access_key": access_key, "secret_key": secret_key, "bucket": bucket} + elif provider_type == "local": + if not path: + return HTMLResponse('
Local storage requires a path
') + config = {"path": path} + + # Test the provider connection before saving + provider = storage_providers.create_provider(provider_type, { + **config, + "capacity_gb": capacity_gb + }) + if not provider: + return HTMLResponse('
Failed to create provider with given config
') + + success, message = await provider.test_connection() + if not success: + return HTMLResponse(f'
Provider connection failed: {message}
') + + # Save to database + name = provider_name or f"{provider_type}-{username}-{len(await db.get_user_storage_by_type(username, provider_type)) + 1}" + storage_id = await db.add_user_storage( + username=username, + provider_type=provider_type, + provider_name=name, + config=config, + capacity_gb=capacity_gb, + description=description + ) + + if not storage_id: + return HTMLResponse('
Failed to save storage provider
') + + return HTMLResponse(f''' +
Storage provider "{name}" added successfully!
+ + ''') + + +@app.get("/storage/{storage_id}") +async def get_storage(storage_id: int, user: User = Depends(get_required_user)): + """Get a specific storage provider.""" + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage provider not found") + if storage["username"] != user.username: + raise HTTPException(403, "Not authorized") + + usage = await db.get_storage_usage(storage_id) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + storage["donated_gb"] = storage["capacity_gb"] // 2 + + return storage + + +@app.patch("/storage/{storage_id}") +async def update_storage(storage_id: int, req: UpdateStorageRequest, user: User = Depends(get_required_user)): + """Update a storage provider.""" + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage provider not found") + if storage["username"] != user.username: + raise HTTPException(403, "Not authorized") + + # If updating config, test the new connection + if req.config: + existing_config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + new_config = {**existing_config, **req.config} + provider = storage_providers.create_provider(storage["provider_type"], { + **new_config, + "capacity_gb": req.capacity_gb or storage["capacity_gb"] + }) + if provider: + success, message = await provider.test_connection() + if not success: + raise HTTPException(400, f"Provider connection failed: {message}") + + success = await db.update_user_storage( + storage_id, + config=req.config, + capacity_gb=req.capacity_gb, + is_active=req.is_active + ) + + if not success: + raise HTTPException(500, "Failed to update storage provider") + + return {"message": "Storage provider updated"} + + +@app.delete("/storage/{storage_id}") +async def remove_storage(storage_id: int, request: Request, user: User = Depends(get_optional_user)): + """Remove a storage provider.""" + # Support both Bearer token and cookie auth + username = user.username if user else get_user_from_cookie(request) + if not username: + raise HTTPException(401, "Not authenticated") + + storage = await db.get_storage_by_id(storage_id) + if not storage: + raise HTTPException(404, "Storage provider not found") + if storage["username"] != username: + raise HTTPException(403, "Not authorized") + + success = await db.remove_user_storage(storage_id) + if not success: + raise HTTPException(500, "Failed to remove storage provider") + + # Return empty string for HTMX to remove the element + if wants_html(request): + return HTMLResponse("") + + return {"message": "Storage provider removed"} + + +@app.post("/storage/{storage_id}/test") +async def test_storage(storage_id: int, request: Request, user: User = Depends(get_optional_user)): + """Test storage provider connectivity.""" + # Support both Bearer token and cookie auth + username = user.username if user else get_user_from_cookie(request) + if not username: + if wants_html(request): + return HTMLResponse('Not authenticated', status_code=401) + raise HTTPException(401, "Not authenticated") + + storage = await db.get_storage_by_id(storage_id) + if not storage: + if wants_html(request): + return HTMLResponse('Storage not found', status_code=404) + raise HTTPException(404, "Storage provider not found") + if storage["username"] != username: + if wants_html(request): + return HTMLResponse('Not authorized', status_code=403) + raise HTTPException(403, "Not authorized") + + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + provider = storage_providers.create_provider(storage["provider_type"], { + **config, + "capacity_gb": storage["capacity_gb"] + }) + + if not provider: + if wants_html(request): + return HTMLResponse('Failed to create provider') + raise HTTPException(500, "Failed to create provider") + + success, message = await provider.test_connection() + + if wants_html(request): + if success: + return HTMLResponse(f'{message}') + return HTMLResponse(f'{message}') + + return {"success": success, "message": message} + + +STORAGE_PROVIDERS_INFO = { + "pinata": {"name": "Pinata", "desc": "1GB free, IPFS pinning", "color": "blue"}, + "web3storage": {"name": "web3.storage", "desc": "IPFS + Filecoin", "color": "green"}, + "nftstorage": {"name": "NFT.Storage", "desc": "Free for NFTs", "color": "pink"}, + "infura": {"name": "Infura IPFS", "desc": "5GB free", "color": "orange"}, + "filebase": {"name": "Filebase", "desc": "5GB free, S3+IPFS", "color": "cyan"}, + "storj": {"name": "Storj", "desc": "25GB free", "color": "indigo"}, + "local": {"name": "Local Storage", "desc": "Your own disk", "color": "purple"}, +} + + +async def ui_storage_page(username: str, storages: list, request: Request) -> HTMLResponse: + """Render main storage settings page showing provider types.""" + # Count configs per type + type_counts = {} + for s in storages: + pt = s["provider_type"] + type_counts[pt] = type_counts.get(pt, 0) + 1 + + # Build provider type cards + cards = "" + for ptype, info in STORAGE_PROVIDERS_INFO.items(): + count = type_counts.get(ptype, 0) + count_badge = f'{count}' if count > 0 else "" + cards += f''' + +
+ {info["name"]} + {count_badge} +
+
{info["desc"]}
+
+ ''' + + # Total stats + total_capacity = sum(s["capacity_gb"] for s in storages) + total_used = sum(s.get("used_bytes", 0) for s in storages) + total_pins = sum(s.get("pin_count", 0) for s in storages) + + content = f''' +
+
+

Storage Providers

+
+ +
+

+ Attach your own storage to help power the network. 50% of your capacity is donated to store + shared content, making popular assets more resilient. +

+ +
+
+
{len(storages)}
+
Total Configs
+
+
+
{total_capacity} GB
+
Total Capacity
+
+
+
{total_used / (1024**3):.1f} GB
+
Used
+
+
+
{total_pins}
+
Total Pins
+
+
+ +

Select Provider Type

+
+ {cards} +
+
+
+ ''' + + return HTMLResponse(base_html("Storage", content, username)) + + +@app.get("/storage/type/{provider_type}") +async def storage_type_page(provider_type: str, request: Request, user: User = Depends(get_optional_user)): + """Page for managing storage configs of a specific type.""" + username = user.username if user else get_user_from_cookie(request) + if not username: + return RedirectResponse(url="/login", status_code=302) + + if provider_type not in STORAGE_PROVIDERS_INFO: + raise HTTPException(404, "Invalid provider type") + + storages = await db.get_user_storage_by_type(username, provider_type) + + # Add usage stats + for storage in storages: + usage = await db.get_storage_usage(storage["id"]) + storage["used_bytes"] = usage["used_bytes"] + storage["pin_count"] = usage["pin_count"] + # Mask sensitive config keys + if storage.get("config"): + config = storage["config"] if isinstance(storage["config"], dict) else json.loads(storage["config"]) + masked = {} + for k, v in config.items(): + if "key" in k.lower() or "token" in k.lower() or "secret" in k.lower(): + masked[k] = v[:4] + "..." + v[-4:] if len(str(v)) > 8 else "****" + else: + masked[k] = v + storage["config_display"] = masked + + info = STORAGE_PROVIDERS_INFO[provider_type] + return await ui_storage_type_page(username, provider_type, info, storages, request) + + +async def ui_storage_type_page(username: str, provider_type: str, info: dict, storages: list, request: Request) -> HTMLResponse: + """Render per-type storage management page.""" + + def format_bytes(b): + if b > 1024**3: + return f"{b / 1024**3:.1f} GB" + if b > 1024**2: + return f"{b / 1024**2:.1f} MB" + if b > 1024: + return f"{b / 1024:.1f} KB" + return f"{b} bytes" + + # Build storage rows + storage_rows = "" + for s in storages: + status_class = "bg-green-600" if s["is_active"] else "bg-gray-600" + status_text = "Active" if s["is_active"] else "Inactive" + config_display = s.get("config_display", {}) + config_html = ", ".join(f"{k}: {v}" for k, v in config_display.items() if k != "path") + desc = s.get("description") or "" + desc_html = f'
{desc}
' if desc else "" + + storage_rows += f''' +
+
+
+

{s["provider_name"] or provider_type}

+ {desc_html} +
+
+ {status_text} + + +
+
+
+
+
Capacity
+
{s["capacity_gb"]} GB
+
+
+
Donated
+
{s["capacity_gb"] // 2} GB
+
+
+
Used
+
{format_bytes(s["used_bytes"])}
+
+
+
Pins
+
{s["pin_count"]}
+
+
+
{config_html}
+
+
+ ''' + + if not storages: + storage_rows = f'

No {info["name"]} configs yet. Add one below.

' + + # Build form fields based on provider type + form_fields = "" + if provider_type == "pinata": + form_fields = ''' +
+ + +
+
+ + +
+ ''' + elif provider_type in ("web3storage", "nftstorage"): + form_fields = ''' +
+ + +
+ ''' + elif provider_type == "infura": + form_fields = ''' +
+ + +
+
+ + +
+ ''' + elif provider_type in ("filebase", "storj"): + form_fields = ''' +
+ + +
+
+ + +
+
+ + +
+ ''' + elif provider_type == "local": + form_fields = ''' +
+ + +
+ ''' + + content = f''' +
+
+ ← Back +

{info["name"]} Storage

+
+ +
+

Your {info["name"]} Configs

+
+ {storage_rows} +
+
+ +
+

Add New {info["name"]} Config

+
+ + + {form_fields} + +
+ + +
+
+ + +
+
+ + +
+ +
+ +
+
+
+
+
+ ''' + + return HTMLResponse(base_html(f"{info['name']} Storage", content, username)) + + +# ============ Client Download ============ + +CLIENT_TARBALL = Path(__file__).parent / "artdag-client.tar.gz" + +@app.get("/download/client") +async def download_client(): + """Download the Art DAG CLI client.""" + if not CLIENT_TARBALL.exists(): + raise HTTPException(404, "Client package not found") + return FileResponse( + CLIENT_TARBALL, + media_type="application/gzip", + filename="artdag-client.tar.gz" + ) + + +# ============================================================================ +# Help / Documentation Routes +# ============================================================================ + +L2_DOCS_DIR = Path(__file__).parent +COMMON_DOCS_DIR = Path(__file__).parent.parent / "common" + +L2_DOCS_MAP = { + "l2": ("L2 Server (ActivityPub)", L2_DOCS_DIR / "README.md"), + "common": ("Common Library", COMMON_DOCS_DIR / "README.md"), +} + + +@app.get("/help", response_class=HTMLResponse) +async def help_index(request: Request): + """Documentation index page.""" + username = get_user_from_cookie(request) + + # Build doc links + doc_links = "" + for key, (title, path) in L2_DOCS_MAP.items(): + if path.exists(): + doc_links += f''' + +

{title}

+

View documentation

+
+ ''' + + content = f''' +
+

Documentation

+
+ {doc_links} +
+
+ ''' + return HTMLResponse(base_html("Help", content, username)) + + +@app.get("/help/{doc_name}", response_class=HTMLResponse) +async def help_page(doc_name: str, request: Request): + """Render a README as HTML.""" + if doc_name not in L2_DOCS_MAP: + raise HTTPException(404, f"Documentation '{doc_name}' not found") + + title, doc_path = L2_DOCS_MAP[doc_name] + if not doc_path.exists(): + raise HTTPException(404, f"Documentation file not found") + + username = get_user_from_cookie(request) + + # Read and render markdown + md_content = doc_path.read_text() + html_content = markdown.markdown(md_content, extensions=['tables', 'fenced_code']) + + content = f''' +
+ +
+ {html_content} +
+
+ ''' + return HTMLResponse(base_html(title, content, username)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run("server:app", host="0.0.0.0", port=8200, workers=4) diff --git a/setup_keys.py b/setup_keys.py new file mode 100755 index 0000000..1042d5a --- /dev/null +++ b/setup_keys.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +""" +Generate RSA keypair for ActivityPub signing. + +Usage: + python setup_keys.py [--data-dir /path/to/data] [--user username] +""" + +import argparse +import os +from pathlib import Path + +from keys import generate_keypair, has_keys, get_keys_dir + + +def main(): + parser = argparse.ArgumentParser(description="Generate RSA keypair for L2 server") + parser.add_argument("--data-dir", default=os.environ.get("ARTDAG_DATA", str(Path.home() / ".artdag" / "l2")), + help="Data directory") + parser.add_argument("--user", default=os.environ.get("ARTDAG_USER", "giles"), + help="Username") + parser.add_argument("--force", action="store_true", + help="Overwrite existing keys") + + args = parser.parse_args() + data_dir = Path(args.data_dir) + username = args.user + + print(f"Data directory: {data_dir}") + print(f"Username: {username}") + + if has_keys(data_dir, username) and not args.force: + print(f"\nKeys already exist for {username}!") + print(f" Private: {get_keys_dir(data_dir) / f'{username}.pem'}") + print(f" Public: {get_keys_dir(data_dir) / f'{username}.pub'}") + print("\nUse --force to regenerate (will invalidate existing signatures)") + return + + print("\nGenerating RSA-2048 keypair...") + private_pem, public_pem = generate_keypair(data_dir, username) + + keys_dir = get_keys_dir(data_dir) + print(f"\nKeys generated:") + print(f" Private: {keys_dir / f'{username}.pem'} (chmod 600)") + print(f" Public: {keys_dir / f'{username}.pub'}") + print(f"\nPublic key (for verification):") + print(public_pem) + + +if __name__ == "__main__": + main() diff --git a/storage_providers.py b/storage_providers.py new file mode 100644 index 0000000..46dee08 --- /dev/null +++ b/storage_providers.py @@ -0,0 +1,1009 @@ +""" +Storage provider abstraction for user-attachable storage. + +Supports: +- Pinata (IPFS pinning service) +- web3.storage (IPFS pinning service) +- Local filesystem storage +""" + +import hashlib +import json +import logging +import os +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + + +class StorageProvider(ABC): + """Abstract base class for storage backends.""" + + provider_type: str = "unknown" + + @abstractmethod + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """ + Pin content to storage. + + Args: + content_hash: SHA3-256 hash of the content + data: Raw bytes to store + filename: Optional filename hint + + Returns: + IPFS CID or provider-specific ID, or None on failure + """ + pass + + @abstractmethod + async def unpin(self, content_hash: str) -> bool: + """ + Unpin content from storage. + + Args: + content_hash: SHA3-256 hash of the content + + Returns: + True if unpinned successfully + """ + pass + + @abstractmethod + async def get(self, content_hash: str) -> Optional[bytes]: + """ + Retrieve content from storage. + + Args: + content_hash: SHA3-256 hash of the content + + Returns: + Raw bytes or None if not found + """ + pass + + @abstractmethod + async def is_pinned(self, content_hash: str) -> bool: + """Check if content is pinned in this storage.""" + pass + + @abstractmethod + async def test_connection(self) -> tuple[bool, str]: + """ + Test connectivity to the storage provider. + + Returns: + (success, message) tuple + """ + pass + + @abstractmethod + def get_usage(self) -> dict: + """ + Get storage usage statistics. + + Returns: + {used_bytes, capacity_bytes, pin_count} + """ + pass + + +class PinataProvider(StorageProvider): + """Pinata IPFS pinning service provider.""" + + provider_type = "pinata" + + def __init__(self, api_key: str, secret_key: str, capacity_gb: int = 1): + self.api_key = api_key + self.secret_key = secret_key + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.pinata.cloud" + self._usage_cache = None + + def _headers(self) -> dict: + return { + "pinata_api_key": self.api_key, + "pinata_secret_api_key": self.secret_key, + } + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Pinata.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{content_hash[:16]}.bin", data)} + metadata = { + "name": filename or content_hash[:16], + "keyvalues": {"content_hash": content_hash} + } + response = requests.post( + f"{self.base_url}/pinning/pinFileToIPFS", + files=files, + data={"pinataMetadata": json.dumps(metadata)}, + headers=self._headers(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("IpfsHash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Pinata: Pinned {content_hash[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Pinata pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """Unpin content from Pinata by finding its CID first.""" + try: + import asyncio + + def do_unpin(): + # First find the pin by content_hash metadata + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][content_hash]": content_hash, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return False + + # Unpin each matching CID + for pin in pins: + cid = pin.get("ipfs_pin_hash") + if cid: + resp = requests.delete( + f"{self.base_url}/pinning/unpin/{cid}", + headers=self._headers(), + timeout=30 + ) + resp.raise_for_status() + return True + + result = await asyncio.to_thread(do_unpin) + logger.info(f"Pinata: Unpinned {content_hash[:16]}...") + return result + except Exception as e: + logger.error(f"Pinata unpin failed: {e}") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from Pinata via IPFS gateway.""" + try: + import asyncio + + def do_get(): + # First find the CID + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][content_hash]": content_hash, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + pins = response.json().get("rows", []) + + if not pins: + return None + + cid = pins[0].get("ipfs_pin_hash") + if not cid: + return None + + # Fetch from gateway + gateway_response = requests.get( + f"https://gateway.pinata.cloud/ipfs/{cid}", + timeout=120 + ) + gateway_response.raise_for_status() + return gateway_response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Pinata get failed: {e}") + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content is pinned on Pinata.""" + try: + import asyncio + + def do_check(): + response = requests.get( + f"{self.base_url}/data/pinList", + params={"metadata[keyvalues][content_hash]": content_hash, "status": "pinned"}, + headers=self._headers(), + timeout=30 + ) + response.raise_for_status() + return len(response.json().get("rows", [])) > 0 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Pinata API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/data/testAuthentication", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Pinata successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Pinata usage stats.""" + try: + response = requests.get( + f"{self.base_url}/data/userPinnedDataTotal", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + data = response.json() + return { + "used_bytes": data.get("pin_size_total", 0), + "capacity_bytes": self.capacity_bytes, + "pin_count": data.get("pin_count", 0) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class Web3StorageProvider(StorageProvider): + """web3.storage pinning service provider.""" + + provider_type = "web3storage" + + def __init__(self, api_token: str, capacity_gb: int = 1): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.web3.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to web3.storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={ + **self._headers(), + "X-Name": filename or content_hash[:16] + }, + timeout=120 + ) + response.raise_for_status() + return response.json().get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"web3.storage: Pinned {content_hash[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"web3.storage pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """web3.storage doesn't support unpinning - data is stored permanently.""" + logger.warning("web3.storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from web3.storage - would need CID mapping.""" + # web3.storage requires knowing the CID to fetch + # For now, return None - we'd need to maintain a mapping + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test web3.storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1}, + timeout=10 + ) + response.raise_for_status() + return True, "Connected to web3.storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get web3.storage usage stats.""" + try: + response = requests.get( + f"{self.base_url}/user/uploads", + headers=self._headers(), + params={"size": 1000}, + timeout=30 + ) + response.raise_for_status() + uploads = response.json() + total_size = sum(u.get("dagSize", 0) for u in uploads) + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": len(uploads) + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class NFTStorageProvider(StorageProvider): + """NFT.Storage pinning service provider (free for NFT data).""" + + provider_type = "nftstorage" + + def __init__(self, api_token: str, capacity_gb: int = 5): + self.api_token = api_token + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://api.nft.storage" + + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_token}"} + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to NFT.Storage.""" + try: + import asyncio + + def do_pin(): + response = requests.post( + f"{self.base_url}/upload", + data=data, + headers={**self._headers(), "Content-Type": "application/octet-stream"}, + timeout=120 + ) + response.raise_for_status() + return response.json().get("value", {}).get("cid") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"NFT.Storage: Pinned {content_hash[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"NFT.Storage pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """NFT.Storage doesn't support unpinning - data is stored permanently.""" + logger.warning("NFT.Storage: Unpinning not supported (permanent storage)") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from NFT.Storage - would need CID mapping.""" + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content is pinned - would need CID mapping.""" + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test NFT.Storage API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.get( + f"{self.base_url}/", + headers=self._headers(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to NFT.Storage successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid API token" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get NFT.Storage usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class InfuraIPFSProvider(StorageProvider): + """Infura IPFS pinning service provider.""" + + provider_type = "infura" + + def __init__(self, project_id: str, project_secret: str, capacity_gb: int = 5): + self.project_id = project_id + self.project_secret = project_secret + self.capacity_bytes = capacity_gb * 1024**3 + self.base_url = "https://ipfs.infura.io:5001/api/v0" + + def _auth(self) -> tuple: + return (self.project_id, self.project_secret) + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Infura IPFS.""" + try: + import asyncio + + def do_pin(): + files = {"file": (filename or f"{content_hash[:16]}.bin", data)} + response = requests.post( + f"{self.base_url}/add", + files=files, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.json().get("Hash") + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Infura IPFS: Pinned {content_hash[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Infura IPFS pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """Unpin content from Infura IPFS.""" + try: + import asyncio + + def do_unpin(): + response = requests.post( + f"{self.base_url}/pin/rm", + params={"arg": content_hash}, + auth=self._auth(), + timeout=30 + ) + response.raise_for_status() + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Infura IPFS unpin failed: {e}") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from Infura IPFS gateway.""" + try: + import asyncio + + def do_get(): + response = requests.post( + f"{self.base_url}/cat", + params={"arg": content_hash}, + auth=self._auth(), + timeout=120 + ) + response.raise_for_status() + return response.content + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Infura IPFS get failed: {e}") + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content is pinned on Infura IPFS.""" + try: + import asyncio + + def do_check(): + response = requests.post( + f"{self.base_url}/pin/ls", + params={"arg": content_hash}, + auth=self._auth(), + timeout=30 + ) + return response.status_code == 200 + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Infura IPFS API connectivity.""" + try: + import asyncio + + def do_test(): + response = requests.post( + f"{self.base_url}/id", + auth=self._auth(), + timeout=10 + ) + response.raise_for_status() + return True, "Connected to Infura IPFS successfully" + + return await asyncio.to_thread(do_test) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + return False, "Invalid project credentials" + return False, f"HTTP error: {e}" + except Exception as e: + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Infura usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class FilebaseProvider(StorageProvider): + """Filebase S3-compatible IPFS pinning service.""" + + provider_type = "filebase" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 5): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://s3.filebase.com" + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Pin content to Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or f"{content_hash[:16]}.bin" + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + # Get CID from response headers + head = s3.head_object(Bucket=self.bucket, Key=key) + return head.get('Metadata', {}).get('cid', content_hash) + + cid = await asyncio.to_thread(do_pin) + logger.info(f"Filebase: Pinned {content_hash[:16]}... as {cid}") + return cid + except Exception as e: + logger.error(f"Filebase pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """Remove content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=content_hash) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Filebase unpin failed: {e}") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=content_hash) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Filebase get failed: {e}") + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content exists in Filebase.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=content_hash) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Filebase connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Filebase bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Filebase usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class StorjProvider(StorageProvider): + """Storj decentralized cloud storage (S3-compatible).""" + + provider_type = "storj" + + def __init__(self, access_key: str, secret_key: str, bucket: str, capacity_gb: int = 25): + self.access_key = access_key + self.secret_key = secret_key + self.bucket = bucket + self.capacity_bytes = capacity_gb * 1024**3 + self.endpoint = "https://gateway.storjshare.io" + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_pin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + key = filename or content_hash + s3.put_object(Bucket=self.bucket, Key=key, Body=data) + return content_hash + + result = await asyncio.to_thread(do_pin) + logger.info(f"Storj: Stored {content_hash[:16]}...") + return result + except Exception as e: + logger.error(f"Storj pin failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """Remove content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_unpin(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.delete_object(Bucket=self.bucket, Key=content_hash) + return True + + return await asyncio.to_thread(do_unpin) + except Exception as e: + logger.error(f"Storj unpin failed: {e}") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_get(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + response = s3.get_object(Bucket=self.bucket, Key=content_hash) + return response['Body'].read() + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Storj get failed: {e}") + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content exists on Storj.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_check(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_object(Bucket=self.bucket, Key=content_hash) + return True + + return await asyncio.to_thread(do_check) + except Exception: + return False + + async def test_connection(self) -> tuple[bool, str]: + """Test Storj connectivity.""" + try: + import asyncio + import boto3 + from botocore.config import Config + + def do_test(): + s3 = boto3.client( + 's3', + endpoint_url=self.endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + config=Config(signature_version='s3v4') + ) + s3.head_bucket(Bucket=self.bucket) + return True, f"Connected to Storj bucket '{self.bucket}'" + + return await asyncio.to_thread(do_test) + except Exception as e: + if "404" in str(e): + return False, f"Bucket '{self.bucket}' not found" + if "403" in str(e): + return False, "Invalid credentials or no access to bucket" + return False, f"Connection failed: {e}" + + def get_usage(self) -> dict: + """Get Storj usage stats.""" + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +class LocalStorageProvider(StorageProvider): + """Local filesystem storage provider.""" + + provider_type = "local" + + def __init__(self, base_path: str, capacity_gb: int = 10): + self.base_path = Path(base_path) + self.capacity_bytes = capacity_gb * 1024**3 + # Create directory if it doesn't exist + self.base_path.mkdir(parents=True, exist_ok=True) + + def _get_file_path(self, content_hash: str) -> Path: + """Get file path for a content hash (using subdirectories).""" + # Use first 2 chars as subdirectory for better filesystem performance + subdir = content_hash[:2] + return self.base_path / subdir / content_hash + + async def pin(self, content_hash: str, data: bytes, filename: Optional[str] = None) -> Optional[str]: + """Store content locally.""" + try: + import asyncio + + def do_store(): + file_path = self._get_file_path(content_hash) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(data) + return content_hash # Use content_hash as ID for local storage + + result = await asyncio.to_thread(do_store) + logger.info(f"Local: Stored {content_hash[:16]}...") + return result + except Exception as e: + logger.error(f"Local storage failed: {e}") + return None + + async def unpin(self, content_hash: str) -> bool: + """Remove content from local storage.""" + try: + import asyncio + + def do_remove(): + file_path = self._get_file_path(content_hash) + if file_path.exists(): + file_path.unlink() + return True + return False + + return await asyncio.to_thread(do_remove) + except Exception as e: + logger.error(f"Local unpin failed: {e}") + return False + + async def get(self, content_hash: str) -> Optional[bytes]: + """Get content from local storage.""" + try: + import asyncio + + def do_get(): + file_path = self._get_file_path(content_hash) + if file_path.exists(): + return file_path.read_bytes() + return None + + return await asyncio.to_thread(do_get) + except Exception as e: + logger.error(f"Local get failed: {e}") + return None + + async def is_pinned(self, content_hash: str) -> bool: + """Check if content exists in local storage.""" + return self._get_file_path(content_hash).exists() + + async def test_connection(self) -> tuple[bool, str]: + """Test local storage is writable.""" + try: + test_file = self.base_path / ".write_test" + test_file.write_text("test") + test_file.unlink() + return True, f"Local storage ready at {self.base_path}" + except Exception as e: + return False, f"Cannot write to {self.base_path}: {e}" + + def get_usage(self) -> dict: + """Get local storage usage stats.""" + try: + total_size = 0 + file_count = 0 + for subdir in self.base_path.iterdir(): + if subdir.is_dir() and len(subdir.name) == 2: + for f in subdir.iterdir(): + if f.is_file(): + total_size += f.stat().st_size + file_count += 1 + return { + "used_bytes": total_size, + "capacity_bytes": self.capacity_bytes, + "pin_count": file_count + } + except Exception: + return {"used_bytes": 0, "capacity_bytes": self.capacity_bytes, "pin_count": 0} + + +def create_provider(provider_type: str, config: dict) -> Optional[StorageProvider]: + """ + Factory function to create a storage provider from config. + + Args: + provider_type: One of 'pinata', 'web3storage', 'nftstorage', 'infura', 'filebase', 'storj', 'local' + config: Provider-specific configuration dict + + Returns: + StorageProvider instance or None if invalid + """ + try: + if provider_type == "pinata": + return PinataProvider( + api_key=config["api_key"], + secret_key=config["secret_key"], + capacity_gb=config.get("capacity_gb", 1) + ) + elif provider_type == "web3storage": + return Web3StorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "nftstorage": + return NFTStorageProvider( + api_token=config["api_token"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "infura": + return InfuraIPFSProvider( + project_id=config["project_id"], + project_secret=config["project_secret"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "filebase": + return FilebaseProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 5) + ) + elif provider_type == "storj": + return StorjProvider( + access_key=config["access_key"], + secret_key=config["secret_key"], + bucket=config["bucket"], + capacity_gb=config.get("capacity_gb", 25) + ) + elif provider_type == "local": + return LocalStorageProvider( + base_path=config["path"], + capacity_gb=config.get("capacity_gb", 10) + ) + else: + logger.error(f"Unknown provider type: {provider_type}") + return None + except KeyError as e: + logger.error(f"Missing config key for {provider_type}: {e}") + return None + except Exception as e: + logger.error(f"Failed to create provider {provider_type}: {e}") + return None From ea9015f65baf39d0fa4b7d8451f04a89b57d9fb9 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:08:41 +0000 Subject: [PATCH 04/24] Squashed 'common/' content from commit ff185b4 git-subtree-dir: common git-subtree-split: ff185b42f0fa577446c3d00da3438dc148ee8102 --- README.md | 293 ++++++++++++++++++ artdag_common/__init__.py | 18 ++ artdag_common/constants.py | 76 +++++ artdag_common/fragments.py | 91 ++++++ artdag_common/middleware/__init__.py | 16 + .../__pycache__/auth.cpython-310.pyc | Bin 0 -> 6751 bytes artdag_common/middleware/auth.py | 276 +++++++++++++++++ .../middleware/content_negotiation.py | 174 +++++++++++ artdag_common/models/__init__.py | 25 ++ artdag_common/models/requests.py | 74 +++++ artdag_common/models/responses.py | 96 ++++++ artdag_common/rendering.py | 160 ++++++++++ artdag_common/templates/_base.html | 91 ++++++ artdag_common/templates/components/badge.html | 64 ++++ artdag_common/templates/components/card.html | 45 +++ artdag_common/templates/components/dag.html | 176 +++++++++++ .../templates/components/media_preview.html | 98 ++++++ .../templates/components/pagination.html | 82 +++++ artdag_common/templates/components/table.html | 51 +++ artdag_common/utils/__init__.py | 19 ++ artdag_common/utils/formatting.py | 165 ++++++++++ artdag_common/utils/media.py | 166 ++++++++++ artdag_common/utils/pagination.py | 85 +++++ pyproject.toml | 22 ++ 24 files changed, 2363 insertions(+) create mode 100644 README.md create mode 100644 artdag_common/__init__.py create mode 100644 artdag_common/constants.py create mode 100644 artdag_common/fragments.py create mode 100644 artdag_common/middleware/__init__.py create mode 100644 artdag_common/middleware/__pycache__/auth.cpython-310.pyc create mode 100644 artdag_common/middleware/auth.py create mode 100644 artdag_common/middleware/content_negotiation.py create mode 100644 artdag_common/models/__init__.py create mode 100644 artdag_common/models/requests.py create mode 100644 artdag_common/models/responses.py create mode 100644 artdag_common/rendering.py create mode 100644 artdag_common/templates/_base.html create mode 100644 artdag_common/templates/components/badge.html create mode 100644 artdag_common/templates/components/card.html create mode 100644 artdag_common/templates/components/dag.html create mode 100644 artdag_common/templates/components/media_preview.html create mode 100644 artdag_common/templates/components/pagination.html create mode 100644 artdag_common/templates/components/table.html create mode 100644 artdag_common/utils/__init__.py create mode 100644 artdag_common/utils/formatting.py create mode 100644 artdag_common/utils/media.py create mode 100644 artdag_common/utils/pagination.py create mode 100644 pyproject.toml diff --git a/README.md b/README.md new file mode 100644 index 0000000..73d1dd5 --- /dev/null +++ b/README.md @@ -0,0 +1,293 @@ +# artdag-common + +Shared components for Art-DAG L1 (celery) and L2 (activity-pub) servers. + +## Features + +- **Jinja2 Templating**: Unified template environment with shared base templates +- **Reusable Components**: Cards, tables, pagination, DAG visualization, media preview +- **Authentication Middleware**: Cookie and JWT token parsing +- **Content Negotiation**: HTML/JSON/ActivityPub format detection +- **Utility Functions**: Hash truncation, file size formatting, status colors + +## Installation + +```bash +pip install -e /path/to/artdag-common + +# Or add to requirements.txt +-e file:../common +``` + +## Quick Start + +```python +from fastapi import FastAPI, Request +from artdag_common import create_jinja_env, render + +app = FastAPI() + +# Initialize templates with app-specific directory +templates = create_jinja_env("app/templates") + +@app.get("/") +async def home(request: Request): + return render(templates, "home.html", request, title="Home") +``` + +## Package Structure + +``` +artdag_common/ +├── __init__.py # Package exports +├── constants.py # CDN URLs, colors, configs +├── rendering.py # Jinja2 environment and helpers +├── middleware/ +│ ├── auth.py # Authentication utilities +│ └── content_negotiation.py # Accept header parsing +├── models/ +│ ├── requests.py # Shared request models +│ └── responses.py # Shared response models +├── utils/ +│ ├── formatting.py # Text/date formatting +│ ├── media.py # Media type detection +│ └── pagination.py # Pagination helpers +└── templates/ + ├── base.html # Base layout template + └── components/ + ├── badge.html # Status/type badges + ├── card.html # Info cards + ├── dag.html # Cytoscape DAG visualization + ├── media_preview.html # Video/image/audio preview + ├── pagination.html # HTMX pagination + └── table.html # Styled tables +``` + +## Jinja2 Templates + +### Base Template + +The `base.html` template provides: +- Dark theme with Tailwind CSS +- HTMX integration +- Navigation slot +- Content block +- Optional Cytoscape.js block + +```html +{% extends "base.html" %} + +{% block title %}My Page{% endblock %} + +{% block content %} +

Hello World

+{% endblock %} +``` + +### Reusable Components + +#### Card + +```html +{% include "components/card.html" %} +``` + +```html + +
+ {% block card_title %}Title{% endblock %} + {% block card_content %}Content{% endblock %} +
+``` + +#### Badge + +Status and type badges with appropriate colors: + +```html +{% from "components/badge.html" import status_badge, type_badge %} + +{{ status_badge("completed") }} +{{ status_badge("failed") }} +{{ type_badge("video") }} +``` + +#### DAG Visualization + +Interactive Cytoscape.js graph: + +```html +{% include "components/dag.html" %} +``` + +Requires passing `nodes` and `edges` data to template context. + +#### Media Preview + +Responsive media preview with format detection: + +```html +{% include "components/media_preview.html" %} +``` + +Supports video, audio, and image formats. + +#### Pagination + +HTMX-powered infinite scroll pagination: + +```html +{% include "components/pagination.html" %} +``` + +## Template Rendering + +### Full Page Render + +```python +from artdag_common import render + +@app.get("/runs/{run_id}") +async def run_detail(run_id: str, request: Request): + run = get_run(run_id) + return render(templates, "runs/detail.html", request, run=run) +``` + +### Fragment Render (HTMX) + +```python +from artdag_common import render_fragment + +@app.get("/runs/{run_id}/status") +async def run_status_fragment(run_id: str): + run = get_run(run_id) + html = render_fragment(templates, "components/status.html", status=run.status) + return HTMLResponse(html) +``` + +## Authentication Middleware + +### UserContext + +```python +from artdag_common.middleware.auth import UserContext, get_user_from_cookie + +@app.get("/profile") +async def profile(request: Request): + user = get_user_from_cookie(request) + if not user: + return RedirectResponse("/login") + return {"username": user.username, "actor_id": user.actor_id} +``` + +### Token Parsing + +```python +from artdag_common.middleware.auth import get_user_from_header, decode_jwt_claims + +@app.get("/api/me") +async def api_me(request: Request): + user = get_user_from_header(request) + if not user: + raise HTTPException(401, "Not authenticated") + return {"user": user.username} +``` + +## Content Negotiation + +Detect what response format the client wants: + +```python +from artdag_common.middleware.content_negotiation import wants_html, wants_json, wants_activity_json + +@app.get("/users/{username}") +async def user_profile(username: str, request: Request): + user = get_user(username) + + if wants_activity_json(request): + return ActivityPubActor(user) + elif wants_json(request): + return user.dict() + else: + return render(templates, "users/profile.html", request, user=user) +``` + +## Constants + +### CDN URLs + +```python +from artdag_common import TAILWIND_CDN, HTMX_CDN, CYTOSCAPE_CDN + +# Available in templates as globals: +# {{ TAILWIND_CDN }} +# {{ HTMX_CDN }} +# {{ CYTOSCAPE_CDN }} +``` + +### Node Colors + +```python +from artdag_common import NODE_COLORS + +# { +# "SOURCE": "#3b82f6", # Blue +# "EFFECT": "#22c55e", # Green +# "OUTPUT": "#a855f7", # Purple +# "ANALYSIS": "#f59e0b", # Amber +# "_LIST": "#6366f1", # Indigo +# "default": "#6b7280", # Gray +# } +``` + +### Status Colors + +```python +STATUS_COLORS = { + "completed": "bg-green-600", + "cached": "bg-blue-600", + "running": "bg-yellow-600", + "pending": "bg-gray-600", + "failed": "bg-red-600", +} +``` + +## Custom Jinja2 Filters + +The following filters are available in all templates: + +| Filter | Usage | Description | +|--------|-------|-------------| +| `truncate_hash` | `{{ hash\|truncate_hash }}` | Shorten hash to 16 chars with ellipsis | +| `format_size` | `{{ bytes\|format_size }}` | Format bytes as KB/MB/GB | +| `status_color` | `{{ status\|status_color }}` | Get Tailwind class for status | + +Example: + +```html + + {{ run.status }} + + +{{ content_hash|truncate_hash }} + +{{ file_size|format_size }} +``` + +## Development + +```bash +cd /root/art-dag/common + +# Install in development mode +pip install -e . + +# Run tests +pytest +``` + +## Dependencies + +- `fastapi>=0.100.0` - Web framework +- `jinja2>=3.1.0` - Templating engine +- `pydantic>=2.0.0` - Data validation diff --git a/artdag_common/__init__.py b/artdag_common/__init__.py new file mode 100644 index 0000000..e7fd6e5 --- /dev/null +++ b/artdag_common/__init__.py @@ -0,0 +1,18 @@ +""" +Art-DAG Common Library + +Shared components for L1 (celery) and L2 (activity-pub) servers. +""" + +from .constants import NODE_COLORS, TAILWIND_CDN, HTMX_CDN, CYTOSCAPE_CDN +from .rendering import create_jinja_env, render, render_fragment + +__all__ = [ + "NODE_COLORS", + "TAILWIND_CDN", + "HTMX_CDN", + "CYTOSCAPE_CDN", + "create_jinja_env", + "render", + "render_fragment", +] diff --git a/artdag_common/constants.py b/artdag_common/constants.py new file mode 100644 index 0000000..ee8862d --- /dev/null +++ b/artdag_common/constants.py @@ -0,0 +1,76 @@ +""" +Shared constants for Art-DAG servers. +""" + +# CDN URLs +TAILWIND_CDN = "https://cdn.tailwindcss.com?plugins=typography" +HTMX_CDN = "https://unpkg.com/htmx.org@1.9.10" +CYTOSCAPE_CDN = "https://cdnjs.cloudflare.com/ajax/libs/cytoscape/3.28.1/cytoscape.min.js" +DAGRE_CDN = "https://cdnjs.cloudflare.com/ajax/libs/dagre/0.8.5/dagre.min.js" +CYTOSCAPE_DAGRE_CDN = "https://cdn.jsdelivr.net/npm/cytoscape-dagre@2.5.0/cytoscape-dagre.min.js" + +# Node colors for DAG visualization +NODE_COLORS = { + "SOURCE": "#3b82f6", # Blue - input sources + "EFFECT": "#22c55e", # Green - processing effects + "OUTPUT": "#a855f7", # Purple - final outputs + "ANALYSIS": "#f59e0b", # Amber - analysis nodes + "_LIST": "#6366f1", # Indigo - list aggregation + "default": "#6b7280", # Gray - unknown types +} + +# Status colors +STATUS_COLORS = { + "completed": "bg-green-600", + "cached": "bg-blue-600", + "running": "bg-yellow-600", + "pending": "bg-gray-600", + "failed": "bg-red-600", +} + +# Tailwind dark theme configuration +TAILWIND_CONFIG = """ + + +""" + +# Default pagination settings +DEFAULT_PAGE_SIZE = 20 +MAX_PAGE_SIZE = 100 diff --git a/artdag_common/fragments.py b/artdag_common/fragments.py new file mode 100644 index 0000000..321949b --- /dev/null +++ b/artdag_common/fragments.py @@ -0,0 +1,91 @@ +"""Fragment client for fetching HTML fragments from coop apps. + +Lightweight httpx-based client (no Quart dependency) for Art-DAG to consume +coop app fragments like nav-tree, auth-menu, and cart-mini. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Sequence + +import httpx + +log = logging.getLogger(__name__) + +FRAGMENT_HEADER = "X-Fragment-Request" + +_client: httpx.AsyncClient | None = None +_DEFAULT_TIMEOUT = 2.0 + + +def _get_client() -> httpx.AsyncClient: + global _client + if _client is None or _client.is_closed: + _client = httpx.AsyncClient( + timeout=httpx.Timeout(_DEFAULT_TIMEOUT), + follow_redirects=False, + ) + return _client + + +def _internal_url(app_name: str) -> str: + """Resolve internal base URL for a coop app. + + Looks up ``INTERNAL_URL_{APP}`` first, falls back to ``http://{app}:8000``. + """ + env_key = f"INTERNAL_URL_{app_name.upper()}" + return os.getenv(env_key, f"http://{app_name}:8000").rstrip("/") + + +async def fetch_fragment( + app_name: str, + fragment_type: str, + *, + params: dict | None = None, + timeout: float = _DEFAULT_TIMEOUT, + required: bool = False, +) -> str: + """Fetch an HTML fragment from a coop app. + + Returns empty string on failure by default (required=False). + """ + base = _internal_url(app_name) + url = f"{base}/internal/fragments/{fragment_type}" + try: + resp = await _get_client().get( + url, + params=params, + headers={FRAGMENT_HEADER: "1"}, + timeout=timeout, + ) + if resp.status_code == 200: + return resp.text + msg = f"Fragment {app_name}/{fragment_type} returned {resp.status_code}" + log.warning(msg) + if required: + raise RuntimeError(msg) + return "" + except RuntimeError: + raise + except Exception as exc: + msg = f"Fragment {app_name}/{fragment_type} failed: {exc}" + log.warning(msg) + if required: + raise RuntimeError(msg) from exc + return "" + + +async def fetch_fragments( + requests: Sequence[tuple[str, str, dict | None]], + *, + timeout: float = _DEFAULT_TIMEOUT, + required: bool = False, +) -> list[str]: + """Fetch multiple fragments concurrently.""" + return list(await asyncio.gather(*( + fetch_fragment(app, ftype, params=params, timeout=timeout, required=required) + for app, ftype, params in requests + ))) diff --git a/artdag_common/middleware/__init__.py b/artdag_common/middleware/__init__.py new file mode 100644 index 0000000..185158d --- /dev/null +++ b/artdag_common/middleware/__init__.py @@ -0,0 +1,16 @@ +""" +Middleware and FastAPI dependencies for Art-DAG servers. +""" + +from .auth import UserContext, get_user_from_cookie, get_user_from_header, require_auth +from .content_negotiation import wants_html, wants_json, ContentType + +__all__ = [ + "UserContext", + "get_user_from_cookie", + "get_user_from_header", + "require_auth", + "wants_html", + "wants_json", + "ContentType", +] diff --git a/artdag_common/middleware/__pycache__/auth.cpython-310.pyc b/artdag_common/middleware/__pycache__/auth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a262852ded9cf12140f6429bd9d9817d96a8657 GIT binary patch literal 6751 zcmcIp&vP8db)MflyI2q)2tkA`S#3z7urax$6o#>3S|&qLzpOwC37N@OsDb!g8f@Oj03p)7}4pt9-~c<;xuHgNxT3RV6Cl>z>6fV9-8g!0GMj ze*OD<-+Ql_&(GH!JfHk`C*AmtW}hlPuHaP{UB}^0%i%8fGPmcoJeTr*=J$eD z&bZgDh^vs1LGgtBT(+tM%qubG>@2-kWdD8<{9O(OYOOxQ_7Nad?Ht z4;&tU=6&Y27I{_7@tO$5;;PdP_#Cf42wEqPl+N=LC|%+U;*^nHHnNNSBubyF2GynN z5K)P7 zX(!QX-e-Rgt2p5&N%(@i1UA!T9Dn&=%5NeS3mFUoTvrmAJD|ppgNQBdJfi>@AAgtx%^g`}k zcFeDT+2KJe5)~1wIw%iO9`i^9I3q$(0Q6MNzE_L6kFhD&^1j9v?0t_uu%+gUHbo{K zEja7Symw@an?*c?a6jo`<&er{JLQ{2sPns`Uqn*qq3m~{4+jtBeZ1H3sydRfbH{n) za`&+d1bBPr-VvJdR1LCZ#He4kJ{f`Srg|qIYIcPUr0Ayi8_loWD-BOlH${{wEz<$E zFi;}ve(A_lDE{h)=Xdg+IKQ1{LY+^fruUqM3uSq--7deTP&-dJY7Rz4y_`s!CUTl) z>>;W9d#7jit2ryws1+SX@!z81Q@KtHIdEP<70Xl`|jA48+)&4?~$*AN8uO;^w{0^#vTsKkaPyOS|1=sg^V zGS{_n_?f#OjU$lWIK0Cs?PgK{O5BKqm%`Q222X-NF$a4^WFR!r~aWeqV)FVSkl2)y28 zC8Wys8|-$X#HEYpz*srhgS~(A!%gni?V++OBya^Uc&Z6_uI-!t z-$4IBX4`1j29wDyd;iuJGoMCHIJ-R}qO&bk!_NDv+%6|bt0=*+!QM`kzWUL7lL9My z!)WY`_zcil+GY2WEalV15BvNq+syldKZtu+Nj zZMrHb#PD_zn$7sqH_dD#0{3QPuBg}}QH8%PbP?Q9dA|s=JmIQ{>}D6$+m|kKK^t!I zzd*bGmMOR{WiE@_Pm^pYY(D;p7}?xO+opYISGU3IX-_pO@B^`NA6{Kak!dHRc=E(a+8rj9gr19!z;KB$dQ zAAKD}@(naUq{I;lLcE4ie}T+IiJ>?isqrIHS6BXc45i`WFJaO5Jq_iBlKD`XACG;g zQv`KdGr5Xs7nppE(z7NgAc0x6deAWVLyJVE9d_A5 z!kCba0om$FC2iW{R^-7w02%u!h$T>Cuy@$-+xv}#I4htaVY5?>zaQt0gF5GfBSejr; z5al8208$>VIHw(gDi0y`Du4#dvh4gV!FXbwU?U#82X-dXm9971m{=L=b=PTB5;)VA z`7Ik5_O$AI;9G!k({AZ~mSEi4NfpCCXmfX%8E3L|iqihJH6Cfdo0tBPg`q*=E<9O= zCK;RUmAcx2Gn>I4h)fAc$%%U~CFJD8cT&JXgq&PX0nM%5%J*C<2QRP`Ajldj# z5e^8EMP-nTh?^2whrY!WXZ#u)WGfJ?t) zA|IVwV@qS3l`Df$coZG@*4zM$?D^~_Y_MB9f($WQu+EO?+$~MD4S|RXj_@o@GnyfyKXpFoI6S=nO2_g(l1~xI?7LsMa#h2I!8(x(y84Aj z@7TQqns^U!^}A5}v8Pr^7F6L39r^MXc(pfabJO|KMYJUk*vQyS6c}Ufcl7 z9O7ZxaH{)r z%{HP~obV3^)hb3AL$5bj+Gi4Rq9|#WLVY88{eIFLWWqWX)b>g;7|@xluboGnr4e-t zzGgiLs^#YT78iev6s4Ek-F9QnsP4Q z>Bu#_hP}0;K?Pl08E6kI4&K&2E~bFV2O&`c@E0Y&J_dhX6IC3=A-2fca8M@ZKo>qZ zCOfHP8GKSXJ7YvsG;w=jfw*zD6N`JgPDxVSMVWP)5$VjIL8&uZb3mnk2EtrqLJOcc z)uN})|BQZc5wm|RmaTA`WY$fe1WbCUWP`njbOz1DSX9wCdq|cryYUwkz^rtA&2klQ z-yW?srn+@RF`7-Q&6cAB>fTxmI{wK+d7P{BKX?^|J)fH6mhT{IkO^cH$IF|Of*CSj zF!sthCEpdV8h(T*Y@NMD&0C}*!e>p|BeP- zLJ}ZE;&SUpc<*|cn66a2gSPeu@?hfAYVT2pz2uf4g`>i?R zo49}@12)4PHhN_M_GD(y(hM{0*}Db15SHJ*19m-t&=fv!Q0zS;)DL&A(%v5s9LrU2 zuxloErf-H_Aa=*CXx`_Ft%IABVW!z%Utl~L;lhI&N`~!_6s~N+flcl=qQbu`Mnwpt zKNONIUlH$Ujk#u{B34jC1&9|Bx*bZf*{B+qw+Q*LH&8R7l5v0OcEMif6w&0@-W(th zBY(ln7NmNMW?_&>BgzXjY~9w9|A(2JqH>cMtG+wIPal8CO|x+gSGLf06St}J zC0~^5a_zc~O#NhsdYqm7942CJOBX%4jjkI0=GGRvCDzB!PwD4Hxr(1+?%*a}C`meA zM0zx!Xn_`DeItUqB$YQYh&`W57dKeRAhj+ODT7hY;$&IgoPbqH`kx!)Ul~`>+B@1= zTMIvkUiMz1lMB8+J)`(buP%P~g%?*={|`L!>*W9d literal 0 HcmV?d00001 diff --git a/artdag_common/middleware/auth.py b/artdag_common/middleware/auth.py new file mode 100644 index 0000000..b227894 --- /dev/null +++ b/artdag_common/middleware/auth.py @@ -0,0 +1,276 @@ +""" +Authentication middleware and dependencies. + +Provides common authentication patterns for L1 and L2 servers. +Each server can extend or customize these as needed. +""" + +from dataclasses import dataclass +from typing import Callable, Optional, Awaitable, Any +import base64 +import json + +from fastapi import Request, HTTPException, Depends +from fastapi.responses import RedirectResponse + + +@dataclass +class UserContext: + """User context extracted from authentication.""" + username: str + actor_id: str # Full actor ID like "@user@server.com" + token: Optional[str] = None + l2_server: Optional[str] = None # L2 server URL for this user + email: Optional[str] = None # User's email address + + @property + def display_name(self) -> str: + """Get display name (username without @ prefix).""" + return self.username.lstrip("@") + + +def get_user_from_cookie(request: Request) -> Optional[UserContext]: + """ + Extract user context from session cookie. + + Supports two cookie formats: + 1. artdag_session: base64-encoded JSON {"username": "user", "actor_id": "@user@server.com"} + 2. auth_token: raw JWT token (used by L1 servers) + + Args: + request: FastAPI request + + Returns: + UserContext if valid cookie found, None otherwise + """ + # Try artdag_session cookie first (base64-encoded JSON) + cookie = request.cookies.get("artdag_session") + if cookie: + try: + data = json.loads(base64.b64decode(cookie)) + username = data.get("username", "") + actor_id = data.get("actor_id", "") + if not actor_id and username: + actor_id = f"@{username}" + return UserContext( + username=username, + actor_id=actor_id, + email=data.get("email", ""), + ) + except (json.JSONDecodeError, ValueError, KeyError): + pass + + # Try auth_token cookie (raw JWT, used by L1) + token = request.cookies.get("auth_token") + if token: + claims = decode_jwt_claims(token) + if claims: + username = claims.get("username") or claims.get("sub", "") + actor_id = claims.get("actor_id") or claims.get("actor") + if not actor_id and username: + actor_id = f"@{username}" + return UserContext( + username=username, + actor_id=actor_id or "", + token=token, + email=claims.get("email", ""), + ) + + return None + + +def get_user_from_header(request: Request) -> Optional[UserContext]: + """ + Extract user context from Authorization header. + + Supports: + - Bearer format (JWT or opaque token) + - Basic format + + Args: + request: FastAPI request + + Returns: + UserContext if valid header found, None otherwise + """ + auth_header = request.headers.get("Authorization", "") + + if auth_header.startswith("Bearer "): + token = auth_header[7:] + # Attempt to decode JWT claims + claims = decode_jwt_claims(token) + if claims: + username = claims.get("username") or claims.get("sub", "") + actor_id = claims.get("actor_id") or claims.get("actor") + # Default actor_id to @username if not provided + if not actor_id and username: + actor_id = f"@{username}" + return UserContext( + username=username, + actor_id=actor_id or "", + token=token, + ) + + return None + + +def decode_jwt_claims(token: str) -> Optional[dict]: + """ + Decode JWT claims without verification. + + This is useful for extracting user info from a token + when full verification is handled elsewhere. + + Args: + token: JWT token string + + Returns: + Claims dict if valid JWT format, None otherwise + """ + try: + parts = token.split(".") + if len(parts) != 3: + return None + + # Decode payload (second part) + payload = parts[1] + # Add padding if needed + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + + return json.loads(base64.urlsafe_b64decode(payload)) + except (json.JSONDecodeError, ValueError): + return None + + +def create_auth_dependency( + token_validator: Optional[Callable[[str], Awaitable[Optional[dict]]]] = None, + allow_cookie: bool = True, + allow_header: bool = True, +): + """ + Create a customized auth dependency for a specific server. + + Args: + token_validator: Optional async function to validate tokens with backend + allow_cookie: Whether to check cookies for auth + allow_header: Whether to check Authorization header + + Returns: + FastAPI dependency function + """ + async def get_current_user(request: Request) -> Optional[UserContext]: + ctx = None + + # Try header first (API clients) + if allow_header: + ctx = get_user_from_header(request) + if ctx and token_validator: + # Validate token with backend + validated = await token_validator(ctx.token) + if not validated: + ctx = None + + # Fall back to cookie (browser) + if ctx is None and allow_cookie: + ctx = get_user_from_cookie(request) + + return ctx + + return get_current_user + + +async def require_auth(request: Request) -> UserContext: + """ + Dependency that requires authentication. + + Raises HTTPException 401 if not authenticated. + Use with Depends() in route handlers. + + Example: + @app.get("/protected") + async def protected_route(user: UserContext = Depends(require_auth)): + return {"user": user.username} + """ + # Try header first + ctx = get_user_from_header(request) + if ctx is None: + ctx = get_user_from_cookie(request) + + if ctx is None: + # Check Accept header to determine response type + accept = request.headers.get("accept", "") + if "text/html" in accept: + raise HTTPException( + status_code=302, + headers={"Location": "/login"} + ) + raise HTTPException( + status_code=401, + detail="Authentication required" + ) + + return ctx + + +def require_owner(resource_owner_field: str = "username"): + """ + Dependency factory that requires the user to own the resource. + + Args: + resource_owner_field: Field name on the resource that contains owner username + + Returns: + Dependency function + + Example: + @app.delete("/items/{item_id}") + async def delete_item( + item: Item = Depends(get_item), + user: UserContext = Depends(require_owner("created_by")) + ): + ... + """ + async def check_ownership( + request: Request, + user: UserContext = Depends(require_auth), + ) -> UserContext: + # The actual ownership check must be done in the route + # after fetching the resource + return user + + return check_ownership + + +def set_auth_cookie(response: Any, user: UserContext, max_age: int = 86400 * 30) -> None: + """ + Set authentication cookie on response. + + Args: + response: FastAPI response object + user: User context to store + max_age: Cookie max age in seconds (default 30 days) + """ + cookie_data = { + "username": user.username, + "actor_id": user.actor_id, + } + if user.email: + cookie_data["email"] = user.email + data = json.dumps(cookie_data) + cookie_value = base64.b64encode(data.encode()).decode() + + response.set_cookie( + key="artdag_session", + value=cookie_value, + max_age=max_age, + httponly=True, + samesite="lax", + secure=True, # Require HTTPS in production + ) + + +def clear_auth_cookie(response: Any) -> None: + """Clear authentication cookie.""" + response.delete_cookie(key="artdag_session") diff --git a/artdag_common/middleware/content_negotiation.py b/artdag_common/middleware/content_negotiation.py new file mode 100644 index 0000000..aaa47c8 --- /dev/null +++ b/artdag_common/middleware/content_negotiation.py @@ -0,0 +1,174 @@ +""" +Content negotiation utilities. + +Helps determine what response format the client wants. +""" + +from enum import Enum +from typing import Optional + +from fastapi import Request + + +class ContentType(Enum): + """Response content types.""" + HTML = "text/html" + JSON = "application/json" + ACTIVITY_JSON = "application/activity+json" + XML = "application/xml" + + +def wants_html(request: Request) -> bool: + """ + Check if the client wants HTML response. + + Returns True if: + - Accept header contains text/html + - Accept header contains application/xhtml+xml + - No Accept header (browser default) + + Args: + request: FastAPI request + + Returns: + True if HTML is preferred + """ + accept = request.headers.get("accept", "") + + # No accept header usually means browser + if not accept: + return True + + # Check for HTML preferences + if "text/html" in accept: + return True + if "application/xhtml" in accept: + return True + + return False + + +def wants_json(request: Request) -> bool: + """ + Check if the client wants JSON response. + + Returns True if: + - Accept header contains application/json + - Accept header does NOT contain text/html + - Request has .json suffix (convention) + + Args: + request: FastAPI request + + Returns: + True if JSON is preferred + """ + accept = request.headers.get("accept", "") + + # Explicit JSON preference + if "application/json" in accept: + # But not if HTML is also requested (browsers often send both) + if "text/html" not in accept: + return True + + # Check URL suffix convention + if request.url.path.endswith(".json"): + return True + + return False + + +def wants_activity_json(request: Request) -> bool: + """ + Check if the client wants ActivityPub JSON-LD response. + + Used for federation with other ActivityPub servers. + + Args: + request: FastAPI request + + Returns: + True if ActivityPub format is preferred + """ + accept = request.headers.get("accept", "") + + if "application/activity+json" in accept: + return True + if "application/ld+json" in accept: + return True + + return False + + +def get_preferred_type(request: Request) -> ContentType: + """ + Determine the preferred content type from Accept header. + + Args: + request: FastAPI request + + Returns: + ContentType enum value + """ + if wants_activity_json(request): + return ContentType.ACTIVITY_JSON + if wants_json(request): + return ContentType.JSON + return ContentType.HTML + + +def is_htmx_request(request: Request) -> bool: + """ + Check if this is an HTMX request (partial page update). + + HTMX requests set the HX-Request header. + + Args: + request: FastAPI request + + Returns: + True if this is an HTMX request + """ + return request.headers.get("HX-Request") == "true" + + +def get_htmx_target(request: Request) -> Optional[str]: + """ + Get the HTMX target element ID. + + Args: + request: FastAPI request + + Returns: + Target element ID or None + """ + return request.headers.get("HX-Target") + + +def get_htmx_trigger(request: Request) -> Optional[str]: + """ + Get the HTMX trigger element ID. + + Args: + request: FastAPI request + + Returns: + Trigger element ID or None + """ + return request.headers.get("HX-Trigger") + + +def is_ios_request(request: Request) -> bool: + """ + Check if request is from iOS device. + + Useful for video format selection (iOS prefers MP4). + + Args: + request: FastAPI request + + Returns: + True if iOS user agent detected + """ + user_agent = request.headers.get("user-agent", "").lower() + return "iphone" in user_agent or "ipad" in user_agent diff --git a/artdag_common/models/__init__.py b/artdag_common/models/__init__.py new file mode 100644 index 0000000..d0d43c7 --- /dev/null +++ b/artdag_common/models/__init__.py @@ -0,0 +1,25 @@ +""" +Shared Pydantic models for Art-DAG servers. +""" + +from .requests import ( + PaginationParams, + PublishRequest, + StorageConfigRequest, + MetadataUpdateRequest, +) +from .responses import ( + PaginatedResponse, + ErrorResponse, + SuccessResponse, +) + +__all__ = [ + "PaginationParams", + "PublishRequest", + "StorageConfigRequest", + "MetadataUpdateRequest", + "PaginatedResponse", + "ErrorResponse", + "SuccessResponse", +] diff --git a/artdag_common/models/requests.py b/artdag_common/models/requests.py new file mode 100644 index 0000000..1c34d45 --- /dev/null +++ b/artdag_common/models/requests.py @@ -0,0 +1,74 @@ +""" +Request models shared across L1 and L2 servers. +""" + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + +from ..constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE + + +class PaginationParams(BaseModel): + """Common pagination parameters.""" + page: int = Field(default=1, ge=1, description="Page number (1-indexed)") + limit: int = Field( + default=DEFAULT_PAGE_SIZE, + ge=1, + le=MAX_PAGE_SIZE, + description="Items per page" + ) + + @property + def offset(self) -> int: + """Calculate offset for database queries.""" + return (self.page - 1) * self.limit + + +class PublishRequest(BaseModel): + """Request to publish content to L2/storage.""" + name: str = Field(..., min_length=1, max_length=255) + description: Optional[str] = Field(default=None, max_length=2000) + tags: List[str] = Field(default_factory=list) + storage_id: Optional[str] = Field(default=None, description="Target storage provider") + + +class MetadataUpdateRequest(BaseModel): + """Request to update content metadata.""" + name: Optional[str] = Field(default=None, max_length=255) + description: Optional[str] = Field(default=None, max_length=2000) + tags: Optional[List[str]] = Field(default=None) + metadata: Optional[Dict[str, Any]] = Field(default=None) + + +class StorageConfigRequest(BaseModel): + """Request to configure a storage provider.""" + provider_type: str = Field(..., description="Provider type (pinata, web3storage, local, etc.)") + name: str = Field(..., min_length=1, max_length=100) + api_key: Optional[str] = Field(default=None) + api_secret: Optional[str] = Field(default=None) + endpoint: Optional[str] = Field(default=None) + config: Optional[Dict[str, Any]] = Field(default_factory=dict) + is_default: bool = Field(default=False) + + +class RecipeRunRequest(BaseModel): + """Request to run a recipe.""" + recipe_id: str = Field(..., description="Recipe content hash or ID") + inputs: Dict[str, str] = Field(..., description="Map of input name to content hash") + features: List[str] = Field( + default=["beats", "energy"], + description="Analysis features to extract" + ) + + +class PlanRequest(BaseModel): + """Request to generate an execution plan.""" + recipe_yaml: str = Field(..., description="Recipe YAML content") + input_hashes: Dict[str, str] = Field(..., description="Map of input name to content hash") + features: List[str] = Field(default=["beats", "energy"]) + + +class ExecutePlanRequest(BaseModel): + """Request to execute a generated plan.""" + plan_json: str = Field(..., description="JSON-serialized execution plan") + run_id: Optional[str] = Field(default=None, description="Optional run ID for tracking") diff --git a/artdag_common/models/responses.py b/artdag_common/models/responses.py new file mode 100644 index 0000000..447e70c --- /dev/null +++ b/artdag_common/models/responses.py @@ -0,0 +1,96 @@ +""" +Response models shared across L1 and L2 servers. +""" + +from typing import Optional, List, Dict, Any, Generic, TypeVar +from pydantic import BaseModel, Field + +T = TypeVar("T") + + +class PaginatedResponse(BaseModel, Generic[T]): + """Generic paginated response.""" + data: List[Any] = Field(default_factory=list) + pagination: Dict[str, Any] = Field(default_factory=dict) + + @classmethod + def create( + cls, + items: List[Any], + page: int, + limit: int, + total: int, + ) -> "PaginatedResponse": + """Create a paginated response.""" + return cls( + data=items, + pagination={ + "page": page, + "limit": limit, + "total": total, + "has_more": page * limit < total, + "total_pages": (total + limit - 1) // limit, + } + ) + + +class ErrorResponse(BaseModel): + """Standard error response.""" + error: str = Field(..., description="Error message") + detail: Optional[str] = Field(default=None, description="Detailed error info") + code: Optional[str] = Field(default=None, description="Error code") + + +class SuccessResponse(BaseModel): + """Standard success response.""" + success: bool = Field(default=True) + message: Optional[str] = Field(default=None) + data: Optional[Dict[str, Any]] = Field(default=None) + + +class RunStatus(BaseModel): + """Run execution status.""" + run_id: str + status: str = Field(..., description="pending, running, completed, failed") + recipe: Optional[str] = None + plan_id: Optional[str] = None + output_hash: Optional[str] = None + output_ipfs_cid: Optional[str] = None + total_steps: int = 0 + cached_steps: int = 0 + completed_steps: int = 0 + error: Optional[str] = None + + +class CacheItemResponse(BaseModel): + """Cached content item response.""" + content_hash: str + media_type: Optional[str] = None + size: Optional[int] = None + name: Optional[str] = None + description: Optional[str] = None + tags: List[str] = Field(default_factory=list) + ipfs_cid: Optional[str] = None + created_at: Optional[str] = None + + +class RecipeResponse(BaseModel): + """Recipe response.""" + recipe_id: str + name: str + description: Optional[str] = None + inputs: List[Dict[str, Any]] = Field(default_factory=list) + outputs: List[str] = Field(default_factory=list) + node_count: int = 0 + created_at: Optional[str] = None + + +class StorageProviderResponse(BaseModel): + """Storage provider configuration response.""" + storage_id: str + provider_type: str + name: str + is_default: bool = False + is_connected: bool = False + usage_bytes: Optional[int] = None + pin_count: int = 0 diff --git a/artdag_common/rendering.py b/artdag_common/rendering.py new file mode 100644 index 0000000..e5edacf --- /dev/null +++ b/artdag_common/rendering.py @@ -0,0 +1,160 @@ +""" +Jinja2 template rendering system for Art-DAG servers. + +Provides a unified template environment that can load from: +1. The shared artdag_common/templates directory +2. App-specific template directories + +Usage: + from artdag_common import create_jinja_env, render + + # In app initialization + templates = create_jinja_env("app/templates") + + # In route handler + return render(templates, "runs/detail.html", request, run=run, user=user) +""" + +from pathlib import Path +from typing import Any, Optional, Union + +from fastapi import Request +from fastapi.responses import HTMLResponse +from jinja2 import Environment, ChoiceLoader, FileSystemLoader, PackageLoader, select_autoescape + +from .constants import ( + TAILWIND_CDN, + HTMX_CDN, + CYTOSCAPE_CDN, + DAGRE_CDN, + CYTOSCAPE_DAGRE_CDN, + TAILWIND_CONFIG, + NODE_COLORS, + STATUS_COLORS, +) + + +def create_jinja_env(*template_dirs: Union[str, Path]) -> Environment: + """ + Create a Jinja2 environment with the shared templates and optional app-specific dirs. + + Args: + *template_dirs: Additional template directories to search (app-specific) + + Returns: + Configured Jinja2 Environment + + Example: + env = create_jinja_env("/app/templates", "/app/custom") + """ + loaders = [] + + # Add app-specific directories first (higher priority) + for template_dir in template_dirs: + path = Path(template_dir) + if path.exists(): + loaders.append(FileSystemLoader(str(path))) + + # Add shared templates from this package (lower priority, fallback) + loaders.append(PackageLoader("artdag_common", "templates")) + + env = Environment( + loader=ChoiceLoader(loaders), + autoescape=select_autoescape(["html", "xml"]), + trim_blocks=True, + lstrip_blocks=True, + ) + + # Add global context available to all templates + env.globals.update({ + "TAILWIND_CDN": TAILWIND_CDN, + "HTMX_CDN": HTMX_CDN, + "CYTOSCAPE_CDN": CYTOSCAPE_CDN, + "DAGRE_CDN": DAGRE_CDN, + "CYTOSCAPE_DAGRE_CDN": CYTOSCAPE_DAGRE_CDN, + "TAILWIND_CONFIG": TAILWIND_CONFIG, + "NODE_COLORS": NODE_COLORS, + "STATUS_COLORS": STATUS_COLORS, + }) + + # Add custom filters + env.filters["truncate_hash"] = truncate_hash + env.filters["format_size"] = format_size + env.filters["status_color"] = status_color + + return env + + +def render( + env: Environment, + template_name: str, + request: Request, + status_code: int = 200, + **context: Any, +) -> HTMLResponse: + """ + Render a template to an HTMLResponse. + + Args: + env: Jinja2 environment + template_name: Template file path (e.g., "runs/detail.html") + request: FastAPI request object + status_code: HTTP status code (default 200) + **context: Template context variables + + Returns: + HTMLResponse with rendered content + """ + template = env.get_template(template_name) + html = template.render(request=request, **context) + return HTMLResponse(html, status_code=status_code) + + +def render_fragment( + env: Environment, + template_name: str, + **context: Any, +) -> str: + """ + Render a template fragment to a string (for HTMX partial updates). + + Args: + env: Jinja2 environment + template_name: Template file path + **context: Template context variables + + Returns: + Rendered HTML string + """ + template = env.get_template(template_name) + return template.render(**context) + + +# Custom Jinja2 filters + +def truncate_hash(value: str, length: int = 16) -> str: + """Truncate a hash to specified length with ellipsis.""" + if not value: + return "" + if len(value) <= length: + return value + return f"{value[:length]}..." + + +def format_size(size_bytes: Optional[int]) -> str: + """Format file size in human-readable form.""" + if size_bytes is None: + return "Unknown" + if size_bytes < 1024: + return f"{size_bytes} B" + elif size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f} KB" + elif size_bytes < 1024 * 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f} MB" + else: + return f"{size_bytes / (1024 * 1024 * 1024):.1f} GB" + + +def status_color(status: str) -> str: + """Get Tailwind CSS class for a status.""" + return STATUS_COLORS.get(status, STATUS_COLORS["pending"]) diff --git a/artdag_common/templates/_base.html b/artdag_common/templates/_base.html new file mode 100644 index 0000000..deeb67b --- /dev/null +++ b/artdag_common/templates/_base.html @@ -0,0 +1,91 @@ + + + + + + {% block title %}Art-DAG{% endblock %} + + + + + + + + + + + + {% block head %}{% endblock %} + + + + + + {% block header %} + {# Coop-style header: sky banner with title, nav-tree, auth-menu, cart-mini #} +
+
+
+ {# Cart mini #} + {% block cart_mini %}{% endblock %} + + {# Site title #} + + + {# Desktop nav: nav-tree + auth-menu #} + +
+
+ {# Mobile auth #} +
+ {% block auth_menu_mobile %}{% endblock %} +
+
+ {% endblock %} + + {# App-specific sub-nav (Runs, Recipes, Effects, etc.) #} + {% block sub_nav %}{% endblock %} + +
+ {% block content %}{% endblock %} +
+ + {% block footer %}{% endblock %} + {% block scripts %}{% endblock %} + + diff --git a/artdag_common/templates/components/badge.html b/artdag_common/templates/components/badge.html new file mode 100644 index 0000000..8c9f484 --- /dev/null +++ b/artdag_common/templates/components/badge.html @@ -0,0 +1,64 @@ +{# +Badge component for status and type indicators. + +Usage: + {% from "components/badge.html" import badge, status_badge, type_badge %} + + {{ badge("Active", "green") }} + {{ status_badge("completed") }} + {{ type_badge("EFFECT") }} +#} + +{% macro badge(text, color="gray", class="") %} + + {{ text }} + +{% endmacro %} + +{% macro status_badge(status, class="") %} +{% set colors = { + "completed": "green", + "cached": "blue", + "running": "yellow", + "pending": "gray", + "failed": "red", + "active": "green", + "inactive": "gray", +} %} +{% set color = colors.get(status, "gray") %} + + {% if status == "running" %} + + + + + {% endif %} + {{ status | capitalize }} + +{% endmacro %} + +{% macro type_badge(node_type, class="") %} +{% set colors = { + "SOURCE": "blue", + "EFFECT": "green", + "OUTPUT": "purple", + "ANALYSIS": "amber", + "_LIST": "indigo", +} %} +{% set color = colors.get(node_type, "gray") %} + + {{ node_type }} + +{% endmacro %} + +{% macro role_badge(role, class="") %} +{% set colors = { + "input": "blue", + "output": "purple", + "intermediate": "gray", +} %} +{% set color = colors.get(role, "gray") %} + + {{ role | capitalize }} + +{% endmacro %} diff --git a/artdag_common/templates/components/card.html b/artdag_common/templates/components/card.html new file mode 100644 index 0000000..04f5c54 --- /dev/null +++ b/artdag_common/templates/components/card.html @@ -0,0 +1,45 @@ +{# +Card component for displaying information. + +Usage: + {% include "components/card.html" with title="Status", content="Active", class="col-span-2" %} + +Or as a block: + {% call card(title="Details") %} +

Card content here

+ {% endcall %} +#} + +{% macro card(title=None, class="") %} +
+ {% if title %} +

{{ title }}

+ {% endif %} +
+ {{ caller() if caller else "" }} +
+
+{% endmacro %} + +{% macro stat_card(title, value, color="white", class="") %} +
+
{{ value }}
+
{{ title }}
+
+{% endmacro %} + +{% macro info_card(title, items, class="") %} +
+ {% if title %} +

{{ title }}

+ {% endif %} +
+ {% for label, value in items %} +
+
{{ label }}
+
{{ value }}
+
+ {% endfor %} +
+
+{% endmacro %} diff --git a/artdag_common/templates/components/dag.html b/artdag_common/templates/components/dag.html new file mode 100644 index 0000000..fa17fdc --- /dev/null +++ b/artdag_common/templates/components/dag.html @@ -0,0 +1,176 @@ +{# +Cytoscape.js DAG visualization component. + +Usage: + {% from "components/dag.html" import dag_container, dag_scripts, dag_legend %} + + {# In head block #} + {{ dag_scripts() }} + + {# In content #} + {{ dag_container(id="plan-dag", height="400px") }} + {{ dag_legend() }} + + {# In scripts block #} + +#} + +{% macro dag_scripts() %} + + + + +{% endmacro %} + +{% macro dag_container(id="dag-container", height="400px", class="") %} +
+ +{% endmacro %} + +{% macro dag_legend(node_types=None) %} +{% set types = node_types or ["SOURCE", "EFFECT", "_LIST"] %} +
+ {% for type in types %} + + + {{ type }} + + {% endfor %} + + + Cached + +
+{% endmacro %} diff --git a/artdag_common/templates/components/media_preview.html b/artdag_common/templates/components/media_preview.html new file mode 100644 index 0000000..ec810ae --- /dev/null +++ b/artdag_common/templates/components/media_preview.html @@ -0,0 +1,98 @@ +{# +Media preview component for videos, images, and audio. + +Usage: + {% from "components/media_preview.html" import media_preview, video_player, image_preview, audio_player %} + + {{ media_preview(content_hash, media_type, title="Preview") }} + {{ video_player(src="/cache/abc123/mp4", poster="/cache/abc123/thumb") }} +#} + +{% macro media_preview(content_hash, media_type, title=None, class="", show_download=True) %} +
+ {% if title %} +
+

{{ title }}

+
+ {% endif %} + +
+ {% if media_type == "video" %} + {{ video_player("/cache/" + content_hash + "/mp4") }} + {% elif media_type == "image" %} + {{ image_preview("/cache/" + content_hash + "/raw") }} + {% elif media_type == "audio" %} + {{ audio_player("/cache/" + content_hash + "/raw") }} + {% else %} +
+ + + +

Preview not available

+
+ {% endif %} +
+ + {% if show_download %} + + {% endif %} +
+{% endmacro %} + +{% macro video_player(src, poster=None, autoplay=False, muted=True, loop=False, class="") %} + +{% endmacro %} + +{% macro image_preview(src, alt="", class="") %} +{{ alt }} +{% endmacro %} + +{% macro audio_player(src, class="") %} +
+ +
+{% endmacro %} + +{% macro thumbnail(content_hash, media_type, size="w-24 h-24", class="") %} +
+ {% if media_type == "image" %} + + {% elif media_type == "video" %} + + + + {% elif media_type == "audio" %} + + + + {% else %} + + + + {% endif %} +
+{% endmacro %} diff --git a/artdag_common/templates/components/pagination.html b/artdag_common/templates/components/pagination.html new file mode 100644 index 0000000..ec1b4a5 --- /dev/null +++ b/artdag_common/templates/components/pagination.html @@ -0,0 +1,82 @@ +{# +Pagination component with HTMX infinite scroll support. + +Usage: + {% from "components/pagination.html" import infinite_scroll_trigger, page_links %} + + {# Infinite scroll (HTMX) #} + {{ infinite_scroll_trigger(url="/items?page=2", colspan=3, has_more=True) }} + + {# Traditional pagination #} + {{ page_links(current_page=1, total_pages=5, base_url="/items") }} +#} + +{% macro infinite_scroll_trigger(url, colspan=1, has_more=True, target=None) %} +{% if has_more %} + + + + + + + + Loading more... + + + +{% endif %} +{% endmacro %} + +{% macro page_links(current_page, total_pages, base_url, class="") %} + +{% endmacro %} + +{% macro page_info(page, limit, total) %} +
+ Showing {{ (page - 1) * limit + 1 }}-{{ [page * limit, total] | min }} of {{ total }} +
+{% endmacro %} diff --git a/artdag_common/templates/components/table.html b/artdag_common/templates/components/table.html new file mode 100644 index 0000000..1c00fc4 --- /dev/null +++ b/artdag_common/templates/components/table.html @@ -0,0 +1,51 @@ +{# +Table component with dark theme styling. + +Usage: + {% from "components/table.html" import table, table_row %} + + {% call table(columns=["Name", "Status", "Actions"]) %} + {% for item in items %} + {{ table_row([item.name, item.status, actions_html]) }} + {% endfor %} + {% endcall %} +#} + +{% macro table(columns, class="", id="") %} +
+ + + + {% for col in columns %} + + {% endfor %} + + + + {{ caller() }} + +
{{ col }}
+
+{% endmacro %} + +{% macro table_row(cells, class="", href=None) %} + + {% for cell in cells %} + + {% if href and loop.first %} + {{ cell }} + {% else %} + {{ cell | safe }} + {% endif %} + + {% endfor %} + +{% endmacro %} + +{% macro empty_row(colspan, message="No items found") %} + + + {{ message }} + + +{% endmacro %} diff --git a/artdag_common/utils/__init__.py b/artdag_common/utils/__init__.py new file mode 100644 index 0000000..192edfa --- /dev/null +++ b/artdag_common/utils/__init__.py @@ -0,0 +1,19 @@ +""" +Utility functions shared across Art-DAG servers. +""" + +from .pagination import paginate, get_pagination_params +from .media import detect_media_type, get_media_extension, is_streamable +from .formatting import format_date, format_size, truncate_hash, format_duration + +__all__ = [ + "paginate", + "get_pagination_params", + "detect_media_type", + "get_media_extension", + "is_streamable", + "format_date", + "format_size", + "truncate_hash", + "format_duration", +] diff --git a/artdag_common/utils/formatting.py b/artdag_common/utils/formatting.py new file mode 100644 index 0000000..3dcc3a8 --- /dev/null +++ b/artdag_common/utils/formatting.py @@ -0,0 +1,165 @@ +""" +Formatting utilities for display. +""" + +from datetime import datetime +from typing import Optional, Union + + +def format_date( + value: Optional[Union[str, datetime]], + length: int = 10, + include_time: bool = False, +) -> str: + """ + Format a date/datetime for display. + + Args: + value: Date string or datetime object + length: Length to truncate to (default 10 for YYYY-MM-DD) + include_time: Whether to include time portion + + Returns: + Formatted date string + """ + if value is None: + return "" + + if isinstance(value, str): + # Parse ISO format string + try: + if "T" in value: + dt = datetime.fromisoformat(value.replace("Z", "+00:00")) + else: + return value[:length] + except ValueError: + return value[:length] + else: + dt = value + + if include_time: + return dt.strftime("%Y-%m-%d %H:%M") + return dt.strftime("%Y-%m-%d") + + +def format_size(size_bytes: Optional[int]) -> str: + """ + Format file size in human-readable form. + + Args: + size_bytes: Size in bytes + + Returns: + Human-readable size string (e.g., "1.5 MB") + """ + if size_bytes is None: + return "Unknown" + if size_bytes < 0: + return "Unknown" + if size_bytes == 0: + return "0 B" + + units = ["B", "KB", "MB", "GB", "TB"] + unit_index = 0 + size = float(size_bytes) + + while size >= 1024 and unit_index < len(units) - 1: + size /= 1024 + unit_index += 1 + + if unit_index == 0: + return f"{int(size)} {units[unit_index]}" + return f"{size:.1f} {units[unit_index]}" + + +def truncate_hash(value: str, length: int = 16, suffix: str = "...") -> str: + """ + Truncate a hash or long string with ellipsis. + + Args: + value: String to truncate + length: Maximum length before truncation + suffix: Suffix to add when truncated + + Returns: + Truncated string + """ + if not value: + return "" + if len(value) <= length: + return value + return f"{value[:length]}{suffix}" + + +def format_duration(seconds: Optional[float]) -> str: + """ + Format duration in human-readable form. + + Args: + seconds: Duration in seconds + + Returns: + Human-readable duration string (e.g., "2m 30s") + """ + if seconds is None or seconds < 0: + return "Unknown" + + if seconds < 1: + return f"{int(seconds * 1000)}ms" + + if seconds < 60: + return f"{seconds:.1f}s" + + minutes = int(seconds // 60) + remaining_seconds = int(seconds % 60) + + if minutes < 60: + if remaining_seconds: + return f"{minutes}m {remaining_seconds}s" + return f"{minutes}m" + + hours = minutes // 60 + remaining_minutes = minutes % 60 + + if remaining_minutes: + return f"{hours}h {remaining_minutes}m" + return f"{hours}h" + + +def format_count(count: int) -> str: + """ + Format a count with abbreviation for large numbers. + + Args: + count: Number to format + + Returns: + Formatted string (e.g., "1.2K", "3.5M") + """ + if count < 1000: + return str(count) + if count < 1000000: + return f"{count / 1000:.1f}K" + if count < 1000000000: + return f"{count / 1000000:.1f}M" + return f"{count / 1000000000:.1f}B" + + +def format_percentage(value: float, decimals: int = 1) -> str: + """ + Format a percentage value. + + Args: + value: Percentage value (0-100 or 0-1) + decimals: Number of decimal places + + Returns: + Formatted percentage string + """ + # Assume 0-1 if less than 1 + if value <= 1: + value *= 100 + + if decimals == 0: + return f"{int(value)}%" + return f"{value:.{decimals}f}%" diff --git a/artdag_common/utils/media.py b/artdag_common/utils/media.py new file mode 100644 index 0000000..ef0eaee --- /dev/null +++ b/artdag_common/utils/media.py @@ -0,0 +1,166 @@ +""" +Media type detection and handling utilities. +""" + +from pathlib import Path +from typing import Optional +import mimetypes + +# Initialize mimetypes database +mimetypes.init() + +# Media type categories +VIDEO_TYPES = {"video/mp4", "video/webm", "video/quicktime", "video/x-msvideo", "video/avi"} +IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"} +AUDIO_TYPES = {"audio/mpeg", "audio/wav", "audio/ogg", "audio/flac", "audio/aac", "audio/mp3"} + +# File extension mappings +EXTENSION_TO_CATEGORY = { + # Video + ".mp4": "video", + ".webm": "video", + ".mov": "video", + ".avi": "video", + ".mkv": "video", + # Image + ".jpg": "image", + ".jpeg": "image", + ".png": "image", + ".gif": "image", + ".webp": "image", + ".svg": "image", + # Audio + ".mp3": "audio", + ".wav": "audio", + ".ogg": "audio", + ".flac": "audio", + ".aac": "audio", + ".m4a": "audio", +} + + +def detect_media_type(path: Path) -> str: + """ + Detect the media category for a file. + + Args: + path: Path to the file + + Returns: + Category string: "video", "image", "audio", or "unknown" + """ + if not path: + return "unknown" + + # Try extension first + ext = path.suffix.lower() + if ext in EXTENSION_TO_CATEGORY: + return EXTENSION_TO_CATEGORY[ext] + + # Try mimetypes + mime_type, _ = mimetypes.guess_type(str(path)) + if mime_type: + if mime_type in VIDEO_TYPES or mime_type.startswith("video/"): + return "video" + if mime_type in IMAGE_TYPES or mime_type.startswith("image/"): + return "image" + if mime_type in AUDIO_TYPES or mime_type.startswith("audio/"): + return "audio" + + return "unknown" + + +def get_mime_type(path: Path) -> str: + """ + Get the MIME type for a file. + + Args: + path: Path to the file + + Returns: + MIME type string or "application/octet-stream" + """ + mime_type, _ = mimetypes.guess_type(str(path)) + return mime_type or "application/octet-stream" + + +def get_media_extension(media_type: str) -> str: + """ + Get the typical file extension for a media type. + + Args: + media_type: Media category or MIME type + + Returns: + File extension with dot (e.g., ".mp4") + """ + if media_type == "video": + return ".mp4" + if media_type == "image": + return ".png" + if media_type == "audio": + return ".mp3" + + # Try as MIME type + ext = mimetypes.guess_extension(media_type) + return ext or "" + + +def is_streamable(path: Path) -> bool: + """ + Check if a file type is streamable (video/audio). + + Args: + path: Path to the file + + Returns: + True if the file can be streamed + """ + media_type = detect_media_type(path) + return media_type in ("video", "audio") + + +def needs_conversion(path: Path, target_format: str = "mp4") -> bool: + """ + Check if a video file needs format conversion. + + Args: + path: Path to the file + target_format: Target format (default mp4) + + Returns: + True if conversion is needed + """ + media_type = detect_media_type(path) + if media_type != "video": + return False + + ext = path.suffix.lower().lstrip(".") + return ext != target_format + + +def get_video_src( + content_hash: str, + original_path: Optional[Path] = None, + is_ios: bool = False, +) -> str: + """ + Get the appropriate video source URL. + + For iOS devices, prefer MP4 format. + + Args: + content_hash: Content hash for the video + original_path: Optional original file path + is_ios: Whether the client is iOS + + Returns: + URL path for the video source + """ + if is_ios: + return f"/cache/{content_hash}/mp4" + + if original_path and original_path.suffix.lower() in (".mp4", ".webm"): + return f"/cache/{content_hash}/raw" + + return f"/cache/{content_hash}/mp4" diff --git a/artdag_common/utils/pagination.py b/artdag_common/utils/pagination.py new file mode 100644 index 0000000..f892f95 --- /dev/null +++ b/artdag_common/utils/pagination.py @@ -0,0 +1,85 @@ +""" +Pagination utilities. +""" + +from typing import List, Any, Tuple, Optional + +from fastapi import Request + +from ..constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE + + +def get_pagination_params(request: Request) -> Tuple[int, int]: + """ + Extract pagination parameters from request query string. + + Args: + request: FastAPI request + + Returns: + Tuple of (page, limit) + """ + try: + page = int(request.query_params.get("page", 1)) + page = max(1, page) + except ValueError: + page = 1 + + try: + limit = int(request.query_params.get("limit", DEFAULT_PAGE_SIZE)) + limit = max(1, min(limit, MAX_PAGE_SIZE)) + except ValueError: + limit = DEFAULT_PAGE_SIZE + + return page, limit + + +def paginate( + items: List[Any], + page: int = 1, + limit: int = DEFAULT_PAGE_SIZE, +) -> Tuple[List[Any], dict]: + """ + Paginate a list of items. + + Args: + items: Full list of items + page: Page number (1-indexed) + limit: Items per page + + Returns: + Tuple of (paginated items, pagination info dict) + """ + total = len(items) + start = (page - 1) * limit + end = start + limit + + paginated = items[start:end] + + return paginated, { + "page": page, + "limit": limit, + "total": total, + "has_more": end < total, + "total_pages": (total + limit - 1) // limit if total > 0 else 1, + } + + +def calculate_offset(page: int, limit: int) -> int: + """Calculate database offset from page and limit.""" + return (page - 1) * limit + + +def build_pagination_info( + page: int, + limit: int, + total: int, +) -> dict: + """Build pagination info dictionary.""" + return { + "page": page, + "limit": limit, + "total": total, + "has_more": page * limit < total, + "total_pages": (total + limit - 1) // limit if total > 0 else 1, + } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8205b9b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "artdag-common" +version = "0.1.3" +description = "Shared components for Art-DAG L1 and L2 servers" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.100.0", + "jinja2>=3.1.0", + "pydantic>=2.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["artdag_common"] From cc2dcbddd46da50449eb817eac79295dcda245aa Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:09:39 +0000 Subject: [PATCH 05/24] Squashed 'core/' content from commit 4957443 git-subtree-dir: core git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07 --- .gitignore | 47 + README.md | 110 ++ artdag/__init__.py | 61 + artdag/activities.py | 371 ++++ artdag/activitypub/__init__.py | 33 + artdag/activitypub/activity.py | 203 ++ artdag/activitypub/actor.py | 206 +++ artdag/activitypub/ownership.py | 226 +++ artdag/activitypub/signatures.py | 163 ++ artdag/analysis/__init__.py | 26 + artdag/analysis/analyzer.py | 282 +++ artdag/analysis/audio.py | 336 ++++ artdag/analysis/schema.py | 352 ++++ artdag/analysis/video.py | 266 +++ artdag/cache.py | 464 +++++ artdag/cli.py | 724 ++++++++ artdag/client.py | 201 ++ artdag/dag.py | 344 ++++ artdag/effects/__init__.py | 55 + artdag/effects/binding.py | 311 ++++ artdag/effects/frame_processor.py | 347 ++++ artdag/effects/loader.py | 455 +++++ artdag/effects/meta.py | 247 +++ artdag/effects/runner.py | 259 +++ artdag/effects/sandbox.py | 431 +++++ artdag/engine.py | 246 +++ artdag/executor.py | 106 ++ artdag/nodes/__init__.py | 11 + artdag/nodes/compose.py | 548 ++++++ artdag/nodes/effect.py | 520 ++++++ artdag/nodes/encoding.py | 50 + artdag/nodes/source.py | 62 + artdag/nodes/transform.py | 224 +++ artdag/planning/__init__.py | 29 + artdag/planning/planner.py | 756 ++++++++ artdag/planning/schema.py | 594 ++++++ artdag/planning/tree_reduction.py | 231 +++ artdag/registry/__init__.py | 20 + artdag/registry/registry.py | 294 +++ artdag/server.py | 253 +++ artdag/sexp/__init__.py | 75 + artdag/sexp/compiler.py | 2463 +++++++++++++++++++++++++ artdag/sexp/effect_loader.py | 337 ++++ artdag/sexp/evaluator.py | 869 +++++++++ artdag/sexp/external_tools.py | 292 +++ artdag/sexp/ffmpeg_compiler.py | 616 +++++++ artdag/sexp/parser.py | 425 +++++ artdag/sexp/planner.py | 2187 ++++++++++++++++++++++ artdag/sexp/primitives.py | 620 +++++++ artdag/sexp/scheduler.py | 779 ++++++++ artdag/sexp/stage_cache.py | 412 +++++ artdag/sexp/test_ffmpeg_compiler.py | 146 ++ artdag/sexp/test_primitives.py | 201 ++ artdag/sexp/test_stage_cache.py | 324 ++++ artdag/sexp/test_stage_compiler.py | 286 +++ artdag/sexp/test_stage_integration.py | 739 ++++++++ artdag/sexp/test_stage_planner.py | 228 +++ artdag/sexp/test_stage_scheduler.py | 323 ++++ docs/EXECUTION_MODEL.md | 384 ++++ docs/IPFS_PRIMARY_ARCHITECTURE.md | 443 +++++ docs/L1_STORAGE.md | 181 ++ docs/OFFLINE_TESTING.md | 211 +++ effects/identity/README.md | 35 + effects/identity/requirements.txt | 2 + examples/simple_sequence.yaml | 42 + examples/test_local.sh | 54 + examples/test_plan.py | 93 + pyproject.toml | 62 + scripts/compute_repo_hash.py | 67 + scripts/install-ffglitch.sh | 82 + scripts/register_identity_effect.py | 83 + scripts/setup_actor.py | 120 ++ scripts/sign_assets.py | 143 ++ tests/__init__.py | 1 + tests/test_activities.py | 613 ++++++ tests/test_cache.py | 163 ++ tests/test_dag.py | 271 +++ tests/test_engine.py | 464 +++++ tests/test_executor.py | 110 ++ tests/test_ipfs_access.py | 301 +++ 80 files changed, 25711 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 artdag/__init__.py create mode 100644 artdag/activities.py create mode 100644 artdag/activitypub/__init__.py create mode 100644 artdag/activitypub/activity.py create mode 100644 artdag/activitypub/actor.py create mode 100644 artdag/activitypub/ownership.py create mode 100644 artdag/activitypub/signatures.py create mode 100644 artdag/analysis/__init__.py create mode 100644 artdag/analysis/analyzer.py create mode 100644 artdag/analysis/audio.py create mode 100644 artdag/analysis/schema.py create mode 100644 artdag/analysis/video.py create mode 100644 artdag/cache.py create mode 100644 artdag/cli.py create mode 100644 artdag/client.py create mode 100644 artdag/dag.py create mode 100644 artdag/effects/__init__.py create mode 100644 artdag/effects/binding.py create mode 100644 artdag/effects/frame_processor.py create mode 100644 artdag/effects/loader.py create mode 100644 artdag/effects/meta.py create mode 100644 artdag/effects/runner.py create mode 100644 artdag/effects/sandbox.py create mode 100644 artdag/engine.py create mode 100644 artdag/executor.py create mode 100644 artdag/nodes/__init__.py create mode 100644 artdag/nodes/compose.py create mode 100644 artdag/nodes/effect.py create mode 100644 artdag/nodes/encoding.py create mode 100644 artdag/nodes/source.py create mode 100644 artdag/nodes/transform.py create mode 100644 artdag/planning/__init__.py create mode 100644 artdag/planning/planner.py create mode 100644 artdag/planning/schema.py create mode 100644 artdag/planning/tree_reduction.py create mode 100644 artdag/registry/__init__.py create mode 100644 artdag/registry/registry.py create mode 100644 artdag/server.py create mode 100644 artdag/sexp/__init__.py create mode 100644 artdag/sexp/compiler.py create mode 100644 artdag/sexp/effect_loader.py create mode 100644 artdag/sexp/evaluator.py create mode 100644 artdag/sexp/external_tools.py create mode 100644 artdag/sexp/ffmpeg_compiler.py create mode 100644 artdag/sexp/parser.py create mode 100644 artdag/sexp/planner.py create mode 100644 artdag/sexp/primitives.py create mode 100644 artdag/sexp/scheduler.py create mode 100644 artdag/sexp/stage_cache.py create mode 100644 artdag/sexp/test_ffmpeg_compiler.py create mode 100644 artdag/sexp/test_primitives.py create mode 100644 artdag/sexp/test_stage_cache.py create mode 100644 artdag/sexp/test_stage_compiler.py create mode 100644 artdag/sexp/test_stage_integration.py create mode 100644 artdag/sexp/test_stage_planner.py create mode 100644 artdag/sexp/test_stage_scheduler.py create mode 100644 docs/EXECUTION_MODEL.md create mode 100644 docs/IPFS_PRIMARY_ARCHITECTURE.md create mode 100644 docs/L1_STORAGE.md create mode 100644 docs/OFFLINE_TESTING.md create mode 100644 effects/identity/README.md create mode 100644 effects/identity/requirements.txt create mode 100644 examples/simple_sequence.yaml create mode 100755 examples/test_local.sh create mode 100755 examples/test_plan.py create mode 100644 pyproject.toml create mode 100644 scripts/compute_repo_hash.py create mode 100755 scripts/install-ffglitch.sh create mode 100644 scripts/register_identity_effect.py create mode 100644 scripts/setup_actor.py create mode 100644 scripts/sign_assets.py create mode 100644 tests/__init__.py create mode 100644 tests/test_activities.py create mode 100644 tests/test_cache.py create mode 100644 tests/test_dag.py create mode 100644 tests/test_engine.py create mode 100644 tests/test_executor.py create mode 100644 tests/test_ipfs_access.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d2f90c --- /dev/null +++ b/.gitignore @@ -0,0 +1,47 @@ +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Environment +.env +.venv +env/ +venv/ + +# Private keys (ActivityPub secrets) +.cache/ + +# Test outputs +test_cache/ +test_plan_output.json +analysis.json +plan.json +plan_with_analysis.json diff --git a/README.md b/README.md new file mode 100644 index 0000000..27602a4 --- /dev/null +++ b/README.md @@ -0,0 +1,110 @@ +# artdag + +Content-addressed DAG execution engine with ActivityPub ownership. + +## Features + +- **Content-addressed nodes**: `node_id = SHA3-256(type + config + inputs)` for automatic deduplication +- **Quantum-resistant hashing**: SHA-3 throughout for future-proof integrity +- **ActivityPub ownership**: Cryptographically signed ownership claims +- **Federated identity**: `@user@artdag.rose-ash.com` style identities +- **Pluggable executors**: Register custom node types +- **Built-in video primitives**: SOURCE, SEGMENT, RESIZE, TRANSFORM, SEQUENCE, MUX, BLEND + +## Installation + +```bash +pip install -e . +``` + +### Optional: External Effect Tools + +Some effects can use external tools for better performance: + +**Pixelsort** (glitch art pixel sorting): +```bash +# Rust CLI (recommended - fast) +cargo install --git https://github.com/Void-ux/pixelsort.git pixelsort + +# Or Python CLI +pip install git+https://github.com/Blotz/pixelsort-cli +``` + +**Datamosh** (video glitch/corruption): +```bash +# FFglitch (recommended) +./scripts/install-ffglitch.sh + +# Or Python CLI +pip install git+https://github.com/tiberiuiancu/datamoshing +``` + +Check available tools: +```bash +python -m artdag.sexp.external_tools +``` + +## Quick Start + +```python +from artdag import Engine, DAGBuilder, Registry +from artdag.activitypub import OwnershipManager + +# Create ownership manager +manager = OwnershipManager("./my_registry") + +# Create your identity +actor = manager.create_actor("alice", "Alice") +print(f"Created: {actor.handle}") # @alice@artdag.rose-ash.com + +# Register an asset with ownership +asset, activity = manager.register_asset( + actor=actor, + name="my_image", + path="/path/to/image.jpg", + tags=["photo", "art"], +) +print(f"Owned: {asset.name} (hash: {asset.content_hash})") + +# Build and execute a DAG +engine = Engine("./cache") +builder = DAGBuilder() + +source = builder.source(str(asset.path)) +resized = builder.resize(source, width=1920, height=1080) +builder.set_output(resized) + +result = engine.execute(builder.build()) +print(f"Output: {result.output_path}") +``` + +## Architecture + +``` +artdag/ +├── dag.py # Node, DAG, DAGBuilder +├── cache.py # Content-addressed file cache +├── executor.py # Base executor + registry +├── engine.py # DAG execution engine +├── activitypub/ # Identity + ownership +│ ├── actor.py # Actor identity with RSA keys +│ ├── activity.py # Create, Announce activities +│ ├── signatures.py # RSA signing/verification +│ └── ownership.py # Links actors to assets +├── nodes/ # Built-in executors +│ ├── source.py # SOURCE +│ ├── transform.py # SEGMENT, RESIZE, TRANSFORM +│ ├── compose.py # SEQUENCE, LAYER, MUX, BLEND +│ └── effect.py # EFFECT (identity, etc.) +└── effects/ # Effect implementations + └── identity/ # The foundational identity effect +``` + +## Related Repos + +- **Registry**: https://git.rose-ash.com/art-dag/registry - Asset registry with ownership proofs +- **Recipes**: https://git.rose-ash.com/art-dag/recipes - DAG recipes using effects + +## License + +MIT diff --git a/artdag/__init__.py b/artdag/__init__.py new file mode 100644 index 0000000..4b8abe2 --- /dev/null +++ b/artdag/__init__.py @@ -0,0 +1,61 @@ +# artdag - Content-addressed DAG execution engine with ActivityPub ownership +# +# A standalone execution engine that processes directed acyclic graphs (DAGs) +# where each node represents an operation. Nodes are content-addressed for +# automatic caching and deduplication. +# +# Core concepts: +# - Node: An operation with type, config, and inputs +# - DAG: A graph of nodes with a designated output node +# - Executor: Implements the actual operation for a node type +# - Engine: Executes DAGs by resolving dependencies and running executors + +from .dag import Node, DAG, DAGBuilder, NodeType +from .cache import Cache, CacheEntry +from .executor import Executor, register_executor, get_executor +from .engine import Engine +from .registry import Registry, Asset +from .activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + +# Analysis and planning modules (optional, require extra dependencies) +try: + from .analysis import Analyzer, AnalysisResult +except ImportError: + Analyzer = None + AnalysisResult = None + +try: + from .planning import RecipePlanner, ExecutionPlan, ExecutionStep +except ImportError: + RecipePlanner = None + ExecutionPlan = None + ExecutionStep = None + +__all__ = [ + # Core + "Node", + "DAG", + "DAGBuilder", + "NodeType", + "Cache", + "CacheEntry", + "Executor", + "register_executor", + "get_executor", + "Engine", + "Registry", + "Asset", + "Activity", + "ActivityStore", + "ActivityManager", + "make_is_shared_fn", + # Analysis (optional) + "Analyzer", + "AnalysisResult", + # Planning (optional) + "RecipePlanner", + "ExecutionPlan", + "ExecutionStep", +] + +__version__ = "0.1.0" diff --git a/artdag/activities.py b/artdag/activities.py new file mode 100644 index 0000000..0919ee7 --- /dev/null +++ b/artdag/activities.py @@ -0,0 +1,371 @@ +# artdag/activities.py +""" +Persistent activity (job) tracking for cache management. + +Activities represent executions of DAGs. They track: +- Input node IDs (sources) +- Output node ID (terminal node) +- Intermediate node IDs (everything in between) + +This enables deletion rules: +- Shared items (ActivityPub published) cannot be deleted +- Inputs/outputs of activities cannot be deleted +- Intermediates can be deleted (reconstructible) +- Activities can only be discarded if no items are shared +""" + +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set + +from .cache import Cache, CacheEntry +from .dag import DAG + +logger = logging.getLogger(__name__) + + +def make_is_shared_fn(activitypub_store: "ActivityStore") -> Callable[[str], bool]: + """ + Create an is_shared function from an ActivityPub ActivityStore. + + Args: + activitypub_store: The ActivityPub activity store + (from artdag.activitypub.activity) + + Returns: + Function that checks if a cid has been published + """ + def is_shared(cid: str) -> bool: + activities = activitypub_store.find_by_object_hash(cid) + return any(a.activity_type == "Create" for a in activities) + return is_shared + + +@dataclass +class Activity: + """ + A recorded execution of a DAG. + + Tracks which cache entries are inputs, outputs, and intermediates + to enforce deletion rules. + """ + activity_id: str + input_ids: List[str] # Source node cache IDs + output_id: str # Terminal node cache ID + intermediate_ids: List[str] # Everything in between + created_at: float + status: str = "completed" # pending|running|completed|failed + dag_snapshot: Optional[Dict[str, Any]] = None # Serialized DAG for reconstruction + + def to_dict(self) -> Dict[str, Any]: + return { + "activity_id": self.activity_id, + "input_ids": self.input_ids, + "output_id": self.output_id, + "intermediate_ids": self.intermediate_ids, + "created_at": self.created_at, + "status": self.status, + "dag_snapshot": self.dag_snapshot, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Activity": + return cls( + activity_id=data["activity_id"], + input_ids=data["input_ids"], + output_id=data["output_id"], + intermediate_ids=data["intermediate_ids"], + created_at=data["created_at"], + status=data.get("status", "completed"), + dag_snapshot=data.get("dag_snapshot"), + ) + + @classmethod + def from_dag(cls, dag: DAG, activity_id: str = None) -> "Activity": + """ + Create an Activity from a DAG. + + Classifies nodes as inputs, output, or intermediates. + """ + if activity_id is None: + activity_id = str(uuid.uuid4()) + + # Find input nodes (nodes with no inputs - sources) + input_ids = [] + for node_id, node in dag.nodes.items(): + if not node.inputs: + input_ids.append(node_id) + + # Output is the terminal node + output_id = dag.output_id + + # Intermediates are everything else + intermediate_ids = [] + for node_id in dag.nodes: + if node_id not in input_ids and node_id != output_id: + intermediate_ids.append(node_id) + + return cls( + activity_id=activity_id, + input_ids=sorted(input_ids), + output_id=output_id, + intermediate_ids=sorted(intermediate_ids), + created_at=time.time(), + status="completed", + dag_snapshot=dag.to_dict(), + ) + + @property + def all_node_ids(self) -> List[str]: + """All node IDs involved in this activity.""" + return self.input_ids + [self.output_id] + self.intermediate_ids + + +class ActivityStore: + """ + Persistent storage for activities. + + Provides methods to check deletion eligibility and perform deletions. + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._activities: Dict[str, Activity] = {} + self._load() + + def _index_path(self) -> Path: + return self.store_dir / "activities.json" + + def _load(self): + """Load activities from disk.""" + index_path = self._index_path() + if index_path.exists(): + try: + with open(index_path) as f: + data = json.load(f) + self._activities = { + a["activity_id"]: Activity.from_dict(a) + for a in data.get("activities", []) + } + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to load activities: {e}") + self._activities = {} + + def _save(self): + """Save activities to disk.""" + data = { + "version": "1.0", + "activities": [a.to_dict() for a in self._activities.values()], + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def add(self, activity: Activity) -> None: + """Add an activity.""" + self._activities[activity.activity_id] = activity + self._save() + + def get(self, activity_id: str) -> Optional[Activity]: + """Get an activity by ID.""" + return self._activities.get(activity_id) + + def remove(self, activity_id: str) -> bool: + """Remove an activity record (does not delete cache entries).""" + if activity_id not in self._activities: + return False + del self._activities[activity_id] + self._save() + return True + + def list(self) -> List[Activity]: + """List all activities.""" + return list(self._activities.values()) + + def find_by_input_ids(self, input_ids: List[str]) -> List[Activity]: + """Find activities with the same inputs (for UI grouping).""" + sorted_inputs = sorted(input_ids) + return [ + a for a in self._activities.values() + if sorted(a.input_ids) == sorted_inputs + ] + + def find_using_node(self, node_id: str) -> List[Activity]: + """Find all activities that reference a node ID.""" + return [ + a for a in self._activities.values() + if node_id in a.all_node_ids + ] + + def __len__(self) -> int: + return len(self._activities) + + +class ActivityManager: + """ + Manages activities and cache deletion with sharing rules. + + Deletion rules: + 1. Shared items (ActivityPub published) cannot be deleted + 2. Inputs/outputs of activities cannot be deleted + 3. Intermediates can be deleted (reconstructible) + 4. Activities can only be discarded if no items are shared + """ + + def __init__( + self, + cache: Cache, + activity_store: ActivityStore, + is_shared_fn: Callable[[str], bool], + ): + """ + Args: + cache: The L1 cache + activity_store: Activity persistence + is_shared_fn: Function that checks if a cid is shared + (published via ActivityPub) + """ + self.cache = cache + self.activities = activity_store + self._is_shared = is_shared_fn + + def record_activity(self, dag: DAG) -> Activity: + """Record a completed DAG execution as an activity.""" + activity = Activity.from_dag(dag) + self.activities.add(activity) + return activity + + def is_shared(self, node_id: str) -> bool: + """Check if a cache entry is shared (published via ActivityPub).""" + entry = self.cache.get_entry(node_id) + if not entry or not entry.cid: + return False + return self._is_shared(entry.cid) + + def can_delete_cache_entry(self, node_id: str) -> bool: + """ + Check if a cache entry can be deleted. + + Returns False if: + - Entry is shared (ActivityPub published) + - Entry is an input or output of any activity + """ + # Check if shared + if self.is_shared(node_id): + return False + + # Check if it's an input or output of any activity + for activity in self.activities.list(): + if node_id in activity.input_ids: + return False + if node_id == activity.output_id: + return False + + # It's either an intermediate or orphaned - can delete + return True + + def can_discard_activity(self, activity_id: str) -> bool: + """ + Check if an activity can be discarded. + + Returns False if any cache entry (input, output, or intermediate) + is shared via ActivityPub. + """ + activity = self.activities.get(activity_id) + if not activity: + return False + + # Check if any item is shared + for node_id in activity.all_node_ids: + if self.is_shared(node_id): + return False + + return True + + def discard_activity(self, activity_id: str) -> bool: + """ + Discard an activity and delete its intermediate cache entries. + + Returns False if the activity cannot be discarded (has shared items). + + When discarded: + - Intermediate cache entries are deleted + - The activity record is removed + - Inputs remain (may be used by other activities) + - Output is deleted if orphaned (not shared, not used elsewhere) + """ + if not self.can_discard_activity(activity_id): + return False + + activity = self.activities.get(activity_id) + if not activity: + return False + + output_id = activity.output_id + intermediate_ids = list(activity.intermediate_ids) + + # Remove the activity record first + self.activities.remove(activity_id) + + # Delete intermediates + for node_id in intermediate_ids: + self.cache.remove(node_id) + logger.debug(f"Deleted intermediate: {node_id}") + + # Check if output is now orphaned + if self._is_orphaned(output_id) and not self.is_shared(output_id): + self.cache.remove(output_id) + logger.debug(f"Deleted orphaned output: {output_id}") + + # Inputs remain - they may be used by other activities + # But check if any are orphaned now + for input_id in activity.input_ids: + if self._is_orphaned(input_id) and not self.is_shared(input_id): + self.cache.remove(input_id) + logger.debug(f"Deleted orphaned input: {input_id}") + + return True + + def _is_orphaned(self, node_id: str) -> bool: + """Check if a node is not referenced by any activity.""" + for activity in self.activities.list(): + if node_id in activity.all_node_ids: + return False + return True + + def get_deletable_entries(self) -> List[CacheEntry]: + """Get all cache entries that can be deleted.""" + deletable = [] + for entry in self.cache.list_entries(): + if self.can_delete_cache_entry(entry.node_id): + deletable.append(entry) + return deletable + + def get_discardable_activities(self) -> List[Activity]: + """Get all activities that can be discarded.""" + return [ + a for a in self.activities.list() + if self.can_discard_activity(a.activity_id) + ] + + def cleanup_intermediates(self) -> int: + """ + Delete all intermediate cache entries. + + Intermediates are safe to delete as they can be reconstructed + from inputs using the DAG. + + Returns: + Number of entries deleted + """ + deleted = 0 + for activity in self.activities.list(): + for node_id in activity.intermediate_ids: + if self.cache.has(node_id): + self.cache.remove(node_id) + deleted += 1 + return deleted diff --git a/artdag/activitypub/__init__.py b/artdag/activitypub/__init__.py new file mode 100644 index 0000000..e9abbdc --- /dev/null +++ b/artdag/activitypub/__init__.py @@ -0,0 +1,33 @@ +# primitive/activitypub/__init__.py +""" +ActivityPub implementation for Art DAG. + +Provides decentralized identity and ownership for assets. +Domain: artdag.rose-ash.com + +Core concepts: +- Actor: A user identity with cryptographic keys +- Object: An asset (image, video, etc.) +- Activity: An action (Create, Announce, Like, etc.) +- Signature: Cryptographic proof of authorship +""" + +from .actor import Actor, ActorStore +from .activity import Activity, CreateActivity, ActivityStore +from .signatures import sign_activity, verify_signature, verify_activity_ownership +from .ownership import OwnershipManager, OwnershipRecord + +__all__ = [ + "Actor", + "ActorStore", + "Activity", + "CreateActivity", + "ActivityStore", + "sign_activity", + "verify_signature", + "verify_activity_ownership", + "OwnershipManager", + "OwnershipRecord", +] + +DOMAIN = "artdag.rose-ash.com" diff --git a/artdag/activitypub/activity.py b/artdag/activitypub/activity.py new file mode 100644 index 0000000..d7ab9b8 --- /dev/null +++ b/artdag/activitypub/activity.py @@ -0,0 +1,203 @@ +# primitive/activitypub/activity.py +""" +ActivityPub Activity types. + +Activities represent actions taken by actors on objects. +Key activity types for Art DAG: +- Create: Actor creates/claims ownership of an object +- Announce: Actor shares/boosts an object +- Like: Actor endorses an object +""" + +import json +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .actor import Actor, DOMAIN + + +def _generate_id() -> str: + """Generate unique activity ID.""" + return str(uuid.uuid4()) + + +@dataclass +class Activity: + """ + Base ActivityPub Activity. + + Attributes: + activity_id: Unique identifier + activity_type: Type (Create, Announce, Like, etc.) + actor_id: ID of the actor performing the activity + object_data: The object of the activity + published: ISO timestamp + signature: Cryptographic signature (added after signing) + """ + activity_id: str + activity_type: str + actor_id: str + object_data: Dict[str, Any] + published: str = field(default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())) + signature: Optional[Dict[str, Any]] = None + + def to_activitypub(self) -> Dict[str, Any]: + """Return ActivityPub JSON-LD representation.""" + activity = { + "@context": "https://www.w3.org/ns/activitystreams", + "type": self.activity_type, + "id": f"https://{DOMAIN}/activities/{self.activity_id}", + "actor": self.actor_id, + "object": self.object_data, + "published": self.published, + } + if self.signature: + activity["signature"] = self.signature + return activity + + def to_dict(self) -> Dict[str, Any]: + """Serialize for storage.""" + return { + "activity_id": self.activity_id, + "activity_type": self.activity_type, + "actor_id": self.actor_id, + "object_data": self.object_data, + "published": self.published, + "signature": self.signature, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Activity": + """Deserialize from storage.""" + return cls( + activity_id=data["activity_id"], + activity_type=data["activity_type"], + actor_id=data["actor_id"], + object_data=data["object_data"], + published=data.get("published", ""), + signature=data.get("signature"), + ) + + +@dataclass +class CreateActivity(Activity): + """ + Create activity - establishes ownership of an object. + + Used when an actor creates or claims an asset. + """ + activity_type: str = field(default="Create", init=False) + + @classmethod + def for_asset( + cls, + actor: Actor, + asset_name: str, + cid: str, + asset_type: str = "Image", + metadata: Dict[str, Any] = None, + ) -> "CreateActivity": + """ + Create a Create activity for an asset. + + Args: + actor: The actor claiming ownership + asset_name: Name of the asset + cid: SHA-3 hash of the asset content + asset_type: ActivityPub object type (Image, Video, Audio, etc.) + metadata: Additional metadata + + Returns: + CreateActivity establishing ownership + """ + object_data = { + "type": asset_type, + "name": asset_name, + "id": f"https://{DOMAIN}/objects/{cid}", + "contentHash": { + "algorithm": "sha3-256", + "value": cid, + }, + "attributedTo": actor.id, + } + if metadata: + object_data["metadata"] = metadata + + return cls( + activity_id=_generate_id(), + actor_id=actor.id, + object_data=object_data, + ) + + +class ActivityStore: + """ + Persistent storage for activities. + + Activities are stored as an append-only log for auditability. + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._activities: List[Activity] = [] + self._load() + + def _log_path(self) -> Path: + return self.store_dir / "activities.json" + + def _load(self): + """Load activities from disk.""" + log_path = self._log_path() + if log_path.exists(): + with open(log_path) as f: + data = json.load(f) + self._activities = [ + Activity.from_dict(a) for a in data.get("activities", []) + ] + + def _save(self): + """Save activities to disk.""" + data = { + "version": "1.0", + "activities": [a.to_dict() for a in self._activities], + } + with open(self._log_path(), "w") as f: + json.dump(data, f, indent=2) + + def add(self, activity: Activity) -> None: + """Add an activity to the log.""" + self._activities.append(activity) + self._save() + + def get(self, activity_id: str) -> Optional[Activity]: + """Get an activity by ID.""" + for a in self._activities: + if a.activity_id == activity_id: + return a + return None + + def list(self) -> List[Activity]: + """List all activities.""" + return list(self._activities) + + def find_by_actor(self, actor_id: str) -> List[Activity]: + """Find activities by actor.""" + return [a for a in self._activities if a.actor_id == actor_id] + + def find_by_object_hash(self, cid: str) -> List[Activity]: + """Find activities referencing an object by hash.""" + results = [] + for a in self._activities: + obj_hash = a.object_data.get("contentHash", {}) + if isinstance(obj_hash, dict) and obj_hash.get("value") == cid: + results.append(a) + elif a.object_data.get("contentHash") == cid: + results.append(a) + return results + + def __len__(self) -> int: + return len(self._activities) diff --git a/artdag/activitypub/actor.py b/artdag/activitypub/actor.py new file mode 100644 index 0000000..8e0deed --- /dev/null +++ b/artdag/activitypub/actor.py @@ -0,0 +1,206 @@ +# primitive/activitypub/actor.py +""" +ActivityPub Actor management. + +An Actor is an identity with: +- Username and display name +- RSA key pair for signing +- ActivityPub-compliant JSON-LD representation +""" + +import json +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa, padding + +DOMAIN = "artdag.rose-ash.com" + + +def _generate_keypair() -> tuple[bytes, bytes]: + """Generate RSA key pair for signing.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return private_pem, public_pem + + +@dataclass +class Actor: + """ + An ActivityPub Actor (identity). + + Attributes: + username: Unique username (e.g., "giles") + display_name: Human-readable name + public_key: PEM-encoded public key + private_key: PEM-encoded private key (kept secret) + created_at: Timestamp of creation + """ + username: str + display_name: str + public_key: bytes + private_key: bytes + created_at: float = field(default_factory=time.time) + domain: str = DOMAIN + + @property + def id(self) -> str: + """ActivityPub actor ID (URL).""" + return f"https://{self.domain}/users/{self.username}" + + @property + def handle(self) -> str: + """Fediverse handle.""" + return f"@{self.username}@{self.domain}" + + @property + def inbox(self) -> str: + """ActivityPub inbox URL.""" + return f"{self.id}/inbox" + + @property + def outbox(self) -> str: + """ActivityPub outbox URL.""" + return f"{self.id}/outbox" + + @property + def key_id(self) -> str: + """Key ID for HTTP Signatures.""" + return f"{self.id}#main-key" + + def to_activitypub(self) -> Dict[str, Any]: + """Return ActivityPub JSON-LD representation.""" + return { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1", + ], + "type": "Person", + "id": self.id, + "preferredUsername": self.username, + "name": self.display_name, + "inbox": self.inbox, + "outbox": self.outbox, + "publicKey": { + "id": self.key_id, + "owner": self.id, + "publicKeyPem": self.public_key.decode("utf-8"), + }, + } + + def to_dict(self) -> Dict[str, Any]: + """Serialize for storage.""" + return { + "username": self.username, + "display_name": self.display_name, + "public_key": self.public_key.decode("utf-8"), + "private_key": self.private_key.decode("utf-8"), + "created_at": self.created_at, + "domain": self.domain, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Actor": + """Deserialize from storage.""" + return cls( + username=data["username"], + display_name=data["display_name"], + public_key=data["public_key"].encode("utf-8"), + private_key=data["private_key"].encode("utf-8"), + created_at=data.get("created_at", time.time()), + domain=data.get("domain", DOMAIN), + ) + + @classmethod + def create(cls, username: str, display_name: str = None) -> "Actor": + """Create a new actor with generated keys.""" + private_pem, public_pem = _generate_keypair() + return cls( + username=username, + display_name=display_name or username, + public_key=public_pem, + private_key=private_pem, + ) + + +class ActorStore: + """ + Persistent storage for actors. + + Structure: + store_dir/ + actors.json # Index of all actors + keys/ + .private.pem + .public.pem + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._actors: Dict[str, Actor] = {} + self._load() + + def _index_path(self) -> Path: + return self.store_dir / "actors.json" + + def _load(self): + """Load actors from disk.""" + index_path = self._index_path() + if index_path.exists(): + with open(index_path) as f: + data = json.load(f) + self._actors = { + username: Actor.from_dict(actor_data) + for username, actor_data in data.get("actors", {}).items() + } + + def _save(self): + """Save actors to disk.""" + data = { + "version": "1.0", + "domain": DOMAIN, + "actors": { + username: actor.to_dict() + for username, actor in self._actors.items() + }, + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def create(self, username: str, display_name: str = None) -> Actor: + """Create and store a new actor.""" + if username in self._actors: + raise ValueError(f"Actor {username} already exists") + + actor = Actor.create(username, display_name) + self._actors[username] = actor + self._save() + return actor + + def get(self, username: str) -> Optional[Actor]: + """Get an actor by username.""" + return self._actors.get(username) + + def list(self) -> list[Actor]: + """List all actors.""" + return list(self._actors.values()) + + def __contains__(self, username: str) -> bool: + return username in self._actors + + def __len__(self) -> int: + return len(self._actors) diff --git a/artdag/activitypub/ownership.py b/artdag/activitypub/ownership.py new file mode 100644 index 0000000..8290871 --- /dev/null +++ b/artdag/activitypub/ownership.py @@ -0,0 +1,226 @@ +# primitive/activitypub/ownership.py +""" +Ownership integration between ActivityPub and Registry. + +Connects actors, activities, and assets to establish provable ownership. +""" + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .actor import Actor, ActorStore +from .activity import Activity, CreateActivity, ActivityStore +from .signatures import sign_activity, verify_activity_ownership +from ..registry import Registry, Asset + + +@dataclass +class OwnershipRecord: + """ + A verified ownership record linking actor to asset. + + Attributes: + actor_handle: The actor's fediverse handle + asset_name: Name of the owned asset + cid: SHA-3 hash of the asset + activity_id: ID of the Create activity establishing ownership + verified: Whether the signature has been verified + """ + actor_handle: str + asset_name: str + cid: str + activity_id: str + verified: bool = False + + +class OwnershipManager: + """ + Manages ownership relationships between actors and assets. + + Integrates: + - ActorStore: Identity management + - Registry: Asset storage + - ActivityStore: Ownership activities + """ + + def __init__(self, base_dir: Path | str): + self.base_dir = Path(base_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + # Initialize stores + self.actors = ActorStore(self.base_dir / "actors") + self.activities = ActivityStore(self.base_dir / "activities") + self.registry = Registry(self.base_dir / "registry") + + def create_actor(self, username: str, display_name: str = None) -> Actor: + """Create a new actor identity.""" + return self.actors.create(username, display_name) + + def get_actor(self, username: str) -> Optional[Actor]: + """Get an actor by username.""" + return self.actors.get(username) + + def register_asset( + self, + actor: Actor, + name: str, + cid: str, + url: str = None, + local_path: Path | str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> tuple[Asset, Activity]: + """ + Register an asset and establish ownership. + + Creates the asset in the registry and a signed Create activity + proving the actor's ownership. + + Args: + actor: The actor claiming ownership + name: Name for the asset + cid: SHA-3-256 hash of the content + url: Public URL (canonical location) + local_path: Optional local path + tags: Optional tags + metadata: Optional metadata + + Returns: + Tuple of (Asset, signed CreateActivity) + """ + # Add to registry + asset = self.registry.add( + name=name, + cid=cid, + url=url, + local_path=local_path, + tags=tags, + metadata=metadata, + ) + + # Create ownership activity + activity = CreateActivity.for_asset( + actor=actor, + asset_name=name, + cid=asset.cid, + asset_type=self._asset_type_to_ap(asset.asset_type), + metadata=metadata, + ) + + # Sign the activity + signed_activity = sign_activity(activity, actor) + + # Store the activity + self.activities.add(signed_activity) + + return asset, signed_activity + + def _asset_type_to_ap(self, asset_type: str) -> str: + """Convert registry asset type to ActivityPub type.""" + type_map = { + "image": "Image", + "video": "Video", + "audio": "Audio", + "unknown": "Document", + } + return type_map.get(asset_type, "Document") + + def get_owner(self, asset_name: str) -> Optional[Actor]: + """ + Get the owner of an asset. + + Finds the earliest Create activity for the asset and returns + the actor if the signature is valid. + """ + asset = self.registry.get(asset_name) + if not asset: + return None + + # Find Create activities for this asset + activities = self.activities.find_by_object_hash(asset.cid) + create_activities = [a for a in activities if a.activity_type == "Create"] + + if not create_activities: + return None + + # Get the earliest (first owner) + earliest = min(create_activities, key=lambda a: a.published) + + # Extract username from actor_id + # Format: https://artdag.rose-ash.com/users/{username} + actor_id = earliest.actor_id + if "/users/" in actor_id: + username = actor_id.split("/users/")[-1] + actor = self.actors.get(username) + if actor and verify_activity_ownership(earliest, actor): + return actor + + return None + + def verify_ownership(self, asset_name: str, actor: Actor) -> bool: + """ + Verify that an actor owns an asset. + + Checks for a valid signed Create activity linking the actor + to the asset. + """ + asset = self.registry.get(asset_name) + if not asset: + return False + + activities = self.activities.find_by_object_hash(asset.cid) + for activity in activities: + if activity.activity_type == "Create" and activity.actor_id == actor.id: + if verify_activity_ownership(activity, actor): + return True + + return False + + def list_owned_assets(self, actor: Actor) -> List[Asset]: + """List all assets owned by an actor.""" + activities = self.activities.find_by_actor(actor.id) + owned = [] + + for activity in activities: + if activity.activity_type == "Create": + # Find asset by hash + obj_hash = activity.object_data.get("contentHash", {}) + if isinstance(obj_hash, dict): + hash_value = obj_hash.get("value") + else: + hash_value = obj_hash + + if hash_value: + asset = self.registry.find_by_hash(hash_value) + if asset: + owned.append(asset) + + return owned + + def get_ownership_records(self) -> List[OwnershipRecord]: + """Get all ownership records.""" + records = [] + + for activity in self.activities.list(): + if activity.activity_type != "Create": + continue + + # Extract info + actor_id = activity.actor_id + username = actor_id.split("/users/")[-1] if "/users/" in actor_id else "unknown" + actor = self.actors.get(username) + + obj_hash = activity.object_data.get("contentHash", {}) + hash_value = obj_hash.get("value") if isinstance(obj_hash, dict) else obj_hash + + records.append(OwnershipRecord( + actor_handle=actor.handle if actor else f"@{username}@unknown", + asset_name=activity.object_data.get("name", "unknown"), + cid=hash_value or "unknown", + activity_id=activity.activity_id, + verified=verify_activity_ownership(activity, actor) if actor else False, + )) + + return records diff --git a/artdag/activitypub/signatures.py b/artdag/activitypub/signatures.py new file mode 100644 index 0000000..099524c --- /dev/null +++ b/artdag/activitypub/signatures.py @@ -0,0 +1,163 @@ +# primitive/activitypub/signatures.py +""" +Cryptographic signatures for ActivityPub. + +Uses RSA-SHA256 signatures compatible with HTTP Signatures spec +and Linked Data Signatures for ActivityPub. +""" + +import base64 +import hashlib +import json +import time +from typing import Any, Dict + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.exceptions import InvalidSignature + +from .actor import Actor +from .activity import Activity + + +def _canonicalize(data: Dict[str, Any]) -> str: + """ + Canonicalize JSON for signing. + + Uses JCS (JSON Canonicalization Scheme) - sorted keys, no whitespace. + """ + return json.dumps(data, sort_keys=True, separators=(",", ":")) + + +def _hash_sha256(data: str) -> bytes: + """Hash string with SHA-256.""" + return hashlib.sha256(data.encode()).digest() + + +def sign_activity(activity: Activity, actor: Actor) -> Activity: + """ + Sign an activity with the actor's private key. + + Uses Linked Data Signatures with RsaSignature2017. + + Args: + activity: The activity to sign + actor: The actor whose key signs the activity + + Returns: + Activity with signature attached + """ + # Load private key + private_key = serialization.load_pem_private_key( + actor.private_key, + password=None, + ) + + # Create signature options + created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + # Canonicalize the activity (without signature) + activity_data = activity.to_activitypub() + activity_data.pop("signature", None) + canonical = _canonicalize(activity_data) + + # Create the data to sign: hash of options + hash of document + options = { + "@context": "https://w3id.org/security/v1", + "type": "RsaSignature2017", + "creator": actor.key_id, + "created": created, + } + options_hash = _hash_sha256(_canonicalize(options)) + document_hash = _hash_sha256(canonical) + to_sign = options_hash + document_hash + + # Sign with RSA-SHA256 + signature_bytes = private_key.sign( + to_sign, + padding.PKCS1v15(), + hashes.SHA256(), + ) + signature_value = base64.b64encode(signature_bytes).decode("utf-8") + + # Attach signature to activity + activity.signature = { + "type": "RsaSignature2017", + "creator": actor.key_id, + "created": created, + "signatureValue": signature_value, + } + + return activity + + +def verify_signature(activity: Activity, public_key_pem: bytes) -> bool: + """ + Verify an activity's signature. + + Args: + activity: The activity with signature + public_key_pem: PEM-encoded public key + + Returns: + True if signature is valid + """ + if not activity.signature: + return False + + try: + # Load public key + public_key = serialization.load_pem_public_key(public_key_pem) + + # Reconstruct signature options + options = { + "@context": "https://w3id.org/security/v1", + "type": activity.signature["type"], + "creator": activity.signature["creator"], + "created": activity.signature["created"], + } + + # Canonicalize activity without signature + activity_data = activity.to_activitypub() + activity_data.pop("signature", None) + canonical = _canonicalize(activity_data) + + # Recreate signed data + options_hash = _hash_sha256(_canonicalize(options)) + document_hash = _hash_sha256(canonical) + signed_data = options_hash + document_hash + + # Decode and verify signature + signature_bytes = base64.b64decode(activity.signature["signatureValue"]) + public_key.verify( + signature_bytes, + signed_data, + padding.PKCS1v15(), + hashes.SHA256(), + ) + return True + + except (InvalidSignature, KeyError, ValueError): + return False + + +def verify_activity_ownership(activity: Activity, actor: Actor) -> bool: + """ + Verify that an activity was signed by the claimed actor. + + Args: + activity: The activity to verify + actor: The claimed actor + + Returns: + True if the activity was signed by this actor + """ + if not activity.signature: + return False + + # Check creator matches actor + if activity.signature.get("creator") != actor.key_id: + return False + + # Verify signature + return verify_signature(activity, actor.public_key) diff --git a/artdag/analysis/__init__.py b/artdag/analysis/__init__.py new file mode 100644 index 0000000..2ab2b81 --- /dev/null +++ b/artdag/analysis/__init__.py @@ -0,0 +1,26 @@ +# artdag/analysis - Audio and video feature extraction +# +# Provides the Analysis phase of the 3-phase execution model: +# 1. ANALYZE - Extract features from inputs +# 2. PLAN - Generate execution plan with cache IDs +# 3. EXECUTE - Run steps with caching + +from .schema import ( + AnalysisResult, + AudioFeatures, + VideoFeatures, + BeatInfo, + EnergyEnvelope, + SpectrumBands, +) +from .analyzer import Analyzer + +__all__ = [ + "Analyzer", + "AnalysisResult", + "AudioFeatures", + "VideoFeatures", + "BeatInfo", + "EnergyEnvelope", + "SpectrumBands", +] diff --git a/artdag/analysis/analyzer.py b/artdag/analysis/analyzer.py new file mode 100644 index 0000000..fd1bdbe --- /dev/null +++ b/artdag/analysis/analyzer.py @@ -0,0 +1,282 @@ +# artdag/analysis/analyzer.py +""" +Main Analyzer class for the Analysis phase. + +Coordinates audio and video feature extraction with caching. +""" + +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional + +from .schema import AnalysisResult, AudioFeatures, VideoFeatures +from .audio import analyze_audio, FEATURE_ALL as AUDIO_ALL +from .video import analyze_video, FEATURE_ALL as VIDEO_ALL + +logger = logging.getLogger(__name__) + + +class AnalysisCache: + """ + Simple file-based cache for analysis results. + + Stores results as JSON files keyed by analysis cache_id. + """ + + def __init__(self, cache_dir: Path): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _path_for(self, cache_id: str) -> Path: + """Get cache file path for a cache_id.""" + return self.cache_dir / f"{cache_id}.json" + + def get(self, cache_id: str) -> Optional[AnalysisResult]: + """Retrieve cached analysis result.""" + path = self._path_for(cache_id) + if not path.exists(): + return None + + try: + with open(path, "r") as f: + data = json.load(f) + return AnalysisResult.from_dict(data) + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to load analysis cache {cache_id}: {e}") + return None + + def put(self, result: AnalysisResult) -> None: + """Store analysis result in cache.""" + path = self._path_for(result.cache_id) + with open(path, "w") as f: + json.dump(result.to_dict(), f, indent=2) + + def has(self, cache_id: str) -> bool: + """Check if analysis result is cached.""" + return self._path_for(cache_id).exists() + + def remove(self, cache_id: str) -> bool: + """Remove cached analysis result.""" + path = self._path_for(cache_id) + if path.exists(): + path.unlink() + return True + return False + + +class Analyzer: + """ + Analyzes media inputs to extract features. + + The Analyzer is the first phase of the 3-phase execution model. + It extracts features from inputs that inform downstream processing. + + Example: + analyzer = Analyzer(cache_dir=Path("./analysis_cache")) + + # Analyze a music file for beats + result = analyzer.analyze( + input_path=Path("/path/to/music.mp3"), + input_hash="abc123...", + features=["beats", "energy"] + ) + + print(f"Tempo: {result.tempo} BPM") + print(f"Beats: {result.beat_times}") + """ + + def __init__( + self, + cache_dir: Optional[Path] = None, + content_cache: Optional["Cache"] = None, # artdag.Cache for input lookup + ): + """ + Initialize the Analyzer. + + Args: + cache_dir: Directory for analysis cache. If None, no caching. + content_cache: artdag Cache for looking up inputs by hash + """ + self.cache = AnalysisCache(cache_dir) if cache_dir else None + self.content_cache = content_cache + + def get_input_path(self, input_hash: str, input_path: Optional[Path] = None) -> Path: + """ + Resolve input to a file path. + + Args: + input_hash: Content hash of the input + input_path: Optional direct path to file + + Returns: + Path to the input file + + Raises: + ValueError: If input cannot be resolved + """ + if input_path and input_path.exists(): + return input_path + + if self.content_cache: + entry = self.content_cache.get(input_hash) + if entry: + return Path(entry.output_path) + + raise ValueError(f"Cannot resolve input {input_hash}: no path provided and not in cache") + + def analyze( + self, + input_hash: str, + features: List[str], + input_path: Optional[Path] = None, + media_type: Optional[str] = None, + ) -> AnalysisResult: + """ + Analyze an input file and extract features. + + Args: + input_hash: Content hash of the input (for cache key) + features: List of features to extract: + Audio: "beats", "tempo", "energy", "spectrum", "onsets" + Video: "metadata", "motion_tempo", "scene_changes" + Meta: "all" (extracts all relevant features) + input_path: Optional direct path to file + media_type: Optional hint ("audio", "video", or None for auto-detect) + + Returns: + AnalysisResult with extracted features + """ + # Compute cache ID + temp_result = AnalysisResult( + input_hash=input_hash, + features_requested=sorted(features), + ) + cache_id = temp_result.cache_id + + # Check cache + if self.cache and self.cache.has(cache_id): + cached = self.cache.get(cache_id) + if cached: + logger.info(f"Analysis cache hit: {cache_id[:16]}...") + return cached + + # Resolve input path + path = self.get_input_path(input_hash, input_path) + logger.info(f"Analyzing {path} for features: {features}") + + # Detect media type if not specified + if media_type is None: + media_type = self._detect_media_type(path) + + # Extract features + audio_features = None + video_features = None + + # Normalize features + if "all" in features: + audio_features_list = [AUDIO_ALL] + video_features_list = [VIDEO_ALL] + else: + audio_features_list = [f for f in features if f in ("beats", "tempo", "energy", "spectrum", "onsets")] + video_features_list = [f for f in features if f in ("metadata", "motion_tempo", "scene_changes")] + + if media_type in ("audio", "video") and audio_features_list: + try: + audio_features = analyze_audio(path, features=audio_features_list) + except Exception as e: + logger.warning(f"Audio analysis failed: {e}") + + if media_type == "video" and video_features_list: + try: + video_features = analyze_video(path, features=video_features_list) + except Exception as e: + logger.warning(f"Video analysis failed: {e}") + + result = AnalysisResult( + input_hash=input_hash, + features_requested=sorted(features), + audio=audio_features, + video=video_features, + analyzed_at=datetime.now(timezone.utc).isoformat(), + ) + + # Cache result + if self.cache: + self.cache.put(result) + + return result + + def analyze_multiple( + self, + inputs: Dict[str, Path], + features: List[str], + ) -> Dict[str, AnalysisResult]: + """ + Analyze multiple inputs. + + Args: + inputs: Dict mapping input_hash to file path + features: Features to extract from all inputs + + Returns: + Dict mapping input_hash to AnalysisResult + """ + results = {} + for input_hash, input_path in inputs.items(): + try: + results[input_hash] = self.analyze( + input_hash=input_hash, + features=features, + input_path=input_path, + ) + except Exception as e: + logger.error(f"Analysis failed for {input_hash}: {e}") + raise + + return results + + def _detect_media_type(self, path: Path) -> str: + """ + Detect if file is audio or video. + + Args: + path: Path to media file + + Returns: + "audio" or "video" + """ + import subprocess + import json + + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + str(path) + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + streams = data.get("streams", []) + + has_video = any(s.get("codec_type") == "video" for s in streams) + has_audio = any(s.get("codec_type") == "audio" for s in streams) + + if has_video: + return "video" + elif has_audio: + return "audio" + else: + return "unknown" + + except (subprocess.CalledProcessError, json.JSONDecodeError): + # Fall back to extension-based detection + ext = path.suffix.lower() + if ext in (".mp4", ".mov", ".avi", ".mkv", ".webm"): + return "video" + elif ext in (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac"): + return "audio" + return "unknown" diff --git a/artdag/analysis/audio.py b/artdag/analysis/audio.py new file mode 100644 index 0000000..4ee034b --- /dev/null +++ b/artdag/analysis/audio.py @@ -0,0 +1,336 @@ +# artdag/analysis/audio.py +""" +Audio feature extraction. + +Uses librosa for beat detection, energy analysis, and spectral features. +Falls back to basic ffprobe if librosa is not available. +""" + +import json +import logging +import subprocess +from pathlib import Path +from typing import List, Optional, Tuple + +from .schema import AudioFeatures, BeatInfo, EnergyEnvelope, SpectrumBands + +logger = logging.getLogger(__name__) + +# Feature names for requesting specific analysis +FEATURE_BEATS = "beats" +FEATURE_TEMPO = "tempo" +FEATURE_ENERGY = "energy" +FEATURE_SPECTRUM = "spectrum" +FEATURE_ONSETS = "onsets" +FEATURE_ALL = "all" + + +def _get_audio_info_ffprobe(path: Path) -> Tuple[float, int, int]: + """Get basic audio info using ffprobe.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + "-select_streams", "a:0", + str(path) + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + if not data.get("streams"): + raise ValueError("No audio stream found") + + stream = data["streams"][0] + duration = float(stream.get("duration", 0)) + sample_rate = int(stream.get("sample_rate", 44100)) + channels = int(stream.get("channels", 2)) + return duration, sample_rate, channels + except (subprocess.CalledProcessError, json.JSONDecodeError, KeyError) as e: + logger.warning(f"ffprobe failed: {e}") + raise ValueError(f"Could not read audio info: {e}") + + +def _extract_audio_to_wav(path: Path, duration: Optional[float] = None) -> Path: + """Extract audio to temporary WAV file for librosa processing.""" + import tempfile + wav_path = Path(tempfile.mktemp(suffix=".wav")) + + cmd = ["ffmpeg", "-y", "-i", str(path)] + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend([ + "-vn", # No video + "-acodec", "pcm_s16le", + "-ar", "22050", # Resample to 22050 Hz for librosa + "-ac", "1", # Mono + str(wav_path) + ]) + + try: + subprocess.run(cmd, capture_output=True, check=True) + return wav_path + except subprocess.CalledProcessError as e: + logger.error(f"Audio extraction failed: {e.stderr}") + raise ValueError(f"Could not extract audio: {e}") + + +def analyze_beats(path: Path, sample_rate: int = 22050) -> BeatInfo: + """ + Detect beats and tempo using librosa. + + Args: + path: Path to audio file (or pre-extracted WAV) + sample_rate: Sample rate for analysis + + Returns: + BeatInfo with beat times, tempo, and confidence + """ + try: + import librosa + except ImportError: + raise ImportError("librosa required for beat detection. Install with: pip install librosa") + + # Load audio + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Detect tempo and beats + tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr) + + # Convert frames to times + beat_times = librosa.frames_to_time(beat_frames, sr=sr).tolist() + + # Estimate confidence from onset strength consistency + onset_env = librosa.onset.onset_strength(y=y, sr=sr) + beat_strength = onset_env[beat_frames] if len(beat_frames) > 0 else [] + confidence = float(beat_strength.mean() / onset_env.max()) if len(beat_strength) > 0 and onset_env.max() > 0 else 0.5 + + # Detect downbeats (first beat of each bar) + # Use beat phase to estimate bar positions + downbeat_times = None + if len(beat_times) >= 4: + # Assume 4/4 time signature, downbeats every 4 beats + downbeat_times = [beat_times[i] for i in range(0, len(beat_times), 4)] + + return BeatInfo( + beat_times=beat_times, + tempo=float(tempo) if hasattr(tempo, '__float__') else float(tempo[0]) if len(tempo) > 0 else 120.0, + confidence=min(1.0, max(0.0, confidence)), + downbeat_times=downbeat_times, + time_signature=4, + ) + + +def analyze_energy(path: Path, window_ms: float = 50.0, sample_rate: int = 22050) -> EnergyEnvelope: + """ + Extract energy (loudness) envelope. + + Args: + path: Path to audio file + window_ms: Analysis window size in milliseconds + sample_rate: Sample rate for analysis + + Returns: + EnergyEnvelope with times and normalized values + """ + try: + import librosa + import numpy as np + except ImportError: + raise ImportError("librosa and numpy required. Install with: pip install librosa numpy") + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Calculate frame size from window_ms + hop_length = int(sr * window_ms / 1000) + + # RMS energy + rms = librosa.feature.rms(y=y, hop_length=hop_length)[0] + + # Normalize to 0-1 + rms_max = rms.max() + if rms_max > 0: + rms_normalized = rms / rms_max + else: + rms_normalized = rms + + # Generate time points + times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop_length) + + return EnergyEnvelope( + times=times.tolist(), + values=rms_normalized.tolist(), + window_ms=window_ms, + ) + + +def analyze_spectrum( + path: Path, + band_ranges: Optional[dict] = None, + window_ms: float = 50.0, + sample_rate: int = 22050 +) -> SpectrumBands: + """ + Extract frequency band envelopes. + + Args: + path: Path to audio file + band_ranges: Dict mapping band name to (low_hz, high_hz) + window_ms: Analysis window size + sample_rate: Sample rate + + Returns: + SpectrumBands with bass, mid, high envelopes + """ + try: + import librosa + import numpy as np + except ImportError: + raise ImportError("librosa and numpy required") + + if band_ranges is None: + band_ranges = { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + } + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + hop_length = int(sr * window_ms / 1000) + + # Compute STFT + n_fft = 2048 + stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length)) + + # Frequency bins + freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) + + def band_energy(low_hz: float, high_hz: float) -> List[float]: + """Sum energy in frequency band.""" + mask = (freqs >= low_hz) & (freqs <= high_hz) + if not mask.any(): + return [0.0] * stft.shape[1] + band = stft[mask, :].sum(axis=0) + # Normalize + band_max = band.max() + if band_max > 0: + band = band / band_max + return band.tolist() + + times = librosa.frames_to_time(np.arange(stft.shape[1]), sr=sr, hop_length=hop_length) + + return SpectrumBands( + bass=band_energy(*band_ranges["bass"]), + mid=band_energy(*band_ranges["mid"]), + high=band_energy(*band_ranges["high"]), + times=times.tolist(), + band_ranges=band_ranges, + ) + + +def analyze_onsets(path: Path, sample_rate: int = 22050) -> List[float]: + """ + Detect onset times (note/sound starts). + + Args: + path: Path to audio file + sample_rate: Sample rate + + Returns: + List of onset times in seconds + """ + try: + import librosa + except ImportError: + raise ImportError("librosa required") + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Detect onsets + onset_frames = librosa.onset.onset_detect(y=y, sr=sr) + onset_times = librosa.frames_to_time(onset_frames, sr=sr) + + return onset_times.tolist() + + +def analyze_audio( + path: Path, + features: Optional[List[str]] = None, +) -> AudioFeatures: + """ + Extract audio features from file. + + Args: + path: Path to audio/video file + features: List of features to extract. Options: + - "beats": Beat detection (tempo, beat times) + - "energy": Loudness envelope + - "spectrum": Frequency band envelopes + - "onsets": Note onset times + - "all": All features + + Returns: + AudioFeatures with requested analysis + """ + if features is None: + features = [FEATURE_ALL] + + # Normalize features + if FEATURE_ALL in features: + features = [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS] + + # Get basic info via ffprobe + duration, sample_rate, channels = _get_audio_info_ffprobe(path) + + result = AudioFeatures( + duration=duration, + sample_rate=sample_rate, + channels=channels, + ) + + # Check if librosa is available for advanced features + try: + import librosa # noqa: F401 + has_librosa = True + except ImportError: + has_librosa = False + if any(f in features for f in [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS]): + logger.warning("librosa not available, skipping advanced audio features") + + if not has_librosa: + return result + + # Extract audio to WAV for librosa + wav_path = None + try: + wav_path = _extract_audio_to_wav(path) + + if FEATURE_BEATS in features or FEATURE_TEMPO in features: + try: + result.beats = analyze_beats(wav_path) + except Exception as e: + logger.warning(f"Beat detection failed: {e}") + + if FEATURE_ENERGY in features: + try: + result.energy = analyze_energy(wav_path) + except Exception as e: + logger.warning(f"Energy analysis failed: {e}") + + if FEATURE_SPECTRUM in features: + try: + result.spectrum = analyze_spectrum(wav_path) + except Exception as e: + logger.warning(f"Spectrum analysis failed: {e}") + + if FEATURE_ONSETS in features: + try: + result.onsets = analyze_onsets(wav_path) + except Exception as e: + logger.warning(f"Onset detection failed: {e}") + + finally: + # Clean up temporary WAV file + if wav_path and wav_path.exists(): + wav_path.unlink() + + return result diff --git a/artdag/analysis/schema.py b/artdag/analysis/schema.py new file mode 100644 index 0000000..4b9825b --- /dev/null +++ b/artdag/analysis/schema.py @@ -0,0 +1,352 @@ +# artdag/analysis/schema.py +""" +Data structures for analysis results. + +Analysis extracts features from input media that inform downstream processing. +Results are cached by: analysis_cache_id = SHA3-256(input_hash + sorted(features)) +""" + +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """Create stable hash from arbitrary data.""" + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class BeatInfo: + """ + Beat detection results. + + Attributes: + beat_times: List of beat positions in seconds + tempo: Estimated tempo in BPM + confidence: Tempo detection confidence (0-1) + downbeat_times: First beat of each bar (if detected) + time_signature: Detected or assumed time signature (e.g., 4) + """ + beat_times: List[float] + tempo: float + confidence: float = 1.0 + downbeat_times: Optional[List[float]] = None + time_signature: int = 4 + + def to_dict(self) -> Dict[str, Any]: + return { + "beat_times": self.beat_times, + "tempo": self.tempo, + "confidence": self.confidence, + "downbeat_times": self.downbeat_times, + "time_signature": self.time_signature, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BeatInfo": + return cls( + beat_times=data["beat_times"], + tempo=data["tempo"], + confidence=data.get("confidence", 1.0), + downbeat_times=data.get("downbeat_times"), + time_signature=data.get("time_signature", 4), + ) + + +@dataclass +class EnergyEnvelope: + """ + Energy (loudness) over time. + + Attributes: + times: Time points in seconds + values: Energy values (0-1, normalized) + window_ms: Analysis window size in milliseconds + """ + times: List[float] + values: List[float] + window_ms: float = 50.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "times": self.times, + "values": self.values, + "window_ms": self.window_ms, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EnergyEnvelope": + return cls( + times=data["times"], + values=data["values"], + window_ms=data.get("window_ms", 50.0), + ) + + def at_time(self, t: float) -> float: + """Interpolate energy value at given time.""" + if not self.times: + return 0.0 + if t <= self.times[0]: + return self.values[0] + if t >= self.times[-1]: + return self.values[-1] + + # Binary search for bracketing indices + lo, hi = 0, len(self.times) - 1 + while hi - lo > 1: + mid = (lo + hi) // 2 + if self.times[mid] <= t: + lo = mid + else: + hi = mid + + # Linear interpolation + t0, t1 = self.times[lo], self.times[hi] + v0, v1 = self.values[lo], self.values[hi] + alpha = (t - t0) / (t1 - t0) if t1 != t0 else 0 + return v0 + alpha * (v1 - v0) + + +@dataclass +class SpectrumBands: + """ + Frequency band envelopes over time. + + Attributes: + bass: Low frequency envelope (20-200 Hz typical) + mid: Mid frequency envelope (200-2000 Hz typical) + high: High frequency envelope (2000-20000 Hz typical) + times: Time points in seconds + band_ranges: Frequency ranges for each band in Hz + """ + bass: List[float] + mid: List[float] + high: List[float] + times: List[float] + band_ranges: Dict[str, Tuple[float, float]] = field(default_factory=lambda: { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + }) + + def to_dict(self) -> Dict[str, Any]: + return { + "bass": self.bass, + "mid": self.mid, + "high": self.high, + "times": self.times, + "band_ranges": self.band_ranges, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SpectrumBands": + return cls( + bass=data["bass"], + mid=data["mid"], + high=data["high"], + times=data["times"], + band_ranges=data.get("band_ranges", { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + }), + ) + + +@dataclass +class AudioFeatures: + """ + All extracted audio features. + + Attributes: + duration: Audio duration in seconds + sample_rate: Sample rate in Hz + channels: Number of audio channels + beats: Beat detection results + energy: Energy envelope + spectrum: Frequency band envelopes + onsets: Note/sound onset times + """ + duration: float + sample_rate: int + channels: int + beats: Optional[BeatInfo] = None + energy: Optional[EnergyEnvelope] = None + spectrum: Optional[SpectrumBands] = None + onsets: Optional[List[float]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "sample_rate": self.sample_rate, + "channels": self.channels, + "beats": self.beats.to_dict() if self.beats else None, + "energy": self.energy.to_dict() if self.energy else None, + "spectrum": self.spectrum.to_dict() if self.spectrum else None, + "onsets": self.onsets, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AudioFeatures": + return cls( + duration=data["duration"], + sample_rate=data["sample_rate"], + channels=data["channels"], + beats=BeatInfo.from_dict(data["beats"]) if data.get("beats") else None, + energy=EnergyEnvelope.from_dict(data["energy"]) if data.get("energy") else None, + spectrum=SpectrumBands.from_dict(data["spectrum"]) if data.get("spectrum") else None, + onsets=data.get("onsets"), + ) + + +@dataclass +class VideoFeatures: + """ + Extracted video features. + + Attributes: + duration: Video duration in seconds + frame_rate: Frames per second + width: Frame width in pixels + height: Frame height in pixels + codec: Video codec name + motion_tempo: Estimated tempo from motion analysis (optional) + scene_changes: Times of detected scene changes + """ + duration: float + frame_rate: float + width: int + height: int + codec: str = "" + motion_tempo: Optional[float] = None + scene_changes: Optional[List[float]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "frame_rate": self.frame_rate, + "width": self.width, + "height": self.height, + "codec": self.codec, + "motion_tempo": self.motion_tempo, + "scene_changes": self.scene_changes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VideoFeatures": + return cls( + duration=data["duration"], + frame_rate=data["frame_rate"], + width=data["width"], + height=data["height"], + codec=data.get("codec", ""), + motion_tempo=data.get("motion_tempo"), + scene_changes=data.get("scene_changes"), + ) + + +@dataclass +class AnalysisResult: + """ + Complete analysis result for an input. + + Combines audio and video features with metadata for caching. + + Attributes: + input_hash: Content hash of the analyzed input + features_requested: List of features that were requested + audio: Audio features (if input has audio) + video: Video features (if input has video) + cache_id: Computed cache ID for this analysis + analyzed_at: Timestamp of analysis + """ + input_hash: str + features_requested: List[str] + audio: Optional[AudioFeatures] = None + video: Optional[VideoFeatures] = None + cache_id: Optional[str] = None + analyzed_at: Optional[str] = None + + def __post_init__(self): + """Compute cache_id if not provided.""" + if self.cache_id is None: + self.cache_id = self._compute_cache_id() + + def _compute_cache_id(self) -> str: + """ + Compute cache ID from input hash and requested features. + + cache_id = SHA3-256(input_hash + sorted(features_requested)) + """ + content = { + "input_hash": self.input_hash, + "features": sorted(self.features_requested), + } + return _stable_hash(content) + + def to_dict(self) -> Dict[str, Any]: + return { + "input_hash": self.input_hash, + "features_requested": self.features_requested, + "audio": self.audio.to_dict() if self.audio else None, + "video": self.video.to_dict() if self.video else None, + "cache_id": self.cache_id, + "analyzed_at": self.analyzed_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnalysisResult": + return cls( + input_hash=data["input_hash"], + features_requested=data["features_requested"], + audio=AudioFeatures.from_dict(data["audio"]) if data.get("audio") else None, + video=VideoFeatures.from_dict(data["video"]) if data.get("video") else None, + cache_id=data.get("cache_id"), + analyzed_at=data.get("analyzed_at"), + ) + + def to_json(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "AnalysisResult": + """Deserialize from JSON string.""" + return cls.from_dict(json.loads(json_str)) + + # Convenience accessors + @property + def tempo(self) -> Optional[float]: + """Get tempo if beats were analyzed.""" + return self.audio.beats.tempo if self.audio and self.audio.beats else None + + @property + def beat_times(self) -> Optional[List[float]]: + """Get beat times if beats were analyzed.""" + return self.audio.beats.beat_times if self.audio and self.audio.beats else None + + @property + def downbeat_times(self) -> Optional[List[float]]: + """Get downbeat times if analyzed.""" + return self.audio.beats.downbeat_times if self.audio and self.audio.beats else None + + @property + def duration(self) -> float: + """Get duration from video or audio.""" + if self.video: + return self.video.duration + if self.audio: + return self.audio.duration + return 0.0 + + @property + def dimensions(self) -> Optional[Tuple[int, int]]: + """Get video dimensions if available.""" + if self.video: + return (self.video.width, self.video.height) + return None diff --git a/artdag/analysis/video.py b/artdag/analysis/video.py new file mode 100644 index 0000000..94d4152 --- /dev/null +++ b/artdag/analysis/video.py @@ -0,0 +1,266 @@ +# artdag/analysis/video.py +""" +Video feature extraction. + +Uses ffprobe for basic metadata and optional OpenCV for motion analysis. +""" + +import json +import logging +import subprocess +from fractions import Fraction +from pathlib import Path +from typing import List, Optional + +from .schema import VideoFeatures + +logger = logging.getLogger(__name__) + +# Feature names +FEATURE_METADATA = "metadata" +FEATURE_MOTION_TEMPO = "motion_tempo" +FEATURE_SCENE_CHANGES = "scene_changes" +FEATURE_ALL = "all" + + +def _parse_frame_rate(rate_str: str) -> float: + """Parse frame rate string like '30000/1001' or '30'.""" + try: + if "/" in rate_str: + frac = Fraction(rate_str) + return float(frac) + return float(rate_str) + except (ValueError, ZeroDivisionError): + return 30.0 # Default + + +def analyze_metadata(path: Path) -> VideoFeatures: + """ + Extract video metadata using ffprobe. + + Args: + path: Path to video file + + Returns: + VideoFeatures with basic metadata + """ + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + "-show_format", + "-select_streams", "v:0", + str(path) + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + raise ValueError(f"Could not read video info: {e}") + + if not data.get("streams"): + raise ValueError("No video stream found") + + stream = data["streams"][0] + fmt = data.get("format", {}) + + # Get duration from format or stream + duration = float(fmt.get("duration", stream.get("duration", 0))) + + # Parse frame rate + frame_rate = _parse_frame_rate(stream.get("avg_frame_rate", "30")) + + return VideoFeatures( + duration=duration, + frame_rate=frame_rate, + width=int(stream.get("width", 0)), + height=int(stream.get("height", 0)), + codec=stream.get("codec_name", ""), + ) + + +def analyze_scene_changes(path: Path, threshold: float = 0.3) -> List[float]: + """ + Detect scene changes using ffmpeg scene detection. + + Args: + path: Path to video file + threshold: Scene change threshold (0-1, lower = more sensitive) + + Returns: + List of scene change times in seconds + """ + cmd = [ + "ffmpeg", "-i", str(path), + "-vf", f"select='gt(scene,{threshold})',showinfo", + "-f", "null", "-" + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True) + stderr = result.stderr + except subprocess.CalledProcessError as e: + logger.warning(f"Scene detection failed: {e}") + return [] + + # Parse scene change times from ffmpeg output + scene_times = [] + for line in stderr.split("\n"): + if "pts_time:" in line: + try: + # Extract pts_time value + for part in line.split(): + if part.startswith("pts_time:"): + time_str = part.split(":")[1] + scene_times.append(float(time_str)) + break + except (ValueError, IndexError): + continue + + return scene_times + + +def analyze_motion_tempo(path: Path, sample_duration: float = 30.0) -> Optional[float]: + """ + Estimate tempo from video motion periodicity. + + Analyzes optical flow or frame differences to detect rhythmic motion. + This is useful for matching video speed to audio tempo. + + Args: + path: Path to video file + sample_duration: Duration to analyze (seconds) + + Returns: + Estimated motion tempo in BPM, or None if not detectable + """ + try: + import cv2 + import numpy as np + except ImportError: + logger.warning("OpenCV not available, skipping motion tempo analysis") + return None + + cap = cv2.VideoCapture(str(path)) + if not cap.isOpened(): + logger.warning(f"Could not open video: {path}") + return None + + try: + fps = cap.get(cv2.CAP_PROP_FPS) + if fps <= 0: + fps = 30.0 + + max_frames = int(sample_duration * fps) + frame_diffs = [] + prev_gray = None + + frame_count = 0 + while frame_count < max_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and resize for speed + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + gray = cv2.resize(gray, (160, 90)) + + if prev_gray is not None: + # Calculate frame difference + diff = cv2.absdiff(gray, prev_gray) + frame_diffs.append(np.mean(diff)) + + prev_gray = gray + frame_count += 1 + + if len(frame_diffs) < 60: # Need at least 2 seconds at 30fps + return None + + # Convert to numpy array + motion = np.array(frame_diffs) + + # Normalize + motion = motion - motion.mean() + if motion.std() > 0: + motion = motion / motion.std() + + # Autocorrelation to find periodicity + n = len(motion) + acf = np.correlate(motion, motion, mode="full")[n-1:] + acf = acf / acf[0] # Normalize + + # Find peaks in autocorrelation (potential beat periods) + # Look for periods between 0.3s (200 BPM) and 2s (30 BPM) + min_lag = int(0.3 * fps) + max_lag = min(int(2.0 * fps), len(acf) - 1) + + if max_lag <= min_lag: + return None + + # Find the highest peak in the valid range + search_range = acf[min_lag:max_lag] + if len(search_range) == 0: + return None + + peak_idx = np.argmax(search_range) + min_lag + peak_value = acf[peak_idx] + + # Only report if peak is significant + if peak_value < 0.1: + return None + + # Convert lag to BPM + period_seconds = peak_idx / fps + bpm = 60.0 / period_seconds + + # Sanity check + if 30 <= bpm <= 200: + return round(bpm, 1) + + return None + + finally: + cap.release() + + +def analyze_video( + path: Path, + features: Optional[List[str]] = None, +) -> VideoFeatures: + """ + Extract video features from file. + + Args: + path: Path to video file + features: List of features to extract. Options: + - "metadata": Basic video info (always included) + - "motion_tempo": Estimated tempo from motion + - "scene_changes": Scene change detection + - "all": All features + + Returns: + VideoFeatures with requested analysis + """ + if features is None: + features = [FEATURE_METADATA] + + if FEATURE_ALL in features: + features = [FEATURE_METADATA, FEATURE_MOTION_TEMPO, FEATURE_SCENE_CHANGES] + + # Basic metadata is always extracted + result = analyze_metadata(path) + + if FEATURE_MOTION_TEMPO in features: + try: + result.motion_tempo = analyze_motion_tempo(path) + except Exception as e: + logger.warning(f"Motion tempo analysis failed: {e}") + + if FEATURE_SCENE_CHANGES in features: + try: + result.scene_changes = analyze_scene_changes(path) + except Exception as e: + logger.warning(f"Scene change detection failed: {e}") + + return result diff --git a/artdag/cache.py b/artdag/cache.py new file mode 100644 index 0000000..6012dba --- /dev/null +++ b/artdag/cache.py @@ -0,0 +1,464 @@ +# primitive/cache.py +""" +Content-addressed file cache for node outputs. + +Each node's output is stored at: cache_dir / node_id / output_file +This enables automatic reuse when the same operation is requested. +""" + +import json +import logging +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def _file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """ + Compute content hash of a file. + + Uses SHA-3 (Keccak) by default for quantum resistance. + """ + import hashlib + hasher = hashlib.new(algorithm) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class CacheEntry: + """Metadata about a cached output.""" + node_id: str + output_path: Path + created_at: float + size_bytes: int + node_type: str + cid: str = "" # Content identifier (IPFS CID or local hash) + execution_time: float = 0.0 + + def to_dict(self) -> Dict: + return { + "node_id": self.node_id, + "output_path": str(self.output_path), + "created_at": self.created_at, + "size_bytes": self.size_bytes, + "node_type": self.node_type, + "cid": self.cid, + "execution_time": self.execution_time, + } + + @classmethod + def from_dict(cls, data: Dict) -> "CacheEntry": + # Support both "cid" and legacy "content_hash" + cid = data.get("cid") or data.get("content_hash", "") + return cls( + node_id=data["node_id"], + output_path=Path(data["output_path"]), + created_at=data["created_at"], + size_bytes=data["size_bytes"], + node_type=data["node_type"], + cid=cid, + execution_time=data.get("execution_time", 0.0), + ) + + +@dataclass +class CacheStats: + """Statistics about cache usage.""" + total_entries: int = 0 + total_size_bytes: int = 0 + hits: int = 0 + misses: int = 0 + hit_rate: float = 0.0 + + def record_hit(self): + self.hits += 1 + self._update_rate() + + def record_miss(self): + self.misses += 1 + self._update_rate() + + def _update_rate(self): + total = self.hits + self.misses + self.hit_rate = self.hits / total if total > 0 else 0.0 + + +class Cache: + """ + Code-addressed file cache. + + The filesystem IS the index - no JSON index files needed. + Each node's hash is its directory name. + + Structure: + cache_dir/ + / + output.ext # Actual output file + metadata.json # Per-node metadata (optional) + """ + + def __init__(self, cache_dir: Path | str): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.stats = CacheStats() + + def _node_dir(self, node_id: str) -> Path: + """Get the cache directory for a node.""" + return self.cache_dir / node_id + + def _find_output_file(self, node_dir: Path) -> Optional[Path]: + """Find the output file in a node directory.""" + if not node_dir.exists() or not node_dir.is_dir(): + return None + for f in node_dir.iterdir(): + if f.is_file() and f.name.startswith("output."): + return f + return None + + def get(self, node_id: str) -> Optional[Path]: + """ + Get cached output path for a node. + + Checks filesystem directly - no in-memory index. + Returns the output path if cached, None otherwise. + """ + node_dir = self._node_dir(node_id) + output_file = self._find_output_file(node_dir) + + if output_file: + self.stats.record_hit() + logger.debug(f"Cache hit: {node_id[:16]}...") + return output_file + + self.stats.record_miss() + return None + + def put(self, node_id: str, source_path: Path, node_type: str, + execution_time: float = 0.0, move: bool = False) -> Path: + """ + Store a file in the cache. + + Args: + node_id: The code-addressed node ID (hash) + source_path: Path to the file to cache + node_type: Type of the node (for metadata) + execution_time: How long the node took to execute + move: If True, move the file instead of copying + + Returns: + Path to the cached file + """ + node_dir = self._node_dir(node_id) + node_dir.mkdir(parents=True, exist_ok=True) + + # Preserve extension + ext = source_path.suffix or ".out" + output_path = node_dir / f"output{ext}" + + # Copy or move file (skip if already in place) + source_resolved = Path(source_path).resolve() + output_resolved = output_path.resolve() + if source_resolved != output_resolved: + if move: + shutil.move(source_path, output_path) + else: + shutil.copy2(source_path, output_path) + + # Compute content hash (IPFS CID of the result) + cid = _file_hash(output_path) + + # Store per-node metadata (optional, for stats/debugging) + metadata = { + "node_id": node_id, + "output_path": str(output_path), + "created_at": time.time(), + "size_bytes": output_path.stat().st_size, + "node_type": node_type, + "cid": cid, + "execution_time": execution_time, + } + metadata_path = node_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.debug(f"Cached: {node_id[:16]}... ({metadata['size_bytes']} bytes)") + return output_path + + def has(self, node_id: str) -> bool: + """Check if a node is cached (without affecting stats).""" + return self._find_output_file(self._node_dir(node_id)) is not None + + def remove(self, node_id: str) -> bool: + """Remove a node from the cache.""" + node_dir = self._node_dir(node_id) + if node_dir.exists(): + shutil.rmtree(node_dir) + return True + return False + + def clear(self): + """Clear all cached entries.""" + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + shutil.rmtree(node_dir) + self.stats = CacheStats() + + def get_stats(self) -> CacheStats: + """Get cache statistics (scans filesystem).""" + stats = CacheStats() + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + output_file = self._find_output_file(node_dir) + if output_file: + stats.total_entries += 1 + stats.total_size_bytes += output_file.stat().st_size + stats.hits = self.stats.hits + stats.misses = self.stats.misses + stats.hit_rate = self.stats.hit_rate + return stats + + def list_entries(self) -> List[CacheEntry]: + """List all cache entries (scans filesystem).""" + entries = [] + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + entry = self._load_entry_from_disk(node_dir.name) + if entry: + entries.append(entry) + return entries + + def _load_entry_from_disk(self, node_id: str) -> Optional[CacheEntry]: + """Load entry metadata from disk.""" + node_dir = self._node_dir(node_id) + metadata_path = node_dir / "metadata.json" + output_file = self._find_output_file(node_dir) + + if not output_file: + return None + + if metadata_path.exists(): + try: + with open(metadata_path) as f: + data = json.load(f) + return CacheEntry.from_dict(data) + except (json.JSONDecodeError, KeyError): + pass + + # Fallback: create entry from filesystem + return CacheEntry( + node_id=node_id, + output_path=output_file, + created_at=output_file.stat().st_mtime, + size_bytes=output_file.stat().st_size, + node_type="unknown", + cid=_file_hash(output_file), + ) + + def get_entry(self, node_id: str) -> Optional[CacheEntry]: + """Get cache entry metadata (without affecting stats).""" + return self._load_entry_from_disk(node_id) + + def find_by_cid(self, cid: str) -> Optional[CacheEntry]: + """Find a cache entry by its content hash (scans filesystem).""" + for entry in self.list_entries(): + if entry.cid == cid: + return entry + return None + + def prune(self, max_size_bytes: int = None, max_age_seconds: float = None) -> int: + """ + Prune cache based on size or age. + + Args: + max_size_bytes: Remove oldest entries until under this size + max_age_seconds: Remove entries older than this + + Returns: + Number of entries removed + """ + removed = 0 + now = time.time() + entries = self.list_entries() + + # Remove by age first + if max_age_seconds is not None: + for entry in entries: + if now - entry.created_at > max_age_seconds: + self.remove(entry.node_id) + removed += 1 + + # Then by size (remove oldest first) + if max_size_bytes is not None: + stats = self.get_stats() + if stats.total_size_bytes > max_size_bytes: + sorted_entries = sorted(entries, key=lambda e: e.created_at) + total_size = stats.total_size_bytes + for entry in sorted_entries: + if total_size <= max_size_bytes: + break + self.remove(entry.node_id) + total_size -= entry.size_bytes + removed += 1 + + return removed + + def get_output_path(self, node_id: str, extension: str = ".mkv") -> Path: + """Get the output path for a node (creates directory if needed).""" + node_dir = self._node_dir(node_id) + node_dir.mkdir(parents=True, exist_ok=True) + return node_dir / f"output{extension}" + + # Effect storage methods + + def _effects_dir(self) -> Path: + """Get the effects subdirectory.""" + effects_dir = self.cache_dir / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + return effects_dir + + def store_effect(self, source: str) -> str: + """ + Store an effect in the cache. + + Args: + source: Effect source code + + Returns: + Content hash (cache ID) of the effect + """ + import hashlib as _hashlib + + # Compute content hash + cid = _hashlib.sha3_256(source.encode("utf-8")).hexdigest() + + # Try to load full metadata if effects module available + try: + from .effects.loader import load_effect + loaded = load_effect(source) + meta_dict = loaded.meta.to_dict() + dependencies = loaded.dependencies + requires_python = loaded.requires_python + except ImportError: + # Fallback: store without parsed metadata + meta_dict = {} + dependencies = [] + requires_python = ">=3.10" + + effect_dir = self._effects_dir() / cid + effect_dir.mkdir(parents=True, exist_ok=True) + + # Store source + source_path = effect_dir / "effect.py" + source_path.write_text(source, encoding="utf-8") + + # Store metadata + metadata = { + "cid": cid, + "meta": meta_dict, + "dependencies": dependencies, + "requires_python": requires_python, + "stored_at": time.time(), + } + metadata_path = effect_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Stored effect '{loaded.meta.name}' with hash {cid[:16]}...") + return cid + + def get_effect(self, cid: str) -> Optional[str]: + """ + Get effect source by content hash. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Effect source code if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + source_path = effect_dir / "effect.py" + + if not source_path.exists(): + return None + + return source_path.read_text(encoding="utf-8") + + def get_effect_path(self, cid: str) -> Optional[Path]: + """ + Get path to effect source file. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Path to effect.py if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + source_path = effect_dir / "effect.py" + + if not source_path.exists(): + return None + + return source_path + + def get_effect_metadata(self, cid: str) -> Optional[dict]: + """ + Get effect metadata by content hash. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Metadata dict if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + metadata_path = effect_dir / "metadata.json" + + if not metadata_path.exists(): + return None + + try: + with open(metadata_path) as f: + return json.load(f) + except (json.JSONDecodeError, KeyError): + return None + + def has_effect(self, cid: str) -> bool: + """Check if an effect is cached.""" + effect_dir = self._effects_dir() / cid + return (effect_dir / "effect.py").exists() + + def list_effects(self) -> List[dict]: + """List all cached effects with their metadata.""" + effects = [] + effects_dir = self._effects_dir() + + if not effects_dir.exists(): + return effects + + for effect_dir in effects_dir.iterdir(): + if effect_dir.is_dir(): + metadata = self.get_effect_metadata(effect_dir.name) + if metadata: + effects.append(metadata) + + return effects + + def remove_effect(self, cid: str) -> bool: + """Remove an effect from the cache.""" + effect_dir = self._effects_dir() / cid + + if not effect_dir.exists(): + return False + + shutil.rmtree(effect_dir) + logger.info(f"Removed effect {cid[:16]}...") + return True diff --git a/artdag/cli.py b/artdag/cli.py new file mode 100644 index 0000000..9aa5c8c --- /dev/null +++ b/artdag/cli.py @@ -0,0 +1,724 @@ +#!/usr/bin/env python3 +""" +Art DAG CLI + +Command-line interface for the 3-phase execution model: + artdag analyze - Extract features from inputs + artdag plan - Generate execution plan + artdag execute - Run the plan + artdag run-recipe - Full pipeline + +Usage: + artdag analyze -i :[@] [--features ] + artdag plan -i : [--analysis ] + artdag execute [--dry-run] + artdag run-recipe -i :[@] +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +def parse_input(input_str: str) -> Tuple[str, str, Optional[str]]: + """ + Parse input specification: name:hash[@path] + + Returns (name, hash, path or None) + """ + if "@" in input_str: + name_hash, path = input_str.rsplit("@", 1) + else: + name_hash = input_str + path = None + + if ":" not in name_hash: + raise ValueError(f"Invalid input format: {input_str}. Expected name:hash[@path]") + + name, hash_value = name_hash.split(":", 1) + return name, hash_value, path + + +def parse_inputs(input_list: List[str]) -> Tuple[Dict[str, str], Dict[str, str]]: + """ + Parse list of input specifications. + + Returns (input_hashes, input_paths) + """ + input_hashes = {} + input_paths = {} + + for input_str in input_list: + name, hash_value, path = parse_input(input_str) + input_hashes[name] = hash_value + if path: + input_paths[name] = path + + return input_hashes, input_paths + + +def cmd_analyze(args): + """Run analysis phase.""" + from .analysis import Analyzer + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["all"] + + # Create analyzer + cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./analysis_cache") + analyzer = Analyzer(cache_dir=cache_dir) + + # Analyze each input + results = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + path = Path(path) + + print(f"Analyzing {name} ({hash_value[:16]}...)...") + + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=path, + ) + + results[hash_value] = result.to_dict() + + # Print summary + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM") + print(f" Beats: {len(result.audio.beats.beat_times)}") + if result.video: + print(f" Duration: {result.video.duration:.1f}s") + print(f" Dimensions: {result.video.width}x{result.video.height}") + + # Write output + output_path = Path(args.output) if args.output else Path("analysis.json") + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + + print(f"\nAnalysis saved to: {output_path}") + + +def cmd_plan(args): + """Run planning phase.""" + from .analysis import AnalysisResult + from .planning import RecipePlanner, Recipe + + # Load recipe + recipe = Recipe.from_file(Path(args.recipe)) + print(f"Recipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, _ = parse_inputs(args.input) + + # Load analysis if provided + analysis = {} + if args.analysis: + with open(args.analysis, "r") as f: + analysis_data = json.load(f) + for hash_value, data in analysis_data.items(): + analysis[hash_value] = AnalysisResult.from_dict(data) + + # Create planner + planner = RecipePlanner(use_tree_reduction=not args.no_tree_reduction) + + # Generate plan + print("Generating execution plan...") + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis, + ) + + # Print summary + print(f"\nPlan ID: {plan.plan_id[:16]}...") + print(f"Steps: {len(plan.steps)}") + + steps_by_level = plan.get_steps_by_level() + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + print(f"Levels: {max_level + 1}") + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f" Level {level}: {len(steps)} steps (parallel)") + + # Write output + output_path = Path(args.output) if args.output else Path("plan.json") + with open(output_path, "w") as f: + f.write(plan.to_json()) + + print(f"\nPlan saved to: {output_path}") + + +def cmd_execute(args): + """Run execution phase.""" + from .planning import ExecutionPlan + from .cache import Cache + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Load plan + with open(args.plan, "r") as f: + plan = ExecutionPlan.from_json(f.read()) + + print(f"Executing plan: {plan.plan_id[:16]}...") + print(f"Steps: {len(plan.steps)}") + + if args.dry_run: + print("\n=== DRY RUN ===") + + # Check cache status + cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache")) + steps_by_level = plan.get_steps_by_level() + + cached_count = 0 + pending_count = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}:") + for step in steps: + if cache.has(step.cache_id): + print(f" [CACHED] {step.step_id}: {step.node_type}") + cached_count += 1 + else: + print(f" [PENDING] {step.step_id}: {step.node_type}") + pending_count += 1 + + print(f"\nSummary: {cached_count} cached, {pending_count} pending") + return + + # Execute locally (for testing - production uses Celery) + cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache")) + + cache_paths = {} + for name, hash_value in plan.input_hashes.items(): + if cache.has(hash_value): + entry = cache.get(hash_value) + cache_paths[hash_value] = str(entry.output_path) + + steps_by_level = plan.get_steps_by_level() + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + if cache.has(step.cache_id): + cached_path = cache.get(step.cache_id) + cache_paths[step.cache_id] = str(cached_path) + cache_paths[step.step_id] = str(cached_path) + print(f" [CACHED] {step.step_id}") + cached += 1 + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + # Get executor + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" ERROR: No executor for {step.node_type}") + continue + + # Resolve inputs + input_paths = [] + for input_id in step.input_steps: + if input_id in cache_paths: + input_paths.append(Path(cache_paths[input_id])) + else: + input_step = plan.get_step(input_id) + if input_step and input_step.cache_id in cache_paths: + input_paths.append(Path(cache_paths[input_step.cache_id])) + + if len(input_paths) != len(step.input_steps): + print(f" ERROR: Missing inputs") + continue + + # Execute + output_path = cache.get_output_path(step.cache_id) + try: + result_path = executor.execute(step.config, input_paths, output_path) + cache.put(step.cache_id, result_path, node_type=step.node_type) + cache_paths[step.cache_id] = str(result_path) + cache_paths[step.step_id] = str(result_path) + print(f" [DONE] -> {result_path}") + executed += 1 + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_path = cache_paths.get(output_step.cache_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Cached: {cached}") + print(f"Executed: {executed}") + if output_path: + print(f"Output: {output_path}") + + +def cmd_run_recipe(args): + """Run complete pipeline: analyze → plan → execute.""" + from .analysis import Analyzer, AnalysisResult + from .planning import RecipePlanner, Recipe + from .cache import Cache + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Load recipe + recipe = Recipe.from_file(Path(args.recipe)) + print(f"Recipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["beats", "energy"] + + cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./cache") + + # Phase 1: Analyze + print("\n=== Phase 1: Analysis ===") + analyzer = Analyzer(cache_dir=cache_dir / "analysis") + + analysis = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + path = Path(path) + print(f"Analyzing {name}...") + + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=path, + ) + analysis[hash_value] = result + + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats") + + # Phase 2: Plan + print("\n=== Phase 2: Planning ===") + + # Check for cached plan + plans_dir = cache_dir / "plans" + plans_dir.mkdir(parents=True, exist_ok=True) + + # Generate plan to get plan_id (deterministic hash) + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis, + ) + + plan_cache_path = plans_dir / f"{plan.plan_id}.json" + + if plan_cache_path.exists(): + print(f"Plan cached: {plan.plan_id[:16]}...") + from .planning import ExecutionPlan + with open(plan_cache_path, "r") as f: + plan = ExecutionPlan.from_json(f.read()) + else: + # Save plan to cache + with open(plan_cache_path, "w") as f: + f.write(plan.to_json()) + print(f"Plan saved: {plan.plan_id[:16]}...") + + print(f"Plan: {len(plan.steps)} steps") + steps_by_level = plan.get_steps_by_level() + print(f"Levels: {len(steps_by_level)}") + + # Phase 3: Execute + print("\n=== Phase 3: Execution ===") + + cache = Cache(cache_dir) + + # Build initial cache paths + cache_paths = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + cache_paths[hash_value] = path + cache_paths[name] = path + + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + if cache.has(step.cache_id): + cached_path = cache.get(step.cache_id) + cache_paths[step.cache_id] = str(cached_path) + cache_paths[step.step_id] = str(cached_path) + print(f" [CACHED] {step.step_id}") + cached += 1 + continue + + # Handle SOURCE specially + if step.node_type == "SOURCE": + cid = step.config.get("cid") + if cid in cache_paths: + cache_paths[step.cache_id] = cache_paths[cid] + cache_paths[step.step_id] = cache_paths[cid] + print(f" [SOURCE] {step.step_id}") + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" SKIP: No executor for {step.node_type}") + continue + + # Resolve inputs + input_paths_list = [] + for input_id in step.input_steps: + if input_id in cache_paths: + input_paths_list.append(Path(cache_paths[input_id])) + else: + input_step = plan.get_step(input_id) + if input_step and input_step.cache_id in cache_paths: + input_paths_list.append(Path(cache_paths[input_step.cache_id])) + + if len(input_paths_list) != len(step.input_steps): + print(f" ERROR: Missing inputs for {step.step_id}") + continue + + output_path = cache.get_output_path(step.cache_id) + try: + result_path = executor.execute(step.config, input_paths_list, output_path) + cache.put(step.cache_id, result_path, node_type=step.node_type) + cache_paths[step.cache_id] = str(result_path) + cache_paths[step.step_id] = str(result_path) + print(f" [DONE]") + executed += 1 + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_path = cache_paths.get(output_step.cache_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Cached: {cached}") + print(f"Executed: {executed}") + if output_path: + print(f"Output: {output_path}") + + +def cmd_run_recipe_ipfs(args): + """Run complete pipeline with IPFS-primary mode. + + Everything stored on IPFS: + - Inputs (media files) + - Analysis results (JSON) + - Execution plans (JSON) + - Step outputs (media files) + """ + import hashlib + import shutil + import tempfile + + from .analysis import Analyzer, AnalysisResult + from .planning import RecipePlanner, Recipe, ExecutionPlan + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Check for ipfs_client + try: + from art_celery import ipfs_client + except ImportError: + # Try relative import for when running from art-celery + try: + import ipfs_client + except ImportError: + print("Error: ipfs_client not available. Install art-celery or run from art-celery directory.") + sys.exit(1) + + # Check IPFS availability + if not ipfs_client.is_available(): + print("Error: IPFS daemon not available. Start IPFS with 'ipfs daemon'") + sys.exit(1) + + print("=== IPFS-Primary Mode ===") + print(f"IPFS Node: {ipfs_client.get_node_id()[:16]}...") + + # Load recipe + recipe_path = Path(args.recipe) + recipe = Recipe.from_file(recipe_path) + print(f"\nRecipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["beats", "energy"] + + # Phase 0: Register on IPFS + print("\n=== Phase 0: Register on IPFS ===") + + # Register recipe + recipe_bytes = recipe_path.read_bytes() + recipe_cid = ipfs_client.add_bytes(recipe_bytes) + print(f"Recipe CID: {recipe_cid}") + + # Register inputs + input_cids = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + cid = ipfs_client.add_file(Path(path)) + if cid: + input_cids[name] = cid + print(f"Input '{name}': {cid}") + else: + print(f"Error: Failed to add input '{name}' to IPFS") + sys.exit(1) + + # Phase 1: Analyze + print("\n=== Phase 1: Analysis ===") + + # Create temp dir for analysis + work_dir = Path(tempfile.mkdtemp(prefix="artdag_ipfs_")) + analysis_cids = {} + analysis = {} + + try: + for name, hash_value in input_hashes.items(): + input_cid = input_cids.get(name) + if not input_cid: + continue + + print(f"Analyzing {name}...") + + # Fetch from IPFS to temp + temp_input = work_dir / f"input_{name}.mkv" + if not ipfs_client.get_file(input_cid, temp_input): + print(f" Error: Failed to fetch from IPFS") + continue + + # Run analysis + analyzer = Analyzer(cache_dir=None) + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=temp_input, + ) + + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats") + + # Store analysis on IPFS + analysis_cid = ipfs_client.add_json(result.to_dict()) + if analysis_cid: + analysis_cids[hash_value] = analysis_cid + analysis[hash_value] = result + print(f" Analysis CID: {analysis_cid}") + + # Phase 2: Plan + print("\n=== Phase 2: Planning ===") + + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis if analysis else None, + ) + + # Store plan on IPFS + import json + plan_dict = json.loads(plan.to_json()) + plan_cid = ipfs_client.add_json(plan_dict) + print(f"Plan ID: {plan.plan_id[:16]}...") + print(f"Plan CID: {plan_cid}") + print(f"Steps: {len(plan.steps)}") + + steps_by_level = plan.get_steps_by_level() + print(f"Levels: {len(steps_by_level)}") + + # Phase 3: Execute + print("\n=== Phase 3: Execution ===") + + # CID results + cid_results = dict(input_cids) + step_cids = {} + + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + # Handle SOURCE + if step.node_type == "SOURCE": + source_name = step.config.get("name") or step.step_id + cid = cid_results.get(source_name) + if cid: + step_cids[step.step_id] = cid + print(f" [SOURCE] {step.step_id}") + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" SKIP: No executor for {step.node_type}") + continue + + # Fetch inputs from IPFS + input_paths_list = [] + for i, input_step_id in enumerate(step.input_steps): + input_cid = step_cids.get(input_step_id) or cid_results.get(input_step_id) + if not input_cid: + print(f" ERROR: Missing input CID for {input_step_id}") + continue + + temp_path = work_dir / f"step_{step.step_id}_input_{i}.mkv" + if not ipfs_client.get_file(input_cid, temp_path): + print(f" ERROR: Failed to fetch {input_cid}") + continue + input_paths_list.append(temp_path) + + if len(input_paths_list) != len(step.input_steps): + print(f" ERROR: Missing inputs") + continue + + # Execute + output_path = work_dir / f"step_{step.step_id}_output.mkv" + try: + result_path = executor.execute(step.config, input_paths_list, output_path) + + # Add to IPFS + output_cid = ipfs_client.add_file(result_path) + if output_cid: + step_cids[step.step_id] = output_cid + print(f" [DONE] CID: {output_cid}") + executed += 1 + else: + print(f" [FAILED] Could not add to IPFS") + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_cid = step_cids.get(output_step.step_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Executed: {executed}") + if output_cid: + print(f"Output CID: {output_cid}") + print(f"Fetch with: ipfs get {output_cid}") + + # Summary of all CIDs + print(f"\n=== All CIDs ===") + print(f"Recipe: {recipe_cid}") + print(f"Plan: {plan_cid}") + for name, cid in input_cids.items(): + print(f"Input '{name}': {cid}") + for hash_val, cid in analysis_cids.items(): + print(f"Analysis '{hash_val[:16]}...': {cid}") + if output_cid: + print(f"Output: {output_cid}") + + finally: + # Cleanup temp + shutil.rmtree(work_dir, ignore_errors=True) + + +def main(): + parser = argparse.ArgumentParser( + prog="artdag", + description="Art DAG - Declarative media composition", + ) + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # analyze command + analyze_parser = subparsers.add_parser("analyze", help="Extract features from inputs") + analyze_parser.add_argument("recipe", help="Recipe YAML file") + analyze_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash[@path]") + analyze_parser.add_argument("--features", help="Features to extract (comma-separated)") + analyze_parser.add_argument("-o", "--output", help="Output file (default: analysis.json)") + analyze_parser.add_argument("--cache-dir", help="Analysis cache directory") + + # plan command + plan_parser = subparsers.add_parser("plan", help="Generate execution plan") + plan_parser.add_argument("recipe", help="Recipe YAML file") + plan_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash") + plan_parser.add_argument("--analysis", help="Analysis JSON file") + plan_parser.add_argument("-o", "--output", help="Output file (default: plan.json)") + plan_parser.add_argument("--no-tree-reduction", action="store_true", + help="Disable tree reduction optimization") + + # execute command + execute_parser = subparsers.add_parser("execute", help="Execute a plan") + execute_parser.add_argument("plan", help="Plan JSON file") + execute_parser.add_argument("--dry-run", action="store_true", + help="Show what would execute") + execute_parser.add_argument("--cache-dir", help="Cache directory") + + # run-recipe command + run_parser = subparsers.add_parser("run-recipe", help="Full pipeline: analyze → plan → execute") + run_parser.add_argument("recipe", help="Recipe YAML file") + run_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash[@path]") + run_parser.add_argument("--features", help="Features to extract (comma-separated)") + run_parser.add_argument("--cache-dir", help="Cache directory") + run_parser.add_argument("--ipfs-primary", action="store_true", + help="Use IPFS-primary mode (everything on IPFS, no local cache)") + + args = parser.parse_args() + + if args.command == "analyze": + cmd_analyze(args) + elif args.command == "plan": + cmd_plan(args) + elif args.command == "execute": + cmd_execute(args) + elif args.command == "run-recipe": + if getattr(args, 'ipfs_primary', False): + cmd_run_recipe_ipfs(args) + else: + cmd_run_recipe(args) + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/artdag/client.py b/artdag/client.py new file mode 100644 index 0000000..21a1ab5 --- /dev/null +++ b/artdag/client.py @@ -0,0 +1,201 @@ +# primitive/client.py +""" +Client SDK for the primitive execution server. + +Provides a simple API for submitting DAGs and retrieving results. + +Usage: + client = PrimitiveClient("http://localhost:8080") + + # Build a DAG + builder = DAGBuilder() + source = builder.source("/path/to/video.mp4") + segment = builder.segment(source, duration=5.0) + builder.set_output(segment) + dag = builder.build() + + # Execute and wait for result + result = client.execute(dag) + print(f"Output: {result.output_path}") +""" + +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional +from urllib.request import urlopen, Request +from urllib.error import HTTPError, URLError + +from .dag import DAG, DAGBuilder + + +@dataclass +class ExecutionResult: + """Result from server execution.""" + success: bool + output_path: Optional[Path] = None + error: Optional[str] = None + execution_time: float = 0.0 + nodes_executed: int = 0 + nodes_cached: int = 0 + + +@dataclass +class CacheStats: + """Cache statistics from server.""" + total_entries: int = 0 + total_size_bytes: int = 0 + hits: int = 0 + misses: int = 0 + hit_rate: float = 0.0 + + +class PrimitiveClient: + """ + Client for the primitive execution server. + + Args: + base_url: Server URL (e.g., "http://localhost:8080") + timeout: Request timeout in seconds + """ + + def __init__(self, base_url: str = "http://localhost:8080", timeout: float = 300): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def _request(self, method: str, path: str, data: dict = None) -> dict: + """Make HTTP request to server.""" + url = f"{self.base_url}{path}" + + if data is not None: + body = json.dumps(data).encode() + headers = {"Content-Type": "application/json"} + else: + body = None + headers = {} + + req = Request(url, data=body, headers=headers, method=method) + + try: + with urlopen(req, timeout=self.timeout) as response: + return json.loads(response.read().decode()) + except HTTPError as e: + error_body = e.read().decode() + try: + error_data = json.loads(error_body) + raise RuntimeError(error_data.get("error", str(e))) + except json.JSONDecodeError: + raise RuntimeError(f"HTTP {e.code}: {error_body}") + except URLError as e: + raise ConnectionError(f"Failed to connect to server: {e}") + + def health(self) -> bool: + """Check if server is healthy.""" + try: + result = self._request("GET", "/health") + return result.get("status") == "ok" + except Exception: + return False + + def submit(self, dag: DAG) -> str: + """ + Submit a DAG for execution. + + Args: + dag: The DAG to execute + + Returns: + Job ID for tracking + """ + result = self._request("POST", "/execute", dag.to_dict()) + return result["job_id"] + + def status(self, job_id: str) -> str: + """ + Get job status. + + Args: + job_id: Job ID from submit() + + Returns: + Status: "pending", "running", "completed", or "failed" + """ + result = self._request("GET", f"/status/{job_id}") + return result["status"] + + def result(self, job_id: str) -> Optional[ExecutionResult]: + """ + Get job result (non-blocking). + + Args: + job_id: Job ID from submit() + + Returns: + ExecutionResult if complete, None if still running + """ + data = self._request("GET", f"/result/{job_id}") + + if not data.get("ready", False): + return None + + return ExecutionResult( + success=data.get("success", False), + output_path=Path(data["output_path"]) if data.get("output_path") else None, + error=data.get("error"), + execution_time=data.get("execution_time", 0), + nodes_executed=data.get("nodes_executed", 0), + nodes_cached=data.get("nodes_cached", 0), + ) + + def wait(self, job_id: str, poll_interval: float = 0.5) -> ExecutionResult: + """ + Wait for job completion and return result. + + Args: + job_id: Job ID from submit() + poll_interval: Seconds between status checks + + Returns: + ExecutionResult + """ + while True: + result = self.result(job_id) + if result is not None: + return result + time.sleep(poll_interval) + + def execute(self, dag: DAG, poll_interval: float = 0.5) -> ExecutionResult: + """ + Submit DAG and wait for result. + + Convenience method combining submit() and wait(). + + Args: + dag: The DAG to execute + poll_interval: Seconds between status checks + + Returns: + ExecutionResult + """ + job_id = self.submit(dag) + return self.wait(job_id, poll_interval) + + def cache_stats(self) -> CacheStats: + """Get cache statistics.""" + data = self._request("GET", "/cache/stats") + return CacheStats( + total_entries=data.get("total_entries", 0), + total_size_bytes=data.get("total_size_bytes", 0), + hits=data.get("hits", 0), + misses=data.get("misses", 0), + hit_rate=data.get("hit_rate", 0.0), + ) + + def clear_cache(self) -> None: + """Clear the server cache.""" + self._request("DELETE", "/cache") + + +# Re-export DAGBuilder for convenience +__all__ = ["PrimitiveClient", "ExecutionResult", "CacheStats", "DAGBuilder"] diff --git a/artdag/dag.py b/artdag/dag.py new file mode 100644 index 0000000..735b7a2 --- /dev/null +++ b/artdag/dag.py @@ -0,0 +1,344 @@ +# primitive/dag.py +""" +Core DAG data structures. + +Nodes are content-addressed: node_id = hash(type + config + input_ids) +This enables automatic caching and deduplication. +""" + +import hashlib +import json +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Dict, List, Optional + + +class NodeType(Enum): + """Built-in node types.""" + # Source operations + SOURCE = auto() # Load file from path + + # Transform operations + SEGMENT = auto() # Extract time range + RESIZE = auto() # Scale/crop/pad + TRANSFORM = auto() # Visual effects (color, blur, etc.) + + # Compose operations + SEQUENCE = auto() # Concatenate in time + LAYER = auto() # Stack spatially (overlay) + MUX = auto() # Combine video + audio streams + BLEND = auto() # Blend two inputs + AUDIO_MIX = auto() # Mix multiple audio streams + SWITCH = auto() # Time-based input switching + + # Analysis operations + ANALYZE = auto() # Extract features (audio, motion, etc.) + + # Generation operations + GENERATE = auto() # Create content (text, graphics, etc.) + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """ + Create stable hash from arbitrary data. + + Uses SHA-3 (Keccak) for quantum resistance. + Returns full hash - no truncation. + + Args: + data: Data to hash (will be JSON serialized) + algorithm: Hash algorithm (default: sha3_256) + + Returns: + Full hex digest + """ + # Convert to JSON with sorted keys for stability + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class Node: + """ + A node in the execution DAG. + + Attributes: + node_type: The operation type (NodeType enum or string for custom types) + config: Operation-specific configuration + inputs: List of input node IDs (resolved during execution) + node_id: Content-addressed ID (computed from type + config + inputs) + name: Optional human-readable name for debugging + """ + node_type: NodeType | str + config: Dict[str, Any] = field(default_factory=dict) + inputs: List[str] = field(default_factory=list) + node_id: Optional[str] = None + name: Optional[str] = None + + def __post_init__(self): + """Compute node_id if not provided.""" + if self.node_id is None: + self.node_id = self._compute_id() + + def _compute_id(self) -> str: + """Compute content-addressed ID from node contents.""" + type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type) + content = { + "type": type_str, + "config": self.config, + "inputs": sorted(self.inputs), # Sort for stability + } + return _stable_hash(content) + + def to_dict(self) -> Dict[str, Any]: + """Serialize node to dictionary.""" + type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type) + return { + "node_id": self.node_id, + "node_type": type_str, + "config": self.config, + "inputs": self.inputs, + "name": self.name, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Node": + """Deserialize node from dictionary.""" + type_str = data["node_type"] + try: + node_type = NodeType[type_str] + except KeyError: + node_type = type_str # Custom type as string + + return cls( + node_type=node_type, + config=data.get("config", {}), + inputs=data.get("inputs", []), + node_id=data.get("node_id"), + name=data.get("name"), + ) + + +@dataclass +class DAG: + """ + A directed acyclic graph of nodes. + + Attributes: + nodes: Dictionary mapping node_id -> Node + output_id: The ID of the final output node + metadata: Optional metadata about the DAG (source, version, etc.) + """ + nodes: Dict[str, Node] = field(default_factory=dict) + output_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def add_node(self, node: Node) -> str: + """Add a node to the DAG, returning its ID.""" + if node.node_id in self.nodes: + # Node already exists (deduplication via content addressing) + return node.node_id + self.nodes[node.node_id] = node + return node.node_id + + def set_output(self, node_id: str) -> None: + """Set the output node.""" + if node_id not in self.nodes: + raise ValueError(f"Node {node_id} not in DAG") + self.output_id = node_id + + def get_node(self, node_id: str) -> Node: + """Get a node by ID.""" + if node_id not in self.nodes: + raise KeyError(f"Node {node_id} not found") + return self.nodes[node_id] + + def topological_order(self) -> List[str]: + """Return nodes in topological order (dependencies first).""" + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = self.nodes[node_id] + for input_id in node.inputs: + visit(input_id) + order.append(node_id) + + # Visit all nodes (not just output, in case of disconnected components) + for node_id in self.nodes: + visit(node_id) + + return order + + def validate(self) -> List[str]: + """Validate DAG structure. Returns list of errors (empty if valid).""" + errors = [] + + if self.output_id is None: + errors.append("No output node set") + elif self.output_id not in self.nodes: + errors.append(f"Output node {self.output_id} not in DAG") + + # Check all input references are valid + for node_id, node in self.nodes.items(): + for input_id in node.inputs: + if input_id not in self.nodes: + errors.append(f"Node {node_id} references missing input {input_id}") + + # Check for cycles (skip if we already found missing inputs) + if not any("missing" in e for e in errors): + try: + self.topological_order() + except (RecursionError, KeyError): + errors.append("DAG contains cycles or invalid references") + + return errors + + def to_dict(self) -> Dict[str, Any]: + """Serialize DAG to dictionary.""" + return { + "nodes": {nid: node.to_dict() for nid, node in self.nodes.items()}, + "output_id": self.output_id, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DAG": + """Deserialize DAG from dictionary.""" + dag = cls(metadata=data.get("metadata", {})) + for node_data in data.get("nodes", {}).values(): + dag.add_node(Node.from_dict(node_data)) + dag.output_id = data.get("output_id") + return dag + + def to_json(self) -> str: + """Serialize DAG to JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "DAG": + """Deserialize DAG from JSON string.""" + return cls.from_dict(json.loads(json_str)) + + +class DAGBuilder: + """ + Fluent builder for constructing DAGs. + + Example: + builder = DAGBuilder() + source = builder.source("/path/to/video.mp4") + segment = builder.segment(source, duration=5.0) + builder.set_output(segment) + dag = builder.build() + """ + + def __init__(self): + self.dag = DAG() + + def _add(self, node_type: NodeType | str, config: Dict[str, Any], + inputs: List[str] = None, name: str = None) -> str: + """Add a node and return its ID.""" + node = Node( + node_type=node_type, + config=config, + inputs=inputs or [], + name=name, + ) + return self.dag.add_node(node) + + # Source operations + + def source(self, path: str, name: str = None) -> str: + """Add a SOURCE node.""" + return self._add(NodeType.SOURCE, {"path": path}, name=name) + + # Transform operations + + def segment(self, input_id: str, duration: float = None, + offset: float = 0, precise: bool = True, name: str = None) -> str: + """Add a SEGMENT node.""" + config = {"offset": offset, "precise": precise} + if duration is not None: + config["duration"] = duration + return self._add(NodeType.SEGMENT, config, [input_id], name=name) + + def resize(self, input_id: str, width: int, height: int, + mode: str = "fit", name: str = None) -> str: + """Add a RESIZE node.""" + return self._add( + NodeType.RESIZE, + {"width": width, "height": height, "mode": mode}, + [input_id], + name=name + ) + + def transform(self, input_id: str, effects: Dict[str, Any], + name: str = None) -> str: + """Add a TRANSFORM node.""" + return self._add(NodeType.TRANSFORM, {"effects": effects}, [input_id], name=name) + + # Compose operations + + def sequence(self, input_ids: List[str], transition: Dict[str, Any] = None, + name: str = None) -> str: + """Add a SEQUENCE node.""" + config = {"transition": transition or {"type": "cut"}} + return self._add(NodeType.SEQUENCE, config, input_ids, name=name) + + def layer(self, input_ids: List[str], configs: List[Dict] = None, + name: str = None) -> str: + """Add a LAYER node.""" + return self._add( + NodeType.LAYER, + {"inputs": configs or [{}] * len(input_ids)}, + input_ids, + name=name + ) + + def mux(self, video_id: str, audio_id: str, shortest: bool = True, + name: str = None) -> str: + """Add a MUX node.""" + return self._add( + NodeType.MUX, + {"video_stream": 0, "audio_stream": 1, "shortest": shortest}, + [video_id, audio_id], + name=name + ) + + def blend(self, input1_id: str, input2_id: str, mode: str = "overlay", + opacity: float = 0.5, name: str = None) -> str: + """Add a BLEND node.""" + return self._add( + NodeType.BLEND, + {"mode": mode, "opacity": opacity}, + [input1_id, input2_id], + name=name + ) + + def audio_mix(self, input_ids: List[str], gains: List[float] = None, + normalize: bool = True, name: str = None) -> str: + """Add an AUDIO_MIX node to mix multiple audio streams.""" + config = {"normalize": normalize} + if gains is not None: + config["gains"] = gains + return self._add(NodeType.AUDIO_MIX, config, input_ids, name=name) + + # Output + + def set_output(self, node_id: str) -> "DAGBuilder": + """Set the output node.""" + self.dag.set_output(node_id) + return self + + def build(self) -> DAG: + """Build and validate the DAG.""" + errors = self.dag.validate() + if errors: + raise ValueError(f"Invalid DAG: {errors}") + return self.dag diff --git a/artdag/effects/__init__.py b/artdag/effects/__init__.py new file mode 100644 index 0000000..701765b --- /dev/null +++ b/artdag/effects/__init__.py @@ -0,0 +1,55 @@ +""" +Cacheable effect system. + +Effects are single Python files with: +- PEP 723 embedded dependencies +- @-tag metadata in docstrings +- Frame-by-frame or whole-video API + +Effects are cached by content hash (SHA3-256) and executed in +sandboxed environments for determinism. +""" + +from .meta import EffectMeta, ParamSpec, ExecutionContext +from .loader import load_effect, load_effect_file, LoadedEffect, compute_cid +from .binding import ( + AnalysisData, + ResolvedBinding, + resolve_binding, + resolve_all_bindings, + bindings_to_lookup_table, + has_bindings, + extract_binding_sources, +) +from .sandbox import Sandbox, SandboxConfig, SandboxResult, is_bwrap_available, get_venv_path +from .runner import run_effect, run_effect_from_cache, EffectExecutor + +__all__ = [ + # Meta types + "EffectMeta", + "ParamSpec", + "ExecutionContext", + # Loader + "load_effect", + "load_effect_file", + "LoadedEffect", + "compute_cid", + # Binding + "AnalysisData", + "ResolvedBinding", + "resolve_binding", + "resolve_all_bindings", + "bindings_to_lookup_table", + "has_bindings", + "extract_binding_sources", + # Sandbox + "Sandbox", + "SandboxConfig", + "SandboxResult", + "is_bwrap_available", + "get_venv_path", + # Runner + "run_effect", + "run_effect_from_cache", + "EffectExecutor", +] diff --git a/artdag/effects/binding.py b/artdag/effects/binding.py new file mode 100644 index 0000000..9017185 --- /dev/null +++ b/artdag/effects/binding.py @@ -0,0 +1,311 @@ +""" +Parameter binding resolution. + +Resolves bind expressions to per-frame lookup tables at plan time. +Binding options: + - :range [lo hi] - map 0-1 to output range + - :smooth N - smoothing window in seconds + - :offset N - time offset in seconds + - :on-event V - value on discrete events + - :decay N - exponential decay after event + - :noise N - add deterministic noise (seeded) + - :seed N - explicit RNG seed +""" + +import hashlib +import math +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class AnalysisData: + """ + Analysis data for binding resolution. + + Attributes: + frame_rate: Video frame rate + total_frames: Total number of frames + features: Dict mapping feature name to per-frame values + events: Dict mapping event name to list of frame indices + """ + + frame_rate: float + total_frames: int + features: Dict[str, List[float]] # feature -> [value_per_frame] + events: Dict[str, List[int]] # event -> [frame_indices] + + def get_feature(self, name: str, frame: int) -> float: + """Get feature value at frame, interpolating if needed.""" + if name not in self.features: + return 0.0 + values = self.features[name] + if not values: + return 0.0 + if frame >= len(values): + return values[-1] + return values[frame] + + def get_events_in_range( + self, name: str, start_frame: int, end_frame: int + ) -> List[int]: + """Get event frames in range.""" + if name not in self.events: + return [] + return [f for f in self.events[name] if start_frame <= f < end_frame] + + +@dataclass +class ResolvedBinding: + """ + Resolved binding with per-frame values. + + Attributes: + param_name: Parameter this binding applies to + values: List of values, one per frame + """ + + param_name: str + values: List[float] + + def get(self, frame: int) -> float: + """Get value at frame.""" + if frame >= len(self.values): + return self.values[-1] if self.values else 0.0 + return self.values[frame] + + +def resolve_binding( + binding: Dict[str, Any], + analysis: AnalysisData, + param_name: str, + cache_id: str = None, +) -> ResolvedBinding: + """ + Resolve a binding specification to per-frame values. + + Args: + binding: Binding spec with source, feature, and options + analysis: Analysis data with features and events + param_name: Name of the parameter being bound + cache_id: Cache ID for deterministic seeding + + Returns: + ResolvedBinding with values for each frame + """ + feature = binding.get("feature") + if not feature: + raise ValueError(f"Binding for {param_name} missing feature") + + # Get base values + values = [] + is_event = feature in analysis.events + + if is_event: + # Event-based binding + on_event = binding.get("on_event", 1.0) + decay = binding.get("decay", 0.0) + values = _resolve_event_binding( + analysis.events.get(feature, []), + analysis.total_frames, + analysis.frame_rate, + on_event, + decay, + ) + else: + # Continuous feature binding + feature_values = analysis.features.get(feature, []) + if not feature_values: + # No data, use zeros + values = [0.0] * analysis.total_frames + else: + # Extend to total frames if needed + values = list(feature_values) + while len(values) < analysis.total_frames: + values.append(values[-1] if values else 0.0) + + # Apply offset + offset = binding.get("offset") + if offset: + offset_frames = int(offset * analysis.frame_rate) + values = _apply_offset(values, offset_frames) + + # Apply smoothing + smooth = binding.get("smooth") + if smooth: + window_frames = int(smooth * analysis.frame_rate) + values = _apply_smoothing(values, window_frames) + + # Apply range mapping + range_spec = binding.get("range") + if range_spec: + lo, hi = range_spec + values = _apply_range(values, lo, hi) + + # Apply noise + noise = binding.get("noise") + if noise: + seed = binding.get("seed") + if seed is None and cache_id: + # Derive seed from cache_id for determinism + seed = int(hashlib.sha256(cache_id.encode()).hexdigest()[:8], 16) + values = _apply_noise(values, noise, seed or 0) + + return ResolvedBinding(param_name=param_name, values=values) + + +def _resolve_event_binding( + event_frames: List[int], + total_frames: int, + frame_rate: float, + on_event: float, + decay: float, +) -> List[float]: + """ + Resolve event-based binding with optional decay. + + Args: + event_frames: List of frame indices where events occur + total_frames: Total number of frames + frame_rate: Video frame rate + on_event: Value at event + decay: Decay time constant in seconds (0 = instant) + + Returns: + List of values per frame + """ + values = [0.0] * total_frames + + if not event_frames: + return values + + event_set = set(event_frames) + + if decay <= 0: + # No decay - just mark event frames + for f in event_frames: + if 0 <= f < total_frames: + values[f] = on_event + else: + # Apply exponential decay + decay_frames = decay * frame_rate + for f in event_frames: + if f < 0 or f >= total_frames: + continue + # Apply decay from this event forward + for i in range(f, total_frames): + elapsed = i - f + decayed = on_event * math.exp(-elapsed / decay_frames) + if decayed < 0.001: + break + values[i] = max(values[i], decayed) + + return values + + +def _apply_offset(values: List[float], offset_frames: int) -> List[float]: + """Shift values by offset frames (positive = delay).""" + if offset_frames == 0: + return values + + n = len(values) + result = [0.0] * n + + for i in range(n): + src = i - offset_frames + if 0 <= src < n: + result[i] = values[src] + + return result + + +def _apply_smoothing(values: List[float], window_frames: int) -> List[float]: + """Apply moving average smoothing.""" + if window_frames <= 1: + return values + + n = len(values) + result = [] + half = window_frames // 2 + + for i in range(n): + start = max(0, i - half) + end = min(n, i + half + 1) + avg = sum(values[start:end]) / (end - start) + result.append(avg) + + return result + + +def _apply_range(values: List[float], lo: float, hi: float) -> List[float]: + """Map values from 0-1 to lo-hi range.""" + return [lo + v * (hi - lo) for v in values] + + +def _apply_noise(values: List[float], amount: float, seed: int) -> List[float]: + """Add deterministic noise to values.""" + rng = random.Random(seed) + return [v + rng.uniform(-amount, amount) for v in values] + + +def resolve_all_bindings( + config: Dict[str, Any], + analysis: AnalysisData, + cache_id: str = None, +) -> Dict[str, ResolvedBinding]: + """ + Resolve all bindings in a config dict. + + Looks for values with _binding: True marker. + + Args: + config: Node config with potential bindings + analysis: Analysis data + cache_id: Cache ID for seeding + + Returns: + Dict mapping param name to resolved binding + """ + resolved = {} + + for key, value in config.items(): + if isinstance(value, dict) and value.get("_binding"): + resolved[key] = resolve_binding(value, analysis, key, cache_id) + + return resolved + + +def bindings_to_lookup_table( + bindings: Dict[str, ResolvedBinding], +) -> Dict[str, List[float]]: + """ + Convert resolved bindings to simple lookup tables. + + Returns dict mapping param name to list of per-frame values. + This format is JSON-serializable for inclusion in execution plans. + """ + return {name: binding.values for name, binding in bindings.items()} + + +def has_bindings(config: Dict[str, Any]) -> bool: + """Check if config contains any bindings.""" + for value in config.values(): + if isinstance(value, dict) and value.get("_binding"): + return True + return False + + +def extract_binding_sources(config: Dict[str, Any]) -> List[str]: + """ + Extract all analysis source references from bindings. + + Returns list of node IDs that provide analysis data. + """ + sources = [] + for value in config.values(): + if isinstance(value, dict) and value.get("_binding"): + source = value.get("source") + if source and source not in sources: + sources.append(source) + return sources diff --git a/artdag/effects/frame_processor.py b/artdag/effects/frame_processor.py new file mode 100644 index 0000000..c2a04d2 --- /dev/null +++ b/artdag/effects/frame_processor.py @@ -0,0 +1,347 @@ +""" +FFmpeg pipe-based frame processing. + +Processes video through Python frame-by-frame effects using FFmpeg pipes: + FFmpeg decode -> Python process_frame -> FFmpeg encode + +This avoids writing intermediate frames to disk. +""" + +import logging +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class VideoInfo: + """Video metadata.""" + + width: int + height: int + frame_rate: float + total_frames: int + duration: float + pixel_format: str = "rgb24" + + +def probe_video(path: Path) -> VideoInfo: + """ + Get video information using ffprobe. + + Args: + path: Path to video file + + Returns: + VideoInfo with dimensions, frame rate, etc. + """ + cmd = [ + "ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate,nb_frames,duration", + "-of", "csv=p=0", + str(path), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"ffprobe failed: {result.stderr}") + + parts = result.stdout.strip().split(",") + if len(parts) < 4: + raise RuntimeError(f"Unexpected ffprobe output: {result.stdout}") + + width = int(parts[0]) + height = int(parts[1]) + + # Parse frame rate (could be "30/1" or "30") + fr_parts = parts[2].split("/") + if len(fr_parts) == 2: + frame_rate = float(fr_parts[0]) / float(fr_parts[1]) + else: + frame_rate = float(fr_parts[0]) + + # nb_frames might be N/A + total_frames = 0 + duration = 0.0 + try: + total_frames = int(parts[3]) + except (ValueError, IndexError): + pass + + try: + duration = float(parts[4]) if len(parts) > 4 else 0.0 + except (ValueError, IndexError): + pass + + if total_frames == 0 and duration > 0: + total_frames = int(duration * frame_rate) + + return VideoInfo( + width=width, + height=height, + frame_rate=frame_rate, + total_frames=total_frames, + duration=duration, + ) + + +FrameProcessor = Callable[[np.ndarray, Dict[str, Any], Any], Tuple[np.ndarray, Any]] + + +def process_video( + input_path: Path, + output_path: Path, + process_frame: FrameProcessor, + params: Dict[str, Any], + bindings: Dict[str, List[float]] = None, + initial_state: Any = None, + pixel_format: str = "rgb24", + output_codec: str = "libx264", + output_options: List[str] = None, +) -> Tuple[Path, Any]: + """ + Process video through frame-by-frame effect. + + Args: + input_path: Input video path + output_path: Output video path + process_frame: Function (frame, params, state) -> (frame, state) + params: Static parameter dict + bindings: Per-frame parameter lookup tables + initial_state: Initial state for process_frame + pixel_format: Pixel format for frame data + output_codec: Video codec for output + output_options: Additional ffmpeg output options + + Returns: + Tuple of (output_path, final_state) + """ + bindings = bindings or {} + output_options = output_options or [] + + # Probe input + info = probe_video(input_path) + logger.info(f"Processing {info.width}x{info.height} @ {info.frame_rate}fps") + + # Calculate bytes per frame + if pixel_format == "rgb24": + bytes_per_pixel = 3 + elif pixel_format == "rgba": + bytes_per_pixel = 4 + else: + bytes_per_pixel = 3 # Default to RGB + + frame_size = info.width * info.height * bytes_per_pixel + + # Start decoder process + decode_cmd = [ + "ffmpeg", + "-i", str(input_path), + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-", + ] + + # Start encoder process + encode_cmd = [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-s", f"{info.width}x{info.height}", + "-r", str(info.frame_rate), + "-i", "-", + "-i", str(input_path), # For audio + "-map", "0:v", + "-map", "1:a?", + "-c:v", output_codec, + "-c:a", "aac", + *output_options, + str(output_path), + ] + + logger.debug(f"Decoder: {' '.join(decode_cmd)}") + logger.debug(f"Encoder: {' '.join(encode_cmd)}") + + decoder = subprocess.Popen( + decode_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + encoder = subprocess.Popen( + encode_cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + state = initial_state + frame_idx = 0 + + try: + while True: + # Read frame from decoder + raw_frame = decoder.stdout.read(frame_size) + if len(raw_frame) < frame_size: + break + + # Convert to numpy + frame = np.frombuffer(raw_frame, dtype=np.uint8) + frame = frame.reshape((info.height, info.width, bytes_per_pixel)) + + # Build per-frame params + frame_params = dict(params) + for param_name, values in bindings.items(): + if frame_idx < len(values): + frame_params[param_name] = values[frame_idx] + + # Process frame + processed, state = process_frame(frame, frame_params, state) + + # Ensure correct shape and dtype + if processed.shape != frame.shape: + raise ValueError( + f"Frame shape mismatch: {processed.shape} vs {frame.shape}" + ) + processed = processed.astype(np.uint8) + + # Write to encoder + encoder.stdin.write(processed.tobytes()) + frame_idx += 1 + + if frame_idx % 100 == 0: + logger.debug(f"Processed frame {frame_idx}") + + except Exception as e: + logger.error(f"Frame processing failed at frame {frame_idx}: {e}") + raise + finally: + decoder.stdout.close() + decoder.wait() + encoder.stdin.close() + encoder.wait() + + if encoder.returncode != 0: + stderr = encoder.stderr.read().decode() if encoder.stderr else "" + raise RuntimeError(f"Encoder failed: {stderr}") + + logger.info(f"Processed {frame_idx} frames") + return output_path, state + + +def process_video_batch( + input_path: Path, + output_path: Path, + process_frames: Callable[[List[np.ndarray], Dict[str, Any]], List[np.ndarray]], + params: Dict[str, Any], + batch_size: int = 30, + pixel_format: str = "rgb24", + output_codec: str = "libx264", +) -> Path: + """ + Process video in batches for effects that need temporal context. + + Args: + input_path: Input video path + output_path: Output video path + process_frames: Function (frames_batch, params) -> processed_batch + params: Parameter dict + batch_size: Number of frames per batch + pixel_format: Pixel format + output_codec: Output codec + + Returns: + Output path + """ + info = probe_video(input_path) + + if pixel_format == "rgb24": + bytes_per_pixel = 3 + elif pixel_format == "rgba": + bytes_per_pixel = 4 + else: + bytes_per_pixel = 3 + + frame_size = info.width * info.height * bytes_per_pixel + + decode_cmd = [ + "ffmpeg", + "-i", str(input_path), + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-", + ] + + encode_cmd = [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-s", f"{info.width}x{info.height}", + "-r", str(info.frame_rate), + "-i", "-", + "-i", str(input_path), + "-map", "0:v", + "-map", "1:a?", + "-c:v", output_codec, + "-c:a", "aac", + str(output_path), + ] + + decoder = subprocess.Popen( + decode_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + encoder = subprocess.Popen( + encode_cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + batch = [] + total_processed = 0 + + try: + while True: + raw_frame = decoder.stdout.read(frame_size) + if len(raw_frame) < frame_size: + # Process remaining batch + if batch: + processed = process_frames(batch, params) + for frame in processed: + encoder.stdin.write(frame.astype(np.uint8).tobytes()) + total_processed += 1 + break + + frame = np.frombuffer(raw_frame, dtype=np.uint8) + frame = frame.reshape((info.height, info.width, bytes_per_pixel)) + batch.append(frame) + + if len(batch) >= batch_size: + processed = process_frames(batch, params) + for frame in processed: + encoder.stdin.write(frame.astype(np.uint8).tobytes()) + total_processed += 1 + batch = [] + + finally: + decoder.stdout.close() + decoder.wait() + encoder.stdin.close() + encoder.wait() + + if encoder.returncode != 0: + stderr = encoder.stderr.read().decode() if encoder.stderr else "" + raise RuntimeError(f"Encoder failed: {stderr}") + + logger.info(f"Processed {total_processed} frames in batches of {batch_size}") + return output_path diff --git a/artdag/effects/loader.py b/artdag/effects/loader.py new file mode 100644 index 0000000..47ee36c --- /dev/null +++ b/artdag/effects/loader.py @@ -0,0 +1,455 @@ +""" +Effect file loader. + +Parses effect files with: +- PEP 723 inline script metadata for dependencies +- @-tag docstrings for effect metadata +- META object for programmatic access +""" + +import ast +import hashlib +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from .meta import EffectMeta, ParamSpec + + +@dataclass +class LoadedEffect: + """ + A loaded effect with all metadata. + + Attributes: + source: Original source code + cid: SHA3-256 hash of source + meta: Extracted EffectMeta + dependencies: List of pip dependencies + requires_python: Python version requirement + module: Compiled module (if loaded) + """ + + source: str + cid: str + meta: EffectMeta + dependencies: List[str] = field(default_factory=list) + requires_python: str = ">=3.10" + module: Any = None + + def has_frame_api(self) -> bool: + """Check if effect has frame-by-frame API.""" + return self.meta.api_type == "frame" + + def has_video_api(self) -> bool: + """Check if effect has whole-video API.""" + return self.meta.api_type == "video" + + +def compute_cid(source: str) -> str: + """Compute SHA3-256 hash of effect source.""" + return hashlib.sha3_256(source.encode("utf-8")).hexdigest() + + +def parse_pep723_metadata(source: str) -> Tuple[List[str], str]: + """ + Parse PEP 723 inline script metadata. + + Looks for: + # /// script + # requires-python = ">=3.10" + # dependencies = ["numpy", "opencv-python"] + # /// + + Returns: + Tuple of (dependencies list, requires_python string) + """ + dependencies = [] + requires_python = ">=3.10" + + # Match the script block + pattern = r"# /// script\n(.*?)# ///" + match = re.search(pattern, source, re.DOTALL) + + if not match: + return dependencies, requires_python + + block = match.group(1) + + # Parse dependencies + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if deps_match: + deps_str = deps_match.group(1) + # Extract quoted strings + dependencies = re.findall(r'"([^"]+)"', deps_str) + + # Parse requires-python + python_match = re.search(r'# requires-python = "([^"]+)"', block) + if python_match: + requires_python = python_match.group(1) + + return dependencies, requires_python + + +def parse_docstring_metadata(docstring: str) -> Dict[str, Any]: + """ + Parse @-tag metadata from docstring. + + Supports: + @effect name + @version 1.0.0 + @author @user@domain + @temporal false + @description + Multi-line description text. + + @param name type + @range lo hi + @default value + Description text. + + @example + (fx effect :param value) + + Returns: + Dictionary with extracted metadata + """ + if not docstring: + return {} + + result = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + "examples": [], + } + + lines = docstring.strip().split("\n") + i = 0 + current_param = None + + while i < len(lines): + line = lines[i].strip() + + if line.startswith("@effect "): + result["name"] = line[8:].strip() + + elif line.startswith("@version "): + result["version"] = line[9:].strip() + + elif line.startswith("@author "): + result["author"] = line[8:].strip() + + elif line.startswith("@temporal "): + val = line[10:].strip().lower() + result["temporal"] = val in ("true", "yes", "1") + + elif line.startswith("@description"): + # Collect multi-line description + desc_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + if next_line.strip().startswith("@"): + i -= 1 # Back up to process this tag + break + desc_lines.append(next_line) + i += 1 + result["description"] = "\n".join(desc_lines).strip() + + elif line.startswith("@param "): + # Parse parameter: @param name type + parts = line[7:].split() + if len(parts) >= 2: + current_param = { + "name": parts[0], + "type": parts[1], + "range": None, + "default": None, + "description": "", + } + # Collect param details + desc_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + stripped = next_line.strip() + + if stripped.startswith("@range "): + range_parts = stripped[7:].split() + if len(range_parts) >= 2: + try: + current_param["range"] = ( + float(range_parts[0]), + float(range_parts[1]), + ) + except ValueError: + pass + + elif stripped.startswith("@default "): + current_param["default"] = stripped[9:].strip() + + elif stripped.startswith("@param ") or stripped.startswith("@example"): + i -= 1 # Back up + break + + elif stripped.startswith("@"): + i -= 1 + break + + elif stripped: + desc_lines.append(stripped) + + i += 1 + + current_param["description"] = " ".join(desc_lines) + result["params"].append(current_param) + current_param = None + + elif line.startswith("@example"): + # Collect example + example_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + if next_line.strip().startswith("@") and not next_line.strip().startswith("@example"): + if next_line.strip().startswith("@example"): + i -= 1 + break + if next_line.strip().startswith("@example"): + i -= 1 + break + example_lines.append(next_line) + i += 1 + example = "\n".join(example_lines).strip() + if example: + result["examples"].append(example) + + i += 1 + + return result + + +def extract_meta_from_ast(source: str) -> Optional[Dict[str, Any]]: + """ + Extract META object from source AST. + + Looks for: + META = EffectMeta(...) + + Returns the keyword arguments if found. + """ + try: + tree = ast.parse(source) + except SyntaxError: + return None + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "META": + if isinstance(node.value, ast.Call): + return _extract_call_kwargs(node.value) + return None + + +def _extract_call_kwargs(call: ast.Call) -> Dict[str, Any]: + """Extract keyword arguments from an AST Call node.""" + result = {} + + for keyword in call.keywords: + if keyword.arg is None: + continue + value = _ast_to_value(keyword.value) + if value is not None: + result[keyword.arg] = value + + return result + + +def _ast_to_value(node: ast.expr) -> Any: + """Convert AST node to Python value.""" + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Str): # Python 3.7 compat + return node.s + elif isinstance(node, ast.Num): # Python 3.7 compat + return node.n + elif isinstance(node, ast.NameConstant): # Python 3.7 compat + return node.value + elif isinstance(node, ast.List): + return [_ast_to_value(elt) for elt in node.elts] + elif isinstance(node, ast.Tuple): + return tuple(_ast_to_value(elt) for elt in node.elts) + elif isinstance(node, ast.Dict): + return { + _ast_to_value(k): _ast_to_value(v) + for k, v in zip(node.keys, node.values) + if k is not None + } + elif isinstance(node, ast.Call): + # Handle ParamSpec(...) calls + if isinstance(node.func, ast.Name) and node.func.id == "ParamSpec": + return _extract_call_kwargs(node) + return None + + +def get_module_docstring(source: str) -> str: + """Extract the module-level docstring from source.""" + try: + tree = ast.parse(source) + except SyntaxError: + return "" + + if tree.body and isinstance(tree.body[0], ast.Expr): + if isinstance(tree.body[0].value, ast.Constant): + return tree.body[0].value.value + elif isinstance(tree.body[0].value, ast.Str): # Python 3.7 compat + return tree.body[0].value.s + return "" + + +def load_effect(source: str) -> LoadedEffect: + """ + Load an effect from source code. + + Parses: + 1. PEP 723 metadata for dependencies + 2. Module docstring for @-tag metadata + 3. META object for programmatic metadata + + Priority: META object > docstring > defaults + + Args: + source: Effect source code + + Returns: + LoadedEffect with all metadata + + Raises: + ValueError: If effect is invalid + """ + cid = compute_cid(source) + + # Parse PEP 723 metadata + dependencies, requires_python = parse_pep723_metadata(source) + + # Parse docstring metadata + docstring = get_module_docstring(source) + doc_meta = parse_docstring_metadata(docstring) + + # Try to extract META from AST + ast_meta = extract_meta_from_ast(source) + + # Build EffectMeta, preferring META object over docstring + name = "" + if ast_meta and "name" in ast_meta: + name = ast_meta["name"] + elif doc_meta.get("name"): + name = doc_meta["name"] + + if not name: + raise ValueError("Effect must have a name (@effect or META.name)") + + version = ast_meta.get("version") if ast_meta else doc_meta.get("version", "1.0.0") + temporal = ast_meta.get("temporal") if ast_meta else doc_meta.get("temporal", False) + author = ast_meta.get("author") if ast_meta else doc_meta.get("author", "") + description = ast_meta.get("description") if ast_meta else doc_meta.get("description", "") + examples = ast_meta.get("examples") if ast_meta else doc_meta.get("examples", []) + + # Build params + params = [] + if ast_meta and "params" in ast_meta: + for p in ast_meta["params"]: + if isinstance(p, dict): + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("param_type", "float"), float) + if isinstance(p.get("param_type"), type): + param_type = p["param_type"] + params.append( + ParamSpec( + name=p.get("name", ""), + param_type=param_type, + default=p.get("default"), + range=p.get("range"), + description=p.get("description", ""), + ) + ) + elif doc_meta.get("params"): + for p in doc_meta["params"]: + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("type", "float"), float) + + default = p.get("default") + if default is not None: + try: + default = param_type(default) + except (ValueError, TypeError): + pass + + params.append( + ParamSpec( + name=p["name"], + param_type=param_type, + default=default, + range=p.get("range"), + description=p.get("description", ""), + ) + ) + + # Determine API type by checking for function definitions + api_type = "frame" # default + try: + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + if node.name == "process": + api_type = "video" + break + elif node.name == "process_frame": + api_type = "frame" + break + except SyntaxError: + pass + + meta = EffectMeta( + name=name, + version=version if isinstance(version, str) else "1.0.0", + temporal=bool(temporal), + params=params, + author=author if isinstance(author, str) else "", + description=description if isinstance(description, str) else "", + examples=examples if isinstance(examples, list) else [], + dependencies=dependencies, + requires_python=requires_python, + api_type=api_type, + ) + + return LoadedEffect( + source=source, + cid=cid, + meta=meta, + dependencies=dependencies, + requires_python=requires_python, + ) + + +def load_effect_file(path: Path) -> LoadedEffect: + """Load an effect from a file path.""" + source = path.read_text(encoding="utf-8") + return load_effect(source) + + +def compute_deps_hash(dependencies: List[str]) -> str: + """ + Compute hash of sorted dependencies. + + Used for venv caching - same deps = same hash = reuse venv. + """ + sorted_deps = sorted(dep.lower().strip() for dep in dependencies) + deps_str = "\n".join(sorted_deps) + return hashlib.sha3_256(deps_str.encode("utf-8")).hexdigest() diff --git a/artdag/effects/meta.py b/artdag/effects/meta.py new file mode 100644 index 0000000..810623a --- /dev/null +++ b/artdag/effects/meta.py @@ -0,0 +1,247 @@ +""" +Effect metadata types. + +Defines the core dataclasses for effect metadata: +- ParamSpec: Parameter specification with type, range, and default +- EffectMeta: Complete effect metadata including params and flags +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type, Union + + +@dataclass +class ParamSpec: + """ + Specification for an effect parameter. + + Attributes: + name: Parameter name (used in recipes as :name) + param_type: Python type (float, int, bool, str) + default: Default value if not specified + range: Optional (min, max) tuple for numeric types + description: Human-readable description + choices: Optional list of allowed values (for enums) + """ + + name: str + param_type: Type + default: Any = None + range: Optional[Tuple[float, float]] = None + description: str = "" + choices: Optional[List[Any]] = None + + def validate(self, value: Any) -> Any: + """ + Validate and coerce a parameter value. + + Args: + value: Input value to validate + + Returns: + Validated and coerced value + + Raises: + ValueError: If value is invalid + """ + if value is None: + if self.default is not None: + return self.default + raise ValueError(f"Parameter '{self.name}' requires a value") + + # Type coercion + try: + if self.param_type == bool: + if isinstance(value, str): + value = value.lower() in ("true", "1", "yes") + else: + value = bool(value) + elif self.param_type == int: + value = int(value) + elif self.param_type == float: + value = float(value) + elif self.param_type == str: + value = str(value) + else: + value = self.param_type(value) + except (ValueError, TypeError) as e: + raise ValueError( + f"Parameter '{self.name}' expects {self.param_type.__name__}, " + f"got {type(value).__name__}: {e}" + ) + + # Range check for numeric types + if self.range is not None and self.param_type in (int, float): + min_val, max_val = self.range + if value < min_val or value > max_val: + raise ValueError( + f"Parameter '{self.name}' must be in range " + f"[{min_val}, {max_val}], got {value}" + ) + + # Choices check + if self.choices is not None and value not in self.choices: + raise ValueError( + f"Parameter '{self.name}' must be one of {self.choices}, got {value}" + ) + + return value + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + d = { + "name": self.name, + "type": self.param_type.__name__, + "description": self.description, + } + if self.default is not None: + d["default"] = self.default + if self.range is not None: + d["range"] = list(self.range) + if self.choices is not None: + d["choices"] = self.choices + return d + + +@dataclass +class EffectMeta: + """ + Complete metadata for an effect. + + Attributes: + name: Effect name (used in recipes) + version: Semantic version string + temporal: If True, effect needs complete input (can't be collapsed) + params: List of parameter specifications + author: Optional author identifier + description: Human-readable description + examples: List of example S-expression usages + dependencies: List of Python package dependencies + requires_python: Minimum Python version + api_type: "frame" for frame-by-frame, "video" for whole-video + """ + + name: str + version: str = "1.0.0" + temporal: bool = False + params: List[ParamSpec] = field(default_factory=list) + author: str = "" + description: str = "" + examples: List[str] = field(default_factory=list) + dependencies: List[str] = field(default_factory=list) + requires_python: str = ">=3.10" + api_type: str = "frame" # "frame" or "video" + + def get_param(self, name: str) -> Optional[ParamSpec]: + """Get a parameter spec by name.""" + for param in self.params: + if param.name == name: + return param + return None + + def validate_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate all parameters. + + Args: + params: Dictionary of parameter values + + Returns: + Dictionary with validated/coerced values including defaults + + Raises: + ValueError: If any parameter is invalid + """ + result = {} + for spec in self.params: + value = params.get(spec.name) + result[spec.name] = spec.validate(value) + return result + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "name": self.name, + "version": self.version, + "temporal": self.temporal, + "params": [p.to_dict() for p in self.params], + "author": self.author, + "description": self.description, + "examples": self.examples, + "dependencies": self.dependencies, + "requires_python": self.requires_python, + "api_type": self.api_type, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EffectMeta": + """Create from dictionary.""" + params = [] + for p in data.get("params", []): + # Map type name back to Python type + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("type", "float"), float) + params.append( + ParamSpec( + name=p["name"], + param_type=param_type, + default=p.get("default"), + range=tuple(p["range"]) if p.get("range") else None, + description=p.get("description", ""), + choices=p.get("choices"), + ) + ) + + return cls( + name=data["name"], + version=data.get("version", "1.0.0"), + temporal=data.get("temporal", False), + params=params, + author=data.get("author", ""), + description=data.get("description", ""), + examples=data.get("examples", []), + dependencies=data.get("dependencies", []), + requires_python=data.get("requires_python", ">=3.10"), + api_type=data.get("api_type", "frame"), + ) + + +@dataclass +class ExecutionContext: + """ + Context passed to effect execution. + + Provides controlled access to resources within sandbox. + """ + + input_paths: List[str] + output_path: str + params: Dict[str, Any] + seed: int # Deterministic seed for RNG + frame_rate: float = 30.0 + width: int = 1920 + height: int = 1080 + + # Resolved bindings (frame -> param value lookup) + bindings: Dict[str, List[float]] = field(default_factory=dict) + + def get_param_at_frame(self, param_name: str, frame: int) -> Any: + """ + Get parameter value at a specific frame. + + If parameter has a binding, looks up the bound value. + Otherwise returns the static parameter value. + """ + if param_name in self.bindings: + binding_values = self.bindings[param_name] + if frame < len(binding_values): + return binding_values[frame] + # Past end of binding data, use last value + return binding_values[-1] if binding_values else self.params.get(param_name) + return self.params.get(param_name) + + def get_rng(self) -> "random.Random": + """Get a seeded random number generator.""" + import random + + return random.Random(self.seed) diff --git a/artdag/effects/runner.py b/artdag/effects/runner.py new file mode 100644 index 0000000..2f58c12 --- /dev/null +++ b/artdag/effects/runner.py @@ -0,0 +1,259 @@ +""" +Effect runner. + +Main entry point for executing cached effects with sandboxing. +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .binding import AnalysisData, bindings_to_lookup_table, resolve_all_bindings +from .loader import load_effect, LoadedEffect +from .meta import ExecutionContext +from .sandbox import Sandbox, SandboxConfig, SandboxResult, get_venv_path + +logger = logging.getLogger(__name__) + + +def run_effect( + effect_source: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + cache_id: str = None, + seed: int = 0, + trust_level: str = "untrusted", + timeout: int = 3600, +) -> SandboxResult: + """ + Run an effect with full sandboxing. + + This is the main entry point for effect execution. + + Args: + effect_source: Effect source code + input_paths: List of input file paths + output_path: Output file path + params: Effect parameters (may contain bindings) + analysis: Optional analysis data for binding resolution + cache_id: Cache ID for deterministic seeding + seed: RNG seed (overrides cache_id-based seed) + trust_level: "untrusted" or "trusted" + timeout: Maximum execution time in seconds + + Returns: + SandboxResult with success status and output + """ + # Load and validate effect + loaded = load_effect(effect_source) + logger.info(f"Running effect '{loaded.meta.name}' v{loaded.meta.version}") + + # Resolve bindings if analysis data available + bindings = {} + if analysis: + resolved = resolve_all_bindings(params, analysis, cache_id) + bindings = bindings_to_lookup_table(resolved) + # Remove binding dicts from params, keeping only resolved values + params = { + k: v for k, v in params.items() + if not (isinstance(v, dict) and v.get("_binding")) + } + + # Validate parameters + validated_params = loaded.meta.validate_params(params) + + # Get or create venv for dependencies + venv_path = None + if loaded.dependencies: + venv_path = get_venv_path(loaded.dependencies) + + # Configure sandbox + config = SandboxConfig( + trust_level=trust_level, + venv_path=venv_path, + timeout=timeout, + ) + + # Write effect to temp file + import tempfile + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=False, + ) as f: + f.write(effect_source) + effect_path = Path(f.name) + + try: + with Sandbox(config) as sandbox: + result = sandbox.run_effect( + effect_path=effect_path, + input_paths=input_paths, + output_path=output_path, + params=validated_params, + bindings=bindings, + seed=seed, + ) + finally: + effect_path.unlink(missing_ok=True) + + return result + + +def run_effect_from_cache( + cache, + effect_hash: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + cache_id: str = None, + seed: int = 0, + trust_level: str = "untrusted", + timeout: int = 3600, +) -> SandboxResult: + """ + Run an effect from cache by content hash. + + Args: + cache: Cache instance + effect_hash: Content hash of effect + input_paths: Input file paths + output_path: Output file path + params: Effect parameters + analysis: Optional analysis data + cache_id: Cache ID for seeding + seed: RNG seed + trust_level: "untrusted" or "trusted" + timeout: Max execution time + + Returns: + SandboxResult + """ + effect_source = cache.get_effect(effect_hash) + if not effect_source: + return SandboxResult( + success=False, + error=f"Effect not found in cache: {effect_hash[:16]}...", + ) + + return run_effect( + effect_source=effect_source, + input_paths=input_paths, + output_path=output_path, + params=params, + analysis=analysis, + cache_id=cache_id, + seed=seed, + trust_level=trust_level, + timeout=timeout, + ) + + +def check_effect_temporal(cache, effect_hash: str) -> bool: + """ + Check if an effect is temporal (can't be collapsed). + + Args: + cache: Cache instance + effect_hash: Content hash of effect + + Returns: + True if effect is temporal + """ + metadata = cache.get_effect_metadata(effect_hash) + if not metadata: + return False + + meta = metadata.get("meta", {}) + return meta.get("temporal", False) + + +def get_effect_api_type(cache, effect_hash: str) -> str: + """ + Get the API type of an effect. + + Args: + cache: Cache instance + effect_hash: Content hash of effect + + Returns: + "frame" or "video" + """ + metadata = cache.get_effect_metadata(effect_hash) + if not metadata: + return "frame" + + meta = metadata.get("meta", {}) + return meta.get("api_type", "frame") + + +class EffectExecutor: + """ + Executor for cached effects. + + Provides a higher-level interface for effect execution. + """ + + def __init__(self, cache, trust_level: str = "untrusted"): + """ + Initialize executor. + + Args: + cache: Cache instance + trust_level: Default trust level + """ + self.cache = cache + self.trust_level = trust_level + + def execute( + self, + effect_hash: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + step_cache_id: str = None, + ) -> SandboxResult: + """ + Execute an effect. + + Args: + effect_hash: Content hash of effect + input_paths: Input file paths + output_path: Output path + params: Effect parameters + analysis: Analysis data for bindings + step_cache_id: Step cache ID for seeding + + Returns: + SandboxResult + """ + # Check effect metadata for trust level override + metadata = self.cache.get_effect_metadata(effect_hash) + trust_level = self.trust_level + if metadata: + # L1 owner can mark effect as trusted + if metadata.get("trust_level") == "trusted": + trust_level = "trusted" + + return run_effect_from_cache( + cache=self.cache, + effect_hash=effect_hash, + input_paths=input_paths, + output_path=output_path, + params=params, + analysis=analysis, + cache_id=step_cache_id, + trust_level=trust_level, + ) + + def is_temporal(self, effect_hash: str) -> bool: + """Check if effect is temporal.""" + return check_effect_temporal(self.cache, effect_hash) + + def get_api_type(self, effect_hash: str) -> str: + """Get effect API type.""" + return get_effect_api_type(self.cache, effect_hash) diff --git a/artdag/effects/sandbox.py b/artdag/effects/sandbox.py new file mode 100644 index 0000000..d0d545e --- /dev/null +++ b/artdag/effects/sandbox.py @@ -0,0 +1,431 @@ +""" +Sandbox for effect execution. + +Uses bubblewrap (bwrap) for Linux namespace isolation. +Provides controlled access to: + - Input files (read-only) + - Output file (write) + - stderr (logging) + - Seeded RNG +""" + +import hashlib +import json +import logging +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class SandboxConfig: + """ + Sandbox configuration. + + Attributes: + trust_level: "untrusted" (full isolation) or "trusted" (allows subprocess) + venv_path: Path to effect's virtual environment + wheel_cache: Shared wheel cache directory + timeout: Maximum execution time in seconds + memory_limit: Memory limit in bytes (0 = unlimited) + allow_network: Whether to allow network access + """ + + trust_level: str = "untrusted" + venv_path: Optional[Path] = None + wheel_cache: Path = field(default_factory=lambda: Path("/var/cache/artdag/wheels")) + timeout: int = 3600 # 1 hour default + memory_limit: int = 0 + allow_network: bool = False + + +def is_bwrap_available() -> bool: + """Check if bubblewrap is available.""" + try: + result = subprocess.run( + ["bwrap", "--version"], + capture_output=True, + text=True, + ) + return result.returncode == 0 + except FileNotFoundError: + return False + + +def get_venv_path(dependencies: List[str], cache_dir: Path = None) -> Path: + """ + Get or create venv for given dependencies. + + Uses hash of sorted dependencies for cache key. + + Args: + dependencies: List of pip package specifiers + cache_dir: Base directory for venv cache + + Returns: + Path to venv directory + """ + cache_dir = cache_dir or Path("/var/cache/artdag/venvs") + cache_dir.mkdir(parents=True, exist_ok=True) + + # Compute deps hash + sorted_deps = sorted(dep.lower().strip() for dep in dependencies) + deps_str = "\n".join(sorted_deps) + deps_hash = hashlib.sha3_256(deps_str.encode()).hexdigest()[:16] + + venv_path = cache_dir / deps_hash + + if venv_path.exists(): + logger.debug(f"Reusing venv at {venv_path}") + return venv_path + + # Create new venv + logger.info(f"Creating venv for {len(dependencies)} deps at {venv_path}") + + subprocess.run( + ["python", "-m", "venv", str(venv_path)], + check=True, + ) + + # Install dependencies + pip_path = venv_path / "bin" / "pip" + wheel_cache = Path("/var/cache/artdag/wheels") + + if dependencies: + cmd = [ + str(pip_path), + "install", + "--cache-dir", str(wheel_cache), + *dependencies, + ] + subprocess.run(cmd, check=True) + + return venv_path + + +@dataclass +class SandboxResult: + """Result of sandboxed execution.""" + + success: bool + output_path: Optional[Path] = None + stderr: str = "" + exit_code: int = 0 + error: Optional[str] = None + + +class Sandbox: + """ + Sandboxed effect execution environment. + + Uses bubblewrap for namespace isolation when available, + falls back to subprocess with restricted permissions. + """ + + def __init__(self, config: SandboxConfig = None): + self.config = config or SandboxConfig() + self._temp_dirs: List[Path] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + + def cleanup(self): + """Clean up temporary directories.""" + for temp_dir in self._temp_dirs: + if temp_dir.exists(): + shutil.rmtree(temp_dir, ignore_errors=True) + self._temp_dirs = [] + + def _create_temp_dir(self) -> Path: + """Create a temporary directory for sandbox use.""" + temp_dir = Path(tempfile.mkdtemp(prefix="artdag_sandbox_")) + self._temp_dirs.append(temp_dir) + return temp_dir + + def run_effect( + self, + effect_path: Path, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + bindings: Dict[str, List[float]] = None, + seed: int = 0, + ) -> SandboxResult: + """ + Run an effect in the sandbox. + + Args: + effect_path: Path to effect.py + input_paths: List of input file paths + output_path: Output file path + params: Effect parameters + bindings: Per-frame parameter bindings + seed: RNG seed for determinism + + Returns: + SandboxResult with success status and output + """ + bindings = bindings or {} + + # Create work directory + work_dir = self._create_temp_dir() + config_path = work_dir / "config.json" + effect_copy = work_dir / "effect.py" + + # Copy effect to work dir + shutil.copy(effect_path, effect_copy) + + # Write config file + config_data = { + "input_paths": [str(p) for p in input_paths], + "output_path": str(output_path), + "params": params, + "bindings": bindings, + "seed": seed, + } + config_path.write_text(json.dumps(config_data)) + + if is_bwrap_available() and self.config.trust_level == "untrusted": + return self._run_with_bwrap( + effect_copy, config_path, input_paths, output_path, work_dir + ) + else: + return self._run_subprocess( + effect_copy, config_path, input_paths, output_path, work_dir + ) + + def _run_with_bwrap( + self, + effect_path: Path, + config_path: Path, + input_paths: List[Path], + output_path: Path, + work_dir: Path, + ) -> SandboxResult: + """Run effect with bubblewrap isolation.""" + logger.info("Running effect in bwrap sandbox") + + # Build bwrap command + cmd = [ + "bwrap", + # New PID namespace + "--unshare-pid", + # No network + "--unshare-net", + # Read-only root filesystem + "--ro-bind", "/", "/", + # Read-write work directory + "--bind", str(work_dir), str(work_dir), + # Read-only input files + ] + + for input_path in input_paths: + cmd.extend(["--ro-bind", str(input_path), str(input_path)]) + + # Bind output directory as writable + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + cmd.extend(["--bind", str(output_dir), str(output_dir)]) + + # Bind venv if available + if self.config.venv_path and self.config.venv_path.exists(): + cmd.extend(["--ro-bind", str(self.config.venv_path), str(self.config.venv_path)]) + python_path = self.config.venv_path / "bin" / "python" + else: + python_path = Path("/usr/bin/python3") + + # Add runner script + runner_script = self._get_runner_script() + runner_path = work_dir / "runner.py" + runner_path.write_text(runner_script) + + # Run the effect + cmd.extend([ + str(python_path), + str(runner_path), + str(effect_path), + str(config_path), + ]) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=self.config.timeout, + ) + + if result.returncode == 0 and output_path.exists(): + return SandboxResult( + success=True, + output_path=output_path, + stderr=result.stderr, + exit_code=0, + ) + else: + return SandboxResult( + success=False, + stderr=result.stderr, + exit_code=result.returncode, + error=result.stderr or "Effect execution failed", + ) + + except subprocess.TimeoutExpired: + return SandboxResult( + success=False, + error=f"Effect timed out after {self.config.timeout}s", + exit_code=-1, + ) + except Exception as e: + return SandboxResult( + success=False, + error=str(e), + exit_code=-1, + ) + + def _run_subprocess( + self, + effect_path: Path, + config_path: Path, + input_paths: List[Path], + output_path: Path, + work_dir: Path, + ) -> SandboxResult: + """Run effect in subprocess (fallback without bwrap).""" + logger.warning("Running effect without sandbox isolation") + + # Create runner script + runner_script = self._get_runner_script() + runner_path = work_dir / "runner.py" + runner_path.write_text(runner_script) + + # Determine Python path + if self.config.venv_path and self.config.venv_path.exists(): + python_path = self.config.venv_path / "bin" / "python" + else: + python_path = "python3" + + cmd = [ + str(python_path), + str(runner_path), + str(effect_path), + str(config_path), + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=self.config.timeout, + cwd=str(work_dir), + ) + + if result.returncode == 0 and output_path.exists(): + return SandboxResult( + success=True, + output_path=output_path, + stderr=result.stderr, + exit_code=0, + ) + else: + return SandboxResult( + success=False, + stderr=result.stderr, + exit_code=result.returncode, + error=result.stderr or "Effect execution failed", + ) + + except subprocess.TimeoutExpired: + return SandboxResult( + success=False, + error=f"Effect timed out after {self.config.timeout}s", + exit_code=-1, + ) + except Exception as e: + return SandboxResult( + success=False, + error=str(e), + exit_code=-1, + ) + + def _get_runner_script(self) -> str: + """Get the runner script that executes effects.""" + return '''#!/usr/bin/env python3 +"""Effect runner script - executed in sandbox.""" + +import importlib.util +import json +import sys +from pathlib import Path + +def load_effect(effect_path): + """Load effect module from path.""" + spec = importlib.util.spec_from_file_location("effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +def main(): + if len(sys.argv) < 3: + print("Usage: runner.py ", file=sys.stderr) + sys.exit(1) + + effect_path = Path(sys.argv[1]) + config_path = Path(sys.argv[2]) + + # Load config + config = json.loads(config_path.read_text()) + + input_paths = [Path(p) for p in config["input_paths"]] + output_path = Path(config["output_path"]) + params = config["params"] + bindings = config.get("bindings", {}) + seed = config.get("seed", 0) + + # Load effect + effect = load_effect(effect_path) + + # Check API type + if hasattr(effect, "process"): + # Whole-video API + from artdag.effects.meta import ExecutionContext + ctx = ExecutionContext( + input_paths=[str(p) for p in input_paths], + output_path=str(output_path), + params=params, + seed=seed, + bindings=bindings, + ) + effect.process(input_paths, output_path, params, ctx) + + elif hasattr(effect, "process_frame"): + # Frame-by-frame API + from artdag.effects.frame_processor import process_video + + result_path, _ = process_video( + input_path=input_paths[0], + output_path=output_path, + process_frame=effect.process_frame, + params=params, + bindings=bindings, + ) + + else: + print("Effect must have process() or process_frame()", file=sys.stderr) + sys.exit(1) + + print(f"Success: {output_path}", file=sys.stderr) + +if __name__ == "__main__": + main() +''' diff --git a/artdag/engine.py b/artdag/engine.py new file mode 100644 index 0000000..0e70154 --- /dev/null +++ b/artdag/engine.py @@ -0,0 +1,246 @@ +# primitive/engine.py +""" +DAG execution engine. + +Executes DAGs by: +1. Resolving nodes in topological order +2. Checking cache for each node +3. Running executors for cache misses +4. Storing results in cache +""" + +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from .dag import DAG, Node, NodeType +from .cache import Cache +from .executor import Executor, get_executor + +logger = logging.getLogger(__name__) + + +@dataclass +class ExecutionResult: + """Result of executing a DAG.""" + success: bool + output_path: Optional[Path] = None + error: Optional[str] = None + execution_time: float = 0.0 + nodes_executed: int = 0 + nodes_cached: int = 0 + node_results: Dict[str, Path] = field(default_factory=dict) + + +@dataclass +class NodeProgress: + """Progress update for a node.""" + node_id: str + node_type: str + status: str # "pending", "running", "cached", "completed", "failed" + progress: float = 0.0 # 0.0 to 1.0 + message: str = "" + + +# Progress callback type +ProgressCallback = Callable[[NodeProgress], None] + + +class Engine: + """ + DAG execution engine. + + Manages cache, resolves dependencies, and runs executors. + """ + + def __init__(self, cache_dir: Path | str): + self.cache = Cache(cache_dir) + self._progress_callback: Optional[ProgressCallback] = None + + def set_progress_callback(self, callback: ProgressCallback): + """Set callback for progress updates.""" + self._progress_callback = callback + + def _report_progress(self, progress: NodeProgress): + """Report progress to callback if set.""" + if self._progress_callback: + try: + self._progress_callback(progress) + except Exception as e: + logger.warning(f"Progress callback error: {e}") + + def execute(self, dag: DAG) -> ExecutionResult: + """ + Execute a DAG and return the result. + + Args: + dag: The DAG to execute + + Returns: + ExecutionResult with output path or error + """ + start_time = time.time() + node_results: Dict[str, Path] = {} + nodes_executed = 0 + nodes_cached = 0 + + # Validate DAG + errors = dag.validate() + if errors: + return ExecutionResult( + success=False, + error=f"Invalid DAG: {errors}", + execution_time=time.time() - start_time, + ) + + # Get topological order + try: + order = dag.topological_order() + except Exception as e: + return ExecutionResult( + success=False, + error=f"Failed to order DAG: {e}", + execution_time=time.time() - start_time, + ) + + # Execute each node + for node_id in order: + node = dag.get_node(node_id) + type_str = node.node_type.name if isinstance(node.node_type, NodeType) else str(node.node_type) + + # Report starting + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="pending", + message=f"Processing {type_str}", + )) + + # Check cache first + cached_path = self.cache.get(node_id) + if cached_path is not None: + node_results[node_id] = cached_path + nodes_cached += 1 + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="cached", + progress=1.0, + message="Using cached result", + )) + continue + + # Get executor + executor = get_executor(node.node_type) + if executor is None: + return ExecutionResult( + success=False, + error=f"No executor for node type: {node.node_type}", + execution_time=time.time() - start_time, + node_results=node_results, + ) + + # Resolve input paths + input_paths = [] + for input_id in node.inputs: + if input_id not in node_results: + return ExecutionResult( + success=False, + error=f"Missing input {input_id} for node {node_id}", + execution_time=time.time() - start_time, + node_results=node_results, + ) + input_paths.append(node_results[input_id]) + + # Determine output path + output_path = self.cache.get_output_path(node_id, ".mkv") + + # Execute + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="running", + progress=0.5, + message=f"Executing {type_str}", + )) + + node_start = time.time() + try: + result_path = executor.execute( + config=node.config, + inputs=input_paths, + output_path=output_path, + ) + node_time = time.time() - node_start + + # Store in cache (file is already at output_path) + self.cache.put( + node_id=node_id, + source_path=result_path, + node_type=type_str, + execution_time=node_time, + move=False, # Already in place + ) + + node_results[node_id] = result_path + nodes_executed += 1 + + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="completed", + progress=1.0, + message=f"Completed in {node_time:.2f}s", + )) + + except Exception as e: + logger.error(f"Node {node_id} failed: {e}") + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="failed", + message=str(e), + )) + return ExecutionResult( + success=False, + error=f"Node {node_id} ({type_str}) failed: {e}", + execution_time=time.time() - start_time, + node_results=node_results, + nodes_executed=nodes_executed, + nodes_cached=nodes_cached, + ) + + # Get final output + output_path = node_results.get(dag.output_id) + + return ExecutionResult( + success=True, + output_path=output_path, + execution_time=time.time() - start_time, + nodes_executed=nodes_executed, + nodes_cached=nodes_cached, + node_results=node_results, + ) + + def execute_node(self, node: Node, inputs: List[Path]) -> Path: + """ + Execute a single node (bypassing DAG structure). + + Useful for testing individual executors. + """ + executor = get_executor(node.node_type) + if executor is None: + raise ValueError(f"No executor for node type: {node.node_type}") + + output_path = self.cache.get_output_path(node.node_id, ".mkv") + return executor.execute(node.config, inputs, output_path) + + def get_cache_stats(self): + """Get cache statistics.""" + return self.cache.get_stats() + + def clear_cache(self): + """Clear the cache.""" + self.cache.clear() diff --git a/artdag/executor.py b/artdag/executor.py new file mode 100644 index 0000000..e2deba8 --- /dev/null +++ b/artdag/executor.py @@ -0,0 +1,106 @@ +# primitive/executor.py +""" +Executor base class and registry. + +Executors implement the actual operations for each node type. +They are registered by node type and looked up during execution. +""" + +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Type + +from .dag import NodeType + +logger = logging.getLogger(__name__) + +# Global executor registry +_EXECUTORS: Dict[NodeType | str, Type["Executor"]] = {} + + +class Executor(ABC): + """ + Base class for node executors. + + Subclasses implement execute() to perform the actual operation. + """ + + @abstractmethod + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + """ + Execute the node operation. + + Args: + config: Node configuration + inputs: Paths to input files (from resolved input nodes) + output_path: Where to write the output + + Returns: + Path to the output file + """ + pass + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + """ + Validate node configuration. + + Returns list of error messages (empty if valid). + Override in subclasses for specific validation. + """ + return [] + + def estimate_output_size(self, config: Dict[str, Any], input_sizes: List[int]) -> int: + """ + Estimate output size in bytes. + + Override for better estimates. Default returns sum of inputs. + """ + return sum(input_sizes) if input_sizes else 0 + + +def register_executor(node_type: NodeType | str) -> Callable: + """ + Decorator to register an executor for a node type. + + Usage: + @register_executor(NodeType.SOURCE) + class SourceExecutor(Executor): + ... + """ + def decorator(cls: Type[Executor]) -> Type[Executor]: + if node_type in _EXECUTORS: + logger.warning(f"Overwriting executor for {node_type}") + _EXECUTORS[node_type] = cls + return cls + return decorator + + +def get_executor(node_type: NodeType | str) -> Optional[Executor]: + """ + Get an executor instance for a node type. + + Returns None if no executor is registered. + """ + executor_cls = _EXECUTORS.get(node_type) + if executor_cls is None: + return None + return executor_cls() + + +def list_executors() -> Dict[str, Type[Executor]]: + """List all registered executors.""" + return { + (k.name if isinstance(k, NodeType) else k): v + for k, v in _EXECUTORS.items() + } + + +def clear_executors(): + """Clear all registered executors (for testing).""" + _EXECUTORS.clear() diff --git a/artdag/nodes/__init__.py b/artdag/nodes/__init__.py new file mode 100644 index 0000000..e821b54 --- /dev/null +++ b/artdag/nodes/__init__.py @@ -0,0 +1,11 @@ +# primitive/nodes/__init__.py +""" +Built-in node executors. + +Import this module to register all built-in executors. +""" + +from . import source +from . import transform +from . import compose +from . import effect diff --git a/artdag/nodes/compose.py b/artdag/nodes/compose.py new file mode 100644 index 0000000..a7121c6 --- /dev/null +++ b/artdag/nodes/compose.py @@ -0,0 +1,548 @@ +# primitive/nodes/compose.py +""" +Compose executors: Combine multiple media inputs. + +Primitives: SEQUENCE, LAYER, MUX, BLEND +""" + +import logging +import shutil +import subprocess +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor +from .encoding import WEB_ENCODING_ARGS_STR, get_web_encoding_args + +logger = logging.getLogger(__name__) + + +def _get_duration(path: Path) -> float: + """Get media duration in seconds.""" + cmd = [ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + return float(result.stdout.strip()) + + +def _get_video_info(path: Path) -> dict: + """Get video width, height, frame rate, and sample rate.""" + cmd = [ + "ffprobe", "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate", + "-of", "csv=p=0", + str(path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + parts = result.stdout.strip().split(",") + width = int(parts[0]) if len(parts) > 0 and parts[0] else 1920 + height = int(parts[1]) if len(parts) > 1 and parts[1] else 1080 + fps_str = parts[2] if len(parts) > 2 else "30/1" + # Parse frame rate (e.g., "30/1" or "30000/1001") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) if float(den) != 0 else 30 + else: + fps = float(fps_str) if fps_str else 30 + + # Get audio sample rate + cmd_audio = [ + "ffprobe", "-v", "error", + "-select_streams", "a:0", + "-show_entries", "stream=sample_rate", + "-of", "csv=p=0", + str(path) + ] + result_audio = subprocess.run(cmd_audio, capture_output=True, text=True) + sample_rate = int(result_audio.stdout.strip()) if result_audio.stdout.strip() else 44100 + + return {"width": width, "height": height, "fps": fps, "sample_rate": sample_rate} + + +@register_executor(NodeType.SEQUENCE) +class SequenceExecutor(Executor): + """ + Concatenate inputs in time order. + + Config: + transition: Transition config + type: "cut" | "crossfade" | "fade" + duration: Transition duration in seconds + target_size: How to determine output dimensions when inputs differ + "first": Use first input's dimensions (default) + "last": Use last input's dimensions + "largest": Use largest width and height from all inputs + "explicit": Use width/height config values + width: Target width (when target_size="explicit") + height: Target height (when target_size="explicit") + background: Padding color for letterbox/pillarbox (default: "black") + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 1: + raise ValueError("SEQUENCE requires at least one input") + + if len(inputs) == 1: + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(inputs[0], output_path) + return output_path + + transition = config.get("transition", {"type": "cut"}) + transition_type = transition.get("type", "cut") + transition_duration = transition.get("duration", 0.5) + + # Size handling config + target_size = config.get("target_size", "first") + width = config.get("width") + height = config.get("height") + background = config.get("background", "black") + + if transition_type == "cut": + return self._concat_cut(inputs, output_path, target_size, width, height, background) + elif transition_type == "crossfade": + return self._concat_crossfade(inputs, output_path, transition_duration) + elif transition_type == "fade": + return self._concat_fade(inputs, output_path, transition_duration) + else: + raise ValueError(f"Unknown transition type: {transition_type}") + + def _concat_cut( + self, + inputs: List[Path], + output_path: Path, + target_size: str = "first", + width: int = None, + height: int = None, + background: str = "black", + ) -> Path: + """ + Concatenate with scaling/padding to handle different resolutions. + + Args: + inputs: Input video paths + output_path: Output path + target_size: How to determine output size: + - "first": Use first input's dimensions (default) + - "last": Use last input's dimensions + - "largest": Use largest dimensions from all inputs + - "explicit": Use width/height params + width: Explicit width (when target_size="explicit") + height: Explicit height (when target_size="explicit") + background: Padding color (default: black) + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + n = len(inputs) + input_args = [] + for p in inputs: + input_args.extend(["-i", str(p)]) + + # Get video info for all inputs + infos = [_get_video_info(p) for p in inputs] + + # Determine target dimensions + if target_size == "explicit" and width and height: + target_w, target_h = width, height + elif target_size == "last": + target_w, target_h = infos[-1]["width"], infos[-1]["height"] + elif target_size == "largest": + target_w = max(i["width"] for i in infos) + target_h = max(i["height"] for i in infos) + else: # "first" or default + target_w, target_h = infos[0]["width"], infos[0]["height"] + + # Use common frame rate (from first input) and sample rate + target_fps = infos[0]["fps"] + target_sr = max(i["sample_rate"] for i in infos) + + # Build filter for each input: scale to fit + pad to target size + filter_parts = [] + for i in range(n): + # Scale to fit within target, maintaining aspect ratio, then pad + vf = ( + f"[{i}:v]scale={target_w}:{target_h}:force_original_aspect_ratio=decrease," + f"pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2:color={background}," + f"setsar=1,fps={target_fps:.6f}[v{i}]" + ) + # Resample audio to common rate + af = f"[{i}:a]aresample={target_sr}[a{i}]" + filter_parts.append(vf) + filter_parts.append(af) + + # Build concat filter + stream_labels = "".join(f"[v{i}][a{i}]" for i in range(n)) + filter_parts.append(f"{stream_labels}concat=n={n}:v=1:a=1[outv][outa]") + + filter_complex = ";".join(filter_parts) + + cmd = [ + "ffmpeg", "-y", + *input_args, + "-filter_complex", filter_complex, + "-map", "[outv]", + "-map", "[outa]", + *get_web_encoding_args(), + str(output_path) + ] + + logger.debug(f"SEQUENCE cut: {n} clips -> {target_w}x{target_h} (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Concat failed: {result.stderr}") + + return output_path + + def _concat_crossfade( + self, + inputs: List[Path], + output_path: Path, + duration: float, + ) -> Path: + """Concatenate with crossfade transitions.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + durations = [_get_duration(p) for p in inputs] + n = len(inputs) + input_args = " ".join(f"-i {p}" for p in inputs) + + # Build xfade filter chain + filter_parts = [] + current = "[0:v]" + + for i in range(1, n): + offset = sum(durations[:i]) - duration * i + next_input = f"[{i}:v]" + output_label = f"[v{i}]" if i < n - 1 else "[outv]" + filter_parts.append( + f"{current}{next_input}xfade=transition=fade:duration={duration}:offset={offset}{output_label}" + ) + current = output_label + + # Audio crossfade chain + audio_current = "[0:a]" + for i in range(1, n): + next_input = f"[{i}:a]" + output_label = f"[a{i}]" if i < n - 1 else "[outa]" + filter_parts.append( + f"{audio_current}{next_input}acrossfade=d={duration}{output_label}" + ) + audio_current = output_label + + filter_complex = ";".join(filter_parts) + + cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map [outa] {WEB_ENCODING_ARGS_STR} {output_path}' + + logger.debug(f"SEQUENCE crossfade: {n} clips (web-optimized)") + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if result.returncode != 0: + logger.warning(f"Crossfade failed, falling back to cut: {result.stderr[:200]}") + return self._concat_cut(inputs, output_path) + + return output_path + + def _concat_fade( + self, + inputs: List[Path], + output_path: Path, + duration: float, + ) -> Path: + """Concatenate with fade out/in transitions.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + faded_paths = [] + for i, path in enumerate(inputs): + clip_dur = _get_duration(path) + faded_path = output_path.parent / f"_faded_{i}.mkv" + + cmd = [ + "ffmpeg", "-y", + "-i", str(path), + "-vf", f"fade=in:st=0:d={duration},fade=out:st={clip_dur - duration}:d={duration}", + "-af", f"afade=in:st=0:d={duration},afade=out:st={clip_dur - duration}:d={duration}", + "-c:v", "libx264", "-preset", "ultrafast", "-crf", "18", + "-c:a", "aac", + str(faded_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + faded_paths.append(faded_path) + + result = self._concat_cut(faded_paths, output_path) + + for p in faded_paths: + p.unlink() + + return result + + +@register_executor(NodeType.LAYER) +class LayerExecutor(Executor): + """ + Layer inputs spatially (overlay/composite). + + Config: + inputs: List of per-input configs + position: [x, y] offset + opacity: 0.0-1.0 + scale: Scale factor + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 1: + raise ValueError("LAYER requires at least one input") + + if len(inputs) == 1: + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(inputs[0], output_path) + return output_path + + input_configs = config.get("inputs", [{}] * len(inputs)) + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_args = " ".join(f"-i {p}" for p in inputs) + n = len(inputs) + filter_parts = [] + current = "[0:v]" + + for i in range(1, n): + cfg = input_configs[i] if i < len(input_configs) else {} + x, y = cfg.get("position", [0, 0]) + opacity = cfg.get("opacity", 1.0) + scale = cfg.get("scale", 1.0) + + scale_label = f"[s{i}]" + if scale != 1.0: + filter_parts.append(f"[{i}:v]scale=iw*{scale}:ih*{scale}{scale_label}") + overlay_input = scale_label + else: + overlay_input = f"[{i}:v]" + + output_label = f"[v{i}]" if i < n - 1 else "[outv]" + + if opacity < 1.0: + filter_parts.append( + f"{overlay_input}format=rgba,colorchannelmixer=aa={opacity}[a{i}]" + ) + overlay_input = f"[a{i}]" + + filter_parts.append( + f"{current}{overlay_input}overlay=x={x}:y={y}:format=auto{output_label}" + ) + current = output_label + + filter_complex = ";".join(filter_parts) + + cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map 0:a? {WEB_ENCODING_ARGS_STR} {output_path}' + + logger.debug(f"LAYER: {n} inputs (web-optimized)") + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Layer failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.MUX) +class MuxExecutor(Executor): + """ + Combine video and audio streams. + + Config: + video_stream: Index of video input (default: 0) + audio_stream: Index of audio input (default: 1) + shortest: End when shortest stream ends (default: True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 2: + raise ValueError("MUX requires at least 2 inputs (video + audio)") + + video_idx = config.get("video_stream", 0) + audio_idx = config.get("audio_stream", 1) + shortest = config.get("shortest", True) + + video_path = inputs[video_idx] + audio_path = inputs[audio_idx] + + output_path.parent.mkdir(parents=True, exist_ok=True) + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-map", "0:v:0", + "-map", "1:a:0", + ] + + if shortest: + cmd.append("-shortest") + + cmd.append(str(output_path)) + + logger.debug(f"MUX: video={video_path.name} + audio={audio_path.name}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Mux failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.BLEND) +class BlendExecutor(Executor): + """ + Blend two inputs using a blend mode. + + Config: + mode: Blend mode (multiply, screen, overlay, add, etc.) + opacity: 0.0-1.0 for second input + """ + + BLEND_MODES = { + "multiply": "multiply", + "screen": "screen", + "overlay": "overlay", + "add": "addition", + "subtract": "subtract", + "average": "average", + "difference": "difference", + "lighten": "lighten", + "darken": "darken", + } + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 2: + raise ValueError("BLEND requires exactly 2 inputs") + + mode = config.get("mode", "overlay") + opacity = config.get("opacity", 0.5) + + if mode not in self.BLEND_MODES: + raise ValueError(f"Unknown blend mode: {mode}") + + output_path.parent.mkdir(parents=True, exist_ok=True) + blend_mode = self.BLEND_MODES[mode] + + if opacity < 1.0: + filter_complex = ( + f"[1:v]format=rgba,colorchannelmixer=aa={opacity}[b];" + f"[0:v][b]blend=all_mode={blend_mode}" + ) + else: + filter_complex = f"[0:v][1:v]blend=all_mode={blend_mode}" + + cmd = [ + "ffmpeg", "-y", + "-i", str(inputs[0]), + "-i", str(inputs[1]), + "-filter_complex", filter_complex, + "-map", "0:a?", + *get_web_encoding_args(), + str(output_path) + ] + + logger.debug(f"BLEND: {mode} (opacity={opacity}) (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Blend failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.AUDIO_MIX) +class AudioMixExecutor(Executor): + """ + Mix multiple audio streams. + + Config: + gains: List of gain values per input (0.0-2.0, default 1.0) + normalize: Normalize output to prevent clipping (default True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 2: + raise ValueError("AUDIO_MIX requires at least 2 inputs") + + gains = config.get("gains", [1.0] * len(inputs)) + normalize = config.get("normalize", True) + + # Pad gains list if too short + while len(gains) < len(inputs): + gains.append(1.0) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Build filter: apply volume to each input, then mix + filter_parts = [] + mix_inputs = [] + + for i, gain in enumerate(gains[:len(inputs)]): + if gain != 1.0: + filter_parts.append(f"[{i}:a]volume={gain}[a{i}]") + mix_inputs.append(f"[a{i}]") + else: + mix_inputs.append(f"[{i}:a]") + + # amix filter + normalize_flag = 1 if normalize else 0 + mix_filter = f"{''.join(mix_inputs)}amix=inputs={len(inputs)}:normalize={normalize_flag}[aout]" + filter_parts.append(mix_filter) + + filter_complex = ";".join(filter_parts) + + cmd = [ + "ffmpeg", "-y", + ] + for p in inputs: + cmd.extend(["-i", str(p)]) + + cmd.extend([ + "-filter_complex", filter_complex, + "-map", "[aout]", + "-c:a", "aac", + str(output_path) + ]) + + logger.debug(f"AUDIO_MIX: {len(inputs)} inputs, gains={gains[:len(inputs)]}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Audio mix failed: {result.stderr}") + + return output_path diff --git a/artdag/nodes/effect.py b/artdag/nodes/effect.py new file mode 100644 index 0000000..7b36a3d --- /dev/null +++ b/artdag/nodes/effect.py @@ -0,0 +1,520 @@ +# artdag/nodes/effect.py +""" +Effect executor: Apply effects from the registry or IPFS. + +Primitives: EFFECT + +Effects can be: +1. Built-in (registered with @register_effect) +2. Stored in IPFS (referenced by CID) +""" + +import importlib.util +import logging +import os +import re +import shutil +import tempfile +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import requests + +from ..executor import Executor, register_executor + +logger = logging.getLogger(__name__) + +# Type alias for effect functions: (input_path, output_path, config) -> output_path +EffectFn = Callable[[Path, Path, Dict[str, Any]], Path] + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + +# IPFS API multiaddr - same as ipfs_client.py for consistency +# Docker uses /dns/ipfs/tcp/5001, local dev uses /ip4/127.0.0.1/tcp/5001 +IPFS_API = os.environ.get("IPFS_API", "/ip4/127.0.0.1/tcp/5001") + +# Connection timeout in seconds +IPFS_TIMEOUT = int(os.environ.get("IPFS_TIMEOUT", "30")) + + +def _get_ipfs_base_url() -> str: + """ + Convert IPFS multiaddr to HTTP URL. + + Matches the conversion logic in ipfs_client.py for consistency. + """ + multiaddr = IPFS_API + + # Handle /dns/hostname/tcp/port format (Docker) + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format (local) + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +def _get_effects_cache_dir() -> Optional[Path]: + """Get the effects cache directory from environment or default.""" + # Check both env var names (CACHE_DIR used by art-celery, ARTDAG_CACHE_DIR for standalone) + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = os.environ.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + if effects_dir.exists(): + return effects_dir + + # Try default locations + for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]: + effects_dir = base / "_effects" + if effects_dir.exists(): + return effects_dir + + return None + + +def _fetch_effect_from_ipfs(cid: str, effect_path: Path) -> bool: + """ + Fetch an effect from IPFS and cache locally. + + Uses the IPFS API endpoint (/api/v0/cat) for consistency with ipfs_client.py. + This works reliably in Docker where IPFS_API=/dns/ipfs/tcp/5001. + + Returns True on success, False on failure. + """ + try: + # Use IPFS API (same as ipfs_client.py) + base_url = _get_ipfs_base_url() + url = f"{base_url}/api/v0/cat" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + # Cache locally + effect_path.parent.mkdir(parents=True, exist_ok=True) + effect_path.write_bytes(response.content) + logger.info(f"Fetched effect from IPFS: {cid[:16]}...") + return True + + except Exception as e: + logger.error(f"Failed to fetch effect from IPFS {cid[:16]}...: {e}") + return False + + +def _parse_pep723_dependencies(source: str) -> List[str]: + """ + Parse PEP 723 dependencies from effect source code. + + Returns list of package names (e.g., ["numpy", "opencv-python"]). + """ + match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL) + if not match: + return [] + + block = match.group(1) + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if not deps_match: + return [] + + return re.findall(r'"([^"]+)"', deps_match.group(1)) + + +def _ensure_dependencies(dependencies: List[str], effect_cid: str) -> bool: + """ + Ensure effect dependencies are installed. + + Installs missing packages using pip. Returns True on success. + """ + if not dependencies: + return True + + missing = [] + for dep in dependencies: + # Extract package name (strip version specifiers) + pkg_name = re.split(r'[<>=!]', dep)[0].strip() + # Normalize name (pip uses underscores, imports use underscores or hyphens) + pkg_name_normalized = pkg_name.replace('-', '_').lower() + + try: + __import__(pkg_name_normalized) + except ImportError: + # Some packages have different import names + try: + # Try original name with hyphens replaced + __import__(pkg_name.replace('-', '_')) + except ImportError: + missing.append(dep) + + if not missing: + return True + + logger.info(f"Installing effect dependencies for {effect_cid[:16]}...: {missing}") + + try: + import subprocess + import sys + + result = subprocess.run( + [sys.executable, "-m", "pip", "install", "--quiet"] + missing, + capture_output=True, + text=True, + timeout=120, + ) + + if result.returncode != 0: + logger.error(f"Failed to install dependencies: {result.stderr}") + return False + + logger.info(f"Installed dependencies: {missing}") + return True + + except Exception as e: + logger.error(f"Error installing dependencies: {e}") + return False + + +def _load_cached_effect(effect_cid: str) -> Optional[EffectFn]: + """ + Load an effect by CID, fetching from IPFS if not cached locally. + + Returns the effect function or None if not found. + """ + effects_dir = _get_effects_cache_dir() + + # Create cache dir if needed + if not effects_dir: + # Try to create default cache dir + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = os.environ.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + break + + if not effects_dir: + effects_dir = Path.home() / ".artdag" / "cache" / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + + effect_path = effects_dir / effect_cid / "effect.py" + + # If not cached locally, fetch from IPFS + if not effect_path.exists(): + if not _fetch_effect_from_ipfs(effect_cid, effect_path): + logger.warning(f"Effect not found: {effect_cid[:16]}...") + return None + + # Parse and install dependencies before loading + try: + source = effect_path.read_text() + dependencies = _parse_pep723_dependencies(source) + if dependencies: + logger.info(f"Effect {effect_cid[:16]}... requires: {dependencies}") + if not _ensure_dependencies(dependencies, effect_cid): + logger.error(f"Failed to install dependencies for effect {effect_cid[:16]}...") + return None + except Exception as e: + logger.error(f"Error parsing effect dependencies: {e}") + # Continue anyway - the effect might work without the deps check + + # Load the effect module + try: + spec = importlib.util.spec_from_file_location("cached_effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Check for frame-by-frame API + if hasattr(module, "process_frame"): + return _wrap_frame_effect(module, effect_path) + + # Check for whole-video API + if hasattr(module, "process"): + return _wrap_video_effect(module) + + # Check for old-style effect function + if hasattr(module, "effect"): + return module.effect + + logger.warning(f"Effect has no recognized API: {effect_cid[:16]}...") + return None + + except Exception as e: + logger.error(f"Failed to load effect {effect_cid[:16]}...: {e}") + return None + + +def _wrap_frame_effect(module: ModuleType, effect_path: Path) -> EffectFn: + """Wrap a frame-by-frame effect to work with the executor API.""" + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run frame-by-frame effect through FFmpeg pipes.""" + try: + from ..effects.frame_processor import process_video + except ImportError: + logger.error("Frame processor not available - falling back to copy") + shutil.copy2(input_path, output_path) + return output_path + + # Extract params from config (excluding internal keys) + params = {k: v for k, v in config.items() + if k not in ("effect", "hash", "_binding")} + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=module.process_frame, + params=params, + bindings=bindings, + ) + + return actual_output + + return wrapped_effect + + +def _wrap_video_effect(module: ModuleType) -> EffectFn: + """Wrap a whole-video effect to work with the executor API.""" + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run whole-video effect.""" + from ..effects.meta import ExecutionContext + + params = {k: v for k, v in config.items() + if k not in ("effect", "hash", "_binding")} + + output_path.parent.mkdir(parents=True, exist_ok=True) + + ctx = ExecutionContext( + input_paths=[str(input_path)], + output_path=str(output_path), + params=params, + seed=hash(str(input_path)) & 0xFFFFFFFF, + ) + + module.process([input_path], output_path, params, ctx) + return output_path + + return wrapped_effect + + +# Effect registry - maps effect names to implementations +_EFFECTS: Dict[str, EffectFn] = {} + + +def register_effect(name: str) -> Callable[[F], F]: + """Decorator to register an effect implementation.""" + def decorator(func: F) -> F: + _EFFECTS[name] = func # type: ignore[assignment] + return func + return decorator + + +def get_effect(name: str) -> Optional[EffectFn]: + """Get an effect implementation by name.""" + return _EFFECTS.get(name) + + +# Built-in effects + +@register_effect("identity") +def effect_identity(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """ + Identity effect - returns input unchanged. + + This is the foundational effect: identity(x) = x + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Remove existing output if any + if output_path.exists() or output_path.is_symlink(): + output_path.unlink() + + # Preserve extension from input + actual_output = output_path.with_suffix(input_path.suffix) + if actual_output.exists() or actual_output.is_symlink(): + actual_output.unlink() + + # Symlink to input (zero-copy identity) + os.symlink(input_path.resolve(), actual_output) + logger.debug(f"EFFECT identity: {input_path.name} -> {actual_output}") + + return actual_output + + +def _get_sexp_effect(effect_path: str, recipe_dir: Path = None) -> Optional[EffectFn]: + """ + Load a sexp effect from a .sexp file. + + Args: + effect_path: Relative path to the .sexp effect file + recipe_dir: Base directory for resolving paths + + Returns: + Effect function or None if not a sexp effect + """ + if not effect_path or not effect_path.endswith(".sexp"): + return None + + try: + from ..sexp.effect_loader import SexpEffectLoader + except ImportError: + logger.warning("Sexp effect loader not available") + return None + + try: + loader = SexpEffectLoader(recipe_dir or Path.cwd()) + return loader.load_effect_path(effect_path) + except Exception as e: + logger.error(f"Failed to load sexp effect from {effect_path}: {e}") + return None + + +def _get_python_primitive_effect(effect_name: str) -> Optional[EffectFn]: + """ + Get a Python primitive frame processor effect. + + Checks if the effect has a python_primitive in FFmpegCompiler.EFFECT_MAPPINGS + and wraps it for the executor API. + """ + try: + from ..sexp.ffmpeg_compiler import FFmpegCompiler + from ..sexp.primitives import get_primitive + from ..effects.frame_processor import process_video + except ImportError: + return None + + compiler = FFmpegCompiler() + primitive_name = compiler.has_python_primitive(effect_name) + if not primitive_name: + return None + + primitive_fn = get_primitive(primitive_name) + if not primitive_fn: + logger.warning(f"Python primitive '{primitive_name}' not found for effect '{effect_name}'") + return None + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run Python primitive effect via frame processor.""" + # Extract params (excluding internal keys) + params = {k: v for k, v in config.items() + if k not in ("effect", "cid", "hash", "effect_path", "_binding")} + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + # Wrap primitive to match frame processor signature + def process_frame(frame, frame_params, state): + # Call primitive with frame and params + result = primitive_fn(frame, **frame_params) + return result, state + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=process_frame, + params=params, + bindings=bindings, + ) + + logger.info(f"Processed effect '{effect_name}' via Python primitive '{primitive_name}'") + return actual_output + + return wrapped_effect + + +@register_executor("EFFECT") +class EffectExecutor(Executor): + """ + Apply an effect from the registry or IPFS. + + Config: + effect: Name of the effect to apply + cid: IPFS CID for the effect (fetched from IPFS if not cached) + hash: Legacy alias for cid (backwards compatibility) + params: Optional parameters for the effect + + Inputs: + Single input file to transform + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + effect_name = config.get("effect") + # Support both "cid" (new) and "hash" (legacy) + effect_cid = config.get("cid") or config.get("hash") + + if not effect_name: + raise ValueError("EFFECT requires 'effect' config") + + if len(inputs) != 1: + raise ValueError(f"EFFECT expects 1 input, got {len(inputs)}") + + # Try IPFS effect first if CID provided + effect_fn: Optional[EffectFn] = None + if effect_cid: + effect_fn = _load_cached_effect(effect_cid) + if effect_fn: + logger.info(f"Running effect '{effect_name}' (cid={effect_cid[:16]}...)") + + # Try sexp effect from effect_path (.sexp file) + if effect_fn is None: + effect_path = config.get("effect_path") + if effect_path and effect_path.endswith(".sexp"): + effect_fn = _get_sexp_effect(effect_path) + if effect_fn: + logger.info(f"Running effect '{effect_name}' via sexp definition") + + # Try Python primitive (from FFmpegCompiler.EFFECT_MAPPINGS) + if effect_fn is None: + effect_fn = _get_python_primitive_effect(effect_name) + if effect_fn: + logger.info(f"Running effect '{effect_name}' via Python primitive") + + # Fall back to built-in effect + if effect_fn is None: + effect_fn = get_effect(effect_name) + + if effect_fn is None: + raise ValueError(f"Unknown effect: {effect_name}") + + # Pass full config (effect can extract what it needs) + return effect_fn(inputs[0], output_path, config) + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + errors = [] + if "effect" not in config: + errors.append("EFFECT requires 'effect' config") + else: + # If CID provided, we'll load from IPFS - skip built-in check + has_cid = config.get("cid") or config.get("hash") + if not has_cid and get_effect(config["effect"]) is None: + errors.append(f"Unknown effect: {config['effect']}") + return errors diff --git a/artdag/nodes/encoding.py b/artdag/nodes/encoding.py new file mode 100644 index 0000000..863d062 --- /dev/null +++ b/artdag/nodes/encoding.py @@ -0,0 +1,50 @@ +# artdag/nodes/encoding.py +""" +Web-optimized video encoding settings. + +Provides common FFmpeg arguments for producing videos that: +- Stream efficiently (faststart) +- Play on all browsers (H.264 High profile) +- Support seeking (regular keyframes) +""" + +from typing import List + +# Standard web-optimized video encoding arguments +WEB_VIDEO_ARGS: List[str] = [ + "-c:v", "libx264", + "-preset", "fast", + "-crf", "18", + "-profile:v", "high", + "-level", "4.1", + "-pix_fmt", "yuv420p", # Ensure broad compatibility + "-movflags", "+faststart", # Enable streaming before full download + "-g", "48", # Keyframe every ~2 seconds at 24fps (for seeking) +] + +# Standard audio encoding arguments +WEB_AUDIO_ARGS: List[str] = [ + "-c:a", "aac", + "-b:a", "192k", +] + + +def get_web_encoding_args() -> List[str]: + """Get FFmpeg args for web-optimized video+audio encoding.""" + return WEB_VIDEO_ARGS + WEB_AUDIO_ARGS + + +def get_web_video_args() -> List[str]: + """Get FFmpeg args for web-optimized video encoding only.""" + return WEB_VIDEO_ARGS.copy() + + +def get_web_audio_args() -> List[str]: + """Get FFmpeg args for web-optimized audio encoding only.""" + return WEB_AUDIO_ARGS.copy() + + +# For shell commands (string format) +WEB_VIDEO_ARGS_STR = " ".join(WEB_VIDEO_ARGS) +WEB_AUDIO_ARGS_STR = " ".join(WEB_AUDIO_ARGS) +WEB_ENCODING_ARGS_STR = f"{WEB_VIDEO_ARGS_STR} {WEB_AUDIO_ARGS_STR}" diff --git a/artdag/nodes/source.py b/artdag/nodes/source.py new file mode 100644 index 0000000..1fc7ef1 --- /dev/null +++ b/artdag/nodes/source.py @@ -0,0 +1,62 @@ +# primitive/nodes/source.py +""" +Source executors: Load media from paths. + +Primitives: SOURCE +""" + +import logging +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor + +logger = logging.getLogger(__name__) + + +@register_executor(NodeType.SOURCE) +class SourceExecutor(Executor): + """ + Load source media from a path. + + Config: + path: Path to source file + + Creates a symlink to the source file for zero-copy loading. + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + source_path = Path(config["path"]) + + if not source_path.exists(): + raise FileNotFoundError(f"Source file not found: {source_path}") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Use symlink for zero-copy + if output_path.exists() or output_path.is_symlink(): + output_path.unlink() + + # Preserve extension from source + actual_output = output_path.with_suffix(source_path.suffix) + if actual_output.exists() or actual_output.is_symlink(): + actual_output.unlink() + + os.symlink(source_path.resolve(), actual_output) + logger.debug(f"SOURCE: {source_path.name} -> {actual_output}") + + return actual_output + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + errors = [] + if "path" not in config: + errors.append("SOURCE requires 'path' config") + return errors diff --git a/artdag/nodes/transform.py b/artdag/nodes/transform.py new file mode 100644 index 0000000..e91ba6f --- /dev/null +++ b/artdag/nodes/transform.py @@ -0,0 +1,224 @@ +# primitive/nodes/transform.py +""" +Transform executors: Modify single media inputs. + +Primitives: SEGMENT, RESIZE, TRANSFORM +""" + +import logging +import subprocess +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor +from .encoding import get_web_encoding_args, get_web_video_args + +logger = logging.getLogger(__name__) + + +@register_executor(NodeType.SEGMENT) +class SegmentExecutor(Executor): + """ + Extract a time segment from media. + + Config: + offset: Start time in seconds (default: 0) + duration: Duration in seconds (optional, default: to end) + precise: Use frame-accurate seeking (default: True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("SEGMENT requires exactly one input") + + input_path = inputs[0] + offset = config.get("offset", 0) + duration = config.get("duration") + precise = config.get("precise", True) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + if precise: + # Frame-accurate: decode-seek (slower but precise) + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if offset > 0: + cmd.extend(["-ss", str(offset)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend([*get_web_encoding_args(), str(output_path)]) + else: + # Fast: input-seek at keyframes (may be slightly off) + cmd = ["ffmpeg", "-y"] + if offset > 0: + cmd.extend(["-ss", str(offset)]) + cmd.extend(["-i", str(input_path)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend(["-c", "copy", str(output_path)]) + + logger.debug(f"SEGMENT: offset={offset}, duration={duration}, precise={precise}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Segment failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.RESIZE) +class ResizeExecutor(Executor): + """ + Resize media to target dimensions. + + Config: + width: Target width + height: Target height + mode: "fit" (letterbox), "fill" (crop), "stretch", "pad" + background: Background color for pad mode (default: black) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("RESIZE requires exactly one input") + + input_path = inputs[0] + width = config["width"] + height = config["height"] + mode = config.get("mode", "fit") + background = config.get("background", "black") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + if mode == "fit": + # Scale to fit, add letterboxing + vf = f"scale={width}:{height}:force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}" + elif mode == "fill": + # Scale to fill, crop excess + vf = f"scale={width}:{height}:force_original_aspect_ratio=increase,crop={width}:{height}" + elif mode == "stretch": + # Stretch to exact size + vf = f"scale={width}:{height}" + elif mode == "pad": + # Scale down only if larger, then pad + vf = f"scale='min({width},iw)':'min({height},ih)':force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}" + else: + raise ValueError(f"Unknown resize mode: {mode}") + + cmd = [ + "ffmpeg", "-y", + "-i", str(input_path), + "-vf", vf, + *get_web_video_args(), + "-c:a", "copy", + str(output_path) + ] + + logger.debug(f"RESIZE: {width}x{height} ({mode}) (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Resize failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.TRANSFORM) +class TransformExecutor(Executor): + """ + Apply visual effects to media. + + Config: + effects: Dict of effect -> value + saturation: 0.0-2.0 (1.0 = normal) + contrast: 0.0-2.0 (1.0 = normal) + brightness: -1.0 to 1.0 (0.0 = normal) + gamma: 0.1-10.0 (1.0 = normal) + hue: degrees shift + blur: blur radius + sharpen: sharpen amount + speed: playback speed multiplier + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("TRANSFORM requires exactly one input") + + input_path = inputs[0] + effects = config.get("effects", {}) + + if not effects: + # No effects - just copy + import shutil + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(input_path, output_path) + return output_path + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Build filter chain + vf_parts = [] + af_parts = [] + + # Color adjustments via eq filter + eq_parts = [] + if "saturation" in effects: + eq_parts.append(f"saturation={effects['saturation']}") + if "contrast" in effects: + eq_parts.append(f"contrast={effects['contrast']}") + if "brightness" in effects: + eq_parts.append(f"brightness={effects['brightness']}") + if "gamma" in effects: + eq_parts.append(f"gamma={effects['gamma']}") + if eq_parts: + vf_parts.append(f"eq={':'.join(eq_parts)}") + + # Hue adjustment + if "hue" in effects: + vf_parts.append(f"hue=h={effects['hue']}") + + # Blur + if "blur" in effects: + vf_parts.append(f"boxblur={effects['blur']}") + + # Sharpen + if "sharpen" in effects: + vf_parts.append(f"unsharp=5:5:{effects['sharpen']}:5:5:0") + + # Speed change + if "speed" in effects: + speed = effects["speed"] + vf_parts.append(f"setpts={1/speed}*PTS") + af_parts.append(f"atempo={speed}") + + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + + if vf_parts: + cmd.extend(["-vf", ",".join(vf_parts)]) + if af_parts: + cmd.extend(["-af", ",".join(af_parts)]) + + cmd.extend([*get_web_encoding_args(), str(output_path)]) + + logger.debug(f"TRANSFORM: {list(effects.keys())} (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Transform failed: {result.stderr}") + + return output_path diff --git a/artdag/planning/__init__.py b/artdag/planning/__init__.py new file mode 100644 index 0000000..1d5c89f --- /dev/null +++ b/artdag/planning/__init__.py @@ -0,0 +1,29 @@ +# artdag/planning - Execution plan generation +# +# Provides the Planning phase of the 3-phase execution model: +# 1. ANALYZE - Extract features from inputs +# 2. PLAN - Generate execution plan with cache IDs +# 3. EXECUTE - Run steps with caching + +from .schema import ( + ExecutionStep, + ExecutionPlan, + StepStatus, + StepOutput, + StepInput, + PlanInput, +) +from .planner import RecipePlanner, Recipe +from .tree_reduction import TreeReducer + +__all__ = [ + "ExecutionStep", + "ExecutionPlan", + "StepStatus", + "StepOutput", + "StepInput", + "PlanInput", + "RecipePlanner", + "Recipe", + "TreeReducer", +] diff --git a/artdag/planning/planner.py b/artdag/planning/planner.py new file mode 100644 index 0000000..18f30d8 --- /dev/null +++ b/artdag/planning/planner.py @@ -0,0 +1,756 @@ +# artdag/planning/planner.py +""" +Recipe planner - converts recipes into execution plans. + +The planner is the second phase of the 3-phase execution model. +It takes a recipe and analysis results and generates a complete +execution plan with pre-computed cache IDs. +""" + +import hashlib +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml + +from .schema import ExecutionPlan, ExecutionStep, StepOutput, StepInput, PlanInput +from .tree_reduction import TreeReducer, reduce_sequence +from ..analysis import AnalysisResult + + +def _infer_media_type(node_type: str, config: Dict[str, Any] = None) -> str: + """Infer media type from node type and config.""" + config = config or {} + + # Audio operations + if node_type in ("AUDIO", "MIX_AUDIO", "EXTRACT_AUDIO"): + return "audio/wav" + if "audio" in node_type.lower(): + return "audio/wav" + + # Image operations + if node_type in ("FRAME", "THUMBNAIL", "IMAGE"): + return "image/png" + + # Default to video + return "video/mp4" + +logger = logging.getLogger(__name__) + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """Create stable hash from arbitrary data.""" + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class RecipeNode: + """A node in the recipe DAG.""" + id: str + type: str + config: Dict[str, Any] + inputs: List[str] + + +@dataclass +class Recipe: + """Parsed recipe structure.""" + name: str + version: str + description: str + nodes: List[RecipeNode] + output: str + registry: Dict[str, Any] + owner: str + raw_yaml: str + + @property + def recipe_hash(self) -> str: + """Compute hash of recipe content.""" + return _stable_hash({"yaml": self.raw_yaml}) + + @classmethod + def from_yaml(cls, yaml_content: str) -> "Recipe": + """Parse recipe from YAML string.""" + data = yaml.safe_load(yaml_content) + + nodes = [] + for node_data in data.get("dag", {}).get("nodes", []): + # Handle both 'inputs' as list and 'inputs' as dict + inputs = node_data.get("inputs", []) + if isinstance(inputs, dict): + # Extract input references from dict structure + input_list = [] + for key, value in inputs.items(): + if isinstance(value, str): + input_list.append(value) + elif isinstance(value, list): + input_list.extend(value) + inputs = input_list + elif isinstance(inputs, str): + inputs = [inputs] + + nodes.append(RecipeNode( + id=node_data["id"], + type=node_data["type"], + config=node_data.get("config", {}), + inputs=inputs, + )) + + return cls( + name=data.get("name", "unnamed"), + version=data.get("version", "1.0"), + description=data.get("description", ""), + nodes=nodes, + output=data.get("dag", {}).get("output", ""), + registry=data.get("registry", {}), + owner=data.get("owner", ""), + raw_yaml=yaml_content, + ) + + @classmethod + def from_file(cls, path: Path) -> "Recipe": + """Load recipe from YAML file.""" + with open(path, "r") as f: + return cls.from_yaml(f.read()) + + +class RecipePlanner: + """ + Generates execution plans from recipes. + + The planner: + 1. Parses the recipe + 2. Resolves fixed inputs from registry + 3. Maps variable inputs to provided hashes + 4. Expands MAP/iteration nodes + 5. Applies tree reduction for SEQUENCE nodes + 6. Computes cache IDs for all steps + """ + + def __init__(self, use_tree_reduction: bool = True): + """ + Initialize the planner. + + Args: + use_tree_reduction: Whether to use tree reduction for SEQUENCE + """ + self.use_tree_reduction = use_tree_reduction + + def plan( + self, + recipe: Recipe, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + seed: Optional[int] = None, + ) -> ExecutionPlan: + """ + Generate an execution plan from a recipe. + + Args: + recipe: The parsed recipe + input_hashes: Mapping from input name to content hash + analysis: Analysis results for inputs (keyed by hash) + seed: Random seed for deterministic planning + + Returns: + ExecutionPlan with pre-computed cache IDs + """ + logger.info(f"Planning recipe: {recipe.name}") + + # Build node lookup + nodes_by_id = {n.id: n for n in recipe.nodes} + + # Topologically sort nodes + sorted_ids = self._topological_sort(recipe.nodes) + + # Resolve registry references + registry_hashes = self._resolve_registry(recipe.registry) + + # Build PlanInput objects from input_hashes + plan_inputs = [] + for name, cid in input_hashes.items(): + # Try to find matching SOURCE node for media type + media_type = "application/octet-stream" + for node in recipe.nodes: + if node.id == name and node.type == "SOURCE": + media_type = _infer_media_type("SOURCE", node.config) + break + + plan_inputs.append(PlanInput( + name=name, + cache_id=cid, + cid=cid, + media_type=media_type, + )) + + # Generate steps + steps = [] + step_id_map = {} # Maps recipe node ID to step ID(s) + step_name_map = {} # Maps recipe node ID to human-readable name + analysis_cache_ids = {} + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + logger.debug(f"Processing node: {node.id} ({node.type})") + + new_steps, output_step_id = self._process_node( + node=node, + step_id_map=step_id_map, + step_name_map=step_name_map, + input_hashes=input_hashes, + registry_hashes=registry_hashes, + analysis=analysis or {}, + recipe_name=recipe.name, + ) + + steps.extend(new_steps) + step_id_map[node_id] = output_step_id + # Track human-readable name for this node + if new_steps: + step_name_map[node_id] = new_steps[-1].name + + # Find output step + output_step = step_id_map.get(recipe.output) + if not output_step: + raise ValueError(f"Output node '{recipe.output}' not found") + + # Determine output name + output_name = f"{recipe.name}.output" + output_step_obj = next((s for s in steps if s.step_id == output_step), None) + if output_step_obj and output_step_obj.outputs: + output_name = output_step_obj.outputs[0].name + + # Build analysis cache IDs + if analysis: + analysis_cache_ids = { + h: a.cache_id for h, a in analysis.items() + if a.cache_id + } + + # Create plan + plan = ExecutionPlan( + plan_id=None, # Computed in __post_init__ + name=f"{recipe.name}_plan", + recipe_id=recipe.name, + recipe_name=recipe.name, + recipe_hash=recipe.recipe_hash, + seed=seed, + inputs=plan_inputs, + steps=steps, + output_step=output_step, + output_name=output_name, + analysis_cache_ids=analysis_cache_ids, + input_hashes=input_hashes, + metadata={ + "recipe_version": recipe.version, + "recipe_description": recipe.description, + "owner": recipe.owner, + }, + ) + + # Compute all cache IDs and then generate outputs + plan.compute_all_cache_ids() + plan.compute_levels() + + # Now add outputs to each step (needs cache_id to be computed first) + self._add_step_outputs(plan, recipe.name) + + # Recompute plan_id after outputs are added + plan.plan_id = plan._compute_plan_id() + + logger.info(f"Generated plan with {len(steps)} steps") + return plan + + def _add_step_outputs(self, plan: ExecutionPlan, recipe_name: str) -> None: + """Add output definitions to each step after cache_ids are computed.""" + for step in plan.steps: + if step.outputs: + continue # Already has outputs + + # Generate output name from step name + base_name = step.name or step.step_id + output_name = f"{recipe_name}.{base_name}.out" + + media_type = _infer_media_type(step.node_type, step.config) + + step.add_output( + name=output_name, + media_type=media_type, + index=0, + metadata={}, + ) + + def plan_from_yaml( + self, + yaml_content: str, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + ) -> ExecutionPlan: + """ + Generate plan from YAML string. + + Args: + yaml_content: Recipe YAML content + input_hashes: Mapping from input name to content hash + analysis: Analysis results + + Returns: + ExecutionPlan + """ + recipe = Recipe.from_yaml(yaml_content) + return self.plan(recipe, input_hashes, analysis) + + def plan_from_file( + self, + recipe_path: Path, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + ) -> ExecutionPlan: + """ + Generate plan from recipe file. + + Args: + recipe_path: Path to recipe YAML file + input_hashes: Mapping from input name to content hash + analysis: Analysis results + + Returns: + ExecutionPlan + """ + recipe = Recipe.from_file(recipe_path) + return self.plan(recipe, input_hashes, analysis) + + def _topological_sort(self, nodes: List[RecipeNode]) -> List[str]: + """Topologically sort recipe nodes.""" + nodes_by_id = {n.id: n for n in nodes} + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + if node_id not in nodes_by_id: + return # External input + visited.add(node_id) + node = nodes_by_id[node_id] + for input_id in node.inputs: + visit(input_id) + order.append(node_id) + + for node in nodes: + visit(node.id) + + return order + + def _resolve_registry(self, registry: Dict[str, Any]) -> Dict[str, str]: + """ + Resolve registry references to content hashes. + + Args: + registry: Registry section from recipe + + Returns: + Mapping from name to content hash + """ + hashes = {} + + # Assets + for name, asset_data in registry.get("assets", {}).items(): + if isinstance(asset_data, dict) and "hash" in asset_data: + hashes[name] = asset_data["hash"] + elif isinstance(asset_data, str): + hashes[name] = asset_data + + # Effects + for name, effect_data in registry.get("effects", {}).items(): + if isinstance(effect_data, dict) and "hash" in effect_data: + hashes[f"effect:{name}"] = effect_data["hash"] + elif isinstance(effect_data, str): + hashes[f"effect:{name}"] = effect_data + + return hashes + + def _process_node( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + step_name_map: Dict[str, str], + input_hashes: Dict[str, str], + registry_hashes: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process a recipe node into execution steps. + + Args: + node: Recipe node to process + step_id_map: Mapping from processed node IDs to step IDs + step_name_map: Mapping from node IDs to human-readable names + input_hashes: User-provided input hashes + registry_hashes: Registry-resolved hashes + analysis: Analysis results + recipe_name: Name of the recipe (for generating readable names) + + Returns: + Tuple of (new steps, output step ID) + """ + # SOURCE nodes + if node.type == "SOURCE": + return self._process_source(node, input_hashes, registry_hashes, recipe_name) + + # SOURCE_LIST nodes + if node.type == "SOURCE_LIST": + return self._process_source_list(node, input_hashes, recipe_name) + + # ANALYZE nodes + if node.type == "ANALYZE": + return self._process_analyze(node, step_id_map, analysis, recipe_name) + + # MAP nodes + if node.type == "MAP": + return self._process_map(node, step_id_map, input_hashes, analysis, recipe_name) + + # SEQUENCE nodes (may use tree reduction) + if node.type == "SEQUENCE": + return self._process_sequence(node, step_id_map, recipe_name) + + # SEGMENT_AT nodes + if node.type == "SEGMENT_AT": + return self._process_segment_at(node, step_id_map, analysis, recipe_name) + + # Standard nodes (SEGMENT, RESIZE, TRANSFORM, LAYER, MUX, BLEND, etc.) + return self._process_standard(node, step_id_map, recipe_name) + + def _process_source( + self, + node: RecipeNode, + input_hashes: Dict[str, str], + registry_hashes: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """Process SOURCE node.""" + config = dict(node.config) + + # Variable input? + if config.get("input"): + # Look up in user-provided inputs + if node.id not in input_hashes: + raise ValueError(f"Missing input for SOURCE node '{node.id}'") + cid = input_hashes[node.id] + # Fixed asset from registry? + elif config.get("asset"): + asset_name = config["asset"] + if asset_name not in registry_hashes: + raise ValueError(f"Asset '{asset_name}' not found in registry") + cid = registry_hashes[asset_name] + else: + raise ValueError(f"SOURCE node '{node.id}' has no input or asset") + + # Human-readable name + display_name = config.get("name", node.id) + step_name = f"{recipe_name}.inputs.{display_name}" if recipe_name else display_name + + step = ExecutionStep( + step_id=node.id, + node_type="SOURCE", + config={"input_ref": node.id, "cid": cid}, + input_steps=[], + cache_id=cid, # SOURCE cache_id is just the content hash + name=step_name, + ) + + return [step], step.step_id + + def _process_source_list( + self, + node: RecipeNode, + input_hashes: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SOURCE_LIST node. + + Creates individual SOURCE steps for each item in the list. + """ + # Look for list input + if node.id not in input_hashes: + raise ValueError(f"Missing input for SOURCE_LIST node '{node.id}'") + + input_value = input_hashes[node.id] + + # Parse as comma-separated list if string + if isinstance(input_value, str): + items = [h.strip() for h in input_value.split(",")] + else: + items = list(input_value) + + display_name = node.config.get("name", node.id) + base_name = f"{recipe_name}.{display_name}" if recipe_name else display_name + + steps = [] + for i, cid in enumerate(items): + step = ExecutionStep( + step_id=f"{node.id}_{i}", + node_type="SOURCE", + config={"input_ref": f"{node.id}[{i}]", "cid": cid}, + input_steps=[], + cache_id=cid, + name=f"{base_name}[{i}]", + ) + steps.append(step) + + # Return list marker as output + list_step = ExecutionStep( + step_id=node.id, + node_type="_LIST", + config={"items": [s.step_id for s in steps]}, + input_steps=[s.step_id for s in steps], + name=f"{base_name}.list", + ) + steps.append(list_step) + + return steps, list_step.step_id + + def _process_analyze( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process ANALYZE node. + + ANALYZE nodes reference pre-computed analysis results. + """ + input_step = step_id_map.get(node.inputs[0]) if node.inputs else None + if not input_step: + raise ValueError(f"ANALYZE node '{node.id}' has no input") + + feature = node.config.get("feature", "all") + step_name = f"{recipe_name}.analysis.{feature}" if recipe_name else f"analysis.{feature}" + + step = ExecutionStep( + step_id=node.id, + node_type="ANALYZE", + config={ + "feature": feature, + **node.config, + }, + input_steps=[input_step], + name=step_name, + ) + + return [step], step.step_id + + def _process_map( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + input_hashes: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process MAP node - expand iteration over list. + + MAP applies an operation to each item in a list. + """ + operation = node.config.get("operation", "TRANSFORM") + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + + # Get items input + items_ref = node.config.get("items") or ( + node.inputs[0] if isinstance(node.inputs, list) else + node.inputs.get("items") if isinstance(node.inputs, dict) else None + ) + + if not items_ref: + raise ValueError(f"MAP node '{node.id}' has no items input") + + # Resolve items to list of step IDs + if items_ref in step_id_map: + # Reference to SOURCE_LIST output + items_step = step_id_map[items_ref] + # TODO: expand list items + logger.warning(f"MAP node '{node.id}' references list step, expansion TBD") + item_steps = [items_step] + else: + item_steps = [items_ref] + + # Generate step for each item + steps = [] + output_steps = [] + + for i, item_step in enumerate(item_steps): + step_id = f"{node.id}_{i}" + + if operation == "RANDOM_SLICE": + step = ExecutionStep( + step_id=step_id, + node_type="SEGMENT", + config={ + "random": True, + "seed_from": node.config.get("seed_from"), + "index": i, + }, + input_steps=[item_step], + name=f"{base_name}.slice[{i}]", + ) + elif operation == "TRANSFORM": + step = ExecutionStep( + step_id=step_id, + node_type="TRANSFORM", + config=node.config.get("effects", {}), + input_steps=[item_step], + name=f"{base_name}.transform[{i}]", + ) + elif operation == "ANALYZE": + step = ExecutionStep( + step_id=step_id, + node_type="ANALYZE", + config={"feature": node.config.get("feature", "all")}, + input_steps=[item_step], + name=f"{base_name}.analyze[{i}]", + ) + else: + step = ExecutionStep( + step_id=step_id, + node_type=operation, + config=node.config, + input_steps=[item_step], + name=f"{base_name}.{operation.lower()}[{i}]", + ) + + steps.append(step) + output_steps.append(step_id) + + # Create list output + list_step = ExecutionStep( + step_id=node.id, + node_type="_LIST", + config={"items": output_steps}, + input_steps=output_steps, + name=f"{base_name}.results", + ) + steps.append(list_step) + + return steps, list_step.step_id + + def _process_sequence( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SEQUENCE node. + + Uses tree reduction for parallel composition if enabled. + """ + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + + # Resolve input steps + input_steps = [] + for input_id in node.inputs: + if input_id in step_id_map: + input_steps.append(step_id_map[input_id]) + else: + input_steps.append(input_id) + + if len(input_steps) == 0: + raise ValueError(f"SEQUENCE node '{node.id}' has no inputs") + + if len(input_steps) == 1: + # Single input, no sequence needed + return [], input_steps[0] + + transition_config = node.config.get("transition", {"type": "cut"}) + config = {"transition": transition_config} + + if self.use_tree_reduction and len(input_steps) > 2: + # Use tree reduction + reduction_steps, output_id = reduce_sequence( + input_steps, + transition_config=config, + id_prefix=node.id, + ) + + steps = [] + for i, (step_id, inputs, step_config) in enumerate(reduction_steps): + step = ExecutionStep( + step_id=step_id, + node_type="SEQUENCE", + config=step_config, + input_steps=inputs, + name=f"{base_name}.reduce[{i}]", + ) + steps.append(step) + + return steps, output_id + else: + # Direct sequence + step = ExecutionStep( + step_id=node.id, + node_type="SEQUENCE", + config=config, + input_steps=input_steps, + name=f"{base_name}.concat", + ) + return [step], step.step_id + + def _process_segment_at( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SEGMENT_AT node - cut at specific times. + + Creates SEGMENT steps for each time range. + """ + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + times_from = node.config.get("times_from") + distribute = node.config.get("distribute", "round_robin") + + # TODO: Resolve times from analysis + # For now, create a placeholder + step = ExecutionStep( + step_id=node.id, + node_type="SEGMENT_AT", + config=node.config, + input_steps=[step_id_map.get(i, i) for i in node.inputs], + name=f"{base_name}.segment", + ) + + return [step], step.step_id + + def _process_standard( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """Process standard transformation/composition node.""" + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + input_steps = [step_id_map.get(i, i) for i in node.inputs] + + step = ExecutionStep( + step_id=node.id, + node_type=node.type, + config=node.config, + input_steps=input_steps, + name=f"{base_name}.{node.type.lower()}", + ) + + return [step], step.step_id diff --git a/artdag/planning/schema.py b/artdag/planning/schema.py new file mode 100644 index 0000000..9831d16 --- /dev/null +++ b/artdag/planning/schema.py @@ -0,0 +1,594 @@ +# artdag/planning/schema.py +""" +Data structures for execution plans. + +An ExecutionPlan contains all steps needed to execute a recipe, +with pre-computed cache IDs for each step. +""" + +import hashlib +import json +import os +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional + + +# Cluster key for trust domains +# Systems with the same key produce the same cache_ids and can share work +# Systems with different keys have isolated cache namespaces +CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY") + + +def set_cluster_key(key: Optional[str]) -> None: + """Set the cluster key programmatically.""" + global CLUSTER_KEY + CLUSTER_KEY = key + + +def get_cluster_key() -> Optional[str]: + """Get the current cluster key.""" + return CLUSTER_KEY + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """ + Create stable hash from arbitrary data. + + If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create + isolated trust domains. Systems with the same key can share work; + systems with different keys have separate cache namespaces. + """ + # Mix in cluster key if set + if CLUSTER_KEY: + data = {"_cluster_key": CLUSTER_KEY, "_data": data} + + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +class StepStatus(Enum): + """Status of an execution step.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + CACHED = "cached" + FAILED = "failed" + SKIPPED = "skipped" + + +@dataclass +class StepOutput: + """ + A single output from an execution step. + + Nodes may produce multiple outputs (e.g., split_on_beats produces N segments). + Each output has a human-readable name and a cache_id for storage. + + Attributes: + name: Human-readable name (e.g., "beats.split.segment[0]") + cache_id: Content-addressed hash for caching + media_type: MIME type of the output (e.g., "video/mp4", "audio/wav") + index: Output index for multi-output nodes + metadata: Optional additional metadata (time_range, etc.) + """ + name: str + cache_id: str + media_type: str = "application/octet-stream" + index: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "cache_id": self.cache_id, + "media_type": self.media_type, + "index": self.index, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepOutput": + return cls( + name=data["name"], + cache_id=data["cache_id"], + media_type=data.get("media_type", "application/octet-stream"), + index=data.get("index", 0), + metadata=data.get("metadata", {}), + ) + + +@dataclass +class StepInput: + """ + Reference to an input for a step. + + Inputs can reference outputs from other steps by name. + + Attributes: + name: Input slot name (e.g., "video", "audio", "segments") + source: Source output name (e.g., "beats.split.segment[0]") + cache_id: Resolved cache_id of the source (populated during planning) + """ + name: str + source: str + cache_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "source": self.source, + "cache_id": self.cache_id, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepInput": + return cls( + name=data["name"], + source=data["source"], + cache_id=data.get("cache_id"), + ) + + +@dataclass +class ExecutionStep: + """ + A single step in the execution plan. + + Each step has a pre-computed cache_id that uniquely identifies + its output based on its configuration and input cache_ids. + + Steps can produce multiple outputs (e.g., split_on_beats produces N segments). + Each output has its own cache_id derived from the step's cache_id + index. + + Attributes: + name: Human-readable name relating to recipe (e.g., "beats.split") + step_id: Unique identifier (hash) for this step + node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.) + config: Configuration for the primitive + input_steps: IDs of steps this depends on (legacy, use inputs for new code) + inputs: Structured input references with names and sources + cache_id: Pre-computed cache ID (hash of config + input cache_ids) + outputs: List of outputs this step produces + estimated_duration: Optional estimated execution time + level: Dependency level (0 = no dependencies, higher = more deps) + """ + step_id: str + node_type: str + config: Dict[str, Any] + input_steps: List[str] = field(default_factory=list) + inputs: List[StepInput] = field(default_factory=list) + cache_id: Optional[str] = None + outputs: List[StepOutput] = field(default_factory=list) + name: Optional[str] = None + estimated_duration: Optional[float] = None + level: int = 0 + + def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str: + """ + Compute cache ID from configuration and input cache IDs. + + cache_id = SHA3-256(node_type + config + sorted(input_cache_ids)) + + Args: + input_cache_ids: Mapping from input step_id/name to their cache_id + + Returns: + The computed cache_id + """ + # Use structured inputs if available, otherwise fall back to input_steps + if self.inputs: + resolved_inputs = [ + inp.cache_id or input_cache_ids.get(inp.source, inp.source) + for inp in sorted(self.inputs, key=lambda x: x.name) + ] + else: + resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)] + + content = { + "node_type": self.node_type, + "config": self.config, + "inputs": resolved_inputs, + } + self.cache_id = _stable_hash(content) + return self.cache_id + + def compute_output_cache_id(self, index: int) -> str: + """ + Compute cache ID for a specific output index. + + output_cache_id = SHA3-256(step_cache_id + index) + + Args: + index: The output index + + Returns: + Cache ID for that output + """ + if not self.cache_id: + raise ValueError("Step cache_id must be computed first") + content = {"step_cache_id": self.cache_id, "output_index": index} + return _stable_hash(content) + + def add_output( + self, + name: str, + media_type: str = "application/octet-stream", + index: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> StepOutput: + """ + Add an output to this step. + + Args: + name: Human-readable output name + media_type: MIME type of the output + index: Output index (defaults to next available) + metadata: Optional metadata + + Returns: + The created StepOutput + """ + if index is None: + index = len(self.outputs) + + cache_id = self.compute_output_cache_id(index) + output = StepOutput( + name=name, + cache_id=cache_id, + media_type=media_type, + index=index, + metadata=metadata or {}, + ) + self.outputs.append(output) + return output + + def get_output(self, index: int = 0) -> Optional[StepOutput]: + """Get output by index.""" + if index < len(self.outputs): + return self.outputs[index] + return None + + def get_output_by_name(self, name: str) -> Optional[StepOutput]: + """Get output by name.""" + for output in self.outputs: + if output.name == name: + return output + return None + + def to_dict(self) -> Dict[str, Any]: + return { + "step_id": self.step_id, + "name": self.name, + "node_type": self.node_type, + "config": self.config, + "input_steps": self.input_steps, + "inputs": [inp.to_dict() for inp in self.inputs], + "cache_id": self.cache_id, + "outputs": [out.to_dict() for out in self.outputs], + "estimated_duration": self.estimated_duration, + "level": self.level, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep": + inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])] + outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])] + return cls( + step_id=data["step_id"], + node_type=data["node_type"], + config=data.get("config", {}), + input_steps=data.get("input_steps", []), + inputs=inputs, + cache_id=data.get("cache_id"), + outputs=outputs, + name=data.get("name"), + estimated_duration=data.get("estimated_duration"), + level=data.get("level", 0), + ) + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> "ExecutionStep": + return cls.from_dict(json.loads(json_str)) + + +@dataclass +class PlanInput: + """ + An input to the execution plan. + + Attributes: + name: Human-readable name from recipe (e.g., "source_video") + cache_id: Content hash of the input file + cid: Same as cache_id (for clarity) + media_type: MIME type of the input + """ + name: str + cache_id: str + cid: str + media_type: str = "application/octet-stream" + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "cache_id": self.cache_id, + "cid": self.cid, + "media_type": self.media_type, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PlanInput": + return cls( + name=data["name"], + cache_id=data["cache_id"], + cid=data.get("cid", data["cache_id"]), + media_type=data.get("media_type", "application/octet-stream"), + ) + + +@dataclass +class ExecutionPlan: + """ + Complete execution plan for a recipe. + + Contains all steps in topological order with pre-computed cache IDs. + The plan is deterministic: same recipe + same inputs = same plan. + + Attributes: + name: Human-readable plan name from recipe + plan_id: Hash of the entire plan (for deduplication) + recipe_id: Source recipe identifier + recipe_name: Human-readable recipe name + recipe_hash: Hash of the recipe content + seed: Random seed used for planning + steps: List of steps in execution order + output_step: ID of the final output step + output_name: Human-readable name of the final output + inputs: Structured input definitions + analysis_cache_ids: Cache IDs of analysis results used + input_hashes: Content hashes of input files (legacy, use inputs) + created_at: When the plan was generated + metadata: Optional additional metadata + """ + plan_id: Optional[str] + recipe_id: str + recipe_hash: str + steps: List[ExecutionStep] + output_step: str + name: Optional[str] = None + recipe_name: Optional[str] = None + seed: Optional[int] = None + output_name: Optional[str] = None + inputs: List[PlanInput] = field(default_factory=list) + analysis_cache_ids: Dict[str, str] = field(default_factory=dict) + input_hashes: Dict[str, str] = field(default_factory=dict) + created_at: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if self.created_at is None: + self.created_at = datetime.now(timezone.utc).isoformat() + if self.plan_id is None: + self.plan_id = self._compute_plan_id() + + def _compute_plan_id(self) -> str: + """Compute plan ID from contents.""" + content = { + "recipe_hash": self.recipe_hash, + "steps": [s.to_dict() for s in self.steps], + "input_hashes": self.input_hashes, + "analysis_cache_ids": self.analysis_cache_ids, + } + return _stable_hash(content) + + def compute_all_cache_ids(self) -> None: + """ + Compute cache IDs for all steps in dependency order. + + Must be called after all steps are added to ensure + cache IDs propagate correctly through dependencies. + """ + # Build step lookup + step_by_id = {s.step_id: s for s in self.steps} + + # Cache IDs start with input hashes + cache_ids = dict(self.input_hashes) + + # Process in order (assumes topological order) + for step in self.steps: + # For SOURCE steps referencing inputs, use input hash + if step.node_type == "SOURCE" and step.config.get("input_ref"): + ref = step.config["input_ref"] + if ref in self.input_hashes: + step.cache_id = self.input_hashes[ref] + cache_ids[step.step_id] = step.cache_id + continue + + # For other steps, compute from inputs + input_cache_ids = {} + for input_step_id in step.input_steps: + if input_step_id in cache_ids: + input_cache_ids[input_step_id] = cache_ids[input_step_id] + elif input_step_id in step_by_id: + # Step should have been processed already + input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id + else: + raise ValueError(f"Input step {input_step_id} not found for {step.step_id}") + + step.compute_cache_id(input_cache_ids) + cache_ids[step.step_id] = step.cache_id + + # Recompute plan_id with final cache IDs + self.plan_id = self._compute_plan_id() + + def compute_levels(self) -> int: + """ + Compute dependency levels for all steps. + + Level 0 = no dependencies (can start immediately) + Level N = depends on steps at level N-1 + + Returns: + Maximum level (number of sequential dependency levels) + """ + step_by_id = {s.step_id: s for s in self.steps} + levels = {} + + def compute_level(step_id: str) -> int: + if step_id in levels: + return levels[step_id] + + step = step_by_id.get(step_id) + if step is None: + return 0 # Input from outside the plan + + if not step.input_steps: + levels[step_id] = 0 + step.level = 0 + return 0 + + max_input_level = max(compute_level(s) for s in step.input_steps) + level = max_input_level + 1 + levels[step_id] = level + step.level = level + return level + + for step in self.steps: + compute_level(step.step_id) + + return max(levels.values()) if levels else 0 + + def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]: + """ + Group steps by dependency level. + + Steps at the same level can execute in parallel. + + Returns: + Dict mapping level -> list of steps at that level + """ + by_level: Dict[int, List[ExecutionStep]] = {} + for step in self.steps: + by_level.setdefault(step.level, []).append(step) + return by_level + + def get_step(self, step_id: str) -> Optional[ExecutionStep]: + """Get step by ID.""" + for step in self.steps: + if step.step_id == step_id: + return step + return None + + def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]: + """Get step by cache ID.""" + for step in self.steps: + if step.cache_id == cache_id: + return step + return None + + def get_step_by_name(self, name: str) -> Optional[ExecutionStep]: + """Get step by human-readable name.""" + for step in self.steps: + if step.name == name: + return step + return None + + def get_all_outputs(self) -> Dict[str, StepOutput]: + """ + Get all outputs from all steps, keyed by output name. + + Returns: + Dict mapping output name -> StepOutput + """ + outputs = {} + for step in self.steps: + for output in step.outputs: + outputs[output.name] = output + return outputs + + def get_output_cache_ids(self) -> Dict[str, str]: + """ + Get mapping of output names to cache IDs. + + Returns: + Dict mapping output name -> cache_id + """ + return { + output.name: output.cache_id + for step in self.steps + for output in step.outputs + } + + def to_dict(self) -> Dict[str, Any]: + return { + "plan_id": self.plan_id, + "name": self.name, + "recipe_id": self.recipe_id, + "recipe_name": self.recipe_name, + "recipe_hash": self.recipe_hash, + "seed": self.seed, + "inputs": [i.to_dict() for i in self.inputs], + "steps": [s.to_dict() for s in self.steps], + "output_step": self.output_step, + "output_name": self.output_name, + "analysis_cache_ids": self.analysis_cache_ids, + "input_hashes": self.input_hashes, + "created_at": self.created_at, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan": + inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])] + return cls( + plan_id=data.get("plan_id"), + name=data.get("name"), + recipe_id=data["recipe_id"], + recipe_name=data.get("recipe_name"), + recipe_hash=data["recipe_hash"], + seed=data.get("seed"), + inputs=inputs, + steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])], + output_step=data["output_step"], + output_name=data.get("output_name"), + analysis_cache_ids=data.get("analysis_cache_ids", {}), + input_hashes=data.get("input_hashes", {}), + created_at=data.get("created_at"), + metadata=data.get("metadata", {}), + ) + + def to_json(self, indent: int = 2) -> str: + return json.dumps(self.to_dict(), indent=indent) + + @classmethod + def from_json(cls, json_str: str) -> "ExecutionPlan": + return cls.from_dict(json.loads(json_str)) + + def summary(self) -> str: + """Get a human-readable summary of the plan.""" + by_level = self.get_steps_by_level() + max_level = max(by_level.keys()) if by_level else 0 + + lines = [ + f"Execution Plan: {self.plan_id[:16]}...", + f"Recipe: {self.recipe_id}", + f"Steps: {len(self.steps)}", + f"Levels: {max_level + 1}", + "", + ] + + for level in sorted(by_level.keys()): + steps = by_level[level] + lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)") + for step in steps: + cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]" + lines.append(f" - {step.step_id}: {step.node_type} {cache_status}") + + return "\n".join(lines) diff --git a/artdag/planning/tree_reduction.py b/artdag/planning/tree_reduction.py new file mode 100644 index 0000000..3ab4147 --- /dev/null +++ b/artdag/planning/tree_reduction.py @@ -0,0 +1,231 @@ +# artdag/planning/tree_reduction.py +""" +Tree reduction for parallel composition. + +Instead of sequential pairwise composition: + A → AB → ABC → ABCD (3 sequential steps) + +Use parallel tree reduction: + A ─┬─ AB ─┬─ ABCD + B ─┘ │ + C ─┬─ CD ─┘ + D ─┘ + +This reduces O(N) to O(log N) levels of sequential dependency. +""" + +import math +from dataclasses import dataclass +from typing import List, Tuple, Any, Dict + + +@dataclass +class ReductionNode: + """A node in the reduction tree.""" + node_id: str + input_ids: List[str] + level: int + position: int # Position within level + + +class TreeReducer: + """ + Generates tree reduction plans for parallel composition. + + Used to convert N inputs into optimal parallel SEQUENCE operations. + """ + + def __init__(self, node_type: str = "SEQUENCE"): + """ + Initialize the reducer. + + Args: + node_type: The composition node type (SEQUENCE, AUDIO_MIX, etc.) + """ + self.node_type = node_type + + def reduce( + self, + input_ids: List[str], + id_prefix: str = "reduce", + ) -> Tuple[List[ReductionNode], str]: + """ + Generate a tree reduction plan for the given inputs. + + Args: + input_ids: List of input step IDs to reduce + id_prefix: Prefix for generated node IDs + + Returns: + Tuple of (list of reduction nodes, final output node ID) + """ + if len(input_ids) == 0: + raise ValueError("Cannot reduce empty input list") + + if len(input_ids) == 1: + # Single input, no reduction needed + return [], input_ids[0] + + if len(input_ids) == 2: + # Two inputs, single reduction + node_id = f"{id_prefix}_final" + node = ReductionNode( + node_id=node_id, + input_ids=input_ids, + level=0, + position=0, + ) + return [node], node_id + + # Build tree levels + nodes = [] + current_level = list(input_ids) + level_num = 0 + + while len(current_level) > 1: + next_level = [] + position = 0 + + # Pair up nodes at current level + i = 0 + while i < len(current_level): + if i + 1 < len(current_level): + # Pair available + left = current_level[i] + right = current_level[i + 1] + node_id = f"{id_prefix}_L{level_num}_P{position}" + node = ReductionNode( + node_id=node_id, + input_ids=[left, right], + level=level_num, + position=position, + ) + nodes.append(node) + next_level.append(node_id) + i += 2 + else: + # Odd one out, promote to next level + next_level.append(current_level[i]) + i += 1 + + position += 1 + + current_level = next_level + level_num += 1 + + # The last remaining node is the output + output_id = current_level[0] + + # Rename final node for clarity + if nodes and nodes[-1].node_id == output_id: + nodes[-1].node_id = f"{id_prefix}_final" + output_id = f"{id_prefix}_final" + + return nodes, output_id + + def get_reduction_depth(self, n: int) -> int: + """ + Calculate the number of reduction levels needed. + + Args: + n: Number of inputs + + Returns: + Number of sequential reduction levels (log2(n) ceiling) + """ + if n <= 1: + return 0 + return math.ceil(math.log2(n)) + + def get_total_operations(self, n: int) -> int: + """ + Calculate total number of reduction operations. + + Args: + n: Number of inputs + + Returns: + Total composition operations (always n-1) + """ + return max(0, n - 1) + + def reduce_with_config( + self, + input_ids: List[str], + base_config: Dict[str, Any], + id_prefix: str = "reduce", + ) -> Tuple[List[Tuple[ReductionNode, Dict[str, Any]]], str]: + """ + Generate reduction plan with configuration for each node. + + Args: + input_ids: List of input step IDs + base_config: Base configuration to use for each reduction + id_prefix: Prefix for generated node IDs + + Returns: + Tuple of (list of (node, config) pairs, final output ID) + """ + nodes, output_id = self.reduce(input_ids, id_prefix) + result = [(node, dict(base_config)) for node in nodes] + return result, output_id + + +def reduce_sequence( + input_ids: List[str], + transition_config: Dict[str, Any] = None, + id_prefix: str = "seq", +) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]: + """ + Convenience function for SEQUENCE reduction. + + Args: + input_ids: Input step IDs to sequence + transition_config: Transition configuration (default: cut) + id_prefix: Prefix for generated step IDs + + Returns: + Tuple of (list of (step_id, inputs, config), final step ID) + """ + if transition_config is None: + transition_config = {"transition": {"type": "cut"}} + + reducer = TreeReducer("SEQUENCE") + nodes, output_id = reducer.reduce(input_ids, id_prefix) + + result = [ + (node.node_id, node.input_ids, dict(transition_config)) + for node in nodes + ] + + return result, output_id + + +def reduce_audio_mix( + input_ids: List[str], + mix_config: Dict[str, Any] = None, + id_prefix: str = "mix", +) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]: + """ + Convenience function for AUDIO_MIX reduction. + + Args: + input_ids: Input step IDs to mix + mix_config: Mix configuration + id_prefix: Prefix for generated step IDs + + Returns: + Tuple of (list of (step_id, inputs, config), final step ID) + """ + if mix_config is None: + mix_config = {"normalize": True} + + reducer = TreeReducer("AUDIO_MIX") + nodes, output_id = reducer.reduce(input_ids, id_prefix) + + result = [ + (node.node_id, node.input_ids, dict(mix_config)) + for node in nodes + ] + + return result, output_id diff --git a/artdag/registry/__init__.py b/artdag/registry/__init__.py new file mode 100644 index 0000000..3163387 --- /dev/null +++ b/artdag/registry/__init__.py @@ -0,0 +1,20 @@ +# primitive/registry/__init__.py +""" +Art DAG Registry. + +The registry is the foundational data structure that maps named assets +to their source paths or content-addressed IDs. Assets in the registry +can be referenced by DAGs. + +Example: + registry = Registry("/path/to/registry") + registry.add("cat", "/path/to/cat.jpg", tags=["animal", "photo"]) + + # Later, in a DAG: + builder = DAGBuilder() + cat = builder.source(registry.get("cat").path) +""" + +from .registry import Registry, Asset + +__all__ = ["Registry", "Asset"] diff --git a/artdag/registry/registry.py b/artdag/registry/registry.py new file mode 100644 index 0000000..3290411 --- /dev/null +++ b/artdag/registry/registry.py @@ -0,0 +1,294 @@ +# primitive/registry/registry.py +""" +Asset registry for the Art DAG. + +The registry stores named assets with metadata, enabling: +- Named references to source files +- Tagging and categorization +- Content-addressed deduplication +- Asset discovery and search +""" + +import hashlib +import json +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +def _file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """ + Compute content hash of a file. + + Uses SHA-3 (Keccak) by default for quantum resistance. + SHA-3-256 provides 128-bit security against quantum attacks (Grover's algorithm). + + Args: + path: File to hash + algorithm: Hash algorithm (sha3_256, sha3_512, sha256, blake2b) + + Returns: + Full hex digest (no truncation) + """ + hasher = hashlib.new(algorithm) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class Asset: + """ + A registered asset in the Art DAG. + + The cid is the true identifier. URL and local_path are + locations where the content can be fetched. + + Attributes: + name: Unique name for the asset + cid: SHA-3-256 hash - the canonical identifier + url: Public URL (canonical location) + local_path: Optional local path (for local execution) + asset_type: Type of asset (image, video, audio, etc.) + tags: List of tags for categorization + metadata: Additional metadata (dimensions, duration, etc.) + created_at: Timestamp when added to registry + """ + name: str + cid: str + url: Optional[str] = None + local_path: Optional[Path] = None + asset_type: str = "unknown" + tags: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=time.time) + + @property + def path(self) -> Optional[Path]: + """Backwards compatible path property.""" + return self.local_path + + def to_dict(self) -> Dict[str, Any]: + data = { + "name": self.name, + "cid": self.cid, + "asset_type": self.asset_type, + "tags": self.tags, + "metadata": self.metadata, + "created_at": self.created_at, + } + if self.url: + data["url"] = self.url + if self.local_path: + data["local_path"] = str(self.local_path) + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Asset": + local_path = data.get("local_path") or data.get("path") # backwards compat + return cls( + name=data["name"], + cid=data["cid"], + url=data.get("url"), + local_path=Path(local_path) if local_path else None, + asset_type=data.get("asset_type", "unknown"), + tags=data.get("tags", []), + metadata=data.get("metadata", {}), + created_at=data.get("created_at", time.time()), + ) + + +class Registry: + """ + The Art DAG registry. + + Stores named assets that can be referenced by DAGs. + + Structure: + registry_dir/ + registry.json # Index of all assets + assets/ # Optional: copied asset files + / + + """ + + def __init__(self, registry_dir: Path | str, copy_assets: bool = False): + """ + Initialize the registry. + + Args: + registry_dir: Directory to store registry data + copy_assets: If True, copy assets into registry (content-addressed) + """ + self.registry_dir = Path(registry_dir) + self.registry_dir.mkdir(parents=True, exist_ok=True) + self.copy_assets = copy_assets + self._assets: Dict[str, Asset] = {} + self._load() + + def _index_path(self) -> Path: + return self.registry_dir / "registry.json" + + def _assets_dir(self) -> Path: + return self.registry_dir / "assets" + + def _load(self): + """Load registry from disk.""" + index_path = self._index_path() + if index_path.exists(): + with open(index_path) as f: + data = json.load(f) + self._assets = { + name: Asset.from_dict(asset_data) + for name, asset_data in data.get("assets", {}).items() + } + + def _save(self): + """Save registry to disk.""" + data = { + "version": "1.0", + "assets": {name: asset.to_dict() for name, asset in self._assets.items()}, + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def add( + self, + name: str, + cid: str, + url: str = None, + local_path: Path | str = None, + asset_type: str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> Asset: + """ + Add an asset to the registry. + + Args: + name: Unique name for the asset + cid: SHA-3-256 hash of the content (the canonical identifier) + url: Public URL where the asset can be fetched + local_path: Optional local path (for local execution) + asset_type: Type of asset (image, video, audio, etc.) + tags: List of tags for categorization + metadata: Additional metadata + + Returns: + The created Asset + """ + # Auto-detect asset type from URL or path extension + if asset_type is None: + ext = None + if url: + ext = Path(url.split("?")[0]).suffix.lower() + elif local_path: + ext = Path(local_path).suffix.lower() + if ext: + type_map = { + ".jpg": "image", ".jpeg": "image", ".png": "image", + ".gif": "image", ".webp": "image", ".bmp": "image", + ".mp4": "video", ".mkv": "video", ".avi": "video", + ".mov": "video", ".webm": "video", + ".mp3": "audio", ".wav": "audio", ".flac": "audio", + ".ogg": "audio", ".aac": "audio", + } + asset_type = type_map.get(ext, "unknown") + else: + asset_type = "unknown" + + asset = Asset( + name=name, + cid=cid, + url=url, + local_path=Path(local_path).resolve() if local_path else None, + asset_type=asset_type, + tags=tags or [], + metadata=metadata or {}, + ) + + self._assets[name] = asset + self._save() + return asset + + def add_from_file( + self, + name: str, + path: Path | str, + url: str = None, + asset_type: str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> Asset: + """ + Add an asset from a local file (computes hash automatically). + + Args: + name: Unique name for the asset + path: Path to the source file + url: Optional public URL + asset_type: Type of asset (auto-detected if not provided) + tags: List of tags for categorization + metadata: Additional metadata + + Returns: + The created Asset + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Asset file not found: {path}") + + cid = _file_hash(path) + + return self.add( + name=name, + cid=cid, + url=url, + local_path=path, + asset_type=asset_type, + tags=tags, + metadata=metadata, + ) + + def get(self, name: str) -> Optional[Asset]: + """Get an asset by name.""" + return self._assets.get(name) + + def remove(self, name: str) -> bool: + """Remove an asset from the registry.""" + if name not in self._assets: + return False + del self._assets[name] + self._save() + return True + + def list(self) -> List[Asset]: + """List all assets.""" + return list(self._assets.values()) + + def find_by_tag(self, tag: str) -> List[Asset]: + """Find assets with a specific tag.""" + return [a for a in self._assets.values() if tag in a.tags] + + def find_by_type(self, asset_type: str) -> List[Asset]: + """Find assets of a specific type.""" + return [a for a in self._assets.values() if a.asset_type == asset_type] + + def find_by_hash(self, cid: str) -> Optional[Asset]: + """Find an asset by content hash.""" + for asset in self._assets.values(): + if asset.cid == cid: + return asset + return None + + def __contains__(self, name: str) -> bool: + return name in self._assets + + def __len__(self) -> int: + return len(self._assets) + + def __iter__(self): + return iter(self._assets.values()) diff --git a/artdag/server.py b/artdag/server.py new file mode 100644 index 0000000..f10374c --- /dev/null +++ b/artdag/server.py @@ -0,0 +1,253 @@ +# primitive/server.py +""" +HTTP server for primitive execution engine. + +Provides a REST API for submitting DAGs and retrieving results. + +Endpoints: + POST /execute - Submit DAG for execution + GET /status/:id - Get execution status + GET /result/:id - Get execution result + GET /cache/stats - Get cache statistics + DELETE /cache - Clear cache +""" + +import json +import logging +import threading +import uuid +from dataclasses import dataclass, field +from http.server import HTTPServer, BaseHTTPRequestHandler +from pathlib import Path +from typing import Any, Dict, Optional +from urllib.parse import urlparse + +from .dag import DAG +from .engine import Engine, ExecutionResult +from . import nodes # Register built-in executors + +logger = logging.getLogger(__name__) + + +@dataclass +class Job: + """A pending or completed execution job.""" + job_id: str + dag: DAG + status: str = "pending" # pending, running, completed, failed + result: Optional[ExecutionResult] = None + error: Optional[str] = None + + +class PrimitiveServer: + """ + HTTP server for the primitive engine. + + Usage: + server = PrimitiveServer(cache_dir="/tmp/primitive_cache", port=8080) + server.start() # Blocking + """ + + def __init__(self, cache_dir: Path | str, host: str = "127.0.0.1", port: int = 8080): + self.cache_dir = Path(cache_dir) + self.host = host + self.port = port + self.engine = Engine(self.cache_dir) + self.jobs: Dict[str, Job] = {} + self._lock = threading.Lock() + + def submit_job(self, dag: DAG) -> str: + """Submit a DAG for execution, return job ID.""" + job_id = str(uuid.uuid4())[:8] + job = Job(job_id=job_id, dag=dag) + + with self._lock: + self.jobs[job_id] = job + + # Execute in background thread + thread = threading.Thread(target=self._execute_job, args=(job_id,)) + thread.daemon = True + thread.start() + + return job_id + + def _execute_job(self, job_id: str): + """Execute a job in background.""" + with self._lock: + job = self.jobs.get(job_id) + if not job: + return + job.status = "running" + + try: + result = self.engine.execute(job.dag) + with self._lock: + job.result = result + job.status = "completed" if result.success else "failed" + if not result.success: + job.error = result.error + except Exception as e: + logger.exception(f"Job {job_id} failed") + with self._lock: + job.status = "failed" + job.error = str(e) + + def get_job(self, job_id: str) -> Optional[Job]: + """Get job by ID.""" + with self._lock: + return self.jobs.get(job_id) + + def _create_handler(server_instance): + """Create request handler with access to server instance.""" + + class RequestHandler(BaseHTTPRequestHandler): + server_ref = server_instance + + def log_message(self, format, *args): + logger.debug(format % args) + + def _send_json(self, data: Any, status: int = 200): + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(data).encode()) + + def _send_error(self, message: str, status: int = 400): + self._send_json({"error": message}, status) + + def do_GET(self): + parsed = urlparse(self.path) + path = parsed.path + + if path.startswith("/status/"): + job_id = path[8:] + job = self.server_ref.get_job(job_id) + if not job: + self._send_error("Job not found", 404) + return + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "error": job.error, + }) + + elif path.startswith("/result/"): + job_id = path[8:] + job = self.server_ref.get_job(job_id) + if not job: + self._send_error("Job not found", 404) + return + if job.status == "pending" or job.status == "running": + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "ready": False, + }) + return + + result = job.result + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "ready": True, + "success": result.success if result else False, + "output_path": str(result.output_path) if result and result.output_path else None, + "error": job.error, + "execution_time": result.execution_time if result else 0, + "nodes_executed": result.nodes_executed if result else 0, + "nodes_cached": result.nodes_cached if result else 0, + }) + + elif path == "/cache/stats": + stats = self.server_ref.engine.get_cache_stats() + self._send_json({ + "total_entries": stats.total_entries, + "total_size_bytes": stats.total_size_bytes, + "hits": stats.hits, + "misses": stats.misses, + "hit_rate": stats.hit_rate, + }) + + elif path == "/health": + self._send_json({"status": "ok"}) + + else: + self._send_error("Not found", 404) + + def do_POST(self): + if self.path == "/execute": + try: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length).decode() + data = json.loads(body) + + dag = DAG.from_dict(data) + job_id = self.server_ref.submit_job(dag) + + self._send_json({ + "job_id": job_id, + "status": "pending", + }) + except json.JSONDecodeError as e: + self._send_error(f"Invalid JSON: {e}") + except Exception as e: + self._send_error(str(e), 500) + else: + self._send_error("Not found", 404) + + def do_DELETE(self): + if self.path == "/cache": + self.server_ref.engine.clear_cache() + self._send_json({"status": "cleared"}) + else: + self._send_error("Not found", 404) + + return RequestHandler + + def start(self): + """Start the HTTP server (blocking).""" + handler = self._create_handler() + server = HTTPServer((self.host, self.port), handler) + logger.info(f"Primitive server starting on {self.host}:{self.port}") + print(f"Primitive server running on http://{self.host}:{self.port}") + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nShutting down...") + server.shutdown() + + def start_background(self) -> threading.Thread: + """Start the server in a background thread.""" + thread = threading.Thread(target=self.start) + thread.daemon = True + thread.start() + return thread + + +def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="Primitive execution server") + parser.add_argument("--host", default="127.0.0.1", help="Host to bind to") + parser.add_argument("--port", type=int, default=8080, help="Port to bind to") + parser.add_argument("--cache-dir", default="/tmp/primitive_cache", help="Cache directory") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") + + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + server = PrimitiveServer( + cache_dir=args.cache_dir, + host=args.host, + port=args.port, + ) + server.start() + + +if __name__ == "__main__": + main() diff --git a/artdag/sexp/__init__.py b/artdag/sexp/__init__.py new file mode 100644 index 0000000..08b646f --- /dev/null +++ b/artdag/sexp/__init__.py @@ -0,0 +1,75 @@ +""" +S-expression parsing, compilation, and planning for ArtDAG. + +This module provides: +- parser: Parse S-expression text into Python data structures +- compiler: Compile recipe S-expressions into DAG format +- planner: Generate execution plans from recipes +""" + +from .parser import ( + parse, + parse_all, + serialize, + Symbol, + Keyword, + ParseError, +) + +from .compiler import ( + compile_recipe, + compile_string, + CompiledRecipe, + CompileError, + ParamDef, + _parse_params, +) + +from .planner import ( + create_plan, + ExecutionPlanSexp, + PlanStep, + step_to_task_sexp, + task_cache_id, +) + +from .scheduler import ( + PlanScheduler, + PlanResult, + StepResult, + schedule_plan, + step_to_sexp, + step_sexp_to_string, + verify_step_cache_id, +) + +__all__ = [ + # Parser + 'parse', + 'parse_all', + 'serialize', + 'Symbol', + 'Keyword', + 'ParseError', + # Compiler + 'compile_recipe', + 'compile_string', + 'CompiledRecipe', + 'CompileError', + 'ParamDef', + '_parse_params', + # Planner + 'create_plan', + 'ExecutionPlanSexp', + 'PlanStep', + 'step_to_task_sexp', + 'task_cache_id', + # Scheduler + 'PlanScheduler', + 'PlanResult', + 'StepResult', + 'schedule_plan', + 'step_to_sexp', + 'step_sexp_to_string', + 'verify_step_cache_id', +] diff --git a/artdag/sexp/compiler.py b/artdag/sexp/compiler.py new file mode 100644 index 0000000..9729312 --- /dev/null +++ b/artdag/sexp/compiler.py @@ -0,0 +1,2463 @@ +""" +Compiler for S-expression recipes. + +Transforms S-expression recipes into internal DAG format. +Handles: +- Threading macro expansion (->) +- def bindings for named nodes +- Registry resolution (assets, effects) +- Node ID generation (content-addressed) +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple +import hashlib +import json + +from .parser import Symbol, Keyword, Lambda, parse, serialize +from pathlib import Path + + +def compute_content_cid(content: str) -> str: + """Compute content-addressed ID (SHA256 hash) for content. + + This is used for effects, recipes, and other text content that + will be stored in the cache. The cid can be used to fetch the + content from cache or IPFS. + """ + return hashlib.sha256(content.encode()).hexdigest() + + +def compute_file_cid(file_path: Path) -> str: + """Compute content-addressed ID for a file. + + Args: + file_path: Path to the file + + Returns: + SHA3-256 hash of file contents + """ + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + content = file_path.read_text() + return compute_content_cid(content) + + +def _serialize_for_hash(obj) -> str: + """Serialize any value to canonical S-expression string for hashing.""" + if obj is None: + return "nil" + if isinstance(obj, bool): + return "true" if obj else "false" + if isinstance(obj, (int, float)): + return str(obj) + if isinstance(obj, str): + escaped = obj.replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped}"' + if isinstance(obj, Symbol): + return obj.name + if isinstance(obj, Keyword): + return f":{obj.name}" + if isinstance(obj, Lambda): + params = " ".join(obj.params) + body = _serialize_for_hash(obj.body) + return f"(fn [{params}] {body})" + if isinstance(obj, dict): + items = [] + for k, v in sorted(obj.items()): + items.append(f":{k} {_serialize_for_hash(v)}") + return "{" + " ".join(items) + "}" + if isinstance(obj, list): + items = [_serialize_for_hash(x) for x in obj] + return "(" + " ".join(items) + ")" + return str(obj) + + +class CompileError(Exception): + """Error during recipe compilation.""" + pass + + +@dataclass +class ParamDef: + """Definition of a recipe parameter.""" + name: str + param_type: str # "string", "int", "float", "bool" + default: Any + description: str = "" + range_min: Optional[float] = None + range_max: Optional[float] = None + choices: Optional[List[str]] = None # For enum-like params + + +@dataclass +class CompiledStage: + """A compiled stage with dependencies and outputs.""" + name: str + requires: List[str] # Names of required stages + inputs: List[str] # Names of bindings consumed from required stages + outputs: List[str] # Names of bindings produced by this stage + node_ids: List[str] # Node IDs created in this stage + output_bindings: Dict[str, str] # output_name -> node_id mapping + + +@dataclass +class CompiledRecipe: + """Result of compiling an S-expression recipe.""" + name: str + version: str + description: str + owner: Optional[str] + registry: Dict[str, Dict[str, Any]] # {assets: {...}, effects: {...}} + nodes: List[Dict[str, Any]] # List of node definitions + output_node_id: str + encoding: Dict[str, Any] = field(default_factory=dict) # {codec, crf, preset, audio_codec} + metadata: Dict[str, Any] = field(default_factory=dict) + params: List[ParamDef] = field(default_factory=list) # Declared parameters + stages: List[CompiledStage] = field(default_factory=list) # Compiled stages + stage_order: List[str] = field(default_factory=list) # Topologically sorted stage names + minimal_primitives: bool = False # If True, only core primitives available + source_text: str = "" # Original source text for stable hashing + resolved_params: Dict[str, Any] = field(default_factory=dict) # Resolved parameter values + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format (compatible with YAML structure).""" + return { + "name": self.name, + "version": self.version, + "description": self.description, + "owner": self.owner, + "registry": self.registry, + "dag": { + "nodes": self.nodes, + "output": self.output_node_id, + }, + "encoding": self.encoding, + "metadata": self.metadata, + } + + +@dataclass +class CompilerContext: + """Compilation context tracking bindings and nodes.""" + registry: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {"assets": {}, "effects": {}, "analyzers": {}, "constructs": {}, "templates": {}, "includes": {}}) + template_call_count: int = 0 + bindings: Dict[str, str] = field(default_factory=dict) # name -> node_id + nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict) # node_id -> node + + # Recipe directory for resolving relative paths + recipe_dir: Optional[Path] = None + + # Stage tracking + current_stage: Optional[str] = None # Name of stage currently being compiled + defined_stages: Dict[str, 'CompiledStage'] = field(default_factory=dict) # stage_name -> CompiledStage + stage_bindings: Dict[str, Dict[str, str]] = field(default_factory=dict) # stage_name -> {binding_name -> node_id} + pre_stage_bindings: Dict[str, Any] = field(default_factory=dict) # bindings defined before any stage + stage_node_ids: List[str] = field(default_factory=list) # node IDs created in current stage + + def add_node(self, node_type: str, config: Dict[str, Any], + inputs: List[str] = None, name: str = None) -> str: + """ + Add a node and return its code-addressed ID. + + The node_id is a hash of the S-expression subtree (type, config, inputs), + creating a Merkle-tree like a blockchain - each node's hash includes all + upstream hashes. This is computed purely from the plan, before execution. + + The node_id is a pre-computed "bucket" where the computation result will + be stored. Same plan = same buckets = automatic cache reuse. + """ + # Build canonical S-expression for hashing + # Inputs are already code-addressed node IDs (hashes) + canonical = { + "type": node_type, + "config": config, + "inputs": inputs or [], + } + # Hash the canonical S-expression form using SHA3-256 + canonical_sexp = _serialize_for_hash(canonical) + node_id = hashlib.sha3_256(canonical_sexp.encode()).hexdigest() + + # Check for collision (same hash = same computation, reuse) + if node_id in self.nodes: + return node_id + + self.nodes[node_id] = { + "id": node_id, + "type": node_type, + "config": config, + "inputs": inputs or [], + "name": name, + } + + # Track node in current stage + if self.current_stage is not None: + self.stage_node_ids.append(node_id) + + return node_id + + def get_accessible_bindings(self, stage_inputs: List[str] = None) -> Dict[str, Any]: + """ + Get bindings accessible to the current stage. + + If inside a stage with declared inputs, only those inputs plus pre-stage + bindings are accessible. If outside a stage, all bindings are accessible. + """ + if self.current_stage is None: + return dict(self.bindings) + + # Start with pre-stage bindings (sources, etc.) + accessible = dict(self.pre_stage_bindings) + + # Add declared inputs from required stages + if stage_inputs: + for input_name in stage_inputs: + # Look for the binding in required stages + for stage_name, stage in self.defined_stages.items(): + if input_name in stage.output_bindings: + accessible[input_name] = stage.output_bindings[input_name] + break + else: + # Check if it's in pre-stage bindings (might be a source) + if input_name not in accessible: + raise CompileError( + f"Stage '{self.current_stage}' declares input '{input_name}' " + f"but it's not produced by any required stage" + ) + + return accessible + + +def _topological_sort_stages(stages: Dict[str, 'CompiledStage']) -> List[str]: + """ + Topologically sort stages by their dependencies. + + Returns list of stage names in execution order (dependencies first). + """ + if not stages: + return [] + + # Build dependency graph + in_degree = {name: 0 for name in stages} + dependents = {name: [] for name in stages} + + for name, stage in stages.items(): + for req in stage.requires: + if req in stages: + dependents[req].append(name) + in_degree[name] += 1 + + # Kahn's algorithm + queue = [name for name, degree in in_degree.items() if degree == 0] + result = [] + + while queue: + # Sort for deterministic ordering + queue.sort() + current = queue.pop(0) + result.append(current) + + for dependent in dependents[current]: + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + if len(result) != len(stages): + # This shouldn't happen if we validated cycles earlier + missing = set(stages.keys()) - set(result) + raise CompileError(f"Circular stage dependency detected: {missing}") + + return result + + +def _parse_encoding(value: Any) -> Dict[str, Any]: + """ + Parse encoding settings from S-expression. + + Expects a list like: (:codec "libx264" :crf 18 :preset "fast" :audio-codec "aac") + Returns: {"codec": "libx264", "crf": 18, "preset": "fast", "audio_codec": "aac"} + """ + if not isinstance(value, list): + raise CompileError(f"Encoding must be a list, got {type(value).__name__}") + + result = {} + i = 0 + while i < len(value): + item = value[i] + if isinstance(item, Keyword): + if i + 1 >= len(value): + raise CompileError(f"Encoding keyword {item.name} missing value") + # Convert kebab-case to snake_case for Python + key = item.name.replace("-", "_") + result[key] = value[i + 1] + i += 2 + else: + raise CompileError(f"Expected keyword in encoding, got {type(item).__name__}") + return result + + +def _parse_params(value: Any) -> List[ParamDef]: + """ + Parse parameter definitions from S-expression. + + Syntax: + :params ( + (param_name :type string :default "value" :desc "Description") + (param_name :type float :default 1.0 :range [0 10] :desc "Description") + (param_name :type string :default "a" :choices ["a" "b" "c"] :desc "Description") + ) + + Supported types: string, int, float, bool + Optional: :range [min max], :choices [...], :desc "..." + """ + if not isinstance(value, list): + raise CompileError(f"Params must be a list, got {type(value).__name__}") + + params = [] + for param_def in value: + if not isinstance(param_def, list) or len(param_def) < 1: + raise CompileError(f"Invalid param definition: {param_def}") + + # First element is the parameter name + first = param_def[0] + if isinstance(first, Symbol): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + raise CompileError(f"Param name must be symbol or string, got {type(first).__name__}") + + # Parse keyword arguments + param_type = "string" + default = None + desc = "" + range_min = None + range_max = None + choices = None + + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 >= len(param_def): + raise CompileError(f"Param keyword {item.name} missing value") + kw_value = param_def[i + 1] + + if item.name == "type": + if isinstance(kw_value, Symbol): + param_type = kw_value.name + else: + param_type = str(kw_value) + elif item.name == "default": + # Convert nil symbol to Python None + if isinstance(kw_value, Symbol) and kw_value.name == "nil": + default = None + else: + default = kw_value + elif item.name == "desc" or item.name == "description": + desc = str(kw_value) + elif item.name == "range": + if isinstance(kw_value, list) and len(kw_value) >= 2: + range_min = float(kw_value[0]) + range_max = float(kw_value[1]) + else: + raise CompileError(f"Param range must be [min max], got {kw_value}") + elif item.name == "choices": + if isinstance(kw_value, list): + choices = [str(c) if not isinstance(c, Symbol) else c.name for c in kw_value] + else: + raise CompileError(f"Param choices must be a list, got {kw_value}") + else: + raise CompileError(f"Unknown param keyword :{item.name}") + i += 2 + else: + i += 1 + + # Convert default to appropriate type + if default is not None: + if param_type == "int": + default = int(default) + elif param_type == "float": + default = float(default) + elif param_type == "bool": + if isinstance(default, (int, float)): + default = bool(default) + elif isinstance(default, str): + default = default.lower() in ("true", "1", "yes") + elif param_type == "string": + default = str(default) + + params.append(ParamDef( + name=param_name, + param_type=param_type, + default=default, + description=desc, + range_min=range_min, + range_max=range_max, + choices=choices, + )) + + return params + + +def compile_recipe(sexp: Any, initial_bindings: Dict[str, Any] = None, recipe_dir: Path = None, source_text: str = "") -> CompiledRecipe: + """ + Compile an S-expression recipe into internal format. + + Args: + sexp: Parsed S-expression (list starting with 'recipe' symbol) + initial_bindings: Optional dict of name -> value bindings to inject before compilation. + These can be referenced as variables in the recipe. + recipe_dir: Directory containing the recipe file, for resolving relative paths. + source_text: Original source text for stable hashing. + + Returns: + CompiledRecipe with nodes and registry + + Example: + >>> sexp = parse('(recipe "test" :version "1.0" (-> (source cat) (effect identity)))') + >>> result = compile_recipe(sexp) + >>> # With parameters: + >>> result = compile_recipe(sexp, {"effect_num": 5}) + """ + if not isinstance(sexp, list) or len(sexp) < 2: + raise CompileError("Recipe must be a list starting with 'recipe'") + + head = sexp[0] + if not (isinstance(head, Symbol) and head.name == "recipe"): + raise CompileError(f"Expected 'recipe', got {head}") + + # Extract recipe name + if len(sexp) < 2 or not isinstance(sexp[1], str): + raise CompileError("Recipe name must be a string") + name = sexp[1] + + # Parse keyword arguments and body + ctx = CompilerContext(recipe_dir=recipe_dir) + + version = "1.0" + description = "" + owner = None + encoding = {} + params = [] + body_exprs = [] + minimal_primitives = False + + i = 2 + while i < len(sexp): + item = sexp[i] + + if isinstance(item, Keyword): + if i + 1 >= len(sexp): + raise CompileError(f"Keyword {item.name} missing value") + value = sexp[i + 1] + + if item.name == "version": + version = str(value) + elif item.name == "description": + description = str(value) + elif item.name == "owner": + owner = str(value) + elif item.name == "encoding": + encoding = _parse_encoding(value) + elif item.name == "params": + params = _parse_params(value) + elif item.name == "minimal-primitives": + # Handle boolean value (could be Symbol('true') or Python bool) + if isinstance(value, Symbol): + minimal_primitives = value.name.lower() == "true" + else: + minimal_primitives = bool(value) + else: + raise CompileError(f"Unknown keyword :{item.name}") + i += 2 + else: + # Body expression + body_exprs.append(item) + i += 1 + + # Create bindings from params with their default values + # Initial bindings override param defaults + for param in params: + if initial_bindings and param.name in initial_bindings: + ctx.bindings[param.name] = initial_bindings[param.name] + else: + ctx.bindings[param.name] = param.default + + # Inject any additional initial bindings not covered by params + if initial_bindings: + for k, v in initial_bindings.items(): + if k not in ctx.bindings: + ctx.bindings[k] = v + + # Compile body expressions + # Track when we encounter the first stage to capture pre-stage bindings + output_node_id = None + first_stage_seen = False + + for expr in body_exprs: + # Check if this is a stage form + is_stage_form = ( + isinstance(expr, list) and + len(expr) > 0 and + isinstance(expr[0], Symbol) and + expr[0].name == "stage" + ) + + # Before the first stage, capture bindings as pre-stage bindings + if is_stage_form and not first_stage_seen: + first_stage_seen = True + ctx.pre_stage_bindings = dict(ctx.bindings) + + result = _compile_expr(expr, ctx) + if result is not None: + output_node_id = result + + if output_node_id is None: + raise CompileError("Recipe has no output (no DAG expression)") + + # Build stage order (topological sort) + stage_order = _topological_sort_stages(ctx.defined_stages) + + # Collect stages in order + stages = [ctx.defined_stages[name] for name in stage_order] + + return CompiledRecipe( + name=name, + version=version, + description=description, + owner=owner, + registry=ctx.registry, + nodes=list(ctx.nodes.values()), + output_node_id=output_node_id, + encoding=encoding, + params=params, + stages=stages, + stage_order=stage_order, + minimal_primitives=minimal_primitives, + source_text=source_text, + resolved_params=initial_bindings or {}, + ) + + +def _compile_expr(expr: Any, ctx: CompilerContext) -> Optional[str]: + """ + Compile an expression, returning node_id if it produces a node. + + Handles: + - (asset name :hash "..." :url "...") + - (effect name :hash "..." :url "...") + - (def name expr) + - (-> expr expr ...) + - (source ...), (effect ...), (sequence ...), etc. + """ + if not isinstance(expr, list) or len(expr) == 0: + # Atom - could be a reference + if isinstance(expr, Symbol): + # Look up binding + if expr.name in ctx.bindings: + return ctx.bindings[expr.name] + raise CompileError(f"Undefined symbol: {expr.name}") + return None + + head = expr[0] + if not isinstance(head, Symbol): + raise CompileError(f"Expected symbol at head of expression, got {head}") + + name = head.name + + # Registry declarations + if name == "asset": + return _compile_asset(expr, ctx) + if name == "effect": + return _compile_effect_decl(expr, ctx) + if name == "analyzer": + return _compile_analyzer_decl(expr, ctx) + if name == "construct": + return _compile_construct_decl(expr, ctx) + + # Template definition + if name == "deftemplate": + return _compile_deftemplate(expr, ctx) + + # Include - load and evaluate external sexp file + if name == "include": + return _compile_include(expr, ctx) + + # Binding + if name == "def": + return _compile_def(expr, ctx) + + # Stage form + if name == "stage": + return _compile_stage(expr, ctx) + + # Threading macro + if name == "->": + return _compile_threading(expr, ctx) + + # Node types + if name == "source": + return _compile_source(expr, ctx) + if name in ("effect", "fx"): + return _compile_effect_node(expr, ctx) + if name == "segment": + return _compile_segment(expr, ctx) + if name == "resize": + return _compile_resize(expr, ctx) + if name == "sequence": + return _compile_sequence(expr, ctx) + # Note: layer and blend are now regular effects, not special forms + # Use: (effect layer bg fg :x 0 :y 0) or (effect blend a b :mode "overlay") + if name == "mux": + return _compile_mux(expr, ctx) + if name == "analyze": + return _compile_analyze(expr, ctx) + if name == "scan": + return _compile_scan(expr, ctx) + if name == "blend-multi": + return _compile_blend_multi(expr, ctx) + if name == "make-rng": + return _compile_make_rng(expr, ctx) + if name == "next-seed": + return _compile_next_seed(expr, ctx) + + # Check if it's a registered construct call BEFORE built-in slice-on + # This allows user-defined constructs to override built-ins + if name in ctx.registry.get("constructs", {}): + return _compile_construct_call(expr, ctx) + + if name == "slice-on": + return _compile_slice_on(expr, ctx) + + # Binding expression for parameter linking + if name == "bind": + return _compile_bind(expr, ctx) + + # Pure functions that can be evaluated at compile time + PURE_FUNCTIONS = { + "max", "min", "floor", "ceil", "round", "abs", + "+", "-", "*", "/", "mod", "sqrt", "pow", + "len", "get", "first", "last", "nth", + "=", "<", ">", "<=", ">=", "not=", + "and", "or", "not", + "inc", "dec", + "chunk-every", + "list", "dict", + "assert", + } + if name in PURE_FUNCTIONS: + # Evaluate using the evaluator + from .evaluator import evaluate + # Build env from ctx.bindings + env = dict(ctx.bindings) + try: + result = evaluate(expr, env) + return result + except Exception as e: + raise CompileError(f"Error evaluating {name}: {e}") + + # Template invocation + if name in ctx.registry.get("templates", {}): + return _compile_template_call(expr, ctx) + + raise CompileError(f"Unknown expression type: {name}") + + +def _parse_kwargs(expr: List, start: int = 1) -> Tuple[List[Any], Dict[str, Any]]: + """ + Parse positional args and keyword args from expression. + + Returns (positional_args, keyword_dict) + """ + positional = [] + kwargs = {} + + i = start + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword): + if i + 1 >= len(expr): + raise CompileError(f"Keyword :{item.name} missing value") + kwargs[item.name] = expr[i + 1] + i += 2 + else: + positional.append(item) + i += 1 + + return positional, kwargs + + +def _compile_asset(expr: List, ctx: CompilerContext) -> None: + """Compile (asset name :cid "..." :url "...") or legacy (asset name :hash "...")""" + if len(expr) < 2: + raise CompileError("asset requires a name") + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + _, kwargs = _parse_kwargs(expr, 2) + + # Support both :cid (new IPFS) and :hash (legacy SHA3-256) + asset_cid = kwargs.get("cid") or kwargs.get("hash") + if not asset_cid: + raise CompileError(f"asset {name} requires :cid or :hash") + + ctx.registry["assets"][name] = { + "cid": asset_cid, + "url": kwargs.get("url"), + } + return None + + +def _resolve_effect_path(path: str, ctx: CompilerContext) -> Optional[Path]: + """Resolve an effect path relative to recipe directory. + + Args: + path: Relative or absolute path to effect file + ctx: Compiler context with recipe_dir + + Returns: + Resolved absolute Path, or None if not found + """ + effect_path = Path(path) + + # Already absolute + if effect_path.is_absolute() and effect_path.exists(): + return effect_path + + # Try relative to recipe directory + if ctx.recipe_dir: + recipe_relative = ctx.recipe_dir / path + if recipe_relative.exists(): + return recipe_relative.resolve() + + # Try relative to cwd + import os + cwd = Path(os.getcwd()) + cwd_relative = cwd / path + if cwd_relative.exists(): + return cwd_relative.resolve() + + return None + + +def _compile_effect_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile effect - either declaration or node. + + Declaration: (effect name :cid "..." :url "...") or legacy (effect name :hash "...") + Node: (effect effect-name) or (effect effect-name input-node) + """ + if len(expr) < 2: + raise CompileError("effect requires at least a name") + + # Check if this is a declaration (has :cid or :hash) + _, kwargs = _parse_kwargs(expr, 2) + + # Support both :cid (new) and :hash (legacy) + effect_cid = kwargs.get("cid") or kwargs.get("hash") + + if effect_cid or "path" in kwargs: + # Declaration + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + # Handle temporal flag - could be Symbol('true') or Python bool + temporal = kwargs.get("temporal", False) + if isinstance(temporal, Symbol): + temporal = temporal.name.lower() == "true" + + effect_path = kwargs.get("path") + + # Compute cid from file content if path provided and no cid + if effect_path and not effect_cid: + resolved_path = _resolve_effect_path(effect_path, ctx) + if resolved_path and resolved_path.exists(): + effect_cid = compute_file_cid(resolved_path) + effect_path = str(resolved_path) # Store absolute path + + ctx.registry["effects"][name] = { + "cid": effect_cid, + "path": effect_path, + "url": kwargs.get("url"), + "temporal": temporal, + } + return None + + # Otherwise it's a node - delegate to effect node compiler + return _compile_effect_node(expr, ctx) + + +def _compile_analyzer_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile analyzer declaration. + + Declaration: (analyzer name :path "..." :cid "...") + + Example: + (analyzer beats :path "../analyzers/beats/analyzer.py") + """ + if len(expr) < 2: + raise CompileError("analyzer requires at least a name") + + _, kwargs = _parse_kwargs(expr, 2) + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + ctx.registry["analyzers"][name] = { + "cid": kwargs.get("cid"), + "path": kwargs.get("path"), + "url": kwargs.get("url"), + } + return None + + +def _compile_construct_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile construct declaration. + + Declaration: (construct name :path "...") + + Example: + (construct beat-alternate :path "constructs/beat-alternate.sexp") + """ + if len(expr) < 2: + raise CompileError("construct requires at least a name") + + _, kwargs = _parse_kwargs(expr, 2) + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + ctx.registry["constructs"][name] = { + "path": kwargs.get("path"), + "cid": kwargs.get("cid"), + "url": kwargs.get("url"), + } + return None + + +def _compile_construct_call(expr: List, ctx: CompilerContext) -> str: + """ + Compile a call to a user-defined construct. + + Creates a CONSTRUCT node that will be expanded at plan time. + + Example: + (beat-alternate beats-data (list video-a video-b)) + """ + name = expr[0].name + construct_info = ctx.registry["constructs"][name] + + # Get positional args and kwargs + args, kwargs = _parse_kwargs(expr, 1) + + # Resolve input references + resolved_args = [] + node_inputs = [] # Track actual node IDs for inputs + + for arg in args: + if isinstance(arg, Symbol) and arg.name in ctx.bindings: + node_id = ctx.bindings[arg.name] + resolved_args.append(node_id) + node_inputs.append(node_id) + elif isinstance(arg, list) and arg and isinstance(arg[0], Symbol): + # Check if it's a literal list expression like (list video-a video-b) + if arg[0].name == "list": + # Resolve each element of the list + list_items = [] + for item in arg[1:]: + if isinstance(item, Symbol) and item.name in ctx.bindings: + list_items.append(ctx.bindings[item.name]) + node_inputs.append(ctx.bindings[item.name]) + else: + list_items.append(item) + resolved_args.append(list_items) + else: + # Try to compile as an expression + try: + node_id = _compile_expr(arg, ctx) + if node_id: + resolved_args.append(node_id) + node_inputs.append(node_id) + else: + resolved_args.append(arg) + except CompileError: + resolved_args.append(arg) + else: + resolved_args.append(arg) + + # Also scan kwargs for Symbol references to nodes (like analysis nodes) + # Helper to extract node IDs from a value (handles nested lists/dicts) + def extract_node_ids(val): + if isinstance(val, str) and len(val) == 64: + return [val] + elif isinstance(val, list): + ids = [] + for item in val: + ids.extend(extract_node_ids(item)) + return ids + elif isinstance(val, dict): + ids = [] + for v in val.values(): + ids.extend(extract_node_ids(v)) + return ids + return [] + + for key, value in kwargs.items(): + if isinstance(value, Symbol) and value.name in ctx.bindings: + binding_value = ctx.bindings[value.name] + # If it's a node ID (string hash), add to inputs + if isinstance(binding_value, str) and len(binding_value) == 64: + node_inputs.append(binding_value) + # Also scan lists/dicts for node IDs (e.g., video_infos list) + elif isinstance(binding_value, (list, dict)): + node_inputs.extend(extract_node_ids(binding_value)) + + node_id = ctx.add_node( + "CONSTRUCT", + { + "construct_name": name, + "construct_path": construct_info.get("path"), + "args": resolved_args, + # Include bindings so reducer lambda can reference video sources etc. + "bindings": dict(ctx.bindings), + **kwargs, + }, + inputs=node_inputs, + ) + return node_id + + +def _compile_include(expr: List, ctx: CompilerContext) -> None: + """ + Compile (include :path "...") or (include name :path "..."). + + Loads an external .sexp file and processes its declarations/definitions. + Supports analyzer, effect, construct declarations and def bindings. + + Forms: + (include :path "libs/standard-effects.sexp") ; declaration-only + (include :cid "bafy...") ; from L1/L2 cache + (include preset-name :path "presets/all.sexp") ; binds result to name + + Included files can contain: + - (analyzer name :path "...") declarations + - (effect name :path "...") declarations + - (construct name :path "...") declarations + - (deftemplate name (params...) body...) template definitions + - (def name value) bindings + + For web-based systems: + - :cid loads from L1 local cache or L2 shared cache + - :path is for local development + + Example library file (libs/standard-analyzers.sexp): + ;; Standard audio analyzers + (analyzer beats :path "../artdag-analyzers/beats/analyzer.py") + (analyzer bass :path "../artdag-analyzers/bass/analyzer.py") + (analyzer energy :path "../artdag-analyzers/energy/analyzer.py") + + Example usage: + (include :path "libs/standard-analyzers.sexp") + (include :path "libs/all-effects.sexp") + ;; Now beats, bass, energy analyzers and all effects are available + """ + from pathlib import Path + from .parser import parse_all + from .evaluator import evaluate + + _, kwargs = _parse_kwargs(expr, 1) + + # Name is optional - check if first arg is a symbol (name) or keyword + name = None + if len(expr) >= 2 and isinstance(expr[1], Symbol) and not str(expr[1].name).startswith(":"): + name = expr[1].name + _, kwargs = _parse_kwargs(expr, 2) + + path = kwargs.get("path") + cid = kwargs.get("cid") + + if not path and not cid: + raise CompileError("include requires :path or :cid") + + content = None + + if cid: + # Load from content-addressed cache (L1 local / L2 shared) + content = _load_from_cache(cid, ctx) + + if content is None and path: + # Load from local path + include_path = Path(path) + + # Try relative to recipe directory first + if hasattr(ctx, 'recipe_dir') and ctx.recipe_dir: + recipe_relative = ctx.recipe_dir / path + if recipe_relative.exists(): + include_path = recipe_relative + + # Try relative to cwd + if not include_path.exists(): + import os + cwd = Path(os.getcwd()) + include_path = cwd / path + + if not include_path.exists(): + raise CompileError(f"Include file not found: {path}") + + content = include_path.read_text() + + # Track included file by CID for upload/caching + include_cid = compute_content_cid(content) + ctx.registry["includes"][str(include_path.resolve())] = include_cid + + if content is None: + raise CompileError(f"Could not load include: path={path}, cid={cid}") + + # Parse the included file + sexp_list = parse_all(content) + if not isinstance(sexp_list, list): + sexp_list = [sexp_list] + + # Build an environment from current bindings + env = dict(ctx.bindings) + + for sexp in sexp_list: + if isinstance(sexp, list) and sexp and isinstance(sexp[0], Symbol): + form = sexp[0].name + + if form == "def": + # (def name value) - evaluate and add to bindings + if len(sexp) != 3: + raise CompileError(f"Invalid def in include: {sexp}") + def_name = sexp[1] + if isinstance(def_name, Symbol): + def_name = def_name.name + def_value = evaluate(sexp[2], env) + env[def_name] = def_value + ctx.bindings[def_name] = def_value + + elif form == "analyzer": + # (analyzer name :path "..." [:cid "..."]) + _compile_analyzer_decl(sexp, ctx) + + elif form == "effect": + # (effect name :path "..." [:cid "..."]) + _compile_effect_decl(sexp, ctx) + + elif form == "construct": + # (construct name :path "..." [:cid "..."]) + _compile_construct_decl(sexp, ctx) + + elif form == "deftemplate": + # (deftemplate name (params...) body...) + _compile_deftemplate(sexp, ctx) + + else: + # Try to evaluate as expression + result = evaluate(sexp, env) + # If a name was provided, bind the last result + if name and result is not None: + ctx.bindings[name] = result + else: + # Evaluate as expression (e.g., bare list literal) + result = evaluate(sexp, env) + if name and result is not None: + ctx.bindings[name] = result + + return None + + +def _load_from_cache(cid: str, ctx: CompilerContext) -> Optional[str]: + """ + Load content from L1 (local) or L2 (shared) cache by CID. + + Cache hierarchy: + L1: Local file cache (~/.artdag/cache/{cid}) + L2: Shared/network cache (IPFS, HTTP gateway, etc.) + + Returns file content as string, or None if not found. + """ + from pathlib import Path + import os + + # L1: Local cache directory + cache_dir = Path(os.path.expanduser("~/.artdag/cache")) + l1_path = cache_dir / cid + + if l1_path.exists(): + return l1_path.read_text() + + # L2: Try shared cache sources + content = _load_from_l2(cid, ctx) + + if content: + # Store in L1 for future use + cache_dir.mkdir(parents=True, exist_ok=True) + l1_path.write_text(content) + + return content + + +def _load_from_l2(cid: str, ctx: CompilerContext) -> Optional[str]: + """ + Load content from L2 shared cache. + + Supports: + - IPFS gateways (if CID starts with 'bafy' or 'Qm') + - HTTP URLs (if configured in ctx.l2_sources) + - Custom backends (extensible) + + Returns content as string, or None if not available. + """ + import urllib.request + import urllib.error + + # IPFS gateway (public, for development) + if cid.startswith("bafy") or cid.startswith("Qm"): + gateways = [ + f"https://ipfs.io/ipfs/{cid}", + f"https://dweb.link/ipfs/{cid}", + f"https://cloudflare-ipfs.com/ipfs/{cid}", + ] + for gateway_url in gateways: + try: + with urllib.request.urlopen(gateway_url, timeout=10) as response: + return response.read().decode('utf-8') + except (urllib.error.URLError, urllib.error.HTTPError): + continue + + # Custom L2 sources from context (e.g., private cache server) + l2_sources = getattr(ctx, 'l2_sources', []) + for source in l2_sources: + try: + url = f"{source}/{cid}" + with urllib.request.urlopen(url, timeout=10) as response: + return response.read().decode('utf-8') + except (urllib.error.URLError, urllib.error.HTTPError): + continue + + return None + + +def _compile_def(expr: List, ctx: CompilerContext) -> None: + """Compile (def name expr)""" + if len(expr) != 3: + raise CompileError("def requires exactly 2 arguments: name and expression") + + name = expr[1] + if not isinstance(name, Symbol): + raise CompileError(f"def name must be a symbol, got {name}") + + # If binding already exists (e.g. from command-line param), don't override + # This allows recipes to specify defaults that command-line params can override + if name.name in ctx.bindings: + return None + + body = expr[2] + + # Check if body is a simple value (number, string, etc.) + if isinstance(body, (int, float, str, bool)): + ctx.bindings[name.name] = body + return None + + node_id = _compile_expr(body, ctx) + + # Multi-scan dict emit: expand field bindings + if isinstance(node_id, dict) and node_id.get("_multi_scan"): + for field_name, field_node_id in node_id["fields"].items(): + binding_name = f"{name.name}-{field_name}" + ctx.bindings[binding_name] = field_node_id + if field_node_id in ctx.nodes: + ctx.nodes[field_node_id]["name"] = binding_name + return None + + # If result is a simple value (from evaluated pure function), store it directly + # This includes lists, tuples, dicts from pure functions like `list` + if isinstance(node_id, (int, float, str, bool, list, tuple, dict)): + ctx.bindings[name.name] = node_id + return None + + if node_id is None: + raise CompileError(f"def body must produce a node or value") + + # Store binding for reference resolution + ctx.bindings[name.name] = node_id + + # Also store the name on the node so planner can reference it + if node_id in ctx.nodes: + ctx.nodes[node_id]["name"] = name.name + + return None + + +def _compile_stage(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile (stage :name :requires [...] :inputs [...] :outputs [...] body...). + + Stage form enables explicit dependency declaration, parallel execution, + and variable scoping. + + Example: + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :plan-a + :requires [:analyze-a] + :inputs [beats-a] + :outputs [segments-a] + (def segments-a (make-segments :beats beats-a))) + """ + if len(expr) < 2: + raise CompileError("stage requires at least a name") + + # Parse stage name (first element after 'stage' should be a keyword like :analyze-a) + # The stage name is NOT a key-value pair - it's a standalone keyword + stage_name = None + start_idx = 1 + + if len(expr) > 1: + first_arg = expr[1] + if isinstance(first_arg, Keyword): + stage_name = first_arg.name + start_idx = 2 + elif isinstance(first_arg, Symbol): + stage_name = first_arg.name + start_idx = 2 + + if stage_name is None: + raise CompileError("stage requires a name (e.g., (stage :analyze-a ...))") + + # Now parse remaining kwargs and body + args, kwargs = _parse_kwargs(expr, start_idx) + + # Parse requires, inputs, outputs + requires = [] + if "requires" in kwargs: + req_val = kwargs["requires"] + if isinstance(req_val, list): + for r in req_val: + if isinstance(r, Keyword): + requires.append(r.name) + elif isinstance(r, Symbol): + requires.append(r.name) + elif isinstance(r, str): + requires.append(r) + else: + raise CompileError(f"Invalid require: {r}") + else: + raise CompileError(":requires must be a list") + + inputs = [] + if "inputs" in kwargs: + inp_val = kwargs["inputs"] + if isinstance(inp_val, list): + for i in inp_val: + if isinstance(i, Symbol): + inputs.append(i.name) + elif isinstance(i, str): + inputs.append(i) + else: + raise CompileError(f"Invalid input: {i}") + else: + raise CompileError(":inputs must be a list") + + outputs = [] + if "outputs" in kwargs: + out_val = kwargs["outputs"] + if isinstance(out_val, list): + for o in out_val: + if isinstance(o, Symbol): + outputs.append(o.name) + elif isinstance(o, str): + outputs.append(o) + else: + raise CompileError(f"Invalid output: {o}") + else: + raise CompileError(":outputs must be a list") + + # Validate requires - must reference defined stages + for req in requires: + if req not in ctx.defined_stages: + raise CompileError( + f"Stage '{stage_name}' requires undefined stage '{req}'" + ) + + # Validate inputs - must be produced by required stages + for inp in inputs: + found = False + for req in requires: + if inp in ctx.defined_stages[req].output_bindings: + found = True + break + if not found and inp not in ctx.pre_stage_bindings: + raise CompileError( + f"Stage '{stage_name}' declares input '{inp}' " + f"which is not an output of any required stage or pre-stage binding" + ) + + # Check for circular dependencies (simple check for now) + # A more thorough check would use topological sort + visited = set() + def check_cycle(stage: str, path: List[str]): + if stage in path: + cycle = " -> ".join(path + [stage]) + raise CompileError(f"Circular stage dependency: {cycle}") + if stage in visited: + return + visited.add(stage) + if stage in ctx.defined_stages: + for req in ctx.defined_stages[stage].requires: + check_cycle(req, path + [stage]) + + for req in requires: + check_cycle(req, [stage_name]) + + # Save context state before entering stage + prev_stage = ctx.current_stage + prev_stage_node_ids = ctx.stage_node_ids + + # Enter stage context + ctx.current_stage = stage_name + ctx.stage_node_ids = [] + + # Build accessible bindings for this stage + stage_ctx_bindings = dict(ctx.pre_stage_bindings) + + # Add input bindings from required stages + for inp in inputs: + for req in requires: + if inp in ctx.defined_stages[req].output_bindings: + stage_ctx_bindings[inp] = ctx.defined_stages[req].output_bindings[inp] + break + + # Save current bindings and set up stage bindings + prev_bindings = ctx.bindings + ctx.bindings = stage_ctx_bindings + + # Compile body expressions + # Body expressions are lists or symbols after the stage name and kwargs + # Start from index 2 (after 'stage' and stage name) + body_exprs = [] + i = 2 # Skip 'stage' and stage name + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword): + # Skip keyword and its value + i += 2 + elif isinstance(item, (list, Symbol)): + # Include both list expressions and symbol references + body_exprs.append(item) + i += 1 + else: + i += 1 + + last_result = None + for body_expr in body_exprs: + result = _compile_expr(body_expr, ctx) + if result is not None: + last_result = result + + # Collect output bindings + output_bindings = {} + for out in outputs: + if out in ctx.bindings: + output_bindings[out] = ctx.bindings[out] + else: + raise CompileError( + f"Stage '{stage_name}' declares output '{out}' " + f"but it was not defined in the stage body" + ) + + # Create CompiledStage + compiled_stage = CompiledStage( + name=stage_name, + requires=requires, + inputs=inputs, + outputs=outputs, + node_ids=ctx.stage_node_ids, + output_bindings=output_bindings, + ) + + # Register the stage + ctx.defined_stages[stage_name] = compiled_stage + ctx.stage_bindings[stage_name] = output_bindings + + # Restore context state + ctx.current_stage = prev_stage + ctx.stage_node_ids = prev_stage_node_ids + ctx.bindings = prev_bindings + + # Make stage outputs available to subsequent stages via bindings + ctx.bindings.update(output_bindings) + + return last_result + + +def _compile_threading(expr: List, ctx: CompilerContext) -> str: + """ + Compile (-> expr1 expr2 expr3 ...) + + Each expression's output becomes the implicit first input of the next. + """ + if len(expr) < 2: + raise CompileError("-> requires at least one expression") + + prev_node_id = None + + for i, sub_expr in enumerate(expr[1:]): + if prev_node_id is not None: + # Inject previous node as first input + sub_expr = _inject_input(sub_expr, prev_node_id) + + prev_node_id = _compile_expr(sub_expr, ctx) + + if prev_node_id is None: + raise CompileError(f"Expression {i} in -> chain produced no node") + + return prev_node_id + + +def _inject_input(expr: Any, input_id: str) -> List: + """Inject an input node ID into an expression.""" + if not isinstance(expr, list): + # Symbol reference - wrap in a node that takes input + if isinstance(expr, Symbol): + # Assume it's an effect name + return [Symbol("effect"), expr, Symbol(f"__input_{input_id}")] + raise CompileError(f"Cannot inject input into {expr}") + + # For node expressions, we'll handle the input in the compiler + # Mark it with a special __prev__ reference + return expr + [Symbol("__prev__"), input_id] + + +def _resolve_input(arg: Any, ctx: CompilerContext, prev_id: str = None) -> str: + """Resolve an argument to a node ID.""" + if isinstance(arg, Symbol): + if arg.name == "__prev__": + if prev_id is None: + raise CompileError("__prev__ used outside threading context") + return prev_id + if arg.name.startswith("__input_"): + return arg.name[8:] # Strip __input_ prefix + if arg.name in ctx.bindings: + return ctx.bindings[arg.name] + raise CompileError(f"Undefined reference: {arg.name}") + + if isinstance(arg, str): + # Direct node ID + return arg + + if isinstance(arg, list): + # Nested expression + return _compile_expr(arg, ctx) + + raise CompileError(f"Cannot resolve input: {arg}") + + +def _extract_prev_id(args: List, kwargs: Dict) -> Tuple[List, Dict, Optional[str]]: + """Extract __prev__ marker from args if present.""" + prev_id = None + new_args = [] + + i = 0 + while i < len(args): + if isinstance(args[i], Symbol) and args[i].name == "__prev__": + if i + 1 < len(args): + prev_id = args[i + 1] + i += 2 + continue + new_args.append(args[i]) + i += 1 + + return new_args, kwargs, prev_id + + +def _compile_source(expr: List, ctx: CompilerContext) -> str: + """ + Compile (source asset-name), (source :input "name" ...), or (source :path "file.mkv" ...). + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, _ = _extract_prev_id(args, kwargs) + + if "input" in kwargs: + # Variable input - :input can be followed by a name string + input_val = kwargs["input"] + if isinstance(input_val, str): + # (source :input "User Video" :description "...") + name = input_val + else: + # (source :input true :name "User Video") + name = kwargs.get("name", "Input") + config = { + "input": True, + "name": name, + "description": kwargs.get("description", ""), + } + elif "path" in kwargs: + # Local file path - for development/testing + # (source :path "dog.mkv" :description "Input video") + path = kwargs["path"] + config = { + "path": path, + "description": kwargs.get("description", ""), + } + elif args: + # Asset reference + asset_name = args[0] + if isinstance(asset_name, Symbol): + asset_name = asset_name.name + config = {"asset": asset_name} + else: + raise CompileError("source requires asset name, :input flag, or :path") + + return ctx.add_node("SOURCE", config) + + +def _compile_effect_node(expr: List, ctx: CompilerContext) -> str: + """ + Compile (effect effect-name [input-nodes...] :param value ...). + + Single input: + (effect rotate video :angle 45) + (-> video (effect rotate :angle 45)) + + Multi-input (blend, layer, etc.): + (effect blend video-a video-b :mode "overlay") + (-> video-a (effect blend video-b :mode "overlay")) + + Parameters can be literals or bind expressions: + (effect brightness video :level (bind analysis :energy :range [0 1])) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + if not args: + raise CompileError("effect requires effect name") + + effect_name = args[0] + if isinstance(effect_name, Symbol): + effect_name = effect_name.name + + config = {"effect": effect_name} + + # Look up effect info from registry + effects_registry = ctx.registry.get("effects", {}) + if effect_name in effects_registry: + effect_info = effects_registry[effect_name] + if isinstance(effect_info, dict): + if "path" in effect_info: + config["effect_path"] = effect_info["path"] + if "cid" in effect_info and effect_info["cid"]: + config["effect_cid"] = effect_info["cid"] + elif isinstance(effect_info, str): + config["effect_path"] = effect_info + + # Include full effects_registry with cids for workers to fetch dependencies + # Only include effects that have cids (content-addressed) + effects_with_cids = {} + for name, info in effects_registry.items(): + if isinstance(info, dict) and info.get("cid"): + effects_with_cids[name] = info["cid"] + if effects_with_cids: + config["effects_registry"] = effects_with_cids + + # Process parameter values, looking for bind expressions + # Also track analysis references for workers + analysis_refs = set() + for k, v in kwargs.items(): + if k not in ("hash", "url"): + processed = _process_value(v, ctx) + config[k] = processed + # Extract analysis references from bind expressions + _extract_analysis_refs(processed, analysis_refs) + + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + # Collect inputs - first from threading (prev_id), then from additional args + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[1:]: + # Handle list of inputs: (effect blend [video-a video-b] :mode "overlay") + if isinstance(arg, list) and arg and not isinstance(arg[0], Symbol): + for item in arg: + inputs.append(_resolve_input(item, ctx, prev_id)) + else: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + # Auto-detect multi-input effects + if len(inputs) > 1: + config["multi_input"] = True + + return ctx.add_node("EFFECT", config, inputs) + + +def _extract_analysis_refs(value: Any, refs: set) -> None: + """Extract analysis node references from a processed value. + + Bind expressions contain references to analysis nodes. This function + extracts those references so workers know which analysis data they need. + """ + if isinstance(value, dict): + # Check if this is a bind expression (has _binding flag or source/ref key) + if value.get("_binding") or "bind" in value or "ref" in value or "source" in value: + ref = value.get("source") or value.get("ref") or value.get("bind") + if ref: + refs.add(ref) + # Recursively check nested dicts + for v in value.values(): + _extract_analysis_refs(v, refs) + elif isinstance(value, list): + for item in value: + _extract_analysis_refs(item, refs) + + +def _compile_segment(expr: List, ctx: CompilerContext) -> str: + """Compile (segment :start 0.0 :end 2.0 [input]).""" + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = {} + analysis_refs = set() + + if "start" in kwargs: + val = _process_value(kwargs["start"], ctx) + # Binding dicts are preserved for runtime resolution, None values are skipped + if val is not None: + config["start"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("start"), analysis_refs) + if "end" in kwargs: + val = _process_value(kwargs["end"], ctx) + if val is not None: + config["end"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("end"), analysis_refs) + if "duration" in kwargs: + val = _process_value(kwargs["duration"], ctx) + if val is not None: + config["duration"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("duration"), analysis_refs) + + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("SEGMENT", config, inputs) + + +def _compile_resize(expr: List, ctx: CompilerContext) -> str: + """ + Compile (resize width height :mode "linear" [input]). + + Resize is now an EFFECT that uses the sexp resize-frame effect. + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + if len(args) < 2: + raise CompileError("resize requires width and height") + + # Create EFFECT node with resize effect + # Note: param names match resize.sexp (target-w, target-h to avoid primitive conflict) + config = { + "effect": "resize-frame", + "effect_path": "sexp_effects/effects/resize-frame.sexp", + "target-w": int(args[0]), + "target-h": int(args[1]), + "mode": kwargs.get("mode", "linear"), + } + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[2:]: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("EFFECT", config, inputs) + + +def _compile_sequence(expr: List, ctx: CompilerContext) -> str: + """ + Compile (sequence node1 node2 ... :resize-mode :fit :priority :width). + + Options: + :transition - transition between clips (default: cut) + :resize-mode - fit | crop | stretch | cover (default: none) + :priority - width | height (which dimension to match exactly) + :target-width - explicit target width + :target-height - explicit target height + :pad-color - color for fit mode padding (default: black) + :crop-gravity - center | top | bottom | left | right (default: center) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = { + "transition": kwargs.get("transition", {"type": "cut"}), + } + + # Add normalize config if specified + resize_mode = kwargs.get("resize-mode") + if isinstance(resize_mode, (Symbol, Keyword)): + resize_mode = resize_mode.name + if resize_mode: + config["resize_mode"] = resize_mode + + priority = kwargs.get("priority") + if isinstance(priority, (Symbol, Keyword)): + priority = priority.name + if priority: + config["priority"] = priority + + if kwargs.get("target-width"): + config["target_width"] = kwargs["target-width"] + if kwargs.get("target-height"): + config["target_height"] = kwargs["target-height"] + + pad_color = kwargs.get("pad-color") + if isinstance(pad_color, (Symbol, Keyword)): + pad_color = pad_color.name + config["pad_color"] = pad_color or "black" + + crop_gravity = kwargs.get("crop-gravity") + if isinstance(crop_gravity, (Symbol, Keyword)): + crop_gravity = crop_gravity.name + config["crop_gravity"] = crop_gravity or "center" + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("SEQUENCE", config, inputs) + + +def _compile_mux(expr: List, ctx: CompilerContext) -> str: + """Compile (mux video-node audio-node).""" + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = { + "video_stream": 0, + "audio_stream": 1, + "shortest": kwargs.get("shortest", True), + } + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + if len(inputs) < 2: + raise CompileError("mux requires video and audio inputs") + + return ctx.add_node("MUX", config, inputs) + + +def _compile_slice_on(expr: List, ctx: CompilerContext) -> str: + """ + Compile slice-on with either legacy or lambda syntax. + + Legacy syntax: + (slice-on video analysis :times path :effect fx :pattern pat) + + Lambda syntax: + (slice-on analysis + :times times + :init 0 + :fn (lambda [acc i start end] + {:source video + :effects (if (odd? i) [invert] []) + :acc (inc acc)})) + + Args: + video: input video node (legacy) or omitted (lambda) + analysis: analysis node with times array + :times - path to times array in analysis + :effect - effect to apply (legacy, optional) + :pattern - all, odd, even, alternate (legacy, default: all) + :init - initial accumulator value (lambda) + :fn - reducer lambda function (lambda) + """ + from .parser import Lambda + + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # Check for lambda mode + reducer_fn = kwargs.get("fn") + + # Parse lambda if it's a list + if isinstance(reducer_fn, list): + reducer_fn = _parse_lambda(reducer_fn) + + # Lambda mode: only analysis input required (sources come from fn) + # Legacy mode: requires video and analysis inputs + if reducer_fn is not None: + # Lambda mode - just need analysis input + if len(args) < 1: + raise CompileError("slice-on requires analysis input") + analysis_input = _resolve_input(args[0], ctx, prev_id) + inputs = [analysis_input] + else: + # Legacy mode - need video and analysis inputs + if len(args) < 2: + raise CompileError("slice-on requires video and analysis inputs") + video_input = _resolve_input(args[0], ctx, prev_id) + analysis_input = _resolve_input(args[1], ctx, prev_id) + inputs = [video_input, analysis_input] + + times_path = kwargs.get("times", "times") + if isinstance(times_path, Symbol): + times_path = times_path.name + + config = { + "times_path": times_path, + "fn": reducer_fn, + "init": kwargs.get("init", 0), + # Include bindings so lambda can reference video sources etc. + "bindings": dict(ctx.bindings), + } + + # Optional :videos list for multi-source composition mode + videos_list = kwargs.get("videos") + if videos_list is not None: + if not isinstance(videos_list, list): + raise CompileError(":videos must be a list") + resolved_videos = [] + for v in videos_list: + resolved_videos.append(_resolve_input(v, ctx, None)) + config["videos"] = resolved_videos + # Add to inputs so planner knows about dependencies + for vid in resolved_videos: + if vid not in inputs: + inputs.append(vid) + + return ctx.add_node("SLICE_ON", config, inputs) + + +def _parse_lambda(expr: List): + """Parse a lambda expression list into a Lambda object.""" + from .parser import Lambda, Symbol + + if not expr or not isinstance(expr[0], Symbol): + raise CompileError("Invalid lambda expression") + + name = expr[0].name + if name not in ("lambda", "fn"): + raise CompileError(f"Expected lambda or fn, got {name}") + + if len(expr) < 3: + raise CompileError("lambda requires params and body") + + params = expr[1] + if not isinstance(params, list): + raise CompileError("lambda params must be a list") + + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + elif isinstance(p, str): + param_names.append(p) + else: + raise CompileError(f"Invalid lambda param: {p}") + + return Lambda(param_names, expr[2]) + + +def _compile_analyze(expr: List, ctx: CompilerContext) -> str: + """ + Compile (analyze analyzer-name :param value ...). + + Example: + (analyze beats) + (analyze beats :min-bpm 120 :max-bpm 180) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # First arg is analyzer name + if not args: + raise CompileError("analyze requires analyzer name") + + analyzer_name = args[0] + if isinstance(analyzer_name, Symbol): + analyzer_name = analyzer_name.name + + # Look up analyzer in registry + analyzer_entry = ctx.registry.get("analyzers", {}).get(analyzer_name, {}) + + config = { + "analyzer": analyzer_name, + "analyzer_path": analyzer_entry.get("path"), + "cid": analyzer_entry.get("cid"), + } + # Add params (kwargs) to config + config.update(kwargs) + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[1:]: # Skip analyzer name + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("ANALYZE", config, inputs) + + +def _compile_bind(expr: List, ctx: CompilerContext) -> Dict[str, Any]: + """ + Compile (bind source feature :option value ...). + + Returns a binding specification dict (not a node ID). + + Examples: + (bind analysis :energy) + (bind analysis :energy :range [0 1]) + (bind analysis :beats :on-event 1.0 :decay 0.1) + (bind analysis :energy :range [0 1] :smooth 0.05 :noise 0.1 :seed 42) + """ + args, kwargs = _parse_kwargs(expr, 1) + + if len(args) < 2: + raise CompileError("bind requires source and feature: (bind source :feature ...)") + + source = args[0] + feature = args[1] + + # Source can be a symbol reference + source_ref = None + if isinstance(source, Symbol): + if source.name in ctx.bindings: + source_ref = ctx.bindings[source.name] + else: + source_ref = source.name + + # Feature should be a keyword + feature_name = None + if isinstance(feature, Keyword): + feature_name = feature.name + elif isinstance(feature, Symbol): + feature_name = feature.name + else: + raise CompileError(f"bind feature must be a keyword, got {feature}") + + binding = { + "_binding": True, # Marker for binding resolution + "source": source_ref, + "feature": feature_name, + } + + # Add optional binding modifiers + if "range" in kwargs: + range_val = kwargs["range"] + if isinstance(range_val, list) and len(range_val) == 2: + binding["range"] = [float(range_val[0]), float(range_val[1])] + else: + raise CompileError("bind :range must be [lo hi]") + + if "smooth" in kwargs: + binding["smooth"] = float(kwargs["smooth"]) + + if "offset" in kwargs: + binding["offset"] = float(kwargs["offset"]) + + if "on-event" in kwargs: + binding["on_event"] = float(kwargs["on-event"]) + + if "decay" in kwargs: + binding["decay"] = float(kwargs["decay"]) + + if "noise" in kwargs: + binding["noise"] = float(kwargs["noise"]) + + if "seed" in kwargs: + binding["seed"] = int(kwargs["seed"]) + + return binding + + +def _process_value(value: Any, ctx: CompilerContext) -> Any: + """ + Process a value, resolving nested expressions like bind and math. + + Returns the processed value (could be a binding dict, expression dict, node ref, or literal). + + Supported expressions: + (bind source feature :range [lo hi]) - bind to analysis data + (+ a b), (- a b), (* a b), (/ a b), (mod a b) - math operations + time - current frame time in seconds + frame - current frame number + """ + # Math operators that create runtime expressions + MATH_OPS = {'+', '-', '*', '/', 'mod', 'min', 'max', 'abs', 'sin', 'cos', + 'if', '<', '>', '<=', '>=', '=', + 'rand', 'rand-int', 'rand-range', + 'floor', 'ceil', 'nth'} + + if isinstance(value, Symbol): + # Special runtime symbols + if value.name == "time": + return {"_expr": True, "op": "time"} + if value.name == "frame": + return {"_expr": True, "op": "frame"} + # Resolve symbol from bindings + if value.name in ctx.bindings: + return ctx.bindings[value.name] + # Return as-is if not found (could be an effect reference, etc.) + return value + + if isinstance(value, list) and len(value) > 0: + head = value[0] + head_name = head.name if isinstance(head, Symbol) else None + + if head_name == "bind": + return _compile_bind(value, ctx) + + # Handle lambda expressions - parse but don't compile + if head_name in ("lambda", "fn"): + return _parse_lambda(value) + + # Handle dict expressions - keyword-value pairs for runtime dict construction + if head_name == "dict": + keys = [] + vals = [] + i = 1 + while i < len(value): + if isinstance(value[i], Keyword): + keys.append(value[i].name) + if i + 1 < len(value): + vals.append(_process_value(value[i + 1], ctx)) + i += 2 + else: + i += 1 + return {"_expr": True, "op": "dict", "keys": keys, "args": vals} + + # Handle math expressions - preserve for runtime evaluation + if head_name in MATH_OPS: + processed_args = [_process_value(arg, ctx) for arg in value[1:]] + return {"_expr": True, "op": head_name, "args": processed_args} + + # Could be other nested expressions + return _compile_expr(value, ctx) + + return value + + +def _compile_scan_expr(value: Any, ctx: CompilerContext) -> Any: + """ + Compile an expression for use in scan step/emit. + + Like _process_value but treats unbound symbols as runtime variable + references (for acc, dict fields like rem/hue, etc.). + """ + SCAN_OPS = { + '+', '-', '*', '/', 'mod', 'min', 'max', 'abs', 'sin', 'cos', + 'if', '<', '>', '<=', '>=', '=', + 'rand', 'rand-int', 'rand-range', + 'floor', 'ceil', 'nth', + } + + if isinstance(value, (int, float)): + return value + + if isinstance(value, Keyword): + return value.name + + if isinstance(value, Symbol): + # Known runtime symbols + if value.name in ("time", "frame"): + return {"_expr": True, "op": value.name} + # Check bindings for compile-time constants (e.g., recipe params) + if value.name in ctx.bindings: + bound = ctx.bindings[value.name] + if isinstance(bound, (int, float, str, bool)): + return bound + # Runtime variable reference (acc, rem, hue, etc.) + return {"_expr": True, "op": "var", "name": value.name} + + if isinstance(value, list) and len(value) > 0: + head = value[0] + head_name = head.name if isinstance(head, Symbol) else None + + if head_name == "dict": + # (dict :key1 val1 :key2 val2) + keys = [] + args = [] + i = 1 + while i < len(value): + if isinstance(value[i], Keyword): + keys.append(value[i].name) + if i + 1 < len(value): + args.append(_compile_scan_expr(value[i + 1], ctx)) + i += 2 + else: + i += 1 + return {"_expr": True, "op": "dict", "keys": keys, "args": args} + + if head_name in SCAN_OPS: + processed_args = [_compile_scan_expr(arg, ctx) for arg in value[1:]] + return {"_expr": True, "op": head_name, "args": processed_args} + + # Fall through to _process_value for bind expressions, etc. + return _process_value(value, ctx) + + return value + + +def _eval_const_expr(value, ctx: 'CompilerContext'): + """Evaluate a compile-time constant expression. + + Supports literals, symbol lookups in ctx.bindings, and basic arithmetic. + Used for values like scan :seed that must resolve to a number at compile time. + """ + if isinstance(value, (int, float)): + return value + if isinstance(value, Symbol): + if value.name in ctx.bindings: + bound = ctx.bindings[value.name] + if isinstance(bound, (int, float)): + return bound + raise CompileError(f"Cannot resolve symbol '{value.name}' to a constant") + if isinstance(value, list) and len(value) >= 1: + head = value[0] + if isinstance(head, Symbol): + name = head.name + if name == 'next-seed' and len(value) == 2: + rng_val = _resolve_rng_value(value[1], ctx) + return _derive_seed(rng_val) + args = [_eval_const_expr(a, ctx) for a in value[1:]] + if name == '+' and len(args) >= 2: + return args[0] + args[1] + if name == '-' and len(args) >= 2: + return args[0] - args[1] + if name == '*' and len(args) >= 2: + return args[0] * args[1] + if name == '/' and len(args) >= 2: + return args[0] / args[1] if args[1] != 0 else 0 + if name == 'mod' and len(args) >= 2: + return args[0] % args[1] if args[1] != 0 else 0 + raise CompileError(f"Unsupported constant expression operator: {name}") + raise CompileError(f"Cannot evaluate as constant: {value}") + + +def _derive_seed(rng_val: dict) -> int: + """Derive next unique seed from RNG value, incrementing counter.""" + master = rng_val["master_seed"] + counter = rng_val["_counter"] + digest = hashlib.sha256(f"{master}:{counter[0]}".encode()).hexdigest()[:8] + seed = int(digest, 16) + counter[0] += 1 + return seed + + +def _resolve_rng_value(ref, ctx) -> dict: + """Resolve a reference to an RNG value dict.""" + if isinstance(ref, dict) and ref.get("_rng"): + return ref + if isinstance(ref, Symbol): + if ref.name in ctx.bindings: + val = ctx.bindings[ref.name] + if isinstance(val, dict) and val.get("_rng"): + return val + raise CompileError(f"Symbol '{ref.name}' is not an RNG value") + raise CompileError(f"Expected RNG value, got {type(ref).__name__}") + + +def _compile_make_rng(expr, ctx): + """(make-rng SEED) -> compile-time RNG value dict.""" + if len(expr) != 2: + raise CompileError("make-rng requires exactly 1 argument: seed") + seed_val = _eval_const_expr(expr[1], ctx) + return {"_rng": True, "master_seed": int(seed_val), "_counter": [0]} + + +def _compile_next_seed(expr, ctx): + """(next-seed RNG) -> integer seed drawn from RNG.""" + if len(expr) != 2: + raise CompileError("next-seed requires exactly 1 argument: rng") + rng_val = _resolve_rng_value(expr[1], ctx) + return _derive_seed(rng_val) + + +def _compile_scan(expr: List, ctx: CompilerContext) -> str: + """ + Compile (scan source :seed N :init EXPR :step EXPR :emit EXPR). + + Creates a SCAN node that produces a time-series by iterating over + source analysis events with a step function and emit expression. + + The accumulator can be a number or a dict. Dict field names become + accessible as variables in step/emit expressions. + + The :seed parameter supports compile-time constant expressions, + e.g. (+ seed 100) where seed is a template parameter. + + Examples: + ;; Simple counter accumulator + (scan beat-data :seed 42 :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0)) + + ;; Dict accumulator with named fields + (scan beat-data :seed 101 :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0)) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # Resolve source input + if prev_id: + source_input = prev_id if isinstance(prev_id, str) else str(prev_id) + elif args: + source_input = _resolve_input(args[0], ctx, None) + else: + raise CompileError("scan requires a source input") + + if "rng" in kwargs: + rng_val = _resolve_rng_value(kwargs["rng"], ctx) + seed = _derive_seed(rng_val) + else: + seed = kwargs.get("seed", 0) + seed = _eval_const_expr(seed, ctx) + + if "step" not in kwargs: + raise CompileError("scan requires :step expression") + if "emit" not in kwargs: + raise CompileError("scan requires :emit expression") + + init_expr = _compile_scan_expr(kwargs.get("init", 0), ctx) + step_expr = _compile_scan_expr(kwargs["step"], ctx) + + emit_raw = kwargs["emit"] + if isinstance(emit_raw, dict): + result = {} + for field_name, field_expr in emit_raw.items(): + field_emit = _compile_scan_expr(field_expr, ctx) + config = { + "seed": int(seed), + "init": init_expr, + "step_expr": step_expr, + "emit_expr": field_emit, + } + node_id = ctx.add_node("SCAN", config, inputs=[source_input]) + result[field_name] = node_id + return {"_multi_scan": True, "fields": result} + + emit_expr = _compile_scan_expr(emit_raw, ctx) + + config = { + "seed": int(seed), + "init": init_expr, + "step_expr": step_expr, + "emit_expr": emit_expr, + } + + return ctx.add_node("SCAN", config, inputs=[source_input]) + + +def _compile_blend_multi(expr: List, ctx: CompilerContext) -> str: + """Compile (blend-multi :videos [...] :weights [...] :mode M :resize_mode R). + + Produces a single EFFECT node that takes N video inputs and N weight + bindings, blending them in one pass via the blend_multi effect. + """ + _, kwargs = _parse_kwargs(expr, 1) + + videos = kwargs.get("videos") + weights = kwargs.get("weights") + mode = kwargs.get("mode", "alpha") + resize_mode = kwargs.get("resize_mode", "fit") + + if not videos or not weights: + raise CompileError("blend-multi requires :videos and :weights") + if not isinstance(videos, list) or not isinstance(weights, list): + raise CompileError("blend-multi :videos and :weights must be lists") + if len(videos) != len(weights): + raise CompileError( + f"blend-multi: videos ({len(videos)}) and weights " + f"({len(weights)}) must be same length" + ) + if len(videos) < 2: + raise CompileError("blend-multi requires at least 2 videos") + + # Resolve video symbols to node IDs — these become the multi-input list + input_ids = [] + for v in videos: + input_ids.append(_resolve_input(v, ctx, None)) + + # Process each weight symbol into a binding dict {_binding, source, feature} + weight_bindings = [] + for w in weights: + bind_expr = [Symbol("bind"), w, Symbol("values")] + weight_bindings.append(_process_value(bind_expr, ctx)) + + # Build EFFECT config + effects_registry = ctx.registry.get("effects", {}) + config = { + "effect": "blend_multi", + "multi_input": True, + "weights": weight_bindings, + "mode": mode, + "resize_mode": resize_mode, + } + + # Attach effect path / cid from registry + if "blend_multi" in effects_registry: + effect_info = effects_registry["blend_multi"] + if isinstance(effect_info, dict): + if "path" in effect_info: + config["effect_path"] = effect_info["path"] + if "cid" in effect_info and effect_info["cid"]: + config["effect_cid"] = effect_info["cid"] + + # Include effects registry for workers + effects_with_cids = {} + for name, info in effects_registry.items(): + if isinstance(info, dict) and info.get("cid"): + effects_with_cids[name] = info["cid"] + if effects_with_cids: + config["effects_registry"] = effects_with_cids + + # Extract analysis refs so workers know which analysis data they need + analysis_refs = set() + for wb in weight_bindings: + _extract_analysis_refs(wb, analysis_refs) + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + return ctx.add_node("EFFECT", config, input_ids) + + +def _compile_deftemplate(expr: List, ctx: CompilerContext) -> None: + """Compile (deftemplate NAME (PARAMS...) BODY...). + + Stores the template definition in the registry for later invocation. + Returns None (definition only, no nodes). + """ + if len(expr) < 4: + raise CompileError("deftemplate requires name, params, and body") + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + params = expr[2] + if not isinstance(params, list): + raise CompileError("deftemplate params must be a list") + + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + else: + raise CompileError(f"deftemplate param must be a symbol, got {p}") + + body_forms = expr[3:] + + ctx.registry["templates"][name] = { + "params": param_names, + "body": body_forms, + } + return None + + +def _substitute_template(expr, params_map, local_names, prefix): + """Deep walk s-expression tree, substituting params and prefixing locals.""" + if isinstance(expr, Symbol): + if expr.name in params_map: + return params_map[expr.name] + if expr.name in local_names: + return Symbol(prefix + expr.name) + return expr + if isinstance(expr, list): + return [_substitute_template(e, params_map, local_names, prefix) for e in expr] + if isinstance(expr, dict): + if expr.get("_rng"): + return expr # preserve shared mutable counter + return {k: _substitute_template(v, params_map, local_names, prefix) for k, v in expr.items()} + return expr # numbers, strings, keywords, etc. + + +def _compile_template_call(expr: List, ctx: CompilerContext) -> str: + """Compile a call to a user-defined template. + + Expands the template body with parameter substitution and local name + prefixing, then compiles each resulting form. + """ + name = expr[0].name + template = ctx.registry["templates"][name] + param_names = template["params"] + body_forms = template["body"] + + # Parse keyword args from invocation + _, kwargs = _parse_kwargs(expr, 1) + + # Build param -> value map + params_map = {} + for pname in param_names: + # Convert param name to kwarg key (hyphens match keyword names) + key = pname + if key not in kwargs: + raise CompileError(f"Template '{name}' missing parameter :{key}") + params_map[pname] = kwargs[key] + + # Generate unique prefix + prefix = f"_t{ctx.template_call_count}_" + ctx.template_call_count += 1 + + # Collect local names: scan body for (def NAME ...) forms + local_names = set() + for form in body_forms: + if isinstance(form, list) and len(form) >= 2: + if isinstance(form[0], Symbol) and form[0].name == "def": + if isinstance(form[1], Symbol): + local_names.add(form[1].name) + + # Substitute and compile each body form + last_node_id = None + for form in body_forms: + substituted = _substitute_template(form, params_map, local_names, prefix) + result = _compile_expr(substituted, ctx) + if result is not None: + last_node_id = result + + return last_node_id + + +def compile_string(text: str, initial_bindings: Dict[str, Any] = None, recipe_dir: Path = None) -> CompiledRecipe: + """ + Compile an S-expression recipe string. + + Convenience function combining parse + compile. + + Args: + text: S-expression recipe string + initial_bindings: Optional dict of name -> value bindings to inject before compilation. + These can be referenced as variables in the recipe. + recipe_dir: Directory containing the recipe file, for resolving relative paths to effects etc. + """ + sexp = parse(text) + return compile_recipe(sexp, initial_bindings, recipe_dir=recipe_dir, source_text=text) diff --git a/artdag/sexp/effect_loader.py b/artdag/sexp/effect_loader.py new file mode 100644 index 0000000..bd7ce62 --- /dev/null +++ b/artdag/sexp/effect_loader.py @@ -0,0 +1,337 @@ +""" +Sexp effect loader. + +Loads sexp effect definitions (define-effect forms) and creates +frame processors that evaluate the sexp body with primitives. + +Effects must use :params syntax: + + (define-effect name + :params ( + (param1 :type int :default 8 :range [4 32] :desc "description") + (param2 :type string :default "value" :desc "description") + ) + body) + +For effects with no parameters, use empty :params (): + + (define-effect name + :params () + body) + +Unknown parameters passed to effects will raise an error. + +Usage: + loader = SexpEffectLoader() + effect_fn = loader.load_effect_file(Path("effects/ascii_art.sexp")) + output = effect_fn(input_path, output_path, config) +""" + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from .parser import parse_all, Symbol, Keyword +from .evaluator import evaluate +from .primitives import PRIMITIVES +from .compiler import ParamDef, _parse_params, CompileError + +logger = logging.getLogger(__name__) + + +def _parse_define_effect(sexp) -> tuple: + """ + Parse a define-effect form. + + Required syntax: + (define-effect name + :params ( + (param1 :type int :default 8 :range [4 32] :desc "description") + ) + body) + + Effects MUST use :params syntax. Legacy ((param default) ...) syntax is not supported. + + Returns (name, params_with_defaults, param_defs, body) + where param_defs is a list of ParamDef objects + """ + if not isinstance(sexp, list) or len(sexp) < 3: + raise ValueError(f"Invalid define-effect form: {sexp}") + + head = sexp[0] + if not (isinstance(head, Symbol) and head.name == "define-effect"): + raise ValueError(f"Expected define-effect, got {head}") + + name = sexp[1] + if isinstance(name, Symbol): + name = name.name + + params_with_defaults = {} + param_defs: List[ParamDef] = [] + body = None + found_params = False + + # Parse :params and body + i = 2 + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword) and item.name == "params": + # :params syntax + if i + 1 >= len(sexp): + raise ValueError(f"Effect '{name}': Missing params list after :params keyword") + try: + param_defs = _parse_params(sexp[i + 1]) + # Build params_with_defaults from ParamDef objects + for pd in param_defs: + params_with_defaults[pd.name] = pd.default + except CompileError as e: + raise ValueError(f"Effect '{name}': Error parsing :params: {e}") + found_params = True + i += 2 + elif isinstance(item, Keyword): + # Skip other keywords we don't recognize + i += 2 + elif body is None: + # First non-keyword item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy syntax and reject it + if isinstance(first_elem, list) and len(first_elem) >= 2: + raise ValueError( + f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. " + f"Use :params block instead:\n" + f" :params (\n" + f" (param_name :type int :default 0 :desc \"description\")\n" + f" )" + ) + body = item + i += 1 + else: + i += 1 + + if body is None: + raise ValueError(f"Effect '{name}': No body found") + + if not found_params: + raise ValueError( + f"Effect '{name}': Missing :params block. Effects must declare parameters.\n" + f"For effects with no parameters, use empty :params ():\n" + f" (define-effect {name}\n" + f" :params ()\n" + f" body)" + ) + + return name, params_with_defaults, param_defs, body + + +def _create_process_frame( + effect_name: str, + params_with_defaults: Dict[str, Any], + param_defs: List[ParamDef], + body: Any, +) -> Callable: + """ + Create a process_frame function that evaluates the sexp body. + + The function signature is: (frame, params, state) -> (frame, state) + """ + import math + + def process_frame(frame: np.ndarray, params: Dict[str, Any], state: Any): + """Evaluate sexp effect body on a frame.""" + # Build environment with primitives + env = dict(PRIMITIVES) + + # Add math functions + env["floor"] = lambda x: int(math.floor(x)) + env["ceil"] = lambda x: int(math.ceil(x)) + env["round"] = lambda x: int(round(x)) + env["abs"] = abs + env["min"] = min + env["max"] = max + env["sqrt"] = math.sqrt + env["sin"] = math.sin + env["cos"] = math.cos + + # Add list operations + env["list"] = lambda *args: tuple(args) + env["nth"] = lambda coll, i: coll[int(i)] if coll else None + + # Bind frame + env["frame"] = frame + + # Validate that all provided params are known + known_params = set(params_with_defaults.keys()) + for k in params.keys(): + if k not in known_params: + raise ValueError( + f"Effect '{effect_name}': Unknown parameter '{k}'. " + f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}" + ) + + # Bind parameters (defaults + overrides from config) + for param_name, default in params_with_defaults.items(): + # Use config value if provided, otherwise default + if param_name in params: + env[param_name] = params[param_name] + elif default is not None: + env[param_name] = default + + # Evaluate the body + try: + result = evaluate(body, env) + if isinstance(result, np.ndarray): + return result, state + else: + logger.warning(f"Effect {effect_name} returned {type(result)}, expected ndarray") + return frame, state + except Exception as e: + logger.error(f"Error evaluating effect {effect_name}: {e}") + raise + + return process_frame + + +def load_sexp_effect(source: str, base_path: Optional[Path] = None) -> tuple: + """ + Load a sexp effect from source code. + + Args: + source: Sexp source code + base_path: Base path for resolving relative imports + + Returns: + (effect_name, process_frame_fn, params_with_defaults, param_defs) + where param_defs is a list of ParamDef objects for introspection + """ + exprs = parse_all(source) + + # Find define-effect form + define_effect = None + if isinstance(exprs, list): + for expr in exprs: + if isinstance(expr, list) and expr and isinstance(expr[0], Symbol): + if expr[0].name == "define-effect": + define_effect = expr + break + elif isinstance(exprs, list) and exprs and isinstance(exprs[0], Symbol): + if exprs[0].name == "define-effect": + define_effect = exprs + + if not define_effect: + raise ValueError("No define-effect form found in sexp effect") + + name, params_with_defaults, param_defs, body = _parse_define_effect(define_effect) + process_frame = _create_process_frame(name, params_with_defaults, param_defs, body) + + return name, process_frame, params_with_defaults, param_defs + + +def load_sexp_effect_file(path: Path) -> tuple: + """ + Load a sexp effect from file. + + Returns: + (effect_name, process_frame_fn, params_with_defaults, param_defs) + where param_defs is a list of ParamDef objects for introspection + """ + source = path.read_text() + return load_sexp_effect(source, base_path=path.parent) + + +class SexpEffectLoader: + """ + Loader for sexp effect definitions. + + Creates effect functions compatible with the EffectExecutor. + """ + + def __init__(self, recipe_dir: Optional[Path] = None): + """ + Initialize loader. + + Args: + recipe_dir: Base directory for resolving relative effect paths + """ + self.recipe_dir = recipe_dir or Path.cwd() + # Cache loaded effects with their param_defs for introspection + self._loaded_effects: Dict[str, tuple] = {} + + def load_effect_path(self, effect_path: str) -> Callable: + """ + Load a sexp effect from a relative path. + + Args: + effect_path: Relative path to effect .sexp file + + Returns: + Effect function (input_path, output_path, config) -> output_path + """ + from ..effects.frame_processor import process_video + + full_path = self.recipe_dir / effect_path + if not full_path.exists(): + raise FileNotFoundError(f"Sexp effect not found: {full_path}") + + name, process_frame_fn, params_defaults, param_defs = load_sexp_effect_file(full_path) + logger.info(f"Loaded sexp effect: {name} from {effect_path}") + + # Cache for introspection + self._loaded_effects[effect_path] = (name, params_defaults, param_defs) + + def effect_fn(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run sexp effect via frame processor.""" + # Extract params (excluding internal keys) + params = dict(params_defaults) # Start with defaults + for k, v in config.items(): + if k not in ("effect", "cid", "hash", "effect_path", "_binding"): + params[k] = v + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=process_frame_fn, + params=params, + bindings=bindings, + ) + + logger.info(f"Processed sexp effect '{name}' from {effect_path}") + return actual_output + + return effect_fn + + def get_effect_params(self, effect_path: str) -> List[ParamDef]: + """ + Get parameter definitions for an effect. + + Args: + effect_path: Relative path to effect .sexp file + + Returns: + List of ParamDef objects describing the effect's parameters + """ + if effect_path not in self._loaded_effects: + # Load the effect to get its params + full_path = self.recipe_dir / effect_path + if not full_path.exists(): + raise FileNotFoundError(f"Sexp effect not found: {full_path}") + name, _, params_defaults, param_defs = load_sexp_effect_file(full_path) + self._loaded_effects[effect_path] = (name, params_defaults, param_defs) + + return self._loaded_effects[effect_path][2] + + +def get_sexp_effect_loader(recipe_dir: Optional[Path] = None) -> SexpEffectLoader: + """Get a sexp effect loader instance.""" + return SexpEffectLoader(recipe_dir) diff --git a/artdag/sexp/evaluator.py b/artdag/sexp/evaluator.py new file mode 100644 index 0000000..5e3b175 --- /dev/null +++ b/artdag/sexp/evaluator.py @@ -0,0 +1,869 @@ +""" +Expression evaluator for S-expression DSL. + +Supports: +- Arithmetic: +, -, *, /, mod, sqrt, pow, abs, floor, ceil, round, min, max, clamp +- Comparison: =, <, >, <=, >= +- Logic: and, or, not +- Predicates: odd?, even?, zero?, nil? +- Conditionals: if, cond, case +- Data: list, dict/map construction, get +- Lambda calls +""" + +from typing import Any, Dict, List, Callable +from .parser import Symbol, Keyword, Lambda, Binding + + +class EvalError(Exception): + """Error during expression evaluation.""" + pass + + +# Built-in functions +BUILTINS: Dict[str, Callable] = {} + + +def builtin(name: str): + """Decorator to register a builtin function.""" + def decorator(fn): + BUILTINS[name] = fn + return fn + return decorator + + +@builtin("+") +def add(*args): + return sum(args) + + +@builtin("-") +def sub(a, b=None): + if b is None: + return -a + return a - b + + +@builtin("*") +def mul(*args): + result = 1 + for a in args: + result *= a + return result + + +@builtin("/") +def div(a, b): + return a / b + + +@builtin("mod") +def mod(a, b): + return a % b + + +@builtin("sqrt") +def sqrt(x): + return x ** 0.5 + + +@builtin("pow") +def power(x, n): + return x ** n + + +@builtin("abs") +def absolute(x): + return abs(x) + + +@builtin("floor") +def floor_fn(x): + import math + return math.floor(x) + + +@builtin("ceil") +def ceil_fn(x): + import math + return math.ceil(x) + + +@builtin("round") +def round_fn(x, ndigits=0): + return round(x, int(ndigits)) + + +@builtin("min") +def min_fn(*args): + if len(args) == 1 and isinstance(args[0], (list, tuple)): + return min(args[0]) + return min(args) + + +@builtin("max") +def max_fn(*args): + if len(args) == 1 and isinstance(args[0], (list, tuple)): + return max(args[0]) + return max(args) + + +@builtin("clamp") +def clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +@builtin("=") +def eq(a, b): + return a == b + + +@builtin("<") +def lt(a, b): + return a < b + + +@builtin(">") +def gt(a, b): + return a > b + + +@builtin("<=") +def lte(a, b): + return a <= b + + +@builtin(">=") +def gte(a, b): + return a >= b + + +@builtin("odd?") +def is_odd(n): + return n % 2 == 1 + + +@builtin("even?") +def is_even(n): + return n % 2 == 0 + + +@builtin("zero?") +def is_zero(n): + return n == 0 + + +@builtin("nil?") +def is_nil(x): + return x is None + + +@builtin("not") +def not_fn(x): + return not x + + +@builtin("inc") +def inc(n): + return n + 1 + + +@builtin("dec") +def dec(n): + return n - 1 + + +@builtin("list") +def make_list(*args): + return list(args) + + +@builtin("assert") +def assert_true(condition, message="Assertion failed"): + if not condition: + raise RuntimeError(f"Assertion error: {message}") + return True + + +@builtin("get") +def get(coll, key, default=None): + if isinstance(coll, dict): + # Try the key directly first + result = coll.get(key, None) + if result is not None: + return result + # If key is a Keyword, also try its string name (for Python dicts with string keys) + if isinstance(key, Keyword): + result = coll.get(key.name, None) + if result is not None: + return result + # Return the default + return default + elif isinstance(coll, list): + return coll[key] if 0 <= key < len(coll) else default + else: + raise EvalError(f"get: expected dict or list, got {type(coll).__name__}: {str(coll)[:100]}") + + +@builtin("dict?") +def is_dict(x): + return isinstance(x, dict) + + +@builtin("list?") +def is_list(x): + return isinstance(x, list) + + +@builtin("nil?") +def is_nil(x): + return x is None + + +@builtin("number?") +def is_number(x): + return isinstance(x, (int, float)) + + +@builtin("string?") +def is_string(x): + return isinstance(x, str) + + +@builtin("len") +def length(coll): + return len(coll) + + +@builtin("first") +def first(coll): + return coll[0] if coll else None + + +@builtin("last") +def last(coll): + return coll[-1] if coll else None + + +@builtin("chunk-every") +def chunk_every(coll, n): + """Split collection into chunks of n elements.""" + n = int(n) + return [coll[i:i+n] for i in range(0, len(coll), n)] + + +@builtin("rest") +def rest(coll): + return coll[1:] if coll else [] + + +@builtin("nth") +def nth(coll, n): + return coll[n] if 0 <= n < len(coll) else None + + +@builtin("concat") +def concat(*colls): + """Concatenate multiple lists/sequences.""" + result = [] + for c in colls: + if c is not None: + result.extend(c) + return result + + +@builtin("cons") +def cons(x, coll): + """Prepend x to collection.""" + return [x] + list(coll) if coll else [x] + + +@builtin("append") +def append(coll, x): + """Append x to collection.""" + return list(coll) + [x] if coll else [x] + + +@builtin("range") +def make_range(start, end, step=1): + """Create a range of numbers.""" + return list(range(int(start), int(end), int(step))) + + +@builtin("zip-pairs") +def zip_pairs(coll): + """Zip consecutive pairs: [a,b,c,d] -> [[a,b],[b,c],[c,d]].""" + if not coll or len(coll) < 2: + return [] + return [[coll[i], coll[i+1]] for i in range(len(coll)-1)] + + +@builtin("dict") +def make_dict(*pairs): + """Create dict from key-value pairs: (dict :a 1 :b 2).""" + result = {} + i = 0 + while i < len(pairs) - 1: + key = pairs[i] + if isinstance(key, Keyword): + key = key.name + result[key] = pairs[i + 1] + i += 2 + return result + + +@builtin("keys") +def keys(d): + """Get the keys of a dict as a list.""" + if not isinstance(d, dict): + raise EvalError(f"keys: expected dict, got {type(d).__name__}") + return list(d.keys()) + + +@builtin("vals") +def vals(d): + """Get the values of a dict as a list.""" + if not isinstance(d, dict): + raise EvalError(f"vals: expected dict, got {type(d).__name__}") + return list(d.values()) + + +@builtin("merge") +def merge(*dicts): + """Merge multiple dicts, later dicts override earlier.""" + result = {} + for d in dicts: + if d is not None: + if not isinstance(d, dict): + raise EvalError(f"merge: expected dict, got {type(d).__name__}") + result.update(d) + return result + + +@builtin("assoc") +def assoc(d, *pairs): + """Associate keys with values in a dict: (assoc d :a 1 :b 2).""" + if d is None: + result = {} + elif isinstance(d, dict): + result = dict(d) + else: + raise EvalError(f"assoc: expected dict or nil, got {type(d).__name__}") + + i = 0 + while i < len(pairs) - 1: + key = pairs[i] + if isinstance(key, Keyword): + key = key.name + result[key] = pairs[i + 1] + i += 2 + return result + + +@builtin("dissoc") +def dissoc(d, *keys_to_remove): + """Remove keys from a dict: (dissoc d :a :b).""" + if d is None: + return {} + if not isinstance(d, dict): + raise EvalError(f"dissoc: expected dict or nil, got {type(d).__name__}") + + result = dict(d) + for key in keys_to_remove: + if isinstance(key, Keyword): + key = key.name + result.pop(key, None) + return result + + +@builtin("into") +def into(target, coll): + """Convert a collection into another collection type. + + (into [] {:a 1 :b 2}) -> [["a" 1] ["b" 2]] + (into {} [[:a 1] [:b 2]]) -> {"a": 1, "b": 2} + (into [] [1 2 3]) -> [1 2 3] + """ + if isinstance(target, list): + if isinstance(coll, dict): + return [[k, v] for k, v in coll.items()] + elif isinstance(coll, (list, tuple)): + return list(coll) + else: + raise EvalError(f"into: cannot convert {type(coll).__name__} into list") + elif isinstance(target, dict): + if isinstance(coll, dict): + return dict(coll) + elif isinstance(coll, (list, tuple)): + result = {} + for item in coll: + if isinstance(item, (list, tuple)) and len(item) >= 2: + key = item[0] + if isinstance(key, Keyword): + key = key.name + result[key] = item[1] + else: + raise EvalError(f"into: expected [key value] pairs, got {item}") + return result + else: + raise EvalError(f"into: cannot convert {type(coll).__name__} into dict") + else: + raise EvalError(f"into: unsupported target type {type(target).__name__}") + + +@builtin("filter") +def filter_fn(pred, coll): + """Filter collection by predicate. Pred must be a lambda.""" + if not isinstance(pred, Lambda): + raise EvalError(f"filter: expected lambda as predicate, got {type(pred).__name__}") + + result = [] + for item in coll: + # Evaluate predicate with item + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + # Inline evaluation of pred.body + from . import evaluator + if evaluator.evaluate(pred.body, local_env): + result.append(item) + return result + + +@builtin("some") +def some(pred, coll): + """Return first truthy value of (pred item) for items in coll, or nil.""" + if not isinstance(pred, Lambda): + raise EvalError(f"some: expected lambda as predicate, got {type(pred).__name__}") + + for item in coll: + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + from . import evaluator + result = evaluator.evaluate(pred.body, local_env) + if result: + return result + return None + + +@builtin("every?") +def every(pred, coll): + """Return true if (pred item) is truthy for all items in coll.""" + if not isinstance(pred, Lambda): + raise EvalError(f"every?: expected lambda as predicate, got {type(pred).__name__}") + + for item in coll: + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + from . import evaluator + if not evaluator.evaluate(pred.body, local_env): + return False + return True + + +@builtin("empty?") +def is_empty(coll): + """Return true if collection is empty.""" + if coll is None: + return True + return len(coll) == 0 + + +@builtin("contains?") +def contains(coll, key): + """Check if collection contains key (for dicts) or element (for lists).""" + if isinstance(coll, dict): + if isinstance(key, Keyword): + key = key.name + return key in coll + elif isinstance(coll, (list, tuple)): + return key in coll + return False + + +def evaluate(expr: Any, env: Dict[str, Any] = None) -> Any: + """ + Evaluate an S-expression in the given environment. + + Args: + expr: The expression to evaluate + env: Variable bindings (name -> value) + + Returns: + The result of evaluation + """ + if env is None: + env = {} + + # Literals + if isinstance(expr, (int, float, str, bool)) or expr is None: + return expr + + # Symbol - variable lookup + if isinstance(expr, Symbol): + name = expr.name + if name in env: + return env[name] + if name in BUILTINS: + return BUILTINS[name] + if name == "true": + return True + if name == "false": + return False + if name == "nil": + return None + raise EvalError(f"Undefined symbol: {name}") + + # Keyword - return as-is (used as map keys) + if isinstance(expr, Keyword): + return expr.name + + # Lambda - return as-is (it's a value) + if isinstance(expr, Lambda): + return expr + + # Binding - return as-is (resolved at execution time) + if isinstance(expr, Binding): + return expr + + # Dict literal + if isinstance(expr, dict): + return {k: evaluate(v, env) for k, v in expr.items()} + + # List - function call or special form + if isinstance(expr, list): + if not expr: + return [] + + head = expr[0] + + # If head is a string/number/etc (not Symbol), treat as data list + if not isinstance(head, (Symbol, Lambda, list)): + return [evaluate(x, env) for x in expr] + + # Special forms + if isinstance(head, Symbol): + name = head.name + + # if - conditional + if name == "if": + if len(expr) < 3: + raise EvalError("if requires condition and then-branch") + cond_result = evaluate(expr[1], env) + if cond_result: + return evaluate(expr[2], env) + elif len(expr) > 3: + return evaluate(expr[3], env) + return None + + # cond - multi-way conditional + # Supports both Clojure style: (cond test1 result1 test2 result2 :else default) + # and Scheme style: (cond (test1 result1) (test2 result2) (else default)) + if name == "cond": + clauses = expr[1:] + # Check if Clojure style (flat list) or Scheme style (nested pairs) + # Scheme style: first clause is (test result) - exactly 2 elements + # Clojure style: first clause is a test expression like (= x 0) - 3+ elements + first_is_scheme_clause = ( + clauses and + isinstance(clauses[0], list) and + len(clauses[0]) == 2 and + not (isinstance(clauses[0][0], Symbol) and clauses[0][0].name in ('=', '<', '>', '<=', '>=', '!=', 'not=', 'and', 'or')) + ) + if first_is_scheme_clause: + # Scheme style: ((test result) ...) + for clause in clauses: + if not isinstance(clause, list) or len(clause) < 2: + raise EvalError("cond clause must be (test result)") + test = clause[0] + # Check for else/default + if isinstance(test, Symbol) and test.name in ("else", ":else"): + return evaluate(clause[1], env) + if isinstance(test, Keyword) and test.name == "else": + return evaluate(clause[1], env) + if evaluate(test, env): + return evaluate(clause[1], env) + else: + # Clojure style: test1 result1 test2 result2 ... + i = 0 + while i < len(clauses) - 1: + test = clauses[i] + result = clauses[i + 1] + # Check for :else + if isinstance(test, Keyword) and test.name == "else": + return evaluate(result, env) + if isinstance(test, Symbol) and test.name == ":else": + return evaluate(result, env) + if evaluate(test, env): + return evaluate(result, env) + i += 2 + return None + + # case - switch on value + # (case expr val1 result1 val2 result2 :else default) + if name == "case": + if len(expr) < 2: + raise EvalError("case requires expression to match") + match_val = evaluate(expr[1], env) + clauses = expr[2:] + i = 0 + while i < len(clauses) - 1: + test = clauses[i] + result = clauses[i + 1] + # Check for :else / else + if isinstance(test, Keyword) and test.name == "else": + return evaluate(result, env) + if isinstance(test, Symbol) and test.name in (":else", "else"): + return evaluate(result, env) + # Evaluate test value and compare + test_val = evaluate(test, env) + if match_val == test_val: + return evaluate(result, env) + i += 2 + return None + + # and - short-circuit + if name == "and": + result = True + for arg in expr[1:]: + result = evaluate(arg, env) + if not result: + return result + return result + + # or - short-circuit + if name == "or": + result = False + for arg in expr[1:]: + result = evaluate(arg, env) + if result: + return result + return result + + # let and let* - local bindings (both bind sequentially in this impl) + if name in ("let", "let*"): + if len(expr) < 3: + raise EvalError(f"{name} requires bindings and body") + bindings = expr[1] + + local_env = dict(env) + + if isinstance(bindings, list): + # Check if it's ((name value) ...) style (Lisp let* style) + if bindings and isinstance(bindings[0], list): + for binding in bindings: + if len(binding) != 2: + raise EvalError(f"{name} binding must be (name value)") + var_name = binding[0] + if isinstance(var_name, Symbol): + var_name = var_name.name + value = evaluate(binding[1], local_env) + local_env[var_name] = value + # Vector-style [name value ...] + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var_name = bindings[i] + if isinstance(var_name, Symbol): + var_name = var_name.name + value = evaluate(bindings[i + 1], local_env) + local_env[var_name] = value + else: + raise EvalError(f"{name} bindings must be [name value ...] or ((name value) ...)") + else: + raise EvalError(f"{name} bindings must be a list") + + return evaluate(expr[2], local_env) + + # lambda / fn - create function with closure + if name in ("lambda", "fn"): + if len(expr) < 3: + raise EvalError("lambda requires params and body") + params = expr[1] + if not isinstance(params, list): + raise EvalError("lambda params must be a list") + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + elif isinstance(p, str): + param_names.append(p) + else: + raise EvalError(f"Invalid param: {p}") + # Capture current environment as closure + return Lambda(param_names, expr[2], dict(env)) + + # quote - return unevaluated + if name == "quote": + return expr[1] if len(expr) > 1 else None + + # bind - create binding to analysis data + # (bind analysis-var) + # (bind analysis-var :range [0.3 1.0]) + # (bind analysis-var :range [0 100] :transform sqrt) + if name == "bind": + if len(expr) < 2: + raise EvalError("bind requires analysis reference") + analysis_ref = expr[1] + if isinstance(analysis_ref, Symbol): + symbol_name = analysis_ref.name + # Look up the symbol in environment + if symbol_name in env: + resolved = env[symbol_name] + # If resolved is actual analysis data (dict with times/values or + # S-expression list with Keywords), keep the symbol name as reference + # for later lookup at execution time + if isinstance(resolved, dict) and ("times" in resolved or "values" in resolved): + analysis_ref = symbol_name # Use name as reference, not the data + elif isinstance(resolved, list) and any(isinstance(x, Keyword) for x in resolved): + # Parsed S-expression analysis data ([:times [...] :duration ...]) + analysis_ref = symbol_name + else: + analysis_ref = resolved + else: + raise EvalError(f"bind: undefined symbol '{symbol_name}' - must reference analysis data") + + # Parse optional :range [min max] and :transform + range_min, range_max = 0.0, 1.0 + transform = None + i = 2 + while i < len(expr): + if isinstance(expr[i], Keyword): + kw = expr[i].name + if kw == "range" and i + 1 < len(expr): + range_val = evaluate(expr[i + 1], env) # Evaluate to get actual value + if isinstance(range_val, list) and len(range_val) >= 2: + range_min = float(range_val[0]) + range_max = float(range_val[1]) + i += 2 + elif kw == "transform" and i + 1 < len(expr): + t = expr[i + 1] + if isinstance(t, Symbol): + transform = t.name + elif isinstance(t, str): + transform = t + i += 2 + else: + i += 1 + else: + i += 1 + + return Binding(analysis_ref, range_min=range_min, range_max=range_max, transform=transform) + + # Vector literal [a b c] + if name == "vec" or name == "vector": + return [evaluate(e, env) for e in expr[1:]] + + # map - (map fn coll) + if name == "map": + if len(expr) != 3: + raise EvalError("map requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"map requires lambda, got {type(fn)}") + result = [] + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = item + result.append(evaluate(fn.body, local_env)) + return result + + # map-indexed - (map-indexed fn coll) + if name == "map-indexed": + if len(expr) != 3: + raise EvalError("map-indexed requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"map-indexed requires lambda, got {type(fn)}") + if len(fn.params) < 2: + raise EvalError("map-indexed lambda needs (i item) params") + result = [] + for i, item in enumerate(coll): + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = i + local_env[fn.params[1]] = item + result.append(evaluate(fn.body, local_env)) + return result + + # reduce - (reduce fn init coll) + if name == "reduce": + if len(expr) != 4: + raise EvalError("reduce requires fn, init, and collection") + fn = evaluate(expr[1], env) + acc = evaluate(expr[2], env) + coll = evaluate(expr[3], env) + if not isinstance(fn, Lambda): + raise EvalError(f"reduce requires lambda, got {type(fn)}") + if len(fn.params) < 2: + raise EvalError("reduce lambda needs (acc item) params") + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = acc + local_env[fn.params[1]] = item + acc = evaluate(fn.body, local_env) + return acc + + # for-each - (for-each fn coll) - iterate with side effects + if name == "for-each": + if len(expr) != 3: + raise EvalError("for-each requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"for-each requires lambda, got {type(fn)}") + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = item + evaluate(fn.body, local_env) + return None + + # Function call + fn = evaluate(head, env) + args = [evaluate(arg, env) for arg in expr[1:]] + + # Call builtin + if callable(fn): + return fn(*args) + + # Call lambda + if isinstance(fn, Lambda): + if len(args) != len(fn.params): + raise EvalError(f"Lambda expects {len(fn.params)} args, got {len(args)}") + # Start with closure (captured env), then overlay calling env, then params + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + for name, value in zip(fn.params, args): + local_env[name] = value + return evaluate(fn.body, local_env) + + raise EvalError(f"Not callable: {fn}") + + raise EvalError(f"Cannot evaluate: {expr!r}") + + +def make_env(**kwargs) -> Dict[str, Any]: + """Create an environment with initial bindings.""" + return dict(kwargs) diff --git a/artdag/sexp/external_tools.py b/artdag/sexp/external_tools.py new file mode 100644 index 0000000..fea13e2 --- /dev/null +++ b/artdag/sexp/external_tools.py @@ -0,0 +1,292 @@ +""" +External tool runners for effects that can't be done in FFmpeg. + +Supports: +- datamosh: via ffglitch or datamoshing Python CLI +- pixelsort: via Rust pixelsort or Python pixelsort-cli +""" + +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def find_tool(tool_names: List[str]) -> Optional[str]: + """Find first available tool from a list of candidates.""" + for name in tool_names: + path = shutil.which(name) + if path: + return path + return None + + +def check_python_package(package: str) -> bool: + """Check if a Python package is installed.""" + try: + result = subprocess.run( + ["python3", "-c", f"import {package}"], + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except Exception: + return False + + +# Tool detection +DATAMOSH_TOOLS = ["ffgac", "ffedit"] # ffglitch tools +PIXELSORT_TOOLS = ["pixelsort"] # Rust CLI + + +def get_available_tools() -> Dict[str, Optional[str]]: + """Get dictionary of available external tools.""" + return { + "datamosh": find_tool(DATAMOSH_TOOLS), + "pixelsort": find_tool(PIXELSORT_TOOLS), + "datamosh_python": "datamoshing" if check_python_package("datamoshing") else None, + "pixelsort_python": "pixelsort" if check_python_package("pixelsort") else None, + } + + +def run_datamosh( + input_path: Path, + output_path: Path, + params: Dict[str, Any], +) -> Tuple[bool, str]: + """ + Run datamosh effect using available tool. + + Args: + input_path: Input video file + output_path: Output video file + params: Effect parameters (corruption, block_size, etc.) + + Returns: + (success, error_message) + """ + tools = get_available_tools() + + corruption = params.get("corruption", 0.3) + + # Try ffglitch first + if tools.get("datamosh"): + ffgac = tools["datamosh"] + try: + # ffglitch approach: remove I-frames to create datamosh effect + # This is a simplified version - full datamosh needs custom scripts + with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f: + # Write a simple ffglitch script that corrupts motion vectors + f.write(f""" +// Datamosh script - corrupt motion vectors +let corruption = {corruption}; + +export function glitch_frame(frame, stream) {{ + if (frame.pict_type === 'P' && Math.random() < corruption) {{ + // Corrupt motion vectors + let dominated = frame.mv?.forward?.overflow; + if (dominated) {{ + for (let i = 0; i < dominated.length; i++) {{ + if (Math.random() < corruption) {{ + dominated[i] = [ + Math.floor(Math.random() * 64 - 32), + Math.floor(Math.random() * 64 - 32) + ]; + }} + }} + }} + }} + return frame; +}} +""") + script_path = f.name + + cmd = [ + ffgac, + "-i", str(input_path), + "-s", script_path, + "-o", str(output_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + Path(script_path).unlink(missing_ok=True) + + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + + except subprocess.TimeoutExpired: + return False, "Datamosh timeout" + except Exception as e: + return False, str(e) + + # Fall back to Python datamoshing package + if tools.get("datamosh_python"): + try: + cmd = [ + "python3", "-m", "datamoshing", + str(input_path), + str(output_path), + "--mode", "iframe_removal", + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + return False, "No datamosh tool available. Install ffglitch or: pip install datamoshing" + + +def run_pixelsort( + input_path: Path, + output_path: Path, + params: Dict[str, Any], +) -> Tuple[bool, str]: + """ + Run pixelsort effect using available tool. + + Args: + input_path: Input image/frame file + output_path: Output image file + params: Effect parameters (sort_by, threshold_low, threshold_high, angle) + + Returns: + (success, error_message) + """ + tools = get_available_tools() + + sort_by = params.get("sort_by", "lightness") + threshold_low = params.get("threshold_low", 50) + threshold_high = params.get("threshold_high", 200) + angle = params.get("angle", 0) + + # Try Rust pixelsort first (faster) + if tools.get("pixelsort"): + try: + cmd = [ + tools["pixelsort"], + str(input_path), + "-o", str(output_path), + "--sort", sort_by, + "-r", str(angle), + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + # Fall back to Python pixelsort-cli + if tools.get("pixelsort_python"): + try: + cmd = [ + "python3", "-m", "pixelsort", + "--image_path", str(input_path), + "--output", str(output_path), + "--angle", str(angle), + "--threshold", str(threshold_low / 255), # Normalize to 0-1 + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + return False, "No pixelsort tool available. Install: cargo install pixelsort or pip install pixelsort-cli" + + +def run_pixelsort_video( + input_path: Path, + output_path: Path, + params: Dict[str, Any], + fps: float = 30, +) -> Tuple[bool, str]: + """ + Run pixelsort on a video by processing each frame. + + This extracts frames, processes them, then reassembles. + """ + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + frames_in = tmpdir / "frame_%04d.png" + frames_out = tmpdir / "sorted_%04d.png" + + # Extract frames + extract_cmd = [ + "ffmpeg", "-y", + "-i", str(input_path), + "-vf", f"fps={fps}", + str(frames_in), + ] + result = subprocess.run(extract_cmd, capture_output=True, timeout=300) + if result.returncode != 0: + return False, "Failed to extract frames" + + # Process each frame + frame_files = sorted(tmpdir.glob("frame_*.png")) + for i, frame in enumerate(frame_files): + out_frame = tmpdir / f"sorted_{i:04d}.png" + success, error = run_pixelsort(frame, out_frame, params) + if not success: + return False, f"Frame {i}: {error}" + + # Reassemble + # Get audio from original + reassemble_cmd = [ + "ffmpeg", "-y", + "-framerate", str(fps), + "-i", str(tmpdir / "sorted_%04d.png"), + "-i", str(input_path), + "-map", "0:v", "-map", "1:a?", + "-c:v", "libx264", "-preset", "fast", + "-c:a", "copy", + str(output_path), + ] + result = subprocess.run(reassemble_cmd, capture_output=True, timeout=300) + if result.returncode != 0: + return False, "Failed to reassemble video" + + return True, "" + + +def run_external_effect( + effect_name: str, + input_path: Path, + output_path: Path, + params: Dict[str, Any], + is_video: bool = True, +) -> Tuple[bool, str]: + """ + Run an external effect tool. + + Args: + effect_name: Name of effect (datamosh, pixelsort) + input_path: Input file + output_path: Output file + params: Effect parameters + is_video: Whether input is video (vs single image) + + Returns: + (success, error_message) + """ + if effect_name == "datamosh": + return run_datamosh(input_path, output_path, params) + elif effect_name == "pixelsort": + if is_video: + return run_pixelsort_video(input_path, output_path, params) + else: + return run_pixelsort(input_path, output_path, params) + else: + return False, f"Unknown external effect: {effect_name}" + + +if __name__ == "__main__": + # Print available tools + print("Available external tools:") + for name, path in get_available_tools().items(): + status = path if path else "NOT INSTALLED" + print(f" {name}: {status}") diff --git a/artdag/sexp/ffmpeg_compiler.py b/artdag/sexp/ffmpeg_compiler.py new file mode 100644 index 0000000..d69508e --- /dev/null +++ b/artdag/sexp/ffmpeg_compiler.py @@ -0,0 +1,616 @@ +""" +FFmpeg filter compiler for sexp effects. + +Compiles sexp effect definitions to FFmpeg filter expressions, +with support for dynamic parameters via sendcmd scripts. + +Usage: + compiler = FFmpegCompiler() + + # Compile an effect with static params + filter_str = compiler.compile_effect("brightness", {"amount": 50}) + # -> "eq=brightness=0.196" + + # Compile with dynamic binding to analysis data + filter_str, sendcmd = compiler.compile_effect_with_binding( + "brightness", + {"amount": {"_bind": "bass-data", "range_min": 0, "range_max": 100}}, + analysis_data={"bass-data": {"times": [...], "values": [...]}}, + segment_start=0.0, + segment_duration=5.0, + ) + # -> ("eq=brightness=0.5", "0.0 [eq] brightness 0.5;\n0.05 [eq] brightness 0.6;...") +""" + +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# FFmpeg filter mappings for common effects +# Maps effect name -> {filter: str, params: {param_name: {ffmpeg_param, scale, offset}}} +EFFECT_MAPPINGS = { + "invert": { + "filter": "negate", + "params": {}, + }, + "grayscale": { + "filter": "colorchannelmixer", + "static": "0.3:0.4:0.3:0:0.3:0.4:0.3:0:0.3:0.4:0.3", + "params": {}, + }, + "sepia": { + "filter": "colorchannelmixer", + "static": "0.393:0.769:0.189:0:0.349:0.686:0.168:0:0.272:0.534:0.131", + "params": {}, + }, + "brightness": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0}, + }, + }, + "contrast": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0}, + }, + }, + "saturation": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "saturation", "scale": 1.0, "offset": 0}, + }, + }, + "hue_shift": { + "filter": "hue", + "params": { + "degrees": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0}, + }, + }, + "blur": { + "filter": "gblur", + "params": { + "radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0}, + }, + }, + "sharpen": { + "filter": "unsharp", + "params": { + "amount": {"ffmpeg_param": "la", "scale": 1.0, "offset": 0}, + }, + }, + "pixelate": { + # Scale down then up to create pixelation effect + "filter": "scale", + "static": "iw/8:ih/8:flags=neighbor,scale=iw*8:ih*8:flags=neighbor", + "params": {}, + }, + "vignette": { + "filter": "vignette", + "params": { + "strength": {"ffmpeg_param": "a", "scale": 1.0, "offset": 0}, + }, + }, + "noise": { + "filter": "noise", + "params": { + "amount": {"ffmpeg_param": "alls", "scale": 1.0, "offset": 0}, + }, + }, + "flip": { + "filter": "hflip", # Default horizontal + "params": {}, + }, + "mirror": { + "filter": "hflip", + "params": {}, + }, + "rotate": { + "filter": "rotate", + "params": { + "angle": {"ffmpeg_param": "a", "scale": math.pi/180, "offset": 0}, # degrees to radians + }, + }, + "zoom": { + "filter": "zoompan", + "params": { + "factor": {"ffmpeg_param": "z", "scale": 1.0, "offset": 0}, + }, + }, + "posterize": { + # Use lutyuv to quantize levels (approximate posterization) + "filter": "lutyuv", + "static": "y=floor(val/32)*32:u=floor(val/32)*32:v=floor(val/32)*32", + "params": {}, + }, + "threshold": { + # Use geq for thresholding + "filter": "geq", + "static": "lum='if(gt(lum(X,Y),128),255,0)':cb=128:cr=128", + "params": {}, + }, + "edge_detect": { + "filter": "edgedetect", + "params": { + "low": {"ffmpeg_param": "low", "scale": 1/255, "offset": 0}, + "high": {"ffmpeg_param": "high", "scale": 1/255, "offset": 0}, + }, + }, + "swirl": { + "filter": "lenscorrection", # Approximate with lens distortion + "params": { + "strength": {"ffmpeg_param": "k1", "scale": 0.1, "offset": 0}, + }, + }, + "fisheye": { + "filter": "lenscorrection", + "params": { + "strength": {"ffmpeg_param": "k1", "scale": 1.0, "offset": 0}, + }, + }, + "wave": { + # Wave displacement using geq - need r/g/b for RGB mode + "filter": "geq", + "static": "r='r(X+10*sin(Y/20),Y)':g='g(X+10*sin(Y/20),Y)':b='b(X+10*sin(Y/20),Y)'", + "params": {}, + }, + "rgb_split": { + # Chromatic aberration using geq + "filter": "geq", + "static": "r='p(X+5,Y)':g='p(X,Y)':b='p(X-5,Y)'", + "params": {}, + }, + "scanlines": { + "filter": "drawgrid", + "params": { + "spacing": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0}, + }, + }, + "film_grain": { + "filter": "noise", + "params": { + "intensity": {"ffmpeg_param": "alls", "scale": 100, "offset": 0}, + }, + }, + "crt": { + "filter": "vignette", # Simplified - just vignette for CRT look + "params": {}, + }, + "bloom": { + "filter": "gblur", # Simplified bloom = blur overlay + "params": { + "radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0}, + }, + }, + "color_cycle": { + "filter": "hue", + "params": { + "speed": {"ffmpeg_param": "h", "scale": 360.0, "offset": 0, "time_expr": True}, + }, + "time_based": True, # Uses time expression + }, + "strobe": { + # Strobe using select to drop frames + "filter": "select", + "static": "'mod(n,4)'", + "params": {}, + }, + "echo": { + # Echo using tmix + "filter": "tmix", + "static": "frames=4:weights='1 0.5 0.25 0.125'", + "params": {}, + }, + "trails": { + # Trails using tblend + "filter": "tblend", + "static": "all_mode=average", + "params": {}, + }, + "kaleidoscope": { + # 4-way mirror kaleidoscope using FFmpeg filter chain + # Crops top-left quadrant, mirrors horizontally, then vertically + "filter": "crop", + "complex": True, + "static": "iw/2:ih/2:0:0[q];[q]split[q1][q2];[q1]hflip[qr];[q2][qr]hstack[top];[top]split[t1][t2];[t2]vflip[bot];[t1][bot]vstack", + "params": {}, + }, + "emboss": { + "filter": "convolution", + "static": "-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2", + "params": {}, + }, + "neon_glow": { + # Edge detect + negate for neon-like effect + "filter": "edgedetect", + "static": "mode=colormix:high=0.1", + "params": {}, + }, + "ascii_art": { + # Requires Python frame processing - no FFmpeg equivalent + "filter": None, + "python_primitive": "ascii_art_frame", + "params": { + "char_size": {"default": 8}, + "alphabet": {"default": "standard"}, + "color_mode": {"default": "color"}, + }, + }, + "ascii_zones": { + # Requires Python frame processing - zone-based ASCII + "filter": None, + "python_primitive": "ascii_zones_frame", + "params": { + "char_size": {"default": 8}, + "zone_threshold": {"default": 128}, + }, + }, + "datamosh": { + # External tool: ffglitch or datamoshing CLI, falls back to Python + "filter": None, + "external_tool": "datamosh", + "python_primitive": "datamosh_frame", + "params": { + "block_size": {"default": 32}, + "corruption": {"default": 0.3}, + }, + }, + "pixelsort": { + # External tool: pixelsort CLI (Rust or Python), falls back to Python + "filter": None, + "external_tool": "pixelsort", + "python_primitive": "pixelsort_frame", + "params": { + "sort_by": {"default": "lightness"}, + "threshold_low": {"default": 50}, + "threshold_high": {"default": 200}, + "angle": {"default": 0}, + }, + }, + "ripple": { + # Use geq for ripple displacement + "filter": "geq", + "static": "lum='lum(X+5*sin(hypot(X-W/2,Y-H/2)/10),Y+5*cos(hypot(X-W/2,Y-H/2)/10))'", + "params": {}, + }, + "tile_grid": { + # Use tile filter for grid + "filter": "tile", + "static": "2x2", + "params": {}, + }, + "outline": { + "filter": "edgedetect", + "params": {}, + }, + "color-adjust": { + "filter": "eq", + "params": { + "brightness": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0}, + "contrast": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0}, + }, + }, +} + + +class FFmpegCompiler: + """Compiles sexp effects to FFmpeg filters with sendcmd support.""" + + def __init__(self, effect_mappings: Dict = None): + self.mappings = effect_mappings or EFFECT_MAPPINGS + + def get_full_mapping(self, effect_name: str) -> Optional[Dict]: + """Get full mapping for an effect (including external tools and python primitives).""" + mapping = self.mappings.get(effect_name) + if not mapping: + # Try with underscores/hyphens converted + normalized = effect_name.replace("-", "_").replace(" ", "_").lower() + mapping = self.mappings.get(normalized) + return mapping + + def get_mapping(self, effect_name: str) -> Optional[Dict]: + """Get FFmpeg filter mapping for an effect (returns None for non-FFmpeg effects).""" + mapping = self.get_full_mapping(effect_name) + # Return None if no mapping or no FFmpeg filter + if mapping and mapping.get("filter") is None: + return None + return mapping + + def has_external_tool(self, effect_name: str) -> Optional[str]: + """Check if effect uses an external tool. Returns tool name or None.""" + mapping = self.get_full_mapping(effect_name) + if mapping: + return mapping.get("external_tool") + return None + + def has_python_primitive(self, effect_name: str) -> Optional[str]: + """Check if effect uses a Python primitive. Returns primitive name or None.""" + mapping = self.get_full_mapping(effect_name) + if mapping: + return mapping.get("python_primitive") + return None + + def is_complex_filter(self, effect_name: str) -> bool: + """Check if effect uses a complex filter chain.""" + mapping = self.get_full_mapping(effect_name) + return bool(mapping and mapping.get("complex")) + + def compile_effect( + self, + effect_name: str, + params: Dict[str, Any], + ) -> Optional[str]: + """ + Compile an effect to an FFmpeg filter string with static params. + + Returns None if effect has no FFmpeg mapping. + """ + mapping = self.get_mapping(effect_name) + if not mapping: + return None + + filter_name = mapping["filter"] + + # Handle static filters (no params) + if "static" in mapping: + return f"{filter_name}={mapping['static']}" + + if not mapping.get("params"): + return filter_name + + # Build param string + filter_params = [] + for param_name, param_config in mapping["params"].items(): + if param_name in params: + value = params[param_name] + # Skip if it's a binding (handled separately) + if isinstance(value, dict) and ("_bind" in value or "_binding" in value): + continue + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + offset = param_config.get("offset", 0) + # Handle various value types + if isinstance(value, (int, float)): + ffmpeg_value = value * scale + offset + filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + elif isinstance(value, str): + filter_params.append(f"{ffmpeg_param}={value}") + elif isinstance(value, list) and value and isinstance(value[0], (int, float)): + ffmpeg_value = value[0] * scale + offset + filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + + if filter_params: + return f"{filter_name}={':'.join(filter_params)}" + return filter_name + + def compile_effect_with_bindings( + self, + effect_name: str, + params: Dict[str, Any], + analysis_data: Dict[str, Dict], + segment_start: float, + segment_duration: float, + sample_interval: float = 0.04, # ~25 fps + ) -> Tuple[Optional[str], Optional[str], List[str]]: + """ + Compile an effect with dynamic bindings to a filter + sendcmd script. + + Returns: + (filter_string, sendcmd_script, bound_param_names) + - filter_string: Initial FFmpeg filter (may have placeholder values) + - sendcmd_script: Script content for sendcmd filter + - bound_param_names: List of params that have bindings + """ + mapping = self.get_mapping(effect_name) + if not mapping: + return None, None, [] + + filter_name = mapping["filter"] + static_params = [] + bound_params = [] + sendcmd_lines = [] + + # Handle time-based effects (use FFmpeg expressions with 't') + if mapping.get("time_based"): + for param_name, param_config in mapping.get("params", {}).items(): + if param_name in params: + value = params[param_name] + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + if isinstance(value, (int, float)): + # Create time expression: h='t*speed*scale' + static_params.append(f"{ffmpeg_param}='t*{value}*{scale}'") + else: + static_params.append(f"{ffmpeg_param}='t*{scale}'") + if static_params: + filter_str = f"{filter_name}={':'.join(static_params)}" + else: + filter_str = f"{filter_name}=h='t*360'" # Default rotation + return filter_str, None, [] + + # Process each param + for param_name, param_config in mapping.get("params", {}).items(): + if param_name not in params: + continue + + value = params[param_name] + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + offset = param_config.get("offset", 0) + + # Check if it's a binding + if isinstance(value, dict) and ("_bind" in value or "_binding" in value): + bind_ref = value.get("_bind") or value.get("_binding") + range_min = value.get("range_min", 0.0) + range_max = value.get("range_max", 1.0) + transform = value.get("transform") + + # Get analysis data + analysis = analysis_data.get(bind_ref) + if not analysis: + # Try without -data suffix + analysis = analysis_data.get(bind_ref.replace("-data", "")) + + if analysis and "times" in analysis and "values" in analysis: + times = analysis["times"] + values = analysis["values"] + + # Generate sendcmd entries for this segment + segment_end = segment_start + segment_duration + t = 0.0 # Time relative to segment start + + while t < segment_duration: + abs_time = segment_start + t + + # Find analysis value at this time + raw_value = self._interpolate_value(times, values, abs_time) + + # Apply transform + if transform == "sqrt": + raw_value = math.sqrt(max(0, raw_value)) + elif transform == "pow2": + raw_value = raw_value ** 2 + elif transform == "log": + raw_value = math.log(max(0.001, raw_value)) + + # Map to range + mapped_value = range_min + raw_value * (range_max - range_min) + + # Apply FFmpeg scaling + ffmpeg_value = mapped_value * scale + offset + + # Add sendcmd line (time relative to segment) + sendcmd_lines.append(f"{t:.3f} [{filter_name}] {ffmpeg_param} {ffmpeg_value:.4f};") + + t += sample_interval + + bound_params.append(param_name) + # Use initial value for the filter string + initial_value = self._interpolate_value(times, values, segment_start) + initial_mapped = range_min + initial_value * (range_max - range_min) + initial_ffmpeg = initial_mapped * scale + offset + static_params.append(f"{ffmpeg_param}={initial_ffmpeg:.4f}") + else: + # No analysis data, use range midpoint + mid_value = (range_min + range_max) / 2 + ffmpeg_value = mid_value * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + else: + # Static value - handle various types + if isinstance(value, (int, float)): + ffmpeg_value = value * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + elif isinstance(value, str): + # String value - use as-is (e.g., for direction parameters) + static_params.append(f"{ffmpeg_param}={value}") + elif isinstance(value, list) and value: + # List - try to use first numeric element + first = value[0] + if isinstance(first, (int, float)): + ffmpeg_value = first * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + # Skip other types + + # Handle static filters + if "static" in mapping: + filter_str = f"{filter_name}={mapping['static']}" + elif static_params: + filter_str = f"{filter_name}={':'.join(static_params)}" + else: + filter_str = filter_name + + # Combine sendcmd lines + sendcmd_script = "\n".join(sendcmd_lines) if sendcmd_lines else None + + return filter_str, sendcmd_script, bound_params + + def _interpolate_value( + self, + times: List[float], + values: List[float], + target_time: float, + ) -> float: + """Interpolate a value from analysis data at a given time.""" + if not times or not values: + return 0.5 + + # Find surrounding indices + if target_time <= times[0]: + return values[0] + if target_time >= times[-1]: + return values[-1] + + # Binary search for efficiency + lo, hi = 0, len(times) - 1 + while lo < hi - 1: + mid = (lo + hi) // 2 + if times[mid] <= target_time: + lo = mid + else: + hi = mid + + # Linear interpolation + t0, t1 = times[lo], times[hi] + v0, v1 = values[lo], values[hi] + + if t1 == t0: + return v0 + + alpha = (target_time - t0) / (t1 - t0) + return v0 + alpha * (v1 - v0) + + +def generate_sendcmd_filter( + effects: List[Dict], + analysis_data: Dict[str, Dict], + segment_start: float, + segment_duration: float, +) -> Tuple[str, Optional[Path]]: + """ + Generate FFmpeg filter chain with sendcmd for dynamic effects. + + Args: + effects: List of effect configs with name and params + analysis_data: Analysis data keyed by name + segment_start: Segment start time in source + segment_duration: Segment duration + + Returns: + (filter_chain_string, sendcmd_file_path or None) + """ + import tempfile + + compiler = FFmpegCompiler() + filters = [] + all_sendcmd_lines = [] + + for effect in effects: + effect_name = effect.get("effect") + params = {k: v for k, v in effect.items() if k != "effect"} + + filter_str, sendcmd, _ = compiler.compile_effect_with_bindings( + effect_name, + params, + analysis_data, + segment_start, + segment_duration, + ) + + if filter_str: + filters.append(filter_str) + if sendcmd: + all_sendcmd_lines.append(sendcmd) + + if not filters: + return "", None + + filter_chain = ",".join(filters) + + # NOTE: sendcmd disabled - FFmpeg's sendcmd filter has compatibility issues. + # Bindings use their initial value (sampled at segment start time). + # This is acceptable since each segment is only ~8 seconds. + # The binding value is still music-reactive (varies per segment). + sendcmd_path = None + + return filter_chain, sendcmd_path diff --git a/artdag/sexp/parser.py b/artdag/sexp/parser.py new file mode 100644 index 0000000..8f7b4a4 --- /dev/null +++ b/artdag/sexp/parser.py @@ -0,0 +1,425 @@ +""" +S-expression parser for ArtDAG recipes and plans. + +Supports: +- Lists: (a b c) +- Symbols: foo, bar-baz, -> +- Keywords: :key +- Strings: "hello world" +- Numbers: 42, 3.14, -1.5 +- Comments: ; to end of line +- Vectors: [a b c] (syntactic sugar for lists) +- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts) +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Union +import re + + +@dataclass +class Symbol: + """An unquoted symbol/identifier.""" + name: str + + def __repr__(self): + return f"Symbol({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Symbol): + return self.name == other.name + if isinstance(other, str): + return self.name == other + return False + + def __hash__(self): + return hash(self.name) + + +@dataclass +class Keyword: + """A keyword starting with colon.""" + name: str + + def __repr__(self): + return f"Keyword({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Keyword): + return self.name == other.name + return False + + def __hash__(self): + return hash((':' , self.name)) + + +@dataclass +class Lambda: + """A lambda/anonymous function with closure.""" + params: List[str] # Parameter names + body: Any # Expression body + closure: Dict = None # Captured environment (optional for backwards compat) + + def __repr__(self): + return f"Lambda({self.params}, {self.body!r})" + + +@dataclass +class Binding: + """A binding to analysis data for dynamic effect parameters.""" + analysis_ref: str # Name of analysis variable + track: str = None # Optional track name (e.g., "bass", "energy") + range_min: float = 0.0 # Output range minimum + range_max: float = 1.0 # Output range maximum + transform: str = None # Optional transform: "sqrt", "pow2", "log", etc. + + def __repr__(self): + t = f", transform={self.transform!r}" if self.transform else "" + return f"Binding({self.analysis_ref!r}, track={self.track!r}, range=[{self.range_min}, {self.range_max}]{t})" + + +class ParseError(Exception): + """Error during S-expression parsing.""" + def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1): + self.position = position + self.line = line + self.col = col + super().__init__(f"{message} at line {line}, column {col}") + + +class Tokenizer: + """Tokenize S-expression text into tokens.""" + + # Token patterns + WHITESPACE = re.compile(r'\s+') + COMMENT = re.compile(r';[^\n]*') + STRING = re.compile(r'"(?:[^"\\]|\\.)*"') + NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') + KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') + SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*') + + def __init__(self, text: str): + self.text = text + self.pos = 0 + self.line = 1 + self.col = 1 + + def _advance(self, count: int = 1): + """Advance position, tracking line/column.""" + for _ in range(count): + if self.pos < len(self.text): + if self.text[self.pos] == '\n': + self.line += 1 + self.col = 1 + else: + self.col += 1 + self.pos += 1 + + def _skip_whitespace_and_comments(self): + """Skip whitespace and comments.""" + while self.pos < len(self.text): + # Whitespace + match = self.WHITESPACE.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + # Comments + match = self.COMMENT.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + break + + def peek(self) -> str | None: + """Peek at current character.""" + self._skip_whitespace_and_comments() + if self.pos >= len(self.text): + return None + return self.text[self.pos] + + def next_token(self) -> Any: + """Get the next token.""" + self._skip_whitespace_and_comments() + + if self.pos >= len(self.text): + return None + + char = self.text[self.pos] + start_line, start_col = self.line, self.col + + # Single-character tokens (parens, brackets, braces) + if char in '()[]{}': + self._advance() + return char + + # String + if char == '"': + match = self.STRING.match(self.text, self.pos) + if not match: + raise ParseError("Unterminated string", self.pos, self.line, self.col) + self._advance(match.end() - self.pos) + # Parse escape sequences + content = match.group()[1:-1] + content = content.replace('\\n', '\n') + content = content.replace('\\t', '\t') + content = content.replace('\\"', '"') + content = content.replace('\\\\', '\\') + return content + + # Keyword + if char == ':': + match = self.KEYWORD.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Keyword(match.group()[1:]) # Strip leading colon + raise ParseError(f"Invalid keyword", self.pos, self.line, self.col) + + # Number (must check before symbol due to - prefix) + if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and + (self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')): + match = self.NUMBER.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + num_str = match.group() + if '.' in num_str or 'e' in num_str or 'E' in num_str: + return float(num_str) + return int(num_str) + + # Symbol + match = self.SYMBOL.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Symbol(match.group()) + + raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col) + + +def parse(text: str) -> Any: + """ + Parse an S-expression string into Python data structures. + + Returns: + Parsed S-expression as nested Python structures: + - Lists become Python lists + - Symbols become Symbol objects + - Keywords become Keyword objects + - Strings become Python strings + - Numbers become int/float + + Example: + >>> parse('(recipe "test" :version "1.0")') + [Symbol('recipe'), 'test', Keyword('version'), '1.0'] + """ + tokenizer = Tokenizer(text) + result = _parse_expr(tokenizer) + + # Check for trailing content + if tokenizer.peek() is not None: + raise ParseError("Unexpected content after expression", + tokenizer.pos, tokenizer.line, tokenizer.col) + + return result + + +def parse_all(text: str) -> List[Any]: + """ + Parse multiple S-expressions from a string. + + Returns list of parsed expressions. + """ + tokenizer = Tokenizer(text) + results = [] + + while tokenizer.peek() is not None: + results.append(_parse_expr(tokenizer)) + + return results + + +def _parse_expr(tokenizer: Tokenizer) -> Any: + """Parse a single expression.""" + token = tokenizer.next_token() + + if token is None: + raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col) + + # List + if token == '(': + return _parse_list(tokenizer, ')') + + # Vector (sugar for list) + if token == '[': + return _parse_list(tokenizer, ']') + + # Map/dict: {:key1 val1 :key2 val2} + if token == '{': + return _parse_map(tokenizer) + + # Unexpected closers + if token in (')', ']', '}'): + raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col) + + # Atom + return token + + +def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]: + """Parse a list until the closing delimiter.""" + items = [] + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError(f"Unterminated list, expected {closer!r}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == closer: + tokenizer.next_token() # Consume closer + return items + + items.append(_parse_expr(tokenizer)) + + +def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]: + """Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}.""" + result = {} + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError("Unterminated map, expected '}'", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == '}': + tokenizer.next_token() # Consume closer + return result + + # Parse key (should be a keyword like :key) + key_token = _parse_expr(tokenizer) + if isinstance(key_token, Keyword): + key = key_token.name + elif isinstance(key_token, str): + key = key_token + else: + raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + # Parse value + value = _parse_expr(tokenizer) + result[key] = value + + +def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str: + """ + Serialize a Python data structure back to S-expression format. + + Args: + expr: The expression to serialize + indent: Current indentation level (for pretty printing) + pretty: Whether to use pretty printing with newlines + + Returns: + S-expression string + """ + if isinstance(expr, list): + if not expr: + return "()" + + if pretty: + return _serialize_pretty(expr, indent) + else: + items = [serialize(item, indent, False) for item in expr] + return "(" + " ".join(items) + ")" + + if isinstance(expr, Symbol): + return expr.name + + if isinstance(expr, Keyword): + return f":{expr.name}" + + if isinstance(expr, Lambda): + params = " ".join(expr.params) + body = serialize(expr.body, indent, pretty) + return f"(fn [{params}] {body})" + + if isinstance(expr, Binding): + # analysis_ref can be a string, node ID, or dict - serialize it properly + if isinstance(expr.analysis_ref, str): + ref_str = f'"{expr.analysis_ref}"' + else: + ref_str = serialize(expr.analysis_ref, indent, pretty) + s = f"(bind {ref_str} :range [{expr.range_min} {expr.range_max}]" + if expr.transform: + s += f" :transform {expr.transform}" + return s + ")" + + if isinstance(expr, str): + # Escape special characters + escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t') + return f'"{escaped}"' + + if isinstance(expr, bool): + return "true" if expr else "false" + + if isinstance(expr, (int, float)): + return str(expr) + + if expr is None: + return "nil" + + if isinstance(expr, dict): + # Serialize dict as property list: {:key1 val1 :key2 val2} + items = [] + for k, v in expr.items(): + items.append(f":{k}") + items.append(serialize(v, indent, pretty)) + return "{" + " ".join(items) + "}" + + raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}") + + +def _serialize_pretty(expr: List, indent: int) -> str: + """Pretty-print a list expression with smart formatting.""" + if not expr: + return "()" + + prefix = " " * indent + inner_prefix = " " * (indent + 1) + + # Check if this is a simple list that fits on one line + simple = serialize(expr, indent, False) + if len(simple) < 60 and '\n' not in simple: + return simple + + # Start building multiline output + head = serialize(expr[0], indent + 1, False) + parts = [f"({head}"] + + i = 1 + while i < len(expr): + item = expr[i] + + # Group keyword-value pairs on same line + if isinstance(item, Keyword) and i + 1 < len(expr): + key = serialize(item, 0, False) + val = serialize(expr[i + 1], indent + 1, False) + + # If value is short, put on same line + if len(val) < 50 and '\n' not in val: + parts.append(f"{inner_prefix}{key} {val}") + else: + # Value is complex, serialize it pretty + val_pretty = serialize(expr[i + 1], indent + 1, True) + parts.append(f"{inner_prefix}{key} {val_pretty}") + i += 2 + else: + # Regular item + item_str = serialize(item, indent + 1, True) + parts.append(f"{inner_prefix}{item_str}") + i += 1 + + return "\n".join(parts) + ")" diff --git a/artdag/sexp/planner.py b/artdag/sexp/planner.py new file mode 100644 index 0000000..ecd6595 --- /dev/null +++ b/artdag/sexp/planner.py @@ -0,0 +1,2187 @@ +""" +Execution plan generation from S-expression recipes. + +The planner: +1. Takes a compiled recipe + input content hashes +2. Runs analyzers to get concrete data (beat times, etc.) +3. Expands dynamic nodes (SLICE_ON) into primitive operations +4. Resolves all registry references to content hashes +5. Generates an execution plan with pre-computed cache IDs + +Plans are S-expressions with all references resolved to hashes, +ready for distribution to Celery workers. +""" + +import hashlib +import importlib.util +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Callable + +from .parser import Symbol, Keyword, Binding, serialize +from .compiler import CompiledRecipe + + +# Node types that can be collapsed into a single FFmpeg filter chain +COLLAPSIBLE_TYPES = {"EFFECT", "SEGMENT"} + +# Node types that are boundaries (sources, merges, or special processing) +BOUNDARY_TYPES = {"SOURCE", "SEQUENCE", "MUX", "ANALYZE", "SCAN", "LIST"} + +# Node types that need expansion during planning +EXPANDABLE_TYPES = {"SLICE_ON", "CONSTRUCT"} + + +def _load_module(module_path: Path, module_name: str = "module"): + """Load a Python module from file path.""" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _run_analyzer( + analyzer_path: Path, + input_path: Path, + params: Dict[str, Any], +) -> Dict[str, Any]: + """Run an analyzer module and return results.""" + analyzer = _load_module(analyzer_path, "analyzer") + return analyzer.analyze(input_path, params) + + +def _pre_execute_segment( + node: Dict, + input_path: Path, + work_dir: Path, +) -> Path: + """ + Pre-execute a SEGMENT node during planning. + + This is needed when ANALYZE depends on a SEGMENT output. + Returns path to the segmented file. + """ + import subprocess + import tempfile + + config = node.get("config", {}) + start = config.get("start", 0) + duration = config.get("duration") + end = config.get("end") + + # Detect if input is audio-only + suffix = input_path.suffix.lower() + is_audio = suffix in ('.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a') + + if is_audio: + output_ext = ".m4a" # Use m4a for aac codec + else: + output_ext = ".mp4" + + output_path = work_dir / f"segment_{node['id'][:16]}{output_ext}" + + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + elif end: + cmd.extend(["-t", str(end - start)]) + + if is_audio: + cmd.extend(["-c:a", "aac", str(output_path)]) + else: + cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18", + "-c:a", "aac", str(output_path)]) + + subprocess.run(cmd, check=True, capture_output=True) + return output_path + + +def _serialize_for_hash(obj) -> str: + """Serialize any value to canonical S-expression string for hashing.""" + from .parser import Lambda + + if obj is None: + return "nil" + if isinstance(obj, bool): + return "true" if obj else "false" + if isinstance(obj, (int, float)): + return str(obj) + if isinstance(obj, str): + escaped = obj.replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped}"' + if isinstance(obj, Symbol): + return obj.name + if isinstance(obj, Keyword): + return f":{obj.name}" + if isinstance(obj, Lambda): + params = " ".join(obj.params) + body = _serialize_for_hash(obj.body) + return f"(fn [{params}] {body})" + if isinstance(obj, Binding): + # analysis_ref can be a string, node ID, or dict - serialize it properly + if isinstance(obj.analysis_ref, str): + ref_str = f'"{obj.analysis_ref}"' + else: + ref_str = _serialize_for_hash(obj.analysis_ref) + return f"(bind {ref_str} :range [{obj.range_min} {obj.range_max}])" + if isinstance(obj, dict): + items = [] + for k, v in sorted(obj.items()): + items.append(f":{k} {_serialize_for_hash(v)}") + return "{" + " ".join(items) + "}" + if isinstance(obj, list): + items = [_serialize_for_hash(x) for x in obj] + return "(" + " ".join(items) + ")" + return str(obj) + + +def _stable_hash(data: Any, cluster_key: str = None) -> str: + """Create stable SHA3-256 hash from data using S-expression serialization.""" + if cluster_key: + data = {"_cluster_key": cluster_key, "_data": data} + sexp_str = _serialize_for_hash(data) + return hashlib.sha3_256(sexp_str.encode()).hexdigest() + + +@dataclass +class PlanStep: + """A step in the execution plan.""" + step_id: str + node_type: str + config: Dict[str, Any] + inputs: List[str] # List of input step_ids + cache_id: str + level: int = 0 + stage: Optional[str] = None # Stage this step belongs to + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [Symbol("step"), self.step_id] + + # Add cache-id + sexp.extend([Keyword("cache-id"), self.cache_id]) + + # Add level if > 0 + if self.level > 0: + sexp.extend([Keyword("level"), self.level]) + + # Add stage info if present + if self.stage: + sexp.extend([Keyword("stage"), self.stage]) + + # Add the node expression + node_sexp = [Symbol(self.node_type.lower())] + + # Add config as keywords + for key, value in self.config.items(): + # Convert Binding to sexp form + if isinstance(value, Binding): + value = [Symbol("bind"), value.analysis_ref, + Keyword("range"), [value.range_min, value.range_max]] + node_sexp.extend([Keyword(key), value]) + + # Add inputs if any + if self.inputs: + node_sexp.extend([Keyword("inputs"), self.inputs]) + + sexp.append(node_sexp) + return sexp + + +@dataclass +class StagePlan: + """A stage in the execution plan.""" + stage_name: str + steps: List[PlanStep] + requires: List[str] # Names of required stages + output_bindings: Dict[str, str] # binding_name -> cache_id of output + level: int = 0 # Stage level for parallel execution + + +@dataclass +class ExecutionPlanSexp: + """Execution plan as S-expression.""" + plan_id: str + steps: List[PlanStep] + output_step_id: str + source_hash: str = "" # CID of recipe source + params: Dict[str, Any] = field(default_factory=dict) # Resolved parameter values + params_hash: str = "" # Hash of params for quick comparison + inputs: Dict[str, str] = field(default_factory=dict) # name -> hash + analysis: Dict[str, Dict] = field(default_factory=dict) # name -> {times, values} + metadata: Dict[str, Any] = field(default_factory=dict) + stage_plans: List[StagePlan] = field(default_factory=list) # Stage-level plans + stage_order: List[str] = field(default_factory=list) # Topologically sorted stage names + stage_levels: Dict[str, int] = field(default_factory=dict) # stage_name -> level + effects_registry: Dict[str, Dict] = field(default_factory=dict) # effect_name -> {path, cid, ...} + minimal_primitives: bool = False # If True, interpreter uses only core primitives + + def to_sexp(self) -> List: + """Convert entire plan to S-expression.""" + sexp = [Symbol("plan")] + + # Metadata - purely content-addressed + sexp.extend([Keyword("id"), self.plan_id]) + sexp.extend([Keyword("source-cid"), self.source_hash]) # CID of recipe source + + # Parameters + if self.params: + sexp.extend([Keyword("params-hash"), self.params_hash]) + params_sexp = [Symbol("params")] + for name, value in self.params.items(): + params_sexp.append([Symbol(name), value]) + sexp.append(params_sexp) + + # Input bindings + if self.inputs: + inputs_sexp = [Symbol("inputs")] + for name, hash_val in self.inputs.items(): + inputs_sexp.append([Symbol(name), hash_val]) + sexp.append(inputs_sexp) + + # Analysis data (for effect parameter bindings) + if self.analysis: + analysis_sexp = [Symbol("analysis")] + for name, data in self.analysis.items(): + track_sexp = [Symbol(name)] + if isinstance(data, dict) and "_cache_id" in data: + track_sexp.extend([Keyword("cache-id"), data["_cache_id"]]) + else: + if "times" in data: + track_sexp.extend([Keyword("times"), data["times"]]) + if "values" in data: + track_sexp.extend([Keyword("values"), data["values"]]) + analysis_sexp.append(track_sexp) + sexp.append(analysis_sexp) + + # Stage information + if self.stage_plans: + stages_sexp = [Symbol("stages")] + for stage_plan in self.stage_plans: + stage_sexp = [ + Keyword("name"), stage_plan.stage_name, + Keyword("level"), stage_plan.level, + ] + if stage_plan.requires: + stage_sexp.extend([Keyword("requires"), stage_plan.requires]) + if stage_plan.output_bindings: + outputs_sexp = [] + for name, cache_id in stage_plan.output_bindings.items(): + outputs_sexp.append([Symbol(name), Keyword("cache-id"), cache_id]) + stage_sexp.extend([Keyword("outputs"), outputs_sexp]) + stages_sexp.append(stage_sexp) + sexp.append(stages_sexp) + + # Effects registry - for loading explicitly declared effects + if self.effects_registry: + registry_sexp = [Symbol("effects-registry")] + for name, info in self.effects_registry.items(): + effect_sexp = [Symbol(name)] + if info.get("path"): + effect_sexp.extend([Keyword("path"), info["path"]]) + if info.get("cid"): + effect_sexp.extend([Keyword("cid"), info["cid"]]) + registry_sexp.append(effect_sexp) + sexp.append(registry_sexp) + + # Minimal primitives flag + if self.minimal_primitives: + sexp.extend([Keyword("minimal-primitives"), True]) + + # Steps + for step in self.steps: + sexp.append(step.to_sexp()) + + # Output reference + sexp.extend([Keyword("output"), self.output_step_id]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize plan to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + +def _expand_list_inputs(nodes: List[Dict]) -> List[Dict]: + """ + Expand LIST node inputs in SEQUENCE nodes. + + When a SEQUENCE has a LIST as input, replace it with all the LIST's inputs. + LIST nodes that are referenced by non-SEQUENCE nodes (e.g., EFFECT chains) + are promoted to SEQUENCE nodes so they produce a concatenated output. + Unreferenced LIST nodes are removed. + """ + nodes_by_id = {n["id"]: n for n in nodes} + list_nodes = {n["id"]: n for n in nodes if n["type"] == "LIST"} + + if not list_nodes: + return nodes + + # Determine which LIST nodes are referenced by SEQUENCE vs other node types + list_consumed_by_seq = set() + list_referenced_by_other = set() + for node in nodes: + if node["type"] == "LIST": + continue + for inp in node.get("inputs", []): + if inp in list_nodes: + if node["type"] == "SEQUENCE": + list_consumed_by_seq.add(inp) + else: + list_referenced_by_other.add(inp) + + result = [] + for node in nodes: + if node["type"] == "LIST": + if node["id"] in list_referenced_by_other: + # Promote to SEQUENCE — non-SEQUENCE nodes reference this LIST + result.append({ + "id": node["id"], + "type": "SEQUENCE", + "config": node.get("config", {}), + "inputs": node.get("inputs", []), + }) + # Otherwise skip (consumed by SEQUENCE expansion or unreferenced) + continue + + if node["type"] == "SEQUENCE": + # Expand any LIST inputs + new_inputs = [] + for inp in node.get("inputs", []): + if inp in list_nodes: + # Replace LIST with its contents + new_inputs.extend(list_nodes[inp].get("inputs", [])) + else: + new_inputs.append(inp) + + # Create updated node + result.append({ + **node, + "inputs": new_inputs, + }) + else: + result.append(node) + + return result + + +def _collapse_effect_chains(nodes: List[Dict], registry: Dict = None) -> List[Dict]: + """ + Collapse sequential effect chains into single COMPOUND nodes. + + A chain is a sequence of single-input collapsible nodes where: + - Each node has exactly one input + - No node in the chain is referenced by multiple other nodes + - The chain ends at a boundary or multi-ref node + - No node in the chain is marked as temporal + + Effects can declare :temporal true to prevent collapsing (e.g., reverse). + + Returns a new node list with chains collapsed. + """ + if not nodes: + return nodes + + registry = registry or {} + nodes_by_id = {n["id"]: n for n in nodes} + + # Build reference counts: how many nodes reference each node as input + ref_count = {n["id"]: 0 for n in nodes} + for node in nodes: + for inp in node.get("inputs", []): + if inp in ref_count: + ref_count[inp] += 1 + + # Track which nodes are consumed by chains + consumed = set() + compound_nodes = [] + + def is_temporal(node: Dict) -> bool: + """Check if a node is temporal (needs complete input).""" + config = node.get("config", {}) + # Check node-level temporal flag + if config.get("temporal"): + return True + # Check effect registry for temporal flag + if node["type"] == "EFFECT": + effect_name = config.get("effect") + if effect_name: + effect_meta = registry.get("effects", {}).get(effect_name, {}) + if effect_meta.get("temporal"): + return True + return False + + def is_collapsible(node_id: str) -> bool: + """Check if a node can be part of a chain.""" + if node_id in consumed: + return False + node = nodes_by_id.get(node_id) + if not node: + return False + if node["type"] not in COLLAPSIBLE_TYPES: + return False + # Temporal effects can't be collapsed + if is_temporal(node): + return False + # Effects CAN be collapsed if they have an FFmpeg mapping + # Only fall back to Python interpreter if no mapping exists + config = node.get("config", {}) + if node["type"] == "EFFECT": + effect_name = config.get("effect") + # Import here to avoid circular imports + from .ffmpeg_compiler import FFmpegCompiler + compiler = FFmpegCompiler() + if compiler.get_mapping(effect_name): + return True # Has FFmpeg mapping, can collapse + elif config.get("effect_path"): + return False # No FFmpeg mapping, has Python path, can't collapse + return True + + def is_chain_boundary(node_id: str) -> bool: + """Check if a node is a chain boundary (can't be collapsed into).""" + node = nodes_by_id.get(node_id) + if not node: + return True # Unknown node is a boundary + # Boundary if: it's a boundary type, or referenced by multiple nodes + return node["type"] in BOUNDARY_TYPES or ref_count.get(node_id, 0) > 1 + + def collect_chain(start_id: str) -> List[str]: + """Collect a chain of collapsible nodes starting from start_id.""" + chain = [start_id] + current = start_id + + while True: + node = nodes_by_id[current] + inputs = node.get("inputs", []) + + # Must have exactly one input + if len(inputs) != 1: + break + + next_id = inputs[0] + + # Stop if next is a boundary or already consumed + if is_chain_boundary(next_id) or not is_collapsible(next_id): + break + + # Stop if next is referenced by others besides current + if ref_count.get(next_id, 0) > 1: + break + + chain.append(next_id) + current = next_id + + return chain + + # Process nodes in reverse order (from outputs toward inputs) + # This ensures we find complete chains starting from their end + # First, topologically sort to get dependency order + sorted_ids = [] + visited = set() + + def topo_visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = nodes_by_id.get(node_id) + if node: + for inp in node.get("inputs", []): + topo_visit(inp) + sorted_ids.append(node_id) + + for node in nodes: + topo_visit(node["id"]) + + # Process in reverse topological order (outputs first) + result_nodes = [] + + for node_id in reversed(sorted_ids): + node = nodes_by_id[node_id] + + if node_id in consumed: + continue + + if not is_collapsible(node_id): + # Keep boundary nodes as-is + result_nodes.append(node) + continue + + # Check if this node is the start of a chain (output end) + # A node is a chain start if it's collapsible and either: + # - Referenced by a boundary node + # - Referenced by multiple nodes + # - Is the output node + # For now, collect chain going backwards from this node + + chain = collect_chain(node_id) + + if len(chain) == 1: + # Single node, no collapse needed + result_nodes.append(node) + continue + + # Collapse the chain into a COMPOUND node + # Chain is [end, ..., start] order (backwards from output) + # The compound node: + # - Has the same ID as the chain end (for reference stability) + # - Takes input from what the chain start originally took + # - Has a filter_chain config with all the nodes in order + + chain_start = chain[-1] # First to execute + chain_end = chain[0] # Last to execute + + start_node = nodes_by_id[chain_start] + end_node = nodes_by_id[chain_end] + + # Build filter chain config (in execution order: start to end) + filter_chain = [] + for chain_node_id in reversed(chain): + chain_node = nodes_by_id[chain_node_id] + filter_chain.append({ + "type": chain_node["type"], + "config": chain_node.get("config", {}), + }) + + compound_node = { + "id": chain_end, # Keep the end ID for reference stability + "type": "COMPOUND", + "config": { + "filter_chain": filter_chain, + # Include effects registry so executor can load only declared effects + "effects_registry": registry.get("effects", {}), + }, + "inputs": start_node.get("inputs", []), + "name": f"compound_{len(filter_chain)}_effects", + } + + result_nodes.append(compound_node) + + # Mark all chain nodes as consumed + for chain_node_id in chain: + consumed.add(chain_node_id) + + return result_nodes + + +def _expand_slice_on( + node: Dict, + analysis_data: Dict[str, Any], + registry: Dict, + sources: Dict[str, str] = None, + cluster_key: str = None, + encoding: Dict = None, + named_analysis: Dict = None, +) -> List[Dict]: + """ + Expand a SLICE_ON node into primitive SEGMENT + EFFECT + SEQUENCE nodes. + + Supports two modes: + 1. Legacy: :effect and :pattern parameters + 2. Lambda: :fn parameter with reducer function + + Lambda syntax: + (slice-on analysis + :times times + :init 0 + :fn (lambda [acc i start end] + {:source video + :effects (if (odd? i) [invert] []) + :acc (inc acc)})) + + When all beats produce composition-mode results (layers + compositor) + with the same layer structure, consecutive beats are automatically merged + into fewer compositions with time-varying parameter bindings. This can + reduce thousands of nodes to a handful. + + Args: + node: The SLICE_ON node to expand + analysis_data: Analysis results containing times array + registry: Recipe registry with effect definitions + sources: Map of source names to node IDs + cluster_key: Optional cluster key for hashing + named_analysis: Mutable dict to inject synthetic analysis tracks into + + Returns: + List of expanded nodes (segments, effects, sequence) + """ + from .evaluator import evaluate, EvalError + from .parser import Lambda, Symbol + + config = node.get("config", {}) + node_inputs = node.get("inputs", []) + sources = sources or {} + + # Extract times + times_path = config.get("times_path", "times") + times = analysis_data + for key in times_path.split("."): + times = times[key] + + if not times: + raise ValueError(f"No times found at path '{times_path}' in analysis") + + # Default video input (first input after analysis) + default_video = node_inputs[0] if node_inputs else None + + expanded_nodes = [] + sequence_inputs = [] + base_id = node["id"][:8] + + # Check for lambda-based reducer + reducer_fn = config.get("fn") + + if isinstance(reducer_fn, Lambda): + # Lambda mode - evaluate function for each slice + acc = config.get("init", 0) + slice_times = list(zip([0] + times[:-1], times)) + + # Frame-accurate timing calculation + # Align ALL times to frame boundaries to prevent accumulating drift + fps = (encoding or {}).get("fps", 30) + frame_duration = 1.0 / fps + + # Get total duration from analysis data (beats analyzer includes this) + # Falls back to config target_duration for backwards compatibility + total_duration = analysis_data.get("duration") or config.get("target_duration") + + # Pre-compute frame-aligned cumulative times + cumulative_frames = [0] # Start at frame 0 + for t in times: + # Round to nearest frame boundary + frames = round(t * fps) + cumulative_frames.append(frames) + + # If total duration known, ensure last segment extends to it exactly + if total_duration is not None: + target_frames = round(total_duration * fps) + if target_frames > cumulative_frames[-1]: + cumulative_frames[-1] = target_frames + + # Pre-compute frame-aligned start times and durations for each slice + frame_aligned_starts = [] + frame_aligned_durations = [] + for i in range(len(cumulative_frames) - 1): + start_frames = cumulative_frames[i] + end_frames = cumulative_frames[i + 1] + frame_aligned_starts.append(start_frames * frame_duration) + frame_aligned_durations.append((end_frames - start_frames) * frame_duration) + + # Phase 1: Evaluate all lambdas upfront + videos = config.get("videos", []) + all_results = [] + all_timings = [] # (seg_start, seg_duration) per valid beat + original_indices = [] # original beat index for each result + + for i, (start, end) in enumerate(slice_times): + if start >= end: + continue + + # Build environment with sources, effects, and builtins + env = dict(sources) + + # Add effect names so they can be referenced as symbols + for effect_name in registry.get("effects", {}): + env[effect_name] = effect_name + + # Make :videos list available to lambda + if videos: + env["videos"] = videos + + env["acc"] = acc + env["i"] = i + env["start"] = start + env["end"] = end + + # Evaluate the reducer + result = evaluate([reducer_fn, Symbol("acc"), Symbol("i"), + Symbol("start"), Symbol("end")], env) + + if not isinstance(result, dict): + raise ValueError(f"Reducer must return a dict, got {type(result)}") + + # Extract accumulator + acc = result.get("acc", acc) + + # Segment timing: use frame-aligned values to prevent drift + # Lambda can override with explicit start/duration/end + if result.get("start") is not None or result.get("duration") is not None or result.get("end") is not None: + # Explicit timing from lambda - use as-is + seg_start = result.get("start", start) + seg_duration = result.get("duration") + if seg_duration is None: + if result.get("end") is not None: + seg_duration = result["end"] - seg_start + else: + seg_duration = end - start + else: + # Default: use frame-aligned start and duration to prevent accumulated drift + seg_start = frame_aligned_starts[i] if i < len(frame_aligned_starts) else start + seg_duration = frame_aligned_durations[i] if i < len(frame_aligned_durations) else (end - start) + + all_results.append(result) + all_timings.append((seg_start, seg_duration)) + original_indices.append(i) + + # Phase 2: Merge or expand + all_composition = ( + len(all_results) > 1 + and all("layers" in r for r in all_results) + and named_analysis is not None + ) + + if all_composition: + # All beats are composition mode — try to merge consecutive + # beats with the same layer structure + _merge_composition_beats( + all_results, all_timings, base_id, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, + ) + else: + # Fallback: expand each beat individually + for idx, result in enumerate(all_results): + orig_i = original_indices[idx] + seg_start, seg_duration = all_timings[idx] + + if "layers" in result: + # COMPOSITION MODE — multi-source with per-layer effects + compositor + _expand_composition_beat( + result, orig_i, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs, + ) + else: + # SINGLE-SOURCE MODE (existing behavior) + source_name = result.get("source") + effects = result.get("effects", []) + + # Resolve source to node ID + if isinstance(source_name, Symbol): + source_name = source_name.name + valid_node_ids = set(sources.values()) + if source_name in sources: + video_input = sources[source_name] + elif source_name in valid_node_ids: + video_input = source_name + else: + video_input = default_video + + # Create SEGMENT node + segment_id = f"{base_id}_seg_{orig_i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": seg_start, + "duration": seg_duration, + }, + "inputs": [video_input], + } + expanded_nodes.append(segment_node) + + # Apply effects chain + current_input = segment_id + for j, effect in enumerate(effects): + effect_name, effect_params = _parse_effect_spec(effect) + if not effect_name: + continue + + effect_id = f"{base_id}_fx_{orig_i:04d}_{j}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + + effect_config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + effect_config.update(effect_params) + + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [current_input], + } + expanded_nodes.append(effect_node) + current_input = effect_id + + sequence_inputs.append(current_input) + + else: + # Legacy mode - :effect and :pattern + effect_name = config.get("effect") + effect_path = config.get("effect_path") + pattern = config.get("pattern", "all") + video_input = default_video + + if not video_input: + raise ValueError("SLICE_ON requires video input") + + slice_times = list(zip([0] + times[:-1], times)) + + for i, (start, end) in enumerate(slice_times): + if start >= end: + continue + + # Determine if effect should be applied + apply_effect = False + if effect_name: + if pattern == "all": + apply_effect = True + elif pattern == "odd": + apply_effect = (i % 2 == 1) + elif pattern == "even": + apply_effect = (i % 2 == 0) + elif pattern == "alternate": + apply_effect = (i % 2 == 1) + + # Create SEGMENT node + segment_id = f"{base_id}_seg_{i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": start, + "duration": end - start, + }, + "inputs": [video_input], + } + expanded_nodes.append(segment_node) + + if apply_effect: + effect_id = f"{base_id}_fx_{i:04d}" + effect_config = {"effect": effect_name} + if effect_path: + effect_config["effect_path"] = effect_path + + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [segment_id], + } + expanded_nodes.append(effect_node) + sequence_inputs.append(effect_id) + else: + sequence_inputs.append(segment_id) + # Create LIST node to hold all slices (user must explicitly sequence them) + list_node = { + "id": node["id"], # Keep original ID for reference stability + "type": "LIST", + "config": {}, + "inputs": sequence_inputs, + } + expanded_nodes.append(list_node) + + return expanded_nodes + + +def _parse_effect_spec(effect): + """Parse an effect spec into (name, params) from Symbol, string, or dict.""" + from .parser import Symbol + + effect_name = None + effect_params = {} + + if isinstance(effect, Symbol): + effect_name = effect.name + elif isinstance(effect, str): + effect_name = effect + elif isinstance(effect, dict): + effect_name = effect.get("effect") + if isinstance(effect_name, Symbol): + effect_name = effect_name.name + for k, v in effect.items(): + if k != "effect": + effect_params[k] = v + + return effect_name, effect_params + + +def _expand_composition_beat(result, beat_idx, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs): + """ + Expand a composition-mode beat into per-layer SEGMENT + EFFECT nodes + and a single composition EFFECT node. + + Args: + result: Lambda result dict with 'layers' and optional 'compose' + beat_idx: Beat index for ID generation + base_id: Base ID prefix + videos: List of video node IDs from :videos config + registry: Recipe registry with effect definitions + seg_start: Segment start time + seg_duration: Segment duration + expanded_nodes: List to append generated nodes to + sequence_inputs: List to append final composition node ID to + """ + layers = result["layers"] + compose_spec = result.get("compose", {}) + + layer_outputs = [] + for layer_idx, layer in enumerate(layers): + # Resolve video: integer index into videos list, or node ID string + video_ref = layer.get("video") + if isinstance(video_ref, (int, float)): + video_input = videos[int(video_ref)] + else: + video_input = str(video_ref) + + # SEGMENT for this layer + segment_id = f"{base_id}_seg_{beat_idx:04d}_L{layer_idx}" + expanded_nodes.append({ + "id": segment_id, + "type": "SEGMENT", + "config": {"start": seg_start, "duration": seg_duration}, + "inputs": [video_input], + }) + + # Per-layer EFFECT chain + current = segment_id + for fx_idx, effect in enumerate(layer.get("effects", [])): + effect_name, effect_params = _parse_effect_spec(effect) + if not effect_name: + continue + effect_id = f"{base_id}_fx_{beat_idx:04d}_L{layer_idx}_{fx_idx}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + config.update(effect_params) + expanded_nodes.append({ + "id": effect_id, + "type": "EFFECT", + "config": config, + "inputs": [current], + }) + current = effect_id + layer_outputs.append(current) + + # Composition EFFECT node + compose_name = compose_spec.get("effect", "blend_multi") + compose_id = f"{base_id}_comp_{beat_idx:04d}" + compose_entry = registry.get("effects", {}).get(compose_name, {}) + compose_config = { + "effect": compose_name, + "effect_path": compose_entry.get("path"), + "multi_input": True, + } + for k, v in compose_spec.items(): + if k != "effect": + compose_config[k] = v + + expanded_nodes.append({ + "id": compose_id, + "type": "EFFECT", + "config": compose_config, + "inputs": layer_outputs, + }) + sequence_inputs.append(compose_id) + + +def _fingerprint_composition(result): + """Create a hashable fingerprint of a composition beat's layer structure. + + Beats with the same fingerprint have the same video refs, effect names, + and compositor type — only parameter values differ. Such beats can be + merged into a single composition with time-varying bindings. + """ + layers = result.get("layers", []) + compose = result.get("compose", {}) + + layer_fps = [] + for layer in layers: + video_ref = layer.get("video") + effect_names = tuple( + _parse_effect_spec(e)[0] for e in layer.get("effects", []) + ) + layer_fps.append((video_ref, effect_names)) + + compose_name = compose.get("effect", "blend_multi") + # Include static compose params (excluding list-valued params like weights) + static_compose = tuple(sorted( + (k, v) for k, v in compose.items() + if k not in ("effect", "weights") and isinstance(v, (str, int, float, bool)) + )) + + return (len(layers), tuple(layer_fps), compose_name, static_compose) + + +def _merge_composition_beats( + all_results, all_timings, base_id, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, +): + """Merge consecutive composition beats with the same layer structure. + + Groups consecutive beats by structural fingerprint. Groups of 2+ beats + get merged into a single composition with synthetic analysis tracks for + time-varying parameters. Single beats use standard per-beat expansion. + """ + import sys + + # Compute fingerprints + fingerprints = [_fingerprint_composition(r) for r in all_results] + + # Group consecutive beats with the same fingerprint + groups = [] # list of (start_idx, end_idx_exclusive) + group_start = 0 + for i in range(1, len(fingerprints)): + if fingerprints[i] != fingerprints[group_start]: + groups.append((group_start, i)) + group_start = i + groups.append((group_start, len(fingerprints))) + + print(f" Composition merging: {len(all_results)} beats -> {len(groups)} groups", file=sys.stderr) + + for group_idx, (g_start, g_end) in enumerate(groups): + group_size = g_end - g_start + + if group_size == 1: + # Single beat — use standard expansion + result = all_results[g_start] + seg_start, seg_duration = all_timings[g_start] + _expand_composition_beat( + result, g_start, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs, + ) + else: + # Merge group into one composition with time-varying bindings + _merge_composition_group( + all_results, all_timings, + list(range(g_start, g_end)), + base_id, group_idx, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, + ) + + +def _merge_composition_group( + all_results, all_timings, group_indices, + base_id, group_idx, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, +): + """Merge a group of same-structure composition beats into one composition. + + Creates: + - One SEGMENT per layer (spanning full group duration) + - One EFFECT per layer with time-varying params via synthetic analysis tracks + - One compositor EFFECT with time-varying weights via synthetic tracks + """ + import sys + + first = all_results[group_indices[0]] + layers = first["layers"] + compose_spec = first.get("compose", {}) + num_layers = len(layers) + + # Group timing + first_start = all_timings[group_indices[0]][0] + last_start, last_dur = all_timings[group_indices[-1]] + group_duration = (last_start + last_dur) - first_start + + # Beat start times for synthetic tracks (absolute times) + beat_times = [float(all_timings[i][0]) for i in group_indices] + + print(f" Group {group_idx}: {len(group_indices)} beats, " + f"{first_start:.1f}s -> {first_start + group_duration:.1f}s " + f"({num_layers} layers)", file=sys.stderr) + + # --- Per-layer segments and effects --- + layer_outputs = [] + for layer_idx in range(num_layers): + layer = layers[layer_idx] + + # Resolve video input + video_ref = layer.get("video") + if isinstance(video_ref, (int, float)): + video_input = videos[int(video_ref)] + else: + video_input = str(video_ref) + + # SEGMENT for this layer (full group duration) + segment_id = f"{base_id}_seg_G{group_idx:03d}_L{layer_idx}" + expanded_nodes.append({ + "id": segment_id, + "type": "SEGMENT", + "config": {"start": first_start, "duration": group_duration}, + "inputs": [video_input], + }) + + # Per-layer EFFECT chain + current = segment_id + effects = layer.get("effects", []) + for fx_idx, effect in enumerate(effects): + effect_name, first_params = _parse_effect_spec(effect) + if not effect_name: + continue + + effect_id = f"{base_id}_fx_G{group_idx:03d}_L{layer_idx}_{fx_idx}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + fx_config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + + # For each param, check if it varies across beats + for param_name, first_val in first_params.items(): + values = [] + for bi in group_indices: + beat_layer = all_results[bi]["layers"][layer_idx] + beat_effects = beat_layer.get("effects", []) + if fx_idx < len(beat_effects): + _, beat_params = _parse_effect_spec(beat_effects[fx_idx]) + values.append(float(beat_params.get(param_name, first_val))) + else: + values.append(float(first_val)) + + # Check if all values are identical + if all(v == values[0] for v in values): + fx_config[param_name] = values[0] + else: + # Create synthetic analysis track + # Prefix with 'syn_' to ensure valid S-expression symbol + # (base_id may start with digits, which the parser splits) + track_name = f"syn_{base_id}_L{layer_idx}_fx{fx_idx}_{param_name}" + named_analysis[track_name] = { + "times": beat_times, + "values": values, + } + fx_config[param_name] = { + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], # pass-through + } + + expanded_nodes.append({ + "id": effect_id, + "type": "EFFECT", + "config": fx_config, + "inputs": [current], + }) + current = effect_id + + layer_outputs.append(current) + + # --- Compositor --- + compose_name = compose_spec.get("effect", "blend_multi") + compose_id = f"{base_id}_comp_G{group_idx:03d}" + compose_entry = registry.get("effects", {}).get(compose_name, {}) + compose_config = { + "effect": compose_name, + "effect_path": compose_entry.get("path"), + "multi_input": True, + } + + for k, v in compose_spec.items(): + if k == "effect": + continue + + if isinstance(v, list): + # List param (e.g., weights) — check each element + merged_list = [] + for elem_idx in range(len(v)): + elem_values = [] + for bi in group_indices: + beat_compose = all_results[bi].get("compose", {}) + beat_v = beat_compose.get(k, v) + if isinstance(beat_v, list) and elem_idx < len(beat_v): + elem_values.append(float(beat_v[elem_idx])) + else: + elem_values.append(float(v[elem_idx])) + + if all(ev == elem_values[0] for ev in elem_values): + merged_list.append(elem_values[0]) + else: + track_name = f"syn_{base_id}_comp_{k}_{elem_idx}" + named_analysis[track_name] = { + "times": beat_times, + "values": elem_values, + } + merged_list.append({ + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], + }) + compose_config[k] = merged_list + elif isinstance(v, (int, float)): + # Scalar param — check if it varies + values = [] + for bi in group_indices: + beat_compose = all_results[bi].get("compose", {}) + values.append(float(beat_compose.get(k, v))) + + if all(val == values[0] for val in values): + compose_config[k] = values[0] + else: + track_name = f"syn_{base_id}_comp_{k}" + named_analysis[track_name] = { + "times": beat_times, + "values": values, + } + compose_config[k] = { + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], + } + else: + # String or other — keep as-is + compose_config[k] = v + + expanded_nodes.append({ + "id": compose_id, + "type": "EFFECT", + "config": compose_config, + "inputs": layer_outputs, + }) + sequence_inputs.append(compose_id) + + +def _parse_construct_params(params_list: list) -> tuple: + """ + Parse :params block in a construct definition. + + Syntax: + ( + (param_name :type string :default "value" :desc "description") + ) + + Returns: + (param_names, param_defaults) where param_names is a list of strings + and param_defaults is a dict of param_name -> default_value + """ + param_names = [] + param_defaults = {} + + for param_def in params_list: + if not isinstance(param_def, list) or len(param_def) < 1: + continue + + # First element is the parameter name + first = param_def[0] + if isinstance(first, Symbol): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + continue + + param_names.append(param_name) + + # Parse keyword arguments + default = None + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 >= len(param_def): + break + kw_value = param_def[i + 1] + + if item.name == "default": + default = kw_value + # We could also parse :type, :range, :choices, :desc here + i += 2 + else: + i += 1 + + param_defaults[param_name] = default + + return param_names, param_defaults + + +def _expand_construct( + node: Dict, + registry: Dict, + sources: Dict[str, str], + analysis_data: Dict[str, Dict], + recipe_dir: Path, + cluster_key: str = None, + encoding: Dict = None, +) -> List[Dict]: + """ + Expand a user-defined CONSTRUCT node. + + Loads the construct definition from .sexp file, evaluates it with + the provided arguments, and converts the result into segment nodes. + + Args: + node: The CONSTRUCT node to expand + registry: Recipe registry + sources: Map of source names to node IDs + analysis_data: Analysis results (analysis_id -> {times, values}) + recipe_dir: Recipe directory for resolving paths + cluster_key: Optional cluster key for hashing + encoding: Encoding config + + Returns: + List of expanded nodes (segments, effects, list) + """ + from .parser import parse_all, Symbol + from .evaluator import evaluate + + config = node.get("config", {}) + construct_name = config.get("construct_name") + construct_path = config.get("construct_path") + args = config.get("args", []) + + # Load construct definition + full_path = recipe_dir / construct_path + if not full_path.exists(): + raise ValueError(f"Construct file not found: {full_path}") + + print(f" Loading construct: {construct_name} from {construct_path}", file=sys.stderr) + + construct_text = full_path.read_text() + construct_sexp = parse_all(construct_text) + + # Parse define-construct: (define-construct name "desc" (params...) body) + if not isinstance(construct_sexp, list): + construct_sexp = [construct_sexp] + + # Process imports (effect, construct declarations) in the construct file + # These extend the registry for this construct's scope + local_registry = dict(registry) # Copy parent registry + construct_def = None + + for expr in construct_sexp: + if isinstance(expr, list) and expr and isinstance(expr[0], Symbol): + form_name = expr[0].name + + if form_name == "effect": + # (effect name :path "...") + effect_name = expr[1].name if isinstance(expr[1], Symbol) else expr[1] + # Parse kwargs + i = 2 + kwargs = {} + while i < len(expr): + if isinstance(expr[i], Keyword): + kwargs[expr[i].name] = expr[i + 1] if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + local_registry.setdefault("effects", {})[effect_name] = { + "path": kwargs.get("path"), + "cid": kwargs.get("cid"), + } + print(f" Construct imports effect: {effect_name}", file=sys.stderr) + + elif form_name == "define-construct": + construct_def = expr + + if not construct_def: + raise ValueError(f"No define-construct found in {construct_path}") + + # Use local_registry instead of registry from here + registry = local_registry + + # Parse define-construct - requires :params syntax: + # (define-construct name + # :params ( + # (param1 :type string :default "value" :desc "description") + # ) + # body) + # + # Legacy syntax (define-construct name "desc" (param1 param2) body) is not supported. + def_name = construct_def[1].name if isinstance(construct_def[1], Symbol) else construct_def[1] + + params = [] # List of param names + param_defaults = {} # param_name -> default value + body = None + found_params = False + + idx = 2 + while idx < len(construct_def): + item = construct_def[idx] + if isinstance(item, Keyword) and item.name == "params": + # :params syntax + if idx + 1 >= len(construct_def): + raise ValueError(f"Construct '{def_name}': Missing params list after :params keyword") + params_list = construct_def[idx + 1] + params, param_defaults = _parse_construct_params(params_list) + found_params = True + idx += 2 + elif isinstance(item, Keyword): + # Skip other keywords (like :desc) + idx += 2 + elif isinstance(item, str): + # Skip description strings (but warn about legacy format) + print(f" Warning: Description strings in define-construct are deprecated", file=sys.stderr) + idx += 1 + elif body is None: + # First non-keyword, non-string item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy params syntax and reject it + if isinstance(first_elem, Symbol) and first_elem.name not in ("let", "let*", "if", "when", "do", "begin", "->", "map", "filter", "fn", "reduce", "nth"): + # Could be legacy params if all items are just symbols + if all(isinstance(p, Symbol) for p in item): + raise ValueError( + f"Construct '{def_name}': Legacy parameter syntax (param1 param2) is not supported. " + f"Use :params block instead." + ) + body = item + idx += 1 + else: + idx += 1 + + if body is None: + raise ValueError(f"No body found in define-construct {def_name}") + + # Build environment with sources and analysis data + env = dict(sources) + + # Add bindings from compiler (video-a, video-b, etc.) + if "bindings" in config: + env.update(config["bindings"]) + + # Add effect names so they can be referenced as symbols + for effect_name in registry.get("effects", {}): + env[effect_name] = effect_name + + # Map analysis node IDs to their data with :times and :values + for analysis_id, data in analysis_data.items(): + # Find the name this analysis was bound to + for name, node_id in sources.items(): + if node_id == analysis_id or name.endswith("-data"): + env[name] = data + env[analysis_id] = data + + # Apply param defaults first (for :params syntax) + for param_name, default_value in param_defaults.items(): + if default_value is not None: + env[param_name] = default_value + + # Bind positional args to params (overrides defaults) + param_names = [p.name if isinstance(p, Symbol) else p for p in params] + for i, param in enumerate(param_names): + if i < len(args): + arg = args[i] + # Resolve node IDs to their data if it's analysis + if isinstance(arg, str) and arg in analysis_data: + env[param] = analysis_data[arg] + else: + env[param] = arg + + # Helper to resolve node IDs to analysis data recursively + def resolve_value(val): + """Resolve node IDs and symbols in a value, including inside dicts/lists.""" + if isinstance(val, str) and val in analysis_data: + return analysis_data[val] + elif isinstance(val, str) and val in env: + return env[val] + elif isinstance(val, Symbol): + if val.name in env: + return env[val.name] + return val + elif isinstance(val, dict): + return {k: resolve_value(v) for k, v in val.items()} + elif isinstance(val, list): + return [resolve_value(v) for v in val] + return val + + # Validate and bind keyword arguments from the config (excluding internal keys) + # These may be S-expressions that need evaluation (e.g., lambdas) + # or Symbols that need resolution from bindings + internal_keys = {"construct_name", "construct_path", "args", "bindings"} + known_params = set(param_names) | set(param_defaults.keys()) + for key, value in config.items(): + if key not in internal_keys: + # Convert key to valid identifier (replace - with _) for checking + param_key = key.replace("-", "_") + if param_key not in known_params: + raise ValueError( + f"Construct '{def_name}': Unknown parameter '{key}'. " + f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}" + ) + # Evaluate if it's an expression (list starting with Symbol) + if isinstance(value, list) and value and isinstance(value[0], Symbol): + env[param_key] = evaluate(value, env) + elif isinstance(value, Symbol): + # Resolve Symbol from env/bindings, then resolve any node IDs in the value + if value.name in env: + env[param_key] = resolve_value(env[value.name]) + else: + raise ValueError(f"Undefined symbol in construct arg: {value.name}") + else: + # Resolve node IDs inside dicts/lists + env[param_key] = resolve_value(value) + + # Evaluate construct body + print(f" Evaluating construct with params: {param_names}", file=sys.stderr) + segments = evaluate(body, env) + + if not isinstance(segments, list): + raise ValueError(f"Construct must return a list of segments, got {type(segments)}") + + print(f" Construct produced {len(segments)} segments", file=sys.stderr) + + # Convert segment descriptors to plan nodes + expanded_nodes = [] + sequence_inputs = [] + base_id = node["id"][:8] + + for i, seg in enumerate(segments): + if not isinstance(seg, dict): + continue + + source_ref = seg.get("source") + start = seg.get("start", 0) + print(f" DEBUG segment {i}: source={str(source_ref)[:20]}... start={start}", file=sys.stderr) + end = seg.get("end") + duration = seg.get("duration") or (end - start if end else 1.0) + effects = seg.get("effects", []) + + # Resolve source reference to node ID + source_id = sources.get(source_ref, source_ref) if isinstance(source_ref, str) else source_ref + + # Create segment node + segment_id = f"{base_id}_seg_{i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": start, + "duration": duration, + }, + "inputs": [source_id] if source_id else [], + } + expanded_nodes.append(segment_node) + + # Add effects if specified + if effects: + prev_id = segment_id + for j, eff in enumerate(effects): + effect_name = eff.get("effect") if isinstance(eff, dict) else eff + effect_id = f"{base_id}_fx_{i:04d}_{j:02d}" + # Look up effect_path from registry (prevents collapsing Python effects) + effect_entry = registry.get("effects", {}).get(effect_name, {}) + effect_config = { + "effect": effect_name, + **{k: v for k, v in (eff.items() if isinstance(eff, dict) else []) if k != "effect"}, + } + if effect_entry.get("path"): + effect_config["effect_path"] = effect_entry["path"] + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [prev_id], + } + expanded_nodes.append(effect_node) + prev_id = effect_id + sequence_inputs.append(prev_id) + else: + sequence_inputs.append(segment_id) + + # Create LIST node + list_node = { + "id": node["id"], + "type": "LIST", + "config": {}, + "inputs": sequence_inputs, + } + expanded_nodes.append(list_node) + + return expanded_nodes + + +def _expand_nodes( + nodes: List[Dict], + registry: Dict, + recipe_dir: Path, + source_paths: Dict[str, Path], + work_dir: Path = None, + cluster_key: str = None, + on_analysis: Callable[[str, Dict], None] = None, + encoding: Dict = None, + pre_analysis: Dict[str, Dict] = None, +) -> List[Dict]: + """ + Expand dynamic nodes (SLICE_ON) by running analyzers. + + Processes nodes in dependency order: + 1. SOURCE nodes: resolve file paths + 2. SEGMENT nodes: pre-execute if needed for analysis + 3. ANALYZE nodes: run analyzers (or use pre_analysis), store results + 4. SLICE_ON nodes: expand using analysis results + + Args: + nodes: List of compiled nodes + registry: Recipe registry + recipe_dir: Directory for resolving relative paths + source_paths: Resolved source paths (id -> path) + work_dir: Working directory for temporary files (created if None) + cluster_key: Optional cluster key + on_analysis: Callback when analysis completes (node_id, results) + pre_analysis: Pre-computed analysis data (name -> results) + + Returns: + Tuple of (expanded_nodes, named_analysis) where: + - expanded_nodes: List with SLICE_ON replaced by primitives + - named_analysis: Dict of analyzer_name -> {times, values} + """ + import tempfile + + nodes_by_id = {n["id"]: n for n in nodes} + sorted_ids = _topological_sort(nodes) + + # Create work directory if needed + if work_dir is None: + work_dir = Path(tempfile.mkdtemp(prefix="artdag_plan_")) + + # Track outputs and analysis results + outputs = {} # node_id -> output path or analysis data + analysis_results = {} # node_id -> analysis dict + named_analysis = {} # analyzer_name -> analysis dict (for effect bindings) + pre_executed = set() # nodes pre-executed during planning + expanded = [] + expanded_ids = set() + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + node_type = node["type"] + + if node_type == "SOURCE": + # Resolve source path + config = node.get("config", {}) + if "path" in config: + path = recipe_dir / config["path"] + outputs[node_id] = path.resolve() + source_paths[node_id] = outputs[node_id] + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "SEGMENT": + # Check if this segment's input is resolved + inputs = node.get("inputs", []) + if inputs and inputs[0] in outputs: + input_path = outputs[inputs[0]] + if isinstance(input_path, Path): + # Skip pre-execution if config contains unresolved bindings + seg_config = node.get("config", {}) + has_binding = any( + isinstance(v, Binding) or (isinstance(v, dict) and v.get("_binding")) + for v in [seg_config.get("start"), seg_config.get("duration"), seg_config.get("end")] + if v is not None + ) + if not has_binding: + # Pre-execute segment to get output path + # This is needed if ANALYZE depends on this segment + import sys + print(f" Pre-executing segment: {node_id[:16]}...", file=sys.stderr) + output_path = _pre_execute_segment(node, input_path, work_dir) + outputs[node_id] = output_path + pre_executed.add(node_id) + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "ANALYZE": + # Get or run analysis + config = node.get("config", {}) + analysis_name = node.get("name") or config.get("analyzer") + + # Check for pre-computed analysis first + if pre_analysis and analysis_name and analysis_name in pre_analysis: + import sys + print(f" Using pre-computed analysis: {analysis_name}", file=sys.stderr) + results = pre_analysis[analysis_name] + else: + # Run analyzer to get concrete data + analyzer_path = config.get("analyzer_path") + node_inputs = node.get("inputs", []) + + if not node_inputs: + raise ValueError(f"ANALYZE node {node_id} has no inputs") + + # Get input path - could be SOURCE or pre-executed SEGMENT + input_id = node_inputs[0] + input_path = outputs.get(input_id) + + if input_path is None: + raise ValueError( + f"ANALYZE input {input_id} not resolved. " + "Check that input SOURCE or SEGMENT exists." + ) + + if not isinstance(input_path, Path): + raise ValueError( + f"ANALYZE input {input_id} is not a file path: {type(input_path)}" + ) + + if analyzer_path: + full_path = recipe_dir / analyzer_path + params = {k: v for k, v in config.items() + if k not in ("analyzer", "analyzer_path", "cid")} + import sys + print(f" Running analyzer: {config.get('analyzer', 'unknown')}", file=sys.stderr) + results = _run_analyzer(full_path, input_path, params) + else: + raise ValueError(f"ANALYZE node {node_id} missing analyzer_path") + + analysis_results[node_id] = results + outputs[node_id] = results + + # Store by name for effect binding resolution + if analysis_name: + named_analysis[analysis_name] = results + + if on_analysis: + on_analysis(node_id, results) + + # Keep ANALYZE node in plan (it produces a JSON artifact) + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "SLICE_ON": + # Expand into primitives using analysis results + inputs = node.get("inputs", []) + config = node.get("config", {}) + + # Lambda mode can have just 1 input (analysis), legacy needs 2 (video + analysis) + has_lambda = "fn" in config + if has_lambda: + if len(inputs) < 1: + raise ValueError(f"SLICE_ON {node_id} requires analysis input") + analysis_id = inputs[0] # First input is analysis + else: + if len(inputs) < 2: + raise ValueError(f"SLICE_ON {node_id} requires video and analysis inputs") + analysis_id = inputs[1] + + if analysis_id not in analysis_results: + raise ValueError( + f"SLICE_ON {node_id} analysis input {analysis_id} not found" + ) + + # Build sources map: name -> node_id + # This lets the lambda reference videos by name + sources = {} + for n in nodes: + if n.get("name"): + sources[n["name"]] = n["id"] + + analysis_data = analysis_results[analysis_id] + slice_nodes = _expand_slice_on(node, analysis_data, registry, sources, cluster_key, encoding, named_analysis) + + for sn in slice_nodes: + if sn["id"] not in expanded_ids: + expanded.append(sn) + expanded_ids.add(sn["id"]) + + elif node_type == "CONSTRUCT": + # Expand user-defined construct + config = node.get("config", {}) + construct_name = config.get("construct_name") + construct_path = config.get("construct_path") + + if not construct_path: + raise ValueError(f"CONSTRUCT {node_id} missing path") + + # Build sources map + sources = {} + for n in nodes: + if n.get("name"): + sources[n["name"]] = n["id"] + + # Get analysis data if referenced + inputs = node.get("inputs", []) + analysis_data = {} + for inp in inputs: + if inp in analysis_results: + analysis_data[inp] = analysis_results[inp] + + construct_nodes = _expand_construct( + node, registry, sources, analysis_data, recipe_dir, cluster_key, encoding + ) + + for cn in construct_nodes: + if cn["id"] not in expanded_ids: + expanded.append(cn) + expanded_ids.add(cn["id"]) + + else: + # Keep other nodes as-is + expanded.append(node) + expanded_ids.add(node_id) + + return expanded, named_analysis + + +def create_plan( + recipe: CompiledRecipe, + inputs: Dict[str, str] = None, + recipe_dir: Path = None, + cluster_key: str = None, + on_analysis: Callable[[str, Dict], None] = None, + pre_analysis: Dict[str, Dict] = None, +) -> ExecutionPlanSexp: + """ + Create an execution plan from a compiled recipe. + + Args: + recipe: Compiled S-expression recipe + inputs: Mapping of input names to content hashes + recipe_dir: Directory for resolving relative paths (required for analyzers) + cluster_key: Optional cluster key for cache isolation + on_analysis: Callback when analysis completes (node_id, results) + pre_analysis: Pre-computed analysis data (name -> results), skips running analyzers + + Returns: + ExecutionPlanSexp with all cache IDs computed + + Example: + >>> recipe = compile_string('(recipe "test" (-> (source cat) (effect identity)))') + >>> plan = create_plan(recipe, inputs={}, recipe_dir=Path(".")) + >>> print(plan.to_string()) + """ + inputs = inputs or {} + + # Compute source hash as CID (SHA256 of raw bytes) - this IS the content address + source_hash = hashlib.sha256(recipe.source_text.encode('utf-8')).hexdigest() if recipe.source_text else "" + + # Compute params hash (use JSON + SHA256 for consistency with cache.py) + if recipe.resolved_params: + import json + params_str = json.dumps(recipe.resolved_params, sort_keys=True, default=str) + params_hash = hashlib.sha256(params_str.encode()).hexdigest() + else: + params_hash = "" + + # Check if recipe has expandable nodes (SLICE_ON, etc.) + has_expandable = any(n["type"] in EXPANDABLE_TYPES for n in recipe.nodes) + named_analysis = {} + + if has_expandable: + if recipe_dir is None: + raise ValueError("recipe_dir required for recipes with SLICE_ON nodes") + + # Expand dynamic nodes (runs analyzers, expands SLICE_ON) + source_paths = {} + expanded_nodes, named_analysis = _expand_nodes( + recipe.nodes, + recipe.registry, + recipe_dir, + source_paths, + cluster_key=cluster_key, + on_analysis=on_analysis, + encoding=recipe.encoding, + pre_analysis=pre_analysis, + ) + # Expand LIST inputs in SEQUENCE nodes + expanded_nodes = _expand_list_inputs(expanded_nodes) + # Collapse effect chains after expansion + collapsed_nodes = _collapse_effect_chains(expanded_nodes, recipe.registry) + else: + # No expansion needed + collapsed_nodes = _collapse_effect_chains(recipe.nodes, recipe.registry) + + # Build node lookup from collapsed nodes + nodes_by_id = {node["id"]: node for node in collapsed_nodes} + + # Topological sort + sorted_ids = _topological_sort(collapsed_nodes) + + # Create steps with resolved hashes + steps = [] + cache_ids = {} # step_id -> cache_id + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + step = _create_step( + node, + recipe.registry, + inputs, + cache_ids, + cluster_key, + ) + steps.append(step) + cache_ids[node_id] = step.cache_id + + # Compute levels + _compute_levels(steps, nodes_by_id) + + # Handle stage-aware planning if recipe has stages + stage_plans = [] + stage_order = [] + stage_levels = {} + + if recipe.stages: + # Build mapping from node_id to stage + node_to_stage = {} + for stage in recipe.stages: + for node_id in stage.node_ids: + node_to_stage[node_id] = stage.name + + # Compute stage levels (for parallel execution) + stage_levels = _compute_stage_levels(recipe.stages) + + # Tag each step with stage info + for step in steps: + if step.step_id in node_to_stage: + step.stage = node_to_stage[step.step_id] + + # Build stage plans + for stage_name in recipe.stage_order: + stage = next(s for s in recipe.stages if s.name == stage_name) + stage_steps = [s for s in steps if s.stage == stage_name] + + # Build output bindings with cache IDs + output_cache_ids = {} + for out_name, node_id in stage.output_bindings.items(): + if node_id in cache_ids: + output_cache_ids[out_name] = cache_ids[node_id] + + stage_plans.append(StagePlan( + stage_name=stage_name, + steps=stage_steps, + requires=stage.requires, + output_bindings=output_cache_ids, + level=stage_levels.get(stage_name, 0), + )) + + stage_order = recipe.stage_order + + # Compute plan ID from source CID + steps + plan_content = { + "source_cid": source_hash, + "steps": [{"id": s.step_id, "cache_id": s.cache_id} for s in steps], + "inputs": inputs, + } + plan_id = _stable_hash(plan_content, cluster_key) + + return ExecutionPlanSexp( + plan_id=plan_id, + source_hash=source_hash, + params=recipe.resolved_params, + params_hash=params_hash, + steps=steps, + output_step_id=recipe.output_node_id, + inputs=inputs, + analysis=named_analysis, + stage_plans=stage_plans, + stage_order=stage_order, + stage_levels=stage_levels, + effects_registry=recipe.registry.get("effects", {}), + minimal_primitives=recipe.minimal_primitives, + ) + + +def _topological_sort(nodes: List[Dict]) -> List[str]: + """Sort nodes in dependency order.""" + nodes_by_id = {n["id"]: n for n in nodes} + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = nodes_by_id.get(node_id) + if node: + for input_id in node.get("inputs", []): + visit(input_id) + order.append(node_id) + + for node in nodes: + visit(node["id"]) + + return order + + +def _create_step( + node: Dict, + registry: Dict, + inputs: Dict[str, str], + cache_ids: Dict[str, str], + cluster_key: str = None, +) -> PlanStep: + """Create a PlanStep from a node definition.""" + node_id = node["id"] + node_type = node["type"] + config = dict(node.get("config", {})) + node_inputs = node.get("inputs", []) + + # Resolve registry references + resolved_config = _resolve_config(config, registry, inputs) + + # Get input cache IDs (direct graph inputs) + input_cache_ids = [cache_ids[inp] for inp in node_inputs if inp in cache_ids] + + # Also include analysis_refs as dependencies (for binding resolution) + # These are implicit inputs that affect the computation result + analysis_refs = resolved_config.get("analysis_refs", []) + analysis_cache_ids = [cache_ids[ref] for ref in analysis_refs if ref in cache_ids] + + # Compute cache ID including both inputs and analysis dependencies + cache_content = { + "node_type": node_type, + "config": resolved_config, + "inputs": sorted(input_cache_ids + analysis_cache_ids), + } + cache_id = _stable_hash(cache_content, cluster_key) + + return PlanStep( + step_id=node_id, + node_type=node_type, + config=resolved_config, + inputs=node_inputs, + cache_id=cache_id, + ) + + +def _resolve_config( + config: Dict, + registry: Dict, + inputs: Dict[str, str], +) -> Dict: + """Resolve registry references in config to content hashes.""" + resolved = {} + + for key, value in config.items(): + if key == "filter_chain" and isinstance(value, list): + # Resolve each filter in the chain (for COMPOUND nodes) + resolved_chain = [] + for filter_item in value: + filter_config = filter_item.get("config", {}) + resolved_filter_config = _resolve_config(filter_config, registry, inputs) + resolved_chain.append({ + "type": filter_item["type"], + "config": resolved_filter_config, + }) + resolved["filter_chain"] = resolved_chain + + elif key == "asset" and isinstance(value, str): + # Resolve asset reference - use CID from registry + if value in registry.get("assets", {}): + resolved["cid"] = registry["assets"][value]["cid"] + else: + resolved["asset"] = value # Keep as-is if not in registry + + elif key == "effect" and isinstance(value, str): + # Resolve effect reference - keep name AND add CID/path + resolved["effect"] = value + if value in registry.get("effects", {}): + effect_entry = registry["effects"][value] + if effect_entry.get("cid"): + resolved["cid"] = effect_entry["cid"] + if effect_entry.get("path"): + resolved["effect_path"] = effect_entry["path"] + + elif key == "input" and value is True: + # Variable input - resolve from inputs dict + input_name = config.get("name", "input") + if input_name in inputs: + resolved["hash"] = inputs[input_name] + else: + resolved["input"] = True + resolved["name"] = input_name + + elif key == "path": + # Local file path - keep as-is for local execution + resolved["path"] = value + + else: + resolved[key] = value + + return resolved + + +def _compute_levels(steps: List[PlanStep], nodes_by_id: Dict) -> None: + """Compute dependency levels for steps. + + Considers both inputs (data dependencies) and analysis_refs (binding dependencies). + """ + levels = {} + + def compute_level(step_id: str) -> int: + if step_id in levels: + return levels[step_id] + + node = nodes_by_id.get(step_id) + if not node: + levels[step_id] = 0 + return 0 + + # Collect all dependencies: inputs + analysis_refs + deps = list(node.get("inputs", [])) + + # Add analysis_refs as dependencies (for bindings to analysis data) + config = node.get("config", {}) + analysis_refs = config.get("analysis_refs", []) + deps.extend(analysis_refs) + + if not deps: + levels[step_id] = 0 + return 0 + + max_dep = max(compute_level(dep) for dep in deps) + levels[step_id] = max_dep + 1 + return levels[step_id] + + for step in steps: + step.level = compute_level(step.step_id) + + +def _compute_stage_levels(stages: List) -> Dict[str, int]: + """ + Compute stage levels for parallel execution. + + Stages at the same level have no dependencies between them + and can run in parallel. + """ + from .compiler import CompiledStage + + levels = {} + + def compute_level(stage_name: str) -> int: + if stage_name in levels: + return levels[stage_name] + + stage = next((s for s in stages if s.name == stage_name), None) + if not stage or not stage.requires: + levels[stage_name] = 0 + return 0 + + max_req = max(compute_level(req) for req in stage.requires) + levels[stage_name] = max_req + 1 + return levels[stage_name] + + for stage in stages: + compute_level(stage.name) + + return levels + + +def step_to_task_sexp(step: PlanStep) -> List: + """ + Convert a step to a minimal S-expression for Celery task. + + This is the S-expression that gets sent to a worker. + The worker hashes this to verify cache_id. + """ + sexp = [Symbol(step.node_type.lower())] + + # Add resolved config + for key, value in step.config.items(): + sexp.extend([Keyword(key), value]) + + # Add input cache IDs (not step IDs) + if step.inputs: + sexp.extend([Keyword("inputs"), step.inputs]) + + return sexp + + +def task_cache_id(task_sexp: List, cluster_key: str = None) -> str: + """ + Compute cache ID from task S-expression. + + This allows workers to verify they're executing the right task. + """ + # Serialize S-expression to canonical form + canonical = serialize(task_sexp) + return _stable_hash({"sexp": canonical}, cluster_key) diff --git a/artdag/sexp/primitives.py b/artdag/sexp/primitives.py new file mode 100644 index 0000000..65bbcc0 --- /dev/null +++ b/artdag/sexp/primitives.py @@ -0,0 +1,620 @@ +""" +Frame processing primitives for sexp effects. + +These primitives are called by sexp effect definitions and operate on +numpy arrays (frames). They're used when falling back to Python rendering +instead of FFmpeg. + +Required: numpy, PIL +""" + +import math +from typing import Any, Dict, List, Optional, Tuple + +try: + import numpy as np + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + np = None + +try: + from PIL import Image, ImageDraw, ImageFont + HAS_PIL = True +except ImportError: + HAS_PIL = False + + +# ASCII character sets for different styles +ASCII_ALPHABETS = { + "standard": " .:-=+*#%@", + "blocks": " ░▒▓█", + "simple": " .-:+=xX#", + "detailed": " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$", + "binary": " █", +} + + +def check_deps(): + """Check that required dependencies are available.""" + if not HAS_NUMPY: + raise ImportError("numpy required for frame primitives: pip install numpy") + if not HAS_PIL: + raise ImportError("PIL required for frame primitives: pip install Pillow") + + +def frame_to_image(frame: np.ndarray) -> Image.Image: + """Convert numpy frame to PIL Image.""" + check_deps() + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + return Image.fromarray(frame) + + +def image_to_frame(img: Image.Image) -> np.ndarray: + """Convert PIL Image to numpy frame.""" + check_deps() + return np.array(img) + + +# ============================================================================ +# ASCII Art Primitives +# ============================================================================ + +def cell_sample(frame: np.ndarray, cell_size: int = 8) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample frame into cells, returning average colors and luminances. + + Args: + frame: Input frame (H, W, C) + cell_size: Size of each cell + + Returns: + (colors, luminances) - colors is (rows, cols, 3), luminances is (rows, cols) + """ + check_deps() + h, w = frame.shape[:2] + rows = h // cell_size + cols = w // cell_size + + colors = np.zeros((rows, cols, 3), dtype=np.float32) + luminances = np.zeros((rows, cols), dtype=np.float32) + + for r in range(rows): + for c in range(cols): + y0, y1 = r * cell_size, (r + 1) * cell_size + x0, x1 = c * cell_size, (c + 1) * cell_size + cell = frame[y0:y1, x0:x1] + + # Average color + avg_color = cell.mean(axis=(0, 1)) + colors[r, c] = avg_color[:3] # RGB only + + # Luminance (ITU-R BT.601) + lum = 0.299 * avg_color[0] + 0.587 * avg_color[1] + 0.114 * avg_color[2] + luminances[r, c] = lum + + return colors, luminances + + +def luminance_to_chars( + luminances: np.ndarray, + alphabet: str = "standard", + contrast: float = 1.0, +) -> List[List[str]]: + """ + Convert luminance values to ASCII characters. + + Args: + luminances: 2D array of luminance values (0-255) + alphabet: Name of character set or custom string + contrast: Contrast multiplier + + Returns: + 2D list of characters + """ + check_deps() + chars = ASCII_ALPHABETS.get(alphabet, alphabet) + n_chars = len(chars) + + rows, cols = luminances.shape + result = [] + + for r in range(rows): + row_chars = [] + for c in range(cols): + lum = luminances[r, c] + # Apply contrast around midpoint + lum = 128 + (lum - 128) * contrast + lum = np.clip(lum, 0, 255) + # Map to character index + idx = int(lum / 256 * n_chars) + idx = min(idx, n_chars - 1) + row_chars.append(chars[idx]) + result.append(row_chars) + + return result + + +def render_char_grid( + frame: np.ndarray, + chars: List[List[str]], + colors: np.ndarray, + char_size: int = 8, + color_mode: str = "color", + background: Tuple[int, int, int] = (0, 0, 0), +) -> np.ndarray: + """ + Render character grid to an image. + + Args: + frame: Original frame (for dimensions) + chars: 2D list of characters + colors: Color for each cell (rows, cols, 3) + char_size: Size of each character cell + color_mode: "color", "white", or "green" + background: Background RGB color + + Returns: + Rendered frame + """ + check_deps() + h, w = frame.shape[:2] + rows = len(chars) + cols = len(chars[0]) if chars else 0 + + # Create output image + img = Image.new("RGB", (w, h), background) + draw = ImageDraw.Draw(img) + + # Try to get a monospace font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size) + except (IOError, OSError): + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", char_size) + except (IOError, OSError): + font = ImageFont.load_default() + + for r in range(rows): + for c in range(cols): + char = chars[r][c] + if char == ' ': + continue + + x = c * char_size + y = r * char_size + + if color_mode == "color": + color = tuple(int(v) for v in colors[r, c]) + elif color_mode == "green": + color = (0, 255, 0) + else: # white + color = (255, 255, 255) + + draw.text((x, y), char, fill=color, font=font) + + return np.array(img) + + +def ascii_art_frame( + frame: np.ndarray, + char_size: int = 8, + alphabet: str = "standard", + color_mode: str = "color", + contrast: float = 1.5, + background: Tuple[int, int, int] = (0, 0, 0), +) -> np.ndarray: + """ + Apply ASCII art effect to a frame. + + This is the main entry point for the ascii_art effect. + """ + check_deps() + colors, luminances = cell_sample(frame, char_size) + chars = luminance_to_chars(luminances, alphabet, contrast) + return render_char_grid(frame, chars, colors, char_size, color_mode, background) + + +# ============================================================================ +# ASCII Zones Primitives +# ============================================================================ + +def ascii_zones_frame( + frame: np.ndarray, + char_size: int = 8, + zone_threshold: int = 128, + dark_chars: str = " .-:", + light_chars: str = "=+*#", +) -> np.ndarray: + """ + Apply zone-based ASCII art effect. + + Different character sets for dark vs light regions. + """ + check_deps() + colors, luminances = cell_sample(frame, char_size) + + rows, cols = luminances.shape + chars = [] + + for r in range(rows): + row_chars = [] + for c in range(cols): + lum = luminances[r, c] + if lum < zone_threshold: + # Dark zone + charset = dark_chars + local_lum = lum / zone_threshold # 0-1 within zone + else: + # Light zone + charset = light_chars + local_lum = (lum - zone_threshold) / (255 - zone_threshold) + + idx = int(local_lum * len(charset)) + idx = min(idx, len(charset) - 1) + row_chars.append(charset[idx]) + chars.append(row_chars) + + return render_char_grid(frame, chars, colors, char_size, "color", (0, 0, 0)) + + +# ============================================================================ +# Kaleidoscope Primitives (Python fallback) +# ============================================================================ + +def kaleidoscope_displace( + w: int, + h: int, + segments: int = 6, + rotation: float = 0, + cx: float = None, + cy: float = None, + zoom: float = 1.0, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute kaleidoscope displacement coordinates. + + Returns (x_coords, y_coords) arrays for remapping. + """ + check_deps() + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Create coordinate grids + y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32) + + # Center coordinates + x_centered = x_grid - cx + y_centered = y_grid - cy + + # Convert to polar + r = np.sqrt(x_centered**2 + y_centered**2) / zoom + theta = np.arctan2(y_centered, x_centered) + + # Apply rotation + theta = theta - np.radians(rotation) + + # Kaleidoscope: fold angle into segment + segment_angle = 2 * np.pi / segments + theta = np.abs(np.mod(theta, segment_angle) - segment_angle / 2) + + # Convert back to cartesian + x_new = r * np.cos(theta) + cx + y_new = r * np.sin(theta) + cy + + return x_new, y_new + + +def remap( + frame: np.ndarray, + x_coords: np.ndarray, + y_coords: np.ndarray, +) -> np.ndarray: + """ + Remap frame using coordinate arrays. + + Uses bilinear interpolation. + """ + check_deps() + from scipy import ndimage + + h, w = frame.shape[:2] + + # Clip coordinates + x_coords = np.clip(x_coords, 0, w - 1) + y_coords = np.clip(y_coords, 0, h - 1) + + # Remap each channel + if len(frame.shape) == 3: + result = np.zeros_like(frame) + for c in range(frame.shape[2]): + result[:, :, c] = ndimage.map_coordinates( + frame[:, :, c], + [y_coords, x_coords], + order=1, + mode='reflect', + ) + return result + else: + return ndimage.map_coordinates(frame, [y_coords, x_coords], order=1, mode='reflect') + + +def kaleidoscope_frame( + frame: np.ndarray, + segments: int = 6, + rotation: float = 0, + center_x: float = 0.5, + center_y: float = 0.5, + zoom: float = 1.0, +) -> np.ndarray: + """ + Apply kaleidoscope effect to a frame. + + This is a Python fallback - FFmpeg version is faster. + """ + check_deps() + h, w = frame.shape[:2] + cx = w * center_x + cy = h * center_y + + x_coords, y_coords = kaleidoscope_displace(w, h, segments, rotation, cx, cy, zoom) + return remap(frame, x_coords, y_coords) + + +# ============================================================================ +# Datamosh Primitives (simplified Python version) +# ============================================================================ + +def datamosh_frame( + frame: np.ndarray, + prev_frame: Optional[np.ndarray], + block_size: int = 32, + corruption: float = 0.3, + max_offset: int = 50, + color_corrupt: bool = True, +) -> np.ndarray: + """ + Simplified datamosh effect using block displacement. + + This is a basic approximation - real datamosh works on compressed video. + """ + check_deps() + if prev_frame is None: + return frame.copy() + + h, w = frame.shape[:2] + result = frame.copy() + + # Process in blocks + for y in range(0, h - block_size, block_size): + for x in range(0, w - block_size, block_size): + if np.random.random() < corruption: + # Random offset + ox = np.random.randint(-max_offset, max_offset + 1) + oy = np.random.randint(-max_offset, max_offset + 1) + + # Source from previous frame with offset + src_y = np.clip(y + oy, 0, h - block_size) + src_x = np.clip(x + ox, 0, w - block_size) + + block = prev_frame[src_y:src_y+block_size, src_x:src_x+block_size] + + # Color corruption + if color_corrupt and np.random.random() < 0.3: + # Swap or shift channels + block = np.roll(block, np.random.randint(1, 3), axis=2) + + result[y:y+block_size, x:x+block_size] = block + + return result + + +# ============================================================================ +# Pixelsort Primitives (Python version) +# ============================================================================ + +def pixelsort_frame( + frame: np.ndarray, + sort_by: str = "lightness", + threshold_low: float = 50, + threshold_high: float = 200, + angle: float = 0, + reverse: bool = False, +) -> np.ndarray: + """ + Apply pixel sorting effect to a frame. + """ + check_deps() + from scipy import ndimage + + # Rotate if needed + if angle != 0: + frame = ndimage.rotate(frame, -angle, reshape=False, mode='reflect') + + h, w = frame.shape[:2] + result = frame.copy() + + # Compute sort key + if sort_by == "lightness": + key = 0.299 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2] + elif sort_by == "hue": + # Simple hue approximation + key = np.arctan2( + np.sqrt(3) * (frame[:,:,1].astype(float) - frame[:,:,2]), + 2 * frame[:,:,0].astype(float) - frame[:,:,1] - frame[:,:,2] + ) + elif sort_by == "saturation": + mx = frame.max(axis=2).astype(float) + mn = frame.min(axis=2).astype(float) + key = np.where(mx > 0, (mx - mn) / mx, 0) + else: + key = frame[:,:,0] # Red channel + + # Sort each row + for y in range(h): + row = result[y] + row_key = key[y] + + # Find sortable intervals (pixels within threshold) + mask = (row_key >= threshold_low) & (row_key <= threshold_high) + + # Find runs of True in mask + runs = [] + start = None + for x in range(w): + if mask[x] and start is None: + start = x + elif not mask[x] and start is not None: + if x - start > 1: + runs.append((start, x)) + start = None + if start is not None and w - start > 1: + runs.append((start, w)) + + # Sort each run + for start, end in runs: + indices = np.argsort(row_key[start:end]) + if reverse: + indices = indices[::-1] + result[y, start:end] = row[start:end][indices] + + # Rotate back + if angle != 0: + result = ndimage.rotate(result, angle, reshape=False, mode='reflect') + + return result + + +# ============================================================================ +# Primitive Registry +# ============================================================================ + +def map_char_grid( + chars, + luminances, + fn, +): + """ + Map a function over each cell of a character grid. + + Args: + chars: 2D array/list of characters (rows, cols) + luminances: 2D array of luminance values + fn: Function or Lambda (row, col, char, luminance) -> new_char + + Returns: + New character grid with mapped values (list of lists) + """ + from .parser import Lambda + from .evaluator import evaluate + + # Handle both list and numpy array inputs + if isinstance(chars, np.ndarray): + rows, cols = chars.shape[:2] + else: + rows = len(chars) + cols = len(chars[0]) if rows > 0 and isinstance(chars[0], (list, tuple, str)) else 1 + + # Get luminances as 2D + if isinstance(luminances, np.ndarray): + lum_arr = luminances + else: + lum_arr = np.array(luminances) + + # Check if fn is a Lambda (from sexp) or a Python callable + is_lambda = isinstance(fn, Lambda) + + result = [] + for r in range(rows): + row_result = [] + for c in range(cols): + # Get character + if isinstance(chars, np.ndarray): + ch = chars[r, c] if len(chars.shape) > 1 else chars[r] + elif isinstance(chars[r], str): + ch = chars[r][c] if c < len(chars[r]) else ' ' + else: + ch = chars[r][c] if c < len(chars[r]) else ' ' + + # Get luminance + if len(lum_arr.shape) > 1: + lum = lum_arr[r, c] + else: + lum = lum_arr[r] + + # Call the function + if is_lambda: + # Evaluate the Lambda with arguments bound + call_env = dict(fn.closure) if fn.closure else {} + for param, val in zip(fn.params, [r, c, ch, float(lum)]): + call_env[param] = val + new_ch = evaluate(fn.body, call_env) + else: + new_ch = fn(r, c, ch, float(lum)) + + row_result.append(new_ch) + result.append(row_result) + + return result + + +def alphabet_char(alphabet: str, index: int) -> str: + """ + Get a character from an alphabet at a given index. + + Args: + alphabet: Alphabet name (from ASCII_ALPHABETS) or literal string + index: Index into the alphabet (clamped to valid range) + + Returns: + Character at the index + """ + # Get alphabet string + if alphabet in ASCII_ALPHABETS: + chars = ASCII_ALPHABETS[alphabet] + else: + chars = alphabet + + # Clamp index to valid range + index = int(index) + index = max(0, min(index, len(chars) - 1)) + + return chars[index] + + +PRIMITIVES = { + # ASCII + "cell-sample": cell_sample, + "luminance-to-chars": luminance_to_chars, + "render-char-grid": render_char_grid, + "map-char-grid": map_char_grid, + "alphabet-char": alphabet_char, + "ascii_art_frame": ascii_art_frame, + "ascii_zones_frame": ascii_zones_frame, + + # Kaleidoscope + "kaleidoscope-displace": kaleidoscope_displace, + "remap": remap, + "kaleidoscope_frame": kaleidoscope_frame, + + # Datamosh + "datamosh": datamosh_frame, + "datamosh_frame": datamosh_frame, + + # Pixelsort + "pixelsort": pixelsort_frame, + "pixelsort_frame": pixelsort_frame, +} + + +def get_primitive(name: str): + """Get a primitive function by name.""" + return PRIMITIVES.get(name) + + +def list_primitives() -> List[str]: + """List all available primitives.""" + return list(PRIMITIVES.keys()) diff --git a/artdag/sexp/scheduler.py b/artdag/sexp/scheduler.py new file mode 100644 index 0000000..65daf28 --- /dev/null +++ b/artdag/sexp/scheduler.py @@ -0,0 +1,779 @@ +""" +Celery scheduler for S-expression execution plans. + +Distributes plan steps to workers as S-expressions. +The S-expression is the canonical format - workers receive +serialized S-expressions and can verify cache_ids by hashing them. + +Usage: + from artdag.sexp import compile_string, create_plan + from artdag.sexp.scheduler import schedule_plan + + recipe = compile_string(sexp_content) + plan = create_plan(recipe, inputs={'video': 'abc123...'}) + result = schedule_plan(plan) +""" + +import hashlib +import json +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Callable + +from .parser import Symbol, Keyword, serialize, parse +from .planner import ExecutionPlanSexp, PlanStep + +logger = logging.getLogger(__name__) + + +@dataclass +class StepResult: + """Result from executing a step.""" + step_id: str + cache_id: str + status: str # 'completed', 'cached', 'failed', 'pending' + output_path: Optional[str] = None + error: Optional[str] = None + ipfs_cid: Optional[str] = None + + +@dataclass +class PlanResult: + """Result from executing a complete plan.""" + plan_id: str + status: str # 'completed', 'failed', 'partial' + steps_completed: int = 0 + steps_cached: int = 0 + steps_failed: int = 0 + output_cache_id: Optional[str] = None + output_path: Optional[str] = None + output_ipfs_cid: Optional[str] = None + step_results: Dict[str, StepResult] = field(default_factory=dict) + error: Optional[str] = None + + +def step_to_sexp(step: PlanStep) -> List: + """ + Convert a PlanStep to minimal S-expression for worker. + + This is the canonical form that workers receive. + Workers can verify cache_id by hashing this S-expression. + """ + sexp = [Symbol(step.node_type.lower())] + + # Add config as keywords + for key, value in step.config.items(): + sexp.extend([Keyword(key.replace('_', '-')), value]) + + # Add inputs as cache IDs (not step IDs) + if step.inputs: + sexp.extend([Keyword("inputs"), step.inputs]) + + return sexp + + +def step_sexp_to_string(step: PlanStep) -> str: + """Serialize step to S-expression string for Celery task.""" + return serialize(step_to_sexp(step)) + + +def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool: + """ + Verify that a step's cache_id matches its S-expression. + + Workers should call this to verify they're executing the correct task. + """ + data = {"sexp": step_sexp} + if cluster_key: + data = {"_cluster_key": cluster_key, "_data": data} + + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + computed = hashlib.sha3_256(json_str.encode()).hexdigest() + return computed == expected_cache_id + + +class PlanScheduler: + """ + Schedules execution of S-expression plans on Celery workers. + + The scheduler: + 1. Groups steps by dependency level + 2. Checks cache for already-computed results + 3. Dispatches uncached steps to workers + 4. Waits for completion before proceeding to next level + """ + + def __init__( + self, + cache_manager=None, + celery_app=None, + execute_task_name: str = 'tasks.execute_step_sexp', + ): + """ + Initialize the scheduler. + + Args: + cache_manager: L1 cache manager for checking cached results + celery_app: Celery application instance + execute_task_name: Name of the Celery task for step execution + """ + self.cache_manager = cache_manager + self.celery_app = celery_app + self.execute_task_name = execute_task_name + + def schedule( + self, + plan: ExecutionPlanSexp, + timeout: int = 3600, + ) -> PlanResult: + """ + Schedule and execute a plan. + + Args: + plan: The execution plan (S-expression format) + timeout: Timeout in seconds for the entire plan + + Returns: + PlanResult with execution results + """ + from celery import group + + logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)") + + # Build step lookup and group by level + steps_by_id = {s.step_id: s for s in plan.steps} + steps_by_level = self._group_by_level(plan.steps) + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + + # Track results + result = PlanResult( + plan_id=plan.plan_id, + status="pending", + ) + + # Map step_id -> cache_id for resolving inputs + cache_ids = dict(plan.inputs) # Start with input hashes + for step in plan.steps: + cache_ids[step.step_id] = step.cache_id + + # Execute level by level + for level in range(max_level + 1): + level_steps = steps_by_level.get(level, []) + if not level_steps: + continue + + logger.info(f"Level {level}: {len(level_steps)} steps") + + # Check cache for each step + steps_to_run = [] + for step in level_steps: + if self._is_cached(step.cache_id): + result.steps_cached += 1 + result.step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="cached", + output_path=self._get_cached_path(step.cache_id), + ) + else: + steps_to_run.append(step) + + if not steps_to_run: + logger.info(f"Level {level}: all {len(level_steps)} steps cached") + continue + + # Dispatch uncached steps to workers + logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps") + + tasks = [] + for step in steps_to_run: + # Build task arguments + step_sexp = step_sexp_to_string(step) + input_cache_ids = { + inp: cache_ids.get(inp, inp) + for inp in step.inputs + } + + task = self._get_execute_task().s( + step_sexp=step_sexp, + step_id=step.step_id, + cache_id=step.cache_id, + plan_id=plan.plan_id, + input_cache_ids=input_cache_ids, + ) + tasks.append(task) + + # Execute in parallel + job = group(tasks) + async_result = job.apply_async() + + try: + step_results = async_result.get(timeout=timeout) + except Exception as e: + logger.error(f"Level {level} failed: {e}") + result.status = "failed" + result.error = f"Level {level} failed: {e}" + return result + + # Process results + for step_result in step_results: + step_id = step_result.get("step_id") + status = step_result.get("status") + + result.step_results[step_id] = StepResult( + step_id=step_id, + cache_id=step_result.get("cache_id"), + status=status, + output_path=step_result.get("output_path"), + error=step_result.get("error"), + ipfs_cid=step_result.get("ipfs_cid"), + ) + + if status in ("completed", "cached", "completed_by_other"): + result.steps_completed += 1 + elif status == "failed": + result.steps_failed += 1 + result.status = "failed" + result.error = step_result.get("error") + return result + + # Get final output + output_step = steps_by_id.get(plan.output_step_id) + if output_step: + output_result = result.step_results.get(output_step.step_id) + if output_result: + result.output_cache_id = output_step.cache_id + result.output_path = output_result.output_path + result.output_ipfs_cid = output_result.ipfs_cid + + result.status = "completed" + logger.info( + f"Plan {plan.plan_id[:16]}... completed: " + f"{result.steps_completed} executed, {result.steps_cached} cached" + ) + return result + + def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]: + """Group steps by dependency level.""" + by_level = {} + for step in steps: + by_level.setdefault(step.level, []).append(step) + return by_level + + def _is_cached(self, cache_id: str) -> bool: + """Check if a cache_id exists in cache.""" + if self.cache_manager is None: + return False + path = self.cache_manager.get_by_cid(cache_id) + return path is not None + + def _get_cached_path(self, cache_id: str) -> Optional[str]: + """Get the path for a cached item.""" + if self.cache_manager is None: + return None + path = self.cache_manager.get_by_cid(cache_id) + return str(path) if path else None + + def _get_execute_task(self): + """Get the Celery task for step execution.""" + if self.celery_app is None: + raise RuntimeError("Celery app not configured") + return self.celery_app.tasks[self.execute_task_name] + + +def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler: + """ + Create a scheduler with the given dependencies. + + If not provided, attempts to import from art-celery. + """ + if celery_app is None: + try: + from celery_app import app as celery_app + except ImportError: + pass + + if cache_manager is None: + try: + from cache_manager import get_cache_manager + cache_manager = get_cache_manager() + except ImportError: + pass + + return PlanScheduler( + cache_manager=cache_manager, + celery_app=celery_app, + ) + + +def schedule_plan( + plan: ExecutionPlanSexp, + cache_manager=None, + celery_app=None, + timeout: int = 3600, +) -> PlanResult: + """ + Convenience function to schedule a plan. + + Args: + plan: The execution plan + cache_manager: Optional cache manager + celery_app: Optional Celery app + timeout: Execution timeout + + Returns: + PlanResult + """ + scheduler = create_scheduler(cache_manager, celery_app) + return scheduler.schedule(plan, timeout=timeout) + + +# Stage-aware scheduling + +@dataclass +class StageResult: + """Result from executing a stage.""" + stage_name: str + cache_id: str + status: str # 'completed', 'cached', 'failed', 'pending' + step_results: Dict[str, StepResult] = field(default_factory=dict) + outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id + error: Optional[str] = None + + +@dataclass +class StagePlanResult: + """Result from executing a plan with stages.""" + plan_id: str + status: str # 'completed', 'failed', 'partial' + stages_completed: int = 0 + stages_cached: int = 0 + stages_failed: int = 0 + steps_completed: int = 0 + steps_cached: int = 0 + steps_failed: int = 0 + stage_results: Dict[str, StageResult] = field(default_factory=dict) + output_cache_id: Optional[str] = None + output_path: Optional[str] = None + error: Optional[str] = None + + +class StagePlanScheduler: + """ + Stage-aware scheduler for S-expression plans. + + The scheduler: + 1. Groups stages by level (parallel groups) + 2. For each stage level: + - Check stage cache, skip entire stage if hit + - Execute stage steps (grouped by level within stage) + - Cache stage outputs + 3. Stages at same level can run in parallel + """ + + def __init__( + self, + cache_manager=None, + stage_cache=None, + celery_app=None, + execute_task_name: str = 'tasks.execute_step_sexp', + ): + """ + Initialize the stage-aware scheduler. + + Args: + cache_manager: L1 cache manager for step-level caching + stage_cache: StageCache instance for stage-level caching + celery_app: Celery application instance + execute_task_name: Name of the Celery task for step execution + """ + self.cache_manager = cache_manager + self.stage_cache = stage_cache + self.celery_app = celery_app + self.execute_task_name = execute_task_name + + def schedule( + self, + plan: ExecutionPlanSexp, + timeout: int = 3600, + ) -> StagePlanResult: + """ + Schedule and execute a plan with stage awareness. + + If the plan has stages, uses stage-level scheduling. + Otherwise, falls back to step-level scheduling. + + Args: + plan: The execution plan (S-expression format) + timeout: Timeout in seconds for the entire plan + + Returns: + StagePlanResult with execution results + """ + # If no stages, use regular scheduling + if not plan.stage_plans: + logger.info("Plan has no stages, using step-level scheduling") + regular_scheduler = PlanScheduler( + cache_manager=self.cache_manager, + celery_app=self.celery_app, + execute_task_name=self.execute_task_name, + ) + step_result = regular_scheduler.schedule(plan, timeout) + return StagePlanResult( + plan_id=step_result.plan_id, + status=step_result.status, + steps_completed=step_result.steps_completed, + steps_cached=step_result.steps_cached, + steps_failed=step_result.steps_failed, + output_cache_id=step_result.output_cache_id, + output_path=step_result.output_path, + error=step_result.error, + ) + + logger.info( + f"Scheduling plan {plan.plan_id[:16]}... " + f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)" + ) + + result = StagePlanResult( + plan_id=plan.plan_id, + status="pending", + ) + + # Group stages by level + stages_by_level = self._group_stages_by_level(plan.stage_plans) + max_level = max(stages_by_level.keys()) if stages_by_level else 0 + + # Track stage outputs for data flow + stage_outputs = {} # stage_name -> {binding_name -> cache_id} + + # Execute stage by stage level + for level in range(max_level + 1): + level_stages = stages_by_level.get(level, []) + if not level_stages: + continue + + logger.info(f"Stage level {level}: {len(level_stages)} stages") + + # Check stage cache for each stage + stages_to_run = [] + for stage_plan in level_stages: + if self._is_stage_cached(stage_plan.cache_id): + result.stages_cached += 1 + cached_entry = self._load_cached_stage(stage_plan.cache_id) + if cached_entry: + stage_outputs[stage_plan.stage_name] = { + name: out.cache_id + for name, out in cached_entry.outputs.items() + } + result.stage_results[stage_plan.stage_name] = StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="cached", + outputs=stage_outputs[stage_plan.stage_name], + ) + logger.info(f"Stage {stage_plan.stage_name}: cached") + else: + stages_to_run.append(stage_plan) + + if not stages_to_run: + logger.info(f"Stage level {level}: all {len(level_stages)} stages cached") + continue + + # Execute uncached stages + # For now, execute sequentially; L1 Celery will add parallel execution + for stage_plan in stages_to_run: + logger.info(f"Executing stage: {stage_plan.stage_name}") + + stage_result = self._execute_stage( + stage_plan, + plan, + stage_outputs, + timeout, + ) + + result.stage_results[stage_plan.stage_name] = stage_result + + if stage_result.status == "completed": + result.stages_completed += 1 + stage_outputs[stage_plan.stage_name] = stage_result.outputs + + # Cache the stage result + self._cache_stage(stage_plan, stage_result) + elif stage_result.status == "failed": + result.stages_failed += 1 + result.status = "failed" + result.error = stage_result.error + return result + + # Accumulate step counts + for sr in stage_result.step_results.values(): + if sr.status == "completed": + result.steps_completed += 1 + elif sr.status == "cached": + result.steps_cached += 1 + elif sr.status == "failed": + result.steps_failed += 1 + + # Get final output + if plan.stage_plans: + last_stage = plan.stage_plans[-1] + if last_stage.stage_name in result.stage_results: + stage_res = result.stage_results[last_stage.stage_name] + result.output_cache_id = last_stage.cache_id + # Find the output step's path from step results + for step_res in stage_res.step_results.values(): + if step_res.output_path: + result.output_path = step_res.output_path + + result.status = "completed" + logger.info( + f"Plan {plan.plan_id[:16]}... completed: " + f"{result.stages_completed} stages executed, " + f"{result.stages_cached} stages cached" + ) + return result + + def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]: + """Group stage plans by their level.""" + by_level = {} + for stage_plan in stage_plans: + by_level.setdefault(stage_plan.level, []).append(stage_plan) + return by_level + + def _is_stage_cached(self, cache_id: str) -> bool: + """Check if a stage is cached.""" + if self.stage_cache is None: + return False + return self.stage_cache.has_stage(cache_id) + + def _load_cached_stage(self, cache_id: str): + """Load a cached stage entry.""" + if self.stage_cache is None: + return None + return self.stage_cache.load_stage(cache_id) + + def _cache_stage(self, stage_plan, stage_result: StageResult) -> None: + """Cache a stage result.""" + if self.stage_cache is None: + return + + from .stage_cache import StageCacheEntry, StageOutput + + outputs = {} + for name, cache_id in stage_result.outputs.items(): + outputs[name] = StageOutput( + cache_id=cache_id, + output_type="artifact", + ) + + entry = StageCacheEntry( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + outputs=outputs, + ) + self.stage_cache.save_stage(entry) + + def _execute_stage( + self, + stage_plan, + plan: ExecutionPlanSexp, + stage_outputs: Dict, + timeout: int, + ) -> StageResult: + """ + Execute a single stage. + + Uses the PlanScheduler to execute the stage's steps. + """ + # Create a mini-plan with just this stage's steps + stage_steps = stage_plan.steps + + # Build step lookup + steps_by_id = {s.step_id: s for s in stage_steps} + steps_by_level = {} + for step in stage_steps: + steps_by_level.setdefault(step.level, []).append(step) + + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + + # Track step results + step_results = {} + cache_ids = dict(plan.inputs) # Start with input hashes + for step in plan.steps: + cache_ids[step.step_id] = step.cache_id + + # Include outputs from previous stages + for stage_name, outputs in stage_outputs.items(): + for binding_name, binding_cache_id in outputs.items(): + cache_ids[binding_name] = binding_cache_id + + # Execute steps level by level + for level in range(max_level + 1): + level_steps = steps_by_level.get(level, []) + if not level_steps: + continue + + # Check cache for each step + steps_to_run = [] + for step in level_steps: + if self._is_step_cached(step.cache_id): + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="cached", + output_path=self._get_cached_path(step.cache_id), + ) + else: + steps_to_run.append(step) + + if not steps_to_run: + continue + + # Execute steps (for now, sequentially - L1 will add Celery dispatch) + for step in steps_to_run: + # In a full implementation, this would dispatch to Celery + # For now, mark as pending + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="pending", + ) + + # If Celery is configured, dispatch the task + if self.celery_app: + try: + task_result = self._dispatch_step(step, cache_ids, timeout) + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status=task_result.get("status", "completed"), + output_path=task_result.get("output_path"), + error=task_result.get("error"), + ipfs_cid=task_result.get("ipfs_cid"), + ) + except Exception as e: + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="failed", + error=str(e), + ) + return StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="failed", + step_results=step_results, + error=str(e), + ) + + # Build output bindings + outputs = {} + for out_name, node_id in stage_plan.output_bindings.items(): + outputs[out_name] = cache_ids.get(node_id, node_id) + + return StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="completed", + step_results=step_results, + outputs=outputs, + ) + + def _is_step_cached(self, cache_id: str) -> bool: + """Check if a step is cached.""" + if self.cache_manager is None: + return False + path = self.cache_manager.get_by_cid(cache_id) + return path is not None + + def _get_cached_path(self, cache_id: str) -> Optional[str]: + """Get the path for a cached step.""" + if self.cache_manager is None: + return None + path = self.cache_manager.get_by_cid(cache_id) + return str(path) if path else None + + def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict: + """Dispatch a step to Celery for execution.""" + if self.celery_app is None: + raise RuntimeError("Celery app not configured") + + task = self.celery_app.tasks[self.execute_task_name] + + step_sexp = step_sexp_to_string(step) + input_cache_ids = { + inp: cache_ids.get(inp, inp) + for inp in step.inputs + } + + async_result = task.apply_async( + kwargs={ + "step_sexp": step_sexp, + "step_id": step.step_id, + "cache_id": step.cache_id, + "input_cache_ids": input_cache_ids, + } + ) + + return async_result.get(timeout=timeout) + + +def create_stage_scheduler( + cache_manager=None, + stage_cache=None, + celery_app=None, +) -> StagePlanScheduler: + """ + Create a stage-aware scheduler with the given dependencies. + + Args: + cache_manager: L1 cache manager for step-level caching + stage_cache: StageCache instance for stage-level caching + celery_app: Celery application instance + + Returns: + StagePlanScheduler + """ + if celery_app is None: + try: + from celery_app import app as celery_app + except ImportError: + pass + + if cache_manager is None: + try: + from cache_manager import get_cache_manager + cache_manager = get_cache_manager() + except ImportError: + pass + + return StagePlanScheduler( + cache_manager=cache_manager, + stage_cache=stage_cache, + celery_app=celery_app, + ) + + +def schedule_staged_plan( + plan: ExecutionPlanSexp, + cache_manager=None, + stage_cache=None, + celery_app=None, + timeout: int = 3600, +) -> StagePlanResult: + """ + Convenience function to schedule a plan with stage awareness. + + Args: + plan: The execution plan + cache_manager: Optional step-level cache manager + stage_cache: Optional stage-level cache + celery_app: Optional Celery app + timeout: Execution timeout + + Returns: + StagePlanResult + """ + scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app) + return scheduler.schedule(plan, timeout=timeout) diff --git a/artdag/sexp/stage_cache.py b/artdag/sexp/stage_cache.py new file mode 100644 index 0000000..44cbe4c --- /dev/null +++ b/artdag/sexp/stage_cache.py @@ -0,0 +1,412 @@ +""" +Stage-level cache layer using S-expression storage. + +Provides caching for stage outputs, enabling: +- Stage-level cache hits (skip entire stage execution) +- Analysis result persistence as sexp +- Cross-worker stage cache sharing (for L1 Celery integration) + +All cache files use .sexp extension - no JSON in the pipeline. +""" + +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from .parser import Symbol, Keyword, parse, serialize + + +@dataclass +class StageOutput: + """A single output from a stage.""" + cache_id: Optional[str] = None # For artifacts (files, analysis data) + value: Any = None # For scalar values + output_type: str = "artifact" # "artifact", "analysis", "scalar" + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [] + if self.cache_id: + sexp.extend([Keyword("cache-id"), self.cache_id]) + if self.value is not None: + sexp.extend([Keyword("value"), self.value]) + sexp.extend([Keyword("type"), Keyword(self.output_type)]) + return sexp + + @classmethod + def from_sexp(cls, sexp: List) -> 'StageOutput': + """Parse from S-expression list.""" + cache_id = None + value = None + output_type = "artifact" + + i = 0 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "cache-id": + cache_id = val + elif key == "value": + value = val + elif key == "type": + if isinstance(val, Keyword): + output_type = val.name + else: + output_type = str(val) + i += 2 + else: + i += 1 + else: + i += 1 + + return cls(cache_id=cache_id, value=value, output_type=output_type) + + +@dataclass +class StageCacheEntry: + """Cached result of a stage execution.""" + stage_name: str + cache_id: str + outputs: Dict[str, StageOutput] # binding_name -> output info + completed_at: float = field(default_factory=time.time) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_sexp(self) -> List: + """ + Convert to S-expression for storage. + + Format: + (stage-result + :name "analyze-a" + :cache-id "abc123..." + :completed-at 1705678900.123 + :outputs + ((beats-a :cache-id "def456..." :type :analysis) + (tempo :value 120.5 :type :scalar))) + """ + sexp = [Symbol("stage-result")] + sexp.extend([Keyword("name"), self.stage_name]) + sexp.extend([Keyword("cache-id"), self.cache_id]) + sexp.extend([Keyword("completed-at"), self.completed_at]) + + if self.outputs: + outputs_sexp = [] + for name, output in self.outputs.items(): + output_sexp = [Symbol(name)] + output.to_sexp() + outputs_sexp.append(output_sexp) + sexp.extend([Keyword("outputs"), outputs_sexp]) + + if self.metadata: + sexp.extend([Keyword("metadata"), self.metadata]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + @classmethod + def from_sexp(cls, sexp: List) -> 'StageCacheEntry': + """Parse from S-expression.""" + if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "stage-result": + raise ValueError("Invalid stage-result sexp") + + stage_name = None + cache_id = None + completed_at = time.time() + outputs = {} + metadata = {} + + i = 1 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "name": + stage_name = val + elif key == "cache-id": + cache_id = val + elif key == "completed-at": + completed_at = float(val) + elif key == "outputs": + if isinstance(val, list): + for output_sexp in val: + if isinstance(output_sexp, list) and output_sexp: + out_name = output_sexp[0] + if isinstance(out_name, Symbol): + out_name = out_name.name + outputs[out_name] = StageOutput.from_sexp(output_sexp[1:]) + elif key == "metadata": + metadata = val if isinstance(val, dict) else {} + i += 2 + else: + i += 1 + else: + i += 1 + + if not stage_name or not cache_id: + raise ValueError("stage-result missing required fields (name, cache-id)") + + return cls( + stage_name=stage_name, + cache_id=cache_id, + outputs=outputs, + completed_at=completed_at, + metadata=metadata, + ) + + @classmethod + def from_string(cls, text: str) -> 'StageCacheEntry': + """Parse from S-expression string.""" + sexp = parse(text) + return cls.from_sexp(sexp) + + +class StageCache: + """ + Stage-level cache manager using S-expression files. + + Cache structure: + cache_dir/ + _stages/ + {cache_id}.sexp <- Stage result files + """ + + def __init__(self, cache_dir: Union[str, Path]): + """ + Initialize stage cache. + + Args: + cache_dir: Base cache directory + """ + self.cache_dir = Path(cache_dir) + self.stages_dir = self.cache_dir / "_stages" + self.stages_dir.mkdir(parents=True, exist_ok=True) + + def get_cache_path(self, cache_id: str) -> Path: + """Get the path for a stage cache file.""" + return self.stages_dir / f"{cache_id}.sexp" + + def has_stage(self, cache_id: str) -> bool: + """Check if a stage result is cached.""" + return self.get_cache_path(cache_id).exists() + + def load_stage(self, cache_id: str) -> Optional[StageCacheEntry]: + """ + Load a cached stage result. + + Args: + cache_id: Stage cache ID + + Returns: + StageCacheEntry if found, None otherwise + """ + path = self.get_cache_path(cache_id) + if not path.exists(): + return None + + try: + content = path.read_text() + return StageCacheEntry.from_string(content) + except Exception as e: + # Corrupted cache file - log and return None + import sys + print(f"Warning: corrupted stage cache {cache_id}: {e}", file=sys.stderr) + return None + + def save_stage(self, entry: StageCacheEntry) -> Path: + """ + Save a stage result to cache. + + Args: + entry: Stage cache entry to save + + Returns: + Path to the saved cache file + """ + path = self.get_cache_path(entry.cache_id) + content = entry.to_string(pretty=True) + path.write_text(content) + return path + + def delete_stage(self, cache_id: str) -> bool: + """ + Delete a cached stage result. + + Args: + cache_id: Stage cache ID + + Returns: + True if deleted, False if not found + """ + path = self.get_cache_path(cache_id) + if path.exists(): + path.unlink() + return True + return False + + def list_stages(self) -> List[str]: + """List all cached stage IDs.""" + return [ + p.stem for p in self.stages_dir.glob("*.sexp") + ] + + def clear(self) -> int: + """ + Clear all cached stages. + + Returns: + Number of entries cleared + """ + count = 0 + for path in self.stages_dir.glob("*.sexp"): + path.unlink() + count += 1 + return count + + +@dataclass +class AnalysisResult: + """ + Analysis result stored as S-expression. + + Format: + (analysis-result + :analyzer "beats" + :input-hash "abc123..." + :duration 120.5 + :tempo 128.0 + :times (0.0 0.468 0.937 1.406 ...) + :values (0.8 0.9 0.7 0.85 ...)) + """ + analyzer: str + input_hash: str + data: Dict[str, Any] # Analysis data (times, values, duration, etc.) + computed_at: float = field(default_factory=time.time) + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [Symbol("analysis-result")] + sexp.extend([Keyword("analyzer"), self.analyzer]) + sexp.extend([Keyword("input-hash"), self.input_hash]) + sexp.extend([Keyword("computed-at"), self.computed_at]) + + # Add all data fields + for key, value in self.data.items(): + # Convert key to keyword + sexp.extend([Keyword(key.replace("_", "-")), value]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + @classmethod + def from_sexp(cls, sexp: List) -> 'AnalysisResult': + """Parse from S-expression.""" + if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis-result": + raise ValueError("Invalid analysis-result sexp") + + analyzer = None + input_hash = None + computed_at = time.time() + data = {} + + i = 1 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "analyzer": + analyzer = val + elif key == "input-hash": + input_hash = val + elif key == "computed-at": + computed_at = float(val) + else: + # Convert kebab-case back to snake_case + data_key = key.replace("-", "_") + data[data_key] = val + i += 2 + else: + i += 1 + else: + i += 1 + + if not analyzer: + raise ValueError("analysis-result missing analyzer field") + + return cls( + analyzer=analyzer, + input_hash=input_hash or "", + data=data, + computed_at=computed_at, + ) + + @classmethod + def from_string(cls, text: str) -> 'AnalysisResult': + """Parse from S-expression string.""" + sexp = parse(text) + return cls.from_sexp(sexp) + + +def save_analysis_result( + cache_dir: Union[str, Path], + node_id: str, + result: AnalysisResult, +) -> Path: + """ + Save an analysis result as S-expression. + + Args: + cache_dir: Base cache directory + node_id: Node ID for the analysis + result: Analysis result to save + + Returns: + Path to the saved file + """ + cache_dir = Path(cache_dir) + node_dir = cache_dir / node_id + node_dir.mkdir(parents=True, exist_ok=True) + + path = node_dir / "analysis.sexp" + content = result.to_string(pretty=True) + path.write_text(content) + return path + + +def load_analysis_result( + cache_dir: Union[str, Path], + node_id: str, +) -> Optional[AnalysisResult]: + """ + Load an analysis result from cache. + + Args: + cache_dir: Base cache directory + node_id: Node ID for the analysis + + Returns: + AnalysisResult if found, None otherwise + """ + cache_dir = Path(cache_dir) + path = cache_dir / node_id / "analysis.sexp" + + if not path.exists(): + return None + + try: + content = path.read_text() + return AnalysisResult.from_string(content) + except Exception as e: + import sys + print(f"Warning: corrupted analysis cache {node_id}: {e}", file=sys.stderr) + return None diff --git a/artdag/sexp/test_ffmpeg_compiler.py b/artdag/sexp/test_ffmpeg_compiler.py new file mode 100644 index 0000000..1cfafe5 --- /dev/null +++ b/artdag/sexp/test_ffmpeg_compiler.py @@ -0,0 +1,146 @@ +""" +Tests for FFmpeg filter compilation. + +Validates that each filter mapping produces valid FFmpeg commands. +""" + +import subprocess +import tempfile +from pathlib import Path + +from .ffmpeg_compiler import FFmpegCompiler, EFFECT_MAPPINGS + + +def test_filter_syntax(filter_str: str, duration: float = 0.1, is_complex: bool = False) -> tuple[bool, str]: + """ + Test if an FFmpeg filter string is valid by running it on a test pattern. + + Args: + filter_str: The filter string to test + duration: Duration of test video + is_complex: If True, use -filter_complex instead of -vf + + Returns (success, error_message) + """ + with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: + output_path = f.name + + try: + if is_complex: + # Complex filter graph needs -filter_complex and explicit output mapping + cmd = [ + 'ffmpeg', '-y', + '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10', + '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', + '-filter_complex', f'[0:v]{filter_str}[out]', + '-map', '[out]', '-map', '1:a', + '-c:v', 'libx264', '-preset', 'ultrafast', + '-c:a', 'aac', + '-t', str(duration), + output_path + ] + else: + # Simple filter uses -vf + cmd = [ + 'ffmpeg', '-y', + '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10', + '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', + '-vf', filter_str, + '-c:v', 'libx264', '-preset', 'ultrafast', + '-c:a', 'aac', + '-t', str(duration), + output_path + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=10) + + if result.returncode == 0: + return True, "" + else: + # Extract relevant error + stderr = result.stderr + for line in stderr.split('\n'): + if 'Error' in line or 'error' in line or 'Invalid' in line: + return False, line.strip() + return False, stderr[-500:] if len(stderr) > 500 else stderr + except subprocess.TimeoutExpired: + return False, "Timeout" + except Exception as e: + return False, str(e) + finally: + Path(output_path).unlink(missing_ok=True) + + +def run_all_tests(): + """Test all effect mappings.""" + compiler = FFmpegCompiler() + results = [] + + for effect_name, mapping in EFFECT_MAPPINGS.items(): + filter_name = mapping.get("filter") + + # Skip effects with no FFmpeg equivalent (external tools or python primitives) + if filter_name is None: + reason = "No FFmpeg equivalent" + if mapping.get("external_tool"): + reason = f"External tool: {mapping['external_tool']}" + elif mapping.get("python_primitive"): + reason = f"Python primitive: {mapping['python_primitive']}" + results.append((effect_name, "SKIP", reason)) + continue + + # Check if complex filter + is_complex = mapping.get("complex", False) + + # Build filter string + if "static" in mapping: + filter_str = f"{filter_name}={mapping['static']}" + else: + filter_str = filter_name + + # Test it + success, error = test_filter_syntax(filter_str, is_complex=is_complex) + + if success: + results.append((effect_name, "PASS", filter_str)) + else: + results.append((effect_name, "FAIL", f"{filter_str} -> {error}")) + + return results + + +def print_results(results): + """Print test results.""" + passed = sum(1 for _, status, _ in results if status == "PASS") + failed = sum(1 for _, status, _ in results if status == "FAIL") + skipped = sum(1 for _, status, _ in results if status == "SKIP") + + print(f"\n{'='*60}") + print(f"FFmpeg Filter Tests: {passed} passed, {failed} failed, {skipped} skipped") + print(f"{'='*60}\n") + + # Print failures first + if failed > 0: + print("FAILURES:") + for name, status, msg in results: + if status == "FAIL": + print(f" {name}: {msg}") + print() + + # Print passes + print("PASSED:") + for name, status, msg in results: + if status == "PASS": + print(f" {name}: {msg}") + + # Print skips + if skipped > 0: + print("\nSKIPPED (Python fallback):") + for name, status, msg in results: + if status == "SKIP": + print(f" {name}") + + +if __name__ == "__main__": + results = run_all_tests() + print_results(results) diff --git a/artdag/sexp/test_primitives.py b/artdag/sexp/test_primitives.py new file mode 100644 index 0000000..193c7fd --- /dev/null +++ b/artdag/sexp/test_primitives.py @@ -0,0 +1,201 @@ +""" +Tests for Python primitive effects. + +Tests that ascii_art, ascii_zones, and other Python primitives +can be executed via the EffectExecutor. +""" + +import subprocess +import tempfile +from pathlib import Path + +import pytest + +try: + import numpy as np + from PIL import Image + HAS_DEPS = True +except ImportError: + HAS_DEPS = False + +from .primitives import ( + ascii_art_frame, + ascii_zones_frame, + get_primitive, + list_primitives, +) +from .ffmpeg_compiler import FFmpegCompiler + + +def create_test_video(path: Path, duration: float = 0.5, size: str = "64x64") -> bool: + """Create a short test video using ffmpeg.""" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10", + "-c:v", "libx264", "-preset", "ultrafast", + str(path) + ] + result = subprocess.run(cmd, capture_output=True) + return result.returncode == 0 + + +@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available") +class TestPrimitives: + """Test primitive functions directly.""" + + def test_ascii_art_frame_basic(self): + """Test ascii_art_frame produces output of same shape.""" + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = ascii_art_frame(frame, char_size=8) + assert result.shape == frame.shape + assert result.dtype == np.uint8 + + def test_ascii_zones_frame_basic(self): + """Test ascii_zones_frame produces output of same shape.""" + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = ascii_zones_frame(frame, char_size=8) + assert result.shape == frame.shape + assert result.dtype == np.uint8 + + def test_get_primitive(self): + """Test primitive lookup.""" + assert get_primitive("ascii_art_frame") is ascii_art_frame + assert get_primitive("ascii_zones_frame") is ascii_zones_frame + assert get_primitive("nonexistent") is None + + def test_list_primitives(self): + """Test listing primitives.""" + primitives = list_primitives() + assert "ascii_art_frame" in primitives + assert "ascii_zones_frame" in primitives + assert len(primitives) > 5 + + +class TestFFmpegCompilerPrimitives: + """Test FFmpegCompiler python_primitive mappings.""" + + def test_has_python_primitive_ascii_art(self): + """Test ascii_art has python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame" + + def test_has_python_primitive_ascii_zones(self): + """Test ascii_zones has python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame" + + def test_has_python_primitive_ffmpeg_effect(self): + """Test FFmpeg effects don't have python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("brightness") is None + assert compiler.has_python_primitive("blur") is None + + def test_compile_effect_returns_none_for_primitives(self): + """Test compile_effect returns None for primitive effects.""" + compiler = FFmpegCompiler() + assert compiler.compile_effect("ascii_art", {}) is None + assert compiler.compile_effect("ascii_zones", {}) is None + + +@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available") +class TestEffectExecutorPrimitives: + """Test EffectExecutor with Python primitives.""" + + def test_executor_loads_primitive(self): + """Test that executor finds primitive effects.""" + from ..nodes.effect import _get_python_primitive_effect + + effect_fn = _get_python_primitive_effect("ascii_art") + assert effect_fn is not None + + effect_fn = _get_python_primitive_effect("ascii_zones") + assert effect_fn is not None + + def test_executor_rejects_unknown_effect(self): + """Test that executor returns None for unknown effects.""" + from ..nodes.effect import _get_python_primitive_effect + + effect_fn = _get_python_primitive_effect("nonexistent_effect") + assert effect_fn is None + + def test_execute_ascii_art_effect(self, tmp_path): + """Test executing ascii_art effect on a video.""" + from ..nodes.effect import EffectExecutor + + # Create test video + input_path = tmp_path / "input.mp4" + if not create_test_video(input_path): + pytest.skip("Could not create test video") + + output_path = tmp_path / "output.mkv" + executor = EffectExecutor() + + result = executor.execute( + config={"effect": "ascii_art", "char_size": 8}, + inputs=[input_path], + output_path=output_path, + ) + + assert result.exists() + assert result.stat().st_size > 0 + + +def run_all_tests(): + """Run tests manually.""" + import sys + + # Check dependencies + if not HAS_DEPS: + print("SKIP: numpy/PIL not available") + return + + print("Testing primitives...") + + # Test primitive functions + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + print(" ascii_art_frame...", end=" ") + result = ascii_art_frame(frame, char_size=8) + assert result.shape == frame.shape + print("PASS") + + print(" ascii_zones_frame...", end=" ") + result = ascii_zones_frame(frame, char_size=8) + assert result.shape == frame.shape + print("PASS") + + # Test FFmpegCompiler mappings + print("\nTesting FFmpegCompiler mappings...") + compiler = FFmpegCompiler() + + print(" ascii_art python_primitive...", end=" ") + assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame" + print("PASS") + + print(" ascii_zones python_primitive...", end=" ") + assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame" + print("PASS") + + # Test executor lookup + print("\nTesting EffectExecutor...") + try: + from ..nodes.effect import _get_python_primitive_effect + + print(" _get_python_primitive_effect(ascii_art)...", end=" ") + effect_fn = _get_python_primitive_effect("ascii_art") + assert effect_fn is not None + print("PASS") + + print(" _get_python_primitive_effect(ascii_zones)...", end=" ") + effect_fn = _get_python_primitive_effect("ascii_zones") + assert effect_fn is not None + print("PASS") + + except ImportError as e: + print(f"SKIP: {e}") + + print("\n=== All tests passed ===") + + +if __name__ == "__main__": + run_all_tests() diff --git a/artdag/sexp/test_stage_cache.py b/artdag/sexp/test_stage_cache.py new file mode 100644 index 0000000..87daf3f --- /dev/null +++ b/artdag/sexp/test_stage_cache.py @@ -0,0 +1,324 @@ +""" +Tests for stage cache layer. + +Tests S-expression storage for stage results and analysis data. +""" + +import pytest +import tempfile +from pathlib import Path + +from .stage_cache import ( + StageCache, + StageCacheEntry, + StageOutput, + AnalysisResult, + save_analysis_result, + load_analysis_result, +) +from .parser import parse, serialize + + +class TestStageOutput: + """Test StageOutput dataclass and serialization.""" + + def test_stage_output_artifact(self): + """StageOutput can represent an artifact.""" + output = StageOutput( + cache_id="abc123", + output_type="artifact", + ) + assert output.cache_id == "abc123" + assert output.output_type == "artifact" + + def test_stage_output_scalar(self): + """StageOutput can represent a scalar value.""" + output = StageOutput( + value=120.5, + output_type="scalar", + ) + assert output.value == 120.5 + assert output.output_type == "scalar" + + def test_stage_output_to_sexp(self): + """StageOutput serializes to sexp.""" + output = StageOutput( + cache_id="abc123", + output_type="artifact", + ) + sexp = output.to_sexp() + sexp_str = serialize(sexp) + + assert "cache-id" in sexp_str + assert "abc123" in sexp_str + assert "type" in sexp_str + assert "artifact" in sexp_str + + def test_stage_output_from_sexp(self): + """StageOutput parses from sexp.""" + sexp = parse('(:cache-id "def456" :type :analysis)') + output = StageOutput.from_sexp(sexp) + + assert output.cache_id == "def456" + assert output.output_type == "analysis" + + +class TestStageCacheEntry: + """Test StageCacheEntry serialization.""" + + def test_stage_cache_entry_to_sexp(self): + """StageCacheEntry serializes to sexp.""" + entry = StageCacheEntry( + stage_name="analyze-a", + cache_id="stage_abc123", + outputs={ + "beats": StageOutput(cache_id="beats_def456", output_type="analysis"), + "tempo": StageOutput(value=120.5, output_type="scalar"), + }, + completed_at=1705678900.123, + ) + + sexp = entry.to_sexp() + sexp_str = serialize(sexp) + + assert "stage-result" in sexp_str + assert "analyze-a" in sexp_str + assert "stage_abc123" in sexp_str + assert "outputs" in sexp_str + assert "beats" in sexp_str + + def test_stage_cache_entry_roundtrip(self): + """save -> load produces identical data.""" + entry = StageCacheEntry( + stage_name="analyze-b", + cache_id="stage_xyz789", + outputs={ + "segments": StageOutput(cache_id="seg_123", output_type="artifact"), + }, + completed_at=1705678900.0, + ) + + sexp_str = entry.to_string() + loaded = StageCacheEntry.from_string(sexp_str) + + assert loaded.stage_name == entry.stage_name + assert loaded.cache_id == entry.cache_id + assert "segments" in loaded.outputs + assert loaded.outputs["segments"].cache_id == "seg_123" + + def test_stage_cache_entry_from_sexp(self): + """StageCacheEntry parses from sexp.""" + sexp_str = ''' + (stage-result + :name "test-stage" + :cache-id "cache123" + :completed-at 1705678900.0 + :outputs ((beats :cache-id "beats123" :type :analysis))) + ''' + entry = StageCacheEntry.from_string(sexp_str) + + assert entry.stage_name == "test-stage" + assert entry.cache_id == "cache123" + assert "beats" in entry.outputs + assert entry.outputs["beats"].cache_id == "beats123" + + +class TestStageCache: + """Test StageCache file operations.""" + + def test_save_and_load_stage(self): + """Save and load a stage result.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + entry = StageCacheEntry( + stage_name="analyze", + cache_id="test_cache_id", + outputs={ + "beats": StageOutput(cache_id="beats_out", output_type="analysis"), + }, + ) + + path = cache.save_stage(entry) + assert path.exists() + assert path.suffix == ".sexp" + + loaded = cache.load_stage("test_cache_id") + assert loaded is not None + assert loaded.stage_name == "analyze" + assert "beats" in loaded.outputs + + def test_has_stage(self): + """Check if stage is cached.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + assert not cache.has_stage("nonexistent") + + entry = StageCacheEntry( + stage_name="test", + cache_id="exists_cache_id", + outputs={}, + ) + cache.save_stage(entry) + + assert cache.has_stage("exists_cache_id") + + def test_delete_stage(self): + """Delete a cached stage.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + entry = StageCacheEntry( + stage_name="test", + cache_id="to_delete", + outputs={}, + ) + cache.save_stage(entry) + + assert cache.has_stage("to_delete") + result = cache.delete_stage("to_delete") + assert result is True + assert not cache.has_stage("to_delete") + + def test_list_stages(self): + """List all cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + for i in range(3): + entry = StageCacheEntry( + stage_name=f"stage{i}", + cache_id=f"cache_{i}", + outputs={}, + ) + cache.save_stage(entry) + + stages = cache.list_stages() + assert len(stages) == 3 + assert "cache_0" in stages + assert "cache_1" in stages + assert "cache_2" in stages + + def test_clear(self): + """Clear all cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + for i in range(3): + entry = StageCacheEntry( + stage_name=f"stage{i}", + cache_id=f"cache_{i}", + outputs={}, + ) + cache.save_stage(entry) + + count = cache.clear() + assert count == 3 + assert len(cache.list_stages()) == 0 + + def test_cache_file_extension(self): + """Cache files use .sexp extension.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + path = cache.get_cache_path("test_id") + assert path.suffix == ".sexp" + + def test_invalid_sexp_error_handling(self): + """Graceful error on corrupted cache file.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + # Write corrupted content + corrupt_path = cache.get_cache_path("corrupted") + corrupt_path.write_text("this is not valid sexp )()(") + + # Should return None, not raise + result = cache.load_stage("corrupted") + assert result is None + + +class TestAnalysisResult: + """Test AnalysisResult serialization.""" + + def test_analysis_result_to_sexp(self): + """AnalysisResult serializes to sexp.""" + result = AnalysisResult( + analyzer="beats", + input_hash="input_abc123", + data={ + "duration": 120.5, + "tempo": 128.0, + "times": [0.0, 0.468, 0.937, 1.406], + "values": [0.8, 0.9, 0.7, 0.85], + }, + ) + + sexp = result.to_sexp() + sexp_str = serialize(sexp) + + assert "analysis-result" in sexp_str + assert "beats" in sexp_str + assert "duration" in sexp_str + assert "tempo" in sexp_str + assert "times" in sexp_str + + def test_analysis_result_roundtrip(self): + """Analysis result round-trips through sexp.""" + original = AnalysisResult( + analyzer="scenes", + input_hash="video_xyz", + data={ + "scene_count": 5, + "scene_times": [0.0, 10.5, 25.0, 45.2, 60.0], + }, + ) + + sexp_str = original.to_string() + loaded = AnalysisResult.from_string(sexp_str) + + assert loaded.analyzer == original.analyzer + assert loaded.input_hash == original.input_hash + assert loaded.data["scene_count"] == 5 + + def test_save_and_load_analysis_result(self): + """Save and load analysis result from cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = AnalysisResult( + analyzer="beats", + input_hash="audio_123", + data={ + "tempo": 120.0, + "times": [0.0, 0.5, 1.0], + }, + ) + + path = save_analysis_result(tmpdir, "node_abc", result) + assert path.exists() + assert path.name == "analysis.sexp" + + loaded = load_analysis_result(tmpdir, "node_abc") + assert loaded is not None + assert loaded.analyzer == "beats" + assert loaded.data["tempo"] == 120.0 + + def test_analysis_result_kebab_case(self): + """Keys convert between snake_case and kebab-case.""" + result = AnalysisResult( + analyzer="test", + input_hash="hash", + data={ + "scene_count": 5, + "beat_times": [1, 2, 3], + }, + ) + + sexp_str = result.to_string() + # Kebab case in sexp + assert "scene-count" in sexp_str + assert "beat-times" in sexp_str + + # Back to snake_case after parsing + loaded = AnalysisResult.from_string(sexp_str) + assert "scene_count" in loaded.data + assert "beat_times" in loaded.data diff --git a/artdag/sexp/test_stage_compiler.py b/artdag/sexp/test_stage_compiler.py new file mode 100644 index 0000000..c1d3cc2 --- /dev/null +++ b/artdag/sexp/test_stage_compiler.py @@ -0,0 +1,286 @@ +""" +Tests for stage compilation and scoping. + +Tests the CompiledStage dataclass, stage form parsing, +variable scoping, and dependency validation. +""" + +import pytest + +from .parser import parse, Symbol, Keyword +from .compiler import ( + compile_recipe, + CompileError, + CompiledStage, + CompilerContext, + _topological_sort_stages, +) + + +class TestStageCompilation: + """Test stage form compilation.""" + + def test_parse_stage_form_basic(self): + """Stage parses correctly with name and outputs.""" + recipe = ''' + (recipe "test-stage" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats))) + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 1 + assert compiled.stages[0].name == "analyze" + assert "beats" in compiled.stages[0].outputs + assert len(compiled.stages[0].node_ids) > 0 + + def test_parse_stage_with_requires(self): + """Stage parses correctly with requires and inputs.""" + recipe = ''' + (recipe "test-requires" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [segments] + (def segments (-> audio (segment :times beats))) + (-> segments (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 2 + process_stage = next(s for s in compiled.stages if s.name == "process") + assert process_stage.requires == ["analyze"] + assert "beats" in process_stage.inputs + assert "segments" in process_stage.outputs + + def test_stage_outputs_recorded(self): + """Stage outputs are tracked in CompiledStage.""" + recipe = ''' + (recipe "test-outputs" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo))) + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + stage = compiled.stages[0] + assert "beats" in stage.outputs + assert "tempo" in stage.outputs + assert "beats" in stage.output_bindings + assert "tempo" in stage.output_bindings + + def test_stage_order_topological(self): + """Stages are topologically sorted.""" + recipe = ''' + (recipe "test-order" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + # analyze should come before output + assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output") + + +class TestStageValidation: + """Test stage dependency and input validation.""" + + def test_stage_requires_validation(self): + """Error if requiring non-existent stage.""" + recipe = ''' + (recipe "test-bad-require" + (def audio (source :path "test.mp3")) + + (stage :process + :requires [:nonexistent] + :inputs [beats] + (def result audio))) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + def test_stage_inputs_validation(self): + """Error if input not produced by required stage.""" + recipe = ''' + (recipe "test-bad-input" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [nonexistent] + (def result audio))) + ''' + with pytest.raises(CompileError, match="not an output of any required stage"): + compile_recipe(parse(recipe)) + + def test_undeclared_output_error(self): + """Error if stage declares output not defined in body.""" + recipe = ''' + (recipe "test-missing-output" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats nonexistent] + (def beats (-> audio (analyze beats))))) + ''' + with pytest.raises(CompileError, match="not defined in the stage body"): + compile_recipe(parse(recipe)) + + def test_forward_reference_detection(self): + """Error when requiring a stage not yet defined.""" + # Forward references are not allowed - stages must be defined + # before they can be required + recipe = ''' + (recipe "test-forward" + (def audio (source :path "test.mp3")) + + (stage :a + :requires [:b] + :outputs [out-a] + (def out-a audio)) + + (stage :b + :outputs [out-b] + (def out-b audio) + audio)) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + +class TestStageScoping: + """Test variable scoping between stages.""" + + def test_pre_stage_bindings_accessible(self): + """Sources defined before stages accessible to all stages.""" + recipe = ''' + (recipe "test-pre-stage" + (def audio (source :path "test.mp3")) + (def video (source :path "test.mp4")) + + (stage :analyze-audio + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :analyze-video + :outputs [scenes] + (def scenes (-> video (analyze scenes))) + (-> video (segment :times scenes) (sequence)))) + ''' + # Should compile without error - audio and video accessible to both stages + compiled = compile_recipe(parse(recipe)) + assert len(compiled.stages) == 2 + + def test_stage_bindings_flow_through_requires(self): + """Stage bindings accessible to dependent stages via :inputs.""" + recipe = ''' + (recipe "test-binding-flow" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [result] + (def result (-> audio (segment :times beats))) + (-> result (sequence)))) + ''' + # Should compile without error - beats flows from analyze to process + compiled = compile_recipe(parse(recipe)) + assert len(compiled.stages) == 2 + + +class TestTopologicalSort: + """Test stage topological sorting.""" + + def test_empty_stages(self): + """Empty stages returns empty list.""" + assert _topological_sort_stages({}) == [] + + def test_single_stage(self): + """Single stage returns single element.""" + stages = { + "a": CompiledStage( + name="a", + requires=[], + inputs=[], + outputs=["out"], + node_ids=["n1"], + output_bindings={"out": "n1"}, + ) + } + assert _topological_sort_stages(stages) == ["a"] + + def test_linear_chain(self): + """Linear chain sorted correctly.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + "c": CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + } + result = _topological_sort_stages(stages) + assert result.index("a") < result.index("b") < result.index("c") + + def test_parallel_stages_same_level(self): + """Parallel stages are both valid orderings.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=[], inputs=[], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + } + result = _topological_sort_stages(stages) + # Both a and b should be in result (order doesn't matter) + assert set(result) == {"a", "b"} + + def test_diamond_dependency(self): + """Diamond pattern: A -> B, A -> C, B+C -> D.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + "c": CompiledStage(name="c", requires=["a"], inputs=["x"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + "d": CompiledStage(name="d", requires=["b", "c"], inputs=["y", "z"], outputs=["out"], + node_ids=["n4"], output_bindings={"out": "n4"}), + } + result = _topological_sort_stages(stages) + # a must be first, d must be last + assert result[0] == "a" + assert result[-1] == "d" + # b and c must be before d + assert result.index("b") < result.index("d") + assert result.index("c") < result.index("d") diff --git a/artdag/sexp/test_stage_integration.py b/artdag/sexp/test_stage_integration.py new file mode 100644 index 0000000..f32aa46 --- /dev/null +++ b/artdag/sexp/test_stage_integration.py @@ -0,0 +1,739 @@ +""" +End-to-end integration tests for staged recipes. + +Tests the complete flow: compile -> plan -> execute +for recipes with stages. +""" + +import pytest +import tempfile +from pathlib import Path + +from .parser import parse, serialize +from .compiler import compile_recipe, CompileError +from .planner import ExecutionPlanSexp, StagePlan +from .stage_cache import StageCache, StageCacheEntry, StageOutput +from .scheduler import StagePlanScheduler, StagePlanResult + + +class TestSimpleTwoStageRecipe: + """Test basic two-stage recipe flow.""" + + def test_compile_two_stage_recipe(self): + """Compile a simple two-stage recipe.""" + recipe = ''' + (recipe "test-two-stages" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 2 + assert compiled.stage_order == ["analyze", "output"] + + analyze_stage = compiled.stages[0] + assert analyze_stage.name == "analyze" + assert "beats" in analyze_stage.outputs + + output_stage = compiled.stages[1] + assert output_stage.name == "output" + assert output_stage.requires == ["analyze"] + assert "beats" in output_stage.inputs + + +class TestParallelAnalysisStages: + """Test parallel analysis stages.""" + + def test_compile_parallel_stages(self): + """Two analysis stages can run in parallel.""" + recipe = ''' + (recipe "test-parallel" + (def audio-a (source :path "a.mp3")) + (def audio-b (source :path "b.mp3")) + + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :analyze-b + :outputs [beats-b] + (def beats-b (-> audio-b (analyze beats)))) + + (stage :combine + :requires [:analyze-a :analyze-b] + :inputs [beats-a beats-b] + (-> audio-a (segment :times beats-a) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 3 + + # analyze-a and analyze-b should both be at level 0 (parallel) + analyze_a = next(s for s in compiled.stages if s.name == "analyze-a") + analyze_b = next(s for s in compiled.stages if s.name == "analyze-b") + combine = next(s for s in compiled.stages if s.name == "combine") + + assert analyze_a.requires == [] + assert analyze_b.requires == [] + assert set(combine.requires) == {"analyze-a", "analyze-b"} + + +class TestDiamondDependency: + """Test diamond dependency pattern: A -> B, A -> C, B+C -> D.""" + + def test_compile_diamond_pattern(self): + """Diamond pattern compiles correctly.""" + recipe = ''' + (recipe "test-diamond" + (def audio (source :path "test.mp3")) + + (stage :source-stage + :outputs [audio-ref] + (def audio-ref audio)) + + (stage :branch-b + :requires [:source-stage] + :inputs [audio-ref] + :outputs [result-b] + (def result-b (-> audio-ref (effect gain :amount 0.5)))) + + (stage :branch-c + :requires [:source-stage] + :inputs [audio-ref] + :outputs [result-c] + (def result-c (-> audio-ref (effect gain :amount 0.8)))) + + (stage :merge + :requires [:branch-b :branch-c] + :inputs [result-b result-c] + (-> result-b (blend result-c :mode "mix")))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 4 + + # Check dependencies + source = next(s for s in compiled.stages if s.name == "source-stage") + branch_b = next(s for s in compiled.stages if s.name == "branch-b") + branch_c = next(s for s in compiled.stages if s.name == "branch-c") + merge = next(s for s in compiled.stages if s.name == "merge") + + assert source.requires == [] + assert branch_b.requires == ["source-stage"] + assert branch_c.requires == ["source-stage"] + assert set(merge.requires) == {"branch-b", "branch-c"} + + # source-stage should come first in order + assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-b") + assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-c") + # merge should come last + assert compiled.stage_order.index("branch-b") < compiled.stage_order.index("merge") + assert compiled.stage_order.index("branch-c") < compiled.stage_order.index("merge") + + +class TestStageReuseOnRerun: + """Test that re-running recipe uses cached stages.""" + + def test_stage_reuse(self): + """Re-running recipe uses cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Simulate first run by caching a stage + entry = StageCacheEntry( + stage_name="analyze", + cache_id="fixed_cache_id", + outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")}, + ) + stage_cache.save_stage(entry) + + # Verify cache exists + assert stage_cache.has_stage("fixed_cache_id") + + # Second run should find cache + loaded = stage_cache.load_stage("fixed_cache_id") + assert loaded is not None + assert loaded.stage_name == "analyze" + + +class TestExplicitDataFlowEndToEnd: + """Test that analysis results flow through :inputs/:outputs.""" + + def test_data_flow_declaration(self): + """Explicit data flow is declared correctly.""" + recipe = ''' + (recipe "test-data-flow" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo)))) + + (stage :process + :requires [:analyze] + :inputs [beats tempo] + :outputs [result] + (def result (-> audio (segment :times beats) (effect speed :factor tempo))) + (-> result (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + process = next(s for s in compiled.stages if s.name == "process") + + # Analyze outputs + assert set(analyze.outputs) == {"beats", "tempo"} + assert "beats" in analyze.output_bindings + assert "tempo" in analyze.output_bindings + + # Process inputs + assert set(process.inputs) == {"beats", "tempo"} + assert process.requires == ["analyze"] + + +class TestRecipeFixtures: + """Test using recipe fixtures.""" + + @pytest.fixture + def test_recipe_two_stages(self): + return ''' + (recipe "test-two-stages" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + + @pytest.fixture + def test_recipe_parallel_stages(self): + return ''' + (recipe "test-parallel" + (def audio-a (source :path "a.mp3")) + (def audio-b (source :path "b.mp3")) + + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :analyze-b + :outputs [beats-b] + (def beats-b (-> audio-b (analyze beats)))) + + (stage :combine + :requires [:analyze-a :analyze-b] + :inputs [beats-a beats-b] + (-> audio-a (blend audio-b :mode "mix")))) + ''' + + def test_two_stages_fixture(self, test_recipe_two_stages): + """Two-stage recipe fixture compiles.""" + compiled = compile_recipe(parse(test_recipe_two_stages)) + assert len(compiled.stages) == 2 + + def test_parallel_stages_fixture(self, test_recipe_parallel_stages): + """Parallel stages recipe fixture compiles.""" + compiled = compile_recipe(parse(test_recipe_parallel_stages)) + assert len(compiled.stages) == 3 + + +class TestStageValidationErrors: + """Test error handling for invalid stage recipes.""" + + def test_missing_output_declaration(self): + """Error when stage output not declared.""" + recipe = ''' + (recipe "test-missing-output" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats nonexistent] + (def beats (-> audio (analyze beats))))) + ''' + with pytest.raises(CompileError, match="not defined in the stage body"): + compile_recipe(parse(recipe)) + + def test_input_without_requires(self): + """Error when using input not from required stage.""" + recipe = ''' + (recipe "test-bad-input" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [] + :inputs [beats] + (def result audio))) + ''' + with pytest.raises(CompileError, match="not an output of any required stage"): + compile_recipe(parse(recipe)) + + def test_forward_reference(self): + """Error when requiring stage not yet defined (forward reference).""" + recipe = ''' + (recipe "test-forward-ref" + (def audio (source :path "test.mp3")) + + (stage :a + :requires [:b] + :outputs [out-a] + (def out-a audio) + audio) + + (stage :b + :outputs [out-b] + (def out-b audio) + audio)) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + +class TestBeatSyncDemoRecipe: + """Test the beat-sync demo recipe from examples.""" + + BEAT_SYNC_RECIPE = ''' + ;; Simple staged recipe demo + (recipe "beat-sync-demo" + :version "1.0" + :description "Demo of staged beat-sync workflow" + + ;; Pre-stage definitions (available to all stages) + (def audio (source :path "input.mp3")) + + ;; Stage 1: Analysis (expensive, cached) + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo)))) + + ;; Stage 2: Processing (uses analysis results) + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [segments] + (def segments (-> audio (segment :times beats))) + (-> segments (sequence)))) + ''' + + def test_compile_beat_sync_recipe(self): + """Beat-sync demo recipe compiles correctly.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + assert compiled.name == "beat-sync-demo" + assert compiled.version == "1.0" + assert compiled.description == "Demo of staged beat-sync workflow" + + def test_beat_sync_stage_count(self): + """Beat-sync has 2 stages in correct order.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + assert len(compiled.stages) == 2 + assert compiled.stage_order == ["analyze", "process"] + + def test_beat_sync_analyze_stage(self): + """Analyze stage has correct outputs.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + assert analyze.requires == [] + assert analyze.inputs == [] + assert set(analyze.outputs) == {"beats", "tempo"} + assert "beats" in analyze.output_bindings + assert "tempo" in analyze.output_bindings + + def test_beat_sync_process_stage(self): + """Process stage has correct dependencies and inputs.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + process = next(s for s in compiled.stages if s.name == "process") + assert process.requires == ["analyze"] + assert "beats" in process.inputs + assert "segments" in process.outputs + + def test_beat_sync_node_count(self): + """Beat-sync generates expected number of nodes.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + # 1 SOURCE + 2 ANALYZE + 1 SEGMENT + 1 SEQUENCE = 5 nodes + assert len(compiled.nodes) == 5 + + def test_beat_sync_node_types(self): + """Beat-sync generates correct node types.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + node_types = [n["type"] for n in compiled.nodes] + assert node_types.count("SOURCE") == 1 + assert node_types.count("ANALYZE") == 2 + assert node_types.count("SEGMENT") == 1 + assert node_types.count("SEQUENCE") == 1 + + def test_beat_sync_output_is_sequence(self): + """Beat-sync output node is the sequence node.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id) + assert output_node["type"] == "SEQUENCE" + + +class TestAsciiArtStagedRecipe: + """Test the ASCII art staged recipe.""" + + ASCII_ART_STAGED_RECIPE = ''' + ;; ASCII art effect with staged execution + (recipe "ascii_art_staged" + :version "1.0" + :description "ASCII art effect with staged execution" + :encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30) + + ;; Registry + (effect ascii_art :path "sexp_effects/effects/ascii_art.sexp") + (analyzer energy :path "../artdag-analyzers/energy/analyzer.py") + + ;; Pre-stage definitions + (def color_mode "color") + (def video (source :path "monday.webm")) + (def audio (source :path "dizzy.mp3")) + + ;; Stage 1: Analysis + (stage :analyze + :outputs [energy-data] + (def audio-clip (-> audio (segment :start 60 :duration 10))) + (def energy-data (-> audio-clip (analyze energy)))) + + ;; Stage 2: Process + (stage :process + :requires [:analyze] + :inputs [energy-data] + :outputs [result audio-clip] + (def clip (-> video (segment :start 0 :duration 10))) + (def audio-clip (-> audio (segment :start 60 :duration 10))) + (def result (-> clip + (effect ascii_art + :char_size (bind energy-data values :range [2 32]) + :color_mode color_mode)))) + + ;; Stage 3: Output + (stage :output + :requires [:process] + :inputs [result audio-clip] + (mux result audio-clip))) + ''' + + def test_compile_ascii_art_staged(self): + """ASCII art staged recipe compiles correctly.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + assert compiled.name == "ascii_art_staged" + assert compiled.version == "1.0" + + def test_ascii_art_stage_count(self): + """ASCII art has 3 stages in correct order.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + assert len(compiled.stages) == 3 + assert compiled.stage_order == ["analyze", "process", "output"] + + def test_ascii_art_analyze_stage(self): + """Analyze stage outputs energy-data.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + assert analyze.requires == [] + assert analyze.inputs == [] + assert "energy-data" in analyze.outputs + + def test_ascii_art_process_stage(self): + """Process stage requires analyze and outputs result.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + process = next(s for s in compiled.stages if s.name == "process") + assert process.requires == ["analyze"] + assert "energy-data" in process.inputs + assert "result" in process.outputs + assert "audio-clip" in process.outputs + + def test_ascii_art_output_stage(self): + """Output stage requires process and has mux.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + output = next(s for s in compiled.stages if s.name == "output") + assert output.requires == ["process"] + assert "result" in output.inputs + assert "audio-clip" in output.inputs + + def test_ascii_art_node_count(self): + """ASCII art generates expected nodes.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + # 2 SOURCE + 2 SEGMENT + 1 ANALYZE + 1 EFFECT + 1 MUX = 7+ nodes + assert len(compiled.nodes) >= 7 + + def test_ascii_art_has_mux_output(self): + """ASCII art output is MUX node.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id) + assert output_node["type"] == "MUX" + + +class TestMixedStagedAndNonStagedRecipes: + """Test that non-staged recipes still work.""" + + def test_recipe_without_stages(self): + """Non-staged recipe compiles normally.""" + recipe = ''' + (recipe "no-stages" + (-> (source :path "test.mp3") + (effect gain :amount 0.5))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert compiled.stages == [] + assert compiled.stage_order == [] + # Should still have nodes + assert len(compiled.nodes) > 0 + + def test_mixed_pre_stage_and_stages(self): + """Pre-stage definitions work with stages.""" + recipe = ''' + (recipe "mixed" + ;; Pre-stage definitions + (def audio (source :path "test.mp3")) + (def volume 0.8) + + ;; Stage using pre-stage definitions, ending with output expression + (stage :process + :outputs [result] + (def result (-> audio (effect gain :amount volume))) + result)) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 1 + # audio and volume should be accessible in stage + process = compiled.stages[0] + assert process.name == "process" + assert "result" in process.outputs + + +class TestEffectParamsBlock: + """Test :params block parsing in effect definitions.""" + + def test_parse_effect_with_params_block(self): + """Parse effect with new :params syntax.""" + from .effect_loader import load_sexp_effect + + effect_code = ''' + (define-effect test_effect + :params ( + (size :type int :default 10 :range [1 100] :desc "Size parameter") + (color :type string :default "red" :desc "Color parameter") + (enabled :type int :default 1 :range [0 1] :desc "Enable flag") + ) + frame) + ''' + name, process_fn, defaults, param_defs = load_sexp_effect(effect_code) + + assert name == "test_effect" + assert len(param_defs) == 3 + assert defaults["size"] == 10 + assert defaults["color"] == "red" + assert defaults["enabled"] == 1 + + # Check ParamDef objects + size_param = param_defs[0] + assert size_param.name == "size" + assert size_param.param_type == "int" + assert size_param.default == 10 + assert size_param.range_min == 1.0 + assert size_param.range_max == 100.0 + assert size_param.description == "Size parameter" + + color_param = param_defs[1] + assert color_param.name == "color" + assert color_param.param_type == "string" + assert color_param.default == "red" + + def test_parse_effect_with_choices(self): + """Parse effect with choices in :params.""" + from .effect_loader import load_sexp_effect + + effect_code = ''' + (define-effect mode_effect + :params ( + (mode :type string :default "fast" + :choices [fast slow medium] + :desc "Processing mode") + ) + frame) + ''' + name, _, defaults, param_defs = load_sexp_effect(effect_code) + + assert name == "mode_effect" + assert defaults["mode"] == "fast" + + mode_param = param_defs[0] + assert mode_param.choices == ["fast", "slow", "medium"] + + def test_legacy_effect_syntax_rejected(self): + """Legacy effect syntax should be rejected.""" + from .effect_loader import load_sexp_effect + import pytest + + effect_code = ''' + (define-effect legacy_effect + ((width 100) + (height 200) + (name "default")) + frame) + ''' + with pytest.raises(ValueError) as exc_info: + load_sexp_effect(effect_code) + + assert "Legacy parameter syntax" in str(exc_info.value) + assert ":params" in str(exc_info.value) + + def test_effect_params_introspection(self): + """Test that effect params are available for introspection.""" + from .effect_loader import load_sexp_effect_file + from pathlib import Path + + # Create a temp effect file + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write(''' + (define-effect introspect_test + :params ( + (alpha :type float :default 0.5 :range [0 1] :desc "Alpha value") + ) + frame) + ''') + temp_path = Path(f.name) + + try: + name, _, defaults, param_defs = load_sexp_effect_file(temp_path) + assert name == "introspect_test" + assert len(param_defs) == 1 + assert param_defs[0].name == "alpha" + assert param_defs[0].param_type == "float" + finally: + temp_path.unlink() + + +class TestConstructParamsBlock: + """Test :params block parsing in construct definitions.""" + + def test_parse_construct_params_helper(self): + """Test the _parse_construct_params helper function.""" + from .planner import _parse_construct_params + from .parser import Symbol, Keyword + + params_list = [ + [Symbol("duration"), Keyword("type"), Symbol("float"), + Keyword("default"), 5.0, Keyword("desc"), "Duration in seconds"], + [Symbol("count"), Keyword("type"), Symbol("int"), + Keyword("default"), 10], + ] + + param_names, param_defaults = _parse_construct_params(params_list) + + assert param_names == ["duration", "count"] + assert param_defaults["duration"] == 5.0 + assert param_defaults["count"] == 10 + + def test_construct_params_with_no_defaults(self): + """Test construct params where some have no default.""" + from .planner import _parse_construct_params + from .parser import Symbol, Keyword + + params_list = [ + [Symbol("required_param"), Keyword("type"), Symbol("string")], + [Symbol("optional_param"), Keyword("type"), Symbol("int"), + Keyword("default"), 42], + ] + + param_names, param_defaults = _parse_construct_params(params_list) + + assert param_names == ["required_param", "optional_param"] + assert param_defaults["required_param"] is None + assert param_defaults["optional_param"] == 42 + + +class TestParameterValidation: + """Test that unknown parameters are rejected.""" + + def test_effect_rejects_unknown_params(self): + """Effects should reject unknown parameters.""" + from .effect_loader import load_sexp_effect + import numpy as np + import pytest + + effect_code = ''' + (define-effect test_effect + :params ( + (brightness :type int :default 0 :desc "Brightness") + ) + frame) + ''' + name, process_frame, defaults, _ = load_sexp_effect(effect_code) + + # Create a test frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + state = {} + + # Valid param should work + result, _ = process_frame(frame, {"brightness": 10}, state) + assert isinstance(result, np.ndarray) + + # Unknown param should raise + with pytest.raises(ValueError) as exc_info: + process_frame(frame, {"unknown_param": 42}, state) + + assert "Unknown parameter 'unknown_param'" in str(exc_info.value) + assert "brightness" in str(exc_info.value) + + def test_effect_no_params_rejects_all(self): + """Effects with no params should reject any parameter.""" + from .effect_loader import load_sexp_effect + import numpy as np + import pytest + + effect_code = ''' + (define-effect no_params_effect + :params () + frame) + ''' + name, process_frame, defaults, _ = load_sexp_effect(effect_code) + + # Create a test frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + state = {} + + # Empty params should work + result, _ = process_frame(frame, {}, state) + assert isinstance(result, np.ndarray) + + # Any param should raise + with pytest.raises(ValueError) as exc_info: + process_frame(frame, {"any_param": 42}, state) + + assert "Unknown parameter 'any_param'" in str(exc_info.value) + assert "(none)" in str(exc_info.value) diff --git a/artdag/sexp/test_stage_planner.py b/artdag/sexp/test_stage_planner.py new file mode 100644 index 0000000..51d6d33 --- /dev/null +++ b/artdag/sexp/test_stage_planner.py @@ -0,0 +1,228 @@ +""" +Tests for stage-aware planning. + +Tests stage topological sorting, level computation, cache ID computation, +and plan metadata generation. +""" + +import pytest +from pathlib import Path + +from .parser import parse +from .compiler import compile_recipe, CompiledStage +from .planner import ( + create_plan, + StagePlan, + _compute_stage_levels, + _compute_stage_cache_id, +) + + +class TestStagePlanning: + """Test stage-aware plan creation.""" + + def test_stage_topological_sort_in_plan(self): + """Stages sorted by dependencies in plan.""" + recipe = ''' + (recipe "test-sort" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + # Note: create_plan needs recipe_dir for analysis, we'll test the ordering differently + assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output") + + def test_stage_level_computation(self): + """Independent stages get same level.""" + stages = [ + CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + CompiledStage(name="b", requires=[], inputs=[], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + CompiledStage(name="c", requires=["a", "b"], inputs=["x", "y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + ] + levels = _compute_stage_levels(stages) + + assert levels["a"] == 0 + assert levels["b"] == 0 + assert levels["c"] == 1 # Depends on a and b + + def test_stage_level_chain(self): + """Chain stages get increasing levels.""" + stages = [ + CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + ] + levels = _compute_stage_levels(stages) + + assert levels["a"] == 0 + assert levels["b"] == 1 + assert levels["c"] == 2 + + def test_stage_cache_id_deterministic(self): + """Same stage = same cache ID.""" + stage = CompiledStage( + name="analyze", + requires=[], + inputs=[], + outputs=["beats"], + node_ids=["abc123"], + output_bindings={"beats": "abc123"}, + ) + + cache_id_1 = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + cache_id_2 = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + + assert cache_id_1 == cache_id_2 + + def test_stage_cache_id_includes_requires(self): + """Cache ID changes when required stage cache ID changes.""" + stage = CompiledStage( + name="process", + requires=["analyze"], + inputs=["beats"], + outputs=["result"], + node_ids=["def456"], + output_bindings={"result": "def456"}, + ) + + cache_id_1 = _compute_stage_cache_id( + stage, + stage_cache_ids={"analyze": "req_cache_a"}, + node_cache_ids={"def456": "node_def"}, + cluster_key=None, + ) + cache_id_2 = _compute_stage_cache_id( + stage, + stage_cache_ids={"analyze": "req_cache_b"}, + node_cache_ids={"def456": "node_def"}, + cluster_key=None, + ) + + # Different required stage cache IDs should produce different cache IDs + assert cache_id_1 != cache_id_2 + + def test_stage_cache_id_cluster_key(self): + """Cache ID changes with cluster key.""" + stage = CompiledStage( + name="analyze", + requires=[], + inputs=[], + outputs=["beats"], + node_ids=["abc123"], + output_bindings={"beats": "abc123"}, + ) + + cache_id_no_key = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + cache_id_with_key = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key="cluster123", + ) + + # Cluster key should change the cache ID + assert cache_id_no_key != cache_id_with_key + + +class TestStagePlanMetadata: + """Test stage metadata in execution plans.""" + + def test_plan_without_stages(self): + """Plan without stages has empty stage fields.""" + recipe = ''' + (recipe "no-stages" + (-> (source :path "test.mp3") (effect gain :amount 0.5))) + ''' + compiled = compile_recipe(parse(recipe)) + assert compiled.stages == [] + assert compiled.stage_order == [] + + +class TestStagePlanDataclass: + """Test StagePlan dataclass.""" + + def test_stage_plan_creation(self): + """StagePlan can be created with all fields.""" + from .planner import PlanStep + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={"analyzer": "beats"}, + inputs=["input1"], + cache_id="cache123", + level=0, + stage="analyze", + stage_cache_id="stage_cache_123", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="stage_cache_123", + steps=[step], + requires=[], + output_bindings={"beats": "cache123"}, + level=0, + ) + + assert stage_plan.stage_name == "analyze" + assert stage_plan.cache_id == "stage_cache_123" + assert len(stage_plan.steps) == 1 + assert stage_plan.level == 0 + + +class TestExplicitDataRouting: + """Test that plan includes explicit data routing.""" + + def test_plan_step_includes_stage_info(self): + """PlanStep includes stage and stage_cache_id.""" + from .planner import PlanStep + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache123", + level=0, + stage="analyze", + stage_cache_id="stage_cache_abc", + ) + + sexp = step.to_sexp() + # Convert to string to check for stage info + from .parser import serialize + sexp_str = serialize(sexp) + + assert "stage" in sexp_str + assert "analyze" in sexp_str + assert "stage-cache-id" in sexp_str diff --git a/artdag/sexp/test_stage_scheduler.py b/artdag/sexp/test_stage_scheduler.py new file mode 100644 index 0000000..c7bab64 --- /dev/null +++ b/artdag/sexp/test_stage_scheduler.py @@ -0,0 +1,323 @@ +""" +Tests for stage-aware scheduler. + +Tests stage cache hit/miss, stage execution ordering, +and parallel stage support. +""" + +import pytest +import tempfile +from unittest.mock import Mock, MagicMock, patch + +from .scheduler import ( + StagePlanScheduler, + StageResult, + StagePlanResult, + create_stage_scheduler, + schedule_staged_plan, +) +from .planner import ExecutionPlanSexp, PlanStep, StagePlan +from .stage_cache import StageCache, StageCacheEntry, StageOutput + + +class TestStagePlanScheduler: + """Test stage-aware scheduling.""" + + def test_plan_without_stages_uses_regular_scheduling(self): + """Plans without stages fall back to regular scheduling.""" + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[], + output_step_id="output", + stage_plans=[], # No stages + ) + + scheduler = StagePlanScheduler() + # This will use PlanScheduler internally + # Without Celery, it just returns completed status + result = scheduler.schedule(plan) + + assert isinstance(result, StagePlanResult) + + def test_stage_cache_hit_skips_execution(self): + """Cached stage not re-executed.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Pre-populate cache + entry = StageCacheEntry( + stage_name="analyze", + cache_id="stage_cache_123", + outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")}, + ) + stage_cache.save_stage(entry) + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="step_cache", + level=0, + stage="analyze", + stage_cache_id="stage_cache_123", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="stage_cache_123", + steps=[step], + requires=[], + output_bindings={"beats": "beats_out"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step], + output_step_id="step1", + stage_plans=[stage_plan], + stage_order=["analyze"], + stage_levels={"analyze": 0}, + stage_cache_ids={"analyze": "stage_cache_123"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + assert result.stages_cached == 1 + assert result.stages_completed == 0 + + def test_stage_inputs_loaded_from_cache(self): + """Stage receives inputs from required stage cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Pre-populate upstream stage cache + upstream_entry = StageCacheEntry( + stage_name="analyze", + cache_id="upstream_cache", + outputs={"beats": StageOutput(cache_id="beats_data", output_type="analysis")}, + ) + stage_cache.save_stage(upstream_entry) + + # Steps for stages + upstream_step = PlanStep( + step_id="analyze_step", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="analyze_cache", + level=0, + stage="analyze", + stage_cache_id="upstream_cache", + ) + + downstream_step = PlanStep( + step_id="process_step", + node_type="SEGMENT", + config={}, + inputs=["analyze_step"], + cache_id="process_cache", + level=1, + stage="process", + stage_cache_id="downstream_cache", + ) + + upstream_plan = StagePlan( + stage_name="analyze", + cache_id="upstream_cache", + steps=[upstream_step], + requires=[], + output_bindings={"beats": "beats_data"}, + level=0, + ) + + downstream_plan = StagePlan( + stage_name="process", + cache_id="downstream_cache", + steps=[downstream_step], + requires=["analyze"], + output_bindings={"result": "process_cache"}, + level=1, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[upstream_step, downstream_step], + output_step_id="process_step", + stage_plans=[upstream_plan, downstream_plan], + stage_order=["analyze", "process"], + stage_levels={"analyze": 0, "process": 1}, + stage_cache_ids={"analyze": "upstream_cache", "process": "downstream_cache"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + # Upstream should be cached, downstream executed + assert result.stages_cached == 1 + assert "analyze" in result.stage_results + assert result.stage_results["analyze"].status == "cached" + + def test_parallel_stages_same_level(self): + """Stages at same level can run in parallel.""" + step_a = PlanStep( + step_id="step_a", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache_a", + level=0, + stage="analyze-a", + stage_cache_id="stage_a", + ) + + step_b = PlanStep( + step_id="step_b", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache_b", + level=0, + stage="analyze-b", + stage_cache_id="stage_b", + ) + + stage_a = StagePlan( + stage_name="analyze-a", + cache_id="stage_a", + steps=[step_a], + requires=[], + output_bindings={"beats-a": "cache_a"}, + level=0, + ) + + stage_b = StagePlan( + stage_name="analyze-b", + cache_id="stage_b", + steps=[step_b], + requires=[], + output_bindings={"beats-b": "cache_b"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step_a, step_b], + output_step_id="step_b", + stage_plans=[stage_a, stage_b], + stage_order=["analyze-a", "analyze-b"], + stage_levels={"analyze-a": 0, "analyze-b": 0}, + stage_cache_ids={"analyze-a": "stage_a", "analyze-b": "stage_b"}, + ) + + scheduler = StagePlanScheduler() + # Group stages by level + stages_by_level = scheduler._group_stages_by_level(plan.stage_plans) + + # Both stages should be at level 0 + assert len(stages_by_level[0]) == 2 + + def test_stage_outputs_cached_after_execution(self): + """Stage outputs written to cache after completion.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="step_cache", + level=0, + stage="analyze", + stage_cache_id="new_stage_cache", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="new_stage_cache", + steps=[step], + requires=[], + output_bindings={"beats": "step_cache"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step], + output_step_id="step1", + stage_plans=[stage_plan], + stage_order=["analyze"], + stage_levels={"analyze": 0}, + stage_cache_ids={"analyze": "new_stage_cache"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + # Stage should now be cached + assert stage_cache.has_stage("new_stage_cache") + + +class TestStageResult: + """Test StageResult dataclass.""" + + def test_stage_result_creation(self): + """StageResult can be created with all fields.""" + result = StageResult( + stage_name="test", + cache_id="cache123", + status="completed", + step_results={}, + outputs={"out": "out_cache"}, + ) + + assert result.stage_name == "test" + assert result.status == "completed" + assert result.outputs["out"] == "out_cache" + + +class TestStagePlanResult: + """Test StagePlanResult dataclass.""" + + def test_stage_plan_result_creation(self): + """StagePlanResult can be created with all fields.""" + result = StagePlanResult( + plan_id="plan123", + status="completed", + stages_completed=2, + stages_cached=1, + stages_failed=0, + ) + + assert result.plan_id == "plan123" + assert result.stages_completed == 2 + assert result.stages_cached == 1 + + +class TestSchedulerFactory: + """Test scheduler factory functions.""" + + def test_create_stage_scheduler(self): + """create_stage_scheduler returns StagePlanScheduler.""" + scheduler = create_stage_scheduler() + assert isinstance(scheduler, StagePlanScheduler) + + def test_create_stage_scheduler_with_cache(self): + """create_stage_scheduler accepts stage_cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + scheduler = create_stage_scheduler(stage_cache=stage_cache) + assert scheduler.stage_cache is stage_cache diff --git a/docs/EXECUTION_MODEL.md b/docs/EXECUTION_MODEL.md new file mode 100644 index 0000000..6779721 --- /dev/null +++ b/docs/EXECUTION_MODEL.md @@ -0,0 +1,384 @@ +# Art DAG 3-Phase Execution Model + +## Overview + +The execution model separates DAG processing into three distinct phases: + +``` +Recipe + Inputs → ANALYZE → Analysis Results + ↓ +Analysis + Recipe → PLAN → Execution Plan (with cache IDs) + ↓ +Execution Plan → EXECUTE → Cached Results +``` + +This separation enables: +1. **Incremental development** - Re-run recipes without reprocessing unchanged steps +2. **Parallel execution** - Independent steps run concurrently via Celery +3. **Deterministic caching** - Same inputs always produce same cache IDs +4. **Cost estimation** - Plan phase can estimate work before executing + +## Phase 1: Analysis + +### Purpose +Extract features from input media that inform downstream processing decisions. + +### Inputs +- Recipe YAML with input references +- Input media files (by content hash) + +### Outputs +Analysis results stored as JSON, keyed by input hash: + +```python +@dataclass +class AnalysisResult: + input_hash: str + features: Dict[str, Any] + # Audio features + beats: Optional[List[float]] # Beat times in seconds + downbeats: Optional[List[float]] # Bar-start times + tempo: Optional[float] # BPM + energy: Optional[List[Tuple[float, float]]] # (time, value) envelope + spectrum: Optional[Dict[str, List[Tuple[float, float]]]] # band envelopes + # Video features + duration: float + frame_rate: float + dimensions: Tuple[int, int] + motion_tempo: Optional[float] # Estimated BPM from motion +``` + +### Implementation +```python +class Analyzer: + def analyze(self, input_hash: str, features: List[str]) -> AnalysisResult: + """Extract requested features from input.""" + + def analyze_audio(self, path: Path) -> AudioFeatures: + """Extract all audio features using librosa/essentia.""" + + def analyze_video(self, path: Path) -> VideoFeatures: + """Extract video metadata and motion analysis.""" +``` + +### Caching +Analysis results are cached by: +``` +analysis_cache_id = SHA3-256(input_hash + sorted(feature_names)) +``` + +## Phase 2: Planning + +### Purpose +Convert recipe + analysis into a complete execution plan with pre-computed cache IDs. + +### Inputs +- Recipe YAML (parsed) +- Analysis results for all inputs +- Recipe parameters (user-supplied values) + +### Outputs +An ExecutionPlan containing ordered steps, each with a pre-computed cache ID: + +```python +@dataclass +class ExecutionStep: + step_id: str # Unique identifier + node_type: str # Primitive type (SOURCE, SEQUENCE, etc.) + config: Dict[str, Any] # Node configuration + input_steps: List[str] # IDs of steps this depends on + cache_id: str # Pre-computed: hash(inputs + config) + estimated_duration: float # Optional: for progress reporting + +@dataclass +class ExecutionPlan: + plan_id: str # Hash of entire plan + recipe_id: str # Source recipe + steps: List[ExecutionStep] # Topologically sorted + analysis: Dict[str, AnalysisResult] + output_step: str # Final step ID + + def compute_cache_ids(self): + """Compute all cache IDs in dependency order.""" +``` + +### Cache ID Computation + +Cache IDs are computed in topological order so each step's cache ID +incorporates its inputs' cache IDs: + +```python +def compute_cache_id(step: ExecutionStep, resolved_inputs: Dict[str, str]) -> str: + """ + Cache ID = SHA3-256( + node_type + + canonical_json(config) + + sorted([input_cache_ids]) + ) + """ + components = [ + step.node_type, + json.dumps(step.config, sort_keys=True), + *sorted(resolved_inputs[s] for s in step.input_steps) + ] + return sha3_256('|'.join(components)) +``` + +### Plan Generation + +The planner expands recipe nodes into concrete steps: + +1. **SOURCE nodes** → Direct step with input hash as cache ID +2. **ANALYZE nodes** → Step that references analysis results +3. **TRANSFORM nodes** → Step with static config +4. **TRANSFORM_DYNAMIC nodes** → Expanded to per-frame steps (or use BIND output) +5. **SEQUENCE nodes** → Tree reduction for parallel composition +6. **MAP nodes** → Expanded to N parallel steps + reduction + +### Tree Reduction for Composition + +Instead of sequential pairwise composition: +``` +A → B → C → D (3 sequential steps) +``` + +Use parallel tree reduction: +``` +A ─┬─ AB ─┬─ ABCD +B ─┘ │ +C ─┬─ CD ─┘ +D ─┘ + +Level 0: [A, B, C, D] (4 parallel) +Level 1: [AB, CD] (2 parallel) +Level 2: [ABCD] (1 final) +``` + +This reduces O(N) to O(log N) levels. + +## Phase 3: Execution + +### Purpose +Execute the plan, skipping steps with cached results. + +### Inputs +- ExecutionPlan with pre-computed cache IDs +- Cache state (which IDs already exist) + +### Process + +1. **Claim Check**: For each step, atomically check if result is cached +2. **Task Dispatch**: Uncached steps dispatched to Celery workers +3. **Parallel Execution**: Independent steps run concurrently +4. **Result Storage**: Each step stores result with its cache ID +5. **Progress Tracking**: Real-time status updates + +### Hash-Based Task Claiming + +Prevents duplicate work when multiple workers process the same plan: + +```lua +-- Redis Lua script for atomic claim +local key = KEYS[1] +local data = redis.call('GET', key) +if data then + local status = cjson.decode(data) + if status.status == 'running' or + status.status == 'completed' or + status.status == 'cached' then + return 0 -- Already claimed/done + end +end +local claim_data = ARGV[1] +local ttl = tonumber(ARGV[2]) +redis.call('SETEX', key, ttl, claim_data) +return 1 -- Successfully claimed +``` + +### Celery Task Structure + +```python +@app.task(bind=True) +def execute_step(self, step_json: str, plan_id: str) -> dict: + """Execute a single step with caching.""" + step = ExecutionStep.from_json(step_json) + + # Check cache first + if cache.has(step.cache_id): + return {'status': 'cached', 'cache_id': step.cache_id} + + # Try to claim this work + if not claim_task(step.cache_id, self.request.id): + # Another worker is handling it, wait for result + return wait_for_result(step.cache_id) + + # Do the work + executor = get_executor(step.node_type) + input_paths = [cache.get(s) for s in step.input_steps] + output_path = cache.get_output_path(step.cache_id) + + result_path = executor.execute(step.config, input_paths, output_path) + cache.put(step.cache_id, result_path) + + return {'status': 'completed', 'cache_id': step.cache_id} +``` + +### Execution Orchestration + +```python +class PlanExecutor: + def execute(self, plan: ExecutionPlan) -> ExecutionResult: + """Execute plan with parallel Celery tasks.""" + + # Group steps by level (steps at same level can run in parallel) + levels = self.compute_dependency_levels(plan.steps) + + for level_steps in levels: + # Dispatch all steps at this level + tasks = [ + execute_step.delay(step.to_json(), plan.plan_id) + for step in level_steps + if not self.cache.has(step.cache_id) + ] + + # Wait for level completion + results = [task.get() for task in tasks] + + return self.collect_results(plan) +``` + +## Data Flow Example + +### Recipe: beat-cuts +```yaml +nodes: + - id: music + type: SOURCE + config: { input: true } + + - id: beats + type: ANALYZE + config: { feature: beats } + inputs: [music] + + - id: videos + type: SOURCE_LIST + config: { input: true } + + - id: slices + type: MAP + config: { operation: RANDOM_SLICE } + inputs: + items: videos + timing: beats + + - id: final + type: SEQUENCE + inputs: [slices] +``` + +### Phase 1: Analysis +```python +# Input: music file with hash abc123 +analysis = { + 'abc123': AnalysisResult( + beats=[0.0, 0.48, 0.96, 1.44, ...], + tempo=125.0, + duration=180.0 + ) +} +``` + +### Phase 2: Planning +```python +# Expands MAP into concrete steps +plan = ExecutionPlan( + steps=[ + # Source steps + ExecutionStep(id='music', cache_id='abc123', ...), + ExecutionStep(id='video_0', cache_id='def456', ...), + ExecutionStep(id='video_1', cache_id='ghi789', ...), + + # Slice steps (one per beat group) + ExecutionStep(id='slice_0', cache_id='hash(video_0+timing)', ...), + ExecutionStep(id='slice_1', cache_id='hash(video_1+timing)', ...), + ... + + # Tree reduction for sequence + ExecutionStep(id='seq_0_1', inputs=['slice_0', 'slice_1'], ...), + ExecutionStep(id='seq_2_3', inputs=['slice_2', 'slice_3'], ...), + ExecutionStep(id='seq_final', inputs=['seq_0_1', 'seq_2_3'], ...), + ] +) +``` + +### Phase 3: Execution +``` +Level 0: [music, video_0, video_1] → all cached (SOURCE) +Level 1: [slice_0, slice_1, slice_2, slice_3] → 4 parallel tasks +Level 2: [seq_0_1, seq_2_3] → 2 parallel SEQUENCE tasks +Level 3: [seq_final] → 1 final SEQUENCE task +``` + +## File Structure + +``` +artdag/ +├── artdag/ +│ ├── analysis/ +│ │ ├── __init__.py +│ │ ├── analyzer.py # Main Analyzer class +│ │ ├── audio.py # Audio feature extraction +│ │ └── video.py # Video feature extraction +│ ├── planning/ +│ │ ├── __init__.py +│ │ ├── planner.py # RecipePlanner class +│ │ ├── schema.py # ExecutionPlan, ExecutionStep +│ │ └── tree_reduction.py # Parallel composition optimizer +│ └── execution/ +│ ├── __init__.py +│ ├── executor.py # PlanExecutor class +│ └── claiming.py # Hash-based task claiming + +art-celery/ +├── tasks/ +│ ├── __init__.py +│ ├── analyze.py # analyze_inputs task +│ ├── plan.py # generate_plan task +│ ├── execute.py # execute_step task +│ └── orchestrate.py # run_plan (coordinates all) +├── claiming.py # Redis Lua scripts +└── ... +``` + +## CLI Interface + +```bash +# Full pipeline +artdag run-recipe recipes/beat-cuts/recipe.yaml \ + -i music:abc123 \ + -i videos:def456,ghi789 + +# Phase by phase +artdag analyze recipes/beat-cuts/recipe.yaml -i music:abc123 +# → outputs analysis.json + +artdag plan recipes/beat-cuts/recipe.yaml --analysis analysis.json +# → outputs plan.json + +artdag execute plan.json +# → runs with caching, skips completed steps + +# Dry run (show what would execute) +artdag execute plan.json --dry-run +# → shows which steps are cached vs need execution +``` + +## Benefits + +1. **Development Speed**: Change recipe, re-run → only affected steps execute +2. **Parallelism**: Independent steps run on multiple Celery workers +3. **Reproducibility**: Same inputs + recipe = same cache IDs = same output +4. **Visibility**: Plan shows exactly what will happen before execution +5. **Cost Control**: Estimate compute before committing resources +6. **Fault Tolerance**: Failed runs resume from last successful step diff --git a/docs/IPFS_PRIMARY_ARCHITECTURE.md b/docs/IPFS_PRIMARY_ARCHITECTURE.md new file mode 100644 index 0000000..2e53aaf --- /dev/null +++ b/docs/IPFS_PRIMARY_ARCHITECTURE.md @@ -0,0 +1,443 @@ +# IPFS-Primary Architecture (Sketch) + +A simplified L1 architecture for large-scale distributed rendering where IPFS is the primary data store. + +## Current vs Simplified + +| Component | Current | Simplified | +|-----------|---------|------------| +| Local cache | Custom, per-worker | IPFS node handles it | +| Redis content_index | content_hash → node_id | Eliminated | +| Redis ipfs_index | content_hash → ipfs_cid | Eliminated | +| Step inputs | File paths | IPFS CIDs | +| Step outputs | File path + CID | Just CID | +| Cache lookup | Local → Redis → IPFS | Just IPFS | + +## Core Principle + +**Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.** + +``` +Step input: [cid1, cid2, ...] +Step output: cid_out +``` + +## Worker Architecture + +Each worker runs: + +``` +┌─────────────────────────────────────┐ +│ Worker Node │ +│ │ +│ ┌───────────┐ ┌──────────────┐ │ +│ │ Celery │────│ IPFS Node │ │ +│ │ Worker │ │ (local) │ │ +│ └───────────┘ └──────────────┘ │ +│ │ │ │ +│ │ ┌─────┴─────┐ │ +│ │ │ Local │ │ +│ │ │ Blockstore│ │ +│ │ └───────────┘ │ +│ │ │ +│ ┌────┴────┐ │ +│ │ /tmp │ (ephemeral workspace) │ +│ └─────────┘ │ +└─────────────────────────────────────┘ + │ + │ IPFS libp2p + ▼ + ┌─────────────┐ + │ Other IPFS │ + │ Nodes │ + └─────────────┘ +``` + +## Execution Flow + +### 1. Plan Generation (unchanged) + +```python +plan = planner.plan(recipe, input_hashes) +# plan.steps[].cache_id = deterministic hash +``` + +### 2. Input Registration + +Before execution, register inputs with IPFS: + +```python +input_cids = {} +for name, path in inputs.items(): + cid = ipfs.add(path) + input_cids[name] = cid + +# Plan now carries CIDs +plan.input_cids = input_cids +``` + +### 3. Step Execution + +```python +@celery.task +def execute_step(step_json: str, input_cids: dict[str, str]) -> str: + """Execute step, return output CID.""" + step = ExecutionStep.from_json(step_json) + + # Check if already computed (by cache_id as IPNS key or DHT lookup) + existing_cid = ipfs.resolve(f"/ipns/{step.cache_id}") + if existing_cid: + return existing_cid + + # Fetch inputs from IPFS → local temp files + input_paths = [] + for input_step_id in step.input_steps: + cid = input_cids[input_step_id] + path = ipfs.get(cid, f"/tmp/{cid}") # IPFS node caches automatically + input_paths.append(path) + + # Execute + output_path = f"/tmp/{step.cache_id}.mkv" + executor = get_executor(step.node_type) + executor.execute(step.config, input_paths, output_path) + + # Add output to IPFS + output_cid = ipfs.add(output_path) + + # Publish cache_id → CID mapping (optional, for cache hits) + ipfs.name_publish(step.cache_id, output_cid) + + # Cleanup temp files + cleanup_temp(input_paths + [output_path]) + + return output_cid +``` + +### 4. Orchestration + +```python +@celery.task +def run_plan(plan_json: str) -> str: + """Execute plan, return final output CID.""" + plan = ExecutionPlan.from_json(plan_json) + + # CID results accumulate as steps complete + cid_results = dict(plan.input_cids) + + for level in plan.get_steps_by_level(): + # Parallel execution within level + tasks = [] + for step in level: + step_input_cids = { + sid: cid_results[sid] + for sid in step.input_steps + } + tasks.append(execute_step.s(step.to_json(), step_input_cids)) + + # Wait for level to complete + results = group(tasks).apply_async().get() + + # Record output CIDs + for step, cid in zip(level, results): + cid_results[step.step_id] = cid + + return cid_results[plan.output_step] +``` + +## What's Eliminated + +### No more Redis indexes + +```python +# BEFORE: Complex index management +self._set_content_index(content_hash, node_id) # Redis + local +self._set_ipfs_index(content_hash, ipfs_cid) # Redis + local +node_id = self._get_content_index(content_hash) # Check Redis, fallback local + +# AFTER: Just CIDs +output_cid = ipfs.add(output_path) +return output_cid +``` + +### No more local cache management + +```python +# BEFORE: Custom cache with entries, metadata, cleanup +cache.put(node_id, source_path, node_type, execution_time) +cache.get(node_id) +cache.has(node_id) +cache.cleanup_lru() + +# AFTER: IPFS handles it +ipfs.add(path) # Store +ipfs.get(cid) # Retrieve (cached by IPFS node) +ipfs.pin(cid) # Keep permanently +ipfs.gc() # Cleanup unpinned +``` + +### No more content_hash vs node_id confusion + +```python +# BEFORE: Two identifiers +content_hash = sha3_256(file_bytes) # What the file IS +node_id = cache_id # What computation produced it +# Need indexes to map between them + +# AFTER: One identifier +cid = ipfs.add(file) # Content-addressed, includes hash +# CID IS the identifier +``` + +## Cache Hit Detection + +Two options: + +### Option A: IPNS (mutable names) + +```python +# Publish: cache_id → CID +ipfs.name_publish(key=cache_id, value=output_cid) + +# Lookup before executing +existing = ipfs.name_resolve(cache_id) +if existing: + return existing # Cache hit +``` + +### Option B: DHT record + +```python +# Store in DHT: cache_id → CID +ipfs.dht_put(cache_id, output_cid) + +# Lookup +existing = ipfs.dht_get(cache_id) +``` + +### Option C: Redis (minimal) + +Keep Redis just for cache_id → CID mapping: + +```python +# Store +redis.hset("artdag:cache", cache_id, output_cid) + +# Lookup +existing = redis.hget("artdag:cache", cache_id) +``` + +This is simpler than current approach - one hash, one mapping, no content_hash/node_id confusion. + +## Claiming (Preventing Duplicate Work) + +Still need Redis for atomic claiming: + +```python +# Claim before executing +claimed = redis.set(f"artdag:claim:{cache_id}", worker_id, nx=True, ex=300) +if not claimed: + # Another worker is doing it - wait for result + return wait_for_result(cache_id) +``` + +Or use IPFS pubsub for coordination. + +## Data Flow Diagram + +``` + ┌─────────────┐ + │ Recipe │ + │ + Inputs │ + └──────┬──────┘ + │ + ▼ + ┌─────────────┐ + │ Planner │ + │ (compute │ + │ cache_ids) │ + └──────┬──────┘ + │ + ▼ + ┌─────────────────────────────────┐ + │ ExecutionPlan │ + │ - steps with cache_ids │ + │ - input_cids (from ipfs.add) │ + └─────────────────┬───────────────┘ + │ + ┌────────────┼────────────┐ + ▼ ▼ ▼ + ┌────────┐ ┌────────┐ ┌────────┐ + │Worker 1│ │Worker 2│ │Worker 3│ + │ │ │ │ │ │ + │ IPFS │◄──│ IPFS │◄──│ IPFS │ + │ Node │──►│ Node │──►│ Node │ + └───┬────┘ └───┬────┘ └───┬────┘ + │ │ │ + └────────────┼────────────┘ + │ + ▼ + ┌─────────────┐ + │ Final CID │ + │ (output) │ + └─────────────┘ +``` + +## Benefits + +1. **Simpler code** - No custom cache, no dual indexes +2. **Automatic distribution** - IPFS handles replication +3. **Content verification** - CIDs are self-verifying +4. **Scalable** - Add workers = add IPFS nodes = more cache capacity +5. **Resilient** - Any node can serve any content + +## Tradeoffs + +1. **IPFS dependency** - Every worker needs IPFS node +2. **Initial fetch latency** - First fetch may be slower than local disk +3. **IPNS latency** - Name resolution can be slow (Option C avoids this) + +## Trust Domains (Cluster Key) + +Systems can share work through IPFS, but how do you trust them? + +**Problem:** A malicious system could return wrong CIDs for computed steps. + +**Solution:** Cluster key creates isolated trust domains: + +```bash +export ARTDAG_CLUSTER_KEY="my-secret-shared-key" +``` + +**How it works:** +- The cluster key is mixed into all cache_id computations +- Systems with the same key produce the same cache_ids +- Systems with different keys have separate cache namespaces +- Only share the key with trusted partners + +``` +cache_id = SHA3-256(cluster_key + node_type + config + inputs) +``` + +**Trust model:** +| Scenario | Same Key? | Can Share Work? | +|----------|-----------|-----------------| +| Same organization | Yes | Yes | +| Trusted partner | Yes (shared) | Yes | +| Unknown system | No | No (different cache_ids) | + +**Configuration:** +```yaml +# docker-compose.yml +environment: + - ARTDAG_CLUSTER_KEY=your-secret-key-here +``` + +**Programmatic:** +```python +from artdag.planning.schema import set_cluster_key +set_cluster_key("my-secret-key") +``` + +## Implementation + +The simplified architecture is implemented in `art-celery/`: + +| File | Purpose | +|------|---------| +| `hybrid_state.py` | Hybrid state manager (Redis + IPNS) | +| `tasks/execute_cid.py` | Step execution with CIDs | +| `tasks/analyze_cid.py` | Analysis with CIDs | +| `tasks/orchestrate_cid.py` | Full pipeline orchestration | + +### Key Functions + +**Registration (local → IPFS):** +- `register_input_cid(path)` → `{cid, content_hash}` +- `register_recipe_cid(path)` → `{cid, name, version}` + +**Analysis:** +- `analyze_input_cid(input_cid, input_hash, features)` → `{analysis_cid}` + +**Planning:** +- `generate_plan_cid(recipe_cid, input_cids, input_hashes, analysis_cids)` → `{plan_cid}` + +**Execution:** +- `execute_step_cid(step_json, input_cids)` → `{cid}` +- `execute_plan_from_cid(plan_cid, input_cids)` → `{output_cid}` + +**Full Pipeline:** +- `run_recipe_cid(recipe_cid, input_cids, input_hashes)` → `{output_cid, all_cids}` +- `run_from_local(recipe_path, input_paths)` → registers + runs + +### Hybrid State Manager + +For distributed L1 coordination, use the `HybridStateManager` which provides: + +**Fast path (local Redis):** +- `get_cached_cid(cache_id)` / `set_cached_cid(cache_id, cid)` - microsecond lookups +- `try_claim(cache_id, worker_id)` / `release_claim(cache_id)` - atomic claiming +- `get_analysis_cid()` / `set_analysis_cid()` - analysis cache +- `get_plan_cid()` / `set_plan_cid()` - plan cache +- `get_run_cid()` / `set_run_cid()` - run cache + +**Slow path (background IPNS sync):** +- Periodically syncs local state with global IPNS state (default: every 30s) +- Pulls new entries from remote nodes +- Pushes local updates to IPNS + +**Configuration:** +```bash +# Enable IPNS sync +export ARTDAG_IPNS_SYNC=true +export ARTDAG_IPNS_SYNC_INTERVAL=30 # seconds +``` + +**Usage:** +```python +from hybrid_state import get_state_manager + +state = get_state_manager() + +# Fast local lookup +cid = state.get_cached_cid(cache_id) + +# Fast local write (synced in background) +state.set_cached_cid(cache_id, output_cid) + +# Atomic claim +if state.try_claim(cache_id, worker_id): + # We have the lock + ... +``` + +**Trade-offs:** +- Local Redis: Fast (microseconds), single node +- IPNS sync: Slow (seconds), eventually consistent across nodes +- Duplicate work: Accepted (idempotent - same inputs → same CID) + +### Redis Usage (minimal) + +| Key | Type | Purpose | +|-----|------|---------| +| `artdag:cid_cache` | Hash | cache_id → output CID | +| `artdag:analysis_cache` | Hash | input_hash:features → analysis CID | +| `artdag:plan_cache` | Hash | plan_id → plan CID | +| `artdag:run_cache` | Hash | run_id → output CID | +| `artdag:claim:{cache_id}` | String | worker_id (TTL 5 min) | + +## Migration Path + +1. Keep current system working ✓ +2. Add CID-based tasks ✓ + - `execute_cid.py` ✓ + - `analyze_cid.py` ✓ + - `orchestrate_cid.py` ✓ +3. Add `--ipfs-primary` flag to CLI ✓ +4. Add hybrid state manager for L1 coordination ✓ +5. Gradually deprecate local cache code +6. Remove old tasks when CID versions are stable + +## See Also + +- [L1_STORAGE.md](L1_STORAGE.md) - Current L1 architecture +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase model diff --git a/docs/L1_STORAGE.md b/docs/L1_STORAGE.md new file mode 100644 index 0000000..c371329 --- /dev/null +++ b/docs/L1_STORAGE.md @@ -0,0 +1,181 @@ +# L1 Distributed Storage Architecture + +This document describes how data is stored when running artdag on L1 (the distributed rendering layer). + +## Overview + +L1 uses four storage systems working together: + +| System | Purpose | Data Stored | +|--------|---------|-------------| +| **Local Cache** | Hot storage (fast access) | Media files, plans, analysis | +| **IPFS** | Durable content-addressed storage | All media outputs | +| **Redis** | Coordination & indexes | Claims, mappings, run status | +| **PostgreSQL** | Metadata & ownership | User data, provenance | + +## Storage Flow + +When a step executes on L1: + +``` +1. Executor produces output file +2. Store in local cache (fast) +3. Compute content_hash = SHA3-256(file) +4. Upload to IPFS → get ipfs_cid +5. Update indexes: + - content_hash → node_id (Redis + local) + - content_hash → ipfs_cid (Redis + local) +``` + +Every intermediate step output (SEGMENT, SEQUENCE, etc.) gets its own IPFS CID. + +## Local Cache + +Hot storage on each worker node: + +``` +cache_dir/ + index.json # Cache metadata + content_index.json # content_hash → node_id + ipfs_index.json # content_hash → ipfs_cid + plans/ + {plan_id}.json # Cached execution plans + analysis/ + {hash}.json # Analysis results + {node_id}/ + output.mkv # Media output + metadata.json # CacheEntry metadata +``` + +## IPFS - Durable Media Storage + +All media files are stored in IPFS for durability and content-addressing. + +**Supported pinning providers:** +- Pinata +- web3.storage +- NFT.Storage +- Infura IPFS +- Filebase (S3-compatible) +- Storj (decentralized) +- Local IPFS node + +**Configuration:** +```bash +IPFS_API=/ip4/127.0.0.1/tcp/5001 # Local IPFS daemon +``` + +## Redis - Coordination + +Redis handles distributed coordination across workers. + +### Key Patterns + +| Key | Type | Purpose | +|-----|------|---------| +| `artdag:run:{run_id}` | String | Run status, timestamps, celery task ID | +| `artdag:content_index` | Hash | content_hash → node_id mapping | +| `artdag:ipfs_index` | Hash | content_hash → ipfs_cid mapping | +| `artdag:claim:{cache_id}` | String | Task claiming (prevents duplicate work) | + +### Task Claiming + +Lua scripts ensure atomic claiming across workers: + +``` +Status flow: PENDING → CLAIMED → RUNNING → COMPLETED/CACHED/FAILED +TTL: 5 minutes for claims, 1 hour for results +``` + +This prevents two workers from executing the same step. + +## PostgreSQL - Metadata + +Stores ownership, provenance, and sharing metadata. + +### Tables + +```sql +-- Core cache (shared) +cache_items (content_hash, ipfs_cid, created_at) + +-- Per-user ownership +item_types (content_hash, actor_id, type, metadata) + +-- Run cache (deterministic identity) +run_cache ( + run_id, -- SHA3-256(sorted_inputs + recipe) + output_hash, + ipfs_cid, + provenance_cid, + recipe, inputs, actor_id +) + +-- Storage backends +storage_backends (actor_id, provider_type, config, capacity_gb) + +-- What's stored where +storage_pins (content_hash, storage_id, ipfs_cid, pin_type) +``` + +## Cache Lookup Flow + +When a worker needs a file: + +``` +1. Check local cache by cache_id (fastest) +2. Check Redis content_index: content_hash → node_id +3. Check PostgreSQL cache_items +4. Retrieve from IPFS by CID +5. Store in local cache for next hit +``` + +## Local vs L1 Comparison + +| Feature | Local Testing | L1 Distributed | +|---------|---------------|----------------| +| Local cache | Yes | Yes | +| IPFS | No | Yes | +| Redis | No | Yes | +| PostgreSQL | No | Yes | +| Multi-worker | No | Yes | +| Task claiming | No | Yes (Lua scripts) | +| Durability | Filesystem only | IPFS + PostgreSQL | + +## Content Addressing + +All storage uses SHA3-256 (quantum-resistant): + +- **Files:** `content_hash = SHA3-256(file_bytes)` +- **Computation:** `cache_id = SHA3-256(type + config + input_hashes)` +- **Run identity:** `run_id = SHA3-256(sorted_inputs + recipe)` +- **Plans:** `plan_id = SHA3-256(recipe + inputs + analysis)` + +This ensures: +- Same inputs → same outputs (reproducibility) +- Automatic deduplication across workers +- Content verification (tamper detection) + +## Configuration + +Default locations: + +```bash +# Local cache +~/.artdag/cache # Default +/data/cache # Docker + +# Redis +redis://localhost:6379/5 + +# PostgreSQL +postgresql://user:pass@host/artdag + +# IPFS +/ip4/127.0.0.1/tcp/5001 +``` + +## See Also + +- [OFFLINE_TESTING.md](OFFLINE_TESTING.md) - Local testing without L1 +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model diff --git a/docs/OFFLINE_TESTING.md b/docs/OFFLINE_TESTING.md new file mode 100644 index 0000000..68d1559 --- /dev/null +++ b/docs/OFFLINE_TESTING.md @@ -0,0 +1,211 @@ +# Offline Testing Strategy + +This document describes how to test artdag locally without requiring Redis, IPFS, Celery, or any external distributed infrastructure. + +## Overview + +The artdag system uses a **3-Phase Execution Model** that enables complete offline testing: + +1. **Analysis** - Extract features from input media +2. **Planning** - Generate deterministic execution plan with pre-computed cache IDs +3. **Execution** - Run plan steps, skipping cached results + +This separation allows testing each phase independently and running full pipelines locally. + +## Quick Start + +Run a full offline test with a video file: + +```bash +./examples/test_local.sh ../artdag-art-source/dog.mkv +``` + +This will: +1. Compute the SHA3-256 hash of the input video +2. Run the `simple_sequence` recipe +3. Store all outputs in `test_cache/` + +## Test Scripts + +### `test_local.sh` - Full Pipeline Test + +Location: `./examples/test_local.sh` + +Runs the complete artdag pipeline offline with a real video file. + +**Usage:** +```bash +./examples/test_local.sh +``` + +**Example:** +```bash +./examples/test_local.sh ../artdag-art-source/dog.mkv +``` + +**What it does:** +- Computes content hash of input video +- Runs `artdag run-recipe` with `simple_sequence.yaml` +- Stores outputs in `test_cache/` directory +- No external services required + +### `test_plan.py` - Planning Phase Test + +Location: `./examples/test_plan.py` + +Tests the planning phase without requiring any media files. + +**Usage:** +```bash +python3 examples/test_plan.py +``` + +**What it tests:** +- Recipe loading and YAML parsing +- Execution plan generation +- Cache ID computation (deterministic) +- Multi-level parallel step organization +- Human-readable step names +- Multi-output support + +**Output:** +- Prints plan structure to console +- Saves full plan to `test_plan_output.json` + +### `simple_sequence.yaml` - Sample Recipe + +Location: `./examples/simple_sequence.yaml` + +A simple recipe for testing that: +- Takes a video input +- Extracts two segments (0-2s and 5-7s) +- Concatenates them with SEQUENCE + +## Test Outputs + +All test outputs are stored locally and git-ignored: + +| Output | Description | +|--------|-------------| +| `test_cache/` | Cached execution results (media files, analysis, plans) | +| `test_cache/plans/` | Cached execution plans by plan_id | +| `test_cache/analysis/` | Cached analysis results by input hash | +| `test_plan_output.json` | Generated execution plan from `test_plan.py` | + +## Unit Tests + +The project includes a comprehensive pytest test suite in `tests/`: + +```bash +# Run all unit tests +pytest + +# Run specific test file +pytest tests/test_dag.py +pytest tests/test_engine.py +pytest tests/test_cache.py +``` + +## Testing Each Phase + +### Phase 1: Analysis Only + +Extract features without full execution: + +```bash +python3 -m artdag.cli analyze -i :@ --features beats,energy +``` + +### Phase 2: Planning Only + +Generate an execution plan (no media needed): + +```bash +python3 -m artdag.cli plan -i : +``` + +Or use the test script: + +```bash +python3 examples/test_plan.py +``` + +### Phase 3: Execution Only + +Execute a pre-generated plan: + +```bash +python3 -m artdag.cli execute plan.json +``` + +With dry-run to see what would execute: + +```bash +python3 -m artdag.cli execute plan.json --dry-run +``` + +## Key Testing Features + +### Content Addressing + +All nodes have deterministic IDs computed as: +``` +SHA3-256(type + config + sorted(input_IDs)) +``` + +Same inputs always produce same cache IDs, enabling: +- Reproducibility across runs +- Automatic deduplication +- Incremental execution (only changed steps run) + +### Local Caching + +The `test_cache/` directory stores: +- `plans/{plan_id}.json` - Execution plans (deterministic hash of recipe + inputs + analysis) +- `analysis/{hash}.json` - Analysis results (audio beats, tempo, energy) +- `{cache_id}/output.mkv` - Media outputs from each step + +Subsequent test runs automatically skip cached steps. Plans are cached by their `plan_id`, which is a SHA3-256 hash of the recipe, input hashes, and analysis results - so the same recipe with the same inputs always produces the same plan. + +### No External Dependencies + +Offline testing requires: +- Python 3.9+ +- ffmpeg (for media processing) +- No Redis, IPFS, Celery, or network access + +## Debugging Tips + +1. **Check cache contents:** + ```bash + ls -la test_cache/ + ls -la test_cache/plans/ + ``` + +2. **View cached plan:** + ```bash + cat test_cache/plans/*.json | python3 -m json.tool | head -50 + ``` + +3. **View execution plan structure:** + ```bash + cat test_plan_output.json | python3 -m json.tool + ``` + +4. **Run with verbose output:** + ```bash + python3 -m artdag.cli run-recipe examples/simple_sequence.yaml \ + -i "video:HASH@path" \ + --cache-dir test_cache \ + -v + ``` + +5. **Dry-run to see what would execute:** + ```bash + python3 -m artdag.cli execute plan.json --dry-run + ``` + +## See Also + +- [L1_STORAGE.md](L1_STORAGE.md) - Distributed storage on L1 (IPFS, Redis, PostgreSQL) +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model diff --git a/effects/identity/README.md b/effects/identity/README.md new file mode 100644 index 0000000..afb6cb0 --- /dev/null +++ b/effects/identity/README.md @@ -0,0 +1,35 @@ +# Identity Effect + +The identity effect returns its input unchanged. It serves as the foundational primitive in the effects registry. + +## Purpose + +- **Testing**: Verify the effects pipeline is working correctly +- **No-op placeholder**: Use when an effect slot requires a value but no transformation is needed +- **Composition base**: The neutral element for effect composition + +## Signature + +``` +identity(input) → input +``` + +## Properties + +- **Idempotent**: `identity(identity(x)) = identity(x)` +- **Neutral**: For any effect `f`, `identity ∘ f = f ∘ identity = f` + +## Implementation + +```python +def identity(input): + return input +``` + +## Content Hash + +The identity effect is content-addressed by its behavior: given any input, the output hash equals the input hash. + +## Owner + +Registered by `@giles@artdag.rose-ash.com` diff --git a/effects/identity/requirements.txt b/effects/identity/requirements.txt new file mode 100644 index 0000000..805e561 --- /dev/null +++ b/effects/identity/requirements.txt @@ -0,0 +1,2 @@ +# Identity effect has no dependencies +# It's a pure function: identity(x) = x diff --git a/examples/simple_sequence.yaml b/examples/simple_sequence.yaml new file mode 100644 index 0000000..d4ce009 --- /dev/null +++ b/examples/simple_sequence.yaml @@ -0,0 +1,42 @@ +# Simple sequence recipe - concatenates segments from a single input video +name: simple_sequence +version: "1.0" +description: "Split input into segments and concatenate them" +owner: test@local + +dag: + nodes: + # Input source - variable (provided at runtime) + - id: video + type: SOURCE + config: + input: true + name: "Input Video" + description: "The video to process" + + # Extract first 2 seconds + - id: seg1 + type: SEGMENT + config: + start: 0.0 + end: 2.0 + inputs: + - video + + # Extract seconds 5-7 + - id: seg2 + type: SEGMENT + config: + start: 5.0 + end: 7.0 + inputs: + - video + + # Concatenate the segments + - id: output + type: SEQUENCE + inputs: + - seg1 + - seg2 + + output: output diff --git a/examples/test_local.sh b/examples/test_local.sh new file mode 100755 index 0000000..083f718 --- /dev/null +++ b/examples/test_local.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Local testing script for artdag +# Tests the 3-phase execution without Redis/IPFS + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ARTDAG_DIR="$(dirname "$SCRIPT_DIR")" +CACHE_DIR="${ARTDAG_DIR}/test_cache" +RECIPE="${SCRIPT_DIR}/simple_sequence.yaml" + +# Check for input video +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "" + echo "Example:" + echo " $0 /path/to/test_video.mp4" + exit 1 +fi + +VIDEO_PATH="$1" +if [ ! -f "$VIDEO_PATH" ]; then + echo "Error: Video file not found: $VIDEO_PATH" + exit 1 +fi + +# Compute content hash of input +echo "=== Computing input hash ===" +VIDEO_HASH=$(python3 -c " +import hashlib +with open('$VIDEO_PATH', 'rb') as f: + print(hashlib.sha3_256(f.read()).hexdigest()) +") +echo "Input hash: ${VIDEO_HASH:0:16}..." + +# Change to artdag directory +cd "$ARTDAG_DIR" + +# Run the full pipeline +echo "" +echo "=== Running artdag run-recipe ===" +echo "Recipe: $RECIPE" +echo "Input: video:${VIDEO_HASH:0:16}...@$VIDEO_PATH" +echo "Cache: $CACHE_DIR" +echo "" + +python3 -m artdag.cli run-recipe "$RECIPE" \ + -i "video:${VIDEO_HASH}@${VIDEO_PATH}" \ + --cache-dir "$CACHE_DIR" + +echo "" +echo "=== Done ===" +echo "Cache directory: $CACHE_DIR" +echo "Use 'ls -la $CACHE_DIR' to see cached outputs" diff --git a/examples/test_plan.py b/examples/test_plan.py new file mode 100755 index 0000000..9b3a257 --- /dev/null +++ b/examples/test_plan.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +Test the planning phase locally. + +This tests the new human-readable names and multi-output support +without requiring actual video files or execution. +""" + +import hashlib +import json +import sys +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from artdag.planning import RecipePlanner, Recipe, ExecutionPlan + + +def main(): + # Load recipe + recipe_path = Path(__file__).parent / "simple_sequence.yaml" + if not recipe_path.exists(): + print(f"Recipe not found: {recipe_path}") + return 1 + + recipe = Recipe.from_file(recipe_path) + print(f"Recipe: {recipe.name} v{recipe.version}") + print(f"Nodes: {len(recipe.nodes)}") + print() + + # Fake input hash (would be real content hash in production) + fake_input_hash = hashlib.sha3_256(b"fake video content").hexdigest() + input_hashes = {"video": fake_input_hash} + + print(f"Input: video -> {fake_input_hash[:16]}...") + print() + + # Generate plan + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + seed=42, # Optional seed for reproducibility + ) + + print("=== Generated Plan ===") + print(f"Plan ID: {plan.plan_id[:24]}...") + print(f"Plan Name: {plan.name}") + print(f"Recipe Name: {plan.recipe_name}") + print(f"Output: {plan.output_name}") + print(f"Steps: {len(plan.steps)}") + print() + + # Show steps by level + steps_by_level = plan.get_steps_by_level() + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"Level {level}: {len(steps)} step(s)") + for step in steps: + # Show human-readable name + name = step.name or step.step_id[:20] + print(f" - {name}") + print(f" Type: {step.node_type}") + print(f" Cache ID: {step.cache_id[:16]}...") + if step.outputs: + print(f" Outputs: {len(step.outputs)}") + for out in step.outputs: + print(f" - {out.name} ({out.media_type})") + if step.inputs: + print(f" Inputs: {[inp.name for inp in step.inputs]}") + print() + + # Save plan for inspection + plan_path = Path(__file__).parent.parent / "test_plan_output.json" + with open(plan_path, "w") as f: + f.write(plan.to_json()) + print(f"Plan saved to: {plan_path}") + + # Show plan JSON structure + print() + print("=== Plan JSON Preview ===") + plan_dict = json.loads(plan.to_json()) + # Show first step as example + if plan_dict.get("steps"): + first_step = plan_dict["steps"][0] + print(json.dumps(first_step, indent=2)[:500] + "...") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9ac24c7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "artdag" +version = "0.1.0" +description = "Content-addressed DAG execution engine with ActivityPub ownership" +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.10" +authors = [ + {name = "Giles", email = "giles@rose-ash.com"} +] +keywords = ["dag", "content-addressed", "activitypub", "video", "processing"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "cryptography>=41.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", +] +analysis = [ + "librosa>=0.10.0", + "numpy>=1.24.0", + "pyyaml>=6.0", +] +cv = [ + "opencv-python>=4.8.0", +] +all = [ + "librosa>=0.10.0", + "numpy>=1.24.0", + "pyyaml>=6.0", + "opencv-python>=4.8.0", +] + +[project.scripts] +artdag = "artdag.cli:main" + +[project.urls] +Homepage = "https://artdag.rose-ash.com" +Repository = "https://github.com/giles/artdag" + +[tool.setuptools.packages.find] +where = ["."] +include = ["artdag*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] diff --git a/scripts/compute_repo_hash.py b/scripts/compute_repo_hash.py new file mode 100644 index 0000000..8e841e1 --- /dev/null +++ b/scripts/compute_repo_hash.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Compute content hash of a git repository. + +Hashes all tracked files (respects .gitignore) in sorted order. +""" + +import hashlib +import subprocess +import sys +from pathlib import Path + + +def repo_hash(repo_path: Path) -> str: + """ + Compute SHA3-256 hash of all tracked files in a repo. + + Uses git ls-files to respect .gitignore. + Files are hashed in sorted order for determinism. + Each file contributes: relative_path + file_contents + """ + # Get list of tracked files + result = subprocess.run( + ["git", "ls-files"], + cwd=repo_path, + capture_output=True, + text=True, + check=True, + ) + + files = sorted(result.stdout.strip().split("\n")) + + hasher = hashlib.sha3_256() + + for rel_path in files: + if not rel_path: + continue + + file_path = repo_path / rel_path + if not file_path.is_file(): + continue + + # Include path in hash + hasher.update(rel_path.encode()) + + # Include contents + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + + return hasher.hexdigest() + + +def main(): + if len(sys.argv) > 1: + repo_path = Path(sys.argv[1]) + else: + repo_path = Path.cwd() + + h = repo_hash(repo_path) + print(f"Repository: {repo_path}") + print(f"Hash: {h}") + return h + + +if __name__ == "__main__": + main() diff --git a/scripts/install-ffglitch.sh b/scripts/install-ffglitch.sh new file mode 100755 index 0000000..d7301f2 --- /dev/null +++ b/scripts/install-ffglitch.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# Install ffglitch for datamosh effects +# Usage: ./install-ffglitch.sh [install_dir] + +set -e + +FFGLITCH_VERSION="0.10.2" +INSTALL_DIR="${1:-/usr/local/bin}" + +# Detect architecture +ARCH=$(uname -m) +case "$ARCH" in + x86_64) + URL="https://ffglitch.org/pub/bin/linux64/ffglitch-${FFGLITCH_VERSION}-linux-x86_64.zip" + ARCHIVE="ffglitch.zip" + ;; + aarch64) + URL="https://ffglitch.org/pub/bin/linux-aarch64/ffglitch-${FFGLITCH_VERSION}-linux-aarch64.7z" + ARCHIVE="ffglitch.7z" + ;; + *) + echo "Unsupported architecture: $ARCH" + exit 1 + ;; +esac + +echo "Installing ffglitch ${FFGLITCH_VERSION} for ${ARCH}..." + +# Create temp directory +TMPDIR=$(mktemp -d) +cd "$TMPDIR" + +# Download +echo "Downloading from ${URL}..." +curl -L -o "$ARCHIVE" "$URL" + +# Extract +echo "Extracting..." +if [[ "$ARCHIVE" == *.zip ]]; then + unzip -q "$ARCHIVE" +elif [[ "$ARCHIVE" == *.7z ]]; then + # Requires p7zip + if ! command -v 7z &> /dev/null; then + echo "7z not found. Install with: apt install p7zip-full" + exit 1 + fi + 7z x "$ARCHIVE" > /dev/null +fi + +# Find and install binaries +echo "Installing to ${INSTALL_DIR}..." +find . -name "ffgac" -o -name "ffedit" | while read bin; do + chmod +x "$bin" + if [ -w "$INSTALL_DIR" ]; then + cp "$bin" "$INSTALL_DIR/" + else + sudo cp "$bin" "$INSTALL_DIR/" + fi + echo " Installed: $(basename $bin)" +done + +# Cleanup +cd / +rm -rf "$TMPDIR" + +# Verify +echo "" +echo "Verifying installation..." +if command -v ffgac &> /dev/null; then + echo "ffgac: $(which ffgac)" +else + echo "Warning: ffgac not in PATH. Add ${INSTALL_DIR} to PATH." +fi + +if command -v ffedit &> /dev/null; then + echo "ffedit: $(which ffedit)" +else + echo "Warning: ffedit not in PATH. Add ${INSTALL_DIR} to PATH." +fi + +echo "" +echo "Done! ffglitch installed." diff --git a/scripts/register_identity_effect.py b/scripts/register_identity_effect.py new file mode 100644 index 0000000..0194698 --- /dev/null +++ b/scripts/register_identity_effect.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Register the identity effect owned by giles. +""" + +import hashlib +from pathlib import Path +import sys + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from artdag.activitypub.ownership import OwnershipManager + + +def folder_hash(folder: Path) -> str: + """ + Compute SHA3-256 hash of an entire folder. + + Hashes all files in sorted order for deterministic results. + Each file contributes: relative_path + file_contents + """ + hasher = hashlib.sha3_256() + + # Get all files sorted by relative path + files = sorted(folder.rglob("*")) + + for file_path in files: + if file_path.is_file(): + # Include relative path in hash for structure + rel_path = file_path.relative_to(folder) + hasher.update(str(rel_path).encode()) + + # Include file contents + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + + return hasher.hexdigest() + + +def main(): + # Use .cache as the ownership data directory + base_dir = Path(__file__).parent.parent / ".cache" / "ownership" + manager = OwnershipManager(base_dir) + + # Create or get giles actor + actor = manager.get_actor("giles") + if not actor: + actor = manager.create_actor("giles", "Giles Bradshaw") + print(f"Created actor: {actor.handle}") + else: + print(f"Using existing actor: {actor.handle}") + + # Register the identity effect folder + effect_path = Path(__file__).parent.parent / "effects" / "identity" + cid = folder_hash(effect_path) + + asset, activity = manager.register_asset( + actor=actor, + name="effect:identity", + cid=cid, + local_path=effect_path, + tags=["effect", "primitive", "identity"], + metadata={ + "type": "effect", + "description": "The identity effect - returns input unchanged", + "signature": "identity(input) → input", + }, + ) + + print(f"\nRegistered: {asset.name}") + print(f" Hash: {asset.cid}") + print(f" Path: {asset.local_path}") + print(f" Activity: {activity.activity_id}") + print(f" Owner: {actor.handle}") + + # Verify ownership + verified = manager.verify_ownership(asset.name, actor) + print(f" Ownership verified: {verified}") + +if __name__ == "__main__": + main() diff --git a/scripts/setup_actor.py b/scripts/setup_actor.py new file mode 100644 index 0000000..b1c80cf --- /dev/null +++ b/scripts/setup_actor.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Set up actor with keypair stored securely. + +Private key: ~/.artdag/keys/{username}.pem +Public key: exported for registry +""" + +import json +import os +import sys +from datetime import datetime, timezone +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend + + +def create_keypair(): + """Generate RSA-2048 keypair.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + return private_key + + +def save_private_key(private_key, path: Path): + """Save private key to PEM file.""" + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(pem) + os.chmod(path, 0o600) # Owner read/write only + return pem.decode() + + +def get_public_key_pem(private_key) -> str: + """Extract public key as PEM string.""" + public_key = private_key.public_key() + pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return pem.decode() + + +def create_actor_json(username: str, display_name: str, public_key_pem: str, domain: str = "artdag.rose-ash.com"): + """Create ActivityPub actor JSON.""" + return { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1" + ], + "type": "Person", + "id": f"https://{domain}/users/{username}", + "preferredUsername": username, + "name": display_name, + "inbox": f"https://{domain}/users/{username}/inbox", + "outbox": f"https://{domain}/users/{username}/outbox", + "publicKey": { + "id": f"https://{domain}/users/{username}#main-key", + "owner": f"https://{domain}/users/{username}", + "publicKeyPem": public_key_pem + } + } + + +def main(): + username = "giles" + display_name = "Giles Bradshaw" + domain = "artdag.rose-ash.com" + + keys_dir = Path.home() / ".artdag" / "keys" + private_key_path = keys_dir / f"{username}.pem" + + # Check if key already exists + if private_key_path.exists(): + print(f"Private key already exists: {private_key_path}") + print("Delete it first if you want to regenerate.") + sys.exit(1) + + # Create new keypair + print(f"Creating new keypair for @{username}@{domain}...") + private_key = create_keypair() + + # Save private key + save_private_key(private_key, private_key_path) + print(f"Private key saved: {private_key_path}") + print(f" Mode: 600 (owner read/write only)") + print(f" BACK THIS UP!") + + # Get public key + public_key_pem = get_public_key_pem(private_key) + + # Create actor JSON + actor = create_actor_json(username, display_name, public_key_pem, domain) + + # Output actor JSON + actor_json = json.dumps(actor, indent=2) + print(f"\nActor JSON (for registry/actors/{username}.json):") + print(actor_json) + + # Save to registry + registry_path = Path.home() / "artdag-registry" / "actors" / f"{username}.json" + registry_path.parent.mkdir(parents=True, exist_ok=True) + registry_path.write_text(actor_json) + print(f"\nSaved to: {registry_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sign_assets.py b/scripts/sign_assets.py new file mode 100644 index 0000000..8021f78 --- /dev/null +++ b/scripts/sign_assets.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Sign assets in the registry with giles's private key. + +Creates ActivityPub Create activities with RSA signatures. +""" + +import base64 +import hashlib +import json +import sys +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.backends import default_backend + + +def load_private_key(path: Path): + """Load private key from PEM file.""" + pem_data = path.read_bytes() + return serialization.load_pem_private_key(pem_data, password=None, backend=default_backend()) + + +def sign_data(private_key, data: str) -> str: + """Sign data with RSA private key, return base64 signature.""" + signature = private_key.sign( + data.encode(), + padding.PKCS1v15(), + hashes.SHA256(), + ) + return base64.b64encode(signature).decode() + + +def create_activity(actor_id: str, asset_name: str, cid: str, asset_type: str, domain: str = "artdag.rose-ash.com"): + """Create a Create activity for an asset.""" + now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + return { + "activity_id": str(uuid.uuid4()), + "activity_type": "Create", + "actor_id": actor_id, + "object_data": { + "type": asset_type_to_ap(asset_type), + "name": asset_name, + "id": f"https://{domain}/objects/{cid}", + "contentHash": { + "algorithm": "sha3-256", + "value": cid + }, + "attributedTo": actor_id + }, + "published": now, + } + + +def asset_type_to_ap(asset_type: str) -> str: + """Convert asset type to ActivityPub type.""" + type_map = { + "image": "Image", + "video": "Video", + "audio": "Audio", + "effect": "Application", + "infrastructure": "Application", + } + return type_map.get(asset_type, "Document") + + +def sign_activity(activity: dict, private_key, actor_id: str, domain: str = "artdag.rose-ash.com") -> dict: + """Add signature to activity.""" + # Create canonical string to sign + to_sign = json.dumps(activity["object_data"], sort_keys=True, separators=(",", ":")) + + signature_value = sign_data(private_key, to_sign) + + activity["signature"] = { + "type": "RsaSignature2017", + "creator": f"{actor_id}#main-key", + "created": activity["published"], + "signatureValue": signature_value + } + + return activity + + +def main(): + username = "giles" + domain = "artdag.rose-ash.com" + actor_id = f"https://{domain}/users/{username}" + + # Load private key + private_key_path = Path.home() / ".artdag" / "keys" / f"{username}.pem" + if not private_key_path.exists(): + print(f"Private key not found: {private_key_path}") + print("Run setup_actor.py first.") + sys.exit(1) + + private_key = load_private_key(private_key_path) + print(f"Loaded private key: {private_key_path}") + + # Load registry + registry_path = Path.home() / "artdag-registry" / "registry.json" + with open(registry_path) as f: + registry = json.load(f) + + # Create signed activities for each asset + activities = [] + + for asset_name, asset_data in registry["assets"].items(): + print(f"\nSigning: {asset_name}") + print(f" Hash: {asset_data['cid'][:16]}...") + + activity = create_activity( + actor_id=actor_id, + asset_name=asset_name, + cid=asset_data["cid"], + asset_type=asset_data["asset_type"], + domain=domain, + ) + + signed_activity = sign_activity(activity, private_key, actor_id, domain) + activities.append(signed_activity) + + print(f" Activity ID: {signed_activity['activity_id']}") + print(f" Signature: {signed_activity['signature']['signatureValue'][:32]}...") + + # Save activities + activities_path = Path.home() / "artdag-registry" / "activities.json" + activities_data = { + "version": "1.0", + "activities": activities + } + + with open(activities_path, "w") as f: + json.dump(activities_data, f, indent=2) + + print(f"\nSaved {len(activities)} signed activities to: {activities_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..f6aed20 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests for new standalone primitive engine diff --git a/tests/test_activities.py b/tests/test_activities.py new file mode 100644 index 0000000..36ba61d --- /dev/null +++ b/tests/test_activities.py @@ -0,0 +1,613 @@ +# tests/test_activities.py +"""Tests for the activity tracking and cache deletion system.""" + +import tempfile +import time +from pathlib import Path + +import pytest + +from artdag import Cache, DAG, Node, NodeType +from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + + +class MockActivityPubStore: + """Mock ActivityPub store for testing is_shared functionality.""" + + def __init__(self): + self._shared_hashes = set() + + def mark_shared(self, cid: str): + """Mark a content hash as shared (published).""" + self._shared_hashes.add(cid) + + def find_by_object_hash(self, cid: str): + """Return mock activities for shared hashes.""" + if cid in self._shared_hashes: + return [MockActivity("Create")] + return [] + + +class MockActivity: + """Mock ActivityPub activity.""" + def __init__(self, activity_type: str): + self.activity_type = activity_type + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def cache(temp_dir): + """Create a cache instance.""" + return Cache(temp_dir / "cache") + + +@pytest.fixture +def activity_store(temp_dir): + """Create an activity store instance.""" + return ActivityStore(temp_dir / "activities") + + +@pytest.fixture +def ap_store(): + """Create a mock ActivityPub store.""" + return MockActivityPubStore() + + +@pytest.fixture +def manager(cache, activity_store, ap_store): + """Create an ActivityManager instance.""" + return ActivityManager( + cache=cache, + activity_store=activity_store, + is_shared_fn=make_is_shared_fn(ap_store), + ) + + +def create_test_file(path: Path, content: str = "test content") -> Path: + """Create a test file with content.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + +class TestCacheEntryContentHash: + """Tests for cid in CacheEntry.""" + + def test_put_computes_cid(self, cache, temp_dir): + """put() should compute and store cid.""" + test_file = create_test_file(temp_dir / "input.txt", "hello world") + + cache.put("node1", test_file, "test") + entry = cache.get_entry("node1") + + assert entry is not None + assert entry.cid != "" + assert len(entry.cid) == 64 # SHA-3-256 hex + + def test_same_content_same_hash(self, cache, temp_dir): + """Same file content should produce same hash.""" + file1 = create_test_file(temp_dir / "file1.txt", "identical content") + file2 = create_test_file(temp_dir / "file2.txt", "identical content") + + cache.put("node1", file1, "test") + cache.put("node2", file2, "test") + + entry1 = cache.get_entry("node1") + entry2 = cache.get_entry("node2") + + assert entry1.cid == entry2.cid + + def test_different_content_different_hash(self, cache, temp_dir): + """Different file content should produce different hash.""" + file1 = create_test_file(temp_dir / "file1.txt", "content A") + file2 = create_test_file(temp_dir / "file2.txt", "content B") + + cache.put("node1", file1, "test") + cache.put("node2", file2, "test") + + entry1 = cache.get_entry("node1") + entry2 = cache.get_entry("node2") + + assert entry1.cid != entry2.cid + + def test_find_by_cid(self, cache, temp_dir): + """Should find entry by content hash.""" + test_file = create_test_file(temp_dir / "input.txt", "unique content") + cache.put("node1", test_file, "test") + + entry = cache.get_entry("node1") + found = cache.find_by_cid(entry.cid) + + assert found is not None + assert found.node_id == "node1" + + def test_cid_persists(self, temp_dir): + """cid should persist across cache reloads.""" + cache1 = Cache(temp_dir / "cache") + test_file = create_test_file(temp_dir / "input.txt", "persistent") + cache1.put("node1", test_file, "test") + original_hash = cache1.get_entry("node1").cid + + # Create new cache instance (reload from disk) + cache2 = Cache(temp_dir / "cache") + entry = cache2.get_entry("node1") + + assert entry.cid == original_hash + + +class TestActivity: + """Tests for Activity dataclass.""" + + def test_activity_from_dag(self): + """Activity.from_dag() should classify nodes correctly.""" + # Build a simple DAG: source -> transform -> output + dag = DAG() + source = Node(NodeType.SOURCE, {"path": "/test.mp4"}) + transform = Node(NodeType.TRANSFORM, {"effect": "blur"}, inputs=[source.node_id]) + output = Node(NodeType.RESIZE, {"width": 100}, inputs=[transform.node_id]) + + dag.add_node(source) + dag.add_node(transform) + dag.add_node(output) + dag.set_output(output.node_id) + + activity = Activity.from_dag(dag) + + assert source.node_id in activity.input_ids + assert activity.output_id == output.node_id + assert transform.node_id in activity.intermediate_ids + + def test_activity_with_multiple_inputs(self): + """Activity should handle DAGs with multiple source nodes.""" + dag = DAG() + source1 = Node(NodeType.SOURCE, {"path": "/a.mp4"}) + source2 = Node(NodeType.SOURCE, {"path": "/b.mp4"}) + sequence = Node(NodeType.SEQUENCE, {}, inputs=[source1.node_id, source2.node_id]) + + dag.add_node(source1) + dag.add_node(source2) + dag.add_node(sequence) + dag.set_output(sequence.node_id) + + activity = Activity.from_dag(dag) + + assert len(activity.input_ids) == 2 + assert source1.node_id in activity.input_ids + assert source2.node_id in activity.input_ids + assert activity.output_id == sequence.node_id + assert len(activity.intermediate_ids) == 0 + + def test_activity_serialization(self): + """Activity should serialize and deserialize correctly.""" + dag = DAG() + source = Node(NodeType.SOURCE, {"path": "/test.mp4"}) + dag.add_node(source) + dag.set_output(source.node_id) + + activity = Activity.from_dag(dag) + data = activity.to_dict() + restored = Activity.from_dict(data) + + assert restored.activity_id == activity.activity_id + assert restored.input_ids == activity.input_ids + assert restored.output_id == activity.output_id + assert restored.intermediate_ids == activity.intermediate_ids + + def test_all_node_ids(self): + """all_node_ids should return all nodes.""" + activity = Activity( + activity_id="test", + input_ids=["a", "b"], + output_id="c", + intermediate_ids=["d", "e"], + created_at=time.time(), + ) + + all_ids = activity.all_node_ids + assert set(all_ids) == {"a", "b", "c", "d", "e"} + + +class TestActivityStore: + """Tests for ActivityStore persistence.""" + + def test_add_and_get(self, activity_store): + """Should add and retrieve activities.""" + activity = Activity( + activity_id="test1", + input_ids=["input1"], + output_id="output1", + intermediate_ids=["inter1"], + created_at=time.time(), + ) + + activity_store.add(activity) + retrieved = activity_store.get("test1") + + assert retrieved is not None + assert retrieved.activity_id == "test1" + + def test_persistence(self, temp_dir): + """Activities should persist across store reloads.""" + store1 = ActivityStore(temp_dir / "activities") + activity = Activity( + activity_id="persist", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + store1.add(activity) + + # Reload + store2 = ActivityStore(temp_dir / "activities") + retrieved = store2.get("persist") + + assert retrieved is not None + assert retrieved.activity_id == "persist" + + def test_find_by_input_ids(self, activity_store): + """Should find activities with matching inputs.""" + activity1 = Activity( + activity_id="a1", + input_ids=["x", "y"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity2 = Activity( + activity_id="a2", + input_ids=["y", "x"], # Same inputs, different order + output_id="o2", + intermediate_ids=[], + created_at=time.time(), + ) + activity3 = Activity( + activity_id="a3", + input_ids=["z"], # Different inputs + output_id="o3", + intermediate_ids=[], + created_at=time.time(), + ) + + activity_store.add(activity1) + activity_store.add(activity2) + activity_store.add(activity3) + + found = activity_store.find_by_input_ids(["x", "y"]) + assert len(found) == 2 + assert {a.activity_id for a in found} == {"a1", "a2"} + + def test_find_using_node(self, activity_store): + """Should find activities referencing a node.""" + activity = Activity( + activity_id="a1", + input_ids=["input1"], + output_id="output1", + intermediate_ids=["inter1"], + created_at=time.time(), + ) + activity_store.add(activity) + + # Should find by input + found = activity_store.find_using_node("input1") + assert len(found) == 1 + + # Should find by intermediate + found = activity_store.find_using_node("inter1") + assert len(found) == 1 + + # Should find by output + found = activity_store.find_using_node("output1") + assert len(found) == 1 + + # Should not find unknown + found = activity_store.find_using_node("unknown") + assert len(found) == 0 + + def test_remove(self, activity_store): + """Should remove activities.""" + activity = Activity( + activity_id="to_remove", + input_ids=["i"], + output_id="o", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + assert activity_store.get("to_remove") is not None + + result = activity_store.remove("to_remove") + assert result is True + assert activity_store.get("to_remove") is None + + +class TestActivityManager: + """Tests for ActivityManager deletion rules.""" + + def test_can_delete_orphaned_entry(self, manager, cache, temp_dir): + """Orphaned entries (not in any activity) can be deleted.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cache.put("orphan_node", test_file, "test") + + assert manager.can_delete_cache_entry("orphan_node") is True + + def test_cannot_delete_shared_entry(self, manager, cache, temp_dir, ap_store): + """Shared entries (ActivityPub published) cannot be deleted.""" + test_file = create_test_file(temp_dir / "shared.txt", "shared content") + cache.put("shared_node", test_file, "test") + + # Mark as shared + entry = cache.get_entry("shared_node") + ap_store.mark_shared(entry.cid) + + assert manager.can_delete_cache_entry("shared_node") is False + + def test_cannot_delete_activity_input(self, manager, cache, activity_store, temp_dir): + """Activity inputs cannot be deleted.""" + test_file = create_test_file(temp_dir / "input.txt", "input") + cache.put("input_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("input_node") is False + + def test_cannot_delete_activity_output(self, manager, cache, activity_store, temp_dir): + """Activity outputs cannot be deleted.""" + test_file = create_test_file(temp_dir / "output.txt", "output") + cache.put("output_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("output_node") is False + + def test_can_delete_intermediate(self, manager, cache, activity_store, temp_dir): + """Intermediate entries can be deleted (they're reconstructible).""" + test_file = create_test_file(temp_dir / "inter.txt", "intermediate") + cache.put("inter_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=["inter_node"], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("inter_node") is True + + def test_can_discard_activity_no_shared(self, manager, activity_store): + """Activity can be discarded if nothing is shared.""" + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=["m1"], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_discard_activity("a1") is True + + def test_cannot_discard_activity_with_shared_output(self, manager, cache, activity_store, temp_dir, ap_store): + """Activity cannot be discarded if output is shared.""" + test_file = create_test_file(temp_dir / "output.txt", "output content") + cache.put("o1", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + # Mark output as shared + entry = cache.get_entry("o1") + ap_store.mark_shared(entry.cid) + + assert manager.can_discard_activity("a1") is False + + def test_cannot_discard_activity_with_shared_input(self, manager, cache, activity_store, temp_dir, ap_store): + """Activity cannot be discarded if input is shared.""" + test_file = create_test_file(temp_dir / "input.txt", "input content") + cache.put("i1", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + entry = cache.get_entry("i1") + ap_store.mark_shared(entry.cid) + + assert manager.can_discard_activity("a1") is False + + def test_discard_activity_deletes_intermediates(self, manager, cache, activity_store, temp_dir): + """Discarding activity should delete intermediate cache entries.""" + # Create cache entries + input_file = create_test_file(temp_dir / "input.txt", "input") + inter_file = create_test_file(temp_dir / "inter.txt", "intermediate") + output_file = create_test_file(temp_dir / "output.txt", "output") + + cache.put("i1", input_file, "test") + cache.put("m1", inter_file, "test") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=["m1"], + created_at=time.time(), + ) + activity_store.add(activity) + + # Discard + result = manager.discard_activity("a1") + + assert result is True + assert cache.has("m1") is False # Intermediate deleted + assert activity_store.get("a1") is None # Activity removed + + def test_discard_activity_deletes_orphaned_output(self, manager, cache, activity_store, temp_dir): + """Discarding activity should delete output if orphaned.""" + output_file = create_test_file(temp_dir / "output.txt", "output") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=[], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + manager.discard_activity("a1") + + assert cache.has("o1") is False # Orphaned output deleted + + def test_discard_activity_keeps_shared_output(self, manager, cache, activity_store, temp_dir, ap_store): + """Discarding should fail if output is shared.""" + output_file = create_test_file(temp_dir / "output.txt", "shared output") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=[], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + entry = cache.get_entry("o1") + ap_store.mark_shared(entry.cid) + + result = manager.discard_activity("a1") + + assert result is False # Cannot discard + assert cache.has("o1") is True # Output preserved + assert activity_store.get("a1") is not None # Activity preserved + + def test_discard_keeps_input_used_elsewhere(self, manager, cache, activity_store, temp_dir): + """Input used by another activity should not be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "shared input") + cache.put("shared_input", input_file, "test") + + activity1 = Activity( + activity_id="a1", + input_ids=["shared_input"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity2 = Activity( + activity_id="a2", + input_ids=["shared_input"], + output_id="o2", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity1) + activity_store.add(activity2) + + manager.discard_activity("a1") + + # Input still used by a2 + assert cache.has("shared_input") is True + + def test_get_deletable_entries(self, manager, cache, activity_store, temp_dir): + """Should list all deletable entries.""" + # Orphan (deletable) + orphan = create_test_file(temp_dir / "orphan.txt", "orphan") + cache.put("orphan", orphan, "test") + + # Intermediate (deletable) + inter = create_test_file(temp_dir / "inter.txt", "inter") + cache.put("inter", inter, "test") + + # Input (not deletable) + inp = create_test_file(temp_dir / "input.txt", "input") + cache.put("input", inp, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input"], + output_id="output", + intermediate_ids=["inter"], + created_at=time.time(), + ) + activity_store.add(activity) + + deletable = manager.get_deletable_entries() + deletable_ids = {e.node_id for e in deletable} + + assert "orphan" in deletable_ids + assert "inter" in deletable_ids + assert "input" not in deletable_ids + + def test_cleanup_intermediates(self, manager, cache, activity_store, temp_dir): + """cleanup_intermediates() should delete all intermediate entries.""" + inter1 = create_test_file(temp_dir / "i1.txt", "inter1") + inter2 = create_test_file(temp_dir / "i2.txt", "inter2") + cache.put("inter1", inter1, "test") + cache.put("inter2", inter2, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input"], + output_id="output", + intermediate_ids=["inter1", "inter2"], + created_at=time.time(), + ) + activity_store.add(activity) + + deleted = manager.cleanup_intermediates() + + assert deleted == 2 + assert cache.has("inter1") is False + assert cache.has("inter2") is False + + +class TestMakeIsSharedFn: + """Tests for make_is_shared_fn factory.""" + + def test_returns_true_for_shared(self, ap_store): + """Should return True for shared content.""" + is_shared = make_is_shared_fn(ap_store) + ap_store.mark_shared("hash123") + + assert is_shared("hash123") is True + + def test_returns_false_for_not_shared(self, ap_store): + """Should return False for non-shared content.""" + is_shared = make_is_shared_fn(ap_store) + + assert is_shared("unknown_hash") is False diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..2aac235 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,163 @@ +# tests/test_primitive_new/test_cache.py +"""Tests for primitive cache module.""" + +import pytest +import tempfile +from pathlib import Path + +from artdag.cache import Cache, CacheStats + + +@pytest.fixture +def cache_dir(): + """Create temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def cache(cache_dir): + """Create cache instance.""" + return Cache(cache_dir) + + +@pytest.fixture +def sample_file(cache_dir): + """Create a sample file to cache.""" + file_path = cache_dir / "sample.txt" + file_path.write_text("test content") + return file_path + + +class TestCache: + """Test Cache class.""" + + def test_cache_creation(self, cache_dir): + """Test cache directory is created.""" + cache = Cache(cache_dir / "new_cache") + assert cache.cache_dir.exists() + + def test_cache_put_and_get(self, cache, sample_file): + """Test putting and getting from cache.""" + node_id = "abc123" + cached_path = cache.put(node_id, sample_file, "TEST") + + assert cached_path.exists() + assert cache.has(node_id) + + retrieved = cache.get(node_id) + assert retrieved == cached_path + + def test_cache_miss(self, cache): + """Test cache miss returns None.""" + result = cache.get("nonexistent") + assert result is None + + def test_cache_stats_hit_miss(self, cache, sample_file): + """Test cache hit/miss stats.""" + cache.put("abc123", sample_file, "TEST") + + # Miss + cache.get("nonexistent") + assert cache.stats.misses == 1 + + # Hit + cache.get("abc123") + assert cache.stats.hits == 1 + + assert cache.stats.hit_rate == 0.5 + + def test_cache_remove(self, cache, sample_file): + """Test removing from cache.""" + node_id = "abc123" + cache.put(node_id, sample_file, "TEST") + assert cache.has(node_id) + + cache.remove(node_id) + assert not cache.has(node_id) + + def test_cache_clear(self, cache, sample_file): + """Test clearing cache.""" + cache.put("node1", sample_file, "TEST") + cache.put("node2", sample_file, "TEST") + + assert cache.stats.total_entries == 2 + + cache.clear() + + assert cache.stats.total_entries == 0 + assert not cache.has("node1") + assert not cache.has("node2") + + def test_cache_preserves_extension(self, cache, cache_dir): + """Test that cache preserves file extension.""" + mp4_file = cache_dir / "video.mp4" + mp4_file.write_text("fake video") + + cached = cache.put("video_node", mp4_file, "SOURCE") + assert cached.suffix == ".mp4" + + def test_cache_list_entries(self, cache, sample_file): + """Test listing cache entries.""" + cache.put("node1", sample_file, "TYPE1") + cache.put("node2", sample_file, "TYPE2") + + entries = cache.list_entries() + assert len(entries) == 2 + + node_ids = {e.node_id for e in entries} + assert "node1" in node_ids + assert "node2" in node_ids + + def test_cache_persistence(self, cache_dir, sample_file): + """Test cache persists across instances.""" + # First instance + cache1 = Cache(cache_dir) + cache1.put("abc123", sample_file, "TEST") + + # Second instance loads from disk + cache2 = Cache(cache_dir) + assert cache2.has("abc123") + + def test_cache_prune_by_age(self, cache, sample_file): + """Test pruning by age.""" + import time + + cache.put("old_node", sample_file, "TEST") + + # Manually set old creation time + entry = cache._entries["old_node"] + entry.created_at = time.time() - 3600 # 1 hour ago + + removed = cache.prune(max_age_seconds=1800) # 30 minutes + + assert removed == 1 + assert not cache.has("old_node") + + def test_cache_output_path(self, cache): + """Test getting output path for node.""" + path = cache.get_output_path("abc123", ".mp4") + assert path.suffix == ".mp4" + assert "abc123" in str(path) + assert path.parent.exists() + + +class TestCacheStats: + """Test CacheStats class.""" + + def test_hit_rate_calculation(self): + """Test hit rate calculation.""" + stats = CacheStats() + + stats.record_hit() + stats.record_hit() + stats.record_miss() + + assert stats.hits == 2 + assert stats.misses == 1 + assert abs(stats.hit_rate - 0.666) < 0.01 + + def test_initial_hit_rate(self): + """Test hit rate with no requests.""" + stats = CacheStats() + assert stats.hit_rate == 0.0 diff --git a/tests/test_dag.py b/tests/test_dag.py new file mode 100644 index 0000000..48250c6 --- /dev/null +++ b/tests/test_dag.py @@ -0,0 +1,271 @@ +# tests/test_primitive_new/test_dag.py +"""Tests for primitive DAG data structures.""" + +import pytest +from artdag.dag import Node, NodeType, DAG, DAGBuilder + + +class TestNode: + """Test Node class.""" + + def test_node_creation(self): + """Test basic node creation.""" + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + assert node.node_type == NodeType.SOURCE + assert node.config == {"path": "/test.mp4"} + assert node.node_id is not None + + def test_node_id_is_content_addressed(self): + """Same content produces same node_id.""" + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + assert node1.node_id == node2.node_id + + def test_different_config_different_id(self): + """Different config produces different node_id.""" + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test1.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test2.mp4"}) + assert node1.node_id != node2.node_id + + def test_node_with_inputs(self): + """Node with inputs includes them in ID.""" + node1 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"]) + node2 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"]) + node3 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["def456"]) + + assert node1.node_id == node2.node_id + assert node1.node_id != node3.node_id + + def test_node_serialization(self): + """Test node to_dict and from_dict.""" + original = Node( + node_type=NodeType.SEGMENT, + config={"duration": 5.0, "offset": 10.0}, + inputs=["abc123"], + name="my_segment", + ) + data = original.to_dict() + restored = Node.from_dict(data) + + assert restored.node_type == original.node_type + assert restored.config == original.config + assert restored.inputs == original.inputs + assert restored.name == original.name + assert restored.node_id == original.node_id + + def test_custom_node_type(self): + """Test node with custom string type.""" + node = Node(node_type="CUSTOM_TYPE", config={"custom": True}) + assert node.node_type == "CUSTOM_TYPE" + assert node.node_id is not None + + +class TestDAG: + """Test DAG class.""" + + def test_dag_creation(self): + """Test basic DAG creation.""" + dag = DAG() + assert len(dag.nodes) == 0 + assert dag.output_id is None + + def test_add_node(self): + """Test adding nodes to DAG.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + + assert node_id in dag.nodes + assert dag.nodes[node_id] == node + + def test_node_deduplication(self): + """Same node added twice returns same ID.""" + dag = DAG() + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + + id1 = dag.add_node(node1) + id2 = dag.add_node(node2) + + assert id1 == id2 + assert len(dag.nodes) == 1 + + def test_set_output(self): + """Test setting output node.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + assert dag.output_id == node_id + + def test_set_output_invalid(self): + """Setting invalid output raises error.""" + dag = DAG() + with pytest.raises(ValueError): + dag.set_output("nonexistent") + + def test_topological_order(self): + """Test topological ordering.""" + dag = DAG() + + # Create simple chain: source -> segment -> output + source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + source_id = dag.add_node(source) + + segment = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=[source_id]) + segment_id = dag.add_node(segment) + + dag.set_output(segment_id) + order = dag.topological_order() + + # Source must come before segment + assert order.index(source_id) < order.index(segment_id) + + def test_validate_valid_dag(self): + """Test validation of valid DAG.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + errors = dag.validate() + assert len(errors) == 0 + + def test_validate_no_output(self): + """DAG without output is invalid.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + dag.add_node(node) + + errors = dag.validate() + assert len(errors) > 0 + assert any("output" in e.lower() for e in errors) + + def test_validate_missing_input(self): + """DAG with missing input reference is invalid.""" + dag = DAG() + node = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["nonexistent"]) + node_id = dag.add_node(node) + dag.set_output(node_id) + + errors = dag.validate() + assert len(errors) > 0 + assert any("missing" in e.lower() for e in errors) + + def test_dag_serialization(self): + """Test DAG to_dict and from_dict.""" + dag = DAG(metadata={"name": "test_dag"}) + source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + source_id = dag.add_node(source) + dag.set_output(source_id) + + data = dag.to_dict() + restored = DAG.from_dict(data) + + assert len(restored.nodes) == len(dag.nodes) + assert restored.output_id == dag.output_id + assert restored.metadata == dag.metadata + + def test_dag_json(self): + """Test DAG JSON serialization.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + json_str = dag.to_json() + restored = DAG.from_json(json_str) + + assert len(restored.nodes) == 1 + assert restored.output_id == node_id + + +class TestDAGBuilder: + """Test DAGBuilder class.""" + + def test_builder_source(self): + """Test building source node.""" + builder = DAGBuilder() + source_id = builder.source("/test.mp4") + + assert source_id in builder.dag.nodes + node = builder.dag.nodes[source_id] + assert node.node_type == NodeType.SOURCE + assert node.config["path"] == "/test.mp4" + + def test_builder_segment(self): + """Test building segment node.""" + builder = DAGBuilder() + source_id = builder.source("/test.mp4") + segment_id = builder.segment(source_id, duration=5.0, offset=10.0) + + node = builder.dag.nodes[segment_id] + assert node.node_type == NodeType.SEGMENT + assert node.config["duration"] == 5.0 + assert node.config["offset"] == 10.0 + assert source_id in node.inputs + + def test_builder_chain(self): + """Test building a chain of nodes.""" + builder = DAGBuilder() + source = builder.source("/test.mp4") + segment = builder.segment(source, duration=5.0) + resized = builder.resize(segment, width=1920, height=1080) + builder.set_output(resized) + + dag = builder.build() + + assert len(dag.nodes) == 3 + assert dag.output_id == resized + errors = dag.validate() + assert len(errors) == 0 + + def test_builder_sequence(self): + """Test building sequence node.""" + builder = DAGBuilder() + s1 = builder.source("/clip1.mp4") + s2 = builder.source("/clip2.mp4") + seq = builder.sequence([s1, s2], transition={"type": "crossfade", "duration": 0.5}) + builder.set_output(seq) + + dag = builder.build() + node = dag.nodes[seq] + assert node.node_type == NodeType.SEQUENCE + assert s1 in node.inputs + assert s2 in node.inputs + + def test_builder_mux(self): + """Test building mux node.""" + builder = DAGBuilder() + video = builder.source("/video.mp4") + audio = builder.source("/audio.mp3") + muxed = builder.mux(video, audio) + builder.set_output(muxed) + + dag = builder.build() + node = dag.nodes[muxed] + assert node.node_type == NodeType.MUX + assert video in node.inputs + assert audio in node.inputs + + def test_builder_transform(self): + """Test building transform node.""" + builder = DAGBuilder() + source = builder.source("/test.mp4") + transformed = builder.transform(source, effects={"saturation": 1.5, "contrast": 1.2}) + builder.set_output(transformed) + + dag = builder.build() + node = dag.nodes[transformed] + assert node.node_type == NodeType.TRANSFORM + assert node.config["effects"]["saturation"] == 1.5 + + def test_builder_validation_fails(self): + """Builder raises error for invalid DAG.""" + builder = DAGBuilder() + builder.source("/test.mp4") + # No output set + + with pytest.raises(ValueError): + builder.build() diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..b6e5a95 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,464 @@ +# tests/test_primitive_new/test_engine.py +"""Tests for primitive engine execution.""" + +import pytest +import subprocess +import tempfile +from pathlib import Path + +from artdag.dag import DAG, DAGBuilder, Node, NodeType +from artdag.engine import Engine +from artdag import nodes # Register executors + + +@pytest.fixture +def cache_dir(): + """Create temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def engine(cache_dir): + """Create engine instance.""" + return Engine(cache_dir) + + +@pytest.fixture +def test_video(cache_dir): + """Create a test video file.""" + video_path = cache_dir / "test_video.mp4" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "testsrc=duration=5:size=320x240:rate=30", + "-f", "lavfi", "-i", "sine=frequency=440:duration=5", + "-c:v", "libx264", "-preset", "ultrafast", + "-c:a", "aac", + str(video_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + return video_path + + +@pytest.fixture +def test_audio(cache_dir): + """Create a test audio file.""" + audio_path = cache_dir / "test_audio.mp3" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=5", + "-c:a", "libmp3lame", + str(audio_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + return audio_path + + +class TestEngineBasic: + """Test basic engine functionality.""" + + def test_engine_creation(self, cache_dir): + """Test engine creation.""" + engine = Engine(cache_dir) + assert engine.cache is not None + + def test_invalid_dag(self, engine): + """Test executing invalid DAG.""" + dag = DAG() # No nodes, no output + result = engine.execute(dag) + + assert not result.success + assert "Invalid DAG" in result.error + + def test_missing_executor(self, engine): + """Test executing node with missing executor.""" + dag = DAG() + node = Node(node_type="UNKNOWN_TYPE", config={}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + result = engine.execute(dag) + + assert not result.success + assert "No executor" in result.error + + +class TestSourceExecutor: + """Test SOURCE node executor.""" + + def test_source_creates_symlink(self, engine, test_video): + """Test source node creates symlink.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + assert result.output_path.is_symlink() + + def test_source_missing_file(self, engine): + """Test source with missing file.""" + builder = DAGBuilder() + source = builder.source("/nonexistent/file.mp4") + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert not result.success + assert "not found" in result.error.lower() + + +class TestSegmentExecutor: + """Test SEGMENT node executor.""" + + def test_segment_duration(self, engine, test_video): + """Test segment extracts correct duration.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + segment = builder.segment(source, duration=2.0) + builder.set_output(segment) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify duration + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + assert abs(duration - 2.0) < 0.1 + + def test_segment_with_offset(self, engine, test_video): + """Test segment with offset.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + segment = builder.segment(source, offset=1.0, duration=2.0) + builder.set_output(segment) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + + +class TestResizeExecutor: + """Test RESIZE node executor.""" + + def test_resize_dimensions(self, engine, test_video): + """Test resize to specific dimensions.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + resized = builder.resize(source, width=640, height=480, mode="fit") + builder.set_output(resized) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify dimensions + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "stream=width,height", + "-of", "csv=p=0:s=x", + str(result.output_path) + ], capture_output=True, text=True) + dimensions = probe.stdout.strip().split("\n")[0] + assert "640x480" in dimensions + + +class TestTransformExecutor: + """Test TRANSFORM node executor.""" + + def test_transform_saturation(self, engine, test_video): + """Test transform with saturation effect.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + transformed = builder.transform(source, effects={"saturation": 1.5}) + builder.set_output(transformed) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + assert result.output_path.exists() + + def test_transform_multiple_effects(self, engine, test_video): + """Test transform with multiple effects.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + transformed = builder.transform(source, effects={ + "saturation": 1.2, + "contrast": 1.1, + "brightness": 0.05, + }) + builder.set_output(transformed) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + + +class TestSequenceExecutor: + """Test SEQUENCE node executor.""" + + def test_sequence_cut(self, engine, test_video): + """Test sequence with cut transition.""" + builder = DAGBuilder() + s1 = builder.source(str(test_video)) + seg1 = builder.segment(s1, duration=2.0) + seg2 = builder.segment(s1, offset=2.0, duration=2.0) + seq = builder.sequence([seg1, seg2], transition={"type": "cut"}) + builder.set_output(seq) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify combined duration + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + assert abs(duration - 4.0) < 0.2 + + def test_sequence_crossfade(self, engine, test_video): + """Test sequence with crossfade transition.""" + builder = DAGBuilder() + s1 = builder.source(str(test_video)) + seg1 = builder.segment(s1, duration=3.0) + seg2 = builder.segment(s1, offset=1.0, duration=3.0) + seq = builder.sequence([seg1, seg2], transition={"type": "crossfade", "duration": 0.5}) + builder.set_output(seq) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Duration should be sum minus crossfade + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + # 3 + 3 - 0.5 = 5.5 + assert abs(duration - 5.5) < 0.3 + + +class TestMuxExecutor: + """Test MUX node executor.""" + + def test_mux_video_audio(self, engine, test_video, test_audio): + """Test muxing video and audio.""" + builder = DAGBuilder() + video = builder.source(str(test_video)) + audio = builder.source(str(test_audio)) + muxed = builder.mux(video, audio) + builder.set_output(muxed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + +class TestAudioMixExecutor: + """Test AUDIO_MIX node executor.""" + + def test_audio_mix_simple(self, engine, cache_dir): + """Test simple audio mixing.""" + # Create two test audio files with different frequencies + audio1_path = cache_dir / "audio1.mp3" + audio2_path = cache_dir / "audio2.mp3" + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=440:duration=3", + "-c:a", "libmp3lame", + str(audio1_path) + ], capture_output=True, check=True) + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=3", + "-c:a", "libmp3lame", + str(audio2_path) + ], capture_output=True, check=True) + + builder = DAGBuilder() + a1 = builder.source(str(audio1_path)) + a2 = builder.source(str(audio2_path)) + mixed = builder.audio_mix([a1, a2]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + def test_audio_mix_with_gains(self, engine, cache_dir): + """Test audio mixing with custom gains.""" + audio1_path = cache_dir / "audio1.mp3" + audio2_path = cache_dir / "audio2.mp3" + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=440:duration=3", + "-c:a", "libmp3lame", + str(audio1_path) + ], capture_output=True, check=True) + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=3", + "-c:a", "libmp3lame", + str(audio2_path) + ], capture_output=True, check=True) + + builder = DAGBuilder() + a1 = builder.source(str(audio1_path)) + a2 = builder.source(str(audio2_path)) + mixed = builder.audio_mix([a1, a2], gains=[1.0, 0.3]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + def test_audio_mix_three_inputs(self, engine, cache_dir): + """Test mixing three audio sources.""" + audio_paths = [] + for i, freq in enumerate([440, 660, 880]): + path = cache_dir / f"audio{i}.mp3" + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", f"sine=frequency={freq}:duration=2", + "-c:a", "libmp3lame", + str(path) + ], capture_output=True, check=True) + audio_paths.append(path) + + builder = DAGBuilder() + sources = [builder.source(str(p)) for p in audio_paths] + mixed = builder.audio_mix(sources, gains=[1.0, 0.5, 0.3]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + +class TestCaching: + """Test engine caching behavior.""" + + def test_cache_reuse(self, engine, test_video): + """Test that cached results are reused.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + # First execution + result1 = engine.execute(dag) + assert result1.success + assert result1.nodes_cached == 0 + assert result1.nodes_executed == 1 + + # Second execution should use cache + result2 = engine.execute(dag) + assert result2.success + assert result2.nodes_cached == 1 + assert result2.nodes_executed == 0 + + def test_clear_cache(self, engine, test_video): + """Test clearing cache.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + engine.execute(dag) + assert engine.cache.stats.total_entries == 1 + + engine.clear_cache() + assert engine.cache.stats.total_entries == 0 + + +class TestProgressCallback: + """Test progress callback functionality.""" + + def test_progress_callback(self, engine, test_video): + """Test that progress callback is called.""" + progress_updates = [] + + def callback(progress): + progress_updates.append((progress.node_id, progress.status)) + + engine.set_progress_callback(callback) + + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert len(progress_updates) > 0 + # Should have pending, running, completed + statuses = [p[1] for p in progress_updates] + assert "pending" in statuses + assert "completed" in statuses + + +class TestFullWorkflow: + """Test complete workflow.""" + + def test_full_pipeline(self, engine, test_video, test_audio): + """Test complete video processing pipeline.""" + builder = DAGBuilder() + + # Load sources + video = builder.source(str(test_video)) + audio = builder.source(str(test_audio)) + + # Extract segment + segment = builder.segment(video, duration=3.0) + + # Resize + resized = builder.resize(segment, width=640, height=480) + + # Apply effects + transformed = builder.transform(resized, effects={"saturation": 1.3}) + + # Mux with audio + final = builder.mux(transformed, audio) + builder.set_output(final) + + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + assert result.nodes_executed == 6 # source, source, segment, resize, transform, mux diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..5149554 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,110 @@ +# tests/test_primitive_new/test_executor.py +"""Tests for primitive executor module.""" + +import pytest +from pathlib import Path +from typing import Any, Dict, List + +from artdag.dag import NodeType +from artdag.executor import ( + Executor, + register_executor, + get_executor, + list_executors, + clear_executors, +) + + +class TestExecutorRegistry: + """Test executor registration.""" + + def setup_method(self): + """Clear registry before each test.""" + clear_executors() + + def teardown_method(self): + """Clear registry after each test.""" + clear_executors() + + def test_register_executor(self): + """Test registering an executor.""" + @register_executor(NodeType.SOURCE) + class TestSourceExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = get_executor(NodeType.SOURCE) + assert executor is not None + assert isinstance(executor, TestSourceExecutor) + + def test_register_custom_type(self): + """Test registering executor for custom type.""" + @register_executor("CUSTOM_NODE") + class CustomExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = get_executor("CUSTOM_NODE") + assert executor is not None + + def test_get_unregistered(self): + """Test getting unregistered executor.""" + executor = get_executor(NodeType.ANALYZE) + assert executor is None + + def test_list_executors(self): + """Test listing registered executors.""" + @register_executor(NodeType.SOURCE) + class SourceExec(Executor): + def execute(self, config, inputs, output_path): + return output_path + + @register_executor(NodeType.SEGMENT) + class SegmentExec(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executors = list_executors() + assert "SOURCE" in executors + assert "SEGMENT" in executors + + def test_overwrite_warning(self, caplog): + """Test warning when overwriting executor.""" + @register_executor(NodeType.SOURCE) + class FirstExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + # Register again - should warn + @register_executor(NodeType.SOURCE) + class SecondExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + # Second should be registered + executor = get_executor(NodeType.SOURCE) + assert isinstance(executor, SecondExecutor) + + +class TestExecutorBase: + """Test Executor base class.""" + + def test_validate_config_default(self): + """Test default validate_config returns empty list.""" + class TestExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = TestExecutor() + errors = executor.validate_config({"any": "config"}) + assert errors == [] + + def test_estimate_output_size(self): + """Test default output size estimation.""" + class TestExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = TestExecutor() + size = executor.estimate_output_size({}, [100, 200, 300]) + assert size == 600 diff --git a/tests/test_ipfs_access.py b/tests/test_ipfs_access.py new file mode 100644 index 0000000..33795cb --- /dev/null +++ b/tests/test_ipfs_access.py @@ -0,0 +1,301 @@ +""" +Tests for IPFS access consistency. + +All IPFS access should use IPFS_API (multiaddr format) for consistency +with art-celery's ipfs_client.py. This ensures Docker deployments work +correctly since IPFS_API is set to /dns/ipfs/tcp/5001. +""" + +import os +import re +from pathlib import Path +from typing import Optional +from unittest.mock import patch, MagicMock + +import pytest + + +def multiaddr_to_url(multiaddr: str) -> str: + """ + Convert IPFS multiaddr to HTTP URL. + + This is the canonical conversion used by ipfs_client.py. + """ + # Handle /dns/hostname/tcp/port format + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +class TestMultiaddrConversion: + """Tests for multiaddr to URL conversion.""" + + def test_dns_format(self) -> None: + """Docker DNS format should convert correctly.""" + result = multiaddr_to_url("/dns/ipfs/tcp/5001") + assert result == "http://ipfs:5001" + + def test_dns4_format(self) -> None: + """dns4 format should work.""" + result = multiaddr_to_url("/dns4/ipfs.example.com/tcp/5001") + assert result == "http://ipfs.example.com:5001" + + def test_ip4_format(self) -> None: + """IPv4 format should convert correctly.""" + result = multiaddr_to_url("/ip4/127.0.0.1/tcp/5001") + assert result == "http://127.0.0.1:5001" + + def test_already_url(self) -> None: + """HTTP URLs should pass through.""" + result = multiaddr_to_url("http://localhost:5001") + assert result == "http://localhost:5001" + + def test_fallback(self) -> None: + """Unknown format should fallback to localhost.""" + result = multiaddr_to_url("garbage") + assert result == "http://127.0.0.1:5001" + + +class TestIPFSConfigConsistency: + """ + Tests to ensure IPFS configuration is consistent. + + The effect executor should use IPFS_API (like ipfs_client.py) + rather than a separate IPFS_GATEWAY variable. + """ + + def test_effect_module_should_not_use_gateway_var(self) -> None: + """ + Regression test: Effect module should use IPFS_API, not IPFS_GATEWAY. + + Bug found 2026-01-12: artdag/nodes/effect.py used IPFS_GATEWAY which + defaulted to http://127.0.0.1:8080. This doesn't work in Docker where + the IPFS node is a separate container. The ipfs_client.py uses IPFS_API + which is correctly set in docker-compose. + """ + from artdag.nodes import effect + + # Check if the module still has the old IPFS_GATEWAY variable + # After the fix, this should use IPFS_API instead + has_gateway_var = hasattr(effect, 'IPFS_GATEWAY') + has_api_var = hasattr(effect, 'IPFS_API') or hasattr(effect, '_get_ipfs_base_url') + + # This test documents the current buggy state + # After fix: has_gateway_var should be False, has_api_var should be True + if has_gateway_var and not has_api_var: + pytest.fail( + "Effect module uses IPFS_GATEWAY instead of IPFS_API. " + "This breaks Docker deployments where IPFS_API=/dns/ipfs/tcp/5001 " + "but IPFS_GATEWAY defaults to localhost." + ) + + def test_ipfs_api_default_is_localhost(self) -> None: + """IPFS_API should default to localhost for local development.""" + default_api = "/ip4/127.0.0.1/tcp/5001" + url = multiaddr_to_url(default_api) + assert "127.0.0.1" in url + assert "5001" in url + + def test_docker_ipfs_api_uses_service_name(self) -> None: + """In Docker, IPFS_API should use the service name.""" + docker_api = "/dns/ipfs/tcp/5001" + url = multiaddr_to_url(docker_api) + assert url == "http://ipfs:5001" + assert "127.0.0.1" not in url + + +class TestEffectFetchURL: + """Tests for the URL used to fetch effects from IPFS.""" + + def test_fetch_should_use_api_cat_endpoint(self) -> None: + """ + Effect fetch should use /api/v0/cat endpoint (like ipfs_client.py). + + The IPFS API's cat endpoint works reliably in Docker. + The gateway endpoint (port 8080) requires separate configuration. + """ + # The correct way to fetch via API + base_url = "http://ipfs:5001" + cid = "QmTestCid123" + correct_url = f"{base_url}/api/v0/cat?arg={cid}" + + assert "/api/v0/cat" in correct_url + assert "arg=" in correct_url + + def test_gateway_url_is_different_from_api(self) -> None: + """ + Document the difference between gateway and API URLs. + + Gateway: http://ipfs:8080/ipfs/{cid} (requires IPFS_GATEWAY config) + API: http://ipfs:5001/api/v0/cat?arg={cid} (uses IPFS_API config) + + Using the API is more reliable since IPFS_API is already configured + correctly in docker-compose.yml. + """ + cid = "QmTestCid123" + + # Gateway style (the old broken way) + gateway_url = f"http://ipfs:8080/ipfs/{cid}" + + # API style (the correct way) + api_url = f"http://ipfs:5001/api/v0/cat?arg={cid}" + + # These are different approaches + assert gateway_url != api_url + assert ":8080" in gateway_url + assert ":5001" in api_url + + +class TestEffectDependencies: + """Tests for effect dependency handling.""" + + def test_parse_pep723_dependencies(self) -> None: + """Should parse PEP 723 dependencies from effect source.""" + source = ''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["numpy", "opencv-python"] +# /// +""" +@effect test_effect +""" + +def process_frame(frame, params, state): + return frame, state +''' + # Import the function after the fix is applied + from artdag.nodes.effect import _parse_pep723_dependencies + + deps = _parse_pep723_dependencies(source) + + assert deps == ["numpy", "opencv-python"] + + def test_parse_pep723_no_dependencies(self) -> None: + """Should return empty list if no dependencies block.""" + source = ''' +""" +@effect simple_effect +""" + +def process_frame(frame, params, state): + return frame, state +''' + from artdag.nodes.effect import _parse_pep723_dependencies + + deps = _parse_pep723_dependencies(source) + + assert deps == [] + + def test_ensure_dependencies_already_installed(self) -> None: + """Should return True if dependencies are already installed.""" + from artdag.nodes.effect import _ensure_dependencies + + # os is always available + result = _ensure_dependencies(["os"], "QmTest123") + + assert result is True + + def test_effect_with_missing_dependency_gives_clear_error(self, tmp_path: Path) -> None: + """ + Regression test: Missing dependencies should give clear error message. + + Bug found 2026-01-12: Effect with numpy dependency failed with + "No module named 'numpy'" but this was swallowed and reported as + "Unknown effect: invert" - very confusing. + """ + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffectWithDeps" + + # Create effect that imports a non-existent module + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["some_nonexistent_package_xyz"] +# /// +""" +@effect test_effect +""" +import some_nonexistent_package_xyz + +def process_frame(frame, params, state): + return frame, state +''') + + # The effect file exists + effect_path = effects_dir / effect_cid / "effect.py" + assert effect_path.exists() + + # When loading fails due to missing import, error should mention the dependency + with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}): + from artdag.nodes.effect import _load_cached_effect + + # This should return None but log a clear error about the missing module + result = _load_cached_effect(effect_cid) + + # Currently returns None, which causes "Unknown effect" error + # The real issue is the dependency isn't installed + assert result is None + + +class TestEffectCacheAndFetch: + """Integration tests for effect caching and fetching.""" + + def test_effect_loads_from_cache_without_ipfs(self, tmp_path: Path) -> None: + """When effect is in cache, IPFS should not be contacted.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffect123" + + # Create cached effect + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +def process_frame(frame, params, state): + return frame, state +''') + + # Patch environment and verify effect can be loaded + with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}): + from artdag.nodes.effect import _load_cached_effect + + # Should load without hitting IPFS + effect_fn = _load_cached_effect(effect_cid) + assert effect_fn is not None + + def test_effect_fetch_uses_correct_endpoint(self, tmp_path: Path) -> None: + """When fetching from IPFS, should use API endpoint.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + effect_cid = "QmNonExistentEffect" + + with patch.dict(os.environ, { + "CACHE_DIR": str(tmp_path), + "IPFS_API": "/dns/ipfs/tcp/5001" + }): + with patch('requests.post') as mock_post: + # Set up mock to return effect source + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'def process_frame(f, p, s): return f, s' + mock_post.return_value = mock_response + + from artdag.nodes.effect import _load_cached_effect + + # Try to load - should attempt IPFS fetch + _load_cached_effect(effect_cid) + + # After fix, this should use the API endpoint + # Check if requests.post was called (API style) + # or requests.get was called (gateway style) + # The fix should make it use POST to /api/v0/cat From 7784e6b2b0b6eed6e3ee677f260d6141d9b4c24c Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:09:47 +0000 Subject: [PATCH 06/24] Squashed 'client/' content from commit 4bb0841 git-subtree-dir: client git-subtree-split: 4bb084154a4eb4b4f580d52d936cab05ef313ebb --- .gitignore | 5 + README.md | 263 +++++ artdag.py | 2354 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 + test_gpu_effects.sexp | 38 + test_simple.sexp | 26 + 6 files changed, 2689 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100755 artdag.py create mode 100644 requirements.txt create mode 100644 test_gpu_effects.sexp create mode 100644 test_simple.sexp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b5e701c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*.py[cod] +.venv/ +venv/ +.scripts diff --git a/README.md b/README.md new file mode 100644 index 0000000..80a8268 --- /dev/null +++ b/README.md @@ -0,0 +1,263 @@ +# Art DAG Client + +CLI for interacting with the Art DAG L1 rendering server. + +## Setup + +```bash +pip install -r requirements.txt +``` + +## Configuration + +```bash +# Set L1 server URL (default: http://localhost:8100) +export ARTDAG_SERVER=http://localhost:8100 + +# Set L2 server URL for auth (default: http://localhost:8200) +export ARTDAG_L2=https://artdag.rose-ash.com + +# Or pass with commands +./artdag.py --server http://localhost:8100 --l2 https://artdag.rose-ash.com +``` + +## Authentication + +Most commands require authentication. Login credentials are stored locally in `~/.artdag/token.json`. + +```bash +# Register a new account +artdag register [--email user@example.com] + +# Login +artdag login + +# Check current user +artdag whoami + +# Logout +artdag logout +``` + +## Commands Reference + +### Server & Stats + +```bash +# Show server info +artdag info + +# Show user stats (counts of runs, recipes, effects, media, storage) +artdag stats + +# List known named assets +artdag assets +``` + +### Runs + +```bash +# List runs (with pagination) +artdag runs [--limit N] [--offset N] + +# Start a run +artdag run [--name output_name] [--wait] + +# Get run status +artdag status + +# Get detailed run info +artdag status --plan # Show execution plan with steps +artdag status --artifacts # Show output artifacts +artdag status --analysis # Show audio analysis data + +# Delete a run +artdag delete-run [--force] +``` + +### Recipes + +```bash +# List recipes (with pagination) +artdag recipes [--limit N] [--offset N] + +# Show recipe details +artdag recipe + +# Upload a recipe (YAML or S-expression) +artdag upload-recipe + +# Run a recipe with inputs +artdag run-recipe -i node_id:cid [--wait] + +# Delete a recipe +artdag delete-recipe [--force] +``` + +### Effects + +```bash +# List effects (with pagination) +artdag effects [--limit N] [--offset N] + +# Show effect details +artdag effect + +# Show effect with source code +artdag effect --source + +# Upload an effect (.py file) +artdag upload-effect +``` + +### Media / Cache + +```bash +# List cached content (with pagination and type filter) +artdag cache [--limit N] [--offset N] [--type all|image|video|audio] + +# View/download cached content +artdag view # Show metadata (size, type, friendly name) +artdag view --raw # Get raw content info +artdag view -o output.mp4 # Download raw file +artdag view -o - | mpv - # Pipe raw content to player + +# Upload file to cache and IPFS +artdag upload + +# Import local file to cache (local server only) +artdag import + +# View/update metadata +artdag meta # View metadata +artdag meta -d "Description" # Set description +artdag meta -t "tag1,tag2" # Set tags +artdag meta --publish "my-video" # Publish to L2 + +# Delete cached content +artdag delete-cache [--force] +``` + +### Storage Providers + +```bash +# List storage providers +artdag storage list + +# Add a provider (interactive) +artdag storage add [--name friendly_name] [--capacity GB] +# Types: pinata, web3storage, nftstorage, infura, filebase, storj, local + +# Test provider connectivity +artdag storage test + +# Delete a provider +artdag storage delete [--force] +``` + +### Folders & Collections + +```bash +# Folders +artdag folder list +artdag folder create +artdag folder delete + +# Collections +artdag collection list +artdag collection create +artdag collection delete +``` + +### v2 API (3-Phase Execution) + +```bash +# Generate execution plan +artdag plan -i name:cid [--features beats,energy] [--output plan.json] + +# Execute a plan +artdag execute-plan [--wait] + +# Run recipe (plan + execute in one step) +artdag run-v2 -i name:cid [--wait] + +# Check v2 run status +artdag run-status +``` + +### Publishing to L2 + +```bash +# Publish a run output to L2 +artdag publish +``` + +### Data Management + +```bash +# Clear all user data (preserves storage configs) +artdag clear-data [--force] +``` + +## Example Workflows + +### Basic Rendering + +```bash +# Login +artdag login myuser + +# Check available assets +artdag assets + +# Run an effect on an input +artdag run dog cat --wait + +# View runs +artdag runs + +# Download result +artdag view -o result.mp4 +``` + +### Recipe-Based Processing + +```bash +# Upload a recipe +artdag upload-recipe my-recipe.yaml + +# View recipes +artdag recipes + +# Run with inputs +artdag run-recipe -i video:bafkrei... --wait + +# View run plan +artdag status --plan +``` + +### Managing Storage + +```bash +# Add Pinata storage +artdag storage add pinata --name "My Pinata" + +# Test connection +artdag storage test 1 + +# View all providers +artdag storage list +``` + +### Browsing Media + +```bash +# List all media +artdag cache + +# Filter by type +artdag cache --type video --limit 20 + +# View with pagination +artdag cache --offset 20 --limit 20 +``` diff --git a/artdag.py b/artdag.py new file mode 100755 index 0000000..d28df4c --- /dev/null +++ b/artdag.py @@ -0,0 +1,2354 @@ +#!/usr/bin/env python3 +""" +Art DAG Client + +CLI for interacting with the Art DAG L1 server. +""" + +import json +import os +import sys +import time +from pathlib import Path + +import click +import requests +import yaml + +CONFIG_DIR = Path.home() / ".artdag" +TOKEN_FILE = CONFIG_DIR / "token.json" +CONFIG_FILE = CONFIG_DIR / "config.json" + +# Defaults - can be overridden by env vars, config file, or CLI args +_DEFAULT_SERVER = "http://localhost:8100" +_DEFAULT_L2_SERVER = "http://localhost:8200" + +# Active server URLs (set during CLI init) +DEFAULT_SERVER = None +DEFAULT_L2_SERVER = None + + +def load_config() -> dict: + """Load saved config (server URLs, etc.).""" + if CONFIG_FILE.exists(): + try: + with open(CONFIG_FILE) as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return {} + return {} + + +def save_config(config: dict): + """Save config to file.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + with open(CONFIG_FILE, "w") as f: + json.dump(config, f, indent=2) + + +def get_server(): + """Get server URL.""" + return DEFAULT_SERVER + + +def get_l2_server(): + """Get L2 server URL.""" + return DEFAULT_L2_SERVER + + +def load_token() -> dict: + """Load saved token from config.""" + if TOKEN_FILE.exists(): + with open(TOKEN_FILE) as f: + return json.load(f) + return {} + + +def save_token(token_data: dict): + """Save token to config.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + with open(TOKEN_FILE, "w") as f: + json.dump(token_data, f, indent=2) + TOKEN_FILE.chmod(0o600) + + +def clear_token(): + """Clear saved token.""" + if TOKEN_FILE.exists(): + TOKEN_FILE.unlink() + + +def get_auth_header(require_token: bool = False) -> dict: + """Get headers for API requests. Always includes Accept: application/json.""" + headers = {"Accept": "application/json"} + token_data = load_token() + token = token_data.get("access_token") + if token: + headers["Authorization"] = f"Bearer {token}" + elif require_token: + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + return headers + + +def api_get(path: str, auth: bool = False): + """GET request to server.""" + headers = get_auth_header(require_token=auth) + resp = requests.get(f"{get_server()}{path}", headers=headers) + resp.raise_for_status() + return resp.json() + + +def api_post(path: str, data: dict = None, params: dict = None, auth: bool = False): + """POST request to server.""" + headers = get_auth_header(require_token=auth) + resp = requests.post(f"{get_server()}{path}", json=data, params=params, headers=headers) + resp.raise_for_status() + return resp.json() + + +def _get_default_server(): + """Get default server from env, config, or builtin default.""" + if os.environ.get("ARTDAG_SERVER"): + return os.environ["ARTDAG_SERVER"] + config = load_config() + return config.get("server", _DEFAULT_SERVER) + + +def _get_default_l2(): + """Get default L2 server from env, config, or builtin default.""" + if os.environ.get("ARTDAG_L2"): + return os.environ["ARTDAG_L2"] + config = load_config() + return config.get("l2", _DEFAULT_L2_SERVER) + + +@click.group() +@click.option("--server", "-s", default=None, + help="L1 server URL (saved for future use)") +@click.option("--l2", default=None, + help="L2 server URL (saved for future use)") +@click.pass_context +def cli(ctx, server, l2): + """Art DAG Client - interact with L1 rendering server.""" + ctx.ensure_object(dict) + global DEFAULT_SERVER, DEFAULT_L2_SERVER + + config = load_config() + config_changed = False + + # Use provided value, or fall back to saved/default + if server: + DEFAULT_SERVER = server + if config.get("server") != server: + config["server"] = server + config_changed = True + else: + DEFAULT_SERVER = _get_default_server() + + if l2: + DEFAULT_L2_SERVER = l2 + if config.get("l2") != l2: + config["l2"] = l2 + config_changed = True + else: + DEFAULT_L2_SERVER = _get_default_l2() + + # Save config if changed + if config_changed: + save_config(config) + + ctx.obj["server"] = DEFAULT_SERVER + ctx.obj["l2"] = DEFAULT_L2_SERVER + + +# ============ Auth Commands ============ + +@cli.command() +@click.argument("username") +@click.option("--password", "-p", prompt=True, hide_input=True) +def login(username, password): + """Login to get access token.""" + try: + # Server expects form data, not JSON + resp = requests.post( + f"{get_l2_server()}/auth/login", + data={"username": username, "password": password} + ) + if resp.status_code == 200: + # Check if we got a token back in a cookie + if "auth_token" in resp.cookies: + token = resp.cookies["auth_token"] + # Decode token to get username and expiry + import base64 + try: + # JWT format: header.payload.signature + payload = token.split(".")[1] + # Add padding if needed + payload += "=" * (4 - len(payload) % 4) + decoded = json.loads(base64.urlsafe_b64decode(payload)) + token_data = { + "access_token": token, + "username": decoded.get("username", username), + "expires_at": decoded.get("exp", "") + } + save_token(token_data) + click.echo(f"Logged in as {token_data['username']}") + if token_data.get("expires_at"): + click.echo(f"Token expires: {token_data['expires_at']}") + except Exception: + # If we can't decode, just save the token + save_token({"access_token": token, "username": username}) + click.echo(f"Logged in as {username}") + else: + # HTML response - check for success/error + if "successful" in resp.text.lower(): + click.echo(f"Login successful but no token received. Try logging in via web browser.") + elif "invalid" in resp.text.lower(): + click.echo(f"Login failed: Invalid username or password", err=True) + sys.exit(1) + else: + click.echo(f"Login failed: {resp.text}", err=True) + sys.exit(1) + else: + click.echo(f"Login failed: {resp.text}", err=True) + sys.exit(1) + except requests.RequestException as e: + click.echo(f"Login failed: {e}", err=True) + sys.exit(1) + + +@cli.command() +@click.argument("username") +@click.option("--password", "-p", prompt=True, hide_input=True, confirmation_prompt=True) +@click.option("--email", "-e", default=None, help="Email (optional)") +def register(username, password, email): + """Register a new account.""" + try: + # Server expects form data, not JSON + form_data = { + "username": username, + "password": password, + "password2": password, + } + if email: + form_data["email"] = email + + resp = requests.post( + f"{get_l2_server()}/auth/register", + data=form_data + ) + if resp.status_code == 200: + # Check if we got a token back in a cookie + if "auth_token" in resp.cookies: + token = resp.cookies["auth_token"] + # Decode token to get username and expiry + import base64 + try: + # JWT format: header.payload.signature + payload = token.split(".")[1] + # Add padding if needed + payload += "=" * (4 - len(payload) % 4) + decoded = json.loads(base64.urlsafe_b64decode(payload)) + token_data = { + "access_token": token, + "username": decoded.get("username", username), + "expires_at": decoded.get("exp", "") + } + save_token(token_data) + click.echo(f"Registered and logged in as {token_data['username']}") + except Exception: + # If we can't decode, just save the token + save_token({"access_token": token, "username": username}) + click.echo(f"Registered and logged in as {username}") + else: + # HTML response - registration successful + if "successful" in resp.text.lower(): + click.echo(f"Registered as {username}. Please login to get a token.") + else: + click.echo(f"Registration failed: {resp.text}", err=True) + sys.exit(1) + else: + click.echo(f"Registration failed: {resp.text}", err=True) + sys.exit(1) + except requests.RequestException as e: + click.echo(f"Registration failed: {e}", err=True) + sys.exit(1) + + +@cli.command() +def logout(): + """Logout (clear saved token).""" + clear_token() + click.echo("Logged out") + + +@cli.command() +def whoami(): + """Show current logged-in user.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in") + return + + try: + resp = requests.get( + f"{get_l2_server()}/auth/me", + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 200: + user = resp.json() + click.echo(f"Username: {user['username']}") + click.echo(f"Created: {user['created_at']}") + if user.get('email'): + click.echo(f"Email: {user['email']}") + else: + click.echo("Token invalid or expired. Please login again.", err=True) + clear_token() + except requests.RequestException as e: + click.echo(f"Error: {e}", err=True) + + +@cli.command("config") +@click.option("--clear", is_flag=True, help="Clear saved server settings") +def show_config(clear): + """Show or clear saved configuration.""" + if clear: + if CONFIG_FILE.exists(): + CONFIG_FILE.unlink() + click.echo("Configuration cleared") + else: + click.echo("No configuration to clear") + return + + config = load_config() + click.echo(f"Config file: {CONFIG_FILE}") + click.echo() + click.echo(f"L1 Server: {DEFAULT_SERVER}") + if config.get("server"): + click.echo(f" (saved)") + elif os.environ.get("ARTDAG_SERVER"): + click.echo(f" (from ARTDAG_SERVER env)") + else: + click.echo(f" (default)") + + click.echo(f"L2 Server: {DEFAULT_L2_SERVER}") + if config.get("l2"): + click.echo(f" (saved)") + elif os.environ.get("ARTDAG_L2"): + click.echo(f" (from ARTDAG_L2 env)") + else: + click.echo(f" (default)") + + +# ============ Server Commands ============ + +@cli.command() +def info(): + """Show server info.""" + data = api_get("/") + click.echo(f"Server: {get_server()}") + click.echo(f"Name: {data['name']}") + click.echo(f"Version: {data['version']}") + click.echo(f"Cache: {data['cache_dir']}") + click.echo(f"Runs: {data['runs_count']}") + + +@cli.command() +def stats(): + """Show user stats (runs, recipes, effects, media, storage counts).""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.get(f"{get_server()}/api/stats", headers=headers) + resp.raise_for_status() + stats = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get stats: {e}", err=True) + sys.exit(1) + + click.echo("User Stats:") + click.echo(f" Runs: {stats.get('runs', 0)}") + click.echo(f" Recipes: {stats.get('recipes', 0)}") + click.echo(f" Effects: {stats.get('effects', 0)}") + click.echo(f" Media: {stats.get('media', 0)}") + click.echo(f" Storage: {stats.get('storage', 0)}") + + +@cli.command("clear-data") +@click.option("--force", "-f", is_flag=True, help="Skip confirmation") +def clear_data(force): + """Clear all user L1 data (runs, recipes, effects, media). + + Storage provider configurations are preserved. + This action cannot be undone! + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + # Show current stats first + try: + headers = get_auth_header(require_token=True) + resp = requests.get(f"{get_server()}/api/stats", headers=headers) + resp.raise_for_status() + stats = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get stats: {e}", err=True) + sys.exit(1) + + click.echo("This will delete:") + click.echo(f" Runs: {stats.get('runs', 0)}") + click.echo(f" Recipes: {stats.get('recipes', 0)}") + click.echo(f" Effects: {stats.get('effects', 0)}") + click.echo(f" Media: {stats.get('media', 0)}") + click.echo() + click.echo("Storage configurations will be preserved.") + click.echo() + + if not force: + if not click.confirm("Are you sure you want to delete all this data?"): + click.echo("Cancelled.") + return + + click.echo() + click.echo("Clearing data...") + + try: + resp = requests.delete(f"{get_server()}/api/clear-data", headers=headers) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to clear data: {e}", err=True) + sys.exit(1) + + deleted = result.get("deleted", {}) + click.echo() + click.echo("Deleted:") + click.echo(f" Runs: {deleted.get('runs', 0)}") + click.echo(f" Recipes: {deleted.get('recipes', 0)}") + click.echo(f" Effects: {deleted.get('effects', 0)}") + click.echo(f" Media: {deleted.get('media', 0)}") + + errors = result.get("errors", []) + if errors: + click.echo() + click.echo("Errors encountered:") + for err in errors[:5]: + click.echo(f" - {err}") + if len(errors) > 5: + click.echo(f" ... and {len(errors) - 5} more") + + +@cli.command() +@click.argument("recipe") +@click.argument("input_hash") +@click.option("--name", "-n", help="Output name") +@click.option("--wait", "-w", is_flag=True, help="Wait for completion") +def run(recipe, input_hash, name, wait): + """Start a rendering run. Requires login. + + RECIPE: Effect/recipe to apply (e.g., dog, identity) + INPUT_HASH: Content hash of input asset + """ + # Check auth + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Resolve named assets + assets = api_get("/assets") + if input_hash in assets: + input_hash = assets[input_hash] + click.echo(f"Resolved input to: {input_hash[:16]}...") + + data = { + "recipe": recipe, + "inputs": [input_hash], + } + if name: + data["output_name"] = name + + try: + result = api_post("/runs", data, auth=True) + except requests.HTTPError as e: + if e.response.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + raise + run_id = result["run_id"] + + click.echo(f"Run started: {run_id}") + click.echo(f"Status: {result['status']}") + + if wait: + click.echo("Waiting for completion...") + while True: + status = api_get(f"/runs/{run_id}") + if status["status"] in ("completed", "failed"): + break + time.sleep(1) + click.echo(".", nl=False) + click.echo() + + if status["status"] == "completed": + click.echo(f"Completed!") + click.echo(f"Output: {status['output_cid']}") + else: + click.echo(f"Failed: {status.get('error', 'Unknown error')}") + + +@cli.command("runs") +@click.option("--limit", "-l", default=10, help="Max runs to show") +@click.option("--offset", "-o", default=0, help="Offset for pagination") +def list_runs(limit, offset): + """List all runs with pagination.""" + headers = get_auth_header(require_token=True) + + try: + resp = requests.get(f"{get_server()}/runs?offset={offset}&limit={limit}", headers=headers) + resp.raise_for_status() + data = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to list runs: {e}", err=True) + sys.exit(1) + + runs = data.get("runs", []) + has_more = data.get("has_more", False) + + if not runs: + click.echo("No runs found.") + return + + start = offset + 1 + end = offset + len(runs) + click.echo(f"Showing {start}-{end}" + (" (more available)" if has_more else "")) + click.echo() + + for run in runs: + click.echo(f"Run ID: {run['run_id']}") + click.echo(f" Status: {run['status']}") + click.echo(f" Recipe: {run['recipe']}") + if run.get("recipe_name"): + click.echo(f" Recipe Name: {run['recipe_name']}") + if run.get("output_cid"): + click.echo(f" Output: {run['output_cid']}") + if run.get("created_at"): + click.echo(f" Created: {run['created_at']}") + click.echo() + + +@cli.command() +@click.argument("run_id") +@click.option("--plan", "-p", is_flag=True, help="Show execution plan with steps") +@click.option("--artifacts", "-a", is_flag=True, help="Show output artifacts") +@click.option("--analysis", is_flag=True, help="Show audio analysis data") +def status(run_id, plan, artifacts, analysis): + """Get status of a run with optional detailed views.""" + headers = get_auth_header() # Optional auth, always has Accept header + + try: + resp = requests.get(f"{get_server()}/runs/{run_id}", headers=headers) + if resp.status_code == 404: + click.echo(f"Run not found: {run_id}") + return + resp.raise_for_status() + run = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get run: {e}", err=True) + sys.exit(1) + + # Basic status + click.echo(f"Run ID: {run['run_id']}") + click.echo(f"Status: {run['status']}") + click.echo(f"Recipe: {run['recipe']}") + click.echo(f"Inputs: {', '.join(run.get('inputs', []))}") + click.echo(f"Output Name: {run.get('output_name', 'N/A')}") + click.echo(f"Created: {run['created_at']}") + + if run.get("completed_at"): + click.echo(f"Completed: {run['completed_at']}") + + if run.get("output_cid"): + click.echo(f"Output: {run['output_cid']}") + + if run.get("plan_cid"): + click.echo(f"Plan: {run['plan_cid']}") + + if run.get("error"): + click.echo(f"Error: {run['error']}") + + # Plan view + if plan: + click.echo() + click.echo("Execution Plan:") + click.echo("-" * 60) + try: + plan_resp = requests.get(f"{get_server()}/runs/{run_id}/plan", headers=headers) + if plan_resp.status_code == 200: + plan_data = plan_resp.json() + steps = plan_data.get("steps", []) + if steps: + for i, step in enumerate(steps, 1): + status_str = step.get("status", "pending") + if status_str == "cached": + status_badge = "[cached]" + elif status_str == "completed": + status_badge = "[done]" + elif status_str == "running": + status_badge = "[running]" + else: + status_badge = "[pending]" + + step_id = step.get("id", step.get("node_id", f"step_{i}")) + step_type = step.get("type", "unknown") + output_cid = step.get("output_cid", "") + + click.echo(f" {i}. {status_badge:<10} {step_id} ({step_type})") + if output_cid: + click.echo(f" Output: {output_cid}") + else: + click.echo(" No plan steps available.") + else: + click.echo(" Plan not available.") + except requests.RequestException: + click.echo(" Failed to fetch plan.") + + # Artifacts view + if artifacts: + click.echo() + click.echo("Artifacts:") + click.echo("-" * 60) + try: + art_resp = requests.get(f"{get_server()}/runs/{run_id}/artifacts", headers=headers) + if art_resp.status_code == 200: + art_data = art_resp.json() + artifact_list = art_data.get("artifacts", []) + if artifact_list: + for art in artifact_list: + cid = art.get("cid", art.get("output_cid", "unknown")) + name = art.get("name", art.get("step_id", "output")) + media_type = art.get("media_type", art.get("content_type", "")) + size = art.get("size", "") + click.echo(f" {name}:") + click.echo(f" CID: {cid}") + if media_type: + click.echo(f" Type: {media_type}") + if size: + click.echo(f" Size: {size}") + else: + click.echo(" No artifacts available.") + else: + click.echo(" Artifacts not available.") + except requests.RequestException: + click.echo(" Failed to fetch artifacts.") + + # Analysis view + if analysis: + click.echo() + click.echo("Analysis:") + click.echo("-" * 60) + try: + # Analysis is included in the detail view + detail_resp = requests.get(f"{get_server()}/runs/{run_id}/detail", headers=headers) + if detail_resp.status_code == 200: + detail_data = detail_resp.json() + analysis_data = detail_data.get("analysis", []) + if analysis_data: + for item in analysis_data: + input_name = item.get("input_name", item.get("name", "input")) + click.echo(f" {input_name}:") + if item.get("tempo"): + click.echo(f" Tempo: {item['tempo']} BPM") + if item.get("beat_count"): + click.echo(f" Beats: {item['beat_count']}") + if item.get("energy") is not None: + click.echo(f" Energy: {item['energy']}%") + if item.get("duration"): + click.echo(f" Duration: {item['duration']:.1f}s") + click.echo() + else: + click.echo(" No analysis data available.") + else: + click.echo(" Analysis not available.") + except requests.RequestException: + click.echo(" Failed to fetch analysis.") + + +@cli.command("delete-run") +@click.argument("run_id") +@click.option("--force", "-f", is_flag=True, help="Skip confirmation") +def delete_run(run_id, force): + """Delete a run. Requires login. + + RUN_ID: The run ID to delete + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Get run info first + try: + run = api_get(f"/runs/{run_id}") + except requests.HTTPError as e: + if e.response.status_code == 404: + click.echo(f"Run not found: {run_id}", err=True) + sys.exit(1) + raise + + if not force: + click.echo(f"Run: {run_id}") + click.echo(f"Status: {run['status']}") + click.echo(f"Recipe: {run['recipe']}") + if not click.confirm("Delete this run?"): + click.echo("Cancelled.") + return + + try: + headers = get_auth_header(require_token=True) + resp = requests.delete(f"{get_server()}/runs/{run_id}", headers=headers) + if resp.status_code == 400: + click.echo(f"Cannot delete: {resp.json().get('detail', 'Unknown error')}", err=True) + sys.exit(1) + if resp.status_code == 403: + click.echo("Access denied", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Run not found: {run_id}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to delete run: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted run: {run_id}") + + +@cli.command("delete-cache") +@click.argument("cid") +@click.option("--force", "-f", is_flag=True, help="Skip confirmation") +def delete_cache(cid, force): + """Delete a cached item. Requires login. + + CID: The content identifier (IPFS CID) to delete + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + if not force: + click.echo(f"CID: {cid}") + if not click.confirm("Delete this cached item?"): + click.echo("Cancelled.") + return + + try: + headers = get_auth_header(require_token=True) + resp = requests.delete(f"{get_server()}/cache/{cid}", headers=headers) + if resp.status_code == 400: + click.echo(f"Cannot delete: {resp.json().get('detail', 'Unknown error')}", err=True) + sys.exit(1) + if resp.status_code == 403: + click.echo("Access denied", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Content not found: {cid}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to delete cache item: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted: {cid}") + + +MEDIA_TYPE_EXTENSIONS = { + "image": ["jpg", "jpeg", "png", "gif", "webp", "bmp", "svg"], + "video": ["mp4", "mkv", "webm", "mov", "avi", "wmv"], + "audio": ["mp3", "wav", "flac", "ogg", "m4a", "aac"], +} + + +def matches_media_type(item: dict, media_type: str) -> bool: + """Check if item matches the requested media type.""" + if media_type == "all": + return True + + # Check content_type/media_type field + content_type = item.get("content_type", item.get("media_type", "")) + if content_type: + if media_type == "image" and content_type.startswith("image/"): + return True + if media_type == "video" and content_type.startswith("video/"): + return True + if media_type == "audio" and content_type.startswith("audio/"): + return True + + # Check filename extension + filename = item.get("filename", item.get("friendly_name", "")) + if filename: + ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" + if ext in MEDIA_TYPE_EXTENSIONS.get(media_type, []): + return True + + return False + + +@cli.command() +@click.option("--limit", "-l", default=20, help="Max items to show") +@click.option("--offset", "-o", default=0, help="Offset for pagination") +@click.option("--type", "-t", "media_type", type=click.Choice(["all", "image", "video", "audio"]), + default="all", help="Filter by media type") +def cache(limit, offset, media_type): + """List cached content with pagination and optional type filter.""" + headers = get_auth_header(require_token=True) + + # Fetch more items if filtering to ensure we get enough results + fetch_limit = limit * 3 if media_type != "all" else limit + + try: + resp = requests.get(f"{get_server()}/cache?offset={offset}&limit={fetch_limit}", headers=headers) + resp.raise_for_status() + data = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to list cache: {e}", err=True) + sys.exit(1) + + items = data.get("items", []) + has_more = data.get("has_more", False) + + # Filter by media type if requested + if media_type != "all": + items = [item for item in items if isinstance(item, dict) and matches_media_type(item, media_type)] + items = items[:limit] # Apply limit after filtering + + if not items: + if media_type != "all": + click.echo(f"No {media_type} files found in cache.") + else: + click.echo("Cache is empty.") + return + + start = offset + 1 + end = offset + len(items) + type_str = f" ({media_type})" if media_type != "all" else "" + click.echo(f"Showing {start}-{end}{type_str}" + (" (more available)" if has_more else "")) + click.echo() + + for item in items: + cid = item.get("cid", item) if isinstance(item, dict) else item + name = item.get("friendly_name") or item.get("filename") if isinstance(item, dict) else None + content_type = item.get("content_type", "") if isinstance(item, dict) else "" + type_badge = f"[{content_type.split('/')[0]}]" if content_type else "" + click.echo(f"CID: {cid}") + if name: + click.echo(f" Name: {name}") + if type_badge: + click.echo(f" Type: {type_badge}") + click.echo() + + +@cli.command() +@click.argument("cid") +@click.option("--output", "-o", type=click.Path(), help="Save to file (use - for stdout)") +@click.option("--raw", "-r", is_flag=True, help="Get raw content (use with -o to download)") +def view(cid, output, raw): + """View or download cached content. + + Use -o - to pipe to stdout, e.g.: artdag view -o - | mpv - + Use --raw to get the raw file content instead of metadata. + """ + # Use /raw endpoint if --raw flag or if outputting to file/stdout + if raw or output: + url = f"{get_server()}/cache/{cid}/raw" + else: + url = f"{get_server()}/cache/{cid}" + + try: + if output == "-" or (raw and not output): + # Stream to stdout for piping (--raw without -o also goes to stdout) + resp = requests.get(url, stream=True) + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=8192): + sys.stdout.buffer.write(chunk) + elif output: + # Download to file + resp = requests.get(url, stream=True) + resp.raise_for_status() + with open(output, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + click.echo(f"Saved to: {output}", err=True) + else: + # Get info - use JSON endpoint for metadata + headers = {"Accept": "application/json"} + resp = requests.get(f"{get_server()}/cache/{cid}", headers=headers) + resp.raise_for_status() + info = resp.json() + click.echo(f"CID: {cid}") + click.echo(f"Size: {info.get('size', 'unknown')} bytes") + click.echo(f"Type: {info.get('mime_type') or info.get('media_type', 'unknown')}") + if info.get('friendly_name'): + click.echo(f"Friendly Name: {info['friendly_name']}") + if info.get('title'): + click.echo(f"Title: {info['title']}") + if info.get('filename'): + click.echo(f"Filename: {info['filename']}") + click.echo(f"Raw URL: {get_server()}/cache/{cid}/raw") + except requests.HTTPError as e: + if e.response.status_code == 404: + click.echo(f"Not found: {cid}", err=True) + else: + raise + + +@cli.command("import") +@click.argument("filepath", type=click.Path(exists=True)) +def import_file(filepath): + """Import a local file to cache (local server only).""" + path = str(Path(filepath).resolve()) + result = api_post("/cache/import", params={"path": path}) + click.echo(f"Imported: {result['cid']}") + + +@cli.command() +@click.argument("filepath", type=click.Path(exists=True)) +@click.option("--name", "-n", help="Friendly name for the asset") +def upload(filepath, name): + """Upload a file to cache and IPFS. Requires login.""" + # Check auth + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + with open(filepath, "rb") as f: + files = {"file": (Path(filepath).name, f)} + data = {"display_name": name} if name else {} + headers = get_auth_header(require_token=True) + resp = requests.post(f"{get_server()}/cache/upload", files=files, data=data, headers=headers) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code >= 400: + try: + detail = resp.json().get("detail", resp.text) + except: + detail = resp.text + click.echo(f"Upload failed: {resp.status_code} - {detail}", err=True) + sys.exit(1) + result = resp.json() + click.echo(f"CID: {result['cid']}") + click.echo(f"Friendly name: {result.get('friendly_name', 'N/A')}") + click.echo(f"Size: {result['size']} bytes") + click.echo() + click.echo("Use in recipes:") + friendly = result.get('friendly_name', result['cid']) + click.echo(f' (streaming:make-video-source "{friendly}" 30)') + except requests.RequestException as e: + click.echo(f"Upload failed: {e}", err=True) + sys.exit(1) + + +@cli.command() +def assets(): + """List known assets.""" + data = api_get("/assets") + click.echo("Known assets:") + for name, hash in data.items(): + click.echo(f" {name}: {hash[:16]}...") + + +@cli.command() +@click.argument("run_id") +@click.argument("output_name") +def publish(run_id, output_name): + """Publish an L1 run to L2 (register ownership). Requires login. + + RUN_ID: The L1 run ID to publish + OUTPUT_NAME: Name for the registered asset + """ + # Check auth + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Post to L2 server with auth, including which L1 server has the run + try: + resp = requests.post( + f"{get_l2_server()}/registry/record-run", + json={"run_id": run_id, "output_name": output_name, "l1_server": get_server()}, + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to publish: {e}", err=True) + sys.exit(1) + + result = resp.json() + click.echo(f"Published to L2!") + click.echo(f"Asset: {result['asset']['name']}") + click.echo(f"CID: {result['asset']['cid']}") + click.echo(f"Activity: {result['activity']['activity_id']}") + + +# ============ Metadata Commands ============ + +@cli.command() +@click.argument("cid") +@click.option("--origin", type=click.Choice(["self", "external"]), help="Set origin type") +@click.option("--origin-url", help="Set external origin URL") +@click.option("--origin-note", help="Note about the origin") +@click.option("--description", "-d", help="Set description") +@click.option("--tags", "-t", help="Set tags (comma-separated)") +@click.option("--folder", "-f", help="Set folder path") +@click.option("--add-collection", help="Add to collection") +@click.option("--remove-collection", help="Remove from collection") +@click.option("--publish", "publish_name", help="Publish to L2 with given asset name") +@click.option("--publish-type", default="image", help="Asset type for publishing (image, video)") +@click.option("--republish", is_flag=True, help="Re-sync with L2 after metadata changes") +def meta(cid, origin, origin_url, origin_note, description, tags, folder, add_collection, remove_collection, publish_name, publish_type, republish): + """View or update metadata for a cached item. + + With no options, displays current metadata. + With options, updates the specified fields. + + Use --publish to publish to L2 (requires origin to be set). + Use --republish to sync metadata changes to L2. + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + headers = get_auth_header(require_token=True) + + # Handle publish action + if publish_name: + try: + resp = requests.post( + f"{get_server()}/cache/{cid}/publish", + json={"asset_name": publish_name, "asset_type": publish_type}, + headers=headers + ) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Content not found: {cid}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + click.echo(f"Published to L2!") + click.echo(f"Asset name: {result['asset_name']}") + click.echo(f"Activity: {result['l2_result']['activity']['activity_id']}") + except requests.RequestException as e: + click.echo(f"Failed to publish: {e}", err=True) + sys.exit(1) + return + + # Handle republish action + if republish: + try: + resp = requests.patch( + f"{get_server()}/cache/{cid}/republish", + headers=headers + ) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Content not found: {cid}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + click.echo(f"Re-synced with L2!") + click.echo(f"Asset name: {result['asset_name']}") + except requests.RequestException as e: + click.echo(f"Failed to republish: {e}", err=True) + sys.exit(1) + return + + # If no update options, just display current metadata + has_updates = any([origin, origin_url, origin_note, description, tags, folder, add_collection, remove_collection]) + + if not has_updates: + # GET metadata + try: + resp = requests.get(f"{get_server()}/cache/{cid}/meta", headers=headers) + if resp.status_code == 404: + click.echo(f"Content not found: {cid}", err=True) + sys.exit(1) + if resp.status_code == 403: + click.echo("Access denied", err=True) + sys.exit(1) + resp.raise_for_status() + meta = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get metadata: {e}", err=True) + sys.exit(1) + + click.echo(f"Content Hash: {cid}") + click.echo(f"Uploader: {meta.get('uploader', 'unknown')}") + click.echo(f"Uploaded: {meta.get('uploaded_at', 'unknown')}") + if meta.get("origin"): + origin_info = meta["origin"] + click.echo(f"Origin: {origin_info.get('type', 'unknown')}") + if origin_info.get("url"): + click.echo(f" URL: {origin_info['url']}") + if origin_info.get("note"): + click.echo(f" Note: {origin_info['note']}") + else: + click.echo("Origin: not set") + click.echo(f"Description: {meta.get('description', 'none')}") + click.echo(f"Tags: {', '.join(meta.get('tags', [])) or 'none'}") + click.echo(f"Folder: {meta.get('folder', '/')}") + click.echo(f"Collections: {', '.join(meta.get('collections', [])) or 'none'}") + if meta.get("published"): + pub = meta["published"] + click.echo(f"Published: {pub.get('asset_name')} ({pub.get('published_at')})") + return + + # Build update payload + update = {} + + if origin or origin_url or origin_note: + # Get current origin first + try: + resp = requests.get(f"{get_server()}/cache/{cid}/meta", headers=headers) + resp.raise_for_status() + current = resp.json() + current_origin = current.get("origin", {}) + except: + current_origin = {} + + update["origin"] = { + "type": origin or current_origin.get("type", "self"), + "url": origin_url if origin_url is not None else current_origin.get("url"), + "note": origin_note if origin_note is not None else current_origin.get("note") + } + + if description is not None: + update["description"] = description + + if tags is not None: + update["tags"] = [t.strip() for t in tags.split(",") if t.strip()] + + if folder is not None: + update["folder"] = folder + + if add_collection or remove_collection: + # Get current collections + try: + resp = requests.get(f"{get_server()}/cache/{cid}/meta", headers=headers) + resp.raise_for_status() + current = resp.json() + collections = set(current.get("collections", [])) + except: + collections = set() + + if add_collection: + collections.add(add_collection) + if remove_collection and remove_collection in collections: + collections.remove(remove_collection) + update["collections"] = list(collections) + + # PATCH metadata + try: + resp = requests.patch( + f"{get_server()}/cache/{cid}/meta", + json=update, + headers=headers + ) + if resp.status_code == 404: + click.echo(f"Content not found: {cid}", err=True) + sys.exit(1) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to update metadata: {e}", err=True) + sys.exit(1) + + click.echo("Metadata updated.") + + +# ============ Folder Commands ============ + +@cli.group() +def folder(): + """Manage folders for organizing cached items.""" + pass + + +@folder.command("list") +def folder_list(): + """List all folders.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.get( + f"{get_server()}/user/folders", + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + resp.raise_for_status() + folders = resp.json()["folders"] + except requests.RequestException as e: + click.echo(f"Failed to list folders: {e}", err=True) + sys.exit(1) + + click.echo("Folders:") + for f in folders: + click.echo(f" {f}") + + +@folder.command("create") +@click.argument("path") +def folder_create(path): + """Create a new folder.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.post( + f"{get_server()}/user/folders", + params={"folder_path": path}, + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to create folder: {e}", err=True) + sys.exit(1) + + click.echo(f"Created folder: {path}") + + +@folder.command("delete") +@click.argument("path") +def folder_delete(path): + """Delete a folder (must be empty).""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.delete( + f"{get_server()}/user/folders", + params={"folder_path": path}, + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Folder not found: {path}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to delete folder: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted folder: {path}") + + +# ============ Collection Commands ============ + +@cli.group() +def collection(): + """Manage collections for organizing cached items.""" + pass + + +@collection.command("list") +def collection_list(): + """List all collections.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.get( + f"{get_server()}/user/collections", + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + resp.raise_for_status() + collections = resp.json()["collections"] + except requests.RequestException as e: + click.echo(f"Failed to list collections: {e}", err=True) + sys.exit(1) + + click.echo("Collections:") + for c in collections: + click.echo(f" {c['name']} (created: {c['created_at'][:10]})") + + +@collection.command("create") +@click.argument("name") +def collection_create(name): + """Create a new collection.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.post( + f"{get_server()}/user/collections", + params={"name": name}, + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to create collection: {e}", err=True) + sys.exit(1) + + click.echo(f"Created collection: {name}") + + +@collection.command("delete") +@click.argument("name") +def collection_delete(name): + """Delete a collection.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + resp = requests.delete( + f"{get_server()}/user/collections", + params={"name": name}, + headers={"Authorization": f"Bearer {token_data['access_token']}"} + ) + if resp.status_code == 404: + click.echo(f"Collection not found: {name}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to delete collection: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted collection: {name}") + + +# ============ Storage Commands ============ + +STORAGE_PROVIDER_TYPES = ["pinata", "web3storage", "nftstorage", "infura", "filebase", "storj", "local"] + +STORAGE_CONFIG_FIELDS = { + "pinata": ["api_key", "secret_key"], + "web3storage": ["api_token"], + "nftstorage": ["api_token"], + "infura": ["project_id", "project_secret"], + "filebase": ["access_key", "secret_key", "bucket"], + "storj": ["access_key", "secret_key", "bucket"], + "local": ["path"], +} + + +@cli.group() +def storage(): + """Manage IPFS storage providers.""" + pass + + +@storage.command("list") +def storage_list(): + """List all storage providers.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.get(f"{get_server()}/storage", headers=headers) + resp.raise_for_status() + data = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to list storage providers: {e}", err=True) + sys.exit(1) + + storages = data.get("storages", []) + if not storages: + click.echo("No storage providers configured.") + click.echo(f"\nAvailable types: {', '.join(STORAGE_PROVIDER_TYPES)}") + click.echo("Use 'artdag storage add ' to add one.") + return + + click.echo("Storage Providers:") + click.echo() + for s in storages: + status = "Active" if s.get("is_active", True) else "Inactive" + click.echo(f" [{s['id']}] {s['provider_name'] or s['provider_type']} ({s['provider_type']})") + click.echo(f" Status: {status}") + click.echo(f" Capacity: {s.get('capacity_gb', 'N/A')} GB") + click.echo() + + +@storage.command("add") +@click.argument("provider_type", type=click.Choice(STORAGE_PROVIDER_TYPES)) +@click.option("--name", "-n", help="Friendly name for this provider") +@click.option("--capacity", "-c", type=int, default=5, help="Capacity in GB (default: 5)") +def storage_add(provider_type, name, capacity): + """Add a storage provider (interactive config).""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + # Get config fields for this provider type + fields = STORAGE_CONFIG_FIELDS.get(provider_type, []) + config = {} + + click.echo(f"Configuring {provider_type} storage provider...") + click.echo() + + for field in fields: + is_secret = "secret" in field.lower() or "key" in field.lower() or "token" in field.lower() + if is_secret: + value = click.prompt(f" {field}", hide_input=True) + else: + value = click.prompt(f" {field}") + config[field] = value + + # Send to server + try: + headers = get_auth_header(require_token=True) + payload = { + "provider_type": provider_type, + "config": config, + "capacity_gb": capacity, + } + if name: + payload["provider_name"] = name + + resp = requests.post(f"{get_server()}/storage", json=payload, headers=headers) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to add storage provider: {e}", err=True) + sys.exit(1) + + click.echo() + click.echo(f"Storage provider added (ID: {result.get('id')})") + + +@storage.command("test") +@click.argument("storage_id", type=int) +def storage_test(storage_id): + """Test storage provider connectivity.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.post(f"{get_server()}/storage/{storage_id}/test", headers=headers) + if resp.status_code == 404: + click.echo(f"Storage provider not found: {storage_id}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to test storage: {e}", err=True) + sys.exit(1) + + if result.get("success"): + click.echo(f"Success: {result.get('message', 'Connection OK')}") + else: + click.echo(f"Failed: {result.get('message', 'Unknown error')}", err=True) + sys.exit(1) + + +@storage.command("delete") +@click.argument("storage_id", type=int) +@click.option("--force", "-f", is_flag=True, help="Skip confirmation") +def storage_delete(storage_id, force): + """Delete a storage provider.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + if not force: + if not click.confirm(f"Delete storage provider {storage_id}?"): + click.echo("Cancelled.") + return + + try: + headers = get_auth_header(require_token=True) + resp = requests.delete(f"{get_server()}/storage/{storage_id}", headers=headers) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Storage provider not found: {storage_id}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Failed to delete storage provider: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted storage provider: {storage_id}") + + +# ============ Recipe Commands ============ + +def _is_sexp_file(filepath: str, content: str) -> bool: + """Detect if file is S-expression format.""" + # Check extension first + if filepath.endswith('.sexp'): + return True + # Check content - skip comments and whitespace + for line in content.split('\n'): + stripped = line.strip() + if not stripped or stripped.startswith(';'): + continue + return stripped.startswith('(') + return False + + +@cli.command("upload-recipe") +@click.argument("filepath", type=click.Path(exists=True)) +def upload_recipe(filepath): + """Upload a recipe file (YAML or S-expression). Requires login.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Read content + with open(filepath) as f: + content = f.read() + + # Detect format and validate + is_sexp = _is_sexp_file(filepath, content) + + if is_sexp: + # S-expression - basic syntax check (starts with paren after comments) + # Full validation happens on server + click.echo("Detected S-expression format") + else: + # Validate YAML locally + try: + recipe = yaml.safe_load(content) + except yaml.YAMLError as e: + click.echo(f"Invalid YAML: {e}", err=True) + sys.exit(1) + + # Check required fields for YAML + if not recipe.get("name"): + click.echo("Recipe must have a 'name' field", err=True) + sys.exit(1) + + # Upload + try: + with open(filepath, "rb") as f: + files = {"file": (Path(filepath).name, f)} + headers = get_auth_header(require_token=True) + resp = requests.post(f"{get_server()}/recipes/upload", files=files, headers=headers) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code >= 400: + click.echo(f"Error response: {resp.text}", err=True) + resp.raise_for_status() + result = resp.json() + + click.echo(f"Uploaded recipe: {result['name']} v{result.get('version', '1.0')}") + click.echo(f"Recipe ID: {result['recipe_id']}") + click.echo(f"Variable inputs: {result['variable_inputs']}") + click.echo(f"Fixed inputs: {result['fixed_inputs']}") + except requests.RequestException as e: + click.echo(f"Upload failed: {e}", err=True) + sys.exit(1) + + +@cli.command("upload-effect") +@click.argument("filepath", type=click.Path(exists=True)) +@click.option("--name", "-n", help="Friendly name for the effect") +def upload_effect(filepath, name): + """Upload an effect file to IPFS. Requires login. + + Effects are S-expression files (.sexp) with metadata in comments. + Returns the IPFS CID for use in recipes. + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Check it's a sexp or py file + if not filepath.endswith(".sexp") and not filepath.endswith(".py"): + click.echo("Effect must be a .sexp or .py file", err=True) + sys.exit(1) + + # Upload + try: + with open(filepath, "rb") as f: + files = {"file": (Path(filepath).name, f)} + data = {"display_name": name} if name else {} + headers = get_auth_header(require_token=True) + resp = requests.post(f"{get_server()}/effects/upload", files=files, data=data, headers=headers) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code >= 400: + click.echo(f"Error response: {resp.text}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + + click.echo(f"Uploaded effect: {result['name']} v{result.get('version', '1.0.0')}") + click.echo(f"CID: {result['cid']}") + click.echo(f"Friendly name: {result.get('friendly_name', 'N/A')}") + click.echo(f"Temporal: {result.get('temporal', False)}") + if result.get('params'): + click.echo(f"Parameters: {', '.join(p['name'] for p in result['params'])}") + click.echo() + click.echo("Use in recipes:") + click.echo(f' (effect {result["name"]} :name "{result.get("friendly_name", result["cid"])}")') + except requests.RequestException as e: + click.echo(f"Upload failed: {e}", err=True) + sys.exit(1) + + +@cli.command("effects") +@click.option("--limit", "-l", default=20, help="Max effects to show") +@click.option("--offset", "-o", default=0, help="Offset for pagination") +def list_effects(limit, offset): + """List uploaded effects with pagination.""" + headers = get_auth_header(require_token=True) + + try: + resp = requests.get(f"{get_server()}/effects?offset={offset}&limit={limit}", headers=headers) + resp.raise_for_status() + result = resp.json() + + effects = result.get("effects", []) + has_more = result.get("has_more", False) + + if not effects: + click.echo("No effects found") + return + + start = offset + 1 + end = offset + len(effects) + click.echo(f"Showing {start}-{end}" + (" (more available)" if has_more else "")) + click.echo() + + for effect in effects: + meta = effect.get("meta", {}) + click.echo(f"Name: {meta.get('name', 'unknown')} v{meta.get('version', '?')}") + click.echo(f" CID: {effect['cid']}") + if effect.get('friendly_name'): + click.echo(f" Friendly Name: {effect['friendly_name']}") + click.echo(f" Temporal: {meta.get('temporal', False)}") + if meta.get('params'): + click.echo(f" Params: {', '.join(p['name'] for p in meta['params'])}") + click.echo() + except requests.RequestException as e: + click.echo(f"Failed to list effects: {e}", err=True) + sys.exit(1) + + +@cli.command("effect") +@click.argument("cid") +@click.option("--source", "-s", is_flag=True, help="Show source code") +def show_effect(cid, source): + """Show details of an effect by CID.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.get(f"{get_server()}/effects/{cid}", headers=headers) + if resp.status_code == 404: + click.echo(f"Effect not found: {cid}", err=True) + sys.exit(1) + resp.raise_for_status() + effect = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get effect: {e}", err=True) + sys.exit(1) + + meta = effect.get("meta", effect) + name = meta.get("name", "Unnamed") + version = meta.get("version", "1.0.0") + author = meta.get("author", "Unknown") + description = meta.get("description", "No description") + + click.echo(f"Name: {name} (v{version})") + click.echo(f"Author: {author}") + click.echo(f"Description: {description}") + click.echo(f"CID: {effect.get('cid', cid)}") + if effect.get("uploaded_at"): + click.echo(f"Uploaded: {effect['uploaded_at']}") + if effect.get("uploader"): + click.echo(f"Uploader: {effect['uploader']}") + if meta.get("temporal"): + click.echo("Temporal: Yes") + + # Parameters + params = meta.get("params", []) + if params: + click.echo("\nParameters:") + for p in params: + param_type = p.get("type", "any") + param_desc = p.get("description", "") + param_range = "" + if "min" in p and "max" in p: + param_range = f" [{p['min']}-{p['max']}]" + param_default = f" default: {p['default']}" if "default" in p else "" + click.echo(f" - {p['name']} ({param_type}): {param_desc}{param_range}{param_default}") + + # Dependencies + deps = meta.get("dependencies", []) + if deps: + click.echo("\nDependencies:") + for dep in deps: + click.echo(f" - {dep}") + + # Source code + if source: + click.echo("\nSource Code:") + click.echo("-" * 40) + try: + source_resp = requests.get(f"{get_server()}/effects/{cid}/source", headers=headers) + if source_resp.status_code == 200: + click.echo(source_resp.text) + else: + click.echo("(Source not available)") + except requests.RequestException: + click.echo("(Failed to fetch source)") + + +@cli.command("recipes") +@click.option("--limit", "-l", default=10, help="Max recipes to show") +@click.option("--offset", "-o", default=0, help="Offset for pagination") +def list_recipes(limit, offset): + """List uploaded recipes for the current user with pagination.""" + headers = get_auth_header(require_token=True) + + try: + resp = requests.get(f"{get_server()}/recipes?offset={offset}&limit={limit}", headers=headers) + resp.raise_for_status() + data = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to list recipes: {e}", err=True) + sys.exit(1) + + recipes = data.get("recipes", []) + has_more = data.get("has_more", False) + + if not recipes: + click.echo("No recipes found.") + return + + start = offset + 1 + end = offset + len(recipes) + click.echo(f"Showing {start}-{end}" + (" (more available)" if has_more else "")) + click.echo() + + for recipe in recipes: + recipe_id = recipe["recipe_id"] + var_count = len(recipe.get("variable_inputs", [])) + friendly_name = recipe.get("friendly_name", "") + + click.echo(f"Name: {recipe['name']}") + click.echo(f" Version: {recipe.get('version', 'N/A')}") + if friendly_name: + click.echo(f" Friendly Name: {friendly_name}") + click.echo(f" Variables: {var_count}") + click.echo(f" Recipe ID: {recipe_id}") + click.echo() + + +@cli.command("recipe") +@click.argument("recipe_id") +def show_recipe(recipe_id): + """Show details of a recipe.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Use 'artdag login' first.", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.get(f"{get_server()}/recipes/{recipe_id}", headers=headers) + if resp.status_code == 404: + click.echo(f"Recipe not found: {recipe_id}", err=True) + sys.exit(1) + resp.raise_for_status() + recipe = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get recipe: {e}", err=True) + sys.exit(1) + + click.echo(f"Name: {recipe.get('name', 'Unnamed')}") + click.echo(f"Version: {recipe.get('version', 'N/A')}") + if recipe.get("friendly_name"): + click.echo(f"Friendly Name: {recipe['friendly_name']}") + click.echo(f"Description: {recipe.get('description', 'N/A')}") + click.echo(f"Recipe ID: {recipe['recipe_id']}") + click.echo(f"Owner: {recipe.get('owner', 'N/A')}") + if recipe.get("uploaded_at"): + click.echo(f"Uploaded: {recipe['uploaded_at']}") + + if recipe.get("variable_inputs"): + click.echo("\nVariable Inputs:") + for inp in recipe["variable_inputs"]: + req = "*" if inp.get("required", True) else "" + click.echo(f" - {inp['name']}{req}: {inp.get('description', 'No description')}") + + if recipe.get("fixed_inputs"): + click.echo("\nFixed Inputs:") + for inp in recipe["fixed_inputs"]: + click.echo(f" - {inp['asset']}: {inp['cid']}") + + +@cli.command("run-recipe") +@click.argument("recipe_id") +@click.option("--input", "-i", "inputs", multiple=True, help="Input as node_id:cid") +@click.option("--wait", "-w", is_flag=True, help="Wait for completion") +def run_recipe(recipe_id, inputs, wait): + """Run a recipe with variable inputs. Requires login. + + RECIPE_ID: The recipe ID (content hash) + + Example: artdag run-recipe abc123 -i source_image:def456 + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Parse inputs + input_dict = {} + for inp in inputs: + if ":" not in inp: + click.echo(f"Invalid input format: {inp} (expected node_id:cid)", err=True) + sys.exit(1) + node_id, cid = inp.split(":", 1) + input_dict[node_id] = cid + + # Run + try: + headers = get_auth_header(require_token=True) + resp = requests.post( + f"{get_server()}/recipes/{recipe_id}/run", + json={"inputs": input_dict}, + headers=headers + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + error = resp.json().get("detail", "Bad request") + click.echo(f"Error: {error}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Recipe not found: {recipe_id}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Run failed: {e}", err=True) + sys.exit(1) + + click.echo(f"Run started: {result['run_id']}") + if result.get('recipe'): + click.echo(f"Recipe: {result['recipe']}") + click.echo(f"Status: {result.get('status', 'pending')}") + + if wait: + click.echo("Waiting for completion...") + run_id = result["run_id"] + while True: + time.sleep(2) + try: + resp = requests.get(f"{get_server()}/runs/{run_id}") + resp.raise_for_status() + run = resp.json() + except requests.RequestException: + continue + + if run["status"] == "completed": + click.echo(f"Completed! Output: {run.get('output_cid', 'N/A')}") + break + elif run["status"] == "failed": + click.echo(f"Failed: {run.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + +@cli.command("delete-recipe") +@click.argument("recipe_id") +@click.option("--force", "-f", is_flag=True, help="Skip confirmation") +def delete_recipe(recipe_id, force): + """Delete a recipe. Requires login.""" + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + if not force: + if not click.confirm(f"Delete recipe {recipe_id[:16]}...?"): + click.echo("Cancelled.") + return + + try: + headers = get_auth_header(require_token=True) + resp = requests.delete(f"{get_server()}/recipes/{recipe_id}", headers=headers) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + error = resp.json().get("detail", "Cannot delete") + click.echo(f"Error: {error}", err=True) + sys.exit(1) + if resp.status_code == 404: + click.echo(f"Recipe not found: {recipe_id}", err=True) + sys.exit(1) + resp.raise_for_status() + except requests.RequestException as e: + click.echo(f"Delete failed: {e}", err=True) + sys.exit(1) + + click.echo(f"Deleted recipe: {recipe_id[:16]}...") + + +# ============ v2 API Commands (3-Phase Execution) ============ + +@cli.command("plan") +@click.argument("recipe_file", type=click.Path(exists=True)) +@click.option("--input", "-i", "inputs", multiple=True, help="Input as name:cid") +@click.option("--features", "-f", multiple=True, help="Features to extract (default: beats, energy)") +@click.option("--output", "-o", type=click.Path(), help="Save plan JSON to file") +def generate_plan(recipe_file, inputs, features, output): + """Generate an execution plan from a recipe YAML. Requires login. + + Preview what will be executed without actually running it. + + RECIPE_FILE: Path to recipe YAML file + + Example: artdag plan recipe.yaml -i source_video:abc123 + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Read recipe YAML + with open(recipe_file) as f: + recipe_yaml = f.read() + + # Parse inputs + input_hashes = {} + for inp in inputs: + if ":" not in inp: + click.echo(f"Invalid input format: {inp} (expected name:cid)", err=True) + sys.exit(1) + name, cid = inp.split(":", 1) + input_hashes[name] = cid + + # Build request + request_data = { + "recipe_yaml": recipe_yaml, + "input_hashes": input_hashes, + } + if features: + request_data["features"] = list(features) + + # Submit to API + try: + headers = get_auth_header(require_token=True) + resp = requests.post( + f"{get_server()}/api/v2/plan", + json=request_data, + headers=headers + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Plan generation failed: {e}", err=True) + sys.exit(1) + + # Display results + click.echo(f"Recipe: {result['recipe']}") + click.echo(f"Plan ID: {result['plan_id'][:16]}...") + click.echo(f"Total steps: {result['total_steps']}") + click.echo(f"Cached: {result['cached_steps']}") + click.echo(f"Pending: {result['pending_steps']}") + + if result.get("steps"): + click.echo("\nSteps:") + for step in result["steps"]: + status = "✓ cached" if step["cached"] else "○ pending" + click.echo(f" L{step['level']} {step['step_id']:<20} {step['node_type']:<10} {status}") + + # Save plan JSON if requested + if output: + with open(output, "w") as f: + f.write(result["plan_json"]) + click.echo(f"\nPlan saved to: {output}") + elif result.get("plan_json"): + click.echo(f"\nUse --output to save the plan JSON for later execution.") + + +@cli.command("execute-plan") +@click.argument("plan_file", type=click.Path(exists=True)) +@click.option("--wait", "-w", is_flag=True, help="Wait for completion") +def execute_plan(plan_file, wait): + """Execute a pre-generated execution plan. Requires login. + + PLAN_FILE: Path to plan JSON file (from 'artdag plan --output') + + Example: artdag execute-plan plan.json --wait + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Read plan JSON + with open(plan_file) as f: + plan_json = f.read() + + # Submit to API + try: + headers = get_auth_header(require_token=True) + resp = requests.post( + f"{get_server()}/api/v2/execute", + json={"plan_json": plan_json}, + headers=headers + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Execution failed: {e}", err=True) + sys.exit(1) + + run_id = result["run_id"] + click.echo(f"Run started: {run_id}") + click.echo(f"Status: {result['status']}") + + if wait: + _wait_for_v2_run(token_data, run_id) + + +@cli.command("run-v2") +@click.argument("recipe_file", type=click.Path(exists=True)) +@click.option("--input", "-i", "inputs", multiple=True, help="Input as name:cid") +@click.option("--features", "-f", multiple=True, help="Features to extract (default: beats, energy)") +@click.option("--wait", "-w", is_flag=True, help="Wait for completion") +def run_recipe_v2(recipe_file, inputs, features, wait): + """Run a recipe through 3-phase execution. Requires login. + + Runs the full pipeline: Analyze → Plan → Execute + + RECIPE_FILE: Path to recipe YAML file + + Example: artdag run-v2 recipe.yaml -i source_video:abc123 --wait + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Read recipe YAML + with open(recipe_file) as f: + recipe_yaml = f.read() + + # Parse recipe name for display + try: + recipe_data = yaml.safe_load(recipe_yaml) + recipe_name = recipe_data.get("name", "unknown") + except Exception: + recipe_name = "unknown" + + # Parse inputs + input_hashes = {} + for inp in inputs: + if ":" not in inp: + click.echo(f"Invalid input format: {inp} (expected name:cid)", err=True) + sys.exit(1) + name, cid = inp.split(":", 1) + input_hashes[name] = cid + + # Build request + request_data = { + "recipe_yaml": recipe_yaml, + "input_hashes": input_hashes, + } + if features: + request_data["features"] = list(features) + + # Submit to API + click.echo(f"Running recipe: {recipe_name}") + click.echo(f"Inputs: {len(input_hashes)}") + + try: + headers = get_auth_header(require_token=True) + resp = requests.post( + f"{get_server()}/api/v2/run-recipe", + json=request_data, + headers=headers + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + click.echo(f"Error: {resp.json().get('detail', 'Bad request')}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Run failed: {e}", err=True) + sys.exit(1) + + run_id = result["run_id"] + click.echo(f"Run ID: {run_id}") + click.echo(f"Status: {result['status']}") + + if result.get("output_cid"): + click.echo(f"Output: {result['output_cid']}") + if result.get("output_ipfs_cid"): + click.echo(f"IPFS CID: {result['output_ipfs_cid']}") + return + + if wait: + _wait_for_v2_run(token_data, run_id) + + +def _wait_for_v2_run(token_data: dict, run_id: str): + """Poll v2 run status until completion.""" + click.echo("Waiting for completion...") + headers = get_auth_header(require_token=True) + + while True: + time.sleep(2) + try: + resp = requests.get( + f"{get_server()}/api/v2/run/{run_id}", + headers=headers + ) + resp.raise_for_status() + run = resp.json() + except requests.RequestException as e: + click.echo(f".", nl=False) + continue + + status = run.get("status", "unknown") + + if status == "completed": + click.echo(f"\nCompleted!") + if run.get("output_cid"): + click.echo(f"Output: {run['output_cid']}") + if run.get("output_ipfs_cid"): + click.echo(f"IPFS CID: {run['output_ipfs_cid']}") + if run.get("cached"): + click.echo(f"Steps cached: {run['cached']}") + if run.get("executed"): + click.echo(f"Steps executed: {run['executed']}") + break + elif status == "failed": + click.echo(f"\nFailed: {run.get('error', 'Unknown error')}", err=True) + sys.exit(1) + else: + click.echo(".", nl=False) + + +@cli.command("run-status") +@click.argument("run_id") +def run_status_v2(run_id): + """Get status of a v2 run. Requires login. + + RUN_ID: The run ID from run-v2 or execute-plan + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + try: + headers = get_auth_header(require_token=True) + resp = requests.get( + f"{get_server()}/api/v2/run/{run_id}", + headers=headers + ) + if resp.status_code == 404: + click.echo(f"Run not found: {run_id}", err=True) + sys.exit(1) + resp.raise_for_status() + run = resp.json() + except requests.RequestException as e: + click.echo(f"Failed to get status: {e}", err=True) + sys.exit(1) + + click.echo(f"Run ID: {run_id}") + click.echo(f"Status: {run['status']}") + + if run.get("recipe"): + click.echo(f"Recipe: {run['recipe']}") + if run.get("plan_id"): + click.echo(f"Plan ID: {run['plan_id'][:16]}...") + if run.get("output_cid"): + click.echo(f"Output: {run['output_cid']}") + if run.get("output_ipfs_cid"): + click.echo(f"IPFS CID: {run['output_ipfs_cid']}") + if run.get("cached") is not None: + click.echo(f"Cached: {run['cached']}") + if run.get("executed") is not None: + click.echo(f"Executed: {run['executed']}") + if run.get("error"): + click.echo(f"Error: {run['error']}") + + +@cli.command("stream") +@click.argument("recipe_file", type=click.Path(exists=True)) +@click.option("--output", "-o", default="output.mp4", help="Output filename") +@click.option("--duration", "-d", type=float, help="Duration in seconds") +@click.option("--fps", type=float, help="FPS override") +@click.option("--sources", type=click.Path(exists=True), help="Sources config .sexp file") +@click.option("--audio", type=click.Path(exists=True), help="Audio config .sexp file") +@click.option("--wait", "-w", is_flag=True, help="Wait for completion") +def run_stream(recipe_file, output, duration, fps, sources, audio, wait): + """Run a streaming S-expression recipe. Requires login. + + RECIPE_FILE: Path to the recipe .sexp file + + Example: artdag stream effects/my_effect.sexp --duration 10 --fps 30 -w + """ + token_data = load_token() + if not token_data.get("access_token"): + click.echo("Not logged in. Please run: artdag login ", err=True) + sys.exit(1) + + # Read recipe file + recipe_path = Path(recipe_file) + recipe_sexp = recipe_path.read_text() + + # Read optional config files + sources_sexp = None + if sources: + sources_sexp = Path(sources).read_text() + + audio_sexp = None + if audio: + audio_sexp = Path(audio).read_text() + + # Build request + request_data = { + "recipe_sexp": recipe_sexp, + "output_name": output, + } + if duration: + request_data["duration"] = duration + if fps: + request_data["fps"] = fps + if sources_sexp: + request_data["sources_sexp"] = sources_sexp + if audio_sexp: + request_data["audio_sexp"] = audio_sexp + + # Submit + try: + headers = get_auth_header(require_token=True) + resp = requests.post( + f"{get_server()}/runs/stream", + json=request_data, + headers=headers + ) + if resp.status_code == 401: + click.echo("Authentication failed. Please login again.", err=True) + sys.exit(1) + if resp.status_code == 400: + error = resp.json().get("detail", "Bad request") + click.echo(f"Error: {error}", err=True) + sys.exit(1) + resp.raise_for_status() + result = resp.json() + except requests.RequestException as e: + click.echo(f"Stream failed: {e}", err=True) + sys.exit(1) + + run_id = result["run_id"] + click.echo(f"Stream started: {run_id}") + click.echo(f"Task ID: {result.get('celery_task_id', 'N/A')}") + click.echo(f"Status: {result.get('status', 'pending')}") + + if wait: + click.echo("Waiting for completion...") + while True: + time.sleep(2) + try: + resp = requests.get( + f"{get_server()}/runs/{run_id}", + headers=get_auth_header() + ) + resp.raise_for_status() + run = resp.json() + except requests.RequestException: + continue + + status = run.get("status") + if status == "completed": + click.echo(f"\nCompleted!") + if run.get("output_cid"): + click.echo(f"Output CID: {run['output_cid']}") + break + elif status == "failed": + click.echo(f"\nFailed: {run.get('error', 'Unknown error')}", err=True) + sys.exit(1) + else: + click.echo(".", nl=False) + + +if __name__ == "__main__": + cli() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3fb1204 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +click>=8.0.0 +requests>=2.31.0 +PyYAML>=6.0 diff --git a/test_gpu_effects.sexp b/test_gpu_effects.sexp new file mode 100644 index 0000000..42cea11 --- /dev/null +++ b/test_gpu_effects.sexp @@ -0,0 +1,38 @@ +;; GPU Effects Performance Test +;; Tests rotation, zoom, hue-shift, ripple + +(stream "gpu_effects_test" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load primitives + (require-primitives "geometry") + (require-primitives "core") + (require-primitives "math") + (require-primitives "image") + (require-primitives "color_ops") + + ;; Frame pipeline - test GPU effects + (frame + (let [;; Create a base gradient image + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + color [(* r 255) (* g 255) (* b 255)] + base (image:make-image 1920 1080 color) + + ;; Apply rotation (this is the main GPU bottleneck we optimized) + angle (* t 30) + rotated (geometry:rotate base angle) + + ;; Apply hue shift + hue-shift (* 180 (math:sin (* t 0.5))) + hued (color_ops:hue-shift rotated hue-shift) + + ;; Apply brightness based on time + brightness (+ 0.8 (* 0.4 (math:sin (* t 2)))) + bright (color_ops:brightness hued brightness)] + + bright))) diff --git a/test_simple.sexp b/test_simple.sexp new file mode 100644 index 0000000..c5a0b30 --- /dev/null +++ b/test_simple.sexp @@ -0,0 +1,26 @@ +;; Simple Test - No external assets required +;; Just generates a color gradient that changes over time + +(stream "simple_test" + :fps 30 + :width 720 + :height 720 + :seed 42 + + ;; Load standard primitives + (require-primitives "geometry") + (require-primitives "core") + (require-primitives "math") + (require-primitives "image") + (require-primitives "color_ops") + + ;; Frame pipeline - animated gradient + (frame + (let [;; Time-based color cycling (0-1 range) + r (+ 0.5 (* 0.5 (math:sin (* t 1)))) + g (+ 0.5 (* 0.5 (math:sin (* t 1.3)))) + b (+ 0.5 (* 0.5 (math:sin (* t 1.7)))) + ;; Convert to 0-255 range and create solid color frame + color [(* r 255) (* g 255) (* b 255)] + frame (image:make-image 720 720 color)] + frame))) From c590f2e0399296dc026b06d8e0e7b95d99de84da Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:10:04 +0000 Subject: [PATCH 07/24] Squashed 'test/' content from commit f2edc20 git-subtree-dir: test git-subtree-split: f2edc20cba865a6ef67ca807c2ed6cee8e6c2836 --- .gitignore | 19 + analyze.py | 223 ++ cache.py | 404 +++ configs/audio-dizzy.sexp | 17 + configs/audio-halleluwah.sexp | 17 + configs/sources-default.sexp | 38 + configs/sources-woods-half.sexp | 19 + configs/sources-woods.sexp | 39 + effects/quick_test_explicit.sexp | 150 ++ execute.py | 2368 +++++++++++++++++ plan.py | 415 +++ run-effect.sh | 122 + run-file.sh | 7 + run.py | 127 + run.sh | 7 + run_staged.py | 528 ++++ sexp_effects/__init__.py | 32 + sexp_effects/effects/ascii_art.sexp | 17 + sexp_effects/effects/ascii_art_fx.sexp | 52 + sexp_effects/effects/ascii_fx_zone.sexp | 102 + sexp_effects/effects/ascii_zones.sexp | 30 + sexp_effects/effects/blend.sexp | 31 + sexp_effects/effects/blend_multi.sexp | 58 + sexp_effects/effects/bloom.sexp | 16 + sexp_effects/effects/blur.sexp | 8 + sexp_effects/effects/brightness.sexp | 9 + sexp_effects/effects/color-adjust.sexp | 13 + sexp_effects/effects/color_cycle.sexp | 13 + sexp_effects/effects/contrast.sexp | 9 + sexp_effects/effects/crt.sexp | 30 + sexp_effects/effects/datamosh.sexp | 14 + sexp_effects/effects/echo.sexp | 19 + sexp_effects/effects/edge_detect.sexp | 9 + sexp_effects/effects/emboss.sexp | 13 + sexp_effects/effects/film_grain.sexp | 19 + sexp_effects/effects/fisheye.sexp | 16 + sexp_effects/effects/flip.sexp | 16 + sexp_effects/effects/grayscale.sexp | 7 + sexp_effects/effects/hue_shift.sexp | 12 + sexp_effects/effects/invert.sexp | 9 + sexp_effects/effects/kaleidoscope.sexp | 20 + sexp_effects/effects/layer.sexp | 36 + sexp_effects/effects/mirror.sexp | 33 + sexp_effects/effects/neon_glow.sexp | 23 + sexp_effects/effects/noise.sexp | 8 + sexp_effects/effects/outline.sexp | 24 + sexp_effects/effects/pixelate.sexp | 13 + sexp_effects/effects/pixelsort.sexp | 11 + sexp_effects/effects/posterize.sexp | 8 + sexp_effects/effects/resize-frame.sexp | 11 + sexp_effects/effects/rgb_split.sexp | 13 + sexp_effects/effects/ripple.sexp | 19 + sexp_effects/effects/rotate.sexp | 11 + sexp_effects/effects/saturation.sexp | 9 + sexp_effects/effects/scanlines.sexp | 15 + sexp_effects/effects/sepia.sexp | 7 + sexp_effects/effects/sharpen.sexp | 8 + sexp_effects/effects/strobe.sexp | 16 + sexp_effects/effects/swirl.sexp | 17 + sexp_effects/effects/threshold.sexp | 9 + sexp_effects/effects/tile_grid.sexp | 29 + sexp_effects/effects/trails.sexp | 20 + sexp_effects/effects/vignette.sexp | 23 + sexp_effects/effects/wave.sexp | 22 + sexp_effects/effects/zoom.sexp | 8 + sexp_effects/interpreter.py | 1016 ++++++++ sexp_effects/parser.py | 168 ++ sexp_effects/primitive_libs/__init__.py | 102 + sexp_effects/primitive_libs/arrays.py | 196 ++ sexp_effects/primitive_libs/ascii.py | 388 +++ sexp_effects/primitive_libs/blending.py | 116 + sexp_effects/primitive_libs/color.py | 137 + sexp_effects/primitive_libs/color_ops.py | 90 + sexp_effects/primitive_libs/core.py | 271 ++ sexp_effects/primitive_libs/drawing.py | 136 + sexp_effects/primitive_libs/filters.py | 119 + sexp_effects/primitive_libs/geometry.py | 143 + sexp_effects/primitive_libs/image.py | 144 + sexp_effects/primitive_libs/math.py | 164 ++ sexp_effects/primitive_libs/streaming.py | 462 ++++ sexp_effects/primitives.py | 3043 ++++++++++++++++++++++ sexp_effects/test_interpreter.py | 236 ++ streaming/__init__.py | 44 + streaming/audio.py | 486 ++++ streaming/backends.py | 308 +++ streaming/compositor.py | 595 +++++ streaming/demo.py | 125 + streaming/output.py | 369 +++ streaming/pipeline.py | 846 ++++++ streaming/recipe_adapter.py | 470 ++++ streaming/recipe_executor.py | 415 +++ streaming/sexp_executor.py | 678 +++++ streaming/sexp_interp.py | 376 +++ streaming/sources.py | 281 ++ streaming/stream_sexp.py | 1081 ++++++++ streaming/stream_sexp_generic.py | 859 ++++++ templates/crossfade-zoom.sexp | 25 + templates/cycle-crossfade.sexp | 65 + templates/process-pair.sexp | 112 + templates/scan-oscillating-spin.sexp | 28 + templates/scan-ripple-drops.sexp | 41 + templates/standard-effects.sexp | 22 + templates/standard-primitives.sexp | 14 + templates/stream-process-pair.sexp | 72 + test_effects_pipeline.py | 258 ++ 105 files changed, 19968 insertions(+) create mode 100644 .gitignore create mode 100644 analyze.py create mode 100644 cache.py create mode 100644 configs/audio-dizzy.sexp create mode 100644 configs/audio-halleluwah.sexp create mode 100644 configs/sources-default.sexp create mode 100644 configs/sources-woods-half.sexp create mode 100644 configs/sources-woods.sexp create mode 100644 effects/quick_test_explicit.sexp create mode 100644 execute.py create mode 100644 plan.py create mode 100644 run-effect.sh create mode 100755 run-file.sh create mode 100755 run.py create mode 100755 run.sh create mode 100644 run_staged.py create mode 100644 sexp_effects/__init__.py create mode 100644 sexp_effects/effects/ascii_art.sexp create mode 100644 sexp_effects/effects/ascii_art_fx.sexp create mode 100644 sexp_effects/effects/ascii_fx_zone.sexp create mode 100644 sexp_effects/effects/ascii_zones.sexp create mode 100644 sexp_effects/effects/blend.sexp create mode 100644 sexp_effects/effects/blend_multi.sexp create mode 100644 sexp_effects/effects/bloom.sexp create mode 100644 sexp_effects/effects/blur.sexp create mode 100644 sexp_effects/effects/brightness.sexp create mode 100644 sexp_effects/effects/color-adjust.sexp create mode 100644 sexp_effects/effects/color_cycle.sexp create mode 100644 sexp_effects/effects/contrast.sexp create mode 100644 sexp_effects/effects/crt.sexp create mode 100644 sexp_effects/effects/datamosh.sexp create mode 100644 sexp_effects/effects/echo.sexp create mode 100644 sexp_effects/effects/edge_detect.sexp create mode 100644 sexp_effects/effects/emboss.sexp create mode 100644 sexp_effects/effects/film_grain.sexp create mode 100644 sexp_effects/effects/fisheye.sexp create mode 100644 sexp_effects/effects/flip.sexp create mode 100644 sexp_effects/effects/grayscale.sexp create mode 100644 sexp_effects/effects/hue_shift.sexp create mode 100644 sexp_effects/effects/invert.sexp create mode 100644 sexp_effects/effects/kaleidoscope.sexp create mode 100644 sexp_effects/effects/layer.sexp create mode 100644 sexp_effects/effects/mirror.sexp create mode 100644 sexp_effects/effects/neon_glow.sexp create mode 100644 sexp_effects/effects/noise.sexp create mode 100644 sexp_effects/effects/outline.sexp create mode 100644 sexp_effects/effects/pixelate.sexp create mode 100644 sexp_effects/effects/pixelsort.sexp create mode 100644 sexp_effects/effects/posterize.sexp create mode 100644 sexp_effects/effects/resize-frame.sexp create mode 100644 sexp_effects/effects/rgb_split.sexp create mode 100644 sexp_effects/effects/ripple.sexp create mode 100644 sexp_effects/effects/rotate.sexp create mode 100644 sexp_effects/effects/saturation.sexp create mode 100644 sexp_effects/effects/scanlines.sexp create mode 100644 sexp_effects/effects/sepia.sexp create mode 100644 sexp_effects/effects/sharpen.sexp create mode 100644 sexp_effects/effects/strobe.sexp create mode 100644 sexp_effects/effects/swirl.sexp create mode 100644 sexp_effects/effects/threshold.sexp create mode 100644 sexp_effects/effects/tile_grid.sexp create mode 100644 sexp_effects/effects/trails.sexp create mode 100644 sexp_effects/effects/vignette.sexp create mode 100644 sexp_effects/effects/wave.sexp create mode 100644 sexp_effects/effects/zoom.sexp create mode 100644 sexp_effects/interpreter.py create mode 100644 sexp_effects/parser.py create mode 100644 sexp_effects/primitive_libs/__init__.py create mode 100644 sexp_effects/primitive_libs/arrays.py create mode 100644 sexp_effects/primitive_libs/ascii.py create mode 100644 sexp_effects/primitive_libs/blending.py create mode 100644 sexp_effects/primitive_libs/color.py create mode 100644 sexp_effects/primitive_libs/color_ops.py create mode 100644 sexp_effects/primitive_libs/core.py create mode 100644 sexp_effects/primitive_libs/drawing.py create mode 100644 sexp_effects/primitive_libs/filters.py create mode 100644 sexp_effects/primitive_libs/geometry.py create mode 100644 sexp_effects/primitive_libs/image.py create mode 100644 sexp_effects/primitive_libs/math.py create mode 100644 sexp_effects/primitive_libs/streaming.py create mode 100644 sexp_effects/primitives.py create mode 100644 sexp_effects/test_interpreter.py create mode 100644 streaming/__init__.py create mode 100644 streaming/audio.py create mode 100644 streaming/backends.py create mode 100644 streaming/compositor.py create mode 100644 streaming/demo.py create mode 100644 streaming/output.py create mode 100644 streaming/pipeline.py create mode 100644 streaming/recipe_adapter.py create mode 100644 streaming/recipe_executor.py create mode 100644 streaming/sexp_executor.py create mode 100644 streaming/sexp_interp.py create mode 100644 streaming/sources.py create mode 100644 streaming/stream_sexp.py create mode 100644 streaming/stream_sexp_generic.py create mode 100644 templates/crossfade-zoom.sexp create mode 100644 templates/cycle-crossfade.sexp create mode 100644 templates/process-pair.sexp create mode 100644 templates/scan-oscillating-spin.sexp create mode 100644 templates/scan-ripple-drops.sexp create mode 100644 templates/standard-effects.sexp create mode 100644 templates/standard-primitives.sexp create mode 100644 templates/stream-process-pair.sexp create mode 100644 test_effects_pipeline.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..16ad0a0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Python cache +__pycache__/ +*.pyc + +# Media files +*.mp4 +*.mkv +*.webm +*.mp3 + +# Output files +*.json + +# Cache directories +.cache/ +.stage_cache/ +effects/.stage_cache/ +local_server/.cache/ +local_server/.data/ diff --git a/analyze.py b/analyze.py new file mode 100644 index 0000000..1affa85 --- /dev/null +++ b/analyze.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Run analyzers from a recipe and output analysis data as S-expressions. + +Usage: + analyze.py recipe.sexp [-o analysis.sexp] + +Output format: + (analysis + (beats-data + :tempo 120.5 + :times (0.0 0.5 1.0 1.5 ...) + :duration 10.0) + (bass-data + :times (0.0 0.1 0.2 ...) + :values (0.5 0.8 0.3 ...))) +""" + +import sys +import tempfile +import subprocess +import importlib.util +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent / "artdag")) + +from artdag.sexp import compile_string, parse +from artdag.sexp.parser import Symbol, Keyword, serialize + + +def load_analyzer(analyzer_path: Path): + """Load an analyzer module from file path.""" + spec = importlib.util.spec_from_file_location("analyzer", analyzer_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def run_analyzer(analyzer_path: Path, input_path: Path, params: dict) -> dict: + """Run an analyzer and return results.""" + analyzer = load_analyzer(analyzer_path) + return analyzer.analyze(input_path, params) + + +def pre_execute_segment(source_path: Path, start: float, duration: float, work_dir: Path) -> Path: + """Pre-execute a segment to get audio for analysis.""" + suffix = source_path.suffix.lower() + is_audio = suffix in ('.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a') + + output_ext = ".m4a" if is_audio else ".mp4" + output_path = work_dir / f"segment{output_ext}" + + cmd = ["ffmpeg", "-y", "-i", str(source_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + + if is_audio: + cmd.extend(["-c:a", "aac", str(output_path)]) + else: + cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18", + "-c:a", "aac", str(output_path)]) + + subprocess.run(cmd, check=True, capture_output=True) + return output_path + + +def to_sexp(value, indent=0): + """Convert a Python value to S-expression string.""" + if isinstance(value, dict): + if not value: + return "()" + items = [] + for k, v in value.items(): + key = k.replace('_', '-') + items.append(f":{key} {to_sexp(v)}") + return "(" + " ".join(items) + ")" + elif isinstance(value, list): + if not value: + return "()" + items = [to_sexp(v) for v in value] + return "(" + " ".join(items) + ")" + elif isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return "true" if value else "false" + elif value is None: + return "nil" + elif isinstance(value, float): + return f"{value:.6g}" + else: + return str(value) + + +def analyze_recipe(recipe_path: Path, output_file: Path = None): + """Run all analyzers in a recipe and output S-expression analysis data.""" + + recipe_text = recipe_path.read_text() + recipe_dir = recipe_path.parent + + print(f"Compiling: {recipe_path}", file=sys.stderr) + compiled = compile_string(recipe_text) + print(f"Recipe: {compiled.name} v{compiled.version}", file=sys.stderr) + + # Find all ANALYZE nodes and their dependencies + nodes_by_id = {n["id"]: n for n in compiled.nodes} + + # Track source paths and segment outputs + source_paths = {} + segment_outputs = {} + analysis_results = {} + + work_dir = Path(tempfile.mkdtemp(prefix="artdag_analyze_")) + + # Process nodes in dependency order + def get_input_path(node_id: str) -> Path: + """Resolve the input path for a node.""" + if node_id in segment_outputs: + return segment_outputs[node_id] + if node_id in source_paths: + return source_paths[node_id] + + node = nodes_by_id.get(node_id) + if not node: + return None + + if node["type"] == "SOURCE": + path = recipe_dir / node["config"].get("path", "") + source_paths[node_id] = path.resolve() + return source_paths[node_id] + + if node["type"] == "SEGMENT": + inputs = node.get("inputs", []) + if inputs: + input_path = get_input_path(inputs[0]) + if input_path: + config = node.get("config", {}) + start = config.get("start", 0) + duration = config.get("duration") + output = pre_execute_segment(input_path, start, duration, work_dir) + segment_outputs[node_id] = output + return output + + return None + + # Find and run all analyzers + for node in compiled.nodes: + if node["type"] == "ANALYZE": + config = node.get("config", {}) + analyzer_name = config.get("analyzer", "unknown") + analyzer_path = config.get("analyzer_path") + + if not analyzer_path: + print(f" Skipping {analyzer_name}: no path", file=sys.stderr) + continue + + # Get input + inputs = node.get("inputs", []) + if not inputs: + print(f" Skipping {analyzer_name}: no inputs", file=sys.stderr) + continue + + input_path = get_input_path(inputs[0]) + if not input_path or not input_path.exists(): + print(f" Skipping {analyzer_name}: input not found", file=sys.stderr) + continue + + # Run analyzer + full_path = recipe_dir / analyzer_path + params = {k: v for k, v in config.items() + if k not in ("analyzer", "analyzer_path", "cid")} + + print(f" Running analyzer: {analyzer_name}", file=sys.stderr) + results = run_analyzer(full_path, input_path, params) + + # Store by node ID for uniqueness (multiple analyzers may have same type) + node_id = node.get("id") + analysis_results[node_id] = results + + times = results.get("times", []) + print(f" {len(times)} times @ {results.get('tempo', 0):.1f} BPM", file=sys.stderr) + + # Generate S-expression output + lines = ["(analysis"] + + for name, data in analysis_results.items(): + # Quote node IDs to prevent parser treating hex like "0e42..." as scientific notation + lines.append(f' ("{name}"') + for key, value in data.items(): + sexp_key = key.replace('_', '-') + sexp_value = to_sexp(value) + lines.append(f" :{sexp_key} {sexp_value}") + lines.append(" )") + + lines.append(")") + + output = "\n".join(lines) + + if output_file: + output_file.write_text(output) + print(f"\nAnalysis written to: {output_file}", file=sys.stderr) + else: + print(output) + + print(f"Debug: temp files in {work_dir}", file=sys.stderr) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run analyzers from recipe") + parser.add_argument("recipe", type=Path, help="Recipe file (.sexp)") + parser.add_argument("-o", "--output", type=Path, help="Output file (default: stdout)") + + args = parser.parse_args() + + if not args.recipe.exists(): + print(f"Recipe not found: {args.recipe}", file=sys.stderr) + sys.exit(1) + + analyze_recipe(args.recipe, args.output) diff --git a/cache.py b/cache.py new file mode 100644 index 0000000..fdb9a7e --- /dev/null +++ b/cache.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +""" +Unified content cache for artdag. + +Design: + - IPNS (cache_id) = computation hash, known BEFORE execution + "What would be the result of running X with inputs Y?" + + - CID = content hash, known AFTER execution + "What is this actual content?" + +Structure: + .cache/ + refs/ # IPNS → CID mappings (computation → result) + {cache_id} # Text file containing the CID of the result + content/ # Content-addressed storage + {cid[:2]}/{cid} # Actual content by CID +""" + +import hashlib +import json +import os +from pathlib import Path +from typing import Optional, Dict, Any, Tuple + +# Default cache location - can be overridden via ARTDAG_CACHE env var +DEFAULT_CACHE_DIR = Path(__file__).parent / ".cache" + + +def get_cache_dir() -> Path: + """Get the cache directory, creating if needed.""" + cache_dir = Path(os.environ.get("ARTDAG_CACHE", DEFAULT_CACHE_DIR)) + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def get_refs_dir() -> Path: + """Get the refs directory (IPNS → CID mappings).""" + refs_dir = get_cache_dir() / "refs" + refs_dir.mkdir(parents=True, exist_ok=True) + return refs_dir + + +def get_content_dir() -> Path: + """Get the content directory (CID → content).""" + content_dir = get_cache_dir() / "content" + content_dir.mkdir(parents=True, exist_ok=True) + return content_dir + + +# ============================================================================= +# CID (Content Hash) Operations +# ============================================================================= + +def compute_cid(content: bytes) -> str: + """Compute content ID (SHA256 hash) for bytes.""" + return hashlib.sha256(content).hexdigest() + + +def compute_file_cid(file_path: Path) -> str: + """Compute content ID for a file.""" + with open(file_path, 'rb') as f: + return compute_cid(f.read()) + + +def compute_string_cid(text: str) -> str: + """Compute content ID for a string.""" + return compute_cid(text.encode('utf-8')) + + +# ============================================================================= +# Content Storage (by CID) +# ============================================================================= + +def _content_path(cid: str) -> Path: + """Get path for content by CID.""" + return get_content_dir() / cid[:2] / cid + + +def content_exists_by_cid(cid: str) -> Optional[Path]: + """Check if content exists by CID.""" + path = _content_path(cid) + if path.exists() and path.stat().st_size > 0: + return path + return None + + +def content_store_by_cid(cid: str, content: bytes) -> Path: + """Store content by its CID.""" + path = _content_path(cid) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(content) + return path + + +def content_store_file(file_path: Path) -> Tuple[str, Path]: + """Store a file by its content hash. Returns (cid, path).""" + content = file_path.read_bytes() + cid = compute_cid(content) + path = content_store_by_cid(cid, content) + return cid, path + + +def content_store_string(text: str) -> Tuple[str, Path]: + """Store a string by its content hash. Returns (cid, path).""" + content = text.encode('utf-8') + cid = compute_cid(content) + path = content_store_by_cid(cid, content) + return cid, path + + +def content_get(cid: str) -> Optional[bytes]: + """Get content by CID.""" + path = content_exists_by_cid(cid) + if path: + return path.read_bytes() + return None + + +def content_get_string(cid: str) -> Optional[str]: + """Get string content by CID.""" + content = content_get(cid) + if content: + return content.decode('utf-8') + return None + + +# ============================================================================= +# Refs (IPNS → CID mappings) +# ============================================================================= + +def _ref_path(cache_id: str) -> Path: + """Get path for a ref by cache_id.""" + return get_refs_dir() / cache_id + + +def ref_exists(cache_id: str) -> Optional[str]: + """Check if a ref exists. Returns CID if found.""" + path = _ref_path(cache_id) + if path.exists(): + return path.read_text().strip() + return None + + +def ref_set(cache_id: str, cid: str) -> Path: + """Set a ref (IPNS → CID mapping).""" + path = _ref_path(cache_id) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(cid) + return path + + +def ref_get_content(cache_id: str) -> Optional[bytes]: + """Get content by cache_id (looks up ref, then fetches content).""" + cid = ref_exists(cache_id) + if cid: + return content_get(cid) + return None + + +def ref_get_string(cache_id: str) -> Optional[str]: + """Get string content by cache_id.""" + content = ref_get_content(cache_id) + if content: + return content.decode('utf-8') + return None + + +# ============================================================================= +# High-level Cache Operations +# ============================================================================= + +def cache_store(cache_id: str, content: bytes) -> Tuple[str, Path]: + """ + Store content with IPNS→CID indirection. + + Args: + cache_id: Computation hash (IPNS address) + content: Content to store + + Returns: + (cid, path) tuple + """ + cid = compute_cid(content) + path = content_store_by_cid(cid, content) + ref_set(cache_id, cid) + return cid, path + + +def cache_store_file(cache_id: str, file_path: Path) -> Tuple[str, Path]: + """Store a file with IPNS→CID indirection.""" + content = file_path.read_bytes() + return cache_store(cache_id, content) + + +def cache_store_string(cache_id: str, text: str) -> Tuple[str, Path]: + """Store a string with IPNS→CID indirection.""" + return cache_store(cache_id, text.encode('utf-8')) + + +def cache_store_json(cache_id: str, data: Any) -> Tuple[str, Path]: + """Store JSON data with IPNS→CID indirection.""" + text = json.dumps(data, indent=2) + return cache_store_string(cache_id, text) + + +def cache_exists(cache_id: str) -> Optional[Path]: + """Check if cached content exists for a computation.""" + cid = ref_exists(cache_id) + if cid: + return content_exists_by_cid(cid) + return None + + +def cache_get(cache_id: str) -> Optional[bytes]: + """Get cached content by computation hash.""" + return ref_get_content(cache_id) + + +def cache_get_string(cache_id: str) -> Optional[str]: + """Get cached string by computation hash.""" + return ref_get_string(cache_id) + + +def cache_get_json(cache_id: str) -> Optional[Any]: + """Get cached JSON by computation hash.""" + text = cache_get_string(cache_id) + if text: + return json.loads(text) + return None + + +def cache_get_path(cache_id: str) -> Optional[Path]: + """Get path to cached content by computation hash.""" + cid = ref_exists(cache_id) + if cid: + return content_exists_by_cid(cid) + return None + + +# ============================================================================= +# Plan Cache (convenience wrappers) +# ============================================================================= + +def _stable_hash_params(params: Dict[str, Any]) -> str: + """Compute stable hash of params using JSON + SHA256 (consistent with CID).""" + params_str = json.dumps(params, sort_keys=True, default=str) + return hashlib.sha256(params_str.encode()).hexdigest() + + +def plan_cache_id(source_cid: str, params: Dict[str, Any] = None) -> str: + """ + Compute the cache_id (IPNS address) for a plan. + + Based on source CID + params. Name/version are just metadata. + """ + key = f"plan:{source_cid}" + if params: + params_hash = _stable_hash_params(params) + key = f"{key}:{params_hash}" + return hashlib.sha256(key.encode()).hexdigest() + + +def plan_exists(source_cid: str, params: Dict[str, Any] = None) -> Optional[str]: + """Check if a cached plan exists. Returns CID if found.""" + cache_id = plan_cache_id(source_cid, params) + return ref_exists(cache_id) + + +def plan_store(source_cid: str, params: Dict[str, Any], content: str) -> Tuple[str, str, Path]: + """ + Store a plan in the cache. + + Returns: + (cache_id, cid, path) tuple + """ + cache_id = plan_cache_id(source_cid, params) + cid, path = cache_store_string(cache_id, content) + return cache_id, cid, path + + +def plan_load(source_cid: str, params: Dict[str, Any] = None) -> Optional[str]: + """Load a plan from cache. Returns plan content string.""" + cache_id = plan_cache_id(source_cid, params) + return cache_get_string(cache_id) + + +def plan_get_path(source_cid: str, params: Dict[str, Any] = None) -> Optional[Path]: + """Get path to cached plan.""" + cache_id = plan_cache_id(source_cid, params) + return cache_get_path(cache_id) + + +# ============================================================================= +# Cache Listing +# ============================================================================= + +def list_cache(verbose: bool = False) -> Dict[str, Any]: + """List all cached items.""" + from datetime import datetime + + cache_dir = get_cache_dir() + refs_dir = get_refs_dir() + content_dir = get_content_dir() + + def format_size(size): + if size >= 1_000_000_000: + return f"{size / 1_000_000_000:.1f}GB" + elif size >= 1_000_000: + return f"{size / 1_000_000:.1f}MB" + elif size >= 1000: + return f"{size / 1000:.1f}KB" + else: + return f"{size}B" + + def get_file_info(path: Path) -> Dict: + stat = path.stat() + return { + "path": path, + "name": path.name, + "size": stat.st_size, + "size_str": format_size(stat.st_size), + "mtime": datetime.fromtimestamp(stat.st_mtime), + } + + result = { + "refs": [], + "content": [], + "summary": {"total_items": 0, "total_size": 0}, + } + + # Refs + if refs_dir.exists(): + for f in sorted(refs_dir.iterdir()): + if f.is_file(): + info = get_file_info(f) + info["cache_id"] = f.name + info["cid"] = f.read_text().strip() + # Try to determine type from content + cid = info["cid"] + content_path = content_exists_by_cid(cid) + if content_path: + info["content_size"] = content_path.stat().st_size + info["content_size_str"] = format_size(info["content_size"]) + result["refs"].append(info) + + # Content + if content_dir.exists(): + for subdir in sorted(content_dir.iterdir()): + if subdir.is_dir(): + for f in sorted(subdir.iterdir()): + if f.is_file(): + info = get_file_info(f) + info["cid"] = f.name + result["content"].append(info) + + # Summary + result["summary"]["total_refs"] = len(result["refs"]) + result["summary"]["total_content"] = len(result["content"]) + result["summary"]["total_size"] = sum(i["size"] for i in result["content"]) + result["summary"]["total_size_str"] = format_size(result["summary"]["total_size"]) + + return result + + +def print_cache_listing(verbose: bool = False): + """Print cache listing to stdout.""" + info = list_cache(verbose) + cache_dir = get_cache_dir() + + print(f"\nCache directory: {cache_dir}\n") + + # Refs summary + if info["refs"]: + print(f"=== Refs ({len(info['refs'])}) ===") + for ref in info["refs"][:20]: # Show first 20 + content_info = f" → {ref.get('content_size_str', '?')}" if 'content_size_str' in ref else "" + print(f" {ref['cache_id'][:16]}... → {ref['cid'][:16]}...{content_info}") + if len(info["refs"]) > 20: + print(f" ... and {len(info['refs']) - 20} more") + print() + + # Content by type + if info["content"]: + # Group by first 2 chars (subdirectory) + print(f"=== Content ({len(info['content'])} items, {info['summary']['total_size_str']}) ===") + for item in info["content"][:20]: + print(f" {item['cid'][:16]}... {item['size_str']:>8} {item['mtime'].strftime('%Y-%m-%d %H:%M')}") + if len(info["content"]) > 20: + print(f" ... and {len(info['content']) - 20} more") + print() + + print(f"=== Summary ===") + print(f" Refs: {info['summary']['total_refs']}") + print(f" Content: {info['summary']['total_content']} ({info['summary']['total_size_str']})") + + if verbose: + print(f"\nTo clear cache: rm -rf {cache_dir}/*") + + +if __name__ == "__main__": + import sys + verbose = "-v" in sys.argv or "--verbose" in sys.argv + print_cache_listing(verbose) diff --git a/configs/audio-dizzy.sexp b/configs/audio-dizzy.sexp new file mode 100644 index 0000000..dc16087 --- /dev/null +++ b/configs/audio-dizzy.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Paths relative to working directory (project root) +(def music (streaming:make-audio-analyzer "dizzy.mp3")) + +;; Audio playback path (for sync with video output) +(audio-playback "dizzy.mp3") diff --git a/configs/audio-halleluwah.sexp b/configs/audio-halleluwah.sexp new file mode 100644 index 0000000..5e4b812 --- /dev/null +++ b/configs/audio-halleluwah.sexp @@ -0,0 +1,17 @@ +;; Audio Configuration - dizzy.mp3 +;; +;; Defines audio analyzer and playback for a recipe. +;; Pass to recipe with: --audio configs/audio-dizzy.sexp +;; +;; Provides: +;; - music: audio analyzer for beat/energy detection +;; - audio-playback: path for synchronized playback + +(require-primitives "streaming") + +;; Audio analyzer (provides beat detection and energy levels) +;; Paths relative to working directory (project root) +(def music (streaming:make-audio-analyzer "woods_half/halleluwah.webm")) + +;; Audio playback path (for sync with video output) +(audio-playback "woods_half/halleluwah.webm") diff --git a/configs/sources-default.sexp b/configs/sources-default.sexp new file mode 100644 index 0000000..754bd92 --- /dev/null +++ b/configs/sources-default.sexp @@ -0,0 +1,38 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Paths relative to working directory (project root) +(def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 +]) diff --git a/configs/sources-woods-half.sexp b/configs/sources-woods-half.sexp new file mode 100644 index 0000000..d2feff8 --- /dev/null +++ b/configs/sources-woods-half.sexp @@ -0,0 +1,19 @@ +;; Half-resolution Woods Sources (960x540) +;; +;; Pass to recipe with: --sources configs/sources-woods-half.sexp + +(require-primitives "streaming") + +(def sources [ + (streaming:make-video-source "woods_half/1.webm" 30) + (streaming:make-video-source "woods_half/2.webm" 30) + (streaming:make-video-source "woods_half/3.webm" 30) + (streaming:make-video-source "woods_half/4.webm" 30) +]) + +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} +]) diff --git a/configs/sources-woods.sexp b/configs/sources-woods.sexp new file mode 100644 index 0000000..717bfd9 --- /dev/null +++ b/configs/sources-woods.sexp @@ -0,0 +1,39 @@ +;; Default Sources Configuration +;; +;; Defines video sources and per-pair effect configurations. +;; Pass to recipe with: --sources configs/sources-default.sexp +;; +;; Required by recipes using process-pair macro: +;; - sources: array of video sources +;; - pair-configs: array of effect configurations per source + +(require-primitives "streaming") + +;; Video sources array +;; Paths relative to working directory (project root) +(def sources [ + (streaming:make-video-source "woods/1.webm" 10) + (streaming:make-video-source "woods/2.webm" 10) + (streaming:make-video-source "woods/3.webm" 10) + (streaming:make-video-source "woods/4.webm" 10) + (streaming:make-video-source "woods/5.webm" 10) + (streaming:make-video-source "woods/6.webm" 10) + (streaming:make-video-source "woods/7.webm" 10) + (streaming:make-video-source "woods/8.webm" 10) +]) + +;; Per-pair effect config: rotation direction, rotation ranges, zoom ranges +;; :dir = rotation direction (1 or -1) +;; :rot-a, :rot-b = max rotation angles for clip A and B +;; :zoom-a, :zoom-b = max zoom amounts for clip A and B +(def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + +]) diff --git a/effects/quick_test_explicit.sexp b/effects/quick_test_explicit.sexp new file mode 100644 index 0000000..0a3698b --- /dev/null +++ b/effects/quick_test_explicit.sexp @@ -0,0 +1,150 @@ +;; Quick Test - Fully Explicit Streaming Version +;; +;; The interpreter is completely generic - knows nothing about video/audio. +;; All domain logic is explicit via primitives. +;; +;; Run with built-in sources/audio: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp --fps 30 +;; +;; Run with external config files: +;; python3 -m streaming.stream_sexp_generic effects/quick_test_explicit.sexp \ +;; --sources configs/sources-default.sexp \ +;; --audio configs/audio-dizzy.sexp \ +;; --fps 30 + +(stream "quick_test_explicit" + :fps 30 + :width 1920 + :height 1080 + :seed 42 + + ;; Load standard primitives and effects + (include :path "../templates/standard-primitives.sexp") + (include :path "../templates/standard-effects.sexp") + + ;; Load reusable templates + (include :path "../templates/stream-process-pair.sexp") + (include :path "../templates/crossfade-zoom.sexp") + + ;; === SOURCES AS ARRAY === + (def sources [ + (streaming:make-video-source "monday.webm" 30) + (streaming:make-video-source "escher.webm" 30) + (streaming:make-video-source "2.webm" 30) + (streaming:make-video-source "disruptors.webm" 30) + (streaming:make-video-source "4.mp4" 30) + (streaming:make-video-source "ecstacy.mp4" 30) + (streaming:make-video-source "dopple.webm" 30) + (streaming:make-video-source "5.mp4" 30) + ]) + + ;; Per-pair config: [rot-dir, rot-a-max, rot-b-max, zoom-a-max, zoom-b-max] + ;; Pairs 3,6: reversed (negative rot-a, positive rot-b, shrink zoom-a, grow zoom-b) + ;; Pair 5: smaller ranges + (def pair-configs [ + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 0: monday + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 1: escher + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 2: vid2 + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 3: disruptors (reversed) + {:dir -1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 4: vid4 + {:dir 1 :rot-a 30 :rot-b -30 :zoom-a 1.3 :zoom-b 0.7} ;; 5: ecstacy (smaller) + {:dir -1 :rot-a -45 :rot-b 45 :zoom-a 0.5 :zoom-b 1.5} ;; 6: dopple (reversed) + {:dir 1 :rot-a 45 :rot-b -45 :zoom-a 1.5 :zoom-b 0.5} ;; 7: vid5 + ]) + + ;; Audio analyzer + (def music (streaming:make-audio-analyzer "dizzy.mp3")) + + ;; Audio playback + (audio-playback "../dizzy.mp3") + + ;; === GLOBAL SCANS === + + ;; Cycle state: which source is active (recipe-specific) + ;; clen = beats per source (8-24 beats = ~4-12 seconds) + (scan cycle (streaming:audio-beat music t) + :init {:active 0 :beat 0 :clen 16} + :step (if (< (+ beat 1) clen) + (dict :active active :beat (+ beat 1) :clen clen) + (dict :active (mod (+ active 1) (len sources)) :beat 0 + :clen (+ 8 (mod (* (streaming:audio-beat-count music t) 7) 17))))) + + ;; Reusable scans from templates (require 'music' to be defined) + (include :path "../templates/scan-oscillating-spin.sexp") + (include :path "../templates/scan-ripple-drops.sexp") + + ;; === PER-PAIR STATE (dynamically sized based on sources) === + ;; Each pair has: inv-a, inv-b, hue-a, hue-b, mix, rot-angle + (scan pairs (streaming:audio-beat music t) + :init {:states (map (core:range (len sources)) (lambda (_) + {:inv-a 0 :inv-b 0 :hue-a 0 :hue-b 0 :hue-a-val 0 :hue-b-val 0 :mix 0.5 :mix-rem 5 :angle 0 :rot-beat 0 :rot-clen 25}))} + :step (dict :states (map states (lambda (p) + (let [;; Invert toggles (10% chance, lasts 1-4 beats) + new-inv-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-a) 1))) + new-inv-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- (get p :inv-b) 1))) + ;; Hue shifts (10% chance, lasts 1-4 beats) - use countdown like invert + old-hue-a (get p :hue-a) + old-hue-b (get p :hue-b) + new-hue-a (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-a 1))) + new-hue-b (if (< (core:rand) 0.1) (+ 1 (core:rand-int 1 4)) (core:max 0 (- old-hue-b 1))) + ;; Pick random hue value when triggering (stored separately) + new-hue-a-val (if (> new-hue-a old-hue-a) (+ 30 (* (core:rand) 300)) (get p :hue-a-val)) + new-hue-b-val (if (> new-hue-b old-hue-b) (+ 30 (* (core:rand) 300)) (get p :hue-b-val)) + ;; Mix (holds for 1-10 beats, then picks 0, 0.5, or 1) + mix-rem (get p :mix-rem) + old-mix (get p :mix) + new-mix-rem (if (> mix-rem 0) (- mix-rem 1) (+ 1 (core:rand-int 1 10))) + new-mix (if (> mix-rem 0) old-mix (* (core:rand-int 0 2) 0.5)) + ;; Rotation (accumulates, reverses direction when cycle completes) + rot-beat (get p :rot-beat) + rot-clen (get p :rot-clen) + old-angle (get p :angle) + ;; Note: dir comes from pair-configs, but we store rotation state here + new-rot-beat (if (< (+ rot-beat 1) rot-clen) (+ rot-beat 1) 0) + new-rot-clen (if (< (+ rot-beat 1) rot-clen) rot-clen (+ 20 (core:rand-int 0 10))) + new-angle (+ old-angle (/ 360 rot-clen))] + (dict :inv-a new-inv-a :inv-b new-inv-b + :hue-a new-hue-a :hue-b new-hue-b + :hue-a-val new-hue-a-val :hue-b-val new-hue-b-val + :mix new-mix :mix-rem new-mix-rem + :angle new-angle :rot-beat new-rot-beat :rot-clen new-rot-clen)))))) + + ;; === FRAME PIPELINE === + (frame + (let [now t + e (streaming:audio-energy music now) + + ;; Get cycle state + active (bind cycle :active) + beat-pos (bind cycle :beat) + clen (bind cycle :clen) + + ;; Transition logic: last third of cycle crossfades to next + phase3 (* beat-pos 3) + fading (and (>= phase3 (* clen 2)) (< phase3 (* clen 3))) + fade-amt (if fading (/ (- phase3 (* clen 2)) clen) 0) + next-idx (mod (+ active 1) (len sources)) + + ;; Get pair states array (required by process-pair macro) + pair-states (bind pairs :states) + + ;; Process active pair using macro from template + active-frame (process-pair active) + + ;; Crossfade with zoom during transition (using macro) + result (if fading + (crossfade-zoom active-frame (process-pair next-idx) fade-amt) + active-frame) + + ;; Final: global spin + ripple + spun (rotate result :angle (bind spin :angle)) + rip-gate (bind ripple-state :gate) + rip-amp (* rip-gate (core:map-range e 0 1 5 50))] + + (ripple spun + :amplitude rip-amp + :center_x (bind ripple-state :cx) + :center_y (bind ripple-state :cy) + :frequency 8 + :decay 2 + :speed 5)))) diff --git a/execute.py b/execute.py new file mode 100644 index 0000000..34d94a2 --- /dev/null +++ b/execute.py @@ -0,0 +1,2368 @@ +#!/usr/bin/env python3 +""" +Execute a pre-computed plan. + +Takes a plan file (S-expression) and executes primitive operations, +storing artifacts by their content hash. + +Usage: + analyze.py recipe.sexp > analysis.sexp + plan.py recipe.sexp --analysis analysis.sexp --sexp > plan.sexp + execute.py plan.sexp --analysis analysis.sexp +""" + +import json +import shutil +import subprocess +import sys +import tempfile +import importlib.util +from pathlib import Path +from typing import List + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent / "artdag")) + +from artdag.sexp import parse +from artdag.sexp.parser import Symbol, Keyword +import time +import os +import threading +import concurrent.futures +from itertools import groupby + + +# Limit concurrent raw-video pipelines to prevent memory exhaustion. +# Each pipeline holds raw frames in memory (e.g. ~6MB per 1080p frame) +# and spawns 2+ ffmpeg subprocesses. When the ThreadPoolExecutor runs +# many EFFECT steps in parallel the combined load can freeze the system. +# Default: 1 concurrent pipeline; override with ARTDAG_VIDEO_PIPELINES. +_MAX_VIDEO_PIPELINES = int(os.environ.get("ARTDAG_VIDEO_PIPELINES", 1)) +_video_pipeline_sem = threading.Semaphore(_MAX_VIDEO_PIPELINES) + + +def set_max_video_pipelines(n: int): + """Reconfigure the video-pipeline concurrency limit at runtime.""" + global _video_pipeline_sem, _MAX_VIDEO_PIPELINES + _MAX_VIDEO_PIPELINES = n + _video_pipeline_sem = threading.Semaphore(n) + + +def _video_pipeline_guard(fn): + """Decorator: acquire the video-pipeline semaphore for the call's duration.""" + from functools import wraps + @wraps(fn) + def _guarded(*args, **kwargs): + _video_pipeline_sem.acquire() + try: + return fn(*args, **kwargs) + finally: + _video_pipeline_sem.release() + return _guarded + + +class ProgressBar: + """Simple console progress bar with ETA.""" + + def __init__(self, total: int, desc: str = "", width: int = 30, update_interval: int = 30): + self.total = total + self.desc = desc + self.width = width + self.current = 0 + self.start_time = time.time() + self.update_interval = update_interval + self._last_render = 0 + + def update(self, n: int = 1): + self.current += n + if self.current - self._last_render >= self.update_interval: + self._render() + self._last_render = self.current + + def set(self, n: int): + self.current = n + if self.current - self._last_render >= self.update_interval: + self._render() + self._last_render = self.current + + def _render(self): + elapsed = time.time() - self.start_time + + if self.total == 0: + # Unknown total - just show count + line = f"\r {self.desc} {self.current} frames ({elapsed:.1f}s)" + print(line, end="", file=sys.stderr, flush=True) + return + + pct = self.current / self.total + filled = int(self.width * pct) + bar = "█" * filled + "░" * (self.width - filled) + + if self.current > 0 and pct < 1.0: + eta = elapsed / pct - elapsed + eta_str = f"ETA {eta:.0f}s" + elif pct >= 1.0: + eta_str = f"done in {elapsed:.1f}s" + else: + eta_str = "..." + + line = f"\r {self.desc} |{bar}| {self.current}/{self.total} ({pct*100:.0f}%) {eta_str}" + print(line, end="", file=sys.stderr, flush=True) + + def finish(self): + self._render() + print(file=sys.stderr) # newline + + +def check_cache(cache_dir: Path, cache_id: str, extensions: list) -> Path: + """Check if a cached result exists for a step using IPNS/CID lookup. + + Args: + cache_dir: Cache directory (used for unified cache) + cache_id: IPNS address (computation hash, known before execution) + extensions: List of possible file extensions (for legacy compatibility) + + Returns: + Path to cached content file if found, None otherwise + """ + import cache as unified_cache + + # Look up IPNS → CID mapping + cached_path = unified_cache.cache_exists(cache_id) + if cached_path: + return cached_path + return None + + +def save_to_cache(cache_dir: Path, cache_id: str, source_path: Path) -> Path: + """Save a result to cache using IPNS/CID structure. + + Args: + cache_dir: Cache directory (used for unified cache) + cache_id: IPNS address (computation hash, known before execution) + source_path: Path to the file to cache + + Returns: + Path to the cached content file + """ + import cache as unified_cache + + # Store content by CID, create IPNS → CID ref + cid, cached_path = unified_cache.cache_store_file(cache_id, source_path) + return cached_path + + +def extract_segment_with_loop(input_path: Path, output_path: Path, start: float, duration: float, encoding: dict) -> Path: + """Extract a segment from a video, looping the source if needed to reach requested duration. + + Args: + input_path: Source video file + output_path: Output segment file + start: Start time in seconds + duration: Requested duration in seconds + encoding: Encoding settings dict + + Returns: + Path to the output segment + """ + enc = encoding + fps = enc.get("fps", 30) + + # First attempt without looping + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend(["-r", str(fps), + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-pix_fmt", "yuv420p", + "-c:a", enc.get("audio_codec", "aac"), + str(output_path)]) + + print(f" Extracting segment: start={start}, duration={duration}", file=sys.stderr) + result = subprocess.run(cmd, capture_output=True, text=True) + + # Check if we need to loop + needs_loop = False + if result.returncode == 0 and duration: + probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(output_path)] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + if probe_result.returncode == 0: + probe_data = json.loads(probe_result.stdout) + output_duration = float(probe_data.get("format", {}).get("duration", 0)) + if output_duration < duration - 1.0: # 1 second tolerance + needs_loop = True + print(f" Output {output_duration:.1f}s < requested {duration:.1f}s, will loop", file=sys.stderr) + + if needs_loop or result.returncode != 0: + # Get source duration for wrapping + probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(input_path)] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + src_duration = float(probe_data.get("format", {}).get("duration", 0)) + + if src_duration > 0: + wrapped_start = start % src_duration if start else 0 + print(f" Looping source ({src_duration:.1f}s) to reach {duration:.1f}s", file=sys.stderr) + + # Re-run with stream_loop + cmd = ["ffmpeg", "-y", "-stream_loop", "-1", "-i", str(input_path)] + cmd.extend(["-ss", str(wrapped_start)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend(["-r", str(fps), + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-pix_fmt", "yuv420p", + "-c:a", enc.get("audio_codec", "aac"), + str(output_path)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print(f" FFmpeg loop error: {result.stderr[:200]}", file=sys.stderr) + raise ValueError(f"FFmpeg segment extraction with loop failed") + + if not output_path.exists() or output_path.stat().st_size == 0: + raise ValueError(f"Segment output invalid: {output_path}") + + print(f" Segment: {output_path.stat().st_size / 1024 / 1024:.1f}MB", file=sys.stderr) + return output_path + + +def clean_nil_symbols(obj): + """Recursively convert Symbol('nil') to None and filter out None values from dicts.""" + if isinstance(obj, Symbol): + if obj.name == 'nil': + return None + return obj + elif isinstance(obj, dict): + result = {} + for k, v in obj.items(): + cleaned = clean_nil_symbols(v) + # Skip None values (they were nil) + if cleaned is not None: + result[k] = cleaned + return result + elif isinstance(obj, list): + return [clean_nil_symbols(v) for v in obj] + return obj + + +def parse_analysis_sexp(content: str) -> dict: + """Parse analysis S-expression into dict.""" + sexp = parse(content) + if isinstance(sexp, list) and len(sexp) == 1: + sexp = sexp[0] + + if not isinstance(sexp, list) or not sexp: + raise ValueError("Invalid analysis S-expression") + + # Should be (analysis (name ...) (name ...) ...) + if not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis": + raise ValueError("Expected (analysis ...) S-expression") + + result = {} + for item in sexp[1:]: + if isinstance(item, list) and item and isinstance(item[0], Symbol): + name = item[0].name + data = {} + + i = 1 + while i < len(item): + if isinstance(item[i], Keyword): + key = item[i].name.replace("-", "_") + i += 1 + if i < len(item): + data[key] = item[i] + i += 1 + else: + i += 1 + + result[name] = data + + return result + + +def sexp_to_plan(sexp) -> dict: + """Convert a parsed S-expression plan to a dict.""" + if not isinstance(sexp, list) or not sexp: + raise ValueError("Invalid plan S-expression") + + # Skip 'plan' symbol and name + plan = { + "steps": [], + "analysis": {}, + } + + i = 0 + if isinstance(sexp[0], Symbol) and sexp[0].name == "plan": + i = 1 + + # Parse keywords and steps + while i < len(sexp): + item = sexp[i] + + if isinstance(item, Keyword): + key = item.name.replace("-", "_") + i += 1 + if i < len(sexp): + value = sexp[i] + if key == "encoding" and isinstance(value, list): + # Parse encoding dict from sexp + plan["encoding"] = sexp_to_dict(value) + elif key == "output": + # Map :output to output_step_id + plan["output_step_id"] = value + elif key == "id": + # Map :id to plan_id + plan["plan_id"] = value + elif key == "source_cid": + # Map :source-cid to source_hash + plan["source_hash"] = value + else: + plan[key] = value + i += 1 + elif isinstance(item, list) and item and isinstance(item[0], Symbol): + if item[0].name == "step": + # Parse step + step = parse_step_sexp(item) + plan["steps"].append(step) + elif item[0].name == "analysis": + # Parse analysis data + plan["analysis"] = parse_analysis_sexp(item) + elif item[0].name == "effects-registry": + # Parse effects registry + plan["effects_registry"] = parse_effects_registry_sexp(item) + i += 1 + else: + i += 1 + + return plan + + +def parse_analysis_sexp(sexp) -> dict: + """Parse analysis S-expression: (analysis (bass :times [...] :values [...]) ...) + + Handles both inline data (:times [...] :values [...]) and cache-id refs (:cache-id "..."). + """ + analysis = {} + for item in sexp[1:]: # Skip 'analysis' symbol + if isinstance(item, list) and item and isinstance(item[0], Symbol): + name = item[0].name + data = {} + j = 1 + while j < len(item): + if isinstance(item[j], Keyword): + key = item[j].name + j += 1 + if j < len(item): + data[key] = item[j] + j += 1 + else: + j += 1 + # Normalize: parser gives "cache-id", internal code expects "_cache_id" + if "cache-id" in data: + data["_cache_id"] = data.pop("cache-id") + analysis[name] = data + return analysis + + +def parse_effects_registry_sexp(sexp) -> dict: + """Parse effects-registry S-expression: (effects-registry (rotate :path "...") (blur :path "..."))""" + registry = {} + for item in sexp[1:]: # Skip 'effects-registry' symbol + if isinstance(item, list) and item and isinstance(item[0], Symbol): + name = item[0].name + data = {} + j = 1 + while j < len(item): + if isinstance(item[j], Keyword): + key = item[j].name + j += 1 + if j < len(item): + data[key] = item[j] + j += 1 + else: + j += 1 + registry[name] = data + return registry + + +def parse_bind_sexp(sexp) -> dict: + """Parse a bind S-expression: (bind analysis-ref :range [min max] :offset 60 :transform sqrt)""" + if not isinstance(sexp, list) or len(sexp) < 2: + return None + if not isinstance(sexp[0], Symbol) or sexp[0].name != "bind": + return None + + bind = { + "_bind": sexp[1] if isinstance(sexp[1], str) else sexp[1].name if isinstance(sexp[1], Symbol) else str(sexp[1]), + "range_min": 0.0, + "range_max": 1.0, + "transform": None, + "offset": 0.0, + } + + i = 2 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + kw = sexp[i].name + if kw == "range": + i += 1 + if i < len(sexp) and isinstance(sexp[i], list) and len(sexp[i]) >= 2: + bind["range_min"] = float(sexp[i][0]) + bind["range_max"] = float(sexp[i][1]) + elif kw == "offset": + i += 1 + if i < len(sexp): + bind["offset"] = float(sexp[i]) + elif kw == "transform": + i += 1 + if i < len(sexp): + t = sexp[i] + if isinstance(t, Symbol): + bind["transform"] = t.name + elif isinstance(t, str): + bind["transform"] = t + i += 1 + + return bind + + +def sexp_to_dict(sexp) -> dict: + """Convert S-expression key-value pairs to dict.""" + result = {} + i = 0 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name.replace("-", "_") + i += 1 + if i < len(sexp): + value = sexp[i] + # Check for bind expression and convert to dict format + if isinstance(value, list) and value and isinstance(value[0], Symbol) and value[0].name == "bind": + value = parse_bind_sexp(value) + result[key] = value + i += 1 + else: + i += 1 + return result + + +def parse_step_sexp(sexp) -> dict: + """Parse a step S-expression. + + Supports two formats: + 1. (step "id" :cache-id "..." :type "SOURCE" :path "..." :inputs [...]) + 2. (step "id" :cache-id "..." :level 1 (source :path "..." :inputs [...])) + """ + step = { + "inputs": [], + "config": {}, + } + + i = 1 # Skip 'step' symbol + if i < len(sexp) and isinstance(sexp[i], str): + step["step_id"] = sexp[i] + i += 1 + + while i < len(sexp): + item = sexp[i] + + if isinstance(item, Keyword): + key = item.name.replace("-", "_") + i += 1 + if i < len(sexp): + value = sexp[i] + if key == "type": + step["node_type"] = value if isinstance(value, str) else value.name + elif key == "inputs": + step["inputs"] = value if isinstance(value, list) else [value] + elif key in ("level", "cache", "cache_id"): + if key == "cache": + key = "cache_id" + step[key] = value + else: + # Check for bind expression + if isinstance(value, list) and value and isinstance(value[0], Symbol) and value[0].name == "bind": + value = parse_bind_sexp(value) + # Config value + step["config"][key] = value + i += 1 + elif isinstance(item, list) and item and isinstance(item[0], Symbol): + # Nested node expression: (source :path "..." :inputs [...]) + node_type = item[0].name.upper() + step["node_type"] = node_type + + # Parse node config + j = 1 + while j < len(item): + if isinstance(item[j], Keyword): + key = item[j].name.replace("-", "_") + j += 1 + if j < len(item): + value = item[j] + if key == "inputs": + step["inputs"] = value if isinstance(value, list) else [value] + else: + # Check for bind expression + if isinstance(value, list) and value and isinstance(value[0], Symbol) and value[0].name == "bind": + value = parse_bind_sexp(value) + step["config"][key] = value + j += 1 + else: + j += 1 + i += 1 + else: + i += 1 + + return step + + +def parse_plan_input(content: str) -> dict: + """Parse plan from JSON or S-expression string.""" + content = content.strip() + if content.startswith("{"): + return json.loads(content) + elif content.startswith("("): + sexp = parse(content) + return sexp_to_plan(sexp[0] if isinstance(sexp, list) and len(sexp) == 1 else sexp) + else: + raise ValueError("Plan must be JSON (starting with '{') or S-expression (starting with '(')") + + +# Default encoding settings +DEFAULT_ENCODING = { + "codec": "libx264", + "preset": "fast", + "crf": 18, + "audio_codec": "aac", + "fps": 30, +} + + +def get_encoding(recipe_encoding: dict, step_config: dict) -> dict: + """Merge encoding settings: defaults < recipe < step overrides.""" + encoding = {**DEFAULT_ENCODING} + encoding.update(recipe_encoding) + if "encoding" in step_config: + encoding.update(step_config["encoding"]) + return encoding + + +class SexpEffectModule: + """Wrapper for S-expression effects to provide process_frame interface.""" + + def __init__(self, effect_path: Path, effects_registry: dict = None, recipe_dir: Path = None, minimal_primitives: bool = False): + from sexp_effects import get_interpreter + self.interp = get_interpreter(minimal_primitives=minimal_primitives) + + # Load only explicitly declared effects from the recipe's registry + # No auto-loading from directory - everything must be explicit + if effects_registry: + base_dir = recipe_dir or effect_path.parent.parent # Resolve relative paths + for effect_name, effect_info in effects_registry.items(): + effect_rel_path = effect_info.get("path") + if effect_rel_path: + full_path = (base_dir / effect_rel_path).resolve() + if full_path.exists() and effect_name not in self.interp.effects: + self.interp.load_effect(str(full_path)) + + # Load the specific effect if not already loaded + self.interp.load_effect(str(effect_path)) + self.effect_name = effect_path.stem + + def process_frame(self, frame, params, state): + return self.interp.run_effect(self.effect_name, frame, params, state or {}) + + +def load_effect(effect_path: Path, effects_registry: dict = None, recipe_dir: Path = None, minimal_primitives: bool = False): + """Load an effect module from a local path (.py or .sexp).""" + if effect_path.suffix == ".sexp": + return SexpEffectModule(effect_path, effects_registry, recipe_dir, minimal_primitives) + + spec = importlib.util.spec_from_file_location("effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def interpolate_analysis(times: list, values: list, t: float) -> float: + """Interpolate analysis value at time t.""" + if not times or not values: + return 0.0 + if t <= times[0]: + return values[0] + if t >= times[-1]: + return values[-1] + + # Binary search for surrounding times + lo, hi = 0, len(times) - 1 + while lo < hi - 1: + mid = (lo + hi) // 2 + if times[mid] <= t: + lo = mid + else: + hi = mid + + # Linear interpolation + t0, t1 = times[lo], times[hi] + v0, v1 = values[lo], values[hi] + if t1 == t0: + return v0 + alpha = (t - t0) / (t1 - t0) + return v0 + alpha * (v1 - v0) + + +def apply_transform(value: float, transform: str) -> float: + """Apply a transform function to a value (0-1 range).""" + if transform is None: + return value + if transform == "sqrt": + return value ** 0.5 + elif transform == "pow2": + return value ** 2 + elif transform == "pow3": + return value ** 3 + elif transform == "log": + # Logarithmic scale: log(1 + 9*x) / log(10) maps 0-1 to 0-1 with log curve + import math + return math.log(1 + 9 * value) / math.log(10) if value > 0 else 0 + elif transform == "exp": + # Exponential scale: (10^x - 1) / 9 maps 0-1 to 0-1 with exp curve + return (10 ** value - 1) / 9 + elif transform == "inv": + return 1 - value + else: + return value + + +def eval_expr(value, frame_time: float, frame_num: int, analysis_data: dict) -> float: + """ + Evaluate a runtime expression. + + Supports: + - Literals (int, float) + - Bindings: {"_binding": True, "source": ..., "feature": ...} + - Math expressions: {"_expr": True, "op": "+", "args": [...]} + - Time/frame: {"_expr": True, "op": "time"} or {"_expr": True, "op": "frame"} + """ + import math + + # Literal values + if isinstance(value, (int, float)): + return float(value) + + if not isinstance(value, dict): + return 0.0 # Unknown type + + # Handle bindings + if "_bind" in value or "_binding" in value: + if "_bind" in value: + ref = value["_bind"] + range_min = value.get("range_min", 0.0) + range_max = value.get("range_max", 1.0) + else: + ref = value.get("source", "") + range_val = value.get("range", [0.0, 1.0]) + range_min = range_val[0] if isinstance(range_val, list) else 0.0 + range_max = range_val[1] if isinstance(range_val, list) and len(range_val) > 1 else 1.0 + + transform = value.get("transform") + bind_offset = value.get("offset", 0.0) + + track = analysis_data.get(ref, {}) + times = track.get("times", []) + values = track.get("values", []) + + lookup_time = frame_time + bind_offset + raw = interpolate_analysis(times, values, lookup_time) + transformed = apply_transform(raw, transform) + + return range_min + transformed * (range_max - range_min) + + # Handle expressions + if "_expr" in value: + op = value.get("op") + args = value.get("args", []) + + # Special ops without args + if op == "time": + return frame_time + if op == "frame": + return float(frame_num) + + # Lazy-evaluated ops (don't evaluate all branches) + if op == "if": + cond = eval_expr(args[0], frame_time, frame_num, analysis_data) if args else 0.0 + if cond: + return eval_expr(args[1], frame_time, frame_num, analysis_data) if len(args) > 1 else 0.0 + return eval_expr(args[2], frame_time, frame_num, analysis_data) if len(args) > 2 else 0.0 + + # Evaluate arguments recursively + evaluated = [eval_expr(arg, frame_time, frame_num, analysis_data) for arg in args] + + # Comparison operations + if op == "<" and len(evaluated) >= 2: + return 1.0 if evaluated[0] < evaluated[1] else 0.0 + if op == ">" and len(evaluated) >= 2: + return 1.0 if evaluated[0] > evaluated[1] else 0.0 + if op == "<=" and len(evaluated) >= 2: + return 1.0 if evaluated[0] <= evaluated[1] else 0.0 + if op == ">=" and len(evaluated) >= 2: + return 1.0 if evaluated[0] >= evaluated[1] else 0.0 + if op == "=" and len(evaluated) >= 2: + return 1.0 if evaluated[0] == evaluated[1] else 0.0 + + # Math operations + if op == "+" and len(evaluated) >= 2: + return evaluated[0] + evaluated[1] + if op == "-" and len(evaluated) >= 2: + return evaluated[0] - evaluated[1] + if op == "*" and len(evaluated) >= 2: + return evaluated[0] * evaluated[1] + if op == "/" and len(evaluated) >= 2: + return evaluated[0] / evaluated[1] if evaluated[1] != 0 else 0.0 + if op == "mod" and len(evaluated) >= 2: + return evaluated[0] % evaluated[1] if evaluated[1] != 0 else 0.0 + if op == "min" and len(evaluated) >= 2: + return min(evaluated[0], evaluated[1]) + if op == "max" and len(evaluated) >= 2: + return max(evaluated[0], evaluated[1]) + if op == "abs" and len(evaluated) >= 1: + return abs(evaluated[0]) + if op == "sin" and len(evaluated) >= 1: + return math.sin(evaluated[0]) + if op == "cos" and len(evaluated) >= 1: + return math.cos(evaluated[0]) + if op == "floor" and len(evaluated) >= 1: + return float(math.floor(evaluated[0])) + if op == "ceil" and len(evaluated) >= 1: + return float(math.ceil(evaluated[0])) + + return 0.0 # Fallback + + +def eval_scan_expr(value, rng, variables): + """ + Evaluate a scan expression with seeded RNG and variable bindings. + + Args: + value: Compiled expression (literal, dict with _expr, etc.) + rng: random.Random instance (seeded, advances state per call) + variables: Dict of variable bindings (acc, rem, hue, etc.) + + Returns: + Evaluated value (number or dict) + """ + import math + + if isinstance(value, (int, float)): + return value + + if isinstance(value, str): + return value + + if not isinstance(value, dict) or "_expr" not in value: + return value + + op = value.get("op") + args = value.get("args", []) + + # Variable reference + if op == "var": + name = value.get("name", "") + return variables.get(name, 0) + + # Dict constructor + if op == "dict": + keys = value.get("keys", []) + vals = [eval_scan_expr(a, rng, variables) for a in args] + return dict(zip(keys, vals)) + + # Random ops (advance RNG state) + if op == "rand": + return rng.random() + if op == "rand-int": + lo = int(eval_scan_expr(args[0], rng, variables)) + hi = int(eval_scan_expr(args[1], rng, variables)) + return rng.randint(lo, hi) + if op == "rand-range": + lo = float(eval_scan_expr(args[0], rng, variables)) + hi = float(eval_scan_expr(args[1], rng, variables)) + return rng.uniform(lo, hi) + + # Conditional (lazy - only evaluate taken branch) + if op == "if": + cond = eval_scan_expr(args[0], rng, variables) if args else 0 + if cond: + return eval_scan_expr(args[1], rng, variables) if len(args) > 1 else 0 + return eval_scan_expr(args[2], rng, variables) if len(args) > 2 else 0 + + # Comparison ops + if op in ("<", ">", "<=", ">=", "="): + left = eval_scan_expr(args[0], rng, variables) if args else 0 + right = eval_scan_expr(args[1], rng, variables) if len(args) > 1 else 0 + if op == "<": + return 1 if left < right else 0 + if op == ">": + return 1 if left > right else 0 + if op == "<=": + return 1 if left <= right else 0 + if op == ">=": + return 1 if left >= right else 0 + if op == "=": + return 1 if left == right else 0 + + # Eagerly evaluate remaining args + evaluated = [eval_scan_expr(a, rng, variables) for a in args] + + # Arithmetic ops + if op == "+" and len(evaluated) >= 2: + return evaluated[0] + evaluated[1] + if op == "-" and len(evaluated) >= 2: + return evaluated[0] - evaluated[1] + if op == "-" and len(evaluated) == 1: + return -evaluated[0] + if op == "*" and len(evaluated) >= 2: + return evaluated[0] * evaluated[1] + if op == "/" and len(evaluated) >= 2: + return evaluated[0] / evaluated[1] if evaluated[1] != 0 else 0 + if op == "mod" and len(evaluated) >= 2: + return evaluated[0] % evaluated[1] if evaluated[1] != 0 else 0 + if op == "min" and len(evaluated) >= 2: + return min(evaluated[0], evaluated[1]) + if op == "max" and len(evaluated) >= 2: + return max(evaluated[0], evaluated[1]) + if op == "abs" and len(evaluated) >= 1: + return abs(evaluated[0]) + if op == "sin" and len(evaluated) >= 1: + return math.sin(evaluated[0]) + if op == "cos" and len(evaluated) >= 1: + return math.cos(evaluated[0]) + if op == "floor" and len(evaluated) >= 1: + return math.floor(evaluated[0]) + if op == "ceil" and len(evaluated) >= 1: + return math.ceil(evaluated[0]) + if op == "nth" and len(evaluated) >= 2: + collection = evaluated[0] + index = int(evaluated[1]) + if isinstance(collection, (list, tuple)) and 0 <= index < len(collection): + return collection[index] + return 0 + + return 0 # Fallback + + +def _is_binding(value): + """Check if a value is a binding/expression dict that needs per-frame resolution.""" + return isinstance(value, dict) and ("_bind" in value or "_binding" in value or "_expr" in value) + + +def _check_has_bindings(params: dict) -> bool: + """Check if any param value (including inside lists) contains bindings.""" + for v in params.values(): + if _is_binding(v): + return True + if isinstance(v, list) and any(_is_binding(item) for item in v): + return True + return False + + +def resolve_params(params: dict, frame_time: float, analysis_data: dict, frame_num: int = 0) -> dict: + """Resolve any binding/expression params using analysis data at frame_time. + + Handles bindings at the top level and inside lists (e.g. blend_multi weights). + """ + resolved = {} + for key, value in params.items(): + if _is_binding(value): + resolved[key] = eval_expr(value, frame_time, frame_num, analysis_data) + elif isinstance(value, list): + resolved[key] = [ + eval_expr(item, frame_time, frame_num, analysis_data) + if _is_binding(item) else item + for item in value + ] + else: + resolved[key] = value + return resolved + + +def resolve_scalar_binding(value, analysis_data: dict): + """Resolve a scalar binding (like duration) from analysis data. + + For scalar features like 'duration', retrieves the value directly from analysis data. + For time-varying features, this returns None (use resolve_params instead). + + Returns: + Resolved value (float) if binding can be resolved to scalar, None otherwise. + If value is not a binding, returns the value unchanged. + """ + if not isinstance(value, dict) or not ("_bind" in value or "_binding" in value): + return value + + # Get source reference and feature + if "_bind" in value: + ref = value["_bind"] + feature = "values" # old format defaults to values + else: + ref = value.get("source", "") + feature = value.get("feature", "values") + + # Look up analysis track + track = analysis_data.get(ref, {}) + + # For scalar features like 'duration', get directly + if feature == "duration": + duration = track.get("duration") + if duration is not None: + return float(duration) + return None + + # For time-varying features, can't resolve to scalar + # Return None to indicate this needs frame-by-frame resolution + return None + + +@_video_pipeline_guard +def run_effect(effect_module, input_path: Path, output_path: Path, params: dict, encoding: dict, analysis_data: dict = None, time_offset: float = 0.0, max_duration: float = None): + """Run an effect on a video file. + + Args: + time_offset: Time offset in seconds for resolving bindings (e.g., segment start time in audio) + max_duration: Maximum duration in seconds to process (stops after this many seconds of frames) + """ + import numpy as np + + # Clean nil Symbols from params + params = clean_nil_symbols(params) + + # Get video info including duration + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(input_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + + # Find video stream + video_stream = None + for stream in probe_data.get("streams", []): + if stream.get("codec_type") == "video": + video_stream = stream + break + + if not video_stream: + raise ValueError("No video stream found") + + in_width = int(video_stream["width"]) + in_height = int(video_stream["height"]) + + # Get framerate + fps_str = video_stream.get("r_frame_rate", "30/1") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) + else: + fps = float(fps_str) + + # Get duration for progress bar + duration = None + if "format" in probe_data and "duration" in probe_data["format"]: + duration = float(probe_data["format"]["duration"]) + + # Read frames with ffmpeg + read_cmd = [ + "ffmpeg", "-i", str(input_path), + "-f", "rawvideo", "-pix_fmt", "rgb24", "-" + ] + read_proc = subprocess.Popen(read_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Check if we have any bindings that need per-frame resolution + has_bindings = _check_has_bindings(params) + analysis_data = analysis_data or {} + + # Debug: print bindings and analysis info once + if has_bindings: + print(f" BINDINGS DEBUG: time_offset={time_offset:.2f}", file=sys.stderr) + for k, v in params.items(): + if isinstance(v, dict) and ("_bind" in v or "_binding" in v): + ref = v.get("_bind") or v.get("source") + bind_offset = float(v.get("offset", 0.0)) + track = analysis_data.get(ref, {}) + times = track.get("times", []) + values = track.get("values", []) + if times and values: + # Find first non-zero value + first_nonzero_idx = next((i for i, v in enumerate(values) if v > 0.01), -1) + first_nonzero_time = times[first_nonzero_idx] if first_nonzero_idx >= 0 else -1 + print(f" param {k}: ref='{ref}' bind_offset={bind_offset} time_range=[{min(times):.2f}, {max(times):.2f}]", file=sys.stderr) + print(f" first_nonzero at t={first_nonzero_time:.2f} max_value={max(values):.4f}", file=sys.stderr) + else: + raise ValueError(f"Binding for param '{k}' references '{ref}' but no analysis data found. Available: {list(analysis_data.keys())}") + + # Process first frame to detect output dimensions + in_frame_size = in_width * in_height * 3 + frame_data = read_proc.stdout.read(in_frame_size) + if len(frame_data) < in_frame_size: + read_proc.stdout.close() + read_proc.wait() + raise ValueError("No frames in input video") + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((in_height, in_width, 3)) + + # Resolve params for first frame + if has_bindings: + frame_params = resolve_params(params, time_offset, analysis_data, frame_num=0) + else: + frame_params = params + + # Apply single effect with mix bypass: mix=0 → passthrough, 0=1 → full + def apply_effect(frame, frame_params, state): + mix_val = float(frame_params.get('mix', 1.0)) + if mix_val <= 0: + return frame, state + result, state = effect_module.process_frame(frame, frame_params, state) + if mix_val < 1.0: + result = np.clip( + frame.astype(np.float32) * (1.0 - mix_val) + + result.astype(np.float32) * mix_val, + 0, 255 + ).astype(np.uint8) + return result, state + + state = None + processed, state = apply_effect(frame, frame_params, state) + + # Get output dimensions from processed frame + out_height, out_width = processed.shape[:2] + if out_width != in_width or out_height != in_height: + print(f" Effect resizes: {in_width}x{in_height} -> {out_width}x{out_height}", file=sys.stderr) + + # Now start write process with correct output dimensions + write_cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{out_width}x{out_height}", "-r", str(encoding.get("fps", 30)), + "-i", "-", + "-i", str(input_path), # For audio + "-map", "0:v", "-map", "1:a?", + "-c:v", encoding["codec"], "-preset", encoding["preset"], "-crf", str(encoding["crf"]), + "-pix_fmt", "yuv420p", + "-c:a", encoding["audio_codec"], + str(output_path) + ] + write_proc = subprocess.Popen(write_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Write first processed frame + write_proc.stdin.write(processed.tobytes()) + frame_count = 1 + + # Calculate max frames and total for progress bar + max_frames = None + total_frames = 0 + if max_duration: + max_frames = int(max_duration * fps) + total_frames = max_frames + elif duration: + total_frames = int(duration * fps) + + # Create progress bar + effect_name = getattr(effect_module, 'effect_name', 'effect') + pbar = ProgressBar(total_frames, desc=effect_name) + pbar.set(1) # First frame already processed + + # Process remaining frames + while True: + # Stop if we've reached the frame limit + if max_frames and frame_count >= max_frames: + break + + frame_data = read_proc.stdout.read(in_frame_size) + if len(frame_data) < in_frame_size: + break + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((in_height, in_width, 3)) + + # Resolve params for this frame + if has_bindings: + frame_time = time_offset + frame_count / fps + frame_params = resolve_params(params, frame_time, analysis_data, frame_num=frame_count) + else: + frame_params = params + + processed, state = apply_effect(frame, frame_params, state) + write_proc.stdin.write(processed.tobytes()) + frame_count += 1 + pbar.set(frame_count) + + read_proc.stdout.close() + write_proc.stdin.close() + read_proc.wait() + write_proc.wait() + + pbar.finish() + + +@_video_pipeline_guard +def run_multi_effect(effect_module, input_paths: List[Path], output_path: Path, params: dict, encoding: dict, analysis_data: dict = None, time_offset: float = 0.0, max_duration: float = None): + """Run a multi-input effect on multiple video files. + + Args: + time_offset: Time offset in seconds for resolving bindings (e.g., segment start time in audio) + max_duration: Maximum duration in seconds to process (stops after this many seconds of frames) + """ + import numpy as np + + # Clean nil Symbols from params + params = clean_nil_symbols(params) + + if len(input_paths) < 2: + raise ValueError("Multi-input effect requires at least 2 inputs") + + # Get video info for each input (preserve original dimensions) + input_infos = [] + for input_path in input_paths: + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(input_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + + video_stream = None + for stream in probe_data.get("streams", []): + if stream.get("codec_type") == "video": + video_stream = stream + break + + if not video_stream: + raise ValueError(f"No video stream found in {input_path}") + + w = int(video_stream["width"]) + h = int(video_stream["height"]) + input_infos.append({"width": w, "height": h, "path": input_path}) + print(f" Input: {input_path.name} ({w}x{h})", file=sys.stderr) + + # Get framerate and duration from first input + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(input_paths[0]) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + video_stream = next(s for s in probe_data.get("streams", []) if s.get("codec_type") == "video") + fps_str = video_stream.get("r_frame_rate", "30/1") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) + else: + fps = float(fps_str) + + # Get duration for progress bar + duration = None + if "format" in probe_data and "duration" in probe_data["format"]: + duration = float(probe_data["format"]["duration"]) + + # Open read processes for all inputs - preserve original dimensions + read_procs = [] + for info in input_infos: + read_cmd = [ + "ffmpeg", "-i", str(info["path"]), + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-" # Don't scale - keep original dimensions + ] + proc = subprocess.Popen(read_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + read_procs.append(proc) + + analysis_data = analysis_data or {} + state = None + + # Process first frame to detect output dimensions + frames = [] + for i, (proc, info) in enumerate(zip(read_procs, input_infos)): + frame_size = info["width"] * info["height"] * 3 + frame_data = proc.stdout.read(frame_size) + if len(frame_data) < frame_size: + # Cleanup + for p in read_procs: + p.stdout.close() + p.wait() + raise ValueError(f"No frames in input {i}") + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((info["height"], info["width"], 3)) + frames.append(frame) + + # Check if we have any bindings that need per-frame resolution + has_bindings = _check_has_bindings(params) + + # Resolve params for first frame + if has_bindings: + frame_params = resolve_params(params, time_offset, analysis_data, frame_num=0) + else: + frame_params = params + + processed, state = effect_module.process_frame(frames, frame_params, state) + out_height, out_width = processed.shape[:2] + print(f" Output dimensions: {out_width}x{out_height}", file=sys.stderr) + + # Now start write process with correct output dimensions + write_cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{out_width}x{out_height}", "-r", str(encoding.get("fps", 30)), + "-i", "-", + "-i", str(input_paths[0]), # For audio from first input + "-map", "0:v", "-map", "1:a?", + "-c:v", encoding["codec"], "-preset", encoding["preset"], "-crf", str(encoding["crf"]), + "-pix_fmt", "yuv420p", + "-c:a", encoding["audio_codec"], + str(output_path) + ] + write_proc = subprocess.Popen(write_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Write first processed frame + write_proc.stdin.write(processed.tobytes()) + frame_count = 1 + + # Calculate max frames and total for progress bar + max_frames = None + total_frames = 0 + if max_duration: + max_frames = int(max_duration * fps) + total_frames = max_frames + elif duration: + total_frames = int(duration * fps) + + # Create progress bar + effect_name = getattr(effect_module, 'effect_name', 'blend') + pbar = ProgressBar(total_frames, desc=effect_name) + pbar.set(1) # First frame already processed + + # Process remaining frames + while True: + # Stop if we've reached the frame limit + if max_frames and frame_count >= max_frames: + break + + # Read frame from each input (each may have different dimensions) + frames = [] + all_valid = True + for i, (proc, info) in enumerate(zip(read_procs, input_infos)): + frame_size = info["width"] * info["height"] * 3 + frame_data = proc.stdout.read(frame_size) + if len(frame_data) < frame_size: + all_valid = False + break + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((info["height"], info["width"], 3)) + frames.append(frame) + + if not all_valid: + break + + # Resolve params for this frame + if has_bindings: + frame_time = time_offset + frame_count / fps + frame_params = resolve_params(params, frame_time, analysis_data, frame_num=frame_count) + else: + frame_params = params + + # Pass list of frames to effect + processed, state = effect_module.process_frame(frames, frame_params, state) + write_proc.stdin.write(processed.tobytes()) + frame_count += 1 + pbar.set(frame_count) + + # Cleanup + for proc in read_procs: + proc.stdout.close() + proc.wait() + write_proc.stdin.close() + write_proc.wait() + + pbar.finish() + + +@_video_pipeline_guard +def run_effect_chain(effect_modules, input_path: Path, output_path: Path, + params_list: list, encoding: dict, + analysis_data=None, time_offset: float = 0.0, + max_duration: float = None): + """Run multiple effects as a single-pass fused chain: one decode, one encode, no intermediates. + + Args: + effect_modules: List of effect modules (each has process_frame) + input_path: Input video file + output_path: Output video file + params_list: List of param dicts, one per effect + encoding: Encoding settings + analysis_data: Optional analysis data for binding resolution + time_offset: Time offset for resolving bindings + max_duration: Maximum duration in seconds to process + """ + import numpy as np + + # Clean nil Symbols from each params dict + params_list = [clean_nil_symbols(p) for p in params_list] + + # Probe input for dimensions/fps/duration + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(input_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + + video_stream = None + for stream in probe_data.get("streams", []): + if stream.get("codec_type") == "video": + video_stream = stream + break + if not video_stream: + raise ValueError("No video stream found") + + in_width = int(video_stream["width"]) + in_height = int(video_stream["height"]) + + fps_str = video_stream.get("r_frame_rate", "30/1") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) + else: + fps = float(fps_str) + + duration = None + if "format" in probe_data and "duration" in probe_data["format"]: + duration = float(probe_data["format"]["duration"]) + + # Pre-compute per-effect binding flags + analysis_data = analysis_data or {} + bindings_flags = [] + for params in params_list: + has_b = any(isinstance(v, dict) and ("_bind" in v or "_binding" in v or "_expr" in v) + for v in params.values()) + bindings_flags.append(has_b) + + # Open single ffmpeg reader + read_cmd = [ + "ffmpeg", "-i", str(input_path), + "-f", "rawvideo", "-pix_fmt", "rgb24", "-" + ] + read_proc = subprocess.Popen(read_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Read first frame + in_frame_size = in_width * in_height * 3 + frame_data = read_proc.stdout.read(in_frame_size) + if len(frame_data) < in_frame_size: + read_proc.stdout.close() + read_proc.wait() + raise ValueError("No frames in input video") + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((in_height, in_width, 3)) + + # Apply effect chain to a frame, respecting per-effect mix bypass. + # mix=0 → skip (zero cost), 0=1 → full effect. + def apply_chain(frame, states, frame_num, frame_time): + processed = frame + for idx, (module, params, has_b) in enumerate(zip(effect_modules, params_list, bindings_flags)): + if has_b: + fp = resolve_params(params, frame_time, analysis_data, frame_num=frame_num) + else: + fp = params + mix_val = float(fp.get('mix', 1.0)) + if mix_val <= 0: + continue + result, states[idx] = module.process_frame(processed, fp, states[idx]) + if mix_val < 1.0: + processed = np.clip( + processed.astype(np.float32) * (1.0 - mix_val) + + result.astype(np.float32) * mix_val, + 0, 255 + ).astype(np.uint8) + else: + processed = result + return processed, states + + # Push first frame through all effects to discover final output dimensions + states = [None] * len(effect_modules) + processed, states = apply_chain(frame, states, 0, time_offset) + + out_height, out_width = processed.shape[:2] + if out_width != in_width or out_height != in_height: + print(f" Chain resizes: {in_width}x{in_height} -> {out_width}x{out_height}", file=sys.stderr) + + # Open single ffmpeg writer with final output dimensions + write_cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{out_width}x{out_height}", "-r", str(encoding.get("fps", 30)), + "-i", "-", + "-i", str(input_path), # For audio + "-map", "0:v", "-map", "1:a?", + "-c:v", encoding["codec"], "-preset", encoding["preset"], "-crf", str(encoding["crf"]), + "-pix_fmt", "yuv420p", + "-c:a", encoding["audio_codec"], + str(output_path) + ] + write_proc = subprocess.Popen(write_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Write first processed frame + write_proc.stdin.write(processed.tobytes()) + frame_count = 1 + + # Calculate max frames and total for progress bar + max_frames = None + total_frames = 0 + if max_duration: + max_frames = int(max_duration * fps) + total_frames = max_frames + elif duration: + total_frames = int(duration * fps) + + effect_names = [getattr(m, 'effect_name', '?') for m in effect_modules] + pbar = ProgressBar(total_frames, desc='+'.join(effect_names)) + pbar.set(1) + + # Frame loop: read -> apply chain -> write + while True: + if max_frames and frame_count >= max_frames: + break + + frame_data = read_proc.stdout.read(in_frame_size) + if len(frame_data) < in_frame_size: + break + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((in_height, in_width, 3)) + + frame_time = time_offset + frame_count / fps + processed, states = apply_chain(frame, states, frame_count, frame_time) + + write_proc.stdin.write(processed.tobytes()) + frame_count += 1 + pbar.set(frame_count) + + read_proc.stdout.close() + write_proc.stdin.close() + read_proc.wait() + write_proc.wait() + + pbar.finish() + + +def get_video_dimensions(file_path: Path) -> tuple: + """Get video dimensions using ffprobe.""" + cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(file_path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + data = json.loads(result.stdout) + + for stream in data.get("streams", []): + if stream.get("codec_type") == "video": + return int(stream["width"]), int(stream["height"]) + + return None, None + + +def normalize_video( + input_path: Path, + output_path: Path, + target_width: int, + target_height: int, + resize_mode: str, + priority: str = None, + pad_color: str = "black", + crop_gravity: str = "center", + encoding: dict = None, +) -> Path: + """ + Normalize video to target dimensions. + + resize_mode: + - stretch: force to exact size (distorts) + - crop: scale to fill, crop overflow + - fit: scale to fit, pad remainder + - cover: scale to cover, crop minimally + + priority: width | height (which dimension to match exactly for fit/crop) + """ + enc = encoding or {} + src_width, src_height = get_video_dimensions(input_path) + + if src_width is None: + # Can't determine dimensions, just copy + shutil.copy(input_path, output_path) + return output_path + + # Already correct size? + if src_width == target_width and src_height == target_height: + shutil.copy(input_path, output_path) + return output_path + + src_aspect = src_width / src_height + target_aspect = target_width / target_height + + if resize_mode == "stretch": + # Force exact size + vf = f"scale={target_width}:{target_height}" + + elif resize_mode == "fit": + # Scale to fit within bounds, pad remainder + if priority == "width": + # Match width exactly, pad height + vf = f"scale={target_width}:-1,pad={target_width}:{target_height}:(ow-iw)/2:(oh-ih)/2:{pad_color}" + elif priority == "height": + # Match height exactly, pad width + vf = f"scale=-1:{target_height},pad={target_width}:{target_height}:(ow-iw)/2:(oh-ih)/2:{pad_color}" + else: + # Auto: fit within bounds (may pad both) + if src_aspect > target_aspect: + # Source is wider, fit to width + vf = f"scale={target_width}:-1,pad={target_width}:{target_height}:(ow-iw)/2:(oh-ih)/2:{pad_color}" + else: + # Source is taller, fit to height + vf = f"scale=-1:{target_height},pad={target_width}:{target_height}:(ow-iw)/2:(oh-ih)/2:{pad_color}" + + elif resize_mode == "crop": + # Scale to fill, crop overflow + if priority == "width": + # Match width, crop height + vf = f"scale={target_width}:-1,crop={target_width}:{target_height}" + elif priority == "height": + # Match height, crop width + vf = f"scale=-1:{target_height},crop={target_width}:{target_height}" + else: + # Auto: fill bounds, crop minimally + if src_aspect > target_aspect: + # Source is wider, match height and crop width + vf = f"scale=-1:{target_height},crop={target_width}:{target_height}" + else: + # Source is taller, match width and crop height + vf = f"scale={target_width}:-1,crop={target_width}:{target_height}" + + elif resize_mode == "cover": + # Scale to cover target, crop to exact size + if src_aspect > target_aspect: + vf = f"scale=-1:{target_height},crop={target_width}:{target_height}" + else: + vf = f"scale={target_width}:-1,crop={target_width}:{target_height}" + + else: + # Unknown mode, just copy + shutil.copy(input_path, output_path) + return output_path + + cmd = [ + "ffmpeg", "-y", "-i", str(input_path), + "-vf", vf, + "-r", str(enc.get("fps", 30)), # Normalize framerate for concat compatibility + "-c:v", enc.get("codec", "libx264"), + "-preset", enc.get("preset", "fast"), + "-crf", str(enc.get("crf", 18)), + "-pix_fmt", "yuv420p", # Normalize pixel format for concat compatibility + "-c:a", enc.get("audio_codec", "aac"), + str(output_path) + ] + subprocess.run(cmd, check=True, capture_output=True) + return output_path + + +def tree_concat(files: list, work_dir: Path, prefix: str = "concat") -> Path: + """Concatenate files using a binary tree approach.""" + if len(files) == 1: + return files[0] + + level = 0 + current_files = list(files) + print(f" Tree concat: {len(current_files)} files", file=sys.stderr) + for i, f in enumerate(current_files): + print(f" [{i}] {f}", file=sys.stderr) + + while len(current_files) > 1: + next_files = [] + pairs = (len(current_files) + 1) // 2 + print(f" Level {level}: {len(current_files)} -> {pairs} pairs", file=sys.stderr) + + for i in range(0, len(current_files), 2): + if i + 1 < len(current_files): + concat_file = work_dir / f"{prefix}_L{level}_{i}.txt" + output_file = work_dir / f"{prefix}_L{level}_{i}.mp4" + + with open(concat_file, "w") as f: + f.write(f"file '{current_files[i]}'\n") + f.write(f"file '{current_files[i+1]}'\n") + + cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", + "-i", str(concat_file), "-c", "copy", str(output_file)] + subprocess.run(cmd, capture_output=True) + next_files.append(output_file) + else: + next_files.append(current_files[i]) + + current_files = next_files + level += 1 + + return current_files[0] + + +def execute_plan(plan_path: Path = None, output_path: Path = None, recipe_dir: Path = None, plan_data: dict = None, external_analysis: dict = None, cache_dir: Path = None): + """Execute a plan file (S-expression) or plan dict. + + Args: + cache_dir: Directory to cache intermediate results. If provided, steps will + check for cached outputs before recomputing. + """ + + # Load plan from file, stdin, or dict + if plan_data: + plan = plan_data + elif plan_path and str(plan_path) != "-": + content = plan_path.read_text() + plan = parse_plan_input(content) + else: + # Read from stdin + content = sys.stdin.read() + plan = parse_plan_input(content) + + print(f"Executing plan: {plan['plan_id'][:16]}...", file=sys.stderr) + print(f"Source CID: {plan.get('source_hash', 'unknown')[:16]}...", file=sys.stderr) + print(f"Steps: {len(plan['steps'])}", file=sys.stderr) + + recipe_encoding = plan.get("encoding", {}) + + # Merge plan's embedded analysis (includes synthetic tracks from composition + # merging) with external analysis (fresh ANALYZE step outputs). + # External analysis takes priority for tracks that exist in both. + analysis_data = dict(plan.get("analysis", {})) + if external_analysis: + analysis_data.update(external_analysis) + + # Resolve cache-id refs from plan + for name, data in list(analysis_data.items()): + if isinstance(data, dict) and "_cache_id" in data: + try: + from cache import cache_get_json + loaded = cache_get_json(data["_cache_id"]) + if loaded: + analysis_data[name] = loaded + except ImportError: + pass # standalone mode, no cache available + if recipe_dir is None: + recipe_dir = plan_path.parent if plan_path else Path(".") + + if analysis_data: + print(f"Analysis tracks: {list(analysis_data.keys())}", file=sys.stderr) + + # Get effects registry for loading explicitly declared effects + effects_registry = plan.get("effects_registry", {}) + if effects_registry: + print(f"Effects registry: {list(effects_registry.keys())}", file=sys.stderr) + + # Check for minimal primitives mode + minimal_primitives = plan.get("minimal_primitives", False) + if minimal_primitives: + print(f"Minimal primitives mode: enabled", file=sys.stderr) + + # Execute steps + results = {} # step_id -> output_path + work_dir = Path(tempfile.mkdtemp(prefix="artdag_exec_")) + + # Sort steps by level first (respecting dependencies), then by type within each level + # Type priority within same level: SOURCE/SEGMENT first, then ANALYZE, then EFFECT + steps = plan["steps"] + def step_sort_key(s): + node_type = s.get("node_type") or "UNKNOWN" + # Handle node_type being a Symbol + if hasattr(node_type, 'name'): + node_type = node_type.name + level = s.get("level", 0) + # Ensure level is an int (could be Symbol or None) + if not isinstance(level, int): + level = 0 + # Type priority (tiebreaker within same level): SOURCE=0, SEGMENT=1, ANALYZE=2, others=3 + if node_type == "SOURCE": + type_priority = 0 + elif node_type == "SEGMENT": + type_priority = 1 + elif node_type in ("ANALYZE", "SCAN"): + type_priority = 2 + else: + type_priority = 3 + # Sort by level FIRST, then type priority as tiebreaker + return (level, type_priority) + ordered_steps = sorted(steps, key=step_sort_key) + + try: + def _run_step(step): + step_id = step["step_id"] + node_type = step["node_type"] + config = step["config"] + inputs = step.get("inputs", []) + cache_id = step.get("cache_id", step_id) # IPNS address for caching + + print(f"\n[{step.get('level', 0)}] {node_type}: {step_id[:16]}...", file=sys.stderr) + + if node_type == "SOURCE": + if "path" in config: + src_path = (recipe_dir / config["path"]).resolve() + if not src_path.exists(): + raise FileNotFoundError(f"Source not found: {src_path}") + results[step_id] = src_path + print(f" -> {src_path}", file=sys.stderr) + + elif node_type == "SEGMENT": + is_audio = str(results[inputs[0]]).lower().endswith( + ('.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a')) + + input_path = results[inputs[0]] + start = config.get("start", 0) + duration = config.get("duration") + end = config.get("end") + + # Resolve any bindings to scalar values + start = resolve_scalar_binding(start, analysis_data) if start else 0 + duration = resolve_scalar_binding(duration, analysis_data) if duration else None + end = resolve_scalar_binding(end, analysis_data) if end else None + + # Check cache + cached = check_cache(cache_dir, cache_id, ['.m4a'] if is_audio else ['.mp4']) + if cached: + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + print(f" Resolved: start={start}, duration={duration}", file=sys.stderr) + + enc = get_encoding(recipe_encoding, config) + + if is_audio: + output_file = work_dir / f"segment_{step_id}.m4a" + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend(["-c:a", enc["audio_codec"], str(output_file)]) + else: + output_file = work_dir / f"segment_{step_id}.mp4" + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + elif end: + cmd.extend(["-t", str(end - start)]) + cmd.extend(["-r", str(enc["fps"]), # Normalize frame rate + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-c:a", enc["audio_codec"], + str(output_file)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + + # Check if segment has video content AND correct duration, if not try with looping + needs_loop = False + if not is_audio and result.returncode == 0: + probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(output_file)] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + has_video = any(s.get("codec_type") == "video" for s in probe_data.get("streams", [])) + if not has_video: + needs_loop = True + # Also check if output duration matches requested duration + elif duration: + output_duration = float(probe_data.get("format", {}).get("duration", 0)) + # If output is significantly shorter than requested, need to loop + if output_duration < duration - 1.0: # 1 second tolerance + needs_loop = True + print(f" Output {output_duration:.1f}s < requested {duration:.1f}s, will loop", file=sys.stderr) + + if needs_loop or result.returncode != 0: + # Get source duration and loop the input + probe_cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(input_path)] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + src_duration = float(probe_data.get("format", {}).get("duration", 0)) + + if src_duration > 0: + # Wrap start time to source duration + wrapped_start = start % src_duration if start else 0 + seg_duration = duration if duration else (end - start if end else None) + + print(f" Wrapping segment: {start:.2f}s -> {wrapped_start:.2f}s (source={src_duration:.2f}s)", file=sys.stderr) + + # Use stream_loop for seamless looping if segment spans wrap point + if wrapped_start + (seg_duration or 0) > src_duration: + # Need to loop - use concat filter + cmd = ["ffmpeg", "-y", "-stream_loop", "-1", "-i", str(input_path)] + cmd.extend(["-ss", str(wrapped_start)]) + if seg_duration: + cmd.extend(["-t", str(seg_duration)]) + cmd.extend(["-r", str(enc["fps"]), + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-c:a", enc["audio_codec"], + str(output_file)]) + else: + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + cmd.extend(["-ss", str(wrapped_start)]) + if seg_duration: + cmd.extend(["-t", str(seg_duration)]) + cmd.extend(["-r", str(enc["fps"]), + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-c:a", enc["audio_codec"], + str(output_file)]) + + subprocess.run(cmd, check=True, capture_output=True) + else: + raise ValueError(f"Cannot determine source duration for looping") + + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + print(f" -> {output_file}", file=sys.stderr) + + elif node_type == "EFFECT": + # Check cache + cached = check_cache(cache_dir, cache_id, ['.mp4']) + if cached: + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + effect_name = config.get("effect", "unknown") + effect_path = config.get("effect_path") + is_multi_input = config.get("multi_input", False) + + output_file = work_dir / f"effect_{step_id}.mp4" + enc = get_encoding(recipe_encoding, config) + + if effect_path: + full_path = recipe_dir / effect_path + effect_module = load_effect(full_path, effects_registry, recipe_dir, minimal_primitives) + params = {k: v for k, v in config.items() + if k not in ("effect", "effect_path", "cid", "encoding", "multi_input")} + print(f" Effect: {effect_name}", file=sys.stderr) + + # Get timing offset and duration for bindings + effect_time_offset = config.get("start", config.get("segment_start", 0)) + effect_duration = config.get("duration") + + if is_multi_input and len(inputs) > 1: + # Multi-input effect (blend, layer, etc.) + input_paths = [results[inp] for inp in inputs] + run_multi_effect(effect_module, input_paths, output_file, params, enc, analysis_data, time_offset=effect_time_offset, max_duration=effect_duration) + else: + # Single-input effect + input_path = results[inputs[0]] + run_effect(effect_module, input_path, output_file, params, enc, analysis_data, time_offset=effect_time_offset, max_duration=effect_duration) + else: + input_path = results[inputs[0]] + shutil.copy(input_path, output_file) + + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + print(f" -> {output_file}", file=sys.stderr) + + elif node_type == "SEQUENCE": + # Check cache first + cached = check_cache(cache_dir, cache_id, ['.mp4']) + if cached: + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + if len(inputs) < 2: + results[step_id] = results[inputs[0]] + return + + input_files = [results[inp] for inp in inputs] + enc = get_encoding(recipe_encoding, config) + + # Check for normalization config + resize_mode = config.get("resize_mode") + if resize_mode: + # Determine target dimensions + target_width = config.get("target_width") or enc.get("width") + target_height = config.get("target_height") or enc.get("height") + + # If no explicit target, use first input's dimensions + if not target_width or not target_height: + first_w, first_h = get_video_dimensions(input_files[0]) + target_width = target_width or first_w + target_height = target_height or first_h + + if target_width and target_height: + print(f" Normalizing {len(input_files)} inputs to {target_width}x{target_height} ({resize_mode})", file=sys.stderr) + normalized_files = [] + for i, inp_file in enumerate(input_files): + norm_file = work_dir / f"norm_{step_id[:8]}_{i:04d}.mp4" + normalize_video( + inp_file, norm_file, + target_width, target_height, + resize_mode, + priority=config.get("priority"), + pad_color=config.get("pad_color", "black"), + crop_gravity=config.get("crop_gravity", "center"), + encoding=enc, + ) + normalized_files.append(norm_file) + input_files = normalized_files + + # Use tree concat for efficiency + output_file = tree_concat(input_files, work_dir, f"seq_{step_id[:8]}") + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + print(f" -> {output_file}", file=sys.stderr) + + elif node_type == "MUX": + # Check cache + cached = check_cache(cache_dir, cache_id, ['.mp4']) + if cached: + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + video_path = results[inputs[0]] + audio_path = results[inputs[1]] + enc = get_encoding(recipe_encoding, config) + + output_file = work_dir / f"mux_{step_id}.mp4" + + # Get duration for progress bar + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(video_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + mux_duration = None + if probe_result.returncode == 0: + probe_data = json.loads(probe_result.stdout) + mux_duration = float(probe_data.get("format", {}).get("duration", 0)) + + cmd = ["ffmpeg", "-y", + "-i", str(video_path), "-i", str(audio_path), + "-map", "0:v", "-map", "1:a", + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-c:a", enc["audio_codec"], + "-shortest", str(output_file)] + + import re + mux_proc = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True) + pbar = ProgressBar(int(mux_duration * 1000) if mux_duration else 0, desc="mux") + for line in mux_proc.stderr: + m = re.search(r"time=(\d+):(\d+):(\d+)\.(\d+)", line) + if m: + h, mi, s, cs = int(m.group(1)), int(m.group(2)), int(m.group(3)), int(m.group(4)) + ms = h * 3600000 + mi * 60000 + s * 1000 + cs * 10 + pbar.set(ms) + pbar.finish() + mux_proc.wait() + if mux_proc.returncode != 0: + raise RuntimeError("MUX ffmpeg failed") + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + print(f" -> {output_file}", file=sys.stderr) + + elif node_type == "ANALYZE": + # Check cache first + cached = check_cache(cache_dir, cache_id, ['.json']) + if cached: + with open(cached) as f: + analysis_data[step_id] = json.load(f) + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + output_file = work_dir / f"analysis_{step_id}.json" + + if "analysis_results" in config: + # Analysis was done during planning + with open(output_file, "w") as f: + json.dump(config["analysis_results"], f) + analysis_data[step_id] = config["analysis_results"] + print(f" -> {output_file} (from plan)", file=sys.stderr) + else: + # Run analyzer now + analyzer_path = config.get("analyzer_path") + if analyzer_path: + analyzer_path = (recipe_dir / analyzer_path).resolve() + input_path = results[inputs[0]] + + # Load and run analyzer + import importlib.util + spec = importlib.util.spec_from_file_location("analyzer", analyzer_path) + analyzer_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(analyzer_module) + + # Run analysis + analyzer_params = {k: v for k, v in config.items() + if k not in ("analyzer", "analyzer_path", "cid")} + analysis_result = analyzer_module.analyze(input_path, analyzer_params) + + # Save and store results + with open(output_file, "w") as f: + json.dump(analysis_result, f) + analysis_data[step_id] = analysis_result + print(f" -> {output_file} (ran analyzer: {len(analysis_result.get('times', []))} pts)", file=sys.stderr) + else: + print(f" -> no analyzer path!", file=sys.stderr) + + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + + elif node_type == "SCAN": + # Check cache first + cached = check_cache(cache_dir, cache_id, ['.json']) + if cached: + with open(cached) as f: + scan_result = json.load(f) + analysis_data[step_id] = scan_result + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + import random + + # Load source analysis data + source_id = inputs[0] + source_data = analysis_data.get(source_id, {}) + event_times = source_data.get("times", []) + duration = source_data.get("duration", event_times[-1] if event_times else 0) + + seed = config.get("seed", 0) + init_expr = config.get("init", 0) + step_expr = config.get("step_expr") + emit_expr = config.get("emit_expr") + + # Initialize RNG and accumulator + rng = random.Random(seed) + acc = eval_scan_expr(init_expr, rng, {}) + + # Process each event + event_values = [] # (time, emitted_value) pairs + + for t in event_times: + # Build variable bindings from accumulator + if isinstance(acc, dict): + variables = dict(acc) + variables["acc"] = acc + else: + variables = {"acc": acc} + + # Step: update accumulator + acc = eval_scan_expr(step_expr, rng, variables) + + # Rebind after step + if isinstance(acc, dict): + variables = dict(acc) + variables["acc"] = acc + else: + variables = {"acc": acc} + + # Emit: produce output value + emit_val = eval_scan_expr(emit_expr, rng, variables) + if isinstance(emit_val, (int, float)): + event_values.append((t, float(emit_val))) + else: + event_values.append((t, 0.0)) + + # Generate high-resolution time-series with step-held interpolation + resolution = 100 # points per second + hi_res_times = [] + hi_res_values = [] + + current_val = 0.0 + event_idx = 0 + num_points = int(duration * resolution) + 1 + + for i in range(num_points): + t = i / resolution + + # Advance to the latest event at or before time t + while event_idx < len(event_values) and event_values[event_idx][0] <= t: + current_val = event_values[event_idx][1] + event_idx += 1 + + hi_res_times.append(round(t, 4)) + hi_res_values.append(current_val) + + scan_result = { + "times": hi_res_times, + "values": hi_res_values, + "duration": duration, + } + + analysis_data[step_id] = scan_result + + # Save to cache + output_file = work_dir / f"scan_{step_id}.json" + with open(output_file, "w") as f: + json.dump(scan_result, f) + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + + print(f" SCAN: {len(event_times)} events -> {len(hi_res_times)} points ({duration:.1f}s)", file=sys.stderr) + print(f" -> {output_file}", file=sys.stderr) + + elif node_type == "COMPOUND": + # Check cache first + cached = check_cache(cache_dir, cache_id, ['.mp4']) + if cached: + results[step_id] = cached + print(f" -> {cached} (cached)", file=sys.stderr) + return + + # Collapsed effect chains - compile to single FFmpeg command with sendcmd + filter_chain_raw = config.get("filter_chain", []) + if not filter_chain_raw: + raise ValueError("COMPOUND step has empty filter_chain") + + # Get effects registry for this compound step (use different name + # to avoid shadowing the outer effects_registry in nested function) + step_effects_registry = config.get("effects_registry", {}) + + # Convert filter_chain items from S-expression lists to dicts + # and clean nil Symbols from configs + filter_chain = [] + for item in filter_chain_raw: + if isinstance(item, dict): + # Clean nil Symbols from the config + cleaned_item = clean_nil_symbols(item) + filter_chain.append(cleaned_item) + elif isinstance(item, list) and item: + item_dict = sexp_to_dict(item) + ftype = item_dict.get("type", "UNKNOWN") + if isinstance(ftype, Symbol): + ftype = ftype.name + fconfig_raw = item_dict.get("config", {}) + if isinstance(fconfig_raw, list): + fconfig = sexp_to_dict(fconfig_raw) + elif isinstance(fconfig_raw, dict): + fconfig = fconfig_raw + else: + fconfig = {} + # Clean nil Symbols from config + fconfig = clean_nil_symbols(fconfig) + filter_chain.append({"type": ftype, "config": fconfig}) + else: + filter_chain.append({"type": "UNKNOWN", "config": {}}) + + input_path = results[inputs[0]] + # Debug: verify input exists and has content + if not input_path.exists(): + raise ValueError(f"COMPOUND input does not exist: {input_path}") + if input_path.stat().st_size == 0: + raise ValueError(f"COMPOUND input is empty: {input_path}") + print(f" COMPOUND input: {input_path} ({input_path.stat().st_size} bytes)", file=sys.stderr) + enc = get_encoding(recipe_encoding, config) + output_file = work_dir / f"compound_{step_id}.mp4" + + # Extract segment timing and effects + segment_start = 0 + segment_duration = None + effects = [] + + for filter_item in filter_chain: + filter_type = filter_item.get("type", "") + filter_config = filter_item.get("config", {}) + + if filter_type == "SEGMENT": + segment_start = filter_config.get("start", 0) + segment_duration = filter_config.get("duration") + if not segment_duration and filter_config.get("end"): + segment_duration = filter_config["end"] - segment_start + elif filter_type == "EFFECT": + effects.append(filter_config) + + # Try to compile effects to FFmpeg filters + from artdag.sexp.ffmpeg_compiler import FFmpegCompiler, generate_sendcmd_filter + compiler = FFmpegCompiler() + + # Check if any effect has bindings - these need Python path for per-frame resolution + any_has_bindings = any(_check_has_bindings(e) for e in effects) + + # Check if all effects have FFmpeg mappings + all_have_mappings = all( + compiler.get_mapping(e.get("effect", "")) is not None + for e in effects + ) + + # Use FFmpeg only for static effects (no bindings) + # Effects with bindings use Python path for proper per-frame binding resolution + if all_have_mappings and effects and not any_has_bindings: + # Compile to FFmpeg with sendcmd for dynamic params + ffmpeg_filters, sendcmd_path = generate_sendcmd_filter( + effects, + analysis_data, + segment_start, + segment_duration or 1.0, + ) + + # First extract segment with looping if needed + ffmpeg_input = input_path + if segment_start or segment_duration: + seg_temp = work_dir / f"compound_{step_id}_seg_temp.mp4" + extract_segment_with_loop(input_path, seg_temp, segment_start or 0, segment_duration, enc) + ffmpeg_input = seg_temp + + # Build FFmpeg command (segment already extracted, just apply filters) + cmd = ["ffmpeg", "-y", "-i", str(ffmpeg_input)] + + if ffmpeg_filters: + cmd.extend(["-vf", ffmpeg_filters]) + + cmd.extend(["-r", str(enc.get("fps", 30)), + "-c:v", enc["codec"], "-preset", enc["preset"], + "-crf", str(enc["crf"]), "-pix_fmt", "yuv420p", + "-c:a", enc["audio_codec"], + str(output_file)]) + + effect_names = [e.get("effect", "?") for e in effects] + print(f" COMPOUND (FFmpeg): {', '.join(effect_names)}", file=sys.stderr) + print(f" filters: {ffmpeg_filters[:80]}{'...' if len(ffmpeg_filters) > 80 else ''}", file=sys.stderr) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print(f" FFmpeg error: {result.stderr[:200]}", file=sys.stderr) + raise RuntimeError(f"FFmpeg failed: {result.stderr}") + + # Clean up sendcmd file + if sendcmd_path and sendcmd_path.exists(): + sendcmd_path.unlink() + else: + # Fall back to sequential processing for effects without FFmpeg mappings + current_input = input_path + + # First handle segment (with looping if source is shorter than requested) + for filter_item in filter_chain: + if filter_item.get("type") == "SEGMENT": + filter_config = filter_item.get("config", {}) + start = filter_config.get("start", 0) or 0 + duration = filter_config.get("duration") + + if start or duration: + seg_output = work_dir / f"compound_{step_id}_seg.mp4" + extract_segment_with_loop(current_input, seg_output, start, duration, enc) + current_input = seg_output + break + + # Load all effect modules and params for fused single-pass execution + effect_modules = [] + chain_params_list = [] + for effect_config in effects: + effect_name = effect_config.get("effect", "unknown") + effect_path = effect_config.get("effect_path") + + if not effect_path: + for effects_dir in ["effects", "sexp_effects/effects"]: + for ext in [".py", ".sexp"]: + candidate = recipe_dir / effects_dir / f"{effect_name}{ext}" + if candidate.exists(): + effect_path = str(candidate.relative_to(recipe_dir)) + break + if effect_path: + break + + if not effect_path: + raise ValueError(f"COMPOUND EFFECT '{effect_name}' has no effect_path or FFmpeg mapping") + + full_path = recipe_dir / effect_path + effect_modules.append(load_effect(full_path, step_effects_registry or effects_registry, recipe_dir, minimal_primitives)) + chain_params_list.append({k: v for k, v in effect_config.items() + if k not in ("effect", "effect_path", "cid", "encoding", "type")}) + + effect_names = [e.get("effect", "?") for e in effects] + print(f" COMPOUND (fused): {', '.join(effect_names)}", file=sys.stderr) + + run_effect_chain(effect_modules, current_input, output_file, + chain_params_list, enc, analysis_data, + time_offset=segment_start, + max_duration=segment_duration) + + results[step_id] = save_to_cache(cache_dir, cache_id, output_file) or output_file + print(f" -> {output_file}", file=sys.stderr) + + else: + raise ValueError(f"Unknown node type: {node_type}") + + # Group steps by level for parallel execution. + # Default to 4 workers to avoid overwhelming the system with + # CPU-intensive effects (ascii_art, ripple, etc.) running in parallel. + max_workers = int(os.environ.get("ARTDAG_WORKERS", 4)) + level_groups = [] + for k, g in groupby(ordered_steps, key=lambda s: s.get("level", 0)): + level_groups.append((k, list(g))) + + for level_num, level_steps in level_groups: + if len(level_steps) == 1: + _run_step(level_steps[0]) + else: + types = [s.get("node_type", "?") for s in level_steps] + types = [t.name if hasattr(t, 'name') else str(t) for t in types] + type_counts = {} + for t in types: + type_counts[t] = type_counts.get(t, 0) + 1 + type_summary = ", ".join(f"{v}x {k}" for k, v in type_counts.items()) + print(f"\n >> Level {level_num}: {len(level_steps)} steps in parallel ({type_summary})", file=sys.stderr) + with concurrent.futures.ThreadPoolExecutor(max_workers=min(len(level_steps), max_workers)) as pool: + futures = [pool.submit(_run_step, s) for s in level_steps] + for f in concurrent.futures.as_completed(futures): + f.result() # re-raises exceptions from threads + + # Get final output + final_output = results[plan["output_step_id"]] + print(f"\n--- Output ---", file=sys.stderr) + print(f"Final: {final_output}", file=sys.stderr) + + if output_path: + # Handle stdout specially - remux to streamable format + if str(output_path) in ("/dev/stdout", "-"): + # MP4 isn't streamable, use matroska which is + cmd = [ + "ffmpeg", "-y", "-i", str(final_output), + "-c", "copy", "-f", "matroska", "pipe:1" + ] + subprocess.run(cmd, stdout=sys.stdout.buffer, stderr=subprocess.DEVNULL) + return output_path + else: + shutil.copy(final_output, output_path) + print(f"Copied to: {output_path}", file=sys.stderr) + # Print path to stdout for piping + print(output_path) + return output_path + else: + # Use truncated source CID for output filename + source_cid = plan.get('source_hash', 'output')[:16] + out = recipe_dir / f"{source_cid}-output.mp4" + shutil.copy(final_output, out) + print(f"Copied to: {out}", file=sys.stderr) + # Print path to stdout for piping + print(out) + return out + + finally: + print(f"Debug: temp files in {work_dir}", file=sys.stderr) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Execute a plan") + parser.add_argument("plan", nargs="?", default="-", help="Plan file (- for stdin)") + parser.add_argument("-o", "--output", type=Path, help="Output file") + parser.add_argument("-d", "--dir", type=Path, default=Path("."), help="Recipe directory for resolving paths") + parser.add_argument("-a", "--analysis", type=Path, help="Analysis file (.sexp)") + + args = parser.parse_args() + + plan_path = None if args.plan == "-" else Path(args.plan) + if plan_path and not plan_path.exists(): + print(f"Plan not found: {plan_path}") + sys.exit(1) + + # Load external analysis if provided + external_analysis = None + if args.analysis: + if not args.analysis.exists(): + print(f"Analysis file not found: {args.analysis}") + sys.exit(1) + external_analysis = parse_analysis_sexp(args.analysis.read_text()) + + execute_plan(plan_path, args.output, args.dir, external_analysis=external_analysis) diff --git a/plan.py b/plan.py new file mode 100644 index 0000000..6bdead3 --- /dev/null +++ b/plan.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" +Plan generator for S-expression recipes. + +Expands dynamic nodes (SLICE_ON) into primitives using analysis data. +Outputs a plan that can be executed by execute.py. + +Usage: + analyze.py recipe.sexp > analysis.sexp + plan.py recipe.sexp --analysis analysis.sexp --sexp > plan.sexp + execute.py plan.sexp --analysis analysis.sexp +""" + +import sys +import json +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent / "artdag")) + +from artdag.sexp import compile_string, parse +from artdag.sexp.planner import create_plan +from artdag.sexp.parser import Binding, serialize as sexp_serialize, Symbol, Keyword + + +def parse_analysis_sexp(content: str) -> dict: + """Parse analysis S-expression into dict.""" + sexp = parse(content) + if isinstance(sexp, list) and len(sexp) == 1: + sexp = sexp[0] + + if not isinstance(sexp, list) or not sexp: + raise ValueError("Invalid analysis S-expression") + + # Should be (analysis (name ...) (name ...) ...) + if not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis": + raise ValueError("Expected (analysis ...) S-expression") + + result = {} + for item in sexp[1:]: + if isinstance(item, list) and item: + # Handle both Symbol names and quoted string names (node IDs) + first = item[0] + if isinstance(first, Symbol): + name = first.name + elif isinstance(first, str): + name = first + else: + continue # Skip malformed entries + data = {} + + i = 1 + while i < len(item): + if isinstance(item[i], Keyword): + key = item[i].name.replace("-", "_") + i += 1 + if i < len(item): + data[key] = item[i] + i += 1 + else: + i += 1 + + result[name] = data + + return result + + +def to_sexp(value, indent=0): + """Convert a Python value to S-expression string.""" + from artdag.sexp.parser import Lambda + + # Handle Binding objects + if isinstance(value, Binding): + # analysis_ref can be a string, node ID, or dict - serialize it properly + if isinstance(value.analysis_ref, str): + ref_str = f'"{value.analysis_ref}"' + else: + ref_str = to_sexp(value.analysis_ref, 0) + s = f'(bind {ref_str} :range [{value.range_min} {value.range_max}]' + if value.transform: + s += f' :transform {value.transform}' + return s + ')' + + # Handle binding dicts from compiler (convert to bind sexp format) + if isinstance(value, dict) and value.get("_binding"): + source = value.get("source", "") + range_val = value.get("range", [0.0, 1.0]) + range_min = range_val[0] if isinstance(range_val, list) else 0.0 + range_max = range_val[1] if isinstance(range_val, list) and len(range_val) > 1 else 1.0 + transform = value.get("transform") + offset = value.get("offset") + s = f'(bind "{source}" :range [{range_min} {range_max}]' + if offset: + s += f' :offset {offset}' + if transform: + s += f' :transform {transform}' + return s + ')' + + # Handle Symbol - serialize as bare identifier + if isinstance(value, Symbol): + return value.name + + # Handle Keyword - serialize with colon prefix + if isinstance(value, Keyword): + return f':{value.name}' + + # Handle Lambda + if isinstance(value, Lambda): + params = " ".join(value.params) + body = to_sexp(value.body, 0) + return f'(fn [{params}] {body})' + + prefix = " " * indent + if isinstance(value, dict): + if not value: + return "()" + items = [] + for k, v in value.items(): + if isinstance(k, str): + # Keys starting with _ are internal markers - keep underscore to avoid :-foo + if k.startswith('_'): + key_str = k # Keep as-is: _binding -> :_binding + else: + key_str = k.replace('_', '-') + else: + key_str = str(k) + items.append(f":{key_str} {to_sexp(v, 0)}") + return "(" + " ".join(items) + ")" + elif isinstance(value, list): + if not value: + return "()" + items = [to_sexp(v, 0) for v in value] + return "(" + " ".join(items) + ")" + elif isinstance(value, str): + # Escape special characters in strings + escaped = value.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') + return f'"{escaped}"' + elif isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, (int, float)): + return str(value) + elif value is None: + return "nil" + else: + # For any unknown type, convert to string and quote it + return f'"{str(value)}"' + + +def plan_recipe(recipe_path: Path, output_format: str = "text", output_file: Path = None, analysis_path: Path = None, params: dict = None): + """Compile recipe, expand dynamic nodes using analysis, output plan. + + Args: + recipe_path: Path to recipe file + output_format: Output format (text, json, sexp) + output_file: Optional output file path + analysis_path: Optional pre-computed analysis file + params: Optional dict of name -> value bindings to inject into compilation + """ + + recipe_text = recipe_path.read_text() + recipe_dir = recipe_path.parent + + print(f"Compiling: {recipe_path}", file=sys.stderr) + if params: + print(f"Parameters: {params}", file=sys.stderr) + compiled = compile_string(recipe_text, params) + print(f"Recipe: {compiled.name} v{compiled.version}", file=sys.stderr) + print(f"Nodes: {len(compiled.nodes)}", file=sys.stderr) + + # Load pre-computed analysis if provided (file or stdin with -) + pre_analysis = None + if analysis_path: + if str(analysis_path) == "-": + print(f"Loading analysis: stdin", file=sys.stderr) + analysis_text = sys.stdin.read() + else: + print(f"Loading analysis: {analysis_path}", file=sys.stderr) + analysis_text = analysis_path.read_text() + pre_analysis = parse_analysis_sexp(analysis_text) + print(f" Tracks: {list(pre_analysis.keys())}", file=sys.stderr) + + # Track analysis results for embedding in plan + analysis_data = {} + + def on_analysis(node_id, results): + analysis_data[node_id] = results + times = results.get("times", []) + print(f" Analysis complete: {len(times)} beat times", file=sys.stderr) + + # Create plan (uses pre_analysis or runs analyzers, expands SLICE_ON) + print("\n--- Planning ---", file=sys.stderr) + plan = create_plan( + compiled, + inputs={}, + recipe_dir=recipe_dir, + on_analysis=on_analysis, + pre_analysis=pre_analysis, + ) + + print(f"\nPlan ID: {plan.plan_id[:16]}...", file=sys.stderr) + print(f"Steps: {len(plan.steps)}", file=sys.stderr) + + # Generate output + if output_format == "sexp": + output = generate_sexp_output(compiled, plan, analysis_data) + elif output_format == "json": + output = generate_json_output(compiled, plan, analysis_data) + else: + output = generate_text_output(compiled, plan, analysis_data) + + # Write output + if output_file: + output_file.write_text(output) + print(f"\nPlan written to: {output_file}", file=sys.stderr) + else: + print(output) + + +class PlanJSONEncoder(json.JSONEncoder): + """Custom encoder for plan objects.""" + def default(self, obj): + if isinstance(obj, Binding): + return { + "_type": "binding", + "analysis_ref": obj.analysis_ref, + "track": obj.track, + "range_min": obj.range_min, + "range_max": obj.range_max, + "transform": obj.transform, + } + if isinstance(obj, Symbol): + return {"_type": "symbol", "name": obj.name} + if isinstance(obj, Keyword): + return {"_type": "keyword", "name": obj.name} + return super().default(obj) + + +def generate_json_output(compiled, plan, analysis_data): + """Generate JSON plan output.""" + output = { + "plan_id": plan.plan_id, + "recipe_id": compiled.name, + "recipe_hash": plan.recipe_hash, + "encoding": compiled.encoding, + "output_step_id": plan.output_step_id, + "steps": [], + } + + for step in plan.steps: + step_dict = { + "step_id": step.step_id, + "node_type": step.node_type, + "config": step.config, + "inputs": step.inputs, + "level": step.level, + "cache_id": step.cache_id, + } + # Embed analysis results for ANALYZE steps + if step.node_type == "ANALYZE" and step.step_id in analysis_data: + step_dict["config"]["analysis_results"] = analysis_data[step.step_id] + output["steps"].append(step_dict) + + return json.dumps(output, indent=2, cls=PlanJSONEncoder) + + +def generate_sexp_output(compiled, plan, analysis_data): + """Generate S-expression plan output.""" + lines = [ + f'(plan "{compiled.name}"', + f' :version "{compiled.version}"', + f' :plan-id "{plan.plan_id}"', + ] + + if compiled.encoding: + lines.append(f' :encoding {to_sexp(compiled.encoding)}') + + # Include analysis data for effect parameter bindings + if plan.analysis: + lines.append('') + lines.append(' (analysis') + for name, data in plan.analysis.items(): + times = data.get("times", []) + values = data.get("values", []) + # Truncate for display but include all data + times_str = " ".join(str(t) for t in times) + values_str = " ".join(str(v) for v in values) + lines.append(f' ({name}') + lines.append(f' :times ({times_str})') + lines.append(f' :values ({values_str}))') + lines.append(' )') + + lines.append('') + + for step in plan.steps: + lines.append(f' (step "{step.step_id}"') + lines.append(f' :type {step.node_type}') + lines.append(f' :level {step.level}') + lines.append(f' :cache "{step.cache_id}"') + if step.inputs: + inputs_str = " ".join(f'"{i}"' for i in step.inputs) + lines.append(f' :inputs ({inputs_str})') + for key, value in step.config.items(): + lines.append(f' :{key.replace("_", "-")} {to_sexp(value)}') + lines.append(' )') + + lines.append('') + lines.append(f' :output "{plan.output_step_id}")') + + return '\n'.join(lines) + + +def generate_text_output(compiled, plan, analysis_data): + """Generate human-readable text output.""" + lines = [ + f"Recipe: {compiled.name} v{compiled.version}", + ] + + if compiled.encoding: + lines.append(f"Encoding: {compiled.encoding}") + + lines.extend([ + f"\nPlan ID: {plan.plan_id}", + f"Output: {plan.output_step_id[:16]}...", + f"\nSteps ({len(plan.steps)}):", + "-" * 60, + ]) + + for step in plan.steps: + lines.append(f"\n[{step.level}] {step.node_type}") + lines.append(f" id: {step.step_id[:16]}...") + lines.append(f" cache: {step.cache_id[:16]}...") + if step.inputs: + lines.append(f" inputs: {[i[:16] + '...' for i in step.inputs]}") + for key, value in step.config.items(): + if key == "analysis_results": + lines.append(f" {key}: <{len(value.get('times', []))} times>") + else: + lines.append(f" {key}: {value}") + + return '\n'.join(lines) + + +def parse_param(param_str: str) -> tuple: + """Parse a key=value parameter string. + + Args: + param_str: String in format "key=value" + + Returns: + Tuple of (key, parsed_value) where value is converted to int/float if possible + """ + if "=" not in param_str: + raise ValueError(f"Invalid parameter format: {param_str} (expected key=value)") + + key, value = param_str.split("=", 1) + key = key.strip() + value = value.strip() + + # Try to parse as int + try: + return (key, int(value)) + except ValueError: + pass + + # Try to parse as float + try: + return (key, float(value)) + except ValueError: + pass + + # Return as string + return (key, value) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Generate execution plan from recipe") + parser.add_argument("recipe", type=Path, help="Recipe file (.sexp)") + parser.add_argument("-o", "--output", type=Path, help="Output file (default: stdout)") + parser.add_argument("-a", "--analysis", type=Path, help="Pre-computed analysis file (.sexp)") + parser.add_argument("-p", "--param", action="append", dest="params", metavar="KEY=VALUE", + help="Set recipe parameter (can be used multiple times)") + parser.add_argument("--json", action="store_true", help="Output JSON format") + parser.add_argument("--text", action="store_true", help="Output human-readable text format") + + args = parser.parse_args() + + if not args.recipe.exists(): + print(f"Recipe not found: {args.recipe}", file=sys.stderr) + sys.exit(1) + + if args.analysis and str(args.analysis) != "-" and not args.analysis.exists(): + print(f"Analysis file not found: {args.analysis}", file=sys.stderr) + sys.exit(1) + + # Parse parameters + params = {} + if args.params: + for param_str in args.params: + try: + key, value = parse_param(param_str) + params[key] = value + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if args.json: + fmt = "json" + elif args.text: + fmt = "text" + else: + fmt = "sexp" + + plan_recipe(args.recipe, fmt, args.output, args.analysis, params or None) diff --git a/run-effect.sh b/run-effect.sh new file mode 100644 index 0000000..1c7114a --- /dev/null +++ b/run-effect.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# Run a single effect by number (0-42) +# Usage: ./run-effect.sh +# +# Note: For effects with simple numeric params, you can also use native params: +# python3 plan.py recipe-parametric.sexp -p strength=5 -p amount=30 | python3 execute.py - -d . -o output.mp4 + +EFFECT_NUM=${1:-0} + +# Effect definitions array +EFFECTS=( + "(effect invert)" + "(effect grayscale)" + "(effect sepia)" + "(effect brightness :amount 30)" + "(effect contrast :amount 1.5)" + "(effect saturation :amount 2.0)" + "(effect hue_shift :degrees 90)" + "(effect color_cycle :speed 2)" + "(effect threshold :level 128)" + "(effect posterize :levels 6)" + "(effect blur :radius 10)" + "(effect sharpen :amount 2)" + "(effect bloom :intensity 0.6 :radius 20)" + "(effect color-adjust :brightness 20 :contrast 1.2)" + "(effect swirl :strength 3)" + "(effect fisheye :strength 0.5)" + "(effect wave :amplitude 30 :wavelength 60)" + "(effect ripple :amplitude 20 :frequency 6)" + "(effect kaleidoscope :segments 6 :rotation_speed 30)" + "(effect zoom :factor 1.2)" + "(effect rotate :angle 15)" + "(effect mirror :direction \"horizontal\")" + "(effect pixelate :block_size 16)" + "(effect ascii_art :char_size 8 :color_mode \"color\")" + "(effect ascii_zones :char_size 10)" + "(effect edge_detect :low 50 :high 150)" + "(effect emboss :strength 1.5)" + "(effect outline :thickness 2)" + "(effect neon_glow :glow_radius 20 :glow_intensity 2)" + "(effect crt :line_spacing 3 :vignette_amount 0.3)" + "(effect scanlines :spacing 3 :intensity 0.4)" + "(effect film_grain :intensity 0.25)" + "(effect vignette :strength 0.6)" + "(effect noise :amount 40)" + "(effect rgb_split :offset_x 20)" + "(effect echo :num_echoes 4 :decay 0.5)" + "(effect trails :persistence 0.7)" + "(effect strobe :frequency 4)" + "(effect flip :direction \"horizontal\")" + "(effect tile_grid :rows 2 :cols 2)" + "(effect pixelsort :threshold_low 30 :threshold_high 220)" + "(effect datamosh :corruption 0.5 :block_size 24)" +) + +if [ "$EFFECT_NUM" -lt 0 ] || [ "$EFFECT_NUM" -ge ${#EFFECTS[@]} ]; then + echo "Effect number must be 0-$((${#EFFECTS[@]}-1))" + exit 1 +fi + +EFFECT="${EFFECTS[$EFFECT_NUM]}" +echo "Running effect $EFFECT_NUM: $EFFECT" + +# Create temp recipe with selected effect +cat > /tmp/recipe-temp.sexp << EOF +(recipe "effect-test" + :version "1.0" + :encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30) + + (effect ascii_art :path "sexp_effects/effects/ascii_art.sexp") + (effect ascii_zones :path "sexp_effects/effects/ascii_zones.sexp") + (effect bloom :path "sexp_effects/effects/bloom.sexp") + (effect blur :path "sexp_effects/effects/blur.sexp") + (effect brightness :path "sexp_effects/effects/brightness.sexp") + (effect color-adjust :path "sexp_effects/effects/color-adjust.sexp") + (effect color_cycle :path "sexp_effects/effects/color_cycle.sexp") + (effect contrast :path "sexp_effects/effects/contrast.sexp") + (effect crt :path "sexp_effects/effects/crt.sexp") + (effect datamosh :path "sexp_effects/effects/datamosh.sexp") + (effect echo :path "sexp_effects/effects/echo.sexp") + (effect edge_detect :path "sexp_effects/effects/edge_detect.sexp") + (effect emboss :path "sexp_effects/effects/emboss.sexp") + (effect film_grain :path "sexp_effects/effects/film_grain.sexp") + (effect fisheye :path "sexp_effects/effects/fisheye.sexp") + (effect flip :path "sexp_effects/effects/flip.sexp") + (effect grayscale :path "sexp_effects/effects/grayscale.sexp") + (effect hue_shift :path "sexp_effects/effects/hue_shift.sexp") + (effect invert :path "sexp_effects/effects/invert.sexp") + (effect kaleidoscope :path "sexp_effects/effects/kaleidoscope.sexp") + (effect mirror :path "sexp_effects/effects/mirror.sexp") + (effect neon_glow :path "sexp_effects/effects/neon_glow.sexp") + (effect noise :path "sexp_effects/effects/noise.sexp") + (effect outline :path "sexp_effects/effects/outline.sexp") + (effect pixelate :path "sexp_effects/effects/pixelate.sexp") + (effect pixelsort :path "sexp_effects/effects/pixelsort.sexp") + (effect posterize :path "sexp_effects/effects/posterize.sexp") + (effect rgb_split :path "sexp_effects/effects/rgb_split.sexp") + (effect ripple :path "sexp_effects/effects/ripple.sexp") + (effect rotate :path "sexp_effects/effects/rotate.sexp") + (effect saturation :path "sexp_effects/effects/saturation.sexp") + (effect scanlines :path "sexp_effects/effects/scanlines.sexp") + (effect sepia :path "sexp_effects/effects/sepia.sexp") + (effect sharpen :path "sexp_effects/effects/sharpen.sexp") + (effect strobe :path "sexp_effects/effects/strobe.sexp") + (effect swirl :path "sexp_effects/effects/swirl.sexp") + (effect threshold :path "sexp_effects/effects/threshold.sexp") + (effect tile_grid :path "sexp_effects/effects/tile_grid.sexp") + (effect trails :path "sexp_effects/effects/trails.sexp") + (effect vignette :path "sexp_effects/effects/vignette.sexp") + (effect wave :path "sexp_effects/effects/wave.sexp") + (effect zoom :path "sexp_effects/effects/zoom.sexp") + + (def video (source :path "monday.webm")) + (def audio (source :path "dizzy.mp3")) + (def clip (-> video (segment :start 0 :duration 10))) + (def audio-clip (-> audio (segment :start 0 :duration 10))) + (def result (-> clip $EFFECT)) + (mux result audio-clip)) +EOF + +python3 plan.py /tmp/recipe-temp.sexp | python3 execute.py - -d . -o "effect-${EFFECT_NUM}.mp4" +echo "Output: effect-${EFFECT_NUM}.mp4" diff --git a/run-file.sh b/run-file.sh new file mode 100755 index 0000000..adacb4b --- /dev/null +++ b/run-file.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Run the full pipeline: analyze -> plan -> execute -> play +# Usage: ./run.sh recipe.sexp + +RECIPE="${1:-recipe-bound.sexp}" + +python analyze.py "$RECIPE" | python plan.py "$RECIPE" -a - | python execute.py - -d "$(dirname "$RECIPE")" -o output.mp4 diff --git a/run.py b/run.py new file mode 100755 index 0000000..23703c7 --- /dev/null +++ b/run.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Run a recipe: plan then execute. + +This is a convenience wrapper that: +1. Generates a plan (runs analyzers, expands SLICE_ON) +2. Executes the plan (produces video output) +""" + +import json +import sys +import tempfile +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent / "artdag")) + +from artdag.sexp import compile_string +from artdag.sexp.planner import create_plan +from artdag.sexp.parser import Binding + +# Import execute functionality +from execute import execute_plan + + +class PlanEncoder(json.JSONEncoder): + """JSON encoder that handles Binding objects.""" + def default(self, obj): + if isinstance(obj, Binding): + return { + "_bind": obj.analysis_ref, + "range_min": obj.range_min, + "range_max": obj.range_max, + } + return super().default(obj) + + +def run_recipe(recipe_path: Path, output_path: Path = None): + """Run a recipe file: plan then execute.""" + + recipe_text = recipe_path.read_text() + recipe_dir = recipe_path.parent + + print(f"=== COMPILE ===") + print(f"Recipe: {recipe_path}") + compiled = compile_string(recipe_text) + print(f"Name: {compiled.name} v{compiled.version}") + print(f"Nodes: {len(compiled.nodes)}") + + # Track analysis results + analysis_data = {} + + def on_analysis(node_id, results): + analysis_data[node_id] = results + times = results.get("times", []) + print(f" Analysis: {len(times)} beat times @ {results.get('tempo', 0):.1f} BPM") + + # Generate plan + print(f"\n=== PLAN ===") + plan = create_plan( + compiled, + inputs={}, + recipe_dir=recipe_dir, + on_analysis=on_analysis, + ) + + print(f"Plan ID: {plan.plan_id[:16]}...") + print(f"Steps: {len(plan.steps)}") + + # Write plan to temp file for execute + plan_dict = { + "plan_id": plan.plan_id, + "recipe_id": compiled.name, + "recipe_hash": plan.recipe_hash, + "encoding": compiled.encoding, + "output_step_id": plan.output_step_id, + "steps": [], + } + + for step in plan.steps: + step_dict = { + "step_id": step.step_id, + "node_type": step.node_type, + "config": step.config, + "inputs": step.inputs, + "level": step.level, + "cache_id": step.cache_id, + } + if step.node_type == "ANALYZE" and step.step_id in analysis_data: + step_dict["config"]["analysis_results"] = analysis_data[step.step_id] + plan_dict["steps"].append(step_dict) + + # Save plan + work_dir = Path(tempfile.mkdtemp(prefix="artdag_run_")) + plan_file = work_dir / "plan.json" + with open(plan_file, "w") as f: + json.dump(plan_dict, f, indent=2, cls=PlanEncoder) + + print(f"Plan saved: {plan_file}") + + # Execute plan + print(f"\n=== EXECUTE ===") + result = execute_plan(plan_file, output_path, recipe_dir) + + print(f"\n=== DONE ===") + print(f"Output: {result}") + return result + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: run.py [output.mp4]") + print() + print("Commands:") + print(" run.py - Plan and execute recipe") + print(" plan.py - Generate plan only") + print(" execute.py - Execute pre-generated plan") + sys.exit(1) + + recipe_path = Path(sys.argv[1]) + output_path = Path(sys.argv[2]) if len(sys.argv) > 2 else None + + if not recipe_path.exists(): + print(f"Recipe not found: {recipe_path}") + sys.exit(1) + + run_recipe(recipe_path, output_path) diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..b65e5f4 --- /dev/null +++ b/run.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Run the full pipeline: analyze -> plan -> execute -> play +# Usage: ./run.sh recipe.sexp + +RECIPE="${1:-recipe-bound.sexp}" + +python3 analyze.py "$RECIPE" | python plan.py "$RECIPE" -a - | python execute.py - -d "$(dirname "$RECIPE")" | xargs mpv --fs diff --git a/run_staged.py b/run_staged.py new file mode 100644 index 0000000..597aacb --- /dev/null +++ b/run_staged.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python3 +""" +Run a staged recipe through analyze -> plan -> execute pipeline. + +This script demonstrates stage-level caching: analysis stages can be +skipped on re-run if the inputs haven't changed. + +Usage: + python3 run_staged.py recipe.sexp [-o output.mp4] + python3 run_staged.py effects/ascii_art_staged.sexp -o ascii_out.mp4 + +The script: +1. Compiles the recipe and extracts stage information +2. For each stage in topological order: + - Check stage cache (skip if hit) + - Run stage (analyze, plan, execute) + - Cache stage outputs +3. Produce final output +""" + +import os +import sys +import json +import tempfile +import shutil +import subprocess +from pathlib import Path +from typing import Dict, List, Optional, Any + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent / "artdag")) + +from artdag.sexp import compile_string, parse +from artdag.sexp.parser import Symbol, Keyword, serialize +from artdag.sexp.planner import create_plan + +# Import unified cache +import cache as unified_cache + +import hashlib + + +def _cache_analysis_tracks(plan): + """Cache each analysis track individually, replace data with cache-id refs.""" + import json as _json + for name, data in plan.analysis.items(): + json_str = _json.dumps(data, sort_keys=True) + content_cid = hashlib.sha256(json_str.encode()).hexdigest() + unified_cache.cache_store_json(content_cid, data) + plan.analysis[name] = {"_cache_id": content_cid} + + +def _resolve_analysis_refs(analysis_dict): + """Resolve cache-id refs back to full analysis data.""" + resolved = {} + for name, data in analysis_dict.items(): + if isinstance(data, dict) and "_cache_id" in data: + loaded = unified_cache.cache_get_json(data["_cache_id"]) + if loaded: + resolved[name] = loaded + else: + resolved[name] = data + return resolved + + +def run_staged_recipe( + recipe_path: Path, + output_path: Optional[Path] = None, + cache_dir: Optional[Path] = None, + params: Optional[Dict[str, Any]] = None, + verbose: bool = True, + force_replan: bool = False, +) -> Path: + """ + Run a staged recipe with stage-level caching. + + Args: + recipe_path: Path to the .sexp recipe file + output_path: Optional output file path + cache_dir: Optional cache directory for stage results + params: Optional parameter overrides + verbose: Print progress information + + Returns: + Path to the final output file + """ + recipe_text = recipe_path.read_text() + recipe_dir = recipe_path.parent + + # Use unified cache + content_cache_dir = unified_cache.get_content_dir() + + def log(msg: str): + if verbose: + print(msg, file=sys.stderr) + + # Store recipe source by CID + recipe_cid, _ = unified_cache.content_store_string(recipe_text) + log(f"Recipe CID: {recipe_cid[:16]}...") + + # Compile recipe + log(f"Compiling: {recipe_path}") + compiled = compile_string(recipe_text, params, recipe_dir=recipe_dir) + log(f"Recipe: {compiled.name} v{compiled.version}") + log(f"Nodes: {len(compiled.nodes)}") + + # Store effects by CID + for effect_name, effect_info in compiled.registry.get("effects", {}).items(): + effect_path = effect_info.get("path") + effect_cid = effect_info.get("cid") + if effect_path and effect_cid: + effect_file = Path(effect_path) + if effect_file.exists(): + stored_cid, _ = unified_cache.content_store_file(effect_file) + if stored_cid == effect_cid: + log(f"Effect '{effect_name}' CID: {effect_cid[:16]}...") + else: + log(f"Warning: Effect '{effect_name}' CID mismatch") + + # Store analyzers by CID + for analyzer_name, analyzer_info in compiled.registry.get("analyzers", {}).items(): + analyzer_path = analyzer_info.get("path") + analyzer_cid = analyzer_info.get("cid") + if analyzer_path: + analyzer_file = Path(analyzer_path) if Path(analyzer_path).is_absolute() else recipe_dir / analyzer_path + if analyzer_file.exists(): + stored_cid, _ = unified_cache.content_store_file(analyzer_file) + log(f"Analyzer '{analyzer_name}' CID: {stored_cid[:16]}...") + + # Store included files by CID + for include_path, include_cid in compiled.registry.get("includes", {}).items(): + include_file = Path(include_path) + if include_file.exists(): + stored_cid, _ = unified_cache.content_store_file(include_file) + if stored_cid == include_cid: + log(f"Include '{include_file.name}' CID: {include_cid[:16]}...") + else: + log(f"Warning: Include '{include_file.name}' CID mismatch") + + # Check for stages + if not compiled.stages: + log("No stages found - running as regular recipe") + return _run_non_staged(compiled, recipe_dir, output_path, verbose) + + log(f"\nStages: {len(compiled.stages)}") + log(f"Stage order: {compiled.stage_order}") + + # Display stage info + for stage in compiled.stages: + log(f"\n Stage: {stage.name}") + log(f" Requires: {stage.requires or '(none)'}") + log(f" Inputs: {stage.inputs or '(none)'}") + log(f" Outputs: {stage.outputs}") + + # Create plan with analysis + log("\n--- Planning ---") + analysis_data = {} + + def on_analysis(node_id: str, results: dict): + analysis_data[node_id] = results + times = results.get("times", []) + log(f" Analysis complete: {node_id[:16]}... ({len(times)} times)") + + # Check for cached plan using unified cache + plan_cid = unified_cache.plan_exists(recipe_cid, params) + + if plan_cid and not force_replan: + plan_cache_path = unified_cache.plan_get_path(recipe_cid, params) + log(f"\nFound cached plan: {plan_cid[:16]}...") + plan_sexp_str = unified_cache.plan_load(recipe_cid, params) + + # Parse the cached plan + from execute import parse_plan_input + plan_dict = parse_plan_input(plan_sexp_str) + + # Resolve cache-id refs in plan's embedded analysis + if "analysis" in plan_dict: + plan_dict["analysis"] = _resolve_analysis_refs(plan_dict["analysis"]) + + # Load analysis data from unified cache + analysis_data = {} + for step in plan_dict.get("steps", []): + if step.get("node_type") == "ANALYZE": + step_id = step.get("step_id") + cached_analysis = unified_cache.cache_get_json(step_id) + if cached_analysis: + analysis_data[step_id] = cached_analysis + log(f" Loaded analysis: {step_id[:16]}...") + + log(f"Plan ID: {plan_dict.get('plan_id', 'unknown')[:16]}...") + log(f"Steps: {len(plan_dict.get('steps', []))}") + log(f"Analysis tracks: {list(analysis_data.keys())}") + + # Execute directly from cached plan + log("\n--- Execution (from cached plan) ---") + from execute import execute_plan + + result_path = execute_plan( + plan_path=plan_cache_path, + output_path=output_path, + recipe_dir=recipe_dir, + external_analysis=analysis_data, + cache_dir=content_cache_dir, + ) + + log(f"\n--- Complete ---") + log(f"Output: {result_path}") + return result_path + + # No cached plan - create new one + plan = create_plan( + compiled, + inputs={}, + recipe_dir=recipe_dir, + on_analysis=on_analysis, + ) + + log(f"\nPlan ID: {plan.plan_id[:16]}...") + log(f"Steps: {len(plan.steps)}") + log(f"Analysis tracks: {list(analysis_data.keys())}") + + # Cache analysis tracks individually and replace with cache-id refs + _cache_analysis_tracks(plan) + + # Save plan to unified cache + plan_sexp_str = plan.to_string(pretty=True) + plan_cache_id, plan_cid, plan_cache_path = unified_cache.plan_store(recipe_cid, params, plan_sexp_str) + log(f"Saved plan: {plan_cache_id[:16]}... → {plan_cid[:16]}...") + + # Execute the plan using execute.py logic + log("\n--- Execution ---") + from execute import execute_plan + + # Resolve cache-id refs back to full data for execution + resolved_analysis = _resolve_analysis_refs(plan.analysis) + + plan_dict = { + "plan_id": plan.plan_id, + "source_hash": plan.source_hash, + "encoding": compiled.encoding, + "output_step_id": plan.output_step_id, + "analysis": {**resolved_analysis, **analysis_data}, + "effects_registry": plan.effects_registry, + "minimal_primitives": plan.minimal_primitives, + "steps": [], + } + + for step in plan.steps: + step_dict = { + "step_id": step.step_id, + "node_type": step.node_type, + "config": step.config, + "inputs": step.inputs, + "level": step.level, + "cache_id": step.cache_id, + } + # Tag with stage info if present + if step.stage: + step_dict["stage"] = step.stage + plan_dict["steps"].append(step_dict) + + # Execute using unified cache + result_path = execute_plan( + plan_path=None, + output_path=output_path, + recipe_dir=recipe_dir, + plan_data=plan_dict, + external_analysis=analysis_data, + cache_dir=content_cache_dir, + ) + + log(f"\n--- Complete ---") + log(f"Output: {result_path}") + + return result_path + + +def _run_non_staged(compiled, recipe_dir: Path, output_path: Optional[Path], verbose: bool) -> Path: + """Run a non-staged recipe using the standard pipeline.""" + from execute import execute_plan + from plan import plan_recipe + + # This is a fallback for recipes without stages + # Just run through regular plan -> execute + raise NotImplementedError("Non-staged recipes should use plan.py | execute.py") + + +def list_cache(verbose: bool = False): + """List all cached items using the unified cache.""" + unified_cache.print_cache_listing(verbose) + + +def list_params(recipe_path: Path): + """List available parameters for a recipe and its effects.""" + from artdag.sexp import parse + from artdag.sexp.parser import Symbol, Keyword + from artdag.sexp.compiler import _parse_params + from artdag.sexp.effect_loader import load_sexp_effect_file + + recipe_text = recipe_path.read_text() + sexp = parse(recipe_text) + + if isinstance(sexp, list) and len(sexp) == 1: + sexp = sexp[0] + + # Find recipe name + recipe_name = sexp[1] if len(sexp) > 1 and isinstance(sexp[1], str) else recipe_path.stem + + # Find :params block and effect declarations + recipe_params = [] + effect_declarations = {} # name -> path + + i = 2 + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword) and item.name == "params": + if i + 1 < len(sexp): + recipe_params = _parse_params(sexp[i + 1]) + i += 2 + elif isinstance(item, list) and item: + # Check for effect declaration: (effect name :path "...") + if isinstance(item[0], Symbol) and item[0].name == "effect": + if len(item) >= 2: + effect_name = item[1].name if isinstance(item[1], Symbol) else item[1] + # Find :path + j = 2 + while j < len(item): + if isinstance(item[j], Keyword) and item[j].name == "path": + if j + 1 < len(item): + effect_declarations[effect_name] = item[j + 1] + break + j += 1 + i += 1 + else: + i += 1 + + # Load effect params + effect_params = {} # effect_name -> list of ParamDef + recipe_dir = recipe_path.parent + + for effect_name, effect_rel_path in effect_declarations.items(): + effect_path = recipe_dir / effect_rel_path + if effect_path.exists() and effect_path.suffix == ".sexp": + try: + _, _, _, param_defs = load_sexp_effect_file(effect_path) + if param_defs: + effect_params[effect_name] = param_defs + except Exception as e: + print(f"Warning: Could not load params from effect {effect_name}: {e}", file=sys.stderr) + + # Print results + def print_params(params, header_prefix=""): + print(f"{header_prefix}{'Name':<20} {'Type':<8} {'Default':<12} {'Range/Choices':<20} Description") + print(f"{header_prefix}{'-' * 88}") + for p in params: + range_str = "" + if p.range_min is not None and p.range_max is not None: + range_str = f"[{p.range_min}, {p.range_max}]" + elif p.choices: + range_str = ", ".join(p.choices[:3]) + if len(p.choices) > 3: + range_str += "..." + + default_str = str(p.default) if p.default is not None else "-" + if len(default_str) > 10: + default_str = default_str[:9] + "…" + + print(f"{header_prefix}{p.name:<20} {p.param_type:<8} {default_str:<12} {range_str:<20} {p.description}") + + if recipe_params: + print(f"\nRecipe parameters for '{recipe_name}':\n") + print_params(recipe_params) + else: + print(f"\nRecipe '{recipe_name}' has no declared parameters.") + + if effect_params: + for effect_name, params in effect_params.items(): + print(f"\n\nEffect '{effect_name}' parameters:\n") + print_params(params) + + if not recipe_params and not effect_params: + print("\nParameters can be declared using :params block:") + print(""" + :params ( + (color_mode :type string :default "color" :desc "Character color") + (char_size :type int :default 12 :range [4 32] :desc "Cell size") + ) +""") + return + + print("\n\nUsage:") + print(f" python3 run_staged.py {recipe_path} -p = [-p = ...]") + print(f"\nExample:") + all_params = recipe_params + [p for params in effect_params.values() for p in params] + if all_params: + p = all_params[0] + example_val = p.default if p.default else ("value" if p.param_type == "string" else "1") + print(f" python3 run_staged.py {recipe_path} -p {p.name}={example_val}") + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Run a staged recipe with stage-level caching", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 run_staged.py effects/ascii_art_fx_staged.sexp --list-params + python3 run_staged.py effects/ascii_art_fx_staged.sexp -o output.mp4 + python3 run_staged.py recipe.sexp -p color_mode=lime -p char_jitter=5 + """ + ) + parser.add_argument("recipe", type=Path, nargs="?", help="Recipe file (.sexp)") + parser.add_argument("-o", "--output", type=Path, help="Output file path") + parser.add_argument("-p", "--param", action="append", dest="params", + metavar="KEY=VALUE", help="Set recipe parameter") + parser.add_argument("-q", "--quiet", action="store_true", help="Suppress progress output") + parser.add_argument("--list-params", action="store_true", help="List available parameters and exit") + parser.add_argument("--list-cache", action="store_true", help="List cached items and exit") + parser.add_argument("--no-cache", action="store_true", help="Ignore cached plan, force re-planning") + parser.add_argument("--show-plan", action="store_true", help="Show the plan S-expression and exit (don't execute)") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument("-j", "--jobs", type=int, default=None, + help="Max parallel workers (default: 4, or ARTDAG_WORKERS env)") + parser.add_argument("--pipelines", type=int, default=None, + help="Max concurrent video pipelines (default: 1, or ARTDAG_VIDEO_PIPELINES env)") + + args = parser.parse_args() + + # Apply concurrency limits before any execution + if args.jobs is not None: + os.environ["ARTDAG_WORKERS"] = str(args.jobs) + if args.pipelines is not None: + os.environ["ARTDAG_VIDEO_PIPELINES"] = str(args.pipelines) + from execute import set_max_video_pipelines + set_max_video_pipelines(args.pipelines) + + # List cache mode - doesn't require recipe + if args.list_cache: + list_cache(verbose=args.verbose) + sys.exit(0) + + # All other modes require a recipe + if not args.recipe: + print("Error: recipe file required", file=sys.stderr) + sys.exit(1) + + if not args.recipe.exists(): + print(f"Recipe not found: {args.recipe}", file=sys.stderr) + sys.exit(1) + + # List params mode + if args.list_params: + list_params(args.recipe) + sys.exit(0) + + # Parse parameters + params = {} + if args.params: + for param_str in args.params: + if "=" not in param_str: + print(f"Invalid parameter format: {param_str}", file=sys.stderr) + sys.exit(1) + key, value = param_str.split("=", 1) + # Try to parse as number + try: + value = int(value) + except ValueError: + try: + value = float(value) + except ValueError: + pass # Keep as string + params[key] = value + + # Show plan mode - generate plan and display without executing + if args.show_plan: + recipe_text = args.recipe.read_text() + recipe_dir = args.recipe.parent + + # Compute recipe CID (content hash) + recipe_cid, _ = unified_cache.content_store_string(recipe_text) + + compiled = compile_string(recipe_text, params if params else None, recipe_dir=recipe_dir) + + # Check for cached plan using unified cache (keyed by source CID + params) + plan_cid = unified_cache.plan_exists(recipe_cid, params if params else None) + + if plan_cid and not args.no_cache: + print(f";; Cached plan CID: {plan_cid}", file=sys.stderr) + plan_sexp_str = unified_cache.plan_load(recipe_cid, params if params else None) + print(plan_sexp_str) + else: + print(f";; Generating new plan...", file=sys.stderr) + analysis_data = {} + def on_analysis(node_id: str, results: dict): + analysis_data[node_id] = results + + plan = create_plan( + compiled, + inputs={}, + recipe_dir=recipe_dir, + on_analysis=on_analysis, + ) + # Cache analysis tracks individually before serialization + _cache_analysis_tracks(plan) + plan_sexp_str = plan.to_string(pretty=True) + + # Save to unified cache + cache_id, plan_cid, plan_path = unified_cache.plan_store(recipe_cid, params if params else None, plan_sexp_str) + print(f";; Saved: {cache_id[:16]}... → {plan_cid}", file=sys.stderr) + print(plan_sexp_str) + sys.exit(0) + + result = run_staged_recipe( + recipe_path=args.recipe, + output_path=args.output, + params=params if params else None, + verbose=not args.quiet, + force_replan=args.no_cache, + ) + + # Print final output path + print(result) + + +if __name__ == "__main__": + main() diff --git a/sexp_effects/__init__.py b/sexp_effects/__init__.py new file mode 100644 index 0000000..b001c71 --- /dev/null +++ b/sexp_effects/__init__.py @@ -0,0 +1,32 @@ +""" +S-Expression Effects System + +Safe, shareable effects defined in S-expressions. +""" + +from .parser import parse, parse_file, Symbol, Keyword +from .interpreter import ( + Interpreter, + get_interpreter, + load_effect, + load_effects_dir, + run_effect, + list_effects, + make_process_frame, +) +from .primitives import PRIMITIVES + +__all__ = [ + 'parse', + 'parse_file', + 'Symbol', + 'Keyword', + 'Interpreter', + 'get_interpreter', + 'load_effect', + 'load_effects_dir', + 'run_effect', + 'list_effects', + 'make_process_frame', + 'PRIMITIVES', +] diff --git a/sexp_effects/effects/ascii_art.sexp b/sexp_effects/effects/ascii_art.sexp new file mode 100644 index 0000000..5565872 --- /dev/null +++ b/sexp_effects/effects/ascii_art.sexp @@ -0,0 +1,17 @@ +;; ASCII Art effect - converts image to ASCII characters +(require-primitives "ascii") + +(define-effect ascii_art + :params ( + (char_size :type int :default 8 :range [4 32]) + (alphabet :type string :default "standard") + (color_mode :type string :default "color" :desc ""color", "mono", "invert", or any color name/hex") + (background_color :type string :default "black" :desc "background color name/hex") + (invert_colors :type int :default 0 :desc "swap foreground and background colors") + (contrast :type float :default 1.5 :range [1 3]) + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid frame chars colors char_size color_mode background_color invert_colors))) diff --git a/sexp_effects/effects/ascii_art_fx.sexp b/sexp_effects/effects/ascii_art_fx.sexp new file mode 100644 index 0000000..2bb14be --- /dev/null +++ b/sexp_effects/effects/ascii_art_fx.sexp @@ -0,0 +1,52 @@ +;; ASCII Art FX - converts image to ASCII characters with per-character effects +(require-primitives "ascii") + +(define-effect ascii_art_fx + :params ( + ;; Basic parameters + (char_size :type int :default 8 :range [4 32] + :desc "Size of each character cell in pixels") + (alphabet :type string :default "standard" + :desc "Character set to use") + (color_mode :type string :default "color" + :choices [color mono invert] + :desc "Color mode: color, mono, invert, or any color name/hex") + (background_color :type string :default "black" + :desc "Background color name or hex value") + (invert_colors :type int :default 0 :range [0 1] + :desc "Swap foreground and background colors (0/1)") + (contrast :type float :default 1.5 :range [1 3] + :desc "Character selection contrast") + + ;; Per-character effects + (char_jitter :type float :default 0 :range [0 20] + :desc "Position jitter amount in pixels") + (char_scale :type float :default 1.0 :range [0.5 2.0] + :desc "Character scale factor") + (char_rotation :type float :default 0 :range [0 180] + :desc "Rotation amount in degrees") + (char_hue_shift :type float :default 0 :range [0 360] + :desc "Hue shift in degrees") + + ;; Modulation sources + (jitter_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives jitter modulation") + (scale_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives scale modulation") + (rotation_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives rotation modulation") + (hue_source :type string :default "none" + :choices [none luminance inv_luminance saturation position_x position_y position_diag random center_dist] + :desc "What drives hue shift modulation") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + (chars (luminance-to-chars luminances alphabet contrast))) + (render-char-grid-fx frame chars colors luminances char_size + color_mode background_color invert_colors + char_jitter char_scale char_rotation char_hue_shift + jitter_source scale_source rotation_source hue_source))) diff --git a/sexp_effects/effects/ascii_fx_zone.sexp b/sexp_effects/effects/ascii_fx_zone.sexp new file mode 100644 index 0000000..69e5340 --- /dev/null +++ b/sexp_effects/effects/ascii_fx_zone.sexp @@ -0,0 +1,102 @@ +;; Composable ASCII Art with Per-Zone Expression-Driven Effects +;; Requires ascii primitive library for the ascii-fx-zone primitive + +(require-primitives "ascii") + +;; Two modes of operation: +;; +;; 1. EXPRESSION MODE: Use zone-* variables in expression parameters +;; Zone variables available: +;; zone-row, zone-col: Grid position (integers) +;; zone-row-norm, zone-col-norm: Normalized position (0-1) +;; zone-lum: Cell luminance (0-1) +;; zone-sat: Cell saturation (0-1) +;; zone-hue: Cell hue (0-360) +;; zone-r, zone-g, zone-b: RGB components (0-1) +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 80 +;; :char_hue (* zone-lum 180) +;; :char_rotation (* zone-col-norm 30)) +;; +;; 2. CELL EFFECT MODE: Pass a lambda to apply arbitrary effects per-cell +;; The lambda receives (cell-image zone-dict) and returns modified cell. +;; Zone dict contains: row, col, row-norm, col-norm, lum, sat, hue, r, g, b, +;; char, color, cell_size, plus any bound analysis values. +;; +;; Any loaded sexp effect can be called on cells - each cell is just a small frame: +;; (blur cell radius) - Gaussian blur +;; (rotate cell angle) - Rotate by angle degrees +;; (brightness cell factor) - Adjust brightness +;; (contrast cell factor) - Adjust contrast +;; (saturation cell factor) - Adjust saturation +;; (hue_shift cell degrees) - Shift hue +;; (rgb_split cell offset_x offset_y) - RGB channel split +;; (invert cell) - Invert colors +;; (pixelate cell block_size) - Pixelate +;; (wave cell amplitude freq) - Wave distortion +;; ... and any other loaded effect +;; +;; Example: +;; (ascii-fx-zone frame +;; :cols 60 +;; :cell_effect (lambda [cell zone] +;; (blur (rotate cell (* (get zone "energy") 45)) +;; (if (> (get zone "lum") 0.5) 3 0)))) + +(define-effect ascii_fx_zone + :params ( + (cols :type int :default 80 :range [20 200] + :desc "Number of character columns") + (char_size :type int :default nil :range [4 32] + :desc "Character cell size in pixels (overrides cols if set)") + (alphabet :type string :default "standard" + :desc "Character set: standard, blocks, simple, digits, or custom string") + (color_mode :type string :default "color" + :desc "Color mode: color, mono, invert, or any color name/hex") + (background :type string :default "black" + :desc "Background color name or hex value") + (contrast :type float :default 1.5 :range [0.5 3.0] + :desc "Contrast for character selection") + (char_hue :type any :default nil + :desc "Hue shift expression (evaluated per-zone with zone-* vars)") + (char_saturation :type any :default nil + :desc "Saturation multiplier expression (1.0 = unchanged)") + (char_brightness :type any :default nil + :desc "Brightness multiplier expression (1.0 = unchanged)") + (char_scale :type any :default nil + :desc "Character scale expression (1.0 = normal size)") + (char_rotation :type any :default nil + :desc "Character rotation expression (degrees)") + (char_jitter :type any :default nil + :desc "Position jitter expression (pixels)") + (cell_effect :type any :default nil + :desc "Lambda (cell zone) -> cell for arbitrary per-cell effects") + ;; Convenience params for staged recipes (avoids compile-time expression issues) + (energy :type float :default nil + :desc "Energy multiplier (0-1) from audio analysis bind") + (rotation_scale :type float :default 0 + :desc "Max rotation at top-right when energy=1 (degrees)") + ) + ;; The ascii-fx-zone special form handles expression params + ;; If energy + rotation_scale provided, it builds: energy * scale * position_factor + ;; where position_factor = 0 at bottom-left, 3 at top-right + ;; If cell_effect provided, each character is rendered to a cell image, + ;; passed to the lambda, and the result composited back + (ascii-fx-zone frame + :cols cols + :char_size char_size + :alphabet alphabet + :color_mode color_mode + :background background + :contrast contrast + :char_hue char_hue + :char_saturation char_saturation + :char_brightness char_brightness + :char_scale char_scale + :char_rotation char_rotation + :char_jitter char_jitter + :cell_effect cell_effect + :energy energy + :rotation_scale rotation_scale)) diff --git a/sexp_effects/effects/ascii_zones.sexp b/sexp_effects/effects/ascii_zones.sexp new file mode 100644 index 0000000..6bc441c --- /dev/null +++ b/sexp_effects/effects/ascii_zones.sexp @@ -0,0 +1,30 @@ +;; ASCII Zones effect - different character sets for different brightness zones +;; Dark areas use simple chars, mid uses standard, bright uses blocks +(require-primitives "ascii") + +(define-effect ascii_zones + :params ( + (char_size :type int :default 8 :range [4 32]) + (dark_threshold :type int :default 80 :range [0 128]) + (bright_threshold :type int :default 180 :range [128 255]) + (color_mode :type string :default "color") + ) + (let* ((sample (cell-sample frame char_size)) + (colors (nth sample 0)) + (luminances (nth sample 1)) + ;; Start with simple chars as base + (base-chars (luminance-to-chars luminances "simple" 1.2)) + ;; Map each cell to appropriate alphabet based on brightness zone + (zoned-chars (map-char-grid base-chars luminances + (lambda (r c ch lum) + (cond + ;; Bright zones: use block characters + ((> lum bright_threshold) + (alphabet-char "blocks" (floor (/ (- lum bright_threshold) 15)))) + ;; Dark zones: use simple sparse chars + ((< lum dark_threshold) + (alphabet-char " .-" (floor (/ lum 30)))) + ;; Mid zones: use standard ASCII + (else + (alphabet-char "standard" (floor (/ lum 4))))))))) + (render-char-grid frame zoned-chars colors char_size color_mode (list 0 0 0)))) diff --git a/sexp_effects/effects/blend.sexp b/sexp_effects/effects/blend.sexp new file mode 100644 index 0000000..bf7fefd --- /dev/null +++ b/sexp_effects/effects/blend.sexp @@ -0,0 +1,31 @@ +;; Blend effect - combines two video frames +;; Streaming-compatible: frame is background, overlay is second frame +;; Usage: (blend background overlay :opacity 0.5 :mode "alpha") +;; +;; Params: +;; mode - blend mode (add, multiply, screen, overlay, difference, lighten, darken, alpha) +;; opacity - blend amount (0-1) + +(require-primitives "image" "blending" "core") + +(define-effect blend + :params ( + (overlay :type frame :default nil) + (mode :type string :default "alpha") + (opacity :type float :default 0.5) + ) + (if (core:is-nil overlay) + frame + (let [a frame + b overlay + a-h (image:height a) + a-w (image:width a) + b-h (image:height b) + b-w (image:width b) + ;; Resize b to match a if needed + b-sized (if (and (= a-w b-w) (= a-h b-h)) + b + (image:resize b a-w a-h "linear"))] + (if (= mode "alpha") + (blending:blend-images a b-sized opacity) + (blending:blend-images a (blending:blend-mode a b-sized mode) opacity))))) diff --git a/sexp_effects/effects/blend_multi.sexp b/sexp_effects/effects/blend_multi.sexp new file mode 100644 index 0000000..1ee160f --- /dev/null +++ b/sexp_effects/effects/blend_multi.sexp @@ -0,0 +1,58 @@ +;; N-way weighted blend effect +;; Streaming-compatible: pass inputs as a list of frames +;; Usage: (blend_multi :inputs [(read a) (read b) (read c)] :weights [0.3 0.4 0.3]) +;; +;; Parameters: +;; inputs - list of N frames to blend +;; weights - list of N floats, one per input (resolved per-frame) +;; mode - blend mode applied when folding each frame in: +;; "alpha" — pure weighted average (default) +;; "multiply" — darken by multiplication +;; "screen" — lighten (inverse multiply) +;; "overlay" — contrast-boosting midtone blend +;; "soft-light" — gentle dodge/burn +;; "hard-light" — strong dodge/burn +;; "color-dodge" — brightens towards white +;; "color-burn" — darkens towards black +;; "difference" — absolute pixel difference +;; "exclusion" — softer difference +;; "add" — additive (clamped) +;; "subtract" — subtractive (clamped) +;; "darken" — per-pixel minimum +;; "lighten" — per-pixel maximum +;; resize_mode - how to match frame dimensions (fit, crop, stretch) +;; +;; Uses a left-fold over inputs[1..N-1]. At each step the running +;; opacity is: w[i] / (w[0] + w[1] + ... + w[i]) +;; which produces the correct normalised weighted result. + +(require-primitives "image" "blending") + +(define-effect blend_multi + :params ( + (inputs :type list :default []) + (weights :type list :default []) + (mode :type string :default "alpha") + (resize_mode :type string :default "fit") + ) + (let [n (len inputs) + ;; Target dimensions from first frame + target-w (image:width (nth inputs 0)) + target-h (image:height (nth inputs 0)) + ;; Fold over indices 1..n-1 + ;; Accumulator is (list blended-frame running-weight-sum) + seed (list (nth inputs 0) (nth weights 0)) + result (reduce (range 1 n) seed + (lambda (pair i) + (let [acc (nth pair 0) + running (nth pair 1) + w (nth weights i) + new-running (+ running w) + opacity (/ w (max new-running 0.001)) + f (image:resize (nth inputs i) target-w target-h "linear") + ;; Apply blend mode then mix with opacity + blended (if (= mode "alpha") + (blending:blend-images acc f opacity) + (blending:blend-images acc (blending:blend-mode acc f mode) opacity))] + (list blended new-running))))] + (nth result 0))) diff --git a/sexp_effects/effects/bloom.sexp b/sexp_effects/effects/bloom.sexp new file mode 100644 index 0000000..3524d01 --- /dev/null +++ b/sexp_effects/effects/bloom.sexp @@ -0,0 +1,16 @@ +;; Bloom effect - glow on bright areas +(require-primitives "image" "blending") + +(define-effect bloom + :params ( + (intensity :type float :default 0.5 :range [0 2]) + (threshold :type int :default 200 :range [0 255]) + (radius :type int :default 15 :range [1 50]) + ) + (let* ((bright (map-pixels frame + (lambda (x y c) + (if (> (luminance c) threshold) + c + (rgb 0 0 0))))) + (blurred (image:blur bright radius))) + (blending:blend-mode frame blurred "add"))) diff --git a/sexp_effects/effects/blur.sexp b/sexp_effects/effects/blur.sexp new file mode 100644 index 0000000..b71a55a --- /dev/null +++ b/sexp_effects/effects/blur.sexp @@ -0,0 +1,8 @@ +;; Blur effect - gaussian blur +(require-primitives "image") + +(define-effect blur + :params ( + (radius :type int :default 5 :range [1 50]) + ) + (image:blur frame (max 1 radius))) diff --git a/sexp_effects/effects/brightness.sexp b/sexp_effects/effects/brightness.sexp new file mode 100644 index 0000000..4af53a7 --- /dev/null +++ b/sexp_effects/effects/brightness.sexp @@ -0,0 +1,9 @@ +;; Brightness effect - adjusts overall brightness +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect brightness + :params ( + (amount :type int :default 0 :range [-255 255]) + ) + (color_ops:adjust-brightness frame amount)) diff --git a/sexp_effects/effects/color-adjust.sexp b/sexp_effects/effects/color-adjust.sexp new file mode 100644 index 0000000..5318bdd --- /dev/null +++ b/sexp_effects/effects/color-adjust.sexp @@ -0,0 +1,13 @@ +;; Color adjustment effect - replaces TRANSFORM node +(require-primitives "color_ops") + +(define-effect color-adjust + :params ( + (brightness :type int :default 0 :range [-255 255] :desc "Brightness adjustment") + (contrast :type float :default 1 :range [0 3] :desc "Contrast multiplier") + (saturation :type float :default 1 :range [0 2] :desc "Saturation multiplier") + ) + (-> frame + (color_ops:adjust-brightness brightness) + (color_ops:adjust-contrast contrast) + (color_ops:adjust-saturation saturation))) diff --git a/sexp_effects/effects/color_cycle.sexp b/sexp_effects/effects/color_cycle.sexp new file mode 100644 index 0000000..e08dbb6 --- /dev/null +++ b/sexp_effects/effects/color_cycle.sexp @@ -0,0 +1,13 @@ +;; Color Cycle effect - animated hue rotation +(require-primitives "color_ops") + +(define-effect color_cycle + :params ( + (speed :type int :default 1 :range [0 10]) + ) + (let ((shift (* t speed 360))) + (map-pixels frame + (lambda (x y c) + (let* ((hsv (rgb->hsv c)) + (new-h (mod (+ (first hsv) shift) 360))) + (hsv->rgb (list new-h (nth hsv 1) (nth hsv 2)))))))) diff --git a/sexp_effects/effects/contrast.sexp b/sexp_effects/effects/contrast.sexp new file mode 100644 index 0000000..660661d --- /dev/null +++ b/sexp_effects/effects/contrast.sexp @@ -0,0 +1,9 @@ +;; Contrast effect - adjusts image contrast +;; Uses vectorized adjust primitive for fast processing +(require-primitives "color_ops") + +(define-effect contrast + :params ( + (amount :type int :default 1 :range [0.5 3]) + ) + (color_ops:adjust-contrast frame amount)) diff --git a/sexp_effects/effects/crt.sexp b/sexp_effects/effects/crt.sexp new file mode 100644 index 0000000..097eaf9 --- /dev/null +++ b/sexp_effects/effects/crt.sexp @@ -0,0 +1,30 @@ +;; CRT effect - old monitor simulation +(require-primitives "image") + +(define-effect crt + :params ( + (line_spacing :type int :default 2 :range [1 10]) + (line_opacity :type float :default 0.3 :range [0 1]) + (vignette_amount :type float :default 0.2) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (sqrt (+ (* cx cx) (* cy cy))))) + (map-pixels frame + (lambda (x y c) + (let* (;; Scanline darkening + (scanline-factor (if (= 0 (mod y line_spacing)) + (- 1 line_opacity) + 1)) + ;; Vignette + (dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (vignette-factor (- 1 (* (/ dist max-dist) vignette_amount))) + ;; Combined + (factor (* scanline-factor vignette-factor))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/sexp_effects/effects/datamosh.sexp b/sexp_effects/effects/datamosh.sexp new file mode 100644 index 0000000..60cec66 --- /dev/null +++ b/sexp_effects/effects/datamosh.sexp @@ -0,0 +1,14 @@ +;; Datamosh effect - glitch block corruption + +(define-effect datamosh + :params ( + (block_size :type int :default 32 :range [8 128]) + (corruption :type float :default 0.3 :range [0 1]) + (max_offset :type int :default 50 :range [0 200]) + (color_corrupt :type bool :default true) + ) + ;; Get previous frame from state, or use current frame if none + (let ((prev (state-get "prev_frame" frame))) + (begin + (state-set "prev_frame" (copy frame)) + (datamosh frame prev block_size corruption max_offset color_corrupt)))) diff --git a/sexp_effects/effects/echo.sexp b/sexp_effects/effects/echo.sexp new file mode 100644 index 0000000..2aa2287 --- /dev/null +++ b/sexp_effects/effects/echo.sexp @@ -0,0 +1,19 @@ +;; Echo effect - motion trails using frame buffer +(require-primitives "blending") + +(define-effect echo + :params ( + (num_echoes :type int :default 4 :range [1 20]) + (decay :type float :default 0.5 :range [0 1]) + ) + (let* ((buffer (state-get 'buffer (list))) + (new-buffer (take (cons frame buffer) (+ num_echoes 1)))) + (begin + (state-set 'buffer new-buffer) + ;; Blend frames with decay + (if (< (length new-buffer) 2) + frame + (let ((result (copy frame))) + ;; Simple blend of first two frames for now + ;; Full version would fold over all frames + (blending:blend-images frame (nth new-buffer 1) (* decay 0.5))))))) diff --git a/sexp_effects/effects/edge_detect.sexp b/sexp_effects/effects/edge_detect.sexp new file mode 100644 index 0000000..170befb --- /dev/null +++ b/sexp_effects/effects/edge_detect.sexp @@ -0,0 +1,9 @@ +;; Edge detection effect - highlights edges +(require-primitives "image") + +(define-effect edge_detect + :params ( + (low :type int :default 50 :range [10 100]) + (high :type int :default 150 :range [50 300]) + ) + (image:edge-detect frame low high)) diff --git a/sexp_effects/effects/emboss.sexp b/sexp_effects/effects/emboss.sexp new file mode 100644 index 0000000..1eac3ce --- /dev/null +++ b/sexp_effects/effects/emboss.sexp @@ -0,0 +1,13 @@ +;; Emboss effect - creates raised/3D appearance +(require-primitives "blending") + +(define-effect emboss + :params ( + (strength :type int :default 1 :range [0.5 3]) + (blend :type float :default 0.3 :range [0 1]) + ) + (let* ((kernel (list (list (- strength) (- strength) 0) + (list (- strength) 1 strength) + (list 0 strength strength))) + (embossed (convolve frame kernel))) + (blending:blend-images embossed frame blend))) diff --git a/sexp_effects/effects/film_grain.sexp b/sexp_effects/effects/film_grain.sexp new file mode 100644 index 0000000..29bdd75 --- /dev/null +++ b/sexp_effects/effects/film_grain.sexp @@ -0,0 +1,19 @@ +;; Film Grain effect - adds film grain texture +(require-primitives "core") + +(define-effect film_grain + :params ( + (intensity :type float :default 0.2 :range [0 1]) + (colored :type bool :default false) + ) + (let ((grain-amount (* intensity 50))) + (map-pixels frame + (lambda (x y c) + (if colored + (rgb (clamp (+ (red c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (green c) (gaussian 0 grain-amount)) 0 255) + (clamp (+ (blue c) (gaussian 0 grain-amount)) 0 255)) + (let ((n (gaussian 0 grain-amount))) + (rgb (clamp (+ (red c) n) 0 255) + (clamp (+ (green c) n) 0 255) + (clamp (+ (blue c) n) 0 255)))))))) diff --git a/sexp_effects/effects/fisheye.sexp b/sexp_effects/effects/fisheye.sexp new file mode 100644 index 0000000..37750a7 --- /dev/null +++ b/sexp_effects/effects/fisheye.sexp @@ -0,0 +1,16 @@ +;; Fisheye effect - barrel/pincushion lens distortion +(require-primitives "geometry" "image") + +(define-effect fisheye + :params ( + (strength :type float :default 0.3 :range [-1 1]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom_correct :type bool :default true) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:fisheye-coords w h strength cx cy zoom_correct))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/flip.sexp b/sexp_effects/effects/flip.sexp new file mode 100644 index 0000000..977e1e1 --- /dev/null +++ b/sexp_effects/effects/flip.sexp @@ -0,0 +1,16 @@ +;; Flip effect - flips image horizontally or vertically +(require-primitives "geometry") + +(define-effect flip + :params ( + (horizontal :type bool :default true) + (vertical :type bool :default false) + ) + (let ((result frame)) + (if horizontal + (set! result (geometry:flip-img result "horizontal")) + nil) + (if vertical + (set! result (geometry:flip-img result "vertical")) + nil) + result)) diff --git a/sexp_effects/effects/grayscale.sexp b/sexp_effects/effects/grayscale.sexp new file mode 100644 index 0000000..848f8a7 --- /dev/null +++ b/sexp_effects/effects/grayscale.sexp @@ -0,0 +1,7 @@ +;; Grayscale effect - converts to grayscale +;; Uses vectorized mix-gray primitive for fast processing +(require-primitives "image") + +(define-effect grayscale + :params () + (image:grayscale frame)) diff --git a/sexp_effects/effects/hue_shift.sexp b/sexp_effects/effects/hue_shift.sexp new file mode 100644 index 0000000..ab61bd6 --- /dev/null +++ b/sexp_effects/effects/hue_shift.sexp @@ -0,0 +1,12 @@ +;; Hue shift effect - rotates hue values +;; Uses vectorized shift-hsv primitive for fast processing + +(require-primitives "color_ops") + +(define-effect hue_shift + :params ( + (degrees :type int :default 0 :range [0 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((shift (+ degrees (* speed t)))) + (color_ops:shift-hsv frame shift 1 1))) diff --git a/sexp_effects/effects/invert.sexp b/sexp_effects/effects/invert.sexp new file mode 100644 index 0000000..34936da --- /dev/null +++ b/sexp_effects/effects/invert.sexp @@ -0,0 +1,9 @@ +;; Invert effect - inverts all colors +;; Uses vectorized invert-img primitive for fast processing +;; amount param: 0 = no invert, 1 = full invert (threshold at 0.5) + +(require-primitives "color_ops") + +(define-effect invert + :params ((amount :type float :default 1 :range [0 1])) + (if (> amount 0.5) (color_ops:invert-img frame) frame)) diff --git a/sexp_effects/effects/kaleidoscope.sexp b/sexp_effects/effects/kaleidoscope.sexp new file mode 100644 index 0000000..9487ae2 --- /dev/null +++ b/sexp_effects/effects/kaleidoscope.sexp @@ -0,0 +1,20 @@ +;; Kaleidoscope effect - mandala-like symmetry patterns +(require-primitives "geometry" "image") + +(define-effect kaleidoscope + :params ( + (segments :type int :default 6 :range [3 16]) + (rotation :type int :default 0 :range [0 360]) + (rotation_speed :type int :default 0 :range [-180 180]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (zoom :type int :default 1 :range [0.5 3]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + ;; Total rotation including time-based animation + (total_rot (+ rotation (* rotation_speed (or _time 0)))) + (coords (geometry:kaleidoscope-coords w h segments total_rot cx cy zoom))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/layer.sexp b/sexp_effects/effects/layer.sexp new file mode 100644 index 0000000..e57d627 --- /dev/null +++ b/sexp_effects/effects/layer.sexp @@ -0,0 +1,36 @@ +;; Layer effect - composite overlay over background at position +;; Streaming-compatible: frame is background, overlay is foreground +;; Usage: (layer background overlay :x 10 :y 20 :opacity 0.8) +;; +;; Params: +;; overlay - frame to composite on top +;; x, y - position to place overlay +;; opacity - blend amount (0-1) +;; mode - blend mode (alpha, multiply, screen, etc.) + +(require-primitives "image" "blending" "core") + +(define-effect layer + :params ( + (overlay :type frame :default nil) + (x :type int :default 0) + (y :type int :default 0) + (opacity :type float :default 1.0) + (mode :type string :default "alpha") + ) + (if (core:is-nil overlay) + frame + (let [bg (copy frame) + fg overlay + fg-w (image:width fg) + fg-h (image:height fg)] + (if (= opacity 1.0) + ;; Simple paste + (paste bg fg x y) + ;; Blend with opacity + (let [blended (if (= mode "alpha") + (blending:blend-images (image:crop bg x y fg-w fg-h) fg opacity) + (blending:blend-images (image:crop bg x y fg-w fg-h) + (blending:blend-mode (image:crop bg x y fg-w fg-h) fg mode) + opacity))] + (paste bg blended x y)))))) diff --git a/sexp_effects/effects/mirror.sexp b/sexp_effects/effects/mirror.sexp new file mode 100644 index 0000000..a450cb6 --- /dev/null +++ b/sexp_effects/effects/mirror.sexp @@ -0,0 +1,33 @@ +;; Mirror effect - mirrors half of image +(require-primitives "geometry" "image") + +(define-effect mirror + :params ( + (mode :type string :default "left_right") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (hw (floor (/ w 2))) + (hh (floor (/ h 2)))) + (cond + ((= mode "left_right") + (let ((left (image:crop frame 0 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img left "horizontal") hw 0))) + + ((= mode "right_left") + (let ((right (image:crop frame hw 0 hw h)) + (result (copy frame))) + (paste result (geometry:flip-img right "horizontal") 0 0))) + + ((= mode "top_bottom") + (let ((top (image:crop frame 0 0 w hh)) + (result (copy frame))) + (paste result (geometry:flip-img top "vertical") 0 hh))) + + ((= mode "bottom_top") + (let ((bottom (image:crop frame 0 hh w hh)) + (result (copy frame))) + (paste result (geometry:flip-img bottom "vertical") 0 0))) + + (else frame)))) diff --git a/sexp_effects/effects/neon_glow.sexp b/sexp_effects/effects/neon_glow.sexp new file mode 100644 index 0000000..39245ab --- /dev/null +++ b/sexp_effects/effects/neon_glow.sexp @@ -0,0 +1,23 @@ +;; Neon Glow effect - glowing edge effect +(require-primitives "image" "blending") + +(define-effect neon_glow + :params ( + (edge_low :type int :default 50 :range [10 200]) + (edge_high :type int :default 150 :range [50 300]) + (glow_radius :type int :default 15 :range [1 50]) + (glow_intensity :type int :default 2 :range [0.5 5]) + (background :type float :default 0.3 :range [0 1]) + ) + (let* ((edge-img (image:edge-detect frame edge_low edge_high)) + (glow (image:blur edge-img glow_radius)) + ;; Intensify the glow + (bright-glow (map-pixels glow + (lambda (x y c) + (rgb (clamp (* (red c) glow_intensity) 0 255) + (clamp (* (green c) glow_intensity) 0 255) + (clamp (* (blue c) glow_intensity) 0 255)))))) + (blending:blend-mode (blending:blend-images frame (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 background)) + bright-glow + "screen"))) diff --git a/sexp_effects/effects/noise.sexp b/sexp_effects/effects/noise.sexp new file mode 100644 index 0000000..4da8298 --- /dev/null +++ b/sexp_effects/effects/noise.sexp @@ -0,0 +1,8 @@ +;; Noise effect - adds random noise +;; Uses vectorized add-noise primitive for fast processing + +(define-effect noise + :params ( + (amount :type int :default 20 :range [0 100]) + ) + (add-noise frame amount)) diff --git a/sexp_effects/effects/outline.sexp b/sexp_effects/effects/outline.sexp new file mode 100644 index 0000000..276f891 --- /dev/null +++ b/sexp_effects/effects/outline.sexp @@ -0,0 +1,24 @@ +;; Outline effect - shows only edges +(require-primitives "image") + +(define-effect outline + :params ( + (thickness :type int :default 2 :range [1 10]) + (threshold :type int :default 100 :range [20 300]) + (color :type list :default (list 0 0 0) + ) + (fill_mode "original")) + (let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold)) + (dilated (if (> thickness 1) + (dilate edge-img thickness) + edge-img)) + (base (cond + ((= fill_mode "original") (copy frame)) + ((= fill_mode "white") (make-image (image:width frame) (image:height frame) (list 255 255 255))) + (else (make-image (image:width frame) (image:height frame) (list 0 0 0)))))) + (map-pixels base + (lambda (x y c) + (let ((edge-val (luminance (pixel dilated x y)))) + (if (> edge-val 128) + color + c)))))) diff --git a/sexp_effects/effects/pixelate.sexp b/sexp_effects/effects/pixelate.sexp new file mode 100644 index 0000000..3d28ce1 --- /dev/null +++ b/sexp_effects/effects/pixelate.sexp @@ -0,0 +1,13 @@ +;; Pixelate effect - creates blocky pixels +(require-primitives "image") + +(define-effect pixelate + :params ( + (block_size :type int :default 8 :range [2 64]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (small-w (max 1 (floor (/ w block_size)))) + (small-h (max 1 (floor (/ h block_size)))) + (small (image:resize frame small-w small-h "area"))) + (image:resize small w h "nearest"))) diff --git a/sexp_effects/effects/pixelsort.sexp b/sexp_effects/effects/pixelsort.sexp new file mode 100644 index 0000000..155ac13 --- /dev/null +++ b/sexp_effects/effects/pixelsort.sexp @@ -0,0 +1,11 @@ +;; Pixelsort effect - glitch art pixel sorting + +(define-effect pixelsort + :params ( + (sort_by :type string :default "lightness") + (threshold_low :type int :default 50 :range [0 255]) + (threshold_high :type int :default 200 :range [0 255]) + (angle :type int :default 0 :range [0 180]) + (reverse :type bool :default false) + ) + (pixelsort frame sort_by threshold_low threshold_high angle reverse)) diff --git a/sexp_effects/effects/posterize.sexp b/sexp_effects/effects/posterize.sexp new file mode 100644 index 0000000..7052ed3 --- /dev/null +++ b/sexp_effects/effects/posterize.sexp @@ -0,0 +1,8 @@ +;; Posterize effect - reduces color levels +(require-primitives "color_ops") + +(define-effect posterize + :params ( + (levels :type int :default 8 :range [2 32]) + ) + (color_ops:posterize frame levels)) diff --git a/sexp_effects/effects/resize-frame.sexp b/sexp_effects/effects/resize-frame.sexp new file mode 100644 index 0000000..a1cce27 --- /dev/null +++ b/sexp_effects/effects/resize-frame.sexp @@ -0,0 +1,11 @@ +;; Resize effect - replaces RESIZE node +;; Note: uses target-w/target-h to avoid conflict with width/height primitives +(require-primitives "image") + +(define-effect resize-frame + :params ( + (target-w :type int :default 640 :desc "Target width in pixels") + (target-h :type int :default 480 :desc "Target height in pixels") + (mode :type string :default "linear" :choices [linear nearest area] :desc "Interpolation mode") + ) + (image:resize frame target-w target-h mode)) diff --git a/sexp_effects/effects/rgb_split.sexp b/sexp_effects/effects/rgb_split.sexp new file mode 100644 index 0000000..4582701 --- /dev/null +++ b/sexp_effects/effects/rgb_split.sexp @@ -0,0 +1,13 @@ +;; RGB Split effect - chromatic aberration + +(define-effect rgb_split + :params ( + (offset_x :type int :default 10 :range [-50 50]) + (offset_y :type int :default 0 :range [-50 50]) + ) + (let* ((r (channel frame 0)) + (g (channel frame 1)) + (b (channel frame 2)) + (r-shifted (translate (merge-channels r r r) offset_x offset_y)) + (b-shifted (translate (merge-channels b b b) (- offset_x) (- offset_y)))) + (merge-channels (channel r-shifted 0) g (channel b-shifted 0)))) diff --git a/sexp_effects/effects/ripple.sexp b/sexp_effects/effects/ripple.sexp new file mode 100644 index 0000000..0bb7a8d --- /dev/null +++ b/sexp_effects/effects/ripple.sexp @@ -0,0 +1,19 @@ +;; Ripple effect - radial wave distortion from center +(require-primitives "geometry" "image" "math") + +(define-effect ripple + :params ( + (frequency :type int :default 5 :range [1 20]) + (amplitude :type int :default 10 :range [0 50]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (decay :type int :default 1 :range [0 5]) + (speed :type int :default 1 :range [0 10]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (phase (* (or t 0) speed 2 pi)) + (coords (geometry:ripple-displace w h frequency amplitude cx cy decay phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/rotate.sexp b/sexp_effects/effects/rotate.sexp new file mode 100644 index 0000000..d06c2f7 --- /dev/null +++ b/sexp_effects/effects/rotate.sexp @@ -0,0 +1,11 @@ +;; Rotate effect - rotates image + +(require-primitives "geometry") + +(define-effect rotate + :params ( + (angle :type int :default 0 :range [-360 360]) + (speed :type int :default 0 :desc "rotation per second") + ) + (let ((total-angle (+ angle (* speed t)))) + (geometry:rotate-img frame total-angle))) diff --git a/sexp_effects/effects/saturation.sexp b/sexp_effects/effects/saturation.sexp new file mode 100644 index 0000000..9852dc7 --- /dev/null +++ b/sexp_effects/effects/saturation.sexp @@ -0,0 +1,9 @@ +;; Saturation effect - adjusts color saturation +;; Uses vectorized shift-hsv primitive for fast processing +(require-primitives "color_ops") + +(define-effect saturation + :params ( + (amount :type int :default 1 :range [0 3]) + ) + (color_ops:adjust-saturation frame amount)) diff --git a/sexp_effects/effects/scanlines.sexp b/sexp_effects/effects/scanlines.sexp new file mode 100644 index 0000000..ddfcf44 --- /dev/null +++ b/sexp_effects/effects/scanlines.sexp @@ -0,0 +1,15 @@ +;; Scanlines effect - VHS-style horizontal line shifting +(require-primitives "core") + +(define-effect scanlines + :params ( + (amplitude :type int :default 10 :range [0 100]) + (frequency :type int :default 10 :range [1 100]) + (randomness :type float :default 0.5 :range [0 1]) + ) + (map-rows frame + (lambda (y row) + (let* ((sine-shift (* amplitude (sin (/ (* y 6.28) (max 1 frequency))))) + (rand-shift (core:rand-range (- amplitude) amplitude)) + (shift (floor (lerp sine-shift rand-shift randomness)))) + (roll row shift 0))))) diff --git a/sexp_effects/effects/sepia.sexp b/sexp_effects/effects/sepia.sexp new file mode 100644 index 0000000..e3a5875 --- /dev/null +++ b/sexp_effects/effects/sepia.sexp @@ -0,0 +1,7 @@ +;; Sepia effect - applies sepia tone +;; Classic warm vintage look +(require-primitives "color_ops") + +(define-effect sepia + :params () + (color_ops:sepia frame)) diff --git a/sexp_effects/effects/sharpen.sexp b/sexp_effects/effects/sharpen.sexp new file mode 100644 index 0000000..538bd7f --- /dev/null +++ b/sexp_effects/effects/sharpen.sexp @@ -0,0 +1,8 @@ +;; Sharpen effect - sharpens edges +(require-primitives "image") + +(define-effect sharpen + :params ( + (amount :type int :default 1 :range [0 5]) + ) + (image:sharpen frame amount)) diff --git a/sexp_effects/effects/strobe.sexp b/sexp_effects/effects/strobe.sexp new file mode 100644 index 0000000..e51ba30 --- /dev/null +++ b/sexp_effects/effects/strobe.sexp @@ -0,0 +1,16 @@ +;; Strobe effect - holds frames for choppy look +(require-primitives "core") + +(define-effect strobe + :params ( + (frame_rate :type int :default 12 :range [1 60]) + ) + (let* ((held (state-get 'held nil)) + (held-until (state-get 'held-until 0)) + (frame-duration (/ 1 frame_rate))) + (if (or (core:is-nil held) (>= t held-until)) + (begin + (state-set 'held (copy frame)) + (state-set 'held-until (+ t frame-duration)) + frame) + held))) diff --git a/sexp_effects/effects/swirl.sexp b/sexp_effects/effects/swirl.sexp new file mode 100644 index 0000000..ba9cf57 --- /dev/null +++ b/sexp_effects/effects/swirl.sexp @@ -0,0 +1,17 @@ +;; Swirl effect - spiral vortex distortion +(require-primitives "geometry" "image") + +(define-effect swirl + :params ( + (strength :type int :default 1 :range [-10 10]) + (radius :type float :default 0.5 :range [0.1 2]) + (center_x :type float :default 0.5 :range [0 1]) + (center_y :type float :default 0.5 :range [0 1]) + (falloff :type string :default "quadratic") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (* w center_x)) + (cy (* h center_y)) + (coords (geometry:swirl-coords w h strength radius cx cy falloff))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/threshold.sexp b/sexp_effects/effects/threshold.sexp new file mode 100644 index 0000000..50d3bc5 --- /dev/null +++ b/sexp_effects/effects/threshold.sexp @@ -0,0 +1,9 @@ +;; Threshold effect - converts to black and white +(require-primitives "color_ops") + +(define-effect threshold + :params ( + (level :type int :default 128 :range [0 255]) + (invert :type bool :default false) + ) + (color_ops:threshold frame level invert)) diff --git a/sexp_effects/effects/tile_grid.sexp b/sexp_effects/effects/tile_grid.sexp new file mode 100644 index 0000000..44487a9 --- /dev/null +++ b/sexp_effects/effects/tile_grid.sexp @@ -0,0 +1,29 @@ +;; Tile Grid effect - tiles image in grid +(require-primitives "geometry" "image") + +(define-effect tile_grid + :params ( + (rows :type int :default 2 :range [1 10]) + (cols :type int :default 2 :range [1 10]) + (gap :type int :default 0 :range [0 50]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (tile-w (floor (/ (- w (* gap (- cols 1))) cols))) + (tile-h (floor (/ (- h (* gap (- rows 1))) rows))) + (tile (image:resize frame tile-w tile-h "area")) + (result (make-image w h (list 0 0 0)))) + (begin + ;; Manually place tiles using nested iteration + ;; This is a simplified version - full version would loop + (paste result tile 0 0) + (if (> cols 1) + (paste result tile (+ tile-w gap) 0) + nil) + (if (> rows 1) + (paste result tile 0 (+ tile-h gap)) + nil) + (if (and (> cols 1) (> rows 1)) + (paste result tile (+ tile-w gap) (+ tile-h gap)) + nil) + result))) diff --git a/sexp_effects/effects/trails.sexp b/sexp_effects/effects/trails.sexp new file mode 100644 index 0000000..f16c302 --- /dev/null +++ b/sexp_effects/effects/trails.sexp @@ -0,0 +1,20 @@ +;; Trails effect - persistent motion trails +(require-primitives "image" "blending") + +(define-effect trails + :params ( + (persistence :type float :default 0.8 :range [0 0.99]) + ) + (let* ((buffer (state-get 'buffer nil)) + (current frame)) + (if (= buffer nil) + (begin + (state-set 'buffer (copy frame)) + frame) + (let* ((faded (blending:blend-images buffer + (make-image (image:width frame) (image:height frame) (list 0 0 0)) + (- 1 persistence))) + (result (blending:blend-mode faded current "lighten"))) + (begin + (state-set 'buffer result) + result))))) diff --git a/sexp_effects/effects/vignette.sexp b/sexp_effects/effects/vignette.sexp new file mode 100644 index 0000000..46e63ee --- /dev/null +++ b/sexp_effects/effects/vignette.sexp @@ -0,0 +1,23 @@ +;; Vignette effect - darkens corners +(require-primitives "image") + +(define-effect vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type int :default 1 :range [0.5 2]) + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + (cx (/ w 2)) + (cy (/ h 2)) + (max-dist (* (sqrt (+ (* cx cx) (* cy cy))) radius))) + (map-pixels frame + (lambda (x y c) + (let* ((dx (- x cx)) + (dy (- y cy)) + (dist (sqrt (+ (* dx dx) (* dy dy)))) + (factor (- 1 (* (/ dist max-dist) strength))) + (factor (clamp factor 0 1))) + (rgb (* (red c) factor) + (* (green c) factor) + (* (blue c) factor))))))) diff --git a/sexp_effects/effects/wave.sexp b/sexp_effects/effects/wave.sexp new file mode 100644 index 0000000..98b03c2 --- /dev/null +++ b/sexp_effects/effects/wave.sexp @@ -0,0 +1,22 @@ +;; Wave effect - sine wave displacement distortion +(require-primitives "geometry" "image") + +(define-effect wave + :params ( + (amplitude :type int :default 10 :range [0 100]) + (wavelength :type int :default 50 :range [10 500]) + (speed :type int :default 1 :range [0 10]) + (direction :type string :default "horizontal") + ) + (let* ((w (image:width frame)) + (h (image:height frame)) + ;; Use _time for animation phase + (phase (* (or _time 0) speed 2 pi)) + ;; Calculate frequency: waves per dimension + (freq (/ (if (= direction "vertical") w h) wavelength)) + (axis (cond + ((= direction "horizontal") "x") + ((= direction "vertical") "y") + (else "both"))) + (coords (geometry:wave-coords w h axis freq amplitude phase))) + (geometry:remap frame (geometry:coords-x coords) (geometry:coords-y coords)))) diff --git a/sexp_effects/effects/zoom.sexp b/sexp_effects/effects/zoom.sexp new file mode 100644 index 0000000..6e4b9ff --- /dev/null +++ b/sexp_effects/effects/zoom.sexp @@ -0,0 +1,8 @@ +;; Zoom effect - zooms in/out from center +(require-primitives "geometry") + +(define-effect zoom + :params ( + (amount :type int :default 1 :range [0.1 5]) + ) + (geometry:scale-img frame amount amount)) diff --git a/sexp_effects/interpreter.py b/sexp_effects/interpreter.py new file mode 100644 index 0000000..830904a --- /dev/null +++ b/sexp_effects/interpreter.py @@ -0,0 +1,1016 @@ +""" +S-Expression Effect Interpreter + +Interprets effect definitions written in S-expressions. +Only allows safe primitives - no arbitrary code execution. +""" + +import numpy as np +from typing import Any, Dict, List, Optional, Callable +from pathlib import Path + +from .parser import Symbol, Keyword, parse, parse_file +from .primitives import PRIMITIVES, reset_rng + + +def _is_symbol(x) -> bool: + """Check if x is a Symbol (duck typing to support multiple Symbol classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Symbol' + + +def _is_keyword(x) -> bool: + """Check if x is a Keyword (duck typing to support multiple Keyword classes).""" + return hasattr(x, 'name') and type(x).__name__ == 'Keyword' + + +def _symbol_name(x) -> str: + """Get the name from a Symbol.""" + return x.name if hasattr(x, 'name') else str(x) + + +class Environment: + """Lexical environment for variable bindings.""" + + def __init__(self, parent: 'Environment' = None): + self.bindings: Dict[str, Any] = {} + self.parent = parent + + def get(self, name: str) -> Any: + if name in self.bindings: + return self.bindings[name] + if self.parent: + return self.parent.get(name) + raise NameError(f"Undefined variable: {name}") + + def set(self, name: str, value: Any): + self.bindings[name] = value + + def has(self, name: str) -> bool: + if name in self.bindings: + return True + if self.parent: + return self.parent.has(name) + return False + + +class Lambda: + """A user-defined function (lambda).""" + + def __init__(self, params: List[str], body: Any, env: Environment): + self.params = params + self.body = body + self.env = env # Closure environment + + def __repr__(self): + return f"" + + +class EffectDefinition: + """A parsed effect definition.""" + + def __init__(self, name: str, params: Dict[str, Any], body: Any): + self.name = name + self.params = params # {name: (type, default)} + self.body = body + + def __repr__(self): + return f"" + + +class Interpreter: + """ + S-Expression interpreter for effects. + + Provides a safe execution environment where only + whitelisted primitives can be called. + + Args: + minimal_primitives: If True, only load core primitives (arithmetic, comparison, + basic data access). Additional primitives must be loaded with + (require-primitives) or (with-primitives). + If False (default), load all legacy primitives for backward compatibility. + """ + + def __init__(self, minimal_primitives: bool = False): + # Base environment with primitives + self.global_env = Environment() + self.minimal_primitives = minimal_primitives + + if minimal_primitives: + # Load only core primitives + from .primitive_libs.core import PRIMITIVES as CORE_PRIMITIVES + for name, fn in CORE_PRIMITIVES.items(): + self.global_env.set(name, fn) + else: + # Load all legacy primitives for backward compatibility + for name, fn in PRIMITIVES.items(): + self.global_env.set(name, fn) + + # Special values + self.global_env.set('true', True) + self.global_env.set('false', False) + self.global_env.set('nil', None) + + # Loaded effect definitions + self.effects: Dict[str, EffectDefinition] = {} + + def eval(self, expr: Any, env: Environment = None) -> Any: + """Evaluate an S-expression.""" + if env is None: + env = self.global_env + + # Atoms + if isinstance(expr, (int, float, str, bool)): + return expr + + if expr is None: + return None + + # Handle Symbol (duck typing to support both sexp_effects.parser.Symbol and artdag.sexp.parser.Symbol) + if _is_symbol(expr): + return env.get(expr.name) + + # Handle Keyword (duck typing) + if _is_keyword(expr): + return expr # Keywords evaluate to themselves + + if isinstance(expr, np.ndarray): + return expr # Images pass through + + # Lists (function calls / special forms) + if isinstance(expr, list): + if not expr: + return [] + + head = expr[0] + + # Special forms + if _is_symbol(head): + form = head.name + + # Quote + if form == 'quote': + return expr[1] + + # Define + if form == 'define': + name = expr[1] + if _is_symbol(name): + value = self.eval(expr[2], env) + self.global_env.set(name.name, value) + return value + else: + raise SyntaxError(f"define requires symbol, got {name}") + + # Define-effect + if form == 'define-effect': + return self._define_effect(expr, env) + + # Lambda + if form == 'lambda' or form == 'λ': + params = [p.name if _is_symbol(p) else p for p in expr[1]] + body = expr[2] + return Lambda(params, body, env) + + # Let + if form == 'let': + return self._eval_let(expr, env) + + # Let* + if form == 'let*': + return self._eval_let_star(expr, env) + + # If + if form == 'if': + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + # Cond + if form == 'cond': + return self._eval_cond(expr, env) + + # And + if form == 'and': + result = True + for e in expr[1:]: + result = self.eval(e, env) + if not result: + return False + return result + + # Or + if form == 'or': + for e in expr[1:]: + result = self.eval(e, env) + if result: + return result + return False + + # Not + if form == 'not': + return not self.eval(expr[1], env) + + # Begin (sequence) + if form == 'begin': + result = None + for e in expr[1:]: + result = self.eval(e, env) + return result + + # Thread-first macro: (-> x (f a) (g b)) => (g (f x a) b) + if form == '->': + result = self.eval(expr[1], env) + for form_expr in expr[2:]: + if isinstance(form_expr, list): + # Insert result as first arg: (f a b) => (f result a b) + result = self.eval([form_expr[0], result] + form_expr[1:], env) + else: + # Just a symbol: f => (f result) + result = self.eval([form_expr, result], env) + return result + + # Set! (mutation) + if form == 'set!': + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + value = self.eval(expr[2], env) + # Find and update in appropriate scope + scope = env + while scope: + if name in scope.bindings: + scope.bindings[name] = value + return value + scope = scope.parent + raise NameError(f"Cannot set undefined variable: {name}") + + # State-get / state-set (for effect state) + if form == 'state-get': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + default = self.eval(expr[2], env) if len(expr) > 2 else None + return state.get(key, default) + + if form == 'state-set': + state = env.get('__state__') + key = self.eval(expr[1], env) + if _is_symbol(key): + key = key.name + value = self.eval(expr[2], env) + state[key] = value + return value + + # ascii-fx-zone special form - delays evaluation of expression parameters + if form == 'ascii-fx-zone': + return self._eval_ascii_fx_zone(expr, env) + + # with-primitives - load primitive library and scope to body + if form == 'with-primitives': + return self._eval_with_primitives(expr, env) + + # require-primitives - load primitive library into current scope + if form == 'require-primitives': + return self._eval_require_primitives(expr, env) + + # Function call + fn = self.eval(head, env) + args = [self.eval(arg, env) for arg in expr[1:]] + + # Handle keyword arguments + pos_args = [] + kw_args = {} + i = 0 + while i < len(args): + if _is_keyword(args[i]): + kw_args[args[i].name] = args[i + 1] if i + 1 < len(args) else None + i += 2 + else: + pos_args.append(args[i]) + i += 1 + + return self._apply(fn, pos_args, kw_args, env) + + raise TypeError(f"Cannot evaluate: {expr}") + + def _wrap_lambda(self, lam: 'Lambda') -> Callable: + """Wrap a Lambda in a Python callable for use by primitives.""" + def wrapper(*args): + new_env = Environment(lam.env) + for i, param in enumerate(lam.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(lam.body, new_env) + return wrapper + + def _apply(self, fn: Any, args: List[Any], kwargs: Dict[str, Any], env: Environment) -> Any: + """Apply a function to arguments.""" + if isinstance(fn, Lambda): + # User-defined function + new_env = Environment(fn.env) + for i, param in enumerate(fn.params): + if i < len(args): + new_env.set(param, args[i]) + else: + new_env.set(param, None) + return self.eval(fn.body, new_env) + + elif callable(fn): + # Wrap any Lambda arguments so primitives can call them + wrapped_args = [] + for arg in args: + if isinstance(arg, Lambda): + wrapped_args.append(self._wrap_lambda(arg)) + else: + wrapped_args.append(arg) + + # Inject _interp and _env for primitives that need them + import inspect + try: + sig = inspect.signature(fn) + params = sig.parameters + if '_interp' in params and '_interp' not in kwargs: + kwargs['_interp'] = self + if '_env' in params and '_env' not in kwargs: + kwargs['_env'] = env + except (ValueError, TypeError): + # Some built-in functions don't have inspectable signatures + pass + + # Primitive function + if kwargs: + return fn(*wrapped_args, **kwargs) + return fn(*wrapped_args) + + else: + raise TypeError(f"Cannot call: {fn}") + + def _parse_bindings(self, bindings: list) -> list: + """Parse bindings in either Scheme or Clojure style. + + Scheme: ((x 1) (y 2)) -> [(x, 1), (y, 2)] + Clojure: [x 1 y 2] -> [(x, 1), (y, 2)] + """ + if not bindings: + return [] + + # Check if Clojure style (flat list with symbols and values alternating) + if _is_symbol(bindings[0]): + # Clojure style: [x 1 y 2] + pairs = [] + i = 0 + while i < len(bindings) - 1: + name = bindings[i].name if _is_symbol(bindings[i]) else bindings[i] + value = bindings[i + 1] + pairs.append((name, value)) + i += 2 + return pairs + else: + # Scheme style: ((x 1) (y 2)) + pairs = [] + for binding in bindings: + name = binding[0].name if _is_symbol(binding[0]) else binding[0] + value = binding[1] + pairs.append((name, value)) + return pairs + + def _eval_let(self, expr: Any, env: Environment) -> Any: + """Evaluate let expression: (let ((x 1) (y 2)) body) or (let [x 1 y 2] body) + + Note: Uses sequential binding (like Clojure let / Scheme let*) so each + binding can reference previous bindings. + """ + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Sequential: can see previous bindings + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_let_star(self, expr: Any, env: Environment) -> Any: + """Evaluate let* expression: sequential bindings.""" + bindings = expr[1] + body = expr[2] + + new_env = Environment(env) + for name, value_expr in self._parse_bindings(bindings): + value = self.eval(value_expr, new_env) # Evaluate in current env + new_env.set(name, value) + + return self.eval(body, new_env) + + def _eval_cond(self, expr: Any, env: Environment) -> Any: + """Evaluate cond expression.""" + for clause in expr[1:]: + test = clause[0] + if _is_symbol(test) and test.name == 'else': + return self.eval(clause[1], env) + if self.eval(test, env): + return self.eval(clause[1], env) + return None + + def _eval_with_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate with-primitives: scoped primitive library loading. + + Syntax: + (with-primitives "math" + (sin (* x pi))) + + (with-primitives "math" :path "custom/math.py" + body) + + The primitives from the library are only available within the body. + """ + # Parse library name and optional path + lib_name = expr[1] + if _is_symbol(lib_name): + lib_name = lib_name.name + + path = None + body_start = 2 + + # Check for :path keyword + if len(expr) > 2 and _is_keyword(expr[2]) and expr[2].name == 'path': + path = expr[3] + body_start = 4 + + # Load the primitive library + primitives = self.load_primitive_library(lib_name, path) + + # Create new environment with primitives + new_env = Environment(env) + for name, fn in primitives.items(): + new_env.set(name, fn) + + # Evaluate body in new environment + result = None + for e in expr[body_start:]: + result = self.eval(e, new_env) + return result + + def _eval_require_primitives(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require-primitives: load primitives into current scope. + + Syntax: + (require-primitives "math" "color" "filters") + + Unlike with-primitives, this loads into the current environment + (typically used at top-level to set up an effect's dependencies). + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + primitives = self.load_primitive_library(lib_name) + for name, fn in primitives.items(): + env.set(name, fn) + + return None + + def load_primitive_library(self, name: str, path: str = None) -> dict: + """ + Load a primitive library by name or path. + + Returns dict of {name: function}. + """ + from .primitive_libs import load_primitive_library + return load_primitive_library(name, path) + + def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any: + """ + Evaluate ascii-fx-zone special form. + + Syntax: + (ascii-fx-zone frame + :cols 80 + :alphabet "standard" + :color_mode "color" + :background "black" + :contrast 1.5 + :char_hue ;; NOT evaluated - passed to primitive + :char_saturation + :char_brightness + :char_scale + :char_rotation + :char_jitter ) + + The expression parameters (:char_hue, etc.) are NOT pre-evaluated. + They are passed as raw S-expressions to the primitive which + evaluates them per-zone with zone context variables injected. + + Requires: (require-primitives "ascii") + """ + # Look up ascii-fx-zone primitive from environment + # It must be loaded via (require-primitives "ascii") + try: + prim_ascii_fx_zone = env.get('ascii-fx-zone') + except NameError: + raise NameError( + "ascii-fx-zone primitive not found. " + "Add (require-primitives \"ascii\") to your effect file." + ) + + # Expression parameter names that should NOT be evaluated + expr_params = {'char_hue', 'char_saturation', 'char_brightness', + 'char_scale', 'char_rotation', 'char_jitter', 'cell_effect'} + + # Parse arguments + frame = self.eval(expr[1], env) # First arg is always the frame + + # Defaults + cols = 80 + char_size = None # If set, overrides cols + alphabet = "standard" + color_mode = "color" + background = "black" + contrast = 1.5 + char_hue = None + char_saturation = None + char_brightness = None + char_scale = None + char_rotation = None + char_jitter = None + cell_effect = None # Lambda for arbitrary per-cell effects + # Convenience params for staged recipes + energy = None + rotation_scale = 0 + # Extra params to pass to zone dict for lambdas + extra_params = {} + + # Parse keyword arguments + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item): + if i + 1 >= len(expr): + break + value_expr = expr[i + 1] + kw_name = item.name + + if kw_name in expr_params: + # Resolve symbol references but don't evaluate expressions + # This handles the case where effect definition passes a param like :char_hue char_hue + resolved = value_expr + if _is_symbol(value_expr): + try: + resolved = env.get(value_expr.name) + except NameError: + resolved = value_expr # Keep as symbol if not found + + if kw_name == 'char_hue': + char_hue = resolved + elif kw_name == 'char_saturation': + char_saturation = resolved + elif kw_name == 'char_brightness': + char_brightness = resolved + elif kw_name == 'char_scale': + char_scale = resolved + elif kw_name == 'char_rotation': + char_rotation = resolved + elif kw_name == 'char_jitter': + char_jitter = resolved + elif kw_name == 'cell_effect': + cell_effect = resolved + else: + # Evaluate normally + value = self.eval(value_expr, env) + if kw_name == 'cols': + cols = int(value) + elif kw_name == 'char_size': + # Handle nil/None values + if value is None or (_is_symbol(value) and value.name == 'nil'): + char_size = None + else: + char_size = int(value) + elif kw_name == 'alphabet': + alphabet = str(value) + elif kw_name == 'color_mode': + color_mode = str(value) + elif kw_name == 'background': + background = str(value) + elif kw_name == 'contrast': + contrast = float(value) + elif kw_name == 'energy': + if value is None or (_is_symbol(value) and value.name == 'nil'): + energy = None + else: + energy = float(value) + elif kw_name == 'rotation_scale': + rotation_scale = float(value) + else: + # Store any other params for lambdas to access + extra_params[kw_name] = value + i += 2 + else: + i += 1 + + # If energy and rotation_scale provided, build rotation expression + # rotation = energy * rotation_scale * position_factor + # position_factor: bottom-left=0, top-right=3 + # Formula: 1.5 * (zone-col-norm + (1 - zone-row-norm)) + if energy is not None and rotation_scale > 0: + # Build expression as S-expression list that will be evaluated per-zone + # (* (* energy rotation_scale) (* 1.5 (+ zone-col-norm (- 1 zone-row-norm)))) + energy_times_scale = energy * rotation_scale + # The position part uses zone variables, so we build it as an expression + char_rotation = [ + Symbol('*'), + energy_times_scale, + [Symbol('*'), 1.5, + [Symbol('+'), Symbol('zone-col-norm'), + [Symbol('-'), 1, Symbol('zone-row-norm')]]] + ] + + # Pull any extra params from environment that aren't standard params + # These are typically passed from recipes for use in cell_effect lambdas + standard_params = { + 'cols', 'char_size', 'alphabet', 'color_mode', 'background', 'contrast', + 'char_hue', 'char_saturation', 'char_brightness', 'char_scale', + 'char_rotation', 'char_jitter', 'cell_effect', 'energy', 'rotation_scale', + 'frame', 't', '_time', '__state__', '__interp__', 'true', 'false', 'nil' + } + # Check environment for extra bindings + current_env = env + while current_env is not None: + for k, v in current_env.bindings.items(): + if k not in standard_params and k not in extra_params and not callable(v): + # Add non-standard, non-callable bindings to extra_params + if isinstance(v, (int, float, str, bool)) or v is None: + extra_params[k] = v + current_env = current_env.parent + + # Call the primitive with interpreter and env for expression evaluation + return prim_ascii_fx_zone( + frame, + cols=cols, + char_size=char_size, + alphabet=alphabet, + color_mode=color_mode, + background=background, + contrast=contrast, + char_hue=char_hue, + char_saturation=char_saturation, + char_brightness=char_brightness, + char_scale=char_scale, + char_rotation=char_rotation, + char_jitter=char_jitter, + cell_effect=cell_effect, + energy=energy, + rotation_scale=rotation_scale, + _interp=self, + _env=env, + **extra_params + ) + + def _define_effect(self, expr: Any, env: Environment) -> EffectDefinition: + """ + Parse effect definition. + + Required syntax: + (define-effect name + :params ( + (param1 :type int :default 8 :desc "description") + ) + body) + + Effects MUST use :params syntax. Legacy ((param default) ...) is not supported. + """ + name = expr[1].name if _is_symbol(expr[1]) else expr[1] + + params = {} + body = None + found_params = False + + # Parse :params and body + i = 2 + while i < len(expr): + item = expr[i] + if _is_keyword(item) and item.name == "params": + # :params syntax + if i + 1 >= len(expr): + raise SyntaxError(f"Effect '{name}': Missing params list after :params keyword") + params_list = expr[i + 1] + params = self._parse_params_block(params_list) + found_params = True + i += 2 + elif _is_keyword(item): + # Skip other keywords (like :desc) + i += 2 + elif body is None: + # First non-keyword item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy syntax and reject it + if isinstance(first_elem, list) and len(first_elem) >= 2: + raise SyntaxError( + f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. " + f"Use :params block instead." + ) + body = item + i += 1 + else: + i += 1 + + if body is None: + raise SyntaxError(f"Effect '{name}': No body found") + + if not found_params: + raise SyntaxError( + f"Effect '{name}': Missing :params block. " + f"For effects with no parameters, use empty :params ()" + ) + + effect = EffectDefinition(name, params, body) + self.effects[name] = effect + return effect + + def _parse_params_block(self, params_list: list) -> Dict[str, Any]: + """ + Parse :params block syntax: + ( + (param_name :type int :default 8 :range [4 32] :desc "description") + ) + """ + params = {} + for param_def in params_list: + if not isinstance(param_def, list) or len(param_def) < 1: + continue + + # First element is the parameter name + first = param_def[0] + if _is_symbol(first): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + continue + + # Parse keyword arguments + default = None + i = 1 + while i < len(param_def): + item = param_def[i] + if _is_keyword(item): + if i + 1 >= len(param_def): + break + kw_value = param_def[i + 1] + + if item.name == "default": + default = kw_value + i += 2 + else: + i += 1 + + params[param_name] = default + + return params + + def load_effect(self, path: str) -> EffectDefinition: + """Load an effect definition from a .sexp file.""" + expr = parse_file(path) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the last defined effect + if self.effects: + return list(self.effects.values())[-1] + return None + + def load_effect_from_string(self, sexp_content: str, effect_name: str = None) -> EffectDefinition: + """Load an effect definition from an S-expression string. + + Args: + sexp_content: The S-expression content as a string + effect_name: Optional name hint (used if effect doesn't define its own name) + + Returns: + The loaded EffectDefinition + """ + expr = parse(sexp_content) + + # Handle multiple top-level expressions + if isinstance(expr, list) and expr and isinstance(expr[0], list): + for e in expr: + self.eval(e) + else: + self.eval(expr) + + # Return the effect if we can find it by name + if effect_name and effect_name in self.effects: + return self.effects[effect_name] + + # Return the most recently loaded effect + if self.effects: + return list(self.effects.values())[-1] + + return None + + def run_effect(self, name: str, frame, params: Dict[str, Any], + state: Dict[str, Any]) -> tuple: + """ + Run an effect on frame(s). + + Args: + name: Effect name + frame: Input frame (H, W, 3) RGB uint8, or list of frames for multi-input + params: Effect parameters (overrides defaults) + state: Persistent state dict + + Returns: + (output_frame, new_state) + """ + if name not in self.effects: + raise ValueError(f"Unknown effect: {name}") + + effect = self.effects[name] + + # Create environment for this run + env = Environment(self.global_env) + + # Bind frame(s) - support both single frame and list of frames + if isinstance(frame, list): + # Multi-input effect + frames = frame + env.set('frame', frames[0] if frames else None) # Backwards compat + env.set('inputs', frames) + # Named frame bindings + for i, f in enumerate(frames): + env.set(f'frame-{chr(ord("a") + i)}', f) # frame-a, frame-b, etc. + else: + # Single-input effect + env.set('frame', frame) + + # Bind state + if state is None: + state = {} + env.set('__state__', state) + + # Validate that all provided params are known (except internal params) + # Extra params are allowed and will be passed through to cell_effect lambdas + known_params = set(effect.params.keys()) + internal_params = {'_time', 'seed', '_binding', 'effect', 'cid', 'hash', 'effect_path'} + extra_effect_params = {} # Unknown params passed through for cell_effect lambdas + for k in params.keys(): + if k not in known_params and k not in internal_params: + # Allow unknown params - they'll be passed to cell_effect lambdas via zone dict + extra_effect_params[k] = params[k] + + # Bind parameters (defaults + overrides) + for pname, pdefault in effect.params.items(): + value = params.get(pname) + if value is None: + # Evaluate default if it's an expression (list) + if isinstance(pdefault, list): + value = self.eval(pdefault, env) + else: + value = pdefault + env.set(pname, value) + + # Bind extra params (unknown params passed through for cell_effect lambdas) + for k, v in extra_effect_params.items(): + env.set(k, v) + + # Reset RNG with seed if provided + seed = params.get('seed', 42) + reset_rng(int(seed)) + + # Bind time if provided + time_val = params.get('_time', 0) + env.set('t', time_val) + env.set('_time', time_val) + + # Evaluate body + result = self.eval(effect.body, env) + + # Ensure result is an image + if not isinstance(result, np.ndarray): + result = frame + + return result, state + + def eval_with_zone(self, expr, env: Environment, zone) -> Any: + """ + Evaluate expression with zone-* variables injected. + + Args: + expr: Expression to evaluate (S-expression) + env: Parent environment with bound values + zone: ZoneContext object with cell data + + Zone variables injected: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: + Evaluated result (typically a number) + """ + # Create child environment with zone variables + zone_env = Environment(env) + zone_env.set('zone-row', zone.row) + zone_env.set('zone-col', zone.col) + zone_env.set('zone-row-norm', zone.row_norm) + zone_env.set('zone-col-norm', zone.col_norm) + zone_env.set('zone-lum', zone.luminance) + zone_env.set('zone-sat', zone.saturation) + zone_env.set('zone-hue', zone.hue) + zone_env.set('zone-r', zone.r) + zone_env.set('zone-g', zone.g) + zone_env.set('zone-b', zone.b) + + return self.eval(expr, zone_env) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + +_interpreter = None +_interpreter_minimal = None + + +def get_interpreter(minimal_primitives: bool = False) -> Interpreter: + """Get or create the global interpreter. + + Args: + minimal_primitives: If True, return interpreter with only core primitives. + Additional primitives must be loaded with require-primitives or with-primitives. + """ + global _interpreter, _interpreter_minimal + + if minimal_primitives: + if _interpreter_minimal is None: + _interpreter_minimal = Interpreter(minimal_primitives=True) + return _interpreter_minimal + else: + if _interpreter is None: + _interpreter = Interpreter(minimal_primitives=False) + return _interpreter + + +def load_effect(path: str) -> EffectDefinition: + """Load an effect from a .sexp file.""" + return get_interpreter().load_effect(path) + + +def load_effects_dir(directory: str): + """Load all .sexp effects from a directory.""" + interp = get_interpreter() + dir_path = Path(directory) + for path in dir_path.glob('*.sexp'): + try: + interp.load_effect(str(path)) + except Exception as e: + print(f"Warning: Failed to load {path}: {e}") + + +def run_effect(name: str, frame: np.ndarray, params: Dict[str, Any], + state: Dict[str, Any] = None) -> tuple: + """Run an effect.""" + return get_interpreter().run_effect(name, frame, params, state or {}) + + +def list_effects() -> List[str]: + """List loaded effect names.""" + return list(get_interpreter().effects.keys()) + + +# ============================================================================= +# Adapter for existing effect system +# ============================================================================= + +def make_process_frame(effect_path: str) -> Callable: + """ + Create a process_frame function from a .sexp effect. + + This allows S-expression effects to be used with the existing + effect system. + """ + interp = get_interpreter() + interp.load_effect(effect_path) + effect_name = Path(effect_path).stem + + def process_frame(frame: np.ndarray, params: dict, state: dict) -> tuple: + return interp.run_effect(effect_name, frame, params, state) + + return process_frame diff --git a/sexp_effects/parser.py b/sexp_effects/parser.py new file mode 100644 index 0000000..12bedfd --- /dev/null +++ b/sexp_effects/parser.py @@ -0,0 +1,168 @@ +""" +S-Expression Parser + +Parses S-expressions into Python data structures: +- Lists become Python lists +- Symbols become Symbol objects +- Numbers become int/float +- Strings become str +- Keywords (:foo) become Keyword objects +""" + +import re +from dataclasses import dataclass +from typing import Any, List, Union + + +@dataclass(frozen=True) +class Symbol: + """A symbol (identifier) in the S-expression.""" + name: str + + def __repr__(self): + return self.name + + +@dataclass(frozen=True) +class Keyword: + """A keyword like :foo in the S-expression.""" + name: str + + def __repr__(self): + return f":{self.name}" + + +# Token patterns +TOKEN_PATTERNS = [ + (r'\s+', None), # Whitespace (skip) + (r';[^\n]*', None), # Comments (skip) + (r'\(', 'LPAREN'), + (r'\)', 'RPAREN'), + (r'\[', 'LBRACKET'), + (r'\]', 'RBRACKET'), + (r"'", 'QUOTE'), + (r'"([^"\\]|\\.)*"', 'STRING'), + (r':[a-zA-Z_][a-zA-Z0-9_\-]*', 'KEYWORD'), + (r'-?[0-9]+\.[0-9]+', 'FLOAT'), + (r'-?[0-9]+', 'INT'), + (r'#t|#f|true|false', 'BOOL'), + (r'[a-zA-Z_+\-*/<>=!?][a-zA-Z0-9_+\-*/<>=!?]*', 'SYMBOL'), +] + +TOKEN_REGEX = '|'.join(f'(?P<{name}>{pattern})' if name else f'(?:{pattern})' + for pattern, name in TOKEN_PATTERNS) + + +def tokenize(source: str) -> List[tuple]: + """Tokenize S-expression source code.""" + tokens = [] + for match in re.finditer(TOKEN_REGEX, source): + kind = match.lastgroup + value = match.group() + if kind: + tokens.append((kind, value)) + return tokens + + +def parse(source: str) -> Any: + """Parse S-expression source into Python data structures.""" + tokens = tokenize(source) + pos = [0] # Use list for mutability in nested function + + def parse_expr(): + if pos[0] >= len(tokens): + raise SyntaxError("Unexpected end of input") + + kind, value = tokens[pos[0]] + + if kind == 'LPAREN': + pos[0] += 1 + items = [] + while pos[0] < len(tokens) and tokens[pos[0]][0] != 'RPAREN': + items.append(parse_expr()) + if pos[0] >= len(tokens): + raise SyntaxError("Missing closing parenthesis") + pos[0] += 1 # Skip RPAREN + return items + + if kind == 'LBRACKET': + pos[0] += 1 + items = [] + while pos[0] < len(tokens) and tokens[pos[0]][0] != 'RBRACKET': + items.append(parse_expr()) + if pos[0] >= len(tokens): + raise SyntaxError("Missing closing bracket") + pos[0] += 1 # Skip RBRACKET + return items + + elif kind == 'RPAREN': + raise SyntaxError("Unexpected closing parenthesis") + + elif kind == 'QUOTE': + pos[0] += 1 + return [Symbol('quote'), parse_expr()] + + elif kind == 'STRING': + pos[0] += 1 + # Remove quotes and unescape + return value[1:-1].replace('\\"', '"').replace('\\n', '\n') + + elif kind == 'INT': + pos[0] += 1 + return int(value) + + elif kind == 'FLOAT': + pos[0] += 1 + return float(value) + + elif kind == 'BOOL': + pos[0] += 1 + return value in ('#t', 'true') + + elif kind == 'KEYWORD': + pos[0] += 1 + return Keyword(value[1:]) # Remove leading : + + elif kind == 'SYMBOL': + pos[0] += 1 + return Symbol(value) + + else: + raise SyntaxError(f"Unknown token: {kind} {value}") + + result = parse_expr() + + # Check for multiple top-level expressions + if pos[0] < len(tokens): + # Allow multiple top-level expressions, return as list + results = [result] + while pos[0] < len(tokens): + results.append(parse_expr()) + return results + + return result + + +def parse_file(path: str) -> Any: + """Parse an S-expression file.""" + with open(path, 'r') as f: + return parse(f.read()) + + +# Convenience for pretty-printing +def to_sexp(obj: Any) -> str: + """Convert Python object back to S-expression string.""" + if isinstance(obj, list): + return '(' + ' '.join(to_sexp(x) for x in obj) + ')' + elif isinstance(obj, Symbol): + return obj.name + elif isinstance(obj, Keyword): + return f':{obj.name}' + elif isinstance(obj, str): + return f'"{obj}"' + elif isinstance(obj, bool): + return '#t' if obj else '#f' + elif isinstance(obj, (int, float)): + return str(obj) + else: + return repr(obj) diff --git a/sexp_effects/primitive_libs/__init__.py b/sexp_effects/primitive_libs/__init__.py new file mode 100644 index 0000000..47ee174 --- /dev/null +++ b/sexp_effects/primitive_libs/__init__.py @@ -0,0 +1,102 @@ +""" +Primitive Libraries System + +Provides modular loading of primitives. Core primitives are always available, +additional primitive libraries can be loaded on-demand with scoped availability. + +Usage in sexp: + ;; Load at recipe level - available throughout + (primitives math :path "primitive_libs/math.py") + + ;; Or use with-primitives for scoped access + (with-primitives "image" + (blur frame 3)) ;; blur only available inside + + ;; Nested scopes work + (with-primitives "math" + (with-primitives "color" + (hue-shift frame (* (sin t) 30)))) + +Library file format (primitive_libs/math.py): + import math + + def prim_sin(x): return math.sin(x) + def prim_cos(x): return math.cos(x) + + PRIMITIVES = { + 'sin': prim_sin, + 'cos': prim_cos, + } +""" + +import importlib.util +from pathlib import Path +from typing import Dict, Callable, Any, Optional + +# Cache of loaded primitive libraries +_library_cache: Dict[str, Dict[str, Any]] = {} + +# Core primitives - always available, cannot be overridden +CORE_PRIMITIVES: Dict[str, Any] = {} + + +def register_core_primitive(name: str, fn: Callable): + """Register a core primitive that's always available.""" + CORE_PRIMITIVES[name] = fn + + +def load_primitive_library(name: str, path: Optional[str] = None) -> Dict[str, Any]: + """ + Load a primitive library by name or path. + + Args: + name: Library name (e.g., "math", "image", "color") + path: Optional explicit path to library file + + Returns: + Dict of primitive name -> function + """ + # Check cache first + cache_key = path or name + if cache_key in _library_cache: + return _library_cache[cache_key] + + # Find library file + if path: + lib_path = Path(path) + else: + # Look in standard locations + lib_dir = Path(__file__).parent + lib_path = lib_dir / f"{name}.py" + + if not lib_path.exists(): + raise ValueError(f"Primitive library '{name}' not found at {lib_path}") + + if not lib_path.exists(): + raise ValueError(f"Primitive library file not found: {lib_path}") + + # Load the module + spec = importlib.util.spec_from_file_location(f"prim_lib_{name}", lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get PRIMITIVES dict from module + if not hasattr(module, 'PRIMITIVES'): + raise ValueError(f"Primitive library '{name}' missing PRIMITIVES dict") + + primitives = module.PRIMITIVES + + # Cache and return + _library_cache[cache_key] = primitives + return primitives + + +def get_library_names() -> list: + """Get names of available primitive libraries.""" + lib_dir = Path(__file__).parent + return [p.stem for p in lib_dir.glob("*.py") if p.stem != "__init__"] + + +def clear_cache(): + """Clear the library cache (useful for testing).""" + _library_cache.clear() diff --git a/sexp_effects/primitive_libs/arrays.py b/sexp_effects/primitive_libs/arrays.py new file mode 100644 index 0000000..61da196 --- /dev/null +++ b/sexp_effects/primitive_libs/arrays.py @@ -0,0 +1,196 @@ +""" +Array Primitives Library + +Vectorized operations on numpy arrays for coordinate transformations. +""" +import numpy as np + + +# Arithmetic +def prim_arr_add(a, b): + return np.add(a, b) + + +def prim_arr_sub(a, b): + return np.subtract(a, b) + + +def prim_arr_mul(a, b): + return np.multiply(a, b) + + +def prim_arr_div(a, b): + return np.divide(a, b) + + +def prim_arr_mod(a, b): + return np.mod(a, b) + + +def prim_arr_neg(a): + return np.negative(a) + + +# Math functions +def prim_arr_sin(a): + return np.sin(a) + + +def prim_arr_cos(a): + return np.cos(a) + + +def prim_arr_tan(a): + return np.tan(a) + + +def prim_arr_sqrt(a): + return np.sqrt(np.maximum(a, 0)) + + +def prim_arr_pow(a, b): + return np.power(a, b) + + +def prim_arr_abs(a): + return np.abs(a) + + +def prim_arr_exp(a): + return np.exp(a) + + +def prim_arr_log(a): + return np.log(np.maximum(a, 1e-10)) + + +def prim_arr_atan2(y, x): + return np.arctan2(y, x) + + +# Comparison / selection +def prim_arr_min(a, b): + return np.minimum(a, b) + + +def prim_arr_max(a, b): + return np.maximum(a, b) + + +def prim_arr_clip(a, lo, hi): + return np.clip(a, lo, hi) + + +def prim_arr_where(cond, a, b): + return np.where(cond, a, b) + + +def prim_arr_floor(a): + return np.floor(a) + + +def prim_arr_ceil(a): + return np.ceil(a) + + +def prim_arr_round(a): + return np.round(a) + + +# Interpolation +def prim_arr_lerp(a, b, t): + return a + (b - a) * t + + +def prim_arr_smoothstep(edge0, edge1, x): + t = prim_arr_clip((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +# Creation +def prim_arr_zeros(shape): + return np.zeros(shape, dtype=np.float32) + + +def prim_arr_ones(shape): + return np.ones(shape, dtype=np.float32) + + +def prim_arr_full(shape, value): + return np.full(shape, value, dtype=np.float32) + + +def prim_arr_arange(start, stop, step=1): + return np.arange(start, stop, step, dtype=np.float32) + + +def prim_arr_linspace(start, stop, num): + return np.linspace(start, stop, num, dtype=np.float32) + + +def prim_arr_meshgrid(x, y): + return np.meshgrid(x, y) + + +# Coordinate transforms +def prim_polar_from_center(map_x, map_y, cx, cy): + """Convert Cartesian to polar coordinates centered at (cx, cy).""" + dx = map_x - cx + dy = map_y - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + return (r, theta) + + +def prim_cart_from_polar(r, theta, cx, cy): + """Convert polar to Cartesian, adding center offset.""" + x = r * np.cos(theta) + cx + y = r * np.sin(theta) + cy + return (x, y) + + +PRIMITIVES = { + # Arithmetic + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-neg': prim_arr_neg, + + # Math + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-exp': prim_arr_exp, + 'arr-log': prim_arr_log, + 'arr-atan2': prim_arr_atan2, + + # Selection + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-ceil': prim_arr_ceil, + 'arr-round': prim_arr_round, + + # Interpolation + 'arr-lerp': prim_arr_lerp, + 'arr-smoothstep': prim_arr_smoothstep, + + # Creation + 'arr-zeros': prim_arr_zeros, + 'arr-ones': prim_arr_ones, + 'arr-full': prim_arr_full, + 'arr-arange': prim_arr_arange, + 'arr-linspace': prim_arr_linspace, + 'arr-meshgrid': prim_arr_meshgrid, + + # Coordinates + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, +} diff --git a/sexp_effects/primitive_libs/ascii.py b/sexp_effects/primitive_libs/ascii.py new file mode 100644 index 0000000..858f010 --- /dev/null +++ b/sexp_effects/primitive_libs/ascii.py @@ -0,0 +1,388 @@ +""" +ASCII Art Primitives Library + +ASCII art rendering with per-zone expression evaluation and cell effects. +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont +from typing import Any, Dict, List, Optional, Callable +import colorsys + + +# Character sets +CHAR_SETS = { + "standard": " .:-=+*#%@", + "blocks": " ░▒▓█", + "simple": " .:oO@", + "digits": "0123456789", + "binary": "01", + "ascii": " `.-':_,^=;><+!rc*/z?sLTv)J7(|Fi{C}fI31tlu[neoZ5Yxjya]2ESwqkP6h9d4VpOGbUAKXHm8RD#$Bg0MNWQ%&@", +} + +# Default font +_default_font = None + + +def _get_font(size: int): + """Get monospace font at given size.""" + global _default_font + try: + return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) + except: + return ImageFont.load_default() + + +def _parse_color(color_str: str) -> tuple: + """Parse color string to RGB tuple.""" + if color_str.startswith('#'): + hex_color = color_str[1:] + if len(hex_color) == 3: + hex_color = ''.join(c*2 for c in hex_color) + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + + colors = { + 'black': (0, 0, 0), 'white': (255, 255, 255), + 'red': (255, 0, 0), 'green': (0, 255, 0), 'blue': (0, 0, 255), + 'yellow': (255, 255, 0), 'cyan': (0, 255, 255), 'magenta': (255, 0, 255), + 'gray': (128, 128, 128), 'grey': (128, 128, 128), + } + return colors.get(color_str.lower(), (0, 0, 0)) + + +def _cell_sample(frame: np.ndarray, cell_size: int): + """Sample frame into cells, returning colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + """ + h, w = frame.shape[:2] + rows = h // cell_size + cols = w // cell_size + + # Crop to exact grid then block-average via cv2 area interpolation. + cropped = frame[:rows * cell_size, :cols * cell_size] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + luminances = ((0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]) / 255.0).astype(np.float32) + + return colors, luminances + + +def _luminance_to_char(lum: float, alphabet: str, contrast: float) -> str: + """Map luminance to character.""" + chars = CHAR_SETS.get(alphabet, alphabet) + lum = ((lum - 0.5) * contrast + 0.5) + lum = max(0, min(1, lum)) + idx = int(lum * (len(chars) - 1)) + return chars[idx] + + +def _render_char_cell(char: str, cell_size: int, color: tuple, bg_color: tuple) -> np.ndarray: + """Render a single character to a cell image.""" + img = Image.new('RGB', (cell_size, cell_size), bg_color) + draw = ImageDraw.Draw(img) + font = _get_font(cell_size) + + # Center the character + bbox = draw.textbbox((0, 0), char, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + x = (cell_size - text_w) // 2 + y = (cell_size - text_h) // 2 - bbox[1] + + draw.text((x, y), char, fill=color, font=font) + return np.array(img) + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int = 80, + char_size: int = None, + alphabet: str = "standard", + color_mode: str = "color", + background: str = "black", + contrast: float = 1.5, + char_hue = None, + char_saturation = None, + char_brightness = None, + char_scale = None, + char_rotation = None, + char_jitter = None, + cell_effect = None, + energy: float = None, + rotation_scale: float = 0, + _interp = None, + _env = None, + **extra_params +) -> np.ndarray: + """ + Render frame as ASCII art with per-zone effects. + + Args: + frame: Input image + cols: Number of character columns + char_size: Cell size in pixels (overrides cols if set) + alphabet: Character set name or custom string + color_mode: "color", "mono", "invert", or color name + background: Background color name or hex + contrast: Contrast for character selection + char_hue/saturation/brightness/scale/rotation/jitter: Per-zone expressions + cell_effect: Lambda (cell, zone) -> cell for per-cell effects + energy: Energy value from audio analysis + rotation_scale: Max rotation degrees + _interp: Interpreter (auto-injected) + _env: Environment (auto-injected) + **extra_params: Additional params passed to zone dict + """ + h, w = frame.shape[:2] + + # Calculate cell size + if char_size is None or char_size == 0: + cell_size = max(4, w // cols) + else: + cell_size = max(4, int(char_size)) + + # Sample cells + colors, luminances = _cell_sample(frame, cell_size) + rows, cols_actual = luminances.shape + + # Parse background color + bg_color = _parse_color(background) + + # Create output image + out_h = rows * cell_size + out_w = cols_actual * cell_size + output = np.full((out_h, out_w, 3), bg_color, dtype=np.uint8) + + # Check if we have cell_effect + has_cell_effect = cell_effect is not None + + # Process each cell + for r in range(rows): + for c in range(cols_actual): + lum = luminances[r, c] + cell_color = tuple(colors[r, c]) + + # Build zone context + zone = { + 'row': r, + 'col': c, + 'row-norm': r / max(1, rows - 1), + 'col-norm': c / max(1, cols_actual - 1), + 'lum': float(lum), + 'r': cell_color[0] / 255, + 'g': cell_color[1] / 255, + 'b': cell_color[2] / 255, + 'cell_size': cell_size, + } + + # Add HSV + r_f, g_f, b_f = cell_color[0]/255, cell_color[1]/255, cell_color[2]/255 + hsv = colorsys.rgb_to_hsv(r_f, g_f, b_f) + zone['hue'] = hsv[0] * 360 + zone['sat'] = hsv[1] + + # Add energy and rotation_scale + if energy is not None: + zone['energy'] = energy + zone['rotation_scale'] = rotation_scale + + # Add extra params + for k, v in extra_params.items(): + if isinstance(v, (int, float, str, bool)) or v is None: + zone[k] = v + + # Get character + char = _luminance_to_char(lum, alphabet, contrast) + zone['char'] = char + + # Determine cell color based on mode + if color_mode == "mono": + render_color = (255, 255, 255) + elif color_mode == "invert": + render_color = tuple(255 - c for c in cell_color) + elif color_mode == "color": + render_color = cell_color + else: + render_color = _parse_color(color_mode) + + zone['color'] = render_color + + # Render character to cell + cell_img = _render_char_cell(char, cell_size, render_color, bg_color) + + # Apply cell_effect if provided + if has_cell_effect and _interp is not None: + cell_img = _apply_cell_effect(cell_img, zone, cell_effect, _interp, _env, extra_params) + + # Paste cell to output + y1, y2 = r * cell_size, (r + 1) * cell_size + x1, x2 = c * cell_size, (c + 1) * cell_size + output[y1:y2, x1:x2] = cell_img + + # Resize to match input dimensions + if output.shape[:2] != frame.shape[:2]: + output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LINEAR) + + return output + + +def _apply_cell_effect(cell_img, zone, cell_effect, interp, env, extra_params): + """Apply cell_effect lambda to a cell image. + + cell_effect is a Lambda object with params and body. + We create a child environment with zone variables and cell, + then evaluate the lambda body. + """ + # Get Environment class from the interpreter's module + Environment = type(env) + + # Create child environment with zone variables + cell_env = Environment(env) + + # Bind zone variables + for k, v in zone.items(): + cell_env.set(k, v) + + # Also bind with zone- prefix for consistency + cell_env.set('zone-row', zone.get('row', 0)) + cell_env.set('zone-col', zone.get('col', 0)) + cell_env.set('zone-row-norm', zone.get('row-norm', 0)) + cell_env.set('zone-col-norm', zone.get('col-norm', 0)) + cell_env.set('zone-lum', zone.get('lum', 0)) + cell_env.set('zone-sat', zone.get('sat', 0)) + cell_env.set('zone-hue', zone.get('hue', 0)) + cell_env.set('zone-r', zone.get('r', 0)) + cell_env.set('zone-g', zone.get('g', 0)) + cell_env.set('zone-b', zone.get('b', 0)) + + # Inject loaded effects as callable functions + if hasattr(interp, 'effects'): + for effect_name in interp.effects: + def make_effect_fn(name): + def effect_fn(frame, *args): + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 2: + params['offset_x'] = args[0] + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['size'] = args[0] + elif name == 'invert': + pass + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + cell_env.set(effect_name, make_effect_fn(effect_name)) + + # Bind cell image and zone dict + cell_env.set('cell', cell_img) + cell_env.set('zone', zone) + + # Evaluate the cell_effect lambda + # Lambda has params and body - we need to bind the params then evaluate + if hasattr(cell_effect, 'params') and hasattr(cell_effect, 'body'): + # Bind lambda parameters: (lambda [cell zone] body) + if len(cell_effect.params) >= 1: + cell_env.set(cell_effect.params[0], cell_img) + if len(cell_effect.params) >= 2: + cell_env.set(cell_effect.params[1], zone) + + result = interp.eval(cell_effect.body, cell_env) + elif isinstance(cell_effect, list): + # Raw S-expression lambda like (lambda [cell zone] body) or (fn [cell zone] body) + # Check if it's a lambda expression + head = cell_effect[0] if cell_effect else None + head_name = head.name if head and hasattr(head, 'name') else str(head) if head else None + is_lambda = head_name in ('lambda', 'fn') + + if is_lambda: + # (lambda [params...] body) + params = cell_effect[1] if len(cell_effect) > 1 else [] + body = cell_effect[2] if len(cell_effect) > 2 else None + + # Bind lambda parameters + if isinstance(params, list) and len(params) >= 1: + param_name = params[0].name if hasattr(params[0], 'name') else str(params[0]) + cell_env.set(param_name, cell_img) + if isinstance(params, list) and len(params) >= 2: + param_name = params[1].name if hasattr(params[1], 'name') else str(params[1]) + cell_env.set(param_name, zone) + + result = interp.eval(body, cell_env) if body else cell_img + else: + # Some other expression - just evaluate it + result = interp.eval(cell_effect, cell_env) + elif callable(cell_effect): + # It's a callable + result = cell_effect(cell_img, zone) + else: + raise ValueError(f"cell_effect must be a Lambda, list, or callable, got {type(cell_effect)}") + + if isinstance(result, np.ndarray) and result.shape == cell_img.shape: + return result + elif isinstance(result, np.ndarray): + # Shape mismatch - resize to fit + result = cv2.resize(result, (cell_img.shape[1], cell_img.shape[0])) + return result + + raise ValueError(f"cell_effect must return an image array, got {type(result)}") + + +def _get_legacy_ascii_primitives(): + """Import ASCII primitives from legacy primitives module. + + These are loaded lazily to avoid import issues during module loading. + By the time a primitive library is loaded, sexp_effects.primitives + is already in sys.modules (imported by sexp_effects.__init__). + """ + from sexp_effects.primitives import ( + prim_cell_sample, + prim_luminance_to_chars, + prim_render_char_grid, + prim_render_char_grid_fx, + prim_alphabet_char, + prim_alphabet_length, + prim_map_char_grid, + prim_map_colors, + prim_make_char_grid, + prim_set_char, + prim_get_char, + prim_char_grid_dimensions, + cell_sample_extended, + ) + return { + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + } + + +PRIMITIVES = { + 'ascii-fx-zone': prim_ascii_fx_zone, + **_get_legacy_ascii_primitives(), +} diff --git a/sexp_effects/primitive_libs/blending.py b/sexp_effects/primitive_libs/blending.py new file mode 100644 index 0000000..0bf345d --- /dev/null +++ b/sexp_effects/primitive_libs/blending.py @@ -0,0 +1,116 @@ +""" +Blending Primitives Library + +Image blending and compositing operations. +""" +import numpy as np + + +def prim_blend_images(a, b, alpha): + """Blend two images: a * (1-alpha) + b * alpha.""" + alpha = max(0.0, min(1.0, alpha)) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a, b, mode): + """Blend using Photoshop-style blend modes.""" + a = a.astype(float) / 255 + b = b.astype(float) / 255 + + if mode == "multiply": + result = a * b + elif mode == "screen": + result = 1 - (1 - a) * (1 - b) + elif mode == "overlay": + mask = a < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "soft-light": + mask = b < 0.5 + result = np.where(mask, + a - (1 - 2 * b) * a * (1 - a), + a + (2 * b - 1) * (np.sqrt(a) - a)) + elif mode == "hard-light": + mask = b < 0.5 + result = np.where(mask, 2 * a * b, 1 - 2 * (1 - a) * (1 - b)) + elif mode == "color-dodge": + result = np.clip(a / (1 - b + 0.001), 0, 1) + elif mode == "color-burn": + result = 1 - np.clip((1 - a) / (b + 0.001), 0, 1) + elif mode == "difference": + result = np.abs(a - b) + elif mode == "exclusion": + result = a + b - 2 * a * b + elif mode == "add": + result = np.clip(a + b, 0, 1) + elif mode == "subtract": + result = np.clip(a - b, 0, 1) + elif mode == "darken": + result = np.minimum(a, b) + elif mode == "lighten": + result = np.maximum(a, b) + else: + # Default to normal (just return b) + result = b + + return (result * 255).astype(np.uint8) + + +def prim_mask(img, mask_img): + """Apply grayscale mask to image (white=opaque, black=transparent).""" + if len(mask_img.shape) == 3: + mask = mask_img[:, :, 0].astype(float) / 255 + else: + mask = mask_img.astype(float) / 255 + + mask = mask[:, :, np.newaxis] + return (img.astype(float) * mask).astype(np.uint8) + + +def prim_alpha_composite(base, overlay, alpha_channel): + """Composite overlay onto base using alpha channel.""" + if len(alpha_channel.shape) == 3: + alpha = alpha_channel[:, :, 0].astype(float) / 255 + else: + alpha = alpha_channel.astype(float) / 255 + + alpha = alpha[:, :, np.newaxis] + result = base.astype(float) * (1 - alpha) + overlay.astype(float) * alpha + return result.astype(np.uint8) + + +def prim_overlay(base, overlay, x, y, alpha=1.0): + """Overlay image at position (x, y) with optional alpha.""" + result = base.copy() + x, y = int(x), int(y) + oh, ow = overlay.shape[:2] + bh, bw = base.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(ow, bw - x) + sy2 = min(oh, bh - y) + + if sx2 > sx1 and sy2 > sy1: + src = overlay[sy1:sy2, sx1:sx2] + dst = result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] + blended = (dst.astype(float) * (1 - alpha) + src.astype(float) * alpha) + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = blended.astype(np.uint8) + + return result + + +PRIMITIVES = { + # Basic blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + + # Masking + 'mask': prim_mask, + 'alpha-composite': prim_alpha_composite, + + # Overlay + 'overlay': prim_overlay, +} diff --git a/sexp_effects/primitive_libs/color.py b/sexp_effects/primitive_libs/color.py new file mode 100644 index 0000000..0b6854b --- /dev/null +++ b/sexp_effects/primitive_libs/color.py @@ -0,0 +1,137 @@ +""" +Color Primitives Library + +Color manipulation: RGB, HSV, blending, luminance. +""" +import numpy as np +import colorsys + + +def prim_rgb(r, g, b): + """Create RGB color as [r, g, b] (0-255).""" + return [int(max(0, min(255, r))), + int(max(0, min(255, g))), + int(max(0, min(255, b)))] + + +def prim_red(c): + return c[0] + + +def prim_green(c): + return c[1] + + +def prim_blue(c): + return c[2] + + +def prim_luminance(c): + """Perceived luminance (0-1) using standard weights.""" + return (0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) / 255 + + +def prim_rgb_to_hsv(c): + """Convert RGB [0-255] to HSV [h:0-360, s:0-1, v:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, s, v = colorsys.rgb_to_hsv(r, g, b) + return [h * 360, s, v] + + +def prim_hsv_to_rgb(hsv): + """Convert HSV [h:0-360, s:0-1, v:0-1] to RGB [0-255].""" + h, s, v = hsv[0] / 360, hsv[1], hsv[2] + r, g, b = colorsys.hsv_to_rgb(h, s, v) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_rgb_to_hsl(c): + """Convert RGB [0-255] to HSL [h:0-360, s:0-1, l:0-1].""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + h, l, s = colorsys.rgb_to_hls(r, g, b) + return [h * 360, s, l] + + +def prim_hsl_to_rgb(hsl): + """Convert HSL [h:0-360, s:0-1, l:0-1] to RGB [0-255].""" + h, s, l = hsl[0] / 360, hsl[1], hsl[2] + r, g, b = colorsys.hls_to_rgb(h, l, s) + return [int(r * 255), int(g * 255), int(b * 255)] + + +def prim_blend_color(c1, c2, alpha): + """Blend two colors: c1 * (1-alpha) + c2 * alpha.""" + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img): + """Get average color of an image.""" + mean = np.mean(img, axis=(0, 1)) + return [int(mean[0]), int(mean[1]), int(mean[2])] + + +def prim_dominant_color(img, k=1): + """Get dominant color using k-means (simplified: just average for now).""" + return prim_average_color(img) + + +def prim_invert_color(c): + """Invert a color.""" + return [255 - c[0], 255 - c[1], 255 - c[2]] + + +def prim_grayscale_color(c): + """Convert color to grayscale.""" + gray = int(0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2]) + return [gray, gray, gray] + + +def prim_saturate(c, amount): + """Adjust saturation of color. amount=0 is grayscale, 1 is unchanged, >1 is more saturated.""" + hsv = prim_rgb_to_hsv(c) + hsv[1] = max(0, min(1, hsv[1] * amount)) + return prim_hsv_to_rgb(hsv) + + +def prim_brighten(c, amount): + """Adjust brightness. amount=0 is black, 1 is unchanged, >1 is brighter.""" + return [int(max(0, min(255, c[i] * amount))) for i in range(3)] + + +def prim_shift_hue(c, degrees): + """Shift hue by degrees.""" + hsv = prim_rgb_to_hsv(c) + hsv[0] = (hsv[0] + degrees) % 360 + return prim_hsv_to_rgb(hsv) + + +PRIMITIVES = { + # Construction + 'rgb': prim_rgb, + + # Component access + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + + # Color space conversion + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'rgb->hsl': prim_rgb_to_hsl, + 'hsl->rgb': prim_hsl_to_rgb, + + # Blending + 'blend-color': prim_blend_color, + + # Analysis + 'average-color': prim_average_color, + 'dominant-color': prim_dominant_color, + + # Manipulation + 'invert-color': prim_invert_color, + 'grayscale-color': prim_grayscale_color, + 'saturate': prim_saturate, + 'brighten': prim_brighten, + 'shift-hue': prim_shift_hue, +} diff --git a/sexp_effects/primitive_libs/color_ops.py b/sexp_effects/primitive_libs/color_ops.py new file mode 100644 index 0000000..dd9076c --- /dev/null +++ b/sexp_effects/primitive_libs/color_ops.py @@ -0,0 +1,90 @@ +""" +Color Operations Primitives Library + +Vectorized color adjustments: brightness, contrast, saturation, invert, HSV. +These operate on entire images for fast processing. +""" +import numpy as np +import cv2 + + +def prim_adjust(img, brightness=0, contrast=1): + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img, amount): + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img): + """Invert all pixel values.""" + return (255 - img).astype(np.uint8) + + +def prim_shift_hsv(img, h=0, s=1, v=1): + """Shift HSV: h=degrees offset, s/v=multipliers.""" + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +def prim_add_noise(img, amount): + """Add gaussian noise to image.""" + noise = np.random.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img, levels): + """Reduce to N color levels per channel.""" + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_sepia(img, intensity=1.0): + """Apply sepia tone effect.""" + sepia_matrix = np.array([ + [0.393, 0.769, 0.189], + [0.349, 0.686, 0.168], + [0.272, 0.534, 0.131] + ]) + sepia = np.dot(img, sepia_matrix.T) + result = img.astype(np.float32) * (1 - intensity) + sepia * intensity + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_grayscale(img): + """Convert to grayscale (still RGB output).""" + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + return np.stack([gray, gray, gray], axis=-1).astype(np.uint8) + + +PRIMITIVES = { + # Brightness/Contrast + 'adjust': prim_adjust, + + # Saturation + 'mix-gray': prim_mix_gray, + 'grayscale': prim_grayscale, + + # HSV manipulation + 'shift-hsv': prim_shift_hsv, + + # Inversion + 'invert-img': prim_invert_img, + + # Effects + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'sepia': prim_sepia, +} diff --git a/sexp_effects/primitive_libs/core.py b/sexp_effects/primitive_libs/core.py new file mode 100644 index 0000000..352cbd3 --- /dev/null +++ b/sexp_effects/primitive_libs/core.py @@ -0,0 +1,271 @@ +""" +Core Primitives - Always available, minimal essential set. + +These are the primitives that form the foundation of the language. +They cannot be overridden by libraries. +""" + + +# Arithmetic +def prim_add(*args): + if len(args) == 0: + return 0 + result = args[0] + for arg in args[1:]: + result = result + arg + return result + + +def prim_sub(a, b=None): + if b is None: + return -a + return a - b + + +def prim_mul(*args): + if len(args) == 0: + return 1 + result = args[0] + for arg in args[1:]: + result = result * arg + return result + + +def prim_div(a, b): + return a / b + + +def prim_mod(a, b): + return a % b + + +def prim_abs(x): + return abs(x) + + +def prim_min(*args): + return min(args) + + +def prim_max(*args): + return max(args) + + +def prim_round(x): + return round(x) + + +def prim_floor(x): + import math + return math.floor(x) + + +def prim_ceil(x): + import math + return math.ceil(x) + + +# Comparison +def prim_lt(a, b): + return a < b + + +def prim_gt(a, b): + return a > b + + +def prim_le(a, b): + return a <= b + + +def prim_ge(a, b): + return a >= b + + +def prim_eq(a, b): + if isinstance(a, float) or isinstance(b, float): + return abs(a - b) < 1e-9 + return a == b + + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# Logic +def prim_not(x): + return not x + + +def prim_and(*args): + for a in args: + if not a: + return False + return True + + +def prim_or(*args): + for a in args: + if a: + return True + return False + + +# Basic data access +def prim_get(obj, key, default=None): + """Get value from dict or list.""" + if isinstance(obj, dict): + return obj.get(key, default) + elif isinstance(obj, (list, tuple)): + try: + return obj[int(key)] + except (IndexError, ValueError): + return default + return default + + +def prim_nth(seq, i): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_length(seq): + return len(seq) + + +def prim_list(*args): + return list(args) + + +# Type checking +def prim_is_number(x): + return isinstance(x, (int, float)) + + +def prim_is_string(x): + return isinstance(x, str) + + +def prim_is_list(x): + return isinstance(x, (list, tuple)) + + +def prim_is_dict(x): + return isinstance(x, dict) + + +def prim_is_nil(x): + return x is None + + +# Higher-order / iteration +def prim_reduce(seq, init, fn): + """(reduce seq init fn) — fold left: fn(fn(fn(init, s0), s1), s2) ...""" + acc = init + for item in seq: + acc = fn(acc, item) + return acc + + +def prim_map(seq, fn): + """(map seq fn) — apply fn to each element, return new list.""" + return [fn(item) for item in seq] + + +def prim_range(*args): + """(range end), (range start end), or (range start end step) — integer range.""" + if len(args) == 1: + return list(range(int(args[0]))) + elif len(args) == 2: + return list(range(int(args[0]), int(args[1]))) + elif len(args) >= 3: + return list(range(int(args[0]), int(args[1]), int(args[2]))) + return [] + + +# Random +import random +_rng = random.Random() + +def prim_rand(): + """Return random float in [0, 1).""" + return _rng.random() + +def prim_rand_int(lo, hi): + """Return random integer in [lo, hi].""" + return _rng.randint(int(lo), int(hi)) + +def prim_rand_range(lo, hi): + """Return random float in [lo, hi).""" + return lo + _rng.random() * (hi - lo) + +def prim_map_range(val, from_lo, from_hi, to_lo, to_hi): + """Map value from one range to another.""" + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + +# Core primitives dict +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + 'mod': prim_mod, + 'abs': prim_abs, + 'min': prim_min, + 'max': prim_max, + 'round': prim_round, + 'floor': prim_floor, + 'ceil': prim_ceil, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Logic + 'not': prim_not, + 'and': prim_and, + 'or': prim_or, + + # Data access + 'get': prim_get, + 'nth': prim_nth, + 'first': prim_first, + 'length': prim_length, + 'len': prim_length, + 'list': prim_list, + + # Type predicates + 'number?': prim_is_number, + 'string?': prim_is_string, + 'list?': prim_is_list, + 'dict?': prim_is_dict, + 'nil?': prim_is_nil, + 'is-nil': prim_is_nil, + + # Higher-order / iteration + 'reduce': prim_reduce, + 'fold': prim_reduce, + 'map': prim_map, + 'range': prim_range, + + # Random + 'rand': prim_rand, + 'rand-int': prim_rand_int, + 'rand-range': prim_rand_range, + 'map-range': prim_map_range, +} diff --git a/sexp_effects/primitive_libs/drawing.py b/sexp_effects/primitive_libs/drawing.py new file mode 100644 index 0000000..ddd1a01 --- /dev/null +++ b/sexp_effects/primitive_libs/drawing.py @@ -0,0 +1,136 @@ +""" +Drawing Primitives Library + +Draw shapes, text, and characters on images. +""" +import numpy as np +import cv2 +from PIL import Image, ImageDraw, ImageFont + + +# Default font (will be loaded lazily) +_default_font = None + + +def _get_default_font(size=16): + """Get default font, creating if needed.""" + global _default_font + if _default_font is None or _default_font.size != size: + try: + _default_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) + except: + _default_font = ImageFont.load_default() + return _default_font + + +def prim_draw_char(img, char, x, y, font_size=16, color=None): + """Draw a single character at (x, y).""" + if color is None: + color = [255, 255, 255] + + pil_img = Image.fromarray(img) + draw = ImageDraw.Draw(pil_img) + font = _get_default_font(font_size) + draw.text((x, y), char, fill=tuple(color), font=font) + return np.array(pil_img) + + +def prim_draw_text(img, text, x, y, font_size=16, color=None): + """Draw text string at (x, y).""" + if color is None: + color = [255, 255, 255] + + pil_img = Image.fromarray(img) + draw = ImageDraw.Draw(pil_img) + font = _get_default_font(font_size) + draw.text((x, y), text, fill=tuple(color), font=font) + return np.array(pil_img) + + +def prim_fill_rect(img, x, y, w, h, color=None): + """Fill a rectangle with color.""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + x, y, w, h = int(x), int(y), int(w), int(h) + result[y:y+h, x:x+w] = color + return result + + +def prim_draw_rect(img, x, y, w, h, color=None, thickness=1): + """Draw rectangle outline.""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), + tuple(color), thickness) + return result + + +def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1): + """Draw a line from (x1, y1) to (x2, y2).""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), + tuple(color), thickness) + return result + + +def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, fill=False): + """Draw a circle.""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + t = -1 if fill else thickness + cv2.circle(result, (int(cx), int(cy)), int(radius), tuple(color), t) + return result + + +def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, thickness=1, fill=False): + """Draw an ellipse.""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + t = -1 if fill else thickness + cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), + angle, 0, 360, tuple(color), t) + return result + + +def prim_draw_polygon(img, points, color=None, thickness=1, fill=False): + """Draw a polygon from list of [x, y] points.""" + if color is None: + color = [255, 255, 255] + + result = img.copy() + pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + + if fill: + cv2.fillPoly(result, [pts], tuple(color)) + else: + cv2.polylines(result, [pts], True, tuple(color), thickness) + + return result + + +PRIMITIVES = { + # Text + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + + # Rectangles + 'fill-rect': prim_fill_rect, + 'draw-rect': prim_draw_rect, + + # Lines and shapes + 'draw-line': prim_draw_line, + 'draw-circle': prim_draw_circle, + 'draw-ellipse': prim_draw_ellipse, + 'draw-polygon': prim_draw_polygon, +} diff --git a/sexp_effects/primitive_libs/filters.py b/sexp_effects/primitive_libs/filters.py new file mode 100644 index 0000000..a66f107 --- /dev/null +++ b/sexp_effects/primitive_libs/filters.py @@ -0,0 +1,119 @@ +""" +Filters Primitives Library + +Image filters: blur, sharpen, edges, convolution. +""" +import numpy as np +import cv2 + + +def prim_blur(img, radius): + """Gaussian blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img, radius): + """Box blur with given radius.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_median_blur(img, radius): + """Median blur (good for noise removal).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.medianBlur(img, ksize) + + +def prim_bilateral(img, d=9, sigma_color=75, sigma_space=75): + """Bilateral filter (edge-preserving blur).""" + return cv2.bilateralFilter(img, d, sigma_color, sigma_space) + + +def prim_sharpen(img, amount=1.0): + """Sharpen image using unsharp mask.""" + blurred = cv2.GaussianBlur(img, (0, 0), 3) + return cv2.addWeighted(img, 1.0 + amount, blurred, -amount, 0) + + +def prim_edges(img, low=50, high=150): + """Canny edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, low, high) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img, ksize=3): + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=ksize) + sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=ksize) + mag = np.sqrt(sobelx**2 + sobely**2) + mag = np.clip(mag, 0, 255).astype(np.uint8) + return cv2.cvtColor(mag, cv2.COLOR_GRAY2RGB) + + +def prim_laplacian(img, ksize=3): + """Laplacian edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + lap = cv2.Laplacian(gray, cv2.CV_64F, ksize=ksize) + lap = np.abs(lap) + lap = np.clip(lap, 0, 255).astype(np.uint8) + return cv2.cvtColor(lap, cv2.COLOR_GRAY2RGB) + + +def prim_emboss(img): + """Emboss effect.""" + kernel = np.array([[-2, -1, 0], + [-1, 1, 1], + [ 0, 1, 2]]) + result = cv2.filter2D(img, -1, kernel) + return np.clip(result + 128, 0, 255).astype(np.uint8) + + +def prim_dilate(img, size=1): + """Morphological dilation.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.dilate(img, kernel) + + +def prim_erode(img, size=1): + """Morphological erosion.""" + kernel = np.ones((size * 2 + 1, size * 2 + 1), np.uint8) + return cv2.erode(img, kernel) + + +def prim_convolve(img, kernel): + """Apply custom convolution kernel.""" + kernel = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, kernel) + + +PRIMITIVES = { + # Blur + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'median-blur': prim_median_blur, + 'bilateral': prim_bilateral, + + # Sharpen + 'sharpen': prim_sharpen, + + # Edges + 'edges': prim_edges, + 'sobel': prim_sobel, + 'laplacian': prim_laplacian, + + # Effects + 'emboss': prim_emboss, + + # Morphology + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Custom + 'convolve': prim_convolve, +} diff --git a/sexp_effects/primitive_libs/geometry.py b/sexp_effects/primitive_libs/geometry.py new file mode 100644 index 0000000..5b385a4 --- /dev/null +++ b/sexp_effects/primitive_libs/geometry.py @@ -0,0 +1,143 @@ +""" +Geometry Primitives Library + +Geometric transforms: rotate, scale, flip, translate, remap. +""" +import numpy as np +import cv2 + + +def prim_translate(img, dx, dy): + """Translate image by (dx, dy) pixels.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_rotate(img, angle, cx=None, cy=None): + """Rotate image by angle degrees around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_scale(img, sx, sy, cx=None, cy=None): + """Scale image by (sx, sy) around center (cx, cy).""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Build transform matrix + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h)) + + +def prim_flip_h(img): + """Flip image horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img): + """Flip image vertically.""" + return cv2.flip(img, 0) + + +def prim_flip(img, direction="horizontal"): + """Flip image in given direction.""" + if direction in ("horizontal", "h"): + return prim_flip_h(img) + elif direction in ("vertical", "v"): + return prim_flip_v(img) + elif direction in ("both", "hv", "vh"): + return cv2.flip(img, -1) + return img + + +def prim_transpose(img): + """Transpose image (swap x and y).""" + return np.transpose(img, (1, 0, 2)) + + +def prim_remap(img, map_x, map_y): + """Remap image using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), + map_y.astype(np.float32), + cv2.INTER_LINEAR) + + +def prim_make_coords(w, h): + """Create coordinate grids for remapping.""" + x = np.arange(w, dtype=np.float32) + y = np.arange(h, dtype=np.float32) + map_x, map_y = np.meshgrid(x, y) + return (map_x, map_y) + + +def prim_perspective(img, src_pts, dst_pts): + """Apply perspective transform.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getPerspectiveTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpPerspective(img, M, (w, h)) + + +def prim_affine(img, src_pts, dst_pts): + """Apply affine transform using 3 point pairs.""" + src = np.float32(src_pts) + dst = np.float32(dst_pts) + M = cv2.getAffineTransform(src, dst) + h, w = img.shape[:2] + return cv2.warpAffine(img, M, (w, h)) + + +def _get_legacy_geometry_primitives(): + """Import geometry primitives from legacy primitives module.""" + from sexp_effects.primitives import ( + prim_coords_x, + prim_coords_y, + prim_ripple_displace, + prim_fisheye_displace, + prim_kaleidoscope_displace, + ) + return { + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'ripple-displace': prim_ripple_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + } + + +PRIMITIVES = { + # Basic transforms + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + + # Flips + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'flip': prim_flip, + 'transpose': prim_transpose, + + # Remapping + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Advanced transforms + 'perspective': prim_perspective, + 'affine': prim_affine, + + # Displace / coordinate ops (from legacy primitives) + **_get_legacy_geometry_primitives(), +} diff --git a/sexp_effects/primitive_libs/image.py b/sexp_effects/primitive_libs/image.py new file mode 100644 index 0000000..beae3ce --- /dev/null +++ b/sexp_effects/primitive_libs/image.py @@ -0,0 +1,144 @@ +""" +Image Primitives Library + +Basic image operations: dimensions, pixels, resize, crop, paste. +""" +import numpy as np +import cv2 + + +def prim_width(img): + return img.shape[1] + + +def prim_height(img): + return img.shape[0] + + +def prim_make_image(w, h, color=None): + """Create a new image filled with color (default black).""" + if color is None: + color = [0, 0, 0] + img = np.zeros((h, w, 3), dtype=np.uint8) + img[:] = color + return img + + +def prim_copy(img): + return img.copy() + + +def prim_pixel(img, x, y): + """Get pixel color at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + if 0 <= x < w and 0 <= y < h: + return list(img[int(y), int(x)]) + return [0, 0, 0] + + +def prim_set_pixel(img, x, y, color): + """Set pixel at (x, y) to color, returns modified image.""" + result = img.copy() + h, w = result.shape[:2] + if 0 <= x < w and 0 <= y < h: + result[int(y), int(x)] = color + return result + + +def prim_sample(img, x, y): + """Bilinear sample at float coordinates, returns [r, g, b] as floats.""" + h, w = img.shape[:2] + x = max(0, min(w - 1.001, x)) + y = max(0, min(h - 1.001, y)) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + top = c00 * (1 - fx) + c10 * fx + bottom = c01 * (1 - fx) + c11 * fx + return list(top * (1 - fy) + bottom * fy) + + +def prim_channel(img, c): + """Extract single channel (0=R, 1=G, 2=B).""" + return img[:, :, c] + + +def prim_merge_channels(r, g, b): + """Merge three single-channel arrays into RGB image.""" + return np.stack([r, g, b], axis=2).astype(np.uint8) + + +def prim_resize(img, w, h, mode="linear"): + """Resize image to w x h.""" + interp = cv2.INTER_LINEAR + if mode == "nearest": + interp = cv2.INTER_NEAREST + elif mode == "cubic": + interp = cv2.INTER_CUBIC + elif mode == "area": + interp = cv2.INTER_AREA + return cv2.resize(img, (int(w), int(h)), interpolation=interp) + + +def prim_crop(img, x, y, w, h): + """Crop rectangle from image.""" + x, y, w, h = int(x), int(y), int(w), int(h) + ih, iw = img.shape[:2] + x = max(0, min(x, iw - 1)) + y = max(0, min(y, ih - 1)) + w = min(w, iw - x) + h = min(h, ih - y) + return img[y:y+h, x:x+w].copy() + + +def prim_paste(dst, src, x, y): + """Paste src onto dst at position (x, y).""" + result = dst.copy() + x, y = int(x), int(y) + sh, sw = src.shape[:2] + dh, dw = dst.shape[:2] + + # Clip to bounds + sx1 = max(0, -x) + sy1 = max(0, -y) + dx1 = max(0, x) + dy1 = max(0, y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + if sx2 > sx1 and sy2 > sy1: + result[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = src[sy1:sy2, sx1:sx2] + + return result + + +PRIMITIVES = { + # Dimensions + 'width': prim_width, + 'height': prim_height, + + # Creation + 'make-image': prim_make_image, + 'copy': prim_copy, + + # Pixel access + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + + # Channels + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + + # Geometry + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, +} diff --git a/sexp_effects/primitive_libs/math.py b/sexp_effects/primitive_libs/math.py new file mode 100644 index 0000000..140ad3e --- /dev/null +++ b/sexp_effects/primitive_libs/math.py @@ -0,0 +1,164 @@ +""" +Math Primitives Library + +Trigonometry, rounding, clamping, random numbers, etc. +""" +import math +import random as rand_module + + +def prim_sin(x): + return math.sin(x) + + +def prim_cos(x): + return math.cos(x) + + +def prim_tan(x): + return math.tan(x) + + +def prim_asin(x): + return math.asin(x) + + +def prim_acos(x): + return math.acos(x) + + +def prim_atan(x): + return math.atan(x) + + +def prim_atan2(y, x): + return math.atan2(y, x) + + +def prim_sqrt(x): + return math.sqrt(x) + + +def prim_pow(x, y): + return math.pow(x, y) + + +def prim_exp(x): + return math.exp(x) + + +def prim_log(x, base=None): + if base is None: + return math.log(x) + return math.log(x, base) + + +def prim_abs(x): + return abs(x) + + +def prim_floor(x): + return math.floor(x) + + +def prim_ceil(x): + return math.ceil(x) + + +def prim_round(x): + return round(x) + + +def prim_min(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return min(args[0]) + return min(args) + + +def prim_max(*args): + if len(args) == 1 and hasattr(args[0], '__iter__'): + return max(args[0]) + return max(args) + + +def prim_clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +def prim_lerp(a, b, t): + """Linear interpolation: a + (b - a) * t""" + return a + (b - a) * t + + +def prim_smoothstep(edge0, edge1, x): + """Smooth interpolation between 0 and 1.""" + t = prim_clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0) + return t * t * (3 - 2 * t) + + +def prim_random(lo=0.0, hi=1.0): + return rand_module.uniform(lo, hi) + + +def prim_randint(lo, hi): + return rand_module.randint(lo, hi) + + +def prim_gaussian(mean=0.0, std=1.0): + return rand_module.gauss(mean, std) + + +def prim_sign(x): + if x > 0: + return 1 + elif x < 0: + return -1 + return 0 + + +def prim_fract(x): + """Fractional part of x.""" + return x - math.floor(x) + + +PRIMITIVES = { + # Trigonometry + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'asin': prim_asin, + 'acos': prim_acos, + 'atan': prim_atan, + 'atan2': prim_atan2, + + # Powers and roots + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'exp': prim_exp, + 'log': prim_log, + + # Rounding + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'sign': prim_sign, + 'fract': prim_fract, + + # Min/max/clamp + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'smoothstep': prim_smoothstep, + + # Random + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + + # Constants + 'pi': math.pi, + 'tau': math.tau, + 'e': math.e, +} diff --git a/sexp_effects/primitive_libs/streaming.py b/sexp_effects/primitive_libs/streaming.py new file mode 100644 index 0000000..9092087 --- /dev/null +++ b/sexp_effects/primitive_libs/streaming.py @@ -0,0 +1,462 @@ +""" +Streaming primitives for video/audio processing. + +These primitives handle video source reading and audio analysis, +keeping the interpreter completely generic. + +GPU Acceleration: +- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU) +- Hardware video decoding (NVDEC) is used when available +- Dramatically improves performance on GPU nodes +""" + +import os +import numpy as np +import subprocess +import json +from pathlib import Path + +# Try to import CuPy for GPU acceleration +try: + import cupy as cp + CUPY_AVAILABLE = True +except ImportError: + cp = None + CUPY_AVAILABLE = False + +# GPU persistence mode - output CuPy arrays instead of numpy +# Disabled by default until all primitives support GPU frames +GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE + +# Check for hardware decode support (cached) +_HWDEC_AVAILABLE = None + + +def _check_hwdec(): + """Check if NVIDIA hardware decode is available.""" + global _HWDEC_AVAILABLE + if _HWDEC_AVAILABLE is not None: + return _HWDEC_AVAILABLE + + try: + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2) + if result.returncode != 0: + _HWDEC_AVAILABLE = False + return False + result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5) + _HWDEC_AVAILABLE = "cuda" in result.stdout + except Exception: + _HWDEC_AVAILABLE = False + + return _HWDEC_AVAILABLE + + +class VideoSource: + """Video source with persistent streaming pipe for fast sequential reads.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + self.fps = fps # Output fps for the stream + self._frame_size = None + self._duration = None + self._proc = None # Persistent ffmpeg process + self._stream_time = 0.0 # Current position in stream + self._frame_time = 1.0 / fps # Time per frame at output fps + self._last_read_time = -1 + self._cached_frame = None + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + # Get video info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}") + try: + info = json.loads(result.stdout) + except json.JSONDecodeError: + raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}") + + for stream in info.get("streams", []): + if stream.get("codec_type") == "video": + self._frame_size = (stream.get("width", 720), stream.get("height", 720)) + # Try direct duration field first + if "duration" in stream: + self._duration = float(stream["duration"]) + # Fall back to tags.DURATION (webm format: "00:01:00.124000000") + elif "tags" in stream and "DURATION" in stream["tags"]: + dur_str = stream["tags"]["DURATION"] + parts = dur_str.split(":") + if len(parts) == 3: + h, m, s = parts + self._duration = int(h) * 3600 + int(m) * 60 + float(s) + break + + # Fallback: check format duration if stream duration not found + if self._duration is None and "format" in info and "duration" in info["format"]: + self._duration = float(info["format"]["duration"]) + + if not self._frame_size: + self._frame_size = (720, 720) + + import sys + print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr) + + def _start_stream(self, seek_time: float = 0): + """Start or restart the ffmpeg streaming process. + + Uses NVIDIA hardware decoding (NVDEC) when available for better performance. + """ + if self._proc: + self._proc.kill() + self._proc = None + + # Check file exists before trying to open + if not self.path.exists(): + raise FileNotFoundError(f"Video file not found: {self.path}") + + w, h = self._frame_size + + # Build ffmpeg command with optional hardware decode + cmd = ["ffmpeg", "-v", "error"] + + # Use hardware decode if available (significantly faster) + if _check_hwdec(): + cmd.extend(["-hwaccel", "cuda"]) + + cmd.extend([ + "-ss", f"{seek_time:.3f}", + "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(self.fps), # Output at specified fps + "-" + ]) + + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self._stream_time = seek_time + + # Check if process started successfully by reading first bit of stderr + import select + import sys + readable, _, _ = select.select([self._proc.stderr], [], [], 0.5) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr) + + def _read_frame_from_stream(self): + """Read one frame from the stream. + + Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise. + """ + w, h = self._frame_size + frame_size = w * h * 3 + + if not self._proc or self._proc.poll() is not None: + return None + + data = self._proc.stdout.read(frame_size) + if len(data) < frame_size: + return None + + frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() + + # Transfer to GPU if persistence mode enabled + if GPU_PERSIST: + return cp.asarray(frame) + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + if self._cached_frame is not None: + return self._cached_frame + return self.read_at(0) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time using streaming with smart seeking.""" + # Cache check - return same frame for same time + if t == self._last_read_time and self._cached_frame is not None: + return self._cached_frame + + w, h = self._frame_size + + # Loop time if video is shorter + seek_time = t + if self._duration and self._duration > 0: + seek_time = t % self._duration + # If we're within 0.1s of the end, wrap to beginning to avoid EOF issues + if seek_time > self._duration - 0.1: + seek_time = 0.0 + + # Decide whether to seek or continue streaming + # Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead + # Allow small backward tolerance to handle floating point and timing jitter + need_seek = ( + self._proc is None or + self._proc.poll() is not None or + seek_time < self._stream_time - self._frame_time or # More than 1 frame backward + seek_time > self._stream_time + 2.0 + ) + + if need_seek: + import sys + reason = "no proc" if self._proc is None else "proc dead" if self._proc.poll() is not None else "backward" if seek_time < self._stream_time else "jump" + print(f"SEEK {self.path.name}: t={t:.4f} seek={seek_time:.4f} stream={self._stream_time:.4f} ({reason})", file=sys.stderr) + self._start_stream(seek_time) + + # Skip frames to reach target time + skip_retries = 0 + while self._stream_time + self._frame_time <= seek_time: + frame = self._read_frame_from_stream() + if frame is None: + # Stream ended or failed - restart from seek point + import time + skip_retries += 1 + if skip_retries > 3: + # Give up skipping, just start fresh at seek_time + self._start_stream(seek_time) + time.sleep(0.1) + break + self._start_stream(seek_time) + time.sleep(0.05) + continue + self._stream_time += self._frame_time + skip_retries = 0 # Reset on successful read + + # Read the target frame with retry logic + frame = None + max_retries = 3 + for attempt in range(max_retries): + frame = self._read_frame_from_stream() + if frame is not None: + break + + # Stream failed - try restarting + import sys + import time + print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr) + + # Check for ffmpeg errors + if self._proc and self._proc.stderr: + try: + import select + readable, _, _ = select.select([self._proc.stderr], [], [], 0.1) + if readable: + err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore') + if err: + print(f"ffmpeg error: {err}", file=sys.stderr) + except: + pass + + # Wait a bit and restart + time.sleep(0.1) + self._start_stream(seek_time) + + # Give ffmpeg time to start + time.sleep(0.1) + + if frame is None: + import sys + raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries") + else: + self._stream_time += self._frame_time + + self._last_read_time = t + self._cached_frame = frame + return frame + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._frame_size + + def close(self): + if self._proc: + self._proc.kill() + self._proc = None + + +class AudioAnalyzer: + """Audio analyzer for energy and beat detection.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + self.sample_rate = sample_rate + + # Check if file exists + if not self.path.exists(): + raise FileNotFoundError(f"Audio file not found: {self.path}") + + # Load audio via ffmpeg + cmd = ["ffmpeg", "-v", "error", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}") + self._audio = np.frombuffer(result.stdout, dtype=np.float32) + if len(self._audio) == 0: + raise RuntimeError(f"Audio file is empty or invalid: {self.path}") + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}") + info = json.loads(result.stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + # Beat detection state + self._flux_history = [] + self._last_beat_time = -1 + self._beat_count = 0 + self._last_beat_check_time = -1 + # Cache beat result for current time (so multiple scans see same result) + self._beat_cache_time = -1 + self._beat_cache_result = False + + def get_energy(self, t: float) -> float: + """Get energy level at time t (0-1).""" + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + """Check if there's a beat at time t.""" + # Return cached result if same time (multiple scans query same frame) + if t == self._beat_cache_time: + return self._beat_cache_result + + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size/2: + self._beat_cache_time = t + self._beat_cache_result = False + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + self._beat_cache_time = t + self._beat_cache_result = False + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + if len(self._flux_history) > 50: + self._flux_history = self._flux_history[-50:] + + if len(self._flux_history) < 5: + self._beat_cache_time = t + self._beat_cache_result = False + return False + + recent = [f for _, f in self._flux_history[-20:]] + threshold = np.mean(recent) + 1.5 * np.std(recent) + + is_beat = flux > threshold and (t - self._last_beat_time) > 0.1 + if is_beat: + self._last_beat_time = t + if t > self._last_beat_check_time: + self._beat_count += 1 + self._last_beat_check_time = t + + # Cache result for this time + self._beat_cache_time = t + self._beat_cache_result = is_beat + return is_beat + + def get_beat_count(self, t: float) -> int: + """Get cumulative beat count up to time t.""" + # Ensure beat detection has run up to this time + self.get_beat(t) + return self._beat_count + + +# === Primitives === + +def prim_make_video_source(path: str, fps: float = 30): + """Create a video source from a file path.""" + return VideoSource(path, fps) + + +def prim_source_read(source: VideoSource, t: float = None): + """Read a frame from a video source.""" + import sys + if t is not None: + frame = source.read_at(t) + # Debug: show source and time + if int(t * 10) % 10 == 0: # Every second + print(f"READ {source.path.name}: t={t:.2f} stream={source._stream_time:.2f}", file=sys.stderr) + return frame + return source.read() + + +def prim_source_skip(source: VideoSource): + """Skip a frame (keep pipe in sync).""" + source.skip() + + +def prim_source_size(source: VideoSource): + """Get (width, height) of source.""" + return source.size + + +def prim_make_audio_analyzer(path: str): + """Create an audio analyzer from a file path.""" + return AudioAnalyzer(path) + + +def prim_audio_energy(analyzer: AudioAnalyzer, t: float) -> float: + """Get energy level (0-1) at time t.""" + return analyzer.get_energy(t) + + +def prim_audio_beat(analyzer: AudioAnalyzer, t: float) -> bool: + """Check if there's a beat at time t.""" + return analyzer.get_beat(t) + + +def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int: + """Get cumulative beat count up to time t.""" + return analyzer.get_beat_count(t) + + +def prim_audio_duration(analyzer: AudioAnalyzer) -> float: + """Get audio duration in seconds.""" + return analyzer.duration + + +# Export primitives +PRIMITIVES = { + # Video source + 'make-video-source': prim_make_video_source, + 'source-read': prim_source_read, + 'source-skip': prim_source_skip, + 'source-size': prim_source_size, + + # Audio analyzer + 'make-audio-analyzer': prim_make_audio_analyzer, + 'audio-energy': prim_audio_energy, + 'audio-beat': prim_audio_beat, + 'audio-beat-count': prim_audio_beat_count, + 'audio-duration': prim_audio_duration, +} diff --git a/sexp_effects/primitives.py b/sexp_effects/primitives.py new file mode 100644 index 0000000..8bdca5c --- /dev/null +++ b/sexp_effects/primitives.py @@ -0,0 +1,3043 @@ +""" +Safe Primitives for S-Expression Effects + +These are the building blocks that user-defined effects can use. +All primitives operate only on image data - no filesystem, network, etc. +""" + +import numpy as np +import cv2 +from typing import Any, Callable, Dict, List, Tuple, Optional +from dataclasses import dataclass +import math + + +@dataclass +class ZoneContext: + """Context for a single cell/zone in ASCII art grid.""" + row: int + col: int + row_norm: float # Normalized row position 0-1 + col_norm: float # Normalized col position 0-1 + luminance: float # Cell luminance 0-1 + saturation: float # Cell saturation 0-1 + hue: float # Cell hue 0-360 + r: float # Red component 0-1 + g: float # Green component 0-1 + b: float # Blue component 0-1 + + +class DeterministicRNG: + """Seeded RNG for reproducible effects.""" + + def __init__(self, seed: int = 42): + self._rng = np.random.RandomState(seed) + + def random(self, low: float = 0, high: float = 1) -> float: + return self._rng.uniform(low, high) + + def randint(self, low: int, high: int) -> int: + return self._rng.randint(low, high + 1) + + def gaussian(self, mean: float = 0, std: float = 1) -> float: + return self._rng.normal(mean, std) + + +# Global RNG instance (reset per frame with seed param) +_rng = DeterministicRNG() + + +def reset_rng(seed: int): + """Reset the global RNG with a new seed.""" + global _rng + _rng = DeterministicRNG(seed) + + +# ============================================================================= +# Color Names (FFmpeg/X11 compatible) +# ============================================================================= + +NAMED_COLORS = { + # Basic colors + "black": (0, 0, 0), + "white": (255, 255, 255), + "red": (255, 0, 0), + "green": (0, 128, 0), + "blue": (0, 0, 255), + "yellow": (255, 255, 0), + "cyan": (0, 255, 255), + "magenta": (255, 0, 255), + + # Grays + "gray": (128, 128, 128), + "grey": (128, 128, 128), + "darkgray": (169, 169, 169), + "darkgrey": (169, 169, 169), + "lightgray": (211, 211, 211), + "lightgrey": (211, 211, 211), + "dimgray": (105, 105, 105), + "dimgrey": (105, 105, 105), + "silver": (192, 192, 192), + + # Reds + "darkred": (139, 0, 0), + "firebrick": (178, 34, 34), + "crimson": (220, 20, 60), + "indianred": (205, 92, 92), + "lightcoral": (240, 128, 128), + "salmon": (250, 128, 114), + "darksalmon": (233, 150, 122), + "lightsalmon": (255, 160, 122), + "tomato": (255, 99, 71), + "orangered": (255, 69, 0), + "coral": (255, 127, 80), + + # Oranges + "orange": (255, 165, 0), + "darkorange": (255, 140, 0), + + # Yellows + "gold": (255, 215, 0), + "lightyellow": (255, 255, 224), + "lemonchiffon": (255, 250, 205), + "papayawhip": (255, 239, 213), + "moccasin": (255, 228, 181), + "peachpuff": (255, 218, 185), + "palegoldenrod": (238, 232, 170), + "khaki": (240, 230, 140), + "darkkhaki": (189, 183, 107), + + # Greens + "lime": (0, 255, 0), + "limegreen": (50, 205, 50), + "forestgreen": (34, 139, 34), + "darkgreen": (0, 100, 0), + "seagreen": (46, 139, 87), + "mediumseagreen": (60, 179, 113), + "springgreen": (0, 255, 127), + "mediumspringgreen": (0, 250, 154), + "lightgreen": (144, 238, 144), + "palegreen": (152, 251, 152), + "darkseagreen": (143, 188, 143), + "greenyellow": (173, 255, 47), + "chartreuse": (127, 255, 0), + "lawngreen": (124, 252, 0), + "olivedrab": (107, 142, 35), + "olive": (128, 128, 0), + "darkolivegreen": (85, 107, 47), + "yellowgreen": (154, 205, 50), + + # Cyans/Teals + "aqua": (0, 255, 255), + "teal": (0, 128, 128), + "darkcyan": (0, 139, 139), + "lightcyan": (224, 255, 255), + "aquamarine": (127, 255, 212), + "mediumaquamarine": (102, 205, 170), + "paleturquoise": (175, 238, 238), + "turquoise": (64, 224, 208), + "mediumturquoise": (72, 209, 204), + "darkturquoise": (0, 206, 209), + "cadetblue": (95, 158, 160), + + # Blues + "navy": (0, 0, 128), + "darkblue": (0, 0, 139), + "mediumblue": (0, 0, 205), + "royalblue": (65, 105, 225), + "cornflowerblue": (100, 149, 237), + "steelblue": (70, 130, 180), + "dodgerblue": (30, 144, 255), + "deepskyblue": (0, 191, 255), + "lightskyblue": (135, 206, 250), + "skyblue": (135, 206, 235), + "lightsteelblue": (176, 196, 222), + "lightblue": (173, 216, 230), + "powderblue": (176, 224, 230), + "slateblue": (106, 90, 205), + "mediumslateblue": (123, 104, 238), + "darkslateblue": (72, 61, 139), + "midnightblue": (25, 25, 112), + + # Purples/Violets + "purple": (128, 0, 128), + "darkmagenta": (139, 0, 139), + "darkviolet": (148, 0, 211), + "blueviolet": (138, 43, 226), + "darkorchid": (153, 50, 204), + "mediumorchid": (186, 85, 211), + "orchid": (218, 112, 214), + "violet": (238, 130, 238), + "plum": (221, 160, 221), + "thistle": (216, 191, 216), + "lavender": (230, 230, 250), + "indigo": (75, 0, 130), + "mediumpurple": (147, 112, 219), + "fuchsia": (255, 0, 255), + "hotpink": (255, 105, 180), + "deeppink": (255, 20, 147), + "mediumvioletred": (199, 21, 133), + "palevioletred": (219, 112, 147), + + # Pinks + "pink": (255, 192, 203), + "lightpink": (255, 182, 193), + "mistyrose": (255, 228, 225), + + # Browns + "brown": (165, 42, 42), + "maroon": (128, 0, 0), + "saddlebrown": (139, 69, 19), + "sienna": (160, 82, 45), + "chocolate": (210, 105, 30), + "peru": (205, 133, 63), + "sandybrown": (244, 164, 96), + "burlywood": (222, 184, 135), + "tan": (210, 180, 140), + "rosybrown": (188, 143, 143), + "goldenrod": (218, 165, 32), + "darkgoldenrod": (184, 134, 11), + + # Whites + "snow": (255, 250, 250), + "honeydew": (240, 255, 240), + "mintcream": (245, 255, 250), + "azure": (240, 255, 255), + "aliceblue": (240, 248, 255), + "ghostwhite": (248, 248, 255), + "whitesmoke": (245, 245, 245), + "seashell": (255, 245, 238), + "beige": (245, 245, 220), + "oldlace": (253, 245, 230), + "floralwhite": (255, 250, 240), + "ivory": (255, 255, 240), + "antiquewhite": (250, 235, 215), + "linen": (250, 240, 230), + "lavenderblush": (255, 240, 245), + "wheat": (245, 222, 179), + "cornsilk": (255, 248, 220), + "blanchedalmond": (255, 235, 205), + "bisque": (255, 228, 196), + "navajowhite": (255, 222, 173), + + # Special + "transparent": (0, 0, 0), # Note: no alpha support, just black +} + + +def parse_color(color_spec: str) -> Optional[Tuple[int, int, int]]: + """ + Parse a color specification into RGB tuple. + + Supports: + - Named colors: "red", "green", "lime", "navy", etc. + - Hex colors: "#FF0000", "#f00", "0xFF0000" + - Special modes: "color", "mono", "invert" return None (handled separately) + + Returns: + RGB tuple (r, g, b) or None for special modes + """ + if color_spec is None: + return None + + color_spec = str(color_spec).strip().lower() + + # Special modes handled elsewhere + if color_spec in ("color", "mono", "invert"): + return None + + # Check named colors + if color_spec in NAMED_COLORS: + return NAMED_COLORS[color_spec] + + # Handle hex colors + hex_str = None + if color_spec.startswith("#"): + hex_str = color_spec[1:] + elif color_spec.startswith("0x"): + hex_str = color_spec[2:] + elif all(c in "0123456789abcdef" for c in color_spec) and len(color_spec) in (3, 6): + hex_str = color_spec + + if hex_str: + try: + if len(hex_str) == 3: + # Short form: #RGB -> #RRGGBB + r = int(hex_str[0] * 2, 16) + g = int(hex_str[1] * 2, 16) + b = int(hex_str[2] * 2, 16) + return (r, g, b) + elif len(hex_str) == 6: + r = int(hex_str[0:2], 16) + g = int(hex_str[2:4], 16) + b = int(hex_str[4:6], 16) + return (r, g, b) + except ValueError: + pass + + # Unknown color - default to None (will use original colors) + return None + + +# ============================================================================= +# Image Primitives +# ============================================================================= + +def prim_width(img: np.ndarray) -> int: + """Get image width.""" + return img.shape[1] + + +def prim_height(img: np.ndarray) -> int: + """Get image height.""" + return img.shape[0] + + +def prim_make_image(w: int, h: int, color: List[int]) -> np.ndarray: + """Create a new image filled with color.""" + img = np.zeros((int(h), int(w), 3), dtype=np.uint8) + if color: + img[:, :] = color[:3] + return img + + +def prim_copy(img: np.ndarray) -> np.ndarray: + """Copy an image.""" + return img.copy() + + +def prim_pixel(img: np.ndarray, x: int, y: int) -> List[int]: + """Get pixel at (x, y) as [r, g, b].""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + return list(img[y, x]) + return [0, 0, 0] + + +def prim_set_pixel(img: np.ndarray, x: int, y: int, color: List[int]) -> np.ndarray: + """Set pixel at (x, y). Returns modified image.""" + h, w = img.shape[:2] + x, y = int(x), int(y) + if 0 <= x < w and 0 <= y < h: + img[y, x] = color[:3] + return img + + +def prim_sample(img: np.ndarray, x: float, y: float) -> List[float]: + """Bilinear sample at float coordinates.""" + h, w = img.shape[:2] + x = np.clip(x, 0, w - 1) + y = np.clip(y, 0, h - 1) + + x0, y0 = int(x), int(y) + x1, y1 = min(x0 + 1, w - 1), min(y0 + 1, h - 1) + fx, fy = x - x0, y - y0 + + c00 = img[y0, x0].astype(float) + c10 = img[y0, x1].astype(float) + c01 = img[y1, x0].astype(float) + c11 = img[y1, x1].astype(float) + + c = (c00 * (1 - fx) * (1 - fy) + + c10 * fx * (1 - fy) + + c01 * (1 - fx) * fy + + c11 * fx * fy) + + return list(c) + + +def prim_channel(img: np.ndarray, c: int) -> np.ndarray: + """Extract a single channel as 2D array.""" + return img[:, :, int(c)].copy() + + +def prim_merge_channels(r: np.ndarray, g: np.ndarray, b: np.ndarray) -> np.ndarray: + """Merge three channels into RGB image.""" + return np.stack([r, g, b], axis=-1).astype(np.uint8) + + +def prim_resize(img: np.ndarray, w: int, h: int, mode: str = "linear") -> np.ndarray: + """Resize image. Mode: linear, nearest, area.""" + w, h = int(w), int(h) + if w < 1 or h < 1: + return img + interp = { + "linear": cv2.INTER_LINEAR, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + }.get(mode, cv2.INTER_LINEAR) + return cv2.resize(img, (w, h), interpolation=interp) + + +def prim_crop(img: np.ndarray, x: int, y: int, w: int, h: int) -> np.ndarray: + """Crop a region from image.""" + ih, iw = img.shape[:2] + x, y, w, h = int(x), int(y), int(w), int(h) + x = max(0, min(x, iw)) + y = max(0, min(y, ih)) + w = max(0, min(w, iw - x)) + h = max(0, min(h, ih - y)) + return img[y:y + h, x:x + w].copy() + + +def prim_paste(dst: np.ndarray, src: np.ndarray, x: int, y: int) -> np.ndarray: + """Paste src onto dst at position (x, y).""" + dh, dw = dst.shape[:2] + sh, sw = src.shape[:2] + x, y = int(x), int(y) + + # Calculate valid regions + sx1 = max(0, -x) + sy1 = max(0, -y) + sx2 = min(sw, dw - x) + sy2 = min(sh, dh - y) + + dx1 = max(0, x) + dy1 = max(0, y) + dx2 = dx1 + (sx2 - sx1) + dy2 = dy1 + (sy2 - sy1) + + if dx2 > dx1 and dy2 > dy1: + dst[dy1:dy2, dx1:dx2] = src[sy1:sy2, sx1:sx2] + + return dst + + +# ============================================================================= +# Color Primitives +# ============================================================================= + +def prim_rgb(r: float, g: float, b: float) -> List[int]: + """Create RGB color.""" + return [int(np.clip(r, 0, 255)), + int(np.clip(g, 0, 255)), + int(np.clip(b, 0, 255))] + + +def prim_red(c: List[int]) -> int: + return c[0] if c else 0 + + +def prim_green(c: List[int]) -> int: + return c[1] if len(c) > 1 else 0 + + +def prim_blue(c: List[int]) -> int: + return c[2] if len(c) > 2 else 0 + + +def prim_luminance(c: List[int]) -> float: + """Calculate luminance (grayscale value).""" + if not c: + return 0 + return 0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2] + + +def prim_rgb_to_hsv(c: List[int]) -> List[float]: + """Convert RGB to HSV.""" + r, g, b = c[0] / 255, c[1] / 255, c[2] / 255 + mx, mn = max(r, g, b), min(r, g, b) + diff = mx - mn + + if diff == 0: + h = 0 + elif mx == r: + h = (60 * ((g - b) / diff) + 360) % 360 + elif mx == g: + h = (60 * ((b - r) / diff) + 120) % 360 + else: + h = (60 * ((r - g) / diff) + 240) % 360 + + s = 0 if mx == 0 else diff / mx + v = mx + + return [h, s * 100, v * 100] + + +def prim_hsv_to_rgb(hsv: List[float]) -> List[int]: + """Convert HSV to RGB.""" + h, s, v = hsv[0], hsv[1] / 100, hsv[2] / 100 + c = v * s + x = c * (1 - abs((h / 60) % 2 - 1)) + m = v - c + + if h < 60: + r, g, b = c, x, 0 + elif h < 120: + r, g, b = x, c, 0 + elif h < 180: + r, g, b = 0, c, x + elif h < 240: + r, g, b = 0, x, c + elif h < 300: + r, g, b = x, 0, c + else: + r, g, b = c, 0, x + + return [int((r + m) * 255), int((g + m) * 255), int((b + m) * 255)] + + +def prim_blend_color(c1: List[int], c2: List[int], alpha: float) -> List[int]: + """Blend two colors.""" + alpha = np.clip(alpha, 0, 1) + return [int(c1[i] * (1 - alpha) + c2[i] * alpha) for i in range(3)] + + +def prim_average_color(img: np.ndarray) -> List[int]: + """Get average color of image/region.""" + return [int(x) for x in img.mean(axis=(0, 1))] + + +# ============================================================================= +# Image Operations (Bulk) +# ============================================================================= + +def prim_map_pixels(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each pixel: fn(x, y, [r,g,b]) -> [r,g,b].""" + result = img.copy() + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + new_color = fn(x, y, color) + if new_color is not None: + result[y, x] = new_color[:3] + return result + + +def prim_map_rows(img: np.ndarray, fn: Callable) -> np.ndarray: + """Apply function to each row: fn(y, row) -> row.""" + result = img.copy() + h = img.shape[0] + for y in range(h): + row = img[y].copy() + new_row = fn(y, row) + if new_row is not None: + result[y] = new_row + return result + + +def prim_for_grid(img: np.ndarray, cell_size: int, fn: Callable) -> np.ndarray: + """Iterate over grid cells: fn(gx, gy, cell_img) for side effects.""" + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + for gy in range(rows): + for gx in range(cols): + y, x = gy * cell_size, gx * cell_size + cell = img[y:y + cell_size, x:x + cell_size] + fn(gx, gy, cell) + + return img + + +def prim_fold_pixels(img: np.ndarray, init: Any, fn: Callable) -> Any: + """Fold over pixels: fn(acc, x, y, color) -> acc.""" + acc = init + h, w = img.shape[:2] + for y in range(h): + for x in range(w): + color = list(img[y, x]) + acc = fn(acc, x, y, color) + return acc + + +# ============================================================================= +# Convolution / Filters +# ============================================================================= + +def prim_convolve(img: np.ndarray, kernel: List[List[float]]) -> np.ndarray: + """Apply convolution kernel.""" + k = np.array(kernel, dtype=np.float32) + return cv2.filter2D(img, -1, k) + + +def prim_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Gaussian blur.""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.GaussianBlur(img, (ksize, ksize), 0) + + +def prim_box_blur(img: np.ndarray, radius: int) -> np.ndarray: + """Box blur (faster than Gaussian).""" + radius = max(1, int(radius)) + ksize = radius * 2 + 1 + return cv2.blur(img, (ksize, ksize)) + + +def prim_edges(img: np.ndarray, low: int = 50, high: int = 150) -> np.ndarray: + """Canny edge detection, returns grayscale edges.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + edges = cv2.Canny(gray, int(low), int(high)) + return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + + +def prim_sobel(img: np.ndarray) -> np.ndarray: + """Sobel edge detection.""" + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.float32) + sx = cv2.Sobel(gray, cv2.CV_32F, 1, 0) + sy = cv2.Sobel(gray, cv2.CV_32F, 0, 1) + magnitude = np.sqrt(sx ** 2 + sy ** 2) + magnitude = np.clip(magnitude, 0, 255).astype(np.uint8) + return cv2.cvtColor(magnitude, cv2.COLOR_GRAY2RGB) + + +def prim_dilate(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological dilation.""" + kernel = np.ones((size, size), np.uint8) + return cv2.dilate(img, kernel, iterations=1) + + +def prim_erode(img: np.ndarray, size: int = 1) -> np.ndarray: + """Morphological erosion.""" + kernel = np.ones((size, size), np.uint8) + return cv2.erode(img, kernel, iterations=1) + + +# ============================================================================= +# Geometric Transforms +# ============================================================================= + +def prim_translate(img: np.ndarray, dx: float, dy: float) -> np.ndarray: + """Translate image.""" + h, w = img.shape[:2] + M = np.float32([[1, 0, dx], [0, 1, dy]]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_rotate(img: np.ndarray, angle: float, cx: float = None, cy: float = None) -> np.ndarray: + """Rotate image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_scale(img: np.ndarray, sx: float, sy: float, cx: float = None, cy: float = None) -> np.ndarray: + """Scale image around center.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + M = np.float32([ + [sx, 0, cx * (1 - sx)], + [0, sy, cy * (1 - sy)] + ]) + return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) + + +def prim_flip_h(img: np.ndarray) -> np.ndarray: + """Flip horizontally.""" + return cv2.flip(img, 1) + + +def prim_flip_v(img: np.ndarray) -> np.ndarray: + """Flip vertically.""" + return cv2.flip(img, 0) + + +def prim_remap(img: np.ndarray, map_x: np.ndarray, map_y: np.ndarray) -> np.ndarray: + """Remap using coordinate maps.""" + return cv2.remap(img, map_x.astype(np.float32), map_y.astype(np.float32), + cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + +def prim_make_coords(w: int, h: int) -> Tuple[np.ndarray, np.ndarray]: + """Create coordinate grid (map_x, map_y).""" + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + return map_x, map_y + + +# ============================================================================= +# Blending +# ============================================================================= + +def prim_blend_images(a: np.ndarray, b: np.ndarray, alpha: float) -> np.ndarray: + """Blend two images. Auto-resizes b to match a if sizes differ.""" + alpha = np.clip(alpha, 0, 1) + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + return (a.astype(float) * (1 - alpha) + b.astype(float) * alpha).astype(np.uint8) + + +def prim_blend_mode(a: np.ndarray, b: np.ndarray, mode: str) -> np.ndarray: + """Blend with various modes: add, multiply, screen, overlay, difference. + Auto-resizes b to match a if sizes differ.""" + # Auto-resize b to match a if different sizes + if a.shape[:2] != b.shape[:2]: + b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_LINEAR) + af = a.astype(float) / 255 + bf = b.astype(float) / 255 + + if mode == "add": + result = af + bf + elif mode == "multiply": + result = af * bf + elif mode == "screen": + result = 1 - (1 - af) * (1 - bf) + elif mode == "overlay": + mask = af < 0.5 + result = np.where(mask, 2 * af * bf, 1 - 2 * (1 - af) * (1 - bf)) + elif mode == "difference": + result = np.abs(af - bf) + elif mode == "lighten": + result = np.maximum(af, bf) + elif mode == "darken": + result = np.minimum(af, bf) + else: + result = af + + return (np.clip(result, 0, 1) * 255).astype(np.uint8) + + +def prim_mask(img: np.ndarray, mask_img: np.ndarray) -> np.ndarray: + """Apply grayscale mask to image.""" + if len(mask_img.shape) == 3: + mask = cv2.cvtColor(mask_img, cv2.COLOR_RGB2GRAY) + else: + mask = mask_img + mask_f = mask.astype(float) / 255 + result = img.astype(float) * mask_f[:, :, np.newaxis] + return result.astype(np.uint8) + + +# ============================================================================= +# Drawing +# ============================================================================= + +# Simple font (5x7 bitmap characters) +FONT_5X7 = { + ' ': [0, 0, 0, 0, 0, 0, 0], + '.': [0, 0, 0, 0, 0, 0, 4], + ':': [0, 0, 4, 0, 4, 0, 0], + '-': [0, 0, 0, 14, 0, 0, 0], + '=': [0, 0, 14, 0, 14, 0, 0], + '+': [0, 4, 4, 31, 4, 4, 0], + '*': [0, 4, 21, 14, 21, 4, 0], + '#': [10, 31, 10, 10, 31, 10, 0], + '%': [19, 19, 4, 8, 25, 25, 0], + '@': [14, 17, 23, 21, 23, 16, 14], + '0': [14, 17, 19, 21, 25, 17, 14], + '1': [4, 12, 4, 4, 4, 4, 14], + '2': [14, 17, 1, 2, 4, 8, 31], + '3': [31, 2, 4, 2, 1, 17, 14], + '4': [2, 6, 10, 18, 31, 2, 2], + '5': [31, 16, 30, 1, 1, 17, 14], + '6': [6, 8, 16, 30, 17, 17, 14], + '7': [31, 1, 2, 4, 8, 8, 8], + '8': [14, 17, 17, 14, 17, 17, 14], + '9': [14, 17, 17, 15, 1, 2, 12], +} + +# Add uppercase letters +for i, c in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ'): + FONT_5X7[c] = [0] * 7 # Placeholder + + +def prim_draw_char(img: np.ndarray, char: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw a character at position.""" + # Use OpenCV's built-in font for simplicity + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, char, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_draw_text(img: np.ndarray, text: str, x: int, y: int, + size: int, color: List[int]) -> np.ndarray: + """Draw text at position.""" + font = cv2.FONT_HERSHEY_SIMPLEX + scale = size / 20.0 + thickness = max(1, int(size / 10)) + cv2.putText(img, text, (int(x), int(y + size)), font, scale, tuple(color[:3]), thickness) + return img + + +def prim_fill_rect(img: np.ndarray, x: int, y: int, w: int, h: int, + color: List[int]) -> np.ndarray: + """Fill rectangle.""" + x, y, w, h = int(x), int(y), int(w), int(h) + img[y:y + h, x:x + w] = color[:3] + return img + + +def prim_draw_line(img: np.ndarray, x1: int, y1: int, x2: int, y2: int, + color: List[int], thickness: int = 1) -> np.ndarray: + """Draw line.""" + cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), tuple(color[:3]), int(thickness)) + return img + + +# ============================================================================= +# Math Primitives +# ============================================================================= + +def prim_sin(x: float) -> float: + return math.sin(x) + + +def prim_cos(x: float) -> float: + return math.cos(x) + + +def prim_tan(x: float) -> float: + return math.tan(x) + + +def prim_atan2(y: float, x: float) -> float: + return math.atan2(y, x) + + +def prim_sqrt(x: float) -> float: + return math.sqrt(max(0, x)) + + +def prim_pow(x: float, y: float) -> float: + return math.pow(x, y) + + +def prim_abs(x: float) -> float: + return abs(x) + + +def prim_floor(x: float) -> int: + return int(math.floor(x)) + + +def prim_ceil(x: float) -> int: + return int(math.ceil(x)) + + +def prim_round(x: float) -> int: + return int(round(x)) + + +def prim_min(*args) -> float: + return min(args) + + +def prim_max(*args) -> float: + return max(args) + + +def prim_clamp(x: float, lo: float, hi: float) -> float: + return max(lo, min(hi, x)) + + +def prim_lerp(a: float, b: float, t: float) -> float: + """Linear interpolation.""" + return a + (b - a) * t + + +def prim_mod(a: float, b: float) -> float: + return a % b + + +def prim_random(lo: float = 0, hi: float = 1) -> float: + """Random number from global RNG.""" + return _rng.random(lo, hi) + + +def prim_randint(lo: int, hi: int) -> int: + """Random integer from global RNG.""" + return _rng.randint(lo, hi) + + +def prim_gaussian(mean: float = 0, std: float = 1) -> float: + """Gaussian random from global RNG.""" + return _rng.gaussian(mean, std) + + +def prim_assert(condition, message: str = "Assertion failed"): + """Assert that condition is true, raise error with message if false.""" + if not condition: + raise RuntimeError(f"Assertion error: {message}") + return True + + +# ============================================================================= +# Array/List Primitives +# ============================================================================= + +def prim_length(seq) -> int: + return len(seq) + + +def prim_nth(seq, i: int): + i = int(i) + if 0 <= i < len(seq): + return seq[i] + return None + + +def prim_first(seq): + return seq[0] if seq else None + + +def prim_rest(seq): + return seq[1:] if seq else [] + + +def prim_take(seq, n: int): + return seq[:int(n)] + + +def prim_drop(seq, n: int): + return seq[int(n):] + + +def prim_cons(x, seq): + return [x] + list(seq) + + +def prim_append(*seqs): + result = [] + for s in seqs: + result.extend(s) + return result + + +def prim_reverse(seq): + return list(reversed(seq)) + + +def prim_range(start: int, end: int, step: int = 1) -> List[int]: + return list(range(int(start), int(end), int(step))) + + +def prim_roll(arr: np.ndarray, shift: int, axis: int = 0) -> np.ndarray: + """Circular roll of array.""" + return np.roll(arr, int(shift), axis=int(axis)) + + +def prim_list(*args) -> list: + """Create a list.""" + return list(args) + + +# ============================================================================= +# Primitive Registry +# ============================================================================= + +def prim_add(*args): + return sum(args) + +def prim_sub(a, b=None): + if b is None: + return -a # Unary negation + return a - b + +def prim_mul(*args): + result = 1 + for x in args: + result *= x + return result + +def prim_div(a, b): + return a / b if b != 0 else 0 + +def prim_lt(a, b): + return a < b + +def prim_gt(a, b): + return a > b + +def prim_le(a, b): + return a <= b + +def prim_ge(a, b): + return a >= b + +def prim_eq(a, b): + # Handle None/nil comparisons with numpy arrays + if a is None: + return b is None + if b is None: + return a is None + if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.array_equal(a, b) + return False # array vs non-array + return a == b + +def prim_ne(a, b): + return not prim_eq(a, b) + + +# ============================================================================= +# Vectorized Bulk Operations (true primitives for composing effects) +# ============================================================================= + +def prim_color_matrix(img: np.ndarray, matrix: List[List[float]]) -> np.ndarray: + """Apply a 3x3 color transformation matrix to all pixels.""" + m = np.array(matrix, dtype=np.float32) + result = img.astype(np.float32) @ m.T + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_adjust(img: np.ndarray, brightness: float = 0, contrast: float = 1) -> np.ndarray: + """Adjust brightness and contrast. Brightness: -255 to 255, Contrast: 0 to 3+.""" + result = (img.astype(np.float32) - 128) * contrast + 128 + brightness + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_mix_gray(img: np.ndarray, amount: float) -> np.ndarray: + """Mix image with its grayscale version. 0=original, 1=grayscale.""" + gray = 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2] + gray_rgb = np.stack([gray, gray, gray], axis=-1) + result = img.astype(np.float32) * (1 - amount) + gray_rgb * amount + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_invert_img(img: np.ndarray) -> np.ndarray: + """Invert all pixel values.""" + return (255 - img).astype(np.uint8) + + +def prim_add_noise(img: np.ndarray, amount: float) -> np.ndarray: + """Add gaussian noise to image.""" + noise = _rng._rng.normal(0, amount, img.shape) + result = img.astype(np.float32) + noise + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_quantize(img: np.ndarray, levels: int) -> np.ndarray: + """Reduce to N color levels per channel.""" + levels = max(2, int(levels)) + factor = 256 / levels + result = (img // factor) * factor + factor // 2 + return np.clip(result, 0, 255).astype(np.uint8) + + +def prim_shift_hsv(img: np.ndarray, h: float = 0, s: float = 1, v: float = 1) -> np.ndarray: + """Shift HSV: h=degrees offset, s/v=multipliers.""" + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) + hsv[:, :, 0] = (hsv[:, :, 0] + h / 2) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] * s, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] * v, 0, 255) + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + +# ============================================================================= +# Array Math Primitives (vectorized operations on coordinate arrays) +# ============================================================================= + +def prim_arr_add(a: np.ndarray, b) -> np.ndarray: + """Element-wise addition. b can be array or scalar.""" + return (np.asarray(a) + np.asarray(b)).astype(np.float32) + + +def prim_arr_sub(a: np.ndarray, b) -> np.ndarray: + """Element-wise subtraction. b can be array or scalar.""" + return (np.asarray(a) - np.asarray(b)).astype(np.float32) + + +def prim_arr_mul(a: np.ndarray, b) -> np.ndarray: + """Element-wise multiplication. b can be array or scalar.""" + return (np.asarray(a) * np.asarray(b)).astype(np.float32) + + +def prim_arr_div(a: np.ndarray, b) -> np.ndarray: + """Element-wise division. b can be array or scalar.""" + b = np.asarray(b) + # Avoid division by zero + with np.errstate(divide='ignore', invalid='ignore'): + result = np.asarray(a) / np.where(b == 0, 1e-10, b) + return result.astype(np.float32) + + +def prim_arr_mod(a: np.ndarray, b) -> np.ndarray: + """Element-wise modulo.""" + return (np.asarray(a) % np.asarray(b)).astype(np.float32) + + +def prim_arr_sin(a: np.ndarray) -> np.ndarray: + """Element-wise sine.""" + return np.sin(np.asarray(a)).astype(np.float32) + + +def prim_arr_cos(a: np.ndarray) -> np.ndarray: + """Element-wise cosine.""" + return np.cos(np.asarray(a)).astype(np.float32) + + +def prim_arr_tan(a: np.ndarray) -> np.ndarray: + """Element-wise tangent.""" + return np.tan(np.asarray(a)).astype(np.float32) + + +def prim_arr_sqrt(a: np.ndarray) -> np.ndarray: + """Element-wise square root.""" + return np.sqrt(np.maximum(0, np.asarray(a))).astype(np.float32) + + +def prim_arr_pow(a: np.ndarray, b) -> np.ndarray: + """Element-wise power.""" + return np.power(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_abs(a: np.ndarray) -> np.ndarray: + """Element-wise absolute value.""" + return np.abs(np.asarray(a)).astype(np.float32) + + +def prim_arr_neg(a: np.ndarray) -> np.ndarray: + """Element-wise negation.""" + return (-np.asarray(a)).astype(np.float32) + + +def prim_arr_exp(a: np.ndarray) -> np.ndarray: + """Element-wise exponential.""" + return np.exp(np.asarray(a)).astype(np.float32) + + +def prim_arr_atan2(y: np.ndarray, x: np.ndarray) -> np.ndarray: + """Element-wise atan2(y, x).""" + return np.arctan2(np.asarray(y), np.asarray(x)).astype(np.float32) + + +def prim_arr_min(a: np.ndarray, b) -> np.ndarray: + """Element-wise minimum.""" + return np.minimum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_max(a: np.ndarray, b) -> np.ndarray: + """Element-wise maximum.""" + return np.maximum(np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_clip(a: np.ndarray, lo, hi) -> np.ndarray: + """Element-wise clip to range.""" + return np.clip(np.asarray(a), lo, hi).astype(np.float32) + + +def prim_arr_where(cond: np.ndarray, a, b) -> np.ndarray: + """Element-wise conditional: where cond is true, use a, else b.""" + return np.where(np.asarray(cond), np.asarray(a), np.asarray(b)).astype(np.float32) + + +def prim_arr_floor(a: np.ndarray) -> np.ndarray: + """Element-wise floor.""" + return np.floor(np.asarray(a)).astype(np.float32) + + +def prim_arr_lerp(a: np.ndarray, b: np.ndarray, t) -> np.ndarray: + """Element-wise linear interpolation.""" + a, b = np.asarray(a), np.asarray(b) + return (a + (b - a) * t).astype(np.float32) + + +# ============================================================================= +# Coordinate Transformation Primitives +# ============================================================================= + +def prim_polar_from_center(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create polar coordinates (r, theta) from image center. + + Usage: + (polar-from-center img) ; center of image + (polar-from-center img cx cy) ; custom center + (polar-from-center w h cx cy) ; explicit dimensions + + Returns: (r, theta) tuple of arrays + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + r = np.sqrt(dx**2 + dy**2) + theta = np.arctan2(dy, dx) + + return (r, theta) + + +def prim_cart_from_polar(r: np.ndarray, theta: np.ndarray, cx: float, cy: float) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert polar coordinates back to Cartesian. + + Args: + r: radius array + theta: angle array + cx, cy: center point + + Returns: (x, y) tuple of coordinate arrays + """ + x = (cx + r * np.cos(theta)).astype(np.float32) + y = (cy + r * np.sin(theta)).astype(np.float32) + return (x, y) + + +def prim_normalize_coords(img_or_w, h_or_cx=None, cx=None, cy=None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create normalized coordinates (-1 to 1) from center. + + Returns: (x_norm, y_norm) tuple of arrays where center is (0,0) + """ + if isinstance(img_or_w, np.ndarray): + h, w = img_or_w.shape[:2] + if h_or_cx is None: + cx, cy = w / 2, h / 2 + else: + cx, cy = h_or_cx, cx if cx is not None else h / 2 + else: + w = int(img_or_w) + h = int(h_or_cx) + cx = cx if cx is not None else w / 2 + cy = cy if cy is not None else h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + + return (x_norm, y_norm) + + +def prim_coords_x(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get x/first component from coordinate tuple.""" + return coords[0] + + +def prim_coords_y(coords: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + """Get y/second component from coordinate tuple.""" + return coords[1] + + +def prim_make_coords_centered(w: int, h: int, cx: float = None, cy: float = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Create coordinate grids centered at (cx, cy). + Like make-coords but returns coordinates relative to center. + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + return (x_coords - cx, y_coords - cy) + + +# ============================================================================= +# Specialized Distortion Primitives +# ============================================================================= + +def prim_wave_displace(w: int, h: int, axis: str, freq: float, amp: float, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create wave displacement maps. + + Args: + w, h: dimensions + axis: "x" (horizontal waves) or "y" (vertical waves) + freq: wave frequency (waves per image width/height) + amp: wave amplitude in pixels + phase: phase offset in radians + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) + map_y = np.tile(np.arange(h, dtype=np.float32).reshape(-1, 1), (1, w)) + + if axis == "x" or axis == "horizontal": + # Horizontal waves: displace x based on y + wave = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + map_x = map_x + wave + elif axis == "y" or axis == "vertical": + # Vertical waves: displace y based on x + wave = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_y = map_y + wave + elif axis == "both": + wave_x = np.sin(2 * np.pi * freq * map_y / h + phase) * amp + wave_y = np.sin(2 * np.pi * freq * map_x / w + phase) * amp + map_x = map_x + wave_x + map_y = map_y + wave_y + + return (map_x, map_y) + + +def prim_swirl_displace(w: int, h: int, strength: float, radius: float = 0.5, + cx: float = None, cy: float = None, falloff: str = "quadratic") -> Tuple[np.ndarray, np.ndarray]: + """ + Create swirl displacement maps. + + Args: + w, h: dimensions + strength: swirl strength in radians + radius: effect radius as fraction of max dimension + cx, cy: center (defaults to image center) + falloff: "linear", "quadratic", or "gaussian" + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + radius_px = max(w, h) * radius + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + angle = np.arctan2(dy, dx) + + # Normalized distance for falloff + norm_dist = dist / radius_px + + # Calculate falloff factor + if falloff == "linear": + factor = np.maximum(0, 1 - norm_dist) + elif falloff == "gaussian": + factor = np.exp(-norm_dist**2 * 2) + else: # quadratic + factor = np.maximum(0, 1 - norm_dist**2) + + # Apply swirl rotation + new_angle = angle + strength * factor + + # Calculate new coordinates + map_x = (cx + dist * np.cos(new_angle)).astype(np.float32) + map_y = (cy + dist * np.sin(new_angle)).astype(np.float32) + + return (map_x, map_y) + + +def prim_fisheye_displace(w: int, h: int, strength: float, cx: float = None, cy: float = None, + zoom_correct: bool = True) -> Tuple[np.ndarray, np.ndarray]: + """ + Create fisheye/barrel distortion displacement maps. + + Args: + w, h: dimensions + strength: distortion strength (-1 to 1, positive=bulge, negative=pinch) + cx, cy: center (defaults to image center) + zoom_correct: auto-zoom to hide black edges + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize coordinates + x_norm = (x_coords - cx) / (w / 2) + y_norm = (y_coords - cy) / (h / 2) + r = np.sqrt(x_norm**2 + y_norm**2) + + # Apply barrel/pincushion distortion + if strength > 0: + r_distorted = r * (1 + strength * r**2) + else: + r_distorted = r / (1 - strength * r**2 + 0.001) + + # Calculate scale factor + with np.errstate(divide='ignore', invalid='ignore'): + scale = np.where(r > 0, r_distorted / r, 1) + + # Apply zoom correction + if zoom_correct and strength > 0: + zoom = 1 + strength * 0.5 + scale = scale / zoom + + # Calculate new coordinates + map_x = (x_norm * scale * (w / 2) + cx).astype(np.float32) + map_y = (y_norm * scale * (h / 2) + cy).astype(np.float32) + + return (map_x, map_y) + + +def prim_kaleidoscope_displace(w: int, h: int, segments: int, rotation: float = 0, + cx: float = None, cy: float = None, zoom: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create kaleidoscope displacement maps. + + Args: + w, h: dimensions + segments: number of symmetry segments (3-16) + rotation: rotation angle in degrees + cx, cy: center (defaults to image center) + zoom: zoom factor + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + segments = max(3, min(int(segments), 16)) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + segment_angle = 2 * np.pi / segments + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Translate to center + x_centered = x_coords - cx + y_centered = y_coords - cy + + # Convert to polar + r = np.sqrt(x_centered**2 + y_centered**2) + theta = np.arctan2(y_centered, x_centered) + + # Apply rotation + theta = theta - np.deg2rad(rotation) + + # Fold angle into first segment and mirror + theta_normalized = theta % (2 * np.pi) + segment_idx = (theta_normalized / segment_angle).astype(int) + theta_in_segment = theta_normalized - segment_idx * segment_angle + + # Mirror alternating segments + mirror_mask = (segment_idx % 2) == 1 + theta_in_segment = np.where(mirror_mask, segment_angle - theta_in_segment, theta_in_segment) + + # Apply zoom + r = r / zoom + + # Convert back to Cartesian + map_x = (r * np.cos(theta_in_segment) + cx).astype(np.float32) + map_y = (r * np.sin(theta_in_segment) + cy).astype(np.float32) + + return (map_x, map_y) + + +# ============================================================================= +# Character/ASCII Art Primitives +# ============================================================================= + +# Character sets ordered by visual density (light to dark) +CHAR_ALPHABETS = { + "standard": " .`'^\",:;Il!i><~+_-?][}{1)(|/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$", + "blocks": " ░▒▓█", + "simple": " .-:=+*#%@", + "digits": " 0123456789", +} + +# Global atlas cache: keyed on (frozenset(chars), cell_size) -> +# (atlas_array, char_to_idx) where atlas_array is (N, cell_size, cell_size) uint8. +_char_atlas_cache = {} +_CHAR_ATLAS_CACHE_MAX = 32 + + +def _get_char_atlas(alphabet: str, cell_size: int) -> dict: + """Get or create character atlas for alphabet (legacy dict version).""" + atlas_arr, char_to_idx = _get_render_atlas(alphabet, cell_size) + # Build legacy dict from array + idx_to_char = {v: k for k, v in char_to_idx.items()} + return {idx_to_char[i]: atlas_arr[i] for i in range(len(atlas_arr))} + + +def _get_render_atlas(unique_chars_or_alphabet, cell_size: int): + """Get or build a stacked numpy atlas for vectorised rendering. + + Args: + unique_chars_or_alphabet: Either an alphabet name (str looked up in + CHAR_ALPHABETS), a literal character string, or a set/frozenset + of characters. + cell_size: Pixel size of each cell. + + Returns: + (atlas_array, char_to_idx) where + atlas_array: (num_chars, cell_size, cell_size) uint8 masks + char_to_idx: dict mapping character -> index in atlas_array + """ + if isinstance(unique_chars_or_alphabet, (set, frozenset)): + chars_tuple = tuple(sorted(unique_chars_or_alphabet)) + else: + resolved = CHAR_ALPHABETS.get(unique_chars_or_alphabet, unique_chars_or_alphabet) + chars_tuple = tuple(resolved) + + cache_key = (chars_tuple, cell_size) + cached = _char_atlas_cache.get(cache_key) + if cached is not None: + return cached + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + n = len(chars_tuple) + atlas = np.zeros((n, cell_size, cell_size), dtype=np.uint8) + char_to_idx = {} + + for i, char in enumerate(chars_tuple): + char_to_idx[char] = i + if char and char != ' ': + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + cv2.putText(atlas[i], char, (text_x, text_y), + font, font_scale, 255, thickness, cv2.LINE_AA) + except Exception: + pass + + # Evict oldest entry if cache is full + if len(_char_atlas_cache) >= _CHAR_ATLAS_CACHE_MAX: + _char_atlas_cache.pop(next(iter(_char_atlas_cache))) + + _char_atlas_cache[cache_key] = (atlas, char_to_idx) + return atlas, char_to_idx + + +def prim_cell_sample(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample image into cell grid, returning average colors and luminances. + + Uses cv2.resize with INTER_AREA (pixel-area averaging) which is + ~25x faster than numpy reshape+mean for block downsampling. + + Args: + img: source image + cell_size: size of each cell in pixels + + Returns: (colors, luminances) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32)) + + # Crop to exact grid then block-average via cv2 area interpolation. + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + return (colors, luminances) + + +def cell_sample_extended(img: np.ndarray, cell_size: int) -> Tuple[np.ndarray, np.ndarray, List[List[ZoneContext]]]: + """ + Sample image into cell grid, returning colors, luminances, and full zone contexts. + + Args: + img: source image (RGB) + cell_size: size of each cell in pixels + + Returns: (colors, luminances, zone_contexts) tuple + - colors: (rows, cols, 3) array of average RGB per cell + - luminances: (rows, cols) array of average brightness 0-255 + - zone_contexts: 2D list of ZoneContext objects with full cell data + """ + cell_size = max(1, int(cell_size)) + h, w = img.shape[:2] + rows = h // cell_size + cols = w // cell_size + + if rows < 1 or cols < 1: + return (np.zeros((1, 1, 3), dtype=np.uint8), + np.zeros((1, 1), dtype=np.float32), + [[ZoneContext(0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]]) + + # Crop to grid + grid_h, grid_w = rows * cell_size, cols * cell_size + cropped = img[:grid_h, :grid_w] + + # Reshape and average + reshaped = cropped.reshape(rows, cell_size, cols, cell_size, 3) + colors = reshaped.mean(axis=(1, 3)).astype(np.uint8) + + # Compute luminance (0-255) + luminances = (0.299 * colors[:, :, 0] + + 0.587 * colors[:, :, 1] + + 0.114 * colors[:, :, 2]).astype(np.float32) + + # Normalize colors to 0-1 for HSV/saturation calculations + colors_float = colors.astype(np.float32) / 255.0 + + # Compute HSV values for each cell + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + diff = max_c - min_c + + # Saturation + saturation = np.where(max_c > 0, diff / max_c, 0) + + # Hue (0-360) + hue = np.zeros((rows, cols), dtype=np.float32) + # Avoid division by zero + mask = diff > 0 + r, g, b = colors_float[:, :, 0], colors_float[:, :, 1], colors_float[:, :, 2] + + # Red is max + red_max = mask & (max_c == r) + hue[red_max] = 60 * (((g[red_max] - b[red_max]) / diff[red_max]) % 6) + + # Green is max + green_max = mask & (max_c == g) + hue[green_max] = 60 * ((b[green_max] - r[green_max]) / diff[green_max] + 2) + + # Blue is max + blue_max = mask & (max_c == b) + hue[blue_max] = 60 * ((r[blue_max] - g[blue_max]) / diff[blue_max] + 4) + + # Ensure hue is in 0-360 range + hue = hue % 360 + + # Build zone contexts + zone_contexts = [] + for row in range(rows): + row_contexts = [] + for col in range(cols): + ctx = ZoneContext( + row=row, + col=col, + row_norm=row / max(1, rows - 1) if rows > 1 else 0.5, + col_norm=col / max(1, cols - 1) if cols > 1 else 0.5, + luminance=luminances[row, col] / 255.0, # Normalize to 0-1 + saturation=float(saturation[row, col]), + hue=float(hue[row, col]), + r=float(colors_float[row, col, 0]), + g=float(colors_float[row, col, 1]), + b=float(colors_float[row, col, 2]), + ) + row_contexts.append(ctx) + zone_contexts.append(row_contexts) + + return (colors, luminances, zone_contexts) + + +def prim_luminance_to_chars(luminances: np.ndarray, alphabet: str, contrast: float = 1.0) -> List[List[str]]: + """ + Map luminance values to characters from alphabet. + + Args: + luminances: (rows, cols) array of brightness values 0-255 + alphabet: character set name or literal string (light to dark) + contrast: contrast boost factor + + Returns: 2D list of single-character strings + """ + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + num_chars = len(chars) + + # Apply contrast + lum = luminances.astype(np.float32) + if contrast != 1.0: + lum = (lum - 128) * contrast + 128 + lum = np.clip(lum, 0, 255) + + # Map to indices + indices = ((lum / 255) * (num_chars - 1)).astype(np.int32) + indices = np.clip(indices, 0, num_chars - 1) + + # Vectorised conversion via numpy char array lookup + chars_arr = np.array(list(chars)) + char_grid = chars_arr[indices.ravel()].reshape(indices.shape) + + return char_grid.tolist() + + +def prim_render_char_grid(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + cell_size: int, color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False) -> np.ndarray: + """ + Render a grid of characters onto an image. + + Uses vectorised numpy operations instead of per-cell Python loops: + the character atlas is looked up via fancy indexing and the full + mask + colour image are assembled in bulk. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + cell_size: size of each cell + color_mode: "color" (original colors), "mono" (white), "invert", + or any color name/hex value ("green", "lime", "#00ff00") + background_color: background color name/hex ("black", "navy", "#001100") + invert_colors: if True, swap foreground and background colors + + Returns: rendered image + """ + # Parse color_mode - may be a named color or hex value + fg_color = parse_color(color_mode) + + # Parse background_color + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + # Handle invert_colors - swap fg and bg + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + + # --- Build atlas & index grid --- + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + atlas, char_to_idx = _get_render_atlas(unique_chars, cell_size) + + # Convert 2D char list to index array using ordinal lookup table + # (avoids per-cell Python dict lookup). + space_idx = char_to_idx.get(' ', 0) + max_ord = max(ord(ch) for ch in char_to_idx) + 1 + ord_lookup = np.full(max_ord, space_idx, dtype=np.int32) + for ch, idx in char_to_idx.items(): + if ch: + ord_lookup[ord(ch)] = idx + + flat = [ch for row in chars for ch in row] + ords = np.frombuffer(np.array(flat, dtype='U1'), dtype=np.uint32) + char_indices = ord_lookup[ords].reshape(rows, cols) + + # --- Vectorised mask assembly --- + # atlas[char_indices] -> (rows, cols, cell_size, cell_size) + # Transpose to (rows, cell_size, cols, cell_size) then reshape to full image. + all_masks = atlas[char_indices] + full_mask = all_masks.transpose(0, 2, 1, 3).reshape(h, w) + + # Expand per-cell colours to per-pixel (only when needed). + need_color_full = (color_mode in ("color", "invert") + or (fg_color is None and color_mode != "mono")) + + if need_color_full: + color_full = np.repeat( + np.repeat(colors[:rows, :cols], cell_size, axis=0), + cell_size, axis=1) + + # --- Vectorised colour composite --- + # Use element-wise multiply/np.where instead of boolean-indexed scatter + # for much better memory access patterns. + mask_u8 = (full_mask > 0).astype(np.uint8)[:, :, np.newaxis] + + if color_mode == "invert": + # Background is source colour; characters are black. + # result = color_full * (1 - mask) + result = color_full * (1 - mask_u8) + elif fg_color is not None: + # Fixed foreground colour on background. + fg = np.array(fg_color, dtype=np.uint8) + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, fg, bg_arr).astype(np.uint8) + elif color_mode == "mono": + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, np.uint8(255), bg_arr).astype(np.uint8) + else: + # "color" mode – each cell uses its source colour on bg. + if bg == [0, 0, 0]: + result = color_full * mask_u8 + else: + bg_arr = np.array(bg, dtype=np.uint8) + result = np.where(mask_u8, color_full, bg_arr).astype(np.uint8) + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_render_char_grid_fx(img: np.ndarray, chars: List[List[str]], colors: np.ndarray, + luminances: np.ndarray, cell_size: int, + color_mode: str = "color", + background_color: str = "black", + invert_colors: bool = False, + char_jitter: float = 0.0, + char_scale: float = 1.0, + char_rotation: float = 0.0, + char_hue_shift: float = 0.0, + jitter_source: str = "none", + scale_source: str = "none", + rotation_source: str = "none", + hue_source: str = "none") -> np.ndarray: + """ + Render a grid of characters with per-character effects. + + Args: + img: source image (for dimensions) + chars: 2D list of single characters + colors: (rows, cols, 3) array of colors per cell + luminances: (rows, cols) array of luminance values (0-255) + cell_size: size of each cell + color_mode: "color", "mono", "invert", or any color name/hex + background_color: background color name/hex + invert_colors: if True, swap foreground and background colors + char_jitter: base jitter amount in pixels + char_scale: base scale factor (1.0 = normal) + char_rotation: base rotation in degrees + char_hue_shift: base hue shift in degrees (0-360) + jitter_source: source for jitter modulation ("none", "luminance", "position", "random") + scale_source: source for scale modulation + rotation_source: source for rotation modulation + hue_source: source for hue modulation + + Per-character effect sources: + "none" - use base value only + "luminance" - modulate by cell luminance (0-1) + "inv_luminance" - modulate by inverse luminance (dark = high) + "saturation" - modulate by cell color saturation + "position_x" - modulate by horizontal position (0-1) + "position_y" - modulate by vertical position (0-1) + "position_diag" - modulate by diagonal position + "random" - random per-cell value (deterministic from position) + "center_dist" - distance from center (0=center, 1=corner) + + Returns: rendered image + """ + # Parse colors + fg_color = parse_color(color_mode) + + if isinstance(background_color, (list, tuple)): + bg_color = tuple(int(c) for c in background_color[:3]) + else: + bg_color = parse_color(background_color) + if bg_color is None: + bg_color = (0, 0, 0) + + if invert_colors and fg_color is not None: + fg_color, bg_color = bg_color, fg_color + + cell_size = max(1, int(cell_size)) + + if not chars or not chars[0]: + return img.copy() + + rows = len(chars) + cols = len(chars[0]) + h, w = rows * cell_size, cols * cell_size + + bg = list(bg_color) + result = np.full((h, w, 3), bg, dtype=np.uint8) + + # Normalize luminances to 0-1 + lum_normalized = luminances.astype(np.float32) / 255.0 + + # Compute saturation from colors + colors_float = colors.astype(np.float32) / 255.0 + max_c = colors_float.max(axis=2) + min_c = colors_float.min(axis=2) + saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0) + + # Helper to get modulation value for a cell + def get_mod_value(source: str, r: int, c: int) -> float: + if source == "none": + return 1.0 + elif source == "luminance": + return lum_normalized[r, c] + elif source == "inv_luminance": + return 1.0 - lum_normalized[r, c] + elif source == "saturation": + return saturation[r, c] + elif source == "position_x": + return c / max(1, cols - 1) if cols > 1 else 0.5 + elif source == "position_y": + return r / max(1, rows - 1) if rows > 1 else 0.5 + elif source == "position_diag": + px = c / max(1, cols - 1) if cols > 1 else 0.5 + py = r / max(1, rows - 1) if rows > 1 else 0.5 + return (px + py) / 2.0 + elif source == "random": + # Deterministic random based on position + seed = (r * 1000 + c) % 10000 + return ((seed * 9301 + 49297) % 233280) / 233280.0 + elif source == "center_dist": + cx, cy = (cols - 1) / 2.0, (rows - 1) / 2.0 + dx = (c - cx) / max(1, cx) if cx > 0 else 0 + dy = (r - cy) / max(1, cy) if cy > 0 else 0 + return min(1.0, math.sqrt(dx*dx + dy*dy)) + else: + return 1.0 + + # Build character atlas at base size + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, we need to render characters larger then transform + max_scale = max(1.0, char_scale * 1.5) # Allow headroom for scaling + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with effects + for r in range(rows): + for c in range(cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell modulation values + jitter_mod = get_mod_value(jitter_source, r, c) + scale_mod = get_mod_value(scale_source, r, c) + rot_mod = get_mod_value(rotation_source, r, c) + hue_mod = get_mod_value(hue_source, r, c) + + # Compute effective values + eff_jitter = char_jitter * jitter_mod + eff_scale = char_scale * (0.5 + 0.5 * scale_mod) if scale_source != "none" else char_scale + eff_rotation = char_rotation * (rot_mod * 2 - 1) # -1 to 1 range + eff_hue_shift = char_hue_shift * hue_mod + + # Apply transformations + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + # Fill cell with source color first + cy1 = max(0, r * cell_size) + cy2 = min(h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + # Convert to HSV, shift hue, convert back + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + # Cast to int to avoid uint8 overflow, then back to uint8 + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < h and 0 <= px < w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = img.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(h, orig_h) + copy_w = min(w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def _render_with_cell_effect( + frame: np.ndarray, + chars: List[List[str]], + colors: np.ndarray, + luminances: np.ndarray, + zone_contexts: List[List['ZoneContext']], + cell_size: int, + bg_color: tuple, + fg_color: tuple, + color_mode: str, + cell_effect, # Lambda or callable: (cell_image, zone_dict) -> cell_image + extra_params: dict, + interp, + env, + result: np.ndarray, +) -> np.ndarray: + """ + Render ASCII art using a cell_effect lambda for arbitrary per-cell transforms. + + Each character is rendered to a cell image, the cell_effect is called with + (cell_image, zone_dict), and the returned cell is composited into result. + + This allows arbitrary effects (rotate, blur, etc.) to be applied per-character. + """ + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + out_h, out_w = result.shape[:2] + + # Build character atlas (cell-sized colored characters on transparent bg) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + # Helper to render a single character cell + def render_char_cell(char: str, color: np.ndarray) -> np.ndarray: + """Render a character onto a cell-sized RGB image.""" + cell = np.full((cell_size, cell_size, 3), bg_color, dtype=np.uint8) + if not char or char == ' ': + return cell + + try: + (text_w, text_h), _ = cv2.getTextSize(char, font, font_scale, thickness) + text_x = max(0, (cell_size - text_w) // 2) + text_y = (cell_size + text_h) // 2 + + # Render character in white on mask, then apply color + mask = np.zeros((cell_size, cell_size), dtype=np.uint8) + cv2.putText(mask, char, (text_x, text_y), font, font_scale, 255, thickness, cv2.LINE_AA) + + # Apply color where mask is set + for ch in range(3): + cell[:, :, ch] = np.where(mask > 0, color[ch], bg_color[ch]) + except: + pass + + return cell + + # Helper to evaluate cell_effect (handles artdag Lambda objects) + def eval_cell_effect(cell_img: np.ndarray, zone_dict: dict) -> np.ndarray: + """Call cell_effect with (cell_image, zone_dict), handle Lambda objects.""" + if callable(cell_effect): + return cell_effect(cell_img, zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(cell_effect, ArtdagLambda): + # Build env with closure values + eval_env = dict(cell_effect.closure) if cell_effect.closure else {} + # Bind lambda parameters + if len(cell_effect.params) >= 2: + eval_env[cell_effect.params[0]] = cell_img + eval_env[cell_effect.params[1]] = zone_dict + elif len(cell_effect.params) == 1: + # Single param gets zone_dict with cell as 'cell' key + zone_dict['cell'] = cell_img + eval_env[cell_effect.params[0]] = zone_dict + + # Add primitives to eval env + eval_env.update(PRIMITIVES) + + # Add effect runner - allows calling any loaded sexp effect on a cell + # Usage: (apply-effect "effect_name" cell {"param" value ...}) + # Or: (apply-effect "effect_name" cell) for defaults + def apply_effect_fn(effect_name, frame, params=None): + """Run a loaded sexp effect on a frame (cell).""" + if interp and hasattr(interp, 'run_effect'): + if params is None: + params = {} + result, _ = interp.run_effect(effect_name, frame, params, {}) + return result + return frame + eval_env['apply-effect'] = apply_effect_fn + + # Also inject loaded effects directly as callable functions + # These wrappers take positional args in common order for each effect + # Usage: (blur cell 5) or (rotate cell 45) etc. + if interp and hasattr(interp, 'effects'): + for effect_name in interp.effects: + # Create a wrapper that calls run_effect with positional-to-named mapping + def make_effect_fn(name): + def effect_fn(frame, *args): + # Map common positional args to named params + params = {} + if name == 'blur' and len(args) >= 1: + params['radius'] = args[0] + elif name == 'rotate' and len(args) >= 1: + params['angle'] = args[0] + elif name == 'brightness' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'contrast' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'saturation' and len(args) >= 1: + params['factor'] = args[0] + elif name == 'hue_shift' and len(args) >= 1: + params['degrees'] = args[0] + elif name == 'rgb_split' and len(args) >= 1: + params['offset_x'] = args[0] + if len(args) >= 2: + params['offset_y'] = args[1] + elif name == 'pixelate' and len(args) >= 1: + params['block_size'] = args[0] + elif name == 'wave' and len(args) >= 1: + params['amplitude'] = args[0] + if len(args) >= 2: + params['frequency'] = args[1] + elif name == 'noise' and len(args) >= 1: + params['amount'] = args[0] + elif name == 'posterize' and len(args) >= 1: + params['levels'] = args[0] + elif name == 'threshold' and len(args) >= 1: + params['level'] = args[0] + elif name == 'sharpen' and len(args) >= 1: + params['amount'] = args[0] + elif len(args) == 1 and isinstance(args[0], dict): + # Accept dict as single arg + params = args[0] + result, _ = interp.run_effect(name, frame, params, {}) + return result + return effect_fn + eval_env[effect_name] = make_effect_fn(effect_name) + + result = artdag_evaluate(cell_effect.body, eval_env) + if isinstance(result, np.ndarray): + return result + return cell_img + except ImportError: + pass + + # Fallback: return cell unchanged + return cell_img + + # Render each cell + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + zone = zone_contexts[r][c] + + # Determine character color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + color = np.array([0, 0, 0], dtype=np.uint8) + else: + color = colors[r, c].copy() + + # Render character to cell image + cell_img = render_char_cell(char, color) + + # Build zone dict + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + 'color': color.tolist(), + 'cell_size': cell_size, + } + # Add extra params (energy, rotation_scale, etc.) + if extra_params: + zone_dict.update(extra_params) + + # Call cell_effect + modified_cell = eval_cell_effect(cell_img, zone_dict) + + # Ensure result is valid + if modified_cell is None or not isinstance(modified_cell, np.ndarray): + modified_cell = cell_img + if modified_cell.shape[:2] != (cell_size, cell_size): + # Resize if cell size changed + modified_cell = cv2.resize(modified_cell, (cell_size, cell_size)) + if len(modified_cell.shape) == 2: + # Convert grayscale to RGB + modified_cell = cv2.cvtColor(modified_cell, cv2.COLOR_GRAY2RGB) + + # Composite into result + y1 = r * cell_size + x1 = c * cell_size + y2 = min(y1 + cell_size, out_h) + x2 = min(x1 + cell_size, out_w) + ch = y2 - y1 + cw = x2 - x1 + result[y1:y2, x1:x2] = modified_cell[:ch, :cw] + + # Resize to match original frame if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + bg = list(bg_color) + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_ascii_fx_zone( + frame: np.ndarray, + cols: int, + char_size_override: int, # If set, overrides cols-based calculation + alphabet: str, + color_mode: str, + background: str, + contrast: float, + char_hue_expr, # Expression, literal, or None + char_sat_expr, # Expression, literal, or None + char_bright_expr, # Expression, literal, or None + char_scale_expr, # Expression, literal, or None + char_rotation_expr, # Expression, literal, or None + char_jitter_expr, # Expression, literal, or None + interp, # Interpreter for expression evaluation + env, # Environment with bound values + extra_params=None, # Extra params to include in zone dict for lambdas + cell_effect=None, # Lambda (cell_image, zone_dict) -> cell_image for arbitrary cell effects +) -> np.ndarray: + """ + Render ASCII art with per-zone expression-driven transforms. + + Args: + frame: Source image (H, W, 3) RGB uint8 + cols: Number of character columns + char_size_override: If set, use this cell size instead of cols-based + alphabet: Character set name or literal string + color_mode: "color", "mono", "invert", or color name/hex + background: Background color name or hex + contrast: Contrast boost for character selection + char_hue_expr: Expression for hue shift (evaluated per zone) + char_sat_expr: Expression for saturation adjustment (evaluated per zone) + char_bright_expr: Expression for brightness adjustment (evaluated per zone) + char_scale_expr: Expression for scale factor (evaluated per zone) + char_rotation_expr: Expression for rotation degrees (evaluated per zone) + char_jitter_expr: Expression for position jitter (evaluated per zone) + interp: Interpreter instance for expression evaluation + env: Environment with bound variables + cell_effect: Optional lambda that receives (cell_image, zone_dict) and returns + a modified cell_image. When provided, each character is rendered + to a cell image, passed to this lambda, and the result composited. + This allows arbitrary effects to be applied per-character. + + Zone variables available in expressions: + zone-row, zone-col: Grid position (integers) + zone-row-norm, zone-col-norm: Normalized position (0-1) + zone-lum: Cell luminance (0-1) + zone-sat: Cell saturation (0-1) + zone-hue: Cell hue (0-360) + zone-r, zone-g, zone-b: RGB components (0-1) + + Returns: Rendered image + """ + h, w = frame.shape[:2] + # Use char_size if provided, otherwise calculate from cols + if char_size_override is not None: + cell_size = max(4, int(char_size_override)) + else: + cell_size = max(4, w // cols) + + # Get zone data using extended sampling + colors, luminances, zone_contexts = cell_sample_extended(frame, cell_size) + + # Convert luminances to characters + chars = prim_luminance_to_chars(luminances, alphabet, contrast) + + grid_rows = len(chars) + grid_cols = len(chars[0]) if chars else 0 + + # Parse colors + fg_color = parse_color(color_mode) + if isinstance(background, (list, tuple)): + bg_color = tuple(int(c) for c in background[:3]) + else: + bg_color = parse_color(background) + if bg_color is None: + bg_color = (0, 0, 0) + + # Arrays for per-zone transform values + hue_shifts = np.zeros((grid_rows, grid_cols), dtype=np.float32) + saturations = np.ones((grid_rows, grid_cols), dtype=np.float32) + brightness = np.ones((grid_rows, grid_cols), dtype=np.float32) + scales = np.ones((grid_rows, grid_cols), dtype=np.float32) + rotations = np.zeros((grid_rows, grid_cols), dtype=np.float32) + jitters = np.zeros((grid_rows, grid_cols), dtype=np.float32) + + # Helper to evaluate expression or return literal value + def eval_expr(expr, zone, char): + if expr is None: + return None + if isinstance(expr, (int, float)): + return expr + + # Build zone dict for lambda calls + zone_dict = { + 'row': zone.row, + 'col': zone.col, + 'row-norm': zone.row_norm, + 'col-norm': zone.col_norm, + 'lum': zone.luminance, + 'sat': zone.saturation, + 'hue': zone.hue, + 'r': zone.r, + 'g': zone.g, + 'b': zone.b, + 'char': char, + } + # Add extra params (energy, rotation_scale, etc.) for lambdas to access + if extra_params: + zone_dict.update(extra_params) + + # Check if it's a Python callable + if callable(expr): + return expr(zone_dict) + + # Check if it's an artdag Lambda object + try: + from artdag.sexp.parser import Lambda as ArtdagLambda + from artdag.sexp.evaluator import evaluate as artdag_evaluate + if isinstance(expr, ArtdagLambda): + # Build env with zone dict and any closure values + eval_env = dict(expr.closure) if expr.closure else {} + # Bind the lambda parameter to zone_dict + if expr.params: + eval_env[expr.params[0]] = zone_dict + return artdag_evaluate(expr.body, eval_env) + except ImportError: + pass + + # It's an expression - evaluate with zone context (sexp_effects style) + return interp.eval_with_zone(expr, env, zone) + + # Evaluate expressions for each zone + for r in range(grid_rows): + for c in range(grid_cols): + zone = zone_contexts[r][c] + char = chars[r][c] + + val = eval_expr(char_hue_expr, zone, char) + if val is not None: + hue_shifts[r, c] = float(val) + + val = eval_expr(char_sat_expr, zone, char) + if val is not None: + saturations[r, c] = float(val) + + val = eval_expr(char_bright_expr, zone, char) + if val is not None: + brightness[r, c] = float(val) + + val = eval_expr(char_scale_expr, zone, char) + if val is not None: + scales[r, c] = float(val) + + val = eval_expr(char_rotation_expr, zone, char) + if val is not None: + rotations[r, c] = float(val) + + val = eval_expr(char_jitter_expr, zone, char) + if val is not None: + jitters[r, c] = float(val) + + # Now render with computed transform arrays + out_h, out_w = grid_rows * cell_size, grid_cols * cell_size + bg = list(bg_color) + result = np.full((out_h, out_w, 3), bg, dtype=np.uint8) + + # If cell_effect is provided, use the cell-mapper rendering path + if cell_effect is not None: + return _render_with_cell_effect( + frame, chars, colors, luminances, zone_contexts, + cell_size, bg_color, fg_color, color_mode, + cell_effect, extra_params, interp, env, result + ) + + # Build character atlas + font = cv2.FONT_HERSHEY_SIMPLEX + base_font_scale = cell_size / 20.0 + thickness = max(1, int(cell_size / 10)) + + unique_chars = set() + for row in chars: + for ch in row: + unique_chars.add(ch) + + # For rotation/scale, render characters larger then transform + max_scale = max(1.0, np.max(scales) * 1.5) + atlas_size = int(cell_size * max_scale * 1.5) + + atlas = {} + for char in unique_chars: + if char and char != ' ': + try: + char_img = np.zeros((atlas_size, atlas_size), dtype=np.uint8) + scaled_font = base_font_scale * max_scale + (text_w, text_h), _ = cv2.getTextSize(char, font, scaled_font, thickness) + text_x = max(0, (atlas_size - text_w) // 2) + text_y = (atlas_size + text_h) // 2 + cv2.putText(char_img, char, (text_x, text_y), font, scaled_font, 255, thickness, cv2.LINE_AA) + atlas[char] = char_img + except: + atlas[char] = None + else: + atlas[char] = None + + # Render characters with per-zone effects + for r in range(grid_rows): + for c in range(grid_cols): + char = chars[r][c] + if not char or char == ' ': + continue + + char_img = atlas.get(char) + if char_img is None: + continue + + # Get per-cell values + eff_scale = scales[r, c] + eff_rotation = rotations[r, c] + eff_jitter = jitters[r, c] + eff_hue_shift = hue_shifts[r, c] + eff_brightness = brightness[r, c] + eff_saturation = saturations[r, c] + + # Apply transformations to character + transformed = char_img.copy() + + # Rotation + if abs(eff_rotation) > 0.5: + center = (atlas_size // 2, atlas_size // 2) + rot_matrix = cv2.getRotationMatrix2D(center, eff_rotation, 1.0) + transformed = cv2.warpAffine(transformed, rot_matrix, (atlas_size, atlas_size)) + + # Scale - resize to target size + target_size = max(1, int(cell_size * eff_scale)) + if target_size != atlas_size: + transformed = cv2.resize(transformed, (target_size, target_size), interpolation=cv2.INTER_LINEAR) + + # Compute position with jitter + base_y = r * cell_size + base_x = c * cell_size + + if eff_jitter > 0: + # Deterministic jitter based on position + jx = ((r * 7 + c * 13) % 100) / 100.0 - 0.5 + jy = ((r * 11 + c * 17) % 100) / 100.0 - 0.5 + base_x += int(jx * eff_jitter * 2) + base_y += int(jy * eff_jitter * 2) + + # Center the character in the cell + offset = (target_size - cell_size) // 2 + y1 = base_y - offset + x1 = base_x - offset + + # Determine color + if fg_color is not None: + color = np.array(fg_color, dtype=np.uint8) + elif color_mode == "mono": + color = np.array([255, 255, 255], dtype=np.uint8) + elif color_mode == "invert": + cy1 = max(0, r * cell_size) + cy2 = min(out_h, (r + 1) * cell_size) + cx1 = max(0, c * cell_size) + cx2 = min(out_w, (c + 1) * cell_size) + result[cy1:cy2, cx1:cx2] = colors[r, c] + color = np.array([0, 0, 0], dtype=np.uint8) + else: # color mode - use source colors + color = colors[r, c].copy() + + # Apply hue shift + if abs(eff_hue_shift) > 0.5 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_hue = (int(color_hsv[0, 0, 0]) + int(eff_hue_shift * 180 / 360)) % 180 + color_hsv[0, 0, 0] = np.uint8(new_hue) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply saturation adjustment + if abs(eff_saturation - 1.0) > 0.01 and color_mode not in ("mono", "invert") and fg_color is None: + color_hsv = cv2.cvtColor(color.reshape(1, 1, 3), cv2.COLOR_RGB2HSV) + new_sat = np.clip(int(color_hsv[0, 0, 1] * eff_saturation), 0, 255) + color_hsv[0, 0, 1] = np.uint8(new_sat) + color = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB).flatten() + + # Apply brightness adjustment + if abs(eff_brightness - 1.0) > 0.01: + color = np.clip(color.astype(np.float32) * eff_brightness, 0, 255).astype(np.uint8) + + # Blit character to result + mask = transformed > 0 + th, tw = transformed.shape[:2] + + for dy in range(th): + for dx in range(tw): + py = y1 + dy + px = x1 + dx + if 0 <= py < out_h and 0 <= px < out_w and mask[dy, dx]: + result[py, px] = color + + # Resize to match original if needed + orig_h, orig_w = frame.shape[:2] + if result.shape[0] != orig_h or result.shape[1] != orig_w: + padded = np.full((orig_h, orig_w, 3), bg, dtype=np.uint8) + copy_h = min(out_h, orig_h) + copy_w = min(out_w, orig_w) + padded[:copy_h, :copy_w] = result[:copy_h, :copy_w] + result = padded + + return result + + +def prim_make_char_grid(rows: int, cols: int, fill_char: str = " ") -> List[List[str]]: + """Create a character grid filled with a character.""" + return [[fill_char for _ in range(cols)] for _ in range(rows)] + + +def prim_set_char(chars: List[List[str]], row: int, col: int, char: str) -> List[List[str]]: + """Set a character at position (returns modified copy).""" + result = [r[:] for r in chars] # shallow copy rows + if 0 <= row < len(result) and 0 <= col < len(result[0]): + result[row][col] = char + return result + + +def prim_get_char(chars: List[List[str]], row: int, col: int) -> str: + """Get character at position.""" + if 0 <= row < len(chars) and 0 <= col < len(chars[0]): + return chars[row][col] + return " " + + +def prim_char_grid_dimensions(chars: List[List[str]]) -> Tuple[int, int]: + """Get (rows, cols) of character grid.""" + if not chars: + return (0, 0) + return (len(chars), len(chars[0]) if chars[0] else 0) + + +def prim_alphabet_char(alphabet: str, index: int) -> str: + """Get character at index from alphabet (wraps around).""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + if not chars: + return " " + return chars[int(index) % len(chars)] + + +def prim_alphabet_length(alphabet: str) -> int: + """Get length of alphabet.""" + chars = CHAR_ALPHABETS.get(alphabet, alphabet) + return len(chars) + + +def prim_map_char_grid(chars: List[List[str]], luminances: np.ndarray, fn: Callable) -> List[List[str]]: + """ + Map a function over character grid. + + fn receives (row, col, char, luminance) and returns new character. + This allows per-cell character selection based on position, brightness, etc. + + Example: + (map-char-grid chars luminances + (lambda (r c ch lum) + (if (> lum 128) + (alphabet-char "blocks" (floor (/ lum 50))) + ch))) + """ + if not chars or not chars[0]: + return chars + + rows = len(chars) + cols = len(chars[0]) + result = [] + + for r in range(rows): + row = [] + for c in range(cols): + ch = chars[r][c] + lum = float(luminances[r, c]) if r < luminances.shape[0] and c < luminances.shape[1] else 0 + new_ch = fn(r, c, ch, lum) + row.append(str(new_ch) if new_ch else " ") + result.append(row) + + return result + + +def prim_map_colors(colors: np.ndarray, fn: Callable) -> np.ndarray: + """ + Map a function over color grid. + + fn receives (row, col, color) and returns new [r, g, b]. + Color is a list [r, g, b]. + """ + if colors.size == 0: + return colors + + rows, cols = colors.shape[:2] + result = colors.copy() + + for r in range(rows): + for c in range(cols): + color = list(colors[r, c]) + new_color = fn(r, c, color) + if new_color is not None: + result[r, c] = new_color[:3] + + return result + + +# ============================================================================= +# Glitch Art Primitives +# ============================================================================= + +def prim_pixelsort(img: np.ndarray, sort_by: str = "lightness", + threshold_low: float = 50, threshold_high: float = 200, + angle: float = 0, reverse: bool = False) -> np.ndarray: + """ + Pixel sorting glitch effect. + + Args: + img: source image + sort_by: "lightness", "hue", "saturation", "red", "green", "blue" + threshold_low: pixels below this aren't sorted + threshold_high: pixels above this aren't sorted + angle: 0 = horizontal, 90 = vertical + reverse: reverse sort order + """ + h, w = img.shape[:2] + + # Rotate for vertical sorting + if 45 <= (angle % 180) <= 135: + frame = np.transpose(img, (1, 0, 2)) + h, w = frame.shape[:2] + rotated = True + else: + frame = img + rotated = False + + result = frame.copy() + + # Get sort values + if sort_by == "lightness": + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + elif sort_by == "hue": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 0].astype(np.float32) + elif sort_by == "saturation": + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + sort_values = hsv[:, :, 1].astype(np.float32) + elif sort_by == "red": + sort_values = frame[:, :, 0].astype(np.float32) + elif sort_by == "green": + sort_values = frame[:, :, 1].astype(np.float32) + elif sort_by == "blue": + sort_values = frame[:, :, 2].astype(np.float32) + else: + sort_values = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32) + + # Create mask + mask = (sort_values >= threshold_low) & (sort_values <= threshold_high) + + # Sort each row + for y in range(h): + row = result[y].copy() + row_mask = mask[y] + row_values = sort_values[y] + + # Find contiguous segments + segments = [] + start = None + for i, val in enumerate(row_mask): + if val and start is None: + start = i + elif not val and start is not None: + segments.append((start, i)) + start = None + if start is not None: + segments.append((start, len(row_mask))) + + # Sort each segment + for seg_start, seg_end in segments: + if seg_end - seg_start > 1: + segment_values = row_values[seg_start:seg_end] + sort_indices = np.argsort(segment_values) + if reverse: + sort_indices = sort_indices[::-1] + row[seg_start:seg_end] = row[seg_start:seg_end][sort_indices] + + result[y] = row + + # Rotate back + if rotated: + result = np.transpose(result, (1, 0, 2)) + + return np.ascontiguousarray(result) + + +def prim_datamosh(img: np.ndarray, prev_frame: np.ndarray, + block_size: int = 32, corruption: float = 0.3, + max_offset: int = 50, color_corrupt: bool = True) -> np.ndarray: + """ + Datamosh/glitch block corruption effect. + + Args: + img: current frame + prev_frame: previous frame (or None) + block_size: size of corruption blocks + corruption: probability 0-1 of corrupting each block + max_offset: maximum pixel shift + color_corrupt: also apply color channel shifts + """ + if corruption <= 0: + return img.copy() + + block_size = max(8, min(int(block_size), 128)) + h, w = img.shape[:2] + result = img.copy() + + for by in range(0, h, block_size): + for bx in range(0, w, block_size): + bh = min(block_size, h - by) + bw = min(block_size, w - bx) + + if _rng.random() < corruption: + corruption_type = _rng.randint(0, 3) + + if corruption_type == 0 and max_offset > 0: + # Shift + ox = _rng.randint(-max_offset, max_offset) + oy = _rng.randint(-max_offset, max_offset) + src_x = max(0, min(bx + ox, w - bw)) + src_y = max(0, min(by + oy, h - bh)) + result[by:by+bh, bx:bx+bw] = img[src_y:src_y+bh, src_x:src_x+bw] + + elif corruption_type == 1 and prev_frame is not None: + # Duplicate from previous frame + if prev_frame.shape == img.shape: + result[by:by+bh, bx:bx+bw] = prev_frame[by:by+bh, bx:bx+bw] + + elif corruption_type == 2 and color_corrupt: + # Color channel shift + block = result[by:by+bh, bx:bx+bw].copy() + shift = _rng.randint(1, 3) + channel = _rng.randint(0, 2) + block[:, :, channel] = np.roll(block[:, :, channel], shift, axis=0) + result[by:by+bh, bx:bx+bw] = block + + else: + # Swap with another block + other_bx = _rng.randint(0, max(0, w - bw)) + other_by = _rng.randint(0, max(0, h - bh)) + temp = result[by:by+bh, bx:bx+bw].copy() + result[by:by+bh, bx:bx+bw] = img[other_by:other_by+bh, other_bx:other_bx+bw] + result[other_by:other_by+bh, other_bx:other_bx+bw] = temp + + return result + + +def prim_ripple_displace(w: int, h: int, freq: float, amp: float, cx: float = None, cy: float = None, + decay: float = 0, phase: float = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Create radial ripple displacement maps. + + Args: + w, h: dimensions + freq: ripple frequency + amp: ripple amplitude in pixels + cx, cy: center + decay: how fast ripples decay with distance (0 = no decay) + phase: phase offset + + Returns: (map_x, map_y) for use with remap + """ + w, h = int(w), int(h) + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = np.sqrt(dx**2 + dy**2) + + # Calculate ripple displacement (radial) + ripple = np.sin(2 * np.pi * freq * dist / max(w, h) + phase) * amp + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max(w, h)) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return (map_x, map_y) + + +PRIMITIVES = { + # Arithmetic + '+': prim_add, + '-': prim_sub, + '*': prim_mul, + '/': prim_div, + + # Comparison + '<': prim_lt, + '>': prim_gt, + '<=': prim_le, + '>=': prim_ge, + '=': prim_eq, + '!=': prim_ne, + + # Image + 'width': prim_width, + 'height': prim_height, + 'make-image': prim_make_image, + 'copy': prim_copy, + 'pixel': prim_pixel, + 'set-pixel': prim_set_pixel, + 'sample': prim_sample, + 'channel': prim_channel, + 'merge-channels': prim_merge_channels, + 'resize': prim_resize, + 'crop': prim_crop, + 'paste': prim_paste, + + # Color + 'rgb': prim_rgb, + 'red': prim_red, + 'green': prim_green, + 'blue': prim_blue, + 'luminance': prim_luminance, + 'rgb->hsv': prim_rgb_to_hsv, + 'hsv->rgb': prim_hsv_to_rgb, + 'blend-color': prim_blend_color, + 'average-color': prim_average_color, + + # Vectorized bulk operations + 'color-matrix': prim_color_matrix, + 'adjust': prim_adjust, + 'mix-gray': prim_mix_gray, + 'invert-img': prim_invert_img, + 'add-noise': prim_add_noise, + 'quantize': prim_quantize, + 'shift-hsv': prim_shift_hsv, + + # Bulk operations + 'map-pixels': prim_map_pixels, + 'map-rows': prim_map_rows, + 'for-grid': prim_for_grid, + 'fold-pixels': prim_fold_pixels, + + # Filters + 'convolve': prim_convolve, + 'blur': prim_blur, + 'box-blur': prim_box_blur, + 'edges': prim_edges, + 'sobel': prim_sobel, + 'dilate': prim_dilate, + 'erode': prim_erode, + + # Geometry + 'translate': prim_translate, + 'rotate-img': prim_rotate, + 'scale-img': prim_scale, + 'flip-h': prim_flip_h, + 'flip-v': prim_flip_v, + 'remap': prim_remap, + 'make-coords': prim_make_coords, + + # Blending + 'blend-images': prim_blend_images, + 'blend-mode': prim_blend_mode, + 'mask': prim_mask, + + # Drawing + 'draw-char': prim_draw_char, + 'draw-text': prim_draw_text, + 'fill-rect': prim_fill_rect, + 'draw-line': prim_draw_line, + + # Math + 'sin': prim_sin, + 'cos': prim_cos, + 'tan': prim_tan, + 'atan2': prim_atan2, + 'sqrt': prim_sqrt, + 'pow': prim_pow, + 'abs': prim_abs, + 'floor': prim_floor, + 'ceil': prim_ceil, + 'round': prim_round, + 'min': prim_min, + 'max': prim_max, + 'clamp': prim_clamp, + 'lerp': prim_lerp, + 'mod': prim_mod, + 'random': prim_random, + 'randint': prim_randint, + 'gaussian': prim_gaussian, + 'assert': prim_assert, + 'pi': math.pi, + 'tau': math.tau, + + # Array + 'length': prim_length, + 'len': prim_length, # alias + 'nth': prim_nth, + 'first': prim_first, + 'rest': prim_rest, + 'take': prim_take, + 'drop': prim_drop, + 'cons': prim_cons, + 'append': prim_append, + 'reverse': prim_reverse, + 'range': prim_range, + 'roll': prim_roll, + 'list': prim_list, + + # Array math (vectorized operations on coordinate arrays) + 'arr+': prim_arr_add, + 'arr-': prim_arr_sub, + 'arr*': prim_arr_mul, + 'arr/': prim_arr_div, + 'arr-mod': prim_arr_mod, + 'arr-sin': prim_arr_sin, + 'arr-cos': prim_arr_cos, + 'arr-tan': prim_arr_tan, + 'arr-sqrt': prim_arr_sqrt, + 'arr-pow': prim_arr_pow, + 'arr-abs': prim_arr_abs, + 'arr-neg': prim_arr_neg, + 'arr-exp': prim_arr_exp, + 'arr-atan2': prim_arr_atan2, + 'arr-min': prim_arr_min, + 'arr-max': prim_arr_max, + 'arr-clip': prim_arr_clip, + 'arr-where': prim_arr_where, + 'arr-floor': prim_arr_floor, + 'arr-lerp': prim_arr_lerp, + + # Coordinate transformations + 'polar-from-center': prim_polar_from_center, + 'cart-from-polar': prim_cart_from_polar, + 'normalize-coords': prim_normalize_coords, + 'coords-x': prim_coords_x, + 'coords-y': prim_coords_y, + 'make-coords-centered': prim_make_coords_centered, + + # Specialized distortion maps + 'wave-displace': prim_wave_displace, + 'swirl-displace': prim_swirl_displace, + 'fisheye-displace': prim_fisheye_displace, + 'kaleidoscope-displace': prim_kaleidoscope_displace, + 'ripple-displace': prim_ripple_displace, + + # Character/ASCII art + 'cell-sample': prim_cell_sample, + 'cell-sample-extended': cell_sample_extended, + 'luminance-to-chars': prim_luminance_to_chars, + 'render-char-grid': prim_render_char_grid, + 'render-char-grid-fx': prim_render_char_grid_fx, + 'ascii-fx-zone': prim_ascii_fx_zone, + 'make-char-grid': prim_make_char_grid, + 'set-char': prim_set_char, + 'get-char': prim_get_char, + 'char-grid-dimensions': prim_char_grid_dimensions, + 'alphabet-char': prim_alphabet_char, + 'alphabet-length': prim_alphabet_length, + 'map-char-grid': prim_map_char_grid, + 'map-colors': prim_map_colors, + + # Glitch art + 'pixelsort': prim_pixelsort, + 'datamosh': prim_datamosh, + +} diff --git a/sexp_effects/test_interpreter.py b/sexp_effects/test_interpreter.py new file mode 100644 index 0000000..550b21a --- /dev/null +++ b/sexp_effects/test_interpreter.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +""" +Test the S-expression effect interpreter. +""" + +import numpy as np +import sys +from pathlib import Path + +# Add parent to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sexp_effects import ( + get_interpreter, + load_effects_dir, + run_effect, + list_effects, + parse, +) + + +def test_parser(): + """Test S-expression parser.""" + print("Testing parser...") + + # Simple expressions + assert parse("42") == 42 + assert parse("3.14") == 3.14 + assert parse('"hello"') == "hello" + assert parse("true") == True + + # Lists + assert parse("(+ 1 2)")[0].name == "+" + assert parse("(+ 1 2)")[1] == 1 + + # Nested + expr = parse("(define x (+ 1 2))") + assert expr[0].name == "define" + + print(" Parser OK") + + +def test_interpreter_basics(): + """Test basic interpreter operations.""" + print("Testing interpreter basics...") + + interp = get_interpreter() + + # Math + assert interp.eval(parse("(+ 1 2)")) == 3 + assert interp.eval(parse("(* 3 4)")) == 12 + assert interp.eval(parse("(- 10 3)")) == 7 + + # Comparison + assert interp.eval(parse("(< 1 2)")) == True + assert interp.eval(parse("(> 1 2)")) == False + + # Let binding + assert interp.eval(parse("(let ((x 5)) x)")) == 5 + assert interp.eval(parse("(let ((x 5) (y 3)) (+ x y))")) == 8 + + # Lambda + result = interp.eval(parse("((lambda (x) (* x 2)) 5)")) + assert result == 10 + + # If + assert interp.eval(parse("(if true 1 2)")) == 1 + assert interp.eval(parse("(if false 1 2)")) == 2 + + print(" Interpreter basics OK") + + +def test_primitives(): + """Test image primitives.""" + print("Testing primitives...") + + interp = get_interpreter() + + # Create test image + img = np.zeros((100, 100, 3), dtype=np.uint8) + img[50, 50] = [255, 128, 64] + + interp.global_env.set('test_img', img) + + # Width/height + assert interp.eval(parse("(width test_img)")) == 100 + assert interp.eval(parse("(height test_img)")) == 100 + + # Pixel + pixel = interp.eval(parse("(pixel test_img 50 50)")) + assert pixel == [255, 128, 64] + + # RGB + color = interp.eval(parse("(rgb 100 150 200)")) + assert color == [100, 150, 200] + + # Luminance + lum = interp.eval(parse("(luminance (rgb 100 100 100))")) + assert abs(lum - 100) < 1 + + print(" Primitives OK") + + +def test_effect_loading(): + """Test loading effects from .sexp files.""" + print("Testing effect loading...") + + # Load all effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + effects = list_effects() + print(f" Loaded {len(effects)} effects: {', '.join(sorted(effects))}") + + assert len(effects) > 0 + print(" Effect loading OK") + + +def test_effect_execution(): + """Test running effects on images.""" + print("Testing effect execution...") + + # Create test image + img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + + # Load effects + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Test each effect + effects = list_effects() + passed = 0 + failed = [] + + for name in sorted(effects): + try: + result, state = run_effect(name, img.copy(), {'_time': 0.5}, {}) + assert isinstance(result, np.ndarray) + assert result.shape == img.shape + passed += 1 + print(f" {name}: OK") + except Exception as e: + failed.append((name, str(e))) + print(f" {name}: FAILED - {e}") + + print(f" Passed: {passed}/{len(effects)}") + if failed: + print(f" Failed: {[f[0] for f in failed]}") + + return passed, failed + + +def test_ascii_fx_zone(): + """Test ascii_fx_zone effect with zone expressions.""" + print("Testing ascii_fx_zone...") + + interp = get_interpreter() + + # Load the effect + effects_dir = Path(__file__).parent / "effects" + load_effects_dir(str(effects_dir)) + + # Create gradient test frame + frame = np.zeros((120, 160, 3), dtype=np.uint8) + for x in range(160): + frame[:, x] = int(x / 160 * 255) + frame = np.stack([frame[:,:,0]]*3, axis=2) + + # Test 1: Basic without expressions + result, _ = run_effect('ascii_fx_zone', frame, {'cols': 20}, {}) + assert result.shape == frame.shape + print(" Basic run: OK") + + # Test 2: With zone-lum expression + expr = parse('(* zone-lum 180)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': expr + }, {}) + assert result.shape == frame.shape + print(" Zone-lum expression: OK") + + # Test 3: With multiple expressions + scale_expr = parse('(+ 0.5 (* zone-lum 0.5))') + rot_expr = parse('(* zone-row-norm 30)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_scale': scale_expr, + 'char_rotation': rot_expr + }, {}) + assert result.shape == frame.shape + print(" Multiple expressions: OK") + + # Test 4: With numeric literals + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': 90, + 'char_scale': 1.2 + }, {}) + assert result.shape == frame.shape + print(" Numeric literals: OK") + + # Test 5: Zone position expressions + col_expr = parse('(* zone-col-norm 360)') + result, _ = run_effect('ascii_fx_zone', frame, { + 'cols': 20, + 'char_hue': col_expr + }, {}) + assert result.shape == frame.shape + print(" Zone position expression: OK") + + print(" ascii_fx_zone OK") + + +def main(): + print("=" * 60) + print("S-Expression Effect Interpreter Tests") + print("=" * 60) + + test_parser() + test_interpreter_basics() + test_primitives() + test_effect_loading() + test_ascii_fx_zone() + passed, failed = test_effect_execution() + + print("=" * 60) + if not failed: + print("All tests passed!") + else: + print(f"Tests completed with {len(failed)} failures") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/streaming/__init__.py b/streaming/__init__.py new file mode 100644 index 0000000..2c007cc --- /dev/null +++ b/streaming/__init__.py @@ -0,0 +1,44 @@ +""" +Streaming video compositor for real-time effect processing. + +This module provides a frame-by-frame streaming architecture that: +- Reads from multiple video sources with automatic looping +- Applies effects inline (no intermediate files) +- Composites layers with time-varying weights +- Outputs to display, file, or stream + +Usage: + from streaming import StreamingCompositor, VideoSource, AudioAnalyzer + + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[...], + compositor_config={...}, + ) + + # With live audio + audio = AudioAnalyzer(device=0) + compositor.run(output="output.mp4", duration=60, audio=audio) + + # With preview window + compositor.run(output="preview", duration=60) + +Backends: + - numpy: Works everywhere, ~3-5 fps (default) + - glsl: Requires GPU, 30+ fps real-time (future) +""" + +from .sources import VideoSource, ImageSource +from .compositor import StreamingCompositor +from .backends import NumpyBackend, get_backend +from .output import DisplayOutput, FileOutput + +__all__ = [ + "StreamingCompositor", + "VideoSource", + "ImageSource", + "NumpyBackend", + "get_backend", + "DisplayOutput", + "FileOutput", +] diff --git a/streaming/audio.py b/streaming/audio.py new file mode 100644 index 0000000..9d20937 --- /dev/null +++ b/streaming/audio.py @@ -0,0 +1,486 @@ +""" +Live audio analysis for reactive effects. + +Provides real-time audio features: +- Energy (RMS amplitude) +- Beat detection +- Frequency bands (bass, mid, high) +""" + +import numpy as np +from typing import Optional +import threading +import time + + +class AudioAnalyzer: + """ + Real-time audio analyzer using sounddevice. + + Captures audio from microphone/line-in and computes + features in real-time for effect parameter bindings. + + Example: + analyzer = AudioAnalyzer(device=0) + analyzer.start() + + # In compositor loop: + energy = analyzer.get_energy() + beat = analyzer.get_beat() + + analyzer.stop() + """ + + def __init__( + self, + device: int = None, + sample_rate: int = 44100, + block_size: int = 1024, + buffer_seconds: float = 0.5, + ): + """ + Initialize audio analyzer. + + Args: + device: Audio input device index (None = default) + sample_rate: Audio sample rate + block_size: Samples per block + buffer_seconds: Ring buffer duration + """ + self.sample_rate = sample_rate + self.block_size = block_size + self.device = device + + # Ring buffer for recent audio + buffer_size = int(sample_rate * buffer_seconds) + self._buffer = np.zeros(buffer_size, dtype=np.float32) + self._buffer_pos = 0 + self._lock = threading.Lock() + + # Beat detection state + self._last_energy = 0 + self._energy_history = [] + self._last_beat_time = 0 + self._beat_threshold = 1.5 # Energy ratio for beat detection + self._min_beat_interval = 0.1 # Min seconds between beats + + # Stream state + self._stream = None + self._running = False + + def _audio_callback(self, indata, frames, time_info, status): + """Called by sounddevice for each audio block.""" + with self._lock: + # Add to ring buffer + data = indata[:, 0] if len(indata.shape) > 1 else indata + n = len(data) + if self._buffer_pos + n <= len(self._buffer): + self._buffer[self._buffer_pos:self._buffer_pos + n] = data + else: + # Wrap around + first = len(self._buffer) - self._buffer_pos + self._buffer[self._buffer_pos:] = data[:first] + self._buffer[:n - first] = data[first:] + self._buffer_pos = (self._buffer_pos + n) % len(self._buffer) + + def start(self): + """Start audio capture.""" + try: + import sounddevice as sd + except ImportError: + print("Warning: sounddevice not installed. Audio analysis disabled.") + print("Install with: pip install sounddevice") + return + + self._stream = sd.InputStream( + device=self.device, + channels=1, + samplerate=self.sample_rate, + blocksize=self.block_size, + callback=self._audio_callback, + ) + self._stream.start() + self._running = True + + def stop(self): + """Stop audio capture.""" + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + self._running = False + + def get_energy(self) -> float: + """ + Get current audio energy (RMS amplitude). + + Returns: + Energy value normalized to 0-1 range (approximately) + """ + with self._lock: + # Use recent samples + recent = 2048 + if self._buffer_pos >= recent: + data = self._buffer[self._buffer_pos - recent:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(recent - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # RMS energy + rms = np.sqrt(np.mean(data ** 2)) + + # Normalize (typical mic input is quite low) + normalized = min(1.0, rms * 10) + + return normalized + + def get_beat(self) -> bool: + """ + Detect if current moment is a beat. + + Simple onset detection based on energy spikes. + + Returns: + True if beat detected, False otherwise + """ + current_energy = self.get_energy() + now = time.time() + + # Update energy history + self._energy_history.append(current_energy) + if len(self._energy_history) > 20: + self._energy_history.pop(0) + + # Need enough history + if len(self._energy_history) < 5: + self._last_energy = current_energy + return False + + # Average recent energy + avg_energy = np.mean(self._energy_history[:-1]) + + # Beat if current energy is significantly above average + is_beat = ( + current_energy > avg_energy * self._beat_threshold and + now - self._last_beat_time > self._min_beat_interval and + current_energy > self._last_energy # Rising edge + ) + + if is_beat: + self._last_beat_time = now + + self._last_energy = current_energy + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """ + Get frequency spectrum divided into bands. + + Args: + bands: Number of frequency bands (default 3: bass, mid, high) + + Returns: + Array of band energies, normalized to 0-1 + """ + with self._lock: + # Use recent samples for FFT + n = 2048 + if self._buffer_pos >= n: + data = self._buffer[self._buffer_pos - n:self._buffer_pos] + else: + data = np.concatenate([ + self._buffer[-(n - self._buffer_pos):], + self._buffer[:self._buffer_pos] + ]) + + # FFT + fft = np.abs(np.fft.rfft(data * np.hanning(len(data)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + start = i * band_size + end = start + band_size + result[i] = np.mean(fft[start:end]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return self._running + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() + + +class FileAudioAnalyzer: + """ + Audio analyzer that reads from a file (for testing/development). + + Pre-computes analysis and plays back in sync with video. + """ + + def __init__(self, path: str, analysis_data: dict = None): + """ + Initialize from audio file. + + Args: + path: Path to audio file + analysis_data: Pre-computed analysis (times, values, etc.) + """ + self.path = path + self.analysis_data = analysis_data or {} + self._current_time = 0 + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Get energy at current time from pre-computed data.""" + track = self.analysis_data.get("energy", {}) + return self._interpolate(track, self._current_time) + + def get_beat(self) -> bool: + """Check if current time is near a beat.""" + track = self.analysis_data.get("beats", {}) + times = track.get("times", []) + + # Check if we're within 50ms of a beat + for beat_time in times: + if abs(beat_time - self._current_time) < 0.05: + return True + return False + + def _interpolate(self, track: dict, t: float) -> float: + """Interpolate value at time t.""" + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + if t <= times[0]: + return values[0] + if t >= times[-1]: + return values[-1] + + # Find bracket and interpolate + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + return values[i] * (1 - alpha) + values[i + 1] * alpha + + return values[-1] + + @property + def is_running(self) -> bool: + return True + + +class StreamingAudioAnalyzer: + """ + Real-time audio analyzer that streams from a file. + + Reads audio in sync with video time and computes features on-the-fly. + No pre-computation needed - analysis happens as frames are processed. + """ + + def __init__(self, path: str, sample_rate: int = 22050, hop_length: int = 512): + """ + Initialize streaming audio analyzer. + + Args: + path: Path to audio file + sample_rate: Sample rate for analysis + hop_length: Hop length for feature extraction + """ + import subprocess + import json + + self.path = path + self.sample_rate = sample_rate + self.hop_length = hop_length + self._current_time = 0.0 + + # Get audio duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(path)] + result = subprocess.run(cmd, capture_output=True, text=True) + info = json.loads(result.stdout) + self.duration = float(info["format"]["duration"]) + + # Audio buffer and state + self._audio_data = None + self._energy_history = [] + self._last_energy = 0 + self._last_beat_time = -1 + self._beat_threshold = 1.5 + self._min_beat_interval = 0.15 + + # Load audio lazily + self._loaded = False + + def _load_audio(self): + """Load audio data on first use.""" + if self._loaded: + return + + import subprocess + + # Use ffmpeg to decode audio to raw PCM + cmd = [ + "ffmpeg", "-v", "quiet", + "-i", str(self.path), + "-f", "f32le", # 32-bit float, little-endian + "-ac", "1", # mono + "-ar", str(self.sample_rate), + "-" + ] + result = subprocess.run(cmd, capture_output=True) + self._audio_data = np.frombuffer(result.stdout, dtype=np.float32) + self._loaded = True + + def set_time(self, t: float): + """Set current playback time.""" + self._current_time = t + + def get_energy(self) -> float: + """Compute energy at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return 0.0 + + # Get sample index for current time + sample_idx = int(self._current_time * self.sample_rate) + window_size = self.hop_length * 2 + + start = max(0, sample_idx - window_size // 2) + end = min(len(self._audio_data), sample_idx + window_size // 2) + + if start >= end: + return 0.0 + + # RMS energy + chunk = self._audio_data[start:end] + rms = np.sqrt(np.mean(chunk ** 2)) + + # Normalize to 0-1 range (approximate) + energy = min(1.0, rms * 3.0) + + self._last_energy = energy + return energy + + def get_beat(self) -> bool: + """Detect beat using spectral flux (change in frequency content).""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return False + + # Get audio chunks for current and previous frame + sample_idx = int(self._current_time * self.sample_rate) + chunk_size = self.hop_length * 2 + + # Current chunk + start = max(0, sample_idx - chunk_size // 2) + end = min(len(self._audio_data), sample_idx + chunk_size // 2) + if end - start < chunk_size // 2: + return False + current_chunk = self._audio_data[start:end] + + # Previous chunk (one hop back) + prev_start = max(0, start - self.hop_length) + prev_end = max(0, end - self.hop_length) + if prev_end <= prev_start: + return False + prev_chunk = self._audio_data[prev_start:prev_end] + + # Compute spectra + current_spec = np.abs(np.fft.rfft(current_chunk * np.hanning(len(current_chunk)))) + prev_spec = np.abs(np.fft.rfft(prev_chunk * np.hanning(len(prev_chunk)))) + + # Spectral flux: sum of positive differences (onset = new frequencies appearing) + min_len = min(len(current_spec), len(prev_spec)) + diff = current_spec[:min_len] - prev_spec[:min_len] + flux = np.sum(np.maximum(0, diff)) # Only count increases + + # Normalize by spectrum size + flux = flux / (min_len + 1) + + # Update flux history + self._energy_history.append((self._current_time, flux)) + while self._energy_history and self._energy_history[0][0] < self._current_time - 1.5: + self._energy_history.pop(0) + + if len(self._energy_history) < 3: + return False + + # Adaptive threshold based on recent flux values + flux_values = [f for t, f in self._energy_history] + mean_flux = np.mean(flux_values) + std_flux = np.std(flux_values) + 0.001 # Avoid division by zero + + # Beat if flux is above mean (more sensitive threshold) + threshold = mean_flux + std_flux * 0.3 # Lower = more sensitive + min_interval = 0.1 # Allow up to 600 BPM + time_ok = self._current_time - self._last_beat_time > min_interval + + is_beat = flux > threshold and time_ok + + if is_beat: + self._last_beat_time = self._current_time + + return is_beat + + def get_spectrum(self, bands: int = 3) -> np.ndarray: + """Get frequency spectrum at current time.""" + self._load_audio() + + if self._audio_data is None or len(self._audio_data) == 0: + return np.zeros(bands) + + sample_idx = int(self._current_time * self.sample_rate) + n = 2048 + + start = max(0, sample_idx - n // 2) + end = min(len(self._audio_data), sample_idx + n // 2) + + if end - start < n // 2: + return np.zeros(bands) + + chunk = self._audio_data[start:end] + + # FFT + fft = np.abs(np.fft.rfft(chunk * np.hanning(len(chunk)))) + + # Divide into bands + band_size = len(fft) // bands + result = np.zeros(bands) + for i in range(bands): + s, e = i * band_size, (i + 1) * band_size + result[i] = np.mean(fft[s:e]) + + # Normalize + max_val = np.max(result) + if max_val > 0: + result = result / max_val + + return result + + @property + def is_running(self) -> bool: + return True diff --git a/streaming/backends.py b/streaming/backends.py new file mode 100644 index 0000000..bc695d6 --- /dev/null +++ b/streaming/backends.py @@ -0,0 +1,308 @@ +""" +Effect processing backends. + +Provides abstraction over different rendering backends: +- numpy: CPU-based, works everywhere, ~3-5 fps +- glsl: GPU-based, requires OpenGL, 30+ fps (future) +""" + +import numpy as np +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from pathlib import Path + + +class Backend(ABC): + """Abstract base class for effect processing backends.""" + + @abstractmethod + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process multiple input frames through effects and composite. + + Args: + frames: List of input frames (one per source) + effects_per_frame: List of effect chains (one per source) + compositor_config: How to blend the layers + t: Current time in seconds + analysis_data: Analysis data for binding resolution + + Returns: + Composited output frame + """ + pass + + @abstractmethod + def load_effect(self, effect_path: Path) -> Any: + """Load an effect definition.""" + pass + + +class NumpyBackend(Backend): + """ + CPU-based effect processing using NumPy. + + Uses existing sexp_effects interpreter for effect execution. + Works on any system, but limited to ~3-5 fps for complex effects. + """ + + def __init__(self, recipe_dir: Path = None, minimal_primitives: bool = True): + self.recipe_dir = recipe_dir or Path(".") + self.minimal_primitives = minimal_primitives + self._interpreter = None + self._loaded_effects = {} + + def _get_interpreter(self): + """Lazy-load the sexp interpreter.""" + if self._interpreter is None: + from sexp_effects import get_interpreter + self._interpreter = get_interpreter(minimal_primitives=self.minimal_primitives) + return self._interpreter + + def load_effect(self, effect_path: Path) -> Any: + """Load an effect from sexp file.""" + effect_key = str(effect_path) + if effect_key not in self._loaded_effects: + interp = self._get_interpreter() + interp.load_effect(str(effect_path)) + self._loaded_effects[effect_key] = effect_path.stem + return self._loaded_effects[effect_key] + + def _resolve_binding(self, value: Any, t: float, analysis_data: Dict) -> Any: + """Resolve a parameter binding to its value at time t.""" + if not isinstance(value, dict): + return value + + if "_binding" in value or "_bind" in value: + source = value.get("source") or value.get("_bind") + feature = value.get("feature", "values") + range_map = value.get("range") + + track = analysis_data.get(source, {}) + times = track.get("times", []) + values = track.get("values", []) + + if not times or not values: + return 0.0 + + # Find value at time t (linear interpolation) + if t <= times[0]: + val = values[0] + elif t >= times[-1]: + val = values[-1] + else: + # Binary search for bracket + for i in range(len(times) - 1): + if times[i] <= t <= times[i + 1]: + alpha = (t - times[i]) / (times[i + 1] - times[i]) + val = values[i] * (1 - alpha) + values[i + 1] * alpha + break + else: + val = values[-1] + + # Apply range mapping + if range_map and len(range_map) == 2: + val = range_map[0] + val * (range_map[1] - range_map[0]) + + return val + + return value + + def _apply_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Apply a single effect to a frame.""" + # Resolve bindings in params + resolved_params = {"_time": t} + for key, value in params.items(): + if key in ("effect", "effect_path", "cid", "analysis_refs"): + continue + resolved_params[key] = self._resolve_binding(value, t, analysis_data) + + # Try fast native effects first + result = self._apply_native_effect(frame, effect_name, resolved_params) + if result is not None: + return result + + # Fall back to sexp interpreter for complex effects + interp = self._get_interpreter() + if effect_name in interp.effects: + result, _ = interp.run_effect(effect_name, frame, resolved_params, {}) + return result + + # Unknown effect - pass through + return frame + + def _apply_native_effect( + self, + frame: np.ndarray, + effect_name: str, + params: Dict, + ) -> Optional[np.ndarray]: + """Fast native numpy effects for real-time streaming.""" + import cv2 + + if effect_name == "zoom": + amount = float(params.get("amount", 1.0)) + if abs(amount - 1.0) < 0.01: + return frame + h, w = frame.shape[:2] + # Crop center and resize + new_w, new_h = int(w / amount), int(h / amount) + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + + elif effect_name == "rotate": + angle = float(params.get("angle", 0)) + if abs(angle) < 0.5: + return frame + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + return cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == "brightness": + amount = float(params.get("amount", 1.0)) + return np.clip(frame * amount, 0, 255).astype(np.uint8) + + elif effect_name == "invert": + amount = float(params.get("amount", 1.0)) + if amount < 0.5: + return frame + return 255 - frame + + # Not a native effect + return None + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """ + Process frames through effects and composite. + """ + if not frames: + return np.zeros((720, 1280, 3), dtype=np.uint8) + + processed = [] + + # Apply effects to each input frame + for i, (frame, effects) in enumerate(zip(frames, effects_per_frame)): + result = frame.copy() + for effect_config in effects: + effect_name = effect_config.get("effect", "") + if effect_name: + result = self._apply_effect( + result, effect_name, effect_config, t, analysis_data + ) + processed.append(result) + + # Composite layers + if len(processed) == 1: + return processed[0] + + return self._composite(processed, compositor_config, t, analysis_data) + + def _composite( + self, + frames: List[np.ndarray], + config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + """Composite multiple frames into one.""" + mode = config.get("mode", "alpha") + weights = config.get("weights", [1.0 / len(frames)] * len(frames)) + + # Resolve weight bindings + resolved_weights = [] + for w in weights: + resolved_weights.append(self._resolve_binding(w, t, analysis_data)) + + # Normalize weights + total = sum(resolved_weights) + if total > 0: + resolved_weights = [w / total for w in resolved_weights] + else: + resolved_weights = [1.0 / len(frames)] * len(frames) + + # Resize frames to match first frame + target_h, target_w = frames[0].shape[:2] + resized = [] + for frame in frames: + if frame.shape[:2] != (target_h, target_w): + import cv2 + frame = cv2.resize(frame, (target_w, target_h)) + resized.append(frame.astype(np.float32)) + + # Weighted blend + result = np.zeros_like(resized[0]) + for frame, weight in zip(resized, resolved_weights): + result += frame * weight + + return np.clip(result, 0, 255).astype(np.uint8) + + +class GLSLBackend(Backend): + """ + GPU-based effect processing using OpenGL/GLSL. + + Requires GPU with OpenGL 3.3+ support (or Mesa software renderer). + Achieves 30+ fps real-time processing. + + TODO: Implement when ready for GPU acceleration. + """ + + def __init__(self): + raise NotImplementedError( + "GLSL backend not yet implemented. Use NumpyBackend for now." + ) + + def load_effect(self, effect_path: Path) -> Any: + pass + + def process_frame( + self, + frames: List[np.ndarray], + effects_per_frame: List[List[Dict]], + compositor_config: Dict, + t: float, + analysis_data: Dict, + ) -> np.ndarray: + pass + + +def get_backend(name: str = "numpy", **kwargs) -> Backend: + """ + Get a backend by name. + + Args: + name: "numpy" or "glsl" + **kwargs: Backend-specific options + + Returns: + Backend instance + """ + if name == "numpy": + return NumpyBackend(**kwargs) + elif name == "glsl": + return GLSLBackend(**kwargs) + else: + raise ValueError(f"Unknown backend: {name}") diff --git a/streaming/compositor.py b/streaming/compositor.py new file mode 100644 index 0000000..477128f --- /dev/null +++ b/streaming/compositor.py @@ -0,0 +1,595 @@ +""" +Streaming video compositor. + +Main entry point for the streaming pipeline. Combines: +- Multiple video sources (with looping) +- Per-source effect chains +- Layer compositing +- Optional live audio analysis +- Output to display/file/stream +""" + +import time +import sys +import numpy as np +from typing import List, Dict, Any, Optional, Union +from pathlib import Path + +from .sources import Source, VideoSource +from .backends import Backend, NumpyBackend, get_backend +from .output import Output, DisplayOutput, FileOutput, MultiOutput + + +class StreamingCompositor: + """ + Real-time streaming video compositor. + + Reads frames from multiple sources, applies effects, composites layers, + and outputs the result - all frame-by-frame without intermediate files. + + Example: + compositor = StreamingCompositor( + sources=["video1.mp4", "video2.mp4"], + effects_per_source=[ + [{"effect": "rotate", "angle": 45}], + [{"effect": "zoom", "amount": 1.5}], + ], + compositor_config={"mode": "alpha", "weights": [0.5, 0.5]}, + ) + compositor.run(output="preview", duration=60) + """ + + def __init__( + self, + sources: List[Union[str, Source]], + effects_per_source: List[List[Dict]] = None, + compositor_config: Dict = None, + analysis_data: Dict = None, + backend: str = "numpy", + recipe_dir: Path = None, + fps: float = 30, + audio_source: str = None, + ): + """ + Initialize the streaming compositor. + + Args: + sources: List of video paths or Source objects + effects_per_source: List of effect chains, one per source + compositor_config: How to blend layers (mode, weights) + analysis_data: Pre-computed analysis data for bindings + backend: "numpy" or "glsl" + recipe_dir: Directory for resolving relative effect paths + fps: Output frame rate + audio_source: Path to audio file for streaming analysis + """ + self.fps = fps + self.recipe_dir = recipe_dir or Path(".") + self.analysis_data = analysis_data or {} + + # Initialize streaming audio analyzer if audio source provided + self._audio_analyzer = None + self._audio_source = audio_source + if audio_source: + from .audio import StreamingAudioAnalyzer + self._audio_analyzer = StreamingAudioAnalyzer(audio_source) + print(f"Streaming audio: {audio_source}", file=sys.stderr) + + # Initialize sources + self.sources: List[Source] = [] + for src in sources: + if isinstance(src, Source): + self.sources.append(src) + elif isinstance(src, (str, Path)): + self.sources.append(VideoSource(str(src), target_fps=fps)) + else: + raise ValueError(f"Unknown source type: {type(src)}") + + # Effect chains (default: no effects) + self.effects_per_source = effects_per_source or [[] for _ in self.sources] + if len(self.effects_per_source) != len(self.sources): + raise ValueError( + f"effects_per_source length ({len(self.effects_per_source)}) " + f"must match sources length ({len(self.sources)})" + ) + + # Compositor config (default: equal blend) + self.compositor_config = compositor_config or { + "mode": "alpha", + "weights": [1.0 / len(self.sources)] * len(self.sources), + } + + # Initialize backend + self.backend: Backend = get_backend( + backend, + recipe_dir=self.recipe_dir, + ) + + # Load effects + self._load_effects() + + def _load_effects(self): + """Pre-load all effect definitions.""" + for effects in self.effects_per_source: + for effect_config in effects: + effect_path = effect_config.get("effect_path") + if effect_path: + full_path = self.recipe_dir / effect_path + if full_path.exists(): + self.backend.load_effect(full_path) + + def _create_output( + self, + output: Union[str, Output], + size: tuple, + ) -> Output: + """Create output target from string or Output object.""" + if isinstance(output, Output): + return output + + if output == "preview": + return DisplayOutput("Streaming Preview", size, + audio_source=self._audio_source, fps=self.fps) + elif output == "null": + from .output import NullOutput + return NullOutput() + elif isinstance(output, str): + return FileOutput(output, size, fps=self.fps, audio_source=self._audio_source) + else: + raise ValueError(f"Unknown output type: {output}") + + def run( + self, + output: Union[str, Output] = "preview", + duration: float = None, + audio_analyzer=None, + show_fps: bool = True, + recipe_executor=None, + ): + """ + Run the streaming compositor. + + Args: + output: Output target - "preview", filename, or Output object + duration: Duration in seconds (None = run until quit) + audio_analyzer: Optional AudioAnalyzer for live audio reactivity + show_fps: Show FPS counter in console + recipe_executor: Optional StreamingRecipeExecutor for full recipe logic + """ + # Determine output size from first source + output_size = self.sources[0].size + + # Create output + out = self._create_output(output, output_size) + + # Determine duration + if duration is None: + # Run until stopped (or min source duration if not looping) + duration = min(s.duration for s in self.sources) + if duration == float('inf'): + duration = 3600 # 1 hour max for live sources + + total_frames = int(duration * self.fps) + frame_time = 1.0 / self.fps + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {total_frames} frames @ {self.fps}fps", file=sys.stderr) + print(f"Output size: {output_size[0]}x{output_size[1]}", file=sys.stderr) + print(f"Press 'q' to quit (if preview)", file=sys.stderr) + + # Frame loop + start_time = time.time() + frame_count = 0 + fps_update_interval = 30 # Update FPS display every N frames + last_fps_time = start_time + last_fps_count = 0 + + try: + for frame_num in range(total_frames): + if not out.is_open: + print(f"\nOutput closed at frame {frame_num}", file=sys.stderr) + break + + t = frame_num * frame_time + + try: + # Update analysis data from streaming audio (file-based) + energy = 0.0 + is_beat = False + if self._audio_analyzer: + self._update_from_audio(self._audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + elif audio_analyzer: + self._update_from_audio(audio_analyzer, t) + energy = self.analysis_data.get("live_energy", {}).get("values", [0])[0] + is_beat = self.analysis_data.get("live_beat", {}).get("values", [0])[0] > 0.5 + + # Read frames from all sources + frames = [src.read_frame(t) for src in self.sources] + + # Process through recipe executor if provided + if recipe_executor: + result = self._process_with_executor( + frames, recipe_executor, energy, is_beat, t + ) + else: + # Simple backend processing + result = self.backend.process_frame( + frames, + self.effects_per_source, + self.compositor_config, + t, + self.analysis_data, + ) + + # Output + out.write(result, t) + frame_count += 1 + + # FPS display + if show_fps and frame_count % fps_update_interval == 0: + now = time.time() + elapsed = now - last_fps_time + if elapsed > 0: + current_fps = (frame_count - last_fps_count) / elapsed + progress = frame_num / total_frames * 100 + print( + f"\r {progress:5.1f}% | {current_fps:5.1f} fps | " + f"frame {frame_num}/{total_frames}", + end="", file=sys.stderr + ) + last_fps_time = now + last_fps_count = frame_count + + except Exception as e: + print(f"\nError at frame {frame_num}, t={t:.1f}s: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + break + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources: + if hasattr(src, 'close'): + src.close() + + # Final stats + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", file=sys.stderr) + + def _process_with_executor( + self, + frames: List[np.ndarray], + executor, + energy: float, + is_beat: bool, + t: float, + ) -> np.ndarray: + """ + Process frames using the recipe executor for full pipeline. + + Implements: + 1. process-pair: two clips per source with effects, blended + 2. cycle-crossfade: dynamic composition with zoom and weights + 3. Final effects: whole-spin, ripple + """ + import cv2 + + # Target size from first source + target_h, target_w = frames[0].shape[:2] + + # Resize all frames to target size (letterbox to preserve aspect ratio) + resized_frames = [] + for frame in frames: + fh, fw = frame.shape[:2] + if (fh, fw) != (target_h, target_w): + # Calculate scale to fit while preserving aspect ratio + scale = min(target_w / fw, target_h / fh) + new_w, new_h = int(fw * scale), int(fh * scale) + resized = cv2.resize(frame, (new_w, new_h)) + # Center on black canvas + canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8) + x_off = (target_w - new_w) // 2 + y_off = (target_h - new_h) // 2 + canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized + resized_frames.append(canvas) + else: + resized_frames.append(frame) + frames = resized_frames + + # Update executor state + executor.on_frame(energy, is_beat, t) + + # Get weights to know which sources are active + weights = executor.get_cycle_weights() + + # Process each source as a "pair" (clip A and B with different effects) + processed_pairs = [] + + for i, frame in enumerate(frames): + # Skip sources with zero weight (but still need placeholder) + if i < len(weights) and weights[i] < 0.001: + processed_pairs.append(None) + continue + # Get effect params for clip A and B + params_a = executor.get_effect_params(i, "a", energy) + params_b = executor.get_effect_params(i, "b", energy) + pair_params = executor.get_pair_params(i) + + # Process clip A + clip_a = self._apply_clip_effects(frame.copy(), params_a, t) + + # Process clip B + clip_b = self._apply_clip_effects(frame.copy(), params_b, t) + + # Blend A and B using pair_mix opacity + opacity = pair_params["blend_opacity"] + blended = cv2.addWeighted( + clip_a, 1 - opacity, + clip_b, opacity, + 0 + ) + + # Apply pair rotation + h, w = blended.shape[:2] + center = (w // 2, h // 2) + angle = pair_params["pair_rotation"] + if abs(angle) > 0.5: + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + blended = cv2.warpAffine(blended, matrix, (w, h)) + + processed_pairs.append(blended) + + # Cycle-crossfade composition + weights = executor.get_cycle_weights() + zooms = executor.get_cycle_zooms() + + # Apply zoom per pair and composite + h, w = target_h, target_w + result = np.zeros((h, w, 3), dtype=np.float32) + + for idx, (pair, weight, zoom) in enumerate(zip(processed_pairs, weights, zooms)): + # Skip zero-weight sources + if pair is None or weight < 0.001: + continue + + orig_shape = pair.shape + + # Apply zoom + if zoom > 1.01: + # Zoom in: crop center and resize up + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = pair[y1:y1+new_h, x1:x1+new_w] + pair = cv2.resize(cropped, (w, h)) + elif zoom < 0.99: + # Zoom out: shrink video and center on black + scaled_w, scaled_h = int(w * zoom), int(h * zoom) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(pair, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + pair = canvas.copy() + + # Draw colored border - size indicates zoom level + border_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] + color = border_colors[idx % 4] + thickness = max(3, int(10 * weight)) # Thicker border = higher weight + pair = np.ascontiguousarray(pair) + pair[:thickness, :] = color + pair[-thickness:, :] = color + pair[:, :thickness] = color + pair[:, -thickness:] = color + + result += pair.astype(np.float32) * weight + + result = np.clip(result, 0, 255).astype(np.uint8) + + # Apply final effects (whole-spin, ripple) + final_params = executor.get_final_effects(energy) + + # Whole spin + spin_angle = final_params["whole_spin_angle"] + if abs(spin_angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, spin_angle, 1.0) + result = cv2.warpAffine(result, matrix, (w, h)) + + # Ripple effect + amp = final_params["ripple_amplitude"] + if amp > 1: + result = self._apply_ripple(result, amp, + final_params["ripple_cx"], + final_params["ripple_cy"], + t) + + return result + + def _apply_clip_effects(self, frame: np.ndarray, params: dict, t: float) -> np.ndarray: + """Apply per-clip effects: rotate, zoom, invert, hue_shift, ascii.""" + import cv2 + + h, w = frame.shape[:2] + + # Rotate + angle = params["rotate_angle"] + if abs(angle) > 0.5: + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + # Zoom + zoom = params["zoom_amount"] + if abs(zoom - 1.0) > 0.01: + new_w, new_h = int(w / zoom), int(h / zoom) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x1 + new_w), min(h, y1 + new_h) + if x2 > x1 and y2 > y1: + cropped = frame[y1:y2, x1:x2] + frame = cv2.resize(cropped, (w, h)) + + # Invert + if params["invert_amount"] > 0.5: + frame = 255 - frame + + # Hue shift + hue_deg = params["hue_degrees"] + if abs(hue_deg) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(np.int32) + int(hue_deg / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + # ASCII art + if params["ascii_mix"] > 0.5: + char_size = max(4, int(params["ascii_char_size"])) + frame = self._apply_ascii(frame, char_size) + + return frame + + def _apply_ascii(self, frame: np.ndarray, char_size: int) -> np.ndarray: + """Apply ASCII art effect.""" + import cv2 + from PIL import Image, ImageDraw, ImageFont + + h, w = frame.shape[:2] + chars = " .:-=+*#%@" + + # Get font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size) + except: + font = ImageFont.load_default() + + # Sample cells using area interpolation (fast block average) + rows = h // char_size + cols = w // char_size + if rows < 1 or cols < 1: + return frame + + # Crop to exact grid and downsample + cropped = frame[:rows * char_size, :cols * char_size] + cell_colors = cv2.resize(cropped, (cols, rows), interpolation=cv2.INTER_AREA) + + # Compute luminance + luminances = (0.299 * cell_colors[:, :, 0] + + 0.587 * cell_colors[:, :, 1] + + 0.114 * cell_colors[:, :, 2]) / 255.0 + + # Create output image + out_h = rows * char_size + out_w = cols * char_size + output = Image.new('RGB', (out_w, out_h), (0, 0, 0)) + draw = ImageDraw.Draw(output) + + # Draw characters + for r in range(rows): + for c in range(cols): + lum = luminances[r, c] + color = tuple(cell_colors[r, c]) + + # Map luminance to character + idx = int(lum * (len(chars) - 1)) + char = chars[idx] + + # Draw character + x = c * char_size + y = r * char_size + draw.text((x, y), char, fill=color, font=font) + + # Convert back to numpy and resize to original + result = np.array(output) + if result.shape[:2] != (h, w): + result = cv2.resize(result, (w, h), interpolation=cv2.INTER_LINEAR) + + return result + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, t: float = 0) -> np.ndarray: + """Apply ripple distortion effect.""" + import cv2 + + h, w = frame.shape[:2] + center_x, center_y = cx * w, cy * h + max_dim = max(w, h) + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Distance from center + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx*dx + dy*dy) + + # Ripple parameters (matching recipe: frequency=8, decay=2, speed=5) + freq = 8 + decay = 2 + speed = 5 + phase = t * speed * 2 * np.pi + + # Ripple displacement (matching original formula) + ripple = np.sin(2 * np.pi * freq * dist / max_dim + phase) * amplitude + + # Apply decay + if decay > 0: + ripple = ripple * np.exp(-dist * decay / max_dim) + + # Displace along radial direction + with np.errstate(divide='ignore', invalid='ignore'): + norm_dx = np.where(dist > 0, dx / dist, 0) + norm_dy = np.where(dist > 0, dy / dist, 0) + + map_x = (x_coords + ripple * norm_dx).astype(np.float32) + map_y = (y_coords + ripple * norm_dy).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, + borderMode=cv2.BORDER_REFLECT) + + def _update_from_audio(self, analyzer, t: float): + """Update analysis data from audio analyzer (streaming or live).""" + # Set time for file-based streaming analyzers + if hasattr(analyzer, 'set_time'): + analyzer.set_time(t) + + # Get current audio features + energy = analyzer.get_energy() if hasattr(analyzer, 'get_energy') else 0 + beat = analyzer.get_beat() if hasattr(analyzer, 'get_beat') else False + + # Update analysis tracks - these can be referenced by effect bindings + self.analysis_data["live_energy"] = { + "times": [t], + "values": [energy], + "duration": float('inf'), + } + self.analysis_data["live_beat"] = { + "times": [t], + "values": [1.0 if beat else 0.0], + "duration": float('inf'), + } + + +def quick_preview( + sources: List[str], + effects: List[List[Dict]] = None, + duration: float = 10, + fps: float = 30, +): + """ + Quick preview helper - show sources with optional effects. + + Example: + quick_preview(["video1.mp4", "video2.mp4"], duration=30) + """ + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects, + fps=fps, + ) + compositor.run(output="preview", duration=duration) diff --git a/streaming/demo.py b/streaming/demo.py new file mode 100644 index 0000000..0b1899f --- /dev/null +++ b/streaming/demo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Demo script for streaming compositor. + +Usage: + # Preview two videos blended + python -m streaming.demo preview video1.mp4 video2.mp4 + + # Record output to file + python -m streaming.demo record video1.mp4 video2.mp4 -o output.mp4 + + # Benchmark (no output) + python -m streaming.demo benchmark video1.mp4 --duration 10 +""" + +import argparse +import sys +from pathlib import Path + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from streaming import StreamingCompositor, VideoSource +from streaming.output import NullOutput + + +def demo_preview(sources: list, duration: float, effects: bool = False): + """Preview sources with optional simple effects.""" + effects_config = None + if effects: + effects_config = [ + [{"effect": "rotate", "angle": 15}], + [{"effect": "zoom", "amount": 1.2}], + ][:len(sources)] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="preview", duration=duration) + + +def demo_record(sources: list, output_path: str, duration: float): + """Record blended output to file.""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output=output_path, duration=duration) + + +def demo_benchmark(sources: list, duration: float): + """Benchmark processing speed (no output).""" + compositor = StreamingCompositor( + sources=sources, + recipe_dir=Path(__file__).parent.parent, + ) + compositor.run(output="null", duration=duration) + + +def demo_audio_reactive(sources: list, duration: float): + """Preview with live audio reactivity.""" + from streaming.audio import AudioAnalyzer + + # Create compositor with energy-reactive effects + effects_config = [ + [{ + "effect": "zoom", + "amount": {"_binding": True, "source": "live_energy", "feature": "values", "range": [1.0, 1.5]}, + }] + for _ in sources + ] + + compositor = StreamingCompositor( + sources=sources, + effects_per_source=effects_config, + recipe_dir=Path(__file__).parent.parent, + ) + + # Start audio analyzer + try: + with AudioAnalyzer() as audio: + print("Audio analyzer started. Make some noise!", file=sys.stderr) + compositor.run(output="preview", duration=duration, audio_analyzer=audio) + except Exception as e: + print(f"Audio not available: {e}", file=sys.stderr) + print("Running without audio...", file=sys.stderr) + compositor.run(output="preview", duration=duration) + + +def main(): + parser = argparse.ArgumentParser(description="Streaming compositor demo") + parser.add_argument("mode", choices=["preview", "record", "benchmark", "audio"], + help="Demo mode") + parser.add_argument("sources", nargs="+", help="Video source files") + parser.add_argument("-o", "--output", help="Output file (for record mode)") + parser.add_argument("-d", "--duration", type=float, default=30, + help="Duration in seconds") + parser.add_argument("--effects", action="store_true", + help="Apply simple effects (for preview)") + + args = parser.parse_args() + + # Verify sources exist + for src in args.sources: + if not Path(src).exists(): + print(f"Error: Source not found: {src}", file=sys.stderr) + sys.exit(1) + + if args.mode == "preview": + demo_preview(args.sources, args.duration, args.effects) + elif args.mode == "record": + if not args.output: + print("Error: --output required for record mode", file=sys.stderr) + sys.exit(1) + demo_record(args.sources, args.output, args.duration) + elif args.mode == "benchmark": + demo_benchmark(args.sources, args.duration) + elif args.mode == "audio": + demo_audio_reactive(args.sources, args.duration) + + +if __name__ == "__main__": + main() diff --git a/streaming/output.py b/streaming/output.py new file mode 100644 index 0000000..c273bd1 --- /dev/null +++ b/streaming/output.py @@ -0,0 +1,369 @@ +""" +Output targets for streaming compositor. + +Supports: +- Display window (preview) +- File output (recording) +- Stream output (RTMP, etc.) - future +""" + +import numpy as np +import subprocess +from abc import ABC, abstractmethod +from typing import Tuple, Optional +from pathlib import Path + + +class Output(ABC): + """Abstract base class for output targets.""" + + @abstractmethod + def write(self, frame: np.ndarray, t: float): + """Write a frame to the output.""" + pass + + @abstractmethod + def close(self): + """Close the output and clean up resources.""" + pass + + @property + @abstractmethod + def is_open(self) -> bool: + """Check if output is still open/valid.""" + pass + + +class DisplayOutput(Output): + """ + Display frames using mpv (handles Wayland properly). + + Useful for live preview. Press 'q' to quit. + """ + + def __init__(self, title: str = "Streaming Preview", size: Tuple[int, int] = None, + audio_source: str = None, fps: float = 30): + self.title = title + self.size = size + self.audio_source = audio_source + self.fps = fps + self._is_open = True + self._process = None + self._audio_process = None + + def _start_mpv(self, frame_size: Tuple[int, int]): + """Start mpv process for display.""" + import sys + w, h = frame_size + cmd = [ + "mpv", + "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + f"--title={self.title}", + "-", + ] + print(f"Starting mpv: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Start audio playback if we have an audio source + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Display frame.""" + if not self._is_open: + return + + # Ensure frame is correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + # Start mpv on first frame + if self._process is None: + self._start_mpv((frame.shape[1], frame.shape[0])) + + # Check if mpv is still running + if self._process.poll() is not None: + self._is_open = False + return + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() # Prevent buffering + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the display and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open + + +class FileOutput(Output): + """ + Write frames to a video file using ffmpeg. + """ + + def __init__( + self, + path: str, + size: Tuple[int, int], + fps: float = 30, + codec: str = "libx264", + crf: int = 18, + preset: str = "fast", + audio_source: str = None, + ): + self.path = Path(path) + self.size = size + self.fps = fps + self._is_open = True + + # Build ffmpeg command + cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{size[0]}x{size[1]}", + "-r", str(fps), + "-i", "-", + ] + + # Add audio input if provided + if audio_source: + cmd.extend(["-i", str(audio_source)]) + # Explicitly map: video from input 0 (rawvideo), audio from input 1 + cmd.extend(["-map", "0:v", "-map", "1:a"]) + + cmd.extend([ + "-c:v", codec, + "-preset", preset, + "-crf", str(crf), + "-pix_fmt", "yuv420p", + ]) + + # Add audio codec if we have audio + if audio_source: + cmd.extend(["-c:a", "aac", "-b:a", "192k", "-shortest"]) + + cmd.append(str(self.path)) + + import sys + print(f"FileOutput cmd: {' '.join(cmd)}", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=None, # Show errors for debugging + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to video file.""" + if not self._is_open or self._process.poll() is not None: + self._is_open = False + return + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close the video file.""" + if self._process: + self._process.stdin.close() + self._process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open and self._process.poll() is None + + +class MultiOutput(Output): + """ + Write to multiple outputs simultaneously. + + Useful for recording while showing preview. + """ + + def __init__(self, outputs: list): + self.outputs = outputs + + def write(self, frame: np.ndarray, t: float): + for output in self.outputs: + if output.is_open: + output.write(frame, t) + + def close(self): + for output in self.outputs: + output.close() + + @property + def is_open(self) -> bool: + return any(o.is_open for o in self.outputs) + + +class NullOutput(Output): + """ + Discard frames (for benchmarking). + """ + + def __init__(self): + self._is_open = True + self.frame_count = 0 + + def write(self, frame: np.ndarray, t: float): + self.frame_count += 1 + + def close(self): + self._is_open = False + + @property + def is_open(self) -> bool: + return self._is_open + + +class PipeOutput(Output): + """ + Pipe frames directly to mpv. + + Launches mpv with rawvideo demuxer and writes frames to stdin. + """ + + def __init__(self, size: Tuple[int, int], fps: float = 30, audio_source: str = None): + self.size = size + self.fps = fps + self.audio_source = audio_source + self._is_open = True + self._process = None + self._audio_process = None + self._started = False + + def _start(self): + """Start mpv and audio on first frame.""" + if self._started: + return + self._started = True + + import sys + w, h = self.size + + # Start mpv + cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={self.fps}", + "--title=Streaming", + "-" + ] + print(f"Starting mpv: {w}x{h} @ {self.fps}fps", file=sys.stderr) + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + # Start audio + if self.audio_source: + audio_cmd = [ + "ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + str(self.audio_source) + ] + print(f"Starting audio: {self.audio_source}", file=sys.stderr) + self._audio_process = subprocess.Popen( + audio_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame: np.ndarray, t: float): + """Write frame to mpv.""" + if not self._is_open: + return + + self._start() + + # Check mpv still running + if self._process.poll() is not None: + self._is_open = False + return + + # Resize if needed + if frame.shape[1] != self.size[0] or frame.shape[0] != self.size[1]: + import cv2 + frame = cv2.resize(frame, self.size) + + # Ensure correct format + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + + try: + self._process.stdin.write(frame.tobytes()) + self._process.stdin.flush() + except BrokenPipeError: + self._is_open = False + + def close(self): + """Close mpv and audio.""" + if self._process: + try: + self._process.stdin.close() + except: + pass + self._process.terminate() + self._process.wait() + if self._audio_process: + self._audio_process.terminate() + self._audio_process.wait() + self._is_open = False + + @property + def is_open(self) -> bool: + if self._process and self._process.poll() is not None: + self._is_open = False + return self._is_open diff --git a/streaming/pipeline.py b/streaming/pipeline.py new file mode 100644 index 0000000..29dd7e1 --- /dev/null +++ b/streaming/pipeline.py @@ -0,0 +1,846 @@ +""" +Streaming pipeline executor. + +Directly executes compiled sexp recipes frame-by-frame. +No adapter layer - frames and analysis flow through the DAG. +""" + +import sys +import time +import numpy as np +from pathlib import Path +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sources import VideoSource +from .audio import StreamingAudioAnalyzer +from .output import DisplayOutput, FileOutput +from .sexp_interp import SexpInterpreter + + +@dataclass +class FrameContext: + """Context passed through the pipeline for each frame.""" + t: float # Current time + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + analysis: Dict[str, Any] = field(default_factory=dict) + + +class StreamingPipeline: + """ + Executes a compiled sexp recipe as a streaming pipeline. + + Frames flow through the DAG directly - no adapter needed. + Each node is evaluated lazily when its output is requested. + """ + + def __init__(self, compiled_recipe, recipe_dir: Path = None, fps: float = 30, seed: int = 42, + output_size: tuple = None): + self.recipe = compiled_recipe + self.recipe_dir = recipe_dir or Path(".") + self.fps = fps + self.seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # Runtime state + self.sources: Dict[str, VideoSource] = {} + self.audio_analyzer: Optional[StreamingAudioAnalyzer] = None + self.audio_source_path: Optional[str] = None + + # Sexp interpreter for expressions + self.interp = SexpInterpreter() + + # Scan state (node_id -> current value) + self.scan_state: Dict[str, Any] = {} + self.scan_emit: Dict[str, Any] = {} + + # SLICE_ON state + self.slice_on_acc: Dict[str, Any] = {} + self.slice_on_result: Dict[str, Any] = {} + + # Frame cache for current timestep (cleared each frame) + self._frame_cache: Dict[str, np.ndarray] = {} + + # Context for current frame + self.ctx = FrameContext(t=0.0) + + # Output size (w, h) - set after sources are initialized + self._output_size = output_size + + # Initialize + self._init_sources() + self._init_scans() + self._init_slice_on() + + # Set output size from first source if not specified + if self._output_size is None and self.sources: + first_source = next(iter(self.sources.values())) + self._output_size = first_source._size + + def _init_sources(self): + """Initialize video and audio sources.""" + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + config = node.get('config', {}) + path = config.get('path') + if path: + full_path = (self.recipe_dir / path).resolve() + suffix = full_path.suffix.lower() + + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + if not full_path.exists(): + print(f"Warning: video not found: {full_path}", file=sys.stderr) + continue + self.sources[node['id']] = VideoSource( + str(full_path), + target_fps=self.fps + ) + elif suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + if not full_path.exists(): + print(f"Warning: audio not found: {full_path}", file=sys.stderr) + continue + self.audio_source_path = str(full_path) + self.audio_analyzer = StreamingAudioAnalyzer(str(full_path)) + + def _init_scans(self): + """Initialize scan nodes with their initial state.""" + import random + seed_offset = 0 + + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + config = node.get('config', {}) + + # Create RNG for this scan + scan_seed = config.get('seed', self.seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + init_value = self.interp.eval(init_expr, {}) + + self.scan_state[node['id']] = { + 'value': init_value, + 'rng': rng, + 'config': config, + } + + # Compute initial emit + self._update_scan_emit(node['id']) + + def _update_scan_emit(self, node_id: str): + """Update the emit value for a scan.""" + state = self.scan_state[node_id] + config = state['config'] + emit_expr = config.get('emit_expr', config.get('emit', None)) + + if emit_expr is None: + # No emit expression - emit the value directly + self.scan_emit[node_id] = state['value'] + return + + # Build environment from state + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG for interpreter + self.interp.rng = state['rng'] + + self.scan_emit[node_id] = self.interp.eval(emit_expr, env) + + def _step_scan(self, node_id: str): + """Step a scan forward on beat.""" + state = self.scan_state[node_id] + config = state['config'] + step_expr = config.get('step_expr', config.get('step', None)) + + if step_expr is None: + return + + # Build environment + env = {} + if isinstance(state['value'], dict): + env.update(state['value']) + else: + env['acc'] = state['value'] + + env['beat_count'] = self.ctx.beat_count + env['time'] = self.ctx.t + + # Set RNG + self.interp.rng = state['rng'] + + # Evaluate step + new_value = self.interp.eval(step_expr, env) + state['value'] = new_value + + # Update emit + self._update_scan_emit(node_id) + + def _init_slice_on(self): + """Initialize SLICE_ON nodes.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + init = config.get('init', {}) + self.slice_on_acc[node['id']] = dict(init) + + # Evaluate initial state + self._eval_slice_on(node['id']) + + def _eval_slice_on(self, node_id: str): + """Evaluate a SLICE_ON node's Lambda.""" + node = self.nodes[node_id] + config = node.get('config', {}) + fn = config.get('fn') + videos = config.get('videos', []) + + if not fn: + return + + acc = self.slice_on_acc[node_id] + n_videos = len(videos) + + # Set up environment + self.interp.globals['videos'] = list(range(n_videos)) + + try: + from .sexp_interp import eval_slice_on_lambda + result = eval_slice_on_lambda( + fn, acc, self.ctx.beat_count, 0, 1, + list(range(n_videos)), self.interp + ) + self.slice_on_result[node_id] = result + + # Update accumulator + if 'acc' in result: + self.slice_on_acc[node_id] = result['acc'] + except Exception as e: + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _on_beat(self): + """Called when a beat is detected.""" + self.ctx.beat_count += 1 + + # Step all scans + for node_id in self.scan_state: + self._step_scan(node_id) + + # Step all SLICE_ON nodes + for node_id in self.slice_on_acc: + self._eval_slice_on(node_id) + + def _get_frame(self, node_id: str) -> Optional[np.ndarray]: + """ + Get the output frame for a node at current time. + + Recursively evaluates inputs as needed. + Results are cached for the current timestep. + """ + if node_id in self._frame_cache: + return self._frame_cache[node_id] + + node = self.nodes.get(node_id) + if not node: + return None + + node_type = node.get('type') + + if node_type == 'SOURCE': + frame = self._eval_source(node) + elif node_type == 'SEGMENT': + frame = self._eval_segment(node) + elif node_type == 'EFFECT': + frame = self._eval_effect(node) + elif node_type == 'SLICE_ON': + frame = self._eval_slice_on_frame(node) + else: + # Unknown node type - try to pass through input + inputs = node.get('inputs', []) + frame = self._get_frame(inputs[0]) if inputs else None + + self._frame_cache[node_id] = frame + return frame + + def _eval_source(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SOURCE node.""" + source = self.sources.get(node['id']) + if source: + return source.read_frame(self.ctx.t) + return None + + def _eval_segment(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SEGMENT node (time segment of source).""" + inputs = node.get('inputs', []) + if not inputs: + return None + + config = node.get('config', {}) + start = config.get('start', 0) + duration = config.get('duration') + + # Resolve any bindings + if isinstance(start, dict): + start = self._resolve_binding(start) if start.get('_binding') else 0 + if isinstance(duration, dict): + duration = self._resolve_binding(duration) if duration.get('_binding') else None + + # Adjust time for segment + t_local = self.ctx.t + (start if isinstance(start, (int, float)) else 0) + if duration and isinstance(duration, (int, float)): + t_local = t_local % duration # Loop within segment + + # Get source frame at adjusted time + source_id = inputs[0] + source = self.sources.get(source_id) + if source: + return source.read_frame(t_local) + + return self._get_frame(source_id) + + def _eval_effect(self, node: dict) -> Optional[np.ndarray]: + """Evaluate an EFFECT node.""" + import cv2 + + inputs = node.get('inputs', []) + config = node.get('config', {}) + effect_name = config.get('effect') + + # Get input frame(s) + input_frames = [self._get_frame(inp) for inp in inputs] + input_frames = [f for f in input_frames if f is not None] + + if not input_frames: + return None + + frame = input_frames[0] + + # Resolve bindings in config + params = self._resolve_config(config) + + # Apply effect based on name + if effect_name == 'rotate': + angle = params.get('angle', 0) + if abs(angle) > 0.5: + h, w = frame.shape[:2] + center = (w // 2, h // 2) + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + frame = cv2.warpAffine(frame, matrix, (w, h)) + + elif effect_name == 'zoom': + amount = params.get('amount', 1.0) + if abs(amount - 1.0) > 0.01: + frame = self._apply_zoom(frame, amount) + + elif effect_name == 'invert': + amount = params.get('amount', 0) + if amount > 0.01: + inverted = 255 - frame + frame = cv2.addWeighted(frame, 1 - amount, inverted, amount, 0) + + elif effect_name == 'hue_shift': + degrees = params.get('degrees', 0) + if abs(degrees) > 1: + hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV) + hsv[:, :, 0] = (hsv[:, :, 0].astype(int) + int(degrees / 2)) % 180 + frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + + elif effect_name == 'blend': + if len(input_frames) >= 2: + opacity = params.get('opacity', 0.5) + frame = cv2.addWeighted(input_frames[0], 1 - opacity, + input_frames[1], opacity, 0) + + elif effect_name == 'blend_multi': + weights = params.get('weights', []) + if len(input_frames) > 1 and weights: + h, w = input_frames[0].shape[:2] + result = np.zeros((h, w, 3), dtype=np.float32) + for f, wt in zip(input_frames, weights): + if f is not None and wt > 0.001: + if f.shape[:2] != (h, w): + f = cv2.resize(f, (w, h)) + result += f.astype(np.float32) * wt + frame = np.clip(result, 0, 255).astype(np.uint8) + + elif effect_name == 'ripple': + amp = params.get('amplitude', 0) + if amp > 1: + frame = self._apply_ripple(frame, amp, + params.get('center_x', 0.5), + params.get('center_y', 0.5), + params.get('frequency', 8), + params.get('decay', 2), + params.get('speed', 5)) + + return frame + + def _eval_slice_on_frame(self, node: dict) -> Optional[np.ndarray]: + """Evaluate a SLICE_ON node - returns composited frame.""" + import cv2 + + config = node.get('config', {}) + video_ids = config.get('videos', []) + result = self.slice_on_result.get(node['id'], {}) + + if not result: + # No result yet - return first video + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get layers and compose info + layers = result.get('layers', []) + compose = result.get('compose', {}) + weights = compose.get('weights', []) + + if not layers or not weights: + if video_ids: + return self._get_frame(video_ids[0]) + return None + + # Get frames for each layer + frames = [] + for i, layer in enumerate(layers): + video_idx = layer.get('video', i) + if video_idx < len(video_ids): + frame = self._get_frame(video_ids[video_idx]) + + # Apply layer effects (zoom) + effects = layer.get('effects', []) + for eff in effects: + eff_name = eff.get('effect') + if hasattr(eff_name, 'name'): + eff_name = eff_name.name + if eff_name == 'zoom': + zoom_amt = eff.get('amount', 1.0) + if frame is not None: + frame = self._apply_zoom(frame, zoom_amt) + + frames.append(frame) + else: + frames.append(None) + + # Composite with weights - use consistent output size + if self._output_size: + w, h = self._output_size + else: + # Fallback to first non-None frame size + for f in frames: + if f is not None: + h, w = f.shape[:2] + break + else: + return None + + output = np.zeros((h, w, 3), dtype=np.float32) + + for frame, weight in zip(frames, weights): + if frame is None or weight < 0.001: + continue + + # Resize to output size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + + output += frame.astype(np.float32) * weight + + # Normalize weights + total_weight = sum(wt for wt in weights if wt > 0.001) + if total_weight > 0 and abs(total_weight - 1.0) > 0.01: + output /= total_weight + + return np.clip(output, 0, 255).astype(np.uint8) + + def _resolve_config(self, config: dict) -> dict: + """Resolve bindings in effect config to actual values.""" + resolved = {} + + for key, value in config.items(): + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', + 'analysis_refs', 'inputs', 'cid'): + continue + + if isinstance(value, dict) and value.get('_binding'): + resolved[key] = self._resolve_binding(value) + elif isinstance(value, dict) and value.get('_expr'): + resolved[key] = self._resolve_expr(value) + else: + resolved[key] = value + + return resolved + + def _resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to its current value.""" + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get raw value from scan or analysis + if source_id in self.scan_emit: + value = self.scan_emit[source_id] + elif source_id in self.ctx.analysis: + data = self.ctx.analysis[source_id] + value = data.get(feature, data.get('values', [0]))[0] if isinstance(data, dict) else data + else: + # Fallback to energy + value = self.ctx.energy + + # Extract feature from dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def _resolve_expr(self, expr: dict) -> Any: + """Resolve a compiled expression.""" + env = { + 'energy': self.ctx.energy, + 'beat_count': self.ctx.beat_count, + 't': self.ctx.t, + } + + # Add scan values + for scan_id, value in self.scan_emit.items(): + # Use short form if available + env[scan_id] = value + + # Extract the actual expression from _expr wrapper + actual_expr = expr.get('_expr', expr) + return self.interp.eval(actual_expr, env) + + def _apply_zoom(self, frame: np.ndarray, amount: float) -> np.ndarray: + """Apply zoom to frame.""" + import cv2 + h, w = frame.shape[:2] + + if amount > 1.01: + # Zoom in: crop center + new_w, new_h = int(w / amount), int(h / amount) + if new_w > 0 and new_h > 0: + x1, y1 = (w - new_w) // 2, (h - new_h) // 2 + cropped = frame[y1:y1+new_h, x1:x1+new_w] + return cv2.resize(cropped, (w, h)) + elif amount < 0.99: + # Zoom out: shrink and center + scaled_w, scaled_h = int(w * amount), int(h * amount) + if scaled_w > 0 and scaled_h > 0: + shrunk = cv2.resize(frame, (scaled_w, scaled_h)) + canvas = np.zeros((h, w, 3), dtype=np.uint8) + x_off, y_off = (w - scaled_w) // 2, (h - scaled_h) // 2 + canvas[y_off:y_off+scaled_h, x_off:x_off+scaled_w] = shrunk + return canvas + + return frame + + def _apply_ripple(self, frame: np.ndarray, amplitude: float, + cx: float, cy: float, frequency: float, + decay: float, speed: float) -> np.ndarray: + """Apply ripple effect.""" + import cv2 + h, w = frame.shape[:2] + + # Create coordinate grids + y_coords, x_coords = np.mgrid[0:h, 0:w].astype(np.float32) + + # Normalize to center + center_x, center_y = w * cx, h * cy + dx = x_coords - center_x + dy = y_coords - center_y + dist = np.sqrt(dx**2 + dy**2) + + # Ripple displacement + phase = self.ctx.t * speed + ripple = amplitude * np.sin(dist / frequency - phase) * np.exp(-dist * decay / max(w, h)) + + # Displace coordinates + angle = np.arctan2(dy, dx) + map_x = (x_coords + ripple * np.cos(angle)).astype(np.float32) + map_y = (y_coords + ripple * np.sin(angle)).astype(np.float32) + + return cv2.remap(frame, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + + def _find_output_node(self) -> Optional[str]: + """Find the final output node (MUX or last EFFECT).""" + # Look for MUX node + for node in self.recipe.nodes: + if node.get('type') == 'MUX': + return node['id'] + + # Otherwise find last EFFECT after SLICE_ON + last_effect = None + found_slice_on = False + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + found_slice_on = True + elif node.get('type') == 'EFFECT' and found_slice_on: + last_effect = node['id'] + + return last_effect + + def render_frame(self, t: float) -> Optional[np.ndarray]: + """Render a single frame at time t.""" + # Clear frame cache + self._frame_cache.clear() + + # Update context + self.ctx.t = t + + # Update audio analysis + if self.audio_analyzer: + self.audio_analyzer.set_time(t) + energy = self.audio_analyzer.get_energy() + is_beat = self.audio_analyzer.get_beat() + + # Beat edge detection + was_beat = self.ctx.is_beat + self.ctx.energy = energy + self.ctx.is_beat = is_beat + + if is_beat and not was_beat: + self._on_beat() + + # Store in analysis dict + self.ctx.analysis['live_energy'] = {'values': [energy]} + self.ctx.analysis['live_beat'] = {'values': [1.0 if is_beat else 0.0]} + + # Find output node and render + output_node = self._find_output_node() + if output_node: + frame = self._get_frame(output_node) + # Normalize to output size + if frame is not None and self._output_size: + w, h = self._output_size + if frame.shape[1] != w or frame.shape[0] != h: + import cv2 + frame = cv2.resize(frame, (w, h)) + return frame + + return None + + def run(self, output: str = "preview", duration: float = None): + """ + Run the pipeline. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration or 60s) + """ + # Determine duration + if duration is None: + if self.audio_analyzer: + duration = self.audio_analyzer.duration + else: + duration = 60.0 + + # Create output + if output == "preview": + # Get frame size from first source + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = DisplayOutput(size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + elif isinstance(output, str): + first_source = next(iter(self.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + out = FileOutput(output, size=(w, h), fps=self.fps, audio_source=self.audio_source_path) + else: + out = output + + frame_time = 1.0 / self.fps + n_frames = int(duration * self.fps) + + print(f"Streaming: {len(self.sources)} sources -> {output}", file=sys.stderr) + print(f"Duration: {duration:.1f}s, {n_frames} frames @ {self.fps}fps", file=sys.stderr) + + start_time = time.time() + frame_count = 0 + + try: + for frame_num in range(n_frames): + t = frame_num * frame_time + + frame = self.render_frame(t) + + if frame is not None: + out.write(frame, t) + frame_count += 1 + + # Progress + if frame_num % 50 == 0: + elapsed = time.time() - start_time + fps = frame_count / elapsed if elapsed > 0 else 0 + pct = 100 * frame_num / n_frames + print(f"\r{pct:5.1f}% | {fps:5.1f} fps | frame {frame_num}/{n_frames}", + end="", file=sys.stderr) + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + finally: + out.close() + for src in self.sources.values(): + src.close() + + elapsed = time.time() - start_time + avg_fps = frame_count / elapsed if elapsed > 0 else 0 + print(f"\nCompleted: {frame_count} frames in {elapsed:.1f}s ({avg_fps:.1f} fps avg)", + file=sys.stderr) + + +def run_pipeline(recipe_path: str, output: str = "preview", + duration: float = None, fps: float = None): + """ + Run a recipe through the streaming pipeline. + + No adapter layer - directly executes the compiled recipe. + """ + from pathlib import Path + + # Add artdag to path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + pipeline.run(output=output, duration=duration) + + +def run_pipeline_piped(recipe_path: str, duration: float = None, fps: float = None): + """ + Run pipeline and pipe directly to mpv with audio. + """ + import subprocess + from pathlib import Path + import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + from artdag.sexp.compiler import compile_string + + recipe_path = Path(recipe_path) + recipe_text = recipe_path.read_text() + compiled = compile_string(recipe_text, {}, recipe_dir=recipe_path.parent) + + pipeline = StreamingPipeline( + compiled, + recipe_dir=recipe_path.parent, + fps=fps or compiled.encoding.get('fps', 30), + ) + + # Get frame info + first_source = next(iter(pipeline.sources.values()), None) + if first_source: + w, h = first_source._size + else: + w, h = 720, 720 + + # Determine duration + if duration is None: + if pipeline.audio_analyzer: + duration = pipeline.audio_analyzer.duration + else: + duration = 60.0 + + actual_fps = fps or compiled.encoding.get('fps', 30) + n_frames = int(duration * actual_fps) + frame_time = 1.0 / actual_fps + + print(f"Streaming {n_frames} frames @ {actual_fps}fps to mpv", file=sys.stderr) + + # Start mpv + mpv_cmd = [ + "mpv", "--no-cache", + "--demuxer=rawvideo", + f"--demuxer-rawvideo-w={w}", + f"--demuxer-rawvideo-h={h}", + "--demuxer-rawvideo-mp-format=rgb24", + f"--demuxer-rawvideo-fps={actual_fps}", + "--title=Streaming Pipeline", + "-" + ] + mpv = subprocess.Popen(mpv_cmd, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL) + + # Start audio if available + audio_proc = None + if pipeline.audio_source_path: + audio_cmd = ["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", + pipeline.audio_source_path] + audio_proc = subprocess.Popen(audio_cmd, stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + + try: + import cv2 + for frame_num in range(n_frames): + if mpv.poll() is not None: + break # mpv closed + + t = frame_num * frame_time + frame = pipeline.render_frame(t) + if frame is not None: + # Ensure consistent frame size + if frame.shape[1] != w or frame.shape[0] != h: + frame = cv2.resize(frame, (w, h)) + if not frame.flags['C_CONTIGUOUS']: + frame = np.ascontiguousarray(frame) + try: + mpv.stdin.write(frame.tobytes()) + mpv.stdin.flush() + except BrokenPipeError: + break + except KeyboardInterrupt: + pass + finally: + if mpv.stdin: + mpv.stdin.close() + mpv.terminate() + if audio_proc: + audio_proc.terminate() + for src in pipeline.sources.values(): + src.close() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe through streaming pipeline") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_pipeline_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_pipeline(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/streaming/recipe_adapter.py b/streaming/recipe_adapter.py new file mode 100644 index 0000000..2133919 --- /dev/null +++ b/streaming/recipe_adapter.py @@ -0,0 +1,470 @@ +""" +Adapter to run sexp recipes through the streaming compositor. + +Bridges the gap between: +- Existing recipe format (sexp files with stages, effects, analysis) +- Streaming compositor (sources, effect chains, compositor config) +""" + +import sys +from pathlib import Path +from typing import Dict, List, Any, Optional + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) + +from .compositor import StreamingCompositor +from .sources import VideoSource +from .audio import FileAudioAnalyzer + + +class RecipeAdapter: + """ + Adapts a compiled sexp recipe to run through the streaming compositor. + + Example: + adapter = RecipeAdapter("effects/quick_test.sexp") + adapter.run(output="preview", duration=60) + """ + + def __init__( + self, + recipe_path: str, + params: Dict[str, Any] = None, + backend: str = "numpy", + ): + """ + Load and prepare a recipe for streaming. + + Args: + recipe_path: Path to .sexp recipe file + params: Parameter overrides + backend: "numpy" or "glsl" + """ + self.recipe_path = Path(recipe_path) + self.recipe_dir = self.recipe_path.parent + self.params = params or {} + self.backend = backend + + # Compile recipe + self._compile() + + def _compile(self): + """Compile the recipe and extract structure.""" + from artdag.sexp.compiler import compile_string + + recipe_text = self.recipe_path.read_text() + self.compiled = compile_string(recipe_text, self.params, recipe_dir=self.recipe_dir) + + # Extract key info + self.sources = {} # name -> path + self.effects_registry = {} # effect_name -> path + self.analyzers = {} # name -> analyzer info + + # Walk nodes to find sources and structure + # nodes is a list in CompiledRecipe + for node in self.compiled.nodes: + node_type = node.get("type", "") + + if node_type == "SOURCE": + config = node.get("config", {}) + path = config.get("path") + if path: + self.sources[node["id"]] = self.recipe_dir / path + + elif node_type == "ANALYZE": + config = node.get("config", {}) + self.analyzers[node["id"]] = { + "analyzer": config.get("analyzer"), + "path": config.get("analyzer_path"), + } + + # Get effects registry from compiled recipe + # registry has 'effects' sub-dict + effects_dict = self.compiled.registry.get("effects", {}) + for name, info in effects_dict.items(): + if info.get("path"): + self.effects_registry[name] = Path(info["path"]) + + def run_analysis(self) -> Dict[str, Any]: + """ + Run analysis phase (energy, beats, etc.). + + Returns: + Dict of analysis track name -> {times, values, duration} + """ + print(f"Running analysis...", file=sys.stderr) + + # Use existing planner's analysis execution + from artdag.sexp.planner import create_plan + + analysis_data = {} + + def on_analysis(node_id: str, results: dict): + analysis_data[node_id] = results + print(f" {node_id[:16]}...: {len(results.get('times', []))} samples", file=sys.stderr) + + # Create plan (runs analysis as side effect) + plan = create_plan( + self.compiled, + inputs={}, + recipe_dir=self.recipe_dir, + on_analysis=on_analysis, + ) + + # Also store named analysis tracks + for name, data in plan.analysis.items(): + analysis_data[name] = data + + return analysis_data + + def build_compositor( + self, + analysis_data: Dict[str, Any] = None, + fps: float = None, + ) -> StreamingCompositor: + """ + Build a streaming compositor from the recipe. + + This is a simplified version that handles common patterns. + Complex recipes may need manual configuration. + + Args: + analysis_data: Pre-computed analysis data + + Returns: + Configured StreamingCompositor + """ + # Extract video and audio sources in SLICE_ON input order + video_sources = [] + audio_source = None + + # Find audio source first + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp3', '.wav', '.flac', '.ogg', '.m4a', '.aac'): + audio_source = str(path) + break + + # Find SLICE_ON node to get correct video order + slice_on_inputs = None + for node in self.compiled.nodes: + if node.get('type') == 'SLICE_ON': + # Use 'videos' config key which has the correct order + config = node.get('config', {}) + slice_on_inputs = config.get('videos', []) + break + + if slice_on_inputs: + # Trace each SLICE_ON input back to its SOURCE + node_lookup = {n['id']: n for n in self.compiled.nodes} + + def trace_to_source(node_id, visited=None): + """Trace a node back to its SOURCE, return source path.""" + if visited is None: + visited = set() + if node_id in visited: + return None + visited.add(node_id) + + node = node_lookup.get(node_id) + if not node: + return None + if node.get('type') == 'SOURCE': + return self.sources.get(node_id) + # Recurse through inputs + for inp in node.get('inputs', []): + result = trace_to_source(inp, visited) + if result: + return result + return None + + # Build video_sources in SLICE_ON input order + for inp_id in slice_on_inputs: + source_path = trace_to_source(inp_id) + if source_path: + suffix = source_path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(source_path)) + + # Fallback to definition order if no SLICE_ON + if not video_sources: + for node_id, path in self.sources.items(): + suffix = path.suffix.lower() + if suffix in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + video_sources.append(str(path)) + + if not video_sources: + raise ValueError("No video sources found in recipe") + + # Build effect chains - use live audio bindings (matching video_sources count) + effects_per_source = self._build_streaming_effects(n_sources=len(video_sources)) + + # Build compositor config from recipe + compositor_config = self._extract_compositor_config(analysis_data) + + return StreamingCompositor( + sources=video_sources, + effects_per_source=effects_per_source, + compositor_config=compositor_config, + analysis_data=analysis_data or {}, + backend=self.backend, + recipe_dir=self.recipe_dir, + fps=fps or self.compiled.encoding.get("fps", 30), + audio_source=audio_source, + ) + + def _build_streaming_effects(self, n_sources: int = None) -> List[List[Dict]]: + """ + Build effect chains for streaming with live audio bindings. + + Replicates the recipe's effect pipeline: + - Per source: rotate, zoom, invert, hue_shift, ascii_art + - All driven by live_energy and live_beat + """ + if n_sources is None: + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + + effects_per_source = [] + + for i in range(n_sources): + # Alternate rotation direction per source + rot_dir = 1 if i % 2 == 0 else -1 + + effects = [ + # Rotate - energy drives angle + { + "effect": "rotate", + "effect_path": str(self.effects_registry.get("rotate", "")), + "angle": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 45 * rot_dir], + }, + }, + # Zoom - energy drives amount + { + "effect": "zoom", + "effect_path": str(self.effects_registry.get("zoom", "")), + "amount": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [1.0, 1.5] if i % 2 == 0 else [1.0, 0.7], + }, + }, + # Invert - beat triggers + { + "effect": "invert", + "effect_path": str(self.effects_registry.get("invert", "")), + "amount": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + # Hue shift - energy drives hue + { + "effect": "hue_shift", + "effect_path": str(self.effects_registry.get("hue_shift", "")), + "degrees": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [0, 180], + }, + }, + # ASCII art - energy drives char size, beat triggers mix + { + "effect": "ascii_art", + "effect_path": str(self.effects_registry.get("ascii_art", "")), + "char_size": { + "_binding": True, + "source": "live_energy", + "feature": "values", + "range": [4, 32], + }, + "mix": { + "_binding": True, + "source": "live_beat", + "feature": "values", + "range": [0, 1], + }, + }, + ] + effects_per_source.append(effects) + + return effects_per_source + + def _extract_effects(self) -> List[List[Dict]]: + """Extract effect chains for each source (legacy, pre-computed analysis).""" + # Simplified: find EFFECT nodes and their configs + effects_per_source = [] + + for node_id, path in self.sources.items(): + if path.suffix.lower() not in ('.mp4', '.webm', '.mov', '.avi', '.mkv'): + continue + + # Find effects that depend on this source + # This is simplified - real implementation would trace the DAG + effects = [] + + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + effect_name = config.get("effect") + if effect_name and effect_name in self.effects_registry: + effect_config = { + "effect": effect_name, + "effect_path": str(self.effects_registry[effect_name]), + } + # Copy only effect params (filter out internal fields) + internal_fields = ( + "effect", "effect_path", "cid", "effect_cid", + "effects_registry", "analysis_refs", "inputs", + ) + for k, v in config.items(): + if k not in internal_fields: + effect_config[k] = v + effects.append(effect_config) + break # One effect per source for now + + effects_per_source.append(effects) + + return effects_per_source + + def _extract_compositor_config(self, analysis_data: Dict) -> Dict: + """Extract compositor configuration.""" + # Look for blend_multi or similar composition nodes + for node in self.compiled.nodes: + if node.get("type") == "EFFECT": + config = node.get("config", {}) + if config.get("effect") == "blend_multi": + return { + "mode": config.get("mode", "alpha"), + "weights": config.get("weights", []), + } + + # Default: equal blend + n_sources = len([p for p in self.sources.values() + if p.suffix.lower() in ('.mp4', '.webm', '.mov', '.avi', '.mkv')]) + return { + "mode": "alpha", + "weights": [1.0 / n_sources] * n_sources if n_sources > 0 else [1.0], + } + + def run( + self, + output: str = "preview", + duration: float = None, + fps: float = None, + ): + """ + Run the recipe through streaming compositor. + + Everything streams: video frames read on-demand, audio analyzed in real-time. + No pre-computation. + + Args: + output: "preview", filename, or Output object + duration: Duration in seconds (default: audio duration) + fps: Frame rate (default from recipe, or 30) + """ + # Build compositor with recipe executor for full pipeline + from .recipe_executor import StreamingRecipeExecutor + + compositor = self.build_compositor(analysis_data={}, fps=fps) + + # Use audio duration if not specified + if duration is None: + if compositor._audio_analyzer: + duration = compositor._audio_analyzer.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + else: + # Live mode - run until quit + print("Live mode - press 'q' to quit", file=sys.stderr) + + # Create sexp executor that interprets the recipe + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(self.compiled, seed=42) + + compositor.run(output=output, duration=duration, recipe_executor=executor) + + +def run_recipe( + recipe_path: str, + output: str = "preview", + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run a recipe through streaming compositor. + + Everything streams in real-time: video frames, audio analysis. + No pre-computation - starts immediately. + + Example: + run_recipe("effects/quick_test.sexp", output="preview", duration=30) + run_recipe("effects/quick_test.sexp", output="preview", fps=5) # Lower fps for slow systems + """ + adapter = RecipeAdapter(recipe_path, params=params) + adapter.run(output=output, duration=duration, fps=fps) + + +def run_recipe_piped( + recipe_path: str, + duration: float = None, + params: Dict = None, + fps: float = None, +): + """ + Run recipe and pipe directly to mpv. + """ + from .output import PipeOutput + + adapter = RecipeAdapter(recipe_path, params=params) + compositor = adapter.build_compositor(analysis_data={}, fps=fps) + + # Get frame size + if compositor.sources: + first_source = compositor.sources[0] + w, h = first_source._size + else: + w, h = 720, 720 + + actual_fps = fps or adapter.compiled.encoding.get('fps', 30) + + # Create pipe output + pipe_out = PipeOutput( + size=(w, h), + fps=actual_fps, + audio_source=compositor._audio_source + ) + + # Create executor + from .sexp_executor import SexpStreamingExecutor + executor = SexpStreamingExecutor(adapter.compiled, seed=42) + + # Run with pipe output + compositor.run(output=pipe_out, duration=duration, recipe_executor=executor) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run sexp recipe with streaming compositor") + parser.add_argument("recipe", help="Path to .sexp recipe file") + parser.add_argument("-o", "--output", default="pipe", + help="Output: 'pipe' (mpv), 'preview', or filename (default: pipe)") + parser.add_argument("-d", "--duration", type=float, default=None, + help="Duration in seconds (default: audio duration)") + parser.add_argument("--fps", type=float, default=None, + help="Frame rate (default: from recipe)") + args = parser.parse_args() + + if args.output == "pipe": + run_recipe_piped(args.recipe, duration=args.duration, fps=args.fps) + else: + run_recipe(args.recipe, output=args.output, duration=args.duration, fps=args.fps) diff --git a/streaming/recipe_executor.py b/streaming/recipe_executor.py new file mode 100644 index 0000000..678d9f6 --- /dev/null +++ b/streaming/recipe_executor.py @@ -0,0 +1,415 @@ +""" +Streaming recipe executor. + +Implements the full recipe logic for real-time streaming: +- Scans (state machines that evolve on beats) +- Process-pair template (two clips with sporadic effects, blended) +- Cycle-crossfade (dynamic composition cycling through video pairs) +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + + +@dataclass +class ScanState: + """State for a scan (beat-driven state machine).""" + value: Any = 0 + rng: random.Random = field(default_factory=random.Random) + + +class StreamingScans: + """ + Real-time scan executor. + + Scans are state machines that evolve on each beat. + They drive effect parameters like invert triggers, hue shifts, etc. + """ + + def __init__(self, seed: int = 42, n_sources: int = 4): + self.master_seed = seed + self.n_sources = n_sources + self.scans: Dict[str, ScanState] = {} + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self._init_scans() + + def _init_scans(self): + """Initialize all scans with their own RNG seeds.""" + scan_names = [] + + # Per-pair scans (dynamic based on n_sources) + for i in range(self.n_sources): + scan_names.extend([ + f"inv_a_{i}", f"inv_b_{i}", f"hue_a_{i}", f"hue_b_{i}", + f"ascii_a_{i}", f"ascii_b_{i}", f"pair_mix_{i}", f"pair_rot_{i}", + ]) + + # Global scans + scan_names.extend(["whole_spin", "ripple_gate", "cycle"]) + + for i, name in enumerate(scan_names): + rng = random.Random(self.master_seed + i) + self.scans[name] = ScanState(value=self._init_value(name), rng=rng) + + def _init_value(self, name: str) -> Any: + """Get initial value for a scan.""" + if name.startswith("inv_") or name.startswith("ascii_"): + return 0 # Counter for remaining beats + elif name.startswith("hue_"): + return {"rem": 0, "hue": 0} + elif name.startswith("pair_mix"): + return {"rem": 0, "opacity": 0.5} + elif name.startswith("pair_rot"): + pair_idx = int(name.split("_")[-1]) + rot_dir = 1 if pair_idx % 2 == 0 else -1 + return {"beat": 0, "clen": 25, "dir": rot_dir, "angle": 0} + elif name == "whole_spin": + return { + "phase": 0, # 0 = waiting, 1 = spinning + "beat": 0, # beats into current phase + "plen": 20, # beats in this phase + "dir": 1, # spin direction + "total_angle": 0.0, # cumulative angle after all spins + "spin_start_angle": 0.0, # angle when current spin started + "spin_start_time": 0.0, # time when current spin started + "spin_end_time": 0.0, # estimated time when spin ends + } + elif name == "ripple_gate": + return {"rem": 0, "cx": 0.5, "cy": 0.5} + elif name == "cycle": + return {"cycle": 0, "beat": 0, "clen": 60} + return 0 + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + # Estimate beat interval from last two beats + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + for name, state in self.scans.items(): + state.value = self._step_scan(name, state.value, state.rng, beat_interval) + + def _step_scan(self, name: str, value: Any, rng: random.Random, beat_interval: float = 0.5) -> Any: + """Step a scan forward by one beat.""" + + # Invert scan: 10% chance, lasts 1-5 beats + if name.startswith("inv_"): + if value > 0: + return value - 1 + elif rng.random() < 0.1: + return rng.randint(1, 5) + return 0 + + # Hue scan: 10% chance, random hue 30-330, lasts 1-5 beats + elif name.startswith("hue_"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "hue": value["hue"]} + elif rng.random() < 0.1: + return {"rem": rng.randint(1, 5), "hue": rng.uniform(30, 330)} + return {"rem": 0, "hue": 0} + + # ASCII scan: 5% chance, lasts 1-3 beats + elif name.startswith("ascii_"): + if value > 0: + return value - 1 + elif rng.random() < 0.05: + return rng.randint(1, 3) + return 0 + + # Pair mix: changes every 1-11 beats + elif name.startswith("pair_mix"): + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "opacity": value["opacity"]} + return {"rem": rng.randint(1, 11), "opacity": rng.choice([0, 0.5, 1.0])} + + # Pair rotation: full rotation every 20-30 beats + elif name.startswith("pair_rot"): + beat = value["beat"] + clen = value["clen"] + dir_ = value["dir"] + angle = value["angle"] + + if beat + 1 < clen: + new_angle = angle + dir_ * (360 / clen) + return {"beat": beat + 1, "clen": clen, "dir": dir_, "angle": new_angle} + else: + return {"beat": 0, "clen": rng.randint(20, 30), "dir": -dir_, "angle": angle} + + # Whole spin: sporadic 720 degree spins (cumulative - stays rotated) + elif name == "whole_spin": + phase = value["phase"] + beat = value["beat"] + plen = value["plen"] + dir_ = value["dir"] + total_angle = value.get("total_angle", 0.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", 0.0) + + if phase == 1: + # Currently spinning + if beat + 1 < plen: + return { + "phase": 1, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Spin complete - update total_angle with final spin + new_total = spin_start_angle + dir_ * 720.0 + return { + "phase": 0, "beat": 0, "plen": rng.randint(20, 40), "dir": dir_, + "total_angle": new_total, + "spin_start_angle": new_total, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time, + } + else: + # Waiting phase + if beat + 1 < plen: + return { + "phase": 0, "beat": beat + 1, "plen": plen, "dir": dir_, + "total_angle": total_angle, + "spin_start_angle": spin_start_angle, + "spin_start_time": spin_start_time, + "spin_end_time": spin_end_time, + } + else: + # Start new spin + new_dir = 1 if rng.random() < 0.5 else -1 + new_plen = rng.randint(10, 25) + spin_duration = new_plen * beat_interval + return { + "phase": 1, "beat": 0, "plen": new_plen, "dir": new_dir, + "total_angle": total_angle, + "spin_start_angle": total_angle, + "spin_start_time": self.current_time, + "spin_end_time": self.current_time + spin_duration, + } + + # Ripple gate: 5% chance, lasts 1-20 beats + elif name == "ripple_gate": + if value["rem"] > 0: + return {"rem": value["rem"] - 1, "cx": value["cx"], "cy": value["cy"]} + elif rng.random() < 0.05: + return {"rem": rng.randint(1, 20), + "cx": rng.uniform(0.1, 0.9), + "cy": rng.uniform(0.1, 0.9)} + return {"rem": 0, "cx": 0.5, "cy": 0.5} + + # Cycle: track which video pair is active + elif name == "cycle": + beat = value["beat"] + clen = value["clen"] + cycle = value["cycle"] + + if beat + 1 < clen: + return {"cycle": cycle, "beat": beat + 1, "clen": clen} + else: + # Move to next pair, vary cycle length + return {"cycle": (cycle + 1) % 4, "beat": 0, + "clen": 40 + (self.beat_count * 7) % 41} + + return value + + def get_emit(self, name: str) -> float: + """Get emitted value for a scan.""" + value = self.scans[name].value + + if name.startswith("inv_") or name.startswith("ascii_"): + return 1.0 if value > 0 else 0.0 + + elif name.startswith("hue_"): + return value["hue"] if value["rem"] > 0 else 0.0 + + elif name.startswith("pair_mix"): + return value["opacity"] + + elif name.startswith("pair_rot"): + return value["angle"] + + elif name == "whole_spin": + # Smooth time-based interpolation during spin + phase = value.get("phase", 0) + if phase == 1: + # Currently spinning - interpolate based on time + spin_start_time = value.get("spin_start_time", 0.0) + spin_end_time = value.get("spin_end_time", spin_start_time + 1.0) + spin_start_angle = value.get("spin_start_angle", 0.0) + dir_ = value.get("dir", 1) + + duration = spin_end_time - spin_start_time + if duration > 0: + progress = (self.current_time - spin_start_time) / duration + progress = max(0.0, min(1.0, progress)) # clamp to 0-1 + else: + progress = 1.0 + + return spin_start_angle + progress * 720.0 * dir_ + else: + # Not spinning - return cumulative angle + return value.get("total_angle", 0.0) + + elif name == "ripple_gate": + return 1.0 if value["rem"] > 0 else 0.0 + + elif name == "cycle": + return value + + return 0.0 + + +class StreamingRecipeExecutor: + """ + Executes a recipe in streaming mode. + + Implements: + - process-pair: two video clips with opposite effects, blended + - cycle-crossfade: dynamic cycling through video pairs + - Final effects: whole-spin rotation, ripple + """ + + def __init__(self, n_sources: int = 4, seed: int = 42): + self.n_sources = n_sources + self.scans = StreamingScans(seed, n_sources=n_sources) + self.last_beat_detected = False + self.current_time = 0.0 + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with current audio analysis.""" + self.current_time = t + self.scans.current_time = t + # Update scans on beat + if is_beat and not self.last_beat_detected: + self.scans.on_beat() + self.last_beat_detected = is_beat + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """ + Get effect parameters for a source clip. + + Args: + source_idx: Which video source (0-3) + clip: "a" or "b" (each source has two clips) + energy: Current audio energy (0-1) + """ + suffix = f"_{source_idx}" + + # Rotation ranges alternate + if source_idx % 2 == 0: + rot_range = [0, 45] if clip == "a" else [0, -45] + zoom_range = [1, 1.5] if clip == "a" else [1, 0.5] + else: + rot_range = [0, -45] if clip == "a" else [0, 45] + zoom_range = [1, 0.5] if clip == "a" else [1, 1.5] + + return { + "rotate_angle": rot_range[0] + energy * (rot_range[1] - rot_range[0]), + "zoom_amount": zoom_range[0] + energy * (zoom_range[1] - zoom_range[0]), + "invert_amount": self.scans.get_emit(f"inv_{clip}{suffix}"), + "hue_degrees": self.scans.get_emit(f"hue_{clip}{suffix}"), + "ascii_mix": 0, # Disabled - too slow without GPU + "ascii_char_size": 4 + energy * 28, # 4-32 + } + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair.""" + suffix = f"_{source_idx}" + return { + "blend_opacity": self.scans.get_emit(f"pair_mix{suffix}"), + "pair_rotation": self.scans.get_emit(f"pair_rot{suffix}"), + } + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade composition.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + weights = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + w = 0.9 + elif phase3 < clen * 2: + w = 0.9 - ((phase3 - clen) / clen) * 0.85 + else: + w = 0.05 + elif active == prev: + if phase3 < clen: + w = 0.05 + elif phase3 < clen * 2: + w = 0.05 + ((phase3 - clen) / clen) * 0.85 + else: + w = 0.9 + else: + w = 0.05 + + weights.append(w) + + # Normalize + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + + return weights + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade.""" + cycle_state = self.scans.get_emit("cycle") + active = cycle_state["cycle"] + beat = cycle_state["beat"] + clen = cycle_state["clen"] + n = self.n_sources + + phase3 = beat * 3 + zooms = [] + + for p in range(n): + prev = (p + n - 1) % n + + if active == p: + if phase3 < clen: + z = 1.0 + elif phase3 < clen * 2: + z = 1.0 + ((phase3 - clen) / clen) * 1.0 + else: + z = 0.1 + elif active == prev: + if phase3 < clen: + z = 3.0 # Start big + elif phase3 < clen * 2: + z = 3.0 - ((phase3 - clen) / clen) * 2.0 # Shrink to 1.0 + else: + z = 1.0 + else: + z = 0.1 + + zooms.append(z) + + return zooms + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (whole-spin, ripple).""" + ripple_gate = self.scans.get_emit("ripple_gate") + ripple_state = self.scans.scans["ripple_gate"].value + + return { + "whole_spin_angle": self.scans.get_emit("whole_spin"), + "ripple_amplitude": ripple_gate * (5 + energy * 45), # 5-50 + "ripple_cx": ripple_state["cx"], + "ripple_cy": ripple_state["cy"], + } diff --git a/streaming/sexp_executor.py b/streaming/sexp_executor.py new file mode 100644 index 0000000..0151853 --- /dev/null +++ b/streaming/sexp_executor.py @@ -0,0 +1,678 @@ +""" +Streaming S-expression executor. + +Executes compiled sexp recipes in real-time by: +- Evaluating scan expressions on each beat +- Resolving bindings to get effect parameter values +- Applying effects frame-by-frame +- Evaluating SLICE_ON Lambda for cycle crossfade +""" + +import random +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, field + +from .sexp_interp import SexpInterpreter, eval_slice_on_lambda + + +@dataclass +class ScanState: + """Runtime state for a scan.""" + node_id: str + name: Optional[str] + value: Any + rng: random.Random + init_expr: dict + step_expr: dict + emit_expr: dict + + +class ExprEvaluator: + """ + Evaluates compiled expression ASTs. + + Expressions are dicts with: + - _expr: True (marks as expression) + - op: operation name + - args: list of arguments + - name: for 'var' ops + - keys: for 'dict' ops + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + + def eval(self, expr: Any, env: Dict[str, Any]) -> Any: + """Evaluate an expression in the given environment.""" + # Literal values + if not isinstance(expr, dict): + return expr + + # Check if it's an expression + if not expr.get('_expr'): + # It's a plain dict - return as-is + return expr + + op = expr.get('op') + args = expr.get('args', []) + + # Evaluate based on operation + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + raise KeyError(f"Unknown variable: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison ops + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic ops + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random ops + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.randint(lo, hi) + elif op == 'rand-range': + lo = self.eval(args[0], env) + hi = self.eval(args[1], env) + return self.rng.uniform(lo, hi) + + # Logic ops + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown operation: {op}") + + +class SexpStreamingExecutor: + """ + Executes a compiled sexp recipe in streaming mode. + + Reads scan definitions, effect chains, and bindings from the + compiled recipe and executes them frame-by-frame. + """ + + def __init__(self, compiled_recipe, seed: int = 42): + self.recipe = compiled_recipe + self.master_seed = seed + + # Build node lookup + self.nodes = {n['id']: n for n in compiled_recipe.nodes} + + # State (must be initialized before _init_scans) + self.beat_count = 0 + self.current_time = 0.0 + self.last_beat_time = 0.0 + self.last_beat_detected = False + self.energy = 0.0 + + # Initialize scans + self.scans: Dict[str, ScanState] = {} + self.scan_outputs: Dict[str, Any] = {} # Current emit values by node_id + self._init_scans() + + # Initialize SLICE_ON interpreter + self.sexp_interp = SexpInterpreter(random.Random(seed)) + self._slice_on_lambda = None + self._slice_on_acc = None + self._slice_on_result = None # Last evaluation result {layers, compose, acc} + self._init_slice_on() + + def _init_slice_on(self): + """Initialize SLICE_ON Lambda for cycle crossfade.""" + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + config = node.get('config', {}) + self._slice_on_lambda = config.get('fn') + init = config.get('init', {}) + self._slice_on_acc = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + # Evaluate initial state + self._eval_slice_on() + break + + def _eval_slice_on(self): + """Evaluate the SLICE_ON Lambda with current state.""" + if not self._slice_on_lambda: + return + + n = len(self._get_video_sources()) + videos = list(range(n)) # Placeholder video indices + + try: + result = eval_slice_on_lambda( + self._slice_on_lambda, + self._slice_on_acc, + self.beat_count, + 0.0, # start time (not used for weights) + 1.0, # end time (not used for weights) + videos, + self.sexp_interp, + ) + self._slice_on_result = result + # Update accumulator for next beat + if 'acc' in result: + self._slice_on_acc = result['acc'] + except Exception as e: + import sys + print(f"SLICE_ON eval error: {e}", file=sys.stderr) + + def _init_scans(self): + """Initialize all scan nodes from the recipe.""" + seed_offset = 0 + for node in self.recipe.nodes: + if node.get('type') == 'SCAN': + node_id = node['id'] + config = node.get('config', {}) + + # Create RNG with unique seed + scan_seed = config.get('seed', self.master_seed + seed_offset) + rng = random.Random(scan_seed) + seed_offset += 1 + + # Evaluate initial value + init_expr = config.get('init', 0) + evaluator = ExprEvaluator(rng) + init_value = evaluator.eval(init_expr, {}) + + self.scans[node_id] = ScanState( + node_id=node_id, + name=node.get('name'), + value=init_value, + rng=rng, + init_expr=init_expr, + step_expr=config.get('step_expr', {}), + emit_expr=config.get('emit_expr', {}), + ) + + # Compute initial emit + self._update_emit(node_id) + + def _update_emit(self, node_id: str): + """Update the emit value for a scan.""" + scan = self.scans[node_id] + evaluator = ExprEvaluator(scan.rng) + + # Build environment from current state + env = self._build_scan_env(scan) + + # Evaluate emit expression + emit_value = evaluator.eval(scan.emit_expr, env) + self.scan_outputs[node_id] = emit_value + + def _build_scan_env(self, scan: ScanState) -> Dict[str, Any]: + """Build environment for scan expression evaluation.""" + env = {} + + # Add state variables + if isinstance(scan.value, dict): + env.update(scan.value) + else: + env['acc'] = scan.value + + # Add beat count + env['beat_count'] = self.beat_count + env['time'] = self.current_time + + return env + + def on_beat(self): + """Update all scans on a beat.""" + self.beat_count += 1 + + # Estimate beat interval + beat_interval = self.current_time - self.last_beat_time if self.last_beat_time > 0 else 0.5 + self.last_beat_time = self.current_time + + # Step each scan + for node_id, scan in self.scans.items(): + evaluator = ExprEvaluator(scan.rng) + env = self._build_scan_env(scan) + + # Evaluate step expression + new_value = evaluator.eval(scan.step_expr, env) + scan.value = new_value + + # Update emit + self._update_emit(node_id) + + # Step the cycle state + self._step_cycle() + + def on_frame(self, energy: float, is_beat: bool, t: float = 0.0): + """Called each frame with audio analysis.""" + self.current_time = t + self.energy = energy + + # Update scans on beat (edge detection) + if is_beat and not self.last_beat_detected: + self.on_beat() + self.last_beat_detected = is_beat + + def resolve_binding(self, binding: dict) -> Any: + """Resolve a binding to get the current value.""" + if not isinstance(binding, dict) or not binding.get('_binding'): + return binding + + source_id = binding.get('source') + feature = binding.get('feature', 'values') + range_map = binding.get('range') + + # Get the raw value + if source_id in self.scan_outputs: + value = self.scan_outputs[source_id] + else: + # Might be an analyzer reference - use energy as fallback + value = self.energy + + # Extract feature if value is a dict + if isinstance(value, dict) and feature in value: + value = value[feature] + + # Apply range mapping + if range_map and isinstance(value, (int, float)): + lo, hi = range_map + value = lo + value * (hi - lo) + + return value + + def get_effect_params(self, effect_node: dict) -> Dict[str, Any]: + """Get resolved parameters for an effect node.""" + config = effect_node.get('config', {}) + params = {} + + for key, value in config.items(): + # Skip internal fields + if key in ('effect', 'effect_path', 'effect_cid', 'effects_registry', 'analysis_refs'): + continue + + # Resolve bindings + params[key] = self.resolve_binding(value) + + return params + + def get_scan_value(self, name: str) -> Any: + """Get scan output by name.""" + for node_id, scan in self.scans.items(): + if scan.name == name: + return self.scan_outputs.get(node_id) + return None + + def get_all_scan_values(self) -> Dict[str, Any]: + """Get all named scan outputs.""" + result = {} + for node_id, scan in self.scans.items(): + if scan.name: + result[scan.name] = self.scan_outputs.get(node_id) + return result + + # === Compositor interface methods === + + def _get_video_sources(self) -> List[str]: + """Get list of video source node IDs.""" + sources = [] + for node in self.recipe.nodes: + if node.get('type') == 'SOURCE': + sources.append(node['id']) + # Filter to video only (exclude audio - last one is usually audio) + # Look at file extensions in the paths + return sources[:-1] if len(sources) > 1 else sources + + def _trace_effect_chain(self, start_id: str, stop_at_blend: bool = True) -> List[dict]: + """Trace effect chain from a node, returning effects in order.""" + chain = [] + current_id = start_id + + for _ in range(20): # Max depth + # Find node that uses current as input + next_node = None + for node in self.recipe.nodes: + if current_id in node.get('inputs', []): + if node.get('type') == 'EFFECT': + effect_type = node.get('config', {}).get('effect') + chain.append(node) + if stop_at_blend and effect_type == 'blend': + return chain + next_node = node + break + elif node.get('type') == 'SEGMENT': + next_node = node + break + + if next_node is None: + break + current_id = next_node['id'] + + return chain + + def _find_clip_chains(self, source_idx: int) -> tuple: + """Find effect chains for clip A and B from a source.""" + sources = self._get_video_sources() + if source_idx >= len(sources): + return [], [] + + source_id = sources[source_idx] + + # Find SEGMENT node + segment_id = None + for node in self.recipe.nodes: + if node.get('type') == 'SEGMENT' and source_id in node.get('inputs', []): + segment_id = node['id'] + break + + if not segment_id: + return [], [] + + # Find the two effect chains from segment (clip A and clip B) + chains = [] + for node in self.recipe.nodes: + if segment_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + chain = self._trace_effect_chain(segment_id) + # Get chain starting from this specific branch + branch_chain = [node] + current = node['id'] + for _ in range(10): + found = False + for n in self.recipe.nodes: + if current in n.get('inputs', []) and n.get('type') == 'EFFECT': + branch_chain.append(n) + if n.get('config', {}).get('effect') == 'blend': + break + current = n['id'] + found = True + break + if not found: + break + chains.append(branch_chain) + + # Return first two chains as A and B + chain_a = chains[0] if len(chains) > 0 else [] + chain_b = chains[1] if len(chains) > 1 else [] + return chain_a, chain_b + + def get_effect_params(self, source_idx: int, clip: str, energy: float) -> Dict: + """Get effect parameters for a source clip (compositor interface).""" + # Get the correct chain for this clip + chain_a, chain_b = self._find_clip_chains(source_idx) + chain = chain_a if clip == 'a' else chain_b + + # Default params + params = { + "rotate_angle": 0, + "zoom_amount": 1.0, + "invert_amount": 0, + "hue_degrees": 0, + "ascii_mix": 0, + "ascii_char_size": 8, + } + + # Resolve from effects in chain + for eff in chain: + config = eff.get('config', {}) + effect_type = config.get('effect') + + if effect_type == 'rotate': + angle_binding = config.get('angle') + if angle_binding: + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + # Bound to analyzer - use energy with range + range_map = angle_binding.get('range') + if range_map: + lo, hi = range_map + params["rotate_angle"] = lo + energy * (hi - lo) + else: + params["rotate_angle"] = self.resolve_binding(angle_binding) + else: + params["rotate_angle"] = angle_binding if isinstance(angle_binding, (int, float)) else 0 + + elif effect_type == 'zoom': + amount_binding = config.get('amount') + if amount_binding: + if isinstance(amount_binding, dict) and amount_binding.get('_binding'): + range_map = amount_binding.get('range') + if range_map: + lo, hi = range_map + params["zoom_amount"] = lo + energy * (hi - lo) + else: + params["zoom_amount"] = self.resolve_binding(amount_binding) + else: + params["zoom_amount"] = amount_binding if isinstance(amount_binding, (int, float)) else 1.0 + + elif effect_type == 'invert': + amount_binding = config.get('amount') + if amount_binding: + val = self.resolve_binding(amount_binding) + params["invert_amount"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'hue_shift': + deg_binding = config.get('degrees') + if deg_binding: + val = self.resolve_binding(deg_binding) + params["hue_degrees"] = val if isinstance(val, (int, float)) else 0 + + elif effect_type == 'ascii_art': + mix_binding = config.get('mix') + if mix_binding: + val = self.resolve_binding(mix_binding) + params["ascii_mix"] = val if isinstance(val, (int, float)) else 0 + size_binding = config.get('char_size') + if size_binding: + if isinstance(size_binding, dict) and size_binding.get('_binding'): + range_map = size_binding.get('range') + if range_map: + lo, hi = range_map + params["ascii_char_size"] = lo + energy * (hi - lo) + + return params + + def get_pair_params(self, source_idx: int) -> Dict: + """Get blend and rotation params for a video pair (compositor interface).""" + params = { + "blend_opacity": 0.5, + "pair_rotation": 0, + } + + # Find the blend node for this source + chain_a, _ = self._find_clip_chains(source_idx) + + # The last effect in chain_a should be the blend + blend_node = None + for eff in reversed(chain_a): + if eff.get('config', {}).get('effect') == 'blend': + blend_node = eff + break + + if blend_node: + config = blend_node.get('config', {}) + opacity_binding = config.get('opacity') + if opacity_binding: + val = self.resolve_binding(opacity_binding) + if isinstance(val, (int, float)): + params["blend_opacity"] = val + + # Find rotate after blend (pair rotation) + blend_id = blend_node['id'] + for node in self.recipe.nodes: + if blend_id in node.get('inputs', []) and node.get('type') == 'EFFECT': + if node.get('config', {}).get('effect') == 'rotate': + angle_binding = node.get('config', {}).get('angle') + if angle_binding: + val = self.resolve_binding(angle_binding) + if isinstance(val, (int, float)): + params["pair_rotation"] = val + break + + return params + + def _get_cycle_state(self) -> dict: + """Get current cycle state from SLICE_ON or internal tracking.""" + if not hasattr(self, '_cycle_state'): + # Initialize from SLICE_ON node + for node in self.recipe.nodes: + if node.get('type') == 'SLICE_ON': + init = node.get('config', {}).get('init', {}) + self._cycle_state = { + 'cycle': init.get('cycle', 0), + 'beat': init.get('beat', 0), + 'clen': init.get('clen', 60), + } + break + else: + self._cycle_state = {'cycle': 0, 'beat': 0, 'clen': 60} + + return self._cycle_state + + def _step_cycle(self): + """Step the cycle state forward on beat by evaluating SLICE_ON Lambda.""" + # Use interpreter to evaluate the Lambda + self._eval_slice_on() + + def get_cycle_weights(self) -> List[float]: + """Get blend weights for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get weights from interpreted result + if self._slice_on_result: + compose = self._slice_on_result.get('compose', {}) + weights = compose.get('weights', []) + if weights and len(weights) == n: + # Normalize + total = sum(weights) + if total > 0: + return [w / total for w in weights] + + # Fallback: equal weights + return [1.0 / n] * n + + def get_cycle_zooms(self) -> List[float]: + """Get zoom amounts for cycle-crossfade from SLICE_ON result.""" + n = len(self._get_video_sources()) + if n == 0: + return [1.0] + + # Get zooms from interpreted result (layers -> effects -> zoom amount) + if self._slice_on_result: + layers = self._slice_on_result.get('layers', []) + if layers and len(layers) == n: + zooms = [] + for layer in layers: + effects = layer.get('effects', []) + zoom_amt = 1.0 + for eff in effects: + if eff.get('effect') == 'zoom' or (hasattr(eff.get('effect'), 'name') and eff.get('effect').name == 'zoom'): + zoom_amt = eff.get('amount', 1.0) + break + zooms.append(zoom_amt) + return zooms + + # Fallback + return [1.0] * n + + def _get_final_rotate_scan_id(self) -> str: + """Find the scan ID that drives the final rotation (after SLICE_ON).""" + if hasattr(self, '_final_rotate_scan_id'): + return self._final_rotate_scan_id + + # Find SLICE_ON node index + slice_on_idx = None + for i, node in enumerate(self.recipe.nodes): + if node.get('type') == 'SLICE_ON': + slice_on_idx = i + break + + # Find rotate effect after SLICE_ON + if slice_on_idx is not None: + for node in self.recipe.nodes[slice_on_idx + 1:]: + if node.get('type') == 'EFFECT': + config = node.get('config', {}) + if config.get('effect') == 'rotate': + angle_binding = config.get('angle', {}) + if isinstance(angle_binding, dict) and angle_binding.get('_binding'): + self._final_rotate_scan_id = angle_binding.get('source') + return self._final_rotate_scan_id + + self._final_rotate_scan_id = None + return None + + def get_final_effects(self, energy: float) -> Dict: + """Get final composition effects (compositor interface).""" + # Get named scans + scan_values = self.get_all_scan_values() + + # Whole spin - get from the specific scan bound to final rotate effect + whole_spin = 0 + final_rotate_scan_id = self._get_final_rotate_scan_id() + if final_rotate_scan_id and final_rotate_scan_id in self.scan_outputs: + val = self.scan_outputs[final_rotate_scan_id] + if isinstance(val, dict) and 'angle' in val: + whole_spin = val['angle'] + elif isinstance(val, (int, float)): + whole_spin = val + + # Ripple + ripple_gate = scan_values.get('ripple-gate', 0) + ripple_cx = scan_values.get('ripple-cx', 0.5) + ripple_cy = scan_values.get('ripple-cy', 0.5) + + if isinstance(ripple_gate, dict): + ripple_gate = ripple_gate.get('gate', 0) if 'gate' in ripple_gate else 1 + + return { + "whole_spin_angle": whole_spin, + "ripple_amplitude": ripple_gate * (5 + energy * 45), + "ripple_cx": ripple_cx if isinstance(ripple_cx, (int, float)) else 0.5, + "ripple_cy": ripple_cy if isinstance(ripple_cy, (int, float)) else 0.5, + } diff --git a/streaming/sexp_interp.py b/streaming/sexp_interp.py new file mode 100644 index 0000000..e3433b2 --- /dev/null +++ b/streaming/sexp_interp.py @@ -0,0 +1,376 @@ +""" +S-expression interpreter for streaming execution. + +Evaluates sexp expressions including: +- let bindings +- lambda definitions and calls +- Arithmetic, comparison, logic operators +- dict/list operations +- Random number generation +""" + +import random +from typing import Any, Dict, List, Callable +from dataclasses import dataclass + + +@dataclass +class Lambda: + """Runtime lambda value.""" + params: List[str] + body: Any + closure: Dict[str, Any] + + +class Symbol: + """Symbol reference.""" + def __init__(self, name: str): + self.name = name + + def __repr__(self): + return f"Symbol({self.name})" + + +class SexpInterpreter: + """ + Interprets S-expressions in real-time. + + Handles the full sexp language used in recipes. + """ + + def __init__(self, rng: random.Random = None): + self.rng = rng or random.Random() + self.globals: Dict[str, Any] = {} + + def eval(self, expr: Any, env: Dict[str, Any] = None) -> Any: + """Evaluate an expression in the given environment.""" + if env is None: + env = {} + + # Literals + if isinstance(expr, (int, float, str, bool)) or expr is None: + return expr + + # Symbol lookup + if isinstance(expr, Symbol) or (hasattr(expr, 'name') and hasattr(expr, '__class__') and expr.__class__.__name__ == 'Symbol'): + name = expr.name if hasattr(expr, 'name') else str(expr) + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined symbol: {name}") + + # Compiled expression dict (from compiler) + if isinstance(expr, dict): + if expr.get('_expr'): + return self._eval_compiled_expr(expr, env) + # Plain dict - evaluate values that might be expressions + result = {} + for k, v in expr.items(): + # Some keys should keep Symbol values as strings (effect names, modes) + if k in ('effect', 'mode') and hasattr(v, 'name'): + result[k] = v.name + else: + result[k] = self.eval(v, env) + return result + + # List expression (sexp) + if isinstance(expr, (list, tuple)) and len(expr) > 0: + return self._eval_list(expr, env) + + # Empty list + if isinstance(expr, (list, tuple)): + return [] + + return expr + + def _eval_compiled_expr(self, expr: dict, env: Dict[str, Any]) -> Any: + """Evaluate a compiled expression dict.""" + op = expr.get('op') + args = expr.get('args', []) + + if op == 'var': + name = expr.get('name') + if name in env: + return env[name] + if name in self.globals: + return self.globals[name] + raise NameError(f"Undefined: {name}") + + elif op == 'dict': + keys = expr.get('keys', []) + values = [self.eval(a, env) for a in args] + return dict(zip(keys, values)) + + elif op == 'get': + obj = self.eval(args[0], env) + key = args[1] + return obj.get(key) if isinstance(obj, dict) else obj[key] + + elif op == 'if': + cond = self.eval(args[0], env) + if cond: + return self.eval(args[1], env) + elif len(args) > 2: + return self.eval(args[2], env) + return None + + # Comparison + elif op == '<': + return self.eval(args[0], env) < self.eval(args[1], env) + elif op == '>': + return self.eval(args[0], env) > self.eval(args[1], env) + elif op == '<=': + return self.eval(args[0], env) <= self.eval(args[1], env) + elif op == '>=': + return self.eval(args[0], env) >= self.eval(args[1], env) + elif op == '=': + return self.eval(args[0], env) == self.eval(args[1], env) + elif op == '!=': + return self.eval(args[0], env) != self.eval(args[1], env) + + # Arithmetic + elif op == '+': + return self.eval(args[0], env) + self.eval(args[1], env) + elif op == '-': + return self.eval(args[0], env) - self.eval(args[1], env) + elif op == '*': + return self.eval(args[0], env) * self.eval(args[1], env) + elif op == '/': + return self.eval(args[0], env) / self.eval(args[1], env) + elif op == 'mod': + return self.eval(args[0], env) % self.eval(args[1], env) + + # Random + elif op == 'rand': + return self.rng.random() + elif op == 'rand-int': + return self.rng.randint(self.eval(args[0], env), self.eval(args[1], env)) + elif op == 'rand-range': + return self.rng.uniform(self.eval(args[0], env), self.eval(args[1], env)) + + # Logic + elif op == 'and': + return all(self.eval(a, env) for a in args) + elif op == 'or': + return any(self.eval(a, env) for a in args) + elif op == 'not': + return not self.eval(args[0], env) + + else: + raise ValueError(f"Unknown op: {op}") + + def _eval_list(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate a list expression (sexp form).""" + if len(expr) == 0: + return [] + + head = expr[0] + + # Get head name + if isinstance(head, Symbol) or (hasattr(head, 'name') and hasattr(head, '__class__')): + head_name = head.name if hasattr(head, 'name') else str(head) + elif isinstance(head, str): + head_name = head + else: + # Not a symbol - check if it's a data list or function call + if isinstance(head, dict): + # List of dicts - evaluate each element as data + return [self.eval(item, env) for item in expr] + # Otherwise evaluate as function call + fn = self.eval(head, env) + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + # Special forms + if head_name == 'let': + return self._eval_let(expr, env) + elif head_name in ('lambda', 'fn'): + return self._eval_lambda(expr, env) + elif head_name == 'if': + return self._eval_if(expr, env) + elif head_name == 'dict': + return self._eval_dict(expr, env) + elif head_name == 'get': + obj = self.eval(expr[1], env) + key = self.eval(expr[2], env) if len(expr) > 2 else expr[2] + if isinstance(key, str): + return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None) + return obj[key] + elif head_name == 'len': + return len(self.eval(expr[1], env)) + elif head_name == 'range': + start = self.eval(expr[1], env) + end = self.eval(expr[2], env) if len(expr) > 2 else start + if len(expr) == 2: + return list(range(end)) + return list(range(start, end)) + elif head_name == 'map': + fn = self.eval(expr[1], env) + lst = self.eval(expr[2], env) + return [self._call(fn, [x], env) for x in lst] + elif head_name == 'mod': + return self.eval(expr[1], env) % self.eval(expr[2], env) + + # Arithmetic + elif head_name == '+': + return self.eval(expr[1], env) + self.eval(expr[2], env) + elif head_name == '-': + if len(expr) == 2: + return -self.eval(expr[1], env) + return self.eval(expr[1], env) - self.eval(expr[2], env) + elif head_name == '*': + return self.eval(expr[1], env) * self.eval(expr[2], env) + elif head_name == '/': + return self.eval(expr[1], env) / self.eval(expr[2], env) + + # Comparison + elif head_name == '<': + return self.eval(expr[1], env) < self.eval(expr[2], env) + elif head_name == '>': + return self.eval(expr[1], env) > self.eval(expr[2], env) + elif head_name == '<=': + return self.eval(expr[1], env) <= self.eval(expr[2], env) + elif head_name == '>=': + return self.eval(expr[1], env) >= self.eval(expr[2], env) + elif head_name == '=': + return self.eval(expr[1], env) == self.eval(expr[2], env) + + # Logic + elif head_name == 'and': + return all(self.eval(a, env) for a in expr[1:]) + elif head_name == 'or': + return any(self.eval(a, env) for a in expr[1:]) + elif head_name == 'not': + return not self.eval(expr[1], env) + + # Function call + else: + fn = env.get(head_name) or self.globals.get(head_name) + if fn is None: + raise NameError(f"Undefined function: {head_name}") + args = [self.eval(a, env) for a in expr[1:]] + return self._call(fn, args, env) + + def _eval_let(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (let [bindings...] body).""" + bindings = expr[1] + body = expr[2] + + # Create new environment with bindings + new_env = dict(env) + + # Process bindings in pairs + i = 0 + while i < len(bindings): + name = bindings[i] + if isinstance(name, Symbol) or hasattr(name, 'name'): + name = name.name if hasattr(name, 'name') else str(name) + value = self.eval(bindings[i + 1], new_env) + new_env[name] = value + i += 2 + + return self.eval(body, new_env) + + def _eval_lambda(self, expr: list, env: Dict[str, Any]) -> Lambda: + """Evaluate (lambda [params] body).""" + params_expr = expr[1] + body = expr[2] + + # Extract parameter names + params = [] + for p in params_expr: + if isinstance(p, Symbol) or hasattr(p, 'name'): + params.append(p.name if hasattr(p, 'name') else str(p)) + else: + params.append(str(p)) + + return Lambda(params=params, body=body, closure=dict(env)) + + def _eval_if(self, expr: list, env: Dict[str, Any]) -> Any: + """Evaluate (if cond then else).""" + cond = self.eval(expr[1], env) + if cond: + return self.eval(expr[2], env) + elif len(expr) > 3: + return self.eval(expr[3], env) + return None + + def _eval_dict(self, expr: list, env: Dict[str, Any]) -> dict: + """Evaluate (dict :key val ...).""" + result = {} + i = 1 + while i < len(expr): + key = expr[i] + # Handle keyword syntax (:key) and Keyword objects + if hasattr(key, 'name'): + key = key.name + elif hasattr(key, '__class__') and key.__class__.__name__ == 'Keyword': + key = str(key).lstrip(':') + elif isinstance(key, str) and key.startswith(':'): + key = key[1:] + value = self.eval(expr[i + 1], env) + result[key] = value + i += 2 + return result + + def _call(self, fn: Any, args: List[Any], env: Dict[str, Any]) -> Any: + """Call a function with arguments.""" + if isinstance(fn, Lambda): + # Our own Lambda type + call_env = dict(fn.closure) + for param, arg in zip(fn.params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif hasattr(fn, 'params') and hasattr(fn, 'body'): + # Lambda from parser (artdag.sexp.parser.Lambda) + call_env = dict(env) + if hasattr(fn, 'closure') and fn.closure: + call_env.update(fn.closure) + # Get param names + params = [] + for p in fn.params: + if hasattr(p, 'name'): + params.append(p.name) + else: + params.append(str(p)) + for param, arg in zip(params, args): + call_env[param] = arg + return self.eval(fn.body, call_env) + elif callable(fn): + return fn(*args) + else: + raise TypeError(f"Not callable: {type(fn).__name__}") + + +def eval_slice_on_lambda(lambda_obj, acc: dict, i: int, start: float, end: float, + videos: list, interp: SexpInterpreter = None) -> dict: + """ + Evaluate a SLICE_ON lambda function. + + Args: + lambda_obj: The Lambda object from the compiled recipe + acc: Current accumulator state + i: Beat index + start: Slice start time + end: Slice end time + videos: List of video inputs + interp: Interpreter to use + + Returns: + Dict with 'layers', 'compose', 'acc' keys + """ + if interp is None: + interp = SexpInterpreter() + + # Set up global 'videos' for (len videos) to work + interp.globals['videos'] = videos + + # Build initial environment with lambda parameters + env = dict(lambda_obj.closure) if hasattr(lambda_obj, 'closure') and lambda_obj.closure else {} + env['videos'] = videos + + # Call the lambda + result = interp._call(lambda_obj, [acc, i, start, end], env) + + return result diff --git a/streaming/sources.py b/streaming/sources.py new file mode 100644 index 0000000..71e7e53 --- /dev/null +++ b/streaming/sources.py @@ -0,0 +1,281 @@ +""" +Video and image sources with looping support. +""" + +import numpy as np +import subprocess +import json +from pathlib import Path +from typing import Optional, Tuple +from abc import ABC, abstractmethod + + +class Source(ABC): + """Abstract base class for frame sources.""" + + @abstractmethod + def read_frame(self, t: float) -> np.ndarray: + """Read frame at time t (with looping if needed).""" + pass + + @property + @abstractmethod + def duration(self) -> float: + """Source duration in seconds.""" + pass + + @property + @abstractmethod + def size(self) -> Tuple[int, int]: + """Frame size as (width, height).""" + pass + + @property + @abstractmethod + def fps(self) -> float: + """Frames per second.""" + pass + + +class VideoSource(Source): + """ + Video file source with automatic looping. + + Reads frames on-demand, seeking as needed. When time exceeds + duration, wraps around (loops). + """ + + def __init__(self, path: str, target_fps: float = 30): + self.path = Path(path) + self.target_fps = target_fps + + # Initialize decode state first (before _probe which could fail) + self._process: Optional[subprocess.Popen] = None + self._current_start: Optional[float] = None + self._frame_buffer: Optional[np.ndarray] = None + self._buffer_time: Optional[float] = None + + self._duration = None + self._size = None + self._fps = None + + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + self._probe() + + def _probe(self): + """Get video metadata.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_format", "-show_streams", + str(self.path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + data = json.loads(result.stdout) + + # Get duration + self._duration = float(data["format"]["duration"]) + + # Get video stream info + for stream in data["streams"]: + if stream["codec_type"] == "video": + self._size = (int(stream["width"]), int(stream["height"])) + # Parse fps from r_frame_rate (e.g., "30/1" or "30000/1001") + fps_parts = stream.get("r_frame_rate", "30/1").split("/") + self._fps = float(fps_parts[0]) / float(fps_parts[1]) + break + + @property + def duration(self) -> float: + return self._duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def _start_decode(self, start_time: float): + """Start ffmpeg decode process from given time.""" + if self._process: + try: + self._process.stdout.close() + except: + pass + self._process.terminate() + try: + self._process.wait(timeout=1) + except: + self._process.kill() + self._process.wait() + + w, h = self._size + cmd = [ + "ffmpeg", "-v", "quiet", + "-ss", str(start_time), + "-i", str(self.path), + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-r", str(self.target_fps), + "-" + ] + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + bufsize=w * h * 3 * 4, # Buffer a few frames + ) + self._current_start = start_time + self._buffer_time = start_time + + def read_frame(self, t: float) -> np.ndarray: + """ + Read frame at time t. + + If t exceeds duration, wraps around (loops). + Seeks if needed, otherwise reads sequentially. + """ + # Wrap time for looping + t_wrapped = t % self._duration + + # Check if we need to seek (loop point or large time jump) + need_seek = ( + self._process is None or + self._buffer_time is None or + abs(t_wrapped - self._buffer_time) > 1.0 / self.target_fps * 2 + ) + + if need_seek: + self._start_decode(t_wrapped) + + # Read frame + w, h = self._size + frame_size = w * h * 3 + + # Try to read with retries for seek settling + for attempt in range(3): + raw = self._process.stdout.read(frame_size) + if len(raw) == frame_size: + break + # End of stream or seek not ready - restart from beginning + self._start_decode(0) + + if len(raw) < frame_size: + # Still no data - return last frame or black + if self._frame_buffer is not None: + return self._frame_buffer.copy() + return np.zeros((h, w, 3), dtype=np.uint8) + + frame = np.frombuffer(raw, dtype=np.uint8).reshape((h, w, 3)) + self._frame_buffer = frame # Cache for fallback + self._buffer_time = t_wrapped + 1.0 / self.target_fps + + return frame + + def close(self): + """Clean up resources.""" + if self._process: + self._process.terminate() + self._process.wait() + self._process = None + + def __del__(self): + self.close() + + def __repr__(self): + return f"VideoSource({self.path.name}, {self._size[0]}x{self._size[1]}, {self._duration:.1f}s)" + + +class ImageSource(Source): + """ + Static image source (returns same frame for any time). + + Useful for backgrounds, overlays, etc. + """ + + def __init__(self, path: str): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Image not found: {path}") + + # Load image + import cv2 + self._frame = cv2.imread(str(self.path)) + self._frame = cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB) + self._size = (self._frame.shape[1], self._frame.shape[0]) + + @property + def duration(self) -> float: + return float('inf') # Images last forever + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return 30.0 # Arbitrary + + def read_frame(self, t: float) -> np.ndarray: + return self._frame.copy() + + def __repr__(self): + return f"ImageSource({self.path.name}, {self._size[0]}x{self._size[1]})" + + +class LiveSource(Source): + """ + Live video capture source (webcam, capture card, etc.). + + Time parameter is ignored - always returns latest frame. + """ + + def __init__(self, device: int = 0, size: Tuple[int, int] = (1280, 720), fps: float = 30): + import cv2 + self._cap = cv2.VideoCapture(device) + self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, size[0]) + self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, size[1]) + self._cap.set(cv2.CAP_PROP_FPS, fps) + + # Get actual settings + self._size = ( + int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + ) + self._fps = self._cap.get(cv2.CAP_PROP_FPS) + + if not self._cap.isOpened(): + raise RuntimeError(f"Could not open video device {device}") + + @property + def duration(self) -> float: + return float('inf') # Live - no duration + + @property + def size(self) -> Tuple[int, int]: + return self._size + + @property + def fps(self) -> float: + return self._fps + + def read_frame(self, t: float) -> np.ndarray: + """Read latest frame (t is ignored for live sources).""" + import cv2 + ret, frame = self._cap.read() + if not ret: + return np.zeros((self._size[1], self._size[0], 3), dtype=np.uint8) + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + def close(self): + self._cap.release() + + def __del__(self): + self.close() + + def __repr__(self): + return f"LiveSource({self._size[0]}x{self._size[1]}, {self._fps}fps)" diff --git a/streaming/stream_sexp.py b/streaming/stream_sexp.py new file mode 100644 index 0000000..b36dabf --- /dev/null +++ b/streaming/stream_sexp.py @@ -0,0 +1,1081 @@ +""" +Generic Streaming S-expression Interpreter. + +Executes streaming sexp recipes frame-by-frame. +The sexp defines the pipeline logic - interpreter just provides primitives. + +Primitives: + (read source-name) - read frame from source + (rotate frame :angle N) - rotate frame + (zoom frame :amount N) - zoom frame + (invert frame :amount N) - invert colors + (hue-shift frame :degrees N) - shift hue + (blend frame1 frame2 :opacity N) - blend two frames + (blend-weighted [frames...] [weights...]) - weighted blend + (ripple frame :amplitude N :cx N :cy N ...) - ripple effect + + (bind scan-name :field) - get scan state field + (map value [lo hi]) - map 0-1 value to range + energy - current energy (0-1) + beat - 1 if beat, 0 otherwise + t - current time + beat-count - total beats so far + +Example sexp: + (stream "test" + :fps 30 + (source vid "video.mp4") + (audio aud "music.mp3") + + (scan spin beat + :init {:angle 0 :dir 1} + :step (dict :angle (+ angle (* dir 10)) :dir dir)) + + (frame + (-> (read vid) + (rotate :angle (bind spin :angle)) + (zoom :amount (map energy [1 1.5]))))) +""" + +import sys +import time +import json +import hashlib +import numpy as np +import subprocess +from pathlib import Path +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional, Tuple, Union + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) +from artdag.sexp.parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class StreamContext: + """Runtime context for streaming.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + energy: float = 0.0 + is_beat: bool = False + beat_count: int = 0 + output_size: Tuple[int, int] = (720, 720) + + +class StreamCache: + """Cache for streaming data.""" + + def __init__(self, cache_dir: Path, recipe_hash: str): + self.cache_dir = cache_dir / recipe_hash + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.analysis_buffer: Dict[str, List] = {} + self.scan_states: Dict[str, List] = {} + self.keyframe_interval = 5.0 + + def record_analysis(self, name: str, t: float, value: float): + if name not in self.analysis_buffer: + self.analysis_buffer[name] = [] + t = float(t) if hasattr(t, 'item') else t + value = float(value) if hasattr(value, 'item') else value + self.analysis_buffer[name].append((t, value)) + + def record_scan_state(self, name: str, t: float, state: dict): + if name not in self.scan_states: + self.scan_states[name] = [] + states = self.scan_states[name] + if not states or t - states[-1][0] >= self.keyframe_interval: + t = float(t) if hasattr(t, 'item') else t + clean = {k: (float(v) if hasattr(v, 'item') else v) for k, v in state.items()} + self.scan_states[name].append((t, clean)) + + def flush(self): + for name, data in self.analysis_buffer.items(): + path = self.cache_dir / f"analysis_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(data) + path.write_text(json.dumps(existing)) + self.analysis_buffer.clear() + + for name, states in self.scan_states.items(): + path = self.cache_dir / f"scan_{name}.json" + existing = json.loads(path.read_text()) if path.exists() else [] + existing.extend(states) + path.write_text(json.dumps(existing)) + self.scan_states.clear() + + +class VideoSource: + """Video source - reads frames sequentially.""" + + def __init__(self, path: str, fps: float = 30): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"Video not found: {path}") + + # Get info + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + + for s in info.get("streams", []): + if s.get("codec_type") == "video": + self.width = s.get("width", 720) + self.height = s.get("height", 720) + break + else: + self.width, self.height = 720, 720 + + self.duration = float(info.get("format", {}).get("duration", 60)) + self.size = (self.width, self.height) + + # Start decoder + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "rawvideo", "-pix_fmt", "rgb24", "-r", str(fps), "-"] + self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + self._frame_size = self.width * self.height * 3 + self._current_frame = None + + def read(self) -> Optional[np.ndarray]: + """Read next frame.""" + data = self._proc.stdout.read(self._frame_size) + if len(data) < self._frame_size: + return self._current_frame # Return last frame if stream ends + self._current_frame = np.frombuffer(data, dtype=np.uint8).reshape( + self.height, self.width, 3).copy() + return self._current_frame + + def skip(self): + """Read and discard frame (keep pipe in sync).""" + self._proc.stdout.read(self._frame_size) + + def close(self): + if self._proc: + self._proc.terminate() + self._proc.wait() + + +class AudioAnalyzer: + """Real-time audio analysis.""" + + def __init__(self, path: str, sample_rate: int = 22050): + self.path = Path(path) + + # Load audio + cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), + "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] + self._audio = np.frombuffer( + subprocess.run(cmd, capture_output=True).stdout, dtype=np.float32) + self.sample_rate = sample_rate + + # Get duration + cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", + "-show_format", str(self.path)] + info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) + self.duration = float(info.get("format", {}).get("duration", 60)) + + self._flux_history = [] + self._last_beat_time = -1 + + def get_energy(self, t: float) -> float: + idx = int(t * self.sample_rate) + start = max(0, idx - 512) + end = min(len(self._audio), idx + 512) + if start >= end: + return 0.0 + return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0) + + def get_beat(self, t: float) -> bool: + idx = int(t * self.sample_rate) + size = 2048 + + start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2) + if end - start < size//2: + return False + curr = self._audio[start:end] + + pstart, pend = max(0, start - 512), max(0, end - 512) + if pend <= pstart: + return False + prev = self._audio[pstart:pend] + + curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr)))) + prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev)))) + + n = min(len(curr_spec), len(prev_spec)) + flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1) + + self._flux_history.append((t, flux)) + while self._flux_history and self._flux_history[0][0] < t - 1.5: + self._flux_history.pop(0) + + if len(self._flux_history) < 3: + return False + + vals = [f for _, f in self._flux_history] + threshold = np.mean(vals) + np.std(vals) * 0.3 + 0.001 + + is_beat = flux > threshold and t - self._last_beat_time > 0.1 + if is_beat: + self._last_beat_time = t + return is_beat + + +class StreamInterpreter: + """ + Generic streaming sexp interpreter. + + Evaluates the frame pipeline expression each frame. + """ + + def __init__(self, sexp_path: str, cache_dir: str = None): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + recipe_hash = hashlib.sha256(text.encode()).hexdigest()[:16] + cache_path = Path(cache_dir) if cache_dir else self.sexp_dir / ".stream_cache" + self.cache = StreamCache(cache_path, recipe_hash) + + self.ctx = StreamContext(fps=self.config.get('fps', 30)) + self.sources: Dict[str, VideoSource] = {} + self.frames: Dict[str, np.ndarray] = {} # Current frame per source + self._sources_read: set = set() # Track which sources read this frame + self.audios: Dict[str, AudioAnalyzer] = {} # Multiple named audio sources + self.audio_paths: Dict[str, str] = {} + self.audio_state: Dict[str, dict] = {} # Per-audio: {energy, is_beat, beat_count, last_beat} + self.scans: Dict[str, dict] = {} + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} # name -> Python function + self.effects: Dict[str, dict] = {} # name -> {params, body} + self.macros: Dict[str, dict] = {} # name -> {params, body} + self.primitive_lib_dir = self.sexp_dir.parent / "sexp_effects" / "primitive_libs" + + self.frame_pipeline = None # The (frame ...) expression + + import random + self.rng = random.Random(self.config.get('seed', 42)) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file.""" + import importlib.util + + # Try multiple paths + lib_paths = [ + self.primitive_lib_dir / f"{lib_name}.py", + self.sexp_dir / "primitive_libs" / f"{lib_name}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{lib_name}.py", + ] + + lib_path = None + for p in lib_paths: + if p.exists(): + lib_path = p + break + + if not lib_path: + print(f"Warning: primitive library '{lib_name}' not found", file=sys.stderr) + return + + spec = importlib.util.spec_from_file_location(lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Extract all prim_* functions + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] # Remove 'prim_' prefix + self.primitives[prim_name] = func + # Also register with dashes instead of underscores + dash_name = prim_name.replace('_', '-') + self.primitives[dash_name] = func + # Also register with -img suffix (sexp convention) + self.primitives[dash_name + '-img'] = func + count += 1 + + # Also check for PRIMITIVES dict (some modules use this for additional exports) + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + self.primitives[name] = func + # Also register underscore version + underscore_name = name.replace('-', '_') + self.primitives[underscore_name] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions)", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + print(f"Warning: effect file not found: {effect_path}", file=sys.stderr) + return + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + # Parse params list + params_list = form[i + 1] + for p in params_list: + if isinstance(p, list) and p: + pname = p[0].name if isinstance(p[0], Symbol) else str(p[0]) + pdef = {'default': 0} + j = 1 + while j < len(p): + if isinstance(p[j], Keyword): + pdef[p[j].name] = p[j + 1] if j + 1 < len(p) else None + j += 2 + else: + j += 1 + params[pname] = pdef + i += 2 + else: + i += 2 + else: + # Body expression + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body, 'path': str(effect_path)} + print(f"Effect: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [] + body = None + + if len(form) > 2 and isinstance(form[2], list): + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + if len(form) > 3: + body = form[3] + + self.macros[name] = {'params': params, 'body': body} + print(f"Macro: {name}", file=sys.stderr) + + def _init(self): + """Initialize sources, scans, and pipeline from sexp.""" + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + # === External loading === + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + # (effect name :path "...") + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'include': + # (include :path "...") + i = 1 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) # Reuse effect loader for includes + i += 2 + else: + i += 1 + + # === Sources === + + elif cmd == 'source': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.sources[name] = VideoSource(str(full), self.ctx.fps) + print(f"Source: {name} -> {full}", file=sys.stderr) + else: + print(f"Warning: {full} not found", file=sys.stderr) + + elif cmd == 'audio': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + path = str(form[2]).strip('"') + full = (self.sexp_dir / path).resolve() + if full.exists(): + self.audios[name] = AudioAnalyzer(str(full)) + self.audio_paths[name] = str(full) + self.audio_state[name] = {'energy': 0.0, 'is_beat': False, 'beat_count': 0, 'last_beat': False} + print(f"Audio: {name} -> {full}", file=sys.stderr) + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + # Trigger can be: + # (beat audio-name) - trigger on beat from specific audio + # beat - legacy: trigger on beat from first audio + trigger_expr = form[2] + if isinstance(trigger_expr, list) and len(trigger_expr) >= 2: + # (beat audio-name) + trigger_type = trigger_expr[0].name if isinstance(trigger_expr[0], Symbol) else str(trigger_expr[0]) + trigger_audio = trigger_expr[1].name if isinstance(trigger_expr[1], Symbol) else str(trigger_expr[1]) + trigger = (trigger_type, trigger_audio) + else: + # Legacy bare symbol + trigger = trigger_expr.name if isinstance(trigger_expr, Symbol) else str(trigger_expr) + + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], {}) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger, + } + trigger_str = f"{trigger[0]} {trigger[1]}" if isinstance(trigger, tuple) else trigger + print(f"Scan: {name} (on {trigger_str})", file=sys.stderr) + + elif cmd == 'frame': + # (frame expr) - the pipeline to evaluate each frame + self.frame_pipeline = form[1] if len(form) > 1 else None + + # Set output size from first source + if self.sources: + first = next(iter(self.sources.values())) + self.ctx.output_size = first.size + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + import cv2 + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, Symbol): + name = expr.name + # Built-in values + if name == 't' or name == '_time': + return self.ctx.t + if name == 'pi': + import math + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + return 0 + + if isinstance(expr, Keyword): + return expr.name + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check if op is a closure in environment + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + # Invoke closure + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + # Insert result as first arg + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Audio analysis (explicit) === + + if op == 'energy': + # (energy audio-name) - get current energy from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['energy'] + return 0.0 + + if op == 'beat': + # (beat audio-name) - 1 if beat this frame, 0 otherwise + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return 1.0 if self.audio_state[audio_name]['is_beat'] else 0.0 + return 0.0 + + if op == 'beat-count': + # (beat-count audio-name) - total beats from named audio + audio_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if audio_name in self.audio_state: + return self.audio_state[audio_name]['beat_count'] + return 0 + + # === Frame operations === + + if op == 'read': + # (read source-name) - get current frame from source (lazy read) + name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if name not in self.frames: + if name in self.sources: + self.frames[name] = self.sources[name].read() + self._sources_read.add(name) + return self.frames.get(name) + + # === Binding and mapping === + + if op == 'bind': + # (bind scan-name :field) or (bind scan-name) + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + field = None + if len(args) > 1 and isinstance(args[1], Keyword): + field = args[1].name + + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if field: + return state.get(field, 0) + return state + return 0 + + if op == 'map': + # (map value [lo hi]) + val = self._eval(args[0], env) + range_list = self._eval(args[1], env) if len(args) > 1 else [0, 1] + if isinstance(range_list, list) and len(range_list) >= 2: + lo, hi = range_list[0], range_list[1] + return lo + val * (hi - lo) + return val + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + if op == 'map-range': + # (map-range val from-lo from-hi to-lo to-hi) + val = self._eval(args[0], env) + from_lo = self._eval(args[1], env) + from_hi = self._eval(args[2], env) + to_lo = self._eval(args[3], env) + to_hi = self._eval(args[4], env) + # Normalize val to 0-1 in source range, then scale to target range + if from_hi == from_lo: + return to_lo + t = (val - from_lo) / (from_hi - from_lo) + return to_lo + t * (to_hi - to_lo) + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + # Lisp-style or: returns first truthy value, or last value if none truthy + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + # (cond pred1 expr1 pred2 expr2 ... true else-expr) + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + # (lambda (params...) body) - create a closure + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + # Return a closure dict that captures the current env + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + # Support both formats: + # (let [name val name val ...] body) - flat vector + # (let ((name val) (name val) ...) body) - nested list + # Note: our let already evaluates sequentially like let* + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + # Nested format: ((name val) (name val) ...) + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + # Flat format: [name val name val ...] + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Random === + + if op == 'rand': + return self.rng.random() + if op == 'rand-int': + lo = int(self._eval(args[0], env)) + hi = int(self._eval(args[1], env)) + return self.rng.randint(lo, hi) + if op == 'rand-range': + lo = self._eval(args[0], env) + hi = self._eval(args[1], env) + return lo + self.rng.random() * (hi - lo) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + result[args[i].name] = self._eval(args[i + 1], env) if i + 1 < len(args) else None + i += 2 + else: + i += 1 + return result + + if op == 'get': + d = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(d, dict): + return d.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, list) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + lst = self._eval(args[0], env) + return len(lst) if isinstance(lst, (list, dict, str)) else 0 + + # === External effects === + if op in self.effects: + effect = self.effects[op] + effect_env = dict(env) + effect_env['t'] = self.ctx.t + + # Set defaults for all params + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + # Parse args: first is frame, then positional params, then kwargs + positional_idx = 0 + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + # Keyword arg + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + # Positional arg + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + return self._eval(effect['body'], effect_env) + + # === External primitives === + if op in self.primitives: + prim_func = self.primitives[op] + # Evaluate all args + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + # Call primitive + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # === Macros === + if op in self.macros: + macro = self.macros[op] + # Bind macro params to args (unevaluated) + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + macro_env[pname] = args[i] if i < len(args) else None + # Expand and evaluate + return self._eval(macro['body'], macro_env) + + # === Primitive-style call (name-with-dashes -> prim_name_with_underscores) === + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # Unknown - return as-is + return expr + + def _step_scans(self): + """Step scans on beat from specific audio.""" + for name, scan in self.scans.items(): + trigger = scan['trigger'] + + # Check if this scan should step + should_step = False + audio_name = None + + if isinstance(trigger, tuple) and trigger[0] == 'beat': + # Explicit: (beat audio-name) + audio_name = trigger[1] + if audio_name in self.audio_state: + should_step = self.audio_state[audio_name]['is_beat'] + elif trigger == 'beat': + # Legacy: use first audio + if self.audio_state: + audio_name = next(iter(self.audio_state)) + should_step = self.audio_state[audio_name]['is_beat'] + + if should_step and audio_name: + state = self.audio_state[audio_name] + env = dict(scan['state']) + env['beat_count'] = state['beat_count'] + env['t'] = self.ctx.t + env['energy'] = state['energy'] + + if scan['step']: + new_state = self._eval(scan['step'], env) + if isinstance(new_state, dict): + scan['state'] = new_state + elif new_state is not None: + scan['state'] = {'acc': new_state} + + self.cache.record_scan_state(name, self.ctx.t, scan['state']) + + def run(self, duration: float = None, output: str = "pipe"): + """Run the streaming pipeline.""" + from .output import PipeOutput, DisplayOutput, FileOutput + + self._init() + + if not self.sources: + print("Error: no sources", file=sys.stderr) + return + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w, h = self.ctx.output_size + + # Duration from first audio or default + if duration is None: + if self.audios: + first_audio = next(iter(self.audios.values())) + duration = first_audio.duration + else: + duration = 60.0 + + n_frames = int(duration * self.ctx.fps) + frame_time = 1.0 / self.ctx.fps + + print(f"Streaming {n_frames} frames @ {self.ctx.fps}fps", file=sys.stderr) + + # Use first audio for playback sync + first_audio_path = next(iter(self.audio_paths.values())) if self.audio_paths else None + + # Output + if output == "pipe": + out = PipeOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + else: + out = FileOutput(output, size=(w, h), fps=self.ctx.fps, + audio_source=first_audio_path) + + try: + for frame_num in range(n_frames): + if not out.is_open: + print(f"\nOutput closed at {frame_num}", file=sys.stderr) + break + + self.ctx.t = frame_num * frame_time + self.ctx.frame_num = frame_num + + # Update all audio states + for audio_name, analyzer in self.audios.items(): + state = self.audio_state[audio_name] + energy = analyzer.get_energy(self.ctx.t) + is_beat_raw = analyzer.get_beat(self.ctx.t) + is_beat = is_beat_raw and not state['last_beat'] + state['last_beat'] = is_beat_raw + + state['energy'] = energy + state['is_beat'] = is_beat + if is_beat: + state['beat_count'] += 1 + + self.cache.record_analysis(f'{audio_name}_energy', self.ctx.t, energy) + self.cache.record_analysis(f'{audio_name}_beat', self.ctx.t, 1.0 if is_beat else 0.0) + + # Step scans + self._step_scans() + + # Clear frames - will be read lazily + self.frames.clear() + self._sources_read = set() + + # Evaluate pipeline (reads happen on-demand) + result = self._eval(self.frame_pipeline, {}) + + # Skip unread sources to keep pipes in sync + for name, src in self.sources.items(): + if name not in self._sources_read: + src.skip() + + # Ensure output size + if result is not None: + import cv2 + if result.shape[:2] != (h, w): + result = cv2.resize(result, (w, h)) + out.write(result, self.ctx.t) + + # Progress + if frame_num % 30 == 0: + pct = 100 * frame_num / n_frames + # Show beats from first audio + total_beats = 0 + if self.audio_state: + first_state = next(iter(self.audio_state.values())) + total_beats = first_state['beat_count'] + print(f"\r{pct:5.1f}% | beats:{total_beats}", + end="", file=sys.stderr) + sys.stderr.flush() + + if frame_num % 300 == 0: + self.cache.flush() + + except KeyboardInterrupt: + print("\nInterrupted", file=sys.stderr) + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + finally: + out.close() + for src in self.sources.values(): + src.close() + self.cache.flush() + + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path) + if fps: + interp.ctx.fps = fps + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None, help="Override fps (default: from sexp)") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps) diff --git a/streaming/stream_sexp_generic.py b/streaming/stream_sexp_generic.py new file mode 100644 index 0000000..2f8d4b2 --- /dev/null +++ b/streaming/stream_sexp_generic.py @@ -0,0 +1,859 @@ +""" +Fully Generic Streaming S-expression Interpreter. + +The interpreter knows NOTHING about video, audio, or any domain. +All domain logic comes from primitives loaded via (require-primitives ...). + +Built-in forms: + - Control: if, cond, let, let*, lambda, -> + - Arithmetic: +, -, *, /, mod, map-range + - Comparison: <, >, =, <=, >=, and, or, not + - Data: dict, get, list, nth, len, quote + - Random: rand, rand-int, rand-range + - Scan: bind (access scan state) + +Everything else comes from primitives or effects. + +Context (ctx) is passed explicitly to frame evaluation: + - ctx.t: current time + - ctx.frame-num: current frame number + - ctx.fps: frames per second +""" + +import sys +import time +import json +import hashlib +import math +import numpy as np +from pathlib import Path +from dataclasses import dataclass +from typing import Dict, List, Any, Optional, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "artdag")) +from artdag.sexp.parser import parse, parse_all, Symbol, Keyword + + +@dataclass +class Context: + """Runtime context passed to frame evaluation.""" + t: float = 0.0 + frame_num: int = 0 + fps: float = 30.0 + + +class StreamInterpreter: + """ + Fully generic streaming sexp interpreter. + + No domain-specific knowledge - just evaluates expressions + and calls primitives. + """ + + def __init__(self, sexp_path: str): + self.sexp_path = Path(sexp_path) + self.sexp_dir = self.sexp_path.parent + + text = self.sexp_path.read_text() + self.ast = parse(text) + + self.config = self._parse_config() + + # Global environment for def bindings + self.globals: Dict[str, Any] = {} + + # Scans + self.scans: Dict[str, dict] = {} + + # Audio playback path (for syncing output) + self.audio_playback: Optional[str] = None + + # Registries for external definitions + self.primitives: Dict[str, Any] = {} + self.effects: Dict[str, dict] = {} + self.macros: Dict[str, dict] = {} + self.primitive_lib_dir = self.sexp_dir.parent / "sexp_effects" / "primitive_libs" + + self.frame_pipeline = None + + # External config files (set before run()) + self.sources_config: Optional[Path] = None + self.audio_config: Optional[Path] = None + + import random + self.rng = random.Random(self.config.get('seed', 42)) + + def _load_config_file(self, config_path): + """Load a config file and process its definitions.""" + config_path = Path(config_path) # Accept str or Path + if not config_path.exists(): + print(f"Warning: config file not found: {config_path}", file=sys.stderr) + return + + text = config_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'def': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Config: {name}", file=sys.stderr) + + elif cmd == 'audio-playback': + # Path relative to working directory (consistent with other paths) + path = str(form[1]).strip('"') + self.audio_playback = str(Path(path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + def _parse_config(self) -> dict: + """Parse config from (stream name :key val ...).""" + config = {'fps': 30, 'seed': 42, 'width': 720, 'height': 720} + if not self.ast or not isinstance(self.ast[0], Symbol): + return config + if self.ast[0].name != 'stream': + return config + + i = 2 + while i < len(self.ast): + if isinstance(self.ast[i], Keyword): + config[self.ast[i].name] = self.ast[i + 1] if i + 1 < len(self.ast) else None + i += 2 + elif isinstance(self.ast[i], list): + break + else: + i += 1 + return config + + def _load_primitives(self, lib_name: str): + """Load primitives from a Python library file.""" + import importlib.util + + lib_paths = [ + self.primitive_lib_dir / f"{lib_name}.py", + self.sexp_dir / "primitive_libs" / f"{lib_name}.py", + self.sexp_dir.parent / "sexp_effects" / "primitive_libs" / f"{lib_name}.py", + ] + + lib_path = None + for p in lib_paths: + if p.exists(): + lib_path = p + break + + if not lib_path: + print(f"Warning: primitive library '{lib_name}' not found", file=sys.stderr) + return + + spec = importlib.util.spec_from_file_location(lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + count = 0 + for name in dir(module): + if name.startswith('prim_'): + func = getattr(module, name) + prim_name = name[5:] + dash_name = prim_name.replace('_', '-') + # Register ONLY with namespace (geometry:ripple-displace) + self.primitives[f"{lib_name}:{dash_name}"] = func + count += 1 + + if hasattr(module, 'PRIMITIVES'): + prims = getattr(module, 'PRIMITIVES') + if isinstance(prims, dict): + for name, func in prims.items(): + # Register ONLY with namespace + dash_name = name.replace('_', '-') + self.primitives[f"{lib_name}:{dash_name}"] = func + count += 1 + + print(f"Loaded primitives: {lib_name} ({count} functions)", file=sys.stderr) + + def _load_effect(self, effect_path: Path): + """Load and register an effect from a .sexp file.""" + if not effect_path.exists(): + print(f"Warning: effect file not found: {effect_path}", file=sys.stderr) + return + + text = effect_path.read_text() + ast = parse_all(text) + + for form in ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'define-effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = {} + body = None + i = 2 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'params' and i + 1 < len(form): + for pdef in form[i + 1]: + if isinstance(pdef, list) and pdef: + pname = pdef[0].name if isinstance(pdef[0], Symbol) else str(pdef[0]) + pinfo = {'default': 0} + j = 1 + while j < len(pdef): + if isinstance(pdef[j], Keyword) and j + 1 < len(pdef): + pinfo[pdef[j].name] = pdef[j + 1] + j += 2 + else: + j += 1 + params[pname] = pinfo + i += 2 + else: + body = form[i] + i += 1 + + self.effects[name] = {'params': params, 'body': body} + print(f"Effect: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'effect': + # Handle (effect name :path "...") in included files - recursive + i = 2 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + # Resolve relative to the file being loaded + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'include': + # Handle (include :path "...") in included files - recursive + i = 1 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (effect_path.parent / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'scan': + # Handle scans from included files + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + def _init(self): + """Initialize from sexp - load primitives, effects, defs, scans.""" + # Load external config files first (they can override recipe definitions) + if self.sources_config: + self._load_config_file(self.sources_config) + if self.audio_config: + self._load_config_file(self.audio_config) + + for form in self.ast: + if not isinstance(form, list) or not form: + continue + if not isinstance(form[0], Symbol): + continue + + cmd = form[0].name + + if cmd == 'require-primitives': + lib_name = form[1] if isinstance(form[1], str) else str(form[1]).strip('"') + self._load_primitives(lib_name) + + elif cmd == 'effect': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + i = 2 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'include': + i = 1 + while i < len(form): + if isinstance(form[i], Keyword) and form[i].name == 'path': + path = str(form[i + 1]).strip('"') + full = (self.sexp_dir / path).resolve() + self._load_effect(full) + i += 2 + else: + i += 1 + + elif cmd == 'audio-playback': + # (audio-playback "path") - set audio file for playback sync + # Skip if already set by config file + if self.audio_playback is None: + path = str(form[1]).strip('"') + self.audio_playback = str((self.sexp_dir / path).resolve()) + print(f"Audio playback: {self.audio_playback}", file=sys.stderr) + + elif cmd == 'def': + # (def name expr) - evaluate and store in globals + # Skip if already defined by config file + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + if name in self.globals: + print(f"Def: {name} (from config, skipped)", file=sys.stderr) + continue + value = self._eval(form[2], self.globals) + self.globals[name] = value + print(f"Def: {name}", file=sys.stderr) + + elif cmd == 'defmacro': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + params = [p.name if isinstance(p, Symbol) else str(p) for p in form[2]] + body = form[3] + self.macros[name] = {'params': params, 'body': body} + + elif cmd == 'scan': + name = form[1].name if isinstance(form[1], Symbol) else str(form[1]) + trigger_expr = form[2] + init_val, step_expr = {}, None + i = 3 + while i < len(form): + if isinstance(form[i], Keyword): + if form[i].name == 'init' and i + 1 < len(form): + init_val = self._eval(form[i + 1], self.globals) + elif form[i].name == 'step' and i + 1 < len(form): + step_expr = form[i + 1] + i += 2 + else: + i += 1 + + self.scans[name] = { + 'state': dict(init_val) if isinstance(init_val, dict) else {'acc': init_val}, + 'init': init_val, + 'step': step_expr, + 'trigger': trigger_expr, + } + print(f"Scan: {name}", file=sys.stderr) + + elif cmd == 'frame': + self.frame_pipeline = form[1] if len(form) > 1 else None + + def _eval(self, expr, env: dict) -> Any: + """Evaluate an expression.""" + + # Primitives + if isinstance(expr, (int, float)): + return expr + if isinstance(expr, str): + return expr + if isinstance(expr, bool): + return expr + + if isinstance(expr, Symbol): + name = expr.name + # Built-in constants + if name == 'pi': + return math.pi + if name == 'true': + return True + if name == 'false': + return False + if name == 'nil': + return None + # Environment lookup + if name in env: + return env[name] + # Global lookup + if name in self.globals: + return self.globals[name] + # Scan state lookup + if name in self.scans: + return self.scans[name]['state'] + raise NameError(f"Undefined variable: {name}") + + if isinstance(expr, Keyword): + return expr.name + + if not isinstance(expr, list) or not expr: + return expr + + # Dict literal {:key val ...} + if isinstance(expr[0], Keyword): + result = {} + i = 0 + while i < len(expr): + if isinstance(expr[i], Keyword): + result[expr[i].name] = self._eval(expr[i + 1], env) if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + return result + + head = expr[0] + if not isinstance(head, Symbol): + return [self._eval(e, env) for e in expr] + + op = head.name + args = expr[1:] + + # Check for closure call + if op in env: + val = env[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + if op in self.globals: + val = self.globals[op] + if isinstance(val, dict) and val.get('_type') == 'closure': + closure = val + closure_env = dict(closure['env']) + for i, pname in enumerate(closure['params']): + closure_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(closure['body'], closure_env) + + # Threading macro + if op == '->': + result = self._eval(args[0], env) + for form in args[1:]: + if isinstance(form, list) and form: + new_form = [form[0], result] + form[1:] + result = self._eval(new_form, env) + else: + result = self._eval([form, result], env) + return result + + # === Binding === + + if op == 'bind': + scan_name = args[0].name if isinstance(args[0], Symbol) else str(args[0]) + if scan_name in self.scans: + state = self.scans[scan_name]['state'] + if len(args) > 1: + key = args[1].name if isinstance(args[1], Keyword) else str(args[1]) + return state.get(key, 0) + return state + return 0 + + # === Arithmetic === + + if op == '+': + return sum(self._eval(a, env) for a in args) + if op == '-': + vals = [self._eval(a, env) for a in args] + return vals[0] - sum(vals[1:]) if len(vals) > 1 else -vals[0] + if op == '*': + result = 1 + for a in args: + result *= self._eval(a, env) + return result + if op == '/': + vals = [self._eval(a, env) for a in args] + return vals[0] / vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + if op == 'mod': + vals = [self._eval(a, env) for a in args] + return vals[0] % vals[1] if len(vals) > 1 and vals[1] != 0 else 0 + + # === Comparison === + + if op == '<': + return self._eval(args[0], env) < self._eval(args[1], env) + if op == '>': + return self._eval(args[0], env) > self._eval(args[1], env) + if op == '=': + return self._eval(args[0], env) == self._eval(args[1], env) + if op == '<=': + return self._eval(args[0], env) <= self._eval(args[1], env) + if op == '>=': + return self._eval(args[0], env) >= self._eval(args[1], env) + + if op == 'and': + for arg in args: + if not self._eval(arg, env): + return False + return True + + if op == 'or': + result = False + for arg in args: + result = self._eval(arg, env) + if result: + return result + return result + + if op == 'not': + return not self._eval(args[0], env) + + # === Logic === + + if op == 'if': + cond = self._eval(args[0], env) + if cond: + return self._eval(args[1], env) + return self._eval(args[2], env) if len(args) > 2 else None + + if op == 'cond': + i = 0 + while i < len(args) - 1: + pred = self._eval(args[i], env) + if pred: + return self._eval(args[i + 1], env) + i += 2 + return None + + if op == 'lambda': + params = args[0] + body = args[1] + param_names = [p.name if isinstance(p, Symbol) else str(p) for p in params] + return {'_type': 'closure', 'params': param_names, 'body': body, 'env': dict(env)} + + if op == 'let' or op == 'let*': + bindings = args[0] + body = args[1] + new_env = dict(env) + + if bindings and isinstance(bindings[0], list): + for binding in bindings: + if isinstance(binding, list) and len(binding) >= 2: + name = binding[0].name if isinstance(binding[0], Symbol) else str(binding[0]) + val = self._eval(binding[1], new_env) + new_env[name] = val + else: + i = 0 + while i < len(bindings): + name = bindings[i].name if isinstance(bindings[i], Symbol) else str(bindings[i]) + val = self._eval(bindings[i + 1], new_env) + new_env[name] = val + i += 2 + return self._eval(body, new_env) + + # === Dict === + + if op == 'dict': + result = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + key = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + result[key] = val + i += 2 + else: + i += 1 + return result + + if op == 'get': + obj = self._eval(args[0], env) + key = args[1].name if isinstance(args[1], Keyword) else self._eval(args[1], env) + if isinstance(obj, dict): + return obj.get(key, 0) + return 0 + + # === List === + + if op == 'list': + return [self._eval(a, env) for a in args] + + if op == 'quote': + return args[0] if args else None + + if op == 'nth': + lst = self._eval(args[0], env) + idx = int(self._eval(args[1], env)) + if isinstance(lst, (list, tuple)) and 0 <= idx < len(lst): + return lst[idx] + return None + + if op == 'len': + val = self._eval(args[0], env) + return len(val) if hasattr(val, '__len__') else 0 + + if op == 'map': + seq = self._eval(args[0], env) + fn = self._eval(args[1], env) + if not isinstance(seq, (list, tuple)): + return [] + # Handle closure (lambda from sexp) + if isinstance(fn, dict) and fn.get('_type') == 'closure': + results = [] + for item in seq: + closure_env = dict(fn['env']) + if fn['params']: + closure_env[fn['params'][0]] = item + results.append(self._eval(fn['body'], closure_env)) + return results + # Handle Python callable + if callable(fn): + return [fn(item) for item in seq] + return [] + + # === Effects === + + if op in self.effects: + effect = self.effects[op] + effect_env = dict(env) + + param_names = list(effect['params'].keys()) + for pname, pdef in effect['params'].items(): + effect_env[pname] = pdef.get('default', 0) + + positional_idx = 0 + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_env[pname] = self._eval(args[i + 1], env) + i += 2 + else: + val = self._eval(args[i], env) + if positional_idx == 0: + effect_env['frame'] = val + elif positional_idx - 1 < len(param_names): + effect_env[param_names[positional_idx - 1]] = val + positional_idx += 1 + i += 1 + + return self._eval(effect['body'], effect_env) + + # === Primitives === + + if op in self.primitives: + prim_func = self.primitives[op] + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # === Macros (function-like: args evaluated before binding) === + + if op in self.macros: + macro = self.macros[op] + macro_env = dict(env) + for i, pname in enumerate(macro['params']): + # Evaluate args in calling environment before binding + macro_env[pname] = self._eval(args[i], env) if i < len(args) else None + return self._eval(macro['body'], macro_env) + + # Underscore variant lookup + prim_name = op.replace('-', '_') + if prim_name in self.primitives: + prim_func = self.primitives[prim_name] + evaluated_args = [] + kwargs = {} + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + k = args[i].name.replace('-', '_') + v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + kwargs[k] = v + i += 2 + else: + evaluated_args.append(self._eval(args[i], env)) + i += 1 + + try: + if kwargs: + return prim_func(*evaluated_args, **kwargs) + return prim_func(*evaluated_args) + except Exception as e: + print(f"Primitive {op} error: {e}", file=sys.stderr) + return None + + # Unknown - return as-is + return expr + + def _step_scans(self, ctx: Context, env: dict): + """Step scans based on trigger evaluation.""" + for name, scan in self.scans.items(): + trigger_expr = scan['trigger'] + + # Evaluate trigger in context + should_step = self._eval(trigger_expr, env) + + if should_step: + state = scan['state'] + step_env = dict(state) + step_env.update(env) + + new_state = self._eval(scan['step'], step_env) + if isinstance(new_state, dict): + scan['state'] = new_state + else: + scan['state'] = {'acc': new_state} + + def run(self, duration: float = None, output: str = "pipe"): + """Run the streaming pipeline.""" + # Import output classes - handle both package and direct execution + try: + from .output import PipeOutput, DisplayOutput, FileOutput + except ImportError: + from output import PipeOutput, DisplayOutput, FileOutput + + self._init() + + if not self.frame_pipeline: + print("Error: no (frame ...) pipeline defined", file=sys.stderr) + return + + w = self.config.get('width', 720) + h = self.config.get('height', 720) + fps = self.config.get('fps', 30) + + if duration is None: + # Try to get duration from audio if available + for name, val in self.globals.items(): + if hasattr(val, 'duration'): + duration = val.duration + print(f"Using audio duration: {duration:.1f}s", file=sys.stderr) + break + else: + duration = 60.0 + + n_frames = int(duration * fps) + frame_time = 1.0 / fps + + print(f"Streaming {n_frames} frames @ {fps}fps", file=sys.stderr) + + # Create context + ctx = Context(fps=fps) + + # Output (with optional audio sync) + audio = self.audio_playback + if output == "pipe": + out = PipeOutput(size=(w, h), fps=fps, audio_source=audio) + elif output == "preview": + out = DisplayOutput(size=(w, h), fps=fps, audio_source=audio) + else: + out = FileOutput(output, size=(w, h), fps=fps, audio_source=audio) + + try: + frame_times = [] + for frame_num in range(n_frames): + if not out.is_open: + break + + frame_start = time.time() + ctx.t = frame_num * frame_time + ctx.frame_num = frame_num + + # Build frame environment with context + frame_env = { + 'ctx': { + 't': ctx.t, + 'frame-num': ctx.frame_num, + 'fps': ctx.fps, + }, + 't': ctx.t, # Also expose t directly for convenience + 'frame-num': ctx.frame_num, + } + + # Step scans + self._step_scans(ctx, frame_env) + + # Evaluate pipeline + result = self._eval(self.frame_pipeline, frame_env) + + if result is not None and hasattr(result, 'shape'): + out.write(result, ctx.t) + + frame_elapsed = time.time() - frame_start + frame_times.append(frame_elapsed) + + # Progress with timing + if frame_num % 30 == 0: + pct = 100 * frame_num / n_frames + avg_ms = 1000 * sum(frame_times[-30:]) / max(1, len(frame_times[-30:])) + target_ms = 1000 * frame_time + print(f"\r{pct:5.1f}% [{avg_ms:.0f}ms/frame, target {target_ms:.0f}ms]", end="", file=sys.stderr, flush=True) + + finally: + out.close() + print("\nDone", file=sys.stderr) + + +def run_stream(sexp_path: str, duration: float = None, output: str = "pipe", fps: float = None, + sources_config: str = None, audio_config: str = None): + """Run a streaming sexp.""" + interp = StreamInterpreter(sexp_path) + if fps: + interp.config['fps'] = fps + if sources_config: + interp.sources_config = Path(sources_config) + if audio_config: + interp.audio_config = Path(audio_config) + interp.run(duration=duration, output=output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Run streaming sexp (generic interpreter)") + parser.add_argument("sexp", help="Path to .sexp file") + parser.add_argument("-d", "--duration", type=float, default=None) + parser.add_argument("-o", "--output", default="pipe") + parser.add_argument("--fps", type=float, default=None) + parser.add_argument("--sources", dest="sources_config", help="Path to sources config .sexp file") + parser.add_argument("--audio", dest="audio_config", help="Path to audio config .sexp file") + args = parser.parse_args() + + run_stream(args.sexp, duration=args.duration, output=args.output, fps=args.fps, + sources_config=args.sources_config, audio_config=args.audio_config) diff --git a/templates/crossfade-zoom.sexp b/templates/crossfade-zoom.sexp new file mode 100644 index 0000000..fc6d9ad --- /dev/null +++ b/templates/crossfade-zoom.sexp @@ -0,0 +1,25 @@ +;; Crossfade with Zoom Transition +;; +;; Macro for transitioning between two frames with a zoom effect. +;; Active frame zooms out while next frame zooms in. +;; +;; Required context: +;; - zoom effect must be loaded +;; - blend effect must be loaded +;; +;; Parameters: +;; active-frame: current frame +;; next-frame: frame to transition to +;; fade-amt: transition progress (0 = all active, 1 = all next) +;; +;; Usage: +;; (include :path "../templates/crossfade-zoom.sexp") +;; ... +;; (crossfade-zoom active-frame next-frame 0.5) + +(defmacro crossfade-zoom (active-frame next-frame fade-amt) + (let [active-zoom (+ 1.0 fade-amt) + active-zoomed (zoom active-frame :amount active-zoom) + next-zoom (+ 0.1 (* fade-amt 0.9)) + next-zoomed (zoom next-frame :amount next-zoom)] + (blend active-zoomed next-zoomed :opacity fade-amt))) diff --git a/templates/cycle-crossfade.sexp b/templates/cycle-crossfade.sexp new file mode 100644 index 0000000..40a87ca --- /dev/null +++ b/templates/cycle-crossfade.sexp @@ -0,0 +1,65 @@ +;; cycle-crossfade template +;; +;; Generalized cycling zoom-crossfade for any number of video layers. +;; Cycles through videos with smooth zoom-based crossfade transitions. +;; +;; Parameters: +;; beat-data - beat analysis node (drives timing) +;; input-videos - list of video nodes to cycle through +;; init-clen - initial cycle length in beats +;; +;; NOTE: The parameter is named "input-videos" (not "videos") because +;; template substitution replaces param names everywhere in the AST. +;; The planner's _expand_slice_on injects env["videos"] at plan time, +;; so (len videos) inside the lambda references that injected value. + +(deftemplate cycle-crossfade + (beat-data input-videos init-clen) + + (slice-on beat-data + :videos input-videos + :init {:cycle 0 :beat 0 :clen init-clen} + :fn (lambda [acc i start end] + (let [beat (get acc "beat") + clen (get acc "clen") + active (get acc "cycle") + n (len videos) + phase3 (* beat 3) + wt (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (- 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.0)) + (if (= active prev) + (if (< phase3 clen) 0.0 + (if (< phase3 (* clen 2)) + (* (/ (- phase3 clen) clen) 1.0) + 1.0)) + 0.0)))) + zm (lambda [p] + (let [prev (mod (+ p (- n 1)) n)] + (if (= active p) + ;; Active video: normal -> zoom out during transition -> tiny + (if (< phase3 clen) 1.0 + (if (< phase3 (* clen 2)) + (+ 1.0 (* (/ (- phase3 clen) clen) 1.0)) + 0.1)) + (if (= active prev) + ;; Incoming video: tiny -> zoom in during transition -> normal + (if (< phase3 clen) 0.1 + (if (< phase3 (* clen 2)) + (+ 0.1 (* (/ (- phase3 clen) clen) 0.9)) + 1.0)) + 0.1)))) + new-acc (if (< (+ beat 1) clen) + (dict :cycle active :beat (+ beat 1) :clen clen) + (dict :cycle (mod (+ active 1) n) :beat 0 + :clen (+ 40 (mod (* i 7) 41))))] + {:layers (map (lambda [p] + {:video p :effects [{:effect zoom :amount (zm p)}]}) + (range 0 n)) + :compose {:effect blend_multi :mode "alpha" + :weights (map (lambda [p] (wt p)) (range 0 n))} + :acc new-acc})))) diff --git a/templates/process-pair.sexp b/templates/process-pair.sexp new file mode 100644 index 0000000..6720cd2 --- /dev/null +++ b/templates/process-pair.sexp @@ -0,0 +1,112 @@ +;; process-pair template +;; +;; Reusable video-pair processor: takes a single video source, creates two +;; clips (A and B) with opposite rotations and sporadic effects, blends them, +;; and applies a per-pair slow rotation driven by a beat scan. +;; +;; All sporadic triggers (invert, hue-shift, ascii) and pair-level controls +;; (blend opacity, rotation) are defined internally using seed offsets. +;; +;; Parameters: +;; video - source video node +;; energy - energy analysis node (drives rotation/zoom amounts) +;; beat-data - beat analysis node (drives sporadic triggers) +;; rng - RNG object from (make-rng seed) for auto-derived seeds +;; rot-dir - initial rotation direction: 1 (clockwise) or -1 (anti-clockwise) +;; rot-a/b - rotation ranges for clip A/B (e.g. [0 45]) +;; zoom-a/b - zoom ranges for clip A/B (e.g. [1 1.5]) + +(deftemplate process-pair + (video energy beat-data rng rot-dir rot-a rot-b zoom-a zoom-b) + + ;; --- Sporadic triggers for clip A --- + + ;; Invert: 10% chance per beat, lasts 1-5 beats + (def inv-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + ;; Hue shift: 10% chance, random hue 30-330 deg, lasts 1-5 beats + (def hue-a (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + ;; ASCII art: 5% chance, lasts 1-3 beats + (def ascii-a (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Sporadic triggers for clip B (offset seeds) --- + + (def inv-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0))) + + (def hue-b (scan beat-data :rng rng :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0))) + + (def ascii-b (scan beat-data :rng rng :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.05) (rand-int 1 3) 0)) + :emit (if (> acc 0) 1 0))) + + ;; --- Pair-level controls --- + + ;; Internal A/B blend: randomly show A (0), both (0.5), or B (1), every 1-11 beats + (def pair-mix (scan beat-data :rng rng + :init (dict :rem 0 :opacity 0.5) + :step (if (> rem 0) + (dict :rem (- rem 1) :opacity opacity) + (dict :rem (rand-int 1 11) :opacity (* (rand-int 0 2) 0.5))) + :emit opacity)) + + ;; Per-pair rotation: one full rotation every 20-30 beats, alternating direction + (def pair-rot (scan beat-data :rng rng + :init (dict :beat 0 :clen 25 :dir rot-dir :angle 0) + :step (if (< (+ beat 1) clen) + (dict :beat (+ beat 1) :clen clen :dir dir + :angle (+ angle (* dir (/ 360 clen)))) + (dict :beat 0 :clen (rand-int 20 30) :dir (* dir -1) + :angle angle)) + :emit angle)) + + ;; --- Clip A processing --- + (def clip-a (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-a (-> clip-a + (effect rotate :angle (bind energy values :range rot-a)) + (effect zoom :amount (bind energy values :range zoom-a)) + (effect invert :amount (bind inv-a values)) + (effect hue_shift :degrees (bind hue-a values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-a values)) + )) + + ;; --- Clip B processing --- + (def clip-b (-> video (segment :start 0 :duration (bind energy duration)))) + (def rotated-b (-> clip-b + (effect rotate :angle (bind energy values :range rot-b)) + (effect zoom :amount (bind energy values :range zoom-b)) + (effect invert :amount (bind inv-b values)) + (effect hue_shift :degrees (bind hue-b values)) + ;; ASCII disabled - too slow without GPU + ;; (effect ascii_art + ;; :char_size (bind energy values :range [4 32]) + ;; :mix (bind ascii-b values)) + )) + + ;; --- Blend A+B and apply pair rotation --- + (-> rotated-a + (effect blend rotated-b + :mode "alpha" :opacity (bind pair-mix values) :resize_mode "fit") + (effect rotate + :angle (bind pair-rot values)))) diff --git a/templates/scan-oscillating-spin.sexp b/templates/scan-oscillating-spin.sexp new file mode 100644 index 0000000..051f079 --- /dev/null +++ b/templates/scan-oscillating-spin.sexp @@ -0,0 +1,28 @@ +;; Oscillating Spin Scan +;; +;; Accumulates rotation angle on each beat, reversing direction +;; periodically for an oscillating effect. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; +;; Provides scan: spin +;; Bind with: (bind spin :angle) ;; cumulative rotation angle +;; +;; Behavior: +;; - Rotates 14.4 degrees per beat (completes 360 in 25 beats) +;; - After 20-30 beats, reverses direction +;; - Creates a swinging/oscillating rotation effect +;; +;; Usage: +;; (include :path "../templates/scan-oscillating-spin.sexp") +;; +;; In frame: +;; (rotate frame :angle (bind spin :angle)) + +(scan spin (streaming:audio-beat music t) + :init {:angle 0 :dir 1 :left 25} + :step (if (> left 0) + (dict :angle (+ angle (* dir 14.4)) :dir dir :left (- left 1)) + (dict :angle angle :dir (* dir -1) + :left (+ 20 (mod (streaming:audio-beat-count music t) 11))))) diff --git a/templates/scan-ripple-drops.sexp b/templates/scan-ripple-drops.sexp new file mode 100644 index 0000000..7caf720 --- /dev/null +++ b/templates/scan-ripple-drops.sexp @@ -0,0 +1,41 @@ +;; Beat-Triggered Ripple Drops Scan +;; +;; Creates random ripple drops triggered by audio beats. +;; Each drop has a random center position and duration. +;; +;; Required context: +;; - music: audio analyzer from (streaming:make-audio-analyzer ...) +;; - core primitives loaded +;; +;; Provides scan: ripple-state +;; Bind with: (bind ripple-state :gate) ;; 0 or 1 +;; (bind ripple-state :cx) ;; center x (0-1) +;; (bind ripple-state :cy) ;; center y (0-1) +;; +;; Parameters: +;; trigger-chance: probability per beat (default 0.15) +;; min-duration: minimum beats (default 1) +;; max-duration: maximum beats (default 15) +;; +;; Usage: +;; (include :path "../templates/scan-ripple-drops.sexp") +;; ;; Uses default: 15% chance, 1-15 beat duration +;; +;; In frame: +;; (let [rip-gate (bind ripple-state :gate) +;; rip-amp (* rip-gate (core:map-range e 0 1 5 50))] +;; (ripple frame +;; :amplitude rip-amp +;; :center_x (bind ripple-state :cx) +;; :center_y (bind ripple-state :cy))) + +(scan ripple-state (streaming:audio-beat music t) + :init {:gate 0 :cx 0.5 :cy 0.5 :left 0} + :step (if (> left 0) + (dict :gate 1 :cx cx :cy cy :left (- left 1)) + (if (< (core:rand) 0.15) + (dict :gate 1 + :cx (+ 0.2 (* (core:rand) 0.6)) + :cy (+ 0.2 (* (core:rand) 0.6)) + :left (+ 1 (mod (streaming:audio-beat-count music t) 15))) + (dict :gate 0 :cx 0.5 :cy 0.5 :left 0)))) diff --git a/templates/standard-effects.sexp b/templates/standard-effects.sexp new file mode 100644 index 0000000..9e97f34 --- /dev/null +++ b/templates/standard-effects.sexp @@ -0,0 +1,22 @@ +;; Standard Effects Bundle +;; +;; Loads commonly-used video effects. +;; Include after primitives are loaded. +;; +;; Effects provided: +;; - rotate: rotation by angle +;; - zoom: scale in/out +;; - blend: alpha blend two frames +;; - ripple: water ripple distortion +;; - invert: color inversion +;; - hue_shift: hue rotation +;; +;; Usage: +;; (include :path "../templates/standard-effects.sexp") + +(effect rotate :path "../sexp_effects/effects/rotate.sexp") +(effect zoom :path "../sexp_effects/effects/zoom.sexp") +(effect blend :path "../sexp_effects/effects/blend.sexp") +(effect ripple :path "../sexp_effects/effects/ripple.sexp") +(effect invert :path "../sexp_effects/effects/invert.sexp") +(effect hue_shift :path "../sexp_effects/effects/hue_shift.sexp") diff --git a/templates/standard-primitives.sexp b/templates/standard-primitives.sexp new file mode 100644 index 0000000..6e2c62d --- /dev/null +++ b/templates/standard-primitives.sexp @@ -0,0 +1,14 @@ +;; Standard Primitives Bundle +;; +;; Loads all commonly-used primitive libraries. +;; Include this at the top of streaming recipes. +;; +;; Usage: +;; (include :path "../templates/standard-primitives.sexp") + +(require-primitives "geometry") +(require-primitives "core") +(require-primitives "image") +(require-primitives "blending") +(require-primitives "color_ops") +(require-primitives "streaming") diff --git a/templates/stream-process-pair.sexp b/templates/stream-process-pair.sexp new file mode 100644 index 0000000..55f408e --- /dev/null +++ b/templates/stream-process-pair.sexp @@ -0,0 +1,72 @@ +;; stream-process-pair template (streaming-compatible) +;; +;; Macro for processing a video source pair with full effects. +;; Reads source, applies A/B effects (rotate, zoom, invert, hue), blends, +;; and applies pair-level rotation. +;; +;; Required context (must be defined in calling scope): +;; - sources: array of video sources +;; - pair-configs: array of {:dir :rot-a :rot-b :zoom-a :zoom-b} configs +;; - pair-states: array from (bind pairs :states) +;; - now: current time (t) +;; - e: audio energy (0-1) +;; +;; Required effects (must be loaded): +;; - rotate, zoom, invert, hue_shift, blend +;; +;; Usage: +;; (include :path "../templates/stream-process-pair.sexp") +;; ...in frame pipeline... +;; (let [pair-states (bind pairs :states) +;; now t +;; e (streaming:audio-energy music now)] +;; (process-pair 0)) ;; process source at index 0 + +(require-primitives "core") + +(defmacro process-pair (src-idx) + (let [src (nth sources src-idx) + frame (streaming:source-read src now) + cfg (nth pair-configs src-idx) + state (nth pair-states src-idx) + + ;; Get state values (invert uses countdown > 0) + inv-a-active (if (> (get state :inv-a) 0) 1 0) + inv-b-active (if (> (get state :inv-b) 0) 1 0) + ;; Hue is active only when countdown > 0 + hue-a-val (if (> (get state :hue-a) 0) (get state :hue-a-val) 0) + hue-b-val (if (> (get state :hue-b) 0) (get state :hue-b-val) 0) + mix-opacity (get state :mix) + pair-rot-angle (* (get state :angle) (get cfg :dir)) + + ;; Get config values for energy-mapped ranges + rot-a-max (get cfg :rot-a) + rot-b-max (get cfg :rot-b) + zoom-a-max (get cfg :zoom-a) + zoom-b-max (get cfg :zoom-b) + + ;; Energy-driven rotation and zoom + rot-a (core:map-range e 0 1 0 rot-a-max) + rot-b (core:map-range e 0 1 0 rot-b-max) + zoom-a (core:map-range e 0 1 1 zoom-a-max) + zoom-b (core:map-range e 0 1 1 zoom-b-max) + + ;; Apply effects to clip A + clip-a (-> frame + (rotate :angle rot-a) + (zoom :amount zoom-a) + (invert :amount inv-a-active) + (hue_shift :degrees hue-a-val)) + + ;; Apply effects to clip B + clip-b (-> frame + (rotate :angle rot-b) + (zoom :amount zoom-b) + (invert :amount inv-b-active) + (hue_shift :degrees hue-b-val)) + + ;; Blend A+B + blended (blend clip-a clip-b :opacity mix-opacity)] + + ;; Apply pair-level rotation + (rotate blended :angle pair-rot-angle))) diff --git a/test_effects_pipeline.py b/test_effects_pipeline.py new file mode 100644 index 0000000..d1c8870 --- /dev/null +++ b/test_effects_pipeline.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +Test the full effects pipeline: segment -> effect -> output + +This tests that effects can be applied to video segments without +producing "No video stream found" errors. +""" + +import subprocess +import tempfile +import sys +from pathlib import Path + +# Add parent to path +sys.path.insert(0, str(Path(__file__).parent)) + +import numpy as np +from sexp_effects import ( + get_interpreter, + load_effects_dir, + run_effect, + list_effects, +) + + +def create_test_video(path: Path, duration: float = 1.0, size: str = "64x64") -> bool: + """Create a short test video using ffmpeg.""" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10", + "-c:v", "libx264", "-preset", "ultrafast", + str(path) + ] + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + print(f"Failed to create test video: {result.stderr.decode()}") + return False + return True + + +def segment_video(input_path: Path, output_path: Path, start: float, duration: float) -> bool: + """Segment a video file.""" + cmd = [ + "ffmpeg", "-y", "-i", str(input_path), + "-ss", str(start), "-t", str(duration), + "-c:v", "libx264", "-preset", "ultrafast", + "-c:a", "aac", + str(output_path) + ] + result = subprocess.run(cmd, capture_output=True) + if result.returncode != 0: + print(f"Failed to segment video: {result.stderr.decode()}") + return False + + # Verify output has video stream + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(output_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + import json + probe_data = json.loads(probe_result.stdout) + + has_video = any( + s.get("codec_type") == "video" + for s in probe_data.get("streams", []) + ) + if not has_video: + print(f"Segment has no video stream!") + return False + + return True + + +def run_effect_on_video(effect_name: str, input_path: Path, output_path: Path) -> bool: + """Run a sexp effect on a video file using frame processing.""" + import json + + # Get video info + probe_cmd = [ + "ffprobe", "-v", "quiet", "-print_format", "json", + "-show_streams", str(input_path) + ] + probe_result = subprocess.run(probe_cmd, capture_output=True, text=True) + probe_data = json.loads(probe_result.stdout) + + video_stream = None + for stream in probe_data.get("streams", []): + if stream.get("codec_type") == "video": + video_stream = stream + break + + if not video_stream: + print(f" Input has no video stream: {input_path}") + return False + + width = int(video_stream["width"]) + height = int(video_stream["height"]) + fps_str = video_stream.get("r_frame_rate", "10/1") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) + else: + fps = float(fps_str) + + # Read frames, process, write + read_cmd = ["ffmpeg", "-i", str(input_path), "-f", "rawvideo", "-pix_fmt", "rgb24", "-"] + write_cmd = [ + "ffmpeg", "-y", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{width}x{height}", "-r", str(fps), + "-i", "-", + "-c:v", "libx264", "-preset", "ultrafast", + str(output_path) + ] + + read_proc = subprocess.Popen(read_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + write_proc = subprocess.Popen(write_cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) + + frame_size = width * height * 3 + frame_count = 0 + state = {} + + while True: + frame_data = read_proc.stdout.read(frame_size) + if len(frame_data) < frame_size: + break + + frame = np.frombuffer(frame_data, dtype=np.uint8).reshape((height, width, 3)) + processed, state = run_effect(effect_name, frame, {'_time': frame_count / fps}, state) + write_proc.stdin.write(processed.tobytes()) + frame_count += 1 + + read_proc.stdout.close() + write_proc.stdin.close() + read_proc.wait() + write_proc.wait() + + if write_proc.returncode != 0: + print(f" FFmpeg encode failed: {write_proc.stderr.read().decode()}") + return False + + return frame_count > 0 + + +def test_effect_pipeline(effect_name: str, tmpdir: Path) -> tuple: + """ + Test full pipeline: create video -> segment -> apply effect + + Returns (success, error_message) + """ + # Create test video + source_video = tmpdir / "source.mp4" + if not create_test_video(source_video, duration=1.0, size="64x64"): + return False, "Failed to create source video" + + # Segment it (simulate what the recipe does) + segment_video_path = tmpdir / "segment.mp4" + if not segment_video(source_video, segment_video_path, start=0.2, duration=0.5): + return False, "Failed to segment video" + + # Check segment file exists and has content + if not segment_video_path.exists(): + return False, "Segment file doesn't exist" + if segment_video_path.stat().st_size < 100: + return False, f"Segment file too small: {segment_video_path.stat().st_size} bytes" + + # Apply effect + output_video = tmpdir / "output.mp4" + try: + if not run_effect_on_video(effect_name, segment_video_path, output_video): + return False, "Effect processing failed" + except Exception as e: + return False, str(e) + + # Verify output + if not output_video.exists(): + return False, "Output file doesn't exist" + if output_video.stat().st_size < 100: + return False, f"Output file too small: {output_video.stat().st_size} bytes" + + return True, None + + +def main(): + print("=" * 60) + print("Effects Pipeline Test") + print("=" * 60) + + # Load effects + effects_dir = Path(__file__).parent / "sexp_effects" / "effects" + load_effects_dir(str(effects_dir)) + + effects = list_effects() + print(f"Testing {len(effects)} effects through segment->effect pipeline\n") + + passed = [] + failed = [] + + # Test multi-input effects separately + multi_input_effects = ("blend", "layer") + print("\nTesting multi-input effects...") + from sexp_effects.interpreter import get_interpreter + interp = get_interpreter() + frame_a = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + frame_b = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + for name in multi_input_effects: + try: + interp.global_env.set('frame-a', frame_a.copy()) + interp.global_env.set('frame-b', frame_b.copy()) + interp.global_env.set('frame', frame_a.copy()) + result, state = interp.run_effect(name, frame_a.copy(), {'_time': 0.5}, {}) + if isinstance(result, np.ndarray) and result.shape == frame_a.shape: + passed.append(name) + print(f" {name}: OK") + else: + failed.append((name, f"Bad output shape: {result.shape if hasattr(result, 'shape') else type(result)}")) + print(f" {name}: FAILED - bad shape") + except Exception as e: + failed.append((name, str(e))) + print(f" {name}: FAILED - {e}") + + print("\nTesting single-input effects through pipeline...") + + # Test each effect + for effect_name in sorted(effects): + # Skip multi-input effects (already tested above) + if effect_name in multi_input_effects: + continue + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + success, error = test_effect_pipeline(effect_name, tmpdir) + + if success: + passed.append(effect_name) + print(f" {effect_name}: OK") + else: + failed.append((effect_name, error)) + print(f" {effect_name}: FAILED - {error}") + + print() + print("=" * 60) + print(f"Pipeline test: {len(passed)} passed, {len(failed)} failed") + if failed: + print("\nFailed effects:") + for name, error in failed: + print(f" {name}: {error}") + print("=" * 60) + + return len(failed) == 0 + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) From a13c361dee79c3b9a76cdce3d5818becbcc42880 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:16:23 +0000 Subject: [PATCH 08/24] Configure monorepo build: unified CI, local deps, .dockerignore - Dockerfiles use monorepo root as build context - common/ and core/ installed as local packages (no git+https) - Client tarball built from local client/ dir - Unified CI with change detection: common/core -> rebuild both - Per-repo CI workflows removed Co-Authored-By: Claude Opus 4.6 --- .dockerignore | 8 +++ .gitea/workflows/ci.yml | 114 +++++++++++++++++++++++++++++++++++++ l1/.gitea/workflows/ci.yml | 63 -------------------- l1/Dockerfile | 26 +++++---- l1/Dockerfile.gpu | 9 ++- l1/build-client.sh | 37 ------------ l1/requirements.txt | 5 +- l2/.gitea/workflows/ci.yml | 62 -------------------- l2/Dockerfile | 16 +++--- l2/requirements.txt | 3 +- 10 files changed, 153 insertions(+), 190 deletions(-) create mode 100644 .dockerignore create mode 100644 .gitea/workflows/ci.yml delete mode 100644 l1/.gitea/workflows/ci.yml delete mode 100755 l1/build-client.sh delete mode 100644 l2/.gitea/workflows/ci.yml diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..01c3870 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git +.gitea +**/.env +**/.env.gpu +**/__pycache__ +**/.pytest_cache +**/*.pyc +test/ diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..c2bf29e --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,114 @@ +name: Build and Deploy + +on: + push: + branches: [main] + +env: + REGISTRY: registry.rose-ash.com:5000 + ARTDAG_DIR: /root/art-dag-mono + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install tools + run: | + apt-get update && apt-get install -y --no-install-recommends openssh-client + + - name: Set up SSH + env: + SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + mkdir -p ~/.ssh + echo "$SSH_KEY" > ~/.ssh/id_rsa + chmod 600 ~/.ssh/id_rsa + ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true + + - name: Build and deploy + env: + DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} + run: | + ssh "root@$DEPLOY_HOST" " + cd ${{ env.ARTDAG_DIR }} + + OLD_HEAD=\$(git rev-parse HEAD 2>/dev/null || echo none) + + git fetch origin main + git reset --hard origin/main + + NEW_HEAD=\$(git rev-parse HEAD) + + # Change detection + BUILD_L1=false + BUILD_L2=false + if [ \"\$OLD_HEAD\" = \"none\" ] || [ \"\$OLD_HEAD\" = \"\$NEW_HEAD\" ]; then + BUILD_L1=true + BUILD_L2=true + else + CHANGED=\$(git diff --name-only \$OLD_HEAD \$NEW_HEAD) + # common/ or core/ change -> rebuild both + if echo \"\$CHANGED\" | grep -qE '^(common|core)/'; then + BUILD_L1=true + BUILD_L2=true + fi + if echo \"\$CHANGED\" | grep -q '^l1/'; then + BUILD_L1=true + fi + if echo \"\$CHANGED\" | grep -q '^l2/'; then + BUILD_L2=true + fi + if echo \"\$CHANGED\" | grep -q '^client/'; then + BUILD_L1=true + fi + fi + + # Build L1 + if [ \"\$BUILD_L1\" = true ]; then + echo 'Building L1...' + docker build \ + --build-arg CACHEBUST=\$(date +%s) \ + -f l1/Dockerfile \ + -t ${{ env.REGISTRY }}/celery-l1-server:latest \ + -t ${{ env.REGISTRY }}/celery-l1-server:${{ github.sha }} \ + . + docker push ${{ env.REGISTRY }}/celery-l1-server:latest + docker push ${{ env.REGISTRY }}/celery-l1-server:${{ github.sha }} + else + echo 'Skipping L1 (no changes)' + fi + + # Build L2 + if [ \"\$BUILD_L2\" = true ]; then + echo 'Building L2...' + docker build \ + --build-arg CACHEBUST=\$(date +%s) \ + -f l2/Dockerfile \ + -t ${{ env.REGISTRY }}/l2-server:latest \ + -t ${{ env.REGISTRY }}/l2-server:${{ github.sha }} \ + . + docker push ${{ env.REGISTRY }}/l2-server:latest + docker push ${{ env.REGISTRY }}/l2-server:${{ github.sha }} + else + echo 'Skipping L2 (no changes)' + fi + + # Deploy stacks + if [ \"\$BUILD_L1\" = true ]; then + cd l1 && source .env && docker stack deploy -c docker-compose.yml celery && cd .. + echo 'L1 stack deployed' + fi + if [ \"\$BUILD_L2\" = true ]; then + cd l2 && source .env && docker stack deploy -c docker-compose.yml activitypub && cd .. + echo 'L2 stack deployed' + fi + + sleep 10 + echo '=== L1 Services ===' + docker stack services celery + echo '=== L2 Services ===' + docker stack services activitypub + " diff --git a/l1/.gitea/workflows/ci.yml b/l1/.gitea/workflows/ci.yml deleted file mode 100644 index a79f66e..0000000 --- a/l1/.gitea/workflows/ci.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Build and Deploy - -on: - push: - -env: - REGISTRY: registry.rose-ash.com:5000 - IMAGE_CPU: celery-l1-server - -jobs: - build-and-deploy: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install tools - run: | - apt-get update && apt-get install -y --no-install-recommends openssh-client - - - name: Set up SSH - env: - SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - mkdir -p ~/.ssh - echo "$SSH_KEY" > ~/.ssh/id_rsa - chmod 600 ~/.ssh/id_rsa - ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true - - - name: Pull latest code on server - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - BRANCH: ${{ github.ref_name }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/celery - git fetch origin $BRANCH - git checkout $BRANCH - git reset --hard origin/$BRANCH - " - - - name: Build and push image - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/celery - docker build --build-arg CACHEBUST=\$(date +%s) -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest -t ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} . - docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:latest - docker push ${{ env.REGISTRY }}/${{ env.IMAGE_CPU }}:${{ github.sha }} - " - - - name: Deploy stack - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/celery - docker stack deploy -c docker-compose.yml celery - echo 'Waiting for services to update...' - sleep 10 - docker stack services celery - " diff --git a/l1/Dockerfile b/l1/Dockerfile index 90a770d..25de9a8 100644 --- a/l1/Dockerfile +++ b/l1/Dockerfile @@ -1,25 +1,28 @@ FROM python:3.11-slim - WORKDIR /app -# Install git and ffmpeg (for video transcoding) RUN apt-get update && apt-get install -y --no-install-recommends git ffmpeg && rm -rf /var/lib/apt/lists/* -# Install dependencies -COPY requirements.txt . +# Install common + core as local packages (no more git+https) +COPY common/ /tmp/common/ +COPY core/ /tmp/core/ +RUN pip install --no-cache-dir /tmp/common/ /tmp/core/ && rm -rf /tmp/common /tmp/core + +# Install L1 dependencies +COPY l1/requirements.txt . ARG CACHEBUST=1 RUN pip install --no-cache-dir -r requirements.txt -# Copy application -COPY . . +# Copy L1 application +COPY l1/ . -# Clone effects repo +# Build client tarball from local client/ dir +COPY client/ /tmp/artdag-client/ +RUN cd /tmp && tar -czf /app/artdag-client.tar.gz artdag-client && rm -rf /tmp/artdag-client + +# Clone effects repo (still external) RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects -# Build client tarball for download -RUN ./build-client.sh - -# Create cache directory RUN mkdir -p /data/cache ENV PYTHONUNBUFFERED=1 @@ -27,5 +30,4 @@ ENV PYTHONDONTWRITEBYTECODE=1 ENV EFFECTS_PATH=/app/artdag-effects ENV PYTHONPATH=/app -# Default command runs the server CMD ["python", "server.py"] diff --git a/l1/Dockerfile.gpu b/l1/Dockerfile.gpu index 967f788..8f764da 100644 --- a/l1/Dockerfile.gpu +++ b/l1/Dockerfile.gpu @@ -60,7 +60,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ RUN python3 -m pip install --upgrade pip # Install CPU dependencies first -COPY requirements.txt . +COPY l1/requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Install GPU-specific dependencies (CuPy for CUDA 12.x) @@ -74,11 +74,16 @@ COPY --from=builder /decord-install /usr/local/lib/python3.11/dist-packages/ COPY --from=builder /tmp/decord/build/libdecord.so /usr/local/lib/ RUN ldconfig +# Install monorepo shared packages +COPY common/ /tmp/common/ +COPY core/ /tmp/core/ +RUN pip install --no-cache-dir /tmp/common/ /tmp/core/ && rm -rf /tmp/common /tmp/core + # Clone effects repo (before COPY so it gets cached) RUN git clone https://git.rose-ash.com/art-dag/effects.git /app/artdag-effects # Copy application (this invalidates cache for any code change) -COPY . . +COPY l1/ . # Create cache directory RUN mkdir -p /data/cache diff --git a/l1/build-client.sh b/l1/build-client.sh deleted file mode 100755 index c9443b6..0000000 --- a/l1/build-client.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# Build the artdag-client tarball -# This script is run during deployment to create the downloadable client package - -set -e - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -CLIENT_REPO="https://git.rose-ash.com/art-dag/client.git" -TEMP_DIR=$(mktemp -d) -OUTPUT_FILE="$SCRIPT_DIR/artdag-client.tar.gz" - -echo "Building artdag-client.tar.gz..." - -# Clone the client repo -git clone --depth 1 "$CLIENT_REPO" "$TEMP_DIR/artdag-client" 2>/dev/null || { - echo "Failed to clone client repo, trying alternative..." - # Try GitHub if internal git fails - git clone --depth 1 "https://github.com/gilesbradshaw/art-client.git" "$TEMP_DIR/artdag-client" 2>/dev/null || { - echo "Error: Could not clone client repository" - rm -rf "$TEMP_DIR" - exit 1 - } -} - -# Remove .git directory -rm -rf "$TEMP_DIR/artdag-client/.git" -rm -rf "$TEMP_DIR/artdag-client/__pycache__" - -# Create tarball -cd "$TEMP_DIR" -tar -czf "$OUTPUT_FILE" artdag-client - -# Cleanup -rm -rf "$TEMP_DIR" - -echo "Created: $OUTPUT_FILE" -ls -lh "$OUTPUT_FILE" diff --git a/l1/requirements.txt b/l1/requirements.txt index deab545..be6950f 100644 --- a/l1/requirements.txt +++ b/l1/requirements.txt @@ -13,9 +13,6 @@ markdown>=3.5.0 # Common effect dependencies (used by uploaded effects) numpy>=1.24.0 opencv-python-headless>=4.8.0 -# Core artdag from GitHub (tracks main branch) -git+https://github.com/gilesbradshaw/art-dag.git@main -# Shared components (tracks master branch) -git+https://git.rose-ash.com/art-dag/common.git@master +# core (artdag) and common (artdag_common) installed from local dirs in Dockerfile psycopg2-binary nest_asyncio diff --git a/l2/.gitea/workflows/ci.yml b/l2/.gitea/workflows/ci.yml deleted file mode 100644 index 30d34ea..0000000 --- a/l2/.gitea/workflows/ci.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: Build and Deploy - -on: - push: - branches: [main] - -env: - REGISTRY: registry.rose-ash.com:5000 - IMAGE: l2-server - -jobs: - build-and-deploy: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install tools - run: | - apt-get update && apt-get install -y --no-install-recommends openssh-client - - - name: Set up SSH - env: - SSH_KEY: ${{ secrets.DEPLOY_SSH_KEY }} - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - mkdir -p ~/.ssh - echo "$SSH_KEY" > ~/.ssh/id_rsa - chmod 600 ~/.ssh/id_rsa - ssh-keyscan -H "$DEPLOY_HOST" >> ~/.ssh/known_hosts 2>/dev/null || true - - - name: Pull latest code on server - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/activity-pub - git fetch origin main - git reset --hard origin/main - " - - - name: Build and push image - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/activity-pub - docker build --build-arg CACHEBUST=\$(date +%s) -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:latest -t ${{ env.REGISTRY }}/${{ env.IMAGE }}:${{ github.sha }} . - docker push ${{ env.REGISTRY }}/${{ env.IMAGE }}:latest - docker push ${{ env.REGISTRY }}/${{ env.IMAGE }}:${{ github.sha }} - " - - - name: Deploy stack - env: - DEPLOY_HOST: ${{ secrets.DEPLOY_HOST }} - run: | - ssh "root@$DEPLOY_HOST" " - cd /root/art-dag/activity-pub - docker stack deploy -c docker-compose.yml activitypub - echo 'Waiting for services to update...' - sleep 10 - docker stack services activitypub - " diff --git a/l2/Dockerfile b/l2/Dockerfile index 409aadf..085a695 100644 --- a/l2/Dockerfile +++ b/l2/Dockerfile @@ -1,23 +1,23 @@ FROM python:3.11-slim - WORKDIR /app -# Install git for pip to clone dependencies RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* -# Install dependencies -COPY requirements.txt . +# Install common as local package (no more git+https) +COPY common/ /tmp/common/ +RUN pip install --no-cache-dir /tmp/common/ && rm -rf /tmp/common + +# Install L2 dependencies +COPY l2/requirements.txt . ARG CACHEBUST=1 RUN pip install --no-cache-dir -r requirements.txt -# Copy application -COPY . . +# Copy L2 application +COPY l2/ . -# Create data directory RUN mkdir -p /data/l2 ENV PYTHONUNBUFFERED=1 ENV ARTDAG_DATA=/data/l2 -# Default command runs the server CMD ["python", "server.py"] diff --git a/l2/requirements.txt b/l2/requirements.txt index 94d1e5a..5d228c5 100644 --- a/l2/requirements.txt +++ b/l2/requirements.txt @@ -9,5 +9,4 @@ markdown>=3.5.0 python-multipart>=0.0.6 asyncpg>=0.29.0 boto3>=1.34.0 -# Shared components -git+https://git.rose-ash.com/art-dag/common.git@889ea98 +# common (artdag_common) installed from local dir in Dockerfile From e58def135d142da03fcbd0783cfc84b484b55254 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:26:39 +0000 Subject: [PATCH 09/24] Add deploy.sh and zap.sh scripts for manual deploys Ported from old art-dag root, updated for monorepo paths. Co-Authored-By: Claude Opus 4.6 --- deploy.sh | 31 +++++++++++++++++++++++++++++++ zap.sh | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100755 deploy.sh create mode 100755 zap.sh diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..6086ed2 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,31 @@ +set -e + +cd "$(dirname "$0")" + +echo "=== Building L1 ===" +docker build --build-arg CACHEBUST=$(date +%s) -f l1/Dockerfile -t registry.rose-ash.com:5000/celery-l1-server:latest . +docker push registry.rose-ash.com:5000/celery-l1-server:latest + +echo "=== Building L2 ===" +docker build --build-arg CACHEBUST=$(date +%s) -f l2/Dockerfile -t registry.rose-ash.com:5000/l2-server:latest . +docker push registry.rose-ash.com:5000/l2-server:latest + +echo "=== Deploying celery stack ===" +cd l1 && source .env && docker stack deploy -c docker-compose.yml celery && cd .. + +echo "=== Deploying activitypub stack ===" +cd l2 && source .env && docker stack deploy -c docker-compose.yml activitypub && cd .. + +sleep 30 + +docker service update --force celery_l1-worker +docker service update --force celery_l1-server +docker service update --force celery_flower +docker service update --force celery_ipfs +docker stack services celery + +docker service update --force activitypub_l2-server +docker stack services activitypub + +echo "=== Restarting proxy nginx ===" +docker service update --force proxy_nginx diff --git a/zap.sh b/zap.sh new file mode 100755 index 0000000..7e32bb2 --- /dev/null +++ b/zap.sh @@ -0,0 +1,43 @@ +set -e + +cd "$(dirname "$0")" + +echo "=== Building L1 ===" +docker build --build-arg CACHEBUST=$(date +%s) -f l1/Dockerfile -t registry.rose-ash.com:5000/celery-l1-server:latest . + +echo "=== Building L2 ===" +docker build --build-arg CACHEBUST=$(date +%s) -f l2/Dockerfile -t registry.rose-ash.com:5000/l2-server:latest . + +echo "=== Removing stacks ===" +docker stack rm celery +docker stack rm activitypub + +sleep 30 + +echo "=== Removing volumes ===" +docker volume rm activitypub_l2_data +docker volume rm activitypub_postgres_data +docker volume rm activitypub_ipfs_data +docker volume rm celery_l1_cache +docker volume rm celery_redis_data +docker volume rm celery_ipfs_data +docker volume rm celery_postgres_data + +echo "=== Redeploying celery stack ===" +cd l1 && source .env && docker stack deploy -c docker-compose.yml celery && cd .. + +echo "=== Redeploying activitypub stack ===" +cd l2 && source .env && docker stack deploy -c docker-compose.yml activitypub && cd .. + +sleep 30 + +docker service update --force celery_l1-worker +docker service update --force celery_l1-server +docker service update --force celery_flower +docker stack services celery + +docker service update --force activitypub_l2-server +docker stack services activitypub + +echo "=== Restarting proxy nginx ===" +docker service update --force proxy_nginx From e7610bed7c845d63bfd370e7ec7725d2ba30c4c9 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:34:56 +0000 Subject: [PATCH 10/24] Dark content area beneath coop header Wrap content block in bg-dark-800 so all existing dark-themed templates render correctly without per-file migration. Co-Authored-By: Claude Opus 4.6 --- common/artdag_common/templates/_base.html | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/artdag_common/templates/_base.html b/common/artdag_common/templates/_base.html index deeb67b..77e5ef8 100644 --- a/common/artdag_common/templates/_base.html +++ b/common/artdag_common/templates/_base.html @@ -81,8 +81,10 @@ {# App-specific sub-nav (Runs, Recipes, Effects, etc.) #} {% block sub_nav %}{% endblock %} -
+
+
{% block content %}{% endblock %} +
{% block footer %}{% endblock %} From 1b4e51c48c6aff43d9f073d13a229d14a3d52956 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:38:14 +0000 Subject: [PATCH 11/24] Add max-width gutters to match coop layout Wrap page in max-w-screen-2xl mx-auto py-1 px-1 like blog. Co-Authored-By: Claude Opus 4.6 --- common/artdag_common/templates/_base.html | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/artdag_common/templates/_base.html b/common/artdag_common/templates/_base.html index 77e5ef8..4bdc3c6 100644 --- a/common/artdag_common/templates/_base.html +++ b/common/artdag_common/templates/_base.html @@ -49,6 +49,7 @@ +
{% block header %} {# Coop-style header: sky banner with title, nav-tree, auth-menu, cart-mini #}
@@ -88,6 +89,7 @@
{% block footer %}{% endblock %} + {% block scripts %}{% endblock %} From a5717ec4d4da9ef0ad3be7c36014a1e29b9a188f Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:39:19 +0000 Subject: [PATCH 12/24] Fall back to username for auth-menu email param Existing sessions have email=None since the field was just added. Username IS the email in Art-DAG (OAuth returns user.email as username). Co-Authored-By: Claude Opus 4.6 --- l1/app/__init__.py | 2 +- l2/app/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/l1/app/__init__.py b/l1/app/__init__.py index 408983b..3b945b3 100644 --- a/l1/app/__init__.py +++ b/l1/app/__init__.py @@ -170,7 +170,7 @@ def create_app() -> FastAPI: from artdag_common.fragments import fetch_fragments as _fetch_frags user = get_user_from_cookie(request) - auth_params = {"email": user.email} if user and user.email else {} + auth_params = {"email": user.email or user.username} if user else {} nav_params = {"app_name": "artdag", "path": path} try: diff --git a/l2/app/__init__.py b/l2/app/__init__.py index 1062a13..add533b 100644 --- a/l2/app/__init__.py +++ b/l2/app/__init__.py @@ -58,7 +58,7 @@ def create_app() -> FastAPI: from artdag_common.fragments import fetch_fragments as _fetch_frags user = get_user_from_cookie(request) - auth_params = {"email": user.email} if user and user.email else {} + auth_params = {"email": user.email or user.username} if user else {} nav_params = {"app_name": "artdag", "path": path} try: From d8206c7b3ba7312777913aea6fec42e25065ad28 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:45:57 +0000 Subject: [PATCH 13/24] Fix gutter width: close header wrapper before dark main area The max-w-screen-2xl wrapper now only constrains the header/nav, matching blog layout. Dark content area goes full-width with its own inner max-w constraint. Co-Authored-By: Claude Opus 4.6 --- common/artdag_common/templates/_base.html | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/artdag_common/templates/_base.html b/common/artdag_common/templates/_base.html index 4bdc3c6..d32b572 100644 --- a/common/artdag_common/templates/_base.html +++ b/common/artdag_common/templates/_base.html @@ -82,6 +82,8 @@ {# App-specific sub-nav (Runs, Recipes, Effects, etc.) #} {% block sub_nav %}{% endblock %} + {# close max-w-screen-2xl wrapper #} +
{% block content %}{% endblock %} @@ -89,7 +91,6 @@
{% block footer %}{% endblock %} - {% block scripts %}{% endblock %} From 3dde4e79ab9703f07c8275751afdac3adb7881b6 Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:50:31 +0000 Subject: [PATCH 14/24] Add OAuth SSO, device ID, and silent auth to L2 - Replace L2's username/password auth with OAuth SSO via account.rose-ash.com - Add device_id middleware (artdag_did cookie) - Add silent auth check (prompt=none with 5-min cooldown) - Add OAuth config settings and itsdangerous dependency Co-Authored-By: Claude Opus 4.6 --- l2/app/__init__.py | 82 ++++++++++- l2/app/config.py | 8 ++ l2/app/routers/auth.py | 319 +++++++++++++++++------------------------ l2/docker-compose.yml | 8 +- l2/requirements.txt | 1 + 5 files changed, 227 insertions(+), 191 deletions(-) diff --git a/l2/app/__init__.py b/l2/app/__init__.py index add533b..e4938e2 100644 --- a/l2/app/__init__.py +++ b/l2/app/__init__.py @@ -4,16 +4,38 @@ Art-DAG L2 Server Application Factory. Creates and configures the FastAPI application with all routers and middleware. """ +import secrets +import time from pathlib import Path from contextlib import asynccontextmanager +from urllib.parse import quote + from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, HTMLResponse +from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse from artdag_common import create_jinja_env from artdag_common.middleware.auth import get_user_from_cookie from .config import settings +# Paths that should never trigger a silent auth check +_SKIP_PREFIXES = ("/auth/", "/.well-known/", "/health", + "/internal/", "/static/", "/inbox") +_SILENT_CHECK_COOLDOWN = 300 # 5 minutes +_DEVICE_COOKIE = "artdag_did" +_DEVICE_COOKIE_MAX_AGE = 30 * 24 * 3600 # 30 days + +# Derive external base URL from oauth_redirect_uri (e.g. https://artdag.rose-ash.com) +_EXTERNAL_BASE = settings.oauth_redirect_uri.rsplit("/auth/callback", 1)[0] + + +def _external_url(request: Request) -> str: + """Build external URL from request path + query, using configured base domain.""" + url = f"{_EXTERNAL_BASE}{request.url.path}" + if request.url.query: + url += f"?{request.url.query}" + return url + @asynccontextmanager async def lifespan(app: FastAPI): @@ -38,6 +60,64 @@ def create_app() -> FastAPI: lifespan=lifespan, ) + # Silent auth check — auto-login via prompt=none OAuth + # NOTE: registered BEFORE device_id so device_id is outermost (runs first) + @app.middleware("http") + async def silent_auth_check(request: Request, call_next): + path = request.url.path + if ( + request.method != "GET" + or any(path.startswith(p) for p in _SKIP_PREFIXES) + or request.headers.get("hx-request") # skip HTMX + ): + return await call_next(request) + + # Already logged in — pass through + if get_user_from_cookie(request): + return await call_next(request) + + # Check cooldown — don't re-check within 5 minutes + pnone_at = request.cookies.get("pnone_at") + if pnone_at: + try: + pnone_ts = float(pnone_at) + if (time.time() - pnone_ts) < _SILENT_CHECK_COOLDOWN: + return await call_next(request) + except (ValueError, TypeError): + pass + + # Redirect to silent OAuth check + current_url = _external_url(request) + return RedirectResponse( + url=f"/auth/login?prompt=none&next={quote(current_url, safe='')}", + status_code=302, + ) + + # Device ID middleware — track browser identity across domains + # Registered AFTER silent_auth_check so it's outermost (always runs) + @app.middleware("http") + async def device_id_middleware(request: Request, call_next): + did = request.cookies.get(_DEVICE_COOKIE) + if did: + request.state.device_id = did + request.state._new_device_id = False + else: + request.state.device_id = secrets.token_urlsafe(32) + request.state._new_device_id = True + + response = await call_next(request) + + if getattr(request.state, "_new_device_id", False): + response.set_cookie( + key=_DEVICE_COOKIE, + value=request.state.device_id, + max_age=_DEVICE_COOKIE_MAX_AGE, + httponly=True, + samesite="lax", + secure=True, + ) + return response + # Coop fragment pre-fetch — inject nav-tree, auth-menu, cart-mini _FRAG_SKIP = ("/auth/", "/.well-known/", "/health", "/internal/", "/static/", "/inbox") diff --git a/l2/app/config.py b/l2/app/config.py index d88d435..d08b3ed 100644 --- a/l2/app/config.py +++ b/l2/app/config.py @@ -33,6 +33,14 @@ class Settings: jwt_algorithm: str = "HS256" access_token_expire_minutes: int = 60 * 24 * 30 # 30 days + # OAuth SSO (via account.rose-ash.com) + oauth_authorize_url: str = os.environ.get("OAUTH_AUTHORIZE_URL", "https://account.rose-ash.com/auth/oauth/authorize") + oauth_token_url: str = os.environ.get("OAUTH_TOKEN_URL", "https://account.rose-ash.com/auth/oauth/token") + oauth_client_id: str = os.environ.get("OAUTH_CLIENT_ID", "artdag_l2") + oauth_redirect_uri: str = os.environ.get("OAUTH_REDIRECT_URI", "https://artdag.rose-ash.com/auth/callback") + oauth_logout_url: str = os.environ.get("OAUTH_LOGOUT_URL", "https://account.rose-ash.com/auth/sso-logout/") + secret_key: str = os.environ.get("SECRET_KEY", "change-me-in-production") + def __post_init__(self): # Parse L1 servers l1_str = os.environ.get("L1_SERVERS", "https://celery-artdag.rose-ash.com") diff --git a/l2/app/routers/auth.py b/l2/app/routers/auth.py index 4691caf..98fea5a 100644 --- a/l2/app/routers/auth.py +++ b/l2/app/routers/auth.py @@ -1,223 +1,164 @@ """ -Authentication routes for L2 server. +Authentication routes — OAuth2 authorization code flow via account.rose-ash.com. -Handles login, registration, logout, and token verification. +GET /auth/login — redirect to account OAuth authorize +GET /auth/callback — exchange code for user info, set session cookie +GET /auth/logout — clear cookie, redirect through account SSO logout """ -import hashlib -from datetime import datetime, timezone +import secrets +import time -from fastapi import APIRouter, Request, Form, HTTPException, Depends -from fastapi.responses import HTMLResponse, RedirectResponse -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +import httpx +from fastapi import APIRouter, Request +from fastapi.responses import RedirectResponse +from itsdangerous import URLSafeSerializer -from artdag_common import render -from artdag_common.middleware import wants_html +from artdag_common.middleware.auth import UserContext, set_auth_cookie, clear_auth_cookie from ..config import settings -from ..dependencies import get_templates, get_user_from_cookie router = APIRouter() -security = HTTPBearer(auto_error=False) + +_signer = None -@router.get("/login", response_class=HTMLResponse) -async def login_page(request: Request, return_to: str = None): - """Login page.""" - username = get_user_from_cookie(request) +def _get_signer() -> URLSafeSerializer: + global _signer + if _signer is None: + _signer = URLSafeSerializer(settings.secret_key, salt="oauth-state") + return _signer - if username: - templates = get_templates(request) - return render(templates, "auth/already_logged_in.html", request, - user={"username": username}, - ) - templates = get_templates(request) - return render(templates, "auth/login.html", request, - return_to=return_to, +@router.get("/login") +async def login(request: Request): + """Store state + next in signed cookie, redirect to account OAuth authorize.""" + next_url = request.query_params.get("next", "/") + prompt = request.query_params.get("prompt", "") + state = secrets.token_urlsafe(32) + + signer = _get_signer() + state_payload = signer.dumps({"state": state, "next": next_url, "prompt": prompt}) + + device_id = getattr(request.state, "device_id", "") + authorize_url = ( + f"{settings.oauth_authorize_url}" + f"?client_id={settings.oauth_client_id}" + f"&redirect_uri={settings.oauth_redirect_uri}" + f"&device_id={device_id}" + f"&state={state}" ) + if prompt: + authorize_url += f"&prompt={prompt}" - -@router.post("/login", response_class=HTMLResponse) -async def login_submit( - request: Request, - username: str = Form(...), - password: str = Form(...), - return_to: str = Form(None), -): - """Handle login form submission.""" - from auth import authenticate_user, create_access_token - - if not username or not password: - return HTMLResponse( - '
Username and password are required
' - ) - - user = await authenticate_user(settings.data_dir, username.strip(), password) - if not user: - return HTMLResponse( - '
Invalid username or password
' - ) - - token = create_access_token(user.username, l2_server=f"https://{settings.domain}") - - # Handle return_to redirect - if return_to and return_to.startswith("http"): - separator = "&" if "?" in return_to else "?" - redirect_url = f"{return_to}{separator}auth_token={token.access_token}" - response = HTMLResponse(f''' -
Login successful! Redirecting...
- - ''') - else: - response = HTMLResponse(''' -
Login successful! Redirecting...
- - ''') - + response = RedirectResponse(url=authorize_url, status_code=302) response.set_cookie( - key="auth_token", - value=token.access_token, + key="oauth_state", + value=state_payload, + max_age=600, # 10 minutes httponly=True, - max_age=60 * 60 * 24 * 30, samesite="lax", secure=True, ) return response -@router.get("/register", response_class=HTMLResponse) -async def register_page(request: Request): - """Registration page.""" - username = get_user_from_cookie(request) +@router.get("/callback") +async def callback(request: Request): + """Validate state, exchange code via token endpoint, set session cookie.""" + code = request.query_params.get("code", "") + state = request.query_params.get("state", "") + error = request.query_params.get("error", "") + account_did = request.query_params.get("account_did", "") - if username: - templates = get_templates(request) - return render(templates, "auth/already_logged_in.html", request, - user={"username": username}, - ) - - templates = get_templates(request) - return render(templates, "auth/register.html", request) - - -@router.post("/register", response_class=HTMLResponse) -async def register_submit( - request: Request, - username: str = Form(...), - password: str = Form(...), - password2: str = Form(...), - email: str = Form(None), -): - """Handle registration form submission.""" - from auth import create_user, create_access_token - - if not username or not password: - return HTMLResponse('
Username and password are required
') - - if password != password2: - return HTMLResponse('
Passwords do not match
') - - if len(password) < 6: - return HTMLResponse('
Password must be at least 6 characters
') + # Adopt account's device ID as our own (one identity across all apps) + if account_did: + request.state.device_id = account_did + request.state._new_device_id = True # device_id middleware will set cookie + # Recover state from signed cookie + state_cookie = request.cookies.get("oauth_state", "") + signer = _get_signer() try: - user = await create_user(settings.data_dir, username.strip(), password, email) - except ValueError as e: - return HTMLResponse(f'
{str(e)}
') + payload = signer.loads(state_cookie) if state_cookie else {} + except Exception: + payload = {} - token = create_access_token(user.username, l2_server=f"https://{settings.domain}") + next_url = payload.get("next", "/") - response = HTMLResponse(''' -
Registration successful! Redirecting...
- - ''') - response.set_cookie( - key="auth_token", - value=token.access_token, - httponly=True, - max_age=60 * 60 * 24 * 30, - samesite="lax", - secure=True, - ) + # Handle prompt=none rejection (user not logged in on account) + if error == "login_required": + response = RedirectResponse(url=next_url, status_code=302) + response.delete_cookie("oauth_state") + # Set cooldown cookie — don't re-check for 5 minutes + response.set_cookie( + key="pnone_at", + value=str(time.time()), + max_age=300, + httponly=True, + samesite="lax", + secure=True, + ) + # Set device cookie if adopted + if account_did: + response.set_cookie( + key="artdag_did", + value=account_did, + max_age=30 * 24 * 3600, + httponly=True, + samesite="lax", + secure=True, + ) + return response + + # Normal callback — validate state + code + if not state_cookie or not code or not state: + return RedirectResponse(url="/", status_code=302) + + if payload.get("state") != state: + return RedirectResponse(url="/", status_code=302) + + # Exchange code for user info via account's token endpoint + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + settings.oauth_token_url, + json={ + "code": code, + "client_id": settings.oauth_client_id, + "redirect_uri": settings.oauth_redirect_uri, + }, + ) + except httpx.HTTPError: + return RedirectResponse(url="/", status_code=302) + + if resp.status_code != 200: + return RedirectResponse(url="/", status_code=302) + + data = resp.json() + if "error" in data: + return RedirectResponse(url="/", status_code=302) + + # Map OAuth response to artdag UserContext + # Note: account token endpoint returns user.email as "username" + username = data.get("username", "") + email = username # OAuth response "username" is the user's email + actor_id = f"@{username}" + + user = UserContext(username=username, actor_id=actor_id, email=email) + + response = RedirectResponse(url=next_url, status_code=302) + set_auth_cookie(response, user) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") return response @router.get("/logout") -async def logout(request: Request): - """Handle logout.""" - import db - import requests - from auth import get_token_claims - - token = request.cookies.get("auth_token") - claims = get_token_claims(token) if token else None - username = claims.get("sub") if claims else None - - if username and token and claims: - # Revoke token in database - token_hash = hashlib.sha256(token.encode()).hexdigest() - expires_at = datetime.fromtimestamp(claims.get("exp", 0), tz=timezone.utc) - await db.revoke_token(token_hash, username, expires_at) - - # Revoke on attached L1 servers - attached = await db.get_user_renderers(username) - for l1_url in attached: - try: - requests.post( - f"{l1_url}/auth/revoke-user", - json={"username": username, "l2_server": f"https://{settings.domain}"}, - timeout=5, - ) - except Exception: - pass - - response = RedirectResponse(url="/", status_code=302) - response.delete_cookie("auth_token") +async def logout(): + """Clear session cookie, redirect through account SSO logout.""" + response = RedirectResponse(url=settings.oauth_logout_url, status_code=302) + clear_auth_cookie(response) + response.delete_cookie("oauth_state") + response.delete_cookie("pnone_at") return response - - -@router.get("/verify") -async def verify_token( - request: Request, - credentials: HTTPAuthorizationCredentials = Depends(security), -): - """ - Verify a token is valid. - - Called by L1 servers to verify tokens during auth callback. - Returns user info if valid, 401 if not. - """ - import db - from auth import verify_token as verify_jwt, get_token_claims - - # Get token from Authorization header or query param - token = None - if credentials: - token = credentials.credentials - else: - # Try Authorization header manually (for clients that don't use Bearer format) - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - token = auth_header[7:] - - if not token: - raise HTTPException(401, "No token provided") - - # Verify JWT signature and expiry - username = verify_jwt(token) - if not username: - raise HTTPException(401, "Invalid or expired token") - - # Check if token is revoked - claims = get_token_claims(token) - if claims: - token_hash = hashlib.sha256(token.encode()).hexdigest() - if await db.is_token_revoked(token_hash): - raise HTTPException(401, "Token has been revoked") - - return { - "valid": True, - "username": username, - "claims": claims, - } diff --git a/l2/docker-compose.yml b/l2/docker-compose.yml index 0f67e81..9c91ea4 100644 --- a/l2/docker-compose.yml +++ b/l2/docker-compose.yml @@ -52,7 +52,13 @@ services: - INTERNAL_URL_BLOG=http://blog:8000 - INTERNAL_URL_CART=http://cart:8000 - INTERNAL_URL_ACCOUNT=http://account:8000 - # DATABASE_URL, ARTDAG_DOMAIN, ARTDAG_USER, JWT_SECRET from .env file + # OAuth SSO + - OAUTH_AUTHORIZE_URL=https://account.rose-ash.com/auth/oauth/authorize + - OAUTH_TOKEN_URL=https://account.rose-ash.com/auth/oauth/token + - OAUTH_CLIENT_ID=artdag_l2 + - OAUTH_REDIRECT_URI=https://artdag.rose-ash.com/auth/callback + - OAUTH_LOGOUT_URL=https://account.rose-ash.com/auth/sso-logout/ + # DATABASE_URL, ARTDAG_DOMAIN, ARTDAG_USER, JWT_SECRET, SECRET_KEY from .env file healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8200/')"] interval: 10s diff --git a/l2/requirements.txt b/l2/requirements.txt index 5d228c5..83b95d2 100644 --- a/l2/requirements.txt +++ b/l2/requirements.txt @@ -7,6 +7,7 @@ bcrypt>=4.0.0 python-jose[cryptography]>=3.3.0 markdown>=3.5.0 python-multipart>=0.0.6 +itsdangerous>=2.1.0 asyncpg>=0.29.0 boto3>=1.34.0 # common (artdag_common) installed from local dir in Dockerfile From b45a2b6c10d98f2f5a9a7511519b10c5d94975cf Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 01:20:41 +0000 Subject: [PATCH 15/24] Fix OAuth token exchange: use internal URL, add error logging The server-to-server token exchange was hitting the external URL (https://account.rose-ash.com/...) which can fail from inside Docker due to DNS/hairpin NAT. Now uses INTERNAL_URL_ACCOUNT (already set in both docker-compose files) for the POST. Adds logging at all three failure points so silent redirects are diagnosable. Co-Authored-By: Claude Opus 4.6 --- l1/app/config.py | 5 +++++ l1/app/routers/auth.py | 14 ++++++++++++-- l2/app/config.py | 3 +++ l2/app/routers/auth.py | 14 ++++++++++++-- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/l1/app/config.py b/l1/app/config.py index 8aa94d7..6e4c005 100644 --- a/l1/app/config.py +++ b/l1/app/config.py @@ -64,6 +64,11 @@ class Settings: default_factory=lambda: os.environ.get("SECRET_KEY", "change-me-in-production") ) + # Internal account URL for server-to-server token exchange (avoids external DNS/TLS) + internal_account_url: str = field( + default_factory=lambda: os.environ.get("INTERNAL_URL_ACCOUNT", "") + ) + # GPU/Streaming settings streaming_gpu_persist: bool = field( default_factory=lambda: os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" diff --git a/l1/app/routers/auth.py b/l1/app/routers/auth.py index c447f3d..3f85658 100644 --- a/l1/app/routers/auth.py +++ b/l1/app/routers/auth.py @@ -6,6 +6,7 @@ GET /auth/callback — exchange code for user info, set session cookie GET /auth/logout — clear cookie, redirect through account SSO logout """ +import logging import secrets import time @@ -18,6 +19,7 @@ from artdag_common.middleware.auth import UserContext, set_auth_cookie, clear_au from ..config import settings +logger = logging.getLogger(__name__) router = APIRouter() _signer = None @@ -119,24 +121,32 @@ async def callback(request: Request): return RedirectResponse(url="/", status_code=302) # Exchange code for user info via account's token endpoint + # Prefer internal URL (Docker overlay) to avoid external DNS/TLS issues + token_url = settings.oauth_token_url + if settings.internal_account_url: + token_url = f"{settings.internal_account_url.rstrip('/')}/auth/oauth/token" + async with httpx.AsyncClient(timeout=10) as client: try: resp = await client.post( - settings.oauth_token_url, + token_url, json={ "code": code, "client_id": settings.oauth_client_id, "redirect_uri": settings.oauth_redirect_uri, }, ) - except httpx.HTTPError: + except httpx.HTTPError as exc: + logger.error("OAuth token exchange failed: %s %s", type(exc).__name__, exc) return RedirectResponse(url="/", status_code=302) if resp.status_code != 200: + logger.error("OAuth token exchange returned %s: %s", resp.status_code, resp.text[:200]) return RedirectResponse(url="/", status_code=302) data = resp.json() if "error" in data: + logger.error("OAuth token exchange error: %s", data["error"]) return RedirectResponse(url="/", status_code=302) # Map OAuth response to artdag UserContext diff --git a/l2/app/config.py b/l2/app/config.py index d08b3ed..d2c1437 100644 --- a/l2/app/config.py +++ b/l2/app/config.py @@ -41,6 +41,9 @@ class Settings: oauth_logout_url: str = os.environ.get("OAUTH_LOGOUT_URL", "https://account.rose-ash.com/auth/sso-logout/") secret_key: str = os.environ.get("SECRET_KEY", "change-me-in-production") + # Internal account URL for server-to-server token exchange (avoids external DNS/TLS) + internal_account_url: str = os.environ.get("INTERNAL_URL_ACCOUNT", "") + def __post_init__(self): # Parse L1 servers l1_str = os.environ.get("L1_SERVERS", "https://celery-artdag.rose-ash.com") diff --git a/l2/app/routers/auth.py b/l2/app/routers/auth.py index 98fea5a..be7715c 100644 --- a/l2/app/routers/auth.py +++ b/l2/app/routers/auth.py @@ -6,6 +6,7 @@ GET /auth/callback — exchange code for user info, set session cookie GET /auth/logout — clear cookie, redirect through account SSO logout """ +import logging import secrets import time @@ -18,6 +19,7 @@ from artdag_common.middleware.auth import UserContext, set_auth_cookie, clear_au from ..config import settings +logger = logging.getLogger(__name__) router = APIRouter() _signer = None @@ -119,24 +121,32 @@ async def callback(request: Request): return RedirectResponse(url="/", status_code=302) # Exchange code for user info via account's token endpoint + # Prefer internal URL (Docker overlay) to avoid external DNS/TLS issues + token_url = settings.oauth_token_url + if settings.internal_account_url: + token_url = f"{settings.internal_account_url.rstrip('/')}/auth/oauth/token" + async with httpx.AsyncClient(timeout=10) as client: try: resp = await client.post( - settings.oauth_token_url, + token_url, json={ "code": code, "client_id": settings.oauth_client_id, "redirect_uri": settings.oauth_redirect_uri, }, ) - except httpx.HTTPError: + except httpx.HTTPError as exc: + logger.error("OAuth token exchange failed: %s %s", type(exc).__name__, exc) return RedirectResponse(url="/", status_code=302) if resp.status_code != 200: + logger.error("OAuth token exchange returned %s: %s", resp.status_code, resp.text[:200]) return RedirectResponse(url="/", status_code=302) data = resp.json() if "error" in data: + logger.error("OAuth token exchange error: %s", data["error"]) return RedirectResponse(url="/", status_code=302) # Map OAuth response to artdag UserContext From 0e14d2761a0ca4296548104099735fb9c0d53ef8 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 01:35:11 +0000 Subject: [PATCH 16/24] Fix L2 deployment: healthcheck, DB deadlock, CI image resolution - Add /health endpoint (returns 200, skips auth middleware) - Healthcheck now hits /health instead of / (which 302s to OAuth) - Advisory lock in db.init_pool() prevents deadlock when 4 uvicorn workers race to run schema DDL - CI: --resolve-image always on docker stack deploy to force re-pull Co-Authored-By: Claude Opus 4.6 --- .gitea/workflows/ci.yml | 6 +++--- l2/app/__init__.py | 5 +++++ l2/db.py | 11 +++++++++-- l2/docker-compose.yml | 2 +- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index c2bf29e..d0fa8a7 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -96,13 +96,13 @@ jobs: echo 'Skipping L2 (no changes)' fi - # Deploy stacks + # Deploy stacks (--resolve-image always forces re-pull of :latest) if [ \"\$BUILD_L1\" = true ]; then - cd l1 && source .env && docker stack deploy -c docker-compose.yml celery && cd .. + cd l1 && source .env && docker stack deploy --resolve-image always -c docker-compose.yml celery && cd .. echo 'L1 stack deployed' fi if [ \"\$BUILD_L2\" = true ]; then - cd l2 && source .env && docker stack deploy -c docker-compose.yml activitypub && cd .. + cd l2 && source .env && docker stack deploy --resolve-image always -c docker-compose.yml activitypub && cd .. echo 'L2 stack deployed' fi diff --git a/l2/app/__init__.py b/l2/app/__init__.py index e4938e2..fbc713a 100644 --- a/l2/app/__init__.py +++ b/l2/app/__init__.py @@ -160,6 +160,11 @@ def create_app() -> FastAPI: template_dir = Path(__file__).parent / "templates" app.state.templates = create_jinja_env(template_dir) + # Health check (skips auth middleware via _SKIP_PREFIXES) + @app.get("/health") + async def health(): + return JSONResponse({"status": "ok"}) + # Custom 404 handler @app.exception_handler(404) async def not_found_handler(request: Request, exc): diff --git a/l2/db.py b/l2/db.py index 205271d..465826c 100644 --- a/l2/db.py +++ b/l2/db.py @@ -187,9 +187,16 @@ async def init_pool(): max_size=10, command_timeout=60 ) - # Create tables if they don't exist + # Create tables if they don't exist (advisory lock prevents deadlock + # when multiple uvicorn workers start simultaneously) async with _pool.acquire() as conn: - await conn.execute(SCHEMA) + acquired = await conn.fetchval("SELECT pg_try_advisory_lock(42)") + if acquired: + try: + await conn.execute(SCHEMA) + finally: + await conn.execute("SELECT pg_advisory_unlock(42)") + # If another worker holds the lock, schema is being created — skip async def close_pool(): diff --git a/l2/docker-compose.yml b/l2/docker-compose.yml index 9c91ea4..afb5644 100644 --- a/l2/docker-compose.yml +++ b/l2/docker-compose.yml @@ -60,7 +60,7 @@ services: - OAUTH_LOGOUT_URL=https://account.rose-ash.com/auth/sso-logout/ # DATABASE_URL, ARTDAG_DOMAIN, ARTDAG_USER, JWT_SECRET, SECRET_KEY from .env file healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8200/')"] + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8200/health')"] interval: 10s timeout: 5s retries: 3 From f1d80a1777b66fe124cfa374ad23d532aace5667 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 01:42:09 +0000 Subject: [PATCH 17/24] L2: verify auth state with account on each request When user has artdag_session cookie, periodically (every 30s) check account's /auth/internal/check-device endpoint. If account says the device is no longer active (SSO logout), clear the cookie immediately. Prevents stale sign-in after logging out from another app. Co-Authored-By: Claude Opus 4.6 --- l2/app/__init__.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/l2/app/__init__.py b/l2/app/__init__.py index fbc713a..982dfac 100644 --- a/l2/app/__init__.py +++ b/l2/app/__init__.py @@ -10,6 +10,7 @@ from pathlib import Path from contextlib import asynccontextmanager from urllib.parse import quote +import httpx from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse @@ -72,8 +73,43 @@ def create_app() -> FastAPI: ): return await call_next(request) - # Already logged in — pass through + # Already logged in — verify account hasn't revoked auth if get_user_from_cookie(request): + device_id = getattr(request.state, "device_id", None) + if device_id: + # Check every 30s whether account still considers this device active + check_at = request.cookies.get("auth_check_at") + now = time.time() + stale = True + if check_at: + try: + stale = (now - float(check_at)) > 30 + except (ValueError, TypeError): + pass + if stale and settings.internal_account_url: + try: + async with httpx.AsyncClient(timeout=3) as client: + resp = await client.get( + f"{settings.internal_account_url.rstrip('/')}" + f"/auth/internal/check-device" + f"?device_id={device_id}&app=artdag_l2" + ) + if resp.status_code == 200 and not resp.json().get("active"): + # Account revoked — clear cookie + response = await call_next(request) + response.delete_cookie("artdag_session") + response.delete_cookie("pnone_at") + response.delete_cookie("auth_check_at") + return response + except Exception: + pass + # Update check timestamp + response = await call_next(request) + response.set_cookie( + "auth_check_at", str(now), max_age=60, + httponly=True, samesite="lax", secure=True, + ) + return response return await call_next(request) # Check cooldown — don't re-check within 5 minutes From 84e3ff3a91a1f8591bd1809ca7149df060e7ebf8 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:10:43 +0000 Subject: [PATCH 18/24] Route GPU queue to CPU workers for CPU-based job processing Co-Authored-By: Claude Opus 4.6 --- l1/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/l1/docker-compose.yml b/l1/docker-compose.yml index 301e439..6cd912b 100644 --- a/l1/docker-compose.yml +++ b/l1/docker-compose.yml @@ -106,7 +106,7 @@ services: l1-worker: image: registry.rose-ash.com:5000/celery-l1-server:latest - command: sh -c "find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; celery -A celery_app worker --loglevel=info -E" + command: sh -c "find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; celery -A celery_app worker --loglevel=info -E -Q celery,gpu" env_file: - .env environment: From 3bffb97ca18fb716875026bea14e4b7271b85af1 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:24:59 +0000 Subject: [PATCH 19/24] Add JAX CPU to L1 worker image, CUDA JAX to GPU image CPU workers can now run GPU-queue rendering tasks via JAX on CPU. GPU image overrides with jax[cuda12] for full CUDA acceleration. Co-Authored-By: Claude Opus 4.6 --- l1/Dockerfile.gpu | 4 ++-- l1/requirements.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/l1/Dockerfile.gpu b/l1/Dockerfile.gpu index 8f764da..ad33eb8 100644 --- a/l1/Dockerfile.gpu +++ b/l1/Dockerfile.gpu @@ -63,8 +63,8 @@ RUN python3 -m pip install --upgrade pip COPY l1/requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -# Install GPU-specific dependencies (CuPy for CUDA 12.x) -RUN pip install --no-cache-dir cupy-cuda12x +# Install GPU-specific dependencies (CuPy for CUDA 12.x, JAX with CUDA) +RUN pip install --no-cache-dir cupy-cuda12x jax[cuda12] # Install PyNvVideoCodec for zero-copy GPU encoding RUN pip install --no-cache-dir PyNvVideoCodec diff --git a/l1/requirements.txt b/l1/requirements.txt index be6950f..8dd85aa 100644 --- a/l1/requirements.txt +++ b/l1/requirements.txt @@ -13,6 +13,7 @@ markdown>=3.5.0 # Common effect dependencies (used by uploaded effects) numpy>=1.24.0 opencv-python-headless>=4.8.0 +jax[cpu]>=0.4.20 # core (artdag) and common (artdag_common) installed from local dirs in Dockerfile psycopg2-binary nest_asyncio From c53227d991ccb960b4547b85736015f7080ee9f7 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:30:35 +0000 Subject: [PATCH 20/24] Fix multi-res HLS encoding on CPU: fall back to libx264 when NVENC unavailable The hardcoded h264_nvenc encoder fails on CPU-only workers. Now uses check_nvenc_available() to auto-detect and falls back to libx264. Co-Authored-By: Claude Opus 4.6 --- l1/streaming/multi_res_output.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/l1/streaming/multi_res_output.py b/l1/streaming/multi_res_output.py index 40c661a..33e4413 100644 --- a/l1/streaming/multi_res_output.py +++ b/l1/streaming/multi_res_output.py @@ -21,6 +21,8 @@ from dataclasses import dataclass, field import numpy as np +from streaming.output import check_nvenc_available + # Try GPU imports try: import cupy as cp @@ -222,14 +224,26 @@ class MultiResolutionHLSOutput: "-vf", f"scale={quality.width}:{quality.height}:flags=lanczos", ]) - # NVENC encoding with quality settings + # Encoding settings - use NVENC if available, fall back to libx264 + use_nvenc = check_nvenc_available() + if use_nvenc: + cmd.extend([ + "-c:v", "h264_nvenc", + "-preset", "p4", # Balanced speed/quality + "-tune", "hq", + "-b:v", f"{quality.bitrate}k", + "-maxrate", f"{int(quality.bitrate * 1.5)}k", + "-bufsize", f"{quality.bitrate * 2}k", + ]) + else: + cmd.extend([ + "-c:v", "libx264", + "-preset", "fast", + "-b:v", f"{quality.bitrate}k", + "-maxrate", f"{int(quality.bitrate * 1.5)}k", + "-bufsize", f"{quality.bitrate * 2}k", + ]) cmd.extend([ - "-c:v", "h264_nvenc", - "-preset", "p4", # Balanced speed/quality - "-tune", "hq", - "-b:v", f"{quality.bitrate}k", - "-maxrate", f"{int(quality.bitrate * 1.5)}k", - "-bufsize", f"{quality.bitrate * 2}k", "-g", str(int(self.fps * self.segment_duration)), # Keyframe interval = segment duration "-keyint_min", str(int(self.fps * self.segment_duration)), "-sc_threshold", "0", # Disable scene change detection for consistent segments From 07cae101ad1af511fe063e7441ca64495a2a271d Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:35:13 +0000 Subject: [PATCH 21/24] Use JAX for fused pipeline fallback on CPU instead of GPUFrame path When CUDA fused kernels aren't available, the fused-pipeline primitive now uses JAX ops (jax_rotate, jax_scale, jax_shift_hue, etc.) instead of falling back to one-by-one CuPy/GPUFrame operations. Legacy GPUFrame path retained as last resort when JAX is also unavailable. Co-Authored-By: Claude Opus 4.6 --- .../primitive_libs/streaming_gpu.py | 111 ++++++++++++++++-- 1 file changed, 102 insertions(+), 9 deletions(-) diff --git a/l1/sexp_effects/primitive_libs/streaming_gpu.py b/l1/sexp_effects/primitive_libs/streaming_gpu.py index f2aa7ea..a2374f5 100644 --- a/l1/sexp_effects/primitive_libs/streaming_gpu.py +++ b/l1/sexp_effects/primitive_libs/streaming_gpu.py @@ -842,8 +842,9 @@ def _get_cpu_primitives(): PRIMITIVES = _get_cpu_primitives().copy() -# Try to import fused kernel compiler +# Try to import fused kernel compiler (CUDA first, then JAX fallback) _FUSED_KERNELS_AVAILABLE = False +_FUSED_JAX_AVAILABLE = False _compile_frame_pipeline = None _compile_autonomous_pipeline = None try: @@ -853,7 +854,56 @@ try: _FUSED_KERNELS_AVAILABLE = True print("[streaming_gpu] Fused CUDA kernel compiler loaded", file=sys.stderr) except ImportError as e: - print(f"[streaming_gpu] Fused kernels not available: {e}", file=sys.stderr) + print(f"[streaming_gpu] Fused CUDA kernels not available: {e}", file=sys.stderr) + +# JAX fallback for fused pipeline on CPU +_jax_fused_fns = {} +try: + from streaming.sexp_to_jax import ( + jax_rotate, jax_scale, jax_shift_hue, jax_invert, + jax_adjust_brightness, jax_adjust_contrast, jax_resize, + ) + import jax.numpy as jnp + _FUSED_JAX_AVAILABLE = True + from streaming.sexp_to_jax import jax_sample + + def _jax_ripple(img, amplitude=10, frequency=8, decay=2, phase=0, cx=None, cy=None): + """JAX ripple displacement matching the CUDA fused pipeline.""" + h, w = img.shape[:2] + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) + x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) + dx = x_coords - cx + dy = y_coords - cy + dist = jnp.sqrt(dx*dx + dy*dy) + max_dim = jnp.maximum(w, h).astype(jnp.float32) + ripple = jnp.sin(2 * jnp.pi * frequency * dist / max_dim + phase) * amplitude + decay_factor = jnp.exp(-decay * dist / max_dim) + ripple = ripple * decay_factor + angle = jnp.arctan2(dy, dx) + src_x = x_coords + ripple * jnp.cos(angle) + src_y = y_coords + ripple * jnp.sin(angle) + r, g, b = jax_sample(img, src_x.flatten(), src_y.flatten()) + return jnp.stack([ + jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + _jax_fused_fns = { + 'rotate': lambda img, **kw: jax_rotate(img, kw.get('angle', 0)), + 'zoom': lambda img, **kw: jax_scale(img, kw.get('amount', 1.0)), + 'hue_shift': lambda img, **kw: jax_shift_hue(img, kw.get('degrees', 0)), + 'invert': lambda img, **kw: jax_invert(img), + 'brightness': lambda img, **kw: jax_adjust_contrast(img, kw.get('factor', 1.0)), + 'ripple': lambda img, **kw: _jax_ripple(img, **kw), + } + print("[streaming_gpu] JAX fused fallback loaded", file=sys.stderr) +except ImportError as e: + print(f"[streaming_gpu] JAX fallback not available: {e}", file=sys.stderr) # Fused pipeline cache @@ -930,9 +980,53 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): effects_list = other_effects if not _FUSED_KERNELS_AVAILABLE: - # Fallback: apply effects one by one - print(f"[FUSED FALLBACK] Using fallback path for {len(effects_list)} effects", file=sys.stderr) - # Wrap in GPUFrame if needed (GPU functions expect GPUFrame objects) + if _FUSED_JAX_AVAILABLE: + # JAX path: convert to JAX array, apply effects, convert back to numpy + if _FUSED_CALL_COUNT <= 3: + print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr) + # Extract numpy array from GPUFrame if needed + if isinstance(img, GPUFrame): + arr = img.cpu if not img.is_on_gpu else img.gpu.get() + elif hasattr(img, 'get'): + arr = img.get() # CuPy to numpy + else: + arr = np.asarray(img) + result = jnp.array(arr) + for effect in effects_list: + op = effect['op'] + if op == 'rotate': + angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) + result = _jax_fused_fns['rotate'](result, angle=angle) + elif op == 'zoom': + amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) + result = _jax_fused_fns['zoom'](result, amount=amount) + elif op == 'hue_shift': + degrees = effect.get('degrees', 0) + if abs(degrees) > 0.1: + result = _jax_fused_fns['hue_shift'](result, degrees=degrees) + elif op == 'ripple': + amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) + if amplitude > 0.1: + result = _jax_fused_fns['ripple'](result, + amplitude=amplitude, + frequency=effect.get('frequency', 8), + decay=effect.get('decay', 2), + phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), + cx=effect.get('center_x'), + cy=effect.get('center_y')) + elif op == 'brightness': + factor = effect.get('factor', 1.0) + result = _jax_fused_fns['brightness'](result, factor=factor) + elif op == 'invert': + amount = effect.get('amount', 0) + if amount > 0.5: + result = _jax_fused_fns['invert'](result) + else: + raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") + return np.asarray(result) + + # Legacy CuPy/GPUFrame fallback + print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr) if isinstance(img, GPUFrame): result = img else: @@ -948,11 +1042,11 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): result = gpu_zoom(result, amount) elif op == 'hue_shift': degrees = effect.get('degrees', 0) - if abs(degrees) > 0.1: # Only apply if significant shift + if abs(degrees) > 0.1: result = gpu_hue_shift(result, degrees) elif op == 'ripple': amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) - if amplitude > 0.1: # Only apply if amplitude is significant + if amplitude > 0.1: result = gpu_ripple(result, amplitude=amplitude, frequency=effect.get('frequency', 8), @@ -965,11 +1059,10 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): result = gpu_contrast(result, factor, 0) elif op == 'invert': amount = effect.get('amount', 0) - if amount > 0.5: # Only invert if amount > 0.5 + if amount > 0.5: result = gpu_invert(result) else: raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") - # Return raw array, not GPUFrame (downstream expects arrays with .flags attribute) if isinstance(result, GPUFrame): return result.gpu if result.is_on_gpu else result.cpu return result From 4f49985cd57453a22f7fa333a03677533b38fa67 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 15:39:29 +0000 Subject: [PATCH 22/24] Enable JAX compilation for streaming tasks The StreamInterpreter was created without use_jax=True, so the JAX compiler was never activated for production rendering. Desktop testing had this enabled but the celery task path did not. Co-Authored-By: Claude Opus 4.6 --- l1/tasks/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/l1/tasks/streaming.py b/l1/tasks/streaming.py index 7ac6057..3195921 100644 --- a/l1/tasks/streaming.py +++ b/l1/tasks/streaming.py @@ -354,7 +354,7 @@ def run_stream( checkpoint = None # Create interpreter (pass actor_id for friendly name resolution) - interp = StreamInterpreter(str(recipe_path), actor_id=actor_id) + interp = StreamInterpreter(str(recipe_path), actor_id=actor_id, use_jax=True) # Set primitive library directory explicitly interp.primitive_lib_dir = sexp_effects_dir / "primitive_libs" From b788f1f778280e901e43c2dadc91325688ac7eea Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 18:33:53 +0000 Subject: [PATCH 23/24] Fix CPU HLS streaming (yuv420p) and opt-in middleware for fragments - Add -pix_fmt yuv420p to multi_res_output.py libx264 path so browsers can decode CPU-encoded segments (was producing yuv444p / High 4:4:4). - Switch silent auth check and coop fragment middlewares from opt-out blocklists to opt-in: only run for GET requests with Accept: text/html. Prevents unnecessary nav-tree/auth-menu HTTP calls on every HLS segment, IPFS proxy, and API request. - Add opaque grant token verification to L1/L2 dependencies. - Migrate client CLI to device authorization flow. Co-Authored-By: Claude Opus 4.6 --- client/artdag.py | 282 +++++++++++++------------------ l1/app/__init__.py | 22 +-- l1/app/dependencies.py | 81 ++++++++- l1/streaming/multi_res_output.py | 1 + l2/app/dependencies.py | 54 ++++-- 5 files changed, 252 insertions(+), 188 deletions(-) diff --git a/client/artdag.py b/client/artdag.py index d28df4c..bbcbc0f 100755 --- a/client/artdag.py +++ b/client/artdag.py @@ -21,11 +21,11 @@ CONFIG_FILE = CONFIG_DIR / "config.json" # Defaults - can be overridden by env vars, config file, or CLI args _DEFAULT_SERVER = "http://localhost:8100" -_DEFAULT_L2_SERVER = "http://localhost:8200" +_DEFAULT_ACCOUNT_SERVER = "https://account.rose-ash.com" # Active server URLs (set during CLI init) DEFAULT_SERVER = None -DEFAULT_L2_SERVER = None +DEFAULT_ACCOUNT_SERVER = None def load_config() -> dict: @@ -51,9 +51,9 @@ def get_server(): return DEFAULT_SERVER -def get_l2_server(): - """Get L2 server URL.""" - return DEFAULT_L2_SERVER +def get_account_server(): + """Get account server URL.""" + return DEFAULT_ACCOUNT_SERVER def load_token() -> dict: @@ -115,24 +115,24 @@ def _get_default_server(): return config.get("server", _DEFAULT_SERVER) -def _get_default_l2(): - """Get default L2 server from env, config, or builtin default.""" - if os.environ.get("ARTDAG_L2"): - return os.environ["ARTDAG_L2"] +def _get_default_account(): + """Get default account server from env, config, or builtin default.""" + if os.environ.get("ARTDAG_ACCOUNT"): + return os.environ["ARTDAG_ACCOUNT"] config = load_config() - return config.get("l2", _DEFAULT_L2_SERVER) + return config.get("account", _DEFAULT_ACCOUNT_SERVER) @click.group() @click.option("--server", "-s", default=None, help="L1 server URL (saved for future use)") -@click.option("--l2", default=None, - help="L2 server URL (saved for future use)") +@click.option("--account", default=None, + help="Account server URL (saved for future use)") @click.pass_context -def cli(ctx, server, l2): +def cli(ctx, server, account): """Art DAG Client - interact with L1 rendering server.""" ctx.ensure_object(dict) - global DEFAULT_SERVER, DEFAULT_L2_SERVER + global DEFAULT_SERVER, DEFAULT_ACCOUNT_SERVER config = load_config() config_changed = False @@ -146,134 +146,106 @@ def cli(ctx, server, l2): else: DEFAULT_SERVER = _get_default_server() - if l2: - DEFAULT_L2_SERVER = l2 - if config.get("l2") != l2: - config["l2"] = l2 + if account: + DEFAULT_ACCOUNT_SERVER = account + if config.get("account") != account: + config["account"] = account config_changed = True else: - DEFAULT_L2_SERVER = _get_default_l2() + DEFAULT_ACCOUNT_SERVER = _get_default_account() # Save config if changed if config_changed: save_config(config) ctx.obj["server"] = DEFAULT_SERVER - ctx.obj["l2"] = DEFAULT_L2_SERVER + ctx.obj["account"] = DEFAULT_ACCOUNT_SERVER # ============ Auth Commands ============ @cli.command() -@click.argument("username") -@click.option("--password", "-p", prompt=True, hide_input=True) -def login(username, password): - """Login to get access token.""" +def login(): + """Login via device authorization flow.""" + import webbrowser + + account = get_account_server() + + # Request device code try: - # Server expects form data, not JSON resp = requests.post( - f"{get_l2_server()}/auth/login", - data={"username": username, "password": password} + f"{account}/auth/device/authorize", + json={"client_id": "artdag"}, ) - if resp.status_code == 200: - # Check if we got a token back in a cookie - if "auth_token" in resp.cookies: - token = resp.cookies["auth_token"] - # Decode token to get username and expiry - import base64 - try: - # JWT format: header.payload.signature - payload = token.split(".")[1] - # Add padding if needed - payload += "=" * (4 - len(payload) % 4) - decoded = json.loads(base64.urlsafe_b64decode(payload)) - token_data = { - "access_token": token, - "username": decoded.get("username", username), - "expires_at": decoded.get("exp", "") - } - save_token(token_data) - click.echo(f"Logged in as {token_data['username']}") - if token_data.get("expires_at"): - click.echo(f"Token expires: {token_data['expires_at']}") - except Exception: - # If we can't decode, just save the token - save_token({"access_token": token, "username": username}) - click.echo(f"Logged in as {username}") - else: - # HTML response - check for success/error - if "successful" in resp.text.lower(): - click.echo(f"Login successful but no token received. Try logging in via web browser.") - elif "invalid" in resp.text.lower(): - click.echo(f"Login failed: Invalid username or password", err=True) - sys.exit(1) - else: - click.echo(f"Login failed: {resp.text}", err=True) - sys.exit(1) - else: - click.echo(f"Login failed: {resp.text}", err=True) - sys.exit(1) + resp.raise_for_status() + data = resp.json() except requests.RequestException as e: click.echo(f"Login failed: {e}", err=True) sys.exit(1) + device_code = data["device_code"] + user_code = data["user_code"] + verification_uri = data["verification_uri"] + expires_in = data.get("expires_in", 900) + interval = data.get("interval", 5) -@cli.command() -@click.argument("username") -@click.option("--password", "-p", prompt=True, hide_input=True, confirmation_prompt=True) -@click.option("--email", "-e", default=None, help="Email (optional)") -def register(username, password, email): - """Register a new account.""" + click.echo("To sign in, open this URL in your browser:") + click.echo(f" {verification_uri}") + click.echo(f" and enter code: {user_code}") + click.echo() + + # Try to open browser automatically try: - # Server expects form data, not JSON - form_data = { - "username": username, - "password": password, - "password2": password, - } - if email: - form_data["email"] = email + webbrowser.open(verification_uri) + except Exception: + pass - resp = requests.post( - f"{get_l2_server()}/auth/register", - data=form_data - ) - if resp.status_code == 200: - # Check if we got a token back in a cookie - if "auth_token" in resp.cookies: - token = resp.cookies["auth_token"] - # Decode token to get username and expiry - import base64 - try: - # JWT format: header.payload.signature - payload = token.split(".")[1] - # Add padding if needed - payload += "=" * (4 - len(payload) % 4) - decoded = json.loads(base64.urlsafe_b64decode(payload)) - token_data = { - "access_token": token, - "username": decoded.get("username", username), - "expires_at": decoded.get("exp", "") - } - save_token(token_data) - click.echo(f"Registered and logged in as {token_data['username']}") - except Exception: - # If we can't decode, just save the token - save_token({"access_token": token, "username": username}) - click.echo(f"Registered and logged in as {username}") - else: - # HTML response - registration successful - if "successful" in resp.text.lower(): - click.echo(f"Registered as {username}. Please login to get a token.") - else: - click.echo(f"Registration failed: {resp.text}", err=True) - sys.exit(1) - else: - click.echo(f"Registration failed: {resp.text}", err=True) + # Poll for approval + click.echo("Waiting for authorization", nl=False) + deadline = time.time() + expires_in + while time.time() < deadline: + time.sleep(interval) + click.echo(".", nl=False) + + try: + resp = requests.post( + f"{account}/auth/device/token", + json={"device_code": device_code, "client_id": "artdag"}, + ) + data = resp.json() + except requests.RequestException: + continue + + error = data.get("error") + if error == "authorization_pending": + continue + elif error == "expired_token": + click.echo() + click.echo("Code expired. Please try again.", err=True) sys.exit(1) - except requests.RequestException as e: - click.echo(f"Registration failed: {e}", err=True) - sys.exit(1) + elif error == "access_denied": + click.echo() + click.echo("Authorization denied.", err=True) + sys.exit(1) + elif error: + click.echo() + click.echo(f"Login failed: {error}", err=True) + sys.exit(1) + + # Success + token_data = { + "access_token": data["access_token"], + "username": data.get("username", ""), + "display_name": data.get("display_name", ""), + } + save_token(token_data) + click.echo() + click.echo(f"Logged in as {token_data['username'] or token_data['display_name']}") + return + + click.echo() + click.echo("Timed out waiting for authorization.", err=True) + sys.exit(1) @cli.command() @@ -291,22 +263,12 @@ def whoami(): click.echo("Not logged in") return - try: - resp = requests.get( - f"{get_l2_server()}/auth/me", - headers={"Authorization": f"Bearer {token_data['access_token']}"} - ) - if resp.status_code == 200: - user = resp.json() - click.echo(f"Username: {user['username']}") - click.echo(f"Created: {user['created_at']}") - if user.get('email'): - click.echo(f"Email: {user['email']}") - else: - click.echo("Token invalid or expired. Please login again.", err=True) - clear_token() - except requests.RequestException as e: - click.echo(f"Error: {e}", err=True) + username = token_data.get("username", "") + display_name = token_data.get("display_name", "") + if username: + click.echo(f"Username: {username}") + if display_name: + click.echo(f"Name: {display_name}") @cli.command("config") @@ -332,11 +294,11 @@ def show_config(clear): else: click.echo(f" (default)") - click.echo(f"L2 Server: {DEFAULT_L2_SERVER}") - if config.get("l2"): + click.echo(f"Account Server: {DEFAULT_ACCOUNT_SERVER}") + if config.get("account"): click.echo(f" (saved)") - elif os.environ.get("ARTDAG_L2"): - click.echo(f" (from ARTDAG_L2 env)") + elif os.environ.get("ARTDAG_ACCOUNT"): + click.echo(f" (from ARTDAG_ACCOUNT env)") else: click.echo(f" (default)") @@ -459,7 +421,7 @@ def run(recipe, input_hash, name, wait): # Check auth token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Resolve named assets @@ -691,7 +653,7 @@ def delete_run(run_id, force): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Get run info first @@ -741,7 +703,7 @@ def delete_cache(cid, force): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) if not force: @@ -927,7 +889,7 @@ def upload(filepath, name): # Check auth token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -980,13 +942,13 @@ def publish(run_id, output_name): # Check auth token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Post to L2 server with auth, including which L1 server has the run try: resp = requests.post( - f"{get_l2_server()}/registry/record-run", + f"{get_server()}/registry/record-run", json={"run_id": run_id, "output_name": output_name, "l1_server": get_server()}, headers={"Authorization": f"Bearer {token_data['access_token']}"} ) @@ -1031,7 +993,7 @@ def meta(cid, origin, origin_url, origin_note, description, tags, folder, add_co """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) headers = get_auth_header(require_token=True) @@ -1200,7 +1162,7 @@ def folder_list(): """List all folders.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1225,7 +1187,7 @@ def folder_create(path): """Create a new folder.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1251,7 +1213,7 @@ def folder_delete(path): """Delete a folder (must be empty).""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1287,7 +1249,7 @@ def collection_list(): """List all collections.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1312,7 +1274,7 @@ def collection_create(name): """Create a new collection.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1338,7 +1300,7 @@ def collection_delete(name): """Delete a collection.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -1546,7 +1508,7 @@ def upload_recipe(filepath): """Upload a recipe file (YAML or S-expression). Requires login.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Read content @@ -1607,7 +1569,7 @@ def upload_effect(filepath, name): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Check it's a sexp or py file @@ -1854,7 +1816,7 @@ def run_recipe(recipe_id, inputs, wait): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Parse inputs @@ -1922,7 +1884,7 @@ def delete_recipe(recipe_id, force): """Delete a recipe. Requires login.""" token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) if not force: @@ -1969,7 +1931,7 @@ def generate_plan(recipe_file, inputs, features, output): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Read recipe YAML @@ -2047,7 +2009,7 @@ def execute_plan(plan_file, wait): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Read plan JSON @@ -2098,7 +2060,7 @@ def run_recipe_v2(recipe_file, inputs, features, wait): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Read recipe YAML @@ -2213,7 +2175,7 @@ def run_status_v2(run_id): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) try: @@ -2267,7 +2229,7 @@ def run_stream(recipe_file, output, duration, fps, sources, audio, wait): """ token_data = load_token() if not token_data.get("access_token"): - click.echo("Not logged in. Please run: artdag login ", err=True) + click.echo("Not logged in. Please run: artdag login", err=True) sys.exit(1) # Read recipe file diff --git a/l1/app/__init__.py b/l1/app/__init__.py index 3b945b3..15617da 100644 --- a/l1/app/__init__.py +++ b/l1/app/__init__.py @@ -18,8 +18,6 @@ from artdag_common.middleware.auth import get_user_from_cookie from .config import settings -# Paths that should never trigger a silent auth check -_SKIP_PREFIXES = ("/auth/", "/static/", "/api/", "/ipfs/", "/download/", "/inbox", "/health", "/internal/", "/oembed") _SILENT_CHECK_COOLDOWN = 300 # 5 minutes _DEVICE_COOKIE = "artdag_did" _DEVICE_COOKIE_MAX_AGE = 30 * 24 * 3600 # 30 days @@ -60,14 +58,15 @@ def create_app() -> FastAPI: async def shutdown(): await close_db() - # Silent auth check — auto-login via prompt=none OAuth + # Silent auth check — auto-login via prompt=none OAuth. + # Only runs for browser page loads (Accept: text/html). # NOTE: registered BEFORE device_id so device_id is outermost (runs first) @app.middleware("http") async def silent_auth_check(request: Request, call_next): - path = request.url.path + accept = request.headers.get("accept", "") if ( request.method != "GET" - or any(path.startswith(p) for p in _SKIP_PREFIXES) + or "text/html" not in accept or request.headers.get("hx-request") # skip HTMX ): return await call_next(request) @@ -148,17 +147,14 @@ def create_app() -> FastAPI: return response # Coop fragment pre-fetch — inject nav-tree, auth-menu, cart-mini into - # request.state for full-page HTML renders. Skips HTMX, API, and - # internal paths. Failures are silent (fragments default to ""). - _FRAG_SKIP = ("/auth/", "/api/", "/internal/", "/health", "/oembed", - "/ipfs/", "/download/", "/inbox", "/static/") - + # request.state for full-page HTML renders. Opt-in: only fetches for + # browser page loads (Accept: text/html, non-HTMX GET requests). @app.middleware("http") async def coop_fragments_middleware(request: Request, call_next): - path = request.url.path + accept = request.headers.get("accept", "") if ( request.method != "GET" - or any(path.startswith(p) for p in _FRAG_SKIP) + or "text/html" not in accept or request.headers.get("hx-request") or request.headers.get(fragments.FRAGMENT_HEADER) ): @@ -171,7 +167,7 @@ def create_app() -> FastAPI: user = get_user_from_cookie(request) auth_params = {"email": user.email or user.username} if user else {} - nav_params = {"app_name": "artdag", "path": path} + nav_params = {"app_name": "artdag", "path": request.url.path} try: nav_tree_html, auth_menu_html, cart_mini_html = await _fetch_frags([ diff --git a/l1/app/dependencies.py b/l1/app/dependencies.py index fc59947..1e11831 100644 --- a/l1/app/dependencies.py +++ b/l1/app/dependencies.py @@ -54,6 +54,77 @@ def get_templates(request: Request) -> Environment: return request.app.state.templates +async def _verify_opaque_grant(token: str) -> Optional[UserContext]: + """Verify an opaque grant token via account server, with Redis cache.""" + import httpx + import json + + if not settings.internal_account_url: + return None + + # Check L1 Redis cache first + cache_key = f"grant_verify:{token[:16]}" + try: + r = get_redis_client() + cached = r.get(cache_key) + if cached is not None: + if cached == "__invalid__": + return None + data = json.loads(cached) + return UserContext( + username=data["username"], + actor_id=data["actor_id"], + token=token, + email=data.get("email", ""), + ) + except Exception: + pass + + # Call account server + verify_url = f"{settings.internal_account_url.rstrip('/')}/auth/internal/verify-grant" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(verify_url, params={"token": token}) + if resp.status_code != 200: + return None + data = resp.json() + if not data.get("valid"): + # Cache negative result briefly + try: + r = get_redis_client() + r.set(cache_key, "__invalid__", ex=60) + except Exception: + pass + return None + except Exception: + return None + + username = data.get("username", "") + display_name = data.get("display_name", "") + actor_id = f"@{username}" if username else "" + ctx = UserContext( + username=username, + actor_id=actor_id, + token=token, + email=username, + ) + + # Cache positive result for 5 minutes + try: + r = get_redis_client() + cache_data = json.dumps({ + "username": username, + "actor_id": actor_id, + "email": username, + "display_name": display_name, + }) + r.set(cache_key, cache_data, ex=300) + except Exception: + pass + + return ctx + + async def get_current_user(request: Request) -> Optional[UserContext]: """ Get the current user from request (cookie or header). @@ -61,11 +132,19 @@ async def get_current_user(request: Request) -> Optional[UserContext]: This is a permissive dependency - returns None if not authenticated. Use require_auth for routes that require authentication. """ - # Try header first (API clients) + # Try header first (API clients — JWT tokens) ctx = get_user_from_header(request) if ctx: return ctx + # Try opaque grant token (device flow / CLI tokens) + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + ctx = await _verify_opaque_grant(token) + if ctx: + return ctx + # Fall back to cookie (browser) return get_user_from_cookie(request) diff --git a/l1/streaming/multi_res_output.py b/l1/streaming/multi_res_output.py index 33e4413..dfa7d8d 100644 --- a/l1/streaming/multi_res_output.py +++ b/l1/streaming/multi_res_output.py @@ -244,6 +244,7 @@ class MultiResolutionHLSOutput: "-bufsize", f"{quality.bitrate * 2}k", ]) cmd.extend([ + "-pix_fmt", "yuv420p", # Required for browser MSE compatibility "-g", str(int(self.fps * self.segment_duration)), # Keyframe interval = segment duration "-keyint_min", str(int(self.fps * self.segment_duration)), "-sc_threshold", "0", # Disable scene change detection for consistent segments diff --git a/l2/app/dependencies.py b/l2/app/dependencies.py index d10d063..e3c0e1e 100644 --- a/l2/app/dependencies.py +++ b/l2/app/dependencies.py @@ -19,6 +19,34 @@ def get_templates(request: Request): return request.app.state.templates +async def _verify_opaque_grant(token: str) -> Optional[dict]: + """Verify an opaque grant token via account server.""" + import httpx + + if not settings.internal_account_url: + return None + + verify_url = f"{settings.internal_account_url.rstrip('/')}/auth/internal/verify-grant" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(verify_url, params={"token": token}) + if resp.status_code != 200: + return None + data = resp.json() + if not data.get("valid"): + return None + except Exception: + return None + + username = data.get("username", "") + return { + "username": username, + "actor_id": f"https://{settings.domain}/users/{username}", + "token": token, + "sub": username, + } + + async def get_current_user(request: Request) -> Optional[dict]: """ Get current user from cookie or header. @@ -39,22 +67,20 @@ async def get_current_user(request: Request) -> Optional[dict]: if not token: return None - # Verify token + # Verify JWT token username = verify_token(token) - if not username: - return None + if username: + claims = get_token_claims(token) + if claims: + return { + "username": username, + "actor_id": f"https://{settings.domain}/users/{username}", + "token": token, + **claims, + } - # Get full claims - claims = get_token_claims(token) - if not claims: - return None - - return { - "username": username, - "actor_id": f"https://{settings.domain}/users/{username}", - "token": token, - **claims, - } + # JWT failed — try as opaque grant token + return await _verify_opaque_grant(token) async def require_auth(request: Request) -> dict: From 4c2e7165585c70a68239ad60df08d1e8bc6a7594 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 25 Feb 2026 19:31:53 +0000 Subject: [PATCH 24/24] Make JAX the primary fused-pipeline path for CPU/GPU parity JAX via XLA produces identical output on CPU and GPU. Previously CUDA hand-written kernels were preferred on GPU, causing visual differences vs the JAX CPU fallback. Now JAX is always used first, with legacy CuPy/GPUFrame as fallback only when JAX is unavailable. Also adds comprehensive CLAUDE.md for the monorepo. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 74 +++++++++++++++ .../primitive_libs/streaming_gpu.py | 89 ++++++++++--------- 2 files changed, 119 insertions(+), 44 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..afb00d7 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,74 @@ +# Art DAG Monorepo + +Federated content-addressed DAG execution engine for distributed media processing with ActivityPub ownership and provenance tracking. + +## Project Structure + +``` +core/ # DAG engine (artdag package) - nodes, effects, analysis, planning +l1/ # L1 Celery rendering server (FastAPI + Celery + Redis + PostgreSQL) +l2/ # L2 ActivityPub registry (FastAPI + PostgreSQL) +common/ # Shared templates, middleware, models (artdag_common package) +client/ # CLI client +test/ # Integration & e2e tests +``` + +## Tech Stack + +Python 3.11+, FastAPI, Celery, Redis, PostgreSQL (asyncpg for L1), SQLAlchemy, Pydantic, JAX (CPU/GPU), IPFS/Kubo, Docker Swarm, HTMX + Jinja2 for web UI. + +## Key Commands + +### Testing +```bash +cd l1 && pytest tests/ # L1 unit tests +cd core && pytest tests/ # Core unit tests +cd test && python run.py # Full integration pipeline +``` +- pytest uses `asyncio_mode = "auto"` for async tests +- Test files: `test_*.py`, fixtures in `conftest.py` + +### Linting & Type Checking (L1) +```bash +cd l1 && ruff check . # Lint (E, F, I, UP rules) +cd l1 && mypy app/types.py app/routers/recipes.py tests/ +``` +- Line length: 100 chars (E501 ignored) +- Mypy: strict on `app/types.py`, `app/routers/recipes.py`, `tests/`; gradual elsewhere +- Mypy ignores imports for: celery, redis, artdag, artdag_common, ipfs_client + +### Docker +```bash +docker build -f l1/Dockerfile -t celery-l1-server:latest . +docker build -f l1/Dockerfile.gpu -t celery-l1-gpu:latest . +docker build -f l2/Dockerfile -t l2-server:latest . +./deploy.sh # Build, push, deploy stacks +``` + +## Architecture Patterns + +- **3-Phase Execution**: Analyze -> Plan -> Execute (tasks in `l1/tasks/`) +- **Content-Addressed**: All data identified by SHA3-256 hashes or IPFS CIDs +- **Services Pattern**: Business logic in `app/services/`, API endpoints in `app/routers/` +- **Types Module**: Pydantic models and TypedDicts in `app/types.py` +- **Celery Tasks**: In `l1/tasks/`, decorated with `@app.task` +- **S-Expression Effects**: Composable effect language in `l1/sexp_effects/` +- **Storage**: Local filesystem, S3, or IPFS backends (`storage_providers.py`) + +## Auth + +- L1 <-> L2: scoped JWT tokens (no shared secrets) +- L2: password + OAuth SSO, token revocation in Redis (30-day expiry) +- Federation: ActivityPub RSA signatures (`core/artdag/activitypub/`) + +## Key Config Files + +- `l1/pyproject.toml` - mypy, pytest, ruff config for L1 +- `l1/celery_app.py` - Celery initialization +- `l1/database.py` / `l2/db.py` - SQLAlchemy models +- `l1/docker-compose.yml` / `l2/docker-compose.yml` - Swarm stacks + +## Tools + +- Use Context7 MCP for up-to-date library documentation +- Playwright MCP is available for browser automation/testing diff --git a/l1/sexp_effects/primitive_libs/streaming_gpu.py b/l1/sexp_effects/primitive_libs/streaming_gpu.py index a2374f5..145441a 100644 --- a/l1/sexp_effects/primitive_libs/streaming_gpu.py +++ b/l1/sexp_effects/primitive_libs/streaming_gpu.py @@ -979,51 +979,52 @@ def prim_fused_pipeline(img, effects_list, **dynamic_params): # Update effects list to exclude resize ops effects_list = other_effects - if not _FUSED_KERNELS_AVAILABLE: - if _FUSED_JAX_AVAILABLE: - # JAX path: convert to JAX array, apply effects, convert back to numpy - if _FUSED_CALL_COUNT <= 3: - print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr) - # Extract numpy array from GPUFrame if needed - if isinstance(img, GPUFrame): - arr = img.cpu if not img.is_on_gpu else img.gpu.get() - elif hasattr(img, 'get'): - arr = img.get() # CuPy to numpy + # JAX is the primary path — same code on CPU and GPU, XLA handles dispatch + if _FUSED_JAX_AVAILABLE: + if _FUSED_CALL_COUNT <= 3: + print(f"[FUSED JAX] Using JAX path for {len(effects_list)} effects", file=sys.stderr) + # Extract numpy array from GPUFrame if needed + if isinstance(img, GPUFrame): + arr = img.cpu if not img.is_on_gpu else img.gpu.get() + elif hasattr(img, 'get'): + arr = img.get() # CuPy to numpy + else: + arr = np.asarray(img) + result = jnp.array(arr) + for effect in effects_list: + op = effect['op'] + if op == 'rotate': + angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) + result = _jax_fused_fns['rotate'](result, angle=angle) + elif op == 'zoom': + amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) + result = _jax_fused_fns['zoom'](result, amount=amount) + elif op == 'hue_shift': + degrees = effect.get('degrees', 0) + if abs(degrees) > 0.1: + result = _jax_fused_fns['hue_shift'](result, degrees=degrees) + elif op == 'ripple': + amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) + if amplitude > 0.1: + result = _jax_fused_fns['ripple'](result, + amplitude=amplitude, + frequency=effect.get('frequency', 8), + decay=effect.get('decay', 2), + phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), + cx=effect.get('center_x'), + cy=effect.get('center_y')) + elif op == 'brightness': + factor = effect.get('factor', 1.0) + result = _jax_fused_fns['brightness'](result, factor=factor) + elif op == 'invert': + amount = effect.get('amount', 0) + if amount > 0.5: + result = _jax_fused_fns['invert'](result) else: - arr = np.asarray(img) - result = jnp.array(arr) - for effect in effects_list: - op = effect['op'] - if op == 'rotate': - angle = dynamic_params.get('rotate_angle', effect.get('angle', 0)) - result = _jax_fused_fns['rotate'](result, angle=angle) - elif op == 'zoom': - amount = dynamic_params.get('zoom_amount', effect.get('amount', 1.0)) - result = _jax_fused_fns['zoom'](result, amount=amount) - elif op == 'hue_shift': - degrees = effect.get('degrees', 0) - if abs(degrees) > 0.1: - result = _jax_fused_fns['hue_shift'](result, degrees=degrees) - elif op == 'ripple': - amplitude = dynamic_params.get('ripple_amplitude', effect.get('amplitude', 10)) - if amplitude > 0.1: - result = _jax_fused_fns['ripple'](result, - amplitude=amplitude, - frequency=effect.get('frequency', 8), - decay=effect.get('decay', 2), - phase=dynamic_params.get('ripple_phase', effect.get('phase', 0)), - cx=effect.get('center_x'), - cy=effect.get('center_y')) - elif op == 'brightness': - factor = effect.get('factor', 1.0) - result = _jax_fused_fns['brightness'](result, factor=factor) - elif op == 'invert': - amount = effect.get('amount', 0) - if amount > 0.5: - result = _jax_fused_fns['invert'](result) - else: - raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") - return np.asarray(result) + raise ValueError(f"Unsupported fused pipeline operation: '{op}'. Supported ops: rotate, zoom, hue_shift, ripple, brightness, invert, resize") + return np.asarray(result) + + if not _FUSED_KERNELS_AVAILABLE: # Legacy CuPy/GPUFrame fallback print(f"[FUSED FALLBACK] Using legacy GPUFrame path for {len(effects_list)} effects", file=sys.stderr)