Files
rose-ash/test_styled_text.py
giles 80c94ebea7 Squashed 'l1/' content from commit 670aa58
git-subtree-dir: l1
git-subtree-split: 670aa582df99e87fca7c247b949baf452e8c234f
2026-02-24 23:07:19 +00:00

177 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""
Test styled TextStrip rendering against PIL.
"""
import numpy as np
import jax.numpy as jnp
from PIL import Image, ImageDraw, ImageFont
from streaming.jax_typography import (
render_text_strip, place_text_strip_jax, _load_font
)
def render_pil(text, x, y, font_size=36, frame_size=(400, 100),
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with PIL directly."""
frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8)
img = Image.fromarray(frame)
draw = ImageDraw.Draw(img)
font = _load_font(None, font_size)
# Default stroke fill
if stroke_fill is None:
stroke_fill = (0, 0, 0)
if multiline:
draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill,
spacing=line_spacing, align=align, anchor=anchor)
else:
draw.text((x, y), text, fill=(255, 255, 255), font=font,
stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor)
return np.array(img)
def render_strip(text, x, y, font_size=36, frame_size=(400, 100),
stroke_width=0, stroke_fill=None, anchor="la",
multiline=False, line_spacing=4, align="left"):
"""Render with TextStrip."""
frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8)
strip = render_text_strip(
text, None, font_size,
stroke_width=stroke_width, stroke_fill=stroke_fill,
anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align
)
strip_img = jnp.asarray(strip.image)
color = jnp.array([255, 255, 255], dtype=jnp.float32)
result = place_text_strip_jax(
frame, strip_img, x, y,
strip.baseline_y, strip.bearing_x,
color, 1.0,
anchor_x=strip.anchor_x, anchor_y=strip.anchor_y,
stroke_width=strip.stroke_width
)
return np.array(result)
def compare(name, text, x, y, font_size=36, frame_size=(400, 100),
tolerance=0, **kwargs):
"""Compare PIL and TextStrip rendering.
tolerance=0: exact pixel match required
tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences
in center-aligned multiline text where the strip is pre-rendered
at a different base position than the final placement)
"""
pil = render_pil(text, x, y, font_size, frame_size, **kwargs)
strip = render_strip(text, x, y, font_size, frame_size, **kwargs)
diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16))
max_diff = diff.max()
pixels_diff = (diff > 0).any(axis=2).sum()
if max_diff == 0:
print(f"PASS: {name}")
print(f" Max diff: 0, Pixels different: 0")
return True
if tolerance > 0:
# Check if the difference is just a sub-pixel position shift:
# for each shifted version, compute the minimum diff
best_diff = diff.copy()
for dy in range(-tolerance, tolerance + 1):
for dx in range(-tolerance, tolerance + 1):
if dy == 0 and dx == 0:
continue
shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1)
sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16))
best_diff = np.minimum(best_diff, sdiff)
max_shift_diff = best_diff.max()
pixels_shift_diff = (best_diff > 0).any(axis=2).sum()
if max_shift_diff == 0:
print(f"PASS: {name} (within {tolerance}px position tolerance)")
print(f" Raw diff: {max_diff}, After shift tolerance: 0")
return True
status = "FAIL"
print(f"{status}: {name}")
print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}")
# Save debug images
Image.fromarray(pil).save(f"/tmp/pil_{name}.png")
Image.fromarray(strip).save(f"/tmp/strip_{name}.png")
diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8)
Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png")
print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png")
return False
def main():
print("=" * 60)
print("Styled TextStrip vs PIL Comparison")
print("=" * 60)
results = []
# Basic text
results.append(compare("basic", "Hello World", 20, 50))
# Stroke/outline
results.append(compare("stroke_2", "Outlined", 20, 50,
stroke_width=2, stroke_fill=(255, 0, 0)))
results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48,
frame_size=(500, 120),
stroke_width=5, stroke_fill=(0, 0, 0)))
# Anchors - center
results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100),
anchor="mm"))
# Anchors - right
results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100),
anchor="rm"))
# Multiline
results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20,
frame_size=(400, 150),
multiline=True, line_spacing=8))
# Multiline centered (1px tolerance: sub-pixel rendering differs because
# the strip is pre-rendered at an integer position while PIL's center
# alignment uses fractional getlength values for the 'm' anchor shift)
results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20,
frame_size=(400, 150),
multiline=True, anchor="ma", align="center",
tolerance=1))
# Stroke + multiline
results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20,
frame_size=(400, 120),
stroke_width=2, stroke_fill=(0, 0, 255),
multiline=True))
print("=" * 60)
passed = sum(results)
total = len(results)
print(f"Results: {passed}/{total} passed")
if passed == total:
print("ALL TESTS PASSED!")
else:
print(f"FAILED: {total - passed} tests")
print("=" * 60)
if __name__ == "__main__":
main()