Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
477
path_registry.py
Normal file
477
path_registry.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Path Registry - Maps human-friendly paths to content-addressed IDs.
|
||||
|
||||
This module provides a bidirectional mapping between:
|
||||
- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp")
|
||||
- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes)
|
||||
|
||||
The registry is useful for:
|
||||
- Looking up effects by their friendly path name
|
||||
- Resolving cids back to the original path for debugging
|
||||
- Maintaining a stable naming scheme across cache updates
|
||||
|
||||
Storage:
|
||||
- Uses the existing item_types table in the database (path column)
|
||||
- Caches in Redis for fast lookups across distributed workers
|
||||
|
||||
The registry uses a system actor (@system@local) for global path mappings,
|
||||
allowing effects to be resolved by path without requiring user context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# System actor for global path mappings (effects, recipes, analyzers)
|
||||
SYSTEM_ACTOR = "@system@local"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathEntry:
|
||||
"""A registered path with its content-addressed ID."""
|
||||
path: str # Human-friendly path (relative or normalized)
|
||||
cid: str # Content-addressed ID (IPFS CID or hash)
|
||||
content_type: str # Type: "effect", "recipe", "analyzer", etc.
|
||||
actor_id: str = SYSTEM_ACTOR # Owner (system for global)
|
||||
description: Optional[str] = None
|
||||
created_at: float = 0.0
|
||||
|
||||
|
||||
class PathRegistry:
|
||||
"""
|
||||
Registry for mapping paths to content-addressed IDs.
|
||||
|
||||
Uses the existing item_types table for persistence and Redis
|
||||
for fast lookups in distributed Celery workers.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client=None):
|
||||
self._redis = redis_client
|
||||
self._redis_path_to_cid_key = "artdag:path_to_cid"
|
||||
self._redis_cid_to_path_key = "artdag:cid_to_path"
|
||||
|
||||
def _run_async(self, coro):
|
||||
"""Run async coroutine from sync context."""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import threading
|
||||
result = [None]
|
||||
error = [None]
|
||||
|
||||
def run_in_thread():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
result[0] = new_loop.run_until_complete(coro)
|
||||
finally:
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
error[0] = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join(timeout=30)
|
||||
if error[0]:
|
||||
raise error[0]
|
||||
return result[0]
|
||||
except RuntimeError:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def _normalize_path(self, path: str) -> str:
|
||||
"""Normalize a path for consistent storage."""
|
||||
# Remove leading ./ or /
|
||||
path = path.lstrip('./')
|
||||
# Normalize separators
|
||||
path = path.replace('\\', '/')
|
||||
# Remove duplicate slashes
|
||||
while '//' in path:
|
||||
path = path.replace('//', '/')
|
||||
return path
|
||||
|
||||
def register(
|
||||
self,
|
||||
path: str,
|
||||
cid: str,
|
||||
content_type: str = "effect",
|
||||
actor_id: str = SYSTEM_ACTOR,
|
||||
description: Optional[str] = None,
|
||||
) -> PathEntry:
|
||||
"""
|
||||
Register a path -> cid mapping.
|
||||
|
||||
Args:
|
||||
path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp")
|
||||
cid: Content-addressed ID (IPFS CID or hash)
|
||||
content_type: Type of content ("effect", "recipe", "analyzer")
|
||||
actor_id: Owner (default: system for global mappings)
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created PathEntry
|
||||
"""
|
||||
norm_path = self._normalize_path(path)
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
entry = PathEntry(
|
||||
path=norm_path,
|
||||
cid=cid,
|
||||
content_type=content_type,
|
||||
actor_id=actor_id,
|
||||
description=description,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
# Store in database (item_types table)
|
||||
self._save_to_db(entry)
|
||||
|
||||
# Update Redis cache
|
||||
self._update_redis_cache(norm_path, cid)
|
||||
|
||||
logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...")
|
||||
return entry
|
||||
|
||||
def _save_to_db(self, entry: PathEntry):
|
||||
"""Save entry to database using item_types table."""
|
||||
import database
|
||||
|
||||
async def save():
|
||||
import asyncpg
|
||||
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||
try:
|
||||
# Ensure cache_item exists
|
||||
await conn.execute(
|
||||
"INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING",
|
||||
entry.cid
|
||||
)
|
||||
# Insert or update item_type with path
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO item_types (cid, actor_id, type, path, description)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET
|
||||
description = COALESCE(EXCLUDED.description, item_types.description)
|
||||
""",
|
||||
entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description
|
||||
)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
try:
|
||||
self._run_async(save())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save path registry to DB: {e}")
|
||||
|
||||
def _update_redis_cache(self, path: str, cid: str):
|
||||
"""Update Redis cache with mapping."""
|
||||
if self._redis:
|
||||
try:
|
||||
self._redis.hset(self._redis_path_to_cid_key, path, cid)
|
||||
self._redis.hset(self._redis_cid_to_path_key, cid, path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update Redis cache: {e}")
|
||||
|
||||
def get_cid(self, path: str, content_type: str = None) -> Optional[str]:
|
||||
"""
|
||||
Get the cid for a path.
|
||||
|
||||
Args:
|
||||
path: Human-friendly path
|
||||
content_type: Optional type filter
|
||||
|
||||
Returns:
|
||||
The cid, or None if not found
|
||||
"""
|
||||
norm_path = self._normalize_path(path)
|
||||
|
||||
# Try Redis first (fast path)
|
||||
if self._redis:
|
||||
try:
|
||||
val = self._redis.hget(self._redis_path_to_cid_key, norm_path)
|
||||
if val:
|
||||
return val.decode() if isinstance(val, bytes) else val
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis lookup failed: {e}")
|
||||
|
||||
# Fall back to database
|
||||
return self._get_cid_from_db(norm_path, content_type)
|
||||
|
||||
def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]:
|
||||
"""Get cid from database using item_types table."""
|
||||
import database
|
||||
|
||||
async def get():
|
||||
import asyncpg
|
||||
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||
try:
|
||||
if content_type:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT cid FROM item_types WHERE path = $1 AND type = $2",
|
||||
path, content_type
|
||||
)
|
||||
else:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT cid FROM item_types WHERE path = $1",
|
||||
path
|
||||
)
|
||||
return row["cid"] if row else None
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
try:
|
||||
result = self._run_async(get())
|
||||
# Update Redis cache if found
|
||||
if result and self._redis:
|
||||
self._update_redis_cache(path, result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get from DB: {e}")
|
||||
return None
|
||||
|
||||
def get_path(self, cid: str) -> Optional[str]:
|
||||
"""
|
||||
Get the path for a cid.
|
||||
|
||||
Args:
|
||||
cid: Content-addressed ID
|
||||
|
||||
Returns:
|
||||
The path, or None if not found
|
||||
"""
|
||||
# Try Redis first
|
||||
if self._redis:
|
||||
try:
|
||||
val = self._redis.hget(self._redis_cid_to_path_key, cid)
|
||||
if val:
|
||||
return val.decode() if isinstance(val, bytes) else val
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis lookup failed: {e}")
|
||||
|
||||
# Fall back to database
|
||||
return self._get_path_from_db(cid)
|
||||
|
||||
def _get_path_from_db(self, cid: str) -> Optional[str]:
|
||||
"""Get path from database using item_types table."""
|
||||
import database
|
||||
|
||||
async def get():
|
||||
import asyncpg
|
||||
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||
try:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1",
|
||||
cid
|
||||
)
|
||||
return row["path"] if row else None
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
try:
|
||||
result = self._run_async(get())
|
||||
# Update Redis cache if found
|
||||
if result and self._redis:
|
||||
self._update_redis_cache(result, cid)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get from DB: {e}")
|
||||
return None
|
||||
|
||||
def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]:
|
||||
"""
|
||||
List all entries of a given type.
|
||||
|
||||
Args:
|
||||
content_type: Type to filter by ("effect", "recipe", etc.)
|
||||
actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global)
|
||||
|
||||
Returns:
|
||||
List of PathEntry objects
|
||||
"""
|
||||
import database
|
||||
|
||||
async def list_entries():
|
||||
import asyncpg
|
||||
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||
try:
|
||||
if actor_id:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT cid, path, type, actor_id, description,
|
||||
EXTRACT(EPOCH FROM created_at) as created_at
|
||||
FROM item_types
|
||||
WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL
|
||||
ORDER BY path
|
||||
""",
|
||||
content_type, actor_id
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT cid, path, type, actor_id, description,
|
||||
EXTRACT(EPOCH FROM created_at) as created_at
|
||||
FROM item_types
|
||||
WHERE type = $1 AND path IS NOT NULL
|
||||
ORDER BY path
|
||||
""",
|
||||
content_type
|
||||
)
|
||||
return [
|
||||
PathEntry(
|
||||
path=row["path"],
|
||||
cid=row["cid"],
|
||||
content_type=row["type"],
|
||||
actor_id=row["actor_id"],
|
||||
description=row["description"],
|
||||
created_at=row["created_at"] or 0,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
try:
|
||||
return self._run_async(list_entries())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list from DB: {e}")
|
||||
return []
|
||||
|
||||
def delete(self, path: str, content_type: str = None) -> bool:
|
||||
"""
|
||||
Delete a path registration.
|
||||
|
||||
Args:
|
||||
path: The path to delete
|
||||
content_type: Optional type filter
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
norm_path = self._normalize_path(path)
|
||||
|
||||
# Get cid for Redis cleanup
|
||||
cid = self.get_cid(norm_path, content_type)
|
||||
|
||||
# Delete from database
|
||||
deleted = self._delete_from_db(norm_path, content_type)
|
||||
|
||||
# Clean up Redis
|
||||
if deleted and cid and self._redis:
|
||||
try:
|
||||
self._redis.hdel(self._redis_path_to_cid_key, norm_path)
|
||||
self._redis.hdel(self._redis_cid_to_path_key, cid)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up Redis: {e}")
|
||||
|
||||
return deleted
|
||||
|
||||
def _delete_from_db(self, path: str, content_type: str = None) -> bool:
|
||||
"""Delete from database."""
|
||||
import database
|
||||
|
||||
async def delete():
|
||||
import asyncpg
|
||||
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||
try:
|
||||
if content_type:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM item_types WHERE path = $1 AND type = $2",
|
||||
path, content_type
|
||||
)
|
||||
else:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM item_types WHERE path = $1",
|
||||
path
|
||||
)
|
||||
return "DELETE" in result
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
try:
|
||||
return self._run_async(delete())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete from DB: {e}")
|
||||
return False
|
||||
|
||||
def register_effect(
|
||||
self,
|
||||
path: str,
|
||||
cid: str,
|
||||
description: Optional[str] = None,
|
||||
) -> PathEntry:
|
||||
"""
|
||||
Convenience method to register an effect.
|
||||
|
||||
Args:
|
||||
path: Effect path (e.g., "effects/ascii_fx_zone.sexp")
|
||||
cid: IPFS CID of the effect file
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created PathEntry
|
||||
"""
|
||||
return self.register(
|
||||
path=path,
|
||||
cid=cid,
|
||||
content_type="effect",
|
||||
actor_id=SYSTEM_ACTOR,
|
||||
description=description,
|
||||
)
|
||||
|
||||
def get_effect_cid(self, path: str) -> Optional[str]:
|
||||
"""
|
||||
Get CID for an effect by path.
|
||||
|
||||
Args:
|
||||
path: Effect path
|
||||
|
||||
Returns:
|
||||
IPFS CID or None
|
||||
"""
|
||||
return self.get_cid(path, content_type="effect")
|
||||
|
||||
def list_effects(self) -> List[PathEntry]:
|
||||
"""List all registered effects."""
|
||||
return self.list_by_type("effect", actor_id=SYSTEM_ACTOR)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_registry: Optional[PathRegistry] = None
|
||||
|
||||
|
||||
def get_path_registry() -> PathRegistry:
|
||||
"""Get the singleton path registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
import redis
|
||||
from urllib.parse import urlparse
|
||||
|
||||
redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
|
||||
parsed = urlparse(redis_url)
|
||||
redis_client = redis.Redis(
|
||||
host=parsed.hostname or 'localhost',
|
||||
port=parsed.port or 6379,
|
||||
db=int(parsed.path.lstrip('/') or 0),
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5
|
||||
)
|
||||
|
||||
_registry = PathRegistry(redis_client=redis_client)
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_path_registry():
|
||||
"""Reset the singleton (for testing)."""
|
||||
global _registry
|
||||
_registry = None
|
||||
206
sexp_effects/derived.sexp
Normal file
206
sexp_effects/derived.sexp
Normal file
@@ -0,0 +1,206 @@
|
||||
;; Derived Operations
|
||||
;;
|
||||
;; These are built from true primitives using S-expressions.
|
||||
;; Load with: (require "derived")
|
||||
|
||||
;; =============================================================================
|
||||
;; Math Helpers (derivable from where + basic ops)
|
||||
;; =============================================================================
|
||||
|
||||
;; Absolute value
|
||||
(define (abs x) (where (< x 0) (- x) x))
|
||||
|
||||
;; Minimum of two values
|
||||
(define (min2 a b) (where (< a b) a b))
|
||||
|
||||
;; Maximum of two values
|
||||
(define (max2 a b) (where (> a b) a b))
|
||||
|
||||
;; Clamp x to range [lo, hi]
|
||||
(define (clamp x lo hi) (max2 lo (min2 hi x)))
|
||||
|
||||
;; Square of x
|
||||
(define (sq x) (* x x))
|
||||
|
||||
;; Linear interpolation: a*(1-t) + b*t
|
||||
(define (lerp a b t) (+ (* a (- 1 t)) (* b t)))
|
||||
|
||||
;; Smooth interpolation between edges
|
||||
(define (smoothstep edge0 edge1 x)
|
||||
(let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1)))
|
||||
(* t (* t (- 3 (* 2 t))))))
|
||||
|
||||
;; =============================================================================
|
||||
;; Channel Shortcuts (derivable from channel primitive)
|
||||
;; =============================================================================
|
||||
|
||||
;; Extract red channel as xector
|
||||
(define (red frame) (channel frame 0))
|
||||
|
||||
;; Extract green channel as xector
|
||||
(define (green frame) (channel frame 1))
|
||||
|
||||
;; Extract blue channel as xector
|
||||
(define (blue frame) (channel frame 2))
|
||||
|
||||
;; Convert to grayscale xector (ITU-R BT.601)
|
||||
(define (gray frame)
|
||||
(+ (* (red frame) 0.299)
|
||||
(* (green frame) 0.587)
|
||||
(* (blue frame) 0.114)))
|
||||
|
||||
;; Alias for gray
|
||||
(define (luminance frame) (gray frame))
|
||||
|
||||
;; =============================================================================
|
||||
;; Coordinate Generators (derivable from iota + repeat/tile)
|
||||
;; =============================================================================
|
||||
|
||||
;; X coordinate for each pixel [0, width)
|
||||
(define (x-coords frame) (tile (iota (width frame)) (height frame)))
|
||||
|
||||
;; Y coordinate for each pixel [0, height)
|
||||
(define (y-coords frame) (repeat (iota (height frame)) (width frame)))
|
||||
|
||||
;; Normalized X coordinate [0, 1]
|
||||
(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1))))
|
||||
|
||||
;; Normalized Y coordinate [0, 1]
|
||||
(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1))))
|
||||
|
||||
;; Distance from frame center for each pixel
|
||||
(define (dist-from-center frame)
|
||||
(let* ((cx (/ (width frame) 2))
|
||||
(cy (/ (height frame) 2))
|
||||
(dx (- (x-coords frame) cx))
|
||||
(dy (- (y-coords frame) cy)))
|
||||
(sqrt (+ (sq dx) (sq dy)))))
|
||||
|
||||
;; Normalized distance from center [0, ~1]
|
||||
(define (dist-norm frame)
|
||||
(let ((d (dist-from-center frame)))
|
||||
(/ d (max2 1 (βmax d)))))
|
||||
|
||||
;; =============================================================================
|
||||
;; Cell/Grid Operations (derivable from floor + basic math)
|
||||
;; =============================================================================
|
||||
|
||||
;; Cell row index for each pixel
|
||||
(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size)))
|
||||
|
||||
;; Cell column index for each pixel
|
||||
(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size)))
|
||||
|
||||
;; Number of cell rows
|
||||
(define (num-rows frame cell-size) (floor (/ (height frame) cell-size)))
|
||||
|
||||
;; Number of cell columns
|
||||
(define (num-cols frame cell-size) (floor (/ (width frame) cell-size)))
|
||||
|
||||
;; Flat cell index for each pixel
|
||||
(define (cell-indices frame cell-size)
|
||||
(+ (* (cell-row frame cell-size) (num-cols frame cell-size))
|
||||
(cell-col frame cell-size)))
|
||||
|
||||
;; Total number of cells
|
||||
(define (num-cells frame cell-size)
|
||||
(* (num-rows frame cell-size) (num-cols frame cell-size)))
|
||||
|
||||
;; X position within cell [0, cell-size)
|
||||
(define (local-x frame cell-size) (mod (x-coords frame) cell-size))
|
||||
|
||||
;; Y position within cell [0, cell-size)
|
||||
(define (local-y frame cell-size) (mod (y-coords frame) cell-size))
|
||||
|
||||
;; Normalized X within cell [0, 1]
|
||||
(define (local-x-norm frame cell-size)
|
||||
(/ (local-x frame cell-size) (max2 1 (- cell-size 1))))
|
||||
|
||||
;; Normalized Y within cell [0, 1]
|
||||
(define (local-y-norm frame cell-size)
|
||||
(/ (local-y frame cell-size) (max2 1 (- cell-size 1))))
|
||||
|
||||
;; =============================================================================
|
||||
;; Fill Operations (derivable from iota)
|
||||
;; =============================================================================
|
||||
|
||||
;; Xector of n zeros
|
||||
(define (zeros n) (* (iota n) 0))
|
||||
|
||||
;; Xector of n ones
|
||||
(define (ones n) (+ (zeros n) 1))
|
||||
|
||||
;; Xector of n copies of val
|
||||
(define (fill val n) (+ (zeros n) val))
|
||||
|
||||
;; Xector of zeros matching x's length
|
||||
(define (zeros-like x) (* x 0))
|
||||
|
||||
;; Xector of ones matching x's length
|
||||
(define (ones-like x) (+ (zeros-like x) 1))
|
||||
|
||||
;; =============================================================================
|
||||
;; Pooling (derivable from group-reduce)
|
||||
;; =============================================================================
|
||||
|
||||
;; Pool a channel by cell index
|
||||
(define (pool-channel chan cell-idx num-cells)
|
||||
(group-reduce chan cell-idx num-cells "mean"))
|
||||
|
||||
;; Pool red channel to cells
|
||||
(define (pool-red frame cell-size)
|
||||
(pool-channel (red frame)
|
||||
(cell-indices frame cell-size)
|
||||
(num-cells frame cell-size)))
|
||||
|
||||
;; Pool green channel to cells
|
||||
(define (pool-green frame cell-size)
|
||||
(pool-channel (green frame)
|
||||
(cell-indices frame cell-size)
|
||||
(num-cells frame cell-size)))
|
||||
|
||||
;; Pool blue channel to cells
|
||||
(define (pool-blue frame cell-size)
|
||||
(pool-channel (blue frame)
|
||||
(cell-indices frame cell-size)
|
||||
(num-cells frame cell-size)))
|
||||
|
||||
;; Pool grayscale to cells
|
||||
(define (pool-gray frame cell-size)
|
||||
(pool-channel (gray frame)
|
||||
(cell-indices frame cell-size)
|
||||
(num-cells frame cell-size)))
|
||||
|
||||
;; =============================================================================
|
||||
;; Blending (derivable from math)
|
||||
;; =============================================================================
|
||||
|
||||
;; Additive blend
|
||||
(define (blend-add a b) (clamp (+ a b) 0 255))
|
||||
|
||||
;; Multiply blend (normalized)
|
||||
(define (blend-multiply a b) (* (/ a 255) b))
|
||||
|
||||
;; Screen blend
|
||||
(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b))))
|
||||
|
||||
;; Overlay blend
|
||||
(define (blend-overlay a b)
|
||||
(where (< a 128)
|
||||
(* 2 (/ (* a b) 255))
|
||||
(- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255)))))
|
||||
|
||||
;; =============================================================================
|
||||
;; Simple Effects (derivable from primitives)
|
||||
;; =============================================================================
|
||||
|
||||
;; Invert a channel (255 - c)
|
||||
(define (invert-channel c) (- 255 c))
|
||||
|
||||
;; Binary threshold
|
||||
(define (threshold-channel c thresh) (where (> c thresh) 255 0))
|
||||
|
||||
;; Reduce to n levels
|
||||
(define (posterize-channel c levels)
|
||||
(let ((step (/ 255 (- levels 1))))
|
||||
(* (round (/ c step)) step)))
|
||||
@@ -5,7 +5,7 @@
|
||||
:params (
|
||||
(char_size :type int :default 8 :range [4 32])
|
||||
(alphabet :type string :default "standard")
|
||||
(color_mode :type string :default "color" :desc ""color", "mono", "invert", or any color name/hex")
|
||||
(color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex")
|
||||
(background_color :type string :default "black" :desc "background color name/hex")
|
||||
(invert_colors :type int :default 0 :desc "swap foreground and background colors")
|
||||
(contrast :type float :default 1.5 :range [1 3])
|
||||
|
||||
65
sexp_effects/effects/cell_pattern.sexp
Normal file
65
sexp_effects/effects/cell_pattern.sexp
Normal file
@@ -0,0 +1,65 @@
|
||||
;; Cell Pattern effect - custom patterns within cells
|
||||
;;
|
||||
;; Demonstrates building arbitrary per-cell visuals from primitives.
|
||||
;; Uses local coordinates within cells to draw patterns scaled by luminance.
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect cell_pattern
|
||||
:params (
|
||||
(cell-size :type int :default 16 :range [8 48] :desc "Cell size")
|
||||
(pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring")
|
||||
)
|
||||
(let* (
|
||||
;; Pool to get cell colors
|
||||
(pooled (pool-frame frame cell-size))
|
||||
(cell-r (nth pooled 0))
|
||||
(cell-g (nth pooled 1))
|
||||
(cell-b (nth pooled 2))
|
||||
(cell-lum (α/ (nth pooled 3) 255))
|
||||
|
||||
;; Cell indices for each pixel
|
||||
(cell-idx (cell-indices frame cell-size))
|
||||
|
||||
;; Look up cell values for each pixel
|
||||
(pix-r (gather cell-r cell-idx))
|
||||
(pix-g (gather cell-g cell-idx))
|
||||
(pix-b (gather cell-b cell-idx))
|
||||
(pix-lum (gather cell-lum cell-idx))
|
||||
|
||||
;; Local position within cell [0, 1]
|
||||
(lx (local-x-norm frame cell-size))
|
||||
(ly (local-y-norm frame cell-size))
|
||||
|
||||
;; Pattern mask based on pattern type
|
||||
(mask
|
||||
(cond
|
||||
;; Diagonal lines - thickness based on luminance
|
||||
((= pattern "diagonal")
|
||||
(let* ((diag (αmod (α+ lx ly) 0.25))
|
||||
(thickness (α* pix-lum 0.125)))
|
||||
(α< diag thickness)))
|
||||
|
||||
;; Cross pattern
|
||||
((= pattern "cross")
|
||||
(let* ((cx (αabs (α- lx 0.5)))
|
||||
(cy (αabs (α- ly 0.5)))
|
||||
(thickness (α* pix-lum 0.25)))
|
||||
(αor (α< cx thickness) (α< cy thickness))))
|
||||
|
||||
;; Ring pattern
|
||||
((= pattern "ring")
|
||||
(let* ((dx (α- lx 0.5))
|
||||
(dy (α- ly 0.5))
|
||||
(dist (αsqrt (α+ (α² dx) (α² dy))))
|
||||
(target (α* pix-lum 0.4))
|
||||
(thickness 0.05))
|
||||
(α< (αabs (α- dist target)) thickness)))
|
||||
|
||||
;; Default: solid
|
||||
(else (α> pix-lum 0)))))
|
||||
|
||||
;; Apply mask: show cell color where mask is true, black elsewhere
|
||||
(rgb (where mask pix-r 0)
|
||||
(where mask pix-g 0)
|
||||
(where mask pix-b 0))))
|
||||
@@ -6,10 +6,10 @@
|
||||
(num_echoes :type int :default 4 :range [1 20])
|
||||
(decay :type float :default 0.5 :range [0 1])
|
||||
)
|
||||
(let* ((buffer (state-get 'buffer (list)))
|
||||
(let* ((buffer (state-get "buffer" (list)))
|
||||
(new-buffer (take (cons frame buffer) (+ num_echoes 1))))
|
||||
(begin
|
||||
(state-set 'buffer new-buffer)
|
||||
(state-set "buffer" new-buffer)
|
||||
;; Blend frames with decay
|
||||
(if (< (length new-buffer) 2)
|
||||
frame
|
||||
|
||||
49
sexp_effects/effects/halftone.sexp
Normal file
49
sexp_effects/effects/halftone.sexp
Normal file
@@ -0,0 +1,49 @@
|
||||
;; Halftone/dot effect - built from primitive xector operations
|
||||
;;
|
||||
;; Uses:
|
||||
;; pool-frame - downsample to cell luminances
|
||||
;; cell-indices - which cell each pixel belongs to
|
||||
;; gather - look up cell value for each pixel
|
||||
;; local-x/y-norm - position within cell [0,1]
|
||||
;; where - conditional per-pixel
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect halftone
|
||||
:params (
|
||||
(cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells")
|
||||
(dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius")
|
||||
(invert :type bool :default false :desc "Invert (white dots on black)")
|
||||
)
|
||||
(let* (
|
||||
;; Pool frame to get luminance per cell
|
||||
(pooled (pool-frame frame cell-size))
|
||||
(cell-lum (nth pooled 3)) ; luminance is 4th element
|
||||
|
||||
;; For each output pixel, get its cell index
|
||||
(cell-idx (cell-indices frame cell-size))
|
||||
|
||||
;; Get cell luminance for each pixel
|
||||
(pixel-lum (α/ (gather cell-lum cell-idx) 255))
|
||||
|
||||
;; Position within cell, normalized to [-0.5, 0.5]
|
||||
(lx (α- (local-x-norm frame cell-size) 0.5))
|
||||
(ly (α- (local-y-norm frame cell-size) 0.5))
|
||||
|
||||
;; Distance from cell center (0 at center, ~0.7 at corners)
|
||||
(dist (αsqrt (α+ (α² lx) (α² ly))))
|
||||
|
||||
;; Radius based on luminance (brighter = bigger dot)
|
||||
(radius (α* (if invert (α- 1 pixel-lum) pixel-lum)
|
||||
(α* dot-scale 0.5)))
|
||||
|
||||
;; Is this pixel inside the dot?
|
||||
(inside (α< dist radius))
|
||||
|
||||
;; Output color
|
||||
(fg (if invert 255 0))
|
||||
(bg (if invert 0 255))
|
||||
(out (where inside fg bg)))
|
||||
|
||||
;; Grayscale output
|
||||
(rgb out out out)))
|
||||
30
sexp_effects/effects/mosaic.sexp
Normal file
30
sexp_effects/effects/mosaic.sexp
Normal file
@@ -0,0 +1,30 @@
|
||||
;; Mosaic effect - built from primitive xector operations
|
||||
;;
|
||||
;; Uses:
|
||||
;; pool-frame - downsample to cell averages
|
||||
;; cell-indices - which cell each pixel belongs to
|
||||
;; gather - look up cell value for each pixel
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect mosaic
|
||||
:params (
|
||||
(cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells")
|
||||
)
|
||||
(let* (
|
||||
;; Pool frame to get average color per cell (returns r,g,b,lum xectors)
|
||||
(pooled (pool-frame frame cell-size))
|
||||
(cell-r (nth pooled 0))
|
||||
(cell-g (nth pooled 1))
|
||||
(cell-b (nth pooled 2))
|
||||
|
||||
;; For each output pixel, get its cell index
|
||||
(cell-idx (cell-indices frame cell-size))
|
||||
|
||||
;; Gather: look up cell color for each pixel
|
||||
(out-r (gather cell-r cell-idx))
|
||||
(out-g (gather cell-g cell-idx))
|
||||
(out-b (gather cell-b cell-idx)))
|
||||
|
||||
;; Reconstruct frame
|
||||
(rgb out-r out-g out-b)))
|
||||
@@ -5,9 +5,9 @@
|
||||
:params (
|
||||
(thickness :type int :default 2 :range [1 10])
|
||||
(threshold :type int :default 100 :range [20 300])
|
||||
(color :type list :default (list 0 0 0)
|
||||
(color :type list :default (list 0 0 0))
|
||||
(fill_mode :type string :default "original")
|
||||
)
|
||||
(fill_mode "original"))
|
||||
(let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold))
|
||||
(dilated (if (> thickness 1)
|
||||
(dilate edge-img thickness)
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
:params (
|
||||
(frame_rate :type int :default 12 :range [1 60])
|
||||
)
|
||||
(let* ((held (state-get 'held nil))
|
||||
(held-until (state-get 'held-until 0))
|
||||
(let* ((held (state-get "held" nil))
|
||||
(held-until (state-get "held-until" 0))
|
||||
(frame-duration (/ 1 frame_rate)))
|
||||
(if (or (core:is-nil held) (>= t held-until))
|
||||
(begin
|
||||
(state-set 'held (copy frame))
|
||||
(state-set 'held-until (+ t frame-duration))
|
||||
(state-set "held" (copy frame))
|
||||
(state-set "held-until" (+ t frame-duration))
|
||||
frame)
|
||||
held)))
|
||||
|
||||
@@ -5,16 +5,16 @@
|
||||
:params (
|
||||
(persistence :type float :default 0.8 :range [0 0.99])
|
||||
)
|
||||
(let* ((buffer (state-get 'buffer nil))
|
||||
(let* ((buffer (state-get "buffer" nil))
|
||||
(current frame))
|
||||
(if (= buffer nil)
|
||||
(begin
|
||||
(state-set 'buffer (copy frame))
|
||||
(state-set "buffer" (copy frame))
|
||||
frame)
|
||||
(let* ((faded (blending:blend-images buffer
|
||||
(make-image (image:width frame) (image:height frame) (list 0 0 0))
|
||||
(- 1 persistence)))
|
||||
(result (blending:blend-mode faded current "lighten")))
|
||||
(begin
|
||||
(state-set 'buffer result)
|
||||
(state-set "buffer" result)
|
||||
result)))))
|
||||
|
||||
44
sexp_effects/effects/xector_feathered_blend.sexp
Normal file
44
sexp_effects/effects/xector_feathered_blend.sexp
Normal file
@@ -0,0 +1,44 @@
|
||||
;; Feathered blend - blend two same-size frames with distance-based falloff
|
||||
;; Center shows overlay, edges show background, with smooth transition
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect xector_feathered_blend
|
||||
:params (
|
||||
(inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)")
|
||||
(fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)")
|
||||
(overlay :type frame :default nil :desc "Frame to blend in center")
|
||||
)
|
||||
(let* (
|
||||
;; Get normalized distance from center (0 at center, ~1 at corners)
|
||||
(dist (dist-from-center frame))
|
||||
(max-dist (βmax dist))
|
||||
(dist-norm (α/ dist max-dist))
|
||||
|
||||
;; Calculate blend factor:
|
||||
;; - 1.0 when dist-norm < inner-radius (fully overlay)
|
||||
;; - 0.0 when dist-norm > inner-radius + fade-width (fully background)
|
||||
;; - linear ramp between
|
||||
(t (α/ (α- dist-norm inner-radius) fade-width))
|
||||
(blend (α- 1 (αclamp t 0 1)))
|
||||
(inv-blend (α- 1 blend))
|
||||
|
||||
;; Background channels
|
||||
(bg-r (red frame))
|
||||
(bg-g (green frame))
|
||||
(bg-b (blue frame)))
|
||||
|
||||
(if (nil? overlay)
|
||||
;; No overlay - visualize the blend mask
|
||||
(let ((vis (α* blend 255)))
|
||||
(rgb vis vis vis))
|
||||
|
||||
;; Blend overlay with background using the mask
|
||||
(let* ((ov-r (red overlay))
|
||||
(ov-g (green overlay))
|
||||
(ov-b (blue overlay))
|
||||
;; lerp: bg * (1-blend) + overlay * blend
|
||||
(r-out (α+ (α* bg-r inv-blend) (α* ov-r blend)))
|
||||
(g-out (α+ (α* bg-g inv-blend) (α* ov-g blend)))
|
||||
(b-out (α+ (α* bg-b inv-blend) (α* ov-b blend))))
|
||||
(rgb r-out g-out b-out)))))
|
||||
34
sexp_effects/effects/xector_grain.sexp
Normal file
34
sexp_effects/effects/xector_grain.sexp
Normal file
@@ -0,0 +1,34 @@
|
||||
;; Film grain effect using xector operations
|
||||
;; Demonstrates random xectors and mixing scalar/xector math
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect xector_grain
|
||||
:params (
|
||||
(intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity")
|
||||
(colored :type bool :default false :desc "Use colored grain")
|
||||
)
|
||||
(let* (
|
||||
;; Extract channels
|
||||
(r (red frame))
|
||||
(g (green frame))
|
||||
(b (blue frame))
|
||||
|
||||
;; Generate noise xector(s)
|
||||
;; randn-x generates normal distribution noise
|
||||
(grain-amount (* intensity 50)))
|
||||
|
||||
(if colored
|
||||
;; Colored grain: different noise per channel
|
||||
(let* ((nr (randn-x frame 0 grain-amount))
|
||||
(ng (randn-x frame 0 grain-amount))
|
||||
(nb (randn-x frame 0 grain-amount)))
|
||||
(rgb (αclamp (α+ r nr) 0 255)
|
||||
(αclamp (α+ g ng) 0 255)
|
||||
(αclamp (α+ b nb) 0 255)))
|
||||
|
||||
;; Monochrome grain: same noise for all channels
|
||||
(let ((n (randn-x frame 0 grain-amount)))
|
||||
(rgb (αclamp (α+ r n) 0 255)
|
||||
(αclamp (α+ g n) 0 255)
|
||||
(αclamp (α+ b n) 0 255))))))
|
||||
57
sexp_effects/effects/xector_inset_blend.sexp
Normal file
57
sexp_effects/effects/xector_inset_blend.sexp
Normal file
@@ -0,0 +1,57 @@
|
||||
;; Inset blend - fade a smaller frame into a larger background
|
||||
;; Uses distance-based alpha for smooth transition (no hard edges)
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect xector_inset_blend
|
||||
:params (
|
||||
(x :type int :default 0 :desc "X position of inset")
|
||||
(y :type int :default 0 :desc "Y position of inset")
|
||||
(fade-width :type int :default 50 :desc "Width of fade region in pixels")
|
||||
(overlay :type frame :default nil :desc "The smaller frame to inset")
|
||||
)
|
||||
(let* (
|
||||
;; Get dimensions
|
||||
(bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height
|
||||
(bg-w bg-h) ;; placeholder
|
||||
|
||||
;; For now, create a simple centered circular blend
|
||||
;; Distance from center of overlay position
|
||||
(cx (+ x (/ (- bg-w (* 2 x)) 2)))
|
||||
(cy (+ y (/ (- bg-h (* 2 y)) 2)))
|
||||
|
||||
;; Get coordinates as xectors
|
||||
(px (x-coords frame))
|
||||
(py (y-coords frame))
|
||||
|
||||
;; Distance from center
|
||||
(dx (α- px cx))
|
||||
(dy (α- py cy))
|
||||
(dist (αsqrt (α+ (α* dx dx) (α* dy dy))))
|
||||
|
||||
;; Inner radius (fully overlay) and outer radius (fully background)
|
||||
(inner-r (- (/ bg-w 2) x fade-width))
|
||||
(outer-r (- (/ bg-w 2) x))
|
||||
|
||||
;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between
|
||||
(t (α/ (α- dist inner-r) fade-width))
|
||||
(blend (α- 1 (αclamp t 0 1)))
|
||||
|
||||
;; Extract channels from both frames
|
||||
(bg-r (red frame))
|
||||
(bg-g (green frame))
|
||||
(bg-b (blue frame)))
|
||||
|
||||
;; If overlay provided, blend it
|
||||
(if overlay
|
||||
(let* ((ov-r (red overlay))
|
||||
(ov-g (green overlay))
|
||||
(ov-b (blue overlay))
|
||||
;; Linear blend: result = bg * (1-blend) + overlay * blend
|
||||
(r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend)))
|
||||
(g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend)))
|
||||
(b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend))))
|
||||
(rgb r-out g-out b-out))
|
||||
;; No overlay - just show the blend mask for debugging
|
||||
(let ((mask-vis (α* blend 255)))
|
||||
(rgb mask-vis mask-vis mask-vis)))))
|
||||
27
sexp_effects/effects/xector_threshold.sexp
Normal file
27
sexp_effects/effects/xector_threshold.sexp
Normal file
@@ -0,0 +1,27 @@
|
||||
;; Threshold effect using xector operations
|
||||
;; Demonstrates where (conditional select) and β (reduction) for normalization
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect xector_threshold
|
||||
:params (
|
||||
(threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)")
|
||||
(invert :type bool :default false :desc "Invert the threshold")
|
||||
)
|
||||
(let* (
|
||||
;; Get grayscale luminance as xector
|
||||
(luma (gray frame))
|
||||
|
||||
;; Normalize to 0-1 range
|
||||
(luma-norm (α/ luma 255))
|
||||
|
||||
;; Create boolean mask: pixels above threshold
|
||||
(mask (if invert
|
||||
(α< luma-norm threshold)
|
||||
(α>= luma-norm threshold)))
|
||||
|
||||
;; Use where to select: white (255) if above threshold, black (0) if below
|
||||
(out (where mask 255 0)))
|
||||
|
||||
;; Output as grayscale (same value for R, G, B)
|
||||
(rgb out out out)))
|
||||
36
sexp_effects/effects/xector_vignette.sexp
Normal file
36
sexp_effects/effects/xector_vignette.sexp
Normal file
@@ -0,0 +1,36 @@
|
||||
;; Vignette effect using xector operations
|
||||
;; Demonstrates α (element-wise) and β (reduction) patterns
|
||||
|
||||
(require-primitives "xector")
|
||||
|
||||
(define-effect xector_vignette
|
||||
:params (
|
||||
(strength :type float :default 0.5 :range [0 1])
|
||||
(radius :type float :default 1.0 :range [0.5 2])
|
||||
)
|
||||
(let* (
|
||||
;; Get normalized distance from center for each pixel
|
||||
(dist (dist-from-center frame))
|
||||
|
||||
;; Calculate max distance (corner distance)
|
||||
(max-dist (* (βmax dist) radius))
|
||||
|
||||
;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength)
|
||||
;; Using explicit α operators
|
||||
(factor (α- 1 (α* (α/ dist max-dist) strength)))
|
||||
|
||||
;; Clamp factor to [0, 1]
|
||||
(factor (αclamp factor 0 1))
|
||||
|
||||
;; Extract channels as xectors
|
||||
(r (red frame))
|
||||
(g (green frame))
|
||||
(b (blue frame))
|
||||
|
||||
;; Apply factor to each channel (implicit element-wise via Xector operators)
|
||||
(r-out (* r factor))
|
||||
(g-out (* g factor))
|
||||
(b-out (* b factor)))
|
||||
|
||||
;; Combine back to frame
|
||||
(rgb r-out g-out b-out)))
|
||||
@@ -156,11 +156,21 @@ class Interpreter:
|
||||
if form == 'define':
|
||||
name = expr[1]
|
||||
if _is_symbol(name):
|
||||
# Simple define: (define name value)
|
||||
value = self.eval(expr[2], env)
|
||||
self.global_env.set(name.name, value)
|
||||
return value
|
||||
elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]):
|
||||
# Function define: (define (fn-name args...) body)
|
||||
# Desugars to: (define fn-name (lambda (args...) body))
|
||||
fn_name = name[0].name
|
||||
params = [p.name if _is_symbol(p) else p for p in name[1:]]
|
||||
body = expr[2]
|
||||
fn = Lambda(params, body, env)
|
||||
self.global_env.set(fn_name, fn)
|
||||
return fn
|
||||
else:
|
||||
raise SyntaxError(f"define requires symbol, got {name}")
|
||||
raise SyntaxError(f"define requires symbol or (name args...), got {name}")
|
||||
|
||||
# Define-effect
|
||||
if form == 'define-effect':
|
||||
@@ -276,6 +286,10 @@ class Interpreter:
|
||||
if form == 'require-primitives':
|
||||
return self._eval_require_primitives(expr, env)
|
||||
|
||||
# require - load .sexp file into current scope
|
||||
if form == 'require':
|
||||
return self._eval_require(expr, env)
|
||||
|
||||
# Function call
|
||||
fn = self.eval(head, env)
|
||||
args = [self.eval(arg, env) for arg in expr[1:]]
|
||||
@@ -488,6 +502,61 @@ class Interpreter:
|
||||
from .primitive_libs import load_primitive_library
|
||||
return load_primitive_library(name, path)
|
||||
|
||||
def _eval_require(self, expr: Any, env: Environment) -> Any:
|
||||
"""
|
||||
Evaluate require: load a .sexp file and evaluate its definitions.
|
||||
|
||||
Syntax:
|
||||
(require "derived") ; loads derived.sexp from sexp_effects/
|
||||
(require "path/to/file.sexp") ; loads from explicit path
|
||||
|
||||
Definitions from the file are added to the current environment.
|
||||
"""
|
||||
for lib_expr in expr[1:]:
|
||||
if _is_symbol(lib_expr):
|
||||
lib_name = lib_expr.name
|
||||
else:
|
||||
lib_name = lib_expr
|
||||
|
||||
# Find the .sexp file
|
||||
sexp_path = self._find_sexp_file(lib_name)
|
||||
if sexp_path is None:
|
||||
raise ValueError(f"Cannot find sexp file: {lib_name}")
|
||||
|
||||
# Parse and evaluate the file
|
||||
content = parse_file(sexp_path)
|
||||
|
||||
# Evaluate all top-level expressions
|
||||
if isinstance(content, list) and content and isinstance(content[0], list):
|
||||
for e in content:
|
||||
self.eval(e, env)
|
||||
else:
|
||||
self.eval(content, env)
|
||||
|
||||
return None
|
||||
|
||||
def _find_sexp_file(self, name: str) -> Optional[str]:
|
||||
"""Find a .sexp file by name."""
|
||||
# Try various locations
|
||||
candidates = [
|
||||
# Explicit path
|
||||
name,
|
||||
name + '.sexp',
|
||||
# In sexp_effects directory
|
||||
Path(__file__).parent / f'{name}.sexp',
|
||||
Path(__file__).parent / name,
|
||||
# In effects directory
|
||||
Path(__file__).parent / 'effects' / f'{name}.sexp',
|
||||
Path(__file__).parent / 'effects' / name,
|
||||
]
|
||||
|
||||
for path in candidates:
|
||||
p = Path(path) if not isinstance(path, Path) else path
|
||||
if p.exists() and p.is_file():
|
||||
return str(p)
|
||||
|
||||
return None
|
||||
|
||||
def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any:
|
||||
"""
|
||||
Evaluate ascii-fx-zone special form.
|
||||
@@ -876,8 +945,8 @@ class Interpreter:
|
||||
for pname, pdefault in effect.params.items():
|
||||
value = params.get(pname)
|
||||
if value is None:
|
||||
# Evaluate default if it's an expression (list)
|
||||
if isinstance(pdefault, list):
|
||||
# Evaluate default if it's an expression (list) or a symbol (like 'nil')
|
||||
if isinstance(pdefault, list) or _is_symbol(pdefault):
|
||||
value = self.eval(pdefault, env)
|
||||
else:
|
||||
value = pdefault
|
||||
|
||||
@@ -71,7 +71,8 @@ class Tokenizer:
|
||||
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
|
||||
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
|
||||
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
|
||||
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*')
|
||||
# Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations
|
||||
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*')
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
@@ -1,126 +1,680 @@
|
||||
"""
|
||||
Drawing Primitives Library
|
||||
|
||||
Draw shapes, text, and characters on images.
|
||||
Draw shapes, text, and characters on images with sophisticated text handling.
|
||||
|
||||
Text Features:
|
||||
- Font loading from files or system fonts
|
||||
- Text measurement and fitting
|
||||
- Alignment (left/center/right, top/middle/bottom)
|
||||
- Opacity for fade effects
|
||||
- Multi-line text support
|
||||
- Shadow and outline effects
|
||||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
|
||||
# Default font (will be loaded lazily)
|
||||
_default_font = None
|
||||
# =============================================================================
|
||||
# Font Management
|
||||
# =============================================================================
|
||||
|
||||
# Font cache: (path, size) -> font object
|
||||
_font_cache = {}
|
||||
|
||||
# Common system font directories
|
||||
FONT_DIRS = [
|
||||
"/usr/share/fonts",
|
||||
"/usr/local/share/fonts",
|
||||
"~/.fonts",
|
||||
"~/.local/share/fonts",
|
||||
"/System/Library/Fonts", # macOS
|
||||
"/Library/Fonts", # macOS
|
||||
"C:/Windows/Fonts", # Windows
|
||||
]
|
||||
|
||||
# Default fonts to try (in order of preference)
|
||||
DEFAULT_FONTS = [
|
||||
"DejaVuSans.ttf",
|
||||
"DejaVuSansMono.ttf",
|
||||
"Arial.ttf",
|
||||
"Helvetica.ttf",
|
||||
"FreeSans.ttf",
|
||||
"LiberationSans-Regular.ttf",
|
||||
]
|
||||
|
||||
|
||||
def _get_default_font(size=16):
|
||||
"""Get default font, creating if needed."""
|
||||
global _default_font
|
||||
if _default_font is None or _default_font.size != size:
|
||||
def _find_font_file(name: str) -> Optional[str]:
|
||||
"""Find a font file by name in system directories."""
|
||||
# If it's already a full path
|
||||
if os.path.isfile(name):
|
||||
return name
|
||||
|
||||
# Expand user paths
|
||||
expanded = os.path.expanduser(name)
|
||||
if os.path.isfile(expanded):
|
||||
return expanded
|
||||
|
||||
# Search in font directories
|
||||
for font_dir in FONT_DIRS:
|
||||
font_dir = os.path.expanduser(font_dir)
|
||||
if not os.path.isdir(font_dir):
|
||||
continue
|
||||
|
||||
# Direct match
|
||||
direct = os.path.join(font_dir, name)
|
||||
if os.path.isfile(direct):
|
||||
return direct
|
||||
|
||||
# Recursive search
|
||||
for root, dirs, files in os.walk(font_dir):
|
||||
for f in files:
|
||||
if f.lower() == name.lower():
|
||||
return os.path.join(root, f)
|
||||
# Also match without extension
|
||||
base = os.path.splitext(f)[0]
|
||||
if base.lower() == name.lower():
|
||||
return os.path.join(root, f)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont:
|
||||
"""Get a default font at the given size."""
|
||||
for font_name in DEFAULT_FONTS:
|
||||
path = _find_font_file(font_name)
|
||||
if path:
|
||||
try:
|
||||
return ImageFont.truetype(path, size)
|
||||
except:
|
||||
continue
|
||||
|
||||
# Last resort: PIL default
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont:
|
||||
"""
|
||||
Load a font by name or path.
|
||||
|
||||
(make-font "Arial" 32) ; system font by name
|
||||
(make-font "/path/to/font.ttf" 24) ; font file path
|
||||
(make-font "DejaVuSans" 48) ; searches common locations
|
||||
|
||||
Returns a font object for use with text primitives.
|
||||
"""
|
||||
size = int(size)
|
||||
|
||||
# Check cache
|
||||
cache_key = (name_or_path, size)
|
||||
if cache_key in _font_cache:
|
||||
return _font_cache[cache_key]
|
||||
|
||||
# Find the font file
|
||||
path = _find_font_file(name_or_path)
|
||||
if not path:
|
||||
raise FileNotFoundError(f"Font not found: {name_or_path}")
|
||||
|
||||
# Load and cache
|
||||
font = ImageFont.truetype(path, size)
|
||||
_font_cache[cache_key] = font
|
||||
return font
|
||||
|
||||
|
||||
def prim_list_fonts() -> List[str]:
|
||||
"""
|
||||
List available system fonts.
|
||||
|
||||
(list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...)
|
||||
|
||||
Returns list of font filenames found in system directories.
|
||||
"""
|
||||
fonts = set()
|
||||
|
||||
for font_dir in FONT_DIRS:
|
||||
font_dir = os.path.expanduser(font_dir)
|
||||
if not os.path.isdir(font_dir):
|
||||
continue
|
||||
|
||||
for root, dirs, files in os.walk(font_dir):
|
||||
for f in files:
|
||||
if f.lower().endswith(('.ttf', '.otf', '.ttc')):
|
||||
fonts.add(f)
|
||||
|
||||
return sorted(fonts)
|
||||
|
||||
|
||||
def prim_font_size(font: ImageFont.FreeTypeFont) -> int:
|
||||
"""
|
||||
Get the size of a font.
|
||||
|
||||
(font-size my-font) ; -> 24
|
||||
"""
|
||||
return font.size
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Text Measurement
|
||||
# =============================================================================
|
||||
|
||||
def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]:
|
||||
"""
|
||||
Measure text dimensions.
|
||||
|
||||
(text-size "Hello" my-font) ; -> (width height)
|
||||
(text-size "Hello" :font-size 32) ; -> (width height) with default font
|
||||
|
||||
For multi-line text, returns total bounding box.
|
||||
"""
|
||||
if font is None:
|
||||
font = _get_default_font(int(font_size))
|
||||
elif isinstance(font, (int, float)):
|
||||
font = _get_default_font(int(font))
|
||||
|
||||
# Create temporary image for measurement
|
||||
img = Image.new('RGB', (1, 1))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
bbox = draw.textbbox((0, 0), str(text), font=font)
|
||||
width = bbox[2] - bbox[0]
|
||||
height = bbox[3] - bbox[1]
|
||||
|
||||
return (width, height)
|
||||
|
||||
|
||||
def prim_text_metrics(font=None, font_size: int = 24) -> dict:
|
||||
"""
|
||||
Get font metrics.
|
||||
|
||||
(text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25}
|
||||
|
||||
Useful for precise text layout.
|
||||
"""
|
||||
if font is None:
|
||||
font = _get_default_font(int(font_size))
|
||||
elif isinstance(font, (int, float)):
|
||||
font = _get_default_font(int(font))
|
||||
|
||||
ascent, descent = font.getmetrics()
|
||||
return {
|
||||
'ascent': ascent,
|
||||
'descent': descent,
|
||||
'height': ascent + descent,
|
||||
'size': font.size,
|
||||
}
|
||||
|
||||
|
||||
def prim_fit_text_size(text: str, max_width: int, max_height: int,
|
||||
font_name: str = None, min_size: int = 8,
|
||||
max_size: int = 500) -> int:
|
||||
"""
|
||||
Calculate font size to fit text within bounds.
|
||||
|
||||
(fit-text-size "Hello World" 400 100) ; -> 48
|
||||
(fit-text-size "Title" 800 200 :font-name "Arial")
|
||||
|
||||
Returns the largest font size that fits within max_width x max_height.
|
||||
"""
|
||||
max_width = int(max_width)
|
||||
max_height = int(max_height)
|
||||
min_size = int(min_size)
|
||||
max_size = int(max_size)
|
||||
text = str(text)
|
||||
|
||||
# Binary search for optimal size
|
||||
best_size = min_size
|
||||
low, high = min_size, max_size
|
||||
|
||||
while low <= high:
|
||||
mid = (low + high) // 2
|
||||
|
||||
if font_name:
|
||||
try:
|
||||
font = prim_make_font(font_name, mid)
|
||||
except:
|
||||
font = _get_default_font(mid)
|
||||
else:
|
||||
font = _get_default_font(mid)
|
||||
|
||||
w, h = prim_text_size(text, font)
|
||||
|
||||
if w <= max_width and h <= max_height:
|
||||
best_size = mid
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best_size
|
||||
|
||||
|
||||
def prim_fit_font(text: str, max_width: int, max_height: int,
|
||||
font_name: str = None, min_size: int = 8,
|
||||
max_size: int = 500) -> ImageFont.FreeTypeFont:
|
||||
"""
|
||||
Create a font sized to fit text within bounds.
|
||||
|
||||
(fit-font "Hello World" 400 100) ; -> font object
|
||||
(fit-font "Title" 800 200 :font-name "Arial")
|
||||
|
||||
Returns a font object at the optimal size.
|
||||
"""
|
||||
size = prim_fit_text_size(text, max_width, max_height,
|
||||
font_name, min_size, max_size)
|
||||
|
||||
if font_name:
|
||||
try:
|
||||
_default_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size)
|
||||
return prim_make_font(font_name, size)
|
||||
except:
|
||||
_default_font = ImageFont.load_default()
|
||||
return _default_font
|
||||
pass
|
||||
|
||||
return _get_default_font(size)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Text Drawing
|
||||
# =============================================================================
|
||||
|
||||
def prim_text(img: np.ndarray, text: str,
|
||||
x: int = None, y: int = None,
|
||||
width: int = None, height: int = None,
|
||||
font=None, font_size: int = 24, font_name: str = None,
|
||||
color=None, opacity: float = 1.0,
|
||||
align: str = "left", valign: str = "top",
|
||||
fit: bool = False,
|
||||
shadow: bool = False, shadow_color=None, shadow_offset: int = 2,
|
||||
outline: bool = False, outline_color=None, outline_width: int = 1,
|
||||
line_spacing: float = 1.2) -> np.ndarray:
|
||||
"""
|
||||
Draw text with alignment, opacity, and effects.
|
||||
|
||||
Basic usage:
|
||||
(text frame "Hello" :x 100 :y 50)
|
||||
|
||||
Centered in frame:
|
||||
(text frame "Title" :align "center" :valign "middle")
|
||||
|
||||
Fit to box:
|
||||
(text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true)
|
||||
|
||||
With fade (for animations):
|
||||
(text frame "Fading" :x 100 :y 100 :opacity 0.5)
|
||||
|
||||
With effects:
|
||||
(text frame "Shadow" :x 100 :y 100 :shadow true)
|
||||
(text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0))
|
||||
|
||||
Args:
|
||||
img: Input frame
|
||||
text: Text to draw
|
||||
x, y: Position (if not specified, uses alignment in full frame)
|
||||
width, height: Bounding box (for fit and alignment within box)
|
||||
font: Font object from make-font
|
||||
font_size: Size if no font specified
|
||||
font_name: Font name to load
|
||||
color: RGB tuple (default white)
|
||||
opacity: 0.0 (invisible) to 1.0 (opaque) for fading
|
||||
align: "left", "center", "right"
|
||||
valign: "top", "middle", "bottom"
|
||||
fit: If true, auto-size font to fit in box
|
||||
shadow: Draw drop shadow
|
||||
shadow_color: Shadow color (default black)
|
||||
shadow_offset: Shadow offset in pixels
|
||||
outline: Draw text outline
|
||||
outline_color: Outline color (default black)
|
||||
outline_width: Outline thickness
|
||||
line_spacing: Multiplier for line height (for multi-line)
|
||||
|
||||
Returns:
|
||||
Frame with text drawn
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
text = str(text)
|
||||
|
||||
# Default colors
|
||||
if color is None:
|
||||
color = (255, 255, 255)
|
||||
else:
|
||||
color = tuple(int(c) for c in color)
|
||||
|
||||
if shadow_color is None:
|
||||
shadow_color = (0, 0, 0)
|
||||
else:
|
||||
shadow_color = tuple(int(c) for c in shadow_color)
|
||||
|
||||
if outline_color is None:
|
||||
outline_color = (0, 0, 0)
|
||||
else:
|
||||
outline_color = tuple(int(c) for c in outline_color)
|
||||
|
||||
# Determine bounding box
|
||||
if x is None:
|
||||
x = 0
|
||||
if width is None:
|
||||
width = w
|
||||
if y is None:
|
||||
y = 0
|
||||
if height is None:
|
||||
height = h
|
||||
|
||||
x, y = int(x), int(y)
|
||||
box_width = int(width) if width else w - x
|
||||
box_height = int(height) if height else h - y
|
||||
|
||||
# Get or create font
|
||||
if font is None:
|
||||
if fit:
|
||||
font = prim_fit_font(text, box_width, box_height, font_name)
|
||||
elif font_name:
|
||||
try:
|
||||
font = prim_make_font(font_name, int(font_size))
|
||||
except:
|
||||
font = _get_default_font(int(font_size))
|
||||
else:
|
||||
font = _get_default_font(int(font_size))
|
||||
|
||||
# Measure text
|
||||
text_w, text_h = prim_text_size(text, font)
|
||||
|
||||
# Calculate position based on alignment
|
||||
if align == "center":
|
||||
draw_x = x + (box_width - text_w) // 2
|
||||
elif align == "right":
|
||||
draw_x = x + box_width - text_w
|
||||
else: # left
|
||||
draw_x = x
|
||||
|
||||
if valign == "middle":
|
||||
draw_y = y + (box_height - text_h) // 2
|
||||
elif valign == "bottom":
|
||||
draw_y = y + box_height - text_h
|
||||
else: # top
|
||||
draw_y = y
|
||||
|
||||
# Create RGBA image for compositing with opacity
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
|
||||
# Create text layer with transparency
|
||||
text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(text_layer)
|
||||
|
||||
# Draw shadow first (if enabled)
|
||||
if shadow:
|
||||
shadow_x = draw_x + shadow_offset
|
||||
shadow_y = draw_y + shadow_offset
|
||||
shadow_rgba = shadow_color + (int(255 * opacity * 0.5),)
|
||||
draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font)
|
||||
|
||||
# Draw outline (if enabled)
|
||||
if outline:
|
||||
outline_rgba = outline_color + (int(255 * opacity),)
|
||||
ow = int(outline_width)
|
||||
for dx in range(-ow, ow + 1):
|
||||
for dy in range(-ow, ow + 1):
|
||||
if dx != 0 or dy != 0:
|
||||
draw.text((draw_x + dx, draw_y + dy), text,
|
||||
fill=outline_rgba, font=font)
|
||||
|
||||
# Draw main text
|
||||
text_rgba = color + (int(255 * opacity),)
|
||||
draw.text((draw_x, draw_y), text, fill=text_rgba, font=font)
|
||||
|
||||
# Composite
|
||||
result = Image.alpha_composite(pil_img, text_layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_text_box(img: np.ndarray, text: str,
|
||||
x: int, y: int, width: int, height: int,
|
||||
font=None, font_size: int = 24, font_name: str = None,
|
||||
color=None, opacity: float = 1.0,
|
||||
align: str = "center", valign: str = "middle",
|
||||
fit: bool = True,
|
||||
padding: int = 0,
|
||||
background=None, background_opacity: float = 0.5,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""
|
||||
Draw text fitted within a box, optionally with background.
|
||||
|
||||
(text-box frame "Title" 50 50 400 100)
|
||||
(text-box frame "Subtitle" 50 160 400 50
|
||||
:background (0 0 0) :background-opacity 0.7)
|
||||
|
||||
Convenience wrapper around text() for common box-with-text pattern.
|
||||
"""
|
||||
x, y = int(x), int(y)
|
||||
width, height = int(width), int(height)
|
||||
padding = int(padding)
|
||||
|
||||
result = img.copy()
|
||||
|
||||
# Draw background if specified
|
||||
if background is not None:
|
||||
bg_color = tuple(int(c) for c in background)
|
||||
|
||||
# Create background with opacity
|
||||
pil_img = Image.fromarray(result).convert('RGBA')
|
||||
bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
bg_draw = ImageDraw.Draw(bg_layer)
|
||||
bg_rgba = bg_color + (int(255 * background_opacity),)
|
||||
bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba)
|
||||
result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB'))
|
||||
|
||||
# Draw text within padded box
|
||||
return prim_text(result, text,
|
||||
x=x + padding, y=y + padding,
|
||||
width=width - 2 * padding, height=height - 2 * padding,
|
||||
font=font, font_size=font_size, font_name=font_name,
|
||||
color=color, opacity=opacity,
|
||||
align=align, valign=valign, fit=fit,
|
||||
**kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Legacy text functions (keep for compatibility)
|
||||
# =============================================================================
|
||||
|
||||
def prim_draw_char(img, char, x, y, font_size=16, color=None):
|
||||
"""Draw a single character at (x, y)."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
pil_img = Image.fromarray(img)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
font = _get_default_font(font_size)
|
||||
draw.text((x, y), char, fill=tuple(color), font=font)
|
||||
return np.array(pil_img)
|
||||
"""Draw a single character at (x, y). Legacy function."""
|
||||
return prim_text(img, str(char), x=int(x), y=int(y),
|
||||
font_size=int(font_size), color=color)
|
||||
|
||||
|
||||
def prim_draw_text(img, text, x, y, font_size=16, color=None):
|
||||
"""Draw text string at (x, y)."""
|
||||
"""Draw text string at (x, y). Legacy function."""
|
||||
return prim_text(img, str(text), x=int(x), y=int(y),
|
||||
font_size=int(font_size), color=color)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shape Drawing
|
||||
# =============================================================================
|
||||
|
||||
def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0):
|
||||
"""
|
||||
Fill a rectangle with color.
|
||||
|
||||
(fill-rect frame 10 10 100 50 (255 0 0))
|
||||
(fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5)
|
||||
"""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
pil_img = Image.fromarray(img)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
font = _get_default_font(font_size)
|
||||
draw.text((x, y), text, fill=tuple(color), font=font)
|
||||
return np.array(pil_img)
|
||||
|
||||
|
||||
def prim_fill_rect(img, x, y, w, h, color=None):
|
||||
"""Fill a rectangle with color."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
result[y:y+h, x:x+w] = color
|
||||
return result
|
||||
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
result[y:y+h, x:x+w] = color
|
||||
return result
|
||||
|
||||
# With opacity, use alpha compositing
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
draw.rectangle([x, y, x + w, y + h], fill=fill_rgba)
|
||||
result = Image.alpha_composite(pil_img, layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1):
|
||||
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0):
|
||||
"""Draw rectangle outline."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)),
|
||||
tuple(color), thickness)
|
||||
return result
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)),
|
||||
tuple(int(c) for c in color), int(thickness))
|
||||
return result
|
||||
|
||||
# With opacity
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
draw.rectangle([int(x), int(y), int(x+w), int(y+h)],
|
||||
outline=outline_rgba, width=int(thickness))
|
||||
result = Image.alpha_composite(pil_img, layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1):
|
||||
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0):
|
||||
"""Draw a line from (x1, y1) to (x2, y2)."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||
tuple(color), thickness)
|
||||
return result
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||
tuple(int(c) for c in color), int(thickness))
|
||||
return result
|
||||
|
||||
# With opacity
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
draw.line([(int(x1), int(y1)), (int(x2), int(y2))],
|
||||
fill=line_rgba, width=int(thickness))
|
||||
result = Image.alpha_composite(pil_img, layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, fill=False):
|
||||
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1,
|
||||
fill=False, opacity: float = 1.0):
|
||||
"""Draw a circle."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
t = -1 if fill else thickness
|
||||
cv2.circle(result, (int(cx), int(cy)), int(radius), tuple(color), t)
|
||||
return result
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
t = -1 if fill else int(thickness)
|
||||
cv2.circle(result, (int(cx), int(cy)), int(radius),
|
||||
tuple(int(c) for c in color), t)
|
||||
return result
|
||||
|
||||
# With opacity
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
cx, cy, r = int(cx), int(cy), int(radius)
|
||||
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
|
||||
if fill:
|
||||
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba)
|
||||
else:
|
||||
draw.ellipse([cx - r, cy - r, cx + r, cy + r],
|
||||
outline=rgba, width=int(thickness))
|
||||
|
||||
result = Image.alpha_composite(pil_img, layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, thickness=1, fill=False):
|
||||
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None,
|
||||
thickness=1, fill=False, opacity: float = 1.0):
|
||||
"""Draw an ellipse."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
t = -1 if fill else thickness
|
||||
cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)),
|
||||
angle, 0, 360, tuple(color), t)
|
||||
return result
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
t = -1 if fill else int(thickness)
|
||||
cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)),
|
||||
float(angle), 0, 360, tuple(int(c) for c in color), t)
|
||||
return result
|
||||
|
||||
# With opacity (note: PIL doesn't support rotated ellipses easily)
|
||||
# Fall back to cv2 on a separate layer
|
||||
layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
||||
t = -1 if fill else int(thickness)
|
||||
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)),
|
||||
float(angle), 0, 360, rgba, t)
|
||||
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
pil_layer = Image.fromarray(layer)
|
||||
result = Image.alpha_composite(pil_img, pil_layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def prim_draw_polygon(img, points, color=None, thickness=1, fill=False):
|
||||
def prim_draw_polygon(img, points, color=None, thickness=1,
|
||||
fill=False, opacity: float = 1.0):
|
||||
"""Draw a polygon from list of [x, y] points."""
|
||||
if color is None:
|
||||
color = [255, 255, 255]
|
||||
|
||||
result = img.copy()
|
||||
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
|
||||
if opacity >= 1.0:
|
||||
result = img.copy()
|
||||
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
|
||||
if fill:
|
||||
cv2.fillPoly(result, [pts], tuple(int(c) for c in color))
|
||||
else:
|
||||
cv2.polylines(result, [pts], True,
|
||||
tuple(int(c) for c in color), int(thickness))
|
||||
return result
|
||||
|
||||
# With opacity
|
||||
pil_img = Image.fromarray(img).convert('RGBA')
|
||||
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
|
||||
pts_flat = [(int(p[0]), int(p[1])) for p in points]
|
||||
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||
|
||||
if fill:
|
||||
cv2.fillPoly(result, [pts], tuple(color))
|
||||
draw.polygon(pts_flat, fill=rgba)
|
||||
else:
|
||||
cv2.polylines(result, [pts], True, tuple(color), thickness)
|
||||
draw.polygon(pts_flat, outline=rgba, width=int(thickness))
|
||||
|
||||
return result
|
||||
result = Image.alpha_composite(pil_img, layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PRIMITIVES Export
|
||||
# =============================================================================
|
||||
|
||||
PRIMITIVES = {
|
||||
# Text
|
||||
# Font management
|
||||
'make-font': prim_make_font,
|
||||
'list-fonts': prim_list_fonts,
|
||||
'font-size': prim_font_size,
|
||||
|
||||
# Text measurement
|
||||
'text-size': prim_text_size,
|
||||
'text-metrics': prim_text_metrics,
|
||||
'fit-text-size': prim_fit_text_size,
|
||||
'fit-font': prim_fit_font,
|
||||
|
||||
# Text drawing
|
||||
'text': prim_text,
|
||||
'text-box': prim_text_box,
|
||||
|
||||
# Legacy text (compatibility)
|
||||
'draw-char': prim_draw_char,
|
||||
'draw-text': prim_draw_text,
|
||||
|
||||
|
||||
@@ -8,12 +8,18 @@ GPU Acceleration:
|
||||
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
||||
- Hardware video decoding (NVDEC) is used when available
|
||||
- Dramatically improves performance on GPU nodes
|
||||
|
||||
Async Prefetching:
|
||||
- Set STREAMING_PREFETCH=1 to enable background frame prefetching
|
||||
- Decodes upcoming frames while current frame is being processed
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import json
|
||||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
@@ -28,6 +34,10 @@ except ImportError:
|
||||
# Disabled by default until all primitives support GPU frames
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
||||
|
||||
# Async prefetch mode - decode frames in background thread
|
||||
PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1"
|
||||
PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10"))
|
||||
|
||||
# Check for hardware decode support (cached)
|
||||
_HWDEC_AVAILABLE = None
|
||||
|
||||
@@ -283,6 +293,122 @@ class VideoSource:
|
||||
self._proc = None
|
||||
|
||||
|
||||
class PrefetchingVideoSource:
|
||||
"""
|
||||
Video source with background prefetching for improved performance.
|
||||
|
||||
Wraps VideoSource and adds a background thread that pre-decodes
|
||||
upcoming frames while the main thread processes the current frame.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, fps: float = 30, buffer_size: int = None):
|
||||
self._source = VideoSource(path, fps)
|
||||
self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE
|
||||
self._buffer = {} # time -> frame
|
||||
self._buffer_lock = threading.Lock()
|
||||
self._prefetch_time = 0.0
|
||||
self._frame_time = 1.0 / fps
|
||||
self._stop_event = threading.Event()
|
||||
self._request_event = threading.Event()
|
||||
self._target_time = 0.0
|
||||
|
||||
# Start prefetch thread
|
||||
self._thread = threading.Thread(target=self._prefetch_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
import sys
|
||||
print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr)
|
||||
|
||||
def _prefetch_loop(self):
|
||||
"""Background thread that pre-reads frames."""
|
||||
while not self._stop_event.is_set():
|
||||
# Wait for work or timeout
|
||||
self._request_event.wait(timeout=0.01)
|
||||
self._request_event.clear()
|
||||
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
# Prefetch frames ahead of target time
|
||||
target = self._target_time
|
||||
with self._buffer_lock:
|
||||
# Clean old frames (more than 1 second behind)
|
||||
old_times = [t for t in self._buffer.keys() if t < target - 1.0]
|
||||
for t in old_times:
|
||||
del self._buffer[t]
|
||||
|
||||
# Count how many frames we have buffered ahead
|
||||
buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target)
|
||||
|
||||
# Prefetch if buffer not full
|
||||
if buffered_ahead < self._buffer_size:
|
||||
# Find next time to prefetch
|
||||
prefetch_t = target
|
||||
with self._buffer_lock:
|
||||
existing_times = set(self._buffer.keys())
|
||||
for _ in range(self._buffer_size):
|
||||
if prefetch_t not in existing_times:
|
||||
break
|
||||
prefetch_t += self._frame_time
|
||||
|
||||
# Read the frame (this is the slow part)
|
||||
try:
|
||||
frame = self._source.read_at(prefetch_t)
|
||||
with self._buffer_lock:
|
||||
self._buffer[prefetch_t] = frame
|
||||
except Exception as e:
|
||||
import sys
|
||||
print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr)
|
||||
|
||||
def read_at(self, t: float) -> np.ndarray:
|
||||
"""Read frame at specific time, using prefetch buffer if available."""
|
||||
self._target_time = t
|
||||
self._request_event.set() # Wake up prefetch thread
|
||||
|
||||
# Round to frame time for buffer lookup
|
||||
t_key = round(t / self._frame_time) * self._frame_time
|
||||
|
||||
# Check buffer first
|
||||
with self._buffer_lock:
|
||||
if t_key in self._buffer:
|
||||
return self._buffer[t_key]
|
||||
# Also check for close matches (within half frame time)
|
||||
for buf_t, frame in self._buffer.items():
|
||||
if abs(buf_t - t) < self._frame_time * 0.5:
|
||||
return frame
|
||||
|
||||
# Not in buffer - read directly (blocking)
|
||||
frame = self._source.read_at(t)
|
||||
|
||||
# Store in buffer
|
||||
with self._buffer_lock:
|
||||
self._buffer[t_key] = frame
|
||||
|
||||
return frame
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Read frame (uses last cached or t=0)."""
|
||||
return self.read_at(0)
|
||||
|
||||
def skip(self):
|
||||
"""No-op for seek-based reading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self._source.size
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._source.path
|
||||
|
||||
def close(self):
|
||||
self._stop_event.set()
|
||||
self._request_event.set() # Wake up thread to exit
|
||||
self._thread.join(timeout=1.0)
|
||||
self._source.close()
|
||||
|
||||
|
||||
class AudioAnalyzer:
|
||||
"""Audio analyzer for energy and beat detection."""
|
||||
|
||||
@@ -394,7 +520,12 @@ class AudioAnalyzer:
|
||||
# === Primitives ===
|
||||
|
||||
def prim_make_video_source(path: str, fps: float = 30):
|
||||
"""Create a video source from a file path."""
|
||||
"""Create a video source from a file path.
|
||||
|
||||
Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default).
|
||||
"""
|
||||
if PREFETCH_ENABLED:
|
||||
return PrefetchingVideoSource(path, fps)
|
||||
return VideoSource(path, fps)
|
||||
|
||||
|
||||
|
||||
1382
sexp_effects/primitive_libs/xector.py
Normal file
1382
sexp_effects/primitive_libs/xector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -797,31 +797,63 @@ def prim_tan(x: float) -> float:
|
||||
return math.tan(x)
|
||||
|
||||
|
||||
def prim_atan2(y: float, x: float) -> float:
|
||||
def prim_atan2(y, x):
|
||||
if hasattr(y, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape)
|
||||
return math.atan2(y, x)
|
||||
|
||||
|
||||
def prim_sqrt(x: float) -> float:
|
||||
def prim_sqrt(x):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.sqrt(np.maximum(0, x._data)), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.sqrt(np.maximum(0, x))
|
||||
return math.sqrt(max(0, x))
|
||||
|
||||
|
||||
def prim_pow(x: float, y: float) -> float:
|
||||
def prim_pow(x, y):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
y_data = y._data if hasattr(y, '_data') else y
|
||||
return Xector(np.power(x._data, y_data), x._shape)
|
||||
return math.pow(x, y)
|
||||
|
||||
|
||||
def prim_abs(x: float) -> float:
|
||||
def prim_abs(x):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.abs(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.abs(x)
|
||||
return abs(x)
|
||||
|
||||
|
||||
def prim_floor(x: float) -> int:
|
||||
def prim_floor(x):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.floor(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.floor(x)
|
||||
return int(math.floor(x))
|
||||
|
||||
|
||||
def prim_ceil(x: float) -> int:
|
||||
def prim_ceil(x):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.ceil(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.ceil(x)
|
||||
return int(math.ceil(x))
|
||||
|
||||
|
||||
def prim_round(x: float) -> int:
|
||||
def prim_round(x):
|
||||
if hasattr(x, '_data'): # Xector
|
||||
from sexp_effects.primitive_libs.xector import Xector
|
||||
return Xector(np.round(x._data), x._shape)
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.round(x)
|
||||
return int(round(x))
|
||||
|
||||
|
||||
|
||||
860
streaming/jax_typography.py
Normal file
860
streaming/jax_typography.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
JAX Typography Primitives
|
||||
|
||||
Two approaches for text rendering, both compile to JAX/GPU:
|
||||
|
||||
## 1. TextStrip - Pixel-perfect static text
|
||||
Pre-render entire strings at compile time using PIL.
|
||||
Perfect sub-pixel anti-aliasing, exact match with PIL.
|
||||
Use for: static titles, labels, any text without per-character effects.
|
||||
|
||||
S-expression:
|
||||
(let ((strip (render-text-strip "Hello World" 48)))
|
||||
(place-text-strip frame strip x y :color white))
|
||||
|
||||
## 2. Glyph-by-glyph - Dynamic text effects
|
||||
Individual glyph placement for wave, arc, audio-reactive effects.
|
||||
Each character can have independent position, color, opacity.
|
||||
Note: slight anti-aliasing differences vs PIL due to integer positioning.
|
||||
|
||||
S-expression:
|
||||
; Wave text - y oscillates with character index
|
||||
(let ((glyphs (text-glyphs "Wavy" 48)))
|
||||
(first
|
||||
(fold glyphs (list frame 0)
|
||||
(lambda (acc g)
|
||||
(let ((frm (first acc))
|
||||
(cursor (second acc))
|
||||
(i (length acc))) ; approximate index
|
||||
(list
|
||||
(place-glyph frm (glyph-image g)
|
||||
(+ x cursor)
|
||||
(+ y (* amplitude (sin (* i frequency))))
|
||||
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||
white 1.0)
|
||||
(+ cursor (glyph-advance g))))))))
|
||||
|
||||
; Audio-reactive spacing
|
||||
(let ((glyphs (text-glyphs "Bass" 48))
|
||||
(bass (audio-band 0 200)))
|
||||
(first
|
||||
(fold glyphs (list frame 0)
|
||||
(lambda (acc g)
|
||||
(let ((frm (first acc))
|
||||
(cursor (second acc)))
|
||||
(list
|
||||
(place-glyph frm (glyph-image g)
|
||||
(+ x cursor) y
|
||||
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||
white 1.0)
|
||||
(+ cursor (glyph-advance g) (* bass 20))))))))
|
||||
|
||||
Kerning support:
|
||||
; With kerning adjustment
|
||||
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from typing import Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Glyph Data (computed at compile time)
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class GlyphData:
|
||||
"""Glyph data computed at compile time.
|
||||
|
||||
Attributes:
|
||||
char: The character
|
||||
image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime
|
||||
advance: Horizontal advance (distance to next glyph origin)
|
||||
bearing_x: Left side bearing (x offset from origin to first pixel)
|
||||
bearing_y: Top bearing (y offset from baseline to top of glyph)
|
||||
width: Image width
|
||||
height: Image height
|
||||
"""
|
||||
char: str
|
||||
image: np.ndarray # (H, W, 4) RGBA uint8
|
||||
advance: float
|
||||
bearing_x: float
|
||||
bearing_y: float
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# Font cache: (font_name, font_size) -> {char: GlyphData}
|
||||
_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {}
|
||||
|
||||
# Font metrics cache: (font_name, font_size) -> (ascent, descent)
|
||||
_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {}
|
||||
|
||||
# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment}
|
||||
# Kerning adjustment is added to advance: new_advance = advance + kerning
|
||||
# Typically negative (characters move closer together)
|
||||
_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {}
|
||||
|
||||
|
||||
def _load_font(font_name: str = None, font_size: int = 32):
|
||||
"""Load a font. Called at compile time."""
|
||||
from PIL import ImageFont
|
||||
|
||||
candidates = [
|
||||
font_name,
|
||||
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
|
||||
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
|
||||
]
|
||||
|
||||
for path in candidates:
|
||||
if path is None:
|
||||
continue
|
||||
try:
|
||||
return ImageFont.truetype(path, font_size)
|
||||
except (IOError, OSError):
|
||||
continue
|
||||
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]:
|
||||
"""Get or create glyph cache for a font. Called at compile time."""
|
||||
cache_key = (font_name, font_size)
|
||||
|
||||
if cache_key in _GLYPH_CACHE:
|
||||
return _GLYPH_CACHE[cache_key]
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
ascent, descent = font.getmetrics()
|
||||
_METRICS_CACHE[cache_key] = (ascent, descent)
|
||||
|
||||
glyphs = {}
|
||||
charset = ''.join(chr(i) for i in range(32, 127))
|
||||
|
||||
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
|
||||
temp_draw = ImageDraw.Draw(temp_img)
|
||||
|
||||
for char in charset:
|
||||
# Get metrics
|
||||
bbox = temp_draw.textbbox((0, 0), char, font=font)
|
||||
advance = font.getlength(char)
|
||||
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
|
||||
# Create glyph image with padding
|
||||
padding = 2
|
||||
img_w = max(int(x_max - x_min) + padding * 2, 1)
|
||||
img_h = max(int(y_max - y_min) + padding * 2, 1)
|
||||
|
||||
glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0))
|
||||
glyph_draw = ImageDraw.Draw(glyph_img)
|
||||
|
||||
# Draw at position accounting for bbox offset
|
||||
draw_x = padding - x_min
|
||||
draw_y = padding - y_min
|
||||
glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
glyphs[char] = GlyphData(
|
||||
char=char,
|
||||
image=np.array(glyph_img, dtype=np.uint8),
|
||||
advance=float(advance),
|
||||
bearing_x=float(x_min),
|
||||
bearing_y=float(-y_min), # Distance from baseline to top
|
||||
width=img_w,
|
||||
height=img_h,
|
||||
)
|
||||
|
||||
_GLYPH_CACHE[cache_key] = glyphs
|
||||
return glyphs
|
||||
|
||||
|
||||
def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]:
|
||||
"""Get or create kerning cache for a font. Called at compile time.
|
||||
|
||||
Kerning is computed as:
|
||||
kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b)
|
||||
|
||||
This gives the adjustment needed when placing 'b' after 'a'.
|
||||
Typically negative (characters move closer together).
|
||||
"""
|
||||
cache_key = (font_name, font_size)
|
||||
|
||||
if cache_key in _KERNING_CACHE:
|
||||
return _KERNING_CACHE[cache_key]
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
kerning = {}
|
||||
|
||||
# Compute kerning for all printable ASCII pairs
|
||||
charset = ''.join(chr(i) for i in range(32, 127))
|
||||
|
||||
# Pre-compute individual character lengths
|
||||
char_lengths = {c: font.getlength(c) for c in charset}
|
||||
|
||||
# Compute kerning for each pair
|
||||
for c1 in charset:
|
||||
for c2 in charset:
|
||||
pair_length = font.getlength(c1 + c2)
|
||||
individual_sum = char_lengths[c1] + char_lengths[c2]
|
||||
kern = pair_length - individual_sum
|
||||
|
||||
# Only store non-zero kerning to save memory
|
||||
if abs(kern) > 0.01:
|
||||
kerning[(c1, c2)] = kern
|
||||
|
||||
_KERNING_CACHE[cache_key] = kerning
|
||||
return kerning
|
||||
|
||||
|
||||
def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get kerning adjustment between two characters. Compile-time.
|
||||
|
||||
Returns the adjustment to add to char1's advance when char2 follows.
|
||||
Typically negative (characters move closer).
|
||||
|
||||
Usage in S-expression:
|
||||
(+ (glyph-advance g1) (kerning g1 g2))
|
||||
"""
|
||||
kerning_cache = _get_kerning_cache(font_name, font_size)
|
||||
return kerning_cache.get((char1, char2), 0.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextStrip:
|
||||
"""Pre-rendered text strip with proper sub-pixel anti-aliasing.
|
||||
|
||||
Rendered at compile time using PIL for exact matching.
|
||||
At runtime, just composite onto frame at integer positions.
|
||||
|
||||
Attributes:
|
||||
text: The original text
|
||||
image: RGBA image as numpy array (H, W, 4)
|
||||
width: Strip width
|
||||
height: Strip height
|
||||
baseline_y: Y position of baseline within the strip
|
||||
bearing_x: Left side bearing of first character
|
||||
anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right)
|
||||
anchor_y: Y offset for anchor point (depends on anchor type)
|
||||
stroke_width: Stroke width used when rendering
|
||||
"""
|
||||
text: str
|
||||
image: np.ndarray
|
||||
width: int
|
||||
height: int
|
||||
baseline_y: int
|
||||
bearing_x: float
|
||||
anchor_x: float = 0.0
|
||||
anchor_y: float = 0.0
|
||||
stroke_width: int = 0
|
||||
|
||||
|
||||
# Text strip cache: cache_key -> TextStrip
|
||||
_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {}
|
||||
|
||||
|
||||
def render_text_strip(
|
||||
text: str,
|
||||
font_name: str = None,
|
||||
font_size: int = 32,
|
||||
stroke_width: int = 0,
|
||||
stroke_fill: tuple = None,
|
||||
anchor: str = "la", # left-ascender (PIL default is "la")
|
||||
multiline: bool = False,
|
||||
line_spacing: int = 4,
|
||||
align: str = "left",
|
||||
) -> TextStrip:
|
||||
"""Render text to a strip at compile time. Perfect sub-pixel anti-aliasing.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
font_name: Path to font file (None for default)
|
||||
font_size: Font size in pixels
|
||||
stroke_width: Outline width in pixels (0 for no outline)
|
||||
stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black
|
||||
anchor: PIL anchor code - first char: h=left, m=middle, r=right
|
||||
second char: a=ascender, t=top, m=middle, s=baseline, d=descender
|
||||
multiline: If True, handle newlines in text
|
||||
line_spacing: Extra pixels between lines (for multiline)
|
||||
align: 'left', 'center', 'right' (for multiline)
|
||||
|
||||
Returns:
|
||||
TextStrip with pre-rendered text
|
||||
"""
|
||||
# Build cache key from all parameters
|
||||
cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align)
|
||||
if cache_key in _TEXT_STRIP_CACHE:
|
||||
return _TEXT_STRIP_CACHE[cache_key]
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
font = _load_font(font_name, font_size)
|
||||
ascent, descent = font.getmetrics()
|
||||
|
||||
# Default stroke fill to black
|
||||
if stroke_fill is None:
|
||||
stroke_fill = (0, 0, 0, 255)
|
||||
elif len(stroke_fill) == 3:
|
||||
stroke_fill = (*stroke_fill, 255)
|
||||
|
||||
# Get text bbox (accounting for stroke)
|
||||
temp = Image.new('RGBA', (1, 1))
|
||||
temp_draw = ImageDraw.Draw(temp)
|
||||
|
||||
if multiline:
|
||||
bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing,
|
||||
stroke_width=stroke_width)
|
||||
else:
|
||||
bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width)
|
||||
|
||||
# bbox is (left, top, right, bottom) relative to origin
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
|
||||
# Create image with padding (extra for stroke)
|
||||
padding = 2 + stroke_width
|
||||
img_width = max(int(x_max - x_min) + padding * 2, 1)
|
||||
img_height = max(int(y_max - y_min) + padding * 2, 1)
|
||||
|
||||
# Create RGBA image
|
||||
img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw text at position that puts it in the image
|
||||
draw_x = padding - x_min
|
||||
draw_y = padding - y_min
|
||||
|
||||
if multiline:
|
||||
draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||
spacing=line_spacing, align=align,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||
else:
|
||||
draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||
|
||||
# Baseline is at y=0 in text coordinates, which is at draw_y in image
|
||||
baseline_y = draw_y
|
||||
|
||||
# Convert to numpy for pixel analysis
|
||||
img_array = np.array(img, dtype=np.uint8)
|
||||
|
||||
# Calculate anchor offsets
|
||||
# For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching
|
||||
h_anchor = anchor[0] if len(anchor) > 0 else 'l'
|
||||
v_anchor = anchor[1] if len(anchor) > 1 else 'a'
|
||||
|
||||
# Find actual pixel bounds (for middle anchor calculations)
|
||||
alpha = img_array[:, :, 3]
|
||||
nonzero_cols = np.where(alpha.max(axis=0) > 0)[0]
|
||||
nonzero_rows = np.where(alpha.max(axis=1) > 0)[0]
|
||||
|
||||
if len(nonzero_cols) > 0:
|
||||
pixel_x_min = nonzero_cols.min()
|
||||
pixel_x_max = nonzero_cols.max()
|
||||
pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0
|
||||
else:
|
||||
pixel_x_center = img_width / 2.0
|
||||
|
||||
if len(nonzero_rows) > 0:
|
||||
pixel_y_min = nonzero_rows.min()
|
||||
pixel_y_max = nonzero_rows.max()
|
||||
pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0
|
||||
else:
|
||||
pixel_y_center = img_height / 2.0
|
||||
|
||||
# Horizontal offset
|
||||
text_width = x_max - x_min
|
||||
if h_anchor == 'l': # left edge of text
|
||||
anchor_x = float(draw_x)
|
||||
elif h_anchor == 'm': # middle - use actual pixel center for perfect matching
|
||||
anchor_x = pixel_x_center
|
||||
elif h_anchor == 'r': # right edge of text
|
||||
anchor_x = float(draw_x + text_width)
|
||||
else:
|
||||
anchor_x = float(draw_x)
|
||||
|
||||
# Vertical offset
|
||||
# PIL anchor positions are based on font metrics (ascent/descent):
|
||||
# - 'a' (ascender): at the ascender line → draw_y in strip
|
||||
# - 't' (top): at top of text bounding box → padding in strip
|
||||
# - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender
|
||||
# - 's' (baseline): at baseline = ascent below ascender
|
||||
# - 'd' (descender): at descender line = ascent + descent below ascender
|
||||
|
||||
if v_anchor == 'a': # ascender
|
||||
anchor_y = float(draw_y)
|
||||
elif v_anchor == 't': # top of bbox
|
||||
anchor_y = float(padding)
|
||||
elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation)
|
||||
anchor_y = float(draw_y + (ascent + descent) / 2.0)
|
||||
elif v_anchor == 's': # baseline
|
||||
anchor_y = float(draw_y + ascent)
|
||||
elif v_anchor == 'd': # descender
|
||||
anchor_y = float(draw_y + ascent + descent)
|
||||
else:
|
||||
anchor_y = float(draw_y) # default to ascender
|
||||
|
||||
strip = TextStrip(
|
||||
text=text,
|
||||
image=img_array,
|
||||
width=img_width,
|
||||
height=img_height,
|
||||
baseline_y=baseline_y,
|
||||
bearing_x=float(x_min),
|
||||
anchor_x=anchor_x,
|
||||
anchor_y=anchor_y,
|
||||
stroke_width=stroke_width,
|
||||
)
|
||||
|
||||
_TEXT_STRIP_CACHE[cache_key] = strip
|
||||
return strip
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Compile-time functions (called during S-expression compilation)
|
||||
# =============================================================================
|
||||
|
||||
def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData:
|
||||
"""Get glyph data for a single character. Compile-time."""
|
||||
cache = _get_glyph_cache(font_name, font_size)
|
||||
return cache.get(char, cache.get(' '))
|
||||
|
||||
|
||||
def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list:
|
||||
"""Get glyph data for a string. Compile-time."""
|
||||
cache = _get_glyph_cache(font_name, font_size)
|
||||
space = cache.get(' ')
|
||||
return [cache.get(c, space) for c in text]
|
||||
|
||||
|
||||
def get_font_ascent(font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get font ascent. Compile-time."""
|
||||
_get_glyph_cache(font_name, font_size) # Ensure cache exists
|
||||
return _METRICS_CACHE[(font_name, font_size)][0]
|
||||
|
||||
|
||||
def get_font_descent(font_name: str = None, font_size: int = 32) -> float:
|
||||
"""Get font descent. Compile-time."""
|
||||
_get_glyph_cache(font_name, font_size)
|
||||
return _METRICS_CACHE[(font_name, font_size)][1]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JAX Runtime Primitives
|
||||
# =============================================================================
|
||||
|
||||
def place_glyph_jax(
|
||||
frame: jnp.ndarray,
|
||||
glyph_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||
x: float,
|
||||
y: float,
|
||||
bearing_x: float,
|
||||
bearing_y: float,
|
||||
color: jnp.ndarray, # (3,) RGB 0-255
|
||||
opacity: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Place a glyph onto a frame. This is the core JAX primitive.
|
||||
|
||||
All positioning math can use traced values (x, y from audio, time, etc.)
|
||||
The glyph_image is static (determined at compile time).
|
||||
|
||||
Args:
|
||||
frame: (H, W, 3) RGB frame
|
||||
glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array)
|
||||
x: X position of glyph origin (baseline point)
|
||||
y: Y position of baseline
|
||||
bearing_x: Left side bearing
|
||||
bearing_y: Top bearing (from baseline to top)
|
||||
color: RGB color array
|
||||
opacity: Opacity 0-1
|
||||
|
||||
Returns:
|
||||
Frame with glyph composited
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
gh, gw = glyph_image.shape[:2]
|
||||
|
||||
# Calculate destination position
|
||||
# bearing_x: how far right of origin the glyph starts (can be negative)
|
||||
# bearing_y: how far up from baseline the glyph extends
|
||||
padding = 2 # Must match padding used in glyph creation
|
||||
dst_x = x + bearing_x - padding
|
||||
dst_y = y - bearing_y - padding
|
||||
|
||||
# Extract glyph RGB and alpha
|
||||
glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||
opacity_int = jnp.round(opacity * 255)
|
||||
glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32)
|
||||
glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||
|
||||
# Apply color tint (glyph is white, multiply by color)
|
||||
color_normalized = color.astype(jnp.float32) / 255.0
|
||||
tinted = glyph_rgb * color_normalized
|
||||
|
||||
from jax.lax import dynamic_update_slice
|
||||
|
||||
# Use padded buffer to avoid XLA's dynamic_update_slice clamping
|
||||
buf_h = h + 2 * gh
|
||||
buf_w = w + 2 * gw
|
||||
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||
|
||||
dst_x_int = dst_x.astype(jnp.int32)
|
||||
dst_y_int = dst_y.astype(jnp.int32)
|
||||
place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32)
|
||||
place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32)
|
||||
|
||||
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||
alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0))
|
||||
|
||||
rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :]
|
||||
alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :]
|
||||
|
||||
# Alpha composite using PIL-compatible integer arithmetic
|
||||
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||
dst_int = frame.astype(jnp.int32)
|
||||
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||
|
||||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||
|
||||
|
||||
def place_text_strip_jax(
|
||||
frame: jnp.ndarray,
|
||||
strip_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||
x: float,
|
||||
y: float,
|
||||
baseline_y: int,
|
||||
bearing_x: float,
|
||||
color: jnp.ndarray,
|
||||
opacity: float = 1.0,
|
||||
anchor_x: float = 0.0,
|
||||
anchor_y: float = 0.0,
|
||||
stroke_width: int = 0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Place a pre-rendered text strip onto a frame.
|
||||
|
||||
The strip was rendered at compile time with proper sub-pixel anti-aliasing.
|
||||
This just composites it at the specified position.
|
||||
|
||||
Args:
|
||||
frame: (H, W, 3) RGB frame
|
||||
strip_image: (sh, sw, 4) RGBA text strip
|
||||
x: X position for anchor point
|
||||
y: Y position for anchor point
|
||||
baseline_y: Y position of baseline within the strip
|
||||
bearing_x: Left side bearing
|
||||
color: RGB color
|
||||
opacity: Opacity 0-1
|
||||
anchor_x: X offset of anchor point within strip
|
||||
anchor_y: Y offset of anchor point within strip
|
||||
stroke_width: Stroke width used when rendering (affects padding)
|
||||
|
||||
Returns:
|
||||
Frame with text composited
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
sh, sw = strip_image.shape[:2]
|
||||
|
||||
# Calculate destination position
|
||||
# Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame
|
||||
# anchor_x/anchor_y already account for the anchor position within the strip
|
||||
# Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding)
|
||||
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||
|
||||
# Extract strip RGB and alpha
|
||||
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||
# Use jnp.round (banker's rounding) to match Python's round() used by PIL
|
||||
opacity_int = jnp.round(opacity * 255)
|
||||
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
|
||||
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||
|
||||
# Apply color tint
|
||||
color_normalized = color.astype(jnp.float32) / 255.0
|
||||
tinted = strip_rgb * color_normalized
|
||||
|
||||
from jax.lax import dynamic_update_slice
|
||||
|
||||
# Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior.
|
||||
# XLA clamps indices so the update fits, which silently shifts the strip.
|
||||
# By placing into a buffer padded by strip dimensions, then extracting the
|
||||
# frame-sized region, we get correct clipping for both overflow and underflow.
|
||||
buf_h = h + 2 * sh
|
||||
buf_w = w + 2 * sw
|
||||
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||
|
||||
# Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer
|
||||
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
|
||||
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
|
||||
|
||||
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||
alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
|
||||
|
||||
# Extract frame-sized region (sh, sw are compile-time constants from strip shape)
|
||||
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
|
||||
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
|
||||
|
||||
# Alpha composite using PIL-compatible integer arithmetic:
|
||||
# result = (src * alpha + dst * (255 - alpha) + 127) // 255
|
||||
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||
dst_int = frame.astype(jnp.int32)
|
||||
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||
|
||||
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||
|
||||
|
||||
def place_glyph_simple(
|
||||
frame: jnp.ndarray,
|
||||
glyph: GlyphData,
|
||||
x: float,
|
||||
y: float,
|
||||
color: tuple = (255, 255, 255),
|
||||
opacity: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Convenience wrapper that takes GlyphData directly.
|
||||
Converts glyph image to JAX array.
|
||||
|
||||
For S-expression use, prefer place_glyph_jax with pre-converted arrays.
|
||||
"""
|
||||
glyph_jax = jnp.asarray(glyph.image)
|
||||
color_jax = jnp.array(color, dtype=jnp.float32)
|
||||
|
||||
return place_glyph_jax(
|
||||
frame, glyph_jax, x, y,
|
||||
glyph.bearing_x, glyph.bearing_y,
|
||||
color_jax, opacity
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# S-Expression Primitive Bindings
|
||||
# =============================================================================
|
||||
|
||||
def bind_typography_primitives(env: dict) -> dict:
|
||||
"""
|
||||
Add typography primitives to an S-expression environment.
|
||||
|
||||
Primitives added:
|
||||
(text-glyphs text font-size) -> list of glyph data
|
||||
(glyph-image g) -> JAX array (H, W, 4)
|
||||
(glyph-advance g) -> float
|
||||
(glyph-bearing-x g) -> float
|
||||
(glyph-bearing-y g) -> float
|
||||
(glyph-width g) -> int
|
||||
(glyph-height g) -> int
|
||||
(font-ascent font-size) -> float
|
||||
(font-descent font-size) -> float
|
||||
(place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame
|
||||
"""
|
||||
|
||||
def prim_text_glyphs(text, font_size=32, font_name=None):
|
||||
"""Get list of glyph data for text. Compile-time."""
|
||||
return get_glyphs(str(text), font_name, int(font_size))
|
||||
|
||||
def prim_glyph_image(glyph):
|
||||
"""Get glyph image as JAX array."""
|
||||
return jnp.asarray(glyph.image)
|
||||
|
||||
def prim_glyph_advance(glyph):
|
||||
"""Get glyph advance width."""
|
||||
return glyph.advance
|
||||
|
||||
def prim_glyph_bearing_x(glyph):
|
||||
"""Get glyph left side bearing."""
|
||||
return glyph.bearing_x
|
||||
|
||||
def prim_glyph_bearing_y(glyph):
|
||||
"""Get glyph top bearing."""
|
||||
return glyph.bearing_y
|
||||
|
||||
def prim_glyph_width(glyph):
|
||||
"""Get glyph image width."""
|
||||
return glyph.width
|
||||
|
||||
def prim_glyph_height(glyph):
|
||||
"""Get glyph image height."""
|
||||
return glyph.height
|
||||
|
||||
def prim_font_ascent(font_size=32, font_name=None):
|
||||
"""Get font ascent."""
|
||||
return get_font_ascent(font_name, int(font_size))
|
||||
|
||||
def prim_font_descent(font_size=32, font_name=None):
|
||||
"""Get font descent."""
|
||||
return get_font_descent(font_name, int(font_size))
|
||||
|
||||
def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||
color=(255, 255, 255), opacity=1.0):
|
||||
"""Place glyph on frame. Runtime JAX operation."""
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||
color_arr, opacity)
|
||||
|
||||
def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None):
|
||||
"""Get kerning adjustment between two glyphs. Compile-time.
|
||||
|
||||
Returns adjustment to add to glyph1's advance when glyph2 follows.
|
||||
Typically negative (characters move closer).
|
||||
|
||||
Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||
"""
|
||||
return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size))
|
||||
|
||||
def prim_char_kerning(char1, char2, font_size=32, font_name=None):
|
||||
"""Get kerning adjustment between two characters. Compile-time."""
|
||||
return get_kerning(str(char1), str(char2), font_name, int(font_size))
|
||||
|
||||
# TextStrip primitives for pre-rendered text with proper anti-aliasing
|
||||
def prim_render_text_strip(text, font_size=32, font_name=None):
|
||||
"""Render text to a strip at compile time. Perfect anti-aliasing."""
|
||||
return render_text_strip(str(text), font_name, int(font_size))
|
||||
|
||||
def prim_render_text_strip_styled(
|
||||
text, font_size=32, font_name=None,
|
||||
stroke_width=0, stroke_fill=None,
|
||||
anchor="la", multiline=False, line_spacing=4, align="left"
|
||||
):
|
||||
"""Render styled text to a strip. Supports stroke, anchors, multiline.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
font_size: Size in pixels
|
||||
font_name: Path to font file
|
||||
stroke_width: Outline width (0 = no outline)
|
||||
stroke_fill: Outline color as (R,G,B) or (R,G,B,A)
|
||||
anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender)
|
||||
multiline: If True, handle newlines
|
||||
line_spacing: Extra pixels between lines
|
||||
align: "left", "center", "right" for multiline
|
||||
"""
|
||||
return render_text_strip(
|
||||
str(text), font_name, int(font_size),
|
||||
stroke_width=int(stroke_width),
|
||||
stroke_fill=stroke_fill,
|
||||
anchor=str(anchor),
|
||||
multiline=bool(multiline),
|
||||
line_spacing=int(line_spacing),
|
||||
align=str(align),
|
||||
)
|
||||
|
||||
def prim_text_strip_image(strip):
|
||||
"""Get text strip image as JAX array."""
|
||||
return jnp.asarray(strip.image)
|
||||
|
||||
def prim_text_strip_width(strip):
|
||||
"""Get text strip width."""
|
||||
return strip.width
|
||||
|
||||
def prim_text_strip_height(strip):
|
||||
"""Get text strip height."""
|
||||
return strip.height
|
||||
|
||||
def prim_text_strip_baseline_y(strip):
|
||||
"""Get text strip baseline Y position."""
|
||||
return strip.baseline_y
|
||||
|
||||
def prim_text_strip_bearing_x(strip):
|
||||
"""Get text strip left bearing."""
|
||||
return strip.bearing_x
|
||||
|
||||
def prim_text_strip_anchor_x(strip):
|
||||
"""Get text strip anchor X offset."""
|
||||
return strip.anchor_x
|
||||
|
||||
def prim_text_strip_anchor_y(strip):
|
||||
"""Get text strip anchor Y offset."""
|
||||
return strip.anchor_y
|
||||
|
||||
def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0):
|
||||
"""Place pre-rendered text strip on frame. Runtime JAX operation."""
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
return place_text_strip_jax(
|
||||
frame, strip_img, x, y,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color_arr, opacity,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
stroke_width=strip.stroke_width
|
||||
)
|
||||
|
||||
# Add to environment
|
||||
env.update({
|
||||
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
|
||||
'text-glyphs': prim_text_glyphs,
|
||||
'glyph-image': prim_glyph_image,
|
||||
'glyph-advance': prim_glyph_advance,
|
||||
'glyph-bearing-x': prim_glyph_bearing_x,
|
||||
'glyph-bearing-y': prim_glyph_bearing_y,
|
||||
'glyph-width': prim_glyph_width,
|
||||
'glyph-height': prim_glyph_height,
|
||||
'glyph-kerning': prim_glyph_kerning,
|
||||
'char-kerning': prim_char_kerning,
|
||||
'font-ascent': prim_font_ascent,
|
||||
'font-descent': prim_font_descent,
|
||||
'place-glyph': prim_place_glyph,
|
||||
# TextStrip primitives (for pixel-perfect static text)
|
||||
'render-text-strip': prim_render_text_strip,
|
||||
'render-text-strip-styled': prim_render_text_strip_styled,
|
||||
'text-strip-image': prim_text_strip_image,
|
||||
'text-strip-width': prim_text_strip_width,
|
||||
'text-strip-height': prim_text_strip_height,
|
||||
'text-strip-baseline-y': prim_text_strip_baseline_y,
|
||||
'text-strip-bearing-x': prim_text_strip_bearing_x,
|
||||
'text-strip-anchor-x': prim_text_strip_anchor_x,
|
||||
'text-strip-anchor-y': prim_text_strip_anchor_y,
|
||||
'place-text-strip': prim_place_text_strip,
|
||||
})
|
||||
|
||||
return env
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Example: Render text using primitives (for testing)
|
||||
# =============================================================================
|
||||
|
||||
def render_text_primitives(
|
||||
frame: jnp.ndarray,
|
||||
text: str,
|
||||
x: float,
|
||||
y: float,
|
||||
font_size: int = 32,
|
||||
color: tuple = (255, 255, 255),
|
||||
opacity: float = 1.0,
|
||||
use_kerning: bool = True,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Render text using the primitives.
|
||||
This is what an S-expression would compile to.
|
||||
|
||||
Args:
|
||||
use_kerning: If True, apply kerning adjustments between characters
|
||||
"""
|
||||
glyphs = get_glyphs(text, None, font_size)
|
||||
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||
|
||||
cursor = x
|
||||
for i, g in enumerate(glyphs):
|
||||
glyph_jax = jnp.asarray(g.image)
|
||||
frame = place_glyph_jax(
|
||||
frame, glyph_jax, cursor, y,
|
||||
g.bearing_x, g.bearing_y,
|
||||
color_arr, opacity
|
||||
)
|
||||
# Advance cursor with optional kerning
|
||||
advance = g.advance
|
||||
if use_kerning and i + 1 < len(glyphs):
|
||||
advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size)
|
||||
cursor = cursor + advance
|
||||
|
||||
return frame
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation:
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
@@ -62,6 +63,38 @@ class Context:
|
||||
fps: float = 30.0
|
||||
|
||||
|
||||
class DeferredEffectChain:
|
||||
"""
|
||||
Represents a chain of JAX effects that haven't been executed yet.
|
||||
|
||||
Allows effects to be accumulated through let bindings and fused
|
||||
into a single JIT-compiled function when the result is needed.
|
||||
"""
|
||||
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
|
||||
|
||||
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
|
||||
self.effects = effects # List of effect names, innermost first
|
||||
self.params_list = params_list # List of param dicts, matching effects
|
||||
self.base_frame = base_frame # The actual frame array at the start
|
||||
self.t = t
|
||||
self.frame_num = frame_num
|
||||
|
||||
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
|
||||
"""Add another effect to the chain (outermost)."""
|
||||
return DeferredEffectChain(
|
||||
self.effects + [effect_name],
|
||||
self.params_list + [params],
|
||||
self.base_frame,
|
||||
self.t,
|
||||
self.frame_num
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Allow shape check without forcing execution."""
|
||||
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
|
||||
|
||||
|
||||
class StreamInterpreter:
|
||||
"""
|
||||
Fully generic streaming sexp interpreter.
|
||||
@@ -98,6 +131,9 @@ class StreamInterpreter:
|
||||
self.use_jax = use_jax
|
||||
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
||||
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
||||
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
|
||||
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
|
||||
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
|
||||
if use_jax:
|
||||
if _init_jax():
|
||||
print("JAX acceleration enabled", file=sys.stderr)
|
||||
@@ -238,6 +274,8 @@ class StreamInterpreter:
|
||||
"""Load primitives from a Python library file.
|
||||
|
||||
Prefers GPU-accelerated versions (*_gpu.py) when available.
|
||||
Uses cached modules from sys.modules to ensure consistent state
|
||||
(e.g., same RNG instance for all interpreters).
|
||||
"""
|
||||
import importlib.util
|
||||
|
||||
@@ -264,9 +302,26 @@ class StreamInterpreter:
|
||||
if not lib_path:
|
||||
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
# Use cached module if already imported to preserve state (e.g., RNG)
|
||||
# This is critical for deterministic random number sequences
|
||||
# Check multiple possible module keys (standard import paths and our cache)
|
||||
possible_keys = [
|
||||
f"sexp_effects.primitive_libs.{actual_lib_name}",
|
||||
f"sexp_primitives.{actual_lib_name}",
|
||||
]
|
||||
|
||||
module = None
|
||||
for key in possible_keys:
|
||||
if key in sys.modules:
|
||||
module = sys.modules[key]
|
||||
break
|
||||
|
||||
if module is None:
|
||||
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
# Cache for future use under our key
|
||||
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
|
||||
|
||||
# Check if this is a GPU-accelerated module
|
||||
is_gpu = actual_lib_name.endswith('_gpu')
|
||||
@@ -452,30 +507,353 @@ class StreamInterpreter:
|
||||
|
||||
try:
|
||||
jax_fn = self.jax_effects[name]
|
||||
# Ensure frame is numpy array
|
||||
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
|
||||
# JAX handles numpy and JAX arrays natively, no conversion needed
|
||||
if hasattr(frame, 'cpu'):
|
||||
frame = frame.cpu
|
||||
elif hasattr(frame, 'get'):
|
||||
frame = frame.get()
|
||||
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||
frame = frame.get() # CuPy array -> numpy
|
||||
|
||||
# Get seed from config for deterministic random
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Call JAX function with parameters
|
||||
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||
|
||||
# Convert result back to numpy if needed
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready() # Ensure computation is complete
|
||||
if hasattr(result, '__array__'):
|
||||
result = np.asarray(result)
|
||||
|
||||
return result
|
||||
# Return JAX array directly - don't block or convert per-effect
|
||||
# Conversion to numpy happens once at frame write time
|
||||
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||
except Exception as e:
|
||||
# Fall back to interpreter on error
|
||||
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _is_jax_effect_expr(self, expr) -> bool:
|
||||
"""Check if an expression is a JAX-compiled effect call."""
|
||||
if not isinstance(expr, list) or not expr:
|
||||
return False
|
||||
head = expr[0]
|
||||
if not isinstance(head, Symbol):
|
||||
return False
|
||||
return head.name in self.jax_effects
|
||||
|
||||
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
|
||||
"""
|
||||
Extract a chain of JAX effects from an expression.
|
||||
|
||||
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
|
||||
effect_names and params_list are in execution order (innermost first).
|
||||
|
||||
For (effect1 (effect2 frame :p1 v1) :p2 v2):
|
||||
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
|
||||
"""
|
||||
if not self._is_jax_effect_expr(expr):
|
||||
return None
|
||||
|
||||
chain = []
|
||||
params_list = []
|
||||
current = expr
|
||||
|
||||
while self._is_jax_effect_expr(current):
|
||||
head = current[0]
|
||||
effect_name = head.name
|
||||
args = current[1:]
|
||||
|
||||
# Extract params for this effect
|
||||
effect = self.effects[effect_name]
|
||||
effect_params = {}
|
||||
for pname, pdef in effect['params'].items():
|
||||
effect_params[pname] = pdef.get('default', 0)
|
||||
|
||||
# Find the frame argument (first positional) and other params
|
||||
frame_arg = None
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if isinstance(args[i], Keyword):
|
||||
pname = args[i].name
|
||||
if pname in effect['params'] and i + 1 < len(args):
|
||||
effect_params[pname] = self._eval(args[i + 1], env)
|
||||
i += 2
|
||||
else:
|
||||
if frame_arg is None:
|
||||
frame_arg = args[i] # First positional is frame
|
||||
i += 1
|
||||
|
||||
chain.append(effect_name)
|
||||
params_list.append(effect_params)
|
||||
|
||||
if frame_arg is None:
|
||||
return None # No frame argument found
|
||||
|
||||
# Check if frame_arg is another effect call
|
||||
if self._is_jax_effect_expr(frame_arg):
|
||||
current = frame_arg
|
||||
else:
|
||||
# End of chain - frame_arg is the base frame
|
||||
# Reverse to get innermost-first execution order
|
||||
chain.reverse()
|
||||
params_list.reverse()
|
||||
return (chain, params_list, frame_arg)
|
||||
|
||||
return None
|
||||
|
||||
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
|
||||
"""Generate a cache key for an effect chain.
|
||||
|
||||
Includes static param values in the key since they affect compilation.
|
||||
"""
|
||||
parts = []
|
||||
for name, params in zip(effect_names, params_list):
|
||||
param_parts = []
|
||||
for pname in sorted(params.keys()):
|
||||
pval = params[pname]
|
||||
# Include static values in key (strings, bools)
|
||||
if isinstance(pval, (str, bool)):
|
||||
param_parts.append(f"{pname}={pval}")
|
||||
else:
|
||||
param_parts.append(pname)
|
||||
parts.append(f"{name}:{','.join(param_parts)}")
|
||||
return '|'.join(parts)
|
||||
|
||||
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||
"""
|
||||
Compile a chain of effects into a single fused JAX function.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||
params_list: List of param dicts for each effect (used to detect static types)
|
||||
|
||||
Returns:
|
||||
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
|
||||
"""
|
||||
if not _JAX_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
import jax
|
||||
|
||||
# Get the individual JAX functions
|
||||
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||
|
||||
# Pre-extract param names and identify static params from actual values
|
||||
effect_param_names = []
|
||||
static_params = ['seed'] # seed is always static
|
||||
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||
param_names = list(params.keys())
|
||||
effect_param_names.append(param_names)
|
||||
# Check actual values to identify static types
|
||||
for pname, pval in params.items():
|
||||
if isinstance(pval, (str, bool)):
|
||||
static_params.append(f"_p{i}_{pname}")
|
||||
|
||||
def fused_fn(frame, t, frame_num, seed, **kwargs):
|
||||
result = frame
|
||||
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||
# Extract params for this effect from kwargs
|
||||
effect_kwargs = {}
|
||||
for pname in param_names:
|
||||
key = f"_p{i}_{pname}"
|
||||
if key in kwargs:
|
||||
effect_kwargs[pname] = kwargs[key]
|
||||
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||
return result
|
||||
|
||||
# JIT with static params (seed + any string/bool params)
|
||||
return jax.jit(fused_fn, static_argnames=static_params)
|
||||
except Exception as e:
|
||||
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
|
||||
"""Apply a chain of effects, using fused compilation if available."""
|
||||
chain_key = self._get_chain_key(effect_names, params_list)
|
||||
|
||||
# Try to get or compile fused chain
|
||||
if chain_key not in self.jax_fused_chains:
|
||||
fused_fn = self._compile_effect_chain(effect_names, params_list)
|
||||
self.jax_fused_chains[chain_key] = fused_fn
|
||||
if fused_fn:
|
||||
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
|
||||
|
||||
fused_fn = self.jax_fused_chains.get(chain_key)
|
||||
|
||||
if fused_fn is not None:
|
||||
# Build kwargs with prefixed param names
|
||||
kwargs = {}
|
||||
for i, params in enumerate(params_list):
|
||||
for pname, pval in params.items():
|
||||
kwargs[f"_p{i}_{pname}"] = pval
|
||||
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Handle GPU frames
|
||||
if hasattr(frame, 'cpu'):
|
||||
frame = frame.cpu
|
||||
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||
frame = frame.get()
|
||||
|
||||
try:
|
||||
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"Fused chain error: {e}", file=sys.stderr)
|
||||
|
||||
# Fall back to sequential application
|
||||
result = frame
|
||||
for name, params in zip(effect_names, params_list):
|
||||
result = self._apply_jax_effect(name, result, params, t, frame_num)
|
||||
if result is None:
|
||||
return None
|
||||
return result
|
||||
|
||||
def _force_deferred(self, deferred: DeferredEffectChain):
|
||||
"""Execute a deferred effect chain and return the actual array."""
|
||||
if len(deferred.effects) == 0:
|
||||
return deferred.base_frame
|
||||
|
||||
return self._apply_effect_chain(
|
||||
deferred.effects,
|
||||
deferred.params_list,
|
||||
deferred.base_frame,
|
||||
deferred.t,
|
||||
deferred.frame_num
|
||||
)
|
||||
|
||||
def _maybe_force(self, value):
|
||||
"""Force a deferred chain if needed, otherwise return as-is."""
|
||||
if isinstance(value, DeferredEffectChain):
|
||||
return self._force_deferred(value)
|
||||
return value
|
||||
|
||||
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||
"""
|
||||
Compile a vmapped version of an effect chain for batch processing.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||
params_list: List of param dicts (used to detect static types)
|
||||
|
||||
Returns:
|
||||
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
|
||||
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
|
||||
"""
|
||||
if not _JAX_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Get the individual JAX functions
|
||||
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||
|
||||
# Pre-extract param info
|
||||
effect_param_names = []
|
||||
static_params = set()
|
||||
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||
param_names = list(params.keys())
|
||||
effect_param_names.append(param_names)
|
||||
for pname, pval in params.items():
|
||||
if isinstance(pval, (str, bool)):
|
||||
static_params.add(f"_p{i}_{pname}")
|
||||
|
||||
# Single-frame function (will be vmapped)
|
||||
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
|
||||
result = frame
|
||||
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||
effect_kwargs = {}
|
||||
for pname in param_names:
|
||||
key = f"_p{i}_{pname}"
|
||||
if key in kwargs:
|
||||
effect_kwargs[pname] = kwargs[key]
|
||||
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||
return result
|
||||
|
||||
# Return unbatched function - we'll vmap at call time with proper in_axes
|
||||
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
|
||||
except Exception as e:
|
||||
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
|
||||
frames: list, ts: list, frame_nums: list) -> Optional[list]:
|
||||
"""
|
||||
Apply an effect chain to a batch of frames using vmap.
|
||||
|
||||
Args:
|
||||
effect_names: List of effect names
|
||||
params_list_batch: List of params_list for each frame in batch
|
||||
frames: List of input frames
|
||||
ts: List of time values
|
||||
frame_nums: List of frame numbers
|
||||
|
||||
Returns:
|
||||
List of output frames, or None on failure
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
# Use first frame's params for chain key (assume same structure)
|
||||
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
|
||||
batch_key = f"batch:{chain_key}"
|
||||
|
||||
# Compile batched version if needed
|
||||
if batch_key not in self.jax_batched_chains:
|
||||
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
|
||||
self.jax_batched_chains[batch_key] = batched_fn
|
||||
if batched_fn:
|
||||
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
|
||||
|
||||
batched_fn = self.jax_batched_chains.get(batch_key)
|
||||
|
||||
if batched_fn is not None:
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Stack frames into batch array
|
||||
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
|
||||
ts_array = jnp.array(ts)
|
||||
frame_nums_array = jnp.array(frame_nums)
|
||||
|
||||
# Build kwargs - all numeric params as arrays for vmap
|
||||
kwargs = {}
|
||||
static_kwargs = {} # Non-vmapped (strings, bools)
|
||||
|
||||
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
|
||||
for j, pname in enumerate(params_list_batch[0][i].keys()):
|
||||
key = f"_p{i}_{pname}"
|
||||
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
|
||||
|
||||
first = values[0]
|
||||
if isinstance(first, (str, bool)):
|
||||
# Static params - not vmapped
|
||||
static_kwargs[key] = first
|
||||
elif isinstance(first, (int, float)):
|
||||
# Always batch numeric params for simplicity
|
||||
kwargs[key] = jnp.array(values)
|
||||
elif hasattr(first, 'shape'):
|
||||
kwargs[key] = jnp.stack(values)
|
||||
else:
|
||||
kwargs[key] = jnp.array(values)
|
||||
|
||||
seed = self.config.get('seed', 42)
|
||||
|
||||
# Create wrapper that unpacks the params dict
|
||||
def single_call(frame, t, frame_num, params_dict):
|
||||
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
|
||||
|
||||
# vmap over frame, t, frame_num, and the params dict (as pytree)
|
||||
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# Stack kwargs into a dict of arrays (pytree with matching structure)
|
||||
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
|
||||
|
||||
# Unstack results
|
||||
return [results[i] for i in range(len(frames))]
|
||||
except Exception as e:
|
||||
print(f"Batched chain error: {e}", file=sys.stderr)
|
||||
|
||||
# Fall back to sequential
|
||||
return None
|
||||
|
||||
def _init(self):
|
||||
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
||||
# Set random seed for deterministic output
|
||||
@@ -869,6 +1247,22 @@ class StreamInterpreter:
|
||||
# === Effects ===
|
||||
|
||||
if op in self.effects:
|
||||
# Try to detect and fuse effect chains for JAX acceleration
|
||||
if self.use_jax and op in self.jax_effects:
|
||||
chain_info = self._extract_effect_chain(expr, env)
|
||||
if chain_info is not None:
|
||||
effect_names, params_list, base_frame_expr = chain_info
|
||||
# Only use chain if we have 2+ effects (worth fusing)
|
||||
if len(effect_names) >= 2:
|
||||
base_frame = self._eval(base_frame_expr, env)
|
||||
if base_frame is not None and hasattr(base_frame, 'shape'):
|
||||
t = env.get('t', 0.0)
|
||||
frame_num = env.get('frame-num', 0)
|
||||
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
|
||||
if result is not None:
|
||||
return result
|
||||
# Fall through if chain application fails
|
||||
|
||||
effect = self.effects[op]
|
||||
effect_env = dict(env)
|
||||
|
||||
@@ -895,17 +1289,28 @@ class StreamInterpreter:
|
||||
positional_idx += 1
|
||||
i += 1
|
||||
|
||||
# Try JAX-accelerated execution first
|
||||
# Try JAX-accelerated execution with deferred chaining
|
||||
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
||||
# Build params dict for JAX (exclude 'frame')
|
||||
jax_params = {k: v for k, v in effect_env.items()
|
||||
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
|
||||
if k != 'frame' and k in effect['params']}
|
||||
t = env.get('t', 0.0)
|
||||
frame_num = env.get('frame-num', 0)
|
||||
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
|
||||
if result is not None:
|
||||
return result
|
||||
# Fall through to interpreter if JAX fails
|
||||
|
||||
# Check if input is a deferred chain - if so, extend it
|
||||
if isinstance(frame_val, DeferredEffectChain):
|
||||
return frame_val.extend(op, jax_params)
|
||||
|
||||
# Check if input is a valid frame - create new deferred chain
|
||||
if hasattr(frame_val, 'shape'):
|
||||
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
|
||||
|
||||
# Fall through to interpreter if not a valid frame
|
||||
|
||||
# Force any deferred frame before interpreter evaluation
|
||||
if isinstance(frame_val, DeferredEffectChain):
|
||||
frame_val = self._force_deferred(frame_val)
|
||||
effect_env['frame'] = frame_val
|
||||
|
||||
return self._eval(effect['body'], effect_env)
|
||||
|
||||
@@ -922,10 +1327,15 @@ class StreamInterpreter:
|
||||
if isinstance(args[i], Keyword):
|
||||
k = args[i].name
|
||||
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
||||
# Force deferred chains before passing to primitives
|
||||
v = self._maybe_force(v)
|
||||
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
||||
i += 2
|
||||
else:
|
||||
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
|
||||
val = self._eval(args[i], env)
|
||||
# Force deferred chains before passing to primitives
|
||||
val = self._maybe_force(val)
|
||||
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
|
||||
i += 1
|
||||
try:
|
||||
if kwargs:
|
||||
@@ -1152,6 +1562,61 @@ class StreamInterpreter:
|
||||
eval_times = []
|
||||
write_times = []
|
||||
|
||||
# Batch accumulation for JAX
|
||||
batch_deferred = [] # Accumulated DeferredEffectChains
|
||||
batch_times = [] # Corresponding time values
|
||||
batch_start_frame = 0
|
||||
|
||||
def flush_batch():
|
||||
"""Execute accumulated batch and write results."""
|
||||
nonlocal batch_deferred, batch_times
|
||||
if not batch_deferred:
|
||||
return
|
||||
|
||||
t_flush = time.time()
|
||||
|
||||
# Check if all chains have same structure (can batch)
|
||||
first = batch_deferred[0]
|
||||
can_batch = (
|
||||
self.use_jax and
|
||||
len(batch_deferred) >= 2 and
|
||||
all(d.effects == first.effects for d in batch_deferred)
|
||||
)
|
||||
|
||||
if can_batch:
|
||||
# Try batched execution
|
||||
frames = [d.base_frame for d in batch_deferred]
|
||||
ts = [d.t for d in batch_deferred]
|
||||
frame_nums = [d.frame_num for d in batch_deferred]
|
||||
params_batch = [d.params_list for d in batch_deferred]
|
||||
|
||||
results = self._apply_batched_chain(
|
||||
first.effects, params_batch, frames, ts, frame_nums
|
||||
)
|
||||
|
||||
if results is not None:
|
||||
# Write batched results
|
||||
for result, t in zip(results, batch_times):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, t)
|
||||
batch_deferred = []
|
||||
batch_times = []
|
||||
return
|
||||
|
||||
# Fall back to sequential execution
|
||||
for deferred, t in zip(batch_deferred, batch_times):
|
||||
result = self._force_deferred(deferred)
|
||||
if result is not None and hasattr(result, 'shape'):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, t)
|
||||
|
||||
batch_deferred = []
|
||||
batch_times = []
|
||||
|
||||
for frame_num in range(start_frame, n_frames):
|
||||
if not out.is_open:
|
||||
break
|
||||
@@ -1182,8 +1647,23 @@ class StreamInterpreter:
|
||||
eval_times.append(time.time() - t1)
|
||||
|
||||
t2 = time.time()
|
||||
if result is not None and hasattr(result, 'shape'):
|
||||
out.write(result, ctx.t)
|
||||
if result is not None:
|
||||
if isinstance(result, DeferredEffectChain):
|
||||
# Accumulate for batching
|
||||
batch_deferred.append(result)
|
||||
batch_times.append(ctx.t)
|
||||
|
||||
# Flush when batch is full
|
||||
if len(batch_deferred) >= self.jax_batch_size:
|
||||
flush_batch()
|
||||
else:
|
||||
# Not deferred - flush any pending batch first, then write
|
||||
flush_batch()
|
||||
if hasattr(result, 'shape'):
|
||||
if hasattr(result, 'block_until_ready'):
|
||||
result.block_until_ready()
|
||||
result = np.asarray(result)
|
||||
out.write(result, ctx.t)
|
||||
write_times.append(time.time() - t2)
|
||||
|
||||
frame_elapsed = time.time() - frame_start
|
||||
@@ -1219,6 +1699,9 @@ class StreamInterpreter:
|
||||
except Exception as e:
|
||||
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
|
||||
|
||||
# Flush any remaining batch
|
||||
flush_batch()
|
||||
|
||||
finally:
|
||||
out.close()
|
||||
# Store output for access to properties like playlist_cid
|
||||
|
||||
542
test_funky_text.py
Normal file
542
test_funky_text.py
Normal file
@@ -0,0 +1,542 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Funky comparison tests: PIL vs TextStrip system.
|
||||
Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from streaming.jax_typography import (
|
||||
render_text_strip, place_text_strip_jax, _load_font
|
||||
)
|
||||
|
||||
FONTS = {
|
||||
'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||
'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
|
||||
'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf',
|
||||
'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf',
|
||||
'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
|
||||
'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf',
|
||||
'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf',
|
||||
'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf',
|
||||
}
|
||||
|
||||
|
||||
def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
|
||||
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
|
||||
stroke_width=0, stroke_fill=None, anchor="la",
|
||||
multiline=False, line_spacing=4, align="left"):
|
||||
"""Render with PIL directly, including color/opacity."""
|
||||
frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
|
||||
# For opacity, render to RGBA then composite
|
||||
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(txt_layer)
|
||||
font = _load_font(font_path, font_size)
|
||||
|
||||
if stroke_fill is None:
|
||||
stroke_fill = (0, 0, 0)
|
||||
|
||||
# PIL fill with alpha for opacity
|
||||
alpha_int = int(round(opacity * 255))
|
||||
fill_rgba = (*fill, alpha_int)
|
||||
stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None
|
||||
|
||||
if multiline:
|
||||
draw.multiline_text((x, y), text, fill=fill_rgba, font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_rgba,
|
||||
spacing=line_spacing, align=align, anchor=anchor)
|
||||
else:
|
||||
draw.text((x, y), text, fill=fill_rgba, font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor)
|
||||
|
||||
# Composite onto background
|
||||
bg_img = Image.fromarray(frame).convert('RGBA')
|
||||
result = Image.alpha_composite(bg_img, txt_layer)
|
||||
return np.array(result.convert('RGB'))
|
||||
|
||||
|
||||
def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
|
||||
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
|
||||
stroke_width=0, stroke_fill=None, anchor="la",
|
||||
multiline=False, line_spacing=4, align="left"):
|
||||
"""Render with TextStrip system."""
|
||||
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
|
||||
|
||||
strip = render_text_strip(
|
||||
text, font_path, font_size,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
|
||||
)
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array(fill, dtype=jnp.float32)
|
||||
|
||||
result = place_text_strip_jax(
|
||||
frame, strip_img, x, y,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color, opacity,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
stroke_width=strip.stroke_width
|
||||
)
|
||||
return np.array(result)
|
||||
|
||||
|
||||
def compare(name, tolerance=0, **kwargs):
|
||||
"""Compare PIL and TextStrip rendering."""
|
||||
pil = render_pil(**kwargs)
|
||||
strip = render_strip(**kwargs)
|
||||
|
||||
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
|
||||
max_diff = diff.max()
|
||||
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||
|
||||
if max_diff == 0:
|
||||
print(f" PASS: {name}")
|
||||
return True
|
||||
|
||||
if tolerance > 0:
|
||||
best_diff = diff.copy()
|
||||
for dy in range(-tolerance, tolerance + 1):
|
||||
for dx in range(-tolerance, tolerance + 1):
|
||||
if dy == 0 and dx == 0:
|
||||
continue
|
||||
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
|
||||
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
|
||||
best_diff = np.minimum(best_diff, sdiff)
|
||||
if best_diff.max() == 0:
|
||||
print(f" PASS: {name} (within {tolerance}px tolerance)")
|
||||
return True
|
||||
|
||||
print(f" FAIL: {name}")
|
||||
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
|
||||
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
|
||||
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||
Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png")
|
||||
print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png")
|
||||
return False
|
||||
|
||||
|
||||
def test_colors():
|
||||
"""Test various text colors on various backgrounds."""
|
||||
print("\n--- Colors ---")
|
||||
results = []
|
||||
|
||||
# White on black (baseline)
|
||||
results.append(compare("color_white_on_black",
|
||||
text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0)))
|
||||
|
||||
# Red on black
|
||||
results.append(compare("color_red",
|
||||
text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0)))
|
||||
|
||||
# Green on black
|
||||
results.append(compare("color_green",
|
||||
text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0)))
|
||||
|
||||
# Blue on black
|
||||
results.append(compare("color_blue",
|
||||
text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0)))
|
||||
|
||||
# Yellow on dark gray
|
||||
results.append(compare("color_yellow_on_gray",
|
||||
text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40)))
|
||||
|
||||
# Magenta on white
|
||||
results.append(compare("color_magenta_on_white",
|
||||
text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255)))
|
||||
|
||||
# Subtle: gray text on slightly lighter gray
|
||||
results.append(compare("color_subtle_gray",
|
||||
text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64)))
|
||||
|
||||
# Orange on deep blue
|
||||
results.append(compare("color_orange_on_blue",
|
||||
text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_opacity():
|
||||
"""Test different opacity levels."""
|
||||
print("\n--- Opacity ---")
|
||||
results = []
|
||||
|
||||
results.append(compare("opacity_100",
|
||||
text="Full", x=20, y=30, opacity=1.0))
|
||||
|
||||
results.append(compare("opacity_75",
|
||||
text="75%", x=20, y=30, opacity=0.75))
|
||||
|
||||
results.append(compare("opacity_50",
|
||||
text="Half", x=20, y=30, opacity=0.5))
|
||||
|
||||
results.append(compare("opacity_25",
|
||||
text="Quarter", x=20, y=30, opacity=0.25))
|
||||
|
||||
results.append(compare("opacity_10",
|
||||
text="Ghost", x=20, y=30, opacity=0.1))
|
||||
|
||||
# Opacity on colored background
|
||||
results.append(compare("opacity_on_colored_bg",
|
||||
text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0),
|
||||
opacity=0.5))
|
||||
|
||||
# Color + opacity combo
|
||||
results.append(compare("opacity_red_on_green",
|
||||
text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0),
|
||||
opacity=0.6))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_fonts():
|
||||
"""Test different fonts and sizes."""
|
||||
print("\n--- Fonts & Sizes ---")
|
||||
results = []
|
||||
|
||||
for label, path in FONTS.items():
|
||||
results.append(compare(f"font_{label}",
|
||||
text="Quick Fox", x=20, y=30, font_path=path, font_size=28,
|
||||
frame_size=(300, 80)))
|
||||
|
||||
# Tiny text
|
||||
results.append(compare("size_tiny",
|
||||
text="Tiny text at 12px", x=10, y=15, font_size=12,
|
||||
frame_size=(200, 40)))
|
||||
|
||||
# Big text
|
||||
results.append(compare("size_big",
|
||||
text="BIG", x=20, y=10, font_size=72,
|
||||
frame_size=(300, 100)))
|
||||
|
||||
# Huge text
|
||||
results.append(compare("size_huge",
|
||||
text="XL", x=10, y=10, font_size=120,
|
||||
frame_size=(300, 160)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_anchors():
|
||||
"""Test all anchor combinations."""
|
||||
print("\n--- Anchors ---")
|
||||
results = []
|
||||
|
||||
# All horizontal x vertical combos
|
||||
for h in ['l', 'm', 'r']:
|
||||
for v in ['a', 'm', 's', 'd']:
|
||||
anchor = h + v
|
||||
results.append(compare(f"anchor_{anchor}",
|
||||
text="Anchor", x=200, y=50, anchor=anchor,
|
||||
frame_size=(400, 100)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_strokes():
|
||||
"""Test various stroke widths and colors."""
|
||||
print("\n--- Strokes ---")
|
||||
results = []
|
||||
|
||||
for sw in [1, 2, 3, 4, 6, 8]:
|
||||
results.append(compare(f"stroke_w{sw}",
|
||||
text="Stroke", x=30, y=20, font_size=40,
|
||||
stroke_width=sw, stroke_fill=(0, 0, 0),
|
||||
frame_size=(300, 80)))
|
||||
|
||||
# Colored strokes
|
||||
results.append(compare("stroke_red",
|
||||
text="Red outline", x=20, y=20, font_size=36,
|
||||
stroke_width=3, stroke_fill=(255, 0, 0),
|
||||
frame_size=(350, 80)))
|
||||
|
||||
results.append(compare("stroke_white_on_black",
|
||||
text="Glow", x=20, y=20, font_size=40,
|
||||
fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255),
|
||||
frame_size=(250, 80)))
|
||||
|
||||
# Stroke with bold font
|
||||
results.append(compare("stroke_bold",
|
||||
text="Bold+Stroke", x=20, y=20,
|
||||
font_path=FONTS['bold'], font_size=36,
|
||||
stroke_width=3, stroke_fill=(0, 0, 0),
|
||||
frame_size=(400, 80)))
|
||||
|
||||
# Stroke + colored text on colored bg
|
||||
results.append(compare("stroke_colored_on_bg",
|
||||
text="Party", x=20, y=20, font_size=48,
|
||||
fill=(255, 255, 0), bg=(50, 0, 80),
|
||||
stroke_width=3, stroke_fill=(255, 0, 0),
|
||||
frame_size=(300, 80)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_edge_clipping():
|
||||
"""Test text at frame edges - clipping behavior."""
|
||||
print("\n--- Edge Clipping ---")
|
||||
results = []
|
||||
|
||||
# Text at very left edge
|
||||
results.append(compare("clip_left_edge",
|
||||
text="LEFT", x=0, y=30, frame_size=(200, 80)))
|
||||
|
||||
# Text partially off right edge
|
||||
results.append(compare("clip_right_edge",
|
||||
text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80)))
|
||||
|
||||
# Text at top edge
|
||||
results.append(compare("clip_top",
|
||||
text="TOP", x=20, y=0, frame_size=(200, 80)))
|
||||
|
||||
# Text at bottom edge - partially clipped
|
||||
results.append(compare("clip_bottom",
|
||||
text="BOTTOM", x=20, y=55, font_size=40,
|
||||
frame_size=(200, 80)))
|
||||
|
||||
# Large text overflowing all sides from center
|
||||
results.append(compare("clip_overflow_center",
|
||||
text="HUGE", x=75, y=40, font_size=100, anchor="mm",
|
||||
frame_size=(150, 80)))
|
||||
|
||||
# Corner placement
|
||||
results.append(compare("clip_corner_br",
|
||||
text="Corner", x=350, y=70, font_size=36,
|
||||
frame_size=(400, 100)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_multiline_fancy():
|
||||
"""Test multiline with various styles."""
|
||||
print("\n--- Multiline Fancy ---")
|
||||
results = []
|
||||
|
||||
# Right-aligned (1px tolerance: same sub-pixel issue as center alignment)
|
||||
results.append(compare("multi_right",
|
||||
text="Right\nAligned\nText", x=380, y=20,
|
||||
frame_size=(400, 150),
|
||||
multiline=True, anchor="ra", align="right",
|
||||
tolerance=1))
|
||||
|
||||
# Center + stroke
|
||||
results.append(compare("multi_center_stroke",
|
||||
text="Title\nSubtitle", x=200, y=20,
|
||||
font_size=32, frame_size=(400, 120),
|
||||
multiline=True, anchor="ma", align="center",
|
||||
stroke_width=2, stroke_fill=(0, 0, 0),
|
||||
tolerance=1))
|
||||
|
||||
# Wide line spacing
|
||||
results.append(compare("multi_wide_spacing",
|
||||
text="Line A\nLine B\nLine C", x=20, y=10,
|
||||
frame_size=(300, 200),
|
||||
multiline=True, line_spacing=20))
|
||||
|
||||
# Zero extra spacing
|
||||
results.append(compare("multi_tight_spacing",
|
||||
text="Tight\nPacked\nLines", x=20, y=10,
|
||||
frame_size=(300, 150),
|
||||
multiline=True, line_spacing=0))
|
||||
|
||||
# Many lines
|
||||
results.append(compare("multi_many_lines",
|
||||
text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5,
|
||||
font_size=20, frame_size=(200, 200),
|
||||
multiline=True, line_spacing=4))
|
||||
|
||||
# Bold multiline with stroke
|
||||
results.append(compare("multi_bold_stroke",
|
||||
text="BOLD\nSTROKE", x=20, y=10,
|
||||
font_path=FONTS['bold'], font_size=48,
|
||||
stroke_width=3, stroke_fill=(200, 0, 0),
|
||||
frame_size=(350, 150), multiline=True))
|
||||
|
||||
# Multiline on colored bg with opacity
|
||||
results.append(compare("multi_opacity_on_bg",
|
||||
text="Semi\nTransparent", x=20, y=10,
|
||||
fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7,
|
||||
frame_size=(300, 120), multiline=True))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_special_chars():
|
||||
"""Test special characters and edge cases."""
|
||||
print("\n--- Special Characters ---")
|
||||
results = []
|
||||
|
||||
# Numbers and symbols
|
||||
results.append(compare("chars_numbers",
|
||||
text="0123456789", x=20, y=30, frame_size=(300, 80)))
|
||||
|
||||
results.append(compare("chars_punctuation",
|
||||
text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80)))
|
||||
|
||||
results.append(compare("chars_symbols",
|
||||
text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80)))
|
||||
|
||||
# Single character
|
||||
results.append(compare("chars_single",
|
||||
text="X", x=50, y=30, font_size=48, frame_size=(100, 80)))
|
||||
|
||||
# Very long text (clipped)
|
||||
results.append(compare("chars_long",
|
||||
text="The quick brown fox jumps over the lazy dog", x=10, y=30,
|
||||
font_size=24, frame_size=(400, 80)))
|
||||
|
||||
# Mixed case
|
||||
results.append(compare("chars_mixed_case",
|
||||
text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_combos():
|
||||
"""Complex combinations of features."""
|
||||
print("\n--- Combos ---")
|
||||
results = []
|
||||
|
||||
# Big bold stroke + color + opacity
|
||||
results.append(compare("combo_all_features",
|
||||
text="EPIC", x=20, y=10,
|
||||
font_path=FONTS['bold'], font_size=64,
|
||||
fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85,
|
||||
stroke_width=4, stroke_fill=(180, 0, 0),
|
||||
frame_size=(350, 100)))
|
||||
|
||||
# Small mono on busy background
|
||||
results.append(compare("combo_mono_code",
|
||||
text="fn main() {}", x=10, y=15,
|
||||
font_path=FONTS['mono'], font_size=16,
|
||||
fill=(0, 255, 100), bg=(30, 30, 30),
|
||||
frame_size=(250, 50)))
|
||||
|
||||
# Serif italic multiline with stroke
|
||||
results.append(compare("combo_serif_italic_multi",
|
||||
text="Once upon\na time...", x=20, y=10,
|
||||
font_path=FONTS['italic'], font_size=28,
|
||||
stroke_width=1, stroke_fill=(80, 80, 80),
|
||||
frame_size=(300, 120), multiline=True))
|
||||
|
||||
# Narrow font, big stroke, center anchored
|
||||
results.append(compare("combo_narrow_stroke_center",
|
||||
text="NARROW", x=150, y=40,
|
||||
font_path=FONTS['narrow'], font_size=44,
|
||||
stroke_width=5, stroke_fill=(0, 0, 0),
|
||||
anchor="mm", frame_size=(300, 80)))
|
||||
|
||||
# Multiple strips on same frame (simulated by sequential placement)
|
||||
results.append(compare("combo_opacity_blend",
|
||||
text="Layered", x=20, y=30,
|
||||
fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5,
|
||||
font_size=48, frame_size=(300, 80)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_multi_strip_overlay():
|
||||
"""Test placing multiple strips on the same frame."""
|
||||
print("\n--- Multi-Strip Overlay ---")
|
||||
results = []
|
||||
|
||||
frame_size = (400, 150)
|
||||
bg = (20, 20, 40)
|
||||
|
||||
# PIL version - multiple draw calls
|
||||
font1 = _load_font(FONTS['bold'], 48)
|
||||
font2 = _load_font(None, 24)
|
||||
font3 = _load_font(FONTS['mono'], 18)
|
||||
|
||||
pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
|
||||
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(txt_layer)
|
||||
draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1,
|
||||
stroke_width=2, stroke_fill=(0, 0, 0, 255))
|
||||
draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2)
|
||||
draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3)
|
||||
bg_img = Image.fromarray(pil_frame).convert('RGBA')
|
||||
pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB'))
|
||||
|
||||
# Strip version - multiple placements
|
||||
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
|
||||
|
||||
s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0))
|
||||
s2 = render_text_strip("Subtitle here", None, 24)
|
||||
s3 = render_text_strip("code_snippet()", FONTS['mono'], 18)
|
||||
|
||||
frame = place_text_strip_jax(
|
||||
frame, jnp.asarray(s1.image), 20, 10,
|
||||
s1.baseline_y, s1.bearing_x,
|
||||
jnp.array([255, 255, 0], dtype=jnp.float32), 1.0,
|
||||
anchor_x=s1.anchor_x, anchor_y=s1.anchor_y,
|
||||
stroke_width=s1.stroke_width)
|
||||
|
||||
frame = place_text_strip_jax(
|
||||
frame, jnp.asarray(s2.image), 20, 70,
|
||||
s2.baseline_y, s2.bearing_x,
|
||||
jnp.array([200, 200, 200], dtype=jnp.float32), 1.0,
|
||||
anchor_x=s2.anchor_x, anchor_y=s2.anchor_y,
|
||||
stroke_width=s2.stroke_width)
|
||||
|
||||
frame = place_text_strip_jax(
|
||||
frame, jnp.asarray(s3.image), 20, 110,
|
||||
s3.baseline_y, s3.bearing_x,
|
||||
jnp.array([0, 255, 128], dtype=jnp.float32), 200/255,
|
||||
anchor_x=s3.anchor_x, anchor_y=s3.anchor_y,
|
||||
stroke_width=s3.stroke_width)
|
||||
|
||||
strip_result = np.array(frame)
|
||||
|
||||
diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16))
|
||||
max_diff = diff.max()
|
||||
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||
|
||||
if max_diff <= 1:
|
||||
print(f" PASS: multi_overlay (max_diff={max_diff})")
|
||||
results.append(True)
|
||||
else:
|
||||
print(f" FAIL: multi_overlay")
|
||||
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||
Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png")
|
||||
Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png")
|
||||
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||
Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png")
|
||||
print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png")
|
||||
results.append(False)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Funky TextStrip vs PIL Comparison")
|
||||
print("=" * 60)
|
||||
|
||||
all_results = []
|
||||
all_results.extend(test_colors())
|
||||
all_results.extend(test_opacity())
|
||||
all_results.extend(test_fonts())
|
||||
all_results.extend(test_anchors())
|
||||
all_results.extend(test_strokes())
|
||||
all_results.extend(test_edge_clipping())
|
||||
all_results.extend(test_multiline_fancy())
|
||||
all_results.extend(test_special_chars())
|
||||
all_results.extend(test_combos())
|
||||
all_results.extend(test_multi_strip_overlay())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
passed = sum(all_results)
|
||||
total = len(all_results)
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
if passed == total:
|
||||
print("ALL TESTS PASSED!")
|
||||
else:
|
||||
failed = [i for i, r in enumerate(all_results) if not r]
|
||||
print(f"FAILED: {total - passed} tests (indices: {failed})")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
test_pil_options.py
Normal file
183
test_pil_options.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Explore PIL text options and test if we can match them.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
def load_font(font_name=None, font_size=32):
|
||||
"""Load a font."""
|
||||
candidates = [
|
||||
font_name,
|
||||
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
|
||||
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf',
|
||||
]
|
||||
for path in candidates:
|
||||
if path is None:
|
||||
continue
|
||||
try:
|
||||
return ImageFont.truetype(path, font_size)
|
||||
except (IOError, OSError):
|
||||
continue
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def test_pil_options():
|
||||
"""Test various PIL text options."""
|
||||
|
||||
# Create a test frame
|
||||
frame_size = (600, 400)
|
||||
|
||||
font = load_font(None, 36)
|
||||
font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36)
|
||||
font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36)
|
||||
|
||||
tests = []
|
||||
|
||||
# Test 1: Basic text
|
||||
def basic_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font)
|
||||
return img, "basic"
|
||||
tests.append(basic_text)
|
||||
|
||||
# Test 2: Stroke/outline
|
||||
def stroke_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font,
|
||||
stroke_width=2, stroke_fill=(255, 0, 0, 255))
|
||||
return img, "stroke"
|
||||
tests.append(stroke_text)
|
||||
|
||||
# Test 3: Bold font
|
||||
def bold_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold)
|
||||
return img, "bold"
|
||||
tests.append(bold_text)
|
||||
|
||||
# Test 4: Italic font
|
||||
def italic_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic)
|
||||
return img, "italic"
|
||||
tests.append(italic_text)
|
||||
|
||||
# Test 5: Different anchors
|
||||
def anchor_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
# Draw crosshairs at anchor points
|
||||
for x in [100, 300, 500]:
|
||||
draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255))
|
||||
draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255))
|
||||
|
||||
draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm")
|
||||
draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm")
|
||||
draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm")
|
||||
return img, "anchor"
|
||||
tests.append(anchor_text)
|
||||
|
||||
# Test 6: Multiline text
|
||||
def multiline_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three",
|
||||
fill=(255, 255, 255, 255), font=font, spacing=10)
|
||||
return img, "multiline"
|
||||
tests.append(multiline_text)
|
||||
|
||||
# Test 7: Semi-transparent text
|
||||
def alpha_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font)
|
||||
draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font)
|
||||
draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font)
|
||||
return img, "alpha"
|
||||
tests.append(alpha_text)
|
||||
|
||||
# Test 8: Colored text
|
||||
def colored_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font)
|
||||
draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font)
|
||||
draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font)
|
||||
draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font)
|
||||
return img, "colored"
|
||||
tests.append(colored_text)
|
||||
|
||||
# Test 9: Large stroke
|
||||
def large_stroke():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font,
|
||||
stroke_width=5, stroke_fill=(0, 0, 0, 255))
|
||||
return img, "large_stroke"
|
||||
tests.append(large_stroke)
|
||||
|
||||
# Test 10: Emoji (if supported)
|
||||
def emoji_text():
|
||||
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
# Try to find an emoji font
|
||||
emoji_font = None
|
||||
emoji_paths = [
|
||||
'/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf',
|
||||
'/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf',
|
||||
]
|
||||
for p in emoji_paths:
|
||||
try:
|
||||
emoji_font = ImageFont.truetype(p, 36)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
if emoji_font:
|
||||
draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font)
|
||||
else:
|
||||
draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font)
|
||||
except Exception as e:
|
||||
draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font)
|
||||
return img, "emoji"
|
||||
tests.append(emoji_text)
|
||||
|
||||
# Run all tests
|
||||
print("PIL Text Options Test")
|
||||
print("=" * 60)
|
||||
|
||||
for test_fn in tests:
|
||||
img, name = test_fn()
|
||||
fname = f"/tmp/pil_test_{name}.png"
|
||||
img.save(fname)
|
||||
print(f"Saved: {fname}")
|
||||
|
||||
print("\nCheck /tmp/pil_test_*.png for results")
|
||||
|
||||
# Print available parameters
|
||||
print("\n" + "=" * 60)
|
||||
print("PIL draw.text() parameters:")
|
||||
print(" - xy: position tuple")
|
||||
print(" - text: string to draw")
|
||||
print(" - fill: color (R,G,B) or (R,G,B,A)")
|
||||
print(" - font: ImageFont object")
|
||||
print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)")
|
||||
print(" - spacing: line spacing for multiline")
|
||||
print(" - align: 'left', 'center', 'right' for multiline")
|
||||
print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)")
|
||||
print(" - features: OpenType features list")
|
||||
print(" - language: language code for shaping")
|
||||
print(" - stroke_width: outline width in pixels")
|
||||
print(" - stroke_fill: outline color")
|
||||
print(" - embedded_color: use embedded color glyphs (emoji)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pil_options()
|
||||
176
test_styled_text.py
Normal file
176
test_styled_text.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test styled TextStrip rendering against PIL.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from streaming.jax_typography import (
|
||||
render_text_strip, place_text_strip_jax, _load_font
|
||||
)
|
||||
|
||||
|
||||
def render_pil(text, x, y, font_size=36, frame_size=(400, 100),
|
||||
stroke_width=0, stroke_fill=None, anchor="la",
|
||||
multiline=False, line_spacing=4, align="left"):
|
||||
"""Render with PIL directly."""
|
||||
frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8)
|
||||
img = Image.fromarray(frame)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
font = _load_font(None, font_size)
|
||||
|
||||
# Default stroke fill
|
||||
if stroke_fill is None:
|
||||
stroke_fill = (0, 0, 0)
|
||||
|
||||
if multiline:
|
||||
draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||
spacing=line_spacing, align=align, anchor=anchor)
|
||||
else:
|
||||
draw.text((x, y), text, fill=(255, 255, 255), font=font,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor)
|
||||
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def render_strip(text, x, y, font_size=36, frame_size=(400, 100),
|
||||
stroke_width=0, stroke_fill=None, anchor="la",
|
||||
multiline=False, line_spacing=4, align="left"):
|
||||
"""Render with TextStrip."""
|
||||
frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8)
|
||||
|
||||
strip = render_text_strip(
|
||||
text, None, font_size,
|
||||
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
|
||||
)
|
||||
strip_img = jnp.asarray(strip.image)
|
||||
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||
|
||||
result = place_text_strip_jax(
|
||||
frame, strip_img, x, y,
|
||||
strip.baseline_y, strip.bearing_x,
|
||||
color, 1.0,
|
||||
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||
stroke_width=strip.stroke_width
|
||||
)
|
||||
|
||||
return np.array(result)
|
||||
|
||||
|
||||
def compare(name, text, x, y, font_size=36, frame_size=(400, 100),
|
||||
tolerance=0, **kwargs):
|
||||
"""Compare PIL and TextStrip rendering.
|
||||
|
||||
tolerance=0: exact pixel match required
|
||||
tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences
|
||||
in center-aligned multiline text where the strip is pre-rendered
|
||||
at a different base position than the final placement)
|
||||
"""
|
||||
pil = render_pil(text, x, y, font_size, frame_size, **kwargs)
|
||||
strip = render_strip(text, x, y, font_size, frame_size, **kwargs)
|
||||
|
||||
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
|
||||
max_diff = diff.max()
|
||||
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||
|
||||
if max_diff == 0:
|
||||
print(f"PASS: {name}")
|
||||
print(f" Max diff: 0, Pixels different: 0")
|
||||
return True
|
||||
|
||||
if tolerance > 0:
|
||||
# Check if the difference is just a sub-pixel position shift:
|
||||
# for each shifted version, compute the minimum diff
|
||||
best_diff = diff.copy()
|
||||
for dy in range(-tolerance, tolerance + 1):
|
||||
for dx in range(-tolerance, tolerance + 1):
|
||||
if dy == 0 and dx == 0:
|
||||
continue
|
||||
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
|
||||
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
|
||||
best_diff = np.minimum(best_diff, sdiff)
|
||||
max_shift_diff = best_diff.max()
|
||||
pixels_shift_diff = (best_diff > 0).any(axis=2).sum()
|
||||
if max_shift_diff == 0:
|
||||
print(f"PASS: {name} (within {tolerance}px position tolerance)")
|
||||
print(f" Raw diff: {max_diff}, After shift tolerance: 0")
|
||||
return True
|
||||
|
||||
status = "FAIL"
|
||||
print(f"{status}: {name}")
|
||||
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||
|
||||
# Save debug images
|
||||
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
|
||||
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
|
||||
diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||
Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png")
|
||||
print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Styled TextStrip vs PIL Comparison")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Basic text
|
||||
results.append(compare("basic", "Hello World", 20, 50))
|
||||
|
||||
# Stroke/outline
|
||||
results.append(compare("stroke_2", "Outlined", 20, 50,
|
||||
stroke_width=2, stroke_fill=(255, 0, 0)))
|
||||
|
||||
results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48,
|
||||
frame_size=(500, 120),
|
||||
stroke_width=5, stroke_fill=(0, 0, 0)))
|
||||
|
||||
# Anchors - center
|
||||
results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100),
|
||||
anchor="mm"))
|
||||
|
||||
# Anchors - right
|
||||
results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100),
|
||||
anchor="rm"))
|
||||
|
||||
# Multiline
|
||||
results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20,
|
||||
frame_size=(400, 150),
|
||||
multiline=True, line_spacing=8))
|
||||
|
||||
# Multiline centered (1px tolerance: sub-pixel rendering differs because
|
||||
# the strip is pre-rendered at an integer position while PIL's center
|
||||
# alignment uses fractional getlength values for the 'm' anchor shift)
|
||||
results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20,
|
||||
frame_size=(400, 150),
|
||||
multiline=True, anchor="ma", align="center",
|
||||
tolerance=1))
|
||||
|
||||
# Stroke + multiline
|
||||
results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20,
|
||||
frame_size=(400, 120),
|
||||
stroke_width=2, stroke_fill=(0, 0, 255),
|
||||
multiline=True))
|
||||
|
||||
print("=" * 60)
|
||||
passed = sum(results)
|
||||
total = len(results)
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
|
||||
if passed == total:
|
||||
print("ALL TESTS PASSED!")
|
||||
else:
|
||||
print(f"FAILED: {total - passed} tests")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
517
tests/test_jax_pipeline_integration.py
Normal file
517
tests/test_jax_pipeline_integration.py
Normal file
@@ -0,0 +1,517 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Integration tests comparing JAX and Python rendering pipelines.
|
||||
|
||||
These tests ensure the JAX-compiled effect chains produce identical output
|
||||
to the Python/NumPy path. They test:
|
||||
1. Full effect pipelines through both interpreters
|
||||
2. Multi-frame sequences (to catch phase-dependent bugs)
|
||||
3. Compiled effect chain fusion
|
||||
4. Edge cases like shrinking/zooming that affect boundary handling
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import numpy as np
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the art-celery module is importable
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
|
||||
# Path to test resources
|
||||
TEST_DIR = Path('/home/giles/art/test')
|
||||
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
|
||||
TEMPLATES_DIR = TEST_DIR / 'templates'
|
||||
|
||||
|
||||
def create_test_image(h=96, w=128):
|
||||
"""Create a test image with distinct patterns."""
|
||||
import cv2
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# Create gradient background
|
||||
for y in range(h):
|
||||
for x in range(w):
|
||||
img[y, x] = [
|
||||
int(255 * x / w), # R: horizontal gradient
|
||||
int(255 * y / h), # G: vertical gradient
|
||||
128 # B: constant
|
||||
]
|
||||
|
||||
# Add features
|
||||
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
|
||||
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def test_env(tmp_path_factory):
|
||||
"""Set up test environment with sexp files and test media."""
|
||||
test_dir = tmp_path_factory.mktemp('sexp_test')
|
||||
original_dir = os.getcwd()
|
||||
os.chdir(test_dir)
|
||||
|
||||
# Create directories
|
||||
(test_dir / 'effects').mkdir()
|
||||
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
|
||||
|
||||
# Create test image
|
||||
import cv2
|
||||
test_img = create_test_image()
|
||||
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
|
||||
|
||||
# Copy required effect files
|
||||
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
|
||||
src = EFFECTS_DIR / f'{effect}.sexp'
|
||||
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
|
||||
if src.exists():
|
||||
shutil.copy(src, dst)
|
||||
|
||||
yield {
|
||||
'dir': test_dir,
|
||||
'image_path': test_dir / 'test_image.png',
|
||||
'test_img': test_img,
|
||||
}
|
||||
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
def create_sexp_file(test_dir, content, filename='test.sexp'):
|
||||
"""Create a test sexp file."""
|
||||
path = test_dir / 'effects' / filename
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
class TestJaxPythonPipelineEquivalence:
|
||||
"""Test that JAX and Python pipelines produce equivalent output."""
|
||||
|
||||
def test_single_rotate_effect(self, test_env):
|
||||
"""Test that a single rotate effect matches between Python and JAX."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
|
||||
(frame (rotate frame :angle 15 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
# Python path
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
# JAX path
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
# Inject test image into globals
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
|
||||
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
|
||||
|
||||
# Force deferred if needed
|
||||
py_result = np.asarray(py_interp._maybe_force(py_result))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_rotate_then_zoom(self, test_env):
|
||||
"""Test rotate followed by zoom - tests effect chain fusion."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (-> (rotate frame :angle 15 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_zoom_shrink_boundary_handling(self, test_env):
|
||||
"""Test zoom with shrinking factor - critical for boundary handling."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (zoom frame :amount 0.8 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
# Check corners specifically - these are most affected by boundary handling
|
||||
h, w = test_img.shape[:2]
|
||||
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
|
||||
for y, x in corners:
|
||||
py_val = py_result[y, x]
|
||||
jax_val = jax_result[y, x]
|
||||
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
|
||||
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
|
||||
|
||||
def test_blend_two_clips(self, test_env):
|
||||
"""Test blending two effect chains - the core bug scenario."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
# Check edge region specifically
|
||||
edge_diff = diff[0, :].mean()
|
||||
|
||||
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_blend_with_invert(self, test_env):
|
||||
"""Test blending with invert - matches the problematic recipe pattern."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(require-primitives "color_ops")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
(effect invert :path "../sexp_effects/effects/invert.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0)
|
||||
(invert :amount 1))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
|
||||
class TestDeferredEffectChainFusion:
|
||||
"""Test the DeferredEffectChain fusion mechanism specifically."""
|
||||
|
||||
def test_manual_vs_fused_chain(self, test_env):
|
||||
"""Compare manual application vs fused DeferredEffectChain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
# Create minimal sexp to load effects
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Manual step-by-step application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
rot_angle = -5.0
|
||||
zoom_amount = 0.95
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain application
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
|
||||
|
||||
# Check specific pixels
|
||||
h, w = test_img.shape[:2]
|
||||
for y in [0, h//2, h-1]:
|
||||
for x in [0, w//2, w-1]:
|
||||
manual_val = manual_result[y, x]
|
||||
fused_val = fused_result[y, x]
|
||||
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
|
||||
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
|
||||
|
||||
def test_chain_with_shrink_zoom_boundary(self, test_env):
|
||||
"""Test that shrinking zoom handles boundaries correctly in chain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Parameters that shrink the image (zoom < 1.0)
|
||||
rot_angle = -4.555
|
||||
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
|
||||
|
||||
# Manual application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
# Check top edge specifically - this is where boundary issues manifest
|
||||
top_edge_manual = manual_result[0, :]
|
||||
top_edge_fused = fused_result[0, :]
|
||||
|
||||
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
|
||||
mean_edge_diff = np.mean(edge_diff)
|
||||
|
||||
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
|
||||
|
||||
# Check for zeros at edges that shouldn't be there
|
||||
manual_edge_sum = np.sum(top_edge_manual)
|
||||
fused_edge_sum = np.sum(top_edge_fused)
|
||||
|
||||
if manual_edge_sum > 100: # If manual has significant values
|
||||
assert fused_edge_sum > manual_edge_sum * 0.5, \
|
||||
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
334
tests/test_jax_primitives.py
Normal file
334
tests/test_jax_primitives.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test framework to verify JAX primitives match Python primitives.
|
||||
|
||||
Compares output of each primitive through:
|
||||
1. Python/NumPy path
|
||||
2. JAX path (CPU)
|
||||
3. JAX path (GPU) - if available
|
||||
|
||||
Reports any mismatches with detailed diffs.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/home/giles/art/art-celery')
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Test configuration
|
||||
TEST_WIDTH = 64
|
||||
TEST_HEIGHT = 48
|
||||
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
||||
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
||||
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
primitive: str
|
||||
passed: bool
|
||||
python_mean: float = 0.0
|
||||
jax_mean: float = 0.0
|
||||
diff_mean: float = 0.0
|
||||
diff_max: float = 0.0
|
||||
pct_within_1: float = 0.0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
||||
"""Create a test frame with known pattern."""
|
||||
if pattern == 'gradient':
|
||||
# Diagonal gradient
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
r = (x * 255 / w).astype(np.uint8)
|
||||
g = (y * 255 / h).astype(np.uint8)
|
||||
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
||||
return np.stack([r, g, b], axis=2)
|
||||
elif pattern == 'checkerboard':
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
check = ((x // 8) + (y // 8)) % 2
|
||||
v = (check * 255).astype(np.uint8)
|
||||
return np.stack([v, v, v], axis=2)
|
||||
elif pattern == 'solid':
|
||||
return np.full((h, w, 3), 128, dtype=np.uint8)
|
||||
else:
|
||||
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
||||
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
||||
if py_out is None or jax_out is None:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
||||
# Compare coordinate maps
|
||||
diffs = []
|
||||
for k in py_out:
|
||||
if k in jax_out:
|
||||
py_arr = np.asarray(py_out[k])
|
||||
jax_arr = np.asarray(jax_out[k])
|
||||
if py_arr.shape == jax_arr.shape:
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
diffs.append(diff)
|
||||
if diffs:
|
||||
all_diff = np.concatenate([d.flatten() for d in diffs])
|
||||
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
py_arr = np.asarray(py_out)
|
||||
jax_arr = np.asarray(jax_out)
|
||||
|
||||
if py_arr.shape != jax_arr.shape:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Test Definitions
|
||||
# ============================================================================
|
||||
|
||||
PRIMITIVE_TESTS = {
|
||||
# Geometry primitives
|
||||
'geometry:ripple-displace': {
|
||||
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
||||
'returns': 'coords',
|
||||
},
|
||||
'geometry:rotate-img': {
|
||||
'args': ['frame', 45],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:scale-img': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-h': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-v': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Color operations
|
||||
'color_ops:invert': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:grayscale': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:brightness': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:contrast': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:hue-shift': {
|
||||
'args': ['frame', 90],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Image operations
|
||||
'image:width': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:height': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:channel': {
|
||||
'args': ['frame', 0],
|
||||
'returns': 'array',
|
||||
},
|
||||
|
||||
# Blending
|
||||
'blending:blend': {
|
||||
'args': ['frame', 'frame2', 0.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-add': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-multiply': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the Python interpreter."""
|
||||
if prim_name not in interp.primitives:
|
||||
return None
|
||||
|
||||
func = interp.primitives[prim_name]
|
||||
args = []
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(frame.copy())
|
||||
elif a == 'frame2':
|
||||
args.append(frame2.copy())
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
try:
|
||||
return func(*args)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the JAX compiler."""
|
||||
try:
|
||||
from streaming.sexp_to_jax import JaxCompiler
|
||||
import jax.numpy as jnp
|
||||
|
||||
compiler = JaxCompiler()
|
||||
|
||||
# Build a simple expression to test the primitive
|
||||
from sexp_effects.parser import Symbol, Keyword
|
||||
|
||||
args = []
|
||||
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
||||
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(Symbol('frame'))
|
||||
elif a == 'frame2':
|
||||
args.append(Symbol('frame2'))
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
# Create expression: (prim_name arg1 arg2 ...)
|
||||
expr = [Symbol(prim_name)] + args
|
||||
|
||||
result = compiler._eval(expr, env)
|
||||
|
||||
if hasattr(result, '__array__'):
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
||||
"""Test a single primitive."""
|
||||
frame = create_test_frame(pattern='gradient')
|
||||
frame2 = create_test_frame(pattern='checkerboard')
|
||||
|
||||
result = TestResult(primitive=prim_name, passed=False)
|
||||
|
||||
# Run Python version
|
||||
try:
|
||||
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
||||
if py_out is not None and hasattr(py_out, 'shape'):
|
||||
result.python_mean = float(np.mean(py_out))
|
||||
except Exception as e:
|
||||
result.error = f"Python error: {e}"
|
||||
return result
|
||||
|
||||
# Run JAX version
|
||||
try:
|
||||
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
||||
if jax_out is not None and hasattr(jax_out, 'shape'):
|
||||
result.jax_mean = float(np.mean(jax_out))
|
||||
except Exception as e:
|
||||
result.error = f"JAX error: {e}"
|
||||
return result
|
||||
|
||||
if py_out is None:
|
||||
result.error = "Python returned None"
|
||||
return result
|
||||
if jax_out is None:
|
||||
result.error = "JAX returned None"
|
||||
return result
|
||||
|
||||
# Compare
|
||||
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
||||
result.diff_mean = diff_mean
|
||||
result.diff_max = diff_max
|
||||
result.pct_within_1 = pct
|
||||
|
||||
# Check pass/fail
|
||||
result.passed = (
|
||||
diff_mean <= TOLERANCE_MEAN and
|
||||
diff_max <= TOLERANCE_MAX and
|
||||
pct >= TOLERANCE_PCT
|
||||
)
|
||||
|
||||
if not result.passed:
|
||||
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_primitives(interp) -> List[str]:
|
||||
"""Discover all primitives available in the interpreter."""
|
||||
return sorted(interp.primitives.keys())
|
||||
|
||||
|
||||
def run_all_tests(verbose=True):
|
||||
"""Run all primitive tests."""
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import os
|
||||
os.chdir('/home/giles/art/test')
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
|
||||
# Create interpreter to get Python primitives
|
||||
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
||||
interp._init()
|
||||
|
||||
results = []
|
||||
|
||||
print("=" * 60)
|
||||
print("JAX PRIMITIVE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Test defined primitives
|
||||
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
||||
result = test_primitive(interp, prim_name, test_def)
|
||||
results.append(result)
|
||||
|
||||
status = "✓ PASS" if result.passed else "✗ FAIL"
|
||||
if verbose:
|
||||
print(f"{status} {prim_name}")
|
||||
if not result.passed:
|
||||
print(f" {result.error}")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
failed = sum(1 for r in results if not r.passed)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"SUMMARY: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed > 0:
|
||||
print("\nFailed primitives:")
|
||||
for r in results:
|
||||
if not r.passed:
|
||||
print(f" - {r.primitive}: {r.error}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_all_tests()
|
||||
305
tests/test_xector.py
Normal file
305
tests/test_xector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Tests for xector primitives - parallel array operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sexp_effects.primitive_libs.xector import (
|
||||
Xector,
|
||||
xector_red, xector_green, xector_blue, xector_rgb,
|
||||
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
|
||||
xector_dist_from_center,
|
||||
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
|
||||
alpha_sin, alpha_cos, alpha_sq,
|
||||
alpha_lt, alpha_gt, alpha_eq,
|
||||
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
|
||||
xector_where, xector_fill, xector_zeros, xector_ones,
|
||||
is_xector,
|
||||
)
|
||||
|
||||
|
||||
class TestXectorBasics:
|
||||
"""Test Xector class basic operations."""
|
||||
|
||||
def test_create_from_list(self):
|
||||
x = Xector([1, 2, 3])
|
||||
assert len(x) == 3
|
||||
assert is_xector(x)
|
||||
|
||||
def test_create_from_numpy(self):
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
x = Xector(arr)
|
||||
assert len(x) == 3
|
||||
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
|
||||
|
||||
def test_implicit_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = a + b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_implicit_mul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([2, 2, 2])
|
||||
c = a * b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_scalar_broadcast(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = a + 10
|
||||
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
|
||||
|
||||
def test_scalar_broadcast_rmul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = 2 * a
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
|
||||
class TestAlphaOperations:
|
||||
"""Test α (element-wise) operations."""
|
||||
|
||||
def test_alpha_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = alpha_add(a, b)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_alpha_add_multi(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([1, 1, 1])
|
||||
c = Xector([10, 10, 10])
|
||||
d = alpha_add(a, b, c)
|
||||
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
|
||||
|
||||
def test_alpha_mul_scalar(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = alpha_mul(a, 2)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_alpha_sqrt(self):
|
||||
a = Xector([1, 4, 9, 16])
|
||||
c = alpha_sqrt(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_alpha_clamp(self):
|
||||
a = Xector([-5, 0, 5, 10, 15])
|
||||
c = alpha_clamp(a, 0, 10)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
|
||||
|
||||
def test_alpha_sin_cos(self):
|
||||
a = Xector([0, np.pi/2, np.pi])
|
||||
s = alpha_sin(a)
|
||||
c = alpha_cos(a)
|
||||
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
|
||||
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
|
||||
|
||||
def test_alpha_sq(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
c = alpha_sq(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
|
||||
|
||||
def test_alpha_comparison(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
b = Xector([2, 2, 2, 2])
|
||||
lt = alpha_lt(a, b)
|
||||
gt = alpha_gt(a, b)
|
||||
eq = alpha_eq(a, b)
|
||||
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
|
||||
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
|
||||
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
|
||||
|
||||
|
||||
class TestBetaOperations:
|
||||
"""Test β (reduction) operations."""
|
||||
|
||||
def test_beta_add(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_add(a) == 10
|
||||
|
||||
def test_beta_mul(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mul(a) == 24
|
||||
|
||||
def test_beta_min_max(self):
|
||||
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert beta_min(a) == 1
|
||||
assert beta_max(a) == 9
|
||||
|
||||
def test_beta_mean(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mean(a) == 2.5
|
||||
|
||||
def test_beta_count(self):
|
||||
a = Xector([1, 2, 3, 4, 5])
|
||||
assert beta_count(a) == 5
|
||||
|
||||
|
||||
class TestFrameConversion:
|
||||
"""Test frame/xector conversion."""
|
||||
|
||||
def test_extract_channels(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[255, 0, 0], [0, 255, 0]],
|
||||
[[0, 0, 255], [128, 128, 128]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
assert len(r) == 4
|
||||
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
|
||||
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
|
||||
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
|
||||
|
||||
def test_rgb_roundtrip(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[100, 150, 200], [50, 75, 100]],
|
||||
[[200, 100, 50], [25, 50, 75]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
reconstructed = xector_rgb(r, g, b)
|
||||
np.testing.assert_array_equal(reconstructed, frame)
|
||||
|
||||
def test_modify_and_reconstruct(self):
|
||||
frame = np.array([
|
||||
[[100, 100, 100], [100, 100, 100]],
|
||||
[[100, 100, 100], [100, 100, 100]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
# Double red channel
|
||||
r_doubled = r * 2
|
||||
|
||||
result = xector_rgb(r_doubled, g, b)
|
||||
|
||||
# Red should be 200, others unchanged
|
||||
assert result[0, 0, 0] == 200
|
||||
assert result[0, 0, 1] == 100
|
||||
assert result[0, 0, 2] == 100
|
||||
|
||||
|
||||
class TestCoordinates:
|
||||
"""Test coordinate generation."""
|
||||
|
||||
def test_x_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
x = xector_x_coords(frame)
|
||||
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
|
||||
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
|
||||
|
||||
def test_y_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
y = xector_y_coords(frame)
|
||||
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
|
||||
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
|
||||
|
||||
def test_normalized_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_x_norm(frame)
|
||||
y = xector_y_norm(frame)
|
||||
|
||||
# x should go 0 to 1 across width
|
||||
assert x.to_numpy()[0] == 0
|
||||
assert x.to_numpy()[2] == 1
|
||||
|
||||
# y should go 0 to 1 down height
|
||||
assert y.to_numpy()[0] == 0
|
||||
assert y.to_numpy()[3] == 1
|
||||
|
||||
|
||||
class TestConditional:
|
||||
"""Test conditional operations."""
|
||||
|
||||
def test_where(self):
|
||||
cond = Xector([True, False, True, False])
|
||||
true_val = Xector([1, 1, 1, 1])
|
||||
false_val = Xector([0, 0, 0, 0])
|
||||
|
||||
result = xector_where(cond, true_val, false_val)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
|
||||
|
||||
def test_where_with_comparison(self):
|
||||
a = Xector([1, 5, 3, 7])
|
||||
threshold = 4
|
||||
|
||||
# Elements > 4 become 255, others become 0
|
||||
result = xector_where(alpha_gt(a, threshold), 255, 0)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
|
||||
|
||||
def test_fill(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_fill(42, frame)
|
||||
assert len(x) == 6
|
||||
assert all(v == 42 for v in x.to_numpy())
|
||||
|
||||
def test_zeros_ones(self):
|
||||
frame = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||
z = xector_zeros(frame)
|
||||
o = xector_ones(frame)
|
||||
|
||||
assert all(v == 0 for v in z.to_numpy())
|
||||
assert all(v == 1 for v in o.to_numpy())
|
||||
|
||||
|
||||
class TestInterpreterIntegration:
|
||||
"""Test xector operations through the interpreter."""
|
||||
|
||||
def test_xector_vignette_effect(self):
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
|
||||
# Load the xector vignette effect
|
||||
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
|
||||
|
||||
# Create a test frame (white)
|
||||
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
|
||||
# Run effect
|
||||
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
|
||||
|
||||
# Center should be brighter than corners
|
||||
center = result[50, 50]
|
||||
corner = result[0, 0]
|
||||
|
||||
assert center.mean() > corner.mean(), "Center should be brighter than corners"
|
||||
# Corners should be darkened
|
||||
assert corner.mean() < 255, "Corners should be darkened"
|
||||
|
||||
def test_implicit_elementwise(self):
|
||||
"""Test that regular + works element-wise on xectors."""
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
# Load xector primitives
|
||||
from sexp_effects.primitive_libs.xector import PRIMITIVES
|
||||
for name, fn in PRIMITIVES.items():
|
||||
interp.global_env.set(name, fn)
|
||||
|
||||
# Parse and eval a simple xector expression
|
||||
from sexp_effects.parser import parse
|
||||
expr = parse('(+ (red frame) 10)')
|
||||
|
||||
# Create test frame
|
||||
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
|
||||
interp.global_env.set('frame', frame)
|
||||
|
||||
result = interp.eval(expr)
|
||||
|
||||
# Should be a xector with values 110
|
||||
assert is_xector(result)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user