Add GPU persistence mode and hardware decode support to streaming primitives

- Add CuPy integration for GPU-resident frame output
- Add NVDEC hardware decode detection and ffmpeg acceleration
- Configurable via STREAMING_GPU_PERSIST environment variable

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-02-06 15:13:02 +00:00
parent 95fcc67dcc
commit f2edc20cba

View File

@@ -3,13 +3,53 @@ Streaming primitives for video/audio processing.
These primitives handle video source reading and audio analysis, These primitives handle video source reading and audio analysis,
keeping the interpreter completely generic. keeping the interpreter completely generic.
GPU Acceleration:
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
- Hardware video decoding (NVDEC) is used when available
- Dramatically improves performance on GPU nodes
""" """
import os
import numpy as np import numpy as np
import subprocess import subprocess
import json import json
from pathlib import Path from pathlib import Path
# Try to import CuPy for GPU acceleration
try:
import cupy as cp
CUPY_AVAILABLE = True
except ImportError:
cp = None
CUPY_AVAILABLE = False
# GPU persistence mode - output CuPy arrays instead of numpy
# Disabled by default until all primitives support GPU frames
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
# Check for hardware decode support (cached)
_HWDEC_AVAILABLE = None
def _check_hwdec():
"""Check if NVIDIA hardware decode is available."""
global _HWDEC_AVAILABLE
if _HWDEC_AVAILABLE is not None:
return _HWDEC_AVAILABLE
try:
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
if result.returncode != 0:
_HWDEC_AVAILABLE = False
return False
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
_HWDEC_AVAILABLE = "cuda" in result.stdout
except Exception:
_HWDEC_AVAILABLE = False
return _HWDEC_AVAILABLE
class VideoSource: class VideoSource:
"""Video source with persistent streaming pipe for fast sequential reads.""" """Video source with persistent streaming pipe for fast sequential reads."""
@@ -25,11 +65,20 @@ class VideoSource:
self._last_read_time = -1 self._last_read_time = -1
self._cached_frame = None self._cached_frame = None
# Check if file exists
if not self.path.exists():
raise FileNotFoundError(f"Video file not found: {self.path}")
# Get video info # Get video info
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", str(self.path)] "-show_streams", str(self.path)]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
info = json.loads(result.stdout) if result.returncode != 0:
raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}")
try:
info = json.loads(result.stdout)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}")
for stream in info.get("streams", []): for stream in info.get("streams", []):
if stream.get("codec_type") == "video": if stream.get("codec_type") == "video":
@@ -46,30 +95,64 @@ class VideoSource:
self._duration = int(h) * 3600 + int(m) * 60 + float(s) self._duration = int(h) * 3600 + int(m) * 60 + float(s)
break break
# Fallback: check format duration if stream duration not found
if self._duration is None and "format" in info and "duration" in info["format"]:
self._duration = float(info["format"]["duration"])
if not self._frame_size: if not self._frame_size:
self._frame_size = (720, 720) self._frame_size = (720, 720)
import sys
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
def _start_stream(self, seek_time: float = 0): def _start_stream(self, seek_time: float = 0):
"""Start or restart the ffmpeg streaming process.""" """Start or restart the ffmpeg streaming process.
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
"""
if self._proc: if self._proc:
self._proc.kill() self._proc.kill()
self._proc = None self._proc = None
# Check file exists before trying to open
if not self.path.exists():
raise FileNotFoundError(f"Video file not found: {self.path}")
w, h = self._frame_size w, h = self._frame_size
cmd = [
"ffmpeg", "-v", "quiet", # Build ffmpeg command with optional hardware decode
cmd = ["ffmpeg", "-v", "error"]
# Use hardware decode if available (significantly faster)
if _check_hwdec():
cmd.extend(["-hwaccel", "cuda"])
cmd.extend([
"-ss", f"{seek_time:.3f}", "-ss", f"{seek_time:.3f}",
"-i", str(self.path), "-i", str(self.path),
"-f", "rawvideo", "-pix_fmt", "rgb24", "-f", "rawvideo", "-pix_fmt", "rgb24",
"-s", f"{w}x{h}", "-s", f"{w}x{h}",
"-r", str(self.fps), # Output at specified fps "-r", str(self.fps), # Output at specified fps
"-" "-"
] ])
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self._stream_time = seek_time self._stream_time = seek_time
def _read_frame_from_stream(self) -> np.ndarray: # Check if process started successfully by reading first bit of stderr
"""Read one frame from the stream.""" import select
import sys
readable, _, _ = select.select([self._proc.stderr], [], [], 0.5)
if readable:
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
if err:
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
def _read_frame_from_stream(self):
"""Read one frame from the stream.
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
"""
w, h = self._frame_size w, h = self._frame_size
frame_size = w * h * 3 frame_size = w * h * 3
@@ -80,7 +163,12 @@ class VideoSource:
if len(data) < frame_size: if len(data) < frame_size:
return None return None
return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy() frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
# Transfer to GPU if persistence mode enabled
if GPU_PERSIST:
return cp.asarray(frame)
return frame
def read(self) -> np.ndarray: def read(self) -> np.ndarray:
"""Read frame (uses last cached or t=0).""" """Read frame (uses last cached or t=0)."""
@@ -100,6 +188,9 @@ class VideoSource:
seek_time = t seek_time = t
if self._duration and self._duration > 0: if self._duration and self._duration > 0:
seek_time = t % self._duration seek_time = t % self._duration
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
if seek_time > self._duration - 0.1:
seek_time = 0.0
# Decide whether to seek or continue streaming # Decide whether to seek or continue streaming
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead # Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
@@ -118,20 +209,59 @@ class VideoSource:
self._start_stream(seek_time) self._start_stream(seek_time)
# Skip frames to reach target time # Skip frames to reach target time
skip_retries = 0
while self._stream_time + self._frame_time <= seek_time: while self._stream_time + self._frame_time <= seek_time:
frame = self._read_frame_from_stream() frame = self._read_frame_from_stream()
if frame is None: if frame is None:
# Stream ended, restart from seek point # Stream ended or failed - restart from seek point
import time
skip_retries += 1
if skip_retries > 3:
# Give up skipping, just start fresh at seek_time
self._start_stream(seek_time)
time.sleep(0.1)
break
self._start_stream(seek_time) self._start_stream(seek_time)
break time.sleep(0.05)
continue
self._stream_time += self._frame_time self._stream_time += self._frame_time
skip_retries = 0 # Reset on successful read
# Read the target frame with retry logic
frame = None
max_retries = 3
for attempt in range(max_retries):
frame = self._read_frame_from_stream()
if frame is not None:
break
# Stream failed - try restarting
import sys
import time
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
# Check for ffmpeg errors
if self._proc and self._proc.stderr:
try:
import select
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
if readable:
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
if err:
print(f"ffmpeg error: {err}", file=sys.stderr)
except:
pass
# Wait a bit and restart
time.sleep(0.1)
self._start_stream(seek_time)
# Give ffmpeg time to start
time.sleep(0.1)
# Read the target frame
frame = self._read_frame_from_stream()
if frame is None: if frame is None:
import sys import sys
print(f"NULL FRAME {self.path.name}: t={t:.2f} seek={seek_time:.2f}", file=sys.stderr) raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
frame = np.zeros((h, w, 3), dtype=np.uint8)
else: else:
self._stream_time += self._frame_time self._stream_time += self._frame_time
@@ -160,16 +290,27 @@ class AudioAnalyzer:
self.path = Path(path) self.path = Path(path)
self.sample_rate = sample_rate self.sample_rate = sample_rate
# Check if file exists
if not self.path.exists():
raise FileNotFoundError(f"Audio file not found: {self.path}")
# Load audio via ffmpeg # Load audio via ffmpeg
cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path), cmd = ["ffmpeg", "-v", "error", "-i", str(self.path),
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"] "-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
result = subprocess.run(cmd, capture_output=True) result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}")
self._audio = np.frombuffer(result.stdout, dtype=np.float32) self._audio = np.frombuffer(result.stdout, dtype=np.float32)
if len(self._audio) == 0:
raise RuntimeError(f"Audio file is empty or invalid: {self.path}")
# Get duration # Get duration
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
"-show_format", str(self.path)] "-show_format", str(self.path)]
info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout) result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}")
info = json.loads(result.stdout)
self.duration = float(info.get("format", {}).get("duration", 60)) self.duration = float(info.get("format", {}).get("duration", 60))
# Beat detection state # Beat detection state
@@ -302,3 +443,20 @@ def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int:
def prim_audio_duration(analyzer: AudioAnalyzer) -> float: def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
"""Get audio duration in seconds.""" """Get audio duration in seconds."""
return analyzer.duration return analyzer.duration
# Export primitives
PRIMITIVES = {
# Video source
'make-video-source': prim_make_video_source,
'source-read': prim_source_read,
'source-skip': prim_source_skip,
'source-size': prim_source_size,
# Audio analyzer
'make-audio-analyzer': prim_make_audio_analyzer,
'audio-energy': prim_audio_energy,
'audio-beat': prim_audio_beat,
'audio-beat-count': prim_audio_beat_count,
'audio-duration': prim_audio_duration,
}