Files
celery/sexp_effects/primitive_libs/streaming.py
gilesb fc9597456f
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 1m28s
Add JAX typography, xector primitives, deferred effect chains, and GPU streaming
- Add JAX text rendering with font atlas, styled text placement, and typography primitives
- Add xector (element-wise/reduction) operations library and sexp effects
- Add deferred effect chain fusion for JIT-compiled effect pipelines
- Expand drawing primitives with font management, alignment, shadow, and outline
- Add interpreter support for function-style define and require
- Add GPU persistence mode and hardware decode support to streaming
- Add new sexp effects: cell_pattern, halftone, mosaic, and derived definitions
- Add path registry for asset resolution
- Add integration, primitives, and xector tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:41:19 +00:00

594 lines
20 KiB
Python

"""
Streaming primitives for video/audio processing.
These primitives handle video source reading and audio analysis,
keeping the interpreter completely generic.
GPU Acceleration:
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
- Hardware video decoding (NVDEC) is used when available
- Dramatically improves performance on GPU nodes
Async Prefetching:
- Set STREAMING_PREFETCH=1 to enable background frame prefetching
- Decodes upcoming frames while current frame is being processed
"""
import os
import numpy as np
import subprocess
import json
import threading
from collections import deque
from pathlib import Path
# Try to import CuPy for GPU acceleration
try:
import cupy as cp
CUPY_AVAILABLE = True
except ImportError:
cp = None
CUPY_AVAILABLE = False
# GPU persistence mode - output CuPy arrays instead of numpy
# Disabled by default until all primitives support GPU frames
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
# Async prefetch mode - decode frames in background thread
PREFETCH_ENABLED = os.environ.get("STREAMING_PREFETCH", "1") == "1"
PREFETCH_BUFFER_SIZE = int(os.environ.get("STREAMING_PREFETCH_SIZE", "10"))
# Check for hardware decode support (cached)
_HWDEC_AVAILABLE = None
def _check_hwdec():
"""Check if NVIDIA hardware decode is available."""
global _HWDEC_AVAILABLE
if _HWDEC_AVAILABLE is not None:
return _HWDEC_AVAILABLE
try:
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
if result.returncode != 0:
_HWDEC_AVAILABLE = False
return False
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
_HWDEC_AVAILABLE = "cuda" in result.stdout
except Exception:
_HWDEC_AVAILABLE = False
return _HWDEC_AVAILABLE
class VideoSource:
"""Video source with persistent streaming pipe for fast sequential reads."""
def __init__(self, path: str, fps: float = 30):
self.path = Path(path)
self.fps = fps # Output fps for the stream
self._frame_size = None
self._duration = None
self._proc = None # Persistent ffmpeg process
self._stream_time = 0.0 # Current position in stream
self._frame_time = 1.0 / fps # Time per frame at output fps
self._last_read_time = -1
self._cached_frame = None
# Check if file exists
if not self.path.exists():
raise FileNotFoundError(f"Video file not found: {self.path}")
# Get video info
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", str(self.path)]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}")
try:
info = json.loads(result.stdout)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}")
for stream in info.get("streams", []):
if stream.get("codec_type") == "video":
self._frame_size = (stream.get("width", 720), stream.get("height", 720))
# Try direct duration field first
if "duration" in stream:
self._duration = float(stream["duration"])
# Fall back to tags.DURATION (webm format: "00:01:00.124000000")
elif "tags" in stream and "DURATION" in stream["tags"]:
dur_str = stream["tags"]["DURATION"]
parts = dur_str.split(":")
if len(parts) == 3:
h, m, s = parts
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
break
# Fallback: check format duration if stream duration not found
if self._duration is None and "format" in info and "duration" in info["format"]:
self._duration = float(info["format"]["duration"])
if not self._frame_size:
self._frame_size = (720, 720)
import sys
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
def _start_stream(self, seek_time: float = 0):
"""Start or restart the ffmpeg streaming process.
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
"""
if self._proc:
self._proc.kill()
self._proc = None
# Check file exists before trying to open
if not self.path.exists():
raise FileNotFoundError(f"Video file not found: {self.path}")
w, h = self._frame_size
# Build ffmpeg command with optional hardware decode
cmd = ["ffmpeg", "-v", "error"]
# Use hardware decode if available (significantly faster)
if _check_hwdec():
cmd.extend(["-hwaccel", "cuda"])
cmd.extend([
"-ss", f"{seek_time:.3f}",
"-i", str(self.path),
"-f", "rawvideo", "-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(self.fps), # Output at specified fps
"-"
])
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self._stream_time = seek_time
# Check if process started successfully by reading first bit of stderr
import select
import sys
readable, _, _ = select.select([self._proc.stderr], [], [], 0.5)
if readable:
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
if err:
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
def _read_frame_from_stream(self):
"""Read one frame from the stream.
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
"""
w, h = self._frame_size
frame_size = w * h * 3
if not self._proc or self._proc.poll() is not None:
return None
data = self._proc.stdout.read(frame_size)
if len(data) < frame_size:
return None
frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
# Transfer to GPU if persistence mode enabled
if GPU_PERSIST:
return cp.asarray(frame)
return frame
def read(self) -> np.ndarray:
"""Read frame (uses last cached or t=0)."""
if self._cached_frame is not None:
return self._cached_frame
return self.read_at(0)
def read_at(self, t: float) -> np.ndarray:
"""Read frame at specific time using streaming with smart seeking."""
# Cache check - return same frame for same time
if t == self._last_read_time and self._cached_frame is not None:
return self._cached_frame
w, h = self._frame_size
# Loop time if video is shorter
seek_time = t
if self._duration and self._duration > 0:
seek_time = t % self._duration
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
if seek_time > self._duration - 0.1:
seek_time = 0.0
# Decide whether to seek or continue streaming
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
# Allow small backward tolerance to handle floating point and timing jitter
need_seek = (
self._proc is None or
self._proc.poll() is not None or
seek_time < self._stream_time - self._frame_time or # More than 1 frame backward
seek_time > self._stream_time + 2.0
)
if need_seek:
import sys
reason = "no proc" if self._proc is None else "proc dead" if self._proc.poll() is not None else "backward" if seek_time < self._stream_time else "jump"
print(f"SEEK {self.path.name}: t={t:.4f} seek={seek_time:.4f} stream={self._stream_time:.4f} ({reason})", file=sys.stderr)
self._start_stream(seek_time)
# Skip frames to reach target time
skip_retries = 0
while self._stream_time + self._frame_time <= seek_time:
frame = self._read_frame_from_stream()
if frame is None:
# Stream ended or failed - restart from seek point
import time
skip_retries += 1
if skip_retries > 3:
# Give up skipping, just start fresh at seek_time
self._start_stream(seek_time)
time.sleep(0.1)
break
self._start_stream(seek_time)
time.sleep(0.05)
continue
self._stream_time += self._frame_time
skip_retries = 0 # Reset on successful read
# Read the target frame with retry logic
frame = None
max_retries = 3
for attempt in range(max_retries):
frame = self._read_frame_from_stream()
if frame is not None:
break
# Stream failed - try restarting
import sys
import time
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
# Check for ffmpeg errors
if self._proc and self._proc.stderr:
try:
import select
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
if readable:
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
if err:
print(f"ffmpeg error: {err}", file=sys.stderr)
except:
pass
# Wait a bit and restart
time.sleep(0.1)
self._start_stream(seek_time)
# Give ffmpeg time to start
time.sleep(0.1)
if frame is None:
import sys
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
else:
self._stream_time += self._frame_time
self._last_read_time = t
self._cached_frame = frame
return frame
def skip(self):
"""No-op for seek-based reading."""
pass
@property
def size(self):
return self._frame_size
def close(self):
if self._proc:
self._proc.kill()
self._proc = None
class PrefetchingVideoSource:
"""
Video source with background prefetching for improved performance.
Wraps VideoSource and adds a background thread that pre-decodes
upcoming frames while the main thread processes the current frame.
"""
def __init__(self, path: str, fps: float = 30, buffer_size: int = None):
self._source = VideoSource(path, fps)
self._buffer_size = buffer_size or PREFETCH_BUFFER_SIZE
self._buffer = {} # time -> frame
self._buffer_lock = threading.Lock()
self._prefetch_time = 0.0
self._frame_time = 1.0 / fps
self._stop_event = threading.Event()
self._request_event = threading.Event()
self._target_time = 0.0
# Start prefetch thread
self._thread = threading.Thread(target=self._prefetch_loop, daemon=True)
self._thread.start()
import sys
print(f"PrefetchingVideoSource: {path} buffer_size={self._buffer_size}", file=sys.stderr)
def _prefetch_loop(self):
"""Background thread that pre-reads frames."""
while not self._stop_event.is_set():
# Wait for work or timeout
self._request_event.wait(timeout=0.01)
self._request_event.clear()
if self._stop_event.is_set():
break
# Prefetch frames ahead of target time
target = self._target_time
with self._buffer_lock:
# Clean old frames (more than 1 second behind)
old_times = [t for t in self._buffer.keys() if t < target - 1.0]
for t in old_times:
del self._buffer[t]
# Count how many frames we have buffered ahead
buffered_ahead = sum(1 for t in self._buffer.keys() if t >= target)
# Prefetch if buffer not full
if buffered_ahead < self._buffer_size:
# Find next time to prefetch
prefetch_t = target
with self._buffer_lock:
existing_times = set(self._buffer.keys())
for _ in range(self._buffer_size):
if prefetch_t not in existing_times:
break
prefetch_t += self._frame_time
# Read the frame (this is the slow part)
try:
frame = self._source.read_at(prefetch_t)
with self._buffer_lock:
self._buffer[prefetch_t] = frame
except Exception as e:
import sys
print(f"Prefetch error at t={prefetch_t}: {e}", file=sys.stderr)
def read_at(self, t: float) -> np.ndarray:
"""Read frame at specific time, using prefetch buffer if available."""
self._target_time = t
self._request_event.set() # Wake up prefetch thread
# Round to frame time for buffer lookup
t_key = round(t / self._frame_time) * self._frame_time
# Check buffer first
with self._buffer_lock:
if t_key in self._buffer:
return self._buffer[t_key]
# Also check for close matches (within half frame time)
for buf_t, frame in self._buffer.items():
if abs(buf_t - t) < self._frame_time * 0.5:
return frame
# Not in buffer - read directly (blocking)
frame = self._source.read_at(t)
# Store in buffer
with self._buffer_lock:
self._buffer[t_key] = frame
return frame
def read(self) -> np.ndarray:
"""Read frame (uses last cached or t=0)."""
return self.read_at(0)
def skip(self):
"""No-op for seek-based reading."""
pass
@property
def size(self):
return self._source.size
@property
def path(self):
return self._source.path
def close(self):
self._stop_event.set()
self._request_event.set() # Wake up thread to exit
self._thread.join(timeout=1.0)
self._source.close()
class AudioAnalyzer:
"""Audio analyzer for energy and beat detection."""
def __init__(self, path: str, sample_rate: int = 22050):
self.path = Path(path)
self.sample_rate = sample_rate
# Check if file exists
if not self.path.exists():
raise FileNotFoundError(f"Audio file not found: {self.path}")
# Load audio via ffmpeg
cmd = ["ffmpeg", "-v", "error", "-i", str(self.path),
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}")
self._audio = np.frombuffer(result.stdout, dtype=np.float32)
if len(self._audio) == 0:
raise RuntimeError(f"Audio file is empty or invalid: {self.path}")
# Get duration
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
"-show_format", str(self.path)]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}")
info = json.loads(result.stdout)
self.duration = float(info.get("format", {}).get("duration", 60))
# Beat detection state
self._flux_history = []
self._last_beat_time = -1
self._beat_count = 0
self._last_beat_check_time = -1
# Cache beat result for current time (so multiple scans see same result)
self._beat_cache_time = -1
self._beat_cache_result = False
def get_energy(self, t: float) -> float:
"""Get energy level at time t (0-1)."""
idx = int(t * self.sample_rate)
start = max(0, idx - 512)
end = min(len(self._audio), idx + 512)
if start >= end:
return 0.0
return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0)
def get_beat(self, t: float) -> bool:
"""Check if there's a beat at time t."""
# Return cached result if same time (multiple scans query same frame)
if t == self._beat_cache_time:
return self._beat_cache_result
idx = int(t * self.sample_rate)
size = 2048
start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2)
if end - start < size/2:
self._beat_cache_time = t
self._beat_cache_result = False
return False
curr = self._audio[start:end]
pstart, pend = max(0, start - 512), max(0, end - 512)
if pend <= pstart:
self._beat_cache_time = t
self._beat_cache_result = False
return False
prev = self._audio[pstart:pend]
curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr))))
prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev))))
n = min(len(curr_spec), len(prev_spec))
flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1)
self._flux_history.append((t, flux))
if len(self._flux_history) > 50:
self._flux_history = self._flux_history[-50:]
if len(self._flux_history) < 5:
self._beat_cache_time = t
self._beat_cache_result = False
return False
recent = [f for _, f in self._flux_history[-20:]]
threshold = np.mean(recent) + 1.5 * np.std(recent)
is_beat = flux > threshold and (t - self._last_beat_time) > 0.1
if is_beat:
self._last_beat_time = t
if t > self._last_beat_check_time:
self._beat_count += 1
self._last_beat_check_time = t
# Cache result for this time
self._beat_cache_time = t
self._beat_cache_result = is_beat
return is_beat
def get_beat_count(self, t: float) -> int:
"""Get cumulative beat count up to time t."""
# Ensure beat detection has run up to this time
self.get_beat(t)
return self._beat_count
# === Primitives ===
def prim_make_video_source(path: str, fps: float = 30):
"""Create a video source from a file path.
Uses PrefetchingVideoSource if STREAMING_PREFETCH=1 (default).
"""
if PREFETCH_ENABLED:
return PrefetchingVideoSource(path, fps)
return VideoSource(path, fps)
def prim_source_read(source: VideoSource, t: float = None):
"""Read a frame from a video source."""
import sys
if t is not None:
frame = source.read_at(t)
# Debug: show source and time
if int(t * 10) % 10 == 0: # Every second
print(f"READ {source.path.name}: t={t:.2f} stream={source._stream_time:.2f}", file=sys.stderr)
return frame
return source.read()
def prim_source_skip(source: VideoSource):
"""Skip a frame (keep pipe in sync)."""
source.skip()
def prim_source_size(source: VideoSource):
"""Get (width, height) of source."""
return source.size
def prim_make_audio_analyzer(path: str):
"""Create an audio analyzer from a file path."""
return AudioAnalyzer(path)
def prim_audio_energy(analyzer: AudioAnalyzer, t: float) -> float:
"""Get energy level (0-1) at time t."""
return analyzer.get_energy(t)
def prim_audio_beat(analyzer: AudioAnalyzer, t: float) -> bool:
"""Check if there's a beat at time t."""
return analyzer.get_beat(t)
def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int:
"""Get cumulative beat count up to time t."""
return analyzer.get_beat_count(t)
def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
"""Get audio duration in seconds."""
return analyzer.duration
# Export primitives
PRIMITIVES = {
# Video source
'make-video-source': prim_make_video_source,
'source-read': prim_source_read,
'source-skip': prim_source_skip,
'source-size': prim_source_size,
# Audio analyzer
'make-audio-analyzer': prim_make_audio_analyzer,
'audio-energy': prim_audio_energy,
'audio-beat': prim_audio_beat,
'audio-beat-count': prim_audio_beat_count,
'audio-duration': prim_audio_duration,
}