#!/usr/bin/env python3 """ Test styled TextStrip rendering against PIL. """ import numpy as np import jax.numpy as jnp from PIL import Image, ImageDraw, ImageFont from streaming.jax_typography import ( render_text_strip, place_text_strip_jax, _load_font ) def render_pil(text, x, y, font_size=36, frame_size=(400, 100), stroke_width=0, stroke_fill=None, anchor="la", multiline=False, line_spacing=4, align="left"): """Render with PIL directly.""" frame = np.zeros((frame_size[1], frame_size[0], 3), dtype=np.uint8) img = Image.fromarray(frame) draw = ImageDraw.Draw(img) font = _load_font(None, font_size) # Default stroke fill if stroke_fill is None: stroke_fill = (0, 0, 0) if multiline: draw.multiline_text((x, y), text, fill=(255, 255, 255), font=font, stroke_width=stroke_width, stroke_fill=stroke_fill, spacing=line_spacing, align=align, anchor=anchor) else: draw.text((x, y), text, fill=(255, 255, 255), font=font, stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor) return np.array(img) def render_strip(text, x, y, font_size=36, frame_size=(400, 100), stroke_width=0, stroke_fill=None, anchor="la", multiline=False, line_spacing=4, align="left"): """Render with TextStrip.""" frame = jnp.zeros((frame_size[1], frame_size[0], 3), dtype=jnp.uint8) strip = render_text_strip( text, None, font_size, stroke_width=stroke_width, stroke_fill=stroke_fill, anchor=anchor, multiline=multiline, line_spacing=line_spacing, align=align ) strip_img = jnp.asarray(strip.image) color = jnp.array([255, 255, 255], dtype=jnp.float32) result = place_text_strip_jax( frame, strip_img, x, y, strip.baseline_y, strip.bearing_x, color, 1.0, anchor_x=strip.anchor_x, anchor_y=strip.anchor_y, stroke_width=strip.stroke_width ) return np.array(result) def compare(name, text, x, y, font_size=36, frame_size=(400, 100), tolerance=0, **kwargs): """Compare PIL and TextStrip rendering. tolerance=0: exact pixel match required tolerance=1: allow 1-pixel position shift (for sub-pixel rendering differences in center-aligned multiline text where the strip is pre-rendered at a different base position than the final placement) """ pil = render_pil(text, x, y, font_size, frame_size, **kwargs) strip = render_strip(text, x, y, font_size, frame_size, **kwargs) diff = np.abs(pil.astype(np.int16) - strip.astype(np.int16)) max_diff = diff.max() pixels_diff = (diff > 0).any(axis=2).sum() if max_diff == 0: print(f"PASS: {name}") print(f" Max diff: 0, Pixels different: 0") return True if tolerance > 0: # Check if the difference is just a sub-pixel position shift: # for each shifted version, compute the minimum diff best_diff = diff.copy() for dy in range(-tolerance, tolerance + 1): for dx in range(-tolerance, tolerance + 1): if dy == 0 and dx == 0: continue shifted = np.roll(np.roll(strip, dy, axis=0), dx, axis=1) sdiff = np.abs(pil.astype(np.int16) - shifted.astype(np.int16)) best_diff = np.minimum(best_diff, sdiff) max_shift_diff = best_diff.max() pixels_shift_diff = (best_diff > 0).any(axis=2).sum() if max_shift_diff == 0: print(f"PASS: {name} (within {tolerance}px position tolerance)") print(f" Raw diff: {max_diff}, After shift tolerance: 0") return True status = "FAIL" print(f"{status}: {name}") print(f" Max diff: {max_diff}, Pixels different: {pixels_diff}") # Save debug images Image.fromarray(pil).save(f"/tmp/pil_{name}.png") Image.fromarray(strip).save(f"/tmp/strip_{name}.png") diff_scaled = np.clip(diff * 10, 0, 255).astype(np.uint8) Image.fromarray(diff_scaled).save(f"/tmp/diff_{name}.png") print(f" Saved: /tmp/pil_{name}.png, /tmp/strip_{name}.png, /tmp/diff_{name}.png") return False def main(): print("=" * 60) print("Styled TextStrip vs PIL Comparison") print("=" * 60) results = [] # Basic text results.append(compare("basic", "Hello World", 20, 50)) # Stroke/outline results.append(compare("stroke_2", "Outlined", 20, 50, stroke_width=2, stroke_fill=(255, 0, 0))) results.append(compare("stroke_5", "Big Outline", 30, 60, font_size=48, frame_size=(500, 120), stroke_width=5, stroke_fill=(0, 0, 0))) # Anchors - center results.append(compare("anchor_mm", "Center", 200, 50, frame_size=(400, 100), anchor="mm")) # Anchors - right results.append(compare("anchor_rm", "Right", 380, 50, frame_size=(400, 100), anchor="rm")) # Multiline results.append(compare("multiline", "Line 1\nLine 2\nLine 3", 20, 20, frame_size=(400, 150), multiline=True, line_spacing=8)) # Multiline centered (1px tolerance: sub-pixel rendering differs because # the strip is pre-rendered at an integer position while PIL's center # alignment uses fractional getlength values for the 'm' anchor shift) results.append(compare("multiline_center", "Short\nMedium Length\nX", 200, 20, frame_size=(400, 150), multiline=True, anchor="ma", align="center", tolerance=1)) # Stroke + multiline results.append(compare("stroke_multiline", "Line A\nLine B", 20, 20, frame_size=(400, 120), stroke_width=2, stroke_fill=(0, 0, 255), multiline=True)) print("=" * 60) passed = sum(results) total = len(results) print(f"Results: {passed}/{total} passed") if passed == total: print("ALL TESTS PASSED!") else: print(f"FAILED: {total - passed} tests") print("=" * 60) if __name__ == "__main__": main()