Add GPU persistence mode and hardware decode support to streaming primitives
- Add CuPy integration for GPU-resident frame output - Add NVDEC hardware decode detection and ffmpeg acceleration - Configurable via STREAMING_GPU_PERSIST environment variable Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,13 +3,53 @@ Streaming primitives for video/audio processing.
|
|||||||
|
|
||||||
These primitives handle video source reading and audio analysis,
|
These primitives handle video source reading and audio analysis,
|
||||||
keeping the interpreter completely generic.
|
keeping the interpreter completely generic.
|
||||||
|
|
||||||
|
GPU Acceleration:
|
||||||
|
- Set STREAMING_GPU_PERSIST=1 to output CuPy arrays (frames stay on GPU)
|
||||||
|
- Hardware video decoding (NVDEC) is used when available
|
||||||
|
- Dramatically improves performance on GPU nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import subprocess
|
import subprocess
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Try to import CuPy for GPU acceleration
|
||||||
|
try:
|
||||||
|
import cupy as cp
|
||||||
|
CUPY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
cp = None
|
||||||
|
CUPY_AVAILABLE = False
|
||||||
|
|
||||||
|
# GPU persistence mode - output CuPy arrays instead of numpy
|
||||||
|
# Disabled by default until all primitives support GPU frames
|
||||||
|
GPU_PERSIST = os.environ.get("STREAMING_GPU_PERSIST", "0") == "1" and CUPY_AVAILABLE
|
||||||
|
|
||||||
|
# Check for hardware decode support (cached)
|
||||||
|
_HWDEC_AVAILABLE = None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_hwdec():
|
||||||
|
"""Check if NVIDIA hardware decode is available."""
|
||||||
|
global _HWDEC_AVAILABLE
|
||||||
|
if _HWDEC_AVAILABLE is not None:
|
||||||
|
return _HWDEC_AVAILABLE
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=2)
|
||||||
|
if result.returncode != 0:
|
||||||
|
_HWDEC_AVAILABLE = False
|
||||||
|
return False
|
||||||
|
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
|
||||||
|
_HWDEC_AVAILABLE = "cuda" in result.stdout
|
||||||
|
except Exception:
|
||||||
|
_HWDEC_AVAILABLE = False
|
||||||
|
|
||||||
|
return _HWDEC_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
class VideoSource:
|
class VideoSource:
|
||||||
"""Video source with persistent streaming pipe for fast sequential reads."""
|
"""Video source with persistent streaming pipe for fast sequential reads."""
|
||||||
@@ -25,11 +65,20 @@ class VideoSource:
|
|||||||
self._last_read_time = -1
|
self._last_read_time = -1
|
||||||
self._cached_frame = None
|
self._cached_frame = None
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not self.path.exists():
|
||||||
|
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||||
|
|
||||||
# Get video info
|
# Get video info
|
||||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
||||||
"-show_streams", str(self.path)]
|
"-show_streams", str(self.path)]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Failed to probe video '{self.path}': {result.stderr}")
|
||||||
|
try:
|
||||||
info = json.loads(result.stdout)
|
info = json.loads(result.stdout)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise RuntimeError(f"Invalid video file or ffprobe failed: {self.path}")
|
||||||
|
|
||||||
for stream in info.get("streams", []):
|
for stream in info.get("streams", []):
|
||||||
if stream.get("codec_type") == "video":
|
if stream.get("codec_type") == "video":
|
||||||
@@ -46,30 +95,64 @@ class VideoSource:
|
|||||||
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
|
self._duration = int(h) * 3600 + int(m) * 60 + float(s)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Fallback: check format duration if stream duration not found
|
||||||
|
if self._duration is None and "format" in info and "duration" in info["format"]:
|
||||||
|
self._duration = float(info["format"]["duration"])
|
||||||
|
|
||||||
if not self._frame_size:
|
if not self._frame_size:
|
||||||
self._frame_size = (720, 720)
|
self._frame_size = (720, 720)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
print(f"VideoSource: {self.path.name} duration={self._duration} size={self._frame_size}", file=sys.stderr)
|
||||||
|
|
||||||
def _start_stream(self, seek_time: float = 0):
|
def _start_stream(self, seek_time: float = 0):
|
||||||
"""Start or restart the ffmpeg streaming process."""
|
"""Start or restart the ffmpeg streaming process.
|
||||||
|
|
||||||
|
Uses NVIDIA hardware decoding (NVDEC) when available for better performance.
|
||||||
|
"""
|
||||||
if self._proc:
|
if self._proc:
|
||||||
self._proc.kill()
|
self._proc.kill()
|
||||||
self._proc = None
|
self._proc = None
|
||||||
|
|
||||||
|
# Check file exists before trying to open
|
||||||
|
if not self.path.exists():
|
||||||
|
raise FileNotFoundError(f"Video file not found: {self.path}")
|
||||||
|
|
||||||
w, h = self._frame_size
|
w, h = self._frame_size
|
||||||
cmd = [
|
|
||||||
"ffmpeg", "-v", "quiet",
|
# Build ffmpeg command with optional hardware decode
|
||||||
|
cmd = ["ffmpeg", "-v", "error"]
|
||||||
|
|
||||||
|
# Use hardware decode if available (significantly faster)
|
||||||
|
if _check_hwdec():
|
||||||
|
cmd.extend(["-hwaccel", "cuda"])
|
||||||
|
|
||||||
|
cmd.extend([
|
||||||
"-ss", f"{seek_time:.3f}",
|
"-ss", f"{seek_time:.3f}",
|
||||||
"-i", str(self.path),
|
"-i", str(self.path),
|
||||||
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
"-f", "rawvideo", "-pix_fmt", "rgb24",
|
||||||
"-s", f"{w}x{h}",
|
"-s", f"{w}x{h}",
|
||||||
"-r", str(self.fps), # Output at specified fps
|
"-r", str(self.fps), # Output at specified fps
|
||||||
"-"
|
"-"
|
||||||
]
|
])
|
||||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
|
||||||
|
self._proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
self._stream_time = seek_time
|
self._stream_time = seek_time
|
||||||
|
|
||||||
def _read_frame_from_stream(self) -> np.ndarray:
|
# Check if process started successfully by reading first bit of stderr
|
||||||
"""Read one frame from the stream."""
|
import select
|
||||||
|
import sys
|
||||||
|
readable, _, _ = select.select([self._proc.stderr], [], [], 0.5)
|
||||||
|
if readable:
|
||||||
|
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||||
|
if err:
|
||||||
|
print(f"ffmpeg error for {self.path.name}: {err}", file=sys.stderr)
|
||||||
|
|
||||||
|
def _read_frame_from_stream(self):
|
||||||
|
"""Read one frame from the stream.
|
||||||
|
|
||||||
|
Returns CuPy array if GPU_PERSIST is enabled, numpy array otherwise.
|
||||||
|
"""
|
||||||
w, h = self._frame_size
|
w, h = self._frame_size
|
||||||
frame_size = w * h * 3
|
frame_size = w * h * 3
|
||||||
|
|
||||||
@@ -80,7 +163,12 @@ class VideoSource:
|
|||||||
if len(data) < frame_size:
|
if len(data) < frame_size:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
frame = np.frombuffer(data, dtype=np.uint8).reshape((h, w, 3)).copy()
|
||||||
|
|
||||||
|
# Transfer to GPU if persistence mode enabled
|
||||||
|
if GPU_PERSIST:
|
||||||
|
return cp.asarray(frame)
|
||||||
|
return frame
|
||||||
|
|
||||||
def read(self) -> np.ndarray:
|
def read(self) -> np.ndarray:
|
||||||
"""Read frame (uses last cached or t=0)."""
|
"""Read frame (uses last cached or t=0)."""
|
||||||
@@ -100,6 +188,9 @@ class VideoSource:
|
|||||||
seek_time = t
|
seek_time = t
|
||||||
if self._duration and self._duration > 0:
|
if self._duration and self._duration > 0:
|
||||||
seek_time = t % self._duration
|
seek_time = t % self._duration
|
||||||
|
# If we're within 0.1s of the end, wrap to beginning to avoid EOF issues
|
||||||
|
if seek_time > self._duration - 0.1:
|
||||||
|
seek_time = 0.0
|
||||||
|
|
||||||
# Decide whether to seek or continue streaming
|
# Decide whether to seek or continue streaming
|
||||||
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
|
# Seek if: no stream, going backwards (more than 1 frame), or jumping more than 2 seconds ahead
|
||||||
@@ -118,20 +209,59 @@ class VideoSource:
|
|||||||
self._start_stream(seek_time)
|
self._start_stream(seek_time)
|
||||||
|
|
||||||
# Skip frames to reach target time
|
# Skip frames to reach target time
|
||||||
|
skip_retries = 0
|
||||||
while self._stream_time + self._frame_time <= seek_time:
|
while self._stream_time + self._frame_time <= seek_time:
|
||||||
frame = self._read_frame_from_stream()
|
frame = self._read_frame_from_stream()
|
||||||
if frame is None:
|
if frame is None:
|
||||||
# Stream ended, restart from seek point
|
# Stream ended or failed - restart from seek point
|
||||||
|
import time
|
||||||
|
skip_retries += 1
|
||||||
|
if skip_retries > 3:
|
||||||
|
# Give up skipping, just start fresh at seek_time
|
||||||
self._start_stream(seek_time)
|
self._start_stream(seek_time)
|
||||||
|
time.sleep(0.1)
|
||||||
break
|
break
|
||||||
|
self._start_stream(seek_time)
|
||||||
|
time.sleep(0.05)
|
||||||
|
continue
|
||||||
self._stream_time += self._frame_time
|
self._stream_time += self._frame_time
|
||||||
|
skip_retries = 0 # Reset on successful read
|
||||||
|
|
||||||
# Read the target frame
|
# Read the target frame with retry logic
|
||||||
|
frame = None
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
frame = self._read_frame_from_stream()
|
frame = self._read_frame_from_stream()
|
||||||
|
if frame is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Stream failed - try restarting
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
print(f"RETRY {self.path.name}: attempt {attempt+1}/{max_retries} at t={t:.2f}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Check for ffmpeg errors
|
||||||
|
if self._proc and self._proc.stderr:
|
||||||
|
try:
|
||||||
|
import select
|
||||||
|
readable, _, _ = select.select([self._proc.stderr], [], [], 0.1)
|
||||||
|
if readable:
|
||||||
|
err = self._proc.stderr.read(4096).decode('utf-8', errors='ignore')
|
||||||
|
if err:
|
||||||
|
print(f"ffmpeg error: {err}", file=sys.stderr)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait a bit and restart
|
||||||
|
time.sleep(0.1)
|
||||||
|
self._start_stream(seek_time)
|
||||||
|
|
||||||
|
# Give ffmpeg time to start
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
import sys
|
import sys
|
||||||
print(f"NULL FRAME {self.path.name}: t={t:.2f} seek={seek_time:.2f}", file=sys.stderr)
|
raise RuntimeError(f"Failed to read video frame from {self.path.name} at t={t:.2f} after {max_retries} retries")
|
||||||
frame = np.zeros((h, w, 3), dtype=np.uint8)
|
|
||||||
else:
|
else:
|
||||||
self._stream_time += self._frame_time
|
self._stream_time += self._frame_time
|
||||||
|
|
||||||
@@ -160,16 +290,27 @@ class AudioAnalyzer:
|
|||||||
self.path = Path(path)
|
self.path = Path(path)
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not self.path.exists():
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {self.path}")
|
||||||
|
|
||||||
# Load audio via ffmpeg
|
# Load audio via ffmpeg
|
||||||
cmd = ["ffmpeg", "-v", "quiet", "-i", str(self.path),
|
cmd = ["ffmpeg", "-v", "error", "-i", str(self.path),
|
||||||
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
|
"-f", "f32le", "-ac", "1", "-ar", str(sample_rate), "-"]
|
||||||
result = subprocess.run(cmd, capture_output=True)
|
result = subprocess.run(cmd, capture_output=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Failed to load audio '{self.path}': {result.stderr.decode()}")
|
||||||
self._audio = np.frombuffer(result.stdout, dtype=np.float32)
|
self._audio = np.frombuffer(result.stdout, dtype=np.float32)
|
||||||
|
if len(self._audio) == 0:
|
||||||
|
raise RuntimeError(f"Audio file is empty or invalid: {self.path}")
|
||||||
|
|
||||||
# Get duration
|
# Get duration
|
||||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json",
|
||||||
"-show_format", str(self.path)]
|
"-show_format", str(self.path)]
|
||||||
info = json.loads(subprocess.run(cmd, capture_output=True, text=True).stdout)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Failed to probe audio '{self.path}': {result.stderr}")
|
||||||
|
info = json.loads(result.stdout)
|
||||||
self.duration = float(info.get("format", {}).get("duration", 60))
|
self.duration = float(info.get("format", {}).get("duration", 60))
|
||||||
|
|
||||||
# Beat detection state
|
# Beat detection state
|
||||||
@@ -302,3 +443,20 @@ def prim_audio_beat_count(analyzer: AudioAnalyzer, t: float) -> int:
|
|||||||
def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
|
def prim_audio_duration(analyzer: AudioAnalyzer) -> float:
|
||||||
"""Get audio duration in seconds."""
|
"""Get audio duration in seconds."""
|
||||||
return analyzer.duration
|
return analyzer.duration
|
||||||
|
|
||||||
|
|
||||||
|
# Export primitives
|
||||||
|
PRIMITIVES = {
|
||||||
|
# Video source
|
||||||
|
'make-video-source': prim_make_video_source,
|
||||||
|
'source-read': prim_source_read,
|
||||||
|
'source-skip': prim_source_skip,
|
||||||
|
'source-size': prim_source_size,
|
||||||
|
|
||||||
|
# Audio analyzer
|
||||||
|
'make-audio-analyzer': prim_make_audio_analyzer,
|
||||||
|
'audio-energy': prim_audio_energy,
|
||||||
|
'audio-beat': prim_audio_beat,
|
||||||
|
'audio-beat-count': prim_audio_beat_count,
|
||||||
|
'audio-duration': prim_audio_duration,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user