Add GPU persistence mode and hardware decode support to streaming primitives
- Add CuPy integration for GPU-resident frame output - Add NVDEC hardware decode detection and ffmpeg acceleration - Configurable via STREAMING_GPU_PERSIST environment variable Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,13 +3,53 @@ Streaming primitives for video/audio processing.
|
||||
|
||||
These primitives handle video source reading and audio analysis,
|
||||
keeping the interpreter completely generic.
|
||||
|
||||
GPU Acceleration:
|
||||
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
||||
- Hardware video decoding (NVDEC) is used when available
|
||||
- Dramatically improves performance on GPU nodes
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Try to import CuPy for GPU acceleration
|
||||
try:
|
||||
import cupy as cp
|
||||
CUPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
cp = None
|
||||
CUPY_AVAILABLE = False
|
||||
|
||||
# GPU persistence mode - output CuPy arrays instead of numpy
|
||||
# Disabled by default until all primitives support GPU frames
|
||||
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
||||
|
||||
# Check for hardware decode support (cached)
|
||||
_HWDEC_AVAILABLE = None
|
||||
|
||||
|
||||
def _check_hwdec():
|
||||
"""Check if NVIDIA hardware decode is available."""
|
||||
global _HWDEC_AVAILABLE
|
||||
if _HWDEC_AVAILABLE is not None:
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
try:
|
||||
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
|
||||
if result.returncode != 0:
|
||||
_HWDEC_AVAILABLE = False
|
||||
return False
|
||||
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
|
||||
_HWDEC_AVAILABLE = "cuda" in result.stdout
|
||||
except Exception:
|
||||
_HWDEC_AVAILABLE = False
|
||||
|
||||
return _HWDEC_AVAILABLE
|
||||
|
||||
|
||||
class VideoSource:
|
||||
"""Video source with persistent streaming pipe for fast sequential reads."""
|
||||
@@ -25,11 +65,20 @@ class VideoSource:
|
||||
self._last_read_time = -1
|
||||
self._cached_frame = None
|
||||
|
||||
# Check if file exists
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||
|
||||
# Get video info
|
||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_streams", str(self.path)]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}")
|
||||
try:
|
||||
info = json.loads(result.stdout)
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}")
|
||||
|
||||
for stream in info.get("streams", []):
|
||||
if stream.get("codec_type") == "video":
|
||||
@@ -46,30 +95,64 @@ class VideoSource:
|
||||
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
|
||||
break
|
||||
|
||||
# Fallback: check format duration if stream duration not found
|
||||
if self._duration is None and "format" in info and "duration" in info["format"]:
|
||||
self._duration = float(info["format"]["duration"])
|
||||
|
||||
if not self._frame_size:
|
||||
self._frame_size = (720, 720)
|
||||
|
||||
import sys
|
||||
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
|
||||
|
||||
def _start_stream(self, seek_time: float = 0):
|
||||
"""Start or restart the ffmpeg streaming process."""
|
||||
"""Start or restart the ffmpeg streaming process.
|
||||
|
||||
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
|
||||
"""
|
||||
if self._proc:
|
||||
self._proc.kill()
|
||||
self._proc = None
|
||||
|
||||
# Check file exists before trying to open
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||
|
||||
w, h = self._frame_size
|
||||
cmd = [
|
||||
"ffmpeg", "-v", "quiet",
|
||||
|
||||
# Build ffmpeg command with optional hardware decode
|
||||
cmd = ["ffmpeg", "-v", "error"]
|
||||
|
||||
# Use hardware decode if available (significantly faster)
|
||||
if _check_hwdec():
|
||||
cmd.extend(["-hwaccel", "cuda"])
|
||||
|
||||
cmd.extend([
|
||||
"-ss", f"{seek_time:.3f}",
|
||||
"-i", str(self.path),
|
||||
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(self.fps), # Output at specified fps
|
||||
"-"
|
||||
]
|
||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||
])
|
||||
|
||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
self._stream_time = seek_time
|
||||
|
||||
def _read_frame_from_stream(self) -> np.ndarray:
|
||||
"""Read one frame from the stream."""
|
||||
# Check if process started successfully by reading first bit of stderr
|
||||
import select
|
||||
import sys
|
||||
readable, _, _ = select.select([self._proc.stderr], [], [], 0.5)
|
||||
if readable:
|
||||
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||
if err:
|
||||
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
|
||||
|
||||
def _read_frame_from_stream(self):
|
||||
"""Read one frame from the stream.
|
||||
|
||||
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
|
||||
"""
|
||||
w, h = self._frame_size
|
||||
frame_size = w * h * 3
|
||||
|
||||
@@ -80,7 +163,12 @@ class VideoSource:
|
||||
if len(data) < frame_size:
|
||||
return None
|
||||
|
||||
return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||
frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||
|
||||
# Transfer to GPU if persistence mode enabled
|
||||
if GPU_PERSIST:
|
||||
return cp.asarray(frame)
|
||||
return frame
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Read frame (uses last cached or t=0)."""
|
||||
@@ -100,6 +188,9 @@ class VideoSource:
|
||||
seek_time = t
|
||||
if self._duration and self._duration > 0:
|
||||
seek_time = t % self._duration
|
||||
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
|
||||
if seek_time > self._duration - 0.1:
|
||||
seek_time = 0.0
|
||||
|
||||
# Decide whether to seek or continue streaming
|
||||
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
|
||||
@@ -118,20 +209,59 @@ class VideoSource:
|
||||
self._start_stream(seek_time)
|
||||
|
||||
# Skip frames to reach target time
|
||||
skip_retries = 0
|
||||
while self._stream_time + self._frame_time <= seek_time:
|
||||
frame = self._read_frame_from_stream()
|
||||
if frame is None:
|
||||
# Stream ended, restart from seek point
|
||||
# Stream ended or failed - restart from seek point
|
||||
import time
|
||||
skip_retries += 1
|
||||
if skip_retries > 3:
|
||||
# Give up skipping, just start fresh at seek_time
|
||||
self._start_stream(seek_time)
|
||||
time.sleep(0.1)
|
||||
break
|
||||
self._start_stream(seek_time)
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
self._stream_time += self._frame_time
|
||||
skip_retries = 0 # Reset on successful read
|
||||
|
||||
# Read the target frame
|
||||
# Read the target frame with retry logic
|
||||
frame = None
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
frame = self._read_frame_from_stream()
|
||||
if frame is not None:
|
||||
break
|
||||
|
||||
# Stream failed - try restarting
|
||||
import sys
|
||||
import time
|
||||
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
|
||||
|
||||
# Check for ffmpeg errors
|
||||
if self._proc and self._proc.stderr:
|
||||
try:
|
||||
import select
|
||||
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
|
||||
if readable:
|
||||
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||
if err:
|
||||
print(f"ffmpeg error: {err}", file=sys.stderr)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Wait a bit and restart
|
||||
time.sleep(0.1)
|
||||
self._start_stream(seek_time)
|
||||
|
||||
# Give ffmpeg time to start
|
||||
time.sleep(0.1)
|
||||
|
||||
if frame is None:
|
||||
import sys
|
||||
print(f"NULL FRAME {self.path.name}: t={t:.2f} seek={seek_time:.2f}", file=sys.stderr)
|
||||
frame = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
|
||||
else:
|
||||
self._stream_time += self._frame_time
|
||||
|
||||
@@ -160,16 +290,27 @@ class AudioAnalyzer:
|
||||
self.path = Path(path)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
# Check if file exists
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {self.path}")
|
||||
|
||||
# Load audio via ffmpeg
|
||||
cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path),
|
||||
cmd = ["ffmpeg", "-v", "error", "-i", str(self.path),
|
||||
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}")
|
||||
self._audio = np.frombuffer(result.stdout, dtype=np.float32)
|
||||
if len(self._audio) == 0:
|
||||
raise RuntimeError(f"Audio file is empty or invalid: {self.path}")
|
||||
|
||||
# Get duration
|
||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_format", str(self.path)]
|
||||
info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout)
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}")
|
||||
info = json.loads(result.stdout)
|
||||
self.duration = float(info.get("format", {}).get("duration", 60))
|
||||
|
||||
# Beat detection state
|
||||
@@ -302,3 +443,20 @@ def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int:
|
||||
def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
|
||||
"""Get audio duration in seconds."""
|
||||
return analyzer.duration
|
||||
|
||||
|
||||
# Export primitives
|
||||
PRIMITIVES = {
|
||||
# Video source
|
||||
'make-video-source': prim_make_video_source,
|
||||
'source-read': prim_source_read,
|
||||
'source-skip': prim_source_skip,
|
||||
'source-size': prim_source_size,
|
||||
|
||||
# Audio analyzer
|
||||
'make-audio-analyzer': prim_make_audio_analyzer,
|
||||
'audio-energy': prim_audio_energy,
|
||||
'audio-beat': prim_audio_beat,
|
||||
'audio-beat-count': prim_audio_beat_count,
|
||||
'audio-duration': prim_audio_duration,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user