Files
UnHided/mediaflow_proxy/utils/dash_prebuffer.py
UrloMythus cfc6bbabc9 update
2026-02-19 20:15:03 +01:00

403 lines
15 KiB
Python

"""
DASH Pre-buffer system for reducing latency and improving streaming performance.
This module extends BasePrebuffer with DASH-specific functionality including
MPD parsing integration, profile handling, and init segment management.
"""
import asyncio
import logging
import time
from typing import Dict, Optional, List
from mediaflow_proxy.utils.base_prebuffer import BasePrebuffer
from mediaflow_proxy.utils.cache_utils import (
get_cached_mpd,
get_cached_init_segment,
)
from mediaflow_proxy.configs import settings
logger = logging.getLogger(__name__)
class DASHPreBuffer(BasePrebuffer):
"""
Pre-buffer system for DASH streams.
Extends BasePrebuffer with DASH-specific features:
- MPD manifest parsing and profile handling
- Init segment prebuffering
- Live stream segment tracking
- Profile-based segment prefetching
Uses event-based download coordination from BasePrebuffer to prevent
duplicate downloads between player requests and background prebuffering.
"""
def __init__(
self,
max_cache_size: Optional[int] = None,
prebuffer_segments: Optional[int] = None,
):
"""
Initialize the DASH pre-buffer system.
Args:
max_cache_size: Maximum number of segments to cache (uses config if None)
prebuffer_segments: Number of segments to pre-buffer ahead (uses config if None)
"""
super().__init__(
max_cache_size=max_cache_size or settings.dash_prebuffer_cache_size,
prebuffer_segments=prebuffer_segments or settings.dash_prebuffer_segments,
max_memory_percent=settings.dash_prebuffer_max_memory_percent,
emergency_threshold=settings.dash_prebuffer_emergency_threshold,
segment_ttl=settings.dash_segment_cache_ttl,
)
self.inactivity_timeout = settings.dash_prebuffer_inactivity_timeout
# DASH-specific state
# Track active streams for prefetching: mpd_url -> stream_info
self.active_streams: Dict[str, dict] = {}
self.prefetch_tasks: Dict[str, asyncio.Task] = {}
# Additional stats for DASH
self.init_segments_prebuffered = 0
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
def log_stats(self) -> None:
"""Log current prebuffer statistics with DASH-specific info."""
stats = self.stats.to_dict()
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
stats["active_streams"] = len(self.active_streams)
logger.info(f"DASH Prebuffer Stats: {stats}")
async def prebuffer_dash_manifest(
self,
mpd_url: str,
headers: Dict[str, str],
) -> None:
"""
Pre-buffer segments from a DASH manifest using existing MPD parsing.
Args:
mpd_url: URL of the DASH manifest
headers: Headers to use for requests
"""
try:
# First get the basic MPD info without segments
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
if not parsed_mpd:
logger.warning(f"Failed to get parsed MPD for prebuffering: {mpd_url}")
return
is_live = parsed_mpd.get("isLive", False)
base_profiles = parsed_mpd.get("profiles", [])
if not base_profiles:
logger.warning(f"No profiles found in MPD for prebuffering: {mpd_url}")
return
# Now get segments for each profile by parsing with profile_id
profiles_with_segments = []
for profile in base_profiles:
profile_id = profile.get("id")
if profile_id:
parsed_with_segments = await get_cached_mpd(
mpd_url, headers, parse_drm=False, parse_segment_profile_id=profile_id
)
# Find the matching profile with segments
for p in parsed_with_segments.get("profiles", []):
if p.get("id") == profile_id:
profiles_with_segments.append(p)
break
# Store stream info for ongoing prefetching
self.active_streams[mpd_url] = {
"headers": headers,
"is_live": is_live,
"profiles": profiles_with_segments,
"last_access": time.time(),
}
# Prebuffer init segments and media segments
await self._prebuffer_profiles(profiles_with_segments, headers, is_live)
# Start cleanup task if not running
self._ensure_cleanup_task_running()
logger.info(
f"Pre-buffered DASH manifest: {mpd_url} (live={is_live}, profiles={len(profiles_with_segments)})"
)
except Exception as e:
logger.warning(f"Failed to pre-buffer DASH manifest {mpd_url}: {e}")
async def _prebuffer_profiles(
self,
profiles: List[dict],
headers: Dict[str, str],
is_live: bool = False,
) -> None:
"""
Pre-buffer init segments and media segments for all profiles.
For live streams, prebuffers from the END of the segment list.
For VOD, prebuffers from the beginning.
Args:
profiles: List of parsed profiles with resolved URLs
headers: Headers to use for requests
is_live: Whether this is a live stream
"""
if self._should_skip_for_memory():
logger.warning("Memory usage too high, skipping prebuffer")
return
# Collect all segment URLs to prebuffer
segment_urls = []
init_urls = []
for profile in profiles:
# Collect init segment URL
init_url = profile.get("initUrl")
if init_url:
init_urls.append(init_url)
# Get segments to prebuffer
segments = profile.get("segments", [])
if not segments:
continue
# For live streams, prebuffer from the END (most recent)
if is_live:
segments_to_buffer = segments[-self.prebuffer_segment_count :]
else:
segments_to_buffer = segments[: self.prebuffer_segment_count]
for segment in segments_to_buffer:
segment_url = segment.get("media")
if segment_url:
segment_urls.append(segment_url)
# Prebuffer init segments (using special init cache)
for init_url in init_urls:
asyncio.create_task(self._prebuffer_init_segment(init_url, headers))
# Prebuffer media segments using base class method
if segment_urls:
await self.prebuffer_segments_batch(segment_urls, headers)
async def _prebuffer_init_segment(
self,
init_url: str,
headers: Dict[str, str],
) -> None:
"""
Prebuffer an init segment using the init segment cache.
Args:
init_url: URL of the init segment
headers: Headers for the request
"""
try:
# get_cached_init_segment handles both caching and downloading
content = await get_cached_init_segment(init_url, headers)
if content:
self.init_segments_prebuffered += 1
self.stats.bytes_prebuffered += len(content)
logger.debug(f"Prebuffered init segment ({len(content)} bytes)")
except Exception as e:
logger.warning(f"Failed to prebuffer init segment: {e}")
async def prefetch_upcoming_segments(
self,
mpd_url: str,
current_segment_url: str,
headers: Dict[str, str],
profile_id: Optional[str] = None,
) -> None:
"""
Prefetch upcoming segments based on current playback position.
Called when a segment is requested to prefetch the next N segments.
Args:
mpd_url: URL of the MPD manifest
current_segment_url: URL of the currently requested segment
headers: Headers to use for requests
profile_id: Optional profile ID to limit prefetching to
"""
self.stats.prefetch_triggered += 1
try:
# First check if we have cached profiles with segments
if mpd_url in self.active_streams:
# Update last access time
self.active_streams[mpd_url]["last_access"] = time.time()
profiles = self.active_streams[mpd_url].get("profiles", [])
else:
# Get parsed MPD
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
if not parsed_mpd:
return
profiles = parsed_mpd.get("profiles", [])
for profile in profiles:
pid = profile.get("id")
if profile_id and pid != profile_id:
continue
segments = profile.get("segments", [])
# If no segments, try to get them by parsing with profile_id
if not segments and pid:
parsed_with_segments = await get_cached_mpd(
mpd_url, headers, parse_drm=False, parse_segment_profile_id=pid
)
for p in parsed_with_segments.get("profiles", []):
if p.get("id") == pid:
segments = p.get("segments", [])
break
# Find current segment index
current_index = -1
for i, segment in enumerate(segments):
if segment.get("media") == current_segment_url:
current_index = i
break
if current_index < 0:
continue
# Collect next N segment URLs
segment_urls = []
end_index = min(current_index + 1 + self.prebuffer_segment_count, len(segments))
for i in range(current_index + 1, end_index):
segment_url = segments[i].get("media")
if segment_url:
segment_urls.append(segment_url)
if segment_urls:
logger.debug(f"Prefetching {len(segment_urls)} upcoming segments from index {current_index + 1}")
# Run prefetch in background
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
except Exception as e:
logger.warning(f"Failed to prefetch upcoming segments: {e}")
async def prefetch_for_live_playlist(
self,
profiles: List[dict],
headers: Dict[str, str],
) -> None:
"""
Prefetch segments for a live playlist refresh.
Called from process_playlist to ensure upcoming segments are cached.
Args:
profiles: List of profiles with resolved segment URLs
headers: Headers to use for requests
"""
segment_urls = []
for profile in profiles:
segments = profile.get("segments", [])
if not segments:
continue
# For live, prefetch the last N segments (most recent)
segments_to_prefetch = segments[-self.prebuffer_segment_count :]
for segment in segments_to_prefetch:
segment_url = segment.get("media")
if segment_url:
# Check if already cached before adding
cached = await self.try_get_cached(segment_url)
if not cached:
segment_urls.append(segment_url)
if segment_urls:
logger.debug(f"Live playlist prefetch: {len(segment_urls)} segments")
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
def _ensure_cleanup_task_running(self) -> None:
"""Ensure the cleanup task is running."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_inactive_streams())
async def _cleanup_inactive_streams(self) -> None:
"""
Periodically check for and clean up inactive streams.
Runs in the background and removes streams that haven't been
accessed recently.
"""
while True:
try:
await asyncio.sleep(30) # Check every 30 seconds
if not self.active_streams:
logger.debug("No active DASH streams to monitor, stopping cleanup")
return
current_time = time.time()
streams_to_remove = []
for mpd_url, stream_info in self.active_streams.items():
last_access = stream_info.get("last_access", 0)
time_since_access = current_time - last_access
if time_since_access > self.inactivity_timeout:
streams_to_remove.append(mpd_url)
logger.info(f"Cleaning up inactive DASH stream ({time_since_access:.0f}s idle)")
# Remove inactive streams
for mpd_url in streams_to_remove:
self.active_streams.pop(mpd_url, None)
task = self.prefetch_tasks.pop(mpd_url, None)
if task:
task.cancel()
if streams_to_remove:
logger.info(f"Cleaned up {len(streams_to_remove)} inactive DASH stream(s)")
except asyncio.CancelledError:
logger.debug("DASH cleanup task cancelled")
return
except Exception as e:
logger.warning(f"Error in DASH cleanup task: {e}")
def get_stats(self) -> dict:
"""Get current prebuffer statistics."""
stats = self.stats.to_dict()
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
stats["active_streams"] = len(self.active_streams)
return stats
def clear_cache(self) -> None:
"""Clear active streams tracking and log final stats."""
self.log_stats()
self.active_streams.clear()
for task in self.prefetch_tasks.values():
task.cancel()
self.prefetch_tasks.clear()
# Cancel cleanup task
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
self._cleanup_task = None
self.stats.reset()
self.init_segments_prebuffered = 0
logger.info("DASH pre-buffer state cleared")
async def close(self) -> None:
"""Close the pre-buffer system."""
self.clear_cache()
# Global DASH pre-buffer instance
dash_prebuffer = DASHPreBuffer()