diff --git a/path_registry.py b/path_registry.py new file mode 100644 index 0000000..985be18 --- /dev/null +++ b/path_registry.py @@ -0,0 +1,477 @@ +""" +Path Registry - Maps human-friendly paths to content-addressed IDs. + +This module provides a bidirectional mapping between: +- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp") +- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes) + +The registry is useful for: +- Looking up effects by their friendly path name +- Resolving cids back to the original path for debugging +- Maintaining a stable naming scheme across cache updates + +Storage: +- Uses the existing item_types table in the database (path column) +- Caches in Redis for fast lookups across distributed workers + +The registry uses a system actor (@system@local) for global path mappings, +allowing effects to be resolved by path without requiring user context. +""" + +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# System actor for global path mappings (effects, recipes, analyzers) +SYSTEM_ACTOR = "@system@local" + + +@dataclass +class PathEntry: + """A registered path with its content-addressed ID.""" + path: str # Human-friendly path (relative or normalized) + cid: str # Content-addressed ID (IPFS CID or hash) + content_type: str # Type: "effect", "recipe", "analyzer", etc. + actor_id: str = SYSTEM_ACTOR # Owner (system for global) + description: Optional[str] = None + created_at: float = 0.0 + + +class PathRegistry: + """ + Registry for mapping paths to content-addressed IDs. + + Uses the existing item_types table for persistence and Redis + for fast lookups in distributed Celery workers. + """ + + def __init__(self, redis_client=None): + self._redis = redis_client + self._redis_path_to_cid_key = "artdag:path_to_cid" + self._redis_cid_to_path_key = "artdag:cid_to_path" + + def _run_async(self, coro): + """Run async coroutine from sync context.""" + import asyncio + + try: + loop = asyncio.get_running_loop() + import threading + result = [None] + error = [None] + + def run_in_thread(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + result[0] = new_loop.run_until_complete(coro) + finally: + new_loop.close() + except Exception as e: + error[0] = e + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) + if error[0]: + raise error[0] + return result[0] + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + def _normalize_path(self, path: str) -> str: + """Normalize a path for consistent storage.""" + # Remove leading ./ or / + path = path.lstrip('./') + # Normalize separators + path = path.replace('\\', '/') + # Remove duplicate slashes + while '//' in path: + path = path.replace('//', '/') + return path + + def register( + self, + path: str, + cid: str, + content_type: str = "effect", + actor_id: str = SYSTEM_ACTOR, + description: Optional[str] = None, + ) -> PathEntry: + """ + Register a path -> cid mapping. + + Args: + path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp") + cid: Content-addressed ID (IPFS CID or hash) + content_type: Type of content ("effect", "recipe", "analyzer") + actor_id: Owner (default: system for global mappings) + description: Optional description + + Returns: + The created PathEntry + """ + norm_path = self._normalize_path(path) + now = datetime.now(timezone.utc).timestamp() + + entry = PathEntry( + path=norm_path, + cid=cid, + content_type=content_type, + actor_id=actor_id, + description=description, + created_at=now, + ) + + # Store in database (item_types table) + self._save_to_db(entry) + + # Update Redis cache + self._update_redis_cache(norm_path, cid) + + logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...") + return entry + + def _save_to_db(self, entry: PathEntry): + """Save entry to database using item_types table.""" + import database + + async def save(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + # Ensure cache_item exists + await conn.execute( + "INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING", + entry.cid + ) + # Insert or update item_type with path + await conn.execute( + """ + INSERT INTO item_types (cid, actor_id, type, path, description) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET + description = COALESCE(EXCLUDED.description, item_types.description) + """, + entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description + ) + finally: + await conn.close() + + try: + self._run_async(save()) + except Exception as e: + logger.warning(f"Failed to save path registry to DB: {e}") + + def _update_redis_cache(self, path: str, cid: str): + """Update Redis cache with mapping.""" + if self._redis: + try: + self._redis.hset(self._redis_path_to_cid_key, path, cid) + self._redis.hset(self._redis_cid_to_path_key, cid, path) + except Exception as e: + logger.warning(f"Failed to update Redis cache: {e}") + + def get_cid(self, path: str, content_type: str = None) -> Optional[str]: + """ + Get the cid for a path. + + Args: + path: Human-friendly path + content_type: Optional type filter + + Returns: + The cid, or None if not found + """ + norm_path = self._normalize_path(path) + + # Try Redis first (fast path) + if self._redis: + try: + val = self._redis.hget(self._redis_path_to_cid_key, norm_path) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_cid_from_db(norm_path, content_type) + + def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]: + """Get cid from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + row = await conn.fetchrow( + "SELECT cid FROM item_types WHERE path = $1", + path + ) + return row["cid"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(path, result) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def get_path(self, cid: str) -> Optional[str]: + """ + Get the path for a cid. + + Args: + cid: Content-addressed ID + + Returns: + The path, or None if not found + """ + # Try Redis first + if self._redis: + try: + val = self._redis.hget(self._redis_cid_to_path_key, cid) + if val: + return val.decode() if isinstance(val, bytes) else val + except Exception as e: + logger.warning(f"Redis lookup failed: {e}") + + # Fall back to database + return self._get_path_from_db(cid) + + def _get_path_from_db(self, cid: str) -> Optional[str]: + """Get path from database using item_types table.""" + import database + + async def get(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + row = await conn.fetchrow( + "SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1", + cid + ) + return row["path"] if row else None + finally: + await conn.close() + + try: + result = self._run_async(get()) + # Update Redis cache if found + if result and self._redis: + self._update_redis_cache(result, cid) + return result + except Exception as e: + logger.warning(f"Failed to get from DB: {e}") + return None + + def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]: + """ + List all entries of a given type. + + Args: + content_type: Type to filter by ("effect", "recipe", etc.) + actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global) + + Returns: + List of PathEntry objects + """ + import database + + async def list_entries(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if actor_id: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL + ORDER BY path + """, + content_type, actor_id + ) + else: + rows = await conn.fetch( + """ + SELECT cid, path, type, actor_id, description, + EXTRACT(EPOCH FROM created_at) as created_at + FROM item_types + WHERE type = $1 AND path IS NOT NULL + ORDER BY path + """, + content_type + ) + return [ + PathEntry( + path=row["path"], + cid=row["cid"], + content_type=row["type"], + actor_id=row["actor_id"], + description=row["description"], + created_at=row["created_at"] or 0, + ) + for row in rows + ] + finally: + await conn.close() + + try: + return self._run_async(list_entries()) + except Exception as e: + logger.warning(f"Failed to list from DB: {e}") + return [] + + def delete(self, path: str, content_type: str = None) -> bool: + """ + Delete a path registration. + + Args: + path: The path to delete + content_type: Optional type filter + + Returns: + True if deleted, False if not found + """ + norm_path = self._normalize_path(path) + + # Get cid for Redis cleanup + cid = self.get_cid(norm_path, content_type) + + # Delete from database + deleted = self._delete_from_db(norm_path, content_type) + + # Clean up Redis + if deleted and cid and self._redis: + try: + self._redis.hdel(self._redis_path_to_cid_key, norm_path) + self._redis.hdel(self._redis_cid_to_path_key, cid) + except Exception as e: + logger.warning(f"Failed to clean up Redis: {e}") + + return deleted + + def _delete_from_db(self, path: str, content_type: str = None) -> bool: + """Delete from database.""" + import database + + async def delete(): + import asyncpg + conn = await asyncpg.connect(database.DATABASE_URL) + try: + if content_type: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1 AND type = $2", + path, content_type + ) + else: + result = await conn.execute( + "DELETE FROM item_types WHERE path = $1", + path + ) + return "DELETE" in result + finally: + await conn.close() + + try: + return self._run_async(delete()) + except Exception as e: + logger.warning(f"Failed to delete from DB: {e}") + return False + + def register_effect( + self, + path: str, + cid: str, + description: Optional[str] = None, + ) -> PathEntry: + """ + Convenience method to register an effect. + + Args: + path: Effect path (e.g., "effects/ascii_fx_zone.sexp") + cid: IPFS CID of the effect file + description: Optional description + + Returns: + The created PathEntry + """ + return self.register( + path=path, + cid=cid, + content_type="effect", + actor_id=SYSTEM_ACTOR, + description=description, + ) + + def get_effect_cid(self, path: str) -> Optional[str]: + """ + Get CID for an effect by path. + + Args: + path: Effect path + + Returns: + IPFS CID or None + """ + return self.get_cid(path, content_type="effect") + + def list_effects(self) -> List[PathEntry]: + """List all registered effects.""" + return self.list_by_type("effect", actor_id=SYSTEM_ACTOR) + + +# Singleton instance +_registry: Optional[PathRegistry] = None + + +def get_path_registry() -> PathRegistry: + """Get the singleton path registry instance.""" + global _registry + if _registry is None: + import redis + from urllib.parse import urlparse + + redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') + parsed = urlparse(redis_url) + redis_client = redis.Redis( + host=parsed.hostname or 'localhost', + port=parsed.port or 6379, + db=int(parsed.path.lstrip('/') or 0), + socket_timeout=5, + socket_connect_timeout=5 + ) + + _registry = PathRegistry(redis_client=redis_client) + return _registry + + +def reset_path_registry(): + """Reset the singleton (for testing).""" + global _registry + _registry = None diff --git a/sexp_effects/derived.sexp b/sexp_effects/derived.sexp new file mode 100644 index 0000000..7e1aae3 --- /dev/null +++ b/sexp_effects/derived.sexp @@ -0,0 +1,206 @@ +;; Derived Operations +;; +;; These are built from true primitives using S-expressions. +;; Load with: (require "derived") + +;; ============================================================================= +;; Math Helpers (derivable from where + basic ops) +;; ============================================================================= + +;; Absolute value +(define (abs x) (where (< x 0) (- x) x)) + +;; Minimum of two values +(define (min2 a b) (where (< a b) a b)) + +;; Maximum of two values +(define (max2 a b) (where (> a b) a b)) + +;; Clamp x to range [lo, hi] +(define (clamp x lo hi) (max2 lo (min2 hi x))) + +;; Square of x +(define (sq x) (* x x)) + +;; Linear interpolation: a*(1-t) + b*t +(define (lerp a b t) (+ (* a (- 1 t)) (* b t))) + +;; Smooth interpolation between edges +(define (smoothstep edge0 edge1 x) + (let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1))) + (* t (* t (- 3 (* 2 t)))))) + +;; ============================================================================= +;; Channel Shortcuts (derivable from channel primitive) +;; ============================================================================= + +;; Extract red channel as xector +(define (red frame) (channel frame 0)) + +;; Extract green channel as xector +(define (green frame) (channel frame 1)) + +;; Extract blue channel as xector +(define (blue frame) (channel frame 2)) + +;; Convert to grayscale xector (ITU-R BT.601) +(define (gray frame) + (+ (* (red frame) 0.299) + (* (green frame) 0.587) + (* (blue frame) 0.114))) + +;; Alias for gray +(define (luminance frame) (gray frame)) + +;; ============================================================================= +;; Coordinate Generators (derivable from iota + repeat/tile) +;; ============================================================================= + +;; X coordinate for each pixel [0, width) +(define (x-coords frame) (tile (iota (width frame)) (height frame))) + +;; Y coordinate for each pixel [0, height) +(define (y-coords frame) (repeat (iota (height frame)) (width frame))) + +;; Normalized X coordinate [0, 1] +(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1)))) + +;; Normalized Y coordinate [0, 1] +(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1)))) + +;; Distance from frame center for each pixel +(define (dist-from-center frame) + (let* ((cx (/ (width frame) 2)) + (cy (/ (height frame) 2)) + (dx (- (x-coords frame) cx)) + (dy (- (y-coords frame) cy))) + (sqrt (+ (sq dx) (sq dy))))) + +;; Normalized distance from center [0, ~1] +(define (dist-norm frame) + (let ((d (dist-from-center frame))) + (/ d (max2 1 (βmax d))))) + +;; ============================================================================= +;; Cell/Grid Operations (derivable from floor + basic math) +;; ============================================================================= + +;; Cell row index for each pixel +(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size))) + +;; Cell column index for each pixel +(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size))) + +;; Number of cell rows +(define (num-rows frame cell-size) (floor (/ (height frame) cell-size))) + +;; Number of cell columns +(define (num-cols frame cell-size) (floor (/ (width frame) cell-size))) + +;; Flat cell index for each pixel +(define (cell-indices frame cell-size) + (+ (* (cell-row frame cell-size) (num-cols frame cell-size)) + (cell-col frame cell-size))) + +;; Total number of cells +(define (num-cells frame cell-size) + (* (num-rows frame cell-size) (num-cols frame cell-size))) + +;; X position within cell [0, cell-size) +(define (local-x frame cell-size) (mod (x-coords frame) cell-size)) + +;; Y position within cell [0, cell-size) +(define (local-y frame cell-size) (mod (y-coords frame) cell-size)) + +;; Normalized X within cell [0, 1] +(define (local-x-norm frame cell-size) + (/ (local-x frame cell-size) (max2 1 (- cell-size 1)))) + +;; Normalized Y within cell [0, 1] +(define (local-y-norm frame cell-size) + (/ (local-y frame cell-size) (max2 1 (- cell-size 1)))) + +;; ============================================================================= +;; Fill Operations (derivable from iota) +;; ============================================================================= + +;; Xector of n zeros +(define (zeros n) (* (iota n) 0)) + +;; Xector of n ones +(define (ones n) (+ (zeros n) 1)) + +;; Xector of n copies of val +(define (fill val n) (+ (zeros n) val)) + +;; Xector of zeros matching x's length +(define (zeros-like x) (* x 0)) + +;; Xector of ones matching x's length +(define (ones-like x) (+ (zeros-like x) 1)) + +;; ============================================================================= +;; Pooling (derivable from group-reduce) +;; ============================================================================= + +;; Pool a channel by cell index +(define (pool-channel chan cell-idx num-cells) + (group-reduce chan cell-idx num-cells "mean")) + +;; Pool red channel to cells +(define (pool-red frame cell-size) + (pool-channel (red frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool green channel to cells +(define (pool-green frame cell-size) + (pool-channel (green frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool blue channel to cells +(define (pool-blue frame cell-size) + (pool-channel (blue frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; Pool grayscale to cells +(define (pool-gray frame cell-size) + (pool-channel (gray frame) + (cell-indices frame cell-size) + (num-cells frame cell-size))) + +;; ============================================================================= +;; Blending (derivable from math) +;; ============================================================================= + +;; Additive blend +(define (blend-add a b) (clamp (+ a b) 0 255)) + +;; Multiply blend (normalized) +(define (blend-multiply a b) (* (/ a 255) b)) + +;; Screen blend +(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b)))) + +;; Overlay blend +(define (blend-overlay a b) + (where (< a 128) + (* 2 (/ (* a b) 255)) + (- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255))))) + +;; ============================================================================= +;; Simple Effects (derivable from primitives) +;; ============================================================================= + +;; Invert a channel (255 - c) +(define (invert-channel c) (- 255 c)) + +;; Binary threshold +(define (threshold-channel c thresh) (where (> c thresh) 255 0)) + +;; Reduce to n levels +(define (posterize-channel c levels) + (let ((step (/ 255 (- levels 1)))) + (* (round (/ c step)) step))) diff --git a/sexp_effects/effects/ascii_art.sexp b/sexp_effects/effects/ascii_art.sexp index 5565872..0504768 100644 --- a/sexp_effects/effects/ascii_art.sexp +++ b/sexp_effects/effects/ascii_art.sexp @@ -5,7 +5,7 @@ :params ( (char_size :type int :default 8 :range [4 32]) (alphabet :type string :default "standard") - (color_mode :type string :default "color" :desc ""color", "mono", "invert", or any color name/hex") + (color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex") (background_color :type string :default "black" :desc "background color name/hex") (invert_colors :type int :default 0 :desc "swap foreground and background colors") (contrast :type float :default 1.5 :range [1 3]) diff --git a/sexp_effects/effects/cell_pattern.sexp b/sexp_effects/effects/cell_pattern.sexp new file mode 100644 index 0000000..bc503bb --- /dev/null +++ b/sexp_effects/effects/cell_pattern.sexp @@ -0,0 +1,65 @@ +;; Cell Pattern effect - custom patterns within cells +;; +;; Demonstrates building arbitrary per-cell visuals from primitives. +;; Uses local coordinates within cells to draw patterns scaled by luminance. + +(require-primitives "xector") + +(define-effect cell_pattern + :params ( + (cell-size :type int :default 16 :range [8 48] :desc "Cell size") + (pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring") + ) + (let* ( + ;; Pool to get cell colors + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + (cell-lum (α/ (nth pooled 3) 255)) + + ;; Cell indices for each pixel + (cell-idx (cell-indices frame cell-size)) + + ;; Look up cell values for each pixel + (pix-r (gather cell-r cell-idx)) + (pix-g (gather cell-g cell-idx)) + (pix-b (gather cell-b cell-idx)) + (pix-lum (gather cell-lum cell-idx)) + + ;; Local position within cell [0, 1] + (lx (local-x-norm frame cell-size)) + (ly (local-y-norm frame cell-size)) + + ;; Pattern mask based on pattern type + (mask + (cond + ;; Diagonal lines - thickness based on luminance + ((= pattern "diagonal") + (let* ((diag (αmod (α+ lx ly) 0.25)) + (thickness (α* pix-lum 0.125))) + (α< diag thickness))) + + ;; Cross pattern + ((= pattern "cross") + (let* ((cx (αabs (α- lx 0.5))) + (cy (αabs (α- ly 0.5))) + (thickness (α* pix-lum 0.25))) + (αor (α< cx thickness) (α< cy thickness)))) + + ;; Ring pattern + ((= pattern "ring") + (let* ((dx (α- lx 0.5)) + (dy (α- ly 0.5)) + (dist (αsqrt (α+ (α² dx) (α² dy)))) + (target (α* pix-lum 0.4)) + (thickness 0.05)) + (α< (αabs (α- dist target)) thickness))) + + ;; Default: solid + (else (α> pix-lum 0))))) + + ;; Apply mask: show cell color where mask is true, black elsewhere + (rgb (where mask pix-r 0) + (where mask pix-g 0) + (where mask pix-b 0)))) diff --git a/sexp_effects/effects/echo.sexp b/sexp_effects/effects/echo.sexp index 2aa2287..599a1d6 100644 --- a/sexp_effects/effects/echo.sexp +++ b/sexp_effects/effects/echo.sexp @@ -6,10 +6,10 @@ (num_echoes :type int :default 4 :range [1 20]) (decay :type float :default 0.5 :range [0 1]) ) - (let* ((buffer (state-get 'buffer (list))) + (let* ((buffer (state-get "buffer" (list))) (new-buffer (take (cons frame buffer) (+ num_echoes 1)))) (begin - (state-set 'buffer new-buffer) + (state-set "buffer" new-buffer) ;; Blend frames with decay (if (< (length new-buffer) 2) frame diff --git a/sexp_effects/effects/halftone.sexp b/sexp_effects/effects/halftone.sexp new file mode 100644 index 0000000..2190a4a --- /dev/null +++ b/sexp_effects/effects/halftone.sexp @@ -0,0 +1,49 @@ +;; Halftone/dot effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell luminances +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel +;; local-x/y-norm - position within cell [0,1] +;; where - conditional per-pixel + +(require-primitives "xector") + +(define-effect halftone + :params ( + (cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells") + (dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius") + (invert :type bool :default false :desc "Invert (white dots on black)") + ) + (let* ( + ;; Pool frame to get luminance per cell + (pooled (pool-frame frame cell-size)) + (cell-lum (nth pooled 3)) ; luminance is 4th element + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Get cell luminance for each pixel + (pixel-lum (α/ (gather cell-lum cell-idx) 255)) + + ;; Position within cell, normalized to [-0.5, 0.5] + (lx (α- (local-x-norm frame cell-size) 0.5)) + (ly (α- (local-y-norm frame cell-size) 0.5)) + + ;; Distance from cell center (0 at center, ~0.7 at corners) + (dist (αsqrt (α+ (α² lx) (α² ly)))) + + ;; Radius based on luminance (brighter = bigger dot) + (radius (α* (if invert (α- 1 pixel-lum) pixel-lum) + (α* dot-scale 0.5))) + + ;; Is this pixel inside the dot? + (inside (α< dist radius)) + + ;; Output color + (fg (if invert 255 0)) + (bg (if invert 0 255)) + (out (where inside fg bg))) + + ;; Grayscale output + (rgb out out out))) diff --git a/sexp_effects/effects/mosaic.sexp b/sexp_effects/effects/mosaic.sexp new file mode 100644 index 0000000..5de07de --- /dev/null +++ b/sexp_effects/effects/mosaic.sexp @@ -0,0 +1,30 @@ +;; Mosaic effect - built from primitive xector operations +;; +;; Uses: +;; pool-frame - downsample to cell averages +;; cell-indices - which cell each pixel belongs to +;; gather - look up cell value for each pixel + +(require-primitives "xector") + +(define-effect mosaic + :params ( + (cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells") + ) + (let* ( + ;; Pool frame to get average color per cell (returns r,g,b,lum xectors) + (pooled (pool-frame frame cell-size)) + (cell-r (nth pooled 0)) + (cell-g (nth pooled 1)) + (cell-b (nth pooled 2)) + + ;; For each output pixel, get its cell index + (cell-idx (cell-indices frame cell-size)) + + ;; Gather: look up cell color for each pixel + (out-r (gather cell-r cell-idx)) + (out-g (gather cell-g cell-idx)) + (out-b (gather cell-b cell-idx))) + + ;; Reconstruct frame + (rgb out-r out-g out-b))) diff --git a/sexp_effects/effects/outline.sexp b/sexp_effects/effects/outline.sexp index 276f891..921a0b8 100644 --- a/sexp_effects/effects/outline.sexp +++ b/sexp_effects/effects/outline.sexp @@ -5,9 +5,9 @@ :params ( (thickness :type int :default 2 :range [1 10]) (threshold :type int :default 100 :range [20 300]) - (color :type list :default (list 0 0 0) + (color :type list :default (list 0 0 0)) + (fill_mode :type string :default "original") ) - (fill_mode "original")) (let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold)) (dilated (if (> thickness 1) (dilate edge-img thickness) diff --git a/sexp_effects/effects/strobe.sexp b/sexp_effects/effects/strobe.sexp index e51ba30..2bf80b4 100644 --- a/sexp_effects/effects/strobe.sexp +++ b/sexp_effects/effects/strobe.sexp @@ -5,12 +5,12 @@ :params ( (frame_rate :type int :default 12 :range [1 60]) ) - (let* ((held (state-get 'held nil)) - (held-until (state-get 'held-until 0)) + (let* ((held (state-get "held" nil)) + (held-until (state-get "held-until" 0)) (frame-duration (/ 1 frame_rate))) (if (or (core:is-nil held) (>= t held-until)) (begin - (state-set 'held (copy frame)) - (state-set 'held-until (+ t frame-duration)) + (state-set "held" (copy frame)) + (state-set "held-until" (+ t frame-duration)) frame) held))) diff --git a/sexp_effects/effects/trails.sexp b/sexp_effects/effects/trails.sexp index f16c302..5c0fc7c 100644 --- a/sexp_effects/effects/trails.sexp +++ b/sexp_effects/effects/trails.sexp @@ -5,16 +5,16 @@ :params ( (persistence :type float :default 0.8 :range [0 0.99]) ) - (let* ((buffer (state-get 'buffer nil)) + (let* ((buffer (state-get "buffer" nil)) (current frame)) (if (= buffer nil) (begin - (state-set 'buffer (copy frame)) + (state-set "buffer" (copy frame)) frame) (let* ((faded (blending:blend-images buffer (make-image (image:width frame) (image:height frame) (list 0 0 0)) (- 1 persistence))) (result (blending:blend-mode faded current "lighten"))) (begin - (state-set 'buffer result) + (state-set "buffer" result) result))))) diff --git a/sexp_effects/effects/xector_feathered_blend.sexp b/sexp_effects/effects/xector_feathered_blend.sexp new file mode 100644 index 0000000..96224fb --- /dev/null +++ b/sexp_effects/effects/xector_feathered_blend.sexp @@ -0,0 +1,44 @@ +;; Feathered blend - blend two same-size frames with distance-based falloff +;; Center shows overlay, edges show background, with smooth transition + +(require-primitives "xector") + +(define-effect xector_feathered_blend + :params ( + (inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)") + (fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)") + (overlay :type frame :default nil :desc "Frame to blend in center") + ) + (let* ( + ;; Get normalized distance from center (0 at center, ~1 at corners) + (dist (dist-from-center frame)) + (max-dist (βmax dist)) + (dist-norm (α/ dist max-dist)) + + ;; Calculate blend factor: + ;; - 1.0 when dist-norm < inner-radius (fully overlay) + ;; - 0.0 when dist-norm > inner-radius + fade-width (fully background) + ;; - linear ramp between + (t (α/ (α- dist-norm inner-radius) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + (inv-blend (α- 1 blend)) + + ;; Background channels + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + (if (nil? overlay) + ;; No overlay - visualize the blend mask + (let ((vis (α* blend 255))) + (rgb vis vis vis)) + + ;; Blend overlay with background using the mask + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; lerp: bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r inv-blend) (α* ov-r blend))) + (g-out (α+ (α* bg-g inv-blend) (α* ov-g blend))) + (b-out (α+ (α* bg-b inv-blend) (α* ov-b blend)))) + (rgb r-out g-out b-out))))) diff --git a/sexp_effects/effects/xector_grain.sexp b/sexp_effects/effects/xector_grain.sexp new file mode 100644 index 0000000..64ebfa6 --- /dev/null +++ b/sexp_effects/effects/xector_grain.sexp @@ -0,0 +1,34 @@ +;; Film grain effect using xector operations +;; Demonstrates random xectors and mixing scalar/xector math + +(require-primitives "xector") + +(define-effect xector_grain + :params ( + (intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity") + (colored :type bool :default false :desc "Use colored grain") + ) + (let* ( + ;; Extract channels + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Generate noise xector(s) + ;; randn-x generates normal distribution noise + (grain-amount (* intensity 50))) + + (if colored + ;; Colored grain: different noise per channel + (let* ((nr (randn-x frame 0 grain-amount)) + (ng (randn-x frame 0 grain-amount)) + (nb (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r nr) 0 255) + (αclamp (α+ g ng) 0 255) + (αclamp (α+ b nb) 0 255))) + + ;; Monochrome grain: same noise for all channels + (let ((n (randn-x frame 0 grain-amount))) + (rgb (αclamp (α+ r n) 0 255) + (αclamp (α+ g n) 0 255) + (αclamp (α+ b n) 0 255)))))) diff --git a/sexp_effects/effects/xector_inset_blend.sexp b/sexp_effects/effects/xector_inset_blend.sexp new file mode 100644 index 0000000..597e23a --- /dev/null +++ b/sexp_effects/effects/xector_inset_blend.sexp @@ -0,0 +1,57 @@ +;; Inset blend - fade a smaller frame into a larger background +;; Uses distance-based alpha for smooth transition (no hard edges) + +(require-primitives "xector") + +(define-effect xector_inset_blend + :params ( + (x :type int :default 0 :desc "X position of inset") + (y :type int :default 0 :desc "Y position of inset") + (fade-width :type int :default 50 :desc "Width of fade region in pixels") + (overlay :type frame :default nil :desc "The smaller frame to inset") + ) + (let* ( + ;; Get dimensions + (bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height + (bg-w bg-h) ;; placeholder + + ;; For now, create a simple centered circular blend + ;; Distance from center of overlay position + (cx (+ x (/ (- bg-w (* 2 x)) 2))) + (cy (+ y (/ (- bg-h (* 2 y)) 2))) + + ;; Get coordinates as xectors + (px (x-coords frame)) + (py (y-coords frame)) + + ;; Distance from center + (dx (α- px cx)) + (dy (α- py cy)) + (dist (αsqrt (α+ (α* dx dx) (α* dy dy)))) + + ;; Inner radius (fully overlay) and outer radius (fully background) + (inner-r (- (/ bg-w 2) x fade-width)) + (outer-r (- (/ bg-w 2) x)) + + ;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between + (t (α/ (α- dist inner-r) fade-width)) + (blend (α- 1 (αclamp t 0 1))) + + ;; Extract channels from both frames + (bg-r (red frame)) + (bg-g (green frame)) + (bg-b (blue frame))) + + ;; If overlay provided, blend it + (if overlay + (let* ((ov-r (red overlay)) + (ov-g (green overlay)) + (ov-b (blue overlay)) + ;; Linear blend: result = bg * (1-blend) + overlay * blend + (r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend))) + (g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend))) + (b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend)))) + (rgb r-out g-out b-out)) + ;; No overlay - just show the blend mask for debugging + (let ((mask-vis (α* blend 255))) + (rgb mask-vis mask-vis mask-vis))))) diff --git a/sexp_effects/effects/xector_threshold.sexp b/sexp_effects/effects/xector_threshold.sexp new file mode 100644 index 0000000..c571468 --- /dev/null +++ b/sexp_effects/effects/xector_threshold.sexp @@ -0,0 +1,27 @@ +;; Threshold effect using xector operations +;; Demonstrates where (conditional select) and β (reduction) for normalization + +(require-primitives "xector") + +(define-effect xector_threshold + :params ( + (threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)") + (invert :type bool :default false :desc "Invert the threshold") + ) + (let* ( + ;; Get grayscale luminance as xector + (luma (gray frame)) + + ;; Normalize to 0-1 range + (luma-norm (α/ luma 255)) + + ;; Create boolean mask: pixels above threshold + (mask (if invert + (α< luma-norm threshold) + (α>= luma-norm threshold))) + + ;; Use where to select: white (255) if above threshold, black (0) if below + (out (where mask 255 0))) + + ;; Output as grayscale (same value for R, G, B) + (rgb out out out))) diff --git a/sexp_effects/effects/xector_vignette.sexp b/sexp_effects/effects/xector_vignette.sexp new file mode 100644 index 0000000..d654ca7 --- /dev/null +++ b/sexp_effects/effects/xector_vignette.sexp @@ -0,0 +1,36 @@ +;; Vignette effect using xector operations +;; Demonstrates α (element-wise) and β (reduction) patterns + +(require-primitives "xector") + +(define-effect xector_vignette + :params ( + (strength :type float :default 0.5 :range [0 1]) + (radius :type float :default 1.0 :range [0.5 2]) + ) + (let* ( + ;; Get normalized distance from center for each pixel + (dist (dist-from-center frame)) + + ;; Calculate max distance (corner distance) + (max-dist (* (βmax dist) radius)) + + ;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength) + ;; Using explicit α operators + (factor (α- 1 (α* (α/ dist max-dist) strength))) + + ;; Clamp factor to [0, 1] + (factor (αclamp factor 0 1)) + + ;; Extract channels as xectors + (r (red frame)) + (g (green frame)) + (b (blue frame)) + + ;; Apply factor to each channel (implicit element-wise via Xector operators) + (r-out (* r factor)) + (g-out (* g factor)) + (b-out (* b factor))) + + ;; Combine back to frame + (rgb r-out g-out b-out))) diff --git a/sexp_effects/interpreter.py b/sexp_effects/interpreter.py index 830904a..406f6da 100644 --- a/sexp_effects/interpreter.py +++ b/sexp_effects/interpreter.py @@ -156,11 +156,21 @@ class Interpreter: if form == 'define': name = expr[1] if _is_symbol(name): + # Simple define: (define name value) value = self.eval(expr[2], env) self.global_env.set(name.name, value) return value + elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]): + # Function define: (define (fn-name args...) body) + # Desugars to: (define fn-name (lambda (args...) body)) + fn_name = name[0].name + params = [p.name if _is_symbol(p) else p for p in name[1:]] + body = expr[2] + fn = Lambda(params, body, env) + self.global_env.set(fn_name, fn) + return fn else: - raise SyntaxError(f"define requires symbol, got {name}") + raise SyntaxError(f"define requires symbol or (name args...), got {name}") # Define-effect if form == 'define-effect': @@ -276,6 +286,10 @@ class Interpreter: if form == 'require-primitives': return self._eval_require_primitives(expr, env) + # require - load .sexp file into current scope + if form == 'require': + return self._eval_require(expr, env) + # Function call fn = self.eval(head, env) args = [self.eval(arg, env) for arg in expr[1:]] @@ -488,6 +502,61 @@ class Interpreter: from .primitive_libs import load_primitive_library return load_primitive_library(name, path) + def _eval_require(self, expr: Any, env: Environment) -> Any: + """ + Evaluate require: load a .sexp file and evaluate its definitions. + + Syntax: + (require "derived") ; loads derived.sexp from sexp_effects/ + (require "path/to/file.sexp") ; loads from explicit path + + Definitions from the file are added to the current environment. + """ + for lib_expr in expr[1:]: + if _is_symbol(lib_expr): + lib_name = lib_expr.name + else: + lib_name = lib_expr + + # Find the .sexp file + sexp_path = self._find_sexp_file(lib_name) + if sexp_path is None: + raise ValueError(f"Cannot find sexp file: {lib_name}") + + # Parse and evaluate the file + content = parse_file(sexp_path) + + # Evaluate all top-level expressions + if isinstance(content, list) and content and isinstance(content[0], list): + for e in content: + self.eval(e, env) + else: + self.eval(content, env) + + return None + + def _find_sexp_file(self, name: str) -> Optional[str]: + """Find a .sexp file by name.""" + # Try various locations + candidates = [ + # Explicit path + name, + name + '.sexp', + # In sexp_effects directory + Path(__file__).parent / f'{name}.sexp', + Path(__file__).parent / name, + # In effects directory + Path(__file__).parent / 'effects' / f'{name}.sexp', + Path(__file__).parent / 'effects' / name, + ] + + for path in candidates: + p = Path(path) if not isinstance(path, Path) else path + if p.exists() and p.is_file(): + return str(p) + + return None + def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any: """ Evaluate ascii-fx-zone special form. @@ -876,8 +945,8 @@ class Interpreter: for pname, pdefault in effect.params.items(): value = params.get(pname) if value is None: - # Evaluate default if it's an expression (list) - if isinstance(pdefault, list): + # Evaluate default if it's an expression (list) or a symbol (like 'nil') + if isinstance(pdefault, list) or _is_symbol(pdefault): value = self.eval(pdefault, env) else: value = pdefault diff --git a/sexp_effects/parser.py b/sexp_effects/parser.py index 215d714..5e17565 100644 --- a/sexp_effects/parser.py +++ b/sexp_effects/parser.py @@ -71,7 +71,8 @@ class Tokenizer: STRING = re.compile(r'"(?:[^"\\]|\\.)*"') NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') - SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*') + # Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations + SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*') def __init__(self, text: str): self.text = text diff --git a/sexp_effects/primitive_libs/drawing.py b/sexp_effects/primitive_libs/drawing.py index ddd1a01..50e0c45 100644 --- a/sexp_effects/primitive_libs/drawing.py +++ b/sexp_effects/primitive_libs/drawing.py @@ -1,126 +1,680 @@ """ Drawing Primitives Library -Draw shapes, text, and characters on images. +Draw shapes, text, and characters on images with sophisticated text handling. + +Text Features: +- Font loading from files or system fonts +- Text measurement and fitting +- Alignment (left/center/right, top/middle/bottom) +- Opacity for fade effects +- Multi-line text support +- Shadow and outline effects """ import numpy as np import cv2 from PIL import Image, ImageDraw, ImageFont +import os +import glob as glob_module +from typing import Optional, Tuple, List, Union -# Default font (will be loaded lazily) -_default_font = None +# ============================================================================= +# Font Management +# ============================================================================= + +# Font cache: (path, size) -> font object +_font_cache = {} + +# Common system font directories +FONT_DIRS = [ + "/usr/share/fonts", + "/usr/local/share/fonts", + "~/.fonts", + "~/.local/share/fonts", + "/System/Library/Fonts", # macOS + "/Library/Fonts", # macOS + "C:/Windows/Fonts", # Windows +] + +# Default fonts to try (in order of preference) +DEFAULT_FONTS = [ + "DejaVuSans.ttf", + "DejaVuSansMono.ttf", + "Arial.ttf", + "Helvetica.ttf", + "FreeSans.ttf", + "LiberationSans-Regular.ttf", +] -def _get_default_font(size=16): - """Get default font, creating if needed.""" - global _default_font - if _default_font is None or _default_font.size != size: +def _find_font_file(name: str) -> Optional[str]: + """Find a font file by name in system directories.""" + # If it's already a full path + if os.path.isfile(name): + return name + + # Expand user paths + expanded = os.path.expanduser(name) + if os.path.isfile(expanded): + return expanded + + # Search in font directories + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + # Direct match + direct = os.path.join(font_dir, name) + if os.path.isfile(direct): + return direct + + # Recursive search + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower() == name.lower(): + return os.path.join(root, f) + # Also match without extension + base = os.path.splitext(f)[0] + if base.lower() == name.lower(): + return os.path.join(root, f) + + return None + + +def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont: + """Get a default font at the given size.""" + for font_name in DEFAULT_FONTS: + path = _find_font_file(font_name) + if path: + try: + return ImageFont.truetype(path, size) + except: + continue + + # Last resort: PIL default + return ImageFont.load_default() + + +def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont: + """ + Load a font by name or path. + + (make-font "Arial" 32) ; system font by name + (make-font "/path/to/font.ttf" 24) ; font file path + (make-font "DejaVuSans" 48) ; searches common locations + + Returns a font object for use with text primitives. + """ + size = int(size) + + # Check cache + cache_key = (name_or_path, size) + if cache_key in _font_cache: + return _font_cache[cache_key] + + # Find the font file + path = _find_font_file(name_or_path) + if not path: + raise FileNotFoundError(f"Font not found: {name_or_path}") + + # Load and cache + font = ImageFont.truetype(path, size) + _font_cache[cache_key] = font + return font + + +def prim_list_fonts() -> List[str]: + """ + List available system fonts. + + (list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...) + + Returns list of font filenames found in system directories. + """ + fonts = set() + + for font_dir in FONT_DIRS: + font_dir = os.path.expanduser(font_dir) + if not os.path.isdir(font_dir): + continue + + for root, dirs, files in os.walk(font_dir): + for f in files: + if f.lower().endswith(('.ttf', '.otf', '.ttc')): + fonts.add(f) + + return sorted(fonts) + + +def prim_font_size(font: ImageFont.FreeTypeFont) -> int: + """ + Get the size of a font. + + (font-size my-font) ; -> 24 + """ + return font.size + + +# ============================================================================= +# Text Measurement +# ============================================================================= + +def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]: + """ + Measure text dimensions. + + (text-size "Hello" my-font) ; -> (width height) + (text-size "Hello" :font-size 32) ; -> (width height) with default font + + For multi-line text, returns total bounding box. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + # Create temporary image for measurement + img = Image.new('RGB', (1, 1)) + draw = ImageDraw.Draw(img) + + bbox = draw.textbbox((0, 0), str(text), font=font) + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + + return (width, height) + + +def prim_text_metrics(font=None, font_size: int = 24) -> dict: + """ + Get font metrics. + + (text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25} + + Useful for precise text layout. + """ + if font is None: + font = _get_default_font(int(font_size)) + elif isinstance(font, (int, float)): + font = _get_default_font(int(font)) + + ascent, descent = font.getmetrics() + return { + 'ascent': ascent, + 'descent': descent, + 'height': ascent + descent, + 'size': font.size, + } + + +def prim_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> int: + """ + Calculate font size to fit text within bounds. + + (fit-text-size "Hello World" 400 100) ; -> 48 + (fit-text-size "Title" 800 200 :font-name "Arial") + + Returns the largest font size that fits within max_width x max_height. + """ + max_width = int(max_width) + max_height = int(max_height) + min_size = int(min_size) + max_size = int(max_size) + text = str(text) + + # Binary search for optimal size + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + + if font_name: + try: + font = prim_make_font(font_name, mid) + except: + font = _get_default_font(mid) + else: + font = _get_default_font(mid) + + w, h = prim_text_size(text, font) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + +def prim_fit_font(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, + max_size: int = 500) -> ImageFont.FreeTypeFont: + """ + Create a font sized to fit text within bounds. + + (fit-font "Hello World" 400 100) ; -> font object + (fit-font "Title" 800 200 :font-name "Arial") + + Returns a font object at the optimal size. + """ + size = prim_fit_text_size(text, max_width, max_height, + font_name, min_size, max_size) + + if font_name: try: - _default_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) + return prim_make_font(font_name, size) except: - _default_font = ImageFont.load_default() - return _default_font + pass + return _get_default_font(size) + + +# ============================================================================= +# Text Drawing +# ============================================================================= + +def prim_text(img: np.ndarray, text: str, + x: int = None, y: int = None, + width: int = None, height: int = None, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "left", valign: str = "top", + fit: bool = False, + shadow: bool = False, shadow_color=None, shadow_offset: int = 2, + outline: bool = False, outline_color=None, outline_width: int = 1, + line_spacing: float = 1.2) -> np.ndarray: + """ + Draw text with alignment, opacity, and effects. + + Basic usage: + (text frame "Hello" :x 100 :y 50) + + Centered in frame: + (text frame "Title" :align "center" :valign "middle") + + Fit to box: + (text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true) + + With fade (for animations): + (text frame "Fading" :x 100 :y 100 :opacity 0.5) + + With effects: + (text frame "Shadow" :x 100 :y 100 :shadow true) + (text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0)) + + Args: + img: Input frame + text: Text to draw + x, y: Position (if not specified, uses alignment in full frame) + width, height: Bounding box (for fit and alignment within box) + font: Font object from make-font + font_size: Size if no font specified + font_name: Font name to load + color: RGB tuple (default white) + opacity: 0.0 (invisible) to 1.0 (opaque) for fading + align: "left", "center", "right" + valign: "top", "middle", "bottom" + fit: If true, auto-size font to fit in box + shadow: Draw drop shadow + shadow_color: Shadow color (default black) + shadow_offset: Shadow offset in pixels + outline: Draw text outline + outline_color: Outline color (default black) + outline_width: Outline thickness + line_spacing: Multiplier for line height (for multi-line) + + Returns: + Frame with text drawn + """ + h, w = img.shape[:2] + text = str(text) + + # Default colors + if color is None: + color = (255, 255, 255) + else: + color = tuple(int(c) for c in color) + + if shadow_color is None: + shadow_color = (0, 0, 0) + else: + shadow_color = tuple(int(c) for c in shadow_color) + + if outline_color is None: + outline_color = (0, 0, 0) + else: + outline_color = tuple(int(c) for c in outline_color) + + # Determine bounding box + if x is None: + x = 0 + if width is None: + width = w + if y is None: + y = 0 + if height is None: + height = h + + x, y = int(x), int(y) + box_width = int(width) if width else w - x + box_height = int(height) if height else h - y + + # Get or create font + if font is None: + if fit: + font = prim_fit_font(text, box_width, box_height, font_name) + elif font_name: + try: + font = prim_make_font(font_name, int(font_size)) + except: + font = _get_default_font(int(font_size)) + else: + font = _get_default_font(int(font_size)) + + # Measure text + text_w, text_h = prim_text_size(text, font) + + # Calculate position based on alignment + if align == "center": + draw_x = x + (box_width - text_w) // 2 + elif align == "right": + draw_x = x + box_width - text_w + else: # left + draw_x = x + + if valign == "middle": + draw_y = y + (box_height - text_h) // 2 + elif valign == "bottom": + draw_y = y + box_height - text_h + else: # top + draw_y = y + + # Create RGBA image for compositing with opacity + pil_img = Image.fromarray(img).convert('RGBA') + + # Create text layer with transparency + text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(text_layer) + + # Draw shadow first (if enabled) + if shadow: + shadow_x = draw_x + shadow_offset + shadow_y = draw_y + shadow_offset + shadow_rgba = shadow_color + (int(255 * opacity * 0.5),) + draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font) + + # Draw outline (if enabled) + if outline: + outline_rgba = outline_color + (int(255 * opacity),) + ow = int(outline_width) + for dx in range(-ow, ow + 1): + for dy in range(-ow, ow + 1): + if dx != 0 or dy != 0: + draw.text((draw_x + dx, draw_y + dy), text, + fill=outline_rgba, font=font) + + # Draw main text + text_rgba = color + (int(255 * opacity),) + draw.text((draw_x, draw_y), text, fill=text_rgba, font=font) + + # Composite + result = Image.alpha_composite(pil_img, text_layer) + return np.array(result.convert('RGB')) + + +def prim_text_box(img: np.ndarray, text: str, + x: int, y: int, width: int, height: int, + font=None, font_size: int = 24, font_name: str = None, + color=None, opacity: float = 1.0, + align: str = "center", valign: str = "middle", + fit: bool = True, + padding: int = 0, + background=None, background_opacity: float = 0.5, + **kwargs) -> np.ndarray: + """ + Draw text fitted within a box, optionally with background. + + (text-box frame "Title" 50 50 400 100) + (text-box frame "Subtitle" 50 160 400 50 + :background (0 0 0) :background-opacity 0.7) + + Convenience wrapper around text() for common box-with-text pattern. + """ + x, y = int(x), int(y) + width, height = int(width), int(height) + padding = int(padding) + + result = img.copy() + + # Draw background if specified + if background is not None: + bg_color = tuple(int(c) for c in background) + + # Create background with opacity + pil_img = Image.fromarray(result).convert('RGBA') + bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + bg_draw = ImageDraw.Draw(bg_layer) + bg_rgba = bg_color + (int(255 * background_opacity),) + bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba) + result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB')) + + # Draw text within padded box + return prim_text(result, text, + x=x + padding, y=y + padding, + width=width - 2 * padding, height=height - 2 * padding, + font=font, font_size=font_size, font_name=font_name, + color=color, opacity=opacity, + align=align, valign=valign, fit=fit, + **kwargs) + + +# ============================================================================= +# Legacy text functions (keep for compatibility) +# ============================================================================= def prim_draw_char(img, char, x, y, font_size=16, color=None): - """Draw a single character at (x, y).""" - if color is None: - color = [255, 255, 255] - - pil_img = Image.fromarray(img) - draw = ImageDraw.Draw(pil_img) - font = _get_default_font(font_size) - draw.text((x, y), char, fill=tuple(color), font=font) - return np.array(pil_img) + """Draw a single character at (x, y). Legacy function.""" + return prim_text(img, str(char), x=int(x), y=int(y), + font_size=int(font_size), color=color) def prim_draw_text(img, text, x, y, font_size=16, color=None): - """Draw text string at (x, y).""" + """Draw text string at (x, y). Legacy function.""" + return prim_text(img, str(text), x=int(x), y=int(y), + font_size=int(font_size), color=color) + + +# ============================================================================= +# Shape Drawing +# ============================================================================= + +def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0): + """ + Fill a rectangle with color. + + (fill-rect frame 10 10 100 50 (255 0 0)) + (fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5) + """ if color is None: color = [255, 255, 255] - pil_img = Image.fromarray(img) - draw = ImageDraw.Draw(pil_img) - font = _get_default_font(font_size) - draw.text((x, y), text, fill=tuple(color), font=font) - return np.array(pil_img) - - -def prim_fill_rect(img, x, y, w, h, color=None): - """Fill a rectangle with color.""" - if color is None: - color = [255, 255, 255] - - result = img.copy() x, y, w, h = int(x), int(y), int(w), int(h) - result[y:y+h, x:x+w] = color - return result + + if opacity >= 1.0: + result = img.copy() + result[y:y+h, x:x+w] = color + return result + + # With opacity, use alpha compositing + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([x, y, x + w, y + h], fill=fill_rgba) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) -def prim_draw_rect(img, x, y, w, h, color=None, thickness=1): +def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0): """Draw rectangle outline.""" if color is None: color = [255, 255, 255] - result = img.copy() - cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), - tuple(color), thickness) - return result + if opacity >= 1.0: + result = img.copy() + cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.rectangle([int(x), int(y), int(x+w), int(y+h)], + outline=outline_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) -def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1): +def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0): """Draw a line from (x1, y1) to (x2, y2).""" if color is None: color = [255, 255, 255] - result = img.copy() - cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), - tuple(color), thickness) - return result + if opacity >= 1.0: + result = img.copy() + cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + draw.line([(int(x1), int(y1)), (int(x2), int(y2))], + fill=line_rgba, width=int(thickness)) + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) -def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, fill=False): +def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, + fill=False, opacity: float = 1.0): """Draw a circle.""" if color is None: color = [255, 255, 255] - result = img.copy() - t = -1 if fill else thickness - cv2.circle(result, (int(cx), int(cy)), int(radius), tuple(color), t) - return result + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.circle(result, (int(cx), int(cy)), int(radius), + tuple(int(c) for c in color), t) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + cx, cy, r = int(cx), int(cy), int(radius) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + + if fill: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba) + else: + draw.ellipse([cx - r, cy - r, cx + r, cy + r], + outline=rgba, width=int(thickness)) + + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) -def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, thickness=1, fill=False): +def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, + thickness=1, fill=False, opacity: float = 1.0): """Draw an ellipse.""" if color is None: color = [255, 255, 255] - result = img.copy() - t = -1 if fill else thickness - cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), - angle, 0, 360, tuple(color), t) - return result + if opacity >= 1.0: + result = img.copy() + t = -1 if fill else int(thickness) + cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, tuple(int(c) for c in color), t) + return result + + # With opacity (note: PIL doesn't support rotated ellipses easily) + # Fall back to cv2 on a separate layer + layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) + t = -1 if fill else int(thickness) + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) + cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)), + float(angle), 0, 360, rgba, t) + + pil_img = Image.fromarray(img).convert('RGBA') + pil_layer = Image.fromarray(layer) + result = Image.alpha_composite(pil_img, pil_layer) + return np.array(result.convert('RGB')) -def prim_draw_polygon(img, points, color=None, thickness=1, fill=False): +def prim_draw_polygon(img, points, color=None, thickness=1, + fill=False, opacity: float = 1.0): """Draw a polygon from list of [x, y] points.""" if color is None: color = [255, 255, 255] - result = img.copy() - pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + if opacity >= 1.0: + result = img.copy() + pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + if fill: + cv2.fillPoly(result, [pts], tuple(int(c) for c in color)) + else: + cv2.polylines(result, [pts], True, + tuple(int(c) for c in color), int(thickness)) + return result + + # With opacity + pil_img = Image.fromarray(img).convert('RGBA') + layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + + pts_flat = [(int(p[0]), int(p[1])) for p in points] + rgba = tuple(int(c) for c in color) + (int(255 * opacity),) if fill: - cv2.fillPoly(result, [pts], tuple(color)) + draw.polygon(pts_flat, fill=rgba) else: - cv2.polylines(result, [pts], True, tuple(color), thickness) + draw.polygon(pts_flat, outline=rgba, width=int(thickness)) - return result + result = Image.alpha_composite(pil_img, layer) + return np.array(result.convert('RGB')) +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + PRIMITIVES = { - # Text + # Font management + 'make-font': prim_make_font, + 'list-fonts': prim_list_fonts, + 'font-size': prim_font_size, + + # Text measurement + 'text-size': prim_text_size, + 'text-metrics': prim_text_metrics, + 'fit-text-size': prim_fit_text_size, + 'fit-font': prim_fit_font, + + # Text drawing + 'text': prim_text, + 'text-box': prim_text_box, + + # Legacy text (compatibility) 'draw-char': prim_draw_char, 'draw-text': prim_draw_text, diff --git a/sexp_effects/primitive_libs/streaming.py b/sexp_effects/primitive_libs/streaming.py index 9092087..ccb6056 100644 --- a/sexp_effects/primitive_libs/streaming.py +++ b/sexp_effects/primitive_libs/streaming.py @@ -8,12 +8,18 @@ GPU Acceleration: - Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU) - Hardware video decoding (NVDEC) is used when available - Dramatically improves performance on GPU nodes + +Async Prefetching: +- Set STREAMING_PREFETCH=1 to enable background frame prefetching +- Decodes upcoming frames while current frame is being processed """ import os import numpy as np import subprocess import json +import threading +from collections import deque from pathlib import Path # Try to import CuPy for GPU acceleration @@ -28,6 +34,10 @@ except ImportError: # Disabled by default until all primitives support GPU frames GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE +# Async prefetch mode - decode frames in background thread +PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1" +PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10")) + # Check for hardware decode support (cached) _HWDEC_AVAILABLE = None @@ -283,6 +293,122 @@ class VideoSource: self._proc = None +class PrefetchingVideoSource: + """ + Video source with background prefetching for improved performance. + + Wraps VideoSource and adds a background thread that pre-decodes + upcoming frames while the main thread processes the current frame. + """ + + def __init__(self, path: str, fps: float = 30, buffer_size: int = None): + self._source = VideoSource(path, fps) + self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE + self._buffer = {} # time -> frame + self._buffer_lock = threading.Lock() + self._prefetch_time = 0.0 + self._frame_time = 1.0 / fps + self._stop_event = threading.Event() + self._request_event = threading.Event() + self._target_time = 0.0 + + # Start prefetch thread + self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) + self._thread.start() + + import sys + print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr) + + def _prefetch_loop(self): + """Background thread that pre-reads frames.""" + while not self._stop_event.is_set(): + # Wait for work or timeout + self._request_event.wait(timeout=0.01) + self._request_event.clear() + + if self._stop_event.is_set(): + break + + # Prefetch frames ahead of target time + target = self._target_time + with self._buffer_lock: + # Clean old frames (more than 1 second behind) + old_times = [t for t in self._buffer.keys() if t < target - 1.0] + for t in old_times: + del self._buffer[t] + + # Count how many frames we have buffered ahead + buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target) + + # Prefetch if buffer not full + if buffered_ahead < self._buffer_size: + # Find next time to prefetch + prefetch_t = target + with self._buffer_lock: + existing_times = set(self._buffer.keys()) + for _ in range(self._buffer_size): + if prefetch_t not in existing_times: + break + prefetch_t += self._frame_time + + # Read the frame (this is the slow part) + try: + frame = self._source.read_at(prefetch_t) + with self._buffer_lock: + self._buffer[prefetch_t] = frame + except Exception as e: + import sys + print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr) + + def read_at(self, t: float) -> np.ndarray: + """Read frame at specific time, using prefetch buffer if available.""" + self._target_time = t + self._request_event.set() # Wake up prefetch thread + + # Round to frame time for buffer lookup + t_key = round(t / self._frame_time) * self._frame_time + + # Check buffer first + with self._buffer_lock: + if t_key in self._buffer: + return self._buffer[t_key] + # Also check for close matches (within half frame time) + for buf_t, frame in self._buffer.items(): + if abs(buf_t - t) < self._frame_time * 0.5: + return frame + + # Not in buffer - read directly (blocking) + frame = self._source.read_at(t) + + # Store in buffer + with self._buffer_lock: + self._buffer[t_key] = frame + + return frame + + def read(self) -> np.ndarray: + """Read frame (uses last cached or t=0).""" + return self.read_at(0) + + def skip(self): + """No-op for seek-based reading.""" + pass + + @property + def size(self): + return self._source.size + + @property + def path(self): + return self._source.path + + def close(self): + self._stop_event.set() + self._request_event.set() # Wake up thread to exit + self._thread.join(timeout=1.0) + self._source.close() + + class AudioAnalyzer: """Audio analyzer for energy and beat detection.""" @@ -394,7 +520,12 @@ class AudioAnalyzer: # === Primitives === def prim_make_video_source(path: str, fps: float = 30): - """Create a video source from a file path.""" + """Create a video source from a file path. + + Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default). + """ + if PREFETCH_ENABLED: + return PrefetchingVideoSource(path, fps) return VideoSource(path, fps) diff --git a/sexp_effects/primitive_libs/xector.py b/sexp_effects/primitive_libs/xector.py new file mode 100644 index 0000000..fb95dfd --- /dev/null +++ b/sexp_effects/primitive_libs/xector.py @@ -0,0 +1,1382 @@ +""" +Xector Primitives - Parallel array operations for GPU-style data parallelism. + +Inspired by Connection Machine Lisp and hillisp. Xectors are parallel arrays +where operations automatically apply element-wise. + +Usage in sexp: + (require-primitives "xector") + + ;; Extract channels as xectors + (let* ((r (red frame)) + (g (green frame)) + (b (blue frame)) + ;; Operations are element-wise on xectors + (brightness (α+ (α* r 0.299) (α* g 0.587) (α* b 0.114)))) + ;; Reduce to scalar + (βmax brightness)) + + ;; Explicit α for element-wise, implicit also works + (α+ r 10) ;; explicit: add 10 to every element + (+ r 10) ;; implicit: same thing when r is a xector + + ;; β for reductions + (β+ r) ;; sum all elements + (βmax r) ;; maximum element + (βmean r) ;; average + +Operators: + α (alpha) - element-wise: (α+ x y) adds corresponding elements + β (beta) - reduce: (β+ x) sums all elements +""" + +import numpy as np +from typing import Union, Callable, Any + +# Try to use CuPy for GPU acceleration if available +try: + import cupy as cp + HAS_CUPY = True +except ImportError: + cp = None + HAS_CUPY = False + + +class Xector: + """ + Parallel array type for element-wise operations. + + Wraps a numpy/cupy array and provides automatic broadcasting + and element-wise operation semantics. + """ + + def __init__(self, data, shape=None): + """ + Create a Xector from data. + + Args: + data: numpy array, cupy array, scalar, or list + shape: optional shape tuple (for coordinate xectors) + """ + if isinstance(data, Xector): + self._data = data._data + self._shape = data._shape + elif isinstance(data, np.ndarray): + self._data = data.astype(np.float32) + self._shape = shape or data.shape + elif HAS_CUPY and isinstance(data, cp.ndarray): + self._data = data.astype(cp.float32) + self._shape = shape or data.shape + elif isinstance(data, (list, tuple)): + self._data = np.array(data, dtype=np.float32) + self._shape = shape or self._data.shape + else: + # Scalar - will broadcast + self._data = np.float32(data) + self._shape = shape or () + + @property + def data(self): + return self._data + + @property + def shape(self): + return self._shape + + def __len__(self): + return self._data.size + + def __repr__(self): + if self._data.size <= 10: + return f"Xector({self._data})" + return f"Xector(shape={self._shape}, size={self._data.size})" + + def to_numpy(self): + """Convert to numpy array.""" + if HAS_CUPY and isinstance(self._data, cp.ndarray): + return cp.asnumpy(self._data) + return self._data + + def to_gpu(self): + """Move to GPU if available.""" + if HAS_CUPY and not isinstance(self._data, cp.ndarray): + self._data = cp.asarray(self._data) + return self + + # Arithmetic operators - enable implicit element-wise ops + def __add__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data + other_data, self._shape) + + def __radd__(self, other): + return Xector(other + self._data, self._shape) + + def __sub__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data - other_data, self._shape) + + def __rsub__(self, other): + return Xector(other - self._data, self._shape) + + def __mul__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data * other_data, self._shape) + + def __rmul__(self, other): + return Xector(other * self._data, self._shape) + + def __truediv__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data / other_data, self._shape) + + def __rtruediv__(self, other): + return Xector(other / self._data, self._shape) + + def __pow__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data ** other_data, self._shape) + + def __neg__(self): + return Xector(-self._data, self._shape) + + def __abs__(self): + return Xector(np.abs(self._data), self._shape) + + # Comparison operators - return boolean xectors + def __lt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data < other_data, self._shape) + + def __le__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data <= other_data, self._shape) + + def __gt__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data > other_data, self._shape) + + def __ge__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data >= other_data, self._shape) + + def __eq__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data == other_data, self._shape) + + def __ne__(self, other): + other_data = other._data if isinstance(other, Xector) else other + return Xector(self._data != other_data, self._shape) + + +def _unwrap(x): + """Unwrap Xector to underlying data, or return as-is.""" + if isinstance(x, Xector): + return x._data + return x + + +def _wrap(data, shape=None): + """Wrap result in Xector if it's an array.""" + if isinstance(data, (np.ndarray,)) or (HAS_CUPY and isinstance(data, cp.ndarray)): + return Xector(data, shape) + return data + + +# ============================================================================= +# Frame/Xector Conversion +# ============================================================================= +# NOTE: red, green, blue, gray are derived in derived.sexp using (channel frame n) + +def xector_from_frame(frame): + """Convert entire frame to xector (flattened RGB). (xector frame) -> Xector""" + if isinstance(frame, np.ndarray): + return Xector(frame.flatten().astype(np.float32), frame.shape) + raise TypeError(f"Expected frame array, got {type(frame)}") + + +def xector_to_frame(x, shape=None): + """Convert xector back to frame. (to-frame x) or (to-frame x shape) -> frame""" + data = _unwrap(x) + if shape is None and isinstance(x, Xector): + shape = x._shape + if shape is None: + raise ValueError("Shape required to convert xector to frame") + return np.clip(data, 0, 255).reshape(shape).astype(np.uint8) + + +# ============================================================================= +# Coordinate Generators +# ============================================================================= +# NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived +# in derived.sexp using iota, tile, repeat primitives + + +# ============================================================================= +# Alpha (α) - Element-wise Operations +# ============================================================================= + +def alpha_lift(fn): + """Lift a scalar function to work element-wise on xectors.""" + def lifted(*args): + # Check if any arg is a Xector + has_xector = any(isinstance(a, Xector) for a in args) + if not has_xector: + return fn(*args) + + # Get shape from first xector + shape = None + for a in args: + if isinstance(a, Xector): + shape = a._shape + break + + # Unwrap all args + unwrapped = [_unwrap(a) for a in args] + + # Apply function + result = fn(*unwrapped) + + return _wrap(result, shape) + + return lifted + + +# Element-wise math operations +def alpha_add(*args): + """Element-wise addition. (α+ a b ...) -> Xector""" + if len(args) == 0: + return 0 + result = _unwrap(args[0]) + for a in args[1:]: + result = result + _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_sub(a, b=None): + """Element-wise subtraction. (α- a b) -> Xector""" + if b is None: + return Xector(-_unwrap(a)) if isinstance(a, Xector) else -a + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) - _unwrap(b), shape) + + +def alpha_mul(*args): + """Element-wise multiplication. (α* a b ...) -> Xector""" + if len(args) == 0: + return 1 + result = _unwrap(args[0]) + for a in args[1:]: + result = result * _unwrap(a) + return _wrap(result, args[0]._shape if isinstance(args[0], Xector) else None) + + +def alpha_div(a, b): + """Element-wise division. (α/ a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) / _unwrap(b), shape) + + +def alpha_pow(a, b): + """Element-wise power. (α** a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) ** _unwrap(b), shape) + + +def alpha_sqrt(x): + """Element-wise square root. (αsqrt x) -> Xector""" + return _wrap(np.sqrt(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_abs(x): + """Element-wise absolute value. (αabs x) -> Xector""" + return _wrap(np.abs(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_sin(x): + """Element-wise sine. (αsin x) -> Xector""" + return _wrap(np.sin(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_cos(x): + """Element-wise cosine. (αcos x) -> Xector""" + return _wrap(np.cos(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_exp(x): + """Element-wise exponential. (αexp x) -> Xector""" + return _wrap(np.exp(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_log(x): + """Element-wise natural log. (αlog x) -> Xector""" + return _wrap(np.log(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_clamp is derived in derived.sexp as (max2 lo (min2 hi x)) + +def alpha_min(a, b): + """Element-wise minimum. (αmin a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.minimum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_max(a, b): + """Element-wise maximum. (αmax a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.maximum(_unwrap(a), _unwrap(b)), shape) + + +def alpha_mod(a, b): + """Element-wise modulo. (αmod a b) -> Xector""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) % _unwrap(b), shape) + + +def alpha_floor(x): + """Element-wise floor. (αfloor x) -> Xector""" + return _wrap(np.floor(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_ceil(x): + """Element-wise ceiling. (αceil x) -> Xector""" + return _wrap(np.ceil(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +def alpha_round(x): + """Element-wise round. (αround x) -> Xector""" + return _wrap(np.round(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# NOTE: alpha_sq is derived in derived.sexp as (* x x) + +# Comparison operators (return boolean xectors) +def alpha_lt(a, b): + """Element-wise less than. (α< a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) < _unwrap(b), shape) + + +def alpha_le(a, b): + """Element-wise less-or-equal. (α<= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) <= _unwrap(b), shape) + + +def alpha_gt(a, b): + """Element-wise greater than. (α> a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) > _unwrap(b), shape) + + +def alpha_ge(a, b): + """Element-wise greater-or-equal. (α>= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) >= _unwrap(b), shape) + + +def alpha_eq(a, b): + """Element-wise equality. (α= a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(_unwrap(a) == _unwrap(b), shape) + + +# Logical operators +def alpha_and(a, b): + """Element-wise logical and. (αand a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_and(_unwrap(a), _unwrap(b)), shape) + + +def alpha_or(a, b): + """Element-wise logical or. (αor a b) -> Xector[bool]""" + shape = a._shape if isinstance(a, Xector) else (b._shape if isinstance(b, Xector) else None) + return _wrap(np.logical_or(_unwrap(a), _unwrap(b)), shape) + + +def alpha_not(x): + """Element-wise logical not. (αnot x) -> Xector[bool]""" + return _wrap(np.logical_not(_unwrap(x)), x._shape if isinstance(x, Xector) else None) + + +# ============================================================================= +# Beta (β) - Reduction Operations +# ============================================================================= + +def beta_add(x): + """Sum all elements. (β+ x) -> scalar""" + return float(np.sum(_unwrap(x))) + + +def beta_mul(x): + """Product of all elements. (β* x) -> scalar""" + return float(np.prod(_unwrap(x))) + + +def beta_min(x): + """Minimum element. (βmin x) -> scalar""" + return float(np.min(_unwrap(x))) + + +def beta_max(x): + """Maximum element. (βmax x) -> scalar""" + return float(np.max(_unwrap(x))) + + +def beta_mean(x): + """Mean of all elements. (βmean x) -> scalar""" + return float(np.mean(_unwrap(x))) + + +def beta_std(x): + """Standard deviation. (βstd x) -> scalar""" + return float(np.std(_unwrap(x))) + + +def beta_count(x): + """Count of elements. (βcount x) -> scalar""" + return int(np.size(_unwrap(x))) + + +def beta_any(x): + """True if any element is truthy. (βany x) -> bool""" + return bool(np.any(_unwrap(x))) + + +def beta_all(x): + """True if all elements are truthy. (βall x) -> bool""" + return bool(np.all(_unwrap(x))) + + +# ============================================================================= +# Conditional / Selection +# ============================================================================= + +def xector_where(cond, true_val, false_val): + """ + Conditional select. (where cond true-val false-val) -> Xector + + Like numpy.where - selects elements based on condition. + """ + cond_data = _unwrap(cond) + true_data = _unwrap(true_val) + false_data = _unwrap(false_val) + + # Get shape from condition or values + shape = None + for x in [cond, true_val, false_val]: + if isinstance(x, Xector): + shape = x._shape + break + + result = np.where(cond_data, true_data, false_data) + return _wrap(result, shape) + + +# NOTE: fill, zeros, ones are derived in derived.sexp using iota + +def xector_rand(size_or_frame): + """Create xector of random values [0,1). (rand-x frame) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector(np.random.random(size).astype(np.float32), shape) + + +def xector_randn(size_or_frame, mean=0, std=1): + """Create xector of normal random values. (randn-x frame) or (randn-x frame mean std) -> Xector""" + if isinstance(size_or_frame, np.ndarray): + h, w = size_or_frame.shape[:2] + size = h * w + shape = (h, w) + elif isinstance(size_or_frame, Xector): + size = len(size_or_frame) + shape = size_or_frame._shape + else: + size = int(size_or_frame) + shape = (size,) + + return Xector((np.random.randn(size) * std + mean).astype(np.float32), shape) + + +# ============================================================================= +# Type checking +# ============================================================================= + +def is_xector(x): + """Check if x is a Xector. (xector? x) -> bool""" + return isinstance(x, Xector) + + +# ============================================================================= +# CORE PRIMITIVES: gather, scatter, group-reduce, reshape +# These are the fundamental operations everything else builds on. +# ============================================================================= + +def xector_gather(data, indices): + """ + Parallel index lookup. (gather data indices) -> Xector + + For each index in indices, look up the corresponding value in data. + This is the fundamental operation for remapping/resampling. + + Example: + (gather [10 20 30 40] [2 0 1 2]) ; -> [30 10 20 30] + """ + data_arr = _unwrap(data) + idx_arr = _unwrap(indices).astype(np.int32) + + # Flatten data for 1D indexing + flat_data = data_arr.flatten() + + # Clip indices to valid range + idx_clipped = np.clip(idx_arr, 0, len(flat_data) - 1) + + result = flat_data[idx_clipped] + shape = indices._shape if isinstance(indices, Xector) else None + return Xector(result, shape) + + +def xector_gather_2d(data, row_indices, col_indices): + """ + 2D parallel index lookup. (gather-2d data rows cols) -> Xector + + For each (row, col) pair, look up the value in 2D data. + Essential for grid/cell operations. + + Example: + (gather-2d image-lum cell-rows cell-cols) + """ + data_arr = _unwrap(data) + row_arr = _unwrap(row_indices).astype(np.int32) + col_arr = _unwrap(col_indices).astype(np.int32) + + # Get data shape + if isinstance(data, Xector) and data._shape and len(data._shape) >= 2: + h, w = data._shape[:2] + data_2d = data_arr.reshape(h, w) + elif len(data_arr.shape) >= 2: + h, w = data_arr.shape[:2] + data_2d = data_arr.reshape(h, w) if data_arr.ndim == 1 else data_arr + else: + # Assume square + size = int(np.sqrt(len(data_arr))) + h, w = size, size + data_2d = data_arr.reshape(h, w) + + # Clip indices + row_clipped = np.clip(row_arr, 0, h - 1) + col_clipped = np.clip(col_arr, 0, w - 1) + + result = data_2d[row_clipped.flatten(), col_clipped.flatten()] + shape = row_indices._shape if isinstance(row_indices, Xector) else None + return Xector(result, shape) + + +def xector_scatter(indices, values, size): + """ + Parallel index write. (scatter indices values size) -> Xector + + Create a new xector of given size, writing values at indices. + Later writes overwrite earlier ones at same index. + + Example: + (scatter [0 2 4] [10 20 30] 5) ; -> [10 0 20 0 30] + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + idx_clipped = np.clip(idx_arr, 0, int(size) - 1) + result[idx_clipped] = val_arr + + return Xector(result, (int(size),)) + + +def xector_scatter_add(indices, values, size): + """ + Parallel index accumulate. (scatter-add indices values size) -> Xector + + Like scatter, but adds to existing values instead of overwriting. + Useful for histograms, pooling reductions. + + Example: + (scatter-add [0 0 1] [1 2 3] 3) ; -> [3 3 0] (1+2 at index 0) + """ + idx_arr = _unwrap(indices).astype(np.int32) + val_arr = _unwrap(values) + + result = np.zeros(int(size), dtype=np.float32) + np.add.at(result, np.clip(idx_arr, 0, int(size) - 1), val_arr) + + return Xector(result, (int(size),)) + + +def xector_group_reduce(values, group_indices, num_groups, op='mean'): + """ + Reduce values by group. (group-reduce values groups num-groups op) -> Xector + + Groups values by group_indices and reduces each group. + This is the primitive for pooling operations. + + Args: + values: Xector of values to reduce + group_indices: Xector of group assignments (integers) + num_groups: Number of groups (output size) + op: 'mean', 'sum', 'max', 'min' + + Example: + ; Pool 4 values into 2 groups + (group-reduce [1 2 3 4] [0 0 1 1] 2 "mean") ; -> [1.5 3.5] + """ + val_arr = _unwrap(values).flatten() + grp_arr = _unwrap(group_indices).astype(np.int32).flatten() + n = int(num_groups) + + if op == 'sum': + result = np.zeros(n, dtype=np.float32) + np.add.at(result, grp_arr, val_arr) + elif op == 'mean': + sums = np.zeros(n, dtype=np.float32) + counts = np.zeros(n, dtype=np.float32) + np.add.at(sums, grp_arr, val_arr) + np.add.at(counts, grp_arr, 1) + result = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + elif op == 'max': + result = np.full(n, -np.inf, dtype=np.float32) + np.maximum.at(result, grp_arr, val_arr) + result[result == -np.inf] = 0 + elif op == 'min': + result = np.full(n, np.inf, dtype=np.float32) + np.minimum.at(result, grp_arr, val_arr) + result[result == np.inf] = 0 + else: + raise ValueError(f"Unknown reduce op: {op}") + + return Xector(result, (n,)) + + +def xector_reshape(x, *dims): + """ + Reshape xector. (reshape x h w) or (reshape x n) -> Xector + + Changes the logical shape of the xector without changing data. + """ + data = _unwrap(x) + if len(dims) == 1: + new_shape = (int(dims[0]),) + else: + new_shape = tuple(int(d) for d in dims) + + return Xector(data.reshape(-1), new_shape) + + +def xector_shape(x): + """Get shape of xector. (shape x) -> list""" + if isinstance(x, Xector): + return list(x._shape) if x._shape else [len(x)] + if isinstance(x, np.ndarray): + return list(x.shape) + return [] + + +def xector_len(x): + """Get length of xector. (xlen x) -> int""" + return len(_unwrap(x).flatten()) + + +def xector_iota(n): + """ + Generate indices 0 to n-1. (iota n) -> Xector + + Fundamental for generating coordinate xectors. + + Example: + (iota 5) ; -> [0 1 2 3 4] + """ + return Xector(np.arange(int(n), dtype=np.float32), (int(n),)) + + +def xector_repeat(x, n): + """ + Repeat each element n times. (repeat x n) -> Xector + + Example: + (repeat [1 2 3] 2) ; -> [1 1 2 2 3 3] + """ + data = _unwrap(x) + result = np.repeat(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +def xector_tile(x, n): + """ + Tile entire xector n times. (tile x n) -> Xector + + Example: + (tile [1 2 3] 2) ; -> [1 2 3 1 2 3] + """ + data = _unwrap(x) + result = np.tile(data.flatten(), int(n)) + return Xector(result, (len(result),)) + + +# ============================================================================= +# 2D Grid Helpers (built on primitives above) +# ============================================================================= + +def xector_cell_indices(frame, cell_size): + """ + Compute cell index for each pixel. (cell-indices frame cell-size) -> Xector + + Returns flat index of which cell each pixel belongs to. + This is the bridge between pixel-space and cell-space. + """ + h, w = frame.shape[:2] + cell_size = int(cell_size) + + rows = h // cell_size + cols = w // cell_size + + # For each pixel, compute its cell index + y = np.repeat(np.arange(h), w) # [0,0,0..., 1,1,1..., ...] + x = np.tile(np.arange(w), h) # [0,1,2..., 0,1,2..., ...] + + cell_row = y // cell_size + cell_col = x // cell_size + cell_idx = cell_row * cols + cell_col + + # Clip to valid range + cell_idx = np.clip(cell_idx, 0, rows * cols - 1) + + return Xector(cell_idx.astype(np.float32), (h, w)) + + +def xector_local_x(frame, cell_size): + """ + X position within each cell [0, cell_size). (local-x frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + x = np.tile(np.arange(w), h) + local = (x % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y(frame, cell_size): + """ + Y position within each cell [0, cell_size). (local-y frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + y = np.repeat(np.arange(h), w) + local = (y % int(cell_size)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_x_norm(frame, cell_size): + """ + Normalized X within cell [0, 1]. (local-x-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + x = np.tile(np.arange(w), h) + local = ((x % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_local_y_norm(frame, cell_size): + """ + Normalized Y within cell [0, 1]. (local-y-norm frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + y = np.repeat(np.arange(h), w) + local = ((y % cs) / max(1, cs - 1)).astype(np.float32) + return Xector(local, (h, w)) + + +def xector_pool_frame(frame, cell_size, op='mean'): + """ + Pool frame to cell values. (pool-frame frame cell-size) -> (r, g, b, lum) Xectors + + Returns tuple of xectors: (red, green, blue, luminance) for cells. + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + num_cells = rows * cols + + # Compute cell indices for each pixel + y = np.repeat(np.arange(h), w) + x = np.tile(np.arange(w), h) + cell_row = np.clip(y // cs, 0, rows - 1) + cell_col = np.clip(x // cs, 0, cols - 1) + cell_idx = cell_row * cols + cell_col + + # Extract channels + r_flat = frame[:, :, 0].flatten().astype(np.float32) + g_flat = frame[:, :, 1].flatten().astype(np.float32) + b_flat = frame[:, :, 2].flatten().astype(np.float32) + + # Pool each channel + def pool_channel(data): + sums = np.zeros(num_cells, dtype=np.float32) + counts = np.zeros(num_cells, dtype=np.float32) + np.add.at(sums, cell_idx, data) + np.add.at(counts, cell_idx, 1) + return np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) + + r_pooled = pool_channel(r_flat) + g_pooled = pool_channel(g_flat) + b_pooled = pool_channel(b_flat) + lum = 0.299 * r_pooled + 0.587 * g_pooled + 0.114 * b_pooled + + shape = (rows, cols) + return (Xector(r_pooled, shape), + Xector(g_pooled, shape), + Xector(b_pooled, shape), + Xector(lum, shape)) + + +def xector_cell_row(frame, cell_size): + """ + Cell row index for each pixel. (cell-row frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + + y = np.repeat(np.arange(h), w) + cell_row = np.clip(y // cs, 0, rows - 1).astype(np.float32) + return Xector(cell_row, (h, w)) + + +def xector_cell_col(frame, cell_size): + """ + Cell column index for each pixel. (cell-col frame cell-size) -> Xector + """ + h, w = frame.shape[:2] + cs = int(cell_size) + cols = w // cs + + x = np.tile(np.arange(w), h) + cell_col = np.clip(x // cs, 0, cols - 1).astype(np.float32) + return Xector(cell_col, (h, w)) + + +def xector_num_cells(frame, cell_size): + """Number of cells. (num-cells frame cell-size) -> (rows, cols, total)""" + h, w = frame.shape[:2] + cs = int(cell_size) + rows = h // cs + cols = w // cs + return (rows, cols, rows * cols) + + +# ============================================================================= +# Scan (Prefix Operations) - cumulative reductions +# ============================================================================= + +def xector_scan_add(x, axis=None): + """ + Cumulative sum (prefix sum). (scan+ x) or (scan+ x :axis 0) + + Returns array where each element is sum of all previous elements. + Useful for integral images, cumulative effects. + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None: + # Reshape to 2D for axis operation + if shape and len(shape) == 2: + result = np.cumsum(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumsum(data, axis=int(axis)) + else: + result = np.cumsum(data) + + return _wrap(result, shape) + + +def xector_scan_mul(x, axis=None): + """Cumulative product. (scan* x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.cumprod(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.cumprod(data) + + return _wrap(result, shape) + + +def xector_scan_max(x, axis=None): + """Cumulative maximum. (scan-max x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.maximum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.maximum.accumulate(data) + + return _wrap(result, shape) + + +def xector_scan_min(x, axis=None): + """Cumulative minimum. (scan-min x) -> Xector""" + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if axis is not None and shape and len(shape) == 2: + result = np.minimum.accumulate(data.reshape(shape), axis=int(axis)).flatten() + else: + result = np.minimum.accumulate(data) + + return _wrap(result, shape) + + +# ============================================================================= +# Outer Product - Cartesian operations +# ============================================================================= + +def xector_outer(x, y, op='*'): + """ + Outer product. (outer x y) or (outer x y :op '+') + + Creates 2D result where result[i,j] = op(x[i], y[j]). + Default is multiplication (*). + + Useful for generating 2D patterns from 1D vectors. + """ + x_data = _unwrap(x) + y_data = _unwrap(y) + + ops = { + '*': np.multiply, + '+': np.add, + '-': np.subtract, + '/': np.divide, + 'max': np.maximum, + 'min': np.minimum, + 'and': np.logical_and, + 'or': np.logical_or, + 'xor': np.logical_xor, + } + + op_fn = ops.get(op, np.multiply) + result = op_fn.outer(x_data.flatten(), y_data.flatten()) + + # Return as xector with 2D shape + h, w = len(x_data.flatten()), len(y_data.flatten()) + return _wrap(result.flatten(), (h, w)) + + +def xector_outer_add(x, y): + """Outer sum. (outer+ x y) -> result[i,j] = x[i] + y[j]""" + return xector_outer(x, y, '+') + + +def xector_outer_mul(x, y): + """Outer product. (outer* x y) -> result[i,j] = x[i] * y[j]""" + return xector_outer(x, y, '*') + + +def xector_outer_max(x, y): + """Outer max. (outer-max x y) -> result[i,j] = max(x[i], y[j])""" + return xector_outer(x, y, 'max') + + +def xector_outer_min(x, y): + """Outer min. (outer-min x y) -> result[i,j] = min(x[i], y[j])""" + return xector_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - dimensional reductions +# ============================================================================= + +def xector_reduce_axis(x, op='sum', axis=0): + """ + Reduce along an axis. (reduce-axis x :op 'sum' :axis 0) + + ops: 'sum', 'mean', 'max', 'min', 'prod', 'std' + axis: 0 (rows), 1 (columns) + + For a frame-sized xector (H*W): + axis=0: reduce across rows -> W values (one per column) + axis=1: reduce across columns -> H values (one per row) + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + + if shape is None or len(shape) != 2: + # Can't do axis reduction without 2D shape + raise ValueError("reduce-axis requires 2D xector (with shape)") + + h, w = shape + data_2d = data.reshape(h, w) + axis = int(axis) + + ops = { + 'sum': lambda d, a: np.sum(d, axis=a), + '+': lambda d, a: np.sum(d, axis=a), + 'mean': lambda d, a: np.mean(d, axis=a), + 'max': lambda d, a: np.max(d, axis=a), + 'min': lambda d, a: np.min(d, axis=a), + 'prod': lambda d, a: np.prod(d, axis=a), + '*': lambda d, a: np.prod(d, axis=a), + 'std': lambda d, a: np.std(d, axis=a), + } + + op_fn = ops.get(op, ops['sum']) + result = op_fn(data_2d, axis) + + # Result shape: if axis=0, shape is (w,); if axis=1, shape is (h,) + new_shape = (w,) if axis == 0 else (h,) + return _wrap(result.flatten(), new_shape) + + +def xector_sum_axis(x, axis=0): + """Sum along axis. (sum-axis x :axis 0)""" + return xector_reduce_axis(x, 'sum', axis) + + +def xector_mean_axis(x, axis=0): + """Mean along axis. (mean-axis x :axis 0)""" + return xector_reduce_axis(x, 'mean', axis) + + +def xector_max_axis(x, axis=0): + """Max along axis. (max-axis x :axis 0)""" + return xector_reduce_axis(x, 'max', axis) + + +def xector_min_axis(x, axis=0): + """Min along axis. (min-axis x :axis 0)""" + return xector_reduce_axis(x, 'min', axis) + + +# ============================================================================= +# Windowed Operations - sliding window computations +# ============================================================================= + +def xector_window(x, size, op='mean', stride=1): + """ + Sliding window operation. (window x size :op 'mean' :stride 1) + + Applies reduction over sliding windows of given size. + ops: 'sum', 'mean', 'max', 'min' + + For 1D: windows slide along the array + For 2D (with shape): windows are size x size squares + """ + data = _unwrap(x) + shape = x._shape if isinstance(x, Xector) else None + size = int(size) + stride = int(stride) + + ops = { + 'sum': np.sum, + 'mean': np.mean, + 'max': np.max, + 'min': np.min, + 'std': np.std, + } + op_fn = ops.get(op, np.mean) + + if shape and len(shape) == 2: + # 2D sliding window + h, w = shape + data_2d = data.reshape(h, w) + + # Use stride tricks for efficient windowing + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + result = np.zeros((out_h, out_w)) + for i in range(out_h): + for j in range(out_w): + window = data_2d[i*stride:i*stride+size, j*stride:j*stride+size] + result[i, j] = op_fn(window) + + return _wrap(result.flatten(), (out_h, out_w)) + else: + # 1D sliding window + n = len(data) + out_n = (n - size) // stride + 1 + result = np.array([op_fn(data[i*stride:i*stride+size]) for i in range(out_n)]) + return _wrap(result, (out_n,)) + + +def xector_window_sum(x, size, stride=1): + """Sliding window sum. (window-sum x size)""" + return xector_window(x, size, 'sum', stride) + + +def xector_window_mean(x, size, stride=1): + """Sliding window mean. (window-mean x size)""" + return xector_window(x, size, 'mean', stride) + + +def xector_window_max(x, size, stride=1): + """Sliding window max. (window-max x size)""" + return xector_window(x, size, 'max', stride) + + +def xector_window_min(x, size, stride=1): + """Sliding window min. (window-min x size)""" + return xector_window(x, size, 'min', stride) + + +def xector_integral_image(frame): + """ + Compute integral image (summed area table). (integral-image frame) + + Each pixel contains sum of all pixels above and to the left. + Enables O(1) box blur at any radius. + + Returns xector with same shape as frame's luminance. + """ + if hasattr(frame, 'shape') and len(frame.shape) == 3: + # Convert frame to grayscale + gray = np.mean(frame, axis=2) + else: + data = _unwrap(frame) + shape = frame._shape if isinstance(frame, Xector) else None + if shape and len(shape) == 2: + gray = data.reshape(shape) + else: + gray = data + + integral = np.cumsum(np.cumsum(gray, axis=0), axis=1) + h, w = integral.shape + return _wrap(integral.flatten(), (h, w)) + + +def xector_box_blur_fast(integral, x, y, radius, width, height): + """ + Fast box blur using integral image. (box-blur-fast integral x y radius w h) + + Given pre-computed integral image, compute average in box centered at (x,y). + O(1) regardless of radius. + """ + integral_data = _unwrap(integral) + shape = integral._shape if isinstance(integral, Xector) else None + + if shape is None or len(shape) != 2: + raise ValueError("box-blur-fast requires 2D integral image") + + h, w = shape + integral_2d = integral_data.reshape(h, w) + + radius = int(radius) + x, y = int(x), int(y) + + # Clamp coordinates + x1 = max(0, x - radius) + y1 = max(0, y - radius) + x2 = min(w - 1, x + radius) + y2 = min(h - 1, y + radius) + + # Sum in rectangle using integral image + total = integral_2d[y2, x2] + if x1 > 0: + total -= integral_2d[y2, x1 - 1] + if y1 > 0: + total -= integral_2d[y1 - 1, x2] + if x1 > 0 and y1 > 0: + total += integral_2d[y1 - 1, x1 - 1] + + count = (x2 - x1 + 1) * (y2 - y1 + 1) + return total / max(count, 1) + + +# ============================================================================= +# PRIMITIVES Export +# ============================================================================= + +PRIMITIVES = { + # Frame/Xector conversion + # NOTE: red, green, blue, gray, rgb are derived in derived.sexp using (channel frame n) + 'xector': xector_from_frame, + 'to-frame': xector_to_frame, + + # Coordinate generators + # NOTE: x-coords, y-coords, x-norm, y-norm, dist-from-center are derived + # in derived.sexp using iota, tile, repeat primitives + + # Alpha (α) - element-wise operations + 'α+': alpha_add, + 'α-': alpha_sub, + 'α*': alpha_mul, + 'α/': alpha_div, + 'α**': alpha_pow, + 'αsqrt': alpha_sqrt, + 'αabs': alpha_abs, + 'αsin': alpha_sin, + 'αcos': alpha_cos, + 'αexp': alpha_exp, + 'αlog': alpha_log, + # NOTE: αclamp is derived in derived.sexp as (max2 lo (min2 hi x)) + 'αmin': alpha_min, + 'αmax': alpha_max, + 'αmod': alpha_mod, + 'αfloor': alpha_floor, + 'αceil': alpha_ceil, + 'αround': alpha_round, + # NOTE: α² / αsq is derived in derived.sexp as (* x x) + + # Alpha comparison + 'α<': alpha_lt, + 'α<=': alpha_le, + 'α>': alpha_gt, + 'α>=': alpha_ge, + 'α=': alpha_eq, + + # Alpha logical + 'αand': alpha_and, + 'αor': alpha_or, + 'αnot': alpha_not, + + # ASCII fallbacks for α + 'alpha+': alpha_add, + 'alpha-': alpha_sub, + 'alpha*': alpha_mul, + 'alpha/': alpha_div, + 'alpha**': alpha_pow, + 'alpha-sqrt': alpha_sqrt, + 'alpha-abs': alpha_abs, + 'alpha-sin': alpha_sin, + 'alpha-cos': alpha_cos, + 'alpha-exp': alpha_exp, + 'alpha-log': alpha_log, + 'alpha-min': alpha_min, + 'alpha-max': alpha_max, + 'alpha-mod': alpha_mod, + 'alpha-floor': alpha_floor, + 'alpha-ceil': alpha_ceil, + 'alpha-round': alpha_round, + 'alpha<': alpha_lt, + 'alpha<=': alpha_le, + 'alpha>': alpha_gt, + 'alpha>=': alpha_ge, + 'alpha=': alpha_eq, + 'alpha-and': alpha_and, + 'alpha-or': alpha_or, + 'alpha-not': alpha_not, + + # Beta (β) - reduction operations + 'β+': beta_add, + 'β*': beta_mul, + 'βmin': beta_min, + 'βmax': beta_max, + 'βmean': beta_mean, + 'βstd': beta_std, + 'βcount': beta_count, + 'βany': beta_any, + 'βall': beta_all, + + # ASCII fallbacks for β + 'beta+': beta_add, + 'beta*': beta_mul, + 'beta-min': beta_min, + 'beta-max': beta_max, + 'beta-mean': beta_mean, + 'beta-std': beta_std, + 'beta-count': beta_count, + 'beta-any': beta_any, + 'beta-all': beta_all, + + # Convenience aliases + 'sum': beta_add, + 'product': beta_mul, + 'mean': beta_mean, + + # Conditional / Selection + 'where': xector_where, + # NOTE: fill, zeros, ones are derived in derived.sexp using iota + 'rand-x': xector_rand, + 'randn-x': xector_randn, + + # Type checking + 'xector?': is_xector, + + # =========================================== + # CORE PRIMITIVES - fundamental operations + # =========================================== + + # Gather/Scatter - parallel indexing + 'gather': xector_gather, + 'gather-2d': xector_gather_2d, + 'scatter': xector_scatter, + 'scatter-add': xector_scatter_add, + + # Group reduce - pooling primitive + 'group-reduce': xector_group_reduce, + + # Shape operations + 'reshape': xector_reshape, + 'shape': xector_shape, + 'xlen': xector_len, + + # Index generation + 'iota': xector_iota, + 'repeat': xector_repeat, + 'tile': xector_tile, + + # Cell/Grid helpers (built on primitives) + 'cell-indices': xector_cell_indices, + 'cell-row': xector_cell_row, + 'cell-col': xector_cell_col, + 'local-x': xector_local_x, + 'local-y': xector_local_y, + 'local-x-norm': xector_local_x_norm, + 'local-y-norm': xector_local_y_norm, + 'pool-frame': xector_pool_frame, + 'num-cells': xector_num_cells, + + # Scan (prefix) operations - cumulative reductions + 'scan+': xector_scan_add, + 'scan*': xector_scan_mul, + 'scan-max': xector_scan_max, + 'scan-min': xector_scan_min, + 'scan-add': xector_scan_add, + 'scan-mul': xector_scan_mul, + + # Outer product - Cartesian operations + 'outer': xector_outer, + 'outer+': xector_outer_add, + 'outer*': xector_outer_mul, + 'outer-add': xector_outer_add, + 'outer-mul': xector_outer_mul, + 'outer-max': xector_outer_max, + 'outer-min': xector_outer_min, + + # Reduce with axis - dimensional reductions + 'reduce-axis': xector_reduce_axis, + 'sum-axis': xector_sum_axis, + 'mean-axis': xector_mean_axis, + 'max-axis': xector_max_axis, + 'min-axis': xector_min_axis, + + # Windowed operations - sliding window computations + 'window': xector_window, + 'window-sum': xector_window_sum, + 'window-mean': xector_window_mean, + 'window-max': xector_window_max, + 'window-min': xector_window_min, + + # Integral image - for fast box blur + 'integral-image': xector_integral_image, + 'box-blur-fast': xector_box_blur_fast, +} diff --git a/sexp_effects/primitives.py b/sexp_effects/primitives.py index 8bdca5c..9a50356 100644 --- a/sexp_effects/primitives.py +++ b/sexp_effects/primitives.py @@ -797,31 +797,63 @@ def prim_tan(x: float) -> float: return math.tan(x) -def prim_atan2(y: float, x: float) -> float: +def prim_atan2(y, x): + if hasattr(y, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape) return math.atan2(y, x) -def prim_sqrt(x: float) -> float: +def prim_sqrt(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.sqrt(np.maximum(0, x._data)), x._shape) + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(0, x)) return math.sqrt(max(0, x)) -def prim_pow(x: float, y: float) -> float: +def prim_pow(x, y): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + y_data = y._data if hasattr(y, '_data') else y + return Xector(np.power(x._data, y_data), x._shape) return math.pow(x, y) -def prim_abs(x: float) -> float: +def prim_abs(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.abs(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.abs(x) return abs(x) -def prim_floor(x: float) -> int: +def prim_floor(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.floor(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.floor(x) return int(math.floor(x)) -def prim_ceil(x: float) -> int: +def prim_ceil(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.ceil(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.ceil(x) return int(math.ceil(x)) -def prim_round(x: float) -> int: +def prim_round(x): + if hasattr(x, '_data'): # Xector + from sexp_effects.primitive_libs.xector import Xector + return Xector(np.round(x._data), x._shape) + if isinstance(x, np.ndarray): + return np.round(x) return int(round(x)) diff --git a/streaming/jax_typography.py b/streaming/jax_typography.py new file mode 100644 index 0000000..f976b6d --- /dev/null +++ b/streaming/jax_typography.py @@ -0,0 +1,860 @@ +""" +JAX Typography Primitives + +Two approaches for text rendering, both compile to JAX/GPU: + +## 1. TextStrip - Pixel-perfect static text + Pre-render entire strings at compile time using PIL. + Perfect sub-pixel anti-aliasing, exact match with PIL. + Use for: static titles, labels, any text without per-character effects. + + S-expression: + (let ((strip (render-text-strip "Hello World" 48))) + (place-text-strip frame strip x y :color white)) + +## 2. Glyph-by-glyph - Dynamic text effects + Individual glyph placement for wave, arc, audio-reactive effects. + Each character can have independent position, color, opacity. + Note: slight anti-aliasing differences vs PIL due to integer positioning. + + S-expression: + ; Wave text - y oscillates with character index + (let ((glyphs (text-glyphs "Wavy" 48))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc)) + (i (length acc))) ; approximate index + (list + (place-glyph frm (glyph-image g) + (+ x cursor) + (+ y (* amplitude (sin (* i frequency)))) + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g)))))))) + + ; Audio-reactive spacing + (let ((glyphs (text-glyphs "Bass" 48)) + (bass (audio-band 0 200))) + (first + (fold glyphs (list frame 0) + (lambda (acc g) + (let ((frm (first acc)) + (cursor (second acc))) + (list + (place-glyph frm (glyph-image g) + (+ x cursor) y + (glyph-bearing-x g) (glyph-bearing-y g) + white 1.0) + (+ cursor (glyph-advance g) (* bass 20)))))))) + +Kerning support: + ; With kerning adjustment + (+ cursor (glyph-advance g) (glyph-kerning g next-g font-size)) +""" + +import numpy as np +import jax.numpy as jnp +from typing import Tuple, Dict, Any +from dataclasses import dataclass + + +# ============================================================================= +# Glyph Data (computed at compile time) +# ============================================================================= + +@dataclass +class GlyphData: + """Glyph data computed at compile time. + + Attributes: + char: The character + image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime + advance: Horizontal advance (distance to next glyph origin) + bearing_x: Left side bearing (x offset from origin to first pixel) + bearing_y: Top bearing (y offset from baseline to top of glyph) + width: Image width + height: Image height + """ + char: str + image: np.ndarray # (H, W, 4) RGBA uint8 + advance: float + bearing_x: float + bearing_y: float + width: int + height: int + + +# Font cache: (font_name, font_size) -> {char: GlyphData} +_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {} + +# Font metrics cache: (font_name, font_size) -> (ascent, descent) +_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {} + +# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment} +# Kerning adjustment is added to advance: new_advance = advance + kerning +# Typically negative (characters move closer together) +_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {} + + +def _load_font(font_name: str = None, font_size: int = 32): + """Load a font. Called at compile time.""" + from PIL import ImageFont + + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + ] + + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + + return ImageFont.load_default() + + +def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]: + """Get or create glyph cache for a font. Called at compile time.""" + cache_key = (font_name, font_size) + + if cache_key in _GLYPH_CACHE: + return _GLYPH_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + _METRICS_CACHE[cache_key] = (ascent, descent) + + glyphs = {} + charset = ''.join(chr(i) for i in range(32, 127)) + + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + for char in charset: + # Get metrics + bbox = temp_draw.textbbox((0, 0), char, font=font) + advance = font.getlength(char) + + x_min, y_min, x_max, y_max = bbox + + # Create glyph image with padding + padding = 2 + img_w = max(int(x_max - x_min) + padding * 2, 1) + img_h = max(int(y_max - y_min) + padding * 2, 1) + + glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + glyph_draw = ImageDraw.Draw(glyph_img) + + # Draw at position accounting for bbox offset + draw_x = padding - x_min + draw_y = padding - y_min + glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + glyphs[char] = GlyphData( + char=char, + image=np.array(glyph_img, dtype=np.uint8), + advance=float(advance), + bearing_x=float(x_min), + bearing_y=float(-y_min), # Distance from baseline to top + width=img_w, + height=img_h, + ) + + _GLYPH_CACHE[cache_key] = glyphs + return glyphs + + +def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]: + """Get or create kerning cache for a font. Called at compile time. + + Kerning is computed as: + kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b) + + This gives the adjustment needed when placing 'b' after 'a'. + Typically negative (characters move closer together). + """ + cache_key = (font_name, font_size) + + if cache_key in _KERNING_CACHE: + return _KERNING_CACHE[cache_key] + + font = _load_font(font_name, font_size) + kerning = {} + + # Compute kerning for all printable ASCII pairs + charset = ''.join(chr(i) for i in range(32, 127)) + + # Pre-compute individual character lengths + char_lengths = {c: font.getlength(c) for c in charset} + + # Compute kerning for each pair + for c1 in charset: + for c2 in charset: + pair_length = font.getlength(c1 + c2) + individual_sum = char_lengths[c1] + char_lengths[c2] + kern = pair_length - individual_sum + + # Only store non-zero kerning to save memory + if abs(kern) > 0.01: + kerning[(c1, c2)] = kern + + _KERNING_CACHE[cache_key] = kerning + return kerning + + +def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float: + """Get kerning adjustment between two characters. Compile-time. + + Returns the adjustment to add to char1's advance when char2 follows. + Typically negative (characters move closer). + + Usage in S-expression: + (+ (glyph-advance g1) (kerning g1 g2)) + """ + kerning_cache = _get_kerning_cache(font_name, font_size) + return kerning_cache.get((char1, char2), 0.0) + + +@dataclass +class TextStrip: + """Pre-rendered text strip with proper sub-pixel anti-aliasing. + + Rendered at compile time using PIL for exact matching. + At runtime, just composite onto frame at integer positions. + + Attributes: + text: The original text + image: RGBA image as numpy array (H, W, 4) + width: Strip width + height: Strip height + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing of first character + anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right) + anchor_y: Y offset for anchor point (depends on anchor type) + stroke_width: Stroke width used when rendering + """ + text: str + image: np.ndarray + width: int + height: int + baseline_y: int + bearing_x: float + anchor_x: float = 0.0 + anchor_y: float = 0.0 + stroke_width: int = 0 + + +# Text strip cache: cache_key -> TextStrip +_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {} + + +def render_text_strip( + text: str, + font_name: str = None, + font_size: int = 32, + stroke_width: int = 0, + stroke_fill: tuple = None, + anchor: str = "la", # left-ascender (PIL default is "la") + multiline: bool = False, + line_spacing: int = 4, + align: str = "left", +) -> TextStrip: + """Render text to a strip at compile time. Perfect sub-pixel anti-aliasing. + + Args: + text: Text to render + font_name: Path to font file (None for default) + font_size: Font size in pixels + stroke_width: Outline width in pixels (0 for no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black + anchor: PIL anchor code - first char: h=left, m=middle, r=right + second char: a=ascender, t=top, m=middle, s=baseline, d=descender + multiline: If True, handle newlines in text + line_spacing: Extra pixels between lines (for multiline) + align: 'left', 'center', 'right' (for multiline) + + Returns: + TextStrip with pre-rendered text + """ + # Build cache key from all parameters + cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align) + if cache_key in _TEXT_STRIP_CACHE: + return _TEXT_STRIP_CACHE[cache_key] + + from PIL import Image, ImageDraw + + font = _load_font(font_name, font_size) + ascent, descent = font.getmetrics() + + # Default stroke fill to black + if stroke_fill is None: + stroke_fill = (0, 0, 0, 255) + elif len(stroke_fill) == 3: + stroke_fill = (*stroke_fill, 255) + + # Get text bbox (accounting for stroke) + temp = Image.new('RGBA', (1, 1)) + temp_draw = ImageDraw.Draw(temp) + + if multiline: + bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing, + stroke_width=stroke_width) + else: + bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width) + + # bbox is (left, top, right, bottom) relative to origin + x_min, y_min, x_max, y_max = bbox + + # Create image with padding (extra for stroke) + padding = 2 + stroke_width + img_width = max(int(x_max - x_min) + padding * 2, 1) + img_height = max(int(y_max - y_min) + padding * 2, 1) + + # Create RGBA image + img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw text at position that puts it in the image + draw_x = padding - x_min + draw_y = padding - y_min + + if multiline: + draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + spacing=line_spacing, align=align, + stroke_width=stroke_width, stroke_fill=stroke_fill) + else: + draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill) + + # Baseline is at y=0 in text coordinates, which is at draw_y in image + baseline_y = draw_y + + # Convert to numpy for pixel analysis + img_array = np.array(img, dtype=np.uint8) + + # Calculate anchor offsets + # For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching + h_anchor = anchor[0] if len(anchor) > 0 else 'l' + v_anchor = anchor[1] if len(anchor) > 1 else 'a' + + # Find actual pixel bounds (for middle anchor calculations) + alpha = img_array[:, :, 3] + nonzero_cols = np.where(alpha.max(axis=0) > 0)[0] + nonzero_rows = np.where(alpha.max(axis=1) > 0)[0] + + if len(nonzero_cols) > 0: + pixel_x_min = nonzero_cols.min() + pixel_x_max = nonzero_cols.max() + pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0 + else: + pixel_x_center = img_width / 2.0 + + if len(nonzero_rows) > 0: + pixel_y_min = nonzero_rows.min() + pixel_y_max = nonzero_rows.max() + pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0 + else: + pixel_y_center = img_height / 2.0 + + # Horizontal offset + text_width = x_max - x_min + if h_anchor == 'l': # left edge of text + anchor_x = float(draw_x) + elif h_anchor == 'm': # middle - use actual pixel center for perfect matching + anchor_x = pixel_x_center + elif h_anchor == 'r': # right edge of text + anchor_x = float(draw_x + text_width) + else: + anchor_x = float(draw_x) + + # Vertical offset + # PIL anchor positions are based on font metrics (ascent/descent): + # - 'a' (ascender): at the ascender line → draw_y in strip + # - 't' (top): at top of text bounding box → padding in strip + # - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender + # - 's' (baseline): at baseline = ascent below ascender + # - 'd' (descender): at descender line = ascent + descent below ascender + + if v_anchor == 'a': # ascender + anchor_y = float(draw_y) + elif v_anchor == 't': # top of bbox + anchor_y = float(padding) + elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation) + anchor_y = float(draw_y + (ascent + descent) / 2.0) + elif v_anchor == 's': # baseline + anchor_y = float(draw_y + ascent) + elif v_anchor == 'd': # descender + anchor_y = float(draw_y + ascent + descent) + else: + anchor_y = float(draw_y) # default to ascender + + strip = TextStrip( + text=text, + image=img_array, + width=img_width, + height=img_height, + baseline_y=baseline_y, + bearing_x=float(x_min), + anchor_x=anchor_x, + anchor_y=anchor_y, + stroke_width=stroke_width, + ) + + _TEXT_STRIP_CACHE[cache_key] = strip + return strip + + +# ============================================================================= +# Compile-time functions (called during S-expression compilation) +# ============================================================================= + +def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData: + """Get glyph data for a single character. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + return cache.get(char, cache.get(' ')) + + +def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list: + """Get glyph data for a string. Compile-time.""" + cache = _get_glyph_cache(font_name, font_size) + space = cache.get(' ') + return [cache.get(c, space) for c in text] + + +def get_font_ascent(font_name: str = None, font_size: int = 32) -> float: + """Get font ascent. Compile-time.""" + _get_glyph_cache(font_name, font_size) # Ensure cache exists + return _METRICS_CACHE[(font_name, font_size)][0] + + +def get_font_descent(font_name: str = None, font_size: int = 32) -> float: + """Get font descent. Compile-time.""" + _get_glyph_cache(font_name, font_size) + return _METRICS_CACHE[(font_name, font_size)][1] + + +# ============================================================================= +# JAX Runtime Primitives +# ============================================================================= + +def place_glyph_jax( + frame: jnp.ndarray, + glyph_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + bearing_x: float, + bearing_y: float, + color: jnp.ndarray, # (3,) RGB 0-255 + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Place a glyph onto a frame. This is the core JAX primitive. + + All positioning math can use traced values (x, y from audio, time, etc.) + The glyph_image is static (determined at compile time). + + Args: + frame: (H, W, 3) RGB frame + glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array) + x: X position of glyph origin (baseline point) + y: Y position of baseline + bearing_x: Left side bearing + bearing_y: Top bearing (from baseline to top) + color: RGB color array + opacity: Opacity 0-1 + + Returns: + Frame with glyph composited + """ + h, w = frame.shape[:2] + gh, gw = glyph_image.shape[:2] + + # Calculate destination position + # bearing_x: how far right of origin the glyph starts (can be negative) + # bearing_y: how far up from baseline the glyph extends + padding = 2 # Must match padding used in glyph creation + dst_x = x + bearing_x - padding + dst_y = y - bearing_y - padding + + # Extract glyph RGB and alpha + glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + opacity_int = jnp.round(opacity * 255) + glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32) + glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint (glyph is white, multiply by color) + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = glyph_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use padded buffer to avoid XLA's dynamic_update_slice clamping + buf_h = h + 2 * gh + buf_w = w + 2 * gw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + dst_x_int = dst_x.astype(jnp.int32) + dst_y_int = dst_y.astype(jnp.int32) + place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0)) + + rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :] + alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_text_strip_jax( + frame: jnp.ndarray, + strip_image: jnp.ndarray, # (H, W, 4) RGBA + x: float, + y: float, + baseline_y: int, + bearing_x: float, + color: jnp.ndarray, + opacity: float = 1.0, + anchor_x: float = 0.0, + anchor_y: float = 0.0, + stroke_width: int = 0, +) -> jnp.ndarray: + """ + Place a pre-rendered text strip onto a frame. + + The strip was rendered at compile time with proper sub-pixel anti-aliasing. + This just composites it at the specified position. + + Args: + frame: (H, W, 3) RGB frame + strip_image: (sh, sw, 4) RGBA text strip + x: X position for anchor point + y: Y position for anchor point + baseline_y: Y position of baseline within the strip + bearing_x: Left side bearing + color: RGB color + opacity: Opacity 0-1 + anchor_x: X offset of anchor point within strip + anchor_y: Y offset of anchor point within strip + stroke_width: Stroke width used when rendering (affects padding) + + Returns: + Frame with text composited + """ + h, w = frame.shape[:2] + sh, sw = strip_image.shape[:2] + + # Calculate destination position + # Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame + # anchor_x/anchor_y already account for the anchor position within the strip + # Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding) + dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32) + dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32) + + # Extract strip RGB and alpha + strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0 + # Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255 + # Use jnp.round (banker's rounding) to match Python's round() used by PIL + opacity_int = jnp.round(opacity * 255) + strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32) + strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0 + + # Apply color tint + color_normalized = color.astype(jnp.float32) / 255.0 + tinted = strip_rgb * color_normalized + + from jax.lax import dynamic_update_slice + + # Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior. + # XLA clamps indices so the update fits, which silently shifts the strip. + # By placing into a buffer padded by strip dimensions, then extracting the + # frame-sized region, we get correct clipping for both overflow and underflow. + buf_h = h + 2 * sh + buf_w = w + 2 * sw + rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32) + alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32) + + # Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer + place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32) + place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32) + + rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0)) + alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0)) + + # Extract frame-sized region (sh, sw are compile-time constants from strip shape) + rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :] + alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :] + + # Alpha composite using PIL-compatible integer arithmetic: + # result = (src * alpha + dst * (255 - alpha) + 127) // 255 + src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32) + alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32) + dst_int = frame.astype(jnp.int32) + result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255 + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def place_glyph_simple( + frame: jnp.ndarray, + glyph: GlyphData, + x: float, + y: float, + color: tuple = (255, 255, 255), + opacity: float = 1.0, +) -> jnp.ndarray: + """ + Convenience wrapper that takes GlyphData directly. + Converts glyph image to JAX array. + + For S-expression use, prefer place_glyph_jax with pre-converted arrays. + """ + glyph_jax = jnp.asarray(glyph.image) + color_jax = jnp.array(color, dtype=jnp.float32) + + return place_glyph_jax( + frame, glyph_jax, x, y, + glyph.bearing_x, glyph.bearing_y, + color_jax, opacity + ) + + +# ============================================================================= +# S-Expression Primitive Bindings +# ============================================================================= + +def bind_typography_primitives(env: dict) -> dict: + """ + Add typography primitives to an S-expression environment. + + Primitives added: + (text-glyphs text font-size) -> list of glyph data + (glyph-image g) -> JAX array (H, W, 4) + (glyph-advance g) -> float + (glyph-bearing-x g) -> float + (glyph-bearing-y g) -> float + (glyph-width g) -> int + (glyph-height g) -> int + (font-ascent font-size) -> float + (font-descent font-size) -> float + (place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame + """ + + def prim_text_glyphs(text, font_size=32, font_name=None): + """Get list of glyph data for text. Compile-time.""" + return get_glyphs(str(text), font_name, int(font_size)) + + def prim_glyph_image(glyph): + """Get glyph image as JAX array.""" + return jnp.asarray(glyph.image) + + def prim_glyph_advance(glyph): + """Get glyph advance width.""" + return glyph.advance + + def prim_glyph_bearing_x(glyph): + """Get glyph left side bearing.""" + return glyph.bearing_x + + def prim_glyph_bearing_y(glyph): + """Get glyph top bearing.""" + return glyph.bearing_y + + def prim_glyph_width(glyph): + """Get glyph image width.""" + return glyph.width + + def prim_glyph_height(glyph): + """Get glyph image height.""" + return glyph.height + + def prim_font_ascent(font_size=32, font_name=None): + """Get font ascent.""" + return get_font_ascent(font_name, int(font_size)) + + def prim_font_descent(font_size=32, font_name=None): + """Get font descent.""" + return get_font_descent(font_name, int(font_size)) + + def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y, + color=(255, 255, 255), opacity=1.0): + """Place glyph on frame. Runtime JAX operation.""" + color_arr = jnp.array(color, dtype=jnp.float32) + return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y, + color_arr, opacity) + + def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None): + """Get kerning adjustment between two glyphs. Compile-time. + + Returns adjustment to add to glyph1's advance when glyph2 follows. + Typically negative (characters move closer). + + Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size)) + """ + return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size)) + + def prim_char_kerning(char1, char2, font_size=32, font_name=None): + """Get kerning adjustment between two characters. Compile-time.""" + return get_kerning(str(char1), str(char2), font_name, int(font_size)) + + # TextStrip primitives for pre-rendered text with proper anti-aliasing + def prim_render_text_strip(text, font_size=32, font_name=None): + """Render text to a strip at compile time. Perfect anti-aliasing.""" + return render_text_strip(str(text), font_name, int(font_size)) + + def prim_render_text_strip_styled( + text, font_size=32, font_name=None, + stroke_width=0, stroke_fill=None, + anchor="la", multiline=False, line_spacing=4, align="left" + ): + """Render styled text to a strip. Supports stroke, anchors, multiline. + + Args: + text: Text to render + font_size: Size in pixels + font_name: Path to font file + stroke_width: Outline width (0 = no outline) + stroke_fill: Outline color as (R,G,B) or (R,G,B,A) + anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender) + multiline: If True, handle newlines + line_spacing: Extra pixels between lines + align: "left", "center", "right" for multiline + """ + return render_text_strip( + str(text), font_name, int(font_size), + stroke_width=int(stroke_width), + stroke_fill=stroke_fill, + anchor=str(anchor), + multiline=bool(multiline), + line_spacing=int(line_spacing), + align=str(align), + ) + + def prim_text_strip_image(strip): + """Get text strip image as JAX array.""" + return jnp.asarray(strip.image) + + def prim_text_strip_width(strip): + """Get text strip width.""" + return strip.width + + def prim_text_strip_height(strip): + """Get text strip height.""" + return strip.height + + def prim_text_strip_baseline_y(strip): + """Get text strip baseline Y position.""" + return strip.baseline_y + + def prim_text_strip_bearing_x(strip): + """Get text strip left bearing.""" + return strip.bearing_x + + def prim_text_strip_anchor_x(strip): + """Get text strip anchor X offset.""" + return strip.anchor_x + + def prim_text_strip_anchor_y(strip): + """Get text strip anchor Y offset.""" + return strip.anchor_y + + def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0): + """Place pre-rendered text strip on frame. Runtime JAX operation.""" + strip_img = jnp.asarray(strip.image) + color_arr = jnp.array(color, dtype=jnp.float32) + return place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color_arr, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + # Add to environment + env.update({ + # Glyph-by-glyph primitives (for wave, arc, audio-reactive effects) + 'text-glyphs': prim_text_glyphs, + 'glyph-image': prim_glyph_image, + 'glyph-advance': prim_glyph_advance, + 'glyph-bearing-x': prim_glyph_bearing_x, + 'glyph-bearing-y': prim_glyph_bearing_y, + 'glyph-width': prim_glyph_width, + 'glyph-height': prim_glyph_height, + 'glyph-kerning': prim_glyph_kerning, + 'char-kerning': prim_char_kerning, + 'font-ascent': prim_font_ascent, + 'font-descent': prim_font_descent, + 'place-glyph': prim_place_glyph, + # TextStrip primitives (for pixel-perfect static text) + 'render-text-strip': prim_render_text_strip, + 'render-text-strip-styled': prim_render_text_strip_styled, + 'text-strip-image': prim_text_strip_image, + 'text-strip-width': prim_text_strip_width, + 'text-strip-height': prim_text_strip_height, + 'text-strip-baseline-y': prim_text_strip_baseline_y, + 'text-strip-bearing-x': prim_text_strip_bearing_x, + 'text-strip-anchor-x': prim_text_strip_anchor_x, + 'text-strip-anchor-y': prim_text_strip_anchor_y, + 'place-text-strip': prim_place_text_strip, + }) + + return env + + +# ============================================================================= +# Example: Render text using primitives (for testing) +# ============================================================================= + +def render_text_primitives( + frame: jnp.ndarray, + text: str, + x: float, + y: float, + font_size: int = 32, + color: tuple = (255, 255, 255), + opacity: float = 1.0, + use_kerning: bool = True, +) -> jnp.ndarray: + """ + Render text using the primitives. + This is what an S-expression would compile to. + + Args: + use_kerning: If True, apply kerning adjustments between characters + """ + glyphs = get_glyphs(text, None, font_size) + color_arr = jnp.array(color, dtype=jnp.float32) + + cursor = x + for i, g in enumerate(glyphs): + glyph_jax = jnp.asarray(g.image) + frame = place_glyph_jax( + frame, glyph_jax, cursor, y, + g.bearing_x, g.bearing_y, + color_arr, opacity + ) + # Advance cursor with optional kerning + advance = g.advance + if use_kerning and i + 1 < len(glyphs): + advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size) + cursor = cursor + advance + + return frame diff --git a/streaming/sexp_to_jax.py b/streaming/sexp_to_jax.py index a268586..db781f2 100644 --- a/streaming/sexp_to_jax.py +++ b/streaming/sexp_to_jax.py @@ -37,6 +37,9 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from sexp_effects.parser import parse, parse_all, Symbol, Keyword +# Import typography primitives +from streaming.jax_typography import bind_typography_primitives + # ============================================================================= # Compilation Cache @@ -190,6 +193,452 @@ def _get_alphabet_string(alphabet_name: str) -> str: return alphabet_name # Assume it's a custom character string +# ============================================================================= +# Text Rendering with Font Atlas (JAX-compatible) +# ============================================================================= + +# Default character set for text rendering (printable ASCII) +TEXT_CHARSET = ''.join(chr(i) for i in range(32, 127)) # space to ~ + +# Cache for text font atlases: (font_name, font_size) -> (atlas, char_to_idx, char_width, char_height) +_TEXT_ATLAS_CACHE: Dict[tuple, tuple] = {} + + +def _create_text_atlas(font_name: str = None, font_size: int = 32) -> tuple: + """ + Create a font atlas for general text rendering with proper baseline alignment. + + Font Metrics (from typography): + - Ascender: distance from baseline to top of tallest glyph (b, d, h, k, l) + - Descender: distance from baseline to bottom of lowest glyph (g, j, p, q, y) + - Baseline: the line text "sits" on - all characters align to this + - Em-square: the design space, typically = ascender + descender + + Returns: + (atlas, char_to_idx, char_widths, char_height, baseline_offset) + - atlas: numpy array (num_chars, char_height, max_char_width, 4) RGBA + - char_to_idx: dict mapping character to atlas index + - char_widths: numpy array of actual width for each character + - char_height: height of character cells (ascent + descent) + - baseline_offset: pixels from top of cell to baseline (= ascent) + """ + cache_key = (font_name, font_size) + if cache_key in _TEXT_ATLAS_CACHE: + return _TEXT_ATLAS_CACHE[cache_key] + + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError: + raise ImportError("PIL/Pillow required for text rendering") + + # Load font - match drawing.py's font order for consistency + font = None + font_candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', # Same order as drawing.py + '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + '/System/Library/Fonts/Helvetica.ttc', + '/System/Library/Fonts/Arial.ttf', + 'C:\\Windows\\Fonts\\arial.ttf', + ] + + for font_path in font_candidates: + if font_path is None: + continue + try: + font = ImageFont.truetype(font_path, font_size) + break + except (IOError, OSError): + continue + + if font is None: + font = ImageFont.load_default() + + # Get font metrics - this is the key to proper text layout + # getmetrics() returns (ascent, descent) where: + # ascent = pixels from baseline to top of tallest character + # descent = pixels from baseline to bottom of lowest character + ascent, descent = font.getmetrics() + + # Cell dimensions based on font metrics (not per-character bounding boxes) + cell_height = ascent + descent + 2 # +2 for padding + baseline_y = ascent + 1 # Baseline position within cell (1px padding from top) + + # Find max character width + temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0)) + temp_draw = ImageDraw.Draw(temp_img) + + max_width = 0 + char_widths_dict = {} + + for char in TEXT_CHARSET: + try: + # Use getlength for horizontal advance (proper character spacing) + advance = font.getlength(char) + char_widths_dict[char] = int(advance) + max_width = max(max_width, int(advance)) + except: + char_widths_dict[char] = font_size // 2 + max_width = max(max_width, font_size // 2) + + cell_width = max_width + 2 # +2 for padding + + # Create atlas with all characters - draw same way as prim_text for pixel-perfect match + char_to_idx = {} + char_widths = [] # Advance widths + char_left_bearings = [] # Left bearing (x offset from origin to first pixel) + atlas = [] + + # Position to draw at within each tile (with margin for negative bearings) + draw_x = 5 # Margin for chars with negative left bearing + draw_y = 0 # Top of cell (PIL default without anchor) + + for i, char in enumerate(TEXT_CHARSET): + char_to_idx[char] = i + char_widths.append(char_widths_dict.get(char, cell_width // 2)) + + # Create RGBA image for this character + img = Image.new('RGBA', (cell_width, cell_height), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Draw same way as prim_text - at (draw_x, draw_y), no anchor + # This positions the text origin, and glyphs may extend left/right from there + draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font) + + # Get bbox to find left bearing + bbox = draw.textbbox((draw_x, draw_y), char, font=font) + left_bearing = bbox[0] - draw_x # How far left of origin the glyph extends + char_left_bearings.append(left_bearing) + + # Convert to numpy + char_array = np.array(img, dtype=np.uint8) + atlas.append(char_array) + + atlas = np.stack(atlas, axis=0) # (num_chars, char_height, cell_width, 4) + char_widths = np.array(char_widths, dtype=np.int32) + char_left_bearings = np.array(char_left_bearings, dtype=np.int32) + + # Return draw_x (origin offset within tile) so rendering knows where origin is + result = (atlas, char_to_idx, char_widths, cell_height, baseline_y, draw_x, char_left_bearings) + _TEXT_ATLAS_CACHE[cache_key] = result + return result + + +def jax_text_render(frame, text: str, x: int, y: int, + font_name: str = None, font_size: int = 32, + color=(255, 255, 255), opacity: float = 1.0, + align: str = "left", valign: str = "baseline", + shadow: bool = False, shadow_color=(0, 0, 0), + shadow_offset: int = 2): + """ + Render text onto frame using font atlas (JAX-compatible). + + This is designed to be called from within a JIT-compiled function. + The font atlas is created at compile time (using numpy/PIL), + then converted to JAX array for the actual rendering. + + Typography notes: + - Baseline: The line text "sits" on. Most characters rest on this line. + - Ascender: Top of tall letters (b, d, h, k, l) - above baseline + - Descender: Bottom of letters like g, j, p, q, y - below baseline + - For normal text, use valign="baseline" and y = the baseline position + + Args: + frame: Input frame (H, W, 3) + text: Text string to render + x, y: Position reference point (affected by align/valign) + font_name: Font to use (None = default) + font_size: Font size in pixels + color: RGB tuple (0-255) + opacity: 0.0 to 1.0 + align: Horizontal alignment relative to x: + "left" - text starts at x + "center" - text centered on x + "right" - text ends at x + valign: Vertical alignment relative to y: + "baseline" - text baseline at y (default, like normal text) + "top" - top of ascenders at y + "middle" - text vertically centered on y + "bottom" - bottom of descenders at y + shadow: Whether to draw drop shadow + shadow_color: Shadow RGB color + shadow_offset: Shadow offset in pixels + + Returns: + Frame with text rendered + """ + if not text: + return frame + + h, w = frame.shape[:2] + + # Get or create font atlas (this happens at trace time, uses numpy) + atlas_np, char_to_idx, char_widths_np, char_height, baseline_offset, origin_x, left_bearings_np = _create_text_atlas(font_name, font_size) + + # Convert atlas to JAX array + atlas = jnp.asarray(atlas_np) + + # Atlas dimensions + cell_width = atlas.shape[2] + + # Convert text to character indices and compute character widths + # (at trace time, text is static so we can pre-compute) + indices_list = [] + char_x_offsets = [0] # Starting x position for each character + total_width = 0 + + for char in text: + if char in char_to_idx: + idx = char_to_idx[char] + indices_list.append(idx) + char_w = int(char_widths_np[idx]) + else: + indices_list.append(char_to_idx.get(' ', 0)) + char_w = int(char_widths_np[char_to_idx.get(' ', 0)]) + total_width += char_w + char_x_offsets.append(total_width) + + indices = jnp.array(indices_list, dtype=jnp.int32) + num_chars = len(indices_list) + + # Actual text dimensions using proportional widths + text_width = total_width + text_height = char_height + + # Adjust position for horizontal alignment + if align == "center": + x = x - text_width // 2 + elif align == "right": + x = x - text_width + + # Adjust position for vertical alignment + # baseline_offset = pixels from top of cell to baseline + if valign == "baseline": + # y specifies baseline position, so top of text cell is above it + y = y - baseline_offset + elif valign == "middle": + y = y - text_height // 2 + elif valign == "bottom": + y = y - text_height + # valign == "top" needs no adjustment (default) + + # Ensure position is integer + x, y = int(x), int(y) + + # Create text strip with proper character spacing at trace time (using numpy) + # This ensures proportional fonts render correctly + # + # The atlas stores each character drawn at (origin_x, 0) in its tile. + # To place a character at cursor position 'cx': + # - The tile's origin_x should align with cx in the strip + # - So we blit tile to strip starting at (cx - origin_x) + # + # Add padding for characters with negative left bearings + strip_padding = origin_x # Extra space at start for negative bearings + text_strip_np = np.zeros((char_height, strip_padding + text_width + cell_width, 4), dtype=np.uint8) + + for i, char in enumerate(text): + if char in char_to_idx: + idx = char_to_idx[char] + char_tile = atlas_np[idx] # (char_height, cell_width, 4) + cx = char_x_offsets[i] + # Position tile so its origin aligns with cursor position + strip_x = strip_padding + cx - origin_x + if strip_x >= 0: + end_x = min(strip_x + cell_width, text_strip_np.shape[1]) + tile_end = end_x - strip_x + text_strip_np[:, strip_x:end_x] = np.maximum( + text_strip_np[:, strip_x:end_x], char_tile[:, :tile_end]) + + # Trim the strip: + # - Left side: trim to first visible pixel (handles negative left bearing) + # - Right side: use computed text_width (preserve advance width spacing) + alpha = text_strip_np[:, :, 3] + cols_with_content = np.any(alpha > 0, axis=0) + if cols_with_content.any(): + first_col = np.argmax(cols_with_content) + # Right edge: use the computed text width from the strip's logical end + right_col = strip_padding + text_width + # Adjust x to account for the left trim offset + x = x + first_col - strip_padding + text_strip_np = text_strip_np[:, first_col:right_col] + else: + # No visible content, return original frame + return frame + + # Convert to JAX + text_strip = jnp.asarray(text_strip_np) + + # Convert color to array + color = jnp.array(color, dtype=jnp.float32) + shadow_color = jnp.array(shadow_color, dtype=jnp.float32) + + # Apply color tint to text strip (white text * color) + text_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * color + text_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity + + # Start with frame as float + result = frame.astype(jnp.float32) + + # Draw shadow first if enabled + if shadow: + sx, sy = x + shadow_offset, y + shadow_offset + shadow_rgb = text_strip[:, :, :3].astype(jnp.float32) / 255.0 * shadow_color + shadow_alpha = text_strip[:, :, 3].astype(jnp.float32) / 255.0 * opacity * 0.5 + result = _composite_text_strip(result, shadow_rgb, shadow_alpha, sx, sy) + + # Draw main text + result = _composite_text_strip(result, text_rgb, text_alpha, x, y) + + return jnp.clip(result, 0, 255).astype(jnp.uint8) + + +def _composite_text_strip(frame, text_rgb, text_alpha, x, y): + """ + Composite text strip onto frame at position (x, y). + + Uses alpha blending: result = text * alpha + frame * (1 - alpha) + + This is designed to work within JAX tracing. + """ + h, w = frame.shape[:2] + th, tw = text_rgb.shape[:2] + + # Clamp to frame bounds + # Source region (in text strip) + src_x1 = jnp.maximum(0, -x) + src_y1 = jnp.maximum(0, -y) + src_x2 = jnp.minimum(tw, w - x) + src_y2 = jnp.minimum(th, h - y) + + # Destination region (in frame) + dst_x1 = jnp.maximum(0, x) + dst_y1 = jnp.maximum(0, y) + dst_x2 = jnp.minimum(w, x + tw) + dst_y2 = jnp.minimum(h, y + th) + + # Check if there's anything to draw + # (We need to handle this carefully for JAX - can't use Python if with traced values) + # Instead, we'll do the full operation but the slicing will handle bounds + + # Create coordinate grids for the destination region + # We'll use dynamic_slice for JAX-compatible slicing + + # For simplicity and JAX compatibility, we'll create a full-frame text layer + # and composite it - this is less efficient but works with JIT + + # Create full-frame RGBA layer + text_layer_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + text_layer_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Place text strip in the layer using dynamic_update_slice + # First pad the text strip to handle out-of-bounds + padded_rgb = jnp.zeros((h, w, 3), dtype=jnp.float32) + padded_alpha = jnp.zeros((h, w), dtype=jnp.float32) + + # Calculate valid region + y_start = int(max(0, y)) + y_end = int(min(h, y + th)) + x_start = int(max(0, x)) + x_end = int(min(w, x + tw)) + + src_y_start = int(max(0, -y)) + src_y_end = src_y_start + (y_end - y_start) + src_x_start = int(max(0, -x)) + src_x_end = src_x_start + (x_end - x_start) + + # Only proceed if there's a valid region + if y_end > y_start and x_end > x_start and src_y_end > src_y_start and src_x_end > src_x_start: + # Extract the valid portion of text + valid_rgb = text_rgb[src_y_start:src_y_end, src_x_start:src_x_end] + valid_alpha = text_alpha[src_y_start:src_y_end, src_x_start:src_x_end] + + # Use lax.dynamic_update_slice for JAX compatibility + padded_rgb = lax.dynamic_update_slice(padded_rgb, valid_rgb, (y_start, x_start, 0)) + padded_alpha = lax.dynamic_update_slice(padded_alpha, valid_alpha, (y_start, x_start)) + + # Alpha composite: result = text * alpha + frame * (1 - alpha) + alpha_3d = padded_alpha[:, :, jnp.newaxis] + result = padded_rgb * alpha_3d + frame * (1.0 - alpha_3d) + + return result + + +def jax_text_size(text: str, font_name: str = None, font_size: int = 32) -> tuple: + """ + Measure text dimensions (width, height). + + This can be called at compile time to get text dimensions for layout. + + Returns: + (width, height) tuple in pixels + """ + _, char_to_idx, char_widths, char_height, _, _, _ = _create_text_atlas(font_name, font_size) + + # Sum actual character widths + total_width = 0 + for c in text: + if c in char_to_idx: + total_width += int(char_widths[char_to_idx[c]]) + else: + total_width += int(char_widths[char_to_idx.get(' ', 0)]) + + return (total_width, char_height) + + +def jax_font_metrics(font_name: str = None, font_size: int = 32) -> dict: + """ + Get font metrics for layout calculations. + + Typography terms: + - ascent: pixels from baseline to top of tallest glyph (b, d, h, etc.) + - descent: pixels from baseline to bottom of lowest glyph (g, j, p, etc.) + - height: total height = ascent + descent (plus padding) + - baseline: position of baseline from top of text cell + + Returns: + dict with keys: ascent, descent, height, baseline + """ + _, _, _, char_height, baseline_offset, _, _ = _create_text_atlas(font_name, font_size) + + # baseline_offset is pixels from top to baseline (= ascent + padding) + # descent = height - baseline (approximately) + ascent = baseline_offset - 1 # remove padding + descent = char_height - baseline_offset - 1 # remove padding + + return { + 'ascent': ascent, + 'descent': descent, + 'height': char_height, + 'baseline': baseline_offset, + } + + +def jax_fit_text_size(text: str, max_width: int, max_height: int, + font_name: str = None, min_size: int = 8, max_size: int = 200) -> int: + """ + Calculate font size to fit text within bounds. + + Binary search for largest size that fits. + """ + best_size = min_size + low, high = min_size, max_size + + while low <= high: + mid = (low + high) // 2 + w, h = jax_text_size(text, font_name, mid) + + if w <= max_width and h <= max_height: + best_size = mid + low = mid + 1 + else: + high = mid - 1 + + return best_size + + # ============================================================================= # JAX Primitives - True primitives that can't be derived # ============================================================================= @@ -340,29 +789,254 @@ def jax_pool_frame(frame, cell_size): return (r_pooled, g_pooled, b_pooled, lum) +# ============================================================================= +# Scan (Prefix Operations) - JAX implementations +# ============================================================================= + +def jax_scan_add(x, axis=None): + """Cumulative sum (prefix sum).""" + if axis is not None: + return jnp.cumsum(x, axis=int(axis)) + return jnp.cumsum(x.flatten()) + + +def jax_scan_mul(x, axis=None): + """Cumulative product.""" + if axis is not None: + return jnp.cumprod(x, axis=int(axis)) + return jnp.cumprod(x.flatten()) + + +def jax_scan_max(x, axis=None): + """Cumulative maximum.""" + if axis is not None: + return lax.cummax(x, axis=int(axis)) + return lax.cummax(x.flatten(), axis=0) + + +def jax_scan_min(x, axis=None): + """Cumulative minimum.""" + if axis is not None: + return lax.cummin(x, axis=int(axis)) + return lax.cummin(x.flatten(), axis=0) + + +# ============================================================================= +# Outer Product - JAX implementations +# ============================================================================= + +def jax_outer(x, y, op='*'): + """Outer product with configurable operation.""" + x_flat = x.flatten() + y_flat = y.flatten() + + ops = { + '*': lambda a, b: jnp.outer(a, b), + '+': lambda a, b: a[:, None] + b[None, :], + '-': lambda a, b: a[:, None] - b[None, :], + '/': lambda a, b: a[:, None] / b[None, :], + 'max': lambda a, b: jnp.maximum(a[:, None], b[None, :]), + 'min': lambda a, b: jnp.minimum(a[:, None], b[None, :]), + } + + op_fn = ops.get(op, ops['*']) + return op_fn(x_flat, y_flat) + + +def jax_outer_add(x, y): + """Outer sum.""" + return jax_outer(x, y, '+') + + +def jax_outer_mul(x, y): + """Outer product.""" + return jax_outer(x, y, '*') + + +def jax_outer_max(x, y): + """Outer max.""" + return jax_outer(x, y, 'max') + + +def jax_outer_min(x, y): + """Outer min.""" + return jax_outer(x, y, 'min') + + +# ============================================================================= +# Reduce with Axis - JAX implementations +# ============================================================================= + +def jax_reduce_axis(x, op='sum', axis=0): + """Reduce along an axis.""" + axis = int(axis) + ops = { + 'sum': lambda d: jnp.sum(d, axis=axis), + '+': lambda d: jnp.sum(d, axis=axis), + 'mean': lambda d: jnp.mean(d, axis=axis), + 'max': lambda d: jnp.max(d, axis=axis), + 'min': lambda d: jnp.min(d, axis=axis), + 'prod': lambda d: jnp.prod(d, axis=axis), + '*': lambda d: jnp.prod(d, axis=axis), + 'std': lambda d: jnp.std(d, axis=axis), + } + op_fn = ops.get(op, ops['sum']) + return op_fn(x) + + +def jax_sum_axis(x, axis=0): + """Sum along axis.""" + return jnp.sum(x, axis=int(axis)) + + +def jax_mean_axis(x, axis=0): + """Mean along axis.""" + return jnp.mean(x, axis=int(axis)) + + +def jax_max_axis(x, axis=0): + """Max along axis.""" + return jnp.max(x, axis=int(axis)) + + +def jax_min_axis(x, axis=0): + """Min along axis.""" + return jnp.min(x, axis=int(axis)) + + +# ============================================================================= +# Windowed Operations - JAX implementations +# ============================================================================= + +def jax_window(x, size, op='mean', stride=1): + """ + Sliding window operation. + + For 1D arrays: standard sliding window + For 2D arrays: 2D sliding window (size x size) + """ + size = int(size) + stride = int(stride) + + if x.ndim == 1: + # 1D sliding window using convolution trick + n = len(x) + if op == 'sum': + kernel = jnp.ones(size) + return jnp.convolve(x, kernel, mode='valid')[::stride] + elif op == 'mean': + kernel = jnp.ones(size) / size + return jnp.convolve(x, kernel, mode='valid')[::stride] + else: + # For max/min, use manual approach + out_n = (n - size) // stride + 1 + indices = jnp.arange(out_n) * stride + windows = jax.vmap(lambda i: lax.dynamic_slice(x, (i,), (size,)))(indices) + if op == 'max': + return jnp.max(windows, axis=1) + elif op == 'min': + return jnp.min(windows, axis=1) + else: + return jnp.mean(windows, axis=1) + else: + # 2D sliding window + h, w = x.shape[:2] + out_h = (h - size) // stride + 1 + out_w = (w - size) // stride + 1 + + # Extract all windows using vmap + def extract_window(ij): + i, j = ij // out_w, ij % out_w + return lax.dynamic_slice(x, (i * stride, j * stride), (size, size)) + + indices = jnp.arange(out_h * out_w) + windows = jax.vmap(extract_window)(indices) + + if op == 'sum': + result = jnp.sum(windows, axis=(1, 2)) + elif op == 'mean': + result = jnp.mean(windows, axis=(1, 2)) + elif op == 'max': + result = jnp.max(windows, axis=(1, 2)) + elif op == 'min': + result = jnp.min(windows, axis=(1, 2)) + else: + result = jnp.mean(windows, axis=(1, 2)) + + return result.reshape(out_h, out_w) + + +def jax_window_sum(x, size, stride=1): + """Sliding window sum.""" + return jax_window(x, size, 'sum', stride) + + +def jax_window_mean(x, size, stride=1): + """Sliding window mean.""" + return jax_window(x, size, 'mean', stride) + + +def jax_window_max(x, size, stride=1): + """Sliding window max.""" + return jax_window(x, size, 'max', stride) + + +def jax_window_min(x, size, stride=1): + """Sliding window min.""" + return jax_window(x, size, 'min', stride) + + +def jax_integral_image(frame): + """ + Compute integral image (summed area table). + Enables O(1) box blur at any radius. + """ + if frame.ndim == 3: + # Convert to grayscale + gray = jnp.mean(frame.astype(jnp.float32), axis=2) + else: + gray = frame.astype(jnp.float32) + + # Cumsum along both axes + return jnp.cumsum(jnp.cumsum(gray, axis=0), axis=1) + + def jax_sample(frame, x, y): - """Bilinear sample at (x, y) coordinates.""" + """Bilinear sample at (x, y) coordinates. + + Matches OpenCV cv2.remap with INTER_LINEAR and BORDER_CONSTANT (default): + out-of-bounds samples return 0, then bilinear blend includes those zeros. + """ h, w = frame.shape[:2] - # Clamp coordinates - x = jnp.clip(x, 0, w - 1) - y = jnp.clip(y, 0, h - 1) - - # Get integer and fractional parts + # Get integer coords for the 4 sample points x0 = jnp.floor(x).astype(jnp.int32) y0 = jnp.floor(y).astype(jnp.int32) - x1 = jnp.clip(x0 + 1, 0, w - 1) - y1 = jnp.clip(y0 + 1, 0, h - 1) + x1 = x0 + 1 + y1 = y0 + 1 fx = x - x0.astype(jnp.float32) fy = y - y0.astype(jnp.float32) + # Check which sample points are in bounds + valid00 = (x0 >= 0) & (x0 < w) & (y0 >= 0) & (y0 < h) + valid10 = (x1 >= 0) & (x1 < w) & (y0 >= 0) & (y0 < h) + valid01 = (x0 >= 0) & (x0 < w) & (y1 >= 0) & (y1 < h) + valid11 = (x1 >= 0) & (x1 < w) & (y1 >= 0) & (y1 < h) + + # Clamp indices for safe array access (values will be masked anyway) + x0_safe = jnp.clip(x0, 0, w - 1) + x1_safe = jnp.clip(x1, 0, w - 1) + y0_safe = jnp.clip(y0, 0, h - 1) + y1_safe = jnp.clip(y1, 0, h - 1) + # Bilinear interpolation for each channel def interp_channel(c): - c00 = frame[y0, x0, c].astype(jnp.float32) - c10 = frame[y0, x1, c].astype(jnp.float32) - c01 = frame[y1, x0, c].astype(jnp.float32) - c11 = frame[y1, x1, c].astype(jnp.float32) + # Sample with 0 for out-of-bounds (BORDER_CONSTANT) + c00 = jnp.where(valid00, frame[y0_safe, x0_safe, c].astype(jnp.float32), 0.0) + c10 = jnp.where(valid10, frame[y0_safe, x1_safe, c].astype(jnp.float32), 0.0) + c01 = jnp.where(valid01, frame[y1_safe, x0_safe, c].astype(jnp.float32), 0.0) + c11 = jnp.where(valid11, frame[y1_safe, x1_safe, c].astype(jnp.float32), 0.0) return (c00 * (1 - fx) * (1 - fy) + c10 * fx * (1 - fy) + @@ -691,18 +1365,10 @@ def jax_rotate(frame, angle, center_x=None, center_y=None): src_x = cos_t * x_centered - sin_t * y_centered + center_x src_y = sin_t * x_centered + cos_t * y_centered + center_y - # Mask for valid coordinates (out-of-bounds -> black, matching OpenCV) - valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1) - valid_flat = valid.flatten() - # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) - # Zero out-of-bounds pixels (matching OpenCV warpAffine behavior) - r = jnp.where(valid_flat, r, 0) - g = jnp.where(valid_flat, g, 0) - b = jnp.where(valid_flat, b, 0) - return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), @@ -726,17 +1392,10 @@ def jax_scale(frame, scale_x, scale_y=None): src_x = (x_coords - center_x) / scale_x + center_x src_y = (y_coords - center_y) / scale_y + center_y - # Mask for valid coordinates (out-of-bounds -> black) - valid = (src_x >= 0) & (src_x < w - 1) & (src_y >= 0) & (src_y < h - 1) - valid_flat = valid.flatten() - + # Sample using bilinear interpolation + # jax_sample handles BORDER_CONSTANT (returns 0 for out-of-bounds samples) r, g, b = jax_sample(frame, src_x.flatten(), src_y.flatten()) - # Zero out-of-bounds pixels - r = jnp.where(valid_flat, r, 0) - g = jnp.where(valid_flat, g, 0) - b = jnp.where(valid_flat, b, 0) - return jnp.stack([ jnp.clip(r, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g, 0, 255).reshape(h, w).astype(jnp.uint8), @@ -1000,6 +1659,9 @@ class JaxCompiler: # Add derived functions env.update(derived_fns) + # Add typography primitives + bind_typography_primitives(env) + # Add parameters with defaults for pname, pdefault in param_info.items(): if pname in kwargs: @@ -1142,6 +1804,25 @@ class JaxCompiler: raise ValueError(f"Cannot evaluate: {expr}") + def _eval_kwarg(self, args, key: str, default, env: Dict[str, Any]): + """Extract a keyword argument from args list. + + Looks for :key value pattern in args and evaluates the value. + Returns default if not found. + """ + i = 0 + while i < len(args): + if isinstance(args[i], Keyword) and args[i].name == key: + if i + 1 < len(args): + val = self._eval(args[i + 1], env) + # Handle Symbol values (e.g., :op 'sum -> 'sum') + if isinstance(val, Symbol): + return val.name + return val + return default + i += 1 + return default + def _eval_let(self, args, env: Dict[str, Any]) -> Any: """Evaluate (let ((var val) ...) body) or (let* ...) or (let [var val ...] body).""" if len(args) < 2: @@ -1359,18 +2040,29 @@ class JaxCompiler: return result if op == 'or': - vals = [self._eval(a, env) for a in args] - # Use Python or for concrete Python bools - if all(isinstance(v, (bool, np.bool_)) for v in vals): - result = False - for v in vals: - result = result or bool(v) - return result - # Otherwise use JAX logical_or - result = vals[0] - for v in vals[1:]: - result = jnp.logical_or(result, v) - return result + # Lisp-style or: returns first truthy value, not boolean + # (or a b c) returns a if a is truthy, else b if b is truthy, else c + for arg in args: + val = self._eval(arg, env) + # Check if value is truthy + if val is None: + continue + if isinstance(val, (bool, np.bool_)): + if val: + return val + continue + if isinstance(val, (int, float)): + if val: + return val + continue + if hasattr(val, 'shape'): + # JAX/numpy array - return it (considered truthy) + return val + # For other types, check truthiness + if val: + return val + # All values were falsy, return the last one + return self._eval(args[-1], env) if args else None if op == 'not': val = self._eval(args[0], env) @@ -1530,6 +2222,103 @@ class JaxCompiler: if op in ('βall', 'beta-all'): return jnp.all(self._eval(args[0], env)) + # Scan (prefix) operations + if op in ('scan+', 'scan-add'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_add(x, axis) + if op in ('scan*', 'scan-mul'): + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_mul(x, axis) + if op == 'scan-max': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_max(x, axis) + if op == 'scan-min': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', None, env) + return jax_scan_min(x, axis) + + # Outer product operations + if op == 'outer': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + op_type = self._eval_kwarg(args, 'op', '*', env) + return jax_outer(x, y, op_type) + if op in ('outer+', 'outer-add'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_add(x, y) + if op in ('outer*', 'outer-mul'): + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_mul(x, y) + if op == 'outer-max': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_max(x, y) + if op == 'outer-min': + x = self._eval(args[0], env) + y = self._eval(args[1], env) + return jax_outer_min(x, y) + + # Reduce with axis operations + if op == 'reduce-axis': + x = self._eval(args[0], env) + reduce_op = self._eval_kwarg(args, 'op', 'sum', env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_reduce_axis(x, reduce_op, axis) + if op == 'sum-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_sum_axis(x, axis) + if op == 'mean-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_mean_axis(x, axis) + if op == 'max-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_max_axis(x, axis) + if op == 'min-axis': + x = self._eval(args[0], env) + axis = self._eval_kwarg(args, 'axis', 0, env) + return jax_min_axis(x, axis) + + # Windowed operations + if op == 'window': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + win_op = self._eval_kwarg(args, 'op', 'mean', env) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window(x, size, win_op, stride) + if op == 'window-sum': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_sum(x, size, stride) + if op == 'window-mean': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_mean(x, size, stride) + if op == 'window-max': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_max(x, size, stride) + if op == 'window-min': + x = self._eval(args[0], env) + size = int(self._eval(args[1], env)) + stride = int(self._eval_kwarg(args, 'stride', 1, env)) + return jax_window_min(x, size, stride) + + # Integral image + if op == 'integral-image': + frame = self._eval(args[0], env) + return jax_integral_image(frame) + # Convenience - min/max of two values (handle both scalars and arrays) if op == 'min' or op == 'min2': a = self._eval(args[0], env) @@ -1789,6 +2578,87 @@ class JaxCompiler: results.append(result_row) return jnp.stack(results, axis=0) + # ===================================================================== + # Text rendering operations + # ===================================================================== + if op == 'text': + frame = self._eval(args[0], env) + text_str = self._eval(args[1], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + + # Extract keyword arguments + x = self._eval_kwarg(args, 'x', None, env) + y = self._eval_kwarg(args, 'y', None, env) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + color = self._eval_kwarg(args, 'color', (255, 255, 255), env) + opacity = self._eval_kwarg(args, 'opacity', 1.0, env) + align = self._eval_kwarg(args, 'align', 'left', env) + valign = self._eval_kwarg(args, 'valign', 'top', env) + shadow = self._eval_kwarg(args, 'shadow', False, env) + shadow_color = self._eval_kwarg(args, 'shadow-color', (0, 0, 0), env) + shadow_offset = self._eval_kwarg(args, 'shadow-offset', 2, env) + fit = self._eval_kwarg(args, 'fit', False, env) + width = self._eval_kwarg(args, 'width', None, env) + height = self._eval_kwarg(args, 'height', None, env) + + # Handle color as list/tuple + if isinstance(color, (list, tuple)): + color = tuple(int(c) for c in color[:3]) + if isinstance(shadow_color, (list, tuple)): + shadow_color = tuple(int(c) for c in shadow_color[:3]) + + h, w_frame = frame.shape[:2] + + # Default position to 0,0 or center based on alignment + if x is None: + if align == 'center': + x = w_frame // 2 + elif align == 'right': + x = w_frame + else: + x = 0 + if y is None: + if valign == 'middle': + y = h // 2 + elif valign == 'bottom': + y = h + else: + y = 0 + + # Auto-fit text to bounds + if fit and width is not None and height is not None: + font_size = jax_fit_text_size(text_str, int(width), int(height), + font_name, min_size=8, max_size=200) + + return jax_text_render(frame, text_str, int(x), int(y), + font_name=font_name, font_size=int(font_size), + color=color, opacity=float(opacity), + align=str(align), valign=str(valign), + shadow=bool(shadow), shadow_color=shadow_color, + shadow_offset=int(shadow_offset)) + + if op == 'text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + font_size = self._eval_kwarg(args, 'font-size', 32, env) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_text_size(text_str, font_name, int(font_size)) + + if op == 'fit-text-size': + text_str = self._eval(args[0], env) + if isinstance(text_str, Symbol): + text_str = text_str.name + text_str = str(text_str) + max_width = int(self._eval(args[1], env)) + max_height = int(self._eval(args[2], env)) + font_name = self._eval_kwarg(args, 'font-name', None, env) + return jax_fit_text_size(text_str, max_width, max_height, font_name) + # ===================================================================== # Color operations # ===================================================================== @@ -1867,14 +2737,14 @@ class JaxCompiler: # ===================================================================== # Geometry operations # ===================================================================== - if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-img': + if op == 'flip-horizontal' or op == 'flip-h' or op == 'geometry:flip-h' or op == 'geometry:flip-img': frame = self._eval(args[0], env) direction = self._eval(args[1], env) if len(args) > 1 else 'horizontal' if direction == 'vertical' or direction == 'v': return jax_flip_vertical(frame) return jax_flip_horizontal(frame) - if op == 'flip-vertical' or op == 'flip-v': + if op == 'flip-vertical' or op == 'flip-v' or op == 'geometry:flip-v': frame = self._eval(args[0], env) return jax_flip_vertical(frame) @@ -2004,20 +2874,55 @@ class JaxCompiler: ], axis=2) if op == 'geometry:ripple-displace' or op == 'ripple': + # Match Python prim_ripple_displace signature: + # (w h freq amp cx cy decay phase) or (frame ...) first_arg = self._eval(args[0], env) if not hasattr(first_arg, 'shape'): + # Coordinate-only mode: (w h freq amp cx cy decay phase) w = int(first_arg) h = int(self._eval(args[1], env)) - amplitude = self._eval(args[2], env) if len(args) > 2 else 10.0 - frequency = self._eval(args[3], env) if len(args) > 3 else 0.05 + freq = self._eval(args[2], env) if len(args) > 2 else 5.0 + amp = self._eval(args[3], env) if len(args) > 3 else 10.0 + cx = self._eval(args[4], env) if len(args) > 4 else w / 2 + cy = self._eval(args[5], env) if len(args) > 5 else h / 2 + decay = self._eval(args[6], env) if len(args) > 6 else 0.0 + phase = self._eval(args[7], env) if len(args) > 7 else 0.0 frame = None else: + # Frame mode: (frame :amplitude A :frequency F :center_x CX ...) frame = first_arg - amplitude = self._eval(args[1], env) if len(args) > 1 else 10.0 - frequency = self._eval(args[2], env) if len(args) > 2 else 0.05 h, w = frame.shape[:2] - - cx, cy = w / 2, h / 2 + # Parse keyword args + amp = 10.0 + freq = 5.0 + cx = w / 2 + cy = h / 2 + decay = 0.0 + phase = 0.0 + i = 1 + while i < len(args): + if isinstance(args[i], Keyword): + kw = args[i].name + val = self._eval(args[i + 1], env) if i + 1 < len(args) else None + if kw == 'amplitude': + amp = val + elif kw == 'frequency': + freq = val + elif kw == 'center_x': + cx = val * w if val <= 1 else val # normalized or absolute + elif kw == 'center_y': + cy = val * h if val <= 1 else val + elif kw == 'decay': + decay = val + elif kw == 'speed': + # speed affects phase via time + t = env.get('t', 0) + phase = t * val * 2 * jnp.pi + elif kw == 'phase': + phase = val + i += 2 + else: + i += 1 y_coords = jnp.repeat(jnp.arange(h), w).reshape(h, w).astype(jnp.float32) x_coords = jnp.tile(jnp.arange(w), h).reshape(h, w).astype(jnp.float32) @@ -2026,16 +2931,67 @@ class JaxCompiler: dy = y_coords - cy dist = jnp.sqrt(dx*dx + dy*dy) - displacement = amplitude * jnp.sin(dist * frequency) - angle = jnp.arctan2(dy, dx) + # Match Python formula: sin(2*pi*freq*dist/max(w,h) + phase) * amp + max_dim = jnp.maximum(w, h) + ripple = jnp.sin(2 * jnp.pi * freq * dist / max_dim + phase) * amp - src_x = x_coords + displacement * jnp.cos(angle) - src_y = y_coords + displacement * jnp.sin(angle) + # Apply decay (when decay=0, exp(0)=1 so no effect) + decay_factor = jnp.exp(-decay * dist / max_dim) + ripple = ripple * decay_factor + + # Radial displacement - use ADDITION to match Python prim_ripple_displace + # Python (primitives.py line 2890-2891): + # map_x = x_coords + ripple * norm_dx + # map_y = y_coords + ripple * norm_dy + # where norm_dx = dx/dist = cos(angle), norm_dy = dy/dist = sin(angle) + angle = jnp.arctan2(dy, dx) + src_x = x_coords + ripple * jnp.cos(angle) + src_y = y_coords + ripple * jnp.sin(angle) if frame is None: return {'x': src_x, 'y': src_y} + # Sample using bilinear interpolation (jax_sample clamps coords, + # matching OpenCV's default remap behavior) r_out, g_out, b_out = jax_sample(frame, src_x.flatten(), src_y.flatten()) + + return jnp.stack([ + jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), + jnp.clip(b_out, 0, 255).reshape(h, w).astype(jnp.uint8) + ], axis=2) + + if op == 'geometry:coords-x' or op == 'coords-x': + # Extract x coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('x', coords.get('map_x')) + return coords[0] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:coords-y' or op == 'coords-y': + # Extract y coordinates from coord dict + coords = self._eval(args[0], env) + if isinstance(coords, dict): + return coords.get('y', coords.get('map_y')) + return coords[1] if isinstance(coords, (list, tuple)) else coords + + if op == 'geometry:remap' or op == 'remap': + # Remap image using coordinate maps: (frame map_x map_y) + # OpenCV cv2.remap with INTER_LINEAR clamps out-of-bounds coords + frame = self._eval(args[0], env) + map_x = self._eval(args[1], env) + map_y = self._eval(args[2], env) + + h, w = frame.shape[:2] + + # Flatten coordinate maps + src_x = map_x.flatten() + src_y = map_y.flatten() + + # Sample using bilinear interpolation (jax_sample clamps coords internally, + # matching OpenCV's default behavior) + r_out, g_out, b_out = jax_sample(frame, src_x, src_y) + return jnp.stack([ jnp.clip(r_out, 0, 255).reshape(h, w).astype(jnp.uint8), jnp.clip(g_out, 0, 255).reshape(h, w).astype(jnp.uint8), @@ -2435,6 +3391,40 @@ class JaxCompiler: acc = fn_eval(acc, item) return acc + if op == 'fold-indexed': + # (fold-indexed seq init fn) - fold with index + # fn takes (acc item index) or (acc item index cursor) for typography + seq = self._eval(args[0], env) + acc = self._eval(args[1], env) + fn = args[2] # Lambda S-expression + + # Handle lambda + if isinstance(fn, list) and len(fn) >= 3: + head = fn[0] + if isinstance(head, Symbol) and head.name in ('lambda', 'λ'): + params = fn[1] + body = fn[2] + for idx, item in enumerate(seq): + fn_env = env.copy() + if len(params) >= 1: + param_name = params[0].name if isinstance(params[0], Symbol) else str(params[0]) + fn_env[param_name] = acc + if len(params) >= 2: + param_name = params[1].name if isinstance(params[1], Symbol) else str(params[1]) + fn_env[param_name] = item + if len(params) >= 3: + param_name = params[2].name if isinstance(params[2], Symbol) else str(params[2]) + fn_env[param_name] = idx + acc = self._eval(body, fn_env) + return acc + + # Fallback + fn_eval = self._eval(fn, env) + if callable(fn_eval): + for idx, item in enumerate(seq): + acc = fn_eval(acc, item, idx) + return acc + # ===================================================================== # Map-pixels (apply function to each pixel) # ===================================================================== diff --git a/streaming/stream_sexp_generic.py b/streaming/stream_sexp_generic.py index 427832b..0619589 100644 --- a/streaming/stream_sexp_generic.py +++ b/streaming/stream_sexp_generic.py @@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation: """ import sys +import os import time import json import hashlib @@ -62,6 +63,38 @@ class Context: fps: float = 30.0 +class DeferredEffectChain: + """ + Represents a chain of JAX effects that haven't been executed yet. + + Allows effects to be accumulated through let bindings and fused + into a single JIT-compiled function when the result is needed. + """ + __slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num') + + def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int): + self.effects = effects # List of effect names, innermost first + self.params_list = params_list # List of param dicts, matching effects + self.base_frame = base_frame # The actual frame array at the start + self.t = t + self.frame_num = frame_num + + def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain': + """Add another effect to the chain (outermost).""" + return DeferredEffectChain( + self.effects + [effect_name], + self.params_list + [params], + self.base_frame, + self.t, + self.frame_num + ) + + @property + def shape(self): + """Allow shape check without forcing execution.""" + return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None + + class StreamInterpreter: """ Fully generic streaming sexp interpreter. @@ -98,6 +131,9 @@ class StreamInterpreter: self.use_jax = use_jax self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects + self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains + self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains + self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env if use_jax: if _init_jax(): print("JAX acceleration enabled", file=sys.stderr) @@ -238,6 +274,8 @@ class StreamInterpreter: """Load primitives from a Python library file. Prefers GPU-accelerated versions (*_gpu.py) when available. + Uses cached modules from sys.modules to ensure consistent state + (e.g., same RNG instance for all interpreters). """ import importlib.util @@ -264,9 +302,26 @@ class StreamInterpreter: if not lib_path: raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}") - spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + # Use cached module if already imported to preserve state (e.g., RNG) + # This is critical for deterministic random number sequences + # Check multiple possible module keys (standard import paths and our cache) + possible_keys = [ + f"sexp_effects.primitive_libs.{actual_lib_name}", + f"sexp_primitives.{actual_lib_name}", + ] + + module = None + for key in possible_keys: + if key in sys.modules: + module = sys.modules[key] + break + + if module is None: + spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Cache for future use under our key + sys.modules[f"sexp_primitives.{actual_lib_name}"] = module # Check if this is a GPU-accelerated module is_gpu = actual_lib_name.endswith('_gpu') @@ -452,30 +507,353 @@ class StreamInterpreter: try: jax_fn = self.jax_effects[name] - # Ensure frame is numpy array + # Handle GPU frames (CuPy) - need to move to CPU for CPU JAX + # JAX handles numpy and JAX arrays natively, no conversion needed if hasattr(frame, 'cpu'): frame = frame.cpu - elif hasattr(frame, 'get'): - frame = frame.get() + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() # CuPy array -> numpy # Get seed from config for deterministic random seed = self.config.get('seed', 42) # Call JAX function with parameters - result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) - - # Convert result back to numpy if needed - if hasattr(result, 'block_until_ready'): - result.block_until_ready() # Ensure computation is complete - if hasattr(result, '__array__'): - result = np.asarray(result) - - return result + # Return JAX array directly - don't block or convert per-effect + # Conversion to numpy happens once at frame write time + return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) except Exception as e: # Fall back to interpreter on error print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) return None + def _is_jax_effect_expr(self, expr) -> bool: + """Check if an expression is a JAX-compiled effect call.""" + if not isinstance(expr, list) or not expr: + return False + head = expr[0] + if not isinstance(head, Symbol): + return False + return head.name in self.jax_effects + + def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]: + """ + Extract a chain of JAX effects from an expression. + + Returns: (effect_names, params_list, base_frame_expr) or None if not a chain. + effect_names and params_list are in execution order (innermost first). + + For (effect1 (effect2 frame :p1 v1) :p2 v2): + Returns: (['effect2', 'effect1'], [params2, params1], frame_expr) + """ + if not self._is_jax_effect_expr(expr): + return None + + chain = [] + params_list = [] + current = expr + + while self._is_jax_effect_expr(current): + head = current[0] + effect_name = head.name + args = current[1:] + + # Extract params for this effect + effect = self.effects[effect_name] + effect_params = {} + for pname, pdef in effect['params'].items(): + effect_params[pname] = pdef.get('default', 0) + + # Find the frame argument (first positional) and other params + frame_arg = None + i = 0 + while i < len(args): + if isinstance(args[i], Keyword): + pname = args[i].name + if pname in effect['params'] and i + 1 < len(args): + effect_params[pname] = self._eval(args[i + 1], env) + i += 2 + else: + if frame_arg is None: + frame_arg = args[i] # First positional is frame + i += 1 + + chain.append(effect_name) + params_list.append(effect_params) + + if frame_arg is None: + return None # No frame argument found + + # Check if frame_arg is another effect call + if self._is_jax_effect_expr(frame_arg): + current = frame_arg + else: + # End of chain - frame_arg is the base frame + # Reverse to get innermost-first execution order + chain.reverse() + params_list.reverse() + return (chain, params_list, frame_arg) + + return None + + def _get_chain_key(self, effect_names: list, params_list: list) -> str: + """Generate a cache key for an effect chain. + + Includes static param values in the key since they affect compilation. + """ + parts = [] + for name, params in zip(effect_names, params_list): + param_parts = [] + for pname in sorted(params.keys()): + pval = params[pname] + # Include static values in key (strings, bools) + if isinstance(pval, (str, bool)): + param_parts.append(f"{pname}={pval}") + else: + param_parts.append(pname) + parts.append(f"{name}:{','.join(param_parts)}") + return '|'.join(parts) + + def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a chain of effects into a single fused JAX function. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts for each effect (used to detect static types) + + Returns: + JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param names and identify static params from actual values + effect_param_names = [] + static_params = ['seed'] # seed is always static + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + # Check actual values to identify static types + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.append(f"_p{i}_{pname}") + + def fused_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + # Extract params for this effect from kwargs + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # JIT with static params (seed + any string/bool params) + return jax.jit(fused_fn, static_argnames=static_params) + except Exception as e: + print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int): + """Apply a chain of effects, using fused compilation if available.""" + chain_key = self._get_chain_key(effect_names, params_list) + + # Try to get or compile fused chain + if chain_key not in self.jax_fused_chains: + fused_fn = self._compile_effect_chain(effect_names, params_list) + self.jax_fused_chains[chain_key] = fused_fn + if fused_fn: + print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr) + + fused_fn = self.jax_fused_chains.get(chain_key) + + if fused_fn is not None: + # Build kwargs with prefixed param names + kwargs = {} + for i, params in enumerate(params_list): + for pname, pval in params.items(): + kwargs[f"_p{i}_{pname}"] = pval + + seed = self.config.get('seed', 42) + + # Handle GPU frames + if hasattr(frame, 'cpu'): + frame = frame.cpu + elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'): + frame = frame.get() + + try: + return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs) + except Exception as e: + print(f"Fused chain error: {e}", file=sys.stderr) + + # Fall back to sequential application + result = frame + for name, params in zip(effect_names, params_list): + result = self._apply_jax_effect(name, result, params, t, frame_num) + if result is None: + return None + return result + + def _force_deferred(self, deferred: DeferredEffectChain): + """Execute a deferred effect chain and return the actual array.""" + if len(deferred.effects) == 0: + return deferred.base_frame + + return self._apply_effect_chain( + deferred.effects, + deferred.params_list, + deferred.base_frame, + deferred.t, + deferred.frame_num + ) + + def _maybe_force(self, value): + """Force a deferred chain if needed, otherwise return as-is.""" + if isinstance(value, DeferredEffectChain): + return self._force_deferred(value) + return value + + def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]: + """ + Compile a vmapped version of an effect chain for batch processing. + + Args: + effect_names: List of effect names in order [innermost, ..., outermost] + params_list: List of param dicts (used to detect static types) + + Returns: + Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames + Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar + """ + if not _JAX_AVAILABLE: + return None + + try: + import jax + import jax.numpy as jnp + + # Get the individual JAX functions + jax_fns = [self.jax_effects[name] for name in effect_names] + + # Pre-extract param info + effect_param_names = [] + static_params = set() + for i, (name, params) in enumerate(zip(effect_names, params_list)): + param_names = list(params.keys()) + effect_param_names.append(param_names) + for pname, pval in params.items(): + if isinstance(pval, (str, bool)): + static_params.add(f"_p{i}_{pname}") + + # Single-frame function (will be vmapped) + def single_frame_fn(frame, t, frame_num, seed, **kwargs): + result = frame + for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)): + effect_kwargs = {} + for pname in param_names: + key = f"_p{i}_{pname}" + if key in kwargs: + effect_kwargs[pname] = kwargs[key] + result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs) + return result + + # Return unbatched function - we'll vmap at call time with proper in_axes + return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params)) + except Exception as e: + print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr) + return None + + def _apply_batched_chain(self, effect_names: list, params_list_batch: list, + frames: list, ts: list, frame_nums: list) -> Optional[list]: + """ + Apply an effect chain to a batch of frames using vmap. + + Args: + effect_names: List of effect names + params_list_batch: List of params_list for each frame in batch + frames: List of input frames + ts: List of time values + frame_nums: List of frame numbers + + Returns: + List of output frames, or None on failure + """ + if not frames: + return [] + + # Use first frame's params for chain key (assume same structure) + chain_key = self._get_chain_key(effect_names, params_list_batch[0]) + batch_key = f"batch:{chain_key}" + + # Compile batched version if needed + if batch_key not in self.jax_batched_chains: + batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0]) + self.jax_batched_chains[batch_key] = batched_fn + if batched_fn: + print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr) + + batched_fn = self.jax_batched_chains.get(batch_key) + + if batched_fn is not None: + try: + import jax + import jax.numpy as jnp + + # Stack frames into batch array + frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames]) + ts_array = jnp.array(ts) + frame_nums_array = jnp.array(frame_nums) + + # Build kwargs - all numeric params as arrays for vmap + kwargs = {} + static_kwargs = {} # Non-vmapped (strings, bools) + + for i, plist in enumerate(zip(*[p for p in params_list_batch])): + for j, pname in enumerate(params_list_batch[0][i].keys()): + key = f"_p{i}_{pname}" + values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]] + + first = values[0] + if isinstance(first, (str, bool)): + # Static params - not vmapped + static_kwargs[key] = first + elif isinstance(first, (int, float)): + # Always batch numeric params for simplicity + kwargs[key] = jnp.array(values) + elif hasattr(first, 'shape'): + kwargs[key] = jnp.stack(values) + else: + kwargs[key] = jnp.array(values) + + seed = self.config.get('seed', 42) + + # Create wrapper that unpacks the params dict + def single_call(frame, t, frame_num, params_dict): + return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs) + + # vmap over frame, t, frame_num, and the params dict (as pytree) + vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0)) + + # Stack kwargs into a dict of arrays (pytree with matching structure) + results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs) + + # Unstack results + return [results[i] for i in range(len(frames))] + except Exception as e: + print(f"Batched chain error: {e}", file=sys.stderr) + + # Fall back to sequential + return None + def _init(self): """Initialize from sexp - load primitives, effects, defs, scans.""" # Set random seed for deterministic output @@ -869,6 +1247,22 @@ class StreamInterpreter: # === Effects === if op in self.effects: + # Try to detect and fuse effect chains for JAX acceleration + if self.use_jax and op in self.jax_effects: + chain_info = self._extract_effect_chain(expr, env) + if chain_info is not None: + effect_names, params_list, base_frame_expr = chain_info + # Only use chain if we have 2+ effects (worth fusing) + if len(effect_names) >= 2: + base_frame = self._eval(base_frame_expr, env) + if base_frame is not None and hasattr(base_frame, 'shape'): + t = env.get('t', 0.0) + frame_num = env.get('frame-num', 0) + result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num) + if result is not None: + return result + # Fall through if chain application fails + effect = self.effects[op] effect_env = dict(env) @@ -895,17 +1289,28 @@ class StreamInterpreter: positional_idx += 1 i += 1 - # Try JAX-accelerated execution first + # Try JAX-accelerated execution with deferred chaining if self.use_jax and op in self.jax_effects and frame_val is not None: # Build params dict for JAX (exclude 'frame') - jax_params = {k: v for k, v in effect_env.items() + jax_params = {k: self._maybe_force(v) for k, v in effect_env.items() if k != 'frame' and k in effect['params']} t = env.get('t', 0.0) frame_num = env.get('frame-num', 0) - result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num) - if result is not None: - return result - # Fall through to interpreter if JAX fails + + # Check if input is a deferred chain - if so, extend it + if isinstance(frame_val, DeferredEffectChain): + return frame_val.extend(op, jax_params) + + # Check if input is a valid frame - create new deferred chain + if hasattr(frame_val, 'shape'): + return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num) + + # Fall through to interpreter if not a valid frame + + # Force any deferred frame before interpreter evaluation + if isinstance(frame_val, DeferredEffectChain): + frame_val = self._force_deferred(frame_val) + effect_env['frame'] = frame_val return self._eval(effect['body'], effect_env) @@ -922,10 +1327,15 @@ class StreamInterpreter: if isinstance(args[i], Keyword): k = args[i].name v = self._eval(args[i + 1], env) if i + 1 < len(args) else None + # Force deferred chains before passing to primitives + v = self._maybe_force(v) kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) i += 2 else: - evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) + val = self._eval(args[i], env) + # Force deferred chains before passing to primitives + val = self._maybe_force(val) + evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim)) i += 1 try: if kwargs: @@ -1152,6 +1562,61 @@ class StreamInterpreter: eval_times = [] write_times = [] + # Batch accumulation for JAX + batch_deferred = [] # Accumulated DeferredEffectChains + batch_times = [] # Corresponding time values + batch_start_frame = 0 + + def flush_batch(): + """Execute accumulated batch and write results.""" + nonlocal batch_deferred, batch_times + if not batch_deferred: + return + + t_flush = time.time() + + # Check if all chains have same structure (can batch) + first = batch_deferred[0] + can_batch = ( + self.use_jax and + len(batch_deferred) >= 2 and + all(d.effects == first.effects for d in batch_deferred) + ) + + if can_batch: + # Try batched execution + frames = [d.base_frame for d in batch_deferred] + ts = [d.t for d in batch_deferred] + frame_nums = [d.frame_num for d in batch_deferred] + params_batch = [d.params_list for d in batch_deferred] + + results = self._apply_batched_chain( + first.effects, params_batch, frames, ts, frame_nums + ) + + if results is not None: + # Write batched results + for result, t in zip(results, batch_times): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + batch_deferred = [] + batch_times = [] + return + + # Fall back to sequential execution + for deferred, t in zip(batch_deferred, batch_times): + result = self._force_deferred(deferred) + if result is not None and hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, t) + + batch_deferred = [] + batch_times = [] + for frame_num in range(start_frame, n_frames): if not out.is_open: break @@ -1182,8 +1647,23 @@ class StreamInterpreter: eval_times.append(time.time() - t1) t2 = time.time() - if result is not None and hasattr(result, 'shape'): - out.write(result, ctx.t) + if result is not None: + if isinstance(result, DeferredEffectChain): + # Accumulate for batching + batch_deferred.append(result) + batch_times.append(ctx.t) + + # Flush when batch is full + if len(batch_deferred) >= self.jax_batch_size: + flush_batch() + else: + # Not deferred - flush any pending batch first, then write + flush_batch() + if hasattr(result, 'shape'): + if hasattr(result, 'block_until_ready'): + result.block_until_ready() + result = np.asarray(result) + out.write(result, ctx.t) write_times.append(time.time() - t2) frame_elapsed = time.time() - frame_start @@ -1219,6 +1699,9 @@ class StreamInterpreter: except Exception as e: print(f"Warning: progress callback failed: {e}", file=sys.stderr) + # Flush any remaining batch + flush_batch() + finally: out.close() # Store output for access to properties like playlist_cid diff --git a/test_funky_text.py b/test_funky_text.py new file mode 100644 index 0000000..342ef1c --- /dev/null +++ b/test_funky_text.py @@ -0,0 +1,542 @@ +#!/usr/bin/env python3 +""" +Funky comparison tests: PIL vs TextStrip system. +Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + +FONTS = { + 'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + 'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + 'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf', + 'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf', + 'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', + 'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf', + 'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf', + 'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf', +} + + +def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly, including color/opacity.""" + frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + # For opacity, render to RGBA then composite + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + font = _load_font(font_path, font_size) + + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + # PIL fill with alpha for opacity + alpha_int = int(round(opacity * 255)) + fill_rgba = (*fill, alpha_int) + stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None + + if multiline: + draw.multiline_text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=fill_rgba, font=font, + stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor) + + # Composite onto background + bg_img = Image.fromarray(frame).convert('RGBA') + result = Image.alpha_composite(bg_img, txt_layer) + return np.array(result.convert('RGB')) + + +def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100), + fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0, + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip system.""" + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + strip = render_text_strip( + text, font_path, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array(fill, dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, opacity, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + return np.array(result) + + +def compare(name, tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering.""" + pil = render_pil(**kwargs) + strip = render_strip(**kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f" PASS: {name}") + return True + + if tolerance > 0: + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + if best_diff.max() == 0: + print(f" PASS: {name} (within {tolerance}px tolerance)") + return True + + print(f" FAIL: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png") + return False + + +def test_colors(): + """Test various text colors on various backgrounds.""" + print("\n--- Colors ---") + results = [] + + # White on black (baseline) + results.append(compare("color_white_on_black", + text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0))) + + # Red on black + results.append(compare("color_red", + text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0))) + + # Green on black + results.append(compare("color_green", + text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0))) + + # Blue on black + results.append(compare("color_blue", + text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0))) + + # Yellow on dark gray + results.append(compare("color_yellow_on_gray", + text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40))) + + # Magenta on white + results.append(compare("color_magenta_on_white", + text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255))) + + # Subtle: gray text on slightly lighter gray + results.append(compare("color_subtle_gray", + text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64))) + + # Orange on deep blue + results.append(compare("color_orange_on_blue", + text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80))) + + return results + + +def test_opacity(): + """Test different opacity levels.""" + print("\n--- Opacity ---") + results = [] + + results.append(compare("opacity_100", + text="Full", x=20, y=30, opacity=1.0)) + + results.append(compare("opacity_75", + text="75%", x=20, y=30, opacity=0.75)) + + results.append(compare("opacity_50", + text="Half", x=20, y=30, opacity=0.5)) + + results.append(compare("opacity_25", + text="Quarter", x=20, y=30, opacity=0.25)) + + results.append(compare("opacity_10", + text="Ghost", x=20, y=30, opacity=0.1)) + + # Opacity on colored background + results.append(compare("opacity_on_colored_bg", + text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0), + opacity=0.5)) + + # Color + opacity combo + results.append(compare("opacity_red_on_green", + text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0), + opacity=0.6)) + + return results + + +def test_fonts(): + """Test different fonts and sizes.""" + print("\n--- Fonts & Sizes ---") + results = [] + + for label, path in FONTS.items(): + results.append(compare(f"font_{label}", + text="Quick Fox", x=20, y=30, font_path=path, font_size=28, + frame_size=(300, 80))) + + # Tiny text + results.append(compare("size_tiny", + text="Tiny text at 12px", x=10, y=15, font_size=12, + frame_size=(200, 40))) + + # Big text + results.append(compare("size_big", + text="BIG", x=20, y=10, font_size=72, + frame_size=(300, 100))) + + # Huge text + results.append(compare("size_huge", + text="XL", x=10, y=10, font_size=120, + frame_size=(300, 160))) + + return results + + +def test_anchors(): + """Test all anchor combinations.""" + print("\n--- Anchors ---") + results = [] + + # All horizontal x vertical combos + for h in ['l', 'm', 'r']: + for v in ['a', 'm', 's', 'd']: + anchor = h + v + results.append(compare(f"anchor_{anchor}", + text="Anchor", x=200, y=50, anchor=anchor, + frame_size=(400, 100))) + + return results + + +def test_strokes(): + """Test various stroke widths and colors.""" + print("\n--- Strokes ---") + results = [] + + for sw in [1, 2, 3, 4, 6, 8]: + results.append(compare(f"stroke_w{sw}", + text="Stroke", x=30, y=20, font_size=40, + stroke_width=sw, stroke_fill=(0, 0, 0), + frame_size=(300, 80))) + + # Colored strokes + results.append(compare("stroke_red", + text="Red outline", x=20, y=20, font_size=36, + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(350, 80))) + + results.append(compare("stroke_white_on_black", + text="Glow", x=20, y=20, font_size=40, + fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255), + frame_size=(250, 80))) + + # Stroke with bold font + results.append(compare("stroke_bold", + text="Bold+Stroke", x=20, y=20, + font_path=FONTS['bold'], font_size=36, + stroke_width=3, stroke_fill=(0, 0, 0), + frame_size=(400, 80))) + + # Stroke + colored text on colored bg + results.append(compare("stroke_colored_on_bg", + text="Party", x=20, y=20, font_size=48, + fill=(255, 255, 0), bg=(50, 0, 80), + stroke_width=3, stroke_fill=(255, 0, 0), + frame_size=(300, 80))) + + return results + + +def test_edge_clipping(): + """Test text at frame edges - clipping behavior.""" + print("\n--- Edge Clipping ---") + results = [] + + # Text at very left edge + results.append(compare("clip_left_edge", + text="LEFT", x=0, y=30, frame_size=(200, 80))) + + # Text partially off right edge + results.append(compare("clip_right_edge", + text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80))) + + # Text at top edge + results.append(compare("clip_top", + text="TOP", x=20, y=0, frame_size=(200, 80))) + + # Text at bottom edge - partially clipped + results.append(compare("clip_bottom", + text="BOTTOM", x=20, y=55, font_size=40, + frame_size=(200, 80))) + + # Large text overflowing all sides from center + results.append(compare("clip_overflow_center", + text="HUGE", x=75, y=40, font_size=100, anchor="mm", + frame_size=(150, 80))) + + # Corner placement + results.append(compare("clip_corner_br", + text="Corner", x=350, y=70, font_size=36, + frame_size=(400, 100))) + + return results + + +def test_multiline_fancy(): + """Test multiline with various styles.""" + print("\n--- Multiline Fancy ---") + results = [] + + # Right-aligned (1px tolerance: same sub-pixel issue as center alignment) + results.append(compare("multi_right", + text="Right\nAligned\nText", x=380, y=20, + frame_size=(400, 150), + multiline=True, anchor="ra", align="right", + tolerance=1)) + + # Center + stroke + results.append(compare("multi_center_stroke", + text="Title\nSubtitle", x=200, y=20, + font_size=32, frame_size=(400, 120), + multiline=True, anchor="ma", align="center", + stroke_width=2, stroke_fill=(0, 0, 0), + tolerance=1)) + + # Wide line spacing + results.append(compare("multi_wide_spacing", + text="Line A\nLine B\nLine C", x=20, y=10, + frame_size=(300, 200), + multiline=True, line_spacing=20)) + + # Zero extra spacing + results.append(compare("multi_tight_spacing", + text="Tight\nPacked\nLines", x=20, y=10, + frame_size=(300, 150), + multiline=True, line_spacing=0)) + + # Many lines + results.append(compare("multi_many_lines", + text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5, + font_size=20, frame_size=(200, 200), + multiline=True, line_spacing=4)) + + # Bold multiline with stroke + results.append(compare("multi_bold_stroke", + text="BOLD\nSTROKE", x=20, y=10, + font_path=FONTS['bold'], font_size=48, + stroke_width=3, stroke_fill=(200, 0, 0), + frame_size=(350, 150), multiline=True)) + + # Multiline on colored bg with opacity + results.append(compare("multi_opacity_on_bg", + text="Semi\nTransparent", x=20, y=10, + fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7, + frame_size=(300, 120), multiline=True)) + + return results + + +def test_special_chars(): + """Test special characters and edge cases.""" + print("\n--- Special Characters ---") + results = [] + + # Numbers and symbols + results.append(compare("chars_numbers", + text="0123456789", x=20, y=30, frame_size=(300, 80))) + + results.append(compare("chars_punctuation", + text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80))) + + results.append(compare("chars_symbols", + text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80))) + + # Single character + results.append(compare("chars_single", + text="X", x=50, y=30, font_size=48, frame_size=(100, 80))) + + # Very long text (clipped) + results.append(compare("chars_long", + text="The quick brown fox jumps over the lazy dog", x=10, y=30, + font_size=24, frame_size=(400, 80))) + + # Mixed case + results.append(compare("chars_mixed_case", + text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80))) + + return results + + +def test_combos(): + """Complex combinations of features.""" + print("\n--- Combos ---") + results = [] + + # Big bold stroke + color + opacity + results.append(compare("combo_all_features", + text="EPIC", x=20, y=10, + font_path=FONTS['bold'], font_size=64, + fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85, + stroke_width=4, stroke_fill=(180, 0, 0), + frame_size=(350, 100))) + + # Small mono on busy background + results.append(compare("combo_mono_code", + text="fn main() {}", x=10, y=15, + font_path=FONTS['mono'], font_size=16, + fill=(0, 255, 100), bg=(30, 30, 30), + frame_size=(250, 50))) + + # Serif italic multiline with stroke + results.append(compare("combo_serif_italic_multi", + text="Once upon\na time...", x=20, y=10, + font_path=FONTS['italic'], font_size=28, + stroke_width=1, stroke_fill=(80, 80, 80), + frame_size=(300, 120), multiline=True)) + + # Narrow font, big stroke, center anchored + results.append(compare("combo_narrow_stroke_center", + text="NARROW", x=150, y=40, + font_path=FONTS['narrow'], font_size=44, + stroke_width=5, stroke_fill=(0, 0, 0), + anchor="mm", frame_size=(300, 80))) + + # Multiple strips on same frame (simulated by sequential placement) + results.append(compare("combo_opacity_blend", + text="Layered", x=20, y=30, + fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5, + font_size=48, frame_size=(300, 80))) + + return results + + +def test_multi_strip_overlay(): + """Test placing multiple strips on the same frame.""" + print("\n--- Multi-Strip Overlay ---") + results = [] + + frame_size = (400, 150) + bg = (20, 20, 40) + + # PIL version - multiple draw calls + font1 = _load_font(FONTS['bold'], 48) + font2 = _load_font(None, 24) + font3 = _load_font(FONTS['mono'], 18) + + pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8) + txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(txt_layer) + draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1, + stroke_width=2, stroke_fill=(0, 0, 0, 255)) + draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2) + draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3) + bg_img = Image.fromarray(pil_frame).convert('RGBA') + pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB')) + + # Strip version - multiple placements + frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8) + + s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0)) + s2 = render_text_strip("Subtitle here", None, 24) + s3 = render_text_strip("code_snippet()", FONTS['mono'], 18) + + frame = place_text_strip_jax( + frame, jnp.asarray(s1.image), 20, 10, + s1.baseline_y, s1.bearing_x, + jnp.array([255, 255, 0], dtype=jnp.float32), 1.0, + anchor_x=s1.anchor_x, anchor_y=s1.anchor_y, + stroke_width=s1.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s2.image), 20, 70, + s2.baseline_y, s2.bearing_x, + jnp.array([200, 200, 200], dtype=jnp.float32), 1.0, + anchor_x=s2.anchor_x, anchor_y=s2.anchor_y, + stroke_width=s2.stroke_width) + + frame = place_text_strip_jax( + frame, jnp.asarray(s3.image), 20, 110, + s3.baseline_y, s3.bearing_x, + jnp.array([0, 255, 128], dtype=jnp.float32), 200/255, + anchor_x=s3.anchor_x, anchor_y=s3.anchor_y, + stroke_width=s3.stroke_width) + + strip_result = np.array(frame) + + diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff <= 1: + print(f" PASS: multi_overlay (max_diff={max_diff})") + results.append(True) + else: + print(f" FAIL: multi_overlay") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png") + Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png") + diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png") + print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png") + results.append(False) + + return results + + +def main(): + print("=" * 60) + print("Funky TextStrip vs PIL Comparison") + print("=" * 60) + + all_results = [] + all_results.extend(test_colors()) + all_results.extend(test_opacity()) + all_results.extend(test_fonts()) + all_results.extend(test_anchors()) + all_results.extend(test_strokes()) + all_results.extend(test_edge_clipping()) + all_results.extend(test_multiline_fancy()) + all_results.extend(test_special_chars()) + all_results.extend(test_combos()) + all_results.extend(test_multi_strip_overlay()) + + print("\n" + "=" * 60) + passed = sum(all_results) + total = len(all_results) + print(f"Results: {passed}/{total} passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + failed = [i for i, r in enumerate(all_results) if not r] + print(f"FAILED: {total - passed} tests (indices: {failed})") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/test_pil_options.py b/test_pil_options.py new file mode 100644 index 0000000..fb5ffb9 --- /dev/null +++ b/test_pil_options.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Explore PIL text options and test if we can match them. +""" + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +def load_font(font_name=None, font_size=32): + """Load a font.""" + candidates = [ + font_name, + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', + ] + for path in candidates: + if path is None: + continue + try: + return ImageFont.truetype(path, font_size) + except (IOError, OSError): + continue + return ImageFont.load_default() + + +def test_pil_options(): + """Test various PIL text options.""" + + # Create a test frame + frame_size = (600, 400) + + font = load_font(None, 36) + font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36) + font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36) + + tests = [] + + # Test 1: Basic text + def basic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font) + return img, "basic" + tests.append(basic_text) + + # Test 2: Stroke/outline + def stroke_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font, + stroke_width=2, stroke_fill=(255, 0, 0, 255)) + return img, "stroke" + tests.append(stroke_text) + + # Test 3: Bold font + def bold_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold) + return img, "bold" + tests.append(bold_text) + + # Test 4: Italic font + def italic_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic) + return img, "italic" + tests.append(italic_text) + + # Test 5: Different anchors + def anchor_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + # Draw crosshairs at anchor points + for x in [100, 300, 500]: + draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255)) + draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255)) + + draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm") + draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm") + draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm") + return img, "anchor" + tests.append(anchor_text) + + # Test 6: Multiline text + def multiline_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three", + fill=(255, 255, 255, 255), font=font, spacing=10) + return img, "multiline" + tests.append(multiline_text) + + # Test 7: Semi-transparent text + def alpha_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font) + draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font) + draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font) + return img, "alpha" + tests.append(alpha_text) + + # Test 8: Colored text + def colored_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font) + draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font) + draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font) + draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font) + return img, "colored" + tests.append(colored_text) + + # Test 9: Large stroke + def large_stroke(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font, + stroke_width=5, stroke_fill=(0, 0, 0, 255)) + return img, "large_stroke" + tests.append(large_stroke) + + # Test 10: Emoji (if supported) + def emoji_text(): + img = Image.new('RGBA', frame_size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + try: + # Try to find an emoji font + emoji_font = None + emoji_paths = [ + '/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf', + '/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf', + ] + for p in emoji_paths: + try: + emoji_font = ImageFont.truetype(p, 36) + break + except: + pass + if emoji_font: + draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font) + else: + draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font) + except Exception as e: + draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font) + return img, "emoji" + tests.append(emoji_text) + + # Run all tests + print("PIL Text Options Test") + print("=" * 60) + + for test_fn in tests: + img, name = test_fn() + fname = f"/tmp/pil_test_{name}.png" + img.save(fname) + print(f"Saved: {fname}") + + print("\nCheck /tmp/pil_test_*.png for results") + + # Print available parameters + print("\n" + "=" * 60) + print("PIL draw.text() parameters:") + print(" - xy: position tuple") + print(" - text: string to draw") + print(" - fill: color (R,G,B) or (R,G,B,A)") + print(" - font: ImageFont object") + print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)") + print(" - spacing: line spacing for multiline") + print(" - align: 'left', 'center', 'right' for multiline") + print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)") + print(" - features: OpenType features list") + print(" - language: language code for shaping") + print(" - stroke_width: outline width in pixels") + print(" - stroke_fill: outline color") + print(" - embedded_color: use embedded color glyphs (emoji)") + + +if __name__ == "__main__": + test_pil_options() diff --git a/test_styled_text.py b/test_styled_text.py new file mode 100644 index 0000000..925a7fb --- /dev/null +++ b/test_styled_text.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Test styled TextStrip rendering against PIL. +""" + +import numpy as np +import jax.numpy as jnp +from PIL import Image, ImageDraw, ImageFont + +from streaming.jax_typography import ( + render_text_strip, place_text_strip_jax, _load_font +) + + +def render_pil(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with PIL directly.""" + frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8) + img = Image.fromarray(frame) + draw = ImageDraw.Draw(img) + + font = _load_font(None, font_size) + + # Default stroke fill + if stroke_fill is None: + stroke_fill = (0, 0, 0) + + if multiline: + draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, + spacing=line_spacing, align=align, anchor=anchor) + else: + draw.text((x, y), text, fill=(255, 255, 255), font=font, + stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor) + + return np.array(img) + + +def render_strip(text, x, y, font_size=36, frame_size=(400, 100), + stroke_width=0, stroke_fill=None, anchor="la", + multiline=False, line_spacing=4, align="left"): + """Render with TextStrip.""" + frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8) + + strip = render_text_strip( + text, None, font_size, + stroke_width=stroke_width, stroke_fill=stroke_fill, + anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align + ) + strip_img = jnp.asarray(strip.image) + color = jnp.array([255, 255, 255], dtype=jnp.float32) + + result = place_text_strip_jax( + frame, strip_img, x, y, + strip.baseline_y, strip.bearing_x, + color, 1.0, + anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, + stroke_width=strip.stroke_width + ) + + return np.array(result) + + +def compare(name, text, x, y, font_size=36, frame_size=(400, 100), + tolerance=0, **kwargs): + """Compare PIL and TextStrip rendering. + + tolerance=0: exact pixel match required + tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences + in center-aligned multiline text where the strip is pre-rendered + at a different base position than the final placement) + """ + pil = render_pil(text, x, y, font_size, frame_size, **kwargs) + strip = render_strip(text, x, y, font_size, frame_size, **kwargs) + + diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) + max_diff = diff.max() + pixels_diff = (diff > 0).any(axis=2).sum() + + if max_diff == 0: + print(f"PASS: {name}") + print(f" Max diff: 0, Pixels different: 0") + return True + + if tolerance > 0: + # Check if the difference is just a sub-pixel position shift: + # for each shifted version, compute the minimum diff + best_diff = diff.copy() + for dy in range(-tolerance, tolerance + 1): + for dx in range(-tolerance, tolerance + 1): + if dy == 0 and dx == 0: + continue + shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) + sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) + best_diff = np.minimum(best_diff, sdiff) + max_shift_diff = best_diff.max() + pixels_shift_diff = (best_diff > 0).any(axis=2).sum() + if max_shift_diff == 0: + print(f"PASS: {name} (within {tolerance}px position tolerance)") + print(f" Raw diff: {max_diff}, After shift tolerance: 0") + return True + + status = "FAIL" + print(f"{status}: {name}") + print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") + + # Save debug images + Image.fromarray(pil).save(f"/tmp/pil_{name}.png") + Image.fromarray(strip).save(f"/tmp/strip_{name}.png") + diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8) + Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png") + print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png") + + return False + + +def main(): + print("=" * 60) + print("Styled TextStrip vs PIL Comparison") + print("=" * 60) + + results = [] + + # Basic text + results.append(compare("basic", "Hello World", 20, 50)) + + # Stroke/outline + results.append(compare("stroke_2", "Outlined", 20, 50, + stroke_width=2, stroke_fill=(255, 0, 0))) + + results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48, + frame_size=(500, 120), + stroke_width=5, stroke_fill=(0, 0, 0))) + + # Anchors - center + results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100), + anchor="mm")) + + # Anchors - right + results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100), + anchor="rm")) + + # Multiline + results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20, + frame_size=(400, 150), + multiline=True, line_spacing=8)) + + # Multiline centered (1px tolerance: sub-pixel rendering differs because + # the strip is pre-rendered at an integer position while PIL's center + # alignment uses fractional getlength values for the 'm' anchor shift) + results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20, + frame_size=(400, 150), + multiline=True, anchor="ma", align="center", + tolerance=1)) + + # Stroke + multiline + results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20, + frame_size=(400, 120), + stroke_width=2, stroke_fill=(0, 0, 255), + multiline=True)) + + print("=" * 60) + passed = sum(results) + total = len(results) + print(f"Results: {passed}/{total} passed") + + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILED: {total - passed} tests") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/test_jax_pipeline_integration.py b/tests/test_jax_pipeline_integration.py new file mode 100644 index 0000000..8f9fb93 --- /dev/null +++ b/tests/test_jax_pipeline_integration.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +"""Integration tests comparing JAX and Python rendering pipelines. + +These tests ensure the JAX-compiled effect chains produce identical output +to the Python/NumPy path. They test: +1. Full effect pipelines through both interpreters +2. Multi-frame sequences (to catch phase-dependent bugs) +3. Compiled effect chain fusion +4. Edge cases like shrinking/zooming that affect boundary handling +""" +import os +import sys +import pytest +import numpy as np +import shutil +from pathlib import Path + +# Ensure the art-celery module is importable +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sexp_effects.primitive_libs import core as core_mod + + +# Path to test resources +TEST_DIR = Path('/home/giles/art/test') +EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects' +TEMPLATES_DIR = TEST_DIR / 'templates' + + +def create_test_image(h=96, w=128): + """Create a test image with distinct patterns.""" + import cv2 + img = np.zeros((h, w, 3), dtype=np.uint8) + + # Create gradient background + for y in range(h): + for x in range(w): + img[y, x] = [ + int(255 * x / w), # R: horizontal gradient + int(255 * y / h), # G: vertical gradient + 128 # B: constant + ] + + # Add features + cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1) + cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1) + + return img + + +@pytest.fixture(scope='module') +def test_env(tmp_path_factory): + """Set up test environment with sexp files and test media.""" + test_dir = tmp_path_factory.mktemp('sexp_test') + original_dir = os.getcwd() + os.chdir(test_dir) + + # Create directories + (test_dir / 'effects').mkdir() + (test_dir / 'sexp_effects' / 'effects').mkdir(parents=True) + + # Create test image + import cv2 + test_img = create_test_image() + cv2.imwrite(str(test_dir / 'test_image.png'), test_img) + + # Copy required effect files + for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']: + src = EFFECTS_DIR / f'{effect}.sexp' + dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp' + if src.exists(): + shutil.copy(src, dst) + + yield { + 'dir': test_dir, + 'image_path': test_dir / 'test_image.png', + 'test_img': test_img, + } + + os.chdir(original_dir) + + +def create_sexp_file(test_dir, content, filename='test.sexp'): + """Create a test sexp file.""" + path = test_dir / 'effects' / filename + with open(path, 'w') as f: + f.write(content) + return str(path) + + +class TestJaxPythonPipelineEquivalence: + """Test that JAX and Python pipelines produce equivalent output.""" + + def test_single_rotate_effect(self, test_env): + """Test that a single rotate effect matches between Python and JAX.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + + (frame (rotate frame :angle 15 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + # Python path + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + # JAX path + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + # Inject test image into globals + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = py_interp._eval(py_interp.frame_pipeline, frame_env) + jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env) + + # Force deferred if needed + py_result = np.asarray(py_interp._maybe_force(py_result)) + jax_result = np.asarray(jax_interp._maybe_force(jax_result)) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold" + + def test_rotate_then_zoom(self, test_env): + """Test rotate followed by zoom - tests effect chain fusion.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (-> (rotate frame :angle 15 :speed 0) + (zoom :amount 0.95 :speed 0))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold" + + def test_zoom_shrink_boundary_handling(self, test_env): + """Test zoom with shrinking factor - critical for boundary handling.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame (zoom frame :amount 0.8 :speed 0)) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + # Check corners specifically - these are most affected by boundary handling + h, w = test_img.shape[:2] + corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)] + for y, x in corners: + py_val = py_result[y, x] + jax_val = jax_result[y, x] + corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max() + assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}" + + def test_blend_two_clips(self, test_env): + """Test blending two effect chains - the core bug scenario.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + + # Check edge region specifically + edge_diff = diff[0, :].mean() + + assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold" + assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold" + + def test_blend_with_invert(self, test_env): + """Test blending with invert - matches the problematic recipe pattern.""" + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (require-primitives "core") + (require-primitives "image") + (require-primitives "blending") + (require-primitives "color_ops") + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + (effect blend :path "../sexp_effects/effects/blend.sexp") + (effect invert :path "../sexp_effects/effects/invert.sexp") + + (frame + (let [clip_a (-> (rotate frame :angle 5 :speed 0) + (zoom :amount 1.05 :speed 0) + (invert :amount 1)) + clip_b (-> (rotate frame :angle -5 :speed 0) + (zoom :amount 0.95 :speed 0))] + (blend :base clip_a :overlay clip_b :opacity 0.5))) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + from streaming.stream_sexp_generic import StreamInterpreter, Context + import cv2 + + test_img = cv2.imread(str(test_env['image_path'])) + + core_mod.set_random_seed(42) + py_interp = StreamInterpreter(sexp_path, use_jax=False) + py_interp._init() + + core_mod.set_random_seed(42) + jax_interp = StreamInterpreter(sexp_path, use_jax=True) + jax_interp._init() + + ctx = Context(fps=10) + ctx.t = 0.5 + ctx.frame_num = 5 + + frame_env = { + 'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10}, + 't': ctx.t, 'frame-num': ctx.frame_num, + } + + py_interp.globals['frame'] = test_img + jax_interp.globals['frame'] = test_img + + py_result = np.asarray(py_interp._maybe_force( + py_interp._eval(py_interp.frame_pipeline, frame_env))) + jax_result = np.asarray(jax_interp._maybe_force( + jax_interp._eval(jax_interp.frame_pipeline, frame_env))) + + diff = np.abs(py_result.astype(float) - jax_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold" + + +class TestDeferredEffectChainFusion: + """Test the DeferredEffectChain fusion mechanism specifically.""" + + def test_manual_vs_fused_chain(self, test_env): + """Compare manual application vs fused DeferredEffectChain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + # Create minimal sexp to load effects + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Manual step-by-step application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + rot_angle = -5.0 + zoom_amount = 0.95 + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain application + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + diff = np.abs(manual_result.astype(float) - fused_result.astype(float)) + mean_diff = np.mean(diff) + + assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}" + + # Check specific pixels + h, w = test_img.shape[:2] + for y in [0, h//2, h-1]: + for x in [0, w//2, w-1]: + manual_val = manual_result[y, x] + fused_val = fused_result[y, x] + pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max() + assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}" + + def test_chain_with_shrink_zoom_boundary(self, test_env): + """Test that shrinking zoom handles boundaries correctly in chain.""" + import jax.numpy as jnp + from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain + + sexp_content = '''(stream "test" + :width 128 + :height 96 + :seed 42 + + (effect rotate :path "../sexp_effects/effects/rotate.sexp") + (effect zoom :path "../sexp_effects/effects/zoom.sexp") + + (frame frame) +) +''' + sexp_path = create_sexp_file(test_env['dir'], sexp_content) + + core_mod.set_random_seed(42) + interp = StreamInterpreter(sexp_path, use_jax=True) + interp._init() + + test_img = test_env['test_img'] + jax_frame = jnp.array(test_img) + t = 0.5 + frame_num = 5 + + # Parameters that shrink the image (zoom < 1.0) + rot_angle = -4.555 + zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries + + # Manual application + rotate_fn = interp.jax_effects['rotate'] + zoom_fn = interp.jax_effects['zoom'] + + manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42, + angle=rot_angle, speed=0) + manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42, + amount=zoom_amount, speed=0) + manual_result = np.asarray(manual_result) + + # Fused chain + chain = DeferredEffectChain( + ['rotate'], + [{'angle': rot_angle, 'speed': 0}], + jax_frame, t, frame_num + ) + chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0}) + + fused_result = np.asarray(interp._force_deferred(chain)) + + # Check top edge specifically - this is where boundary issues manifest + top_edge_manual = manual_result[0, :] + top_edge_fused = fused_result[0, :] + + edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float)) + mean_edge_diff = np.mean(edge_diff) + + assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}" + + # Check for zeros at edges that shouldn't be there + manual_edge_sum = np.sum(top_edge_manual) + fused_edge_sum = np.sum(top_edge_fused) + + if manual_edge_sum > 100: # If manual has significant values + assert fused_edge_sum > manual_edge_sum * 0.5, \ + f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}" + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_jax_primitives.py b/tests/test_jax_primitives.py new file mode 100644 index 0000000..5fad678 --- /dev/null +++ b/tests/test_jax_primitives.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +""" +Test framework to verify JAX primitives match Python primitives. + +Compares output of each primitive through: +1. Python/NumPy path +2. JAX path (CPU) +3. JAX path (GPU) - if available + +Reports any mismatches with detailed diffs. +""" +import sys +sys.path.insert(0, '/home/giles/art/art-celery') + +import numpy as np +import json +from pathlib import Path +from typing import Dict, List, Tuple, Any, Optional +from dataclasses import dataclass, field + +# Test configuration +TEST_WIDTH = 64 +TEST_HEIGHT = 48 +TOLERANCE_MEAN = 1.0 # Max allowed mean difference +TOLERANCE_MAX = 10.0 # Max allowed single pixel difference +TOLERANCE_PCT = 0.95 # Min % of pixels within ±1 + + +@dataclass +class TestResult: + primitive: str + passed: bool + python_mean: float = 0.0 + jax_mean: float = 0.0 + diff_mean: float = 0.0 + diff_max: float = 0.0 + pct_within_1: float = 0.0 + error: str = "" + + +def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'): + """Create a test frame with known pattern.""" + if pattern == 'gradient': + # Diagonal gradient + y, x = np.mgrid[0:h, 0:w] + r = (x * 255 / w).astype(np.uint8) + g = (y * 255 / h).astype(np.uint8) + b = ((x + y) * 127 / (w + h)).astype(np.uint8) + return np.stack([r, g, b], axis=2) + elif pattern == 'checkerboard': + y, x = np.mgrid[0:h, 0:w] + check = ((x // 8) + (y // 8)) % 2 + v = (check * 255).astype(np.uint8) + return np.stack([v, v, v], axis=2) + elif pattern == 'solid': + return np.full((h, w, 3), 128, dtype=np.uint8) + else: + return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + + +def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]: + """Compare two outputs, return (mean_diff, max_diff, pct_within_1).""" + if py_out is None or jax_out is None: + return float('inf'), float('inf'), 0.0 + + if isinstance(py_out, dict) and isinstance(jax_out, dict): + # Compare coordinate maps + diffs = [] + for k in py_out: + if k in jax_out: + py_arr = np.asarray(py_out[k]) + jax_arr = np.asarray(jax_out[k]) + if py_arr.shape == jax_arr.shape: + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + diffs.append(diff) + if diffs: + all_diff = np.concatenate([d.flatten() for d in diffs]) + return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1)) + return float('inf'), float('inf'), 0.0 + + py_arr = np.asarray(py_out) + jax_arr = np.asarray(jax_out) + + if py_arr.shape != jax_arr.shape: + return float('inf'), float('inf'), 0.0 + + diff = np.abs(py_arr.astype(float) - jax_arr.astype(float)) + return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1)) + + +# ============================================================================ +# Primitive Test Definitions +# ============================================================================ + +PRIMITIVE_TESTS = { + # Geometry primitives + 'geometry:ripple-displace': { + 'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5], + 'returns': 'coords', + }, + 'geometry:rotate-img': { + 'args': ['frame', 45], + 'returns': 'frame', + }, + 'geometry:scale-img': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'geometry:flip-h': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'geometry:flip-v': { + 'args': ['frame'], + 'returns': 'frame', + }, + + # Color operations + 'color_ops:invert': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:grayscale': { + 'args': ['frame'], + 'returns': 'frame', + }, + 'color_ops:brightness': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:contrast': { + 'args': ['frame', 1.5], + 'returns': 'frame', + }, + 'color_ops:hue-shift': { + 'args': ['frame', 90], + 'returns': 'frame', + }, + + # Image operations + 'image:width': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:height': { + 'args': ['frame'], + 'returns': 'scalar', + }, + 'image:channel': { + 'args': ['frame', 0], + 'returns': 'array', + }, + + # Blending + 'blending:blend': { + 'args': ['frame', 'frame2', 0.5], + 'returns': 'frame', + }, + 'blending:blend-add': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, + 'blending:blend-multiply': { + 'args': ['frame', 'frame2'], + 'returns': 'frame', + }, +} + + +def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the Python interpreter.""" + if prim_name not in interp.primitives: + return None + + func = interp.primitives[prim_name] + args = [] + for a in test_def['args']: + if a == 'frame': + args.append(frame.copy()) + elif a == 'frame2': + args.append(frame2.copy()) + else: + args.append(a) + + try: + return func(*args) + except Exception as e: + return None + + +def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any: + """Run a primitive through the JAX compiler.""" + try: + from streaming.sexp_to_jax import JaxCompiler + import jax.numpy as jnp + + compiler = JaxCompiler() + + # Build a simple expression to test the primitive + from sexp_effects.parser import Symbol, Keyword + + args = [] + env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)} + + for a in test_def['args']: + if a == 'frame': + args.append(Symbol('frame')) + elif a == 'frame2': + args.append(Symbol('frame2')) + else: + args.append(a) + + # Create expression: (prim_name arg1 arg2 ...) + expr = [Symbol(prim_name)] + args + + result = compiler._eval(expr, env) + + if hasattr(result, '__array__'): + return np.asarray(result) + return result + + except Exception as e: + return None + + +def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult: + """Test a single primitive.""" + frame = create_test_frame(pattern='gradient') + frame2 = create_test_frame(pattern='checkerboard') + + result = TestResult(primitive=prim_name, passed=False) + + # Run Python version + try: + py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2) + if py_out is not None and hasattr(py_out, 'shape'): + result.python_mean = float(np.mean(py_out)) + except Exception as e: + result.error = f"Python error: {e}" + return result + + # Run JAX version + try: + jax_out = run_jax_primitive(prim_name, test_def, frame, frame2) + if jax_out is not None and hasattr(jax_out, 'shape'): + result.jax_mean = float(np.mean(jax_out)) + except Exception as e: + result.error = f"JAX error: {e}" + return result + + if py_out is None: + result.error = "Python returned None" + return result + if jax_out is None: + result.error = "JAX returned None" + return result + + # Compare + diff_mean, diff_max, pct = compare_outputs(py_out, jax_out) + result.diff_mean = diff_mean + result.diff_max = diff_max + result.pct_within_1 = pct + + # Check pass/fail + result.passed = ( + diff_mean <= TOLERANCE_MEAN and + diff_max <= TOLERANCE_MAX and + pct >= TOLERANCE_PCT + ) + + if not result.passed: + result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}" + + return result + + +def discover_primitives(interp) -> List[str]: + """Discover all primitives available in the interpreter.""" + return sorted(interp.primitives.keys()) + + +def run_all_tests(verbose=True): + """Run all primitive tests.""" + import warnings + warnings.filterwarnings('ignore') + + import os + os.chdir('/home/giles/art/test') + + from streaming.stream_sexp_generic import StreamInterpreter + from sexp_effects.primitive_libs import core as core_mod + + core_mod.set_random_seed(42) + + # Create interpreter to get Python primitives + interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False) + interp._init() + + results = [] + + print("=" * 60) + print("JAX PRIMITIVE TEST SUITE") + print("=" * 60) + + # Test defined primitives + for prim_name, test_def in PRIMITIVE_TESTS.items(): + result = test_primitive(interp, prim_name, test_def) + results.append(result) + + status = "✓ PASS" if result.passed else "✗ FAIL" + if verbose: + print(f"{status} {prim_name}") + if not result.passed: + print(f" {result.error}") + + # Summary + passed = sum(1 for r in results if r.passed) + failed = sum(1 for r in results if not r.passed) + + print("\n" + "=" * 60) + print(f"SUMMARY: {passed} passed, {failed} failed") + print("=" * 60) + + if failed > 0: + print("\nFailed primitives:") + for r in results: + if not r.passed: + print(f" - {r.primitive}: {r.error}") + + return results + + +if __name__ == '__main__': + run_all_tests() diff --git a/tests/test_xector.py b/tests/test_xector.py new file mode 100644 index 0000000..0d006e5 --- /dev/null +++ b/tests/test_xector.py @@ -0,0 +1,305 @@ +""" +Tests for xector primitives - parallel array operations. +""" + +import pytest +import numpy as np +from sexp_effects.primitive_libs.xector import ( + Xector, + xector_red, xector_green, xector_blue, xector_rgb, + xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm, + xector_dist_from_center, + alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp, + alpha_sin, alpha_cos, alpha_sq, + alpha_lt, alpha_gt, alpha_eq, + beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count, + xector_where, xector_fill, xector_zeros, xector_ones, + is_xector, +) + + +class TestXectorBasics: + """Test Xector class basic operations.""" + + def test_create_from_list(self): + x = Xector([1, 2, 3]) + assert len(x) == 3 + assert is_xector(x) + + def test_create_from_numpy(self): + arr = np.array([1.0, 2.0, 3.0]) + x = Xector(arr) + assert len(x) == 3 + np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32)) + + def test_implicit_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = a + b + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_implicit_mul(self): + a = Xector([1, 2, 3]) + b = Xector([2, 2, 2]) + c = a * b + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_scalar_broadcast(self): + a = Xector([1, 2, 3]) + c = a + 10 + np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13]) + + def test_scalar_broadcast_rmul(self): + a = Xector([1, 2, 3]) + c = 2 * a + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + +class TestAlphaOperations: + """Test α (element-wise) operations.""" + + def test_alpha_add(self): + a = Xector([1, 2, 3]) + b = Xector([4, 5, 6]) + c = alpha_add(a, b) + np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9]) + + def test_alpha_add_multi(self): + a = Xector([1, 2, 3]) + b = Xector([1, 1, 1]) + c = Xector([10, 10, 10]) + d = alpha_add(a, b, c) + np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14]) + + def test_alpha_mul_scalar(self): + a = Xector([1, 2, 3]) + c = alpha_mul(a, 2) + np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6]) + + def test_alpha_sqrt(self): + a = Xector([1, 4, 9, 16]) + c = alpha_sqrt(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4]) + + def test_alpha_clamp(self): + a = Xector([-5, 0, 5, 10, 15]) + c = alpha_clamp(a, 0, 10) + np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10]) + + def test_alpha_sin_cos(self): + a = Xector([0, np.pi/2, np.pi]) + s = alpha_sin(a) + c = alpha_cos(a) + np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5) + np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5) + + def test_alpha_sq(self): + a = Xector([1, 2, 3, 4]) + c = alpha_sq(a) + np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16]) + + def test_alpha_comparison(self): + a = Xector([1, 2, 3, 4]) + b = Xector([2, 2, 2, 2]) + lt = alpha_lt(a, b) + gt = alpha_gt(a, b) + eq = alpha_eq(a, b) + np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False]) + np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True]) + np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False]) + + +class TestBetaOperations: + """Test β (reduction) operations.""" + + def test_beta_add(self): + a = Xector([1, 2, 3, 4]) + assert beta_add(a) == 10 + + def test_beta_mul(self): + a = Xector([1, 2, 3, 4]) + assert beta_mul(a) == 24 + + def test_beta_min_max(self): + a = Xector([3, 1, 4, 1, 5, 9, 2, 6]) + assert beta_min(a) == 1 + assert beta_max(a) == 9 + + def test_beta_mean(self): + a = Xector([1, 2, 3, 4]) + assert beta_mean(a) == 2.5 + + def test_beta_count(self): + a = Xector([1, 2, 3, 4, 5]) + assert beta_count(a) == 5 + + +class TestFrameConversion: + """Test frame/xector conversion.""" + + def test_extract_channels(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[255, 0, 0], [0, 255, 0]], + [[0, 0, 255], [128, 128, 128]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + assert len(r) == 4 + np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128]) + np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128]) + np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128]) + + def test_rgb_roundtrip(self): + # Create a 2x2 RGB frame + frame = np.array([ + [[100, 150, 200], [50, 75, 100]], + [[200, 100, 50], [25, 50, 75]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + reconstructed = xector_rgb(r, g, b) + np.testing.assert_array_equal(reconstructed, frame) + + def test_modify_and_reconstruct(self): + frame = np.array([ + [[100, 100, 100], [100, 100, 100]], + [[100, 100, 100], [100, 100, 100]] + ], dtype=np.uint8) + + r = xector_red(frame) + g = xector_green(frame) + b = xector_blue(frame) + + # Double red channel + r_doubled = r * 2 + + result = xector_rgb(r_doubled, g, b) + + # Red should be 200, others unchanged + assert result[0, 0, 0] == 200 + assert result[0, 0, 1] == 100 + assert result[0, 0, 2] == 100 + + +class TestCoordinates: + """Test coordinate generation.""" + + def test_x_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + x = xector_x_coords(frame) + # Should be [0,1,2, 0,1,2] (x coords repeated for each row) + np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2]) + + def test_y_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols + y = xector_y_coords(frame) + # Should be [0,0,0, 1,1,1] (y coords for each pixel) + np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1]) + + def test_normalized_coords(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_x_norm(frame) + y = xector_y_norm(frame) + + # x should go 0 to 1 across width + assert x.to_numpy()[0] == 0 + assert x.to_numpy()[2] == 1 + + # y should go 0 to 1 down height + assert y.to_numpy()[0] == 0 + assert y.to_numpy()[3] == 1 + + +class TestConditional: + """Test conditional operations.""" + + def test_where(self): + cond = Xector([True, False, True, False]) + true_val = Xector([1, 1, 1, 1]) + false_val = Xector([0, 0, 0, 0]) + + result = xector_where(cond, true_val, false_val) + np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0]) + + def test_where_with_comparison(self): + a = Xector([1, 5, 3, 7]) + threshold = 4 + + # Elements > 4 become 255, others become 0 + result = xector_where(alpha_gt(a, threshold), 255, 0) + np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255]) + + def test_fill(self): + frame = np.zeros((2, 3, 3), dtype=np.uint8) + x = xector_fill(42, frame) + assert len(x) == 6 + assert all(v == 42 for v in x.to_numpy()) + + def test_zeros_ones(self): + frame = np.zeros((2, 2, 3), dtype=np.uint8) + z = xector_zeros(frame) + o = xector_ones(frame) + + assert all(v == 0 for v in z.to_numpy()) + assert all(v == 1 for v in o.to_numpy()) + + +class TestInterpreterIntegration: + """Test xector operations through the interpreter.""" + + def test_xector_vignette_effect(self): + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + + # Load the xector vignette effect + interp.load_effect('sexp_effects/effects/xector_vignette.sexp') + + # Create a test frame (white) + frame = np.full((100, 100, 3), 255, dtype=np.uint8) + + # Run effect + result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {}) + + # Center should be brighter than corners + center = result[50, 50] + corner = result[0, 0] + + assert center.mean() > corner.mean(), "Center should be brighter than corners" + # Corners should be darkened + assert corner.mean() < 255, "Corners should be darkened" + + def test_implicit_elementwise(self): + """Test that regular + works element-wise on xectors.""" + from sexp_effects.interpreter import Interpreter + + interp = Interpreter(minimal_primitives=True) + # Load xector primitives + from sexp_effects.primitive_libs.xector import PRIMITIVES + for name, fn in PRIMITIVES.items(): + interp.global_env.set(name, fn) + + # Parse and eval a simple xector expression + from sexp_effects.parser import parse + expr = parse('(+ (red frame) 10)') + + # Create test frame + frame = np.full((2, 2, 3), 100, dtype=np.uint8) + interp.global_env.set('frame', frame) + + result = interp.eval(expr) + + # Should be a xector with values 110 + assert is_xector(result) + np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110]) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])