new version

This commit is contained in:
UrloMythus
2026-04-15 19:23:14 +02:00
parent 5120b19d0b
commit 8134936d59
135 changed files with 3013 additions and 1589 deletions
+4 -1
View File
@@ -90,6 +90,7 @@ class BasePrebuffer(ABC):
max_memory_percent: float,
emergency_threshold: float,
segment_ttl: int = 60,
prebuffer_lock_timeout: float = 1.0,
):
"""
Initialize the base prebuffer.
@@ -100,12 +101,14 @@ class BasePrebuffer(ABC):
max_memory_percent: Maximum memory usage percentage before skipping prebuffer
emergency_threshold: Memory threshold for emergency cleanup
segment_ttl: TTL for cached segments in seconds
prebuffer_lock_timeout: Lock acquisition timeout (seconds) for background prebuffer tasks
"""
self.max_cache_size = max_cache_size
self.prebuffer_segment_count = prebuffer_segments
self.max_memory_percent = max_memory_percent
self.emergency_threshold = emergency_threshold
self.segment_ttl = segment_ttl
self.prebuffer_lock_timeout = prebuffer_lock_timeout
# Statistics (per-worker, not shared - but that's fine for monitoring)
self.stats = PrebufferStats()
@@ -310,7 +313,7 @@ class BasePrebuffer(ABC):
try:
# Try to acquire lock with short timeout for prebuffering
# If lock is held by another process, skip this segment
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=1.0)
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=self.prebuffer_lock_timeout)
if not lock_acquired:
# Another process is downloading, skip this segment
+57 -7
View File
@@ -52,6 +52,7 @@ class DASHPreBuffer(BasePrebuffer):
max_memory_percent=settings.dash_prebuffer_max_memory_percent,
emergency_threshold=settings.dash_prebuffer_emergency_threshold,
segment_ttl=settings.dash_segment_cache_ttl,
prebuffer_lock_timeout=settings.dash_prebuffer_lock_timeout,
)
self.inactivity_timeout = settings.dash_prebuffer_inactivity_timeout
@@ -100,9 +101,26 @@ class DASHPreBuffer(BasePrebuffer):
logger.warning(f"No profiles found in MPD for prebuffering: {mpd_url}")
return
# Now get segments for each profile by parsing with profile_id
# Early-out for SegmentBase VOD content.
# SegmentBase profiles have initRange set (byte range within a large single file).
# Prebuffering them means downloading 100 MB+ per profile — not useful.
# If every profile is SegmentBase, skip the expensive per-profile segment
# parsing entirely and only prewarm init segments via the direct cache path.
all_segment_base = all(p.get("initRange") is not None for p in base_profiles)
if all_segment_base and not is_live:
logger.info(
f"[prebuffer_dash_manifest] Skipping SegmentBase VOD: {mpd_url} "
f"({len(base_profiles)} profiles, all SegmentBase)"
)
return
# Now get segments for each profile by parsing with profile_id.
# Skip SegmentBase profiles — their "segments" are whole-file downloads.
profiles_with_segments = []
for profile in base_profiles:
if profile.get("initRange") is not None:
# SegmentBase profile — skip expensive segment parsing
continue
profile_id = profile.get("id")
if profile_id:
parsed_with_segments = await get_cached_mpd(
@@ -122,8 +140,10 @@ class DASHPreBuffer(BasePrebuffer):
"last_access": time.time(),
}
# Prebuffer init segments and media segments
await self._prebuffer_profiles(profiles_with_segments, headers, is_live)
# For live streams we only prewarm init segments by default; media prefetch
# is driven by player/playlist requests to avoid lock contention storms.
include_media = not is_live or settings.dash_live_initial_media_prebuffer
await self._prebuffer_profiles(profiles_with_segments, headers, is_live, include_media=include_media)
# Start cleanup task if not running
self._ensure_cleanup_task_running()
@@ -140,6 +160,7 @@ class DASHPreBuffer(BasePrebuffer):
profiles: List[dict],
headers: Dict[str, str],
is_live: bool = False,
include_media: bool = True,
) -> None:
"""
Pre-buffer init segments and media segments for all profiles.
@@ -151,6 +172,7 @@ class DASHPreBuffer(BasePrebuffer):
profiles: List of parsed profiles with resolved URLs
headers: Headers to use for requests
is_live: Whether this is a live stream
include_media: Whether media segments should be prebuffered (init segments are always prewarmed)
"""
if self._should_skip_for_memory():
logger.warning("Memory usage too high, skipping prebuffer")
@@ -161,13 +183,29 @@ class DASHPreBuffer(BasePrebuffer):
init_urls = []
for profile in profiles:
# Skip SegmentBase profiles entirely.
# SegmentBase "segments" are byte-range slices of a single large file
# (e.g. a25.mp4 = 122 MB). Prebuffering them without a Range header
# would download the entire file per profile — potentially GB of data
# for all audio tracks before a single second of playback is served.
# SegmentBase is identified by the profile having initRange (byte range
# within the base file) or by segments having mediaRange set.
segments = profile.get("segments", [])
is_segment_base = profile.get("initRange") is not None or (
segments and segments[0].get("mediaRange") is not None
)
if is_segment_base:
logger.debug(f"[_prebuffer_profiles] Skipping SegmentBase profile: {profile.get('id')}")
continue
# Collect init segment URL
init_url = profile.get("initUrl")
if init_url:
init_urls.append(init_url)
# Get segments to prebuffer
segments = profile.get("segments", [])
if not include_media:
continue
if not segments:
continue
@@ -283,7 +321,13 @@ class DASHPreBuffer(BasePrebuffer):
if segment_urls:
logger.debug(f"Prefetching {len(segment_urls)} upcoming segments from index {current_index + 1}")
# Run prefetch in background
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
asyncio.create_task(
self.prebuffer_segments_batch(
segment_urls,
headers,
max_concurrent=max(settings.dash_prefetch_max_concurrent, 1),
)
)
except Exception as e:
logger.warning(f"Failed to prefetch upcoming segments: {e}")
@@ -322,7 +366,13 @@ class DASHPreBuffer(BasePrebuffer):
if segment_urls:
logger.debug(f"Live playlist prefetch: {len(segment_urls)} segments")
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
asyncio.create_task(
self.prebuffer_segments_batch(
segment_urls,
headers,
max_concurrent=max(settings.dash_prefetch_max_concurrent, 1),
)
)
def _ensure_cleanup_task_running(self) -> None:
"""Ensure the cleanup task is running."""
+9 -84
View File
@@ -1,12 +1,12 @@
"""
Helper functions for automatic stream extraction in proxy routes.
This module provides caching and extraction helpers for DLHD/DaddyLive
and Sportsonline/Sportzonline streams that are auto-detected in proxy routes.
This module provides caching and extraction helpers for Sportsonline/Sportzonline
streams that are auto-detected in proxy routes.
"""
import copy
import logging
import re
import time
from urllib.parse import urlparse
@@ -19,88 +19,11 @@ from mediaflow_proxy.utils.http_utils import ProxyRequestHeaders, DownloadError
logger = logging.getLogger(__name__)
# DLHD extraction cache: {original_url: {"data": extraction_result, "timestamp": time.time()}}
_dlhd_extraction_cache: dict = {}
_dlhd_cache_duration = 600 # 10 minutes in seconds
# Sportsonline extraction cache
_sportsonline_extraction_cache: dict = {}
_sportsonline_cache_duration = 600 # 10 minutes in seconds
async def check_and_extract_dlhd_stream(
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
) -> dict | None:
"""
Check if destination contains DLHD/DaddyLive patterns and extract stream directly.
Uses caching to avoid repeated extractions (10 minute cache).
Args:
request (Request): The incoming HTTP request.
destination (str): The destination URL to check.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
force_refresh (bool): Force re-extraction even if cached data exists.
Returns:
dict | None: Extracted stream data if DLHD link detected, None otherwise.
"""
# Check for common DLHD/DaddyLive patterns in the URL
# This includes stream-XXX pattern and domain names like dlhd.dad or daddylive.sx
is_dlhd_link = (
re.search(r"stream-\d+", destination)
or "dlhd.dad" in urlparse(destination).netloc
or "daddylive.sx" in urlparse(destination).netloc
)
if not is_dlhd_link:
return None
logger.info(f"DLHD link detected: {destination}")
# Check cache first (unless force_refresh is True)
current_time = time.time()
if not force_refresh and destination in _dlhd_extraction_cache:
cached_entry = _dlhd_extraction_cache[destination]
cache_age = current_time - cached_entry["timestamp"]
if cache_age < _dlhd_cache_duration:
logger.info(f"Using cached DLHD data (age: {cache_age:.1f}s)")
return cached_entry["data"]
else:
logger.info(f"DLHD cache expired (age: {cache_age:.1f}s), re-extracting...")
del _dlhd_extraction_cache[destination]
# Extract stream data
try:
logger.info(f"Extracting DLHD stream data from: {destination}")
extractor = ExtractorFactory.get_extractor("DLHD", proxy_headers.request)
result = await extractor.extract(destination)
logger.info(f"DLHD extraction successful. Stream URL: {result.get('destination_url')}")
# Handle dlhd_key_params - encode them for URL passing
if "dlhd_key_params" in result:
key_params = result.pop("dlhd_key_params")
# Add key params as special query parameters for key URL handling
result["dlhd_channel_salt"] = key_params.get("channel_salt", "")
result["dlhd_auth_token"] = key_params.get("auth_token", "")
result["dlhd_iframe_url"] = key_params.get("iframe_url", "")
logger.info("DLHD key params extracted for dynamic header computation")
# Cache a copy of result to prevent downstream mutations from corrupting the cache
_dlhd_extraction_cache[destination] = {"data": result.copy(), "timestamp": current_time}
logger.info(f"DLHD data cached for {_dlhd_cache_duration}s")
return result
except (ExtractorError, DownloadError) as e:
logger.error(f"DLHD extraction failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"DLHD extraction failed: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error during DLHD extraction: {str(e)}")
raise HTTPException(status_code=500, detail=f"DLHD extraction failed: {str(e)}")
async def check_and_extract_sportsonline_stream(
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
) -> dict | None:
@@ -117,8 +40,9 @@ async def check_and_extract_sportsonline_stream(
Returns:
dict | None: Extracted stream data if Sportsonline link detected, None otherwise.
"""
parsed_netloc = urlparse(destination).netloc
is_sportsonline_link = "sportzonline." in parsed_netloc or "sportsonline." in parsed_netloc
hostname = (urlparse(destination).hostname or "").lower()
hostname_labels = {part for part in hostname.split(".") if part}
is_sportsonline_link = bool(hostname_labels & {"sportzonline", "sportsonline", "sportzsonline"})
if not is_sportsonline_link:
return None
@@ -130,7 +54,7 @@ async def check_and_extract_sportsonline_stream(
cached_entry = _sportsonline_extraction_cache[destination]
if current_time - cached_entry["timestamp"] < _sportsonline_cache_duration:
logger.info(f"Using cached Sportsonline data (age: {current_time - cached_entry['timestamp']:.1f}s)")
return cached_entry["data"]
return copy.deepcopy(cached_entry["data"])
else:
logger.info("Sportsonline cache expired, re-extracting...")
del _sportsonline_extraction_cache[destination]
@@ -140,7 +64,8 @@ async def check_and_extract_sportsonline_stream(
extractor = ExtractorFactory.get_extractor("Sportsonline", proxy_headers.request)
result = await extractor.extract(destination)
logger.info(f"Sportsonline extraction successful. Stream URL: {result.get('destination_url')}")
_sportsonline_extraction_cache[destination] = {"data": result, "timestamp": current_time}
# Cache a copy of result to prevent downstream mutations from corrupting the cache
_sportsonline_extraction_cache[destination] = {"data": copy.deepcopy(result), "timestamp": current_time}
logger.info(f"Sportsonline data cached for {_sportsonline_cache_duration}s")
return result
except (ExtractorError, DownloadError) as e:
-1
View File
@@ -181,7 +181,6 @@ def initialize_routing_from_config(transport_config) -> None:
# Hardcoded routes for specific domains (SSL verification disabled)
hardcoded_domains = [
"all://jxoplay.xyz",
"all://dlhd.dad",
"all://*.newkso.ru",
]
+15 -2
View File
@@ -360,13 +360,21 @@ class Streamer:
await self.session.close()
async def download_file_with_retry(url: str, headers: dict) -> bytes:
async def download_file_with_retry(
url: str,
headers: dict,
timeout: typing.Optional[ClientTimeout] = None,
) -> bytes:
"""
Downloads a file with retry logic.
Args:
url: The URL of the file to download.
headers: The headers to include in the request.
timeout: Optional aiohttp ClientTimeout override. When None the global
transport timeout is used. Pass a ClientTimeout with sock_read set
(and total=None) for large ranged downloads so that the per-chunk
read deadline is used instead of a hard total-download limit.
Returns:
bytes: The downloaded file content.
@@ -374,7 +382,7 @@ async def download_file_with_retry(url: str, headers: dict) -> bytes:
Raises:
DownloadError: If the download fails after retries.
"""
async with create_aiohttp_session(url) as (session, proxy_url):
async with create_aiohttp_session(url, timeout=timeout) as (session, proxy_url):
try:
response = await fetch_with_retry(session, "GET", url, headers, proxy=proxy_url)
return await response.read()
@@ -689,6 +697,7 @@ class ProxyRequestHeaders:
response: dict
remove: list # headers to remove from response
propagate: dict # response headers to propagate to segments (rp_ prefix)
auto_added_range: bool = False # True if range header was auto-added by proxy (not from client)
def apply_header_manipulation(
@@ -754,6 +763,10 @@ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
# Filter out empty values
propagate_headers = {k[3:].lower(): v for k, v in request.query_params.items() if k.lower().startswith("rp_") and v}
for k, v in propagate_headers.items():
if k not in request_headers:
request_headers[k] = v
# Parse headers to remove from response (x_headers parameter)
x_headers_param = request.query_params.get("x_headers", "")
remove_headers = [h.strip().lower() for h in x_headers_param.split(",") if h.strip()] if x_headers_param else []
+72 -33
View File
@@ -159,6 +159,11 @@ class M3U8Processor:
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
).replace("/hls/manifest.m3u8", "/hls/segment")
self.playlist_url = None # Will be set when processing starts
# Per HLS spec, any URI on the line immediately following #EXT-X-STREAM-INF
# is a variant sub-playlist, not a segment. Track this so proxy_content_url
# can route it to the manifest endpoint regardless of file extension or query
# params (e.g. VixCloud uses ?type=video which looks like a segment URL).
self._after_stream_inf = False
def _should_apply_start_offset(self, content: str) -> bool:
"""
@@ -656,10 +661,25 @@ class M3U8Processor:
str: The processed line.
"""
if "URI=" in line:
self._after_stream_inf = False
return await self.process_key_line(line, base_url)
elif not line.startswith("#") and line.strip():
if self._after_stream_inf:
# Per HLS spec §4.3.4.2, the URI on the line immediately following
# #EXT-X-STREAM-INF is always a variant sub-playlist, never a segment.
# Route it to the manifest proxy regardless of extension or query params
# (e.g. VixCloud uses ?type=video rather than a .m3u8 extension).
self._after_stream_inf = False
full_url = parse.urljoin(base_url, line)
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
return await self.proxy_content_url(line, base_url)
else:
if line.startswith("#EXT-X-STREAM-INF"):
self._after_stream_inf = True
elif line.startswith("#"):
# Any other tag resets the flag — EXT-X-STREAM-INF must be
# immediately followed by a URI with no intervening tags.
self._after_stream_inf = False
return line
async def process_key_line(self, line: str, base_url: str) -> str:
@@ -691,12 +711,20 @@ class M3U8Processor:
# otherwise the proxied destination URL gets the wrong upstream hostname.
if self.key_url and line.startswith("#EXT-X-KEY"):
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
# Check if this is a DLHD stream with key params (needs stream endpoint for header computation)
query_params = dict(self.request.query_params)
is_dlhd_key_request = "dlhd_salt" in query_params and "/key/" in uri.geturl()
# Use stream endpoint for DLHD key URLs, manifest endpoint for others
resolved_key_url = uri.geturl()
if not uri.scheme:
resolved_key_url = parse.urljoin(base_url, resolved_key_url)
# #EXT-X-MEDIA and #EXT-X-I-FRAME-STREAM-INF carry sub-playlist URIs
# (audio/subtitle/i-frame rendition playlists) that must be routed through
# the manifest proxy so their own segment URLs get rewritten.
# All other tags with URI= (#EXT-X-KEY, #EXT-X-MAP, #EXT-X-SESSION-KEY)
# reference raw binary data and must go through the segment proxy so the
# raw bytes are returned to the player without M3U8 parsing.
is_sub_playlist = line.startswith(("#EXT-X-MEDIA", "#EXT-X-I-FRAME-STREAM-INF"))
new_uri = await self.proxy_url(
uri.geturl(), base_url, use_full_url=True, is_playlist=not is_dlhd_key_request
resolved_key_url, base_url, use_full_url=True, is_playlist=is_sub_playlist, is_key=not is_sub_playlist
)
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
return line
@@ -725,9 +753,18 @@ class M3U8Processor:
# Determine routing strategy based on configuration
routing_strategy = settings.m3u8_content_routing
# Check if we should force MediaFlow proxy for all playlist URLs
# If force_playlist_proxy is set, route all content through the proxy but
# distinguish actual playlists from segments: sending raw MPEG-TS bytes to
# the manifest endpoint (is_playlist=True) causes a "#EXTM3U not found" error
# because the endpoint tries to parse binary segment data as an HLS playlist.
# Segments disguised as .js/.css (e.g. chevy streams) must go to the segment
# endpoint (is_playlist=False) so the bytes are streamed back as-is.
if self.force_playlist_proxy:
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
parsed_url = parse.urlparse(full_url)
content_is_playlist = parsed_url.path.endswith((".m3u", ".m3u8", ".m3u_plus")) or parse.parse_qs(
parsed_url.query
).get("type", [""])[0] in ["m3u", "m3u8", "m3u_plus"]
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=content_is_playlist)
# For playlist URLs, always use MediaFlow proxy regardless of strategy
# Check for actual playlist file extensions, not just substring matches
@@ -783,7 +820,14 @@ class M3U8Processor:
segment_urls.append(parse.urljoin(base_url, line))
return segment_urls
async def proxy_url(self, url: str, base_url: str, use_full_url: bool = False, is_playlist: bool = True) -> str:
async def proxy_url(
self,
url: str,
base_url: str,
use_full_url: bool = False,
is_playlist: bool = True,
is_key: bool = False,
) -> str:
"""
Proxies a URL, encoding it with the MediaFlow proxy URL.
@@ -792,6 +836,7 @@ class M3U8Processor:
base_url (str): The base URL to resolve relative URLs.
use_full_url (bool): Whether to use the URL as-is (True) or join with base_url (False).
is_playlist (bool): Whether this is a playlist URL (uses manifest endpoint) or segment URL (uses stream endpoint).
is_key (bool): Whether this is a key/init-segment URL (suppresses AES param injection).
Returns:
str: The proxied URL.
@@ -819,31 +864,25 @@ class M3U8Processor:
if is_playlist:
proxy_url = self.mediaflow_proxy_url
else:
# Check if this is a DLHD key URL (needs /stream endpoint for header computation)
is_dlhd_key = "dlhd_salt" in query_params and "/key/" in full_url
if is_dlhd_key:
# Use /stream endpoint for DLHD key URLs
proxy_url = self.mediaflow_proxy_url.replace("/hls/manifest.m3u8", "/stream")
else:
# Determine segment extension from the URL
# Default to .ts for traditional HLS, but detect fMP4 extensions
segment_ext = "ts"
url_lower = full_url.lower()
# Check for fMP4/CMAF extensions
if url_lower.endswith(".m4s"):
segment_ext = "m4s"
elif url_lower.endswith(".mp4"):
segment_ext = "mp4"
elif url_lower.endswith(".m4a"):
segment_ext = "m4a"
elif url_lower.endswith(".m4v"):
segment_ext = "m4v"
elif url_lower.endswith(".aac"):
segment_ext = "aac"
# Build segment proxy URL with correct extension
proxy_url = f"{self.segment_proxy_base_url}.{segment_ext}"
# Remove h_range header - each segment should handle its own range requests
query_params.pop("h_range", None)
# Determine segment extension from the URL
# Default to .ts for traditional HLS, but detect fMP4 extensions
segment_ext = "ts"
url_lower = full_url.lower()
# Check for fMP4/CMAF extensions
if url_lower.endswith(".m4s"):
segment_ext = "m4s"
elif url_lower.endswith(".mp4"):
segment_ext = "mp4"
elif url_lower.endswith(".m4a"):
segment_ext = "m4a"
elif url_lower.endswith(".m4v"):
segment_ext = "m4v"
elif url_lower.endswith(".aac"):
segment_ext = "aac"
# Build segment proxy URL with correct extension
proxy_url = f"{self.segment_proxy_base_url}.{segment_ext}"
# Remove h_range header - each segment should handle its own range requests
query_params.pop("h_range", None)
return encode_mediaflow_proxy_url(
proxy_url,
+251 -22
View File
@@ -1,6 +1,8 @@
import logging
import math
import re
import statistics
import struct
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Optional, Union
from urllib.parse import urljoin
@@ -250,12 +252,22 @@ def parse_representation(
if "video" not in mime_type and "audio" not in mime_type:
return None
# Raw rep.id from XML — used for $RepresentationID$ URL template expansion.
rep_id = representation.get("@id") or adaptation.get("@id") or "0"
adapt_id = adaptation.get("@id") or "0"
bandwidth_for_id = int(representation.get("@bandwidth") or adaptation.get("@bandwidth") or 0)
# Globally unique profile ID: "{adapt_id}_{rep_id}_{bandwidth}".
# Within one AdaptationSet multiple representations can share the same @id
# (same quality tier, different codec variant), so we include bandwidth to distinguish.
unique_id = f"{adapt_id}_{rep_id}_{bandwidth_for_id}"
profile = {
"id": representation.get("@id") or adaptation.get("@id"),
"id": unique_id,
"rep_id": rep_id, # raw XML @id for $RepresentationID$ template expansion
"mimeType": mime_type,
"lang": representation.get("@lang") or adaptation.get("@lang"),
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
"bandwidth": bandwidth_for_id,
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
"mediaPresentationDuration": media_presentation_duration,
}
@@ -287,12 +299,18 @@ def parse_representation(
# Extract segment template start number for adaptive sequence calculation
segment_template_data = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
if segment_template_data:
profile["segment_template_start_number_explicit"] = "@startNumber" in segment_template_data
try:
profile["segment_template_start_number"] = int(segment_template_data.get("@startNumber", 1))
except (ValueError, TypeError):
profile["segment_template_start_number"] = 1
try:
profile["segment_template_timescale"] = int(segment_template_data.get("@timescale", 1))
except (ValueError, TypeError):
profile["segment_template_timescale"] = 1
else:
profile["segment_template_start_number"] = 1
profile["segment_template_start_number_explicit"] = False
# For SegmentBase profiles, we need to set initUrl even when not parsing segments
# This is needed for the HLS playlist builder to reference the init URL
@@ -375,11 +393,19 @@ def parse_segment_template(
"""
segments = []
timescale = int(item.get("@timescale", 1))
profile["segment_template_timescale"] = timescale
profile["segment_template_start_number_explicit"] = "@startNumber" in item
try:
profile["segment_template_start_number"] = int(
item.get("@startNumber", profile.get("segment_template_start_number", 1))
)
except (ValueError, TypeError):
profile["segment_template_start_number"] = 1
# Initialization
if "@initialization" in item:
media = item["@initialization"]
media = media.replace("$RepresentationID$", profile["id"])
media = media.replace("$RepresentationID$", profile.get("rep_id", profile["id"]))
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
# Combine base_url and media, then resolve against mpd_url
if base_url:
@@ -390,6 +416,10 @@ def parse_segment_template(
if "SegmentTimeline" in item:
segments.extend(parse_segment_timeline(parsed_dict, item, profile, mpd_url, timescale, base_url))
elif "@duration" in item:
try:
profile["nominal_duration_mpd_timescale"] = int(item["@duration"])
except (ValueError, TypeError):
pass
segments.extend(parse_segment_duration(parsed_dict, item, profile, mpd_url, timescale, base_url))
return segments
@@ -420,13 +450,36 @@ def parse_segment_timeline(
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
start_number = int(item.get("@startNumber", 1))
timeline_segments = preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
nominal_duration = _resolve_nominal_timeline_duration(timeline_segments)
if nominal_duration:
profile["nominal_duration_mpd_timescale"] = nominal_duration
segments = [
create_segment_data(timeline, item, profile, mpd_url, timescale, base_url)
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
create_segment_data(timeline, item, profile, mpd_url, timescale, base_url) for timeline in timeline_segments
]
return segments
def _resolve_nominal_timeline_duration(timeline_segments: List[Dict]) -> Optional[int]:
"""
Resolve a stable nominal segment duration from expanded SegmentTimeline entries.
Live timelines often contain occasional shorter segments; using median keeps
sequence calculations stable when the window slides.
"""
durations = []
for segment in timeline_segments:
duration = segment.get("duration_mpd_timescale")
if isinstance(duration, (int, float)) and duration > 0:
durations.append(int(duration))
if not durations:
return None
return int(statistics.median_low(durations))
def preprocess_timeline(
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
) -> List[Dict]:
@@ -491,25 +544,41 @@ def parse_segment_duration(
"""
duration = int(item["@duration"])
start_number = int(item.get("@startNumber", 1))
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
segment_duration_sec = duration / timescale
if parsed_dict["isLive"]:
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
profile["nominal_duration_mpd_timescale"] = duration
segments = generate_live_segments(
parsed_dict,
segment_duration_sec,
start_number,
duration_mpd_timescale=duration,
presentation_time_offset=presentation_time_offset,
)
else:
segments = generate_vod_segments(profile, duration, timescale, start_number)
return [create_segment_data(seg, item, profile, mpd_url, timescale, base_url) for seg in segments]
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
def generate_live_segments(
parsed_dict: dict,
segment_duration_sec: float,
start_number: int,
duration_mpd_timescale: Optional[int] = None,
presentation_time_offset: int = 0,
) -> List[Dict]:
"""
Generates live segments based on the segment duration and start number.
This is used for live MPD manifests.
Args:
parsed_dict (dict): The parsed MPD data.
segment_duration_sec (float): The segment duration in seconds.
start_number (int): The starting segment number.
segment_duration_sec: The segment duration in seconds.
start_number: The starting segment number.
duration_mpd_timescale: Segment duration in MPD timescale units.
presentation_time_offset: MPD presentationTimeOffset, in timescale units.
Returns:
List[Dict]: The list of generated live segments.
@@ -524,15 +593,22 @@ def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start
start_number,
)
return [
{
segments = []
for number in range(earliest_segment_number, earliest_segment_number + segment_count):
start_time = parsed_dict["availabilityStartTime"] + timedelta(
seconds=(number - start_number) * segment_duration_sec
)
segment = {
"number": number,
"start_time": parsed_dict["availabilityStartTime"]
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
"duration": segment_duration_sec,
"start_time": start_time,
"end_time": start_time + timedelta(seconds=segment_duration_sec),
"duration": duration_mpd_timescale if duration_mpd_timescale is not None else segment_duration_sec,
}
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
]
if duration_mpd_timescale is not None:
segment["duration_mpd_timescale"] = duration_mpd_timescale
segment["time"] = presentation_time_offset + (number - start_number) * duration_mpd_timescale
segments.append(segment)
return segments
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
@@ -575,13 +651,40 @@ def create_segment_data(
Dict: The created segment data.
"""
media_template = item["@media"]
media = media_template.replace("$RepresentationID$", profile["id"])
media = media_template.replace("$RepresentationID$", profile.get("rep_id", profile["id"]))
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
media = media.replace("$Number$", str(segment["number"]))
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if "time" in segment and timescale is not None:
media = media.replace("$Time$", str(int(segment["time"])))
if "$Time$" in media and timescale is not None:
time_value = None
if "time" in segment:
time_value = int(segment["time"])
else:
duration_mpd_timescale = segment.get("duration_mpd_timescale")
if duration_mpd_timescale is None:
try:
duration_mpd_timescale = int(item.get("@duration", 0))
except (TypeError, ValueError):
duration_mpd_timescale = 0
if duration_mpd_timescale:
try:
start_number = int(item.get("@startNumber", profile.get("segment_template_start_number", 1)))
except (TypeError, ValueError):
start_number = profile.get("segment_template_start_number", 1)
try:
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
except (TypeError, ValueError):
presentation_time_offset = 0
time_value = presentation_time_offset + (int(segment["number"]) - start_number) * int(
duration_mpd_timescale
)
if time_value is not None:
media = media.replace("$Time$", str(time_value))
if "$Time$" in media:
logger.warning("Unresolved $Time$ placeholder in segment URL template: %s", media_template)
# Combine base_url and media, then resolve against mpd_url
if base_url:
@@ -597,7 +700,9 @@ def create_segment_data(
# Add time and duration metadata for adaptive sequence calculation
if "time" in segment:
segment_data["time"] = segment["time"]
if "duration" in segment:
if "duration_mpd_timescale" in segment:
segment_data["duration_mpd_timescale"] = segment["duration_mpd_timescale"]
elif "time" in segment and "duration" in segment and timescale is not None:
segment_data["duration_mpd_timescale"] = segment["duration"]
if "start_time" in segment and "end_time" in segment:
@@ -610,8 +715,14 @@ def create_segment_data(
}
)
elif "start_time" in segment and "duration" in segment:
# duration here is in timescale units (from timeline segments)
duration_seconds = segment["duration"] / timescale if timescale else segment["duration"]
duration_mpd_timescale = segment.get("duration_mpd_timescale")
if duration_mpd_timescale is not None and timescale:
duration_seconds = duration_mpd_timescale / timescale
elif "time" in segment and timescale:
# Timeline-based segments store duration in MPD timescale units.
duration_seconds = segment["duration"] / timescale
else:
duration_seconds = segment["duration"]
segment_data.update(
{
"start_time": segment["start_time"],
@@ -748,6 +859,17 @@ def parse_segment_base(representation: dict, profile: dict, mpd_url: str) -> Lis
elif total_duration is None:
total_duration = 0
# Derive the media byte range: media data starts immediately after the init segment.
# init_range is "start-end" (e.g. "0-657"), so media begins at end+1.
# Without this, the proxy tries to download the whole file which causes CDN timeouts.
media_range = None
if init_range:
try:
end_byte = int(init_range.split("-")[1])
media_range = f"{end_byte + 1}-"
except (ValueError, IndexError):
pass
# For SegmentBase, we return a single segment representing the entire media
# The media URL is the same as initUrl but will be accessed with different byte ranges
return [
@@ -758,10 +880,117 @@ def parse_segment_base(representation: dict, profile: dict, mpd_url: str) -> Lis
"extinf": total_duration if total_duration > 0 else 1.0,
"indexRange": index_range,
"initRange": init_range,
"mediaRange": media_range,
}
]
def parse_sidx_fragments(sidx_bytes: bytes, index_range_start: int) -> List[Dict]:
"""
Parse a SIDX (Segment Index) box and return per-fragment byte ranges and durations.
The SIDX box lists subsegment sizes and durations so the player can seek directly
to any fragment without scanning the whole file. We use this to generate one HLS
segment entry per fragment, enabling efficient seeking in SegmentBase MPDs.
Args:
sidx_bytes: Raw bytes starting at index_range_start in the media file.
May include other boxes (e.g. styp) before the sidx box.
index_range_start: Byte offset of sidx_bytes[0] within the media file.
Returns:
List of dicts, each with:
'start' first byte of fragment in media file (inclusive)
'end' last byte of fragment in media file (inclusive)
'duration_timescale' subsegment duration in SIDX timescale units
'timescale' SIDX timescale (ticks per second)
"""
# Scan forward to find the sidx box (may be preceded by styp or other boxes)
offset = 0
sidx_box_start = -1
while offset + 8 <= len(sidx_bytes):
box_size = struct.unpack_from(">I", sidx_bytes, offset)[0]
box_type = sidx_bytes[offset + 4 : offset + 8]
if box_type == b"sidx":
sidx_box_start = offset
break
if box_size < 8:
break
offset += box_size
if sidx_box_start < 0:
return []
sidx_file_start = index_range_start + sidx_box_start
sidx_size = struct.unpack_from(">I", sidx_bytes, sidx_box_start)[0]
sidx_file_end = sidx_file_start + sidx_size # first byte AFTER the sidx box
off = sidx_box_start + 8 # skip box-size (4) + box-type (4)
if off >= len(sidx_bytes):
return []
version = sidx_bytes[off]
off += 4 # version (1) + flags (3)
off += 4 # reference_id (4)
if off + 4 > len(sidx_bytes):
return []
timescale = struct.unpack_from(">I", sidx_bytes, off)[0]
off += 4
if version == 0:
off += 4 # earliest_presentation_time (4)
if off + 4 > len(sidx_bytes):
return []
first_offset = struct.unpack_from(">I", sidx_bytes, off)[0]
off += 4
else:
off += 8 # earliest_presentation_time (8)
if off + 8 > len(sidx_bytes):
return []
first_offset = struct.unpack_from(">Q", sidx_bytes, off)[0]
off += 8
off += 2 # reserved (2)
if off + 2 > len(sidx_bytes):
return []
reference_count = struct.unpack_from(">H", sidx_bytes, off)[0]
off += 2
# First fragment starts immediately after SIDX + first_offset bytes of gap
frag_start = sidx_file_end + first_offset
fragments = []
for _ in range(reference_count):
if off + 12 > len(sidx_bytes):
break
ref_field = struct.unpack_from(">I", sidx_bytes, off)[0]
off += 4
ref_type = (ref_field >> 31) & 1
referenced_size = ref_field & 0x7FFF_FFFF
duration = struct.unpack_from(">I", sidx_bytes, off)[0]
off += 4
off += 4 # SAP field (ignored)
if ref_type == 0: # media reference (not an index-of-indexes reference)
fragments.append(
{
"start": frag_start,
"end": frag_start + referenced_size - 1,
"duration_timescale": duration,
"timescale": timescale,
}
)
frag_start += referenced_size
return fragments
def parse_duration(duration_str: str) -> float:
"""
Parses a duration ISO 8601 string into seconds.
+52
View File
@@ -454,6 +454,58 @@ async def set_cached_extractor(key: str, data: dict, ttl: int = EXTRACTOR_CACHE_
logger.debug(f"[Redis] Extractor cache set ({ttl}s TTL): {key[:60]}...")
# =============================================================================
# Telegram Document -> Message Cache
# =============================================================================
# Caches (session_fingerprint, chat_id, document_id) -> message_id mappings.
TELEGRAM_DOC_MSG_CACHE_PREFIX = "mfp:telegram_doc_msg:"
def _telegram_doc_msg_key(session_fingerprint: str, chat_id: str, document_id: int) -> str:
raw = f"{session_fingerprint}:{chat_id}:{document_id}"
raw_hash = hashlib.md5(raw.encode()).hexdigest()
return f"{TELEGRAM_DOC_MSG_CACHE_PREFIX}{raw_hash}"
async def get_cached_telegram_doc_message_id(
session_fingerprint: str,
chat_id: str,
document_id: int,
) -> Optional[int]:
"""Get cached Telegram message_id for a chat/document pair."""
r = await get_redis()
if r is None:
return None
redis_key = _telegram_doc_msg_key(session_fingerprint, chat_id, document_id)
data = await r.get(redis_key)
if data is None:
return None
try:
return int(data)
except (TypeError, ValueError):
return None
async def set_cached_telegram_doc_message_id(
session_fingerprint: str,
chat_id: str,
document_id: int,
message_id: int,
ttl: Optional[int] = None,
) -> None:
"""Cache Telegram message_id for a chat/document pair."""
r = await get_redis()
if r is None:
return
redis_key = _telegram_doc_msg_key(session_fingerprint, chat_id, document_id)
cache_ttl = ttl if ttl is not None else settings.telegram_document_cache_ttl
await r.set(redis_key, str(message_id), ex=max(1, int(cache_ttl)))
# =============================================================================
# MPD Cache
# =============================================================================
+200 -23
View File
@@ -11,6 +11,7 @@ Based on FastTelethon technique from mautrix-telegram for parallel downloads.
import asyncio
import base64
import hashlib
import logging
import math
import re
@@ -41,6 +42,10 @@ from telethon.tl.types import (
)
from mediaflow_proxy.configs import settings
from mediaflow_proxy.utils.redis_utils import (
get_cached_telegram_doc_message_id,
set_cached_telegram_doc_message_id,
)
logger = logging.getLogger(__name__)
@@ -201,6 +206,15 @@ class TelegramMediaRef:
chat_id: Optional[Union[int, str]] = None # Channel/group/user ID or username
message_id: Optional[int] = None # Message ID for t.me links
file_id: Optional[str] = None # Direct file reference
document_id: Optional[int] = None # Document ID used for chat scan resolution
class TelegramDocumentNotFoundError(Exception):
"""Raised when document_id cannot be resolved to a chat message."""
class TelegramMessageNotFoundError(Exception):
"""Raised when message_id cannot be resolved in a chat."""
@dataclass
@@ -654,6 +668,7 @@ class TelegramSessionManager:
self._media_info_cache: dict[str, tuple["MediaInfo", float]] = {}
# Persistent sender pool for single-connection downloads (HLS).
self._sender_pool = _SingleSenderPool()
self._session_fingerprint_cache: Optional[str] = None
async def get_client(self) -> TelegramClient:
"""
@@ -703,30 +718,190 @@ class TelegramSessionManager:
logger.info("Telegram client initialized successfully")
return self._client
async def get_message(self, ref: TelegramMediaRef) -> Message:
def _session_fingerprint(self) -> str:
"""Build a stable fingerprint for the active Telegram session."""
if self._session_fingerprint_cache:
return self._session_fingerprint_cache
if settings.telegram_session_string:
raw_session = settings.telegram_session_string.get_secret_value()
else:
raw_session = "telegram-session-missing"
self._session_fingerprint_cache = hashlib.sha256(raw_session.encode()).hexdigest()[:16]
return self._session_fingerprint_cache
@staticmethod
def _chat_id_candidates(chat_id: Union[int, str]) -> list[Union[int, str]]:
"""Return plausible chat_id forms (original first, then normalized forms)."""
candidates: list[Union[int, str]] = [chat_id]
if isinstance(chat_id, int):
if chat_id > 0:
candidates.append(int(f"-100{chat_id}"))
elif isinstance(chat_id, str):
raw = chat_id.strip()
if raw and not raw.startswith("@"):
try:
numeric = int(raw)
candidates.extend(TelegramSessionManager._chat_id_candidates(numeric))
except ValueError:
pass
# de-duplicate while preserving order
unique: list[Union[int, str]] = []
seen = set()
for candidate in candidates:
marker = f"{type(candidate).__name__}:{candidate}"
if marker in seen:
continue
seen.add(marker)
unique.append(candidate)
return unique
@staticmethod
def _is_entity_lookup_error(exc: Exception) -> bool:
"""True when Telethon cannot resolve the chat/entity from the provided chat_id."""
error_name = type(exc).__name__
if error_name in {"PeerIdInvalidError", "ChannelInvalidError", "UsernameNotOccupiedError"}:
return True
if isinstance(exc, ValueError):
message = str(exc).lower()
return "input entity" in message or "peeruser" in message or "peerchannel" in message
return False
@staticmethod
def _extract_document_id_from_file_id(file_id: str) -> Optional[int]:
"""Extract document id from a Bot API file_id when possible."""
media = utils.resolve_bot_file_id(file_id)
if isinstance(media, (Document, Photo)):
return int(media.id)
try:
decoded = decode_file_id(file_id)
return int(decoded.id)
except Exception:
return None
@staticmethod
def _is_matching_document_message(message: Message | None, document_id: int, file_size: Optional[int]) -> bool:
"""Check if message media is the requested Telegram document."""
if message is None or message.media is None:
return False
if not isinstance(message.media, MessageMediaDocument):
return False
document = message.media.document
if not isinstance(document, Document):
return False
if int(document.id) != int(document_id):
return False
if file_size is not None and int(document.size) != int(file_size):
return False
return True
async def _try_get_message_by_id(self, chat_id: Union[int, str], message_id: int) -> Message | None:
"""Fetch a message by ID and normalize 'not found' to None."""
client = await self.get_client()
try:
message = await client.get_messages(chat_id, ids=message_id)
except Exception as e:
if type(e).__name__ == "MessageIdInvalidError" or self._is_entity_lookup_error(e):
return None
raise
if isinstance(message, list):
message = message[0] if message else None
if not message or getattr(message, "id", None) is None:
return None
return message
async def _resolve_document_message(
self,
chat_id: Union[int, str],
document_id: int,
file_size: Optional[int] = None,
) -> tuple[Message, Union[int, str]]:
"""
Get a message by its reference.
Resolve a document_id to a message in the given chat.
Args:
ref: TelegramMediaRef with chat_id and message_id
Returns:
The Message object
Raises:
ValueError: If reference is incomplete
Various Telegram errors: ChannelPrivateError, MessageIdInvalidError, etc.
Lookup order:
1) Redis cache (session+chat+document -> message_id)
2) scan recent chat messages (up to configured limit)
"""
if ref.chat_id is None or ref.message_id is None:
raise ValueError("chat_id and message_id are required to fetch a message")
chat_key = str(chat_id)
session_fp = self._session_fingerprint()
chat_candidates = self._chat_id_candidates(chat_id)
cached_message_id = await get_cached_telegram_doc_message_id(session_fp, chat_key, int(document_id))
if cached_message_id is not None:
for candidate in chat_candidates:
cached_message = await self._try_get_message_by_id(candidate, cached_message_id)
if self._is_matching_document_message(cached_message, document_id, file_size):
return cached_message, candidate
client = await self.get_client()
messages = await client.get_messages(ref.chat_id, ids=ref.message_id)
scan_limit = max(1, int(settings.telegram_document_scan_limit))
for candidate in chat_candidates:
try:
async for message in client.iter_messages(candidate, limit=scan_limit):
if not self._is_matching_document_message(message, document_id, file_size):
continue
if not messages:
raise ValueError(f"Message {ref.message_id} not found in {ref.chat_id}")
if getattr(message, "id", None) is not None:
await set_cached_telegram_doc_message_id(
session_fp,
chat_key,
int(document_id),
int(message.id),
ttl=settings.telegram_document_cache_ttl,
)
return message, candidate
except Exception as e:
if self._is_entity_lookup_error(e):
continue
raise
return messages
raise TelegramDocumentNotFoundError(
f"document_id {document_id} not found in chat {chat_id} (scanned last {scan_limit} messages)"
)
async def get_message(self, ref: TelegramMediaRef, file_size: Optional[int] = None) -> Message:
"""
Get a message by message_id or by resolving document_id within chat history.
Fallback behavior:
- If chat_id+message_id is provided, try it first.
- If not found and file_id is provided, decode document_id from file_id and scan.
- If chat_id+document_id is provided, scan directly.
"""
if ref.chat_id is None:
raise ValueError("chat_id is required to fetch a message")
fallback_document_id = ref.document_id
if fallback_document_id is None and ref.file_id:
fallback_document_id = self._extract_document_id_from_file_id(ref.file_id)
if ref.message_id is not None:
for chat_candidate in self._chat_id_candidates(ref.chat_id):
message = await self._try_get_message_by_id(chat_candidate, ref.message_id)
if message is not None:
ref.chat_id = chat_candidate
return message
if fallback_document_id is None:
raise TelegramMessageNotFoundError(f"Message {ref.message_id} not found in {ref.chat_id}")
if fallback_document_id is None:
raise ValueError("message_id or document_id is required with chat_id")
message, resolved_chat_id = await self._resolve_document_message(
ref.chat_id, fallback_document_id, file_size=file_size
)
ref.chat_id = resolved_chat_id
ref.message_id = int(message.id)
ref.document_id = int(fallback_document_id)
return message
def resolve_file_id(self, file_id: str) -> tuple[Union[Document, Photo], int]:
"""
@@ -783,10 +958,12 @@ class TelegramSessionManager:
def _media_info_cache_key(self, ref: TelegramMediaRef) -> str:
"""Derive an in-memory cache key for a TelegramMediaRef."""
if ref.file_id and not ref.message_id:
if ref.file_id and ref.chat_id is None and not ref.message_id:
return f"fid:{ref.file_id}"
if ref.chat_id is not None and ref.message_id is not None:
return f"chat:{ref.chat_id}:msg:{ref.message_id}"
if ref.chat_id is not None and ref.document_id is not None:
return f"chat:{ref.chat_id}:doc:{ref.document_id}"
return ""
async def get_media_info(self, ref: TelegramMediaRef, file_size: Optional[int] = None) -> MediaInfo:
@@ -832,7 +1009,7 @@ class TelegramSessionManager:
) -> MediaInfo:
"""Uncached implementation of get_media_info."""
# Handle file_id reference
if ref.file_id and not ref.message_id:
if ref.file_id and ref.chat_id is None and not ref.message_id:
media, dc_id = self.resolve_file_id(ref.file_id)
if isinstance(media, Document):
@@ -904,7 +1081,7 @@ class TelegramSessionManager:
raise ValueError(f"Unsupported media type from file_id: {type(media)}")
# Handle message-based reference
message = await self.get_message(ref)
message = await self.get_message(ref, file_size=file_size)
if not message.media:
raise ValueError(f"Message {ref.message_id} does not contain media")
@@ -984,7 +1161,7 @@ class TelegramSessionManager:
"""
client = await self.get_client()
if ref.file_id and not ref.message_id:
if ref.file_id and ref.chat_id is None and not ref.message_id:
media, dc_id = self.resolve_file_id(ref.file_id)
if isinstance(media, Document):
@@ -1038,7 +1215,7 @@ class TelegramSessionManager:
``(file_location, dc_id, actual_file_size)``
"""
# Handle file_id reference (no message needed, fast local parse)
if ref.file_id and not ref.message_id:
if ref.file_id and ref.chat_id is None and not ref.message_id:
media, dc_id = self.resolve_file_id(ref.file_id)
if isinstance(media, Document):
@@ -1076,7 +1253,7 @@ class TelegramSessionManager:
raise ValueError(f"Unsupported media type from file_id: {type(media)}")
# Handle message-based reference (requires Telegram API call)
message = await self.get_message(ref)
message = await self.get_message(ref, file_size=file_size)
if not message.media:
raise ValueError(f"Message {ref.message_id} does not contain media")