Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s

- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 15:12:54 +00:00
parent dbc4ece2cc
commit fc9597456f
30 changed files with 7749 additions and 165 deletions

477
path_registry.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Path Registry - Maps human-friendly paths to content-addressed IDs.
This module provides a bidirectional mapping between:
- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp")
- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes)
The registry is useful for:
- Looking up effects by their friendly path name
- Resolving cids back to the original path for debugging
- Maintaining a stable naming scheme across cache updates
Storage:
- Uses the existing item_types table in the database (path column)
- Caches in Redis for fast lookups across distributed workers
The registry uses a system actor (@system@local) for global path mappings,
allowing effects to be resolved by path without requiring user context.
"""
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# System actor for global path mappings (effects, recipes, analyzers)
SYSTEM_ACTOR = "@system@local"
@dataclass
class PathEntry:
"""A registered path with its content-addressed ID."""
path: str # Human-friendly path (relative or normalized)
cid: str # Content-addressed ID (IPFS CID or hash)
content_type: str # Type: "effect", "recipe", "analyzer", etc.
actor_id: str = SYSTEM_ACTOR # Owner (system for global)
description: Optional[str] = None
created_at: float = 0.0
class PathRegistry:
"""
Registry for mapping paths to content-addressed IDs.
Uses the existing item_types table for persistence and Redis
for fast lookups in distributed Celery workers.
"""
def __init__(self, redis_client=None):
self._redis = redis_client
self._redis_path_to_cid_key = "artdag:path_to_cid"
self._redis_cid_to_path_key = "artdag:cid_to_path"
def _run_async(self, coro):
"""Run async coroutine from sync context."""
import asyncio
try:
loop = asyncio.get_running_loop()
import threading
result = [None]
error = [None]
def run_in_thread():
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
result[0] = new_loop.run_until_complete(coro)
finally:
new_loop.close()
except Exception as e:
error[0] = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join(timeout=30)
if error[0]:
raise error[0]
return result[0]
except RuntimeError:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def _normalize_path(self, path: str) -> str:
"""Normalize a path for consistent storage."""
# Remove leading ./ or /
path = path.lstrip('./')
# Normalize separators
path = path.replace('\\', '/')
# Remove duplicate slashes
while '//' in path:
path = path.replace('//', '/')
return path
def register(
self,
path: str,
cid: str,
content_type: str = "effect",
actor_id: str = SYSTEM_ACTOR,
description: Optional[str] = None,
) -> PathEntry:
"""
Register a path -> cid mapping.
Args:
path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp")
cid: Content-addressed ID (IPFS CID or hash)
content_type: Type of content ("effect", "recipe", "analyzer")
actor_id: Owner (default: system for global mappings)
description: Optional description
Returns:
The created PathEntry
"""
norm_path = self._normalize_path(path)
now = datetime.now(timezone.utc).timestamp()
entry = PathEntry(
path=norm_path,
cid=cid,
content_type=content_type,
actor_id=actor_id,
description=description,
created_at=now,
)
# Store in database (item_types table)
self._save_to_db(entry)
# Update Redis cache
self._update_redis_cache(norm_path, cid)
logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...")
return entry
def _save_to_db(self, entry: PathEntry):
"""Save entry to database using item_types table."""
import database
async def save():
import asyncpg
conn = await asyncpg.connect(database.DATABASE_URL)
try:
# Ensure cache_item exists
await conn.execute(
"INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING",
entry.cid
)
# Insert or update item_type with path
await conn.execute(
"""
INSERT INTO item_types (cid, actor_id, type, path, description)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET
description = COALESCE(EXCLUDED.description, item_types.description)
""",
entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description
)
finally:
await conn.close()
try:
self._run_async(save())
except Exception as e:
logger.warning(f"Failed to save path registry to DB: {e}")
def _update_redis_cache(self, path: str, cid: str):
"""Update Redis cache with mapping."""
if self._redis:
try:
self._redis.hset(self._redis_path_to_cid_key, path, cid)
self._redis.hset(self._redis_cid_to_path_key, cid, path)
except Exception as e:
logger.warning(f"Failed to update Redis cache: {e}")
def get_cid(self, path: str, content_type: str = None) -> Optional[str]:
"""
Get the cid for a path.
Args:
path: Human-friendly path
content_type: Optional type filter
Returns:
The cid, or None if not found
"""
norm_path = self._normalize_path(path)
# Try Redis first (fast path)
if self._redis:
try:
val = self._redis.hget(self._redis_path_to_cid_key, norm_path)
if val:
return val.decode() if isinstance(val, bytes) else val
except Exception as e:
logger.warning(f"Redis lookup failed: {e}")
# Fall back to database
return self._get_cid_from_db(norm_path, content_type)
def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]:
"""Get cid from database using item_types table."""
import database
async def get():
import asyncpg
conn = await asyncpg.connect(database.DATABASE_URL)
try:
if content_type:
row = await conn.fetchrow(
"SELECT cid FROM item_types WHERE path = $1 AND type = $2",
path, content_type
)
else:
row = await conn.fetchrow(
"SELECT cid FROM item_types WHERE path = $1",
path
)
return row["cid"] if row else None
finally:
await conn.close()
try:
result = self._run_async(get())
# Update Redis cache if found
if result and self._redis:
self._update_redis_cache(path, result)
return result
except Exception as e:
logger.warning(f"Failed to get from DB: {e}")
return None
def get_path(self, cid: str) -> Optional[str]:
"""
Get the path for a cid.
Args:
cid: Content-addressed ID
Returns:
The path, or None if not found
"""
# Try Redis first
if self._redis:
try:
val = self._redis.hget(self._redis_cid_to_path_key, cid)
if val:
return val.decode() if isinstance(val, bytes) else val
except Exception as e:
logger.warning(f"Redis lookup failed: {e}")
# Fall back to database
return self._get_path_from_db(cid)
def _get_path_from_db(self, cid: str) -> Optional[str]:
"""Get path from database using item_types table."""
import database
async def get():
import asyncpg
conn = await asyncpg.connect(database.DATABASE_URL)
try:
row = await conn.fetchrow(
"SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1",
cid
)
return row["path"] if row else None
finally:
await conn.close()
try:
result = self._run_async(get())
# Update Redis cache if found
if result and self._redis:
self._update_redis_cache(result, cid)
return result
except Exception as e:
logger.warning(f"Failed to get from DB: {e}")
return None
def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]:
"""
List all entries of a given type.
Args:
content_type: Type to filter by ("effect", "recipe", etc.)
actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global)
Returns:
List of PathEntry objects
"""
import database
async def list_entries():
import asyncpg
conn = await asyncpg.connect(database.DATABASE_URL)
try:
if actor_id:
rows = await conn.fetch(
"""
SELECT cid, path, type, actor_id, description,
EXTRACT(EPOCH FROM created_at) as created_at
FROM item_types
WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL
ORDER BY path
""",
content_type, actor_id
)
else:
rows = await conn.fetch(
"""
SELECT cid, path, type, actor_id, description,
EXTRACT(EPOCH FROM created_at) as created_at
FROM item_types
WHERE type = $1 AND path IS NOT NULL
ORDER BY path
""",
content_type
)
return [
PathEntry(
path=row["path"],
cid=row["cid"],
content_type=row["type"],
actor_id=row["actor_id"],
description=row["description"],
created_at=row["created_at"] or 0,
)
for row in rows
]
finally:
await conn.close()
try:
return self._run_async(list_entries())
except Exception as e:
logger.warning(f"Failed to list from DB: {e}")
return []
def delete(self, path: str, content_type: str = None) -> bool:
"""
Delete a path registration.
Args:
path: The path to delete
content_type: Optional type filter
Returns:
True if deleted, False if not found
"""
norm_path = self._normalize_path(path)
# Get cid for Redis cleanup
cid = self.get_cid(norm_path, content_type)
# Delete from database
deleted = self._delete_from_db(norm_path, content_type)
# Clean up Redis
if deleted and cid and self._redis:
try:
self._redis.hdel(self._redis_path_to_cid_key, norm_path)
self._redis.hdel(self._redis_cid_to_path_key, cid)
except Exception as e:
logger.warning(f"Failed to clean up Redis: {e}")
return deleted
def _delete_from_db(self, path: str, content_type: str = None) -> bool:
"""Delete from database."""
import database
async def delete():
import asyncpg
conn = await asyncpg.connect(database.DATABASE_URL)
try:
if content_type:
result = await conn.execute(
"DELETE FROM item_types WHERE path = $1 AND type = $2",
path, content_type
)
else:
result = await conn.execute(
"DELETE FROM item_types WHERE path = $1",
path
)
return "DELETE" in result
finally:
await conn.close()
try:
return self._run_async(delete())
except Exception as e:
logger.warning(f"Failed to delete from DB: {e}")
return False
def register_effect(
self,
path: str,
cid: str,
description: Optional[str] = None,
) -> PathEntry:
"""
Convenience method to register an effect.
Args:
path: Effect path (e.g., "effects/ascii_fx_zone.sexp")
cid: IPFS CID of the effect file
description: Optional description
Returns:
The created PathEntry
"""
return self.register(
path=path,
cid=cid,
content_type="effect",
actor_id=SYSTEM_ACTOR,
description=description,
)
def get_effect_cid(self, path: str) -> Optional[str]:
"""
Get CID for an effect by path.
Args:
path: Effect path
Returns:
IPFS CID or None
"""
return self.get_cid(path, content_type="effect")
def list_effects(self) -> List[PathEntry]:
"""List all registered effects."""
return self.list_by_type("effect", actor_id=SYSTEM_ACTOR)
# Singleton instance
_registry: Optional[PathRegistry] = None
def get_path_registry() -> PathRegistry:
"""Get the singleton path registry instance."""
global _registry
if _registry is None:
import redis
from urllib.parse import urlparse
redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
parsed = urlparse(redis_url)
redis_client = redis.Redis(
host=parsed.hostname or 'localhost',
port=parsed.port or 6379,
db=int(parsed.path.lstrip('/') or 0),
socket_timeout=5,
socket_connect_timeout=5
)
_registry = PathRegistry(redis_client=redis_client)
return _registry
def reset_path_registry():
"""Reset the singleton (for testing)."""
global _registry
_registry = None

206
sexp_effects/derived.sexp Normal file
View File

@@ -0,0 +1,206 @@
;; Derived Operations
;;
;; These are built from true primitives using S-expressions.
;; Load with: (require "derived")
;; =============================================================================
;; Math Helpers (derivable from where + basic ops)
;; =============================================================================
;; Absolute value
(define (abs x) (where (< x 0) (- x) x))
;; Minimum of two values
(define (min2 a b) (where (< a b) a b))
;; Maximum of two values
(define (max2 a b) (where (> a b) a b))
;; Clamp x to range [lo, hi]
(define (clamp x lo hi) (max2 lo (min2 hi x)))
;; Square of x
(define (sq x) (* x x))
;; Linear interpolation: a*(1-t) + b*t
(define (lerp a b t) (+ (* a (- 1 t)) (* b t)))
;; Smooth interpolation between edges
(define (smoothstep edge0 edge1 x)
(let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1)))
(* t (* t (- 3 (* 2 t))))))
;; =============================================================================
;; Channel Shortcuts (derivable from channel primitive)
;; =============================================================================
;; Extract red channel as xector
(define (red frame) (channel frame 0))
;; Extract green channel as xector
(define (green frame) (channel frame 1))
;; Extract blue channel as xector
(define (blue frame) (channel frame 2))
;; Convert to grayscale xector (ITU-R BT.601)
(define (gray frame)
(+ (* (red frame) 0.299)
(* (green frame) 0.587)
(* (blue frame) 0.114)))
;; Alias for gray
(define (luminance frame) (gray frame))
;; =============================================================================
;; Coordinate Generators (derivable from iota + repeat/tile)
;; =============================================================================
;; X coordinate for each pixel [0, width)
(define (x-coords frame) (tile (iota (width frame)) (height frame)))
;; Y coordinate for each pixel [0, height)
(define (y-coords frame) (repeat (iota (height frame)) (width frame)))
;; Normalized X coordinate [0, 1]
(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1))))
;; Normalized Y coordinate [0, 1]
(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1))))
;; Distance from frame center for each pixel
(define (dist-from-center frame)
(let* ((cx (/ (width frame) 2))
(cy (/ (height frame) 2))
(dx (- (x-coords frame) cx))
(dy (- (y-coords frame) cy)))
(sqrt (+ (sq dx) (sq dy)))))
;; Normalized distance from center [0, ~1]
(define (dist-norm frame)
(let ((d (dist-from-center frame)))
(/ d (max2 1 (βmax d)))))
;; =============================================================================
;; Cell/Grid Operations (derivable from floor + basic math)
;; =============================================================================
;; Cell row index for each pixel
(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size)))
;; Cell column index for each pixel
(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size)))
;; Number of cell rows
(define (num-rows frame cell-size) (floor (/ (height frame) cell-size)))
;; Number of cell columns
(define (num-cols frame cell-size) (floor (/ (width frame) cell-size)))
;; Flat cell index for each pixel
(define (cell-indices frame cell-size)
(+ (* (cell-row frame cell-size) (num-cols frame cell-size))
(cell-col frame cell-size)))
;; Total number of cells
(define (num-cells frame cell-size)
(* (num-rows frame cell-size) (num-cols frame cell-size)))
;; X position within cell [0, cell-size)
(define (local-x frame cell-size) (mod (x-coords frame) cell-size))
;; Y position within cell [0, cell-size)
(define (local-y frame cell-size) (mod (y-coords frame) cell-size))
;; Normalized X within cell [0, 1]
(define (local-x-norm frame cell-size)
(/ (local-x frame cell-size) (max2 1 (- cell-size 1))))
;; Normalized Y within cell [0, 1]
(define (local-y-norm frame cell-size)
(/ (local-y frame cell-size) (max2 1 (- cell-size 1))))
;; =============================================================================
;; Fill Operations (derivable from iota)
;; =============================================================================
;; Xector of n zeros
(define (zeros n) (* (iota n) 0))
;; Xector of n ones
(define (ones n) (+ (zeros n) 1))
;; Xector of n copies of val
(define (fill val n) (+ (zeros n) val))
;; Xector of zeros matching x's length
(define (zeros-like x) (* x 0))
;; Xector of ones matching x's length
(define (ones-like x) (+ (zeros-like x) 1))
;; =============================================================================
;; Pooling (derivable from group-reduce)
;; =============================================================================
;; Pool a channel by cell index
(define (pool-channel chan cell-idx num-cells)
(group-reduce chan cell-idx num-cells "mean"))
;; Pool red channel to cells
(define (pool-red frame cell-size)
(pool-channel (red frame)
(cell-indices frame cell-size)
(num-cells frame cell-size)))
;; Pool green channel to cells
(define (pool-green frame cell-size)
(pool-channel (green frame)
(cell-indices frame cell-size)
(num-cells frame cell-size)))
;; Pool blue channel to cells
(define (pool-blue frame cell-size)
(pool-channel (blue frame)
(cell-indices frame cell-size)
(num-cells frame cell-size)))
;; Pool grayscale to cells
(define (pool-gray frame cell-size)
(pool-channel (gray frame)
(cell-indices frame cell-size)
(num-cells frame cell-size)))
;; =============================================================================
;; Blending (derivable from math)
;; =============================================================================
;; Additive blend
(define (blend-add a b) (clamp (+ a b) 0 255))
;; Multiply blend (normalized)
(define (blend-multiply a b) (* (/ a 255) b))
;; Screen blend
(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b))))
;; Overlay blend
(define (blend-overlay a b)
(where (< a 128)
(* 2 (/ (* a b) 255))
(- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255)))))
;; =============================================================================
;; Simple Effects (derivable from primitives)
;; =============================================================================
;; Invert a channel (255 - c)
(define (invert-channel c) (- 255 c))
;; Binary threshold
(define (threshold-channel c thresh) (where (> c thresh) 255 0))
;; Reduce to n levels
(define (posterize-channel c levels)
(let ((step (/ 255 (- levels 1))))
(* (round (/ c step)) step)))

View File

@@ -5,7 +5,7 @@
:params ( :params (
(char_size :type int :default 8 :range [4 32]) (char_size :type int :default 8 :range [4 32])
(alphabet :type string :default "standard") (alphabet :type string :default "standard")
(color_mode :type string :default "color" :desc ""color", "mono", "invert", or any color name/hex") (color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex")
(background_color :type string :default "black" :desc "background color name/hex") (background_color :type string :default "black" :desc "background color name/hex")
(invert_colors :type int :default 0 :desc "swap foreground and background colors") (invert_colors :type int :default 0 :desc "swap foreground and background colors")
(contrast :type float :default 1.5 :range [1 3]) (contrast :type float :default 1.5 :range [1 3])

View File

@@ -0,0 +1,65 @@
;; Cell Pattern effect - custom patterns within cells
;;
;; Demonstrates building arbitrary per-cell visuals from primitives.
;; Uses local coordinates within cells to draw patterns scaled by luminance.
(require-primitives "xector")
(define-effect cell_pattern
:params (
(cell-size :type int :default 16 :range [8 48] :desc "Cell size")
(pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring")
)
(let* (
;; Pool to get cell colors
(pooled (pool-frame frame cell-size))
(cell-r (nth pooled 0))
(cell-g (nth pooled 1))
(cell-b (nth pooled 2))
(cell-lum (α/ (nth pooled 3) 255))
;; Cell indices for each pixel
(cell-idx (cell-indices frame cell-size))
;; Look up cell values for each pixel
(pix-r (gather cell-r cell-idx))
(pix-g (gather cell-g cell-idx))
(pix-b (gather cell-b cell-idx))
(pix-lum (gather cell-lum cell-idx))
;; Local position within cell [0, 1]
(lx (local-x-norm frame cell-size))
(ly (local-y-norm frame cell-size))
;; Pattern mask based on pattern type
(mask
(cond
;; Diagonal lines - thickness based on luminance
((= pattern "diagonal")
(let* ((diag (αmod (α+ lx ly) 0.25))
(thickness (α* pix-lum 0.125)))
(α< diag thickness)))
;; Cross pattern
((= pattern "cross")
(let* ((cx (αabs (α- lx 0.5)))
(cy (αabs (α- ly 0.5)))
(thickness (α* pix-lum 0.25)))
(αor (α< cx thickness) (α< cy thickness))))
;; Ring pattern
((= pattern "ring")
(let* ((dx (α- lx 0.5))
(dy (α- ly 0.5))
(dist (αsqrt (α+ (α² dx) (α² dy))))
(target (α* pix-lum 0.4))
(thickness 0.05))
(α< (αabs (α- dist target)) thickness)))
;; Default: solid
(else (α> pix-lum 0)))))
;; Apply mask: show cell color where mask is true, black elsewhere
(rgb (where mask pix-r 0)
(where mask pix-g 0)
(where mask pix-b 0))))

View File

@@ -6,10 +6,10 @@
(num_echoes :type int :default 4 :range [1 20]) (num_echoes :type int :default 4 :range [1 20])
(decay :type float :default 0.5 :range [0 1]) (decay :type float :default 0.5 :range [0 1])
) )
(let* ((buffer (state-get 'buffer (list))) (let* ((buffer (state-get "buffer" (list)))
(new-buffer (take (cons frame buffer) (+ num_echoes 1)))) (new-buffer (take (cons frame buffer) (+ num_echoes 1))))
(begin (begin
(state-set 'buffer new-buffer) (state-set "buffer" new-buffer)
;; Blend frames with decay ;; Blend frames with decay
(if (< (length new-buffer) 2) (if (< (length new-buffer) 2)
frame frame

View File

@@ -0,0 +1,49 @@
;; Halftone/dot effect - built from primitive xector operations
;;
;; Uses:
;; pool-frame - downsample to cell luminances
;; cell-indices - which cell each pixel belongs to
;; gather - look up cell value for each pixel
;; local-x/y-norm - position within cell [0,1]
;; where - conditional per-pixel
(require-primitives "xector")
(define-effect halftone
:params (
(cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells")
(dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius")
(invert :type bool :default false :desc "Invert (white dots on black)")
)
(let* (
;; Pool frame to get luminance per cell
(pooled (pool-frame frame cell-size))
(cell-lum (nth pooled 3)) ; luminance is 4th element
;; For each output pixel, get its cell index
(cell-idx (cell-indices frame cell-size))
;; Get cell luminance for each pixel
(pixel-lum (α/ (gather cell-lum cell-idx) 255))
;; Position within cell, normalized to [-0.5, 0.5]
(lx (α- (local-x-norm frame cell-size) 0.5))
(ly (α- (local-y-norm frame cell-size) 0.5))
;; Distance from cell center (0 at center, ~0.7 at corners)
(dist (αsqrt (α+ (α² lx) (α² ly))))
;; Radius based on luminance (brighter = bigger dot)
(radius (α* (if invert (α- 1 pixel-lum) pixel-lum)
(α* dot-scale 0.5)))
;; Is this pixel inside the dot?
(inside (α< dist radius))
;; Output color
(fg (if invert 255 0))
(bg (if invert 0 255))
(out (where inside fg bg)))
;; Grayscale output
(rgb out out out)))

View File

@@ -0,0 +1,30 @@
;; Mosaic effect - built from primitive xector operations
;;
;; Uses:
;; pool-frame - downsample to cell averages
;; cell-indices - which cell each pixel belongs to
;; gather - look up cell value for each pixel
(require-primitives "xector")
(define-effect mosaic
:params (
(cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells")
)
(let* (
;; Pool frame to get average color per cell (returns r,g,b,lum xectors)
(pooled (pool-frame frame cell-size))
(cell-r (nth pooled 0))
(cell-g (nth pooled 1))
(cell-b (nth pooled 2))
;; For each output pixel, get its cell index
(cell-idx (cell-indices frame cell-size))
;; Gather: look up cell color for each pixel
(out-r (gather cell-r cell-idx))
(out-g (gather cell-g cell-idx))
(out-b (gather cell-b cell-idx)))
;; Reconstruct frame
(rgb out-r out-g out-b)))

View File

@@ -5,9 +5,9 @@
:params ( :params (
(thickness :type int :default 2 :range [1 10]) (thickness :type int :default 2 :range [1 10])
(threshold :type int :default 100 :range [20 300]) (threshold :type int :default 100 :range [20 300])
(color :type list :default (list 0 0 0) (color :type list :default (list 0 0 0))
(fill_mode :type string :default "original")
) )
(fill_mode "original"))
(let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold)) (let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold))
(dilated (if (> thickness 1) (dilated (if (> thickness 1)
(dilate edge-img thickness) (dilate edge-img thickness)

View File

@@ -5,12 +5,12 @@
:params ( :params (
(frame_rate :type int :default 12 :range [1 60]) (frame_rate :type int :default 12 :range [1 60])
) )
(let* ((held (state-get 'held nil)) (let* ((held (state-get "held" nil))
(held-until (state-get 'held-until 0)) (held-until (state-get "held-until" 0))
(frame-duration (/ 1 frame_rate))) (frame-duration (/ 1 frame_rate)))
(if (or (core:is-nil held) (>= t held-until)) (if (or (core:is-nil held) (>= t held-until))
(begin (begin
(state-set 'held (copy frame)) (state-set "held" (copy frame))
(state-set 'held-until (+ t frame-duration)) (state-set "held-until" (+ t frame-duration))
frame) frame)
held))) held)))

View File

@@ -5,16 +5,16 @@
:params ( :params (
(persistence :type float :default 0.8 :range [0 0.99]) (persistence :type float :default 0.8 :range [0 0.99])
) )
(let* ((buffer (state-get 'buffer nil)) (let* ((buffer (state-get "buffer" nil))
(current frame)) (current frame))
(if (= buffer nil) (if (= buffer nil)
(begin (begin
(state-set 'buffer (copy frame)) (state-set "buffer" (copy frame))
frame) frame)
(let* ((faded (blending:blend-images buffer (let* ((faded (blending:blend-images buffer
(make-image (image:width frame) (image:height frame) (list 0 0 0)) (make-image (image:width frame) (image:height frame) (list 0 0 0))
(- 1 persistence))) (- 1 persistence)))
(result (blending:blend-mode faded current "lighten"))) (result (blending:blend-mode faded current "lighten")))
(begin (begin
(state-set 'buffer result) (state-set "buffer" result)
result))))) result)))))

View File

@@ -0,0 +1,44 @@
;; Feathered blend - blend two same-size frames with distance-based falloff
;; Center shows overlay, edges show background, with smooth transition
(require-primitives "xector")
(define-effect xector_feathered_blend
:params (
(inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)")
(fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)")
(overlay :type frame :default nil :desc "Frame to blend in center")
)
(let* (
;; Get normalized distance from center (0 at center, ~1 at corners)
(dist (dist-from-center frame))
(max-dist (βmax dist))
(dist-norm (α/ dist max-dist))
;; Calculate blend factor:
;; - 1.0 when dist-norm < inner-radius (fully overlay)
;; - 0.0 when dist-norm > inner-radius + fade-width (fully background)
;; - linear ramp between
(t (α/ (α- dist-norm inner-radius) fade-width))
(blend (α- 1 (αclamp t 0 1)))
(inv-blend (α- 1 blend))
;; Background channels
(bg-r (red frame))
(bg-g (green frame))
(bg-b (blue frame)))
(if (nil? overlay)
;; No overlay - visualize the blend mask
(let ((vis (α* blend 255)))
(rgb vis vis vis))
;; Blend overlay with background using the mask
(let* ((ov-r (red overlay))
(ov-g (green overlay))
(ov-b (blue overlay))
;; lerp: bg * (1-blend) + overlay * blend
(r-out (α+ (α* bg-r inv-blend) (α* ov-r blend)))
(g-out (α+ (α* bg-g inv-blend) (α* ov-g blend)))
(b-out (α+ (α* bg-b inv-blend) (α* ov-b blend))))
(rgb r-out g-out b-out)))))

View File

@@ -0,0 +1,34 @@
;; Film grain effect using xector operations
;; Demonstrates random xectors and mixing scalar/xector math
(require-primitives "xector")
(define-effect xector_grain
:params (
(intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity")
(colored :type bool :default false :desc "Use colored grain")
)
(let* (
;; Extract channels
(r (red frame))
(g (green frame))
(b (blue frame))
;; Generate noise xector(s)
;; randn-x generates normal distribution noise
(grain-amount (* intensity 50)))
(if colored
;; Colored grain: different noise per channel
(let* ((nr (randn-x frame 0 grain-amount))
(ng (randn-x frame 0 grain-amount))
(nb (randn-x frame 0 grain-amount)))
(rgb (αclamp (α+ r nr) 0 255)
(αclamp (α+ g ng) 0 255)
(αclamp (α+ b nb) 0 255)))
;; Monochrome grain: same noise for all channels
(let ((n (randn-x frame 0 grain-amount)))
(rgb (αclamp (α+ r n) 0 255)
(αclamp (α+ g n) 0 255)
(αclamp (α+ b n) 0 255))))))

View File

@@ -0,0 +1,57 @@
;; Inset blend - fade a smaller frame into a larger background
;; Uses distance-based alpha for smooth transition (no hard edges)
(require-primitives "xector")
(define-effect xector_inset_blend
:params (
(x :type int :default 0 :desc "X position of inset")
(y :type int :default 0 :desc "Y position of inset")
(fade-width :type int :default 50 :desc "Width of fade region in pixels")
(overlay :type frame :default nil :desc "The smaller frame to inset")
)
(let* (
;; Get dimensions
(bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height
(bg-w bg-h) ;; placeholder
;; For now, create a simple centered circular blend
;; Distance from center of overlay position
(cx (+ x (/ (- bg-w (* 2 x)) 2)))
(cy (+ y (/ (- bg-h (* 2 y)) 2)))
;; Get coordinates as xectors
(px (x-coords frame))
(py (y-coords frame))
;; Distance from center
(dx (α- px cx))
(dy (α- py cy))
(dist (αsqrt (α+ (α* dx dx) (α* dy dy))))
;; Inner radius (fully overlay) and outer radius (fully background)
(inner-r (- (/ bg-w 2) x fade-width))
(outer-r (- (/ bg-w 2) x))
;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between
(t (α/ (α- dist inner-r) fade-width))
(blend (α- 1 (αclamp t 0 1)))
;; Extract channels from both frames
(bg-r (red frame))
(bg-g (green frame))
(bg-b (blue frame)))
;; If overlay provided, blend it
(if overlay
(let* ((ov-r (red overlay))
(ov-g (green overlay))
(ov-b (blue overlay))
;; Linear blend: result = bg * (1-blend) + overlay * blend
(r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend)))
(g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend)))
(b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend))))
(rgb r-out g-out b-out))
;; No overlay - just show the blend mask for debugging
(let ((mask-vis (α* blend 255)))
(rgb mask-vis mask-vis mask-vis)))))

View File

@@ -0,0 +1,27 @@
;; Threshold effect using xector operations
;; Demonstrates where (conditional select) and β (reduction) for normalization
(require-primitives "xector")
(define-effect xector_threshold
:params (
(threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)")
(invert :type bool :default false :desc "Invert the threshold")
)
(let* (
;; Get grayscale luminance as xector
(luma (gray frame))
;; Normalize to 0-1 range
(luma-norm (α/ luma 255))
;; Create boolean mask: pixels above threshold
(mask (if invert
(α< luma-norm threshold)
(α>= luma-norm threshold)))
;; Use where to select: white (255) if above threshold, black (0) if below
(out (where mask 255 0)))
;; Output as grayscale (same value for R, G, B)
(rgb out out out)))

View File

@@ -0,0 +1,36 @@
;; Vignette effect using xector operations
;; Demonstrates α (element-wise) and β (reduction) patterns
(require-primitives "xector")
(define-effect xector_vignette
:params (
(strength :type float :default 0.5 :range [0 1])
(radius :type float :default 1.0 :range [0.5 2])
)
(let* (
;; Get normalized distance from center for each pixel
(dist (dist-from-center frame))
;; Calculate max distance (corner distance)
(max-dist (* (βmax dist) radius))
;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength)
;; Using explicit α operators
(factor (α- 1 (α* (α/ dist max-dist) strength)))
;; Clamp factor to [0, 1]
(factor (αclamp factor 0 1))
;; Extract channels as xectors
(r (red frame))
(g (green frame))
(b (blue frame))
;; Apply factor to each channel (implicit element-wise via Xector operators)
(r-out (* r factor))
(g-out (* g factor))
(b-out (* b factor)))
;; Combine back to frame
(rgb r-out g-out b-out)))

View File

@@ -156,11 +156,21 @@ class Interpreter:
if form == 'define': if form == 'define':
name = expr[1] name = expr[1]
if _is_symbol(name): if _is_symbol(name):
# Simple define: (define name value)
value = self.eval(expr[2], env) value = self.eval(expr[2], env)
self.global_env.set(name.name, value) self.global_env.set(name.name, value)
return value return value
elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]):
# Function define: (define (fn-name args...) body)
# Desugars to: (define fn-name (lambda (args...) body))
fn_name = name[0].name
params = [p.name if _is_symbol(p) else p for p in name[1:]]
body = expr[2]
fn = Lambda(params, body, env)
self.global_env.set(fn_name, fn)
return fn
else: else:
raise SyntaxError(f"define requires symbol, got {name}") raise SyntaxError(f"define requires symbol or (name args...), got {name}")
# Define-effect # Define-effect
if form == 'define-effect': if form == 'define-effect':
@@ -276,6 +286,10 @@ class Interpreter:
if form == 'require-primitives': if form == 'require-primitives':
return self._eval_require_primitives(expr, env) return self._eval_require_primitives(expr, env)
# require - load .sexp file into current scope
if form == 'require':
return self._eval_require(expr, env)
# Function call # Function call
fn = self.eval(head, env) fn = self.eval(head, env)
args = [self.eval(arg, env) for arg in expr[1:]] args = [self.eval(arg, env) for arg in expr[1:]]
@@ -488,6 +502,61 @@ class Interpreter:
from .primitive_libs import load_primitive_library from .primitive_libs import load_primitive_library
return load_primitive_library(name, path) return load_primitive_library(name, path)
def _eval_require(self, expr: Any, env: Environment) -> Any:
"""
Evaluate require: load a .sexp file and evaluate its definitions.
Syntax:
(require "derived") ; loads derived.sexp from sexp_effects/
(require "path/to/file.sexp") ; loads from explicit path
Definitions from the file are added to the current environment.
"""
for lib_expr in expr[1:]:
if _is_symbol(lib_expr):
lib_name = lib_expr.name
else:
lib_name = lib_expr
# Find the .sexp file
sexp_path = self._find_sexp_file(lib_name)
if sexp_path is None:
raise ValueError(f"Cannot find sexp file: {lib_name}")
# Parse and evaluate the file
content = parse_file(sexp_path)
# Evaluate all top-level expressions
if isinstance(content, list) and content and isinstance(content[0], list):
for e in content:
self.eval(e, env)
else:
self.eval(content, env)
return None
def _find_sexp_file(self, name: str) -> Optional[str]:
"""Find a .sexp file by name."""
# Try various locations
candidates = [
# Explicit path
name,
name + '.sexp',
# In sexp_effects directory
Path(__file__).parent / f'{name}.sexp',
Path(__file__).parent / name,
# In effects directory
Path(__file__).parent / 'effects' / f'{name}.sexp',
Path(__file__).parent / 'effects' / name,
]
for path in candidates:
p = Path(path) if not isinstance(path, Path) else path
if p.exists() and p.is_file():
return str(p)
return None
def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any: def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any:
""" """
Evaluate ascii-fx-zone special form. Evaluate ascii-fx-zone special form.
@@ -876,8 +945,8 @@ class Interpreter:
for pname, pdefault in effect.params.items(): for pname, pdefault in effect.params.items():
value = params.get(pname) value = params.get(pname)
if value is None: if value is None:
# Evaluate default if it's an expression (list) # Evaluate default if it's an expression (list) or a symbol (like 'nil')
if isinstance(pdefault, list): if isinstance(pdefault, list) or _is_symbol(pdefault):
value = self.eval(pdefault, env) value = self.eval(pdefault, env)
else: else:
value = pdefault value = pdefault

View File

@@ -71,7 +71,8 @@ class Tokenizer:
STRING = re.compile(r'"(?:[^"\\]|\\.)*"') STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*') # Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*')
def __init__(self, text: str): def __init__(self, text: str):
self.text = text self.text = text

View File

@@ -1,126 +1,680 @@
""" """
Drawing Primitives Library Drawing Primitives Library
Draw shapes, text, and characters on images. Draw shapes, text, and characters on images with sophisticated text handling.
Text Features:
- Font loading from files or system fonts
- Text measurement and fitting
- Alignment (left/center/right, top/middle/bottom)
- Opacity for fade effects
- Multi-line text support
- Shadow and outline effects
""" """
import numpy as np import numpy as np
import cv2 import cv2
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import os
import glob as glob_module
from typing import Optional, Tuple, List, Union
# Default font (will be loaded lazily) # =============================================================================
_default_font = None # Font Management
# =============================================================================
# Font cache: (path, size) -> font object
_font_cache = {}
# Common system font directories
FONT_DIRS = [
"/usr/share/fonts",
"/usr/local/share/fonts",
"~/.fonts",
"~/.local/share/fonts",
"/System/Library/Fonts", # macOS
"/Library/Fonts", # macOS
"C:/Windows/Fonts", # Windows
]
# Default fonts to try (in order of preference)
DEFAULT_FONTS = [
"DejaVuSans.ttf",
"DejaVuSansMono.ttf",
"Arial.ttf",
"Helvetica.ttf",
"FreeSans.ttf",
"LiberationSans-Regular.ttf",
]
def _get_default_font(size=16): def _find_font_file(name: str) -> Optional[str]:
"""Get default font, creating if needed.""" """Find a font file by name in system directories."""
global _default_font # If it's already a full path
if _default_font is None or _default_font.size != size: if os.path.isfile(name):
return name
# Expand user paths
expanded = os.path.expanduser(name)
if os.path.isfile(expanded):
return expanded
# Search in font directories
for font_dir in FONT_DIRS:
font_dir = os.path.expanduser(font_dir)
if not os.path.isdir(font_dir):
continue
# Direct match
direct = os.path.join(font_dir, name)
if os.path.isfile(direct):
return direct
# Recursive search
for root, dirs, files in os.walk(font_dir):
for f in files:
if f.lower() == name.lower():
return os.path.join(root, f)
# Also match without extension
base = os.path.splitext(f)[0]
if base.lower() == name.lower():
return os.path.join(root, f)
return None
def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont:
"""Get a default font at the given size."""
for font_name in DEFAULT_FONTS:
path = _find_font_file(font_name)
if path:
try: try:
_default_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size) return ImageFont.truetype(path, size)
except: except:
_default_font = ImageFont.load_default() continue
return _default_font
# Last resort: PIL default
return ImageFont.load_default()
def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont:
"""
Load a font by name or path.
(make-font "Arial" 32) ; system font by name
(make-font "/path/to/font.ttf" 24) ; font file path
(make-font "DejaVuSans" 48) ; searches common locations
Returns a font object for use with text primitives.
"""
size = int(size)
# Check cache
cache_key = (name_or_path, size)
if cache_key in _font_cache:
return _font_cache[cache_key]
# Find the font file
path = _find_font_file(name_or_path)
if not path:
raise FileNotFoundError(f"Font not found: {name_or_path}")
# Load and cache
font = ImageFont.truetype(path, size)
_font_cache[cache_key] = font
return font
def prim_list_fonts() -> List[str]:
"""
List available system fonts.
(list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...)
Returns list of font filenames found in system directories.
"""
fonts = set()
for font_dir in FONT_DIRS:
font_dir = os.path.expanduser(font_dir)
if not os.path.isdir(font_dir):
continue
for root, dirs, files in os.walk(font_dir):
for f in files:
if f.lower().endswith(('.ttf', '.otf', '.ttc')):
fonts.add(f)
return sorted(fonts)
def prim_font_size(font: ImageFont.FreeTypeFont) -> int:
"""
Get the size of a font.
(font-size my-font) ; -> 24
"""
return font.size
# =============================================================================
# Text Measurement
# =============================================================================
def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]:
"""
Measure text dimensions.
(text-size "Hello" my-font) ; -> (width height)
(text-size "Hello" :font-size 32) ; -> (width height) with default font
For multi-line text, returns total bounding box.
"""
if font is None:
font = _get_default_font(int(font_size))
elif isinstance(font, (int, float)):
font = _get_default_font(int(font))
# Create temporary image for measurement
img = Image.new('RGB', (1, 1))
draw = ImageDraw.Draw(img)
bbox = draw.textbbox((0, 0), str(text), font=font)
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
return (width, height)
def prim_text_metrics(font=None, font_size: int = 24) -> dict:
"""
Get font metrics.
(text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25}
Useful for precise text layout.
"""
if font is None:
font = _get_default_font(int(font_size))
elif isinstance(font, (int, float)):
font = _get_default_font(int(font))
ascent, descent = font.getmetrics()
return {
'ascent': ascent,
'descent': descent,
'height': ascent + descent,
'size': font.size,
}
def prim_fit_text_size(text: str, max_width: int, max_height: int,
font_name: str = None, min_size: int = 8,
max_size: int = 500) -> int:
"""
Calculate font size to fit text within bounds.
(fit-text-size "Hello World" 400 100) ; -> 48
(fit-text-size "Title" 800 200 :font-name "Arial")
Returns the largest font size that fits within max_width x max_height.
"""
max_width = int(max_width)
max_height = int(max_height)
min_size = int(min_size)
max_size = int(max_size)
text = str(text)
# Binary search for optimal size
best_size = min_size
low, high = min_size, max_size
while low <= high:
mid = (low + high) // 2
if font_name:
try:
font = prim_make_font(font_name, mid)
except:
font = _get_default_font(mid)
else:
font = _get_default_font(mid)
w, h = prim_text_size(text, font)
if w <= max_width and h <= max_height:
best_size = mid
low = mid + 1
else:
high = mid - 1
return best_size
def prim_fit_font(text: str, max_width: int, max_height: int,
font_name: str = None, min_size: int = 8,
max_size: int = 500) -> ImageFont.FreeTypeFont:
"""
Create a font sized to fit text within bounds.
(fit-font "Hello World" 400 100) ; -> font object
(fit-font "Title" 800 200 :font-name "Arial")
Returns a font object at the optimal size.
"""
size = prim_fit_text_size(text, max_width, max_height,
font_name, min_size, max_size)
if font_name:
try:
return prim_make_font(font_name, size)
except:
pass
return _get_default_font(size)
# =============================================================================
# Text Drawing
# =============================================================================
def prim_text(img: np.ndarray, text: str,
x: int = None, y: int = None,
width: int = None, height: int = None,
font=None, font_size: int = 24, font_name: str = None,
color=None, opacity: float = 1.0,
align: str = "left", valign: str = "top",
fit: bool = False,
shadow: bool = False, shadow_color=None, shadow_offset: int = 2,
outline: bool = False, outline_color=None, outline_width: int = 1,
line_spacing: float = 1.2) -> np.ndarray:
"""
Draw text with alignment, opacity, and effects.
Basic usage:
(text frame "Hello" :x 100 :y 50)
Centered in frame:
(text frame "Title" :align "center" :valign "middle")
Fit to box:
(text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true)
With fade (for animations):
(text frame "Fading" :x 100 :y 100 :opacity 0.5)
With effects:
(text frame "Shadow" :x 100 :y 100 :shadow true)
(text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0))
Args:
img: Input frame
text: Text to draw
x, y: Position (if not specified, uses alignment in full frame)
width, height: Bounding box (for fit and alignment within box)
font: Font object from make-font
font_size: Size if no font specified
font_name: Font name to load
color: RGB tuple (default white)
opacity: 0.0 (invisible) to 1.0 (opaque) for fading
align: "left", "center", "right"
valign: "top", "middle", "bottom"
fit: If true, auto-size font to fit in box
shadow: Draw drop shadow
shadow_color: Shadow color (default black)
shadow_offset: Shadow offset in pixels
outline: Draw text outline
outline_color: Outline color (default black)
outline_width: Outline thickness
line_spacing: Multiplier for line height (for multi-line)
Returns:
Frame with text drawn
"""
h, w = img.shape[:2]
text = str(text)
# Default colors
if color is None:
color = (255, 255, 255)
else:
color = tuple(int(c) for c in color)
if shadow_color is None:
shadow_color = (0, 0, 0)
else:
shadow_color = tuple(int(c) for c in shadow_color)
if outline_color is None:
outline_color = (0, 0, 0)
else:
outline_color = tuple(int(c) for c in outline_color)
# Determine bounding box
if x is None:
x = 0
if width is None:
width = w
if y is None:
y = 0
if height is None:
height = h
x, y = int(x), int(y)
box_width = int(width) if width else w - x
box_height = int(height) if height else h - y
# Get or create font
if font is None:
if fit:
font = prim_fit_font(text, box_width, box_height, font_name)
elif font_name:
try:
font = prim_make_font(font_name, int(font_size))
except:
font = _get_default_font(int(font_size))
else:
font = _get_default_font(int(font_size))
# Measure text
text_w, text_h = prim_text_size(text, font)
# Calculate position based on alignment
if align == "center":
draw_x = x + (box_width - text_w) // 2
elif align == "right":
draw_x = x + box_width - text_w
else: # left
draw_x = x
if valign == "middle":
draw_y = y + (box_height - text_h) // 2
elif valign == "bottom":
draw_y = y + box_height - text_h
else: # top
draw_y = y
# Create RGBA image for compositing with opacity
pil_img = Image.fromarray(img).convert('RGBA')
# Create text layer with transparency
text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
draw = ImageDraw.Draw(text_layer)
# Draw shadow first (if enabled)
if shadow:
shadow_x = draw_x + shadow_offset
shadow_y = draw_y + shadow_offset
shadow_rgba = shadow_color + (int(255 * opacity * 0.5),)
draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font)
# Draw outline (if enabled)
if outline:
outline_rgba = outline_color + (int(255 * opacity),)
ow = int(outline_width)
for dx in range(-ow, ow + 1):
for dy in range(-ow, ow + 1):
if dx != 0 or dy != 0:
draw.text((draw_x + dx, draw_y + dy), text,
fill=outline_rgba, font=font)
# Draw main text
text_rgba = color + (int(255 * opacity),)
draw.text((draw_x, draw_y), text, fill=text_rgba, font=font)
# Composite
result = Image.alpha_composite(pil_img, text_layer)
return np.array(result.convert('RGB'))
def prim_text_box(img: np.ndarray, text: str,
x: int, y: int, width: int, height: int,
font=None, font_size: int = 24, font_name: str = None,
color=None, opacity: float = 1.0,
align: str = "center", valign: str = "middle",
fit: bool = True,
padding: int = 0,
background=None, background_opacity: float = 0.5,
**kwargs) -> np.ndarray:
"""
Draw text fitted within a box, optionally with background.
(text-box frame "Title" 50 50 400 100)
(text-box frame "Subtitle" 50 160 400 50
:background (0 0 0) :background-opacity 0.7)
Convenience wrapper around text() for common box-with-text pattern.
"""
x, y = int(x), int(y)
width, height = int(width), int(height)
padding = int(padding)
result = img.copy()
# Draw background if specified
if background is not None:
bg_color = tuple(int(c) for c in background)
# Create background with opacity
pil_img = Image.fromarray(result).convert('RGBA')
bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
bg_draw = ImageDraw.Draw(bg_layer)
bg_rgba = bg_color + (int(255 * background_opacity),)
bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba)
result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB'))
# Draw text within padded box
return prim_text(result, text,
x=x + padding, y=y + padding,
width=width - 2 * padding, height=height - 2 * padding,
font=font, font_size=font_size, font_name=font_name,
color=color, opacity=opacity,
align=align, valign=valign, fit=fit,
**kwargs)
# =============================================================================
# Legacy text functions (keep for compatibility)
# =============================================================================
def prim_draw_char(img, char, x, y, font_size=16, color=None): def prim_draw_char(img, char, x, y, font_size=16, color=None):
"""Draw a single character at (x, y).""" """Draw a single character at (x, y). Legacy function."""
if color is None: return prim_text(img, str(char), x=int(x), y=int(y),
color = [255, 255, 255] font_size=int(font_size), color=color)
pil_img = Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
font = _get_default_font(font_size)
draw.text((x, y), char, fill=tuple(color), font=font)
return np.array(pil_img)
def prim_draw_text(img, text, x, y, font_size=16, color=None): def prim_draw_text(img, text, x, y, font_size=16, color=None):
"""Draw text string at (x, y).""" """Draw text string at (x, y). Legacy function."""
return prim_text(img, str(text), x=int(x), y=int(y),
font_size=int(font_size), color=color)
# =============================================================================
# Shape Drawing
# =============================================================================
def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0):
"""
Fill a rectangle with color.
(fill-rect frame 10 10 100 50 (255 0 0))
(fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5)
"""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
pil_img = Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
font = _get_default_font(font_size)
draw.text((x, y), text, fill=tuple(color), font=font)
return np.array(pil_img)
def prim_fill_rect(img, x, y, w, h, color=None):
"""Fill a rectangle with color."""
if color is None:
color = [255, 255, 255]
result = img.copy()
x, y, w, h = int(x), int(y), int(w), int(h) x, y, w, h = int(x), int(y), int(w), int(h)
if opacity >= 1.0:
result = img.copy()
result[y:y+h, x:x+w] = color result[y:y+h, x:x+w] = color
return result return result
# With opacity, use alpha compositing
pil_img = Image.fromarray(img).convert('RGBA')
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
draw.rectangle([x, y, x + w, y + h], fill=fill_rgba)
result = Image.alpha_composite(pil_img, layer)
return np.array(result.convert('RGB'))
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1):
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0):
"""Draw rectangle outline.""" """Draw rectangle outline."""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
if opacity >= 1.0:
result = img.copy() result = img.copy()
cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)), cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)),
tuple(color), thickness) tuple(int(c) for c in color), int(thickness))
return result return result
# With opacity
pil_img = Image.fromarray(img).convert('RGBA')
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
draw.rectangle([int(x), int(y), int(x+w), int(y+h)],
outline=outline_rgba, width=int(thickness))
result = Image.alpha_composite(pil_img, layer)
return np.array(result.convert('RGB'))
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1):
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0):
"""Draw a line from (x1, y1) to (x2, y2).""" """Draw a line from (x1, y1) to (x2, y2)."""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
if opacity >= 1.0:
result = img.copy() result = img.copy()
cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)), cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)),
tuple(color), thickness) tuple(int(c) for c in color), int(thickness))
return result return result
# With opacity
pil_img = Image.fromarray(img).convert('RGBA')
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
draw.line([(int(x1), int(y1)), (int(x2), int(y2))],
fill=line_rgba, width=int(thickness))
result = Image.alpha_composite(pil_img, layer)
return np.array(result.convert('RGB'))
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, fill=False):
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1,
fill=False, opacity: float = 1.0):
"""Draw a circle.""" """Draw a circle."""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
if opacity >= 1.0:
result = img.copy() result = img.copy()
t = -1 if fill else thickness t = -1 if fill else int(thickness)
cv2.circle(result, (int(cx), int(cy)), int(radius), tuple(color), t) cv2.circle(result, (int(cx), int(cy)), int(radius),
tuple(int(c) for c in color), t)
return result return result
# With opacity
pil_img = Image.fromarray(img).convert('RGBA')
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
cx, cy, r = int(cx), int(cy), int(radius)
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, thickness=1, fill=False): if fill:
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba)
else:
draw.ellipse([cx - r, cy - r, cx + r, cy + r],
outline=rgba, width=int(thickness))
result = Image.alpha_composite(pil_img, layer)
return np.array(result.convert('RGB'))
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None,
thickness=1, fill=False, opacity: float = 1.0):
"""Draw an ellipse.""" """Draw an ellipse."""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
if opacity >= 1.0:
result = img.copy() result = img.copy()
t = -1 if fill else thickness t = -1 if fill else int(thickness)
cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)), cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)),
angle, 0, 360, tuple(color), t) float(angle), 0, 360, tuple(int(c) for c in color), t)
return result return result
# With opacity (note: PIL doesn't support rotated ellipses easily)
# Fall back to cv2 on a separate layer
layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
t = -1 if fill else int(thickness)
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)),
float(angle), 0, 360, rgba, t)
def prim_draw_polygon(img, points, color=None, thickness=1, fill=False): pil_img = Image.fromarray(img).convert('RGBA')
pil_layer = Image.fromarray(layer)
result = Image.alpha_composite(pil_img, pil_layer)
return np.array(result.convert('RGB'))
def prim_draw_polygon(img, points, color=None, thickness=1,
fill=False, opacity: float = 1.0):
"""Draw a polygon from list of [x, y] points.""" """Draw a polygon from list of [x, y] points."""
if color is None: if color is None:
color = [255, 255, 255] color = [255, 255, 255]
if opacity >= 1.0:
result = img.copy() result = img.copy()
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
if fill: if fill:
cv2.fillPoly(result, [pts], tuple(color)) cv2.fillPoly(result, [pts], tuple(int(c) for c in color))
else: else:
cv2.polylines(result, [pts], True, tuple(color), thickness) cv2.polylines(result, [pts], True,
tuple(int(c) for c in color), int(thickness))
return result return result
# With opacity
pil_img = Image.fromarray(img).convert('RGBA')
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
pts_flat = [(int(p[0]), int(p[1])) for p in points]
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
if fill:
draw.polygon(pts_flat, fill=rgba)
else:
draw.polygon(pts_flat, outline=rgba, width=int(thickness))
result = Image.alpha_composite(pil_img, layer)
return np.array(result.convert('RGB'))
# =============================================================================
# PRIMITIVES Export
# =============================================================================
PRIMITIVES = { PRIMITIVES = {
# Text # Font management
'make-font': prim_make_font,
'list-fonts': prim_list_fonts,
'font-size': prim_font_size,
# Text measurement
'text-size': prim_text_size,
'text-metrics': prim_text_metrics,
'fit-text-size': prim_fit_text_size,
'fit-font': prim_fit_font,
# Text drawing
'text': prim_text,
'text-box': prim_text_box,
# Legacy text (compatibility)
'draw-char': prim_draw_char, 'draw-char': prim_draw_char,
'draw-text': prim_draw_text, 'draw-text': prim_draw_text,

View File

@@ -8,12 +8,18 @@ GPU Acceleration:
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU) - Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
- Hardware video decoding (NVDEC) is used when available - Hardware video decoding (NVDEC) is used when available
- Dramatically improves performance on GPU nodes - Dramatically improves performance on GPU nodes
Async Prefetching:
- Set STREAMING_PREFETCH=1 to enable background frame prefetching
- Decodes upcoming frames while current frame is being processed
""" """
import os import os
import numpy as np import numpy as np
import subprocess import subprocess
import json import json
import threading
from collections import deque
from pathlib import Path from pathlib import Path
# Try to import CuPy for GPU acceleration # Try to import CuPy for GPU acceleration
@@ -28,6 +34,10 @@ except ImportError:
# Disabled by default until all primitives support GPU frames # Disabled by default until all primitives support GPU frames
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
# Async prefetch mode - decode frames in background thread
PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1"
PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10"))
# Check for hardware decode support (cached) # Check for hardware decode support (cached)
_HWDEC_AVAILABLE = None _HWDEC_AVAILABLE = None
@@ -283,6 +293,122 @@ class VideoSource:
self._proc = None self._proc = None
class PrefetchingVideoSource:
"""
Video source with background prefetching for improved performance.
Wraps VideoSource and adds a background thread that pre-decodes
upcoming frames while the main thread processes the current frame.
"""
def __init__(self, path: str, fps: float = 30, buffer_size: int = None):
self._source = VideoSource(path, fps)
self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE
self._buffer = {} # time -> frame
self._buffer_lock = threading.Lock()
self._prefetch_time = 0.0
self._frame_time = 1.0 / fps
self._stop_event = threading.Event()
self._request_event = threading.Event()
self._target_time = 0.0
# Start prefetch thread
self._thread = threading.Thread(target=self._prefetch_loop, daemon=True)
self._thread.start()
import sys
print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr)
def _prefetch_loop(self):
"""Background thread that pre-reads frames."""
while not self._stop_event.is_set():
# Wait for work or timeout
self._request_event.wait(timeout=0.01)
self._request_event.clear()
if self._stop_event.is_set():
break
# Prefetch frames ahead of target time
target = self._target_time
with self._buffer_lock:
# Clean old frames (more than 1 second behind)
old_times = [t for t in self._buffer.keys() if t < target - 1.0]
for t in old_times:
del self._buffer[t]
# Count how many frames we have buffered ahead
buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target)
# Prefetch if buffer not full
if buffered_ahead < self._buffer_size:
# Find next time to prefetch
prefetch_t = target
with self._buffer_lock:
existing_times = set(self._buffer.keys())
for _ in range(self._buffer_size):
if prefetch_t not in existing_times:
break
prefetch_t += self._frame_time
# Read the frame (this is the slow part)
try:
frame = self._source.read_at(prefetch_t)
with self._buffer_lock:
self._buffer[prefetch_t] = frame
except Exception as e:
import sys
print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr)
def read_at(self, t: float) -> np.ndarray:
"""Read frame at specific time, using prefetch buffer if available."""
self._target_time = t
self._request_event.set() # Wake up prefetch thread
# Round to frame time for buffer lookup
t_key = round(t / self._frame_time) * self._frame_time
# Check buffer first
with self._buffer_lock:
if t_key in self._buffer:
return self._buffer[t_key]
# Also check for close matches (within half frame time)
for buf_t, frame in self._buffer.items():
if abs(buf_t - t) < self._frame_time * 0.5:
return frame
# Not in buffer - read directly (blocking)
frame = self._source.read_at(t)
# Store in buffer
with self._buffer_lock:
self._buffer[t_key] = frame
return frame
def read(self) -> np.ndarray:
"""Read frame (uses last cached or t=0)."""
return self.read_at(0)
def skip(self):
"""No-op for seek-based reading."""
pass
@property
def size(self):
return self._source.size
@property
def path(self):
return self._source.path
def close(self):
self._stop_event.set()
self._request_event.set() # Wake up thread to exit
self._thread.join(timeout=1.0)
self._source.close()
class AudioAnalyzer: class AudioAnalyzer:
"""Audio analyzer for energy and beat detection.""" """Audio analyzer for energy and beat detection."""
@@ -394,7 +520,12 @@ class AudioAnalyzer:
# === Primitives === # === Primitives ===
def prim_make_video_source(path: str, fps: float = 30): def prim_make_video_source(path: str, fps: float = 30):
"""Create a video source from a file path.""" """Create a video source from a file path.
Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default).
"""
if PREFETCH_ENABLED:
return PrefetchingVideoSource(path, fps)
return VideoSource(path, fps) return VideoSource(path, fps)

File diff suppressed because it is too large Load Diff

View File

@@ -797,31 +797,63 @@ def prim_tan(x: float) -> float:
return math.tan(x) return math.tan(x)
def prim_atan2(y: float, x: float) -> float: def prim_atan2(y, x):
if hasattr(y, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape)
return math.atan2(y, x) return math.atan2(y, x)
def prim_sqrt(x: float) -> float: def prim_sqrt(x):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.sqrt(np.maximum(0, x._data)), x._shape)
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(0, x))
return math.sqrt(max(0, x)) return math.sqrt(max(0, x))
def prim_pow(x: float, y: float) -> float: def prim_pow(x, y):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
y_data = y._data if hasattr(y, '_data') else y
return Xector(np.power(x._data, y_data), x._shape)
return math.pow(x, y) return math.pow(x, y)
def prim_abs(x: float) -> float: def prim_abs(x):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.abs(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.abs(x)
return abs(x) return abs(x)
def prim_floor(x: float) -> int: def prim_floor(x):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.floor(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.floor(x)
return int(math.floor(x)) return int(math.floor(x))
def prim_ceil(x: float) -> int: def prim_ceil(x):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.ceil(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.ceil(x)
return int(math.ceil(x)) return int(math.ceil(x))
def prim_round(x: float) -> int: def prim_round(x):
if hasattr(x, '_data'): # Xector
from sexp_effects.primitive_libs.xector import Xector
return Xector(np.round(x._data), x._shape)
if isinstance(x, np.ndarray):
return np.round(x)
return int(round(x)) return int(round(x))

860
streaming/jax_typography.py Normal file
View File

@@ -0,0 +1,860 @@
"""
JAX Typography Primitives
Two approaches for text rendering, both compile to JAX/GPU:
## 1. TextStrip - Pixel-perfect static text
Pre-render entire strings at compile time using PIL.
Perfect sub-pixel anti-aliasing, exact match with PIL.
Use for: static titles, labels, any text without per-character effects.
S-expression:
(let ((strip (render-text-strip "Hello World" 48)))
(place-text-strip frame strip x y :color white))
## 2. Glyph-by-glyph - Dynamic text effects
Individual glyph placement for wave, arc, audio-reactive effects.
Each character can have independent position, color, opacity.
Note: slight anti-aliasing differences vs PIL due to integer positioning.
S-expression:
; Wave text - y oscillates with character index
(let ((glyphs (text-glyphs "Wavy" 48)))
(first
(fold glyphs (list frame 0)
(lambda (acc g)
(let ((frm (first acc))
(cursor (second acc))
(i (length acc))) ; approximate index
(list
(place-glyph frm (glyph-image g)
(+ x cursor)
(+ y (* amplitude (sin (* i frequency))))
(glyph-bearing-x g) (glyph-bearing-y g)
white 1.0)
(+ cursor (glyph-advance g))))))))
; Audio-reactive spacing
(let ((glyphs (text-glyphs "Bass" 48))
(bass (audio-band 0 200)))
(first
(fold glyphs (list frame 0)
(lambda (acc g)
(let ((frm (first acc))
(cursor (second acc)))
(list
(place-glyph frm (glyph-image g)
(+ x cursor) y
(glyph-bearing-x g) (glyph-bearing-y g)
white 1.0)
(+ cursor (glyph-advance g) (* bass 20))))))))
Kerning support:
; With kerning adjustment
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
"""
import numpy as np
import jax.numpy as jnp
from typing import Tuple, Dict, Any
from dataclasses import dataclass
# =============================================================================
# Glyph Data (computed at compile time)
# =============================================================================
@dataclass
class GlyphData:
"""Glyph data computed at compile time.
Attributes:
char: The character
image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime
advance: Horizontal advance (distance to next glyph origin)
bearing_x: Left side bearing (x offset from origin to first pixel)
bearing_y: Top bearing (y offset from baseline to top of glyph)
width: Image width
height: Image height
"""
char: str
image: np.ndarray # (H, W, 4) RGBA uint8
advance: float
bearing_x: float
bearing_y: float
width: int
height: int
# Font cache: (font_name, font_size) -> {char: GlyphData}
_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {}
# Font metrics cache: (font_name, font_size) -> (ascent, descent)
_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {}
# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment}
# Kerning adjustment is added to advance: new_advance = advance + kerning
# Typically negative (characters move closer together)
_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {}
def _load_font(font_name: str = None, font_size: int = 32):
"""Load a font. Called at compile time."""
from PIL import ImageFont
candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
]
for path in candidates:
if path is None:
continue
try:
return ImageFont.truetype(path, font_size)
except (IOError, OSError):
continue
return ImageFont.load_default()
def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]:
"""Get or create glyph cache for a font. Called at compile time."""
cache_key = (font_name, font_size)
if cache_key in _GLYPH_CACHE:
return _GLYPH_CACHE[cache_key]
from PIL import Image, ImageDraw
font = _load_font(font_name, font_size)
ascent, descent = font.getmetrics()
_METRICS_CACHE[cache_key] = (ascent, descent)
glyphs = {}
charset = ''.join(chr(i) for i in range(32, 127))
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
temp_draw = ImageDraw.Draw(temp_img)
for char in charset:
# Get metrics
bbox = temp_draw.textbbox((0, 0), char, font=font)
advance = font.getlength(char)
x_min, y_min, x_max, y_max = bbox
# Create glyph image with padding
padding = 2
img_w = max(int(x_max - x_min) + padding * 2, 1)
img_h = max(int(y_max - y_min) + padding * 2, 1)
glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0))
glyph_draw = ImageDraw.Draw(glyph_img)
# Draw at position accounting for bbox offset
draw_x = padding - x_min
draw_y = padding - y_min
glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
glyphs[char] = GlyphData(
char=char,
image=np.array(glyph_img, dtype=np.uint8),
advance=float(advance),
bearing_x=float(x_min),
bearing_y=float(-y_min), # Distance from baseline to top
width=img_w,
height=img_h,
)
_GLYPH_CACHE[cache_key] = glyphs
return glyphs
def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]:
"""Get or create kerning cache for a font. Called at compile time.
Kerning is computed as:
kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b)
This gives the adjustment needed when placing 'b' after 'a'.
Typically negative (characters move closer together).
"""
cache_key = (font_name, font_size)
if cache_key in _KERNING_CACHE:
return _KERNING_CACHE[cache_key]
font = _load_font(font_name, font_size)
kerning = {}
# Compute kerning for all printable ASCII pairs
charset = ''.join(chr(i) for i in range(32, 127))
# Pre-compute individual character lengths
char_lengths = {c: font.getlength(c) for c in charset}
# Compute kerning for each pair
for c1 in charset:
for c2 in charset:
pair_length = font.getlength(c1 + c2)
individual_sum = char_lengths[c1] + char_lengths[c2]
kern = pair_length - individual_sum
# Only store non-zero kerning to save memory
if abs(kern) > 0.01:
kerning[(c1, c2)] = kern
_KERNING_CACHE[cache_key] = kerning
return kerning
def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float:
"""Get kerning adjustment between two characters. Compile-time.
Returns the adjustment to add to char1's advance when char2 follows.
Typically negative (characters move closer).
Usage in S-expression:
(+ (glyph-advance g1) (kerning g1 g2))
"""
kerning_cache = _get_kerning_cache(font_name, font_size)
return kerning_cache.get((char1, char2), 0.0)
@dataclass
class TextStrip:
"""Pre-rendered text strip with proper sub-pixel anti-aliasing.
Rendered at compile time using PIL for exact matching.
At runtime, just composite onto frame at integer positions.
Attributes:
text: The original text
image: RGBA image as numpy array (H, W, 4)
width: Strip width
height: Strip height
baseline_y: Y position of baseline within the strip
bearing_x: Left side bearing of first character
anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right)
anchor_y: Y offset for anchor point (depends on anchor type)
stroke_width: Stroke width used when rendering
"""
text: str
image: np.ndarray
width: int
height: int
baseline_y: int
bearing_x: float
anchor_x: float = 0.0
anchor_y: float = 0.0
stroke_width: int = 0
# Text strip cache: cache_key -> TextStrip
_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {}
def render_text_strip(
text: str,
font_name: str = None,
font_size: int = 32,
stroke_width: int = 0,
stroke_fill: tuple = None,
anchor: str = "la", # left-ascender (PIL default is "la")
multiline: bool = False,
line_spacing: int = 4,
align: str = "left",
) -> TextStrip:
"""Render text to a strip at compile time. Perfect sub-pixel anti-aliasing.
Args:
text: Text to render
font_name: Path to font file (None for default)
font_size: Font size in pixels
stroke_width: Outline width in pixels (0 for no outline)
stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black
anchor: PIL anchor code - first char: h=left, m=middle, r=right
second char: a=ascender, t=top, m=middle, s=baseline, d=descender
multiline: If True, handle newlines in text
line_spacing: Extra pixels between lines (for multiline)
align: 'left', 'center', 'right' (for multiline)
Returns:
TextStrip with pre-rendered text
"""
# Build cache key from all parameters
cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align)
if cache_key in _TEXT_STRIP_CACHE:
return _TEXT_STRIP_CACHE[cache_key]
from PIL import Image, ImageDraw
font = _load_font(font_name, font_size)
ascent, descent = font.getmetrics()
# Default stroke fill to black
if stroke_fill is None:
stroke_fill = (0, 0, 0, 255)
elif len(stroke_fill) == 3:
stroke_fill = (*stroke_fill, 255)
# Get text bbox (accounting for stroke)
temp = Image.new('RGBA', (1, 1))
temp_draw = ImageDraw.Draw(temp)
if multiline:
bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing,
stroke_width=stroke_width)
else:
bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width)
# bbox is (left, top, right, bottom) relative to origin
x_min, y_min, x_max, y_max = bbox
# Create image with padding (extra for stroke)
padding = 2 + stroke_width
img_width = max(int(x_max - x_min) + padding * 2, 1)
img_height = max(int(y_max - y_min) + padding * 2, 1)
# Create RGBA image
img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Draw text at position that puts it in the image
draw_x = padding - x_min
draw_y = padding - y_min
if multiline:
draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
spacing=line_spacing, align=align,
stroke_width=stroke_width, stroke_fill=stroke_fill)
else:
draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill)
# Baseline is at y=0 in text coordinates, which is at draw_y in image
baseline_y = draw_y
# Convert to numpy for pixel analysis
img_array = np.array(img, dtype=np.uint8)
# Calculate anchor offsets
# For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching
h_anchor = anchor[0] if len(anchor) > 0 else 'l'
v_anchor = anchor[1] if len(anchor) > 1 else 'a'
# Find actual pixel bounds (for middle anchor calculations)
alpha = img_array[:, :, 3]
nonzero_cols = np.where(alpha.max(axis=0) > 0)[0]
nonzero_rows = np.where(alpha.max(axis=1) > 0)[0]
if len(nonzero_cols) > 0:
pixel_x_min = nonzero_cols.min()
pixel_x_max = nonzero_cols.max()
pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0
else:
pixel_x_center = img_width / 2.0
if len(nonzero_rows) > 0:
pixel_y_min = nonzero_rows.min()
pixel_y_max = nonzero_rows.max()
pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0
else:
pixel_y_center = img_height / 2.0
# Horizontal offset
text_width = x_max - x_min
if h_anchor == 'l': # left edge of text
anchor_x = float(draw_x)
elif h_anchor == 'm': # middle - use actual pixel center for perfect matching
anchor_x = pixel_x_center
elif h_anchor == 'r': # right edge of text
anchor_x = float(draw_x + text_width)
else:
anchor_x = float(draw_x)
# Vertical offset
# PIL anchor positions are based on font metrics (ascent/descent):
# - 'a' (ascender): at the ascender line → draw_y in strip
# - 't' (top): at top of text bounding box → padding in strip
# - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender
# - 's' (baseline): at baseline = ascent below ascender
# - 'd' (descender): at descender line = ascent + descent below ascender
if v_anchor == 'a': # ascender
anchor_y = float(draw_y)
elif v_anchor == 't': # top of bbox
anchor_y = float(padding)
elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation)
anchor_y = float(draw_y + (ascent + descent) / 2.0)
elif v_anchor == 's': # baseline
anchor_y = float(draw_y + ascent)
elif v_anchor == 'd': # descender
anchor_y = float(draw_y + ascent + descent)
else:
anchor_y = float(draw_y) # default to ascender
strip = TextStrip(
text=text,
image=img_array,
width=img_width,
height=img_height,
baseline_y=baseline_y,
bearing_x=float(x_min),
anchor_x=anchor_x,
anchor_y=anchor_y,
stroke_width=stroke_width,
)
_TEXT_STRIP_CACHE[cache_key] = strip
return strip
# =============================================================================
# Compile-time functions (called during S-expression compilation)
# =============================================================================
def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData:
"""Get glyph data for a single character. Compile-time."""
cache = _get_glyph_cache(font_name, font_size)
return cache.get(char, cache.get(' '))
def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list:
"""Get glyph data for a string. Compile-time."""
cache = _get_glyph_cache(font_name, font_size)
space = cache.get(' ')
return [cache.get(c, space) for c in text]
def get_font_ascent(font_name: str = None, font_size: int = 32) -> float:
"""Get font ascent. Compile-time."""
_get_glyph_cache(font_name, font_size) # Ensure cache exists
return _METRICS_CACHE[(font_name, font_size)][0]
def get_font_descent(font_name: str = None, font_size: int = 32) -> float:
"""Get font descent. Compile-time."""
_get_glyph_cache(font_name, font_size)
return _METRICS_CACHE[(font_name, font_size)][1]
# =============================================================================
# JAX Runtime Primitives
# =============================================================================
def place_glyph_jax(
frame: jnp.ndarray,
glyph_image: jnp.ndarray, # (H, W, 4) RGBA
x: float,
y: float,
bearing_x: float,
bearing_y: float,
color: jnp.ndarray, # (3,) RGB 0-255
opacity: float = 1.0,
) -> jnp.ndarray:
"""
Place a glyph onto a frame. This is the core JAX primitive.
All positioning math can use traced values (x, y from audio, time, etc.)
The glyph_image is static (determined at compile time).
Args:
frame: (H, W, 3) RGB frame
glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array)
x: X position of glyph origin (baseline point)
y: Y position of baseline
bearing_x: Left side bearing
bearing_y: Top bearing (from baseline to top)
color: RGB color array
opacity: Opacity 0-1
Returns:
Frame with glyph composited
"""
h, w = frame.shape[:2]
gh, gw = glyph_image.shape[:2]
# Calculate destination position
# bearing_x: how far right of origin the glyph starts (can be negative)
# bearing_y: how far up from baseline the glyph extends
padding = 2 # Must match padding used in glyph creation
dst_x = x + bearing_x - padding
dst_y = y - bearing_y - padding
# Extract glyph RGB and alpha
glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
opacity_int = jnp.round(opacity * 255)
glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32)
glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply color tint (glyph is white, multiply by color)
color_normalized = color.astype(jnp.float32) / 255.0
tinted = glyph_rgb * color_normalized
from jax.lax import dynamic_update_slice
# Use padded buffer to avoid XLA's dynamic_update_slice clamping
buf_h = h + 2 * gh
buf_w = w + 2 * gw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
dst_x_int = dst_x.astype(jnp.int32)
dst_y_int = dst_y.astype(jnp.int32)
place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32)
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0))
rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :]
alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :]
# Alpha composite using PIL-compatible integer arithmetic
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_text_strip_jax(
frame: jnp.ndarray,
strip_image: jnp.ndarray, # (H, W, 4) RGBA
x: float,
y: float,
baseline_y: int,
bearing_x: float,
color: jnp.ndarray,
opacity: float = 1.0,
anchor_x: float = 0.0,
anchor_y: float = 0.0,
stroke_width: int = 0,
) -> jnp.ndarray:
"""
Place a pre-rendered text strip onto a frame.
The strip was rendered at compile time with proper sub-pixel anti-aliasing.
This just composites it at the specified position.
Args:
frame: (H, W, 3) RGB frame
strip_image: (sh, sw, 4) RGBA text strip
x: X position for anchor point
y: Y position for anchor point
baseline_y: Y position of baseline within the strip
bearing_x: Left side bearing
color: RGB color
opacity: Opacity 0-1
anchor_x: X offset of anchor point within strip
anchor_y: Y offset of anchor point within strip
stroke_width: Stroke width used when rendering (affects padding)
Returns:
Frame with text composited
"""
h, w = frame.shape[:2]
sh, sw = strip_image.shape[:2]
# Calculate destination position
# Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame
# anchor_x/anchor_y already account for the anchor position within the strip
# Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding)
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
# Extract strip RGB and alpha
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
# Use jnp.round (banker's rounding) to match Python's round() used by PIL
opacity_int = jnp.round(opacity * 255)
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
# Apply color tint
color_normalized = color.astype(jnp.float32) / 255.0
tinted = strip_rgb * color_normalized
from jax.lax import dynamic_update_slice
# Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior.
# XLA clamps indices so the update fits, which silently shifts the strip.
# By placing into a buffer padded by strip dimensions, then extracting the
# frame-sized region, we get correct clipping for both overflow and underflow.
buf_h = h + 2 * sh
buf_w = w + 2 * sw
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
# Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
# Extract frame-sized region (sh, sw are compile-time constants from strip shape)
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
# Alpha composite using PIL-compatible integer arithmetic:
# result = (src * alpha + dst * (255 - alpha) + 127) // 255
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
dst_int = frame.astype(jnp.int32)
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
return jnp.clip(result, 0, 255).astype(jnp.uint8)
def place_glyph_simple(
frame: jnp.ndarray,
glyph: GlyphData,
x: float,
y: float,
color: tuple = (255, 255, 255),
opacity: float = 1.0,
) -> jnp.ndarray:
"""
Convenience wrapper that takes GlyphData directly.
Converts glyph image to JAX array.
For S-expression use, prefer place_glyph_jax with pre-converted arrays.
"""
glyph_jax = jnp.asarray(glyph.image)
color_jax = jnp.array(color, dtype=jnp.float32)
return place_glyph_jax(
frame, glyph_jax, x, y,
glyph.bearing_x, glyph.bearing_y,
color_jax, opacity
)
# =============================================================================
# S-Expression Primitive Bindings
# =============================================================================
def bind_typography_primitives(env: dict) -> dict:
"""
Add typography primitives to an S-expression environment.
Primitives added:
(text-glyphs text font-size) -> list of glyph data
(glyph-image g) -> JAX array (H, W, 4)
(glyph-advance g) -> float
(glyph-bearing-x g) -> float
(glyph-bearing-y g) -> float
(glyph-width g) -> int
(glyph-height g) -> int
(font-ascent font-size) -> float
(font-descent font-size) -> float
(place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame
"""
def prim_text_glyphs(text, font_size=32, font_name=None):
"""Get list of glyph data for text. Compile-time."""
return get_glyphs(str(text), font_name, int(font_size))
def prim_glyph_image(glyph):
"""Get glyph image as JAX array."""
return jnp.asarray(glyph.image)
def prim_glyph_advance(glyph):
"""Get glyph advance width."""
return glyph.advance
def prim_glyph_bearing_x(glyph):
"""Get glyph left side bearing."""
return glyph.bearing_x
def prim_glyph_bearing_y(glyph):
"""Get glyph top bearing."""
return glyph.bearing_y
def prim_glyph_width(glyph):
"""Get glyph image width."""
return glyph.width
def prim_glyph_height(glyph):
"""Get glyph image height."""
return glyph.height
def prim_font_ascent(font_size=32, font_name=None):
"""Get font ascent."""
return get_font_ascent(font_name, int(font_size))
def prim_font_descent(font_size=32, font_name=None):
"""Get font descent."""
return get_font_descent(font_name, int(font_size))
def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y,
color=(255, 255, 255), opacity=1.0):
"""Place glyph on frame. Runtime JAX operation."""
color_arr = jnp.array(color, dtype=jnp.float32)
return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y,
color_arr, opacity)
def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None):
"""Get kerning adjustment between two glyphs. Compile-time.
Returns adjustment to add to glyph1's advance when glyph2 follows.
Typically negative (characters move closer).
Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size))
"""
return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size))
def prim_char_kerning(char1, char2, font_size=32, font_name=None):
"""Get kerning adjustment between two characters. Compile-time."""
return get_kerning(str(char1), str(char2), font_name, int(font_size))
# TextStrip primitives for pre-rendered text with proper anti-aliasing
def prim_render_text_strip(text, font_size=32, font_name=None):
"""Render text to a strip at compile time. Perfect anti-aliasing."""
return render_text_strip(str(text), font_name, int(font_size))
def prim_render_text_strip_styled(
text, font_size=32, font_name=None,
stroke_width=0, stroke_fill=None,
anchor="la", multiline=False, line_spacing=4, align="left"
):
"""Render styled text to a strip. Supports stroke, anchors, multiline.
Args:
text: Text to render
font_size: Size in pixels
font_name: Path to font file
stroke_width: Outline width (0 = no outline)
stroke_fill: Outline color as (R,G,B) or (R,G,B,A)
anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender)
multiline: If True, handle newlines
line_spacing: Extra pixels between lines
align: "left", "center", "right" for multiline
"""
return render_text_strip(
str(text), font_name, int(font_size),
stroke_width=int(stroke_width),
stroke_fill=stroke_fill,
anchor=str(anchor),
multiline=bool(multiline),
line_spacing=int(line_spacing),
align=str(align),
)
def prim_text_strip_image(strip):
"""Get text strip image as JAX array."""
return jnp.asarray(strip.image)
def prim_text_strip_width(strip):
"""Get text strip width."""
return strip.width
def prim_text_strip_height(strip):
"""Get text strip height."""
return strip.height
def prim_text_strip_baseline_y(strip):
"""Get text strip baseline Y position."""
return strip.baseline_y
def prim_text_strip_bearing_x(strip):
"""Get text strip left bearing."""
return strip.bearing_x
def prim_text_strip_anchor_x(strip):
"""Get text strip anchor X offset."""
return strip.anchor_x
def prim_text_strip_anchor_y(strip):
"""Get text strip anchor Y offset."""
return strip.anchor_y
def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0):
"""Place pre-rendered text strip on frame. Runtime JAX operation."""
strip_img = jnp.asarray(strip.image)
color_arr = jnp.array(color, dtype=jnp.float32)
return place_text_strip_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color_arr, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
# Add to environment
env.update({
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
'text-glyphs': prim_text_glyphs,
'glyph-image': prim_glyph_image,
'glyph-advance': prim_glyph_advance,
'glyph-bearing-x': prim_glyph_bearing_x,
'glyph-bearing-y': prim_glyph_bearing_y,
'glyph-width': prim_glyph_width,
'glyph-height': prim_glyph_height,
'glyph-kerning': prim_glyph_kerning,
'char-kerning': prim_char_kerning,
'font-ascent': prim_font_ascent,
'font-descent': prim_font_descent,
'place-glyph': prim_place_glyph,
# TextStrip primitives (for pixel-perfect static text)
'render-text-strip': prim_render_text_strip,
'render-text-strip-styled': prim_render_text_strip_styled,
'text-strip-image': prim_text_strip_image,
'text-strip-width': prim_text_strip_width,
'text-strip-height': prim_text_strip_height,
'text-strip-baseline-y': prim_text_strip_baseline_y,
'text-strip-bearing-x': prim_text_strip_bearing_x,
'text-strip-anchor-x': prim_text_strip_anchor_x,
'text-strip-anchor-y': prim_text_strip_anchor_y,
'place-text-strip': prim_place_text_strip,
})
return env
# =============================================================================
# Example: Render text using primitives (for testing)
# =============================================================================
def render_text_primitives(
frame: jnp.ndarray,
text: str,
x: float,
y: float,
font_size: int = 32,
color: tuple = (255, 255, 255),
opacity: float = 1.0,
use_kerning: bool = True,
) -> jnp.ndarray:
"""
Render text using the primitives.
This is what an S-expression would compile to.
Args:
use_kerning: If True, apply kerning adjustments between characters
"""
glyphs = get_glyphs(text, None, font_size)
color_arr = jnp.array(color, dtype=jnp.float32)
cursor = x
for i, g in enumerate(glyphs):
glyph_jax = jnp.asarray(g.image)
frame = place_glyph_jax(
frame, glyph_jax, cursor, y,
g.bearing_x, g.bearing_y,
color_arr, opacity
)
# Advance cursor with optional kerning
advance = g.advance
if use_kerning and i + 1 < len(glyphs):
advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size)
cursor = cursor + advance
return frame

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation:
""" """
import sys import sys
import os
import time import time
import json import json
import hashlib import hashlib
@@ -62,6 +63,38 @@ class Context:
fps: float = 30.0 fps: float = 30.0
class DeferredEffectChain:
"""
Represents a chain of JAX effects that haven't been executed yet.
Allows effects to be accumulated through let bindings and fused
into a single JIT-compiled function when the result is needed.
"""
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
self.effects = effects # List of effect names, innermost first
self.params_list = params_list # List of param dicts, matching effects
self.base_frame = base_frame # The actual frame array at the start
self.t = t
self.frame_num = frame_num
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
"""Add another effect to the chain (outermost)."""
return DeferredEffectChain(
self.effects + [effect_name],
self.params_list + [params],
self.base_frame,
self.t,
self.frame_num
)
@property
def shape(self):
"""Allow shape check without forcing execution."""
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
class StreamInterpreter: class StreamInterpreter:
""" """
Fully generic streaming sexp interpreter. Fully generic streaming sexp interpreter.
@@ -98,6 +131,9 @@ class StreamInterpreter:
self.use_jax = use_jax self.use_jax = use_jax
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
if use_jax: if use_jax:
if _init_jax(): if _init_jax():
print("JAX acceleration enabled", file=sys.stderr) print("JAX acceleration enabled", file=sys.stderr)
@@ -238,6 +274,8 @@ class StreamInterpreter:
"""Load primitives from a Python library file. """Load primitives from a Python library file.
Prefers GPU-accelerated versions (*_gpu.py) when available. Prefers GPU-accelerated versions (*_gpu.py) when available.
Uses cached modules from sys.modules to ensure consistent state
(e.g., same RNG instance for all interpreters).
""" """
import importlib.util import importlib.util
@@ -264,9 +302,26 @@ class StreamInterpreter:
if not lib_path: if not lib_path:
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}") raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
# Use cached module if already imported to preserve state (e.g., RNG)
# This is critical for deterministic random number sequences
# Check multiple possible module keys (standard import paths and our cache)
possible_keys = [
f"sexp_effects.primitive_libs.{actual_lib_name}",
f"sexp_primitives.{actual_lib_name}",
]
module = None
for key in possible_keys:
if key in sys.modules:
module = sys.modules[key]
break
if module is None:
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path) spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
# Cache for future use under our key
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
# Check if this is a GPU-accelerated module # Check if this is a GPU-accelerated module
is_gpu = actual_lib_name.endswith('_gpu') is_gpu = actual_lib_name.endswith('_gpu')
@@ -452,30 +507,353 @@ class StreamInterpreter:
try: try:
jax_fn = self.jax_effects[name] jax_fn = self.jax_effects[name]
# Ensure frame is numpy array # Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
# JAX handles numpy and JAX arrays natively, no conversion needed
if hasattr(frame, 'cpu'): if hasattr(frame, 'cpu'):
frame = frame.cpu frame = frame.cpu
elif hasattr(frame, 'get'): elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get() frame = frame.get() # CuPy array -> numpy
# Get seed from config for deterministic random # Get seed from config for deterministic random
seed = self.config.get('seed', 42) seed = self.config.get('seed', 42)
# Call JAX function with parameters # Call JAX function with parameters
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params) # Return JAX array directly - don't block or convert per-effect
# Conversion to numpy happens once at frame write time
# Convert result back to numpy if needed return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
if hasattr(result, 'block_until_ready'):
result.block_until_ready() # Ensure computation is complete
if hasattr(result, '__array__'):
result = np.asarray(result)
return result
except Exception as e: except Exception as e:
# Fall back to interpreter on error # Fall back to interpreter on error
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr) print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
return None return None
def _is_jax_effect_expr(self, expr) -> bool:
"""Check if an expression is a JAX-compiled effect call."""
if not isinstance(expr, list) or not expr:
return False
head = expr[0]
if not isinstance(head, Symbol):
return False
return head.name in self.jax_effects
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
"""
Extract a chain of JAX effects from an expression.
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
effect_names and params_list are in execution order (innermost first).
For (effect1 (effect2 frame :p1 v1) :p2 v2):
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
"""
if not self._is_jax_effect_expr(expr):
return None
chain = []
params_list = []
current = expr
while self._is_jax_effect_expr(current):
head = current[0]
effect_name = head.name
args = current[1:]
# Extract params for this effect
effect = self.effects[effect_name]
effect_params = {}
for pname, pdef in effect['params'].items():
effect_params[pname] = pdef.get('default', 0)
# Find the frame argument (first positional) and other params
frame_arg = None
i = 0
while i < len(args):
if isinstance(args[i], Keyword):
pname = args[i].name
if pname in effect['params'] and i + 1 < len(args):
effect_params[pname] = self._eval(args[i + 1], env)
i += 2
else:
if frame_arg is None:
frame_arg = args[i] # First positional is frame
i += 1
chain.append(effect_name)
params_list.append(effect_params)
if frame_arg is None:
return None # No frame argument found
# Check if frame_arg is another effect call
if self._is_jax_effect_expr(frame_arg):
current = frame_arg
else:
# End of chain - frame_arg is the base frame
# Reverse to get innermost-first execution order
chain.reverse()
params_list.reverse()
return (chain, params_list, frame_arg)
return None
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
"""Generate a cache key for an effect chain.
Includes static param values in the key since they affect compilation.
"""
parts = []
for name, params in zip(effect_names, params_list):
param_parts = []
for pname in sorted(params.keys()):
pval = params[pname]
# Include static values in key (strings, bools)
if isinstance(pval, (str, bool)):
param_parts.append(f"{pname}={pval}")
else:
param_parts.append(pname)
parts.append(f"{name}:{','.join(param_parts)}")
return '|'.join(parts)
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a chain of effects into a single fused JAX function.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts for each effect (used to detect static types)
Returns:
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param names and identify static params from actual values
effect_param_names = []
static_params = ['seed'] # seed is always static
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
# Check actual values to identify static types
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.append(f"_p{i}_{pname}")
def fused_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
# Extract params for this effect from kwargs
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# JIT with static params (seed + any string/bool params)
return jax.jit(fused_fn, static_argnames=static_params)
except Exception as e:
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
"""Apply a chain of effects, using fused compilation if available."""
chain_key = self._get_chain_key(effect_names, params_list)
# Try to get or compile fused chain
if chain_key not in self.jax_fused_chains:
fused_fn = self._compile_effect_chain(effect_names, params_list)
self.jax_fused_chains[chain_key] = fused_fn
if fused_fn:
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
fused_fn = self.jax_fused_chains.get(chain_key)
if fused_fn is not None:
# Build kwargs with prefixed param names
kwargs = {}
for i, params in enumerate(params_list):
for pname, pval in params.items():
kwargs[f"_p{i}_{pname}"] = pval
seed = self.config.get('seed', 42)
# Handle GPU frames
if hasattr(frame, 'cpu'):
frame = frame.cpu
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
frame = frame.get()
try:
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
except Exception as e:
print(f"Fused chain error: {e}", file=sys.stderr)
# Fall back to sequential application
result = frame
for name, params in zip(effect_names, params_list):
result = self._apply_jax_effect(name, result, params, t, frame_num)
if result is None:
return None
return result
def _force_deferred(self, deferred: DeferredEffectChain):
"""Execute a deferred effect chain and return the actual array."""
if len(deferred.effects) == 0:
return deferred.base_frame
return self._apply_effect_chain(
deferred.effects,
deferred.params_list,
deferred.base_frame,
deferred.t,
deferred.frame_num
)
def _maybe_force(self, value):
"""Force a deferred chain if needed, otherwise return as-is."""
if isinstance(value, DeferredEffectChain):
return self._force_deferred(value)
return value
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
"""
Compile a vmapped version of an effect chain for batch processing.
Args:
effect_names: List of effect names in order [innermost, ..., outermost]
params_list: List of param dicts (used to detect static types)
Returns:
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
"""
if not _JAX_AVAILABLE:
return None
try:
import jax
import jax.numpy as jnp
# Get the individual JAX functions
jax_fns = [self.jax_effects[name] for name in effect_names]
# Pre-extract param info
effect_param_names = []
static_params = set()
for i, (name, params) in enumerate(zip(effect_names, params_list)):
param_names = list(params.keys())
effect_param_names.append(param_names)
for pname, pval in params.items():
if isinstance(pval, (str, bool)):
static_params.add(f"_p{i}_{pname}")
# Single-frame function (will be vmapped)
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
result = frame
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
effect_kwargs = {}
for pname in param_names:
key = f"_p{i}_{pname}"
if key in kwargs:
effect_kwargs[pname] = kwargs[key]
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
return result
# Return unbatched function - we'll vmap at call time with proper in_axes
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
except Exception as e:
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
return None
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
frames: list, ts: list, frame_nums: list) -> Optional[list]:
"""
Apply an effect chain to a batch of frames using vmap.
Args:
effect_names: List of effect names
params_list_batch: List of params_list for each frame in batch
frames: List of input frames
ts: List of time values
frame_nums: List of frame numbers
Returns:
List of output frames, or None on failure
"""
if not frames:
return []
# Use first frame's params for chain key (assume same structure)
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
batch_key = f"batch:{chain_key}"
# Compile batched version if needed
if batch_key not in self.jax_batched_chains:
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
self.jax_batched_chains[batch_key] = batched_fn
if batched_fn:
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
batched_fn = self.jax_batched_chains.get(batch_key)
if batched_fn is not None:
try:
import jax
import jax.numpy as jnp
# Stack frames into batch array
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
ts_array = jnp.array(ts)
frame_nums_array = jnp.array(frame_nums)
# Build kwargs - all numeric params as arrays for vmap
kwargs = {}
static_kwargs = {} # Non-vmapped (strings, bools)
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
for j, pname in enumerate(params_list_batch[0][i].keys()):
key = f"_p{i}_{pname}"
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
first = values[0]
if isinstance(first, (str, bool)):
# Static params - not vmapped
static_kwargs[key] = first
elif isinstance(first, (int, float)):
# Always batch numeric params for simplicity
kwargs[key] = jnp.array(values)
elif hasattr(first, 'shape'):
kwargs[key] = jnp.stack(values)
else:
kwargs[key] = jnp.array(values)
seed = self.config.get('seed', 42)
# Create wrapper that unpacks the params dict
def single_call(frame, t, frame_num, params_dict):
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
# vmap over frame, t, frame_num, and the params dict (as pytree)
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
# Stack kwargs into a dict of arrays (pytree with matching structure)
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
# Unstack results
return [results[i] for i in range(len(frames))]
except Exception as e:
print(f"Batched chain error: {e}", file=sys.stderr)
# Fall back to sequential
return None
def _init(self): def _init(self):
"""Initialize from sexp - load primitives, effects, defs, scans.""" """Initialize from sexp - load primitives, effects, defs, scans."""
# Set random seed for deterministic output # Set random seed for deterministic output
@@ -869,6 +1247,22 @@ class StreamInterpreter:
# === Effects === # === Effects ===
if op in self.effects: if op in self.effects:
# Try to detect and fuse effect chains for JAX acceleration
if self.use_jax and op in self.jax_effects:
chain_info = self._extract_effect_chain(expr, env)
if chain_info is not None:
effect_names, params_list, base_frame_expr = chain_info
# Only use chain if we have 2+ effects (worth fusing)
if len(effect_names) >= 2:
base_frame = self._eval(base_frame_expr, env)
if base_frame is not None and hasattr(base_frame, 'shape'):
t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0)
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
if result is not None:
return result
# Fall through if chain application fails
effect = self.effects[op] effect = self.effects[op]
effect_env = dict(env) effect_env = dict(env)
@@ -895,17 +1289,28 @@ class StreamInterpreter:
positional_idx += 1 positional_idx += 1
i += 1 i += 1
# Try JAX-accelerated execution first # Try JAX-accelerated execution with deferred chaining
if self.use_jax and op in self.jax_effects and frame_val is not None: if self.use_jax and op in self.jax_effects and frame_val is not None:
# Build params dict for JAX (exclude 'frame') # Build params dict for JAX (exclude 'frame')
jax_params = {k: v for k, v in effect_env.items() jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
if k != 'frame' and k in effect['params']} if k != 'frame' and k in effect['params']}
t = env.get('t', 0.0) t = env.get('t', 0.0)
frame_num = env.get('frame-num', 0) frame_num = env.get('frame-num', 0)
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
if result is not None: # Check if input is a deferred chain - if so, extend it
return result if isinstance(frame_val, DeferredEffectChain):
# Fall through to interpreter if JAX fails return frame_val.extend(op, jax_params)
# Check if input is a valid frame - create new deferred chain
if hasattr(frame_val, 'shape'):
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
# Fall through to interpreter if not a valid frame
# Force any deferred frame before interpreter evaluation
if isinstance(frame_val, DeferredEffectChain):
frame_val = self._force_deferred(frame_val)
effect_env['frame'] = frame_val
return self._eval(effect['body'], effect_env) return self._eval(effect['body'], effect_env)
@@ -922,10 +1327,15 @@ class StreamInterpreter:
if isinstance(args[i], Keyword): if isinstance(args[i], Keyword):
k = args[i].name k = args[i].name
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
# Force deferred chains before passing to primitives
v = self._maybe_force(v)
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim) kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
i += 2 i += 2
else: else:
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim)) val = self._eval(args[i], env)
# Force deferred chains before passing to primitives
val = self._maybe_force(val)
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
i += 1 i += 1
try: try:
if kwargs: if kwargs:
@@ -1152,6 +1562,61 @@ class StreamInterpreter:
eval_times = [] eval_times = []
write_times = [] write_times = []
# Batch accumulation for JAX
batch_deferred = [] # Accumulated DeferredEffectChains
batch_times = [] # Corresponding time values
batch_start_frame = 0
def flush_batch():
"""Execute accumulated batch and write results."""
nonlocal batch_deferred, batch_times
if not batch_deferred:
return
t_flush = time.time()
# Check if all chains have same structure (can batch)
first = batch_deferred[0]
can_batch = (
self.use_jax and
len(batch_deferred) >= 2 and
all(d.effects == first.effects for d in batch_deferred)
)
if can_batch:
# Try batched execution
frames = [d.base_frame for d in batch_deferred]
ts = [d.t for d in batch_deferred]
frame_nums = [d.frame_num for d in batch_deferred]
params_batch = [d.params_list for d in batch_deferred]
results = self._apply_batched_chain(
first.effects, params_batch, frames, ts, frame_nums
)
if results is not None:
# Write batched results
for result, t in zip(results, batch_times):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
return
# Fall back to sequential execution
for deferred, t in zip(batch_deferred, batch_times):
result = self._force_deferred(deferred)
if result is not None and hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, t)
batch_deferred = []
batch_times = []
for frame_num in range(start_frame, n_frames): for frame_num in range(start_frame, n_frames):
if not out.is_open: if not out.is_open:
break break
@@ -1182,7 +1647,22 @@ class StreamInterpreter:
eval_times.append(time.time() - t1) eval_times.append(time.time() - t1)
t2 = time.time() t2 = time.time()
if result is not None and hasattr(result, 'shape'): if result is not None:
if isinstance(result, DeferredEffectChain):
# Accumulate for batching
batch_deferred.append(result)
batch_times.append(ctx.t)
# Flush when batch is full
if len(batch_deferred) >= self.jax_batch_size:
flush_batch()
else:
# Not deferred - flush any pending batch first, then write
flush_batch()
if hasattr(result, 'shape'):
if hasattr(result, 'block_until_ready'):
result.block_until_ready()
result = np.asarray(result)
out.write(result, ctx.t) out.write(result, ctx.t)
write_times.append(time.time() - t2) write_times.append(time.time() - t2)
@@ -1219,6 +1699,9 @@ class StreamInterpreter:
except Exception as e: except Exception as e:
print(f"Warning: progress callback failed: {e}", file=sys.stderr) print(f"Warning: progress callback failed: {e}", file=sys.stderr)
# Flush any remaining batch
flush_batch()
finally: finally:
out.close() out.close()
# Store output for access to properties like playlist_cid # Store output for access to properties like playlist_cid

542
test_funky_text.py Normal file
View File

@@ -0,0 +1,542 @@
#!/usr/bin/env python3
"""
Funky comparison tests: PIL vs TextStrip system.
Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc.
"""
import numpy as np
import jax.numpy as jnp
from PIL import Image, ImageDraw, ImageFont
from streaming.jax_typography import (
render_text_strip, place_text_strip_jax, _load_font
)
FONTS = {
'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf',
'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf',
'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf',
'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf',
'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf',
}
def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with PIL directly, including color/opacity."""
frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
# For opacity, render to RGBA then composite
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(txt_layer)
font = _load_font(font_path, font_size)
if stroke_fill is None:
stroke_fill = (0, 0, 0)
# PIL fill with alpha for opacity
alpha_int = int(round(opacity * 255))
fill_rgba = (*fill, alpha_int)
stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None
if multiline:
draw.multiline_text((x, y), text, fill=fill_rgba, font=font,
stroke_width=stroke_width, stroke_fill=stroke_rgba,
spacing=line_spacing, align=align, anchor=anchor)
else:
draw.text((x, y), text, fill=fill_rgba, font=font,
stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor)
# Composite onto background
bg_img = Image.fromarray(frame).convert('RGBA')
result = Image.alpha_composite(bg_img, txt_layer)
return np.array(result.convert('RGB'))
def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with TextStrip system."""
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
strip = render_text_strip(
text, font_path, font_size,
stroke_width=stroke_width, stroke_fill=stroke_fill,
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
)
strip_img = jnp.asarray(strip.image)
color = jnp.array(fill, dtype=jnp.float32)
result = place_text_strip_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color, opacity,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
return np.array(result)
def compare(name, tolerance=0, **kwargs):
"""Compare PIL and TextStrip rendering."""
pil = render_pil(**kwargs)
strip = render_strip(**kwargs)
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
max_diff = diff.max()
pixels_diff = (diff > 0).any(axis=2).sum()
if max_diff == 0:
print(f" PASS: {name}")
return True
if tolerance > 0:
best_diff = diff.copy()
for dy in range(-tolerance, tolerance + 1):
for dx in range(-tolerance, tolerance + 1):
if dy == 0 and dx == 0:
continue
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
best_diff = np.minimum(best_diff, sdiff)
if best_diff.max() == 0:
print(f" PASS: {name} (within {tolerance}px tolerance)")
return True
print(f" FAIL: {name}")
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png")
print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png")
return False
def test_colors():
"""Test various text colors on various backgrounds."""
print("\n--- Colors ---")
results = []
# White on black (baseline)
results.append(compare("color_white_on_black",
text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0)))
# Red on black
results.append(compare("color_red",
text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0)))
# Green on black
results.append(compare("color_green",
text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0)))
# Blue on black
results.append(compare("color_blue",
text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0)))
# Yellow on dark gray
results.append(compare("color_yellow_on_gray",
text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40)))
# Magenta on white
results.append(compare("color_magenta_on_white",
text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255)))
# Subtle: gray text on slightly lighter gray
results.append(compare("color_subtle_gray",
text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64)))
# Orange on deep blue
results.append(compare("color_orange_on_blue",
text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80)))
return results
def test_opacity():
"""Test different opacity levels."""
print("\n--- Opacity ---")
results = []
results.append(compare("opacity_100",
text="Full", x=20, y=30, opacity=1.0))
results.append(compare("opacity_75",
text="75%", x=20, y=30, opacity=0.75))
results.append(compare("opacity_50",
text="Half", x=20, y=30, opacity=0.5))
results.append(compare("opacity_25",
text="Quarter", x=20, y=30, opacity=0.25))
results.append(compare("opacity_10",
text="Ghost", x=20, y=30, opacity=0.1))
# Opacity on colored background
results.append(compare("opacity_on_colored_bg",
text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0),
opacity=0.5))
# Color + opacity combo
results.append(compare("opacity_red_on_green",
text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0),
opacity=0.6))
return results
def test_fonts():
"""Test different fonts and sizes."""
print("\n--- Fonts & Sizes ---")
results = []
for label, path in FONTS.items():
results.append(compare(f"font_{label}",
text="Quick Fox", x=20, y=30, font_path=path, font_size=28,
frame_size=(300, 80)))
# Tiny text
results.append(compare("size_tiny",
text="Tiny text at 12px", x=10, y=15, font_size=12,
frame_size=(200, 40)))
# Big text
results.append(compare("size_big",
text="BIG", x=20, y=10, font_size=72,
frame_size=(300, 100)))
# Huge text
results.append(compare("size_huge",
text="XL", x=10, y=10, font_size=120,
frame_size=(300, 160)))
return results
def test_anchors():
"""Test all anchor combinations."""
print("\n--- Anchors ---")
results = []
# All horizontal x vertical combos
for h in ['l', 'm', 'r']:
for v in ['a', 'm', 's', 'd']:
anchor = h + v
results.append(compare(f"anchor_{anchor}",
text="Anchor", x=200, y=50, anchor=anchor,
frame_size=(400, 100)))
return results
def test_strokes():
"""Test various stroke widths and colors."""
print("\n--- Strokes ---")
results = []
for sw in [1, 2, 3, 4, 6, 8]:
results.append(compare(f"stroke_w{sw}",
text="Stroke", x=30, y=20, font_size=40,
stroke_width=sw, stroke_fill=(0, 0, 0),
frame_size=(300, 80)))
# Colored strokes
results.append(compare("stroke_red",
text="Red outline", x=20, y=20, font_size=36,
stroke_width=3, stroke_fill=(255, 0, 0),
frame_size=(350, 80)))
results.append(compare("stroke_white_on_black",
text="Glow", x=20, y=20, font_size=40,
fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255),
frame_size=(250, 80)))
# Stroke with bold font
results.append(compare("stroke_bold",
text="Bold+Stroke", x=20, y=20,
font_path=FONTS['bold'], font_size=36,
stroke_width=3, stroke_fill=(0, 0, 0),
frame_size=(400, 80)))
# Stroke + colored text on colored bg
results.append(compare("stroke_colored_on_bg",
text="Party", x=20, y=20, font_size=48,
fill=(255, 255, 0), bg=(50, 0, 80),
stroke_width=3, stroke_fill=(255, 0, 0),
frame_size=(300, 80)))
return results
def test_edge_clipping():
"""Test text at frame edges - clipping behavior."""
print("\n--- Edge Clipping ---")
results = []
# Text at very left edge
results.append(compare("clip_left_edge",
text="LEFT", x=0, y=30, frame_size=(200, 80)))
# Text partially off right edge
results.append(compare("clip_right_edge",
text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80)))
# Text at top edge
results.append(compare("clip_top",
text="TOP", x=20, y=0, frame_size=(200, 80)))
# Text at bottom edge - partially clipped
results.append(compare("clip_bottom",
text="BOTTOM", x=20, y=55, font_size=40,
frame_size=(200, 80)))
# Large text overflowing all sides from center
results.append(compare("clip_overflow_center",
text="HUGE", x=75, y=40, font_size=100, anchor="mm",
frame_size=(150, 80)))
# Corner placement
results.append(compare("clip_corner_br",
text="Corner", x=350, y=70, font_size=36,
frame_size=(400, 100)))
return results
def test_multiline_fancy():
"""Test multiline with various styles."""
print("\n--- Multiline Fancy ---")
results = []
# Right-aligned (1px tolerance: same sub-pixel issue as center alignment)
results.append(compare("multi_right",
text="Right\nAligned\nText", x=380, y=20,
frame_size=(400, 150),
multiline=True, anchor="ra", align="right",
tolerance=1))
# Center + stroke
results.append(compare("multi_center_stroke",
text="Title\nSubtitle", x=200, y=20,
font_size=32, frame_size=(400, 120),
multiline=True, anchor="ma", align="center",
stroke_width=2, stroke_fill=(0, 0, 0),
tolerance=1))
# Wide line spacing
results.append(compare("multi_wide_spacing",
text="Line A\nLine B\nLine C", x=20, y=10,
frame_size=(300, 200),
multiline=True, line_spacing=20))
# Zero extra spacing
results.append(compare("multi_tight_spacing",
text="Tight\nPacked\nLines", x=20, y=10,
frame_size=(300, 150),
multiline=True, line_spacing=0))
# Many lines
results.append(compare("multi_many_lines",
text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5,
font_size=20, frame_size=(200, 200),
multiline=True, line_spacing=4))
# Bold multiline with stroke
results.append(compare("multi_bold_stroke",
text="BOLD\nSTROKE", x=20, y=10,
font_path=FONTS['bold'], font_size=48,
stroke_width=3, stroke_fill=(200, 0, 0),
frame_size=(350, 150), multiline=True))
# Multiline on colored bg with opacity
results.append(compare("multi_opacity_on_bg",
text="Semi\nTransparent", x=20, y=10,
fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7,
frame_size=(300, 120), multiline=True))
return results
def test_special_chars():
"""Test special characters and edge cases."""
print("\n--- Special Characters ---")
results = []
# Numbers and symbols
results.append(compare("chars_numbers",
text="0123456789", x=20, y=30, frame_size=(300, 80)))
results.append(compare("chars_punctuation",
text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80)))
results.append(compare("chars_symbols",
text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80)))
# Single character
results.append(compare("chars_single",
text="X", x=50, y=30, font_size=48, frame_size=(100, 80)))
# Very long text (clipped)
results.append(compare("chars_long",
text="The quick brown fox jumps over the lazy dog", x=10, y=30,
font_size=24, frame_size=(400, 80)))
# Mixed case
results.append(compare("chars_mixed_case",
text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80)))
return results
def test_combos():
"""Complex combinations of features."""
print("\n--- Combos ---")
results = []
# Big bold stroke + color + opacity
results.append(compare("combo_all_features",
text="EPIC", x=20, y=10,
font_path=FONTS['bold'], font_size=64,
fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85,
stroke_width=4, stroke_fill=(180, 0, 0),
frame_size=(350, 100)))
# Small mono on busy background
results.append(compare("combo_mono_code",
text="fn main() {}", x=10, y=15,
font_path=FONTS['mono'], font_size=16,
fill=(0, 255, 100), bg=(30, 30, 30),
frame_size=(250, 50)))
# Serif italic multiline with stroke
results.append(compare("combo_serif_italic_multi",
text="Once upon\na time...", x=20, y=10,
font_path=FONTS['italic'], font_size=28,
stroke_width=1, stroke_fill=(80, 80, 80),
frame_size=(300, 120), multiline=True))
# Narrow font, big stroke, center anchored
results.append(compare("combo_narrow_stroke_center",
text="NARROW", x=150, y=40,
font_path=FONTS['narrow'], font_size=44,
stroke_width=5, stroke_fill=(0, 0, 0),
anchor="mm", frame_size=(300, 80)))
# Multiple strips on same frame (simulated by sequential placement)
results.append(compare("combo_opacity_blend",
text="Layered", x=20, y=30,
fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5,
font_size=48, frame_size=(300, 80)))
return results
def test_multi_strip_overlay():
"""Test placing multiple strips on the same frame."""
print("\n--- Multi-Strip Overlay ---")
results = []
frame_size = (400, 150)
bg = (20, 20, 40)
# PIL version - multiple draw calls
font1 = _load_font(FONTS['bold'], 48)
font2 = _load_font(None, 24)
font3 = _load_font(FONTS['mono'], 18)
pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(txt_layer)
draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1,
stroke_width=2, stroke_fill=(0, 0, 0, 255))
draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2)
draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3)
bg_img = Image.fromarray(pil_frame).convert('RGBA')
pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB'))
# Strip version - multiple placements
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0))
s2 = render_text_strip("Subtitle here", None, 24)
s3 = render_text_strip("code_snippet()", FONTS['mono'], 18)
frame = place_text_strip_jax(
frame, jnp.asarray(s1.image), 20, 10,
s1.baseline_y, s1.bearing_x,
jnp.array([255, 255, 0], dtype=jnp.float32), 1.0,
anchor_x=s1.anchor_x, anchor_y=s1.anchor_y,
stroke_width=s1.stroke_width)
frame = place_text_strip_jax(
frame, jnp.asarray(s2.image), 20, 70,
s2.baseline_y, s2.bearing_x,
jnp.array([200, 200, 200], dtype=jnp.float32), 1.0,
anchor_x=s2.anchor_x, anchor_y=s2.anchor_y,
stroke_width=s2.stroke_width)
frame = place_text_strip_jax(
frame, jnp.asarray(s3.image), 20, 110,
s3.baseline_y, s3.bearing_x,
jnp.array([0, 255, 128], dtype=jnp.float32), 200/255,
anchor_x=s3.anchor_x, anchor_y=s3.anchor_y,
stroke_width=s3.stroke_width)
strip_result = np.array(frame)
diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16))
max_diff = diff.max()
pixels_diff = (diff > 0).any(axis=2).sum()
if max_diff <= 1:
print(f" PASS: multi_overlay (max_diff={max_diff})")
results.append(True)
else:
print(f" FAIL: multi_overlay")
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png")
Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png")
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png")
print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png")
results.append(False)
return results
def main():
print("=" * 60)
print("Funky TextStrip vs PIL Comparison")
print("=" * 60)
all_results = []
all_results.extend(test_colors())
all_results.extend(test_opacity())
all_results.extend(test_fonts())
all_results.extend(test_anchors())
all_results.extend(test_strokes())
all_results.extend(test_edge_clipping())
all_results.extend(test_multiline_fancy())
all_results.extend(test_special_chars())
all_results.extend(test_combos())
all_results.extend(test_multi_strip_overlay())
print("\n" + "=" * 60)
passed = sum(all_results)
total = len(all_results)
print(f"Results: {passed}/{total} passed")
if passed == total:
print("ALL TESTS PASSED!")
else:
failed = [i for i, r in enumerate(all_results) if not r]
print(f"FAILED: {total - passed} tests (indices: {failed})")
print("=" * 60)
if __name__ == "__main__":
main()

183
test_pil_options.py Normal file
View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
"""
Explore PIL text options and test if we can match them.
"""
import numpy as np
from PIL import Image, ImageDraw, ImageFont
def load_font(font_name=None, font_size=32):
"""Load a font."""
candidates = [
font_name,
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf',
]
for path in candidates:
if path is None:
continue
try:
return ImageFont.truetype(path, font_size)
except (IOError, OSError):
continue
return ImageFont.load_default()
def test_pil_options():
"""Test various PIL text options."""
# Create a test frame
frame_size = (600, 400)
font = load_font(None, 36)
font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36)
font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36)
tests = []
# Test 1: Basic text
def basic_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font)
return img, "basic"
tests.append(basic_text)
# Test 2: Stroke/outline
def stroke_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font,
stroke_width=2, stroke_fill=(255, 0, 0, 255))
return img, "stroke"
tests.append(stroke_text)
# Test 3: Bold font
def bold_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold)
return img, "bold"
tests.append(bold_text)
# Test 4: Italic font
def italic_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic)
return img, "italic"
tests.append(italic_text)
# Test 5: Different anchors
def anchor_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Draw crosshairs at anchor points
for x in [100, 300, 500]:
draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255))
draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255))
draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm")
draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm")
draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm")
return img, "anchor"
tests.append(anchor_text)
# Test 6: Multiline text
def multiline_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three",
fill=(255, 255, 255, 255), font=font, spacing=10)
return img, "multiline"
tests.append(multiline_text)
# Test 7: Semi-transparent text
def alpha_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font)
draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font)
draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font)
return img, "alpha"
tests.append(alpha_text)
# Test 8: Colored text
def colored_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font)
draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font)
draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font)
draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font)
return img, "colored"
tests.append(colored_text)
# Test 9: Large stroke
def large_stroke():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font,
stroke_width=5, stroke_fill=(0, 0, 0, 255))
return img, "large_stroke"
tests.append(large_stroke)
# Test 10: Emoji (if supported)
def emoji_text():
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
try:
# Try to find an emoji font
emoji_font = None
emoji_paths = [
'/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf',
'/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf',
]
for p in emoji_paths:
try:
emoji_font = ImageFont.truetype(p, 36)
break
except:
pass
if emoji_font:
draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font)
else:
draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font)
except Exception as e:
draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font)
return img, "emoji"
tests.append(emoji_text)
# Run all tests
print("PIL Text Options Test")
print("=" * 60)
for test_fn in tests:
img, name = test_fn()
fname = f"/tmp/pil_test_{name}.png"
img.save(fname)
print(f"Saved: {fname}")
print("\nCheck /tmp/pil_test_*.png for results")
# Print available parameters
print("\n" + "=" * 60)
print("PIL draw.text() parameters:")
print(" - xy: position tuple")
print(" - text: string to draw")
print(" - fill: color (R,G,B) or (R,G,B,A)")
print(" - font: ImageFont object")
print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)")
print(" - spacing: line spacing for multiline")
print(" - align: 'left', 'center', 'right' for multiline")
print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)")
print(" - features: OpenType features list")
print(" - language: language code for shaping")
print(" - stroke_width: outline width in pixels")
print(" - stroke_fill: outline color")
print(" - embedded_color: use embedded color glyphs (emoji)")
if __name__ == "__main__":
test_pil_options()

176
test_styled_text.py Normal file
View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python3
"""
Test styled TextStrip rendering against PIL.
"""
import numpy as np
import jax.numpy as jnp
from PIL import Image, ImageDraw, ImageFont
from streaming.jax_typography import (
render_text_strip, place_text_strip_jax, _load_font
)
def render_pil(text, x, y, font_size=36, frame_size=(400, 100),
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with PIL directly."""
frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8)
img = Image.fromarray(frame)
draw = ImageDraw.Draw(img)
font = _load_font(None, font_size)
# Default stroke fill
if stroke_fill is None:
stroke_fill = (0, 0, 0)
if multiline:
draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill,
spacing=line_spacing, align=align, anchor=anchor)
else:
draw.text((x, y), text, fill=(255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor)
return np.array(img)
def render_strip(text, x, y, font_size=36, frame_size=(400, 100),
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with TextStrip."""
frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8)
strip = render_text_strip(
text, None, font_size,
stroke_width=stroke_width, stroke_fill=stroke_fill,
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
)
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
return np.array(result)
def compare(name, text, x, y, font_size=36, frame_size=(400, 100),
tolerance=0, **kwargs):
"""Compare PIL and TextStrip rendering.
tolerance=0: exact pixel match required
tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences
in center-aligned multiline text where the strip is pre-rendered
at a different base position than the final placement)
"""
pil = render_pil(text, x, y, font_size, frame_size, **kwargs)
strip = render_strip(text, x, y, font_size, frame_size, **kwargs)
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
max_diff = diff.max()
pixels_diff = (diff > 0).any(axis=2).sum()
if max_diff == 0:
print(f"PASS: {name}")
print(f" Max diff: 0, Pixels different: 0")
return True
if tolerance > 0:
# Check if the difference is just a sub-pixel position shift:
# for each shifted version, compute the minimum diff
best_diff = diff.copy()
for dy in range(-tolerance, tolerance + 1):
for dx in range(-tolerance, tolerance + 1):
if dy == 0 and dx == 0:
continue
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
best_diff = np.minimum(best_diff, sdiff)
max_shift_diff = best_diff.max()
pixels_shift_diff = (best_diff > 0).any(axis=2).sum()
if max_shift_diff == 0:
print(f"PASS: {name} (within {tolerance}px position tolerance)")
print(f" Raw diff: {max_diff}, After shift tolerance: 0")
return True
status = "FAIL"
print(f"{status}: {name}")
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
# Save debug images
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8)
Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png")
print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png")
return False
def main():
print("=" * 60)
print("Styled TextStrip vs PIL Comparison")
print("=" * 60)
results = []
# Basic text
results.append(compare("basic", "Hello World", 20, 50))
# Stroke/outline
results.append(compare("stroke_2", "Outlined", 20, 50,
stroke_width=2, stroke_fill=(255, 0, 0)))
results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48,
frame_size=(500, 120),
stroke_width=5, stroke_fill=(0, 0, 0)))
# Anchors - center
results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100),
anchor="mm"))
# Anchors - right
results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100),
anchor="rm"))
# Multiline
results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20,
frame_size=(400, 150),
multiline=True, line_spacing=8))
# Multiline centered (1px tolerance: sub-pixel rendering differs because
# the strip is pre-rendered at an integer position while PIL's center
# alignment uses fractional getlength values for the 'm' anchor shift)
results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20,
frame_size=(400, 150),
multiline=True, anchor="ma", align="center",
tolerance=1))
# Stroke + multiline
results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20,
frame_size=(400, 120),
stroke_width=2, stroke_fill=(0, 0, 255),
multiline=True))
print("=" * 60)
passed = sum(results)
total = len(results)
print(f"Results: {passed}/{total} passed")
if passed == total:
print("ALL TESTS PASSED!")
else:
print(f"FAILED: {total - passed} tests")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,517 @@
#!/usr/bin/env python3
"""Integration tests comparing JAX and Python rendering pipelines.
These tests ensure the JAX-compiled effect chains produce identical output
to the Python/NumPy path. They test:
1. Full effect pipelines through both interpreters
2. Multi-frame sequences (to catch phase-dependent bugs)
3. Compiled effect chain fusion
4. Edge cases like shrinking/zooming that affect boundary handling
"""
import os
import sys
import pytest
import numpy as np
import shutil
from pathlib import Path
# Ensure the art-celery module is importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sexp_effects.primitive_libs import core as core_mod
# Path to test resources
TEST_DIR = Path('/home/giles/art/test')
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
TEMPLATES_DIR = TEST_DIR / 'templates'
def create_test_image(h=96, w=128):
"""Create a test image with distinct patterns."""
import cv2
img = np.zeros((h, w, 3), dtype=np.uint8)
# Create gradient background
for y in range(h):
for x in range(w):
img[y, x] = [
int(255 * x / w), # R: horizontal gradient
int(255 * y / h), # G: vertical gradient
128 # B: constant
]
# Add features
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
return img
@pytest.fixture(scope='module')
def test_env(tmp_path_factory):
"""Set up test environment with sexp files and test media."""
test_dir = tmp_path_factory.mktemp('sexp_test')
original_dir = os.getcwd()
os.chdir(test_dir)
# Create directories
(test_dir / 'effects').mkdir()
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
# Create test image
import cv2
test_img = create_test_image()
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
# Copy required effect files
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
src = EFFECTS_DIR / f'{effect}.sexp'
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
if src.exists():
shutil.copy(src, dst)
yield {
'dir': test_dir,
'image_path': test_dir / 'test_image.png',
'test_img': test_img,
}
os.chdir(original_dir)
def create_sexp_file(test_dir, content, filename='test.sexp'):
"""Create a test sexp file."""
path = test_dir / 'effects' / filename
with open(path, 'w') as f:
f.write(content)
return str(path)
class TestJaxPythonPipelineEquivalence:
"""Test that JAX and Python pipelines produce equivalent output."""
def test_single_rotate_effect(self, test_env):
"""Test that a single rotate effect matches between Python and JAX."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(frame (rotate frame :angle 15 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
# Python path
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
# JAX path
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
# Inject test image into globals
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
# Force deferred if needed
py_result = np.asarray(py_interp._maybe_force(py_result))
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
def test_rotate_then_zoom(self, test_env):
"""Test rotate followed by zoom - tests effect chain fusion."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (-> (rotate frame :angle 15 :speed 0)
(zoom :amount 0.95 :speed 0)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
def test_zoom_shrink_boundary_handling(self, test_env):
"""Test zoom with shrinking factor - critical for boundary handling."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame (zoom frame :amount 0.8 :speed 0))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
# Check corners specifically - these are most affected by boundary handling
h, w = test_img.shape[:2]
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
for y, x in corners:
py_val = py_result[y, x]
jax_val = jax_result[y, x]
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
def test_blend_two_clips(self, test_env):
"""Test blending two effect chains - the core bug scenario."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
max_diff = np.max(diff)
# Check edge region specifically
edge_diff = diff[0, :].mean()
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
def test_blend_with_invert(self, test_env):
"""Test blending with invert - matches the problematic recipe pattern."""
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(require-primitives "core")
(require-primitives "image")
(require-primitives "blending")
(require-primitives "color_ops")
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(effect blend :path "../sexp_effects/effects/blend.sexp")
(effect invert :path "../sexp_effects/effects/invert.sexp")
(frame
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
(zoom :amount 1.05 :speed 0)
(invert :amount 1))
clip_b (-> (rotate frame :angle -5 :speed 0)
(zoom :amount 0.95 :speed 0))]
(blend :base clip_a :overlay clip_b :opacity 0.5)))
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
from streaming.stream_sexp_generic import StreamInterpreter, Context
import cv2
test_img = cv2.imread(str(test_env['image_path']))
core_mod.set_random_seed(42)
py_interp = StreamInterpreter(sexp_path, use_jax=False)
py_interp._init()
core_mod.set_random_seed(42)
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
jax_interp._init()
ctx = Context(fps=10)
ctx.t = 0.5
ctx.frame_num = 5
frame_env = {
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
't': ctx.t, 'frame-num': ctx.frame_num,
}
py_interp.globals['frame'] = test_img
jax_interp.globals['frame'] = test_img
py_result = np.asarray(py_interp._maybe_force(
py_interp._eval(py_interp.frame_pipeline, frame_env)))
jax_result = np.asarray(jax_interp._maybe_force(
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
class TestDeferredEffectChainFusion:
"""Test the DeferredEffectChain fusion mechanism specifically."""
def test_manual_vs_fused_chain(self, test_env):
"""Compare manual application vs fused DeferredEffectChain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
# Create minimal sexp to load effects
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Manual step-by-step application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
rot_angle = -5.0
zoom_amount = 0.95
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain application
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
mean_diff = np.mean(diff)
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
# Check specific pixels
h, w = test_img.shape[:2]
for y in [0, h//2, h-1]:
for x in [0, w//2, w-1]:
manual_val = manual_result[y, x]
fused_val = fused_result[y, x]
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
def test_chain_with_shrink_zoom_boundary(self, test_env):
"""Test that shrinking zoom handles boundaries correctly in chain."""
import jax.numpy as jnp
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
sexp_content = '''(stream "test"
:width 128
:height 96
:seed 42
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
(frame frame)
)
'''
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
core_mod.set_random_seed(42)
interp = StreamInterpreter(sexp_path, use_jax=True)
interp._init()
test_img = test_env['test_img']
jax_frame = jnp.array(test_img)
t = 0.5
frame_num = 5
# Parameters that shrink the image (zoom < 1.0)
rot_angle = -4.555
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
# Manual application
rotate_fn = interp.jax_effects['rotate']
zoom_fn = interp.jax_effects['zoom']
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
angle=rot_angle, speed=0)
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
amount=zoom_amount, speed=0)
manual_result = np.asarray(manual_result)
# Fused chain
chain = DeferredEffectChain(
['rotate'],
[{'angle': rot_angle, 'speed': 0}],
jax_frame, t, frame_num
)
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
fused_result = np.asarray(interp._force_deferred(chain))
# Check top edge specifically - this is where boundary issues manifest
top_edge_manual = manual_result[0, :]
top_edge_fused = fused_result[0, :]
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
mean_edge_diff = np.mean(edge_diff)
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
# Check for zeros at edges that shouldn't be there
manual_edge_sum = np.sum(top_edge_manual)
fused_edge_sum = np.sum(top_edge_fused)
if manual_edge_sum > 100: # If manual has significant values
assert fused_edge_sum > manual_edge_sum * 0.5, \
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
Test framework to verify JAX primitives match Python primitives.
Compares output of each primitive through:
1. Python/NumPy path
2. JAX path (CPU)
3. JAX path (GPU) - if available
Reports any mismatches with detailed diffs.
"""
import sys
sys.path.insert(0, '/home/giles/art/art-celery')
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass, field
# Test configuration
TEST_WIDTH = 64
TEST_HEIGHT = 48
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
@dataclass
class TestResult:
primitive: str
passed: bool
python_mean: float = 0.0
jax_mean: float = 0.0
diff_mean: float = 0.0
diff_max: float = 0.0
pct_within_1: float = 0.0
error: str = ""
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
"""Create a test frame with known pattern."""
if pattern == 'gradient':
# Diagonal gradient
y, x = np.mgrid[0:h, 0:w]
r = (x * 255 / w).astype(np.uint8)
g = (y * 255 / h).astype(np.uint8)
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
return np.stack([r, g, b], axis=2)
elif pattern == 'checkerboard':
y, x = np.mgrid[0:h, 0:w]
check = ((x // 8) + (y // 8)) % 2
v = (check * 255).astype(np.uint8)
return np.stack([v, v, v], axis=2)
elif pattern == 'solid':
return np.full((h, w, 3), 128, dtype=np.uint8)
else:
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
if py_out is None or jax_out is None:
return float('inf'), float('inf'), 0.0
if isinstance(py_out, dict) and isinstance(jax_out, dict):
# Compare coordinate maps
diffs = []
for k in py_out:
if k in jax_out:
py_arr = np.asarray(py_out[k])
jax_arr = np.asarray(jax_out[k])
if py_arr.shape == jax_arr.shape:
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
diffs.append(diff)
if diffs:
all_diff = np.concatenate([d.flatten() for d in diffs])
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
return float('inf'), float('inf'), 0.0
py_arr = np.asarray(py_out)
jax_arr = np.asarray(jax_out)
if py_arr.shape != jax_arr.shape:
return float('inf'), float('inf'), 0.0
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
# ============================================================================
# Primitive Test Definitions
# ============================================================================
PRIMITIVE_TESTS = {
# Geometry primitives
'geometry:ripple-displace': {
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
'returns': 'coords',
},
'geometry:rotate-img': {
'args': ['frame', 45],
'returns': 'frame',
},
'geometry:scale-img': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'geometry:flip-h': {
'args': ['frame'],
'returns': 'frame',
},
'geometry:flip-v': {
'args': ['frame'],
'returns': 'frame',
},
# Color operations
'color_ops:invert': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:grayscale': {
'args': ['frame'],
'returns': 'frame',
},
'color_ops:brightness': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:contrast': {
'args': ['frame', 1.5],
'returns': 'frame',
},
'color_ops:hue-shift': {
'args': ['frame', 90],
'returns': 'frame',
},
# Image operations
'image:width': {
'args': ['frame'],
'returns': 'scalar',
},
'image:height': {
'args': ['frame'],
'returns': 'scalar',
},
'image:channel': {
'args': ['frame', 0],
'returns': 'array',
},
# Blending
'blending:blend': {
'args': ['frame', 'frame2', 0.5],
'returns': 'frame',
},
'blending:blend-add': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
'blending:blend-multiply': {
'args': ['frame', 'frame2'],
'returns': 'frame',
},
}
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the Python interpreter."""
if prim_name not in interp.primitives:
return None
func = interp.primitives[prim_name]
args = []
for a in test_def['args']:
if a == 'frame':
args.append(frame.copy())
elif a == 'frame2':
args.append(frame2.copy())
else:
args.append(a)
try:
return func(*args)
except Exception as e:
return None
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
"""Run a primitive through the JAX compiler."""
try:
from streaming.sexp_to_jax import JaxCompiler
import jax.numpy as jnp
compiler = JaxCompiler()
# Build a simple expression to test the primitive
from sexp_effects.parser import Symbol, Keyword
args = []
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
for a in test_def['args']:
if a == 'frame':
args.append(Symbol('frame'))
elif a == 'frame2':
args.append(Symbol('frame2'))
else:
args.append(a)
# Create expression: (prim_name arg1 arg2 ...)
expr = [Symbol(prim_name)] + args
result = compiler._eval(expr, env)
if hasattr(result, '__array__'):
return np.asarray(result)
return result
except Exception as e:
return None
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
"""Test a single primitive."""
frame = create_test_frame(pattern='gradient')
frame2 = create_test_frame(pattern='checkerboard')
result = TestResult(primitive=prim_name, passed=False)
# Run Python version
try:
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
if py_out is not None and hasattr(py_out, 'shape'):
result.python_mean = float(np.mean(py_out))
except Exception as e:
result.error = f"Python error: {e}"
return result
# Run JAX version
try:
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
if jax_out is not None and hasattr(jax_out, 'shape'):
result.jax_mean = float(np.mean(jax_out))
except Exception as e:
result.error = f"JAX error: {e}"
return result
if py_out is None:
result.error = "Python returned None"
return result
if jax_out is None:
result.error = "JAX returned None"
return result
# Compare
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
result.diff_mean = diff_mean
result.diff_max = diff_max
result.pct_within_1 = pct
# Check pass/fail
result.passed = (
diff_mean <= TOLERANCE_MEAN and
diff_max <= TOLERANCE_MAX and
pct >= TOLERANCE_PCT
)
if not result.passed:
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
return result
def discover_primitives(interp) -> List[str]:
"""Discover all primitives available in the interpreter."""
return sorted(interp.primitives.keys())
def run_all_tests(verbose=True):
"""Run all primitive tests."""
import warnings
warnings.filterwarnings('ignore')
import os
os.chdir('/home/giles/art/test')
from streaming.stream_sexp_generic import StreamInterpreter
from sexp_effects.primitive_libs import core as core_mod
core_mod.set_random_seed(42)
# Create interpreter to get Python primitives
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
interp._init()
results = []
print("=" * 60)
print("JAX PRIMITIVE TEST SUITE")
print("=" * 60)
# Test defined primitives
for prim_name, test_def in PRIMITIVE_TESTS.items():
result = test_primitive(interp, prim_name, test_def)
results.append(result)
status = "✓ PASS" if result.passed else "✗ FAIL"
if verbose:
print(f"{status} {prim_name}")
if not result.passed:
print(f" {result.error}")
# Summary
passed = sum(1 for r in results if r.passed)
failed = sum(1 for r in results if not r.passed)
print("\n" + "=" * 60)
print(f"SUMMARY: {passed} passed, {failed} failed")
print("=" * 60)
if failed > 0:
print("\nFailed primitives:")
for r in results:
if not r.passed:
print(f" - {r.primitive}: {r.error}")
return results
if __name__ == '__main__':
run_all_tests()

305
tests/test_xector.py Normal file
View File

@@ -0,0 +1,305 @@
"""
Tests for xector primitives - parallel array operations.
"""
import pytest
import numpy as np
from sexp_effects.primitive_libs.xector import (
Xector,
xector_red, xector_green, xector_blue, xector_rgb,
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
xector_dist_from_center,
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
alpha_sin, alpha_cos, alpha_sq,
alpha_lt, alpha_gt, alpha_eq,
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
xector_where, xector_fill, xector_zeros, xector_ones,
is_xector,
)
class TestXectorBasics:
"""Test Xector class basic operations."""
def test_create_from_list(self):
x = Xector([1, 2, 3])
assert len(x) == 3
assert is_xector(x)
def test_create_from_numpy(self):
arr = np.array([1.0, 2.0, 3.0])
x = Xector(arr)
assert len(x) == 3
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
def test_implicit_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = a + b
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_implicit_mul(self):
a = Xector([1, 2, 3])
b = Xector([2, 2, 2])
c = a * b
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_scalar_broadcast(self):
a = Xector([1, 2, 3])
c = a + 10
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
def test_scalar_broadcast_rmul(self):
a = Xector([1, 2, 3])
c = 2 * a
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
class TestAlphaOperations:
"""Test α (element-wise) operations."""
def test_alpha_add(self):
a = Xector([1, 2, 3])
b = Xector([4, 5, 6])
c = alpha_add(a, b)
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
def test_alpha_add_multi(self):
a = Xector([1, 2, 3])
b = Xector([1, 1, 1])
c = Xector([10, 10, 10])
d = alpha_add(a, b, c)
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
def test_alpha_mul_scalar(self):
a = Xector([1, 2, 3])
c = alpha_mul(a, 2)
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
def test_alpha_sqrt(self):
a = Xector([1, 4, 9, 16])
c = alpha_sqrt(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
def test_alpha_clamp(self):
a = Xector([-5, 0, 5, 10, 15])
c = alpha_clamp(a, 0, 10)
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
def test_alpha_sin_cos(self):
a = Xector([0, np.pi/2, np.pi])
s = alpha_sin(a)
c = alpha_cos(a)
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
def test_alpha_sq(self):
a = Xector([1, 2, 3, 4])
c = alpha_sq(a)
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
def test_alpha_comparison(self):
a = Xector([1, 2, 3, 4])
b = Xector([2, 2, 2, 2])
lt = alpha_lt(a, b)
gt = alpha_gt(a, b)
eq = alpha_eq(a, b)
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
class TestBetaOperations:
"""Test β (reduction) operations."""
def test_beta_add(self):
a = Xector([1, 2, 3, 4])
assert beta_add(a) == 10
def test_beta_mul(self):
a = Xector([1, 2, 3, 4])
assert beta_mul(a) == 24
def test_beta_min_max(self):
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
assert beta_min(a) == 1
assert beta_max(a) == 9
def test_beta_mean(self):
a = Xector([1, 2, 3, 4])
assert beta_mean(a) == 2.5
def test_beta_count(self):
a = Xector([1, 2, 3, 4, 5])
assert beta_count(a) == 5
class TestFrameConversion:
"""Test frame/xector conversion."""
def test_extract_channels(self):
# Create a 2x2 RGB frame
frame = np.array([
[[255, 0, 0], [0, 255, 0]],
[[0, 0, 255], [128, 128, 128]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
assert len(r) == 4
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
def test_rgb_roundtrip(self):
# Create a 2x2 RGB frame
frame = np.array([
[[100, 150, 200], [50, 75, 100]],
[[200, 100, 50], [25, 50, 75]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
reconstructed = xector_rgb(r, g, b)
np.testing.assert_array_equal(reconstructed, frame)
def test_modify_and_reconstruct(self):
frame = np.array([
[[100, 100, 100], [100, 100, 100]],
[[100, 100, 100], [100, 100, 100]]
], dtype=np.uint8)
r = xector_red(frame)
g = xector_green(frame)
b = xector_blue(frame)
# Double red channel
r_doubled = r * 2
result = xector_rgb(r_doubled, g, b)
# Red should be 200, others unchanged
assert result[0, 0, 0] == 200
assert result[0, 0, 1] == 100
assert result[0, 0, 2] == 100
class TestCoordinates:
"""Test coordinate generation."""
def test_x_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
x = xector_x_coords(frame)
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
def test_y_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
y = xector_y_coords(frame)
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
def test_normalized_coords(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_x_norm(frame)
y = xector_y_norm(frame)
# x should go 0 to 1 across width
assert x.to_numpy()[0] == 0
assert x.to_numpy()[2] == 1
# y should go 0 to 1 down height
assert y.to_numpy()[0] == 0
assert y.to_numpy()[3] == 1
class TestConditional:
"""Test conditional operations."""
def test_where(self):
cond = Xector([True, False, True, False])
true_val = Xector([1, 1, 1, 1])
false_val = Xector([0, 0, 0, 0])
result = xector_where(cond, true_val, false_val)
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
def test_where_with_comparison(self):
a = Xector([1, 5, 3, 7])
threshold = 4
# Elements > 4 become 255, others become 0
result = xector_where(alpha_gt(a, threshold), 255, 0)
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
def test_fill(self):
frame = np.zeros((2, 3, 3), dtype=np.uint8)
x = xector_fill(42, frame)
assert len(x) == 6
assert all(v == 42 for v in x.to_numpy())
def test_zeros_ones(self):
frame = np.zeros((2, 2, 3), dtype=np.uint8)
z = xector_zeros(frame)
o = xector_ones(frame)
assert all(v == 0 for v in z.to_numpy())
assert all(v == 1 for v in o.to_numpy())
class TestInterpreterIntegration:
"""Test xector operations through the interpreter."""
def test_xector_vignette_effect(self):
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load the xector vignette effect
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
# Create a test frame (white)
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
# Run effect
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
# Center should be brighter than corners
center = result[50, 50]
corner = result[0, 0]
assert center.mean() > corner.mean(), "Center should be brighter than corners"
# Corners should be darkened
assert corner.mean() < 255, "Corners should be darkened"
def test_implicit_elementwise(self):
"""Test that regular + works element-wise on xectors."""
from sexp_effects.interpreter import Interpreter
interp = Interpreter(minimal_primitives=True)
# Load xector primitives
from sexp_effects.primitive_libs.xector import PRIMITIVES
for name, fn in PRIMITIVES.items():
interp.global_env.set(name, fn)
# Parse and eval a simple xector expression
from sexp_effects.parser import parse
expr = parse('(+ (red frame) 10)')
# Create test frame
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
interp.global_env.set('frame', frame)
result = interp.eval(expr)
# Should be a xector with values 110
assert is_xector(result)
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
if __name__ == '__main__':
pytest.main([__file__, '-v'])