- Add CuPy integration for GPU-resident frame output - Add NVDEC hardware decode detection and ffmpeg acceleration - Configurable via STREAMING_GPU_PERSIST environment variable Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
463 lines
16 KiB
Python
463 lines
16 KiB
Python
"""
|
|
Streaming primitives for video/audio processing.
|
|
|
|
These primitives handle video source reading and audio analysis,
|
|
keeping the interpreter completely generic.
|
|
|
|
GPU Acceleration:
|
|
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
|
- Hardware video decoding (NVDEC) is used when available
|
|
- Dramatically improves performance on GPU nodes
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
import subprocess
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Try to import CuPy for GPU acceleration
|
|
try:
|
|
import cupy as cp
|
|
CUPY_AVAILABLE = True
|
|
except ImportError:
|
|
cp = None
|
|
CUPY_AVAILABLE = False
|
|
|
|
# GPU persistence mode - output CuPy arrays instead of numpy
|
|
# Disabled by default until all primitives support GPU frames
|
|
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
|
|
|
# Check for hardware decode support (cached)
|
|
_HWDEC_AVAILABLE = None
|
|
|
|
|
|
def _check_hwdec():
|
|
"""Check if NVIDIA hardware decode is available."""
|
|
global _HWDEC_AVAILABLE
|
|
if _HWDEC_AVAILABLE is not None:
|
|
return _HWDEC_AVAILABLE
|
|
|
|
try:
|
|
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
|
|
if result.returncode != 0:
|
|
_HWDEC_AVAILABLE = False
|
|
return False
|
|
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
|
|
_HWDEC_AVAILABLE = "cuda" in result.stdout
|
|
except Exception:
|
|
_HWDEC_AVAILABLE = False
|
|
|
|
return _HWDEC_AVAILABLE
|
|
|
|
|
|
class VideoSource:
|
|
"""Video source with persistent streaming pipe for fast sequential reads."""
|
|
|
|
def __init__(self, path: str, fps: float = 30):
|
|
self.path = Path(path)
|
|
self.fps = fps # Output fps for the stream
|
|
self._frame_size = None
|
|
self._duration = None
|
|
self._proc = None # Persistent ffmpeg process
|
|
self._stream_time = 0.0 # Current position in stream
|
|
self._frame_time = 1.0 / fps # Time per frame at output fps
|
|
self._last_read_time = -1
|
|
self._cached_frame = None
|
|
|
|
# Check if file exists
|
|
if not self.path.exists():
|
|
raise FileNotFoundError(f"Video file not found: {self.path}")
|
|
|
|
# Get video info
|
|
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
|
"-show_streams", str(self.path)]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}")
|
|
try:
|
|
info = json.loads(result.stdout)
|
|
except json.JSONDecodeError:
|
|
raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}")
|
|
|
|
for stream in info.get("streams", []):
|
|
if stream.get("codec_type") == "video":
|
|
self._frame_size = (stream.get("width", 720), stream.get("height", 720))
|
|
# Try direct duration field first
|
|
if "duration" in stream:
|
|
self._duration = float(stream["duration"])
|
|
# Fall back to tags.DURATION (webm format: "00:01:00.124000000")
|
|
elif "tags" in stream and "DURATION" in stream["tags"]:
|
|
dur_str = stream["tags"]["DURATION"]
|
|
parts = dur_str.split(":")
|
|
if len(parts) == 3:
|
|
h, m, s = parts
|
|
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
|
|
break
|
|
|
|
# Fallback: check format duration if stream duration not found
|
|
if self._duration is None and "format" in info and "duration" in info["format"]:
|
|
self._duration = float(info["format"]["duration"])
|
|
|
|
if not self._frame_size:
|
|
self._frame_size = (720, 720)
|
|
|
|
import sys
|
|
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
|
|
|
|
def _start_stream(self, seek_time: float = 0):
|
|
"""Start or restart the ffmpeg streaming process.
|
|
|
|
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
|
|
"""
|
|
if self._proc:
|
|
self._proc.kill()
|
|
self._proc = None
|
|
|
|
# Check file exists before trying to open
|
|
if not self.path.exists():
|
|
raise FileNotFoundError(f"Video file not found: {self.path}")
|
|
|
|
w, h = self._frame_size
|
|
|
|
# Build ffmpeg command with optional hardware decode
|
|
cmd = ["ffmpeg", "-v", "error"]
|
|
|
|
# Use hardware decode if available (significantly faster)
|
|
if _check_hwdec():
|
|
cmd.extend(["-hwaccel", "cuda"])
|
|
|
|
cmd.extend([
|
|
"-ss", f"{seek_time:.3f}",
|
|
"-i", str(self.path),
|
|
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
|
"-s", f"{w}x{h}",
|
|
"-r", str(self.fps), # Output at specified fps
|
|
"-"
|
|
])
|
|
|
|
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
self._stream_time = seek_time
|
|
|
|
# Check if process started successfully by reading first bit of stderr
|
|
import select
|
|
import sys
|
|
readable, _, _ = select.select([self._proc.stderr], [], [], 0.5)
|
|
if readable:
|
|
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
|
if err:
|
|
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
|
|
|
|
def _read_frame_from_stream(self):
|
|
"""Read one frame from the stream.
|
|
|
|
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
|
|
"""
|
|
w, h = self._frame_size
|
|
frame_size = w * h * 3
|
|
|
|
if not self._proc or self._proc.poll() is not None:
|
|
return None
|
|
|
|
data = self._proc.stdout.read(frame_size)
|
|
if len(data) < frame_size:
|
|
return None
|
|
|
|
frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
|
|
|
# Transfer to GPU if persistence mode enabled
|
|
if GPU_PERSIST:
|
|
return cp.asarray(frame)
|
|
return frame
|
|
|
|
def read(self) -> np.ndarray:
|
|
"""Read frame (uses last cached or t=0)."""
|
|
if self._cached_frame is not None:
|
|
return self._cached_frame
|
|
return self.read_at(0)
|
|
|
|
def read_at(self, t: float) -> np.ndarray:
|
|
"""Read frame at specific time using streaming with smart seeking."""
|
|
# Cache check - return same frame for same time
|
|
if t == self._last_read_time and self._cached_frame is not None:
|
|
return self._cached_frame
|
|
|
|
w, h = self._frame_size
|
|
|
|
# Loop time if video is shorter
|
|
seek_time = t
|
|
if self._duration and self._duration > 0:
|
|
seek_time = t % self._duration
|
|
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
|
|
if seek_time > self._duration - 0.1:
|
|
seek_time = 0.0
|
|
|
|
# Decide whether to seek or continue streaming
|
|
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
|
|
# Allow small backward tolerance to handle floating point and timing jitter
|
|
need_seek = (
|
|
self._proc is None or
|
|
self._proc.poll() is not None or
|
|
seek_time < self._stream_time - self._frame_time or # More than 1 frame backward
|
|
seek_time > self._stream_time + 2.0
|
|
)
|
|
|
|
if need_seek:
|
|
import sys
|
|
reason = "no proc" if self._proc is None else "proc dead" if self._proc.poll() is not None else "backward" if seek_time < self._stream_time else "jump"
|
|
print(f"SEEK {self.path.name}: t={t:.4f} seek={seek_time:.4f} stream={self._stream_time:.4f} ({reason})", file=sys.stderr)
|
|
self._start_stream(seek_time)
|
|
|
|
# Skip frames to reach target time
|
|
skip_retries = 0
|
|
while self._stream_time + self._frame_time <= seek_time:
|
|
frame = self._read_frame_from_stream()
|
|
if frame is None:
|
|
# Stream ended or failed - restart from seek point
|
|
import time
|
|
skip_retries += 1
|
|
if skip_retries > 3:
|
|
# Give up skipping, just start fresh at seek_time
|
|
self._start_stream(seek_time)
|
|
time.sleep(0.1)
|
|
break
|
|
self._start_stream(seek_time)
|
|
time.sleep(0.05)
|
|
continue
|
|
self._stream_time += self._frame_time
|
|
skip_retries = 0 # Reset on successful read
|
|
|
|
# Read the target frame with retry logic
|
|
frame = None
|
|
max_retries = 3
|
|
for attempt in range(max_retries):
|
|
frame = self._read_frame_from_stream()
|
|
if frame is not None:
|
|
break
|
|
|
|
# Stream failed - try restarting
|
|
import sys
|
|
import time
|
|
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
|
|
|
|
# Check for ffmpeg errors
|
|
if self._proc and self._proc.stderr:
|
|
try:
|
|
import select
|
|
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
|
|
if readable:
|
|
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
|
if err:
|
|
print(f"ffmpeg error: {err}", file=sys.stderr)
|
|
except:
|
|
pass
|
|
|
|
# Wait a bit and restart
|
|
time.sleep(0.1)
|
|
self._start_stream(seek_time)
|
|
|
|
# Give ffmpeg time to start
|
|
time.sleep(0.1)
|
|
|
|
if frame is None:
|
|
import sys
|
|
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
|
|
else:
|
|
self._stream_time += self._frame_time
|
|
|
|
self._last_read_time = t
|
|
self._cached_frame = frame
|
|
return frame
|
|
|
|
def skip(self):
|
|
"""No-op for seek-based reading."""
|
|
pass
|
|
|
|
@property
|
|
def size(self):
|
|
return self._frame_size
|
|
|
|
def close(self):
|
|
if self._proc:
|
|
self._proc.kill()
|
|
self._proc = None
|
|
|
|
|
|
class AudioAnalyzer:
|
|
"""Audio analyzer for energy and beat detection."""
|
|
|
|
def __init__(self, path: str, sample_rate: int = 22050):
|
|
self.path = Path(path)
|
|
self.sample_rate = sample_rate
|
|
|
|
# Check if file exists
|
|
if not self.path.exists():
|
|
raise FileNotFoundError(f"Audio file not found: {self.path}")
|
|
|
|
# Load audio via ffmpeg
|
|
cmd = ["ffmpeg", "-v", "error", "-i", str(self.path),
|
|
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
|
|
result = subprocess.run(cmd, capture_output=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}")
|
|
self._audio = np.frombuffer(result.stdout, dtype=np.float32)
|
|
if len(self._audio) == 0:
|
|
raise RuntimeError(f"Audio file is empty or invalid: {self.path}")
|
|
|
|
# Get duration
|
|
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
|
"-show_format", str(self.path)]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}")
|
|
info = json.loads(result.stdout)
|
|
self.duration = float(info.get("format", {}).get("duration", 60))
|
|
|
|
# Beat detection state
|
|
self._flux_history = []
|
|
self._last_beat_time = -1
|
|
self._beat_count = 0
|
|
self._last_beat_check_time = -1
|
|
# Cache beat result for current time (so multiple scans see same result)
|
|
self._beat_cache_time = -1
|
|
self._beat_cache_result = False
|
|
|
|
def get_energy(self, t: float) -> float:
|
|
"""Get energy level at time t (0-1)."""
|
|
idx = int(t * self.sample_rate)
|
|
start = max(0, idx - 512)
|
|
end = min(len(self._audio), idx + 512)
|
|
if start >= end:
|
|
return 0.0
|
|
return min(1.0, np.sqrt(np.mean(self._audio[start:end] ** 2)) * 3.0)
|
|
|
|
def get_beat(self, t: float) -> bool:
|
|
"""Check if there's a beat at time t."""
|
|
# Return cached result if same time (multiple scans query same frame)
|
|
if t == self._beat_cache_time:
|
|
return self._beat_cache_result
|
|
|
|
idx = int(t * self.sample_rate)
|
|
size = 2048
|
|
|
|
start, end = max(0, idx - size//2), min(len(self._audio), idx + size//2)
|
|
if end - start < size/2:
|
|
self._beat_cache_time = t
|
|
self._beat_cache_result = False
|
|
return False
|
|
curr = self._audio[start:end]
|
|
|
|
pstart, pend = max(0, start - 512), max(0, end - 512)
|
|
if pend <= pstart:
|
|
self._beat_cache_time = t
|
|
self._beat_cache_result = False
|
|
return False
|
|
prev = self._audio[pstart:pend]
|
|
|
|
curr_spec = np.abs(np.fft.rfft(curr * np.hanning(len(curr))))
|
|
prev_spec = np.abs(np.fft.rfft(prev * np.hanning(len(prev))))
|
|
|
|
n = min(len(curr_spec), len(prev_spec))
|
|
flux = np.sum(np.maximum(0, curr_spec[:n] - prev_spec[:n])) / (n + 1)
|
|
|
|
self._flux_history.append((t, flux))
|
|
if len(self._flux_history) > 50:
|
|
self._flux_history = self._flux_history[-50:]
|
|
|
|
if len(self._flux_history) < 5:
|
|
self._beat_cache_time = t
|
|
self._beat_cache_result = False
|
|
return False
|
|
|
|
recent = [f for _, f in self._flux_history[-20:]]
|
|
threshold = np.mean(recent) + 1.5 * np.std(recent)
|
|
|
|
is_beat = flux > threshold and (t - self._last_beat_time) > 0.1
|
|
if is_beat:
|
|
self._last_beat_time = t
|
|
if t > self._last_beat_check_time:
|
|
self._beat_count += 1
|
|
self._last_beat_check_time = t
|
|
|
|
# Cache result for this time
|
|
self._beat_cache_time = t
|
|
self._beat_cache_result = is_beat
|
|
return is_beat
|
|
|
|
def get_beat_count(self, t: float) -> int:
|
|
"""Get cumulative beat count up to time t."""
|
|
# Ensure beat detection has run up to this time
|
|
self.get_beat(t)
|
|
return self._beat_count
|
|
|
|
|
|
# === Primitives ===
|
|
|
|
def prim_make_video_source(path: str, fps: float = 30):
|
|
"""Create a video source from a file path."""
|
|
return VideoSource(path, fps)
|
|
|
|
|
|
def prim_source_read(source: VideoSource, t: float = None):
|
|
"""Read a frame from a video source."""
|
|
import sys
|
|
if t is not None:
|
|
frame = source.read_at(t)
|
|
# Debug: show source and time
|
|
if int(t * 10) % 10 == 0: # Every second
|
|
print(f"READ {source.path.name}: t={t:.2f} stream={source._stream_time:.2f}", file=sys.stderr)
|
|
return frame
|
|
return source.read()
|
|
|
|
|
|
def prim_source_skip(source: VideoSource):
|
|
"""Skip a frame (keep pipe in sync)."""
|
|
source.skip()
|
|
|
|
|
|
def prim_source_size(source: VideoSource):
|
|
"""Get (width, height) of source."""
|
|
return source.size
|
|
|
|
|
|
def prim_make_audio_analyzer(path: str):
|
|
"""Create an audio analyzer from a file path."""
|
|
return AudioAnalyzer(path)
|
|
|
|
|
|
def prim_audio_energy(analyzer: AudioAnalyzer, t: float) -> float:
|
|
"""Get energy level (0-1) at time t."""
|
|
return analyzer.get_energy(t)
|
|
|
|
|
|
def prim_audio_beat(analyzer: AudioAnalyzer, t: float) -> bool:
|
|
"""Check if there's a beat at time t."""
|
|
return analyzer.get_beat(t)
|
|
|
|
|
|
def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int:
|
|
"""Get cumulative beat count up to time t."""
|
|
return analyzer.get_beat_count(t)
|
|
|
|
|
|
def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
|
|
"""Get audio duration in seconds."""
|
|
return analyzer.duration
|
|
|
|
|
|
# Export primitives
|
|
PRIMITIVES = {
|
|
# Video source
|
|
'make-video-source': prim_make_video_source,
|
|
'source-read': prim_source_read,
|
|
'source-skip': prim_source_skip,
|
|
'source-size': prim_source_size,
|
|
|
|
# Audio analyzer
|
|
'make-audio-analyzer': prim_make_audio_analyzer,
|
|
'audio-energy': prim_audio_energy,
|
|
'audio-beat': prim_audio_beat,
|
|
'audio-beat-count': prim_audio_beat_count,
|
|
'audio-duration': prim_audio_duration,
|
|
}
|