Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
- Add JAX text rendering with font atlas, styled text placement, and typography primitives - Add xector (element-wise/reduction) operations library and sexp effects - Add deferred effect chain fusion for JIT-compiled effect pipelines - Expand drawing primitives with font management, alignment, shadow, and outline - Add interpreter support for function-style define and require - Add GPU persistence mode and hardware decode support to streaming - Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions - Add path registry for asset resolution - Add integration, primitives, and xector tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
477
path_registry.py
Normal file
477
path_registry.py
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
"""
|
||||||
|
Path Registry - Maps human-friendly paths to content-addressed IDs.
|
||||||
|
|
||||||
|
This module provides a bidirectional mapping between:
|
||||||
|
- Human-friendly paths (e.g., "effects/ascii_fx_zone.sexp")
|
||||||
|
- Content-addressed IDs (IPFS CIDs or SHA3-256 hashes)
|
||||||
|
|
||||||
|
The registry is useful for:
|
||||||
|
- Looking up effects by their friendly path name
|
||||||
|
- Resolving cids back to the original path for debugging
|
||||||
|
- Maintaining a stable naming scheme across cache updates
|
||||||
|
|
||||||
|
Storage:
|
||||||
|
- Uses the existing item_types table in the database (path column)
|
||||||
|
- Caches in Redis for fast lookups across distributed workers
|
||||||
|
|
||||||
|
The registry uses a system actor (@system@local) for global path mappings,
|
||||||
|
allowing effects to be resolved by path without requiring user context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# System actor for global path mappings (effects, recipes, analyzers)
|
||||||
|
SYSTEM_ACTOR = "@system@local"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathEntry:
|
||||||
|
"""A registered path with its content-addressed ID."""
|
||||||
|
path: str # Human-friendly path (relative or normalized)
|
||||||
|
cid: str # Content-addressed ID (IPFS CID or hash)
|
||||||
|
content_type: str # Type: "effect", "recipe", "analyzer", etc.
|
||||||
|
actor_id: str = SYSTEM_ACTOR # Owner (system for global)
|
||||||
|
description: Optional[str] = None
|
||||||
|
created_at: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class PathRegistry:
|
||||||
|
"""
|
||||||
|
Registry for mapping paths to content-addressed IDs.
|
||||||
|
|
||||||
|
Uses the existing item_types table for persistence and Redis
|
||||||
|
for fast lookups in distributed Celery workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis_client=None):
|
||||||
|
self._redis = redis_client
|
||||||
|
self._redis_path_to_cid_key = "artdag:path_to_cid"
|
||||||
|
self._redis_cid_to_path_key = "artdag:cid_to_path"
|
||||||
|
|
||||||
|
def _run_async(self, coro):
|
||||||
|
"""Run async coroutine from sync context."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
import threading
|
||||||
|
result = [None]
|
||||||
|
error = [None]
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
try:
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
try:
|
||||||
|
result[0] = new_loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
error[0] = e
|
||||||
|
|
||||||
|
thread = threading.Thread(target=run_in_thread)
|
||||||
|
thread.start()
|
||||||
|
thread.join(timeout=30)
|
||||||
|
if error[0]:
|
||||||
|
raise error[0]
|
||||||
|
return result[0]
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
def _normalize_path(self, path: str) -> str:
|
||||||
|
"""Normalize a path for consistent storage."""
|
||||||
|
# Remove leading ./ or /
|
||||||
|
path = path.lstrip('./')
|
||||||
|
# Normalize separators
|
||||||
|
path = path.replace('\\', '/')
|
||||||
|
# Remove duplicate slashes
|
||||||
|
while '//' in path:
|
||||||
|
path = path.replace('//', '/')
|
||||||
|
return path
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
cid: str,
|
||||||
|
content_type: str = "effect",
|
||||||
|
actor_id: str = SYSTEM_ACTOR,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> PathEntry:
|
||||||
|
"""
|
||||||
|
Register a path -> cid mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Human-friendly path (e.g., "effects/ascii_fx_zone.sexp")
|
||||||
|
cid: Content-addressed ID (IPFS CID or hash)
|
||||||
|
content_type: Type of content ("effect", "recipe", "analyzer")
|
||||||
|
actor_id: Owner (default: system for global mappings)
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created PathEntry
|
||||||
|
"""
|
||||||
|
norm_path = self._normalize_path(path)
|
||||||
|
now = datetime.now(timezone.utc).timestamp()
|
||||||
|
|
||||||
|
entry = PathEntry(
|
||||||
|
path=norm_path,
|
||||||
|
cid=cid,
|
||||||
|
content_type=content_type,
|
||||||
|
actor_id=actor_id,
|
||||||
|
description=description,
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in database (item_types table)
|
||||||
|
self._save_to_db(entry)
|
||||||
|
|
||||||
|
# Update Redis cache
|
||||||
|
self._update_redis_cache(norm_path, cid)
|
||||||
|
|
||||||
|
logger.info(f"Registered path '{norm_path}' -> {cid[:16]}...")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def _save_to_db(self, entry: PathEntry):
|
||||||
|
"""Save entry to database using item_types table."""
|
||||||
|
import database
|
||||||
|
|
||||||
|
async def save():
|
||||||
|
import asyncpg
|
||||||
|
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||||
|
try:
|
||||||
|
# Ensure cache_item exists
|
||||||
|
await conn.execute(
|
||||||
|
"INSERT INTO cache_items (cid) VALUES ($1) ON CONFLICT DO NOTHING",
|
||||||
|
entry.cid
|
||||||
|
)
|
||||||
|
# Insert or update item_type with path
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO item_types (cid, actor_id, type, path, description)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
ON CONFLICT (cid, actor_id, type, path) DO UPDATE SET
|
||||||
|
description = COALESCE(EXCLUDED.description, item_types.description)
|
||||||
|
""",
|
||||||
|
entry.cid, entry.actor_id, entry.content_type, entry.path, entry.description
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._run_async(save())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save path registry to DB: {e}")
|
||||||
|
|
||||||
|
def _update_redis_cache(self, path: str, cid: str):
|
||||||
|
"""Update Redis cache with mapping."""
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
self._redis.hset(self._redis_path_to_cid_key, path, cid)
|
||||||
|
self._redis.hset(self._redis_cid_to_path_key, cid, path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update Redis cache: {e}")
|
||||||
|
|
||||||
|
def get_cid(self, path: str, content_type: str = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the cid for a path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Human-friendly path
|
||||||
|
content_type: Optional type filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cid, or None if not found
|
||||||
|
"""
|
||||||
|
norm_path = self._normalize_path(path)
|
||||||
|
|
||||||
|
# Try Redis first (fast path)
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
val = self._redis.hget(self._redis_path_to_cid_key, norm_path)
|
||||||
|
if val:
|
||||||
|
return val.decode() if isinstance(val, bytes) else val
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis lookup failed: {e}")
|
||||||
|
|
||||||
|
# Fall back to database
|
||||||
|
return self._get_cid_from_db(norm_path, content_type)
|
||||||
|
|
||||||
|
def _get_cid_from_db(self, path: str, content_type: str = None) -> Optional[str]:
|
||||||
|
"""Get cid from database using item_types table."""
|
||||||
|
import database
|
||||||
|
|
||||||
|
async def get():
|
||||||
|
import asyncpg
|
||||||
|
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||||
|
try:
|
||||||
|
if content_type:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT cid FROM item_types WHERE path = $1 AND type = $2",
|
||||||
|
path, content_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT cid FROM item_types WHERE path = $1",
|
||||||
|
path
|
||||||
|
)
|
||||||
|
return row["cid"] if row else None
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._run_async(get())
|
||||||
|
# Update Redis cache if found
|
||||||
|
if result and self._redis:
|
||||||
|
self._update_redis_cache(path, result)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get from DB: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_path(self, cid: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the path for a cid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cid: Content-addressed ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path, or None if not found
|
||||||
|
"""
|
||||||
|
# Try Redis first
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
val = self._redis.hget(self._redis_cid_to_path_key, cid)
|
||||||
|
if val:
|
||||||
|
return val.decode() if isinstance(val, bytes) else val
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis lookup failed: {e}")
|
||||||
|
|
||||||
|
# Fall back to database
|
||||||
|
return self._get_path_from_db(cid)
|
||||||
|
|
||||||
|
def _get_path_from_db(self, cid: str) -> Optional[str]:
|
||||||
|
"""Get path from database using item_types table."""
|
||||||
|
import database
|
||||||
|
|
||||||
|
async def get():
|
||||||
|
import asyncpg
|
||||||
|
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||||
|
try:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT path FROM item_types WHERE cid = $1 AND path IS NOT NULL ORDER BY created_at LIMIT 1",
|
||||||
|
cid
|
||||||
|
)
|
||||||
|
return row["path"] if row else None
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._run_async(get())
|
||||||
|
# Update Redis cache if found
|
||||||
|
if result and self._redis:
|
||||||
|
self._update_redis_cache(result, cid)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get from DB: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_by_type(self, content_type: str, actor_id: str = None) -> List[PathEntry]:
|
||||||
|
"""
|
||||||
|
List all entries of a given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: Type to filter by ("effect", "recipe", etc.)
|
||||||
|
actor_id: Optional actor filter (None = all, SYSTEM_ACTOR = global)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PathEntry objects
|
||||||
|
"""
|
||||||
|
import database
|
||||||
|
|
||||||
|
async def list_entries():
|
||||||
|
import asyncpg
|
||||||
|
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||||
|
try:
|
||||||
|
if actor_id:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT cid, path, type, actor_id, description,
|
||||||
|
EXTRACT(EPOCH FROM created_at) as created_at
|
||||||
|
FROM item_types
|
||||||
|
WHERE type = $1 AND actor_id = $2 AND path IS NOT NULL
|
||||||
|
ORDER BY path
|
||||||
|
""",
|
||||||
|
content_type, actor_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
"""
|
||||||
|
SELECT cid, path, type, actor_id, description,
|
||||||
|
EXTRACT(EPOCH FROM created_at) as created_at
|
||||||
|
FROM item_types
|
||||||
|
WHERE type = $1 AND path IS NOT NULL
|
||||||
|
ORDER BY path
|
||||||
|
""",
|
||||||
|
content_type
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
PathEntry(
|
||||||
|
path=row["path"],
|
||||||
|
cid=row["cid"],
|
||||||
|
content_type=row["type"],
|
||||||
|
actor_id=row["actor_id"],
|
||||||
|
description=row["description"],
|
||||||
|
created_at=row["created_at"] or 0,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_async(list_entries())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to list from DB: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self, path: str, content_type: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a path registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to delete
|
||||||
|
content_type: Optional type filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
norm_path = self._normalize_path(path)
|
||||||
|
|
||||||
|
# Get cid for Redis cleanup
|
||||||
|
cid = self.get_cid(norm_path, content_type)
|
||||||
|
|
||||||
|
# Delete from database
|
||||||
|
deleted = self._delete_from_db(norm_path, content_type)
|
||||||
|
|
||||||
|
# Clean up Redis
|
||||||
|
if deleted and cid and self._redis:
|
||||||
|
try:
|
||||||
|
self._redis.hdel(self._redis_path_to_cid_key, norm_path)
|
||||||
|
self._redis.hdel(self._redis_cid_to_path_key, cid)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up Redis: {e}")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def _delete_from_db(self, path: str, content_type: str = None) -> bool:
|
||||||
|
"""Delete from database."""
|
||||||
|
import database
|
||||||
|
|
||||||
|
async def delete():
|
||||||
|
import asyncpg
|
||||||
|
conn = await asyncpg.connect(database.DATABASE_URL)
|
||||||
|
try:
|
||||||
|
if content_type:
|
||||||
|
result = await conn.execute(
|
||||||
|
"DELETE FROM item_types WHERE path = $1 AND type = $2",
|
||||||
|
path, content_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await conn.execute(
|
||||||
|
"DELETE FROM item_types WHERE path = $1",
|
||||||
|
path
|
||||||
|
)
|
||||||
|
return "DELETE" in result
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_async(delete())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete from DB: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_effect(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
cid: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> PathEntry:
|
||||||
|
"""
|
||||||
|
Convenience method to register an effect.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Effect path (e.g., "effects/ascii_fx_zone.sexp")
|
||||||
|
cid: IPFS CID of the effect file
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created PathEntry
|
||||||
|
"""
|
||||||
|
return self.register(
|
||||||
|
path=path,
|
||||||
|
cid=cid,
|
||||||
|
content_type="effect",
|
||||||
|
actor_id=SYSTEM_ACTOR,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_effect_cid(self, path: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get CID for an effect by path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Effect path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IPFS CID or None
|
||||||
|
"""
|
||||||
|
return self.get_cid(path, content_type="effect")
|
||||||
|
|
||||||
|
def list_effects(self) -> List[PathEntry]:
|
||||||
|
"""List all registered effects."""
|
||||||
|
return self.list_by_type("effect", actor_id=SYSTEM_ACTOR)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_registry: Optional[PathRegistry] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_path_registry() -> PathRegistry:
|
||||||
|
"""Get the singleton path registry instance."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
import redis
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
|
||||||
|
parsed = urlparse(redis_url)
|
||||||
|
redis_client = redis.Redis(
|
||||||
|
host=parsed.hostname or 'localhost',
|
||||||
|
port=parsed.port or 6379,
|
||||||
|
db=int(parsed.path.lstrip('/') or 0),
|
||||||
|
socket_timeout=5,
|
||||||
|
socket_connect_timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
|
_registry = PathRegistry(redis_client=redis_client)
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_path_registry():
|
||||||
|
"""Reset the singleton (for testing)."""
|
||||||
|
global _registry
|
||||||
|
_registry = None
|
||||||
206
sexp_effects/derived.sexp
Normal file
206
sexp_effects/derived.sexp
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
;; Derived Operations
|
||||||
|
;;
|
||||||
|
;; These are built from true primitives using S-expressions.
|
||||||
|
;; Load with: (require "derived")
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Math Helpers (derivable from where + basic ops)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Absolute value
|
||||||
|
(define (abs x) (where (< x 0) (- x) x))
|
||||||
|
|
||||||
|
;; Minimum of two values
|
||||||
|
(define (min2 a b) (where (< a b) a b))
|
||||||
|
|
||||||
|
;; Maximum of two values
|
||||||
|
(define (max2 a b) (where (> a b) a b))
|
||||||
|
|
||||||
|
;; Clamp x to range [lo, hi]
|
||||||
|
(define (clamp x lo hi) (max2 lo (min2 hi x)))
|
||||||
|
|
||||||
|
;; Square of x
|
||||||
|
(define (sq x) (* x x))
|
||||||
|
|
||||||
|
;; Linear interpolation: a*(1-t) + b*t
|
||||||
|
(define (lerp a b t) (+ (* a (- 1 t)) (* b t)))
|
||||||
|
|
||||||
|
;; Smooth interpolation between edges
|
||||||
|
(define (smoothstep edge0 edge1 x)
|
||||||
|
(let ((t (clamp (/ (- x edge0) (- edge1 edge0)) 0 1)))
|
||||||
|
(* t (* t (- 3 (* 2 t))))))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Channel Shortcuts (derivable from channel primitive)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Extract red channel as xector
|
||||||
|
(define (red frame) (channel frame 0))
|
||||||
|
|
||||||
|
;; Extract green channel as xector
|
||||||
|
(define (green frame) (channel frame 1))
|
||||||
|
|
||||||
|
;; Extract blue channel as xector
|
||||||
|
(define (blue frame) (channel frame 2))
|
||||||
|
|
||||||
|
;; Convert to grayscale xector (ITU-R BT.601)
|
||||||
|
(define (gray frame)
|
||||||
|
(+ (* (red frame) 0.299)
|
||||||
|
(* (green frame) 0.587)
|
||||||
|
(* (blue frame) 0.114)))
|
||||||
|
|
||||||
|
;; Alias for gray
|
||||||
|
(define (luminance frame) (gray frame))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Coordinate Generators (derivable from iota + repeat/tile)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; X coordinate for each pixel [0, width)
|
||||||
|
(define (x-coords frame) (tile (iota (width frame)) (height frame)))
|
||||||
|
|
||||||
|
;; Y coordinate for each pixel [0, height)
|
||||||
|
(define (y-coords frame) (repeat (iota (height frame)) (width frame)))
|
||||||
|
|
||||||
|
;; Normalized X coordinate [0, 1]
|
||||||
|
(define (x-norm frame) (/ (x-coords frame) (max2 1 (- (width frame) 1))))
|
||||||
|
|
||||||
|
;; Normalized Y coordinate [0, 1]
|
||||||
|
(define (y-norm frame) (/ (y-coords frame) (max2 1 (- (height frame) 1))))
|
||||||
|
|
||||||
|
;; Distance from frame center for each pixel
|
||||||
|
(define (dist-from-center frame)
|
||||||
|
(let* ((cx (/ (width frame) 2))
|
||||||
|
(cy (/ (height frame) 2))
|
||||||
|
(dx (- (x-coords frame) cx))
|
||||||
|
(dy (- (y-coords frame) cy)))
|
||||||
|
(sqrt (+ (sq dx) (sq dy)))))
|
||||||
|
|
||||||
|
;; Normalized distance from center [0, ~1]
|
||||||
|
(define (dist-norm frame)
|
||||||
|
(let ((d (dist-from-center frame)))
|
||||||
|
(/ d (max2 1 (βmax d)))))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Cell/Grid Operations (derivable from floor + basic math)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Cell row index for each pixel
|
||||||
|
(define (cell-row frame cell-size) (floor (/ (y-coords frame) cell-size)))
|
||||||
|
|
||||||
|
;; Cell column index for each pixel
|
||||||
|
(define (cell-col frame cell-size) (floor (/ (x-coords frame) cell-size)))
|
||||||
|
|
||||||
|
;; Number of cell rows
|
||||||
|
(define (num-rows frame cell-size) (floor (/ (height frame) cell-size)))
|
||||||
|
|
||||||
|
;; Number of cell columns
|
||||||
|
(define (num-cols frame cell-size) (floor (/ (width frame) cell-size)))
|
||||||
|
|
||||||
|
;; Flat cell index for each pixel
|
||||||
|
(define (cell-indices frame cell-size)
|
||||||
|
(+ (* (cell-row frame cell-size) (num-cols frame cell-size))
|
||||||
|
(cell-col frame cell-size)))
|
||||||
|
|
||||||
|
;; Total number of cells
|
||||||
|
(define (num-cells frame cell-size)
|
||||||
|
(* (num-rows frame cell-size) (num-cols frame cell-size)))
|
||||||
|
|
||||||
|
;; X position within cell [0, cell-size)
|
||||||
|
(define (local-x frame cell-size) (mod (x-coords frame) cell-size))
|
||||||
|
|
||||||
|
;; Y position within cell [0, cell-size)
|
||||||
|
(define (local-y frame cell-size) (mod (y-coords frame) cell-size))
|
||||||
|
|
||||||
|
;; Normalized X within cell [0, 1]
|
||||||
|
(define (local-x-norm frame cell-size)
|
||||||
|
(/ (local-x frame cell-size) (max2 1 (- cell-size 1))))
|
||||||
|
|
||||||
|
;; Normalized Y within cell [0, 1]
|
||||||
|
(define (local-y-norm frame cell-size)
|
||||||
|
(/ (local-y frame cell-size) (max2 1 (- cell-size 1))))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Fill Operations (derivable from iota)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Xector of n zeros
|
||||||
|
(define (zeros n) (* (iota n) 0))
|
||||||
|
|
||||||
|
;; Xector of n ones
|
||||||
|
(define (ones n) (+ (zeros n) 1))
|
||||||
|
|
||||||
|
;; Xector of n copies of val
|
||||||
|
(define (fill val n) (+ (zeros n) val))
|
||||||
|
|
||||||
|
;; Xector of zeros matching x's length
|
||||||
|
(define (zeros-like x) (* x 0))
|
||||||
|
|
||||||
|
;; Xector of ones matching x's length
|
||||||
|
(define (ones-like x) (+ (zeros-like x) 1))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Pooling (derivable from group-reduce)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Pool a channel by cell index
|
||||||
|
(define (pool-channel chan cell-idx num-cells)
|
||||||
|
(group-reduce chan cell-idx num-cells "mean"))
|
||||||
|
|
||||||
|
;; Pool red channel to cells
|
||||||
|
(define (pool-red frame cell-size)
|
||||||
|
(pool-channel (red frame)
|
||||||
|
(cell-indices frame cell-size)
|
||||||
|
(num-cells frame cell-size)))
|
||||||
|
|
||||||
|
;; Pool green channel to cells
|
||||||
|
(define (pool-green frame cell-size)
|
||||||
|
(pool-channel (green frame)
|
||||||
|
(cell-indices frame cell-size)
|
||||||
|
(num-cells frame cell-size)))
|
||||||
|
|
||||||
|
;; Pool blue channel to cells
|
||||||
|
(define (pool-blue frame cell-size)
|
||||||
|
(pool-channel (blue frame)
|
||||||
|
(cell-indices frame cell-size)
|
||||||
|
(num-cells frame cell-size)))
|
||||||
|
|
||||||
|
;; Pool grayscale to cells
|
||||||
|
(define (pool-gray frame cell-size)
|
||||||
|
(pool-channel (gray frame)
|
||||||
|
(cell-indices frame cell-size)
|
||||||
|
(num-cells frame cell-size)))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Blending (derivable from math)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Additive blend
|
||||||
|
(define (blend-add a b) (clamp (+ a b) 0 255))
|
||||||
|
|
||||||
|
;; Multiply blend (normalized)
|
||||||
|
(define (blend-multiply a b) (* (/ a 255) b))
|
||||||
|
|
||||||
|
;; Screen blend
|
||||||
|
(define (blend-screen a b) (- 255 (* (/ (- 255 a) 255) (- 255 b))))
|
||||||
|
|
||||||
|
;; Overlay blend
|
||||||
|
(define (blend-overlay a b)
|
||||||
|
(where (< a 128)
|
||||||
|
(* 2 (/ (* a b) 255))
|
||||||
|
(- 255 (* 2 (/ (* (- 255 a) (- 255 b)) 255)))))
|
||||||
|
|
||||||
|
;; =============================================================================
|
||||||
|
;; Simple Effects (derivable from primitives)
|
||||||
|
;; =============================================================================
|
||||||
|
|
||||||
|
;; Invert a channel (255 - c)
|
||||||
|
(define (invert-channel c) (- 255 c))
|
||||||
|
|
||||||
|
;; Binary threshold
|
||||||
|
(define (threshold-channel c thresh) (where (> c thresh) 255 0))
|
||||||
|
|
||||||
|
;; Reduce to n levels
|
||||||
|
(define (posterize-channel c levels)
|
||||||
|
(let ((step (/ 255 (- levels 1))))
|
||||||
|
(* (round (/ c step)) step)))
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
:params (
|
:params (
|
||||||
(char_size :type int :default 8 :range [4 32])
|
(char_size :type int :default 8 :range [4 32])
|
||||||
(alphabet :type string :default "standard")
|
(alphabet :type string :default "standard")
|
||||||
(color_mode :type string :default "color" :desc ""color", "mono", "invert", or any color name/hex")
|
(color_mode :type string :default "color" :desc "color, mono, invert, or any color name/hex")
|
||||||
(background_color :type string :default "black" :desc "background color name/hex")
|
(background_color :type string :default "black" :desc "background color name/hex")
|
||||||
(invert_colors :type int :default 0 :desc "swap foreground and background colors")
|
(invert_colors :type int :default 0 :desc "swap foreground and background colors")
|
||||||
(contrast :type float :default 1.5 :range [1 3])
|
(contrast :type float :default 1.5 :range [1 3])
|
||||||
|
|||||||
65
sexp_effects/effects/cell_pattern.sexp
Normal file
65
sexp_effects/effects/cell_pattern.sexp
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
;; Cell Pattern effect - custom patterns within cells
|
||||||
|
;;
|
||||||
|
;; Demonstrates building arbitrary per-cell visuals from primitives.
|
||||||
|
;; Uses local coordinates within cells to draw patterns scaled by luminance.
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect cell_pattern
|
||||||
|
:params (
|
||||||
|
(cell-size :type int :default 16 :range [8 48] :desc "Cell size")
|
||||||
|
(pattern :type string :default "diagonal" :desc "Pattern: diagonal, cross, ring")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Pool to get cell colors
|
||||||
|
(pooled (pool-frame frame cell-size))
|
||||||
|
(cell-r (nth pooled 0))
|
||||||
|
(cell-g (nth pooled 1))
|
||||||
|
(cell-b (nth pooled 2))
|
||||||
|
(cell-lum (α/ (nth pooled 3) 255))
|
||||||
|
|
||||||
|
;; Cell indices for each pixel
|
||||||
|
(cell-idx (cell-indices frame cell-size))
|
||||||
|
|
||||||
|
;; Look up cell values for each pixel
|
||||||
|
(pix-r (gather cell-r cell-idx))
|
||||||
|
(pix-g (gather cell-g cell-idx))
|
||||||
|
(pix-b (gather cell-b cell-idx))
|
||||||
|
(pix-lum (gather cell-lum cell-idx))
|
||||||
|
|
||||||
|
;; Local position within cell [0, 1]
|
||||||
|
(lx (local-x-norm frame cell-size))
|
||||||
|
(ly (local-y-norm frame cell-size))
|
||||||
|
|
||||||
|
;; Pattern mask based on pattern type
|
||||||
|
(mask
|
||||||
|
(cond
|
||||||
|
;; Diagonal lines - thickness based on luminance
|
||||||
|
((= pattern "diagonal")
|
||||||
|
(let* ((diag (αmod (α+ lx ly) 0.25))
|
||||||
|
(thickness (α* pix-lum 0.125)))
|
||||||
|
(α< diag thickness)))
|
||||||
|
|
||||||
|
;; Cross pattern
|
||||||
|
((= pattern "cross")
|
||||||
|
(let* ((cx (αabs (α- lx 0.5)))
|
||||||
|
(cy (αabs (α- ly 0.5)))
|
||||||
|
(thickness (α* pix-lum 0.25)))
|
||||||
|
(αor (α< cx thickness) (α< cy thickness))))
|
||||||
|
|
||||||
|
;; Ring pattern
|
||||||
|
((= pattern "ring")
|
||||||
|
(let* ((dx (α- lx 0.5))
|
||||||
|
(dy (α- ly 0.5))
|
||||||
|
(dist (αsqrt (α+ (α² dx) (α² dy))))
|
||||||
|
(target (α* pix-lum 0.4))
|
||||||
|
(thickness 0.05))
|
||||||
|
(α< (αabs (α- dist target)) thickness)))
|
||||||
|
|
||||||
|
;; Default: solid
|
||||||
|
(else (α> pix-lum 0)))))
|
||||||
|
|
||||||
|
;; Apply mask: show cell color where mask is true, black elsewhere
|
||||||
|
(rgb (where mask pix-r 0)
|
||||||
|
(where mask pix-g 0)
|
||||||
|
(where mask pix-b 0))))
|
||||||
@@ -6,10 +6,10 @@
|
|||||||
(num_echoes :type int :default 4 :range [1 20])
|
(num_echoes :type int :default 4 :range [1 20])
|
||||||
(decay :type float :default 0.5 :range [0 1])
|
(decay :type float :default 0.5 :range [0 1])
|
||||||
)
|
)
|
||||||
(let* ((buffer (state-get 'buffer (list)))
|
(let* ((buffer (state-get "buffer" (list)))
|
||||||
(new-buffer (take (cons frame buffer) (+ num_echoes 1))))
|
(new-buffer (take (cons frame buffer) (+ num_echoes 1))))
|
||||||
(begin
|
(begin
|
||||||
(state-set 'buffer new-buffer)
|
(state-set "buffer" new-buffer)
|
||||||
;; Blend frames with decay
|
;; Blend frames with decay
|
||||||
(if (< (length new-buffer) 2)
|
(if (< (length new-buffer) 2)
|
||||||
frame
|
frame
|
||||||
|
|||||||
49
sexp_effects/effects/halftone.sexp
Normal file
49
sexp_effects/effects/halftone.sexp
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
;; Halftone/dot effect - built from primitive xector operations
|
||||||
|
;;
|
||||||
|
;; Uses:
|
||||||
|
;; pool-frame - downsample to cell luminances
|
||||||
|
;; cell-indices - which cell each pixel belongs to
|
||||||
|
;; gather - look up cell value for each pixel
|
||||||
|
;; local-x/y-norm - position within cell [0,1]
|
||||||
|
;; where - conditional per-pixel
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect halftone
|
||||||
|
:params (
|
||||||
|
(cell-size :type int :default 12 :range [4 32] :desc "Size of halftone cells")
|
||||||
|
(dot-scale :type float :default 0.9 :range [0.1 1.0] :desc "Max dot radius")
|
||||||
|
(invert :type bool :default false :desc "Invert (white dots on black)")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Pool frame to get luminance per cell
|
||||||
|
(pooled (pool-frame frame cell-size))
|
||||||
|
(cell-lum (nth pooled 3)) ; luminance is 4th element
|
||||||
|
|
||||||
|
;; For each output pixel, get its cell index
|
||||||
|
(cell-idx (cell-indices frame cell-size))
|
||||||
|
|
||||||
|
;; Get cell luminance for each pixel
|
||||||
|
(pixel-lum (α/ (gather cell-lum cell-idx) 255))
|
||||||
|
|
||||||
|
;; Position within cell, normalized to [-0.5, 0.5]
|
||||||
|
(lx (α- (local-x-norm frame cell-size) 0.5))
|
||||||
|
(ly (α- (local-y-norm frame cell-size) 0.5))
|
||||||
|
|
||||||
|
;; Distance from cell center (0 at center, ~0.7 at corners)
|
||||||
|
(dist (αsqrt (α+ (α² lx) (α² ly))))
|
||||||
|
|
||||||
|
;; Radius based on luminance (brighter = bigger dot)
|
||||||
|
(radius (α* (if invert (α- 1 pixel-lum) pixel-lum)
|
||||||
|
(α* dot-scale 0.5)))
|
||||||
|
|
||||||
|
;; Is this pixel inside the dot?
|
||||||
|
(inside (α< dist radius))
|
||||||
|
|
||||||
|
;; Output color
|
||||||
|
(fg (if invert 255 0))
|
||||||
|
(bg (if invert 0 255))
|
||||||
|
(out (where inside fg bg)))
|
||||||
|
|
||||||
|
;; Grayscale output
|
||||||
|
(rgb out out out)))
|
||||||
30
sexp_effects/effects/mosaic.sexp
Normal file
30
sexp_effects/effects/mosaic.sexp
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
;; Mosaic effect - built from primitive xector operations
|
||||||
|
;;
|
||||||
|
;; Uses:
|
||||||
|
;; pool-frame - downsample to cell averages
|
||||||
|
;; cell-indices - which cell each pixel belongs to
|
||||||
|
;; gather - look up cell value for each pixel
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect mosaic
|
||||||
|
:params (
|
||||||
|
(cell-size :type int :default 16 :range [4 64] :desc "Size of mosaic cells")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Pool frame to get average color per cell (returns r,g,b,lum xectors)
|
||||||
|
(pooled (pool-frame frame cell-size))
|
||||||
|
(cell-r (nth pooled 0))
|
||||||
|
(cell-g (nth pooled 1))
|
||||||
|
(cell-b (nth pooled 2))
|
||||||
|
|
||||||
|
;; For each output pixel, get its cell index
|
||||||
|
(cell-idx (cell-indices frame cell-size))
|
||||||
|
|
||||||
|
;; Gather: look up cell color for each pixel
|
||||||
|
(out-r (gather cell-r cell-idx))
|
||||||
|
(out-g (gather cell-g cell-idx))
|
||||||
|
(out-b (gather cell-b cell-idx)))
|
||||||
|
|
||||||
|
;; Reconstruct frame
|
||||||
|
(rgb out-r out-g out-b)))
|
||||||
@@ -5,9 +5,9 @@
|
|||||||
:params (
|
:params (
|
||||||
(thickness :type int :default 2 :range [1 10])
|
(thickness :type int :default 2 :range [1 10])
|
||||||
(threshold :type int :default 100 :range [20 300])
|
(threshold :type int :default 100 :range [20 300])
|
||||||
(color :type list :default (list 0 0 0)
|
(color :type list :default (list 0 0 0))
|
||||||
|
(fill_mode :type string :default "original")
|
||||||
)
|
)
|
||||||
(fill_mode "original"))
|
|
||||||
(let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold))
|
(let* ((edge-img (image:edge-detect frame (/ threshold 2) threshold))
|
||||||
(dilated (if (> thickness 1)
|
(dilated (if (> thickness 1)
|
||||||
(dilate edge-img thickness)
|
(dilate edge-img thickness)
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
:params (
|
:params (
|
||||||
(frame_rate :type int :default 12 :range [1 60])
|
(frame_rate :type int :default 12 :range [1 60])
|
||||||
)
|
)
|
||||||
(let* ((held (state-get 'held nil))
|
(let* ((held (state-get "held" nil))
|
||||||
(held-until (state-get 'held-until 0))
|
(held-until (state-get "held-until" 0))
|
||||||
(frame-duration (/ 1 frame_rate)))
|
(frame-duration (/ 1 frame_rate)))
|
||||||
(if (or (core:is-nil held) (>= t held-until))
|
(if (or (core:is-nil held) (>= t held-until))
|
||||||
(begin
|
(begin
|
||||||
(state-set 'held (copy frame))
|
(state-set "held" (copy frame))
|
||||||
(state-set 'held-until (+ t frame-duration))
|
(state-set "held-until" (+ t frame-duration))
|
||||||
frame)
|
frame)
|
||||||
held)))
|
held)))
|
||||||
|
|||||||
@@ -5,16 +5,16 @@
|
|||||||
:params (
|
:params (
|
||||||
(persistence :type float :default 0.8 :range [0 0.99])
|
(persistence :type float :default 0.8 :range [0 0.99])
|
||||||
)
|
)
|
||||||
(let* ((buffer (state-get 'buffer nil))
|
(let* ((buffer (state-get "buffer" nil))
|
||||||
(current frame))
|
(current frame))
|
||||||
(if (= buffer nil)
|
(if (= buffer nil)
|
||||||
(begin
|
(begin
|
||||||
(state-set 'buffer (copy frame))
|
(state-set "buffer" (copy frame))
|
||||||
frame)
|
frame)
|
||||||
(let* ((faded (blending:blend-images buffer
|
(let* ((faded (blending:blend-images buffer
|
||||||
(make-image (image:width frame) (image:height frame) (list 0 0 0))
|
(make-image (image:width frame) (image:height frame) (list 0 0 0))
|
||||||
(- 1 persistence)))
|
(- 1 persistence)))
|
||||||
(result (blending:blend-mode faded current "lighten")))
|
(result (blending:blend-mode faded current "lighten")))
|
||||||
(begin
|
(begin
|
||||||
(state-set 'buffer result)
|
(state-set "buffer" result)
|
||||||
result)))))
|
result)))))
|
||||||
|
|||||||
44
sexp_effects/effects/xector_feathered_blend.sexp
Normal file
44
sexp_effects/effects/xector_feathered_blend.sexp
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
;; Feathered blend - blend two same-size frames with distance-based falloff
|
||||||
|
;; Center shows overlay, edges show background, with smooth transition
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect xector_feathered_blend
|
||||||
|
:params (
|
||||||
|
(inner-radius :type float :default 0.3 :range [0 1] :desc "Radius where overlay is 100% (fraction of size)")
|
||||||
|
(fade-width :type float :default 0.2 :range [0 0.5] :desc "Width of fade region (fraction of size)")
|
||||||
|
(overlay :type frame :default nil :desc "Frame to blend in center")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Get normalized distance from center (0 at center, ~1 at corners)
|
||||||
|
(dist (dist-from-center frame))
|
||||||
|
(max-dist (βmax dist))
|
||||||
|
(dist-norm (α/ dist max-dist))
|
||||||
|
|
||||||
|
;; Calculate blend factor:
|
||||||
|
;; - 1.0 when dist-norm < inner-radius (fully overlay)
|
||||||
|
;; - 0.0 when dist-norm > inner-radius + fade-width (fully background)
|
||||||
|
;; - linear ramp between
|
||||||
|
(t (α/ (α- dist-norm inner-radius) fade-width))
|
||||||
|
(blend (α- 1 (αclamp t 0 1)))
|
||||||
|
(inv-blend (α- 1 blend))
|
||||||
|
|
||||||
|
;; Background channels
|
||||||
|
(bg-r (red frame))
|
||||||
|
(bg-g (green frame))
|
||||||
|
(bg-b (blue frame)))
|
||||||
|
|
||||||
|
(if (nil? overlay)
|
||||||
|
;; No overlay - visualize the blend mask
|
||||||
|
(let ((vis (α* blend 255)))
|
||||||
|
(rgb vis vis vis))
|
||||||
|
|
||||||
|
;; Blend overlay with background using the mask
|
||||||
|
(let* ((ov-r (red overlay))
|
||||||
|
(ov-g (green overlay))
|
||||||
|
(ov-b (blue overlay))
|
||||||
|
;; lerp: bg * (1-blend) + overlay * blend
|
||||||
|
(r-out (α+ (α* bg-r inv-blend) (α* ov-r blend)))
|
||||||
|
(g-out (α+ (α* bg-g inv-blend) (α* ov-g blend)))
|
||||||
|
(b-out (α+ (α* bg-b inv-blend) (α* ov-b blend))))
|
||||||
|
(rgb r-out g-out b-out)))))
|
||||||
34
sexp_effects/effects/xector_grain.sexp
Normal file
34
sexp_effects/effects/xector_grain.sexp
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
;; Film grain effect using xector operations
|
||||||
|
;; Demonstrates random xectors and mixing scalar/xector math
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect xector_grain
|
||||||
|
:params (
|
||||||
|
(intensity :type float :default 0.2 :range [0 1] :desc "Grain intensity")
|
||||||
|
(colored :type bool :default false :desc "Use colored grain")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Extract channels
|
||||||
|
(r (red frame))
|
||||||
|
(g (green frame))
|
||||||
|
(b (blue frame))
|
||||||
|
|
||||||
|
;; Generate noise xector(s)
|
||||||
|
;; randn-x generates normal distribution noise
|
||||||
|
(grain-amount (* intensity 50)))
|
||||||
|
|
||||||
|
(if colored
|
||||||
|
;; Colored grain: different noise per channel
|
||||||
|
(let* ((nr (randn-x frame 0 grain-amount))
|
||||||
|
(ng (randn-x frame 0 grain-amount))
|
||||||
|
(nb (randn-x frame 0 grain-amount)))
|
||||||
|
(rgb (αclamp (α+ r nr) 0 255)
|
||||||
|
(αclamp (α+ g ng) 0 255)
|
||||||
|
(αclamp (α+ b nb) 0 255)))
|
||||||
|
|
||||||
|
;; Monochrome grain: same noise for all channels
|
||||||
|
(let ((n (randn-x frame 0 grain-amount)))
|
||||||
|
(rgb (αclamp (α+ r n) 0 255)
|
||||||
|
(αclamp (α+ g n) 0 255)
|
||||||
|
(αclamp (α+ b n) 0 255))))))
|
||||||
57
sexp_effects/effects/xector_inset_blend.sexp
Normal file
57
sexp_effects/effects/xector_inset_blend.sexp
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
;; Inset blend - fade a smaller frame into a larger background
|
||||||
|
;; Uses distance-based alpha for smooth transition (no hard edges)
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect xector_inset_blend
|
||||||
|
:params (
|
||||||
|
(x :type int :default 0 :desc "X position of inset")
|
||||||
|
(y :type int :default 0 :desc "Y position of inset")
|
||||||
|
(fade-width :type int :default 50 :desc "Width of fade region in pixels")
|
||||||
|
(overlay :type frame :default nil :desc "The smaller frame to inset")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Get dimensions
|
||||||
|
(bg-h (first (list (nth (list (red frame)) 0)))) ;; TODO: need image:height
|
||||||
|
(bg-w bg-h) ;; placeholder
|
||||||
|
|
||||||
|
;; For now, create a simple centered circular blend
|
||||||
|
;; Distance from center of overlay position
|
||||||
|
(cx (+ x (/ (- bg-w (* 2 x)) 2)))
|
||||||
|
(cy (+ y (/ (- bg-h (* 2 y)) 2)))
|
||||||
|
|
||||||
|
;; Get coordinates as xectors
|
||||||
|
(px (x-coords frame))
|
||||||
|
(py (y-coords frame))
|
||||||
|
|
||||||
|
;; Distance from center
|
||||||
|
(dx (α- px cx))
|
||||||
|
(dy (α- py cy))
|
||||||
|
(dist (αsqrt (α+ (α* dx dx) (α* dy dy))))
|
||||||
|
|
||||||
|
;; Inner radius (fully overlay) and outer radius (fully background)
|
||||||
|
(inner-r (- (/ bg-w 2) x fade-width))
|
||||||
|
(outer-r (- (/ bg-w 2) x))
|
||||||
|
|
||||||
|
;; Blend factor: 1.0 inside inner-r, 0.0 outside outer-r, linear between
|
||||||
|
(t (α/ (α- dist inner-r) fade-width))
|
||||||
|
(blend (α- 1 (αclamp t 0 1)))
|
||||||
|
|
||||||
|
;; Extract channels from both frames
|
||||||
|
(bg-r (red frame))
|
||||||
|
(bg-g (green frame))
|
||||||
|
(bg-b (blue frame)))
|
||||||
|
|
||||||
|
;; If overlay provided, blend it
|
||||||
|
(if overlay
|
||||||
|
(let* ((ov-r (red overlay))
|
||||||
|
(ov-g (green overlay))
|
||||||
|
(ov-b (blue overlay))
|
||||||
|
;; Linear blend: result = bg * (1-blend) + overlay * blend
|
||||||
|
(r-out (α+ (α* bg-r (α- 1 blend)) (α* ov-r blend)))
|
||||||
|
(g-out (α+ (α* bg-g (α- 1 blend)) (α* ov-g blend)))
|
||||||
|
(b-out (α+ (α* bg-b (α- 1 blend)) (α* ov-b blend))))
|
||||||
|
(rgb r-out g-out b-out))
|
||||||
|
;; No overlay - just show the blend mask for debugging
|
||||||
|
(let ((mask-vis (α* blend 255)))
|
||||||
|
(rgb mask-vis mask-vis mask-vis)))))
|
||||||
27
sexp_effects/effects/xector_threshold.sexp
Normal file
27
sexp_effects/effects/xector_threshold.sexp
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
;; Threshold effect using xector operations
|
||||||
|
;; Demonstrates where (conditional select) and β (reduction) for normalization
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect xector_threshold
|
||||||
|
:params (
|
||||||
|
(threshold :type float :default 0.5 :range [0 1] :desc "Brightness threshold (0-1)")
|
||||||
|
(invert :type bool :default false :desc "Invert the threshold")
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Get grayscale luminance as xector
|
||||||
|
(luma (gray frame))
|
||||||
|
|
||||||
|
;; Normalize to 0-1 range
|
||||||
|
(luma-norm (α/ luma 255))
|
||||||
|
|
||||||
|
;; Create boolean mask: pixels above threshold
|
||||||
|
(mask (if invert
|
||||||
|
(α< luma-norm threshold)
|
||||||
|
(α>= luma-norm threshold)))
|
||||||
|
|
||||||
|
;; Use where to select: white (255) if above threshold, black (0) if below
|
||||||
|
(out (where mask 255 0)))
|
||||||
|
|
||||||
|
;; Output as grayscale (same value for R, G, B)
|
||||||
|
(rgb out out out)))
|
||||||
36
sexp_effects/effects/xector_vignette.sexp
Normal file
36
sexp_effects/effects/xector_vignette.sexp
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
;; Vignette effect using xector operations
|
||||||
|
;; Demonstrates α (element-wise) and β (reduction) patterns
|
||||||
|
|
||||||
|
(require-primitives "xector")
|
||||||
|
|
||||||
|
(define-effect xector_vignette
|
||||||
|
:params (
|
||||||
|
(strength :type float :default 0.5 :range [0 1])
|
||||||
|
(radius :type float :default 1.0 :range [0.5 2])
|
||||||
|
)
|
||||||
|
(let* (
|
||||||
|
;; Get normalized distance from center for each pixel
|
||||||
|
(dist (dist-from-center frame))
|
||||||
|
|
||||||
|
;; Calculate max distance (corner distance)
|
||||||
|
(max-dist (* (βmax dist) radius))
|
||||||
|
|
||||||
|
;; Calculate brightness factor per pixel: 1 - (dist/max-dist * strength)
|
||||||
|
;; Using explicit α operators
|
||||||
|
(factor (α- 1 (α* (α/ dist max-dist) strength)))
|
||||||
|
|
||||||
|
;; Clamp factor to [0, 1]
|
||||||
|
(factor (αclamp factor 0 1))
|
||||||
|
|
||||||
|
;; Extract channels as xectors
|
||||||
|
(r (red frame))
|
||||||
|
(g (green frame))
|
||||||
|
(b (blue frame))
|
||||||
|
|
||||||
|
;; Apply factor to each channel (implicit element-wise via Xector operators)
|
||||||
|
(r-out (* r factor))
|
||||||
|
(g-out (* g factor))
|
||||||
|
(b-out (* b factor)))
|
||||||
|
|
||||||
|
;; Combine back to frame
|
||||||
|
(rgb r-out g-out b-out)))
|
||||||
@@ -156,11 +156,21 @@ class Interpreter:
|
|||||||
if form == 'define':
|
if form == 'define':
|
||||||
name = expr[1]
|
name = expr[1]
|
||||||
if _is_symbol(name):
|
if _is_symbol(name):
|
||||||
|
# Simple define: (define name value)
|
||||||
value = self.eval(expr[2], env)
|
value = self.eval(expr[2], env)
|
||||||
self.global_env.set(name.name, value)
|
self.global_env.set(name.name, value)
|
||||||
return value
|
return value
|
||||||
|
elif isinstance(name, list) and len(name) >= 1 and _is_symbol(name[0]):
|
||||||
|
# Function define: (define (fn-name args...) body)
|
||||||
|
# Desugars to: (define fn-name (lambda (args...) body))
|
||||||
|
fn_name = name[0].name
|
||||||
|
params = [p.name if _is_symbol(p) else p for p in name[1:]]
|
||||||
|
body = expr[2]
|
||||||
|
fn = Lambda(params, body, env)
|
||||||
|
self.global_env.set(fn_name, fn)
|
||||||
|
return fn
|
||||||
else:
|
else:
|
||||||
raise SyntaxError(f"define requires symbol, got {name}")
|
raise SyntaxError(f"define requires symbol or (name args...), got {name}")
|
||||||
|
|
||||||
# Define-effect
|
# Define-effect
|
||||||
if form == 'define-effect':
|
if form == 'define-effect':
|
||||||
@@ -276,6 +286,10 @@ class Interpreter:
|
|||||||
if form == 'require-primitives':
|
if form == 'require-primitives':
|
||||||
return self._eval_require_primitives(expr, env)
|
return self._eval_require_primitives(expr, env)
|
||||||
|
|
||||||
|
# require - load .sexp file into current scope
|
||||||
|
if form == 'require':
|
||||||
|
return self._eval_require(expr, env)
|
||||||
|
|
||||||
# Function call
|
# Function call
|
||||||
fn = self.eval(head, env)
|
fn = self.eval(head, env)
|
||||||
args = [self.eval(arg, env) for arg in expr[1:]]
|
args = [self.eval(arg, env) for arg in expr[1:]]
|
||||||
@@ -488,6 +502,61 @@ class Interpreter:
|
|||||||
from .primitive_libs import load_primitive_library
|
from .primitive_libs import load_primitive_library
|
||||||
return load_primitive_library(name, path)
|
return load_primitive_library(name, path)
|
||||||
|
|
||||||
|
def _eval_require(self, expr: Any, env: Environment) -> Any:
|
||||||
|
"""
|
||||||
|
Evaluate require: load a .sexp file and evaluate its definitions.
|
||||||
|
|
||||||
|
Syntax:
|
||||||
|
(require "derived") ; loads derived.sexp from sexp_effects/
|
||||||
|
(require "path/to/file.sexp") ; loads from explicit path
|
||||||
|
|
||||||
|
Definitions from the file are added to the current environment.
|
||||||
|
"""
|
||||||
|
for lib_expr in expr[1:]:
|
||||||
|
if _is_symbol(lib_expr):
|
||||||
|
lib_name = lib_expr.name
|
||||||
|
else:
|
||||||
|
lib_name = lib_expr
|
||||||
|
|
||||||
|
# Find the .sexp file
|
||||||
|
sexp_path = self._find_sexp_file(lib_name)
|
||||||
|
if sexp_path is None:
|
||||||
|
raise ValueError(f"Cannot find sexp file: {lib_name}")
|
||||||
|
|
||||||
|
# Parse and evaluate the file
|
||||||
|
content = parse_file(sexp_path)
|
||||||
|
|
||||||
|
# Evaluate all top-level expressions
|
||||||
|
if isinstance(content, list) and content and isinstance(content[0], list):
|
||||||
|
for e in content:
|
||||||
|
self.eval(e, env)
|
||||||
|
else:
|
||||||
|
self.eval(content, env)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_sexp_file(self, name: str) -> Optional[str]:
|
||||||
|
"""Find a .sexp file by name."""
|
||||||
|
# Try various locations
|
||||||
|
candidates = [
|
||||||
|
# Explicit path
|
||||||
|
name,
|
||||||
|
name + '.sexp',
|
||||||
|
# In sexp_effects directory
|
||||||
|
Path(__file__).parent / f'{name}.sexp',
|
||||||
|
Path(__file__).parent / name,
|
||||||
|
# In effects directory
|
||||||
|
Path(__file__).parent / 'effects' / f'{name}.sexp',
|
||||||
|
Path(__file__).parent / 'effects' / name,
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in candidates:
|
||||||
|
p = Path(path) if not isinstance(path, Path) else path
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any:
|
def _eval_ascii_fx_zone(self, expr: Any, env: Environment) -> Any:
|
||||||
"""
|
"""
|
||||||
Evaluate ascii-fx-zone special form.
|
Evaluate ascii-fx-zone special form.
|
||||||
@@ -876,8 +945,8 @@ class Interpreter:
|
|||||||
for pname, pdefault in effect.params.items():
|
for pname, pdefault in effect.params.items():
|
||||||
value = params.get(pname)
|
value = params.get(pname)
|
||||||
if value is None:
|
if value is None:
|
||||||
# Evaluate default if it's an expression (list)
|
# Evaluate default if it's an expression (list) or a symbol (like 'nil')
|
||||||
if isinstance(pdefault, list):
|
if isinstance(pdefault, list) or _is_symbol(pdefault):
|
||||||
value = self.eval(pdefault, env)
|
value = self.eval(pdefault, env)
|
||||||
else:
|
else:
|
||||||
value = pdefault
|
value = pdefault
|
||||||
|
|||||||
@@ -71,7 +71,8 @@ class Tokenizer:
|
|||||||
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
|
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
|
||||||
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
|
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
|
||||||
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
|
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
|
||||||
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*')
|
# Symbol pattern includes Greek letters α (alpha) and β (beta) for xector operations
|
||||||
|
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?αβ²λ][a-zA-Z0-9_*+\-><=/!?.:αβ²λ]*')
|
||||||
|
|
||||||
def __init__(self, text: str):
|
def __init__(self, text: str):
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|||||||
@@ -1,126 +1,680 @@
|
|||||||
"""
|
"""
|
||||||
Drawing Primitives Library
|
Drawing Primitives Library
|
||||||
|
|
||||||
Draw shapes, text, and characters on images.
|
Draw shapes, text, and characters on images with sophisticated text handling.
|
||||||
|
|
||||||
|
Text Features:
|
||||||
|
- Font loading from files or system fonts
|
||||||
|
- Text measurement and fitting
|
||||||
|
- Alignment (left/center/right, top/middle/bottom)
|
||||||
|
- Opacity for fade effects
|
||||||
|
- Multi-line text support
|
||||||
|
- Shadow and outline effects
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import os
|
||||||
|
import glob as glob_module
|
||||||
|
from typing import Optional, Tuple, List, Union
|
||||||
|
|
||||||
|
|
||||||
# Default font (will be loaded lazily)
|
# =============================================================================
|
||||||
_default_font = None
|
# Font Management
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Font cache: (path, size) -> font object
|
||||||
|
_font_cache = {}
|
||||||
|
|
||||||
|
# Common system font directories
|
||||||
|
FONT_DIRS = [
|
||||||
|
"/usr/share/fonts",
|
||||||
|
"/usr/local/share/fonts",
|
||||||
|
"~/.fonts",
|
||||||
|
"~/.local/share/fonts",
|
||||||
|
"/System/Library/Fonts", # macOS
|
||||||
|
"/Library/Fonts", # macOS
|
||||||
|
"C:/Windows/Fonts", # Windows
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default fonts to try (in order of preference)
|
||||||
|
DEFAULT_FONTS = [
|
||||||
|
"DejaVuSans.ttf",
|
||||||
|
"DejaVuSansMono.ttf",
|
||||||
|
"Arial.ttf",
|
||||||
|
"Helvetica.ttf",
|
||||||
|
"FreeSans.ttf",
|
||||||
|
"LiberationSans-Regular.ttf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_default_font(size=16):
|
def _find_font_file(name: str) -> Optional[str]:
|
||||||
"""Get default font, creating if needed."""
|
"""Find a font file by name in system directories."""
|
||||||
global _default_font
|
# If it's already a full path
|
||||||
if _default_font is None or _default_font.size != size:
|
if os.path.isfile(name):
|
||||||
|
return name
|
||||||
|
|
||||||
|
# Expand user paths
|
||||||
|
expanded = os.path.expanduser(name)
|
||||||
|
if os.path.isfile(expanded):
|
||||||
|
return expanded
|
||||||
|
|
||||||
|
# Search in font directories
|
||||||
|
for font_dir in FONT_DIRS:
|
||||||
|
font_dir = os.path.expanduser(font_dir)
|
||||||
|
if not os.path.isdir(font_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Direct match
|
||||||
|
direct = os.path.join(font_dir, name)
|
||||||
|
if os.path.isfile(direct):
|
||||||
|
return direct
|
||||||
|
|
||||||
|
# Recursive search
|
||||||
|
for root, dirs, files in os.walk(font_dir):
|
||||||
|
for f in files:
|
||||||
|
if f.lower() == name.lower():
|
||||||
|
return os.path.join(root, f)
|
||||||
|
# Also match without extension
|
||||||
|
base = os.path.splitext(f)[0]
|
||||||
|
if base.lower() == name.lower():
|
||||||
|
return os.path.join(root, f)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_font(size: int = 24) -> ImageFont.FreeTypeFont:
|
||||||
|
"""Get a default font at the given size."""
|
||||||
|
for font_name in DEFAULT_FONTS:
|
||||||
|
path = _find_font_file(font_name)
|
||||||
|
if path:
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(path, size)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Last resort: PIL default
|
||||||
|
return ImageFont.load_default()
|
||||||
|
|
||||||
|
|
||||||
|
def prim_make_font(name_or_path: str, size: int = 24) -> ImageFont.FreeTypeFont:
|
||||||
|
"""
|
||||||
|
Load a font by name or path.
|
||||||
|
|
||||||
|
(make-font "Arial" 32) ; system font by name
|
||||||
|
(make-font "/path/to/font.ttf" 24) ; font file path
|
||||||
|
(make-font "DejaVuSans" 48) ; searches common locations
|
||||||
|
|
||||||
|
Returns a font object for use with text primitives.
|
||||||
|
"""
|
||||||
|
size = int(size)
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = (name_or_path, size)
|
||||||
|
if cache_key in _font_cache:
|
||||||
|
return _font_cache[cache_key]
|
||||||
|
|
||||||
|
# Find the font file
|
||||||
|
path = _find_font_file(name_or_path)
|
||||||
|
if not path:
|
||||||
|
raise FileNotFoundError(f"Font not found: {name_or_path}")
|
||||||
|
|
||||||
|
# Load and cache
|
||||||
|
font = ImageFont.truetype(path, size)
|
||||||
|
_font_cache[cache_key] = font
|
||||||
|
return font
|
||||||
|
|
||||||
|
|
||||||
|
def prim_list_fonts() -> List[str]:
|
||||||
|
"""
|
||||||
|
List available system fonts.
|
||||||
|
|
||||||
|
(list-fonts) ; -> ("Arial.ttf" "DejaVuSans.ttf" ...)
|
||||||
|
|
||||||
|
Returns list of font filenames found in system directories.
|
||||||
|
"""
|
||||||
|
fonts = set()
|
||||||
|
|
||||||
|
for font_dir in FONT_DIRS:
|
||||||
|
font_dir = os.path.expanduser(font_dir)
|
||||||
|
if not os.path.isdir(font_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(font_dir):
|
||||||
|
for f in files:
|
||||||
|
if f.lower().endswith(('.ttf', '.otf', '.ttc')):
|
||||||
|
fonts.add(f)
|
||||||
|
|
||||||
|
return sorted(fonts)
|
||||||
|
|
||||||
|
|
||||||
|
def prim_font_size(font: ImageFont.FreeTypeFont) -> int:
|
||||||
|
"""
|
||||||
|
Get the size of a font.
|
||||||
|
|
||||||
|
(font-size my-font) ; -> 24
|
||||||
|
"""
|
||||||
|
return font.size
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Text Measurement
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def prim_text_size(text: str, font=None, font_size: int = 24) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Measure text dimensions.
|
||||||
|
|
||||||
|
(text-size "Hello" my-font) ; -> (width height)
|
||||||
|
(text-size "Hello" :font-size 32) ; -> (width height) with default font
|
||||||
|
|
||||||
|
For multi-line text, returns total bounding box.
|
||||||
|
"""
|
||||||
|
if font is None:
|
||||||
|
font = _get_default_font(int(font_size))
|
||||||
|
elif isinstance(font, (int, float)):
|
||||||
|
font = _get_default_font(int(font))
|
||||||
|
|
||||||
|
# Create temporary image for measurement
|
||||||
|
img = Image.new('RGB', (1, 1))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
bbox = draw.textbbox((0, 0), str(text), font=font)
|
||||||
|
width = bbox[2] - bbox[0]
|
||||||
|
height = bbox[3] - bbox[1]
|
||||||
|
|
||||||
|
return (width, height)
|
||||||
|
|
||||||
|
|
||||||
|
def prim_text_metrics(font=None, font_size: int = 24) -> dict:
|
||||||
|
"""
|
||||||
|
Get font metrics.
|
||||||
|
|
||||||
|
(text-metrics my-font) ; -> {ascent: 20, descent: 5, height: 25}
|
||||||
|
|
||||||
|
Useful for precise text layout.
|
||||||
|
"""
|
||||||
|
if font is None:
|
||||||
|
font = _get_default_font(int(font_size))
|
||||||
|
elif isinstance(font, (int, float)):
|
||||||
|
font = _get_default_font(int(font))
|
||||||
|
|
||||||
|
ascent, descent = font.getmetrics()
|
||||||
|
return {
|
||||||
|
'ascent': ascent,
|
||||||
|
'descent': descent,
|
||||||
|
'height': ascent + descent,
|
||||||
|
'size': font.size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def prim_fit_text_size(text: str, max_width: int, max_height: int,
|
||||||
|
font_name: str = None, min_size: int = 8,
|
||||||
|
max_size: int = 500) -> int:
|
||||||
|
"""
|
||||||
|
Calculate font size to fit text within bounds.
|
||||||
|
|
||||||
|
(fit-text-size "Hello World" 400 100) ; -> 48
|
||||||
|
(fit-text-size "Title" 800 200 :font-name "Arial")
|
||||||
|
|
||||||
|
Returns the largest font size that fits within max_width x max_height.
|
||||||
|
"""
|
||||||
|
max_width = int(max_width)
|
||||||
|
max_height = int(max_height)
|
||||||
|
min_size = int(min_size)
|
||||||
|
max_size = int(max_size)
|
||||||
|
text = str(text)
|
||||||
|
|
||||||
|
# Binary search for optimal size
|
||||||
|
best_size = min_size
|
||||||
|
low, high = min_size, max_size
|
||||||
|
|
||||||
|
while low <= high:
|
||||||
|
mid = (low + high) // 2
|
||||||
|
|
||||||
|
if font_name:
|
||||||
|
try:
|
||||||
|
font = prim_make_font(font_name, mid)
|
||||||
|
except:
|
||||||
|
font = _get_default_font(mid)
|
||||||
|
else:
|
||||||
|
font = _get_default_font(mid)
|
||||||
|
|
||||||
|
w, h = prim_text_size(text, font)
|
||||||
|
|
||||||
|
if w <= max_width and h <= max_height:
|
||||||
|
best_size = mid
|
||||||
|
low = mid + 1
|
||||||
|
else:
|
||||||
|
high = mid - 1
|
||||||
|
|
||||||
|
return best_size
|
||||||
|
|
||||||
|
|
||||||
|
def prim_fit_font(text: str, max_width: int, max_height: int,
|
||||||
|
font_name: str = None, min_size: int = 8,
|
||||||
|
max_size: int = 500) -> ImageFont.FreeTypeFont:
|
||||||
|
"""
|
||||||
|
Create a font sized to fit text within bounds.
|
||||||
|
|
||||||
|
(fit-font "Hello World" 400 100) ; -> font object
|
||||||
|
(fit-font "Title" 800 200 :font-name "Arial")
|
||||||
|
|
||||||
|
Returns a font object at the optimal size.
|
||||||
|
"""
|
||||||
|
size = prim_fit_text_size(text, max_width, max_height,
|
||||||
|
font_name, min_size, max_size)
|
||||||
|
|
||||||
|
if font_name:
|
||||||
try:
|
try:
|
||||||
_default_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", size)
|
return prim_make_font(font_name, size)
|
||||||
except:
|
except:
|
||||||
_default_font = ImageFont.load_default()
|
pass
|
||||||
return _default_font
|
|
||||||
|
|
||||||
|
return _get_default_font(size)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Text Drawing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def prim_text(img: np.ndarray, text: str,
|
||||||
|
x: int = None, y: int = None,
|
||||||
|
width: int = None, height: int = None,
|
||||||
|
font=None, font_size: int = 24, font_name: str = None,
|
||||||
|
color=None, opacity: float = 1.0,
|
||||||
|
align: str = "left", valign: str = "top",
|
||||||
|
fit: bool = False,
|
||||||
|
shadow: bool = False, shadow_color=None, shadow_offset: int = 2,
|
||||||
|
outline: bool = False, outline_color=None, outline_width: int = 1,
|
||||||
|
line_spacing: float = 1.2) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Draw text with alignment, opacity, and effects.
|
||||||
|
|
||||||
|
Basic usage:
|
||||||
|
(text frame "Hello" :x 100 :y 50)
|
||||||
|
|
||||||
|
Centered in frame:
|
||||||
|
(text frame "Title" :align "center" :valign "middle")
|
||||||
|
|
||||||
|
Fit to box:
|
||||||
|
(text frame "Big Text" :x 50 :y 50 :width 400 :height 100 :fit true)
|
||||||
|
|
||||||
|
With fade (for animations):
|
||||||
|
(text frame "Fading" :x 100 :y 100 :opacity 0.5)
|
||||||
|
|
||||||
|
With effects:
|
||||||
|
(text frame "Shadow" :x 100 :y 100 :shadow true)
|
||||||
|
(text frame "Outline" :x 100 :y 100 :outline true :outline-color (0 0 0))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: Input frame
|
||||||
|
text: Text to draw
|
||||||
|
x, y: Position (if not specified, uses alignment in full frame)
|
||||||
|
width, height: Bounding box (for fit and alignment within box)
|
||||||
|
font: Font object from make-font
|
||||||
|
font_size: Size if no font specified
|
||||||
|
font_name: Font name to load
|
||||||
|
color: RGB tuple (default white)
|
||||||
|
opacity: 0.0 (invisible) to 1.0 (opaque) for fading
|
||||||
|
align: "left", "center", "right"
|
||||||
|
valign: "top", "middle", "bottom"
|
||||||
|
fit: If true, auto-size font to fit in box
|
||||||
|
shadow: Draw drop shadow
|
||||||
|
shadow_color: Shadow color (default black)
|
||||||
|
shadow_offset: Shadow offset in pixels
|
||||||
|
outline: Draw text outline
|
||||||
|
outline_color: Outline color (default black)
|
||||||
|
outline_width: Outline thickness
|
||||||
|
line_spacing: Multiplier for line height (for multi-line)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frame with text drawn
|
||||||
|
"""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
text = str(text)
|
||||||
|
|
||||||
|
# Default colors
|
||||||
|
if color is None:
|
||||||
|
color = (255, 255, 255)
|
||||||
|
else:
|
||||||
|
color = tuple(int(c) for c in color)
|
||||||
|
|
||||||
|
if shadow_color is None:
|
||||||
|
shadow_color = (0, 0, 0)
|
||||||
|
else:
|
||||||
|
shadow_color = tuple(int(c) for c in shadow_color)
|
||||||
|
|
||||||
|
if outline_color is None:
|
||||||
|
outline_color = (0, 0, 0)
|
||||||
|
else:
|
||||||
|
outline_color = tuple(int(c) for c in outline_color)
|
||||||
|
|
||||||
|
# Determine bounding box
|
||||||
|
if x is None:
|
||||||
|
x = 0
|
||||||
|
if width is None:
|
||||||
|
width = w
|
||||||
|
if y is None:
|
||||||
|
y = 0
|
||||||
|
if height is None:
|
||||||
|
height = h
|
||||||
|
|
||||||
|
x, y = int(x), int(y)
|
||||||
|
box_width = int(width) if width else w - x
|
||||||
|
box_height = int(height) if height else h - y
|
||||||
|
|
||||||
|
# Get or create font
|
||||||
|
if font is None:
|
||||||
|
if fit:
|
||||||
|
font = prim_fit_font(text, box_width, box_height, font_name)
|
||||||
|
elif font_name:
|
||||||
|
try:
|
||||||
|
font = prim_make_font(font_name, int(font_size))
|
||||||
|
except:
|
||||||
|
font = _get_default_font(int(font_size))
|
||||||
|
else:
|
||||||
|
font = _get_default_font(int(font_size))
|
||||||
|
|
||||||
|
# Measure text
|
||||||
|
text_w, text_h = prim_text_size(text, font)
|
||||||
|
|
||||||
|
# Calculate position based on alignment
|
||||||
|
if align == "center":
|
||||||
|
draw_x = x + (box_width - text_w) // 2
|
||||||
|
elif align == "right":
|
||||||
|
draw_x = x + box_width - text_w
|
||||||
|
else: # left
|
||||||
|
draw_x = x
|
||||||
|
|
||||||
|
if valign == "middle":
|
||||||
|
draw_y = y + (box_height - text_h) // 2
|
||||||
|
elif valign == "bottom":
|
||||||
|
draw_y = y + box_height - text_h
|
||||||
|
else: # top
|
||||||
|
draw_y = y
|
||||||
|
|
||||||
|
# Create RGBA image for compositing with opacity
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
|
||||||
|
# Create text layer with transparency
|
||||||
|
text_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(text_layer)
|
||||||
|
|
||||||
|
# Draw shadow first (if enabled)
|
||||||
|
if shadow:
|
||||||
|
shadow_x = draw_x + shadow_offset
|
||||||
|
shadow_y = draw_y + shadow_offset
|
||||||
|
shadow_rgba = shadow_color + (int(255 * opacity * 0.5),)
|
||||||
|
draw.text((shadow_x, shadow_y), text, fill=shadow_rgba, font=font)
|
||||||
|
|
||||||
|
# Draw outline (if enabled)
|
||||||
|
if outline:
|
||||||
|
outline_rgba = outline_color + (int(255 * opacity),)
|
||||||
|
ow = int(outline_width)
|
||||||
|
for dx in range(-ow, ow + 1):
|
||||||
|
for dy in range(-ow, ow + 1):
|
||||||
|
if dx != 0 or dy != 0:
|
||||||
|
draw.text((draw_x + dx, draw_y + dy), text,
|
||||||
|
fill=outline_rgba, font=font)
|
||||||
|
|
||||||
|
# Draw main text
|
||||||
|
text_rgba = color + (int(255 * opacity),)
|
||||||
|
draw.text((draw_x, draw_y), text, fill=text_rgba, font=font)
|
||||||
|
|
||||||
|
# Composite
|
||||||
|
result = Image.alpha_composite(pil_img, text_layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
|
def prim_text_box(img: np.ndarray, text: str,
|
||||||
|
x: int, y: int, width: int, height: int,
|
||||||
|
font=None, font_size: int = 24, font_name: str = None,
|
||||||
|
color=None, opacity: float = 1.0,
|
||||||
|
align: str = "center", valign: str = "middle",
|
||||||
|
fit: bool = True,
|
||||||
|
padding: int = 0,
|
||||||
|
background=None, background_opacity: float = 0.5,
|
||||||
|
**kwargs) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Draw text fitted within a box, optionally with background.
|
||||||
|
|
||||||
|
(text-box frame "Title" 50 50 400 100)
|
||||||
|
(text-box frame "Subtitle" 50 160 400 50
|
||||||
|
:background (0 0 0) :background-opacity 0.7)
|
||||||
|
|
||||||
|
Convenience wrapper around text() for common box-with-text pattern.
|
||||||
|
"""
|
||||||
|
x, y = int(x), int(y)
|
||||||
|
width, height = int(width), int(height)
|
||||||
|
padding = int(padding)
|
||||||
|
|
||||||
|
result = img.copy()
|
||||||
|
|
||||||
|
# Draw background if specified
|
||||||
|
if background is not None:
|
||||||
|
bg_color = tuple(int(c) for c in background)
|
||||||
|
|
||||||
|
# Create background with opacity
|
||||||
|
pil_img = Image.fromarray(result).convert('RGBA')
|
||||||
|
bg_layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
bg_draw = ImageDraw.Draw(bg_layer)
|
||||||
|
bg_rgba = bg_color + (int(255 * background_opacity),)
|
||||||
|
bg_draw.rectangle([x, y, x + width, y + height], fill=bg_rgba)
|
||||||
|
result = np.array(Image.alpha_composite(pil_img, bg_layer).convert('RGB'))
|
||||||
|
|
||||||
|
# Draw text within padded box
|
||||||
|
return prim_text(result, text,
|
||||||
|
x=x + padding, y=y + padding,
|
||||||
|
width=width - 2 * padding, height=height - 2 * padding,
|
||||||
|
font=font, font_size=font_size, font_name=font_name,
|
||||||
|
color=color, opacity=opacity,
|
||||||
|
align=align, valign=valign, fit=fit,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Legacy text functions (keep for compatibility)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
def prim_draw_char(img, char, x, y, font_size=16, color=None):
|
def prim_draw_char(img, char, x, y, font_size=16, color=None):
|
||||||
"""Draw a single character at (x, y)."""
|
"""Draw a single character at (x, y). Legacy function."""
|
||||||
if color is None:
|
return prim_text(img, str(char), x=int(x), y=int(y),
|
||||||
color = [255, 255, 255]
|
font_size=int(font_size), color=color)
|
||||||
|
|
||||||
pil_img = Image.fromarray(img)
|
|
||||||
draw = ImageDraw.Draw(pil_img)
|
|
||||||
font = _get_default_font(font_size)
|
|
||||||
draw.text((x, y), char, fill=tuple(color), font=font)
|
|
||||||
return np.array(pil_img)
|
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_text(img, text, x, y, font_size=16, color=None):
|
def prim_draw_text(img, text, x, y, font_size=16, color=None):
|
||||||
"""Draw text string at (x, y)."""
|
"""Draw text string at (x, y). Legacy function."""
|
||||||
|
return prim_text(img, str(text), x=int(x), y=int(y),
|
||||||
|
font_size=int(font_size), color=color)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Shape Drawing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def prim_fill_rect(img, x, y, w, h, color=None, opacity: float = 1.0):
|
||||||
|
"""
|
||||||
|
Fill a rectangle with color.
|
||||||
|
|
||||||
|
(fill-rect frame 10 10 100 50 (255 0 0))
|
||||||
|
(fill-rect frame 10 10 100 50 (255 0 0) :opacity 0.5)
|
||||||
|
"""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
pil_img = Image.fromarray(img)
|
|
||||||
draw = ImageDraw.Draw(pil_img)
|
|
||||||
font = _get_default_font(font_size)
|
|
||||||
draw.text((x, y), text, fill=tuple(color), font=font)
|
|
||||||
return np.array(pil_img)
|
|
||||||
|
|
||||||
|
|
||||||
def prim_fill_rect(img, x, y, w, h, color=None):
|
|
||||||
"""Fill a rectangle with color."""
|
|
||||||
if color is None:
|
|
||||||
color = [255, 255, 255]
|
|
||||||
|
|
||||||
result = img.copy()
|
|
||||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||||
result[y:y+h, x:x+w] = color
|
|
||||||
return result
|
if opacity >= 1.0:
|
||||||
|
result = img.copy()
|
||||||
|
result[y:y+h, x:x+w] = color
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity, use alpha compositing
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(layer)
|
||||||
|
fill_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
draw.rectangle([x, y, x + w, y + h], fill=fill_rgba)
|
||||||
|
result = Image.alpha_composite(pil_img, layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1):
|
def prim_draw_rect(img, x, y, w, h, color=None, thickness=1, opacity: float = 1.0):
|
||||||
"""Draw rectangle outline."""
|
"""Draw rectangle outline."""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
result = img.copy()
|
if opacity >= 1.0:
|
||||||
cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)),
|
result = img.copy()
|
||||||
tuple(color), thickness)
|
cv2.rectangle(result, (int(x), int(y)), (int(x+w), int(y+h)),
|
||||||
return result
|
tuple(int(c) for c in color), int(thickness))
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(layer)
|
||||||
|
outline_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
draw.rectangle([int(x), int(y), int(x+w), int(y+h)],
|
||||||
|
outline=outline_rgba, width=int(thickness))
|
||||||
|
result = Image.alpha_composite(pil_img, layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1):
|
def prim_draw_line(img, x1, y1, x2, y2, color=None, thickness=1, opacity: float = 1.0):
|
||||||
"""Draw a line from (x1, y1) to (x2, y2)."""
|
"""Draw a line from (x1, y1) to (x2, y2)."""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
result = img.copy()
|
if opacity >= 1.0:
|
||||||
cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)),
|
result = img.copy()
|
||||||
tuple(color), thickness)
|
cv2.line(result, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||||
return result
|
tuple(int(c) for c in color), int(thickness))
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(layer)
|
||||||
|
line_rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
draw.line([(int(x1), int(y1)), (int(x2), int(y2))],
|
||||||
|
fill=line_rgba, width=int(thickness))
|
||||||
|
result = Image.alpha_composite(pil_img, layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1, fill=False):
|
def prim_draw_circle(img, cx, cy, radius, color=None, thickness=1,
|
||||||
|
fill=False, opacity: float = 1.0):
|
||||||
"""Draw a circle."""
|
"""Draw a circle."""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
result = img.copy()
|
if opacity >= 1.0:
|
||||||
t = -1 if fill else thickness
|
result = img.copy()
|
||||||
cv2.circle(result, (int(cx), int(cy)), int(radius), tuple(color), t)
|
t = -1 if fill else int(thickness)
|
||||||
return result
|
cv2.circle(result, (int(cx), int(cy)), int(radius),
|
||||||
|
tuple(int(c) for c in color), t)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(layer)
|
||||||
|
cx, cy, r = int(cx), int(cy), int(radius)
|
||||||
|
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
|
||||||
|
if fill:
|
||||||
|
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=rgba)
|
||||||
|
else:
|
||||||
|
draw.ellipse([cx - r, cy - r, cx + r, cy + r],
|
||||||
|
outline=rgba, width=int(thickness))
|
||||||
|
|
||||||
|
result = Image.alpha_composite(pil_img, layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None, thickness=1, fill=False):
|
def prim_draw_ellipse(img, cx, cy, rx, ry, angle=0, color=None,
|
||||||
|
thickness=1, fill=False, opacity: float = 1.0):
|
||||||
"""Draw an ellipse."""
|
"""Draw an ellipse."""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
result = img.copy()
|
if opacity >= 1.0:
|
||||||
t = -1 if fill else thickness
|
result = img.copy()
|
||||||
cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)),
|
t = -1 if fill else int(thickness)
|
||||||
angle, 0, 360, tuple(color), t)
|
cv2.ellipse(result, (int(cx), int(cy)), (int(rx), int(ry)),
|
||||||
return result
|
float(angle), 0, 360, tuple(int(c) for c in color), t)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity (note: PIL doesn't support rotated ellipses easily)
|
||||||
|
# Fall back to cv2 on a separate layer
|
||||||
|
layer = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
||||||
|
t = -1 if fill else int(thickness)
|
||||||
|
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
cv2.ellipse(layer, (int(cx), int(cy)), (int(rx), int(ry)),
|
||||||
|
float(angle), 0, 360, rgba, t)
|
||||||
|
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
pil_layer = Image.fromarray(layer)
|
||||||
|
result = Image.alpha_composite(pil_img, pil_layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
def prim_draw_polygon(img, points, color=None, thickness=1, fill=False):
|
def prim_draw_polygon(img, points, color=None, thickness=1,
|
||||||
|
fill=False, opacity: float = 1.0):
|
||||||
"""Draw a polygon from list of [x, y] points."""
|
"""Draw a polygon from list of [x, y] points."""
|
||||||
if color is None:
|
if color is None:
|
||||||
color = [255, 255, 255]
|
color = [255, 255, 255]
|
||||||
|
|
||||||
result = img.copy()
|
if opacity >= 1.0:
|
||||||
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
|
result = img.copy()
|
||||||
|
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
|
||||||
|
if fill:
|
||||||
|
cv2.fillPoly(result, [pts], tuple(int(c) for c in color))
|
||||||
|
else:
|
||||||
|
cv2.polylines(result, [pts], True,
|
||||||
|
tuple(int(c) for c in color), int(thickness))
|
||||||
|
return result
|
||||||
|
|
||||||
|
# With opacity
|
||||||
|
pil_img = Image.fromarray(img).convert('RGBA')
|
||||||
|
layer = Image.new('RGBA', (pil_img.width, pil_img.height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(layer)
|
||||||
|
|
||||||
|
pts_flat = [(int(p[0]), int(p[1])) for p in points]
|
||||||
|
rgba = tuple(int(c) for c in color) + (int(255 * opacity),)
|
||||||
|
|
||||||
if fill:
|
if fill:
|
||||||
cv2.fillPoly(result, [pts], tuple(color))
|
draw.polygon(pts_flat, fill=rgba)
|
||||||
else:
|
else:
|
||||||
cv2.polylines(result, [pts], True, tuple(color), thickness)
|
draw.polygon(pts_flat, outline=rgba, width=int(thickness))
|
||||||
|
|
||||||
return result
|
result = Image.alpha_composite(pil_img, layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PRIMITIVES Export
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
PRIMITIVES = {
|
PRIMITIVES = {
|
||||||
# Text
|
# Font management
|
||||||
|
'make-font': prim_make_font,
|
||||||
|
'list-fonts': prim_list_fonts,
|
||||||
|
'font-size': prim_font_size,
|
||||||
|
|
||||||
|
# Text measurement
|
||||||
|
'text-size': prim_text_size,
|
||||||
|
'text-metrics': prim_text_metrics,
|
||||||
|
'fit-text-size': prim_fit_text_size,
|
||||||
|
'fit-font': prim_fit_font,
|
||||||
|
|
||||||
|
# Text drawing
|
||||||
|
'text': prim_text,
|
||||||
|
'text-box': prim_text_box,
|
||||||
|
|
||||||
|
# Legacy text (compatibility)
|
||||||
'draw-char': prim_draw_char,
|
'draw-char': prim_draw_char,
|
||||||
'draw-text': prim_draw_text,
|
'draw-text': prim_draw_text,
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,18 @@ GPU Acceleration:
|
|||||||
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
||||||
- Hardware video decoding (NVDEC) is used when available
|
- Hardware video decoding (NVDEC) is used when available
|
||||||
- Dramatically improves performance on GPU nodes
|
- Dramatically improves performance on GPU nodes
|
||||||
|
|
||||||
|
Async Prefetching:
|
||||||
|
- Set STREAMING_PREFETCH=1 to enable background frame prefetching
|
||||||
|
- Decodes upcoming frames while current frame is being processed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import subprocess
|
import subprocess
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Try to import CuPy for GPU acceleration
|
# Try to import CuPy for GPU acceleration
|
||||||
@@ -28,6 +34,10 @@ except ImportError:
|
|||||||
# Disabled by default until all primitives support GPU frames
|
# Disabled by default until all primitives support GPU frames
|
||||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
||||||
|
|
||||||
|
# Async prefetch mode - decode frames in background thread
|
||||||
|
PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1"
|
||||||
|
PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10"))
|
||||||
|
|
||||||
# Check for hardware decode support (cached)
|
# Check for hardware decode support (cached)
|
||||||
_HWDEC_AVAILABLE = None
|
_HWDEC_AVAILABLE = None
|
||||||
|
|
||||||
@@ -283,6 +293,122 @@ class VideoSource:
|
|||||||
self._proc = None
|
self._proc = None
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchingVideoSource:
|
||||||
|
"""
|
||||||
|
Video source with background prefetching for improved performance.
|
||||||
|
|
||||||
|
Wraps VideoSource and adds a background thread that pre-decodes
|
||||||
|
upcoming frames while the main thread processes the current frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: str, fps: float = 30, buffer_size: int = None):
|
||||||
|
self._source = VideoSource(path, fps)
|
||||||
|
self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE
|
||||||
|
self._buffer = {} # time -> frame
|
||||||
|
self._buffer_lock = threading.Lock()
|
||||||
|
self._prefetch_time = 0.0
|
||||||
|
self._frame_time = 1.0 / fps
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._request_event = threading.Event()
|
||||||
|
self._target_time = 0.0
|
||||||
|
|
||||||
|
# Start prefetch thread
|
||||||
|
self._thread = threading.Thread(target=self._prefetch_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
import sys
|
||||||
|
print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr)
|
||||||
|
|
||||||
|
def _prefetch_loop(self):
|
||||||
|
"""Background thread that pre-reads frames."""
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
# Wait for work or timeout
|
||||||
|
self._request_event.wait(timeout=0.01)
|
||||||
|
self._request_event.clear()
|
||||||
|
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prefetch frames ahead of target time
|
||||||
|
target = self._target_time
|
||||||
|
with self._buffer_lock:
|
||||||
|
# Clean old frames (more than 1 second behind)
|
||||||
|
old_times = [t for t in self._buffer.keys() if t < target - 1.0]
|
||||||
|
for t in old_times:
|
||||||
|
del self._buffer[t]
|
||||||
|
|
||||||
|
# Count how many frames we have buffered ahead
|
||||||
|
buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target)
|
||||||
|
|
||||||
|
# Prefetch if buffer not full
|
||||||
|
if buffered_ahead < self._buffer_size:
|
||||||
|
# Find next time to prefetch
|
||||||
|
prefetch_t = target
|
||||||
|
with self._buffer_lock:
|
||||||
|
existing_times = set(self._buffer.keys())
|
||||||
|
for _ in range(self._buffer_size):
|
||||||
|
if prefetch_t not in existing_times:
|
||||||
|
break
|
||||||
|
prefetch_t += self._frame_time
|
||||||
|
|
||||||
|
# Read the frame (this is the slow part)
|
||||||
|
try:
|
||||||
|
frame = self._source.read_at(prefetch_t)
|
||||||
|
with self._buffer_lock:
|
||||||
|
self._buffer[prefetch_t] = frame
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
def read_at(self, t: float) -> np.ndarray:
|
||||||
|
"""Read frame at specific time, using prefetch buffer if available."""
|
||||||
|
self._target_time = t
|
||||||
|
self._request_event.set() # Wake up prefetch thread
|
||||||
|
|
||||||
|
# Round to frame time for buffer lookup
|
||||||
|
t_key = round(t / self._frame_time) * self._frame_time
|
||||||
|
|
||||||
|
# Check buffer first
|
||||||
|
with self._buffer_lock:
|
||||||
|
if t_key in self._buffer:
|
||||||
|
return self._buffer[t_key]
|
||||||
|
# Also check for close matches (within half frame time)
|
||||||
|
for buf_t, frame in self._buffer.items():
|
||||||
|
if abs(buf_t - t) < self._frame_time * 0.5:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
# Not in buffer - read directly (blocking)
|
||||||
|
frame = self._source.read_at(t)
|
||||||
|
|
||||||
|
# Store in buffer
|
||||||
|
with self._buffer_lock:
|
||||||
|
self._buffer[t_key] = frame
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def read(self) -> np.ndarray:
|
||||||
|
"""Read frame (uses last cached or t=0)."""
|
||||||
|
return self.read_at(0)
|
||||||
|
|
||||||
|
def skip(self):
|
||||||
|
"""No-op for seek-based reading."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
return self._source.size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return self._source.path
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._stop_event.set()
|
||||||
|
self._request_event.set() # Wake up thread to exit
|
||||||
|
self._thread.join(timeout=1.0)
|
||||||
|
self._source.close()
|
||||||
|
|
||||||
|
|
||||||
class AudioAnalyzer:
|
class AudioAnalyzer:
|
||||||
"""Audio analyzer for energy and beat detection."""
|
"""Audio analyzer for energy and beat detection."""
|
||||||
|
|
||||||
@@ -394,7 +520,12 @@ class AudioAnalyzer:
|
|||||||
# === Primitives ===
|
# === Primitives ===
|
||||||
|
|
||||||
def prim_make_video_source(path: str, fps: float = 30):
|
def prim_make_video_source(path: str, fps: float = 30):
|
||||||
"""Create a video source from a file path."""
|
"""Create a video source from a file path.
|
||||||
|
|
||||||
|
Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default).
|
||||||
|
"""
|
||||||
|
if PREFETCH_ENABLED:
|
||||||
|
return PrefetchingVideoSource(path, fps)
|
||||||
return VideoSource(path, fps)
|
return VideoSource(path, fps)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1382
sexp_effects/primitive_libs/xector.py
Normal file
1382
sexp_effects/primitive_libs/xector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -797,31 +797,63 @@ def prim_tan(x: float) -> float:
|
|||||||
return math.tan(x)
|
return math.tan(x)
|
||||||
|
|
||||||
|
|
||||||
def prim_atan2(y: float, x: float) -> float:
|
def prim_atan2(y, x):
|
||||||
|
if hasattr(y, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.arctan2(y._data, x._data if hasattr(x, '_data') else x), y._shape)
|
||||||
return math.atan2(y, x)
|
return math.atan2(y, x)
|
||||||
|
|
||||||
|
|
||||||
def prim_sqrt(x: float) -> float:
|
def prim_sqrt(x):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.sqrt(np.maximum(0, x._data)), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.sqrt(np.maximum(0, x))
|
||||||
return math.sqrt(max(0, x))
|
return math.sqrt(max(0, x))
|
||||||
|
|
||||||
|
|
||||||
def prim_pow(x: float, y: float) -> float:
|
def prim_pow(x, y):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
y_data = y._data if hasattr(y, '_data') else y
|
||||||
|
return Xector(np.power(x._data, y_data), x._shape)
|
||||||
return math.pow(x, y)
|
return math.pow(x, y)
|
||||||
|
|
||||||
|
|
||||||
def prim_abs(x: float) -> float:
|
def prim_abs(x):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.abs(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.abs(x)
|
||||||
return abs(x)
|
return abs(x)
|
||||||
|
|
||||||
|
|
||||||
def prim_floor(x: float) -> int:
|
def prim_floor(x):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.floor(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.floor(x)
|
||||||
return int(math.floor(x))
|
return int(math.floor(x))
|
||||||
|
|
||||||
|
|
||||||
def prim_ceil(x: float) -> int:
|
def prim_ceil(x):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.ceil(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.ceil(x)
|
||||||
return int(math.ceil(x))
|
return int(math.ceil(x))
|
||||||
|
|
||||||
|
|
||||||
def prim_round(x: float) -> int:
|
def prim_round(x):
|
||||||
|
if hasattr(x, '_data'): # Xector
|
||||||
|
from sexp_effects.primitive_libs.xector import Xector
|
||||||
|
return Xector(np.round(x._data), x._shape)
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.round(x)
|
||||||
return int(round(x))
|
return int(round(x))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
860
streaming/jax_typography.py
Normal file
860
streaming/jax_typography.py
Normal file
@@ -0,0 +1,860 @@
|
|||||||
|
"""
|
||||||
|
JAX Typography Primitives
|
||||||
|
|
||||||
|
Two approaches for text rendering, both compile to JAX/GPU:
|
||||||
|
|
||||||
|
## 1. TextStrip - Pixel-perfect static text
|
||||||
|
Pre-render entire strings at compile time using PIL.
|
||||||
|
Perfect sub-pixel anti-aliasing, exact match with PIL.
|
||||||
|
Use for: static titles, labels, any text without per-character effects.
|
||||||
|
|
||||||
|
S-expression:
|
||||||
|
(let ((strip (render-text-strip "Hello World" 48)))
|
||||||
|
(place-text-strip frame strip x y :color white))
|
||||||
|
|
||||||
|
## 2. Glyph-by-glyph - Dynamic text effects
|
||||||
|
Individual glyph placement for wave, arc, audio-reactive effects.
|
||||||
|
Each character can have independent position, color, opacity.
|
||||||
|
Note: slight anti-aliasing differences vs PIL due to integer positioning.
|
||||||
|
|
||||||
|
S-expression:
|
||||||
|
; Wave text - y oscillates with character index
|
||||||
|
(let ((glyphs (text-glyphs "Wavy" 48)))
|
||||||
|
(first
|
||||||
|
(fold glyphs (list frame 0)
|
||||||
|
(lambda (acc g)
|
||||||
|
(let ((frm (first acc))
|
||||||
|
(cursor (second acc))
|
||||||
|
(i (length acc))) ; approximate index
|
||||||
|
(list
|
||||||
|
(place-glyph frm (glyph-image g)
|
||||||
|
(+ x cursor)
|
||||||
|
(+ y (* amplitude (sin (* i frequency))))
|
||||||
|
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||||
|
white 1.0)
|
||||||
|
(+ cursor (glyph-advance g))))))))
|
||||||
|
|
||||||
|
; Audio-reactive spacing
|
||||||
|
(let ((glyphs (text-glyphs "Bass" 48))
|
||||||
|
(bass (audio-band 0 200)))
|
||||||
|
(first
|
||||||
|
(fold glyphs (list frame 0)
|
||||||
|
(lambda (acc g)
|
||||||
|
(let ((frm (first acc))
|
||||||
|
(cursor (second acc)))
|
||||||
|
(list
|
||||||
|
(place-glyph frm (glyph-image g)
|
||||||
|
(+ x cursor) y
|
||||||
|
(glyph-bearing-x g) (glyph-bearing-y g)
|
||||||
|
white 1.0)
|
||||||
|
(+ cursor (glyph-advance g) (* bass 20))))))))
|
||||||
|
|
||||||
|
Kerning support:
|
||||||
|
; With kerning adjustment
|
||||||
|
(+ cursor (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from typing import Tuple, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Glyph Data (computed at compile time)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GlyphData:
|
||||||
|
"""Glyph data computed at compile time.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
char: The character
|
||||||
|
image: RGBA image as numpy array (H, W, 4) - converted to JAX at runtime
|
||||||
|
advance: Horizontal advance (distance to next glyph origin)
|
||||||
|
bearing_x: Left side bearing (x offset from origin to first pixel)
|
||||||
|
bearing_y: Top bearing (y offset from baseline to top of glyph)
|
||||||
|
width: Image width
|
||||||
|
height: Image height
|
||||||
|
"""
|
||||||
|
char: str
|
||||||
|
image: np.ndarray # (H, W, 4) RGBA uint8
|
||||||
|
advance: float
|
||||||
|
bearing_x: float
|
||||||
|
bearing_y: float
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
|
||||||
|
|
||||||
|
# Font cache: (font_name, font_size) -> {char: GlyphData}
|
||||||
|
_GLYPH_CACHE: Dict[Tuple, Dict[str, GlyphData]] = {}
|
||||||
|
|
||||||
|
# Font metrics cache: (font_name, font_size) -> (ascent, descent)
|
||||||
|
_METRICS_CACHE: Dict[Tuple, Tuple[float, float]] = {}
|
||||||
|
|
||||||
|
# Kerning cache: (font_name, font_size) -> {(char1, char2): adjustment}
|
||||||
|
# Kerning adjustment is added to advance: new_advance = advance + kerning
|
||||||
|
# Typically negative (characters move closer together)
|
||||||
|
_KERNING_CACHE: Dict[Tuple, Dict[Tuple[str, str], float]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_font(font_name: str = None, font_size: int = 32):
|
||||||
|
"""Load a font. Called at compile time."""
|
||||||
|
from PIL import ImageFont
|
||||||
|
|
||||||
|
candidates = [
|
||||||
|
font_name,
|
||||||
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||||
|
'/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf',
|
||||||
|
'/usr/share/fonts/truetype/freefont/FreeSans.ttf',
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in candidates:
|
||||||
|
if path is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(path, font_size)
|
||||||
|
except (IOError, OSError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ImageFont.load_default()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_glyph_cache(font_name: str = None, font_size: int = 32) -> Dict[str, GlyphData]:
|
||||||
|
"""Get or create glyph cache for a font. Called at compile time."""
|
||||||
|
cache_key = (font_name, font_size)
|
||||||
|
|
||||||
|
if cache_key in _GLYPH_CACHE:
|
||||||
|
return _GLYPH_CACHE[cache_key]
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
font = _load_font(font_name, font_size)
|
||||||
|
ascent, descent = font.getmetrics()
|
||||||
|
_METRICS_CACHE[cache_key] = (ascent, descent)
|
||||||
|
|
||||||
|
glyphs = {}
|
||||||
|
charset = ''.join(chr(i) for i in range(32, 127))
|
||||||
|
|
||||||
|
temp_img = Image.new('RGBA', (200, 200), (0, 0, 0, 0))
|
||||||
|
temp_draw = ImageDraw.Draw(temp_img)
|
||||||
|
|
||||||
|
for char in charset:
|
||||||
|
# Get metrics
|
||||||
|
bbox = temp_draw.textbbox((0, 0), char, font=font)
|
||||||
|
advance = font.getlength(char)
|
||||||
|
|
||||||
|
x_min, y_min, x_max, y_max = bbox
|
||||||
|
|
||||||
|
# Create glyph image with padding
|
||||||
|
padding = 2
|
||||||
|
img_w = max(int(x_max - x_min) + padding * 2, 1)
|
||||||
|
img_h = max(int(y_max - y_min) + padding * 2, 1)
|
||||||
|
|
||||||
|
glyph_img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
glyph_draw = ImageDraw.Draw(glyph_img)
|
||||||
|
|
||||||
|
# Draw at position accounting for bbox offset
|
||||||
|
draw_x = padding - x_min
|
||||||
|
draw_y = padding - y_min
|
||||||
|
glyph_draw.text((draw_x, draw_y), char, fill=(255, 255, 255, 255), font=font)
|
||||||
|
|
||||||
|
glyphs[char] = GlyphData(
|
||||||
|
char=char,
|
||||||
|
image=np.array(glyph_img, dtype=np.uint8),
|
||||||
|
advance=float(advance),
|
||||||
|
bearing_x=float(x_min),
|
||||||
|
bearing_y=float(-y_min), # Distance from baseline to top
|
||||||
|
width=img_w,
|
||||||
|
height=img_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
_GLYPH_CACHE[cache_key] = glyphs
|
||||||
|
return glyphs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kerning_cache(font_name: str = None, font_size: int = 32) -> Dict[Tuple[str, str], float]:
|
||||||
|
"""Get or create kerning cache for a font. Called at compile time.
|
||||||
|
|
||||||
|
Kerning is computed as:
|
||||||
|
kerning(a, b) = getlength(a + b) - getlength(a) - getlength(b)
|
||||||
|
|
||||||
|
This gives the adjustment needed when placing 'b' after 'a'.
|
||||||
|
Typically negative (characters move closer together).
|
||||||
|
"""
|
||||||
|
cache_key = (font_name, font_size)
|
||||||
|
|
||||||
|
if cache_key in _KERNING_CACHE:
|
||||||
|
return _KERNING_CACHE[cache_key]
|
||||||
|
|
||||||
|
font = _load_font(font_name, font_size)
|
||||||
|
kerning = {}
|
||||||
|
|
||||||
|
# Compute kerning for all printable ASCII pairs
|
||||||
|
charset = ''.join(chr(i) for i in range(32, 127))
|
||||||
|
|
||||||
|
# Pre-compute individual character lengths
|
||||||
|
char_lengths = {c: font.getlength(c) for c in charset}
|
||||||
|
|
||||||
|
# Compute kerning for each pair
|
||||||
|
for c1 in charset:
|
||||||
|
for c2 in charset:
|
||||||
|
pair_length = font.getlength(c1 + c2)
|
||||||
|
individual_sum = char_lengths[c1] + char_lengths[c2]
|
||||||
|
kern = pair_length - individual_sum
|
||||||
|
|
||||||
|
# Only store non-zero kerning to save memory
|
||||||
|
if abs(kern) > 0.01:
|
||||||
|
kerning[(c1, c2)] = kern
|
||||||
|
|
||||||
|
_KERNING_CACHE[cache_key] = kerning
|
||||||
|
return kerning
|
||||||
|
|
||||||
|
|
||||||
|
def get_kerning(char1: str, char2: str, font_name: str = None, font_size: int = 32) -> float:
|
||||||
|
"""Get kerning adjustment between two characters. Compile-time.
|
||||||
|
|
||||||
|
Returns the adjustment to add to char1's advance when char2 follows.
|
||||||
|
Typically negative (characters move closer).
|
||||||
|
|
||||||
|
Usage in S-expression:
|
||||||
|
(+ (glyph-advance g1) (kerning g1 g2))
|
||||||
|
"""
|
||||||
|
kerning_cache = _get_kerning_cache(font_name, font_size)
|
||||||
|
return kerning_cache.get((char1, char2), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextStrip:
|
||||||
|
"""Pre-rendered text strip with proper sub-pixel anti-aliasing.
|
||||||
|
|
||||||
|
Rendered at compile time using PIL for exact matching.
|
||||||
|
At runtime, just composite onto frame at integer positions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
text: The original text
|
||||||
|
image: RGBA image as numpy array (H, W, 4)
|
||||||
|
width: Strip width
|
||||||
|
height: Strip height
|
||||||
|
baseline_y: Y position of baseline within the strip
|
||||||
|
bearing_x: Left side bearing of first character
|
||||||
|
anchor_x: X offset for anchor point (0 for left, width/2 for center, width for right)
|
||||||
|
anchor_y: Y offset for anchor point (depends on anchor type)
|
||||||
|
stroke_width: Stroke width used when rendering
|
||||||
|
"""
|
||||||
|
text: str
|
||||||
|
image: np.ndarray
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
baseline_y: int
|
||||||
|
bearing_x: float
|
||||||
|
anchor_x: float = 0.0
|
||||||
|
anchor_y: float = 0.0
|
||||||
|
stroke_width: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Text strip cache: cache_key -> TextStrip
|
||||||
|
_TEXT_STRIP_CACHE: Dict[Tuple, TextStrip] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def render_text_strip(
|
||||||
|
text: str,
|
||||||
|
font_name: str = None,
|
||||||
|
font_size: int = 32,
|
||||||
|
stroke_width: int = 0,
|
||||||
|
stroke_fill: tuple = None,
|
||||||
|
anchor: str = "la", # left-ascender (PIL default is "la")
|
||||||
|
multiline: bool = False,
|
||||||
|
line_spacing: int = 4,
|
||||||
|
align: str = "left",
|
||||||
|
) -> TextStrip:
|
||||||
|
"""Render text to a strip at compile time. Perfect sub-pixel anti-aliasing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to render
|
||||||
|
font_name: Path to font file (None for default)
|
||||||
|
font_size: Font size in pixels
|
||||||
|
stroke_width: Outline width in pixels (0 for no outline)
|
||||||
|
stroke_fill: Outline color as (R,G,B) or (R,G,B,A), default black
|
||||||
|
anchor: PIL anchor code - first char: h=left, m=middle, r=right
|
||||||
|
second char: a=ascender, t=top, m=middle, s=baseline, d=descender
|
||||||
|
multiline: If True, handle newlines in text
|
||||||
|
line_spacing: Extra pixels between lines (for multiline)
|
||||||
|
align: 'left', 'center', 'right' (for multiline)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TextStrip with pre-rendered text
|
||||||
|
"""
|
||||||
|
# Build cache key from all parameters
|
||||||
|
cache_key = (text, font_name, font_size, stroke_width, stroke_fill, anchor, multiline, line_spacing, align)
|
||||||
|
if cache_key in _TEXT_STRIP_CACHE:
|
||||||
|
return _TEXT_STRIP_CACHE[cache_key]
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
font = _load_font(font_name, font_size)
|
||||||
|
ascent, descent = font.getmetrics()
|
||||||
|
|
||||||
|
# Default stroke fill to black
|
||||||
|
if stroke_fill is None:
|
||||||
|
stroke_fill = (0, 0, 0, 255)
|
||||||
|
elif len(stroke_fill) == 3:
|
||||||
|
stroke_fill = (*stroke_fill, 255)
|
||||||
|
|
||||||
|
# Get text bbox (accounting for stroke)
|
||||||
|
temp = Image.new('RGBA', (1, 1))
|
||||||
|
temp_draw = ImageDraw.Draw(temp)
|
||||||
|
|
||||||
|
if multiline:
|
||||||
|
bbox = temp_draw.multiline_textbbox((0, 0), text, font=font, spacing=line_spacing,
|
||||||
|
stroke_width=stroke_width)
|
||||||
|
else:
|
||||||
|
bbox = temp_draw.textbbox((0, 0), text, font=font, stroke_width=stroke_width)
|
||||||
|
|
||||||
|
# bbox is (left, top, right, bottom) relative to origin
|
||||||
|
x_min, y_min, x_max, y_max = bbox
|
||||||
|
|
||||||
|
# Create image with padding (extra for stroke)
|
||||||
|
padding = 2 + stroke_width
|
||||||
|
img_width = max(int(x_max - x_min) + padding * 2, 1)
|
||||||
|
img_height = max(int(y_max - y_min) + padding * 2, 1)
|
||||||
|
|
||||||
|
# Create RGBA image
|
||||||
|
img = Image.new('RGBA', (img_width, img_height), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# Draw text at position that puts it in the image
|
||||||
|
draw_x = padding - x_min
|
||||||
|
draw_y = padding - y_min
|
||||||
|
|
||||||
|
if multiline:
|
||||||
|
draw.multiline_text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||||
|
spacing=line_spacing, align=align,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||||
|
else:
|
||||||
|
draw.text((draw_x, draw_y), text, fill=(255, 255, 255, 255), font=font,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill)
|
||||||
|
|
||||||
|
# Baseline is at y=0 in text coordinates, which is at draw_y in image
|
||||||
|
baseline_y = draw_y
|
||||||
|
|
||||||
|
# Convert to numpy for pixel analysis
|
||||||
|
img_array = np.array(img, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Calculate anchor offsets
|
||||||
|
# For 'm' (middle) anchors, compute from actual rendered pixels for pixel-perfect matching
|
||||||
|
h_anchor = anchor[0] if len(anchor) > 0 else 'l'
|
||||||
|
v_anchor = anchor[1] if len(anchor) > 1 else 'a'
|
||||||
|
|
||||||
|
# Find actual pixel bounds (for middle anchor calculations)
|
||||||
|
alpha = img_array[:, :, 3]
|
||||||
|
nonzero_cols = np.where(alpha.max(axis=0) > 0)[0]
|
||||||
|
nonzero_rows = np.where(alpha.max(axis=1) > 0)[0]
|
||||||
|
|
||||||
|
if len(nonzero_cols) > 0:
|
||||||
|
pixel_x_min = nonzero_cols.min()
|
||||||
|
pixel_x_max = nonzero_cols.max()
|
||||||
|
pixel_x_center = (pixel_x_min + pixel_x_max) / 2.0
|
||||||
|
else:
|
||||||
|
pixel_x_center = img_width / 2.0
|
||||||
|
|
||||||
|
if len(nonzero_rows) > 0:
|
||||||
|
pixel_y_min = nonzero_rows.min()
|
||||||
|
pixel_y_max = nonzero_rows.max()
|
||||||
|
pixel_y_center = (pixel_y_min + pixel_y_max) / 2.0
|
||||||
|
else:
|
||||||
|
pixel_y_center = img_height / 2.0
|
||||||
|
|
||||||
|
# Horizontal offset
|
||||||
|
text_width = x_max - x_min
|
||||||
|
if h_anchor == 'l': # left edge of text
|
||||||
|
anchor_x = float(draw_x)
|
||||||
|
elif h_anchor == 'm': # middle - use actual pixel center for perfect matching
|
||||||
|
anchor_x = pixel_x_center
|
||||||
|
elif h_anchor == 'r': # right edge of text
|
||||||
|
anchor_x = float(draw_x + text_width)
|
||||||
|
else:
|
||||||
|
anchor_x = float(draw_x)
|
||||||
|
|
||||||
|
# Vertical offset
|
||||||
|
# PIL anchor positions are based on font metrics (ascent/descent):
|
||||||
|
# - 'a' (ascender): at the ascender line → draw_y in strip
|
||||||
|
# - 't' (top): at top of text bounding box → padding in strip
|
||||||
|
# - 'm' (middle): center of em-square = (ascent + descent) / 2 below ascender
|
||||||
|
# - 's' (baseline): at baseline = ascent below ascender
|
||||||
|
# - 'd' (descender): at descender line = ascent + descent below ascender
|
||||||
|
|
||||||
|
if v_anchor == 'a': # ascender
|
||||||
|
anchor_y = float(draw_y)
|
||||||
|
elif v_anchor == 't': # top of bbox
|
||||||
|
anchor_y = float(padding)
|
||||||
|
elif v_anchor == 'm': # middle (center of em-square, per PIL's calculation)
|
||||||
|
anchor_y = float(draw_y + (ascent + descent) / 2.0)
|
||||||
|
elif v_anchor == 's': # baseline
|
||||||
|
anchor_y = float(draw_y + ascent)
|
||||||
|
elif v_anchor == 'd': # descender
|
||||||
|
anchor_y = float(draw_y + ascent + descent)
|
||||||
|
else:
|
||||||
|
anchor_y = float(draw_y) # default to ascender
|
||||||
|
|
||||||
|
strip = TextStrip(
|
||||||
|
text=text,
|
||||||
|
image=img_array,
|
||||||
|
width=img_width,
|
||||||
|
height=img_height,
|
||||||
|
baseline_y=baseline_y,
|
||||||
|
bearing_x=float(x_min),
|
||||||
|
anchor_x=anchor_x,
|
||||||
|
anchor_y=anchor_y,
|
||||||
|
stroke_width=stroke_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TEXT_STRIP_CACHE[cache_key] = strip
|
||||||
|
return strip
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Compile-time functions (called during S-expression compilation)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def get_glyph(char: str, font_name: str = None, font_size: int = 32) -> GlyphData:
|
||||||
|
"""Get glyph data for a single character. Compile-time."""
|
||||||
|
cache = _get_glyph_cache(font_name, font_size)
|
||||||
|
return cache.get(char, cache.get(' '))
|
||||||
|
|
||||||
|
|
||||||
|
def get_glyphs(text: str, font_name: str = None, font_size: int = 32) -> list:
|
||||||
|
"""Get glyph data for a string. Compile-time."""
|
||||||
|
cache = _get_glyph_cache(font_name, font_size)
|
||||||
|
space = cache.get(' ')
|
||||||
|
return [cache.get(c, space) for c in text]
|
||||||
|
|
||||||
|
|
||||||
|
def get_font_ascent(font_name: str = None, font_size: int = 32) -> float:
|
||||||
|
"""Get font ascent. Compile-time."""
|
||||||
|
_get_glyph_cache(font_name, font_size) # Ensure cache exists
|
||||||
|
return _METRICS_CACHE[(font_name, font_size)][0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_font_descent(font_name: str = None, font_size: int = 32) -> float:
|
||||||
|
"""Get font descent. Compile-time."""
|
||||||
|
_get_glyph_cache(font_name, font_size)
|
||||||
|
return _METRICS_CACHE[(font_name, font_size)][1]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# JAX Runtime Primitives
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def place_glyph_jax(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
glyph_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
bearing_x: float,
|
||||||
|
bearing_y: float,
|
||||||
|
color: jnp.ndarray, # (3,) RGB 0-255
|
||||||
|
opacity: float = 1.0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
Place a glyph onto a frame. This is the core JAX primitive.
|
||||||
|
|
||||||
|
All positioning math can use traced values (x, y from audio, time, etc.)
|
||||||
|
The glyph_image is static (determined at compile time).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB frame
|
||||||
|
glyph_image: (gh, gw, 4) RGBA glyph (pre-converted to JAX array)
|
||||||
|
x: X position of glyph origin (baseline point)
|
||||||
|
y: Y position of baseline
|
||||||
|
bearing_x: Left side bearing
|
||||||
|
bearing_y: Top bearing (from baseline to top)
|
||||||
|
color: RGB color array
|
||||||
|
opacity: Opacity 0-1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frame with glyph composited
|
||||||
|
"""
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
gh, gw = glyph_image.shape[:2]
|
||||||
|
|
||||||
|
# Calculate destination position
|
||||||
|
# bearing_x: how far right of origin the glyph starts (can be negative)
|
||||||
|
# bearing_y: how far up from baseline the glyph extends
|
||||||
|
padding = 2 # Must match padding used in glyph creation
|
||||||
|
dst_x = x + bearing_x - padding
|
||||||
|
dst_y = y - bearing_y - padding
|
||||||
|
|
||||||
|
# Extract glyph RGB and alpha
|
||||||
|
glyph_rgb = glyph_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||||
|
opacity_int = jnp.round(opacity * 255)
|
||||||
|
glyph_a_raw = glyph_image[:, :, 3:4].astype(jnp.float32)
|
||||||
|
glyph_alpha = jnp.floor(glyph_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
|
||||||
|
# Apply color tint (glyph is white, multiply by color)
|
||||||
|
color_normalized = color.astype(jnp.float32) / 255.0
|
||||||
|
tinted = glyph_rgb * color_normalized
|
||||||
|
|
||||||
|
from jax.lax import dynamic_update_slice
|
||||||
|
|
||||||
|
# Use padded buffer to avoid XLA's dynamic_update_slice clamping
|
||||||
|
buf_h = h + 2 * gh
|
||||||
|
buf_w = w + 2 * gw
|
||||||
|
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||||
|
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||||
|
|
||||||
|
dst_x_int = dst_x.astype(jnp.int32)
|
||||||
|
dst_y_int = dst_y.astype(jnp.int32)
|
||||||
|
place_y = jnp.maximum(dst_y_int + gh, 0).astype(jnp.int32)
|
||||||
|
place_x = jnp.maximum(dst_x_int + gw, 0).astype(jnp.int32)
|
||||||
|
|
||||||
|
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||||
|
alpha_buf = dynamic_update_slice(alpha_buf, glyph_alpha, (place_y, place_x, 0))
|
||||||
|
|
||||||
|
rgb_layer = rgb_buf[gh:gh + h, gw:gw + w, :]
|
||||||
|
alpha_layer = alpha_buf[gh:gh + h, gw:gw + w, :]
|
||||||
|
|
||||||
|
# Alpha composite using PIL-compatible integer arithmetic
|
||||||
|
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
dst_int = frame.astype(jnp.int32)
|
||||||
|
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||||
|
|
||||||
|
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def place_text_strip_jax(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
strip_image: jnp.ndarray, # (H, W, 4) RGBA
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
baseline_y: int,
|
||||||
|
bearing_x: float,
|
||||||
|
color: jnp.ndarray,
|
||||||
|
opacity: float = 1.0,
|
||||||
|
anchor_x: float = 0.0,
|
||||||
|
anchor_y: float = 0.0,
|
||||||
|
stroke_width: int = 0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
Place a pre-rendered text strip onto a frame.
|
||||||
|
|
||||||
|
The strip was rendered at compile time with proper sub-pixel anti-aliasing.
|
||||||
|
This just composites it at the specified position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: (H, W, 3) RGB frame
|
||||||
|
strip_image: (sh, sw, 4) RGBA text strip
|
||||||
|
x: X position for anchor point
|
||||||
|
y: Y position for anchor point
|
||||||
|
baseline_y: Y position of baseline within the strip
|
||||||
|
bearing_x: Left side bearing
|
||||||
|
color: RGB color
|
||||||
|
opacity: Opacity 0-1
|
||||||
|
anchor_x: X offset of anchor point within strip
|
||||||
|
anchor_y: Y offset of anchor point within strip
|
||||||
|
stroke_width: Stroke width used when rendering (affects padding)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frame with text composited
|
||||||
|
"""
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
sh, sw = strip_image.shape[:2]
|
||||||
|
|
||||||
|
# Calculate destination position
|
||||||
|
# Anchor point (anchor_x, anchor_y) in strip should be at (x, y) in frame
|
||||||
|
# anchor_x/anchor_y already account for the anchor position within the strip
|
||||||
|
# Use floor(x + 0.5) for consistent rounding (jnp.round uses banker's rounding)
|
||||||
|
dst_x = jnp.floor(x - anchor_x + 0.5).astype(jnp.int32)
|
||||||
|
dst_y = jnp.floor(y - anchor_y + 0.5).astype(jnp.int32)
|
||||||
|
|
||||||
|
# Extract strip RGB and alpha
|
||||||
|
strip_rgb = strip_image[:, :, :3].astype(jnp.float32) / 255.0
|
||||||
|
# Match PIL's integer alpha math: (coverage * int(opacity*255) + 127) // 255
|
||||||
|
# Use jnp.round (banker's rounding) to match Python's round() used by PIL
|
||||||
|
opacity_int = jnp.round(opacity * 255)
|
||||||
|
strip_a_raw = strip_image[:, :, 3:4].astype(jnp.float32)
|
||||||
|
strip_alpha = jnp.floor(strip_a_raw * opacity_int / 255.0 + 0.5) / 255.0
|
||||||
|
|
||||||
|
# Apply color tint
|
||||||
|
color_normalized = color.astype(jnp.float32) / 255.0
|
||||||
|
tinted = strip_rgb * color_normalized
|
||||||
|
|
||||||
|
from jax.lax import dynamic_update_slice
|
||||||
|
|
||||||
|
# Use a padded buffer to avoid XLA's dynamic_update_slice clamping behavior.
|
||||||
|
# XLA clamps indices so the update fits, which silently shifts the strip.
|
||||||
|
# By placing into a buffer padded by strip dimensions, then extracting the
|
||||||
|
# frame-sized region, we get correct clipping for both overflow and underflow.
|
||||||
|
buf_h = h + 2 * sh
|
||||||
|
buf_w = w + 2 * sw
|
||||||
|
rgb_buf = jnp.zeros((buf_h, buf_w, 3), dtype=jnp.float32)
|
||||||
|
alpha_buf = jnp.zeros((buf_h, buf_w, 1), dtype=jnp.float32)
|
||||||
|
|
||||||
|
# Offset by (sh, sw) so dst=0 maps to (sh, sw) in buffer
|
||||||
|
place_y = jnp.maximum(dst_y + sh, 0).astype(jnp.int32)
|
||||||
|
place_x = jnp.maximum(dst_x + sw, 0).astype(jnp.int32)
|
||||||
|
|
||||||
|
rgb_buf = dynamic_update_slice(rgb_buf, tinted, (place_y, place_x, 0))
|
||||||
|
alpha_buf = dynamic_update_slice(alpha_buf, strip_alpha, (place_y, place_x, 0))
|
||||||
|
|
||||||
|
# Extract frame-sized region (sh, sw are compile-time constants from strip shape)
|
||||||
|
rgb_layer = rgb_buf[sh:sh + h, sw:sw + w, :]
|
||||||
|
alpha_layer = alpha_buf[sh:sh + h, sw:sw + w, :]
|
||||||
|
|
||||||
|
# Alpha composite using PIL-compatible integer arithmetic:
|
||||||
|
# result = (src * alpha + dst * (255 - alpha) + 127) // 255
|
||||||
|
src_int = jnp.floor(rgb_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
alpha_int = jnp.floor(alpha_layer * 255 + 0.5).astype(jnp.int32)
|
||||||
|
dst_int = frame.astype(jnp.int32)
|
||||||
|
result = (src_int * alpha_int + dst_int * (255 - alpha_int) + 127) // 255
|
||||||
|
|
||||||
|
return jnp.clip(result, 0, 255).astype(jnp.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def place_glyph_simple(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
glyph: GlyphData,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
color: tuple = (255, 255, 255),
|
||||||
|
opacity: float = 1.0,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
Convenience wrapper that takes GlyphData directly.
|
||||||
|
Converts glyph image to JAX array.
|
||||||
|
|
||||||
|
For S-expression use, prefer place_glyph_jax with pre-converted arrays.
|
||||||
|
"""
|
||||||
|
glyph_jax = jnp.asarray(glyph.image)
|
||||||
|
color_jax = jnp.array(color, dtype=jnp.float32)
|
||||||
|
|
||||||
|
return place_glyph_jax(
|
||||||
|
frame, glyph_jax, x, y,
|
||||||
|
glyph.bearing_x, glyph.bearing_y,
|
||||||
|
color_jax, opacity
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# S-Expression Primitive Bindings
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def bind_typography_primitives(env: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Add typography primitives to an S-expression environment.
|
||||||
|
|
||||||
|
Primitives added:
|
||||||
|
(text-glyphs text font-size) -> list of glyph data
|
||||||
|
(glyph-image g) -> JAX array (H, W, 4)
|
||||||
|
(glyph-advance g) -> float
|
||||||
|
(glyph-bearing-x g) -> float
|
||||||
|
(glyph-bearing-y g) -> float
|
||||||
|
(glyph-width g) -> int
|
||||||
|
(glyph-height g) -> int
|
||||||
|
(font-ascent font-size) -> float
|
||||||
|
(font-descent font-size) -> float
|
||||||
|
(place-glyph frame glyph-img x y bearing-x bearing-y color opacity) -> frame
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prim_text_glyphs(text, font_size=32, font_name=None):
|
||||||
|
"""Get list of glyph data for text. Compile-time."""
|
||||||
|
return get_glyphs(str(text), font_name, int(font_size))
|
||||||
|
|
||||||
|
def prim_glyph_image(glyph):
|
||||||
|
"""Get glyph image as JAX array."""
|
||||||
|
return jnp.asarray(glyph.image)
|
||||||
|
|
||||||
|
def prim_glyph_advance(glyph):
|
||||||
|
"""Get glyph advance width."""
|
||||||
|
return glyph.advance
|
||||||
|
|
||||||
|
def prim_glyph_bearing_x(glyph):
|
||||||
|
"""Get glyph left side bearing."""
|
||||||
|
return glyph.bearing_x
|
||||||
|
|
||||||
|
def prim_glyph_bearing_y(glyph):
|
||||||
|
"""Get glyph top bearing."""
|
||||||
|
return glyph.bearing_y
|
||||||
|
|
||||||
|
def prim_glyph_width(glyph):
|
||||||
|
"""Get glyph image width."""
|
||||||
|
return glyph.width
|
||||||
|
|
||||||
|
def prim_glyph_height(glyph):
|
||||||
|
"""Get glyph image height."""
|
||||||
|
return glyph.height
|
||||||
|
|
||||||
|
def prim_font_ascent(font_size=32, font_name=None):
|
||||||
|
"""Get font ascent."""
|
||||||
|
return get_font_ascent(font_name, int(font_size))
|
||||||
|
|
||||||
|
def prim_font_descent(font_size=32, font_name=None):
|
||||||
|
"""Get font descent."""
|
||||||
|
return get_font_descent(font_name, int(font_size))
|
||||||
|
|
||||||
|
def prim_place_glyph(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||||
|
color=(255, 255, 255), opacity=1.0):
|
||||||
|
"""Place glyph on frame. Runtime JAX operation."""
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
return place_glyph_jax(frame, glyph_img, x, y, bearing_x, bearing_y,
|
||||||
|
color_arr, opacity)
|
||||||
|
|
||||||
|
def prim_glyph_kerning(glyph1, glyph2, font_size=32, font_name=None):
|
||||||
|
"""Get kerning adjustment between two glyphs. Compile-time.
|
||||||
|
|
||||||
|
Returns adjustment to add to glyph1's advance when glyph2 follows.
|
||||||
|
Typically negative (characters move closer).
|
||||||
|
|
||||||
|
Usage: (+ (glyph-advance g) (glyph-kerning g next-g font-size))
|
||||||
|
"""
|
||||||
|
return get_kerning(glyph1.char, glyph2.char, font_name, int(font_size))
|
||||||
|
|
||||||
|
def prim_char_kerning(char1, char2, font_size=32, font_name=None):
|
||||||
|
"""Get kerning adjustment between two characters. Compile-time."""
|
||||||
|
return get_kerning(str(char1), str(char2), font_name, int(font_size))
|
||||||
|
|
||||||
|
# TextStrip primitives for pre-rendered text with proper anti-aliasing
|
||||||
|
def prim_render_text_strip(text, font_size=32, font_name=None):
|
||||||
|
"""Render text to a strip at compile time. Perfect anti-aliasing."""
|
||||||
|
return render_text_strip(str(text), font_name, int(font_size))
|
||||||
|
|
||||||
|
def prim_render_text_strip_styled(
|
||||||
|
text, font_size=32, font_name=None,
|
||||||
|
stroke_width=0, stroke_fill=None,
|
||||||
|
anchor="la", multiline=False, line_spacing=4, align="left"
|
||||||
|
):
|
||||||
|
"""Render styled text to a strip. Supports stroke, anchors, multiline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to render
|
||||||
|
font_size: Size in pixels
|
||||||
|
font_name: Path to font file
|
||||||
|
stroke_width: Outline width (0 = no outline)
|
||||||
|
stroke_fill: Outline color as (R,G,B) or (R,G,B,A)
|
||||||
|
anchor: 2-char anchor code (e.g., "mm" for center, "la" for left-ascender)
|
||||||
|
multiline: If True, handle newlines
|
||||||
|
line_spacing: Extra pixels between lines
|
||||||
|
align: "left", "center", "right" for multiline
|
||||||
|
"""
|
||||||
|
return render_text_strip(
|
||||||
|
str(text), font_name, int(font_size),
|
||||||
|
stroke_width=int(stroke_width),
|
||||||
|
stroke_fill=stroke_fill,
|
||||||
|
anchor=str(anchor),
|
||||||
|
multiline=bool(multiline),
|
||||||
|
line_spacing=int(line_spacing),
|
||||||
|
align=str(align),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prim_text_strip_image(strip):
|
||||||
|
"""Get text strip image as JAX array."""
|
||||||
|
return jnp.asarray(strip.image)
|
||||||
|
|
||||||
|
def prim_text_strip_width(strip):
|
||||||
|
"""Get text strip width."""
|
||||||
|
return strip.width
|
||||||
|
|
||||||
|
def prim_text_strip_height(strip):
|
||||||
|
"""Get text strip height."""
|
||||||
|
return strip.height
|
||||||
|
|
||||||
|
def prim_text_strip_baseline_y(strip):
|
||||||
|
"""Get text strip baseline Y position."""
|
||||||
|
return strip.baseline_y
|
||||||
|
|
||||||
|
def prim_text_strip_bearing_x(strip):
|
||||||
|
"""Get text strip left bearing."""
|
||||||
|
return strip.bearing_x
|
||||||
|
|
||||||
|
def prim_text_strip_anchor_x(strip):
|
||||||
|
"""Get text strip anchor X offset."""
|
||||||
|
return strip.anchor_x
|
||||||
|
|
||||||
|
def prim_text_strip_anchor_y(strip):
|
||||||
|
"""Get text strip anchor Y offset."""
|
||||||
|
return strip.anchor_y
|
||||||
|
|
||||||
|
def prim_place_text_strip(frame, strip, x, y, color=(255, 255, 255), opacity=1.0):
|
||||||
|
"""Place pre-rendered text strip on frame. Runtime JAX operation."""
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
return place_text_strip_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color_arr, opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to environment
|
||||||
|
env.update({
|
||||||
|
# Glyph-by-glyph primitives (for wave, arc, audio-reactive effects)
|
||||||
|
'text-glyphs': prim_text_glyphs,
|
||||||
|
'glyph-image': prim_glyph_image,
|
||||||
|
'glyph-advance': prim_glyph_advance,
|
||||||
|
'glyph-bearing-x': prim_glyph_bearing_x,
|
||||||
|
'glyph-bearing-y': prim_glyph_bearing_y,
|
||||||
|
'glyph-width': prim_glyph_width,
|
||||||
|
'glyph-height': prim_glyph_height,
|
||||||
|
'glyph-kerning': prim_glyph_kerning,
|
||||||
|
'char-kerning': prim_char_kerning,
|
||||||
|
'font-ascent': prim_font_ascent,
|
||||||
|
'font-descent': prim_font_descent,
|
||||||
|
'place-glyph': prim_place_glyph,
|
||||||
|
# TextStrip primitives (for pixel-perfect static text)
|
||||||
|
'render-text-strip': prim_render_text_strip,
|
||||||
|
'render-text-strip-styled': prim_render_text_strip_styled,
|
||||||
|
'text-strip-image': prim_text_strip_image,
|
||||||
|
'text-strip-width': prim_text_strip_width,
|
||||||
|
'text-strip-height': prim_text_strip_height,
|
||||||
|
'text-strip-baseline-y': prim_text_strip_baseline_y,
|
||||||
|
'text-strip-bearing-x': prim_text_strip_bearing_x,
|
||||||
|
'text-strip-anchor-x': prim_text_strip_anchor_x,
|
||||||
|
'text-strip-anchor-y': prim_text_strip_anchor_y,
|
||||||
|
'place-text-strip': prim_place_text_strip,
|
||||||
|
})
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Example: Render text using primitives (for testing)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def render_text_primitives(
|
||||||
|
frame: jnp.ndarray,
|
||||||
|
text: str,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
font_size: int = 32,
|
||||||
|
color: tuple = (255, 255, 255),
|
||||||
|
opacity: float = 1.0,
|
||||||
|
use_kerning: bool = True,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
Render text using the primitives.
|
||||||
|
This is what an S-expression would compile to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_kerning: If True, apply kerning adjustments between characters
|
||||||
|
"""
|
||||||
|
glyphs = get_glyphs(text, None, font_size)
|
||||||
|
color_arr = jnp.array(color, dtype=jnp.float32)
|
||||||
|
|
||||||
|
cursor = x
|
||||||
|
for i, g in enumerate(glyphs):
|
||||||
|
glyph_jax = jnp.asarray(g.image)
|
||||||
|
frame = place_glyph_jax(
|
||||||
|
frame, glyph_jax, cursor, y,
|
||||||
|
g.bearing_x, g.bearing_y,
|
||||||
|
color_arr, opacity
|
||||||
|
)
|
||||||
|
# Advance cursor with optional kerning
|
||||||
|
advance = g.advance
|
||||||
|
if use_kerning and i + 1 < len(glyphs):
|
||||||
|
advance += get_kerning(g.char, glyphs[i + 1].char, None, font_size)
|
||||||
|
cursor = cursor + advance
|
||||||
|
|
||||||
|
return frame
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,7 @@ Context (ctx) is passed explicitly to frame evaluation:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -62,6 +63,38 @@ class Context:
|
|||||||
fps: float = 30.0
|
fps: float = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredEffectChain:
|
||||||
|
"""
|
||||||
|
Represents a chain of JAX effects that haven't been executed yet.
|
||||||
|
|
||||||
|
Allows effects to be accumulated through let bindings and fused
|
||||||
|
into a single JIT-compiled function when the result is needed.
|
||||||
|
"""
|
||||||
|
__slots__ = ('effects', 'params_list', 'base_frame', 't', 'frame_num')
|
||||||
|
|
||||||
|
def __init__(self, effects: list, params_list: list, base_frame, t: float, frame_num: int):
|
||||||
|
self.effects = effects # List of effect names, innermost first
|
||||||
|
self.params_list = params_list # List of param dicts, matching effects
|
||||||
|
self.base_frame = base_frame # The actual frame array at the start
|
||||||
|
self.t = t
|
||||||
|
self.frame_num = frame_num
|
||||||
|
|
||||||
|
def extend(self, effect_name: str, params: dict) -> 'DeferredEffectChain':
|
||||||
|
"""Add another effect to the chain (outermost)."""
|
||||||
|
return DeferredEffectChain(
|
||||||
|
self.effects + [effect_name],
|
||||||
|
self.params_list + [params],
|
||||||
|
self.base_frame,
|
||||||
|
self.t,
|
||||||
|
self.frame_num
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
"""Allow shape check without forcing execution."""
|
||||||
|
return self.base_frame.shape if hasattr(self.base_frame, 'shape') else None
|
||||||
|
|
||||||
|
|
||||||
class StreamInterpreter:
|
class StreamInterpreter:
|
||||||
"""
|
"""
|
||||||
Fully generic streaming sexp interpreter.
|
Fully generic streaming sexp interpreter.
|
||||||
@@ -98,6 +131,9 @@ class StreamInterpreter:
|
|||||||
self.use_jax = use_jax
|
self.use_jax = use_jax
|
||||||
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
self.jax_effects: Dict[str, Callable] = {} # Cache of JAX-compiled effects
|
||||||
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
self.jax_effect_paths: Dict[str, Path] = {} # Track source paths for effects
|
||||||
|
self.jax_fused_chains: Dict[str, Callable] = {} # Cache of fused effect chains
|
||||||
|
self.jax_batched_chains: Dict[str, Callable] = {} # Cache of vmapped chains
|
||||||
|
self.jax_batch_size: int = int(os.environ.get("JAX_BATCH_SIZE", "30")) # Configurable via env
|
||||||
if use_jax:
|
if use_jax:
|
||||||
if _init_jax():
|
if _init_jax():
|
||||||
print("JAX acceleration enabled", file=sys.stderr)
|
print("JAX acceleration enabled", file=sys.stderr)
|
||||||
@@ -238,6 +274,8 @@ class StreamInterpreter:
|
|||||||
"""Load primitives from a Python library file.
|
"""Load primitives from a Python library file.
|
||||||
|
|
||||||
Prefers GPU-accelerated versions (*_gpu.py) when available.
|
Prefers GPU-accelerated versions (*_gpu.py) when available.
|
||||||
|
Uses cached modules from sys.modules to ensure consistent state
|
||||||
|
(e.g., same RNG instance for all interpreters).
|
||||||
"""
|
"""
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
@@ -264,9 +302,26 @@ class StreamInterpreter:
|
|||||||
if not lib_path:
|
if not lib_path:
|
||||||
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
|
raise FileNotFoundError(f"Primitive library '{lib_name}' not found. Searched paths: {lib_paths}")
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
# Use cached module if already imported to preserve state (e.g., RNG)
|
||||||
module = importlib.util.module_from_spec(spec)
|
# This is critical for deterministic random number sequences
|
||||||
spec.loader.exec_module(module)
|
# Check multiple possible module keys (standard import paths and our cache)
|
||||||
|
possible_keys = [
|
||||||
|
f"sexp_effects.primitive_libs.{actual_lib_name}",
|
||||||
|
f"sexp_primitives.{actual_lib_name}",
|
||||||
|
]
|
||||||
|
|
||||||
|
module = None
|
||||||
|
for key in possible_keys:
|
||||||
|
if key in sys.modules:
|
||||||
|
module = sys.modules[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
spec = importlib.util.spec_from_file_location(actual_lib_name, lib_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
# Cache for future use under our key
|
||||||
|
sys.modules[f"sexp_primitives.{actual_lib_name}"] = module
|
||||||
|
|
||||||
# Check if this is a GPU-accelerated module
|
# Check if this is a GPU-accelerated module
|
||||||
is_gpu = actual_lib_name.endswith('_gpu')
|
is_gpu = actual_lib_name.endswith('_gpu')
|
||||||
@@ -452,30 +507,353 @@ class StreamInterpreter:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
jax_fn = self.jax_effects[name]
|
jax_fn = self.jax_effects[name]
|
||||||
# Ensure frame is numpy array
|
# Handle GPU frames (CuPy) - need to move to CPU for CPU JAX
|
||||||
|
# JAX handles numpy and JAX arrays natively, no conversion needed
|
||||||
if hasattr(frame, 'cpu'):
|
if hasattr(frame, 'cpu'):
|
||||||
frame = frame.cpu
|
frame = frame.cpu
|
||||||
elif hasattr(frame, 'get'):
|
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||||
frame = frame.get()
|
frame = frame.get() # CuPy array -> numpy
|
||||||
|
|
||||||
# Get seed from config for deterministic random
|
# Get seed from config for deterministic random
|
||||||
seed = self.config.get('seed', 42)
|
seed = self.config.get('seed', 42)
|
||||||
|
|
||||||
# Call JAX function with parameters
|
# Call JAX function with parameters
|
||||||
result = jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
# Return JAX array directly - don't block or convert per-effect
|
||||||
|
# Conversion to numpy happens once at frame write time
|
||||||
# Convert result back to numpy if needed
|
return jax_fn(frame, t=t, frame_num=frame_num, seed=seed, **params)
|
||||||
if hasattr(result, 'block_until_ready'):
|
|
||||||
result.block_until_ready() # Ensure computation is complete
|
|
||||||
if hasattr(result, '__array__'):
|
|
||||||
result = np.asarray(result)
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Fall back to interpreter on error
|
# Fall back to interpreter on error
|
||||||
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
print(f"JAX effect {name} error, falling back: {e}", file=sys.stderr)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _is_jax_effect_expr(self, expr) -> bool:
|
||||||
|
"""Check if an expression is a JAX-compiled effect call."""
|
||||||
|
if not isinstance(expr, list) or not expr:
|
||||||
|
return False
|
||||||
|
head = expr[0]
|
||||||
|
if not isinstance(head, Symbol):
|
||||||
|
return False
|
||||||
|
return head.name in self.jax_effects
|
||||||
|
|
||||||
|
def _extract_effect_chain(self, expr, env) -> Optional[Tuple[list, list, Any]]:
|
||||||
|
"""
|
||||||
|
Extract a chain of JAX effects from an expression.
|
||||||
|
|
||||||
|
Returns: (effect_names, params_list, base_frame_expr) or None if not a chain.
|
||||||
|
effect_names and params_list are in execution order (innermost first).
|
||||||
|
|
||||||
|
For (effect1 (effect2 frame :p1 v1) :p2 v2):
|
||||||
|
Returns: (['effect2', 'effect1'], [params2, params1], frame_expr)
|
||||||
|
"""
|
||||||
|
if not self._is_jax_effect_expr(expr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
chain = []
|
||||||
|
params_list = []
|
||||||
|
current = expr
|
||||||
|
|
||||||
|
while self._is_jax_effect_expr(current):
|
||||||
|
head = current[0]
|
||||||
|
effect_name = head.name
|
||||||
|
args = current[1:]
|
||||||
|
|
||||||
|
# Extract params for this effect
|
||||||
|
effect = self.effects[effect_name]
|
||||||
|
effect_params = {}
|
||||||
|
for pname, pdef in effect['params'].items():
|
||||||
|
effect_params[pname] = pdef.get('default', 0)
|
||||||
|
|
||||||
|
# Find the frame argument (first positional) and other params
|
||||||
|
frame_arg = None
|
||||||
|
i = 0
|
||||||
|
while i < len(args):
|
||||||
|
if isinstance(args[i], Keyword):
|
||||||
|
pname = args[i].name
|
||||||
|
if pname in effect['params'] and i + 1 < len(args):
|
||||||
|
effect_params[pname] = self._eval(args[i + 1], env)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
if frame_arg is None:
|
||||||
|
frame_arg = args[i] # First positional is frame
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
chain.append(effect_name)
|
||||||
|
params_list.append(effect_params)
|
||||||
|
|
||||||
|
if frame_arg is None:
|
||||||
|
return None # No frame argument found
|
||||||
|
|
||||||
|
# Check if frame_arg is another effect call
|
||||||
|
if self._is_jax_effect_expr(frame_arg):
|
||||||
|
current = frame_arg
|
||||||
|
else:
|
||||||
|
# End of chain - frame_arg is the base frame
|
||||||
|
# Reverse to get innermost-first execution order
|
||||||
|
chain.reverse()
|
||||||
|
params_list.reverse()
|
||||||
|
return (chain, params_list, frame_arg)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_chain_key(self, effect_names: list, params_list: list) -> str:
|
||||||
|
"""Generate a cache key for an effect chain.
|
||||||
|
|
||||||
|
Includes static param values in the key since they affect compilation.
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
for name, params in zip(effect_names, params_list):
|
||||||
|
param_parts = []
|
||||||
|
for pname in sorted(params.keys()):
|
||||||
|
pval = params[pname]
|
||||||
|
# Include static values in key (strings, bools)
|
||||||
|
if isinstance(pval, (str, bool)):
|
||||||
|
param_parts.append(f"{pname}={pval}")
|
||||||
|
else:
|
||||||
|
param_parts.append(pname)
|
||||||
|
parts.append(f"{name}:{','.join(param_parts)}")
|
||||||
|
return '|'.join(parts)
|
||||||
|
|
||||||
|
def _compile_effect_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||||
|
"""
|
||||||
|
Compile a chain of effects into a single fused JAX function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||||
|
params_list: List of param dicts for each effect (used to detect static types)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JIT-compiled function: (frame, t, frame_num, seed, **all_params) -> frame
|
||||||
|
"""
|
||||||
|
if not _JAX_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
|
||||||
|
# Get the individual JAX functions
|
||||||
|
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||||
|
|
||||||
|
# Pre-extract param names and identify static params from actual values
|
||||||
|
effect_param_names = []
|
||||||
|
static_params = ['seed'] # seed is always static
|
||||||
|
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||||
|
param_names = list(params.keys())
|
||||||
|
effect_param_names.append(param_names)
|
||||||
|
# Check actual values to identify static types
|
||||||
|
for pname, pval in params.items():
|
||||||
|
if isinstance(pval, (str, bool)):
|
||||||
|
static_params.append(f"_p{i}_{pname}")
|
||||||
|
|
||||||
|
def fused_fn(frame, t, frame_num, seed, **kwargs):
|
||||||
|
result = frame
|
||||||
|
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||||
|
# Extract params for this effect from kwargs
|
||||||
|
effect_kwargs = {}
|
||||||
|
for pname in param_names:
|
||||||
|
key = f"_p{i}_{pname}"
|
||||||
|
if key in kwargs:
|
||||||
|
effect_kwargs[pname] = kwargs[key]
|
||||||
|
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# JIT with static params (seed + any string/bool params)
|
||||||
|
return jax.jit(fused_fn, static_argnames=static_params)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to compile effect chain {effect_names}: {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _apply_effect_chain(self, effect_names: list, params_list: list, frame, t: float, frame_num: int):
|
||||||
|
"""Apply a chain of effects, using fused compilation if available."""
|
||||||
|
chain_key = self._get_chain_key(effect_names, params_list)
|
||||||
|
|
||||||
|
# Try to get or compile fused chain
|
||||||
|
if chain_key not in self.jax_fused_chains:
|
||||||
|
fused_fn = self._compile_effect_chain(effect_names, params_list)
|
||||||
|
self.jax_fused_chains[chain_key] = fused_fn
|
||||||
|
if fused_fn:
|
||||||
|
print(f" [JAX fused chain: {' -> '.join(effect_names)}]", file=sys.stderr)
|
||||||
|
|
||||||
|
fused_fn = self.jax_fused_chains.get(chain_key)
|
||||||
|
|
||||||
|
if fused_fn is not None:
|
||||||
|
# Build kwargs with prefixed param names
|
||||||
|
kwargs = {}
|
||||||
|
for i, params in enumerate(params_list):
|
||||||
|
for pname, pval in params.items():
|
||||||
|
kwargs[f"_p{i}_{pname}"] = pval
|
||||||
|
|
||||||
|
seed = self.config.get('seed', 42)
|
||||||
|
|
||||||
|
# Handle GPU frames
|
||||||
|
if hasattr(frame, 'cpu'):
|
||||||
|
frame = frame.cpu
|
||||||
|
elif hasattr(frame, 'get') and hasattr(frame, '__cuda_array_interface__'):
|
||||||
|
frame = frame.get()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return fused_fn(frame, t=t, frame_num=frame_num, seed=seed, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Fused chain error: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Fall back to sequential application
|
||||||
|
result = frame
|
||||||
|
for name, params in zip(effect_names, params_list):
|
||||||
|
result = self._apply_jax_effect(name, result, params, t, frame_num)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _force_deferred(self, deferred: DeferredEffectChain):
|
||||||
|
"""Execute a deferred effect chain and return the actual array."""
|
||||||
|
if len(deferred.effects) == 0:
|
||||||
|
return deferred.base_frame
|
||||||
|
|
||||||
|
return self._apply_effect_chain(
|
||||||
|
deferred.effects,
|
||||||
|
deferred.params_list,
|
||||||
|
deferred.base_frame,
|
||||||
|
deferred.t,
|
||||||
|
deferred.frame_num
|
||||||
|
)
|
||||||
|
|
||||||
|
def _maybe_force(self, value):
|
||||||
|
"""Force a deferred chain if needed, otherwise return as-is."""
|
||||||
|
if isinstance(value, DeferredEffectChain):
|
||||||
|
return self._force_deferred(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _compile_batched_chain(self, effect_names: list, params_list: list) -> Optional[Callable]:
|
||||||
|
"""
|
||||||
|
Compile a vmapped version of an effect chain for batch processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
effect_names: List of effect names in order [innermost, ..., outermost]
|
||||||
|
params_list: List of param dicts (used to detect static types)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batched function: (frames, ts, frame_nums, seed, **batched_params) -> frames
|
||||||
|
Where frames is (N, H, W, 3), ts/frame_nums are (N,), params are (N,) or scalar
|
||||||
|
"""
|
||||||
|
if not _JAX_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
# Get the individual JAX functions
|
||||||
|
jax_fns = [self.jax_effects[name] for name in effect_names]
|
||||||
|
|
||||||
|
# Pre-extract param info
|
||||||
|
effect_param_names = []
|
||||||
|
static_params = set()
|
||||||
|
for i, (name, params) in enumerate(zip(effect_names, params_list)):
|
||||||
|
param_names = list(params.keys())
|
||||||
|
effect_param_names.append(param_names)
|
||||||
|
for pname, pval in params.items():
|
||||||
|
if isinstance(pval, (str, bool)):
|
||||||
|
static_params.add(f"_p{i}_{pname}")
|
||||||
|
|
||||||
|
# Single-frame function (will be vmapped)
|
||||||
|
def single_frame_fn(frame, t, frame_num, seed, **kwargs):
|
||||||
|
result = frame
|
||||||
|
for i, (jax_fn, param_names) in enumerate(zip(jax_fns, effect_param_names)):
|
||||||
|
effect_kwargs = {}
|
||||||
|
for pname in param_names:
|
||||||
|
key = f"_p{i}_{pname}"
|
||||||
|
if key in kwargs:
|
||||||
|
effect_kwargs[pname] = kwargs[key]
|
||||||
|
result = jax_fn(result, t=t, frame_num=frame_num, seed=seed, **effect_kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Return unbatched function - we'll vmap at call time with proper in_axes
|
||||||
|
return jax.jit(single_frame_fn, static_argnames=['seed'] + list(static_params))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to compile batched chain {effect_names}: {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _apply_batched_chain(self, effect_names: list, params_list_batch: list,
|
||||||
|
frames: list, ts: list, frame_nums: list) -> Optional[list]:
|
||||||
|
"""
|
||||||
|
Apply an effect chain to a batch of frames using vmap.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
effect_names: List of effect names
|
||||||
|
params_list_batch: List of params_list for each frame in batch
|
||||||
|
frames: List of input frames
|
||||||
|
ts: List of time values
|
||||||
|
frame_nums: List of frame numbers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of output frames, or None on failure
|
||||||
|
"""
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use first frame's params for chain key (assume same structure)
|
||||||
|
chain_key = self._get_chain_key(effect_names, params_list_batch[0])
|
||||||
|
batch_key = f"batch:{chain_key}"
|
||||||
|
|
||||||
|
# Compile batched version if needed
|
||||||
|
if batch_key not in self.jax_batched_chains:
|
||||||
|
batched_fn = self._compile_batched_chain(effect_names, params_list_batch[0])
|
||||||
|
self.jax_batched_chains[batch_key] = batched_fn
|
||||||
|
if batched_fn:
|
||||||
|
print(f" [JAX batched chain: {' -> '.join(effect_names)} x{len(frames)}]", file=sys.stderr)
|
||||||
|
|
||||||
|
batched_fn = self.jax_batched_chains.get(batch_key)
|
||||||
|
|
||||||
|
if batched_fn is not None:
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
# Stack frames into batch array
|
||||||
|
frames_array = jnp.stack([f if not hasattr(f, 'get') else f.get() for f in frames])
|
||||||
|
ts_array = jnp.array(ts)
|
||||||
|
frame_nums_array = jnp.array(frame_nums)
|
||||||
|
|
||||||
|
# Build kwargs - all numeric params as arrays for vmap
|
||||||
|
kwargs = {}
|
||||||
|
static_kwargs = {} # Non-vmapped (strings, bools)
|
||||||
|
|
||||||
|
for i, plist in enumerate(zip(*[p for p in params_list_batch])):
|
||||||
|
for j, pname in enumerate(params_list_batch[0][i].keys()):
|
||||||
|
key = f"_p{i}_{pname}"
|
||||||
|
values = [p[pname] for p in [params_list_batch[b][i] for b in range(len(frames))]]
|
||||||
|
|
||||||
|
first = values[0]
|
||||||
|
if isinstance(first, (str, bool)):
|
||||||
|
# Static params - not vmapped
|
||||||
|
static_kwargs[key] = first
|
||||||
|
elif isinstance(first, (int, float)):
|
||||||
|
# Always batch numeric params for simplicity
|
||||||
|
kwargs[key] = jnp.array(values)
|
||||||
|
elif hasattr(first, 'shape'):
|
||||||
|
kwargs[key] = jnp.stack(values)
|
||||||
|
else:
|
||||||
|
kwargs[key] = jnp.array(values)
|
||||||
|
|
||||||
|
seed = self.config.get('seed', 42)
|
||||||
|
|
||||||
|
# Create wrapper that unpacks the params dict
|
||||||
|
def single_call(frame, t, frame_num, params_dict):
|
||||||
|
return batched_fn(frame, t, frame_num, seed, **params_dict, **static_kwargs)
|
||||||
|
|
||||||
|
# vmap over frame, t, frame_num, and the params dict (as pytree)
|
||||||
|
vmapped_fn = jax.vmap(single_call, in_axes=(0, 0, 0, 0))
|
||||||
|
|
||||||
|
# Stack kwargs into a dict of arrays (pytree with matching structure)
|
||||||
|
results = vmapped_fn(frames_array, ts_array, frame_nums_array, kwargs)
|
||||||
|
|
||||||
|
# Unstack results
|
||||||
|
return [results[i] for i in range(len(frames))]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Batched chain error: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Fall back to sequential
|
||||||
|
return None
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
"""Initialize from sexp - load primitives, effects, defs, scans."""
|
||||||
# Set random seed for deterministic output
|
# Set random seed for deterministic output
|
||||||
@@ -869,6 +1247,22 @@ class StreamInterpreter:
|
|||||||
# === Effects ===
|
# === Effects ===
|
||||||
|
|
||||||
if op in self.effects:
|
if op in self.effects:
|
||||||
|
# Try to detect and fuse effect chains for JAX acceleration
|
||||||
|
if self.use_jax and op in self.jax_effects:
|
||||||
|
chain_info = self._extract_effect_chain(expr, env)
|
||||||
|
if chain_info is not None:
|
||||||
|
effect_names, params_list, base_frame_expr = chain_info
|
||||||
|
# Only use chain if we have 2+ effects (worth fusing)
|
||||||
|
if len(effect_names) >= 2:
|
||||||
|
base_frame = self._eval(base_frame_expr, env)
|
||||||
|
if base_frame is not None and hasattr(base_frame, 'shape'):
|
||||||
|
t = env.get('t', 0.0)
|
||||||
|
frame_num = env.get('frame-num', 0)
|
||||||
|
result = self._apply_effect_chain(effect_names, params_list, base_frame, t, frame_num)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
# Fall through if chain application fails
|
||||||
|
|
||||||
effect = self.effects[op]
|
effect = self.effects[op]
|
||||||
effect_env = dict(env)
|
effect_env = dict(env)
|
||||||
|
|
||||||
@@ -895,17 +1289,28 @@ class StreamInterpreter:
|
|||||||
positional_idx += 1
|
positional_idx += 1
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
# Try JAX-accelerated execution first
|
# Try JAX-accelerated execution with deferred chaining
|
||||||
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
if self.use_jax and op in self.jax_effects and frame_val is not None:
|
||||||
# Build params dict for JAX (exclude 'frame')
|
# Build params dict for JAX (exclude 'frame')
|
||||||
jax_params = {k: v for k, v in effect_env.items()
|
jax_params = {k: self._maybe_force(v) for k, v in effect_env.items()
|
||||||
if k != 'frame' and k in effect['params']}
|
if k != 'frame' and k in effect['params']}
|
||||||
t = env.get('t', 0.0)
|
t = env.get('t', 0.0)
|
||||||
frame_num = env.get('frame-num', 0)
|
frame_num = env.get('frame-num', 0)
|
||||||
result = self._apply_jax_effect(op, frame_val, jax_params, t, frame_num)
|
|
||||||
if result is not None:
|
# Check if input is a deferred chain - if so, extend it
|
||||||
return result
|
if isinstance(frame_val, DeferredEffectChain):
|
||||||
# Fall through to interpreter if JAX fails
|
return frame_val.extend(op, jax_params)
|
||||||
|
|
||||||
|
# Check if input is a valid frame - create new deferred chain
|
||||||
|
if hasattr(frame_val, 'shape'):
|
||||||
|
return DeferredEffectChain([op], [jax_params], frame_val, t, frame_num)
|
||||||
|
|
||||||
|
# Fall through to interpreter if not a valid frame
|
||||||
|
|
||||||
|
# Force any deferred frame before interpreter evaluation
|
||||||
|
if isinstance(frame_val, DeferredEffectChain):
|
||||||
|
frame_val = self._force_deferred(frame_val)
|
||||||
|
effect_env['frame'] = frame_val
|
||||||
|
|
||||||
return self._eval(effect['body'], effect_env)
|
return self._eval(effect['body'], effect_env)
|
||||||
|
|
||||||
@@ -922,10 +1327,15 @@ class StreamInterpreter:
|
|||||||
if isinstance(args[i], Keyword):
|
if isinstance(args[i], Keyword):
|
||||||
k = args[i].name
|
k = args[i].name
|
||||||
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
v = self._eval(args[i + 1], env) if i + 1 < len(args) else None
|
||||||
|
# Force deferred chains before passing to primitives
|
||||||
|
v = self._maybe_force(v)
|
||||||
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
kwargs[k] = self._maybe_to_numpy(v, for_gpu_primitive=is_gpu_prim)
|
||||||
i += 2
|
i += 2
|
||||||
else:
|
else:
|
||||||
evaluated_args.append(self._maybe_to_numpy(self._eval(args[i], env), for_gpu_primitive=is_gpu_prim))
|
val = self._eval(args[i], env)
|
||||||
|
# Force deferred chains before passing to primitives
|
||||||
|
val = self._maybe_force(val)
|
||||||
|
evaluated_args.append(self._maybe_to_numpy(val, for_gpu_primitive=is_gpu_prim))
|
||||||
i += 1
|
i += 1
|
||||||
try:
|
try:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@@ -1152,6 +1562,61 @@ class StreamInterpreter:
|
|||||||
eval_times = []
|
eval_times = []
|
||||||
write_times = []
|
write_times = []
|
||||||
|
|
||||||
|
# Batch accumulation for JAX
|
||||||
|
batch_deferred = [] # Accumulated DeferredEffectChains
|
||||||
|
batch_times = [] # Corresponding time values
|
||||||
|
batch_start_frame = 0
|
||||||
|
|
||||||
|
def flush_batch():
|
||||||
|
"""Execute accumulated batch and write results."""
|
||||||
|
nonlocal batch_deferred, batch_times
|
||||||
|
if not batch_deferred:
|
||||||
|
return
|
||||||
|
|
||||||
|
t_flush = time.time()
|
||||||
|
|
||||||
|
# Check if all chains have same structure (can batch)
|
||||||
|
first = batch_deferred[0]
|
||||||
|
can_batch = (
|
||||||
|
self.use_jax and
|
||||||
|
len(batch_deferred) >= 2 and
|
||||||
|
all(d.effects == first.effects for d in batch_deferred)
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_batch:
|
||||||
|
# Try batched execution
|
||||||
|
frames = [d.base_frame for d in batch_deferred]
|
||||||
|
ts = [d.t for d in batch_deferred]
|
||||||
|
frame_nums = [d.frame_num for d in batch_deferred]
|
||||||
|
params_batch = [d.params_list for d in batch_deferred]
|
||||||
|
|
||||||
|
results = self._apply_batched_chain(
|
||||||
|
first.effects, params_batch, frames, ts, frame_nums
|
||||||
|
)
|
||||||
|
|
||||||
|
if results is not None:
|
||||||
|
# Write batched results
|
||||||
|
for result, t in zip(results, batch_times):
|
||||||
|
if hasattr(result, 'block_until_ready'):
|
||||||
|
result.block_until_ready()
|
||||||
|
result = np.asarray(result)
|
||||||
|
out.write(result, t)
|
||||||
|
batch_deferred = []
|
||||||
|
batch_times = []
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fall back to sequential execution
|
||||||
|
for deferred, t in zip(batch_deferred, batch_times):
|
||||||
|
result = self._force_deferred(deferred)
|
||||||
|
if result is not None and hasattr(result, 'shape'):
|
||||||
|
if hasattr(result, 'block_until_ready'):
|
||||||
|
result.block_until_ready()
|
||||||
|
result = np.asarray(result)
|
||||||
|
out.write(result, t)
|
||||||
|
|
||||||
|
batch_deferred = []
|
||||||
|
batch_times = []
|
||||||
|
|
||||||
for frame_num in range(start_frame, n_frames):
|
for frame_num in range(start_frame, n_frames):
|
||||||
if not out.is_open:
|
if not out.is_open:
|
||||||
break
|
break
|
||||||
@@ -1182,8 +1647,23 @@ class StreamInterpreter:
|
|||||||
eval_times.append(time.time() - t1)
|
eval_times.append(time.time() - t1)
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
if result is not None and hasattr(result, 'shape'):
|
if result is not None:
|
||||||
out.write(result, ctx.t)
|
if isinstance(result, DeferredEffectChain):
|
||||||
|
# Accumulate for batching
|
||||||
|
batch_deferred.append(result)
|
||||||
|
batch_times.append(ctx.t)
|
||||||
|
|
||||||
|
# Flush when batch is full
|
||||||
|
if len(batch_deferred) >= self.jax_batch_size:
|
||||||
|
flush_batch()
|
||||||
|
else:
|
||||||
|
# Not deferred - flush any pending batch first, then write
|
||||||
|
flush_batch()
|
||||||
|
if hasattr(result, 'shape'):
|
||||||
|
if hasattr(result, 'block_until_ready'):
|
||||||
|
result.block_until_ready()
|
||||||
|
result = np.asarray(result)
|
||||||
|
out.write(result, ctx.t)
|
||||||
write_times.append(time.time() - t2)
|
write_times.append(time.time() - t2)
|
||||||
|
|
||||||
frame_elapsed = time.time() - frame_start
|
frame_elapsed = time.time() - frame_start
|
||||||
@@ -1219,6 +1699,9 @@ class StreamInterpreter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
|
print(f"Warning: progress callback failed: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Flush any remaining batch
|
||||||
|
flush_batch()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
out.close()
|
out.close()
|
||||||
# Store output for access to properties like playlist_cid
|
# Store output for access to properties like playlist_cid
|
||||||
|
|||||||
542
test_funky_text.py
Normal file
542
test_funky_text.py
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Funky comparison tests: PIL vs TextStrip system.
|
||||||
|
Tests colors, opacity, fonts, sizes, edge positions, clipping, overlaps, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from streaming.jax_typography import (
|
||||||
|
render_text_strip, place_text_strip_jax, _load_font
|
||||||
|
)
|
||||||
|
|
||||||
|
FONTS = {
|
||||||
|
'sans': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||||
|
'bold': '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
|
||||||
|
'serif': '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf',
|
||||||
|
'serif_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf',
|
||||||
|
'mono': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf',
|
||||||
|
'mono_bold': '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf',
|
||||||
|
'narrow': '/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf',
|
||||||
|
'italic': '/usr/share/fonts/truetype/freefont/FreeSansOblique.ttf',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def render_pil(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
|
||||||
|
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
|
||||||
|
stroke_width=0, stroke_fill=None, anchor="la",
|
||||||
|
multiline=False, line_spacing=4, align="left"):
|
||||||
|
"""Render with PIL directly, including color/opacity."""
|
||||||
|
frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
|
||||||
|
# For opacity, render to RGBA then composite
|
||||||
|
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(txt_layer)
|
||||||
|
font = _load_font(font_path, font_size)
|
||||||
|
|
||||||
|
if stroke_fill is None:
|
||||||
|
stroke_fill = (0, 0, 0)
|
||||||
|
|
||||||
|
# PIL fill with alpha for opacity
|
||||||
|
alpha_int = int(round(opacity * 255))
|
||||||
|
fill_rgba = (*fill, alpha_int)
|
||||||
|
stroke_rgba = (*stroke_fill, alpha_int) if stroke_width > 0 else None
|
||||||
|
|
||||||
|
if multiline:
|
||||||
|
draw.multiline_text((x, y), text, fill=fill_rgba, font=font,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_rgba,
|
||||||
|
spacing=line_spacing, align=align, anchor=anchor)
|
||||||
|
else:
|
||||||
|
draw.text((x, y), text, fill=fill_rgba, font=font,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_rgba, anchor=anchor)
|
||||||
|
|
||||||
|
# Composite onto background
|
||||||
|
bg_img = Image.fromarray(frame).convert('RGBA')
|
||||||
|
result = Image.alpha_composite(bg_img, txt_layer)
|
||||||
|
return np.array(result.convert('RGB'))
|
||||||
|
|
||||||
|
|
||||||
|
def render_strip(text, x, y, font_path=None, font_size=36, frame_size=(400, 100),
|
||||||
|
fill=(255, 255, 255), bg=(0, 0, 0), opacity=1.0,
|
||||||
|
stroke_width=0, stroke_fill=None, anchor="la",
|
||||||
|
multiline=False, line_spacing=4, align="left"):
|
||||||
|
"""Render with TextStrip system."""
|
||||||
|
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
|
||||||
|
|
||||||
|
strip = render_text_strip(
|
||||||
|
text, font_path, font_size,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||||
|
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
|
||||||
|
)
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array(fill, dtype=jnp.float32)
|
||||||
|
|
||||||
|
result = place_text_strip_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color, opacity,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width
|
||||||
|
)
|
||||||
|
return np.array(result)
|
||||||
|
|
||||||
|
|
||||||
|
def compare(name, tolerance=0, **kwargs):
|
||||||
|
"""Compare PIL and TextStrip rendering."""
|
||||||
|
pil = render_pil(**kwargs)
|
||||||
|
strip = render_strip(**kwargs)
|
||||||
|
|
||||||
|
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
|
||||||
|
max_diff = diff.max()
|
||||||
|
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||||
|
|
||||||
|
if max_diff == 0:
|
||||||
|
print(f" PASS: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if tolerance > 0:
|
||||||
|
best_diff = diff.copy()
|
||||||
|
for dy in range(-tolerance, tolerance + 1):
|
||||||
|
for dx in range(-tolerance, tolerance + 1):
|
||||||
|
if dy == 0 and dx == 0:
|
||||||
|
continue
|
||||||
|
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
|
||||||
|
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
|
||||||
|
best_diff = np.minimum(best_diff, sdiff)
|
||||||
|
if best_diff.max() == 0:
|
||||||
|
print(f" PASS: {name} (within {tolerance}px tolerance)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(f" FAIL: {name}")
|
||||||
|
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||||
|
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
|
||||||
|
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
|
||||||
|
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||||
|
Image.fromarray(diff_vis).save(f"/tmp/diff_{name}.png")
|
||||||
|
print(f" Saved: /tmp/pil_{name}.png /tmp/strip_{name}.png /tmp/diff_{name}.png")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_colors():
|
||||||
|
"""Test various text colors on various backgrounds."""
|
||||||
|
print("\n--- Colors ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# White on black (baseline)
|
||||||
|
results.append(compare("color_white_on_black",
|
||||||
|
text="Hello", x=20, y=30, fill=(255, 255, 255), bg=(0, 0, 0)))
|
||||||
|
|
||||||
|
# Red on black
|
||||||
|
results.append(compare("color_red",
|
||||||
|
text="Red Text", x=20, y=30, fill=(255, 0, 0), bg=(0, 0, 0)))
|
||||||
|
|
||||||
|
# Green on black
|
||||||
|
results.append(compare("color_green",
|
||||||
|
text="Green!", x=20, y=30, fill=(0, 255, 0), bg=(0, 0, 0)))
|
||||||
|
|
||||||
|
# Blue on black
|
||||||
|
results.append(compare("color_blue",
|
||||||
|
text="Blue sky", x=20, y=30, fill=(0, 100, 255), bg=(0, 0, 0)))
|
||||||
|
|
||||||
|
# Yellow on dark gray
|
||||||
|
results.append(compare("color_yellow_on_gray",
|
||||||
|
text="Yellow", x=20, y=30, fill=(255, 255, 0), bg=(40, 40, 40)))
|
||||||
|
|
||||||
|
# Magenta on white
|
||||||
|
results.append(compare("color_magenta_on_white",
|
||||||
|
text="Magenta", x=20, y=30, fill=(255, 0, 255), bg=(255, 255, 255)))
|
||||||
|
|
||||||
|
# Subtle: gray text on slightly lighter gray
|
||||||
|
results.append(compare("color_subtle_gray",
|
||||||
|
text="Subtle", x=20, y=30, fill=(128, 128, 128), bg=(64, 64, 64)))
|
||||||
|
|
||||||
|
# Orange on deep blue
|
||||||
|
results.append(compare("color_orange_on_blue",
|
||||||
|
text="Warm", x=20, y=30, fill=(255, 165, 0), bg=(0, 0, 80)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_opacity():
|
||||||
|
"""Test different opacity levels."""
|
||||||
|
print("\n--- Opacity ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
results.append(compare("opacity_100",
|
||||||
|
text="Full", x=20, y=30, opacity=1.0))
|
||||||
|
|
||||||
|
results.append(compare("opacity_75",
|
||||||
|
text="75%", x=20, y=30, opacity=0.75))
|
||||||
|
|
||||||
|
results.append(compare("opacity_50",
|
||||||
|
text="Half", x=20, y=30, opacity=0.5))
|
||||||
|
|
||||||
|
results.append(compare("opacity_25",
|
||||||
|
text="Quarter", x=20, y=30, opacity=0.25))
|
||||||
|
|
||||||
|
results.append(compare("opacity_10",
|
||||||
|
text="Ghost", x=20, y=30, opacity=0.1))
|
||||||
|
|
||||||
|
# Opacity on colored background
|
||||||
|
results.append(compare("opacity_on_colored_bg",
|
||||||
|
text="Overlay", x=20, y=30, fill=(255, 255, 255), bg=(100, 0, 0),
|
||||||
|
opacity=0.5))
|
||||||
|
|
||||||
|
# Color + opacity combo
|
||||||
|
results.append(compare("opacity_red_on_green",
|
||||||
|
text="Blend", x=20, y=30, fill=(255, 0, 0), bg=(0, 100, 0),
|
||||||
|
opacity=0.6))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_fonts():
|
||||||
|
"""Test different fonts and sizes."""
|
||||||
|
print("\n--- Fonts & Sizes ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for label, path in FONTS.items():
|
||||||
|
results.append(compare(f"font_{label}",
|
||||||
|
text="Quick Fox", x=20, y=30, font_path=path, font_size=28,
|
||||||
|
frame_size=(300, 80)))
|
||||||
|
|
||||||
|
# Tiny text
|
||||||
|
results.append(compare("size_tiny",
|
||||||
|
text="Tiny text at 12px", x=10, y=15, font_size=12,
|
||||||
|
frame_size=(200, 40)))
|
||||||
|
|
||||||
|
# Big text
|
||||||
|
results.append(compare("size_big",
|
||||||
|
text="BIG", x=20, y=10, font_size=72,
|
||||||
|
frame_size=(300, 100)))
|
||||||
|
|
||||||
|
# Huge text
|
||||||
|
results.append(compare("size_huge",
|
||||||
|
text="XL", x=10, y=10, font_size=120,
|
||||||
|
frame_size=(300, 160)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_anchors():
|
||||||
|
"""Test all anchor combinations."""
|
||||||
|
print("\n--- Anchors ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# All horizontal x vertical combos
|
||||||
|
for h in ['l', 'm', 'r']:
|
||||||
|
for v in ['a', 'm', 's', 'd']:
|
||||||
|
anchor = h + v
|
||||||
|
results.append(compare(f"anchor_{anchor}",
|
||||||
|
text="Anchor", x=200, y=50, anchor=anchor,
|
||||||
|
frame_size=(400, 100)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_strokes():
|
||||||
|
"""Test various stroke widths and colors."""
|
||||||
|
print("\n--- Strokes ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for sw in [1, 2, 3, 4, 6, 8]:
|
||||||
|
results.append(compare(f"stroke_w{sw}",
|
||||||
|
text="Stroke", x=30, y=20, font_size=40,
|
||||||
|
stroke_width=sw, stroke_fill=(0, 0, 0),
|
||||||
|
frame_size=(300, 80)))
|
||||||
|
|
||||||
|
# Colored strokes
|
||||||
|
results.append(compare("stroke_red",
|
||||||
|
text="Red outline", x=20, y=20, font_size=36,
|
||||||
|
stroke_width=3, stroke_fill=(255, 0, 0),
|
||||||
|
frame_size=(350, 80)))
|
||||||
|
|
||||||
|
results.append(compare("stroke_white_on_black",
|
||||||
|
text="Glow", x=20, y=20, font_size=40,
|
||||||
|
fill=(255, 255, 255), stroke_width=4, stroke_fill=(0, 0, 255),
|
||||||
|
frame_size=(250, 80)))
|
||||||
|
|
||||||
|
# Stroke with bold font
|
||||||
|
results.append(compare("stroke_bold",
|
||||||
|
text="Bold+Stroke", x=20, y=20,
|
||||||
|
font_path=FONTS['bold'], font_size=36,
|
||||||
|
stroke_width=3, stroke_fill=(0, 0, 0),
|
||||||
|
frame_size=(400, 80)))
|
||||||
|
|
||||||
|
# Stroke + colored text on colored bg
|
||||||
|
results.append(compare("stroke_colored_on_bg",
|
||||||
|
text="Party", x=20, y=20, font_size=48,
|
||||||
|
fill=(255, 255, 0), bg=(50, 0, 80),
|
||||||
|
stroke_width=3, stroke_fill=(255, 0, 0),
|
||||||
|
frame_size=(300, 80)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_clipping():
|
||||||
|
"""Test text at frame edges - clipping behavior."""
|
||||||
|
print("\n--- Edge Clipping ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Text at very left edge
|
||||||
|
results.append(compare("clip_left_edge",
|
||||||
|
text="LEFT", x=0, y=30, frame_size=(200, 80)))
|
||||||
|
|
||||||
|
# Text partially off right edge
|
||||||
|
results.append(compare("clip_right_edge",
|
||||||
|
text="RIGHT SIDE", x=150, y=30, frame_size=(200, 80)))
|
||||||
|
|
||||||
|
# Text at top edge
|
||||||
|
results.append(compare("clip_top",
|
||||||
|
text="TOP", x=20, y=0, frame_size=(200, 80)))
|
||||||
|
|
||||||
|
# Text at bottom edge - partially clipped
|
||||||
|
results.append(compare("clip_bottom",
|
||||||
|
text="BOTTOM", x=20, y=55, font_size=40,
|
||||||
|
frame_size=(200, 80)))
|
||||||
|
|
||||||
|
# Large text overflowing all sides from center
|
||||||
|
results.append(compare("clip_overflow_center",
|
||||||
|
text="HUGE", x=75, y=40, font_size=100, anchor="mm",
|
||||||
|
frame_size=(150, 80)))
|
||||||
|
|
||||||
|
# Corner placement
|
||||||
|
results.append(compare("clip_corner_br",
|
||||||
|
text="Corner", x=350, y=70, font_size=36,
|
||||||
|
frame_size=(400, 100)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiline_fancy():
|
||||||
|
"""Test multiline with various styles."""
|
||||||
|
print("\n--- Multiline Fancy ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Right-aligned (1px tolerance: same sub-pixel issue as center alignment)
|
||||||
|
results.append(compare("multi_right",
|
||||||
|
text="Right\nAligned\nText", x=380, y=20,
|
||||||
|
frame_size=(400, 150),
|
||||||
|
multiline=True, anchor="ra", align="right",
|
||||||
|
tolerance=1))
|
||||||
|
|
||||||
|
# Center + stroke
|
||||||
|
results.append(compare("multi_center_stroke",
|
||||||
|
text="Title\nSubtitle", x=200, y=20,
|
||||||
|
font_size=32, frame_size=(400, 120),
|
||||||
|
multiline=True, anchor="ma", align="center",
|
||||||
|
stroke_width=2, stroke_fill=(0, 0, 0),
|
||||||
|
tolerance=1))
|
||||||
|
|
||||||
|
# Wide line spacing
|
||||||
|
results.append(compare("multi_wide_spacing",
|
||||||
|
text="Line A\nLine B\nLine C", x=20, y=10,
|
||||||
|
frame_size=(300, 200),
|
||||||
|
multiline=True, line_spacing=20))
|
||||||
|
|
||||||
|
# Zero extra spacing
|
||||||
|
results.append(compare("multi_tight_spacing",
|
||||||
|
text="Tight\nPacked\nLines", x=20, y=10,
|
||||||
|
frame_size=(300, 150),
|
||||||
|
multiline=True, line_spacing=0))
|
||||||
|
|
||||||
|
# Many lines
|
||||||
|
results.append(compare("multi_many_lines",
|
||||||
|
text="One\nTwo\nThree\nFour\nFive\nSix", x=20, y=5,
|
||||||
|
font_size=20, frame_size=(200, 200),
|
||||||
|
multiline=True, line_spacing=4))
|
||||||
|
|
||||||
|
# Bold multiline with stroke
|
||||||
|
results.append(compare("multi_bold_stroke",
|
||||||
|
text="BOLD\nSTROKE", x=20, y=10,
|
||||||
|
font_path=FONTS['bold'], font_size=48,
|
||||||
|
stroke_width=3, stroke_fill=(200, 0, 0),
|
||||||
|
frame_size=(350, 150), multiline=True))
|
||||||
|
|
||||||
|
# Multiline on colored bg with opacity
|
||||||
|
results.append(compare("multi_opacity_on_bg",
|
||||||
|
text="Semi\nTransparent", x=20, y=10,
|
||||||
|
fill=(255, 255, 0), bg=(0, 50, 100), opacity=0.7,
|
||||||
|
frame_size=(300, 120), multiline=True))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_special_chars():
|
||||||
|
"""Test special characters and edge cases."""
|
||||||
|
print("\n--- Special Characters ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Numbers and symbols
|
||||||
|
results.append(compare("chars_numbers",
|
||||||
|
text="0123456789", x=20, y=30, frame_size=(300, 80)))
|
||||||
|
|
||||||
|
results.append(compare("chars_punctuation",
|
||||||
|
text="Hello, World! (v2.0)", x=10, y=30, frame_size=(350, 80)))
|
||||||
|
|
||||||
|
results.append(compare("chars_symbols",
|
||||||
|
text="@#$%^&*+=", x=20, y=30, frame_size=(300, 80)))
|
||||||
|
|
||||||
|
# Single character
|
||||||
|
results.append(compare("chars_single",
|
||||||
|
text="X", x=50, y=30, font_size=48, frame_size=(100, 80)))
|
||||||
|
|
||||||
|
# Very long text (clipped)
|
||||||
|
results.append(compare("chars_long",
|
||||||
|
text="The quick brown fox jumps over the lazy dog", x=10, y=30,
|
||||||
|
font_size=24, frame_size=(400, 80)))
|
||||||
|
|
||||||
|
# Mixed case
|
||||||
|
results.append(compare("chars_mixed_case",
|
||||||
|
text="AaBbCcDdEeFf", x=10, y=30, frame_size=(350, 80)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_combos():
|
||||||
|
"""Complex combinations of features."""
|
||||||
|
print("\n--- Combos ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Big bold stroke + color + opacity
|
||||||
|
results.append(compare("combo_all_features",
|
||||||
|
text="EPIC", x=20, y=10,
|
||||||
|
font_path=FONTS['bold'], font_size=64,
|
||||||
|
fill=(255, 200, 0), bg=(20, 0, 40), opacity=0.85,
|
||||||
|
stroke_width=4, stroke_fill=(180, 0, 0),
|
||||||
|
frame_size=(350, 100)))
|
||||||
|
|
||||||
|
# Small mono on busy background
|
||||||
|
results.append(compare("combo_mono_code",
|
||||||
|
text="fn main() {}", x=10, y=15,
|
||||||
|
font_path=FONTS['mono'], font_size=16,
|
||||||
|
fill=(0, 255, 100), bg=(30, 30, 30),
|
||||||
|
frame_size=(250, 50)))
|
||||||
|
|
||||||
|
# Serif italic multiline with stroke
|
||||||
|
results.append(compare("combo_serif_italic_multi",
|
||||||
|
text="Once upon\na time...", x=20, y=10,
|
||||||
|
font_path=FONTS['italic'], font_size=28,
|
||||||
|
stroke_width=1, stroke_fill=(80, 80, 80),
|
||||||
|
frame_size=(300, 120), multiline=True))
|
||||||
|
|
||||||
|
# Narrow font, big stroke, center anchored
|
||||||
|
results.append(compare("combo_narrow_stroke_center",
|
||||||
|
text="NARROW", x=150, y=40,
|
||||||
|
font_path=FONTS['narrow'], font_size=44,
|
||||||
|
stroke_width=5, stroke_fill=(0, 0, 0),
|
||||||
|
anchor="mm", frame_size=(300, 80)))
|
||||||
|
|
||||||
|
# Multiple strips on same frame (simulated by sequential placement)
|
||||||
|
results.append(compare("combo_opacity_blend",
|
||||||
|
text="Layered", x=20, y=30,
|
||||||
|
fill=(255, 0, 0), bg=(0, 0, 255), opacity=0.5,
|
||||||
|
font_size=48, frame_size=(300, 80)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_strip_overlay():
|
||||||
|
"""Test placing multiple strips on the same frame."""
|
||||||
|
print("\n--- Multi-Strip Overlay ---")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
frame_size = (400, 150)
|
||||||
|
bg = (20, 20, 40)
|
||||||
|
|
||||||
|
# PIL version - multiple draw calls
|
||||||
|
font1 = _load_font(FONTS['bold'], 48)
|
||||||
|
font2 = _load_font(None, 24)
|
||||||
|
font3 = _load_font(FONTS['mono'], 18)
|
||||||
|
|
||||||
|
pil_frame = np.full((frame_size[1], frame_size[0], 3), bg, dtype=np.uint8)
|
||||||
|
txt_layer = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(txt_layer)
|
||||||
|
draw.text((20, 10), "TITLE", fill=(255, 255, 0, 255), font=font1,
|
||||||
|
stroke_width=2, stroke_fill=(0, 0, 0, 255))
|
||||||
|
draw.text((20, 70), "Subtitle here", fill=(200, 200, 200, 255), font=font2)
|
||||||
|
draw.text((20, 110), "code_snippet()", fill=(0, 255, 128, 200), font=font3)
|
||||||
|
bg_img = Image.fromarray(pil_frame).convert('RGBA')
|
||||||
|
pil_result = np.array(Image.alpha_composite(bg_img, txt_layer).convert('RGB'))
|
||||||
|
|
||||||
|
# Strip version - multiple placements
|
||||||
|
frame = jnp.full((frame_size[1], frame_size[0], 3), jnp.array(bg, dtype=jnp.uint8), dtype=jnp.uint8)
|
||||||
|
|
||||||
|
s1 = render_text_strip("TITLE", FONTS['bold'], 48, stroke_width=2, stroke_fill=(0, 0, 0))
|
||||||
|
s2 = render_text_strip("Subtitle here", None, 24)
|
||||||
|
s3 = render_text_strip("code_snippet()", FONTS['mono'], 18)
|
||||||
|
|
||||||
|
frame = place_text_strip_jax(
|
||||||
|
frame, jnp.asarray(s1.image), 20, 10,
|
||||||
|
s1.baseline_y, s1.bearing_x,
|
||||||
|
jnp.array([255, 255, 0], dtype=jnp.float32), 1.0,
|
||||||
|
anchor_x=s1.anchor_x, anchor_y=s1.anchor_y,
|
||||||
|
stroke_width=s1.stroke_width)
|
||||||
|
|
||||||
|
frame = place_text_strip_jax(
|
||||||
|
frame, jnp.asarray(s2.image), 20, 70,
|
||||||
|
s2.baseline_y, s2.bearing_x,
|
||||||
|
jnp.array([200, 200, 200], dtype=jnp.float32), 1.0,
|
||||||
|
anchor_x=s2.anchor_x, anchor_y=s2.anchor_y,
|
||||||
|
stroke_width=s2.stroke_width)
|
||||||
|
|
||||||
|
frame = place_text_strip_jax(
|
||||||
|
frame, jnp.asarray(s3.image), 20, 110,
|
||||||
|
s3.baseline_y, s3.bearing_x,
|
||||||
|
jnp.array([0, 255, 128], dtype=jnp.float32), 200/255,
|
||||||
|
anchor_x=s3.anchor_x, anchor_y=s3.anchor_y,
|
||||||
|
stroke_width=s3.stroke_width)
|
||||||
|
|
||||||
|
strip_result = np.array(frame)
|
||||||
|
|
||||||
|
diff = np.abs(pil_result.astype(np.int16) - strip_result.astype(np.int16))
|
||||||
|
max_diff = diff.max()
|
||||||
|
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||||
|
|
||||||
|
if max_diff <= 1:
|
||||||
|
print(f" PASS: multi_overlay (max_diff={max_diff})")
|
||||||
|
results.append(True)
|
||||||
|
else:
|
||||||
|
print(f" FAIL: multi_overlay")
|
||||||
|
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||||
|
Image.fromarray(pil_result).save("/tmp/pil_multi_overlay.png")
|
||||||
|
Image.fromarray(strip_result).save("/tmp/strip_multi_overlay.png")
|
||||||
|
diff_vis = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||||
|
Image.fromarray(diff_vis).save("/tmp/diff_multi_overlay.png")
|
||||||
|
print(f" Saved: /tmp/pil_multi_overlay.png /tmp/strip_multi_overlay.png /tmp/diff_multi_overlay.png")
|
||||||
|
results.append(False)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Funky TextStrip vs PIL Comparison")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
all_results.extend(test_colors())
|
||||||
|
all_results.extend(test_opacity())
|
||||||
|
all_results.extend(test_fonts())
|
||||||
|
all_results.extend(test_anchors())
|
||||||
|
all_results.extend(test_strokes())
|
||||||
|
all_results.extend(test_edge_clipping())
|
||||||
|
all_results.extend(test_multiline_fancy())
|
||||||
|
all_results.extend(test_special_chars())
|
||||||
|
all_results.extend(test_combos())
|
||||||
|
all_results.extend(test_multi_strip_overlay())
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
passed = sum(all_results)
|
||||||
|
total = len(all_results)
|
||||||
|
print(f"Results: {passed}/{total} passed")
|
||||||
|
if passed == total:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
else:
|
||||||
|
failed = [i for i, r in enumerate(all_results) if not r]
|
||||||
|
print(f"FAILED: {total - passed} tests (indices: {failed})")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
183
test_pil_options.py
Normal file
183
test_pil_options.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Explore PIL text options and test if we can match them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
def load_font(font_name=None, font_size=32):
|
||||||
|
"""Load a font."""
|
||||||
|
candidates = [
|
||||||
|
font_name,
|
||||||
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
|
||||||
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf',
|
||||||
|
'/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf',
|
||||||
|
]
|
||||||
|
for path in candidates:
|
||||||
|
if path is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(path, font_size)
|
||||||
|
except (IOError, OSError):
|
||||||
|
continue
|
||||||
|
return ImageFont.load_default()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pil_options():
|
||||||
|
"""Test various PIL text options."""
|
||||||
|
|
||||||
|
# Create a test frame
|
||||||
|
frame_size = (600, 400)
|
||||||
|
|
||||||
|
font = load_font(None, 36)
|
||||||
|
font_bold = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 36)
|
||||||
|
font_italic = load_font('/usr/share/fonts/truetype/dejavu/DejaVuSans-Oblique.ttf', 36)
|
||||||
|
|
||||||
|
tests = []
|
||||||
|
|
||||||
|
# Test 1: Basic text
|
||||||
|
def basic_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Basic Text", fill=(255, 255, 255, 255), font=font)
|
||||||
|
return img, "basic"
|
||||||
|
tests.append(basic_text)
|
||||||
|
|
||||||
|
# Test 2: Stroke/outline
|
||||||
|
def stroke_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Stroke Text", fill=(255, 255, 255, 255), font=font,
|
||||||
|
stroke_width=2, stroke_fill=(255, 0, 0, 255))
|
||||||
|
return img, "stroke"
|
||||||
|
tests.append(stroke_text)
|
||||||
|
|
||||||
|
# Test 3: Bold font
|
||||||
|
def bold_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Bold Text", fill=(255, 255, 255, 255), font=font_bold)
|
||||||
|
return img, "bold"
|
||||||
|
tests.append(bold_text)
|
||||||
|
|
||||||
|
# Test 4: Italic font
|
||||||
|
def italic_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Italic Text", fill=(255, 255, 255, 255), font=font_italic)
|
||||||
|
return img, "italic"
|
||||||
|
tests.append(italic_text)
|
||||||
|
|
||||||
|
# Test 5: Different anchors
|
||||||
|
def anchor_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
# Draw crosshairs at anchor points
|
||||||
|
for x in [100, 300, 500]:
|
||||||
|
draw.line([(x-10, 50), (x+10, 50)], fill=(100, 100, 100, 255))
|
||||||
|
draw.line([(x, 40), (x, 60)], fill=(100, 100, 100, 255))
|
||||||
|
|
||||||
|
draw.text((100, 50), "Left", fill=(255, 255, 255, 255), font=font, anchor="lm")
|
||||||
|
draw.text((300, 50), "Center", fill=(255, 255, 255, 255), font=font, anchor="mm")
|
||||||
|
draw.text((500, 50), "Right", fill=(255, 255, 255, 255), font=font, anchor="rm")
|
||||||
|
return img, "anchor"
|
||||||
|
tests.append(anchor_text)
|
||||||
|
|
||||||
|
# Test 6: Multiline text
|
||||||
|
def multiline_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.multiline_text((20, 20), "Line One\nLine Two\nLine Three",
|
||||||
|
fill=(255, 255, 255, 255), font=font, spacing=10)
|
||||||
|
return img, "multiline"
|
||||||
|
tests.append(multiline_text)
|
||||||
|
|
||||||
|
# Test 7: Semi-transparent text
|
||||||
|
def alpha_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Alpha 100%", fill=(255, 255, 255, 255), font=font)
|
||||||
|
draw.text((20, 60), "Alpha 50%", fill=(255, 255, 255, 128), font=font)
|
||||||
|
draw.text((20, 100), "Alpha 25%", fill=(255, 255, 255, 64), font=font)
|
||||||
|
return img, "alpha"
|
||||||
|
tests.append(alpha_text)
|
||||||
|
|
||||||
|
# Test 8: Colored text
|
||||||
|
def colored_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Red", fill=(255, 0, 0, 255), font=font)
|
||||||
|
draw.text((20, 60), "Green", fill=(0, 255, 0, 255), font=font)
|
||||||
|
draw.text((20, 100), "Blue", fill=(0, 0, 255, 255), font=font)
|
||||||
|
draw.text((20, 140), "Yellow", fill=(255, 255, 0, 255), font=font)
|
||||||
|
return img, "colored"
|
||||||
|
tests.append(colored_text)
|
||||||
|
|
||||||
|
# Test 9: Large stroke
|
||||||
|
def large_stroke():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.text((20, 20), "Big Stroke", fill=(255, 255, 255, 255), font=font,
|
||||||
|
stroke_width=5, stroke_fill=(0, 0, 0, 255))
|
||||||
|
return img, "large_stroke"
|
||||||
|
tests.append(large_stroke)
|
||||||
|
|
||||||
|
# Test 10: Emoji (if supported)
|
||||||
|
def emoji_text():
|
||||||
|
img = Image.new('RGBA', frame_size, (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
# Try to find an emoji font
|
||||||
|
emoji_font = None
|
||||||
|
emoji_paths = [
|
||||||
|
'/usr/share/fonts/truetype/noto/NotoColorEmoji.ttf',
|
||||||
|
'/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf',
|
||||||
|
]
|
||||||
|
for p in emoji_paths:
|
||||||
|
try:
|
||||||
|
emoji_font = ImageFont.truetype(p, 36)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if emoji_font:
|
||||||
|
draw.text((20, 20), "Hello 🎵 World 🎸", fill=(255, 255, 255, 255), font=emoji_font)
|
||||||
|
else:
|
||||||
|
draw.text((20, 20), "No emoji font found", fill=(255, 255, 255, 255), font=font)
|
||||||
|
except Exception as e:
|
||||||
|
draw.text((20, 20), f"Emoji error: {e}", fill=(255, 255, 255, 255), font=font)
|
||||||
|
return img, "emoji"
|
||||||
|
tests.append(emoji_text)
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
print("PIL Text Options Test")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for test_fn in tests:
|
||||||
|
img, name = test_fn()
|
||||||
|
fname = f"/tmp/pil_test_{name}.png"
|
||||||
|
img.save(fname)
|
||||||
|
print(f"Saved: {fname}")
|
||||||
|
|
||||||
|
print("\nCheck /tmp/pil_test_*.png for results")
|
||||||
|
|
||||||
|
# Print available parameters
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("PIL draw.text() parameters:")
|
||||||
|
print(" - xy: position tuple")
|
||||||
|
print(" - text: string to draw")
|
||||||
|
print(" - fill: color (R,G,B) or (R,G,B,A)")
|
||||||
|
print(" - font: ImageFont object")
|
||||||
|
print(" - anchor: 2-char code (la=left-ascender, mm=middle-middle, etc.)")
|
||||||
|
print(" - spacing: line spacing for multiline")
|
||||||
|
print(" - align: 'left', 'center', 'right' for multiline")
|
||||||
|
print(" - direction: 'rtl', 'ltr', 'ttb' (requires libraqm)")
|
||||||
|
print(" - features: OpenType features list")
|
||||||
|
print(" - language: language code for shaping")
|
||||||
|
print(" - stroke_width: outline width in pixels")
|
||||||
|
print(" - stroke_fill: outline color")
|
||||||
|
print(" - embedded_color: use embedded color glyphs (emoji)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_pil_options()
|
||||||
176
test_styled_text.py
Normal file
176
test_styled_text.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test styled TextStrip rendering against PIL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
from streaming.jax_typography import (
|
||||||
|
render_text_strip, place_text_strip_jax, _load_font
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_pil(text, x, y, font_size=36, frame_size=(400, 100),
|
||||||
|
stroke_width=0, stroke_fill=None, anchor="la",
|
||||||
|
multiline=False, line_spacing=4, align="left"):
|
||||||
|
"""Render with PIL directly."""
|
||||||
|
frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8)
|
||||||
|
img = Image.fromarray(frame)
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
font = _load_font(None, font_size)
|
||||||
|
|
||||||
|
# Default stroke fill
|
||||||
|
if stroke_fill is None:
|
||||||
|
stroke_fill = (0, 0, 0)
|
||||||
|
|
||||||
|
if multiline:
|
||||||
|
draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||||
|
spacing=line_spacing, align=align, anchor=anchor)
|
||||||
|
else:
|
||||||
|
draw.text((x, y), text, fill=(255, 255, 255), font=font,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor)
|
||||||
|
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
|
||||||
|
def render_strip(text, x, y, font_size=36, frame_size=(400, 100),
|
||||||
|
stroke_width=0, stroke_fill=None, anchor="la",
|
||||||
|
multiline=False, line_spacing=4, align="left"):
|
||||||
|
"""Render with TextStrip."""
|
||||||
|
frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8)
|
||||||
|
|
||||||
|
strip = render_text_strip(
|
||||||
|
text, None, font_size,
|
||||||
|
stroke_width=stroke_width, stroke_fill=stroke_fill,
|
||||||
|
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
|
||||||
|
)
|
||||||
|
strip_img = jnp.asarray(strip.image)
|
||||||
|
color = jnp.array([255, 255, 255], dtype=jnp.float32)
|
||||||
|
|
||||||
|
result = place_text_strip_jax(
|
||||||
|
frame, strip_img, x, y,
|
||||||
|
strip.baseline_y, strip.bearing_x,
|
||||||
|
color, 1.0,
|
||||||
|
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
|
||||||
|
stroke_width=strip.stroke_width
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.array(result)
|
||||||
|
|
||||||
|
|
||||||
|
def compare(name, text, x, y, font_size=36, frame_size=(400, 100),
|
||||||
|
tolerance=0, **kwargs):
|
||||||
|
"""Compare PIL and TextStrip rendering.
|
||||||
|
|
||||||
|
tolerance=0: exact pixel match required
|
||||||
|
tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences
|
||||||
|
in center-aligned multiline text where the strip is pre-rendered
|
||||||
|
at a different base position than the final placement)
|
||||||
|
"""
|
||||||
|
pil = render_pil(text, x, y, font_size, frame_size, **kwargs)
|
||||||
|
strip = render_strip(text, x, y, font_size, frame_size, **kwargs)
|
||||||
|
|
||||||
|
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
|
||||||
|
max_diff = diff.max()
|
||||||
|
pixels_diff = (diff > 0).any(axis=2).sum()
|
||||||
|
|
||||||
|
if max_diff == 0:
|
||||||
|
print(f"PASS: {name}")
|
||||||
|
print(f" Max diff: 0, Pixels different: 0")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if tolerance > 0:
|
||||||
|
# Check if the difference is just a sub-pixel position shift:
|
||||||
|
# for each shifted version, compute the minimum diff
|
||||||
|
best_diff = diff.copy()
|
||||||
|
for dy in range(-tolerance, tolerance + 1):
|
||||||
|
for dx in range(-tolerance, tolerance + 1):
|
||||||
|
if dy == 0 and dx == 0:
|
||||||
|
continue
|
||||||
|
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
|
||||||
|
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
|
||||||
|
best_diff = np.minimum(best_diff, sdiff)
|
||||||
|
max_shift_diff = best_diff.max()
|
||||||
|
pixels_shift_diff = (best_diff > 0).any(axis=2).sum()
|
||||||
|
if max_shift_diff == 0:
|
||||||
|
print(f"PASS: {name} (within {tolerance}px position tolerance)")
|
||||||
|
print(f" Raw diff: {max_diff}, After shift tolerance: 0")
|
||||||
|
return True
|
||||||
|
|
||||||
|
status = "FAIL"
|
||||||
|
print(f"{status}: {name}")
|
||||||
|
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
|
||||||
|
|
||||||
|
# Save debug images
|
||||||
|
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
|
||||||
|
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
|
||||||
|
diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8)
|
||||||
|
Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png")
|
||||||
|
print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Styled TextStrip vs PIL Comparison")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Basic text
|
||||||
|
results.append(compare("basic", "Hello World", 20, 50))
|
||||||
|
|
||||||
|
# Stroke/outline
|
||||||
|
results.append(compare("stroke_2", "Outlined", 20, 50,
|
||||||
|
stroke_width=2, stroke_fill=(255, 0, 0)))
|
||||||
|
|
||||||
|
results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48,
|
||||||
|
frame_size=(500, 120),
|
||||||
|
stroke_width=5, stroke_fill=(0, 0, 0)))
|
||||||
|
|
||||||
|
# Anchors - center
|
||||||
|
results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100),
|
||||||
|
anchor="mm"))
|
||||||
|
|
||||||
|
# Anchors - right
|
||||||
|
results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100),
|
||||||
|
anchor="rm"))
|
||||||
|
|
||||||
|
# Multiline
|
||||||
|
results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20,
|
||||||
|
frame_size=(400, 150),
|
||||||
|
multiline=True, line_spacing=8))
|
||||||
|
|
||||||
|
# Multiline centered (1px tolerance: sub-pixel rendering differs because
|
||||||
|
# the strip is pre-rendered at an integer position while PIL's center
|
||||||
|
# alignment uses fractional getlength values for the 'm' anchor shift)
|
||||||
|
results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20,
|
||||||
|
frame_size=(400, 150),
|
||||||
|
multiline=True, anchor="ma", align="center",
|
||||||
|
tolerance=1))
|
||||||
|
|
||||||
|
# Stroke + multiline
|
||||||
|
results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20,
|
||||||
|
frame_size=(400, 120),
|
||||||
|
stroke_width=2, stroke_fill=(0, 0, 255),
|
||||||
|
multiline=True))
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
passed = sum(results)
|
||||||
|
total = len(results)
|
||||||
|
print(f"Results: {passed}/{total} passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
else:
|
||||||
|
print(f"FAILED: {total - passed} tests")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
517
tests/test_jax_pipeline_integration.py
Normal file
517
tests/test_jax_pipeline_integration.py
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Integration tests comparing JAX and Python rendering pipelines.
|
||||||
|
|
||||||
|
These tests ensure the JAX-compiled effect chains produce identical output
|
||||||
|
to the Python/NumPy path. They test:
|
||||||
|
1. Full effect pipelines through both interpreters
|
||||||
|
2. Multi-frame sequences (to catch phase-dependent bugs)
|
||||||
|
3. Compiled effect chain fusion
|
||||||
|
4. Edge cases like shrinking/zooming that affect boundary handling
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the art-celery module is importable
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from sexp_effects.primitive_libs import core as core_mod
|
||||||
|
|
||||||
|
|
||||||
|
# Path to test resources
|
||||||
|
TEST_DIR = Path('/home/giles/art/test')
|
||||||
|
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
|
||||||
|
TEMPLATES_DIR = TEST_DIR / 'templates'
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_image(h=96, w=128):
|
||||||
|
"""Create a test image with distinct patterns."""
|
||||||
|
import cv2
|
||||||
|
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Create gradient background
|
||||||
|
for y in range(h):
|
||||||
|
for x in range(w):
|
||||||
|
img[y, x] = [
|
||||||
|
int(255 * x / w), # R: horizontal gradient
|
||||||
|
int(255 * y / h), # G: vertical gradient
|
||||||
|
128 # B: constant
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add features
|
||||||
|
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
|
||||||
|
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def test_env(tmp_path_factory):
|
||||||
|
"""Set up test environment with sexp files and test media."""
|
||||||
|
test_dir = tmp_path_factory.mktemp('sexp_test')
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
os.chdir(test_dir)
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
(test_dir / 'effects').mkdir()
|
||||||
|
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create test image
|
||||||
|
import cv2
|
||||||
|
test_img = create_test_image()
|
||||||
|
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
|
||||||
|
|
||||||
|
# Copy required effect files
|
||||||
|
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
|
||||||
|
src = EFFECTS_DIR / f'{effect}.sexp'
|
||||||
|
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
|
||||||
|
if src.exists():
|
||||||
|
shutil.copy(src, dst)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
'dir': test_dir,
|
||||||
|
'image_path': test_dir / 'test_image.png',
|
||||||
|
'test_img': test_img,
|
||||||
|
}
|
||||||
|
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sexp_file(test_dir, content, filename='test.sexp'):
|
||||||
|
"""Create a test sexp file."""
|
||||||
|
path = test_dir / 'effects' / filename
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
f.write(content)
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJaxPythonPipelineEquivalence:
|
||||||
|
"""Test that JAX and Python pipelines produce equivalent output."""
|
||||||
|
|
||||||
|
def test_single_rotate_effect(self, test_env):
|
||||||
|
"""Test that a single rotate effect matches between Python and JAX."""
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
|
||||||
|
(frame (rotate frame :angle 15 :speed 0))
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = cv2.imread(str(test_env['image_path']))
|
||||||
|
|
||||||
|
# Python path
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||||
|
py_interp._init()
|
||||||
|
|
||||||
|
# JAX path
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
jax_interp._init()
|
||||||
|
|
||||||
|
ctx = Context(fps=10)
|
||||||
|
ctx.t = 0.5
|
||||||
|
ctx.frame_num = 5
|
||||||
|
|
||||||
|
frame_env = {
|
||||||
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||||
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Inject test image into globals
|
||||||
|
py_interp.globals['frame'] = test_img
|
||||||
|
jax_interp.globals['frame'] = test_img
|
||||||
|
|
||||||
|
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
|
||||||
|
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
|
||||||
|
|
||||||
|
# Force deferred if needed
|
||||||
|
py_result = np.asarray(py_interp._maybe_force(py_result))
|
||||||
|
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
|
||||||
|
|
||||||
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
|
||||||
|
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
|
||||||
|
|
||||||
|
def test_rotate_then_zoom(self, test_env):
|
||||||
|
"""Test rotate followed by zoom - tests effect chain fusion."""
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
|
||||||
|
(frame (-> (rotate frame :angle 15 :speed 0)
|
||||||
|
(zoom :amount 0.95 :speed 0)))
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = cv2.imread(str(test_env['image_path']))
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||||
|
py_interp._init()
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
jax_interp._init()
|
||||||
|
|
||||||
|
ctx = Context(fps=10)
|
||||||
|
ctx.t = 0.5
|
||||||
|
ctx.frame_num = 5
|
||||||
|
|
||||||
|
frame_env = {
|
||||||
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||||
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
py_interp.globals['frame'] = test_img
|
||||||
|
jax_interp.globals['frame'] = test_img
|
||||||
|
|
||||||
|
py_result = np.asarray(py_interp._maybe_force(
|
||||||
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||||
|
jax_result = np.asarray(jax_interp._maybe_force(
|
||||||
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||||
|
|
||||||
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
|
||||||
|
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
|
||||||
|
|
||||||
|
def test_zoom_shrink_boundary_handling(self, test_env):
|
||||||
|
"""Test zoom with shrinking factor - critical for boundary handling."""
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
|
||||||
|
(frame (zoom frame :amount 0.8 :speed 0))
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = cv2.imread(str(test_env['image_path']))
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||||
|
py_interp._init()
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
jax_interp._init()
|
||||||
|
|
||||||
|
ctx = Context(fps=10)
|
||||||
|
ctx.t = 0.5
|
||||||
|
ctx.frame_num = 5
|
||||||
|
|
||||||
|
frame_env = {
|
||||||
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||||
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
py_interp.globals['frame'] = test_img
|
||||||
|
jax_interp.globals['frame'] = test_img
|
||||||
|
|
||||||
|
py_result = np.asarray(py_interp._maybe_force(
|
||||||
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||||
|
jax_result = np.asarray(jax_interp._maybe_force(
|
||||||
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||||
|
|
||||||
|
# Check corners specifically - these are most affected by boundary handling
|
||||||
|
h, w = test_img.shape[:2]
|
||||||
|
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
|
||||||
|
for y, x in corners:
|
||||||
|
py_val = py_result[y, x]
|
||||||
|
jax_val = jax_result[y, x]
|
||||||
|
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
|
||||||
|
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
|
||||||
|
|
||||||
|
def test_blend_two_clips(self, test_env):
|
||||||
|
"""Test blending two effect chains - the core bug scenario."""
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(require-primitives "core")
|
||||||
|
(require-primitives "image")
|
||||||
|
(require-primitives "blending")
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||||
|
|
||||||
|
(frame
|
||||||
|
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||||
|
(zoom :amount 1.05 :speed 0))
|
||||||
|
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||||
|
(zoom :amount 0.95 :speed 0))]
|
||||||
|
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = cv2.imread(str(test_env['image_path']))
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||||
|
py_interp._init()
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
jax_interp._init()
|
||||||
|
|
||||||
|
ctx = Context(fps=10)
|
||||||
|
ctx.t = 0.5
|
||||||
|
ctx.frame_num = 5
|
||||||
|
|
||||||
|
frame_env = {
|
||||||
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||||
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
py_interp.globals['frame'] = test_img
|
||||||
|
jax_interp.globals['frame'] = test_img
|
||||||
|
|
||||||
|
py_result = np.asarray(py_interp._maybe_force(
|
||||||
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||||
|
jax_result = np.asarray(jax_interp._maybe_force(
|
||||||
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||||
|
|
||||||
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
max_diff = np.max(diff)
|
||||||
|
|
||||||
|
# Check edge region specifically
|
||||||
|
edge_diff = diff[0, :].mean()
|
||||||
|
|
||||||
|
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
|
||||||
|
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
|
||||||
|
|
||||||
|
def test_blend_with_invert(self, test_env):
|
||||||
|
"""Test blending with invert - matches the problematic recipe pattern."""
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(require-primitives "core")
|
||||||
|
(require-primitives "image")
|
||||||
|
(require-primitives "blending")
|
||||||
|
(require-primitives "color_ops")
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||||
|
(effect invert :path "../sexp_effects/effects/invert.sexp")
|
||||||
|
|
||||||
|
(frame
|
||||||
|
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||||
|
(zoom :amount 1.05 :speed 0)
|
||||||
|
(invert :amount 1))
|
||||||
|
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||||
|
(zoom :amount 0.95 :speed 0))]
|
||||||
|
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = cv2.imread(str(test_env['image_path']))
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||||
|
py_interp._init()
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
jax_interp._init()
|
||||||
|
|
||||||
|
ctx = Context(fps=10)
|
||||||
|
ctx.t = 0.5
|
||||||
|
ctx.frame_num = 5
|
||||||
|
|
||||||
|
frame_env = {
|
||||||
|
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||||
|
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
py_interp.globals['frame'] = test_img
|
||||||
|
jax_interp.globals['frame'] = test_img
|
||||||
|
|
||||||
|
py_result = np.asarray(py_interp._maybe_force(
|
||||||
|
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||||
|
jax_result = np.asarray(jax_interp._maybe_force(
|
||||||
|
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||||
|
|
||||||
|
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
|
||||||
|
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeferredEffectChainFusion:
|
||||||
|
"""Test the DeferredEffectChain fusion mechanism specifically."""
|
||||||
|
|
||||||
|
def test_manual_vs_fused_chain(self, test_env):
|
||||||
|
"""Compare manual application vs fused DeferredEffectChain."""
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||||
|
|
||||||
|
# Create minimal sexp to load effects
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
|
||||||
|
(frame frame)
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
interp._init()
|
||||||
|
|
||||||
|
test_img = test_env['test_img']
|
||||||
|
jax_frame = jnp.array(test_img)
|
||||||
|
t = 0.5
|
||||||
|
frame_num = 5
|
||||||
|
|
||||||
|
# Manual step-by-step application
|
||||||
|
rotate_fn = interp.jax_effects['rotate']
|
||||||
|
zoom_fn = interp.jax_effects['zoom']
|
||||||
|
|
||||||
|
rot_angle = -5.0
|
||||||
|
zoom_amount = 0.95
|
||||||
|
|
||||||
|
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||||
|
angle=rot_angle, speed=0)
|
||||||
|
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||||
|
amount=zoom_amount, speed=0)
|
||||||
|
manual_result = np.asarray(manual_result)
|
||||||
|
|
||||||
|
# Fused chain application
|
||||||
|
chain = DeferredEffectChain(
|
||||||
|
['rotate'],
|
||||||
|
[{'angle': rot_angle, 'speed': 0}],
|
||||||
|
jax_frame, t, frame_num
|
||||||
|
)
|
||||||
|
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||||
|
|
||||||
|
fused_result = np.asarray(interp._force_deferred(chain))
|
||||||
|
|
||||||
|
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
|
||||||
|
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
|
||||||
|
|
||||||
|
# Check specific pixels
|
||||||
|
h, w = test_img.shape[:2]
|
||||||
|
for y in [0, h//2, h-1]:
|
||||||
|
for x in [0, w//2, w-1]:
|
||||||
|
manual_val = manual_result[y, x]
|
||||||
|
fused_val = fused_result[y, x]
|
||||||
|
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
|
||||||
|
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
|
||||||
|
|
||||||
|
def test_chain_with_shrink_zoom_boundary(self, test_env):
|
||||||
|
"""Test that shrinking zoom handles boundaries correctly in chain."""
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||||
|
|
||||||
|
sexp_content = '''(stream "test"
|
||||||
|
:width 128
|
||||||
|
:height 96
|
||||||
|
:seed 42
|
||||||
|
|
||||||
|
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||||
|
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||||
|
|
||||||
|
(frame frame)
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||||
|
interp._init()
|
||||||
|
|
||||||
|
test_img = test_env['test_img']
|
||||||
|
jax_frame = jnp.array(test_img)
|
||||||
|
t = 0.5
|
||||||
|
frame_num = 5
|
||||||
|
|
||||||
|
# Parameters that shrink the image (zoom < 1.0)
|
||||||
|
rot_angle = -4.555
|
||||||
|
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
|
||||||
|
|
||||||
|
# Manual application
|
||||||
|
rotate_fn = interp.jax_effects['rotate']
|
||||||
|
zoom_fn = interp.jax_effects['zoom']
|
||||||
|
|
||||||
|
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||||
|
angle=rot_angle, speed=0)
|
||||||
|
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||||
|
amount=zoom_amount, speed=0)
|
||||||
|
manual_result = np.asarray(manual_result)
|
||||||
|
|
||||||
|
# Fused chain
|
||||||
|
chain = DeferredEffectChain(
|
||||||
|
['rotate'],
|
||||||
|
[{'angle': rot_angle, 'speed': 0}],
|
||||||
|
jax_frame, t, frame_num
|
||||||
|
)
|
||||||
|
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||||
|
|
||||||
|
fused_result = np.asarray(interp._force_deferred(chain))
|
||||||
|
|
||||||
|
# Check top edge specifically - this is where boundary issues manifest
|
||||||
|
top_edge_manual = manual_result[0, :]
|
||||||
|
top_edge_fused = fused_result[0, :]
|
||||||
|
|
||||||
|
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
|
||||||
|
mean_edge_diff = np.mean(edge_diff)
|
||||||
|
|
||||||
|
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
|
||||||
|
|
||||||
|
# Check for zeros at edges that shouldn't be there
|
||||||
|
manual_edge_sum = np.sum(top_edge_manual)
|
||||||
|
fused_edge_sum = np.sum(top_edge_fused)
|
||||||
|
|
||||||
|
if manual_edge_sum > 100: # If manual has significant values
|
||||||
|
assert fused_edge_sum > manual_edge_sum * 0.5, \
|
||||||
|
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
334
tests/test_jax_primitives.py
Normal file
334
tests/test_jax_primitives.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test framework to verify JAX primitives match Python primitives.
|
||||||
|
|
||||||
|
Compares output of each primitive through:
|
||||||
|
1. Python/NumPy path
|
||||||
|
2. JAX path (CPU)
|
||||||
|
3. JAX path (GPU) - if available
|
||||||
|
|
||||||
|
Reports any mismatches with detailed diffs.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, '/home/giles/art/art-celery')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Any, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
TEST_WIDTH = 64
|
||||||
|
TEST_HEIGHT = 48
|
||||||
|
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
||||||
|
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
||||||
|
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestResult:
|
||||||
|
primitive: str
|
||||||
|
passed: bool
|
||||||
|
python_mean: float = 0.0
|
||||||
|
jax_mean: float = 0.0
|
||||||
|
diff_mean: float = 0.0
|
||||||
|
diff_max: float = 0.0
|
||||||
|
pct_within_1: float = 0.0
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
||||||
|
"""Create a test frame with known pattern."""
|
||||||
|
if pattern == 'gradient':
|
||||||
|
# Diagonal gradient
|
||||||
|
y, x = np.mgrid[0:h, 0:w]
|
||||||
|
r = (x * 255 / w).astype(np.uint8)
|
||||||
|
g = (y * 255 / h).astype(np.uint8)
|
||||||
|
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
||||||
|
return np.stack([r, g, b], axis=2)
|
||||||
|
elif pattern == 'checkerboard':
|
||||||
|
y, x = np.mgrid[0:h, 0:w]
|
||||||
|
check = ((x // 8) + (y // 8)) % 2
|
||||||
|
v = (check * 255).astype(np.uint8)
|
||||||
|
return np.stack([v, v, v], axis=2)
|
||||||
|
elif pattern == 'solid':
|
||||||
|
return np.full((h, w, 3), 128, dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
||||||
|
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
||||||
|
if py_out is None or jax_out is None:
|
||||||
|
return float('inf'), float('inf'), 0.0
|
||||||
|
|
||||||
|
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
||||||
|
# Compare coordinate maps
|
||||||
|
diffs = []
|
||||||
|
for k in py_out:
|
||||||
|
if k in jax_out:
|
||||||
|
py_arr = np.asarray(py_out[k])
|
||||||
|
jax_arr = np.asarray(jax_out[k])
|
||||||
|
if py_arr.shape == jax_arr.shape:
|
||||||
|
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||||
|
diffs.append(diff)
|
||||||
|
if diffs:
|
||||||
|
all_diff = np.concatenate([d.flatten() for d in diffs])
|
||||||
|
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
||||||
|
return float('inf'), float('inf'), 0.0
|
||||||
|
|
||||||
|
py_arr = np.asarray(py_out)
|
||||||
|
jax_arr = np.asarray(jax_out)
|
||||||
|
|
||||||
|
if py_arr.shape != jax_arr.shape:
|
||||||
|
return float('inf'), float('inf'), 0.0
|
||||||
|
|
||||||
|
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||||
|
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Primitive Test Definitions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
PRIMITIVE_TESTS = {
|
||||||
|
# Geometry primitives
|
||||||
|
'geometry:ripple-displace': {
|
||||||
|
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
||||||
|
'returns': 'coords',
|
||||||
|
},
|
||||||
|
'geometry:rotate-img': {
|
||||||
|
'args': ['frame', 45],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'geometry:scale-img': {
|
||||||
|
'args': ['frame', 1.5],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'geometry:flip-h': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'geometry:flip-v': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
|
||||||
|
# Color operations
|
||||||
|
'color_ops:invert': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'color_ops:grayscale': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'color_ops:brightness': {
|
||||||
|
'args': ['frame', 1.5],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'color_ops:contrast': {
|
||||||
|
'args': ['frame', 1.5],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'color_ops:hue-shift': {
|
||||||
|
'args': ['frame', 90],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
|
||||||
|
# Image operations
|
||||||
|
'image:width': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'scalar',
|
||||||
|
},
|
||||||
|
'image:height': {
|
||||||
|
'args': ['frame'],
|
||||||
|
'returns': 'scalar',
|
||||||
|
},
|
||||||
|
'image:channel': {
|
||||||
|
'args': ['frame', 0],
|
||||||
|
'returns': 'array',
|
||||||
|
},
|
||||||
|
|
||||||
|
# Blending
|
||||||
|
'blending:blend': {
|
||||||
|
'args': ['frame', 'frame2', 0.5],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'blending:blend-add': {
|
||||||
|
'args': ['frame', 'frame2'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
'blending:blend-multiply': {
|
||||||
|
'args': ['frame', 'frame2'],
|
||||||
|
'returns': 'frame',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||||
|
"""Run a primitive through the Python interpreter."""
|
||||||
|
if prim_name not in interp.primitives:
|
||||||
|
return None
|
||||||
|
|
||||||
|
func = interp.primitives[prim_name]
|
||||||
|
args = []
|
||||||
|
for a in test_def['args']:
|
||||||
|
if a == 'frame':
|
||||||
|
args.append(frame.copy())
|
||||||
|
elif a == 'frame2':
|
||||||
|
args.append(frame2.copy())
|
||||||
|
else:
|
||||||
|
args.append(a)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args)
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||||
|
"""Run a primitive through the JAX compiler."""
|
||||||
|
try:
|
||||||
|
from streaming.sexp_to_jax import JaxCompiler
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
compiler = JaxCompiler()
|
||||||
|
|
||||||
|
# Build a simple expression to test the primitive
|
||||||
|
from sexp_effects.parser import Symbol, Keyword
|
||||||
|
|
||||||
|
args = []
|
||||||
|
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
||||||
|
|
||||||
|
for a in test_def['args']:
|
||||||
|
if a == 'frame':
|
||||||
|
args.append(Symbol('frame'))
|
||||||
|
elif a == 'frame2':
|
||||||
|
args.append(Symbol('frame2'))
|
||||||
|
else:
|
||||||
|
args.append(a)
|
||||||
|
|
||||||
|
# Create expression: (prim_name arg1 arg2 ...)
|
||||||
|
expr = [Symbol(prim_name)] + args
|
||||||
|
|
||||||
|
result = compiler._eval(expr, env)
|
||||||
|
|
||||||
|
if hasattr(result, '__array__'):
|
||||||
|
return np.asarray(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
||||||
|
"""Test a single primitive."""
|
||||||
|
frame = create_test_frame(pattern='gradient')
|
||||||
|
frame2 = create_test_frame(pattern='checkerboard')
|
||||||
|
|
||||||
|
result = TestResult(primitive=prim_name, passed=False)
|
||||||
|
|
||||||
|
# Run Python version
|
||||||
|
try:
|
||||||
|
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
||||||
|
if py_out is not None and hasattr(py_out, 'shape'):
|
||||||
|
result.python_mean = float(np.mean(py_out))
|
||||||
|
except Exception as e:
|
||||||
|
result.error = f"Python error: {e}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Run JAX version
|
||||||
|
try:
|
||||||
|
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
||||||
|
if jax_out is not None and hasattr(jax_out, 'shape'):
|
||||||
|
result.jax_mean = float(np.mean(jax_out))
|
||||||
|
except Exception as e:
|
||||||
|
result.error = f"JAX error: {e}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
if py_out is None:
|
||||||
|
result.error = "Python returned None"
|
||||||
|
return result
|
||||||
|
if jax_out is None:
|
||||||
|
result.error = "JAX returned None"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
||||||
|
result.diff_mean = diff_mean
|
||||||
|
result.diff_max = diff_max
|
||||||
|
result.pct_within_1 = pct
|
||||||
|
|
||||||
|
# Check pass/fail
|
||||||
|
result.passed = (
|
||||||
|
diff_mean <= TOLERANCE_MEAN and
|
||||||
|
diff_max <= TOLERANCE_MAX and
|
||||||
|
pct >= TOLERANCE_PCT
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.passed:
|
||||||
|
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def discover_primitives(interp) -> List[str]:
|
||||||
|
"""Discover all primitives available in the interpreter."""
|
||||||
|
return sorted(interp.primitives.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests(verbose=True):
|
||||||
|
"""Run all primitive tests."""
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.chdir('/home/giles/art/test')
|
||||||
|
|
||||||
|
from streaming.stream_sexp_generic import StreamInterpreter
|
||||||
|
from sexp_effects.primitive_libs import core as core_mod
|
||||||
|
|
||||||
|
core_mod.set_random_seed(42)
|
||||||
|
|
||||||
|
# Create interpreter to get Python primitives
|
||||||
|
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
||||||
|
interp._init()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("JAX PRIMITIVE TEST SUITE")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test defined primitives
|
||||||
|
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
||||||
|
result = test_primitive(interp, prim_name, test_def)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
status = "✓ PASS" if result.passed else "✗ FAIL"
|
||||||
|
if verbose:
|
||||||
|
print(f"{status} {prim_name}")
|
||||||
|
if not result.passed:
|
||||||
|
print(f" {result.error}")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
passed = sum(1 for r in results if r.passed)
|
||||||
|
failed = sum(1 for r in results if not r.passed)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"SUMMARY: {passed} passed, {failed} failed")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if failed > 0:
|
||||||
|
print("\nFailed primitives:")
|
||||||
|
for r in results:
|
||||||
|
if not r.passed:
|
||||||
|
print(f" - {r.primitive}: {r.error}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_all_tests()
|
||||||
305
tests/test_xector.py
Normal file
305
tests/test_xector.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
Tests for xector primitives - parallel array operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from sexp_effects.primitive_libs.xector import (
|
||||||
|
Xector,
|
||||||
|
xector_red, xector_green, xector_blue, xector_rgb,
|
||||||
|
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
|
||||||
|
xector_dist_from_center,
|
||||||
|
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
|
||||||
|
alpha_sin, alpha_cos, alpha_sq,
|
||||||
|
alpha_lt, alpha_gt, alpha_eq,
|
||||||
|
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
|
||||||
|
xector_where, xector_fill, xector_zeros, xector_ones,
|
||||||
|
is_xector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestXectorBasics:
|
||||||
|
"""Test Xector class basic operations."""
|
||||||
|
|
||||||
|
def test_create_from_list(self):
|
||||||
|
x = Xector([1, 2, 3])
|
||||||
|
assert len(x) == 3
|
||||||
|
assert is_xector(x)
|
||||||
|
|
||||||
|
def test_create_from_numpy(self):
|
||||||
|
arr = np.array([1.0, 2.0, 3.0])
|
||||||
|
x = Xector(arr)
|
||||||
|
assert len(x) == 3
|
||||||
|
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
|
||||||
|
|
||||||
|
def test_implicit_add(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
b = Xector([4, 5, 6])
|
||||||
|
c = a + b
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||||
|
|
||||||
|
def test_implicit_mul(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
b = Xector([2, 2, 2])
|
||||||
|
c = a * b
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||||
|
|
||||||
|
def test_scalar_broadcast(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
c = a + 10
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
|
||||||
|
|
||||||
|
def test_scalar_broadcast_rmul(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
c = 2 * a
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||||
|
|
||||||
|
|
||||||
|
class TestAlphaOperations:
|
||||||
|
"""Test α (element-wise) operations."""
|
||||||
|
|
||||||
|
def test_alpha_add(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
b = Xector([4, 5, 6])
|
||||||
|
c = alpha_add(a, b)
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||||
|
|
||||||
|
def test_alpha_add_multi(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
b = Xector([1, 1, 1])
|
||||||
|
c = Xector([10, 10, 10])
|
||||||
|
d = alpha_add(a, b, c)
|
||||||
|
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
|
||||||
|
|
||||||
|
def test_alpha_mul_scalar(self):
|
||||||
|
a = Xector([1, 2, 3])
|
||||||
|
c = alpha_mul(a, 2)
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||||
|
|
||||||
|
def test_alpha_sqrt(self):
|
||||||
|
a = Xector([1, 4, 9, 16])
|
||||||
|
c = alpha_sqrt(a)
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
|
||||||
|
|
||||||
|
def test_alpha_clamp(self):
|
||||||
|
a = Xector([-5, 0, 5, 10, 15])
|
||||||
|
c = alpha_clamp(a, 0, 10)
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
|
||||||
|
|
||||||
|
def test_alpha_sin_cos(self):
|
||||||
|
a = Xector([0, np.pi/2, np.pi])
|
||||||
|
s = alpha_sin(a)
|
||||||
|
c = alpha_cos(a)
|
||||||
|
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
|
||||||
|
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
|
||||||
|
|
||||||
|
def test_alpha_sq(self):
|
||||||
|
a = Xector([1, 2, 3, 4])
|
||||||
|
c = alpha_sq(a)
|
||||||
|
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
|
||||||
|
|
||||||
|
def test_alpha_comparison(self):
|
||||||
|
a = Xector([1, 2, 3, 4])
|
||||||
|
b = Xector([2, 2, 2, 2])
|
||||||
|
lt = alpha_lt(a, b)
|
||||||
|
gt = alpha_gt(a, b)
|
||||||
|
eq = alpha_eq(a, b)
|
||||||
|
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
|
||||||
|
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
|
||||||
|
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
|
||||||
|
|
||||||
|
|
||||||
|
class TestBetaOperations:
|
||||||
|
"""Test β (reduction) operations."""
|
||||||
|
|
||||||
|
def test_beta_add(self):
|
||||||
|
a = Xector([1, 2, 3, 4])
|
||||||
|
assert beta_add(a) == 10
|
||||||
|
|
||||||
|
def test_beta_mul(self):
|
||||||
|
a = Xector([1, 2, 3, 4])
|
||||||
|
assert beta_mul(a) == 24
|
||||||
|
|
||||||
|
def test_beta_min_max(self):
|
||||||
|
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
|
||||||
|
assert beta_min(a) == 1
|
||||||
|
assert beta_max(a) == 9
|
||||||
|
|
||||||
|
def test_beta_mean(self):
|
||||||
|
a = Xector([1, 2, 3, 4])
|
||||||
|
assert beta_mean(a) == 2.5
|
||||||
|
|
||||||
|
def test_beta_count(self):
|
||||||
|
a = Xector([1, 2, 3, 4, 5])
|
||||||
|
assert beta_count(a) == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrameConversion:
|
||||||
|
"""Test frame/xector conversion."""
|
||||||
|
|
||||||
|
def test_extract_channels(self):
|
||||||
|
# Create a 2x2 RGB frame
|
||||||
|
frame = np.array([
|
||||||
|
[[255, 0, 0], [0, 255, 0]],
|
||||||
|
[[0, 0, 255], [128, 128, 128]]
|
||||||
|
], dtype=np.uint8)
|
||||||
|
|
||||||
|
r = xector_red(frame)
|
||||||
|
g = xector_green(frame)
|
||||||
|
b = xector_blue(frame)
|
||||||
|
|
||||||
|
assert len(r) == 4
|
||||||
|
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
|
||||||
|
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
|
||||||
|
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
|
||||||
|
|
||||||
|
def test_rgb_roundtrip(self):
|
||||||
|
# Create a 2x2 RGB frame
|
||||||
|
frame = np.array([
|
||||||
|
[[100, 150, 200], [50, 75, 100]],
|
||||||
|
[[200, 100, 50], [25, 50, 75]]
|
||||||
|
], dtype=np.uint8)
|
||||||
|
|
||||||
|
r = xector_red(frame)
|
||||||
|
g = xector_green(frame)
|
||||||
|
b = xector_blue(frame)
|
||||||
|
|
||||||
|
reconstructed = xector_rgb(r, g, b)
|
||||||
|
np.testing.assert_array_equal(reconstructed, frame)
|
||||||
|
|
||||||
|
def test_modify_and_reconstruct(self):
|
||||||
|
frame = np.array([
|
||||||
|
[[100, 100, 100], [100, 100, 100]],
|
||||||
|
[[100, 100, 100], [100, 100, 100]]
|
||||||
|
], dtype=np.uint8)
|
||||||
|
|
||||||
|
r = xector_red(frame)
|
||||||
|
g = xector_green(frame)
|
||||||
|
b = xector_blue(frame)
|
||||||
|
|
||||||
|
# Double red channel
|
||||||
|
r_doubled = r * 2
|
||||||
|
|
||||||
|
result = xector_rgb(r_doubled, g, b)
|
||||||
|
|
||||||
|
# Red should be 200, others unchanged
|
||||||
|
assert result[0, 0, 0] == 200
|
||||||
|
assert result[0, 0, 1] == 100
|
||||||
|
assert result[0, 0, 2] == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoordinates:
|
||||||
|
"""Test coordinate generation."""
|
||||||
|
|
||||||
|
def test_x_coords(self):
|
||||||
|
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||||
|
x = xector_x_coords(frame)
|
||||||
|
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
|
||||||
|
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
|
||||||
|
|
||||||
|
def test_y_coords(self):
|
||||||
|
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||||
|
y = xector_y_coords(frame)
|
||||||
|
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
|
||||||
|
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
|
||||||
|
|
||||||
|
def test_normalized_coords(self):
|
||||||
|
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||||
|
x = xector_x_norm(frame)
|
||||||
|
y = xector_y_norm(frame)
|
||||||
|
|
||||||
|
# x should go 0 to 1 across width
|
||||||
|
assert x.to_numpy()[0] == 0
|
||||||
|
assert x.to_numpy()[2] == 1
|
||||||
|
|
||||||
|
# y should go 0 to 1 down height
|
||||||
|
assert y.to_numpy()[0] == 0
|
||||||
|
assert y.to_numpy()[3] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditional:
|
||||||
|
"""Test conditional operations."""
|
||||||
|
|
||||||
|
def test_where(self):
|
||||||
|
cond = Xector([True, False, True, False])
|
||||||
|
true_val = Xector([1, 1, 1, 1])
|
||||||
|
false_val = Xector([0, 0, 0, 0])
|
||||||
|
|
||||||
|
result = xector_where(cond, true_val, false_val)
|
||||||
|
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
|
||||||
|
|
||||||
|
def test_where_with_comparison(self):
|
||||||
|
a = Xector([1, 5, 3, 7])
|
||||||
|
threshold = 4
|
||||||
|
|
||||||
|
# Elements > 4 become 255, others become 0
|
||||||
|
result = xector_where(alpha_gt(a, threshold), 255, 0)
|
||||||
|
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
|
||||||
|
|
||||||
|
def test_fill(self):
|
||||||
|
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||||
|
x = xector_fill(42, frame)
|
||||||
|
assert len(x) == 6
|
||||||
|
assert all(v == 42 for v in x.to_numpy())
|
||||||
|
|
||||||
|
def test_zeros_ones(self):
|
||||||
|
frame = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||||
|
z = xector_zeros(frame)
|
||||||
|
o = xector_ones(frame)
|
||||||
|
|
||||||
|
assert all(v == 0 for v in z.to_numpy())
|
||||||
|
assert all(v == 1 for v in o.to_numpy())
|
||||||
|
|
||||||
|
|
||||||
|
class TestInterpreterIntegration:
|
||||||
|
"""Test xector operations through the interpreter."""
|
||||||
|
|
||||||
|
def test_xector_vignette_effect(self):
|
||||||
|
from sexp_effects.interpreter import Interpreter
|
||||||
|
|
||||||
|
interp = Interpreter(minimal_primitives=True)
|
||||||
|
|
||||||
|
# Load the xector vignette effect
|
||||||
|
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
|
||||||
|
|
||||||
|
# Create a test frame (white)
|
||||||
|
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Run effect
|
||||||
|
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
|
||||||
|
|
||||||
|
# Center should be brighter than corners
|
||||||
|
center = result[50, 50]
|
||||||
|
corner = result[0, 0]
|
||||||
|
|
||||||
|
assert center.mean() > corner.mean(), "Center should be brighter than corners"
|
||||||
|
# Corners should be darkened
|
||||||
|
assert corner.mean() < 255, "Corners should be darkened"
|
||||||
|
|
||||||
|
def test_implicit_elementwise(self):
|
||||||
|
"""Test that regular + works element-wise on xectors."""
|
||||||
|
from sexp_effects.interpreter import Interpreter
|
||||||
|
|
||||||
|
interp = Interpreter(minimal_primitives=True)
|
||||||
|
# Load xector primitives
|
||||||
|
from sexp_effects.primitive_libs.xector import PRIMITIVES
|
||||||
|
for name, fn in PRIMITIVES.items():
|
||||||
|
interp.global_env.set(name, fn)
|
||||||
|
|
||||||
|
# Parse and eval a simple xector expression
|
||||||
|
from sexp_effects.parser import parse
|
||||||
|
expr = parse('(+ (red frame) 10)')
|
||||||
|
|
||||||
|
# Create test frame
|
||||||
|
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
|
||||||
|
interp.global_env.set('frame', frame)
|
||||||
|
|
||||||
|
result = interp.eval(expr)
|
||||||
|
|
||||||
|
# Should be a xector with values 110
|
||||||
|
assert is_xector(result)
|
||||||
|
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user