mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-06-10 09:10:23 +00:00
update
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,684 @@
|
||||
"""
|
||||
Acestream session management with cross-process coordination.
|
||||
|
||||
This module provides:
|
||||
- AcestreamSessionManager: Manages acestream sessions per infohash with cross-process coordination
|
||||
- AcestreamSession: Represents a single acestream session with playback URLs
|
||||
- AsyncMultiWriter: Fan-out writer for streaming to multiple clients (MPEG-TS mode)
|
||||
|
||||
Architecture:
|
||||
- Uses Redis for cross-worker coordination and session registry
|
||||
- Each worker can reuse existing session's playback_url (acestream allows multiple connections)
|
||||
- Session cleanup via command_url?method=stop when all clients disconnect
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
from mediaflow_proxy.utils import redis_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AcestreamResponse:
|
||||
"""Response from acestream's format=json API."""
|
||||
|
||||
playback_url: str
|
||||
stat_url: str
|
||||
command_url: str
|
||||
infohash: str
|
||||
playback_session_id: str
|
||||
is_live: bool
|
||||
is_encrypted: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AcestreamSession:
|
||||
"""
|
||||
Represents an active acestream session.
|
||||
|
||||
A session is created when the first client requests a stream for an infohash.
|
||||
Multiple clients can share the same session (same playback_url).
|
||||
"""
|
||||
|
||||
infohash: str
|
||||
pid: str
|
||||
playback_url: str
|
||||
command_url: str
|
||||
stat_url: str
|
||||
playback_session_id: str
|
||||
is_live: bool
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_access: float = field(default_factory=time.time)
|
||||
last_segment_request: float = field(default_factory=time.time)
|
||||
client_count: int = 0
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update last access time."""
|
||||
self.last_access = time.time()
|
||||
|
||||
def touch_segment(self) -> None:
|
||||
"""Update last segment request time (indicates active playback)."""
|
||||
now = time.time()
|
||||
self.last_access = now
|
||||
self.last_segment_request = now
|
||||
|
||||
def is_actively_streaming(self, timeout: float = 30.0) -> bool:
|
||||
"""Check if this session has recent segment activity."""
|
||||
return (time.time() - self.last_segment_request) < timeout
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert session to dictionary for file-based registry."""
|
||||
return {
|
||||
"infohash": self.infohash,
|
||||
"pid": self.pid,
|
||||
"playback_url": self.playback_url,
|
||||
"command_url": self.command_url,
|
||||
"stat_url": self.stat_url,
|
||||
"playback_session_id": self.playback_session_id,
|
||||
"is_live": self.is_live,
|
||||
"created_at": self.created_at,
|
||||
"last_access": self.last_access,
|
||||
"last_segment_request": self.last_segment_request,
|
||||
"worker_pid": os.getpid(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AcestreamSession":
|
||||
"""Create session from dictionary."""
|
||||
return cls(
|
||||
infohash=data["infohash"],
|
||||
pid=data["pid"],
|
||||
playback_url=data["playback_url"],
|
||||
command_url=data["command_url"],
|
||||
stat_url=data["stat_url"],
|
||||
playback_session_id=data["playback_session_id"],
|
||||
is_live=data.get("is_live", True),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
last_access=data.get("last_access", time.time()),
|
||||
last_segment_request=data.get("last_segment_request", time.time()),
|
||||
)
|
||||
|
||||
|
||||
class AsyncMultiWriter:
|
||||
"""
|
||||
Async multi-writer for fan-out streaming to multiple clients.
|
||||
|
||||
Based on acexy's PMultiWriter but adapted for Python asyncio.
|
||||
Writes are done in parallel to all connected writers.
|
||||
Writers that fail are automatically removed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._writers: List[asyncio.StreamWriter] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, writer: asyncio.StreamWriter) -> None:
|
||||
"""Add a writer to the list."""
|
||||
async with self._lock:
|
||||
if writer not in self._writers:
|
||||
self._writers.append(writer)
|
||||
logger.debug(f"[AsyncMultiWriter] Added writer, total: {len(self._writers)}")
|
||||
|
||||
async def remove(self, writer: asyncio.StreamWriter) -> None:
|
||||
"""Remove a writer from the list."""
|
||||
async with self._lock:
|
||||
if writer in self._writers:
|
||||
self._writers.remove(writer)
|
||||
logger.debug(f"[AsyncMultiWriter] Removed writer, total: {len(self._writers)}")
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
"""
|
||||
Write data to all connected writers in parallel.
|
||||
|
||||
Writers that fail are automatically removed.
|
||||
|
||||
Returns:
|
||||
Number of successful writes.
|
||||
"""
|
||||
if not data:
|
||||
return 0
|
||||
|
||||
async with self._lock:
|
||||
if not self._writers:
|
||||
return 0
|
||||
|
||||
writers_copy = list(self._writers)
|
||||
|
||||
failed_writers = []
|
||||
successful = 0
|
||||
|
||||
async def write_to_single(writer: asyncio.StreamWriter) -> bool:
|
||||
try:
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
return True
|
||||
except (ConnectionResetError, BrokenPipeError, ConnectionError) as e:
|
||||
logger.debug(f"[AsyncMultiWriter] Writer disconnected: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"[AsyncMultiWriter] Write error: {e}")
|
||||
return False
|
||||
|
||||
# Write to all writers in parallel
|
||||
results = await asyncio.gather(
|
||||
*[write_to_single(w) for w in writers_copy],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for writer, result in zip(writers_copy, results):
|
||||
if result is True:
|
||||
successful += 1
|
||||
else:
|
||||
failed_writers.append(writer)
|
||||
|
||||
# Remove failed writers
|
||||
if failed_writers:
|
||||
async with self._lock:
|
||||
for writer in failed_writers:
|
||||
if writer in self._writers:
|
||||
self._writers.remove(writer)
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return successful
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Number of connected writers."""
|
||||
return len(self._writers)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all writers."""
|
||||
async with self._lock:
|
||||
for writer in self._writers:
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._writers.clear()
|
||||
|
||||
|
||||
class AcestreamSessionManager:
|
||||
"""
|
||||
Manages acestream sessions with cross-process coordination.
|
||||
|
||||
Features:
|
||||
- Per-worker session tracking
|
||||
- Redis-based session registry for cross-worker visibility
|
||||
- Session creation via acestream's format=json API
|
||||
- Session cleanup via command_url?method=stop
|
||||
- Session keepalive via periodic stat_url polling
|
||||
"""
|
||||
|
||||
# Redis key prefixes
|
||||
REGISTRY_PREFIX = "mfp:acestream:session:"
|
||||
REGISTRY_TTL = 3600 # 1 hour
|
||||
|
||||
def __init__(self):
|
||||
# Per-worker session tracking (infohash -> session)
|
||||
self._sessions: Dict[str, AcestreamSession] = {}
|
||||
|
||||
# Keepalive task
|
||||
self._keepalive_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
# HTTP client session
|
||||
self._http_session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
logger.info("[AcestreamSessionManager] Initialized with Redis backend")
|
||||
|
||||
def _get_registry_key(self, infohash: str) -> str:
|
||||
"""Get the Redis key for an infohash."""
|
||||
hash_key = hashlib.md5(infohash.encode()).hexdigest()
|
||||
return f"{self.REGISTRY_PREFIX}{hash_key}"
|
||||
|
||||
async def _get_http_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP client session."""
|
||||
if self._http_session is None or self._http_session.closed:
|
||||
self._http_session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
|
||||
return self._http_session
|
||||
|
||||
async def _read_registry(self, infohash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Read session data from Redis registry."""
|
||||
try:
|
||||
r = await redis_utils.get_redis()
|
||||
key = self._get_registry_key(infohash)
|
||||
data = await r.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Error reading registry: {e}")
|
||||
return None
|
||||
|
||||
async def _write_registry(self, session: AcestreamSession) -> None:
|
||||
"""Write session data to Redis registry."""
|
||||
try:
|
||||
r = await redis_utils.get_redis()
|
||||
key = self._get_registry_key(session.infohash)
|
||||
await r.set(key, json.dumps(session.to_dict()), ex=self.REGISTRY_TTL)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Error writing registry: {e}")
|
||||
|
||||
async def _delete_registry(self, infohash: str) -> None:
|
||||
"""Delete session from Redis registry."""
|
||||
try:
|
||||
r = await redis_utils.get_redis()
|
||||
key = self._get_registry_key(infohash)
|
||||
await r.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Error deleting registry: {e}")
|
||||
|
||||
async def _create_acestream_session(self, infohash: str, content_id: Optional[str] = None) -> AcestreamResponse:
|
||||
"""
|
||||
Create a new acestream session via format=json API.
|
||||
|
||||
Args:
|
||||
infohash: The infohash of the content (40-char hex from magnet link)
|
||||
content_id: Optional content ID (alternative to infohash)
|
||||
|
||||
Returns:
|
||||
AcestreamResponse with playback URLs
|
||||
|
||||
Raises:
|
||||
Exception if session creation fails
|
||||
"""
|
||||
base_url = f"http://{settings.acestream_host}:{settings.acestream_port}"
|
||||
pid = str(uuid4())
|
||||
|
||||
# Build URL with parameters
|
||||
# Acestream uses different parameter names:
|
||||
# - 'id' or 'content_id' for content IDs
|
||||
# - 'infohash' for magnet link hashes (40-char hex)
|
||||
params = {
|
||||
"format": "json",
|
||||
"pid": pid,
|
||||
}
|
||||
|
||||
if content_id:
|
||||
# Content ID provided - use 'id' parameter
|
||||
params["id"] = content_id
|
||||
else:
|
||||
# Only infohash provided - use 'infohash' parameter
|
||||
params["infohash"] = infohash
|
||||
|
||||
# Use manifest.m3u8 for HLS or getstream for MPEG-TS
|
||||
# We'll use manifest.m3u8 as the primary since we leverage HLS infrastructure
|
||||
url = f"{base_url}/ace/manifest.m3u8"
|
||||
|
||||
session = await self._get_http_session()
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
if data.get("error"):
|
||||
raise Exception(f"Acestream error: {data['error']}")
|
||||
|
||||
resp = data.get("response", {})
|
||||
return AcestreamResponse(
|
||||
playback_url=resp.get("playback_url", ""),
|
||||
stat_url=resp.get("stat_url", ""),
|
||||
command_url=resp.get("command_url", ""),
|
||||
infohash=resp.get("infohash", infohash),
|
||||
playback_session_id=resp.get("playback_session_id", ""),
|
||||
is_live=bool(resp.get("is_live", 1)),
|
||||
is_encrypted=bool(resp.get("is_encrypted", 0)),
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[AcestreamSessionManager] HTTP error creating session: {e}")
|
||||
raise
|
||||
|
||||
async def get_or_create_session(
|
||||
self,
|
||||
infohash: str,
|
||||
content_id: Optional[str] = None,
|
||||
increment_client: bool = True,
|
||||
) -> AcestreamSession:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Uses Redis locking to coordinate session creation across workers.
|
||||
|
||||
Args:
|
||||
infohash: The infohash of the content
|
||||
content_id: Optional content ID
|
||||
increment_client: Whether to increment client count (False for manifest requests)
|
||||
|
||||
Returns:
|
||||
AcestreamSession instance
|
||||
"""
|
||||
# Check if we already have this session in this worker
|
||||
if infohash in self._sessions:
|
||||
session = self._sessions[infohash]
|
||||
session.touch()
|
||||
if increment_client:
|
||||
session.client_count += 1
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Reusing existing session: {infohash[:16]}... "
|
||||
f"(clients: {session.client_count})"
|
||||
)
|
||||
return session
|
||||
|
||||
# Need to create or fetch session - use Redis lock
|
||||
lock_key = f"acestream_session:{infohash}"
|
||||
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=30)
|
||||
|
||||
if not lock_acquired:
|
||||
raise Exception(f"Failed to acquire lock for acestream session: {infohash[:16]}...")
|
||||
|
||||
try:
|
||||
# Double-check after acquiring lock
|
||||
if infohash in self._sessions:
|
||||
session = self._sessions[infohash]
|
||||
session.touch()
|
||||
if increment_client:
|
||||
session.client_count += 1
|
||||
return session
|
||||
|
||||
# Check registry for existing session from another worker
|
||||
registry_data = await self._read_registry(infohash)
|
||||
|
||||
if registry_data:
|
||||
# Validate session is still alive by checking stat_url
|
||||
if await self._validate_session(registry_data.get("stat_url", "")):
|
||||
logger.info(f"[AcestreamSessionManager] Using existing session from registry: {infohash[:16]}...")
|
||||
session = AcestreamSession.from_dict(registry_data)
|
||||
session.client_count = 1 if increment_client else 0
|
||||
self._sessions[infohash] = session
|
||||
self._ensure_tasks()
|
||||
return session
|
||||
else:
|
||||
# Session is stale, remove from registry
|
||||
await self._delete_registry(infohash)
|
||||
|
||||
# Create new session
|
||||
logger.info(f"[AcestreamSessionManager] Creating new session: {infohash[:16]}...")
|
||||
try:
|
||||
response = await self._create_acestream_session(infohash, content_id)
|
||||
|
||||
session = AcestreamSession(
|
||||
infohash=infohash,
|
||||
pid=str(uuid4()),
|
||||
playback_url=response.playback_url,
|
||||
command_url=response.command_url,
|
||||
stat_url=response.stat_url,
|
||||
playback_session_id=response.playback_session_id,
|
||||
is_live=response.is_live,
|
||||
client_count=1 if increment_client else 0,
|
||||
)
|
||||
|
||||
self._sessions[infohash] = session
|
||||
await self._write_registry(session)
|
||||
self._ensure_tasks()
|
||||
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Created session: {infohash[:16]}... "
|
||||
f"playback_url: {response.playback_url}"
|
||||
)
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AcestreamSessionManager] Failed to create session: {e}")
|
||||
raise
|
||||
finally:
|
||||
await redis_utils.release_lock(lock_key)
|
||||
|
||||
async def _validate_session(self, stat_url: str) -> bool:
|
||||
"""Check if a session is still valid by polling stat_url."""
|
||||
if not stat_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
session = await self._get_http_session()
|
||||
async with session.get(stat_url, timeout=aiohttp.ClientTimeout(total=5)) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"[AcestreamSessionManager] Session validation failed: {e}")
|
||||
return False
|
||||
|
||||
async def release_session(self, infohash: str) -> None:
|
||||
"""
|
||||
Release a client's hold on a session.
|
||||
|
||||
Decrements client count. When count reaches 0, the session is closed.
|
||||
|
||||
Args:
|
||||
infohash: The infohash of the session to release
|
||||
"""
|
||||
if infohash not in self._sessions:
|
||||
return
|
||||
|
||||
session = self._sessions[infohash]
|
||||
session.client_count -= 1
|
||||
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Released client from session: {infohash[:16]}... "
|
||||
f"(remaining clients: {session.client_count})"
|
||||
)
|
||||
|
||||
if session.client_count <= 0:
|
||||
await self._close_session(infohash)
|
||||
|
||||
async def invalidate_session(self, infohash: str) -> None:
|
||||
"""
|
||||
Invalidate a stale session (e.g., when we get 403 from acestream).
|
||||
|
||||
This forces the session to be closed and removed from registry,
|
||||
so next request will create a fresh session.
|
||||
|
||||
Args:
|
||||
infohash: The infohash of the session to invalidate
|
||||
"""
|
||||
logger.warning(f"[AcestreamSessionManager] Invalidating stale session: {infohash[:16]}...")
|
||||
|
||||
if infohash in self._sessions:
|
||||
session = self._sessions.pop(infohash)
|
||||
# Try to stop the session gracefully
|
||||
if session.command_url:
|
||||
try:
|
||||
http_session = await self._get_http_session()
|
||||
url = f"{session.command_url}?method=stop"
|
||||
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=3)) as response:
|
||||
logger.debug(f"[AcestreamSessionManager] Stop command sent: {response.status}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[AcestreamSessionManager] Error stopping stale session: {e}")
|
||||
|
||||
# Always remove from registry
|
||||
await self._delete_registry(infohash)
|
||||
logger.info(f"[AcestreamSessionManager] Session invalidated: {infohash[:16]}...")
|
||||
|
||||
async def _close_session(self, infohash: str) -> None:
|
||||
"""
|
||||
Close an acestream session.
|
||||
|
||||
Calls command_url?method=stop to properly close the session.
|
||||
"""
|
||||
if infohash not in self._sessions:
|
||||
return
|
||||
|
||||
session = self._sessions.pop(infohash)
|
||||
|
||||
lock_key = f"acestream_session:{infohash}"
|
||||
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=10, timeout=10)
|
||||
|
||||
try:
|
||||
# Check if this is the last worker using this session
|
||||
registry_data = await self._read_registry(infohash)
|
||||
|
||||
# Only close if we're the owner or session is stale
|
||||
if registry_data and registry_data.get("worker_pid") == os.getpid():
|
||||
# We're the owner, close the session
|
||||
if session.command_url:
|
||||
try:
|
||||
http_session = await self._get_http_session()
|
||||
url = f"{session.command_url}?method=stop"
|
||||
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as response:
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Closed session: {infohash[:16]}... "
|
||||
f"(status: {response.status})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Error closing session: {e}")
|
||||
|
||||
await self._delete_registry(infohash)
|
||||
else:
|
||||
# Another worker may still be using this session
|
||||
logger.debug(
|
||||
f"[AcestreamSessionManager] Session {infohash[:16]}... owned by another worker, not closing"
|
||||
)
|
||||
finally:
|
||||
if lock_acquired:
|
||||
await redis_utils.release_lock(lock_key)
|
||||
|
||||
def _ensure_tasks(self) -> None:
|
||||
"""Ensure background tasks are running."""
|
||||
if self._keepalive_task is None or self._keepalive_task.done():
|
||||
self._keepalive_task = asyncio.create_task(self._keepalive_loop())
|
||||
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _keepalive_loop(self) -> None:
|
||||
"""Periodically poll stat_url to keep sessions alive with active clients or recent segment activity."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(settings.acestream_keepalive_interval)
|
||||
|
||||
for infohash, session in list(self._sessions.items()):
|
||||
# Keepalive sessions with active clients OR recent segment activity
|
||||
# This ensures HLS streams (which don't use client_count) stay alive
|
||||
has_recent_activity = session.is_actively_streaming(timeout=settings.acestream_empty_timeout)
|
||||
|
||||
if session.client_count <= 0 and not has_recent_activity:
|
||||
logger.debug(
|
||||
f"[AcestreamSessionManager] Skipping keepalive (no clients, no recent segments): "
|
||||
f"{infohash[:16]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
if session.stat_url:
|
||||
try:
|
||||
http_session = await self._get_http_session()
|
||||
async with http_session.get(
|
||||
session.stat_url,
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
session.touch()
|
||||
await self._write_registry(session)
|
||||
logger.debug(
|
||||
f"[AcestreamSessionManager] Keepalive OK: {infohash[:16]}... "
|
||||
f"(clients: {session.client_count}, recent_activity: {has_recent_activity})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AcestreamSessionManager] Keepalive failed: {infohash[:16]}... "
|
||||
f"(status: {response.status})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Keepalive error: {infohash[:16]}... - {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Keepalive loop error: {e}")
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Periodically clean up stale sessions."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(15) # Check every 15 seconds
|
||||
|
||||
now = time.time()
|
||||
timeout = settings.acestream_session_timeout
|
||||
empty_timeout = settings.acestream_empty_timeout
|
||||
|
||||
for infohash, session in list(self._sessions.items()):
|
||||
idle_time = now - session.last_access
|
||||
segment_idle_time = now - session.last_segment_request
|
||||
|
||||
# Don't clean up sessions with recent segment activity (active playback)
|
||||
# Use empty_timeout as the threshold for "recent" activity
|
||||
if segment_idle_time < empty_timeout:
|
||||
logger.debug(
|
||||
f"[AcestreamSessionManager] Session has recent segment activity: {infohash[:16]}... "
|
||||
f"(segment idle: {segment_idle_time:.0f}s)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Clean up sessions with no clients after empty_timeout (faster cleanup)
|
||||
if session.client_count <= 0 and idle_time > empty_timeout:
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Cleaning up empty session: {infohash[:16]}... "
|
||||
f"(idle: {idle_time:.0f}s, segment idle: {segment_idle_time:.0f}s)"
|
||||
)
|
||||
await self._close_session(infohash)
|
||||
# Clean up any session after session_timeout regardless of client count
|
||||
elif idle_time > timeout:
|
||||
logger.info(
|
||||
f"[AcestreamSessionManager] Cleaning up stale session: {infohash[:16]}... "
|
||||
f"(idle: {idle_time:.0f}s, segment idle: {segment_idle_time:.0f}s, clients: {session.client_count})"
|
||||
)
|
||||
await self._close_session(infohash)
|
||||
|
||||
# Note: Redis entries expire via TTL, no manual cleanup needed
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[AcestreamSessionManager] Cleanup loop error: {e}")
|
||||
|
||||
def get_session(self, infohash: str) -> Optional[AcestreamSession]:
|
||||
"""Get a session by infohash if it exists in this worker."""
|
||||
return self._sessions.get(infohash)
|
||||
|
||||
def get_active_sessions(self) -> Dict[str, AcestreamSession]:
|
||||
"""Get all active sessions in this worker."""
|
||||
return dict(self._sessions)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session manager and clean up resources."""
|
||||
# Cancel background tasks
|
||||
if self._keepalive_task:
|
||||
self._keepalive_task.cancel()
|
||||
try:
|
||||
await self._keepalive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close all sessions
|
||||
for infohash in list(self._sessions.keys()):
|
||||
await self._close_session(infohash)
|
||||
|
||||
# Close HTTP session
|
||||
if self._http_session and not self._http_session.closed:
|
||||
await self._http_session.close()
|
||||
|
||||
logger.info("[AcestreamSessionManager] Closed")
|
||||
|
||||
|
||||
# Global session manager instance
|
||||
acestream_manager = AcestreamSessionManager()
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
"""Abstract class for AES."""
|
||||
|
||||
|
||||
class AES(object):
|
||||
def __init__(self, key, mode, IV, implementation):
|
||||
if len(key) not in (16, 24, 32):
|
||||
@@ -19,21 +20,21 @@ class AES(object):
|
||||
self.isAEAD = False
|
||||
self.block_size = 16
|
||||
self.implementation = implementation
|
||||
if len(key)==16:
|
||||
if len(key) == 16:
|
||||
self.name = "aes128"
|
||||
elif len(key)==24:
|
||||
elif len(key) == 24:
|
||||
self.name = "aes192"
|
||||
elif len(key)==32:
|
||||
elif len(key) == 32:
|
||||
self.name = "aes256"
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
#CBC-Mode encryption, returns ciphertext
|
||||
#WARNING: *MAY* modify the input as well
|
||||
# CBC-Mode encryption, returns ciphertext
|
||||
# WARNING: *MAY* modify the input as well
|
||||
def encrypt(self, plaintext):
|
||||
assert(len(plaintext) % 16 == 0)
|
||||
assert len(plaintext) % 16 == 0
|
||||
|
||||
#CBC-Mode decryption, returns plaintext
|
||||
#WARNING: *MAY* modify the input as well
|
||||
# CBC-Mode decryption, returns plaintext
|
||||
# WARNING: *MAY* modify the input as well
|
||||
def decrypt(self, ciphertext):
|
||||
assert(len(ciphertext) % 16 == 0)
|
||||
assert len(ciphertext) % 16 == 0
|
||||
|
||||
@@ -18,6 +18,7 @@ from . import python_aes
|
||||
from .constanttime import ct_compare_digest
|
||||
from .cryptomath import bytesToNumber, numberToByteArray
|
||||
|
||||
|
||||
class AESGCM(object):
|
||||
"""
|
||||
AES-GCM implementation. Note: this implementation does not attempt
|
||||
@@ -39,7 +40,7 @@ class AESGCM(object):
|
||||
self.key = key
|
||||
|
||||
self._rawAesEncrypt = rawAesEncrypt
|
||||
self._ctr = python_aes.new(self.key, 6, bytearray(b'\x00' * 16))
|
||||
self._ctr = python_aes.new(self.key, 6, bytearray(b"\x00" * 16))
|
||||
|
||||
# The GCM key is AES(0).
|
||||
h = bytesToNumber(self._rawAesEncrypt(bytearray(16)))
|
||||
@@ -51,11 +52,8 @@ class AESGCM(object):
|
||||
self._productTable = [0] * 16
|
||||
self._productTable[self._reverseBits(1)] = h
|
||||
for i in range(2, 16, 2):
|
||||
self._productTable[self._reverseBits(i)] = \
|
||||
self._gcmShift(self._productTable[self._reverseBits(i//2)])
|
||||
self._productTable[self._reverseBits(i+1)] = \
|
||||
self._gcmAdd(self._productTable[self._reverseBits(i)], h)
|
||||
|
||||
self._productTable[self._reverseBits(i)] = self._gcmShift(self._productTable[self._reverseBits(i // 2)])
|
||||
self._productTable[self._reverseBits(i + 1)] = self._gcmAdd(self._productTable[self._reverseBits(i)], h)
|
||||
|
||||
def _auth(self, ciphertext, ad, tagMask):
|
||||
y = 0
|
||||
@@ -68,7 +66,7 @@ class AESGCM(object):
|
||||
|
||||
def _update(self, y, data):
|
||||
for i in range(0, len(data) // 16):
|
||||
y ^= bytesToNumber(data[16*i:16*i+16])
|
||||
y ^= bytesToNumber(data[16 * i : 16 * i + 16])
|
||||
y = self._mul(y)
|
||||
extra = len(data) % 16
|
||||
if extra != 0:
|
||||
@@ -79,26 +77,26 @@ class AESGCM(object):
|
||||
return y
|
||||
|
||||
def _mul(self, y):
|
||||
""" Returns y*H, where H is the GCM key. """
|
||||
"""Returns y*H, where H is the GCM key."""
|
||||
ret = 0
|
||||
# Multiply H by y 4 bits at a time, starting with the highest power
|
||||
# terms.
|
||||
for i in range(0, 128, 4):
|
||||
# Multiply by x^4. The reduction for the top four terms is
|
||||
# precomputed.
|
||||
retHigh = ret & 0xf
|
||||
retHigh = ret & 0xF
|
||||
ret >>= 4
|
||||
ret ^= (AESGCM._gcmReductionTable[retHigh] << (128-16))
|
||||
ret ^= AESGCM._gcmReductionTable[retHigh] << (128 - 16)
|
||||
|
||||
# Add in y' * H where y' are the next four terms of y, shifted down
|
||||
# to the x^0..x^4. This is one of the pre-computed multiples of
|
||||
# H. The multiplication by x^4 shifts them back into place.
|
||||
ret ^= self._productTable[y & 0xf]
|
||||
ret ^= self._productTable[y & 0xF]
|
||||
y >>= 4
|
||||
assert y == 0
|
||||
return ret
|
||||
|
||||
def seal(self, nonce, plaintext, data=''):
|
||||
def seal(self, nonce, plaintext, data=""):
|
||||
"""
|
||||
Encrypts and authenticates plaintext using nonce and data. Returns the
|
||||
ciphertext, consisting of the encrypted plaintext and tag concatenated.
|
||||
@@ -123,7 +121,7 @@ class AESGCM(object):
|
||||
|
||||
return ciphertext + tag
|
||||
|
||||
def open(self, nonce, ciphertext, data=''):
|
||||
def open(self, nonce, ciphertext, data=""):
|
||||
"""
|
||||
Decrypts and authenticates ciphertext using nonce and data. If the
|
||||
tag is valid, the plaintext is returned. If the tag is invalid,
|
||||
@@ -156,8 +154,8 @@ class AESGCM(object):
|
||||
@staticmethod
|
||||
def _reverseBits(i):
|
||||
assert i < 16
|
||||
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
|
||||
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
|
||||
i = ((i << 2) & 0xC) | ((i >> 2) & 0x3)
|
||||
i = ((i << 1) & 0xA) | ((i >> 1) & 0x5)
|
||||
return i
|
||||
|
||||
@staticmethod
|
||||
@@ -173,12 +171,12 @@ class AESGCM(object):
|
||||
# The x^127 term was shifted up to x^128, so subtract a 1+x+x^2+x^7
|
||||
# term. This is 0b11100001 or 0xe1 when represented as an 8-bit
|
||||
# polynomial.
|
||||
x ^= 0xe1 << (128-8)
|
||||
x ^= 0xE1 << (128 - 8)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _inc32(counter):
|
||||
for i in range(len(counter)-1, len(counter)-5, -1):
|
||||
for i in range(len(counter) - 1, len(counter) - 5, -1):
|
||||
counter[i] = (counter[i] + 1) % 256
|
||||
if counter[i] != 0:
|
||||
break
|
||||
@@ -188,6 +186,20 @@ class AESGCM(object):
|
||||
# result is stored as a 16-bit polynomial. This is used in the reduction step to
|
||||
# multiply elements of GF(2^128) by x^4.
|
||||
_gcmReductionTable = [
|
||||
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
|
||||
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
|
||||
0x0000,
|
||||
0x1C20,
|
||||
0x3840,
|
||||
0x2460,
|
||||
0x7080,
|
||||
0x6CA0,
|
||||
0x48C0,
|
||||
0x54E0,
|
||||
0xE100,
|
||||
0xFD20,
|
||||
0xD940,
|
||||
0xC560,
|
||||
0x9180,
|
||||
0x8DA0,
|
||||
0xA9C0,
|
||||
0xB5E0,
|
||||
]
|
||||
|
||||
@@ -9,56 +9,56 @@ logger = logging.getLogger(__name__)
|
||||
def is_base64_url(url: str) -> bool:
|
||||
"""
|
||||
Check if a URL appears to be base64 encoded.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): The URL to check.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the URL appears to be base64 encoded, False otherwise.
|
||||
"""
|
||||
# Check if the URL doesn't start with http/https and contains base64-like characters
|
||||
if url.startswith(('http://', 'https://', 'ftp://', 'ftps://')):
|
||||
if url.startswith(("http://", "https://", "ftp://", "ftps://")):
|
||||
return False
|
||||
|
||||
|
||||
# Base64 URLs typically contain only alphanumeric characters, +, /, and =
|
||||
# and don't contain typical URL characters like ://
|
||||
base64_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=')
|
||||
# and don't contain typical URL characters like ://
|
||||
base64_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
|
||||
url_chars = set(url)
|
||||
|
||||
|
||||
# If the URL contains characters not in base64 charset, it's likely not base64
|
||||
if not url_chars.issubset(base64_chars):
|
||||
return False
|
||||
|
||||
|
||||
# Additional heuristic: base64 strings are typically longer and don't contain common URL patterns
|
||||
if len(url) < 10: # Too short to be a meaningful base64 encoded URL
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def decode_base64_url(encoded_url: str) -> Optional[str]:
|
||||
"""
|
||||
Decode a base64 encoded URL.
|
||||
|
||||
|
||||
Args:
|
||||
encoded_url (str): The base64 encoded URL string.
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: The decoded URL if successful, None if decoding fails.
|
||||
"""
|
||||
try:
|
||||
# Handle URL-safe base64 encoding (replace - with + and _ with /)
|
||||
url_safe_encoded = encoded_url.replace('-', '+').replace('_', '/')
|
||||
|
||||
url_safe_encoded = encoded_url.replace("-", "+").replace("_", "/")
|
||||
|
||||
# Add padding if necessary
|
||||
missing_padding = len(url_safe_encoded) % 4
|
||||
if missing_padding:
|
||||
url_safe_encoded += '=' * (4 - missing_padding)
|
||||
|
||||
url_safe_encoded += "=" * (4 - missing_padding)
|
||||
|
||||
# Decode the base64 string
|
||||
decoded_bytes = base64.b64decode(url_safe_encoded)
|
||||
decoded_url = decoded_bytes.decode('utf-8')
|
||||
|
||||
decoded_url = decoded_bytes.decode("utf-8")
|
||||
|
||||
# Validate that the decoded string is a valid URL
|
||||
parsed = urlparse(decoded_url)
|
||||
if parsed.scheme and parsed.netloc:
|
||||
@@ -67,7 +67,7 @@ def decode_base64_url(encoded_url: str) -> Optional[str]:
|
||||
else:
|
||||
logger.warning(f"Decoded string is not a valid URL: {decoded_url}")
|
||||
return None
|
||||
|
||||
|
||||
except (base64.binascii.Error, UnicodeDecodeError, ValueError) as e:
|
||||
logger.debug(f"Failed to decode base64 URL '{encoded_url[:50]}...': {e}")
|
||||
return None
|
||||
@@ -76,27 +76,27 @@ def decode_base64_url(encoded_url: str) -> Optional[str]:
|
||||
def encode_url_to_base64(url: str, url_safe: bool = True) -> str:
|
||||
"""
|
||||
Encode a URL to base64.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): The URL to encode.
|
||||
url_safe (bool): Whether to use URL-safe base64 encoding (default: True).
|
||||
|
||||
|
||||
Returns:
|
||||
str: The base64 encoded URL.
|
||||
"""
|
||||
try:
|
||||
url_bytes = url.encode('utf-8')
|
||||
url_bytes = url.encode("utf-8")
|
||||
if url_safe:
|
||||
# Use URL-safe base64 encoding (replace + with - and / with _)
|
||||
encoded = base64.urlsafe_b64encode(url_bytes).decode('utf-8')
|
||||
encoded = base64.urlsafe_b64encode(url_bytes).decode("utf-8")
|
||||
# Remove padding for cleaner URLs
|
||||
encoded = encoded.rstrip('=')
|
||||
encoded = encoded.rstrip("=")
|
||||
else:
|
||||
encoded = base64.b64encode(url_bytes).decode('utf-8')
|
||||
|
||||
encoded = base64.b64encode(url_bytes).decode("utf-8")
|
||||
|
||||
logger.debug(f"Encoded URL to base64: {url} -> {encoded}")
|
||||
return encoded
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encode URL to base64: {e}")
|
||||
raise
|
||||
@@ -106,10 +106,10 @@ def process_potential_base64_url(url: str) -> str:
|
||||
"""
|
||||
Process a URL that might be base64 encoded. If it's base64 encoded, decode it.
|
||||
Otherwise, return the original URL.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): The URL to process.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The processed URL (decoded if it was base64, original otherwise).
|
||||
"""
|
||||
@@ -119,5 +119,5 @@ def process_potential_base64_url(url: str) -> str:
|
||||
return decoded_url
|
||||
else:
|
||||
logger.warning(f"URL appears to be base64 but failed to decode: {url[:50]}...")
|
||||
|
||||
return url
|
||||
|
||||
return url
|
||||
|
||||
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Base prebuffer class with shared functionality for HLS and DASH prebuffering.
|
||||
|
||||
This module provides cross-process download coordination using Redis-based locking
|
||||
to prevent duplicate downloads across multiple uvicorn workers. Both player requests
|
||||
and background prebuffer tasks use the same coordination mechanism.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import psutil
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mediaflow_proxy.utils.cache_utils import (
|
||||
get_cached_segment,
|
||||
set_cached_segment,
|
||||
)
|
||||
from mediaflow_proxy.utils.http_utils import download_file_with_retry
|
||||
from mediaflow_proxy.utils import redis_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrebufferStats:
|
||||
"""Statistics for prebuffer performance tracking."""
|
||||
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
segments_prebuffered: int = 0
|
||||
bytes_prebuffered: int = 0
|
||||
prefetch_triggered: int = 0
|
||||
downloads_coordinated: int = 0 # Times we waited for existing download
|
||||
last_reset: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate percentage."""
|
||||
total = self.cache_hits + self.cache_misses
|
||||
return (self.cache_hits / total * 100) if total > 0 else 0.0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset statistics."""
|
||||
self.cache_hits = 0
|
||||
self.cache_misses = 0
|
||||
self.segments_prebuffered = 0
|
||||
self.bytes_prebuffered = 0
|
||||
self.prefetch_triggered = 0
|
||||
self.downloads_coordinated = 0
|
||||
self.last_reset = time.time()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert stats to dictionary for logging."""
|
||||
return {
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": f"{self.hit_rate:.1f}%",
|
||||
"segments_prebuffered": self.segments_prebuffered,
|
||||
"bytes_prebuffered_mb": f"{self.bytes_prebuffered / 1024 / 1024:.2f}",
|
||||
"prefetch_triggered": self.prefetch_triggered,
|
||||
"downloads_coordinated": self.downloads_coordinated,
|
||||
"uptime_seconds": int(time.time() - self.last_reset),
|
||||
}
|
||||
|
||||
|
||||
class BasePrebuffer(ABC):
|
||||
"""
|
||||
Base class for prebuffer systems with cross-process download coordination.
|
||||
|
||||
This class provides:
|
||||
- Cross-process coordination using Redis locks to prevent duplicate downloads
|
||||
- Memory usage monitoring
|
||||
- Cache statistics tracking
|
||||
- Shared download and caching logic
|
||||
|
||||
The Redis-based locking ensures that even with multiple uvicorn workers,
|
||||
only one worker downloads any given segment at a time.
|
||||
|
||||
Subclasses should implement protocol-specific logic (HLS playlist parsing,
|
||||
DASH MPD handling, etc.) while inheriting the core download coordination.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: int,
|
||||
prebuffer_segments: int,
|
||||
max_memory_percent: float,
|
||||
emergency_threshold: float,
|
||||
segment_ttl: int = 60,
|
||||
):
|
||||
"""
|
||||
Initialize the base prebuffer.
|
||||
|
||||
Args:
|
||||
max_cache_size: Maximum number of segments to track
|
||||
prebuffer_segments: Number of segments to pre-buffer ahead
|
||||
max_memory_percent: Maximum memory usage percentage before skipping prebuffer
|
||||
emergency_threshold: Memory threshold for emergency cleanup
|
||||
segment_ttl: TTL for cached segments in seconds
|
||||
"""
|
||||
self.max_cache_size = max_cache_size
|
||||
self.prebuffer_segment_count = prebuffer_segments
|
||||
self.max_memory_percent = max_memory_percent
|
||||
self.emergency_threshold = emergency_threshold
|
||||
self.segment_ttl = segment_ttl
|
||||
|
||||
# Statistics (per-worker, not shared - but that's fine for monitoring)
|
||||
self.stats = PrebufferStats()
|
||||
|
||||
# Stats logging task
|
||||
self._stats_task: Optional[asyncio.Task] = None
|
||||
self._stats_interval = 60 # Log stats every 60 seconds
|
||||
|
||||
def _get_memory_usage_percent(self) -> float:
|
||||
"""Get current memory usage percentage."""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
return memory.percent
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get memory usage: {e}")
|
||||
return 0.0
|
||||
|
||||
def _check_memory_threshold(self) -> bool:
|
||||
"""Check if memory usage exceeds the emergency threshold."""
|
||||
return self._get_memory_usage_percent() > self.emergency_threshold
|
||||
|
||||
def _should_skip_for_memory(self) -> bool:
|
||||
"""Check if we should skip prebuffering due to high memory usage."""
|
||||
return self._get_memory_usage_percent() > self.max_memory_percent
|
||||
|
||||
def record_cache_hit(self) -> None:
|
||||
"""Record a cache hit for statistics."""
|
||||
self.stats.cache_hits += 1
|
||||
self._ensure_stats_logging()
|
||||
|
||||
def record_cache_miss(self) -> None:
|
||||
"""Record a cache miss for statistics."""
|
||||
self.stats.cache_misses += 1
|
||||
self._ensure_stats_logging()
|
||||
|
||||
def _ensure_stats_logging(self) -> None:
|
||||
"""Ensure the stats logging task is running."""
|
||||
if self._stats_task is None or self._stats_task.done():
|
||||
self._stats_task = asyncio.create_task(self._periodic_stats_logging())
|
||||
|
||||
async def _periodic_stats_logging(self) -> None:
|
||||
"""Periodically log prebuffer statistics."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._stats_interval)
|
||||
|
||||
# Only log if there's been activity
|
||||
if self.stats.cache_hits > 0 or self.stats.cache_misses > 0:
|
||||
self.log_stats()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in stats logging: {e}")
|
||||
|
||||
async def get_or_download(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
timeout: float = 10.0,
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Get a segment from cache or download it, with cross-process coordination.
|
||||
|
||||
This is the primary method for getting segments. It:
|
||||
1. Checks cache first (immediate return if hit)
|
||||
2. Acquires Redis lock to prevent duplicate downloads across workers
|
||||
3. Double-checks cache after acquiring lock
|
||||
4. Downloads and caches if needed
|
||||
|
||||
The Redis-based locking ensures that even with multiple uvicorn workers,
|
||||
only one worker downloads any given segment at a time.
|
||||
|
||||
Args:
|
||||
url: URL of the segment to get
|
||||
headers: Headers to use for the request
|
||||
timeout: Maximum time to wait for lock acquisition (seconds).
|
||||
Keep this short (10s) for player requests - if lock is held
|
||||
too long, fall back to direct streaming.
|
||||
|
||||
Returns:
|
||||
Segment data if successful, None if failed or timed out
|
||||
"""
|
||||
self._ensure_stats_logging()
|
||||
|
||||
# Check cache first (Redis cache is shared across workers)
|
||||
cached = await get_cached_segment(url)
|
||||
if cached:
|
||||
self.record_cache_hit()
|
||||
logger.info(f"[get_or_download] CACHE HIT ({len(cached)} bytes): {url}")
|
||||
return cached
|
||||
|
||||
# Cache miss - need to coordinate download across workers
|
||||
logger.info(f"[get_or_download] CACHE MISS: {url}")
|
||||
|
||||
lock_key = f"segment_download:{url}"
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
# Acquire Redis lock - only one worker downloads at a time
|
||||
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=timeout)
|
||||
|
||||
if not lock_acquired:
|
||||
logger.warning(f"[get_or_download] Lock TIMEOUT ({timeout}s), falling back to streaming: {url}")
|
||||
return None
|
||||
|
||||
# Double-check cache after acquiring lock
|
||||
# Another worker may have completed the download while we waited
|
||||
cached = await get_cached_segment(url)
|
||||
if cached:
|
||||
# Count this as a cache hit since we didn't download
|
||||
self.record_cache_hit()
|
||||
self.stats.downloads_coordinated += 1
|
||||
logger.info(f"[get_or_download] Found in cache after lock (coordinated): {url}")
|
||||
return cached
|
||||
|
||||
# We're the one who needs to download - count as miss now
|
||||
self.record_cache_miss()
|
||||
|
||||
# We're the first - download and cache
|
||||
logger.info(f"[get_or_download] Downloading: {url}")
|
||||
content = await self._download_and_cache(url, headers)
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[get_or_download] Error during download coordination: {e}")
|
||||
return None
|
||||
finally:
|
||||
if lock_acquired:
|
||||
await redis_utils.release_lock(lock_key)
|
||||
|
||||
async def _download_and_cache(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Download a segment and cache it.
|
||||
|
||||
This method should only be called while holding the Redis lock.
|
||||
|
||||
Args:
|
||||
url: URL to download
|
||||
headers: Headers for the request
|
||||
|
||||
Returns:
|
||||
Downloaded content if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
content = await download_file_with_retry(url, headers)
|
||||
if content:
|
||||
logger.info(f"[_download_and_cache] Downloaded {len(content)} bytes, caching: {url}")
|
||||
await set_cached_segment(url, content, ttl=self.segment_ttl)
|
||||
self.stats.segments_prebuffered += 1
|
||||
self.stats.bytes_prebuffered += len(content)
|
||||
return content
|
||||
else:
|
||||
logger.warning(f"[_download_and_cache] Download returned empty: {url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[_download_and_cache] Failed to download: {url} - {e}")
|
||||
return None
|
||||
|
||||
async def try_get_cached(self, url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Check cache only, don't download.
|
||||
|
||||
Use this for background prebuffer tasks that shouldn't block
|
||||
if segment isn't available yet.
|
||||
|
||||
Args:
|
||||
url: URL to check in cache
|
||||
|
||||
Returns:
|
||||
Cached data if available, None otherwise
|
||||
"""
|
||||
return await get_cached_segment(url)
|
||||
|
||||
async def prebuffer_segment(self, url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Prebuffer a single segment in the background.
|
||||
|
||||
This method uses Redis locking to prevent duplicate downloads
|
||||
across multiple workers.
|
||||
|
||||
Args:
|
||||
url: URL of segment to prebuffer
|
||||
headers: Headers for the request
|
||||
"""
|
||||
if self._should_skip_for_memory():
|
||||
logger.debug("Skipping prebuffer due to high memory usage")
|
||||
return
|
||||
|
||||
# Check if already cached
|
||||
cached = await get_cached_segment(url)
|
||||
if cached:
|
||||
logger.debug(f"[prebuffer_segment] Already cached, skipping: {url}")
|
||||
return
|
||||
|
||||
lock_key = f"segment_download:{url}"
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
# Try to acquire lock with short timeout for prebuffering
|
||||
# If lock is held by another process, skip this segment
|
||||
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=1.0)
|
||||
|
||||
if not lock_acquired:
|
||||
# Another process is downloading, skip this segment
|
||||
logger.debug(f"[prebuffer_segment] Lock busy, skipping: {url}")
|
||||
return
|
||||
|
||||
# Double-check cache after acquiring lock
|
||||
cached = await get_cached_segment(url)
|
||||
if cached:
|
||||
logger.debug(f"[prebuffer_segment] Found in cache after lock: {url}")
|
||||
return
|
||||
|
||||
# Download and cache
|
||||
logger.info(f"[prebuffer_segment] Downloading: {url}")
|
||||
await self._download_and_cache(url, headers)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[prebuffer_segment] Error: {e}")
|
||||
finally:
|
||||
if lock_acquired:
|
||||
await redis_utils.release_lock(lock_key)
|
||||
|
||||
async def prebuffer_segments_batch(
|
||||
self,
|
||||
urls: list,
|
||||
headers: Dict[str, str],
|
||||
max_concurrent: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Prebuffer multiple segments with concurrency control.
|
||||
|
||||
Args:
|
||||
urls: List of segment URLs to prebuffer
|
||||
headers: Headers for requests
|
||||
max_concurrent: Maximum concurrent downloads (default 2 to avoid
|
||||
lock contention with player requests)
|
||||
"""
|
||||
if self._should_skip_for_memory():
|
||||
logger.warning("Skipping prebuffer due to high memory usage")
|
||||
return
|
||||
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def limited_prebuffer(url: str):
|
||||
async with semaphore:
|
||||
await self.prebuffer_segment(url, headers)
|
||||
|
||||
# Start all prebuffer tasks
|
||||
tasks = [limited_prebuffer(url) for url in urls]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""Log current prebuffer statistics."""
|
||||
logger.info(f"Prebuffer Stats: {self.stats.to_dict()}")
|
||||
@@ -1,308 +1,31 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Any
|
||||
"""
|
||||
Cache utilities for mediaflow-proxy.
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
All caching is now done via Redis for cross-worker sharing.
|
||||
See redis_utils.py for the underlying Redis operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
|
||||
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
|
||||
from mediaflow_proxy.utils import redis_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Represents a cache entry with metadata."""
|
||||
|
||||
data: bytes
|
||||
expires_at: float
|
||||
access_count: int = 0
|
||||
last_access: float = 0.0
|
||||
size: int = 0
|
||||
# =============================================================================
|
||||
# Init Segment Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LRUMemoryCache:
|
||||
"""Thread-safe LRU memory cache with support."""
|
||||
|
||||
def __init__(self, maxsize: int):
|
||||
self.maxsize = maxsize
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
self._current_size = 0
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache.pop(key) # Remove and re-insert for LRU
|
||||
if time.time() < entry.expires_at:
|
||||
entry.access_count += 1
|
||||
entry.last_access = time.time()
|
||||
self._cache[key] = entry
|
||||
return entry
|
||||
else:
|
||||
# Remove expired entry
|
||||
self._current_size -= entry.size
|
||||
self._cache.pop(key, None)
|
||||
return None
|
||||
|
||||
def set(self, key: str, entry: CacheEntry) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
old_entry = self._cache[key]
|
||||
self._current_size -= old_entry.size
|
||||
|
||||
# Check if we need to make space
|
||||
while self._current_size + entry.size > self.maxsize and self._cache:
|
||||
_, removed_entry = self._cache.popitem(last=False)
|
||||
self._current_size -= removed_entry.size
|
||||
|
||||
self._cache[key] = entry
|
||||
self._current_size += entry.size
|
||||
|
||||
def remove(self, key: str) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache.pop(key)
|
||||
self._current_size -= entry.size
|
||||
|
||||
|
||||
class HybridCache:
|
||||
"""High-performance hybrid cache combining memory and file storage."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir_name: str,
|
||||
ttl: int,
|
||||
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
|
||||
executor_workers: int = 4,
|
||||
):
|
||||
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
|
||||
self.ttl = ttl
|
||||
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize cache directories
|
||||
self._init_cache_dirs()
|
||||
|
||||
def _init_cache_dirs(self):
|
||||
"""Initialize sharded cache directories."""
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _get_md5_hash(self, key: str) -> str:
|
||||
"""Get the MD5 hash of a cache key."""
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def _get_file_path(self, key: str) -> Path:
|
||||
"""Get the file path for a cache key."""
|
||||
return self.cache_dir / key
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||
"""
|
||||
Get value from cache, trying memory first then file.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Cached value or default if not found
|
||||
"""
|
||||
key = self._get_md5_hash(key)
|
||||
# Try memory cache first
|
||||
entry = self.memory_cache.get(key)
|
||||
if entry is not None:
|
||||
return entry.data
|
||||
|
||||
# Try file cache
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
metadata_size = await f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_size, "big")
|
||||
metadata_bytes = await f.read(metadata_length)
|
||||
metadata = json.loads(metadata_bytes.decode())
|
||||
|
||||
# Check expiration
|
||||
if metadata["expires_at"] < time.time():
|
||||
await self.delete(key)
|
||||
return default
|
||||
|
||||
# Read data
|
||||
data = await f.read()
|
||||
|
||||
# Update memory cache in background
|
||||
entry = CacheEntry(
|
||||
data=data,
|
||||
expires_at=metadata["expires_at"],
|
||||
access_count=metadata["access_count"] + 1,
|
||||
last_access=time.time(),
|
||||
size=len(data),
|
||||
)
|
||||
self.memory_cache.set(key, entry)
|
||||
|
||||
return data
|
||||
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from cache: {e}")
|
||||
return default
|
||||
|
||||
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in both memory and file cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
data: Data to cache
|
||||
ttl: Optional TTL override
|
||||
|
||||
Returns:
|
||||
bool: Success status
|
||||
"""
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise ValueError("Data must be bytes, bytearray, or memoryview")
|
||||
|
||||
ttl_seconds = self.ttl if ttl is None else ttl
|
||||
|
||||
key = self._get_md5_hash(key)
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
# Explicit request to avoid caching - remove any previous entry and return success
|
||||
self.memory_cache.remove(key)
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
await aiofiles.os.remove(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing cache file: {e}")
|
||||
return True
|
||||
|
||||
expires_at = time.time() + ttl_seconds
|
||||
|
||||
# Create cache entry
|
||||
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
|
||||
|
||||
# Update memory cache
|
||||
self.memory_cache.set(key, entry)
|
||||
file_path = self._get_file_path(key)
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
|
||||
# Update file cache
|
||||
try:
|
||||
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
|
||||
metadata_bytes = json.dumps(metadata).encode()
|
||||
metadata_size = len(metadata_bytes).to_bytes(8, "big")
|
||||
|
||||
async with aiofiles.open(temp_path, "wb") as f:
|
||||
await f.write(metadata_size)
|
||||
await f.write(metadata_bytes)
|
||||
await f.write(data)
|
||||
|
||||
await aiofiles.os.rename(temp_path, file_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to cache: {e}")
|
||||
try:
|
||||
await aiofiles.os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete item from both caches."""
|
||||
hashed_key = self._get_md5_hash(key)
|
||||
self.memory_cache.remove(hashed_key)
|
||||
|
||||
try:
|
||||
file_path = self._get_file_path(hashed_key)
|
||||
await aiofiles.os.remove(file_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting from cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AsyncMemoryCache:
|
||||
"""Async wrapper around LRUMemoryCache."""
|
||||
|
||||
def __init__(self, max_memory_size: int):
|
||||
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||
"""Get value from cache."""
|
||||
entry = self.memory_cache.get(key)
|
||||
return entry.data if entry is not None else default
|
||||
|
||||
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||
"""Set value in cache."""
|
||||
try:
|
||||
ttl_seconds = 3600 if ttl is None else ttl
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
self.memory_cache.remove(key)
|
||||
return True
|
||||
|
||||
expires_at = time.time() + ttl_seconds
|
||||
entry = CacheEntry(
|
||||
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
|
||||
)
|
||||
self.memory_cache.set(key, entry)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting cache value: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete item from cache."""
|
||||
try:
|
||||
self.memory_cache.remove(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting from cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Create cache instances
|
||||
INIT_SEGMENT_CACHE = HybridCache(
|
||||
cache_dir_name="init_segment_cache",
|
||||
ttl=3600, # 1 hour
|
||||
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
|
||||
)
|
||||
|
||||
MPD_CACHE = AsyncMemoryCache(
|
||||
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
|
||||
)
|
||||
|
||||
EXTRACTOR_CACHE = HybridCache(
|
||||
cache_dir_name="extractor_cache",
|
||||
ttl=5 * 60, # 5 minutes
|
||||
max_memory_size=50 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
# Specific cache implementations
|
||||
async def get_cached_init_segment(
|
||||
init_url: str,
|
||||
headers: dict,
|
||||
cache_token: str | None = None,
|
||||
ttl: Optional[int] = None,
|
||||
byte_range: str | None = None,
|
||||
) -> Optional[bytes]:
|
||||
"""Get initialization segment from cache or download it.
|
||||
|
||||
@@ -310,29 +33,39 @@ async def get_cached_init_segment(
|
||||
rely on different DRM keys or initialization payloads (e.g. key rotation).
|
||||
|
||||
ttl overrides the default cache TTL; pass a value <= 0 to skip caching entirely.
|
||||
"""
|
||||
|
||||
byte_range specifies a byte range for SegmentBase MPDs (e.g., '0-11568').
|
||||
"""
|
||||
use_cache = ttl is None or ttl > 0
|
||||
cache_key = f"{init_url}|{cache_token}" if cache_token else init_url
|
||||
# Include byte_range in cache key for SegmentBase
|
||||
cache_key = f"{init_url}|{cache_token}|{byte_range}" if cache_token or byte_range else init_url
|
||||
|
||||
if use_cache:
|
||||
cached_data = await INIT_SEGMENT_CACHE.get(cache_key)
|
||||
cached_data = await redis_utils.get_cached_init_segment(cache_key)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
else:
|
||||
# Remove any previously cached entry when caching is disabled
|
||||
await INIT_SEGMENT_CACHE.delete(cache_key)
|
||||
|
||||
try:
|
||||
init_content = await download_file_with_retry(init_url, headers)
|
||||
# Add Range header if byte_range is specified (for SegmentBase MPDs)
|
||||
request_headers = dict(headers)
|
||||
if byte_range:
|
||||
request_headers["Range"] = f"bytes={byte_range}"
|
||||
|
||||
init_content = await download_file_with_retry(init_url, request_headers)
|
||||
if init_content and use_cache:
|
||||
await INIT_SEGMENT_CACHE.set(cache_key, init_content, ttl=ttl)
|
||||
cache_ttl = ttl if ttl is not None else redis_utils.DEFAULT_INIT_CACHE_TTL
|
||||
await redis_utils.set_cached_init_segment(cache_key, init_content, ttl=cache_ttl)
|
||||
return init_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading init segment: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MPD Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_cached_mpd(
|
||||
mpd_url: str,
|
||||
headers: dict,
|
||||
@@ -341,13 +74,13 @@ async def get_cached_mpd(
|
||||
) -> dict:
|
||||
"""Get MPD from cache or download and parse it."""
|
||||
# Try cache first
|
||||
cached_data = await MPD_CACHE.get(mpd_url)
|
||||
cached_data = await redis_utils.get_cached_mpd(mpd_url)
|
||||
if cached_data is not None:
|
||||
try:
|
||||
mpd_dict = json.loads(cached_data)
|
||||
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||
except json.JSONDecodeError:
|
||||
await MPD_CACHE.delete(mpd_url)
|
||||
return parse_mpd_dict(cached_data, mpd_url, parse_drm, parse_segment_profile_id)
|
||||
except Exception:
|
||||
# Invalid cached data, will re-download
|
||||
pass
|
||||
|
||||
# Download and parse if not cached
|
||||
try:
|
||||
@@ -355,8 +88,9 @@ async def get_cached_mpd(
|
||||
mpd_dict = parse_mpd(mpd_content)
|
||||
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||
|
||||
# Cache the original MPD dict
|
||||
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict.get("minimumUpdatePeriod"))
|
||||
# Cache the original MPD dict with TTL from minimumUpdatePeriod
|
||||
cache_ttl = parsed_dict.get("minimumUpdatePeriod") or redis_utils.DEFAULT_MPD_CACHE_TTL
|
||||
await redis_utils.set_cached_mpd(mpd_url, mpd_dict, ttl=cache_ttl)
|
||||
return parsed_dict
|
||||
except DownloadError as error:
|
||||
logger.error(f"Error downloading MPD: {error}")
|
||||
@@ -366,21 +100,158 @@ async def get_cached_mpd(
|
||||
raise error
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Extractor Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_cached_extractor_result(key: str) -> Optional[dict]:
|
||||
"""Get extractor result from cache."""
|
||||
cached_data = await EXTRACTOR_CACHE.get(key)
|
||||
if cached_data is not None:
|
||||
try:
|
||||
return json.loads(cached_data)
|
||||
except json.JSONDecodeError:
|
||||
await EXTRACTOR_CACHE.delete(key)
|
||||
return None
|
||||
return await redis_utils.get_cached_extractor(key)
|
||||
|
||||
|
||||
async def set_cache_extractor_result(key: str, result: dict) -> bool:
|
||||
"""Cache extractor result."""
|
||||
try:
|
||||
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
|
||||
await redis_utils.set_cached_extractor(key, result)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching extractor result: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Processed Init Segment Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_cached_processed_init(
|
||||
init_url: str,
|
||||
key_id: str,
|
||||
) -> Optional[bytes]:
|
||||
"""Get processed (DRM-stripped) init segment from cache.
|
||||
|
||||
Args:
|
||||
init_url: URL of the init segment
|
||||
key_id: DRM key ID used for processing
|
||||
|
||||
Returns:
|
||||
Processed init segment bytes if cached, None otherwise
|
||||
"""
|
||||
cache_key = f"processed|{init_url}|{key_id}"
|
||||
return await redis_utils.get_cached_processed_init(cache_key)
|
||||
|
||||
|
||||
async def set_cached_processed_init(
|
||||
init_url: str,
|
||||
key_id: str,
|
||||
processed_content: bytes,
|
||||
ttl: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Cache processed (DRM-stripped) init segment.
|
||||
|
||||
Args:
|
||||
init_url: URL of the init segment
|
||||
key_id: DRM key ID used for processing
|
||||
processed_content: The processed init segment bytes
|
||||
ttl: Optional TTL override
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
cache_key = f"processed|{init_url}|{key_id}"
|
||||
try:
|
||||
cache_ttl = ttl if ttl is not None else redis_utils.DEFAULT_PROCESSED_INIT_TTL
|
||||
await redis_utils.set_cached_processed_init(cache_key, processed_content, ttl=cache_ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching processed init segment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Processed Segment Cache (decrypted/remuxed segments)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_cached_processed_segment(
|
||||
segment_url: str,
|
||||
key_id: str = None,
|
||||
remux: bool = False,
|
||||
) -> Optional[bytes]:
|
||||
"""Get processed (decrypted/remuxed) segment from cache.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment
|
||||
key_id: DRM key ID if decrypted
|
||||
remux: Whether the segment was remuxed to TS
|
||||
|
||||
Returns:
|
||||
Processed segment bytes if cached, None otherwise
|
||||
"""
|
||||
cache_key = f"proc|{segment_url}|{key_id or ''}|{remux}"
|
||||
return await redis_utils.get_cached_segment(cache_key)
|
||||
|
||||
|
||||
async def set_cached_processed_segment(
|
||||
segment_url: str,
|
||||
content: bytes,
|
||||
key_id: str = None,
|
||||
remux: bool = False,
|
||||
ttl: int = 60,
|
||||
) -> bool:
|
||||
"""Cache processed (decrypted/remuxed) segment.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment
|
||||
content: Processed segment bytes
|
||||
key_id: DRM key ID if decrypted
|
||||
remux: Whether the segment was remuxed to TS
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
cache_key = f"proc|{segment_url}|{key_id or ''}|{remux}"
|
||||
try:
|
||||
await redis_utils.set_cached_segment(cache_key, content, ttl=ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching processed segment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Segment Cache
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_cached_segment(segment_url: str) -> Optional[bytes]:
|
||||
"""Get media segment from prebuffer cache.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment
|
||||
|
||||
Returns:
|
||||
Segment bytes if cached, None otherwise
|
||||
"""
|
||||
return await redis_utils.get_cached_segment(segment_url)
|
||||
|
||||
|
||||
async def set_cached_segment(segment_url: str, content: bytes, ttl: int = 60) -> bool:
|
||||
"""Cache media segment with configurable TTL.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment
|
||||
content: Segment bytes
|
||||
ttl: Time to live in seconds (default 60s, configurable via dash_segment_cache_ttl)
|
||||
|
||||
Returns:
|
||||
True if cached successfully
|
||||
"""
|
||||
try:
|
||||
await redis_utils.set_cached_segment(segment_url, content, ttl=ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching segment: {e}")
|
||||
return False
|
||||
|
||||
+78
-171
@@ -1,27 +1,27 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
# See the LICENSE file for legal information regarding use of this file
|
||||
|
||||
"""Classes for reading/writing binary data (such as TLS records)."""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import struct
|
||||
from struct import pack
|
||||
|
||||
from .compat import bytes_to_int
|
||||
|
||||
|
||||
class DecodeError(SyntaxError):
|
||||
"""Exception raised in case of decoding errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BadCertificateError(SyntaxError):
|
||||
"""Exception raised in case of bad certificate."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Writer(object):
|
||||
class Writer:
|
||||
"""Serialisation helper for complex byte-based structures."""
|
||||
|
||||
def __init__(self):
|
||||
@@ -32,102 +32,51 @@ class Writer(object):
|
||||
"""Add a single-byte wide element to buffer, see add()."""
|
||||
self.bytes.append(val)
|
||||
|
||||
if sys.version_info < (2, 7):
|
||||
# struct.pack on Python2.6 does not raise exception if the value
|
||||
# is larger than can fit inside the specified size
|
||||
def addTwo(self, val):
|
||||
"""Add a double-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>H', val)
|
||||
def addTwo(self, val):
|
||||
"""Add a double-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack(">H", val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addThree(self, val):
|
||||
"""Add a three-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>BH', val >> 16, val & 0xffff)
|
||||
def addThree(self, val):
|
||||
"""Add a three-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack(">BH", val >> 16, val & 0xFFFF)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addFour(self, val):
|
||||
"""Add a four-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffffffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>I', val)
|
||||
else:
|
||||
def addTwo(self, val):
|
||||
"""Add a double-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>H', val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
def addFour(self, val):
|
||||
"""Add a four-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack(">I", val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addThree(self, val):
|
||||
"""Add a three-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>BH', val >> 16, val & 0xffff)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
def add(self, x, length):
|
||||
"""
|
||||
Add a single positive integer value x, encode it in length bytes.
|
||||
|
||||
def addFour(self, val):
|
||||
"""Add a four-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>I', val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
Encode positive integer x in big-endian format using length bytes,
|
||||
add to the internal buffer.
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
# the method is called thousands of times, so it's better to extern
|
||||
# the version info check
|
||||
def add(self, x, length):
|
||||
"""
|
||||
Add a single positive integer value x, encode it in length bytes
|
||||
:type x: int
|
||||
:param x: value to encode
|
||||
|
||||
Encode positive integer x in big-endian format using length bytes,
|
||||
add to the internal buffer.
|
||||
|
||||
:type x: int
|
||||
:param x: value to encode
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes to use for encoding the value
|
||||
"""
|
||||
try:
|
||||
self.bytes += x.to_bytes(length, 'big')
|
||||
except OverflowError:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
else:
|
||||
_addMethods = {1: addOne, 2: addTwo, 3: addThree, 4: addFour}
|
||||
|
||||
def add(self, x, length):
|
||||
"""
|
||||
Add a single positive integer value x, encode it in length bytes
|
||||
|
||||
Encode positive iteger x in big-endian format using length bytes,
|
||||
add to the internal buffer.
|
||||
|
||||
:type x: int
|
||||
:param x: value to encode
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes to use for encoding the value
|
||||
"""
|
||||
try:
|
||||
self._addMethods[length](self, x)
|
||||
except KeyError:
|
||||
self.bytes += bytearray(length)
|
||||
newIndex = len(self.bytes) - 1
|
||||
for i in range(newIndex, newIndex - length, -1):
|
||||
self.bytes[i] = x & 0xFF
|
||||
x >>= 8
|
||||
if x != 0:
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
:type length: int
|
||||
:param length: number of bytes to use for encoding the value
|
||||
"""
|
||||
try:
|
||||
self.bytes += x.to_bytes(length, "big")
|
||||
except OverflowError:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addFixSeq(self, seq, length):
|
||||
"""
|
||||
Add a list of items, encode every item in length bytes
|
||||
Add a list of items, encode every item in length bytes.
|
||||
|
||||
Uses the unbounded iterable seq to produce items, each of
|
||||
which is then encoded to length bytes
|
||||
which is then encoded to length bytes.
|
||||
|
||||
:type seq: iterable of int
|
||||
:param seq: list of positive integers to encode
|
||||
@@ -138,72 +87,35 @@ class Writer(object):
|
||||
for e in seq:
|
||||
self.add(e, length)
|
||||
|
||||
if sys.version_info < (2, 7):
|
||||
# struct.pack on Python2.6 does not raise exception if the value
|
||||
# is larger than can fit inside the specified size
|
||||
def _addVarSeqTwo(self, seq):
|
||||
"""Helper method for addVarSeq"""
|
||||
if not all(0 <= i <= 0xffff for i in seq):
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
self.bytes += pack('>' + 'H' * len(seq), *seq)
|
||||
def addVarSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a bounded list of same-sized values.
|
||||
|
||||
def addVarSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a bounded list of same-sized values
|
||||
Create a list of specific length with all items being of the same
|
||||
size.
|
||||
|
||||
Create a list of specific length with all items being of the same
|
||||
size
|
||||
:type seq: list of int
|
||||
:param seq: list of positive integers to encode
|
||||
|
||||
:type seq: list of int
|
||||
:param seq: list of positive integers to encode
|
||||
:type length: int
|
||||
:param length: amount of bytes in which to encode every item
|
||||
|
||||
:type length: int
|
||||
:param length: amount of bytes in which to encode every item
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: amount of bytes in which to encode the overall
|
||||
length of the array
|
||||
"""
|
||||
self.add(len(seq)*length, lengthLength)
|
||||
if length == 1:
|
||||
self.bytes.extend(seq)
|
||||
elif length == 2:
|
||||
self._addVarSeqTwo(seq)
|
||||
else:
|
||||
for i in seq:
|
||||
self.add(i, length)
|
||||
else:
|
||||
def addVarSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a bounded list of same-sized values
|
||||
|
||||
Create a list of specific length with all items being of the same
|
||||
size
|
||||
|
||||
:type seq: list of int
|
||||
:param seq: list of positive integers to encode
|
||||
|
||||
:type length: int
|
||||
:param length: amount of bytes in which to encode every item
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: amount of bytes in which to encode the overall
|
||||
length of the array
|
||||
"""
|
||||
seqLen = len(seq)
|
||||
self.add(seqLen*length, lengthLength)
|
||||
if length == 1:
|
||||
self.bytes.extend(seq)
|
||||
elif length == 2:
|
||||
try:
|
||||
self.bytes += pack('>' + 'H' * seqLen, *seq)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
else:
|
||||
for i in seq:
|
||||
self.add(i, length)
|
||||
:type lengthLength: int
|
||||
:param lengthLength: amount of bytes in which to encode the overall
|
||||
length of the array
|
||||
"""
|
||||
seqLen = len(seq)
|
||||
self.add(seqLen * length, lengthLength)
|
||||
if length == 1:
|
||||
self.bytes.extend(seq)
|
||||
elif length == 2:
|
||||
try:
|
||||
self.bytes += pack(">" + "H" * seqLen, *seq)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
else:
|
||||
for i in seq:
|
||||
self.add(i, length)
|
||||
|
||||
def addVarTupleSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
@@ -257,7 +169,7 @@ class Writer(object):
|
||||
self.bytes += data
|
||||
|
||||
|
||||
class Parser(object):
|
||||
class Parser:
|
||||
"""
|
||||
Parser for TLV and LV byte-based encodings.
|
||||
|
||||
@@ -269,9 +181,6 @@ class Parser(object):
|
||||
read a 4-byte integer from a 2-byte buffer), most methods will raise a
|
||||
DecodeError exception.
|
||||
|
||||
TODO: don't use an exception used by language parser to indicate errors
|
||||
in application code.
|
||||
|
||||
:vartype bytes: bytearray
|
||||
:ivar bytes: data to be interpreted (buffer)
|
||||
|
||||
@@ -285,14 +194,14 @@ class Parser(object):
|
||||
:ivar indexCheck: position at which the structure begins in buffer
|
||||
"""
|
||||
|
||||
def __init__(self, bytes):
|
||||
def __init__(self, data):
|
||||
"""
|
||||
Bind raw bytes with parser.
|
||||
|
||||
:type bytes: bytearray
|
||||
:param bytes: bytes to be parsed/interpreted
|
||||
:type data: bytearray
|
||||
:param data: bytes to be parsed/interpreted
|
||||
"""
|
||||
self.bytes = bytes
|
||||
self.bytes = data
|
||||
self.index = 0
|
||||
self.indexCheck = 0
|
||||
self.lengthCheck = 0
|
||||
@@ -307,7 +216,7 @@ class Parser(object):
|
||||
:rtype: int
|
||||
"""
|
||||
ret = self.getFixBytes(length)
|
||||
return bytes_to_int(ret, 'big')
|
||||
return bytes_to_int(ret, "big")
|
||||
|
||||
def getFixBytes(self, lengthBytes):
|
||||
"""
|
||||
@@ -358,10 +267,10 @@ class Parser(object):
|
||||
|
||||
:rtype: list of int
|
||||
"""
|
||||
l = [0] * lengthList
|
||||
result = [0] * lengthList
|
||||
for x in range(lengthList):
|
||||
l[x] = self.get(length)
|
||||
return l
|
||||
result[x] = self.get(length)
|
||||
return result
|
||||
|
||||
def getVarList(self, length, lengthLength):
|
||||
"""
|
||||
@@ -377,13 +286,12 @@ class Parser(object):
|
||||
"""
|
||||
lengthList = self.get(lengthLength)
|
||||
if lengthList % length != 0:
|
||||
raise DecodeError("Encoded length not a multiple of element "
|
||||
"length")
|
||||
raise DecodeError("Encoded length not a multiple of element length")
|
||||
lengthList = lengthList // length
|
||||
l = [0] * lengthList
|
||||
result = [0] * lengthList
|
||||
for x in range(lengthList):
|
||||
l[x] = self.get(length)
|
||||
return l
|
||||
result[x] = self.get(length)
|
||||
return result
|
||||
|
||||
def getVarTupleList(self, elemLength, elemNum, lengthLength):
|
||||
"""
|
||||
@@ -402,8 +310,7 @@ class Parser(object):
|
||||
"""
|
||||
lengthList = self.get(lengthLength)
|
||||
if lengthList % (elemLength * elemNum) != 0:
|
||||
raise DecodeError("Encoded length not a multiple of element "
|
||||
"length")
|
||||
raise DecodeError("Encoded length not a multiple of element length")
|
||||
tupleCount = lengthList // (elemLength * elemNum)
|
||||
tupleList = []
|
||||
for _ in range(tupleCount):
|
||||
|
||||
+71
-180
@@ -1,223 +1,114 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
# See the LICENSE file for legal information regarding use of this file
|
||||
|
||||
"""Miscellaneous functions to mask Python version differences."""
|
||||
"""Miscellaneous utility functions for Python 3.13+."""
|
||||
|
||||
import sys
|
||||
import re
|
||||
import platform
|
||||
import binascii
|
||||
import traceback
|
||||
import re
|
||||
import time
|
||||
|
||||
|
||||
if sys.version_info >= (3,0):
|
||||
def compat26Str(x):
|
||||
"""Identity function for compatibility."""
|
||||
return x
|
||||
|
||||
def compat26Str(x): return x
|
||||
|
||||
# Python 3.3 requires bytes instead of bytearrays for HMAC
|
||||
# So, python 2.6 requires strings, python 3 requires 'bytes',
|
||||
# and python 2.7 and 3.5 can handle bytearrays...
|
||||
# pylint: disable=invalid-name
|
||||
# we need to keep compatHMAC and `x` for API compatibility
|
||||
if sys.version_info < (3, 4):
|
||||
def compatHMAC(x):
|
||||
"""Convert bytes-like input to format acceptable for HMAC."""
|
||||
return bytes(x)
|
||||
else:
|
||||
def compatHMAC(x):
|
||||
"""Convert bytes-like input to format acceptable for HMAC."""
|
||||
return x
|
||||
# pylint: enable=invalid-name
|
||||
def compatHMAC(x):
|
||||
"""Convert bytes-like input to format acceptable for HMAC."""
|
||||
return x
|
||||
|
||||
def compatAscii2Bytes(val):
|
||||
"""Convert ASCII string to bytes."""
|
||||
if isinstance(val, str):
|
||||
return bytes(val, 'ascii')
|
||||
return val
|
||||
|
||||
def compat_b2a(val):
|
||||
"""Convert an ASCII bytes string to string."""
|
||||
return str(val, 'ascii')
|
||||
def compatAscii2Bytes(val):
|
||||
"""Convert ASCII string to bytes."""
|
||||
if isinstance(val, str):
|
||||
return bytes(val, "ascii")
|
||||
return val
|
||||
|
||||
def raw_input(s):
|
||||
return input(s)
|
||||
|
||||
# So, the python3 binascii module deals with bytearrays, and python2
|
||||
# deals with strings... I would rather deal with the "a" part as
|
||||
# strings, and the "b" part as bytearrays, regardless of python version,
|
||||
# so...
|
||||
def a2b_hex(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base16 error: %s" % e)
|
||||
return b
|
||||
|
||||
def a2b_base64(s):
|
||||
try:
|
||||
if isinstance(s, str):
|
||||
s = bytearray(s, "ascii")
|
||||
b = bytearray(binascii.a2b_base64(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base64 error: %s" % e)
|
||||
return b
|
||||
def compat_b2a(val):
|
||||
"""Convert an ASCII bytes string to string."""
|
||||
return str(val, "ascii")
|
||||
|
||||
def b2a_hex(b):
|
||||
return binascii.b2a_hex(b).decode("ascii")
|
||||
|
||||
def b2a_base64(b):
|
||||
return binascii.b2a_base64(b).decode("ascii")
|
||||
|
||||
def readStdinBinary():
|
||||
return sys.stdin.buffer.read()
|
||||
def a2b_hex(s):
|
||||
"""Convert hex string to bytearray."""
|
||||
try:
|
||||
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
|
||||
except Exception as e:
|
||||
raise SyntaxError(f"base16 error: {e}") from e
|
||||
return b
|
||||
|
||||
def compatLong(num):
|
||||
return int(num)
|
||||
|
||||
int_types = tuple([int])
|
||||
def a2b_base64(s):
|
||||
"""Convert base64 string to bytearray."""
|
||||
try:
|
||||
if isinstance(s, str):
|
||||
s = bytearray(s, "ascii")
|
||||
b = bytearray(binascii.a2b_base64(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError(f"base64 error: {e}") from e
|
||||
return b
|
||||
|
||||
def formatExceptionTrace(e):
|
||||
"""Return exception information formatted as string"""
|
||||
return str(e)
|
||||
|
||||
def time_stamp():
|
||||
"""Returns system time as a float"""
|
||||
if sys.version_info >= (3, 3):
|
||||
return time.perf_counter()
|
||||
return time.clock()
|
||||
def b2a_hex(b):
|
||||
"""Convert bytes to hex string."""
|
||||
return binascii.b2a_hex(b).decode("ascii")
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text, flags=re.UNICODE)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint is stupid here and deson't notice it's a function, not
|
||||
# constant
|
||||
bytes_to_int = int.from_bytes
|
||||
# pylint: enable=invalid-name
|
||||
def b2a_base64(b):
|
||||
"""Convert bytes to base64 string."""
|
||||
return binascii.b2a_base64(b).decode("ascii")
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
return val.bit_length()
|
||||
|
||||
def int_to_bytes(val, length=None, byteorder="big"):
|
||||
"""Return number converted to bytes"""
|
||||
if length is None:
|
||||
if val:
|
||||
length = byte_length(val)
|
||||
else:
|
||||
length = 1
|
||||
# for gmpy we need to convert back to native int
|
||||
if type(val) != int:
|
||||
val = int(val)
|
||||
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
|
||||
def readStdinBinary():
|
||||
"""Read binary data from stdin."""
|
||||
import sys
|
||||
|
||||
else:
|
||||
# Python 2.6 requires strings instead of bytearrays in a couple places,
|
||||
# so we define this function so it does the conversion if needed.
|
||||
# same thing with very old 2.7 versions
|
||||
# or on Jython
|
||||
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) \
|
||||
or platform.system() == 'Java':
|
||||
def compat26Str(x): return str(x)
|
||||
return sys.stdin.buffer.read()
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text)
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
if val == 0:
|
||||
return 0
|
||||
return len(bin(val))-2
|
||||
else:
|
||||
def compat26Str(x): return x
|
||||
def compatLong(num):
|
||||
"""Convert to int (compatibility function)."""
|
||||
return int(num)
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text, flags=re.UNICODE)
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
return val.bit_length()
|
||||
int_types = (int,)
|
||||
|
||||
def compatAscii2Bytes(val):
|
||||
"""Convert ASCII string to bytes."""
|
||||
return val
|
||||
|
||||
def compat_b2a(val):
|
||||
"""Convert an ASCII bytes string to string."""
|
||||
return str(val)
|
||||
def formatExceptionTrace(e):
|
||||
"""Return exception information formatted as string."""
|
||||
return str(e)
|
||||
|
||||
# So, python 2.6 requires strings, python 3 requires 'bytes',
|
||||
# and python 2.7 can handle bytearrays...
|
||||
def compatHMAC(x): return compat26Str(x)
|
||||
|
||||
def a2b_hex(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_hex(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base16 error: %s" % e)
|
||||
return b
|
||||
def time_stamp():
|
||||
"""Returns system time as a float."""
|
||||
return time.perf_counter()
|
||||
|
||||
def a2b_base64(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_base64(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base64 error: %s" % e)
|
||||
return b
|
||||
|
||||
def b2a_hex(b):
|
||||
return binascii.b2a_hex(compat26Str(b))
|
||||
|
||||
def b2a_base64(b):
|
||||
return binascii.b2a_base64(compat26Str(b))
|
||||
|
||||
def compatLong(num):
|
||||
return long(num)
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string."""
|
||||
return re.sub(r"\s+", "", text, flags=re.UNICODE)
|
||||
|
||||
int_types = (int, long)
|
||||
|
||||
# pylint on Python3 goes nuts for the sys dereferences...
|
||||
bytes_to_int = int.from_bytes
|
||||
|
||||
#pylint: disable=no-member
|
||||
def formatExceptionTrace(e):
|
||||
"""Return exception information formatted as string"""
|
||||
newStr = "".join(traceback.format_exception(sys.exc_type,
|
||||
sys.exc_value,
|
||||
sys.exc_traceback))
|
||||
return newStr
|
||||
#pylint: enable=no-member
|
||||
|
||||
def time_stamp():
|
||||
"""Returns system time as a float"""
|
||||
return time.clock()
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
return val.bit_length()
|
||||
|
||||
def bytes_to_int(val, byteorder):
|
||||
"""Convert bytes to an int."""
|
||||
if not val:
|
||||
return 0
|
||||
if byteorder == "big":
|
||||
return int(b2a_hex(val), 16)
|
||||
if byteorder == "little":
|
||||
return int(b2a_hex(val[::-1]), 16)
|
||||
raise ValueError("Only 'big' and 'little' endian supported")
|
||||
|
||||
def int_to_bytes(val, length=None, byteorder="big"):
|
||||
"""Return number converted to bytes"""
|
||||
if length is None:
|
||||
if val:
|
||||
length = byte_length(val)
|
||||
else:
|
||||
length = 1
|
||||
if byteorder == "big":
|
||||
return bytearray((val >> i) & 0xff
|
||||
for i in reversed(range(0, length*8, 8)))
|
||||
if byteorder == "little":
|
||||
return bytearray((val >> i) & 0xff
|
||||
for i in range(0, length*8, 8))
|
||||
raise ValueError("Only 'big' or 'little' endian supported")
|
||||
def int_to_bytes(val, length=None, byteorder="big"):
|
||||
"""Return number converted to bytes."""
|
||||
if length is None:
|
||||
if val:
|
||||
length = byte_length(val)
|
||||
else:
|
||||
length = 1
|
||||
# for gmpy we need to convert back to native int
|
||||
if not isinstance(val, int):
|
||||
val = int(val)
|
||||
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
|
||||
|
||||
|
||||
def byte_length(val):
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import division
|
||||
from .compat import compatHMAC
|
||||
import hmac
|
||||
|
||||
|
||||
def ct_lt_u32(val_a, val_b):
|
||||
"""
|
||||
Returns 1 if val_a < val_b, 0 otherwise. Constant time.
|
||||
@@ -18,10 +19,10 @@ def ct_lt_u32(val_a, val_b):
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val_a &= 0xffffffff
|
||||
val_b &= 0xffffffff
|
||||
val_a &= 0xFFFFFFFF
|
||||
val_b &= 0xFFFFFFFF
|
||||
|
||||
return (val_a^((val_a^val_b)|(((val_a-val_b)&0xffffffff)^val_b)))>>31
|
||||
return (val_a ^ ((val_a ^ val_b) | (((val_a - val_b) & 0xFFFFFFFF) ^ val_b))) >> 31
|
||||
|
||||
|
||||
def ct_gt_u32(val_a, val_b):
|
||||
@@ -77,8 +78,8 @@ def ct_isnonzero_u32(val):
|
||||
:param val: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val &= 0xffffffff
|
||||
return (val|(-val&0xffffffff)) >> 31
|
||||
val &= 0xFFFFFFFF
|
||||
return (val | (-val & 0xFFFFFFFF)) >> 31
|
||||
|
||||
|
||||
def ct_neq_u32(val_a, val_b):
|
||||
@@ -91,10 +92,11 @@ def ct_neq_u32(val_a, val_b):
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val_a &= 0xffffffff
|
||||
val_b &= 0xffffffff
|
||||
val_a &= 0xFFFFFFFF
|
||||
val_b &= 0xFFFFFFFF
|
||||
|
||||
return (((val_a - val_b) & 0xFFFFFFFF) | ((val_b - val_a) & 0xFFFFFFFF)) >> 31
|
||||
|
||||
return (((val_a-val_b)&0xffffffff) | ((val_b-val_a)&0xffffffff)) >> 31
|
||||
|
||||
def ct_eq_u32(val_a, val_b):
|
||||
"""
|
||||
@@ -108,8 +110,8 @@ def ct_eq_u32(val_a, val_b):
|
||||
"""
|
||||
return 1 ^ ct_neq_u32(val_a, val_b)
|
||||
|
||||
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
|
||||
block_size=16):
|
||||
|
||||
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version, block_size=16):
|
||||
"""
|
||||
Check CBC cipher HMAC and padding. Close to constant time.
|
||||
|
||||
@@ -135,7 +137,7 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
|
||||
assert version in ((3, 0), (3, 1), (3, 2), (3, 3))
|
||||
|
||||
data_len = len(data)
|
||||
if mac.digest_size + 1 > data_len: # data_len is public
|
||||
if mac.digest_size + 1 > data_len: # data_len is public
|
||||
return False
|
||||
|
||||
# 0 - OK
|
||||
@@ -144,11 +146,11 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
|
||||
#
|
||||
# check padding
|
||||
#
|
||||
pad_length = data[data_len-1]
|
||||
pad_length = data[data_len - 1]
|
||||
pad_start = data_len - pad_length - 1
|
||||
pad_start = max(0, pad_start)
|
||||
|
||||
if version == (3, 0): # version is public
|
||||
if version == (3, 0): # version is public
|
||||
# in SSLv3 we can only check if pad is not longer than the cipher
|
||||
# block size
|
||||
|
||||
@@ -179,33 +181,35 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
|
||||
data_mac = mac.copy()
|
||||
data_mac.update(compatHMAC(seqnumBytes))
|
||||
data_mac.update(compatHMAC(bytearray([contentType])))
|
||||
if version != (3, 0): # version is public
|
||||
if version != (3, 0): # version is public
|
||||
data_mac.update(compatHMAC(bytearray([version[0]])))
|
||||
data_mac.update(compatHMAC(bytearray([version[1]])))
|
||||
data_mac.update(compatHMAC(bytearray([mac_start >> 8])))
|
||||
data_mac.update(compatHMAC(bytearray([mac_start & 0xff])))
|
||||
data_mac.update(compatHMAC(bytearray([mac_start & 0xFF])))
|
||||
data_mac.update(compatHMAC(data[:start_pos]))
|
||||
|
||||
# don't check past the array end (already checked to be >= zero)
|
||||
end_pos = data_len - mac.digest_size
|
||||
|
||||
# calculate all possible
|
||||
for i in range(start_pos, end_pos): # constant for given overall length
|
||||
for i in range(start_pos, end_pos): # constant for given overall length
|
||||
cur_mac = data_mac.copy()
|
||||
cur_mac.update(compatHMAC(data[start_pos:i]))
|
||||
mac_compare = bytearray(cur_mac.digest())
|
||||
# compare the hash for real only if it's the place where mac is
|
||||
# supposed to be
|
||||
mask = ct_lsb_prop_u8(ct_eq_u32(i, mac_start))
|
||||
for j in range(0, mac.digest_size): # digest_size is public
|
||||
result |= (data[i+j] ^ mac_compare[j]) & mask
|
||||
for j in range(0, mac.digest_size): # digest_size is public
|
||||
result |= (data[i + j] ^ mac_compare[j]) & mask
|
||||
|
||||
# return python boolean
|
||||
return result == 0
|
||||
|
||||
if hasattr(hmac, 'compare_digest'):
|
||||
|
||||
if hasattr(hmac, "compare_digest"):
|
||||
ct_compare_digest = hmac.compare_digest
|
||||
else:
|
||||
|
||||
def ct_compare_digest(val_a, val_b):
|
||||
"""Compares if string like objects are equal. Constant time."""
|
||||
if len(val_a) != len(val_b):
|
||||
|
||||
@@ -52,7 +52,7 @@ class EncryptionHandler:
|
||||
del data["ip"] # Remove IP from the data
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
|
||||
|
||||
+110
-106
@@ -1,22 +1,27 @@
|
||||
# Authors:
|
||||
# Authors:
|
||||
# Trevor Perrin
|
||||
# Martin von Loewis - python 3 port
|
||||
# Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
|
||||
#
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
# See the LICENSE file for legal information regarding use of this file
|
||||
|
||||
"""cryptomath module
|
||||
|
||||
This module has basic math/crypto code."""
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import math
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
from .compat import compat26Str, compatHMAC, compatLong, \
|
||||
bytes_to_int, int_to_bytes, bit_length, byte_length
|
||||
import math
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from .codec import Writer
|
||||
from .compat import (
|
||||
bit_length,
|
||||
byte_length,
|
||||
bytes_to_int,
|
||||
compat26Str,
|
||||
compatHMAC,
|
||||
int_to_bytes,
|
||||
)
|
||||
|
||||
from . import tlshashlib as hashlib
|
||||
from . import tlshmac as hmac
|
||||
@@ -33,27 +38,31 @@ pycryptoLoaded = False
|
||||
# **************************************************************************
|
||||
|
||||
# Check that os.urandom works
|
||||
import zlib
|
||||
assert len(zlib.compress(os.urandom(1000))) > 900
|
||||
|
||||
|
||||
def getRandomBytes(howMany):
|
||||
b = bytearray(os.urandom(howMany))
|
||||
assert(len(b) == howMany)
|
||||
assert len(b) == howMany
|
||||
return b
|
||||
|
||||
|
||||
prngName = "os.urandom"
|
||||
|
||||
# **************************************************************************
|
||||
# Simple hash functions
|
||||
# **************************************************************************
|
||||
|
||||
|
||||
def MD5(b):
|
||||
"""Return a MD5 digest of data"""
|
||||
return secureHash(b, 'md5')
|
||||
return secureHash(b, "md5")
|
||||
|
||||
|
||||
def SHA1(b):
|
||||
"""Return a SHA1 digest of data"""
|
||||
return secureHash(b, 'sha1')
|
||||
return secureHash(b, "sha1")
|
||||
|
||||
|
||||
def secureHash(data, algorithm):
|
||||
"""Return a digest of `data` using `algorithm`"""
|
||||
@@ -61,33 +70,40 @@ def secureHash(data, algorithm):
|
||||
hashInstance.update(compat26Str(data))
|
||||
return bytearray(hashInstance.digest())
|
||||
|
||||
|
||||
def secureHMAC(k, b, algorithm):
|
||||
"""Return a HMAC using `b` and `k` using `algorithm`"""
|
||||
k = compatHMAC(k)
|
||||
b = compatHMAC(b)
|
||||
return bytearray(hmac.new(k, b, getattr(hashlib, algorithm)).digest())
|
||||
|
||||
|
||||
def HMAC_MD5(k, b):
|
||||
return secureHMAC(k, b, 'md5')
|
||||
return secureHMAC(k, b, "md5")
|
||||
|
||||
|
||||
def HMAC_SHA1(k, b):
|
||||
return secureHMAC(k, b, 'sha1')
|
||||
return secureHMAC(k, b, "sha1")
|
||||
|
||||
|
||||
def HMAC_SHA256(k, b):
|
||||
return secureHMAC(k, b, 'sha256')
|
||||
return secureHMAC(k, b, "sha256")
|
||||
|
||||
|
||||
def HMAC_SHA384(k, b):
|
||||
return secureHMAC(k, b, 'sha384')
|
||||
return secureHMAC(k, b, "sha384")
|
||||
|
||||
|
||||
def HKDF_expand(PRK, info, L, algorithm):
|
||||
N = divceil(L, getattr(hashlib, algorithm)().digest_size)
|
||||
T = bytearray()
|
||||
Titer = bytearray()
|
||||
for x in range(1, N+2):
|
||||
for x in range(1, N + 2):
|
||||
T += Titer
|
||||
Titer = secureHMAC(PRK, Titer + info + bytearray([x]), algorithm)
|
||||
return T[:L]
|
||||
|
||||
|
||||
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
|
||||
"""
|
||||
TLS1.3 key derivation function (HKDF-Expand-Label).
|
||||
@@ -108,6 +124,7 @@ def HKDF_expand_label(secret, label, hashValue, length, algorithm):
|
||||
|
||||
return HKDF_expand(secret, hkdfLabel.bytes, length, algorithm)
|
||||
|
||||
|
||||
def derive_secret(secret, label, handshake_hashes, algorithm):
|
||||
"""
|
||||
TLS1.3 key derivation function (Derive-Secret).
|
||||
@@ -123,17 +140,17 @@ def derive_secret(secret, label, handshake_hashes, algorithm):
|
||||
:rtype: bytearray
|
||||
"""
|
||||
if handshake_hashes is None:
|
||||
hs_hash = secureHash(bytearray(b''), algorithm)
|
||||
hs_hash = secureHash(bytearray(b""), algorithm)
|
||||
else:
|
||||
hs_hash = handshake_hashes.digest(algorithm)
|
||||
return HKDF_expand_label(secret, label, hs_hash,
|
||||
getattr(hashlib, algorithm)().digest_size,
|
||||
algorithm)
|
||||
return HKDF_expand_label(secret, label, hs_hash, getattr(hashlib, algorithm)().digest_size, algorithm)
|
||||
|
||||
|
||||
# **************************************************************************
|
||||
# Converter Functions
|
||||
# **************************************************************************
|
||||
|
||||
|
||||
def bytesToNumber(b, endian="big"):
|
||||
"""
|
||||
Convert a number stored in bytearray to an integer.
|
||||
@@ -156,7 +173,7 @@ def numberToByteArray(n, howManyBytes=None, endian="big"):
|
||||
if howManyBytes < length:
|
||||
ret = int_to_bytes(n, length, endian)
|
||||
if endian == "big":
|
||||
return ret[length-howManyBytes:length]
|
||||
return ret[length - howManyBytes : length]
|
||||
return ret[:howManyBytes]
|
||||
return int_to_bytes(n, howManyBytes, endian)
|
||||
|
||||
@@ -172,12 +189,12 @@ def mpiToNumber(mpi):
|
||||
def numberToMPI(n):
|
||||
b = numberToByteArray(n)
|
||||
ext = 0
|
||||
#If the high-order bit is going to be set,
|
||||
#add an extra byte of zeros
|
||||
if (numBits(n) & 0x7)==0:
|
||||
# If the high-order bit is going to be set,
|
||||
# add an extra byte of zeros
|
||||
if (numBits(n) & 0x7) == 0:
|
||||
ext = 1
|
||||
length = numBytes(n) + ext
|
||||
b = bytearray(4+ext) + b
|
||||
b = bytearray(4 + ext) + b
|
||||
b[0] = (length >> 24) & 0xFF
|
||||
b[1] = (length >> 16) & 0xFF
|
||||
b[2] = (length >> 8) & 0xFF
|
||||
@@ -190,75 +207,57 @@ def numberToMPI(n):
|
||||
# **************************************************************************
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint recognises them as constants, not function names, also
|
||||
# we can't change their names without API change
|
||||
numBits = bit_length
|
||||
|
||||
|
||||
numBytes = byte_length
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
# **************************************************************************
|
||||
# Big Number Math
|
||||
# **************************************************************************
|
||||
|
||||
|
||||
def getRandomNumber(low, high):
|
||||
assert low < high
|
||||
howManyBits = numBits(high)
|
||||
howManyBytes = numBytes(high)
|
||||
lastBits = howManyBits % 8
|
||||
while 1:
|
||||
bytes = getRandomBytes(howManyBytes)
|
||||
while True:
|
||||
random_bytes = getRandomBytes(howManyBytes)
|
||||
if lastBits:
|
||||
bytes[0] = bytes[0] % (1 << lastBits)
|
||||
n = bytesToNumber(bytes)
|
||||
if n >= low and n < high:
|
||||
random_bytes[0] = random_bytes[0] % (1 << lastBits)
|
||||
n = bytesToNumber(random_bytes)
|
||||
if low <= n < high:
|
||||
return n
|
||||
|
||||
def gcd(a,b):
|
||||
a, b = max(a,b), min(a,b)
|
||||
|
||||
def gcd(a, b):
|
||||
a, b = max(a, b), min(a, b)
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
|
||||
def lcm(a, b):
|
||||
return (a * b) // gcd(a, b)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# disable pylint check as the (a, b) are part of the API
|
||||
if GMPY2_LOADED:
|
||||
def invMod(a, b):
|
||||
"""Return inverse of a mod b, zero if none."""
|
||||
if a == 0:
|
||||
return 0
|
||||
return powmod(a, -1, b)
|
||||
else:
|
||||
# Use Extended Euclidean Algorithm
|
||||
def invMod(a, b):
|
||||
"""Return inverse of a mod b, zero if none."""
|
||||
c, d = a, b
|
||||
uc, ud = 1, 0
|
||||
while c != 0:
|
||||
q = d // c
|
||||
c, d = d-(q*c), c
|
||||
uc, ud = ud - (q * uc), uc
|
||||
if d == 1:
|
||||
return ud % b
|
||||
return 0
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
def invMod(a, b):
|
||||
"""Return inverse of a mod b, zero if none."""
|
||||
c, d = a, b
|
||||
uc, ud = 1, 0
|
||||
while c != 0:
|
||||
q = d // c
|
||||
c, d = d - (q * c), c
|
||||
uc, ud = ud - (q * uc), uc
|
||||
if d == 1:
|
||||
return ud % b
|
||||
return 0
|
||||
|
||||
|
||||
if gmpyLoaded or GMPY2_LOADED:
|
||||
def powMod(base, power, modulus):
|
||||
base = mpz(base)
|
||||
power = mpz(power)
|
||||
modulus = mpz(modulus)
|
||||
result = pow(base, power, modulus)
|
||||
return compatLong(result)
|
||||
else:
|
||||
powMod = pow
|
||||
# Use built-in pow for modular exponentiation (Python 3 handles this efficiently)
|
||||
powMod = pow
|
||||
|
||||
|
||||
def divceil(divident, divisor):
|
||||
@@ -267,10 +266,10 @@ def divceil(divident, divisor):
|
||||
return quot + int(bool(r))
|
||||
|
||||
|
||||
#Pre-calculate a sieve of the ~100 primes < 1000:
|
||||
# Pre-calculate a sieve of the ~100 primes < 1000:
|
||||
def makeSieve(n):
|
||||
sieve = list(range(n))
|
||||
for count in range(2, int(math.sqrt(n))+1):
|
||||
for count in range(2, int(math.sqrt(n)) + 1):
|
||||
if sieve[count] == 0:
|
||||
continue
|
||||
x = sieve[count] * 2
|
||||
@@ -280,30 +279,34 @@ def makeSieve(n):
|
||||
sieve = [x for x in sieve[2:] if x]
|
||||
return sieve
|
||||
|
||||
|
||||
def isPrime(n, iterations=5, display=False, sieve=makeSieve(1000)):
|
||||
#Trial division with sieve
|
||||
# Trial division with sieve
|
||||
for x in sieve:
|
||||
if x >= n: return True
|
||||
if n % x == 0: return False
|
||||
#Passed trial division, proceed to Rabin-Miller
|
||||
#Rabin-Miller implemented per Ferguson & Schneier
|
||||
#Compute s, t for Rabin-Miller
|
||||
if display: print("*", end=' ')
|
||||
s, t = n-1, 0
|
||||
if x >= n:
|
||||
return True
|
||||
if n % x == 0:
|
||||
return False
|
||||
# Passed trial division, proceed to Rabin-Miller
|
||||
# Rabin-Miller implemented per Ferguson & Schneier
|
||||
# Compute s, t for Rabin-Miller
|
||||
if display:
|
||||
print("*", end=" ")
|
||||
s, t = n - 1, 0
|
||||
while s % 2 == 0:
|
||||
s, t = s//2, t+1
|
||||
#Repeat Rabin-Miller x times
|
||||
a = 2 #Use 2 as a base for first iteration speedup, per HAC
|
||||
for count in range(iterations):
|
||||
s, t = s // 2, t + 1
|
||||
# Repeat Rabin-Miller x times
|
||||
a = 2 # Use 2 as a base for first iteration speedup, per HAC
|
||||
for _ in range(iterations):
|
||||
v = powMod(a, s, n)
|
||||
if v==1:
|
||||
if v == 1:
|
||||
continue
|
||||
i = 0
|
||||
while v != n-1:
|
||||
if i == t-1:
|
||||
while v != n - 1:
|
||||
if i == t - 1:
|
||||
return False
|
||||
else:
|
||||
v, i = powMod(v, 2, n), i+1
|
||||
v, i = powMod(v, 2, n), i + 1
|
||||
a = getRandomNumber(2, n)
|
||||
return True
|
||||
|
||||
@@ -316,16 +319,16 @@ def getRandomPrime(bits, display=False):
|
||||
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
|
||||
"""
|
||||
assert bits >= 10
|
||||
#The 1.5 ensures the 2 MSBs are set
|
||||
#Thus, when used for p,q in RSA, n will have its MSB set
|
||||
# The 1.5 ensures the 2 MSBs are set
|
||||
# Thus, when used for p,q in RSA, n will have its MSB set
|
||||
#
|
||||
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
#29 % 30 and keep them there
|
||||
low = ((2 ** (bits-1)) * 3) // 2
|
||||
high = 2 ** bits - 30
|
||||
# Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
# 29 % 30 and keep them there
|
||||
low = ((2 ** (bits - 1)) * 3) // 2
|
||||
high = 2**bits - 30
|
||||
while True:
|
||||
if display:
|
||||
print(".", end=' ')
|
||||
print(".", end=" ")
|
||||
cand_p = getRandomNumber(low, high)
|
||||
# make odd
|
||||
if cand_p % 2 == 0:
|
||||
@@ -334,7 +337,7 @@ def getRandomPrime(bits, display=False):
|
||||
return cand_p
|
||||
|
||||
|
||||
#Unused at the moment...
|
||||
# Unused at the moment...
|
||||
def getRandomSafePrime(bits, display=False):
|
||||
"""Generate a random safe prime.
|
||||
|
||||
@@ -342,23 +345,24 @@ def getRandomSafePrime(bits, display=False):
|
||||
the (p-1)/2 will also be prime.
|
||||
"""
|
||||
assert bits >= 10
|
||||
#The 1.5 ensures the 2 MSBs are set
|
||||
#Thus, when used for p,q in RSA, n will have its MSB set
|
||||
# The 1.5 ensures the 2 MSBs are set
|
||||
# Thus, when used for p,q in RSA, n will have its MSB set
|
||||
#
|
||||
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
#29 % 30 and keep them there
|
||||
low = (2 ** (bits-2)) * 3//2
|
||||
high = (2 ** (bits-1)) - 30
|
||||
# Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
# 29 % 30 and keep them there
|
||||
low = (2 ** (bits - 2)) * 3 // 2
|
||||
high = (2 ** (bits - 1)) - 30
|
||||
q = getRandomNumber(low, high)
|
||||
q += 29 - (q % 30)
|
||||
while 1:
|
||||
if display: print(".", end=' ')
|
||||
while True:
|
||||
if display:
|
||||
print(".", end=" ")
|
||||
q += 30
|
||||
if (q >= high):
|
||||
if q >= high:
|
||||
q = getRandomNumber(low, high)
|
||||
q += 29 - (q % 30)
|
||||
#Ideas from Tom Wu's SRP code
|
||||
#Do trial division on p and q before Rabin-Miller
|
||||
# Ideas from Tom Wu's SRP code
|
||||
# Do trial division on p and q before Rabin-Miller
|
||||
if isPrime(q, 0, display=display):
|
||||
p = (2 * q) + 1
|
||||
if isPrime(p, display=display):
|
||||
|
||||
@@ -1,373 +1,402 @@
|
||||
import logging
|
||||
import psutil
|
||||
from typing import Dict, Optional, List
|
||||
from urllib.parse import urljoin
|
||||
import xmltodict
|
||||
from mediaflow_proxy.utils.http_utils import create_httpx_client
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DASHPreBuffer:
|
||||
"""
|
||||
Pre-buffer system for DASH streams to reduce latency and improve streaming performance.
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_size: Optional[int] = None, prebuffer_segments: Optional[int] = None):
|
||||
"""
|
||||
Initialize the DASH pre-buffer system.
|
||||
|
||||
Args:
|
||||
max_cache_size (int): Maximum number of segments to cache (uses config if None)
|
||||
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
|
||||
"""
|
||||
self.max_cache_size = max_cache_size or settings.dash_prebuffer_cache_size
|
||||
self.prebuffer_segments = prebuffer_segments or settings.dash_prebuffer_segments
|
||||
self.max_memory_percent = settings.dash_prebuffer_max_memory_percent
|
||||
self.emergency_threshold = settings.dash_prebuffer_emergency_threshold
|
||||
|
||||
# Cache for different types of DASH content
|
||||
self.segment_cache: Dict[str, bytes] = {}
|
||||
self.init_segment_cache: Dict[str, bytes] = {}
|
||||
self.manifest_cache: Dict[str, dict] = {}
|
||||
|
||||
# Track segment URLs for each adaptation set
|
||||
self.adaptation_segments: Dict[str, List[str]] = {}
|
||||
self.client = create_httpx_client()
|
||||
|
||||
def _get_memory_usage_percent(self) -> float:
|
||||
"""
|
||||
Get current memory usage percentage.
|
||||
|
||||
Returns:
|
||||
float: Memory usage percentage
|
||||
"""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
return memory.percent
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get memory usage: {e}")
|
||||
return 0.0
|
||||
|
||||
def _check_memory_threshold(self) -> bool:
|
||||
"""
|
||||
Check if memory usage exceeds the emergency threshold.
|
||||
|
||||
Returns:
|
||||
bool: True if emergency cleanup is needed
|
||||
"""
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
return memory_percent > self.emergency_threshold
|
||||
|
||||
def _emergency_cache_cleanup(self) -> None:
|
||||
"""
|
||||
Perform emergency cache cleanup when memory usage is high.
|
||||
"""
|
||||
if self._check_memory_threshold():
|
||||
logger.warning("Emergency DASH cache cleanup triggered due to high memory usage")
|
||||
|
||||
# Clear 50% of segment cache
|
||||
segment_cache_size = len(self.segment_cache)
|
||||
segment_keys_to_remove = list(self.segment_cache.keys())[:segment_cache_size // 2]
|
||||
for key in segment_keys_to_remove:
|
||||
del self.segment_cache[key]
|
||||
|
||||
# Clear 50% of init segment cache
|
||||
init_cache_size = len(self.init_segment_cache)
|
||||
init_keys_to_remove = list(self.init_segment_cache.keys())[:init_cache_size // 2]
|
||||
for key in init_keys_to_remove:
|
||||
del self.init_segment_cache[key]
|
||||
|
||||
logger.info(f"Emergency cleanup removed {len(segment_keys_to_remove)} segments and {len(init_keys_to_remove)} init segments from cache")
|
||||
|
||||
async def prebuffer_dash_manifest(self, mpd_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer segments from a DASH manifest.
|
||||
|
||||
Args:
|
||||
mpd_url (str): URL of the DASH manifest
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
# Download and parse MPD manifest
|
||||
response = await self.client.get(mpd_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
mpd_content = response.text
|
||||
|
||||
# Parse MPD XML
|
||||
mpd_dict = xmltodict.parse(mpd_content)
|
||||
|
||||
# Store manifest in cache
|
||||
self.manifest_cache[mpd_url] = mpd_dict
|
||||
|
||||
# Extract initialization segments and first few segments
|
||||
await self._extract_and_prebuffer_segments(mpd_dict, mpd_url, headers)
|
||||
|
||||
logger.info(f"Pre-buffered DASH manifest: {mpd_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer DASH manifest {mpd_url}: {e}")
|
||||
|
||||
async def _extract_and_prebuffer_segments(self, mpd_dict: dict, base_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Extract and pre-buffer segments from MPD manifest.
|
||||
|
||||
Args:
|
||||
mpd_dict (dict): Parsed MPD manifest
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
# Extract Period and AdaptationSet information
|
||||
mpd = mpd_dict.get('MPD', {})
|
||||
periods = mpd.get('Period', [])
|
||||
if not isinstance(periods, list):
|
||||
periods = [periods]
|
||||
|
||||
for period in periods:
|
||||
adaptation_sets = period.get('AdaptationSet', [])
|
||||
if not isinstance(adaptation_sets, list):
|
||||
adaptation_sets = [adaptation_sets]
|
||||
|
||||
for adaptation_set in adaptation_sets:
|
||||
# Extract initialization segment
|
||||
init_segment = adaptation_set.get('SegmentTemplate', {}).get('@initialization')
|
||||
if init_segment:
|
||||
init_url = urljoin(base_url, init_segment)
|
||||
await self._download_init_segment(init_url, headers)
|
||||
|
||||
# Extract segment template
|
||||
segment_template = adaptation_set.get('SegmentTemplate', {})
|
||||
if segment_template:
|
||||
await self._prebuffer_template_segments(segment_template, base_url, headers)
|
||||
|
||||
# Extract segment list
|
||||
segment_list = adaptation_set.get('SegmentList', {})
|
||||
if segment_list:
|
||||
await self._prebuffer_list_segments(segment_list, base_url, headers)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract segments from MPD: {e}")
|
||||
|
||||
async def _download_init_segment(self, init_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Download and cache initialization segment.
|
||||
|
||||
Args:
|
||||
init_url (str): URL of the initialization segment
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
"""
|
||||
try:
|
||||
# Check memory usage before downloading
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping init segment download")
|
||||
return
|
||||
|
||||
response = await self.client.get(init_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cache the init segment
|
||||
self.init_segment_cache[init_url] = response.content
|
||||
|
||||
# Check for emergency cleanup
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
|
||||
logger.debug(f"Cached init segment: {init_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download init segment {init_url}: {e}")
|
||||
|
||||
async def _prebuffer_template_segments(self, segment_template: dict, base_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer segments using segment template.
|
||||
|
||||
Args:
|
||||
segment_template (dict): Segment template from MPD
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
media_template = segment_template.get('@media')
|
||||
if not media_template:
|
||||
return
|
||||
|
||||
# Extract template parameters
|
||||
start_number = int(segment_template.get('@startNumber', 1))
|
||||
duration = float(segment_template.get('@duration', 0))
|
||||
timescale = float(segment_template.get('@timescale', 1))
|
||||
|
||||
# Pre-buffer first few segments
|
||||
for i in range(self.prebuffer_segments):
|
||||
segment_number = start_number + i
|
||||
segment_url = media_template.replace('$Number$', str(segment_number))
|
||||
full_url = urljoin(base_url, segment_url)
|
||||
|
||||
await self._download_segment(full_url, headers)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer template segments: {e}")
|
||||
|
||||
async def _prebuffer_list_segments(self, segment_list: dict, base_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer segments from segment list.
|
||||
|
||||
Args:
|
||||
segment_list (dict): Segment list from MPD
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
segments = segment_list.get('SegmentURL', [])
|
||||
if not isinstance(segments, list):
|
||||
segments = [segments]
|
||||
|
||||
# Pre-buffer first few segments
|
||||
for segment in segments[:self.prebuffer_segments]:
|
||||
segment_url = segment.get('@src')
|
||||
if segment_url:
|
||||
full_url = urljoin(base_url, segment_url)
|
||||
await self._download_segment(full_url, headers)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer list segments: {e}")
|
||||
|
||||
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Download a single segment and cache it.
|
||||
|
||||
Args:
|
||||
segment_url (str): URL of the segment to download
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
"""
|
||||
try:
|
||||
# Check memory usage before downloading
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping segment download")
|
||||
return
|
||||
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cache the segment
|
||||
self.segment_cache[segment_url] = response.content
|
||||
|
||||
# Check for emergency cleanup
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
# Maintain cache size
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
# Remove oldest entries (simple FIFO)
|
||||
oldest_key = next(iter(self.segment_cache))
|
||||
del self.segment_cache[oldest_key]
|
||||
|
||||
logger.debug(f"Cached DASH segment: {segment_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download DASH segment {segment_url}: {e}")
|
||||
|
||||
async def get_segment(self, segment_url: str, headers: Dict[str, str]) -> Optional[bytes]:
|
||||
"""
|
||||
Get a segment from cache or download it.
|
||||
|
||||
Args:
|
||||
segment_url (str): URL of the segment
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
|
||||
Returns:
|
||||
Optional[bytes]: Cached segment data or None if not available
|
||||
"""
|
||||
# Check segment cache first
|
||||
if segment_url in self.segment_cache:
|
||||
logger.debug(f"DASH cache hit for segment: {segment_url}")
|
||||
return self.segment_cache[segment_url]
|
||||
|
||||
# Check init segment cache
|
||||
if segment_url in self.init_segment_cache:
|
||||
logger.debug(f"DASH cache hit for init segment: {segment_url}")
|
||||
return self.init_segment_cache[segment_url]
|
||||
|
||||
# Check memory usage before downloading
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
|
||||
return None
|
||||
|
||||
# Download if not in cache
|
||||
try:
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
segment_data = response.content
|
||||
|
||||
# Determine if it's an init segment or regular segment
|
||||
if 'init' in segment_url.lower() or segment_url.endswith('.mp4'):
|
||||
self.init_segment_cache[segment_url] = segment_data
|
||||
else:
|
||||
self.segment_cache[segment_url] = segment_data
|
||||
|
||||
# Check for emergency cleanup
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
# Maintain cache size
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
oldest_key = next(iter(self.segment_cache))
|
||||
del self.segment_cache[oldest_key]
|
||||
|
||||
logger.debug(f"Downloaded and cached DASH segment: {segment_url}")
|
||||
return segment_data
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get DASH segment {segment_url}: {e}")
|
||||
return None
|
||||
|
||||
async def get_manifest(self, mpd_url: str, headers: Dict[str, str]) -> Optional[dict]:
|
||||
"""
|
||||
Get MPD manifest from cache or download it.
|
||||
|
||||
Args:
|
||||
mpd_url (str): URL of the MPD manifest
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Cached manifest data or None if not available
|
||||
"""
|
||||
# Check cache first
|
||||
if mpd_url in self.manifest_cache:
|
||||
logger.debug(f"DASH cache hit for manifest: {mpd_url}")
|
||||
return self.manifest_cache[mpd_url]
|
||||
|
||||
# Download if not in cache
|
||||
try:
|
||||
response = await self.client.get(mpd_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
mpd_content = response.text
|
||||
mpd_dict = xmltodict.parse(mpd_content)
|
||||
|
||||
# Cache the manifest
|
||||
self.manifest_cache[mpd_url] = mpd_dict
|
||||
|
||||
logger.debug(f"Downloaded and cached DASH manifest: {mpd_url}")
|
||||
return mpd_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get DASH manifest {mpd_url}: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the DASH cache."""
|
||||
self.segment_cache.clear()
|
||||
self.init_segment_cache.clear()
|
||||
self.manifest_cache.clear()
|
||||
self.adaptation_segments.clear()
|
||||
logger.info("DASH pre-buffer cache cleared")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the pre-buffer system."""
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
# Global DASH pre-buffer instance
|
||||
dash_prebuffer = DASHPreBuffer()
|
||||
"""
|
||||
DASH Pre-buffer system for reducing latency and improving streaming performance.
|
||||
|
||||
This module extends BasePrebuffer with DASH-specific functionality including
|
||||
MPD parsing integration, profile handling, and init segment management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from mediaflow_proxy.utils.base_prebuffer import BasePrebuffer
|
||||
from mediaflow_proxy.utils.cache_utils import (
|
||||
get_cached_mpd,
|
||||
get_cached_init_segment,
|
||||
)
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DASHPreBuffer(BasePrebuffer):
|
||||
"""
|
||||
Pre-buffer system for DASH streams.
|
||||
|
||||
Extends BasePrebuffer with DASH-specific features:
|
||||
- MPD manifest parsing and profile handling
|
||||
- Init segment prebuffering
|
||||
- Live stream segment tracking
|
||||
- Profile-based segment prefetching
|
||||
|
||||
Uses event-based download coordination from BasePrebuffer to prevent
|
||||
duplicate downloads between player requests and background prebuffering.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: Optional[int] = None,
|
||||
prebuffer_segments: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the DASH pre-buffer system.
|
||||
|
||||
Args:
|
||||
max_cache_size: Maximum number of segments to cache (uses config if None)
|
||||
prebuffer_segments: Number of segments to pre-buffer ahead (uses config if None)
|
||||
"""
|
||||
super().__init__(
|
||||
max_cache_size=max_cache_size or settings.dash_prebuffer_cache_size,
|
||||
prebuffer_segments=prebuffer_segments or settings.dash_prebuffer_segments,
|
||||
max_memory_percent=settings.dash_prebuffer_max_memory_percent,
|
||||
emergency_threshold=settings.dash_prebuffer_emergency_threshold,
|
||||
segment_ttl=settings.dash_segment_cache_ttl,
|
||||
)
|
||||
|
||||
self.inactivity_timeout = settings.dash_prebuffer_inactivity_timeout
|
||||
|
||||
# DASH-specific state
|
||||
# Track active streams for prefetching: mpd_url -> stream_info
|
||||
self.active_streams: Dict[str, dict] = {}
|
||||
self.prefetch_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Additional stats for DASH
|
||||
self.init_segments_prebuffered = 0
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""Log current prebuffer statistics with DASH-specific info."""
|
||||
stats = self.stats.to_dict()
|
||||
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
|
||||
stats["active_streams"] = len(self.active_streams)
|
||||
logger.info(f"DASH Prebuffer Stats: {stats}")
|
||||
|
||||
async def prebuffer_dash_manifest(
|
||||
self,
|
||||
mpd_url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Pre-buffer segments from a DASH manifest using existing MPD parsing.
|
||||
|
||||
Args:
|
||||
mpd_url: URL of the DASH manifest
|
||||
headers: Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
# First get the basic MPD info without segments
|
||||
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
|
||||
if not parsed_mpd:
|
||||
logger.warning(f"Failed to get parsed MPD for prebuffering: {mpd_url}")
|
||||
return
|
||||
|
||||
is_live = parsed_mpd.get("isLive", False)
|
||||
base_profiles = parsed_mpd.get("profiles", [])
|
||||
|
||||
if not base_profiles:
|
||||
logger.warning(f"No profiles found in MPD for prebuffering: {mpd_url}")
|
||||
return
|
||||
|
||||
# Now get segments for each profile by parsing with profile_id
|
||||
profiles_with_segments = []
|
||||
for profile in base_profiles:
|
||||
profile_id = profile.get("id")
|
||||
if profile_id:
|
||||
parsed_with_segments = await get_cached_mpd(
|
||||
mpd_url, headers, parse_drm=False, parse_segment_profile_id=profile_id
|
||||
)
|
||||
# Find the matching profile with segments
|
||||
for p in parsed_with_segments.get("profiles", []):
|
||||
if p.get("id") == profile_id:
|
||||
profiles_with_segments.append(p)
|
||||
break
|
||||
|
||||
# Store stream info for ongoing prefetching
|
||||
self.active_streams[mpd_url] = {
|
||||
"headers": headers,
|
||||
"is_live": is_live,
|
||||
"profiles": profiles_with_segments,
|
||||
"last_access": time.time(),
|
||||
}
|
||||
|
||||
# Prebuffer init segments and media segments
|
||||
await self._prebuffer_profiles(profiles_with_segments, headers, is_live)
|
||||
|
||||
# Start cleanup task if not running
|
||||
self._ensure_cleanup_task_running()
|
||||
|
||||
logger.info(
|
||||
f"Pre-buffered DASH manifest: {mpd_url} (live={is_live}, profiles={len(profiles_with_segments)})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer DASH manifest {mpd_url}: {e}")
|
||||
|
||||
async def _prebuffer_profiles(
|
||||
self,
|
||||
profiles: List[dict],
|
||||
headers: Dict[str, str],
|
||||
is_live: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Pre-buffer init segments and media segments for all profiles.
|
||||
|
||||
For live streams, prebuffers from the END of the segment list.
|
||||
For VOD, prebuffers from the beginning.
|
||||
|
||||
Args:
|
||||
profiles: List of parsed profiles with resolved URLs
|
||||
headers: Headers to use for requests
|
||||
is_live: Whether this is a live stream
|
||||
"""
|
||||
if self._should_skip_for_memory():
|
||||
logger.warning("Memory usage too high, skipping prebuffer")
|
||||
return
|
||||
|
||||
# Collect all segment URLs to prebuffer
|
||||
segment_urls = []
|
||||
init_urls = []
|
||||
|
||||
for profile in profiles:
|
||||
# Collect init segment URL
|
||||
init_url = profile.get("initUrl")
|
||||
if init_url:
|
||||
init_urls.append(init_url)
|
||||
|
||||
# Get segments to prebuffer
|
||||
segments = profile.get("segments", [])
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
# For live streams, prebuffer from the END (most recent)
|
||||
if is_live:
|
||||
segments_to_buffer = segments[-self.prebuffer_segment_count :]
|
||||
else:
|
||||
segments_to_buffer = segments[: self.prebuffer_segment_count]
|
||||
|
||||
for segment in segments_to_buffer:
|
||||
segment_url = segment.get("media")
|
||||
if segment_url:
|
||||
segment_urls.append(segment_url)
|
||||
|
||||
# Prebuffer init segments (using special init cache)
|
||||
for init_url in init_urls:
|
||||
asyncio.create_task(self._prebuffer_init_segment(init_url, headers))
|
||||
|
||||
# Prebuffer media segments using base class method
|
||||
if segment_urls:
|
||||
await self.prebuffer_segments_batch(segment_urls, headers)
|
||||
|
||||
async def _prebuffer_init_segment(
|
||||
self,
|
||||
init_url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Prebuffer an init segment using the init segment cache.
|
||||
|
||||
Args:
|
||||
init_url: URL of the init segment
|
||||
headers: Headers for the request
|
||||
"""
|
||||
try:
|
||||
# get_cached_init_segment handles both caching and downloading
|
||||
content = await get_cached_init_segment(init_url, headers)
|
||||
if content:
|
||||
self.init_segments_prebuffered += 1
|
||||
self.stats.bytes_prebuffered += len(content)
|
||||
logger.debug(f"Prebuffered init segment ({len(content)} bytes)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to prebuffer init segment: {e}")
|
||||
|
||||
async def prefetch_upcoming_segments(
|
||||
self,
|
||||
mpd_url: str,
|
||||
current_segment_url: str,
|
||||
headers: Dict[str, str],
|
||||
profile_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Prefetch upcoming segments based on current playback position.
|
||||
|
||||
Called when a segment is requested to prefetch the next N segments.
|
||||
|
||||
Args:
|
||||
mpd_url: URL of the MPD manifest
|
||||
current_segment_url: URL of the currently requested segment
|
||||
headers: Headers to use for requests
|
||||
profile_id: Optional profile ID to limit prefetching to
|
||||
"""
|
||||
self.stats.prefetch_triggered += 1
|
||||
|
||||
try:
|
||||
# First check if we have cached profiles with segments
|
||||
if mpd_url in self.active_streams:
|
||||
# Update last access time
|
||||
self.active_streams[mpd_url]["last_access"] = time.time()
|
||||
profiles = self.active_streams[mpd_url].get("profiles", [])
|
||||
else:
|
||||
# Get parsed MPD
|
||||
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
|
||||
if not parsed_mpd:
|
||||
return
|
||||
profiles = parsed_mpd.get("profiles", [])
|
||||
|
||||
for profile in profiles:
|
||||
pid = profile.get("id")
|
||||
if profile_id and pid != profile_id:
|
||||
continue
|
||||
|
||||
segments = profile.get("segments", [])
|
||||
|
||||
# If no segments, try to get them by parsing with profile_id
|
||||
if not segments and pid:
|
||||
parsed_with_segments = await get_cached_mpd(
|
||||
mpd_url, headers, parse_drm=False, parse_segment_profile_id=pid
|
||||
)
|
||||
for p in parsed_with_segments.get("profiles", []):
|
||||
if p.get("id") == pid:
|
||||
segments = p.get("segments", [])
|
||||
break
|
||||
|
||||
# Find current segment index
|
||||
current_index = -1
|
||||
for i, segment in enumerate(segments):
|
||||
if segment.get("media") == current_segment_url:
|
||||
current_index = i
|
||||
break
|
||||
|
||||
if current_index < 0:
|
||||
continue
|
||||
|
||||
# Collect next N segment URLs
|
||||
segment_urls = []
|
||||
end_index = min(current_index + 1 + self.prebuffer_segment_count, len(segments))
|
||||
for i in range(current_index + 1, end_index):
|
||||
segment_url = segments[i].get("media")
|
||||
if segment_url:
|
||||
segment_urls.append(segment_url)
|
||||
|
||||
if segment_urls:
|
||||
logger.debug(f"Prefetching {len(segment_urls)} upcoming segments from index {current_index + 1}")
|
||||
# Run prefetch in background
|
||||
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to prefetch upcoming segments: {e}")
|
||||
|
||||
async def prefetch_for_live_playlist(
|
||||
self,
|
||||
profiles: List[dict],
|
||||
headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Prefetch segments for a live playlist refresh.
|
||||
|
||||
Called from process_playlist to ensure upcoming segments are cached.
|
||||
|
||||
Args:
|
||||
profiles: List of profiles with resolved segment URLs
|
||||
headers: Headers to use for requests
|
||||
"""
|
||||
segment_urls = []
|
||||
|
||||
for profile in profiles:
|
||||
segments = profile.get("segments", [])
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
# For live, prefetch the last N segments (most recent)
|
||||
segments_to_prefetch = segments[-self.prebuffer_segment_count :]
|
||||
|
||||
for segment in segments_to_prefetch:
|
||||
segment_url = segment.get("media")
|
||||
if segment_url:
|
||||
# Check if already cached before adding
|
||||
cached = await self.try_get_cached(segment_url)
|
||||
if not cached:
|
||||
segment_urls.append(segment_url)
|
||||
|
||||
if segment_urls:
|
||||
logger.debug(f"Live playlist prefetch: {len(segment_urls)} segments")
|
||||
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
|
||||
|
||||
def _ensure_cleanup_task_running(self) -> None:
|
||||
"""Ensure the cleanup task is running."""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_inactive_streams())
|
||||
|
||||
async def _cleanup_inactive_streams(self) -> None:
|
||||
"""
|
||||
Periodically check for and clean up inactive streams.
|
||||
|
||||
Runs in the background and removes streams that haven't been
|
||||
accessed recently.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
if not self.active_streams:
|
||||
logger.debug("No active DASH streams to monitor, stopping cleanup")
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
streams_to_remove = []
|
||||
|
||||
for mpd_url, stream_info in self.active_streams.items():
|
||||
last_access = stream_info.get("last_access", 0)
|
||||
time_since_access = current_time - last_access
|
||||
|
||||
if time_since_access > self.inactivity_timeout:
|
||||
streams_to_remove.append(mpd_url)
|
||||
logger.info(f"Cleaning up inactive DASH stream ({time_since_access:.0f}s idle)")
|
||||
|
||||
# Remove inactive streams
|
||||
for mpd_url in streams_to_remove:
|
||||
self.active_streams.pop(mpd_url, None)
|
||||
task = self.prefetch_tasks.pop(mpd_url, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
if streams_to_remove:
|
||||
logger.info(f"Cleaned up {len(streams_to_remove)} inactive DASH stream(s)")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("DASH cleanup task cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in DASH cleanup task: {e}")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get current prebuffer statistics."""
|
||||
stats = self.stats.to_dict()
|
||||
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
|
||||
stats["active_streams"] = len(self.active_streams)
|
||||
return stats
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear active streams tracking and log final stats."""
|
||||
self.log_stats()
|
||||
self.active_streams.clear()
|
||||
for task in self.prefetch_tasks.values():
|
||||
task.cancel()
|
||||
self.prefetch_tasks.clear()
|
||||
# Cancel cleanup task
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
self._cleanup_task = None
|
||||
self.stats.reset()
|
||||
self.init_segments_prebuffered = 0
|
||||
logger.info("DASH pre-buffer state cleared")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the pre-buffer system."""
|
||||
self.clear_cache()
|
||||
|
||||
|
||||
# Global DASH pre-buffer instance
|
||||
dash_prebuffer = DASHPreBuffer()
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
#
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
"""Methods for deprecating old names for arguments or attributes."""
|
||||
|
||||
import warnings
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def deprecated_class_name(old_name,
|
||||
warn="Class name '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
def deprecated_class_name(old_name, warn="Class name '{old_name}' is deprecated, please use '{new_name}'"):
|
||||
"""
|
||||
Class decorator to deprecate a use of class.
|
||||
|
||||
@@ -21,14 +20,12 @@ def deprecated_class_name(old_name,
|
||||
keyword name and the 'new_name' for the current one.
|
||||
Example: "Old name: {old_nam}, use '{new_name}' instead".
|
||||
"""
|
||||
|
||||
def _wrap(obj):
|
||||
assert callable(obj)
|
||||
|
||||
def _warn():
|
||||
warnings.warn(warn.format(old_name=old_name,
|
||||
new_name=obj.__name__),
|
||||
DeprecationWarning,
|
||||
stacklevel=3)
|
||||
warnings.warn(warn.format(old_name=old_name, new_name=obj.__name__), DeprecationWarning, stacklevel=3)
|
||||
|
||||
def _wrap_with_warn(func, is_inspect):
|
||||
@wraps(func)
|
||||
@@ -41,12 +38,12 @@ def deprecated_class_name(old_name,
|
||||
# isinstance(old_name(), new_name) to work
|
||||
frame = inspect.currentframe().f_back
|
||||
code = inspect.getframeinfo(frame).code_context
|
||||
if [line for line in code
|
||||
if '{0}('.format(old_name) in line]:
|
||||
if [line for line in code if "{0}(".format(old_name) in line]:
|
||||
_warn()
|
||||
else:
|
||||
_warn()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
# Make old name available.
|
||||
@@ -63,11 +60,11 @@ def deprecated_class_name(old_name,
|
||||
frame.f_globals[old_name] = placeholder
|
||||
|
||||
return obj
|
||||
|
||||
return _wrap
|
||||
|
||||
|
||||
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, please use '{new_name}'"):
|
||||
"""Decorator to translate obsolete names and warn about their use.
|
||||
|
||||
:param dict names: dictionary with pairs of new_name: old_name
|
||||
@@ -78,27 +75,24 @@ def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
|
||||
deprecated keyword name and 'new_name' for the current one.
|
||||
Example: "Old name: {old_name}, use {new_name} instead".
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
for new_name, old_name in names.items():
|
||||
if old_name in kwargs:
|
||||
if new_name in kwargs:
|
||||
raise TypeError("got multiple values for keyword "
|
||||
"argument '{0}'".format(new_name))
|
||||
warnings.warn(warn.format(old_name=old_name,
|
||||
new_name=new_name),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
raise TypeError("got multiple values for keyword argument '{0}'".format(new_name))
|
||||
warnings.warn(warn.format(old_name=old_name, new_name=new_name), DeprecationWarning, stacklevel=2)
|
||||
kwargs[new_name] = kwargs.pop(old_name)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated_instance_attrs(names,
|
||||
warn="Attribute '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
def deprecated_instance_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
|
||||
"""Decorator to deprecate class instance attributes.
|
||||
|
||||
Translates all names in `names` to use new names and emits warnings
|
||||
@@ -119,27 +113,20 @@ def deprecated_instance_attrs(names,
|
||||
def decorator(clazz):
|
||||
def getx(self, name, __old_getx=getattr(clazz, "__getattr__", None)):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
|
||||
return getattr(self, names[name])
|
||||
if __old_getx:
|
||||
if hasattr(__old_getx, "__func__"):
|
||||
return __old_getx.__func__(self, name)
|
||||
return __old_getx(self, name)
|
||||
raise AttributeError("'{0}' object has no attribute '{1}'"
|
||||
.format(clazz.__name__, name))
|
||||
raise AttributeError("'{0}' object has no attribute '{1}'".format(clazz.__name__, name))
|
||||
|
||||
getx.__name__ = "__getattr__"
|
||||
clazz.__getattr__ = getx
|
||||
|
||||
def setx(self, name, value, __old_setx=getattr(clazz, "__setattr__")):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
|
||||
setattr(self, names[name], value)
|
||||
else:
|
||||
__old_setx(self, name, value)
|
||||
@@ -149,10 +136,7 @@ def deprecated_instance_attrs(names,
|
||||
|
||||
def delx(self, name, __old_delx=getattr(clazz, "__delattr__")):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
|
||||
delattr(self, names[name])
|
||||
else:
|
||||
__old_delx(self, name)
|
||||
@@ -161,11 +145,11 @@ def deprecated_instance_attrs(names,
|
||||
clazz.__delattr__ = delx
|
||||
|
||||
return clazz
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
|
||||
"""Decorator to deprecate all specified attributes in class.
|
||||
|
||||
Translates all names in `names` to use new names and emits warnings
|
||||
@@ -180,6 +164,7 @@ def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
|
||||
deprecated keyword name and 'new_name' for the current one.
|
||||
Example: "Old name: {old_name}, use {new_name} instead".
|
||||
"""
|
||||
|
||||
# prepare metaclass for handling all the class methods, class variables
|
||||
# and static methods (as they don't go through instance's __getattr__)
|
||||
class DeprecatedProps(type):
|
||||
@@ -192,27 +177,33 @@ def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
|
||||
|
||||
# apply metaclass
|
||||
orig_vars = cls.__dict__.copy()
|
||||
slots = orig_vars.get('__slots__')
|
||||
slots = orig_vars.get("__slots__")
|
||||
if slots is not None:
|
||||
if isinstance(slots, str):
|
||||
slots = [slots]
|
||||
for slots_var in slots:
|
||||
orig_vars.pop(slots_var)
|
||||
orig_vars.pop('__dict__', None)
|
||||
orig_vars.pop('__weakref__', None)
|
||||
orig_vars.pop("__dict__", None)
|
||||
orig_vars.pop("__weakref__", None)
|
||||
return metaclass(cls.__name__, cls.__bases__, orig_vars)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecated_method(message):
|
||||
"""Decorator for deprecating methods.
|
||||
|
||||
:param ste message: The message you want to display.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn("{0} is a deprecated method. {1}".format(func.__name__, message),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
warnings.warn(
|
||||
"{0} is a deprecated method. {1}".format(func.__name__, message), DeprecationWarning, stacklevel=2
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Helper functions for automatic stream extraction in proxy routes.
|
||||
|
||||
This module provides caching and extraction helpers for DLHD/DaddyLive
|
||||
and Sportsonline/Sportzonline streams that are auto-detected in proxy routes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
|
||||
from mediaflow_proxy.extractors.base import ExtractorError
|
||||
from mediaflow_proxy.extractors.factory import ExtractorFactory
|
||||
from mediaflow_proxy.utils.http_utils import ProxyRequestHeaders, DownloadError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DLHD extraction cache: {original_url: {"data": extraction_result, "timestamp": time.time()}}
|
||||
_dlhd_extraction_cache: dict = {}
|
||||
_dlhd_cache_duration = 600 # 10 minutes in seconds
|
||||
|
||||
# Sportsonline extraction cache
|
||||
_sportsonline_extraction_cache: dict = {}
|
||||
_sportsonline_cache_duration = 600 # 10 minutes in seconds
|
||||
|
||||
|
||||
async def check_and_extract_dlhd_stream(
|
||||
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
|
||||
) -> dict | None:
|
||||
"""
|
||||
Check if destination contains DLHD/DaddyLive patterns and extract stream directly.
|
||||
Uses caching to avoid repeated extractions (10 minute cache).
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
destination (str): The destination URL to check.
|
||||
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||
force_refresh (bool): Force re-extraction even if cached data exists.
|
||||
|
||||
Returns:
|
||||
dict | None: Extracted stream data if DLHD link detected, None otherwise.
|
||||
"""
|
||||
# Check for common DLHD/DaddyLive patterns in the URL
|
||||
# This includes stream-XXX pattern and domain names like dlhd.dad or daddylive.sx
|
||||
is_dlhd_link = (
|
||||
re.search(r"stream-\d+", destination)
|
||||
or "dlhd.dad" in urlparse(destination).netloc
|
||||
or "daddylive.sx" in urlparse(destination).netloc
|
||||
)
|
||||
|
||||
if not is_dlhd_link:
|
||||
return None
|
||||
|
||||
logger.info(f"DLHD link detected: {destination}")
|
||||
|
||||
# Check cache first (unless force_refresh is True)
|
||||
current_time = time.time()
|
||||
if not force_refresh and destination in _dlhd_extraction_cache:
|
||||
cached_entry = _dlhd_extraction_cache[destination]
|
||||
cache_age = current_time - cached_entry["timestamp"]
|
||||
|
||||
if cache_age < _dlhd_cache_duration:
|
||||
logger.info(f"Using cached DLHD data (age: {cache_age:.1f}s)")
|
||||
return cached_entry["data"]
|
||||
else:
|
||||
logger.info(f"DLHD cache expired (age: {cache_age:.1f}s), re-extracting...")
|
||||
del _dlhd_extraction_cache[destination]
|
||||
|
||||
# Extract stream data
|
||||
try:
|
||||
logger.info(f"Extracting DLHD stream data from: {destination}")
|
||||
extractor = ExtractorFactory.get_extractor("DLHD", proxy_headers.request)
|
||||
result = await extractor.extract(destination)
|
||||
|
||||
logger.info(f"DLHD extraction successful. Stream URL: {result.get('destination_url')}")
|
||||
|
||||
# Handle dlhd_key_params - encode them for URL passing
|
||||
if "dlhd_key_params" in result:
|
||||
key_params = result.pop("dlhd_key_params")
|
||||
# Add key params as special query parameters for key URL handling
|
||||
result["dlhd_channel_salt"] = key_params.get("channel_salt", "")
|
||||
result["dlhd_auth_token"] = key_params.get("auth_token", "")
|
||||
result["dlhd_iframe_url"] = key_params.get("iframe_url", "")
|
||||
logger.info("DLHD key params extracted for dynamic header computation")
|
||||
|
||||
# Cache a copy of result to prevent downstream mutations from corrupting the cache
|
||||
_dlhd_extraction_cache[destination] = {"data": result.copy(), "timestamp": current_time}
|
||||
logger.info(f"DLHD data cached for {_dlhd_cache_duration}s")
|
||||
|
||||
return result
|
||||
|
||||
except (ExtractorError, DownloadError) as e:
|
||||
logger.error(f"DLHD extraction failed: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"DLHD extraction failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during DLHD extraction: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"DLHD extraction failed: {str(e)}")
|
||||
|
||||
|
||||
async def check_and_extract_sportsonline_stream(
|
||||
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
|
||||
) -> dict | None:
|
||||
"""
|
||||
Check if destination contains Sportsonline/Sportzonline patterns and extract stream directly.
|
||||
Uses caching to avoid repeated extractions (10 minute cache).
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
destination (str): The destination URL to check.
|
||||
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||
force_refresh (bool): Force re-extraction even if cached data exists.
|
||||
|
||||
Returns:
|
||||
dict | None: Extracted stream data if Sportsonline link detected, None otherwise.
|
||||
"""
|
||||
parsed_netloc = urlparse(destination).netloc
|
||||
is_sportsonline_link = "sportzonline." in parsed_netloc or "sportsonline." in parsed_netloc
|
||||
|
||||
if not is_sportsonline_link:
|
||||
return None
|
||||
|
||||
logger.info(f"Sportsonline link detected: {destination}")
|
||||
|
||||
current_time = time.time()
|
||||
if not force_refresh and destination in _sportsonline_extraction_cache:
|
||||
cached_entry = _sportsonline_extraction_cache[destination]
|
||||
if current_time - cached_entry["timestamp"] < _sportsonline_cache_duration:
|
||||
logger.info(f"Using cached Sportsonline data (age: {current_time - cached_entry['timestamp']:.1f}s)")
|
||||
return cached_entry["data"]
|
||||
else:
|
||||
logger.info("Sportsonline cache expired, re-extracting...")
|
||||
del _sportsonline_extraction_cache[destination]
|
||||
|
||||
try:
|
||||
logger.info(f"Extracting Sportsonline stream data from: {destination}")
|
||||
extractor = ExtractorFactory.get_extractor("Sportsonline", proxy_headers.request)
|
||||
result = await extractor.extract(destination)
|
||||
logger.info(f"Sportsonline extraction successful. Stream URL: {result.get('destination_url')}")
|
||||
_sportsonline_extraction_cache[destination] = {"data": result, "timestamp": current_time}
|
||||
logger.info(f"Sportsonline data cached for {_sportsonline_cache_duration}s")
|
||||
return result
|
||||
except (ExtractorError, DownloadError) as e:
|
||||
logger.error(f"Sportsonline extraction failed: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Sportsonline extraction failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during Sportsonline extraction: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Sportsonline extraction failed: {str(e)}")
|
||||
@@ -1,490 +1,478 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import psutil
|
||||
from typing import Dict, Optional, List
|
||||
from urllib.parse import urlparse
|
||||
import httpx
|
||||
from mediaflow_proxy.utils.http_utils import create_httpx_client
|
||||
from mediaflow_proxy.configs import settings
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HLSPreBuffer:
|
||||
"""
|
||||
Pre-buffer system for HLS streams to reduce latency and improve streaming performance.
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_size: Optional[int] = None, prebuffer_segments: Optional[int] = None):
|
||||
"""
|
||||
Initialize the HLS pre-buffer system.
|
||||
|
||||
Args:
|
||||
max_cache_size (int): Maximum number of segments to cache (uses config if None)
|
||||
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
self.max_cache_size = max_cache_size or settings.hls_prebuffer_cache_size
|
||||
self.prebuffer_segments = prebuffer_segments or settings.hls_prebuffer_segments
|
||||
self.max_memory_percent = settings.hls_prebuffer_max_memory_percent
|
||||
self.emergency_threshold = settings.hls_prebuffer_emergency_threshold
|
||||
# Cache LRU
|
||||
self.segment_cache: "OrderedDict[str, bytes]" = OrderedDict()
|
||||
# Mappa playlist -> lista segmenti
|
||||
self.segment_urls: Dict[str, List[str]] = {}
|
||||
# Mappa inversa segmento -> (playlist_url, index)
|
||||
self.segment_to_playlist: Dict[str, tuple[str, int]] = {}
|
||||
# Stato per playlist: {headers, last_access, refresh_task, target_duration}
|
||||
self.playlist_state: Dict[str, dict] = {}
|
||||
self.client = create_httpx_client()
|
||||
|
||||
async def prebuffer_playlist(self, playlist_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer segments from an HLS playlist.
|
||||
|
||||
Args:
|
||||
playlist_url (str): URL of the HLS playlist
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting pre-buffer for playlist: {playlist_url}")
|
||||
response = await self.client.get(playlist_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
playlist_content = response.text
|
||||
|
||||
# Se master playlist: prendi la prima variante (fix relativo)
|
||||
if "#EXT-X-STREAM-INF" in playlist_content:
|
||||
logger.debug(f"Master playlist detected, finding first variant")
|
||||
variant_urls = self._extract_variant_urls(playlist_content, playlist_url)
|
||||
if variant_urls:
|
||||
first_variant_url = variant_urls[0]
|
||||
logger.debug(f"Pre-buffering first variant: {first_variant_url}")
|
||||
await self.prebuffer_playlist(first_variant_url, headers)
|
||||
else:
|
||||
logger.warning("No variants found in master playlist")
|
||||
return
|
||||
|
||||
# Media playlist: estrai segmenti, salva stato e lancia refresh loop
|
||||
segment_urls = self._extract_segment_urls(playlist_content, playlist_url)
|
||||
self.segment_urls[playlist_url] = segment_urls
|
||||
# aggiorna mappa inversa
|
||||
for idx, u in enumerate(segment_urls):
|
||||
self.segment_to_playlist[u] = (playlist_url, idx)
|
||||
|
||||
# prebuffer iniziale
|
||||
await self._prebuffer_segments(segment_urls[:self.prebuffer_segments], headers)
|
||||
logger.info(f"Pre-buffered {min(self.prebuffer_segments, len(segment_urls))} segments for {playlist_url}")
|
||||
|
||||
# setup refresh loop se non già attivo
|
||||
target_duration = self._parse_target_duration(playlist_content) or 6
|
||||
st = self.playlist_state.get(playlist_url, {})
|
||||
if not st.get("refresh_task") or st["refresh_task"].done():
|
||||
task = asyncio.create_task(self._refresh_playlist_loop(playlist_url, headers, target_duration))
|
||||
self.playlist_state[playlist_url] = {
|
||||
"headers": headers,
|
||||
"last_access": asyncio.get_event_loop().time(),
|
||||
"refresh_task": task,
|
||||
"target_duration": target_duration,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer playlist {playlist_url}: {e}")
|
||||
|
||||
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract segment URLs from HLS playlist content.
|
||||
|
||||
Args:
|
||||
playlist_content (str): Content of the HLS playlist
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List[str]: List of segment URLs
|
||||
"""
|
||||
segment_urls = []
|
||||
lines = playlist_content.split('\n')
|
||||
|
||||
logger.debug(f"Analyzing playlist with {len(lines)} lines")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
# Check if line contains a URL (http/https) or is a relative path
|
||||
if 'http://' in line or 'https://' in line:
|
||||
segment_urls.append(line)
|
||||
logger.debug(f"Found absolute URL: {line}")
|
||||
elif line and not line.startswith('#'):
|
||||
# This might be a relative path to a segment
|
||||
parsed_base = urlparse(base_url)
|
||||
# Ensure proper path joining
|
||||
if line.startswith('/'):
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
|
||||
else:
|
||||
# Get the directory path from base_url
|
||||
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
|
||||
segment_urls.append(segment_url)
|
||||
logger.debug(f"Found relative path: {line} -> {segment_url}")
|
||||
|
||||
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
|
||||
if segment_urls:
|
||||
logger.debug(f"First segment URL: {segment_urls[0]}")
|
||||
else:
|
||||
logger.debug("No segment URLs found in playlist")
|
||||
# Log first few lines for debugging
|
||||
for i, line in enumerate(lines[:10]):
|
||||
logger.debug(f"Line {i}: {line}")
|
||||
|
||||
return segment_urls
|
||||
|
||||
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
|
||||
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
|
||||
"""
|
||||
from urllib.parse import urljoin
|
||||
variant_urls = []
|
||||
lines = [l.strip() for l in playlist_content.split('\n')]
|
||||
take_next_uri = False
|
||||
for line in lines:
|
||||
if line.startswith("#EXT-X-STREAM-INF"):
|
||||
take_next_uri = True
|
||||
continue
|
||||
if take_next_uri:
|
||||
take_next_uri = False
|
||||
if line and not line.startswith('#'):
|
||||
variant_urls.append(urljoin(base_url, line))
|
||||
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
|
||||
if variant_urls:
|
||||
logger.debug(f"First variant URL: {variant_urls[0]}")
|
||||
return variant_urls
|
||||
|
||||
async def _prebuffer_segments(self, segment_urls: List[str], headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer specific segments.
|
||||
|
||||
Args:
|
||||
segment_urls (List[str]): List of segment URLs to pre-buffer
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
tasks = []
|
||||
for url in segment_urls:
|
||||
if url not in self.segment_cache:
|
||||
tasks.append(self._download_segment(url, headers))
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def _get_memory_usage_percent(self) -> float:
|
||||
"""
|
||||
Get current memory usage percentage.
|
||||
|
||||
Returns:
|
||||
float: Memory usage percentage
|
||||
"""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
return memory.percent
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get memory usage: {e}")
|
||||
return 0.0
|
||||
|
||||
def _check_memory_threshold(self) -> bool:
|
||||
"""
|
||||
Check if memory usage exceeds the emergency threshold.
|
||||
|
||||
Returns:
|
||||
bool: True if emergency cleanup is needed
|
||||
"""
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
return memory_percent > self.emergency_threshold
|
||||
|
||||
def _emergency_cache_cleanup(self) -> None:
|
||||
"""
|
||||
Esegue cleanup LRU rimuovendo il 50% più vecchio.
|
||||
"""
|
||||
if self._check_memory_threshold():
|
||||
logger.warning("Emergency cache cleanup triggered due to high memory usage")
|
||||
to_remove = max(1, len(self.segment_cache) // 2)
|
||||
removed = 0
|
||||
while removed < to_remove and self.segment_cache:
|
||||
self.segment_cache.popitem(last=False) # rimuovi LRU
|
||||
removed += 1
|
||||
logger.info(f"Emergency cleanup removed {removed} segments from cache")
|
||||
|
||||
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Download a single segment and cache it.
|
||||
|
||||
Args:
|
||||
segment_url (str): URL of the segment to download
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
"""
|
||||
try:
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
|
||||
return
|
||||
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cache LRU
|
||||
self.segment_cache[segment_url] = response.content
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
# Evict LRU finché non rientra
|
||||
while len(self.segment_cache) > self.max_cache_size:
|
||||
self.segment_cache.popitem(last=False)
|
||||
|
||||
logger.debug(f"Cached segment: {segment_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download segment {segment_url}: {e}")
|
||||
|
||||
async def get_segment(self, segment_url: str, headers: Dict[str, str]) -> Optional[bytes]:
|
||||
"""
|
||||
Get a segment from cache or download it.
|
||||
|
||||
Args:
|
||||
segment_url (str): URL of the segment
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
|
||||
Returns:
|
||||
Optional[bytes]: Cached segment data or None if not available
|
||||
"""
|
||||
# Check cache first
|
||||
if segment_url in self.segment_cache:
|
||||
logger.debug(f"Cache hit for segment: {segment_url}")
|
||||
# LRU touch
|
||||
data = self.segment_cache[segment_url]
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
# aggiorna last_access per la playlist se mappata
|
||||
pl = self.segment_to_playlist.get(segment_url)
|
||||
if pl:
|
||||
st = self.playlist_state.get(pl[0])
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
return data
|
||||
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
|
||||
return None
|
||||
|
||||
try:
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
segment_data = response.content
|
||||
|
||||
# Cache LRU
|
||||
self.segment_cache[segment_url] = segment_data
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
while len(self.segment_cache) > self.max_cache_size:
|
||||
self.segment_cache.popitem(last=False)
|
||||
|
||||
# aggiorna last_access per playlist
|
||||
pl = self.segment_to_playlist.get(segment_url)
|
||||
if pl:
|
||||
st = self.playlist_state.get(pl[0])
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"Downloaded and cached segment: {segment_url}")
|
||||
return segment_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get segment {segment_url}: {e}")
|
||||
return None
|
||||
|
||||
async def prebuffer_from_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Dato un URL di segmento, prebuffer i successivi in base alla playlist e all'indice mappato.
|
||||
"""
|
||||
mapped = self.segment_to_playlist.get(segment_url)
|
||||
if not mapped:
|
||||
return
|
||||
playlist_url, idx = mapped
|
||||
# aggiorna access time
|
||||
st = self.playlist_state.get(playlist_url)
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
await self.prebuffer_next_segments(playlist_url, idx, headers)
|
||||
|
||||
async def prebuffer_next_segments(self, playlist_url: str, current_segment_index: int, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer next segments based on current playback position.
|
||||
|
||||
Args:
|
||||
playlist_url (str): URL of the playlist
|
||||
current_segment_index (int): Index of current segment
|
||||
headers (Dict[str, str]): Headers to use for requests
|
||||
"""
|
||||
if playlist_url not in self.segment_urls:
|
||||
return
|
||||
segment_urls = self.segment_urls[playlist_url]
|
||||
next_segments = segment_urls[current_segment_index + 1:current_segment_index + 1 + self.prebuffer_segments]
|
||||
if next_segments:
|
||||
await self._prebuffer_segments(next_segments, headers)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the segment cache."""
|
||||
self.segment_cache.clear()
|
||||
self.segment_urls.clear()
|
||||
self.segment_to_playlist.clear()
|
||||
self.playlist_state.clear()
|
||||
logger.info("HLS pre-buffer cache cleared")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the pre-buffer system."""
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
# Global pre-buffer instance
|
||||
hls_prebuffer = HLSPreBuffer()
|
||||
|
||||
|
||||
class HLSPreBuffer:
|
||||
def _parse_target_duration(self, playlist_content: str) -> Optional[int]:
|
||||
"""
|
||||
Parse EXT-X-TARGETDURATION from a media playlist and return duration in seconds.
|
||||
Returns None if not present or unparsable.
|
||||
"""
|
||||
for line in playlist_content.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("#EXT-X-TARGETDURATION:"):
|
||||
try:
|
||||
value = line.split(":", 1)[1].strip()
|
||||
return int(float(value))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _refresh_playlist_loop(self, playlist_url: str, headers: Dict[str, str], target_duration: int) -> None:
|
||||
"""
|
||||
Aggiorna periodicamente la playlist per seguire la sliding window e mantenere la cache coerente.
|
||||
Interrompe e pulisce dopo inattività prolungata.
|
||||
"""
|
||||
sleep_s = max(2, min(15, int(target_duration)))
|
||||
inactivity_timeout = 600 # 10 minuti
|
||||
while True:
|
||||
try:
|
||||
st = self.playlist_state.get(playlist_url)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if not st:
|
||||
return
|
||||
if now - st.get("last_access", now) > inactivity_timeout:
|
||||
# cleanup specifico della playlist
|
||||
urls = set(self.segment_urls.get(playlist_url, []))
|
||||
if urls:
|
||||
# rimuovi dalla cache solo i segmenti di questa playlist
|
||||
for u in list(self.segment_cache.keys()):
|
||||
if u in urls:
|
||||
self.segment_cache.pop(u, None)
|
||||
# rimuovi mapping
|
||||
for u in urls:
|
||||
self.segment_to_playlist.pop(u, None)
|
||||
self.segment_urls.pop(playlist_url, None)
|
||||
self.playlist_state.pop(playlist_url, None)
|
||||
logger.info(f"Stopped HLS prebuffer for inactive playlist: {playlist_url}")
|
||||
return
|
||||
|
||||
# refresh manifest
|
||||
resp = await self.client.get(playlist_url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
content = resp.text
|
||||
new_target = self._parse_target_duration(content)
|
||||
if new_target:
|
||||
sleep_s = max(2, min(15, int(new_target)))
|
||||
|
||||
new_urls = self._extract_segment_urls(content, playlist_url)
|
||||
if new_urls:
|
||||
self.segment_urls[playlist_url] = new_urls
|
||||
# rebuild reverse map per gli ultimi N (limita la memoria)
|
||||
for idx, u in enumerate(new_urls[-(self.max_cache_size * 2):]):
|
||||
# rimappiando sovrascrivi eventuali entry
|
||||
real_idx = len(new_urls) - (self.max_cache_size * 2) + idx if len(new_urls) > (self.max_cache_size * 2) else idx
|
||||
self.segment_to_playlist[u] = (playlist_url, real_idx)
|
||||
|
||||
# tenta un prebuffer proattivo: se conosciamo l'ultimo segmento accessibile, anticipa i successivi
|
||||
# Non conosciamo l'indice di riproduzione corrente qui, quindi non facciamo nulla di aggressivo.
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Playlist refresh error for {playlist_url}: {e}")
|
||||
await asyncio.sleep(sleep_s)
|
||||
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract segment URLs from HLS playlist content.
|
||||
|
||||
Args:
|
||||
playlist_content (str): Content of the HLS playlist
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List[str]: List of segment URLs
|
||||
"""
|
||||
segment_urls = []
|
||||
lines = playlist_content.split('\n')
|
||||
|
||||
logger.debug(f"Analyzing playlist with {len(lines)} lines")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
# Check if line contains a URL (http/https) or is a relative path
|
||||
if 'http://' in line or 'https://' in line:
|
||||
segment_urls.append(line)
|
||||
logger.debug(f"Found absolute URL: {line}")
|
||||
elif line and not line.startswith('#'):
|
||||
# This might be a relative path to a segment
|
||||
parsed_base = urlparse(base_url)
|
||||
# Ensure proper path joining
|
||||
if line.startswith('/'):
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
|
||||
else:
|
||||
# Get the directory path from base_url
|
||||
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
|
||||
segment_urls.append(segment_url)
|
||||
logger.debug(f"Found relative path: {line} -> {segment_url}")
|
||||
|
||||
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
|
||||
if segment_urls:
|
||||
logger.debug(f"First segment URL: {segment_urls[0]}")
|
||||
else:
|
||||
logger.debug("No segment URLs found in playlist")
|
||||
# Log first few lines for debugging
|
||||
for i, line in enumerate(lines[:10]):
|
||||
logger.debug(f"Line {i}: {line}")
|
||||
|
||||
return segment_urls
|
||||
|
||||
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
|
||||
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
|
||||
"""
|
||||
from urllib.parse import urljoin
|
||||
variant_urls = []
|
||||
lines = [l.strip() for l in playlist_content.split('\n')]
|
||||
take_next_uri = False
|
||||
for line in lines:
|
||||
if line.startswith("#EXT-X-STREAM-INF"):
|
||||
take_next_uri = True
|
||||
continue
|
||||
if take_next_uri:
|
||||
take_next_uri = False
|
||||
if line and not line.startswith('#'):
|
||||
variant_urls.append(urljoin(base_url, line))
|
||||
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
|
||||
if variant_urls:
|
||||
logger.debug(f"First variant URL: {variant_urls[0]}")
|
||||
return variant_urls
|
||||
"""
|
||||
HLS Pre-buffer system with priority-based sequential prefetching.
|
||||
|
||||
This module provides a smart prebuffering system that:
|
||||
- Prioritizes player-requested segments (downloaded immediately)
|
||||
- Prefetches remaining segments sequentially in background
|
||||
- Supports multiple users watching the same channel (shared prefetcher)
|
||||
- Cleans up inactive prefetchers automatically
|
||||
|
||||
Architecture:
|
||||
1. When playlist is fetched, register_playlist() creates a PlaylistPrefetcher
|
||||
2. PlaylistPrefetcher runs a background loop: priority queue -> sequential prefetch
|
||||
3. When player requests a segment, request_segment() adds it to priority queue
|
||||
4. Prefetcher downloads priority segment first, then continues sequential
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from mediaflow_proxy.utils.base_prebuffer import BasePrebuffer
|
||||
from mediaflow_proxy.utils.cache_utils import get_cached_segment
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaylistPrefetcher:
|
||||
"""
|
||||
Manages prefetching for a single playlist with priority support.
|
||||
|
||||
Key design for live streams with changing tokens:
|
||||
- Does NOT start prefetching immediately on registration
|
||||
- Only starts prefetching AFTER player requests a segment
|
||||
- This ensures we prefetch from the CURRENT playlist, not stale ones
|
||||
|
||||
The prefetcher runs a background loop that:
|
||||
1. Waits for player to request a segment (priority)
|
||||
2. Downloads the priority segment first
|
||||
3. Then prefetches subsequent segments sequentially
|
||||
4. Stops when cancelled or all segments are prefetched
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
playlist_url: str,
|
||||
segment_urls: List[str],
|
||||
headers: Dict[str, str],
|
||||
prebuffer: "HLSPreBuffer",
|
||||
prefetch_limit: int = 5,
|
||||
):
|
||||
"""
|
||||
Initialize a playlist prefetcher.
|
||||
|
||||
Args:
|
||||
playlist_url: URL of the HLS playlist
|
||||
segment_urls: Ordered list of segment URLs from the playlist
|
||||
headers: Headers to use for requests
|
||||
prebuffer: Parent HLSPreBuffer instance for download methods
|
||||
prefetch_limit: Maximum number of segments to prefetch ahead of player position
|
||||
"""
|
||||
self.playlist_url = playlist_url
|
||||
self.segment_urls = segment_urls
|
||||
self.headers = headers
|
||||
self.prebuffer = prebuffer
|
||||
self.prefetch_limit = prefetch_limit
|
||||
|
||||
self.last_access = time.time()
|
||||
self.current_index = 0 # Next segment to prefetch sequentially
|
||||
self.player_index = 0 # Last segment index requested by player
|
||||
self.priority_event = asyncio.Event() # Signals priority segment available
|
||||
self.priority_url: Optional[str] = None # Current priority segment
|
||||
self.cancelled = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._lock = asyncio.Lock() # Protects priority_url
|
||||
|
||||
# Track which segments are already cached or being downloaded
|
||||
self.downloading: set = set()
|
||||
|
||||
# Track if prefetching has been activated by a player request
|
||||
self.activated = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the prefetch background task."""
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._run())
|
||||
logger.info(f"[PlaylistPrefetcher] Started (waiting for activation): {self.playlist_url}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the prefetch background task."""
|
||||
self.cancelled = True
|
||||
self.priority_event.set() # Wake up the loop
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
logger.info(f"[PlaylistPrefetcher] Stopped for: {self.playlist_url}")
|
||||
|
||||
def update_segments(self, segment_urls: List[str]) -> None:
|
||||
"""
|
||||
Update segment URLs (called when playlist is refreshed).
|
||||
|
||||
Args:
|
||||
segment_urls: New list of segment URLs
|
||||
"""
|
||||
self.segment_urls = segment_urls
|
||||
self.last_access = time.time()
|
||||
logger.debug(f"[PlaylistPrefetcher] Updated segments ({len(segment_urls)}): {self.playlist_url}")
|
||||
|
||||
async def request_priority(self, segment_url: str) -> None:
|
||||
"""
|
||||
Player requested this segment - update indices and activate prefetching.
|
||||
|
||||
The player will download this segment via get_or_download().
|
||||
The prefetcher's job is to prefetch segments AHEAD of the player,
|
||||
not to download the segment the player is already requesting.
|
||||
|
||||
For VOD/movie streams: handles seek by detecting large jumps in segment
|
||||
index and resetting the prefetch window accordingly.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment the player needs
|
||||
"""
|
||||
self.last_access = time.time()
|
||||
self.activated = True # Activate prefetching
|
||||
|
||||
# Update player position for prefetch limit calculation
|
||||
segment_index = self._find_segment_index(segment_url)
|
||||
if segment_index >= 0:
|
||||
old_player_index = self.player_index
|
||||
self.player_index = segment_index
|
||||
# Start prefetching from the NEXT segment (player handles current one)
|
||||
self.current_index = segment_index + 1
|
||||
|
||||
# Detect seek: if player jumped more than prefetch_limit segments
|
||||
# This handles VOD seek scenarios where user jumps to different position
|
||||
jump_distance = abs(segment_index - old_player_index)
|
||||
if jump_distance > self.prefetch_limit and old_player_index >= 0:
|
||||
logger.info(
|
||||
f"[PlaylistPrefetcher] Seek detected: jumped {jump_distance} segments "
|
||||
f"(from {old_player_index} to {segment_index})"
|
||||
)
|
||||
|
||||
# Signal the prefetch loop to wake up and start prefetching ahead
|
||||
async with self._lock:
|
||||
self.priority_url = segment_url
|
||||
self.priority_event.set()
|
||||
|
||||
def _find_segment_index(self, segment_url: str) -> int:
|
||||
"""Find the index of a segment URL in the list."""
|
||||
try:
|
||||
return self.segment_urls.index(segment_url)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
async def _run(self) -> None:
|
||||
"""
|
||||
Main prefetch loop.
|
||||
|
||||
For live streams: waits until activated by player request before prefetching.
|
||||
Priority: Player-requested segment > Sequential prefetch
|
||||
After downloading priority segment, continue sequential from that point.
|
||||
|
||||
Prefetching is LIMITED to `prefetch_limit` segments ahead of the player's
|
||||
current position to avoid downloading the entire stream.
|
||||
"""
|
||||
logger.info(f"[PlaylistPrefetcher] Loop started for: {self.playlist_url}")
|
||||
|
||||
while not self.cancelled:
|
||||
try:
|
||||
# Wait for activation (player request) before doing anything
|
||||
if not self.activated:
|
||||
try:
|
||||
await asyncio.wait_for(self.priority_event.wait(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# Check for priority segment first
|
||||
async with self._lock:
|
||||
priority_url = self.priority_url
|
||||
self.priority_url = None
|
||||
self.priority_event.clear()
|
||||
|
||||
if priority_url:
|
||||
# Player is already downloading this segment via get_or_download()
|
||||
# We just need to update our indices and skip to prefetching NEXT segments
|
||||
# This avoids duplicate download attempts and inflated cache miss stats
|
||||
priority_index = self._find_segment_index(priority_url)
|
||||
if priority_index >= 0:
|
||||
self.player_index = priority_index
|
||||
self.current_index = priority_index + 1 # Start prefetching from next segment
|
||||
logger.info(
|
||||
f"[PlaylistPrefetcher] Player at index {self.player_index}, "
|
||||
f"will prefetch up to {self.prefetch_limit} segments ahead"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate prefetch limit based on player position
|
||||
max_prefetch_index = self.player_index + self.prefetch_limit + 1
|
||||
|
||||
# No priority - prefetch next sequential segment (only if within limit)
|
||||
if (
|
||||
self.activated
|
||||
and self.current_index < len(self.segment_urls)
|
||||
and self.current_index < max_prefetch_index
|
||||
):
|
||||
url = self.segment_urls[self.current_index]
|
||||
|
||||
# Skip if already cached or being downloaded
|
||||
if url not in self.downloading:
|
||||
cached = await get_cached_segment(url)
|
||||
if not cached:
|
||||
logger.info(
|
||||
f"[PlaylistPrefetcher] Prefetching [{self.current_index}] "
|
||||
f"(player at {self.player_index}, limit {self.prefetch_limit}): {url}"
|
||||
)
|
||||
await self._download_segment(url)
|
||||
else:
|
||||
logger.debug(f"[PlaylistPrefetcher] Already cached [{self.current_index}]: {url}")
|
||||
|
||||
self.current_index += 1
|
||||
else:
|
||||
# Reached prefetch limit or end of segments - wait for player to advance
|
||||
try:
|
||||
await asyncio.wait_for(self.priority_event.wait(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[PlaylistPrefetcher] Loop cancelled: {self.playlist_url}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[PlaylistPrefetcher] Error in loop: {e}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
logger.info(f"[PlaylistPrefetcher] Loop ended: {self.playlist_url}")
|
||||
|
||||
async def _download_segment(self, url: str) -> None:
|
||||
"""
|
||||
Download and cache a segment using the parent prebuffer.
|
||||
|
||||
Args:
|
||||
url: URL of the segment to download
|
||||
"""
|
||||
if url in self.downloading:
|
||||
return
|
||||
|
||||
self.downloading.add(url)
|
||||
try:
|
||||
# Use the base prebuffer's get_or_download for cross-process coordination
|
||||
await self.prebuffer.get_or_download(url, self.headers)
|
||||
finally:
|
||||
self.downloading.discard(url)
|
||||
|
||||
|
||||
class HLSPreBuffer(BasePrebuffer):
|
||||
"""
|
||||
Pre-buffer system for HLS streams with priority-based prefetching.
|
||||
|
||||
Features:
|
||||
- Priority queue: Player-requested segments downloaded first
|
||||
- Sequential prefetch: Background prefetch of remaining segments
|
||||
- Multi-user support: Multiple users share same prefetcher
|
||||
- Automatic cleanup: Inactive prefetchers removed after timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: Optional[int] = None,
|
||||
prebuffer_segments: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the HLS pre-buffer system.
|
||||
|
||||
Args:
|
||||
max_cache_size: Maximum number of segments to cache (uses config if None)
|
||||
prebuffer_segments: Number of segments to pre-buffer ahead (uses config if None)
|
||||
"""
|
||||
super().__init__(
|
||||
max_cache_size=max_cache_size or settings.hls_prebuffer_cache_size,
|
||||
prebuffer_segments=prebuffer_segments or settings.hls_prebuffer_segments,
|
||||
max_memory_percent=settings.hls_prebuffer_max_memory_percent,
|
||||
emergency_threshold=settings.hls_prebuffer_emergency_threshold,
|
||||
segment_ttl=settings.hls_segment_cache_ttl,
|
||||
)
|
||||
|
||||
self.inactivity_timeout = settings.hls_prebuffer_inactivity_timeout
|
||||
|
||||
# Active prefetchers: playlist_url -> PlaylistPrefetcher
|
||||
self.active_prefetchers: Dict[str, PlaylistPrefetcher] = {}
|
||||
|
||||
# Reverse mapping: segment URL -> playlist_url
|
||||
self.segment_to_playlist: Dict[str, str] = {}
|
||||
|
||||
# Lock for prefetcher management
|
||||
self._prefetcher_lock = asyncio.Lock()
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_interval = 30 # Check every 30 seconds
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""Log current prebuffer statistics with HLS-specific info."""
|
||||
stats = self.stats.to_dict()
|
||||
stats["active_prefetchers"] = len(self.active_prefetchers)
|
||||
logger.info(f"HLS Prebuffer Stats: {stats}")
|
||||
|
||||
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract segment URLs from HLS playlist content.
|
||||
|
||||
Args:
|
||||
playlist_content: Content of the HLS playlist
|
||||
base_url: Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List of segment URLs
|
||||
"""
|
||||
segment_urls = []
|
||||
lines = playlist_content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
# Absolute URL
|
||||
if line.startswith("http://") or line.startswith("https://"):
|
||||
segment_urls.append(line)
|
||||
else:
|
||||
# Relative URL - resolve against base
|
||||
segment_url = urljoin(base_url, line)
|
||||
segment_urls.append(segment_url)
|
||||
|
||||
return segment_urls
|
||||
|
||||
def _is_master_playlist(self, playlist_content: str) -> bool:
|
||||
"""Check if this is a master playlist (contains variant streams)."""
|
||||
return "#EXT-X-STREAM-INF" in playlist_content
|
||||
|
||||
async def register_playlist(
|
||||
self,
|
||||
playlist_url: str,
|
||||
segment_urls: List[str],
|
||||
headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Register a playlist for prefetching.
|
||||
|
||||
Creates a new PlaylistPrefetcher or updates existing one.
|
||||
Called by M3U8 processor when a playlist is fetched.
|
||||
|
||||
Args:
|
||||
playlist_url: URL of the HLS playlist
|
||||
segment_urls: Ordered list of segment URLs from the playlist
|
||||
headers: Headers to use for requests
|
||||
"""
|
||||
if not segment_urls:
|
||||
logger.debug(f"[register_playlist] No segments, skipping: {playlist_url}")
|
||||
return
|
||||
|
||||
async with self._prefetcher_lock:
|
||||
# Update reverse mapping
|
||||
for url in segment_urls:
|
||||
self.segment_to_playlist[url] = playlist_url
|
||||
|
||||
if playlist_url in self.active_prefetchers:
|
||||
# Update existing prefetcher
|
||||
prefetcher = self.active_prefetchers[playlist_url]
|
||||
prefetcher.update_segments(segment_urls)
|
||||
prefetcher.headers = headers
|
||||
logger.info(f"[register_playlist] Updated existing prefetcher: {playlist_url}")
|
||||
else:
|
||||
# Create new prefetcher with configured prefetch limit
|
||||
prefetcher = PlaylistPrefetcher(
|
||||
playlist_url=playlist_url,
|
||||
segment_urls=segment_urls,
|
||||
headers=headers,
|
||||
prebuffer=self,
|
||||
prefetch_limit=settings.hls_prebuffer_segments,
|
||||
)
|
||||
self.active_prefetchers[playlist_url] = prefetcher
|
||||
prefetcher.start()
|
||||
logger.info(
|
||||
f"[register_playlist] Created new prefetcher ({len(segment_urls)} segments, "
|
||||
f"prefetch_limit={settings.hls_prebuffer_segments}): {playlist_url}"
|
||||
)
|
||||
|
||||
# Ensure cleanup task is running
|
||||
self._ensure_cleanup_task()
|
||||
|
||||
async def request_segment(self, segment_url: str) -> None:
|
||||
"""
|
||||
Player requested a segment - set as priority for prefetching.
|
||||
|
||||
Finds the prefetcher for this segment and adds it to priority queue.
|
||||
Called by the segment endpoint when a segment is requested.
|
||||
|
||||
Args:
|
||||
segment_url: URL of the segment the player needs
|
||||
"""
|
||||
playlist_url = self.segment_to_playlist.get(segment_url)
|
||||
if not playlist_url:
|
||||
logger.debug(f"[request_segment] No prefetcher found for: {segment_url}")
|
||||
return
|
||||
|
||||
prefetcher = self.active_prefetchers.get(playlist_url)
|
||||
if prefetcher:
|
||||
await prefetcher.request_priority(segment_url)
|
||||
else:
|
||||
logger.debug(f"[request_segment] Prefetcher not active for: {playlist_url}")
|
||||
|
||||
def _ensure_cleanup_task(self) -> None:
|
||||
"""Ensure the cleanup task is running."""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Periodically clean up inactive prefetchers."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_inactive_prefetchers()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup_loop] Error: {e}")
|
||||
|
||||
async def _cleanup_inactive_prefetchers(self) -> None:
|
||||
"""Remove prefetchers that haven't been accessed recently."""
|
||||
now = time.time()
|
||||
to_remove = []
|
||||
|
||||
async with self._prefetcher_lock:
|
||||
for playlist_url, prefetcher in self.active_prefetchers.items():
|
||||
inactive_time = now - prefetcher.last_access
|
||||
if inactive_time > self.inactivity_timeout:
|
||||
to_remove.append(playlist_url)
|
||||
logger.info(f"[cleanup] Removing inactive prefetcher ({inactive_time:.0f}s): {playlist_url}")
|
||||
|
||||
for playlist_url in to_remove:
|
||||
prefetcher = self.active_prefetchers.pop(playlist_url, None)
|
||||
if prefetcher:
|
||||
prefetcher.stop()
|
||||
# Clean up reverse mapping
|
||||
for url in prefetcher.segment_urls:
|
||||
self.segment_to_playlist.pop(url, None)
|
||||
|
||||
if to_remove:
|
||||
logger.info(f"[cleanup] Removed {len(to_remove)} inactive prefetchers")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get current prebuffer statistics."""
|
||||
stats = self.stats.to_dict()
|
||||
stats["active_prefetchers"] = len(self.active_prefetchers)
|
||||
return stats
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all prebuffer state and log final stats."""
|
||||
self.log_stats()
|
||||
|
||||
# Stop all prefetchers
|
||||
for prefetcher in self.active_prefetchers.values():
|
||||
prefetcher.stop()
|
||||
|
||||
self.active_prefetchers.clear()
|
||||
self.segment_to_playlist.clear()
|
||||
self.stats.reset()
|
||||
|
||||
logger.info("HLS pre-buffer state cleared")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the pre-buffer system."""
|
||||
self.clear_cache()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
|
||||
|
||||
# Global HLS pre-buffer instance
|
||||
hls_prebuffer = HLSPreBuffer()
|
||||
|
||||
@@ -1,11 +1,47 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_stream_by_resolution(streams: List[Dict[str, Any]], target_resolution: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Find stream matching target resolution (e.g., '1080p', '720p').
|
||||
Falls back to closest lower resolution if exact match not found.
|
||||
|
||||
Args:
|
||||
streams: List of stream dictionaries with 'resolution' key as (width, height) tuple.
|
||||
target_resolution: Target resolution string (e.g., '1080p', '720p').
|
||||
|
||||
Returns:
|
||||
The matching stream dictionary, or None if no streams available.
|
||||
"""
|
||||
# Parse target height from "1080p" -> 1080
|
||||
target_height = int(target_resolution.rstrip("p"))
|
||||
|
||||
# Filter streams with valid resolution (height > 0), sort by height descending
|
||||
valid_streams = [s for s in streams if s.get("resolution", (0, 0))[1] > 0]
|
||||
if not valid_streams:
|
||||
logger.warning("No streams with valid resolution found")
|
||||
return streams[0] if streams else None
|
||||
|
||||
sorted_streams = sorted(valid_streams, key=lambda s: s["resolution"][1], reverse=True)
|
||||
|
||||
# Find exact match or closest lower
|
||||
for stream in sorted_streams:
|
||||
stream_height = stream["resolution"][1]
|
||||
if stream_height <= target_height:
|
||||
logger.info(f"Selected stream with resolution {stream['resolution']} for target {target_resolution}")
|
||||
return stream
|
||||
|
||||
# If all streams are higher than target, return lowest available
|
||||
lowest_stream = sorted_streams[-1]
|
||||
logger.info(f"All streams higher than target {target_resolution}, using lowest: {lowest_stream['resolution']}")
|
||||
return lowest_stream
|
||||
|
||||
|
||||
def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses an HLS master playlist to extract stream information.
|
||||
@@ -18,37 +54,37 @@ def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) ->
|
||||
List[Dict[str, Any]]: A list of dictionaries, each representing a stream variant.
|
||||
"""
|
||||
streams = []
|
||||
lines = playlist_content.strip().split('\n')
|
||||
|
||||
lines = playlist_content.strip().split("\n")
|
||||
|
||||
# Regex to capture attributes from #EXT-X-STREAM-INF
|
||||
stream_inf_pattern = re.compile(r'#EXT-X-STREAM-INF:(.*)')
|
||||
|
||||
stream_inf_pattern = re.compile(r"#EXT-X-STREAM-INF:(.*)")
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith('#EXT-X-STREAM-INF'):
|
||||
stream_info = {'raw_stream_inf': line}
|
||||
if line.startswith("#EXT-X-STREAM-INF"):
|
||||
stream_info = {"raw_stream_inf": line}
|
||||
match = stream_inf_pattern.match(line)
|
||||
if not match:
|
||||
logger.warning(f"Could not parse #EXT-X-STREAM-INF line: {line}")
|
||||
continue
|
||||
attributes_str = match.group(1)
|
||||
|
||||
|
||||
# Parse attributes like BANDWIDTH, RESOLUTION, etc.
|
||||
attributes = re.findall(r'([A-Z-]+)=("([^"]+)"|([^,]+))', attributes_str)
|
||||
for key, _, quoted_val, unquoted_val in attributes:
|
||||
value = quoted_val if quoted_val else unquoted_val
|
||||
if key == 'RESOLUTION':
|
||||
if key == "RESOLUTION":
|
||||
try:
|
||||
width, height = map(int, value.split('x'))
|
||||
stream_info['resolution'] = (width, height)
|
||||
width, height = map(int, value.split("x"))
|
||||
stream_info["resolution"] = (width, height)
|
||||
except ValueError:
|
||||
stream_info['resolution'] = (0, 0)
|
||||
stream_info["resolution"] = (0, 0)
|
||||
else:
|
||||
stream_info[key.lower().replace('-', '_')] = value
|
||||
|
||||
stream_info[key.lower().replace("-", "_")] = value
|
||||
|
||||
# The next line should be the stream URL
|
||||
if i + 1 < len(lines) and not lines[i + 1].startswith('#'):
|
||||
if i + 1 < len(lines) and not lines[i + 1].startswith("#"):
|
||||
stream_url = lines[i + 1].strip()
|
||||
stream_info['url'] = urljoin(base_url, stream_url) if base_url else stream_url
|
||||
stream_info["url"] = urljoin(base_url, stream_url) if base_url else stream_url
|
||||
streams.append(stream_info)
|
||||
|
||||
return streams
|
||||
|
||||
return streams
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
aiohttp client factory with URL-based SSL verification and proxy routing.
|
||||
|
||||
This module provides a centralized HTTP client factory for aiohttp,
|
||||
allowing per-URL configuration of SSL verification and proxy routing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteMatch:
|
||||
"""Configuration for a matched route."""
|
||||
|
||||
verify_ssl: bool = True
|
||||
proxy_url: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class URLRoutingConfig:
|
||||
"""
|
||||
URL-based routing configuration for SSL verification and proxy settings.
|
||||
|
||||
Supports pattern matching:
|
||||
- "all://*.example.com" - matches all protocols for *.example.com
|
||||
- "https://api.example.com" - matches specific protocol and host
|
||||
- "all://" - default fallback for all URLs
|
||||
"""
|
||||
|
||||
# Pattern -> (verify_ssl, proxy_url)
|
||||
routes: Dict[str, Tuple[bool, Optional[str]]] = field(default_factory=dict)
|
||||
|
||||
# Global defaults
|
||||
default_verify_ssl: bool = True
|
||||
default_proxy_url: Optional[str] = None
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
pattern: str,
|
||||
verify_ssl: bool = True,
|
||||
proxy_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a route configuration.
|
||||
|
||||
Args:
|
||||
pattern: URL pattern (e.g., "all://*.example.com", "https://api.example.com")
|
||||
verify_ssl: Whether to verify SSL for this pattern
|
||||
proxy_url: Proxy URL to use for this pattern (None = no proxy)
|
||||
"""
|
||||
self.routes[pattern] = (verify_ssl, proxy_url)
|
||||
|
||||
def match_url(self, url: str) -> RouteMatch:
|
||||
"""
|
||||
Find the best matching route for a URL.
|
||||
|
||||
Args:
|
||||
url: The URL to match
|
||||
|
||||
Returns:
|
||||
RouteMatch with SSL and proxy settings
|
||||
"""
|
||||
if not url:
|
||||
return RouteMatch(
|
||||
verify_ssl=self.default_verify_ssl,
|
||||
proxy_url=self.default_proxy_url,
|
||||
)
|
||||
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.lower()
|
||||
host = parsed.netloc.lower()
|
||||
|
||||
# Remove port from host for matching
|
||||
if ":" in host:
|
||||
host = host.split(":")[0]
|
||||
|
||||
best_match: Optional[RouteMatch] = None
|
||||
best_specificity = -1
|
||||
|
||||
for pattern, (verify_ssl, proxy_url) in self.routes.items():
|
||||
specificity = self._match_pattern(pattern, scheme, host)
|
||||
if specificity > best_specificity:
|
||||
best_specificity = specificity
|
||||
best_match = RouteMatch(verify_ssl=verify_ssl, proxy_url=proxy_url)
|
||||
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Return defaults
|
||||
return RouteMatch(
|
||||
verify_ssl=self.default_verify_ssl,
|
||||
proxy_url=self.default_proxy_url,
|
||||
)
|
||||
|
||||
def _match_pattern(self, pattern: str, scheme: str, host: str) -> int:
|
||||
"""
|
||||
Check if a pattern matches the given scheme and host.
|
||||
|
||||
Returns specificity score (higher = more specific match):
|
||||
- -1: No match
|
||||
- 0: Default match (all://)
|
||||
- 1: Scheme match only
|
||||
- 2: Wildcard host match
|
||||
- 3: Exact host match
|
||||
"""
|
||||
# Parse pattern
|
||||
if "://" in pattern:
|
||||
pattern_scheme, pattern_host = pattern.split("://", 1)
|
||||
else:
|
||||
return -1
|
||||
|
||||
# Check scheme
|
||||
scheme_matches = pattern_scheme.lower() == "all" or pattern_scheme.lower() == scheme
|
||||
|
||||
if not scheme_matches:
|
||||
return -1
|
||||
|
||||
# Empty host = default route
|
||||
if not pattern_host:
|
||||
return 0
|
||||
|
||||
# Check host with wildcard support
|
||||
if pattern_host.startswith("*."):
|
||||
# Wildcard subdomain match
|
||||
suffix = pattern_host[1:] # Remove the *
|
||||
if host.endswith(suffix) or host == pattern_host[2:]:
|
||||
return 2
|
||||
return -1
|
||||
elif pattern_host == host:
|
||||
# Exact match
|
||||
return 3
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
# Global routing configuration - will be initialized from settings
|
||||
_global_routing_config: Optional[URLRoutingConfig] = None
|
||||
_routing_initialized = False
|
||||
|
||||
|
||||
def get_routing_config() -> URLRoutingConfig:
|
||||
"""Get the global URL routing configuration."""
|
||||
global _global_routing_config
|
||||
if _global_routing_config is None:
|
||||
_global_routing_config = URLRoutingConfig()
|
||||
return _global_routing_config
|
||||
|
||||
|
||||
def initialize_routing_from_config(transport_config) -> None:
|
||||
"""
|
||||
Initialize the global routing configuration from TransportConfig.
|
||||
|
||||
Args:
|
||||
transport_config: The TransportConfig instance from settings
|
||||
"""
|
||||
global _global_routing_config, _routing_initialized
|
||||
|
||||
config = URLRoutingConfig(
|
||||
default_verify_ssl=not transport_config.disable_ssl_verification_globally,
|
||||
default_proxy_url=transport_config.proxy_url if transport_config.all_proxy else None,
|
||||
)
|
||||
|
||||
# Add configured routes
|
||||
for pattern, route in transport_config.transport_routes.items():
|
||||
global_verify = not transport_config.disable_ssl_verification_globally
|
||||
verify_ssl = route.verify_ssl if global_verify else False
|
||||
proxy_url = route.proxy_url or transport_config.proxy_url if route.proxy else None
|
||||
config.add_route(pattern, verify_ssl=verify_ssl, proxy_url=proxy_url)
|
||||
|
||||
# Hardcoded routes for specific domains (SSL verification disabled)
|
||||
hardcoded_domains = [
|
||||
"all://jxoplay.xyz",
|
||||
"all://dlhd.dad",
|
||||
"all://*.newkso.ru",
|
||||
]
|
||||
|
||||
for domain in hardcoded_domains:
|
||||
proxy_url = transport_config.proxy_url if transport_config.all_proxy else None
|
||||
config.add_route(domain, verify_ssl=False, proxy_url=proxy_url)
|
||||
|
||||
# Default route for global settings
|
||||
if transport_config.all_proxy or transport_config.disable_ssl_verification_globally:
|
||||
default_proxy = transport_config.proxy_url if transport_config.all_proxy else None
|
||||
config.add_route(
|
||||
"all://",
|
||||
verify_ssl=not transport_config.disable_ssl_verification_globally,
|
||||
proxy_url=default_proxy,
|
||||
)
|
||||
|
||||
_global_routing_config = config
|
||||
_routing_initialized = True
|
||||
|
||||
logger.info(f"Initialized aiohttp routing with {len(config.routes)} routes")
|
||||
|
||||
|
||||
def _ensure_routing_initialized():
|
||||
"""Ensure routing configuration is initialized from settings."""
|
||||
global _routing_initialized
|
||||
if not _routing_initialized:
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
initialize_routing_from_config(settings.transport_config)
|
||||
|
||||
|
||||
def _get_ssl_context(verify: bool) -> ssl.SSLContext:
|
||||
"""Get an SSL context with the specified verification setting."""
|
||||
if verify:
|
||||
return ssl.create_default_context()
|
||||
else:
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
|
||||
def create_proxy_connector(proxy_url: str, verify_ssl: bool = True) -> aiohttp.BaseConnector:
|
||||
"""
|
||||
Create a connector for proxy connections, supporting SOCKS5 and HTTP proxies.
|
||||
|
||||
Args:
|
||||
proxy_url: The proxy URL (socks5://..., http://..., https://...)
|
||||
verify_ssl: Whether to verify SSL certificates
|
||||
|
||||
Returns:
|
||||
Appropriate connector for the proxy type
|
||||
"""
|
||||
parsed = urlparse(proxy_url)
|
||||
scheme = parsed.scheme.lower()
|
||||
|
||||
ssl_context = _get_ssl_context(verify_ssl)
|
||||
|
||||
if scheme in ("socks5", "socks5h", "socks4", "socks4a"):
|
||||
try:
|
||||
from aiohttp_socks import ProxyConnector, ProxyType
|
||||
|
||||
proxy_type_map = {
|
||||
"socks5": ProxyType.SOCKS5,
|
||||
"socks5h": ProxyType.SOCKS5,
|
||||
"socks4": ProxyType.SOCKS4,
|
||||
"socks4a": ProxyType.SOCKS4,
|
||||
}
|
||||
|
||||
return ProxyConnector(
|
||||
proxy_type=proxy_type_map[scheme],
|
||||
host=parsed.hostname,
|
||||
port=parsed.port or 1080,
|
||||
username=parsed.username,
|
||||
password=parsed.password,
|
||||
rdns=scheme.endswith("h"), # Remote DNS resolution for socks5h
|
||||
ssl=ssl_context if not verify_ssl else None,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("aiohttp-socks not installed, SOCKS proxy support unavailable")
|
||||
raise
|
||||
else:
|
||||
# HTTP/HTTPS proxy - use standard connector
|
||||
# The proxy URL will be passed to the request method
|
||||
return TCPConnector(
|
||||
ssl=ssl_context,
|
||||
limit=100,
|
||||
limit_per_host=10,
|
||||
)
|
||||
|
||||
|
||||
def _create_connector(proxy_url: Optional[str], verify_ssl: bool) -> Tuple[aiohttp.BaseConnector, Optional[str]]:
|
||||
"""
|
||||
Create an appropriate connector based on proxy configuration.
|
||||
|
||||
Args:
|
||||
proxy_url: The proxy URL or None
|
||||
verify_ssl: Whether to verify SSL certificates
|
||||
|
||||
Returns:
|
||||
Tuple of (connector, effective_proxy_url)
|
||||
For SOCKS proxies, effective_proxy_url is None (handled by connector)
|
||||
For HTTP proxies, effective_proxy_url is passed to requests
|
||||
"""
|
||||
if proxy_url:
|
||||
parsed_proxy = urlparse(proxy_url)
|
||||
if parsed_proxy.scheme in ("socks5", "socks5h", "socks4", "socks4a"):
|
||||
# SOCKS proxy - use special connector, proxy handled internally
|
||||
connector = create_proxy_connector(proxy_url, verify_ssl)
|
||||
return connector, None
|
||||
else:
|
||||
# HTTP proxy - use standard connector, pass proxy to request
|
||||
ssl_ctx = _get_ssl_context(verify_ssl)
|
||||
connector = TCPConnector(ssl=ssl_ctx, limit=100, limit_per_host=10)
|
||||
return connector, proxy_url
|
||||
else:
|
||||
ssl_ctx = _get_ssl_context(verify_ssl)
|
||||
connector = TCPConnector(ssl=ssl_ctx, limit=100, limit_per_host=10)
|
||||
return connector, None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_aiohttp_session(
|
||||
url: str = None,
|
||||
timeout: typing.Union[int, float, ClientTimeout] = None,
|
||||
headers: typing.Optional[typing.Dict[str, str]] = None,
|
||||
verify: typing.Optional[bool] = None,
|
||||
) -> typing.AsyncGenerator[typing.Tuple[ClientSession, typing.Optional[str]], None]:
|
||||
"""
|
||||
Create an aiohttp ClientSession with configured proxy routing and SSL settings.
|
||||
|
||||
This is the primary way to create HTTP sessions in the application.
|
||||
It automatically applies URL-based routing for SSL verification and proxy settings.
|
||||
|
||||
Args:
|
||||
url: The URL to configure the session for (used for routing)
|
||||
timeout: Request timeout (int/float for total seconds, or ClientTimeout)
|
||||
headers: Default headers for the session
|
||||
verify: Override SSL verification (None = use routing config)
|
||||
|
||||
Yields:
|
||||
Tuple of (session, proxy_url) - proxy_url should be passed to request methods
|
||||
"""
|
||||
_ensure_routing_initialized()
|
||||
|
||||
# Get routing configuration for the URL
|
||||
routing_config = get_routing_config()
|
||||
route_match = routing_config.match_url(url)
|
||||
|
||||
# Determine SSL verification
|
||||
if verify is not None:
|
||||
use_verify = verify
|
||||
else:
|
||||
use_verify = route_match.verify_ssl
|
||||
|
||||
# Create timeout
|
||||
if timeout is None:
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
timeout_config = ClientTimeout(total=settings.transport_config.timeout)
|
||||
elif isinstance(timeout, (int, float)):
|
||||
timeout_config = ClientTimeout(total=timeout)
|
||||
else:
|
||||
timeout_config = timeout
|
||||
|
||||
# Create connector
|
||||
connector, effective_proxy_url = _create_connector(route_match.proxy_url, use_verify)
|
||||
|
||||
session = ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout_config,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
yield session, effective_proxy_url
|
||||
finally:
|
||||
await session.close()
|
||||
+401
-200
@@ -1,25 +1,33 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from urllib import parse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientResponse
|
||||
import anyio
|
||||
import h11
|
||||
import httpx
|
||||
import tenacity
|
||||
from fastapi import Response
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Receive, Send, Scope
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from tqdm.asyncio import tqdm as tqdm_asyncio
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
|
||||
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
|
||||
from mediaflow_proxy.utils.stream_transformers import StreamTransformer
|
||||
from mediaflow_proxy.utils.http_client import (
|
||||
create_aiohttp_session,
|
||||
get_routing_config,
|
||||
_ensure_routing_initialized,
|
||||
_create_connector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,218 +39,279 @@ class DownloadError(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def create_httpx_client(follow_redirects: bool = True, **kwargs) -> httpx.AsyncClient:
|
||||
"""Creates an HTTPX client with configured proxy routing"""
|
||||
mounts = settings.transport_config.get_mounts()
|
||||
kwargs.setdefault("timeout", settings.transport_config.timeout)
|
||||
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, **kwargs)
|
||||
return client
|
||||
def retry_if_download_error_not_404(retry_state):
|
||||
"""Retry on DownloadError except for 404 errors."""
|
||||
if retry_state.outcome.failed:
|
||||
exception = retry_state.outcome.exception()
|
||||
if isinstance(exception, DownloadError):
|
||||
return exception.status_code != 404
|
||||
return False
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(DownloadError),
|
||||
retry=retry_if_download_error_not_404,
|
||||
)
|
||||
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
|
||||
async def fetch_with_retry(
|
||||
session: ClientSession,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
proxy: typing.Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ClientResponse:
|
||||
"""
|
||||
Fetches a URL with retry logic.
|
||||
Fetches a URL with retry logic using native aiohttp.
|
||||
|
||||
Args:
|
||||
client (httpx.AsyncClient): The HTTP client to use for the request.
|
||||
method (str): The HTTP method to use (e.g., GET, POST).
|
||||
url (str): The URL to fetch.
|
||||
headers (dict): The headers to include in the request.
|
||||
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
|
||||
session: The aiohttp ClientSession to use for the request.
|
||||
method: The HTTP method to use (e.g., GET, POST).
|
||||
url: The URL to fetch.
|
||||
headers: The headers to include in the request.
|
||||
proxy: Optional proxy URL for HTTP proxies.
|
||||
**kwargs: Additional arguments to pass to the request.
|
||||
|
||||
Returns:
|
||||
httpx.Response: The HTTP response.
|
||||
ClientResponse: The HTTP response.
|
||||
|
||||
Raises:
|
||||
DownloadError: If the request fails after retries.
|
||||
"""
|
||||
try:
|
||||
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
|
||||
response = await session.request(method, url, headers=headers, proxy=proxy, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.TimeoutException:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout while downloading {url}")
|
||||
raise DownloadError(409, f"Timeout while downloading {url}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
|
||||
if e.response.status_code == 404:
|
||||
logger.error(f"Segment Resource not found: {url}")
|
||||
raise e
|
||||
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
logger.debug(f"Segment not found (404): {url}")
|
||||
raise DownloadError(404, f"Not found (404): {url}")
|
||||
logger.error(f"HTTP error {e.status} while downloading {url}")
|
||||
raise DownloadError(e.status, f"HTTP error {e.status} while downloading {url}")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Client error downloading {url}: {e}")
|
||||
raise DownloadError(502, f"Client error downloading {url}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading {url}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class Streamer:
|
||||
# PNG signature and IEND marker for fake PNG header detection (StreamWish/FileMoon)
|
||||
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
_PNG_IEND_MARKER = b"\x49\x45\x4E\x44\xAE\x42\x60\x82"
|
||||
"""Handles streaming HTTP responses using aiohttp."""
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self, session: ClientSession, proxy_url: typing.Optional[str] = None):
|
||||
"""
|
||||
Initializes the Streamer with an HTTP client.
|
||||
Initializes the Streamer with an aiohttp session.
|
||||
|
||||
Args:
|
||||
client (httpx.AsyncClient): The HTTP client to use for streaming.
|
||||
session: The aiohttp ClientSession to use for streaming.
|
||||
proxy_url: Optional proxy URL for HTTP proxies.
|
||||
"""
|
||||
self.client = client
|
||||
self.response = None
|
||||
self.session = session
|
||||
self.proxy_url = proxy_url
|
||||
self.response: typing.Optional[ClientResponse] = None
|
||||
self.progress_bar = None
|
||||
self.bytes_transferred = 0
|
||||
self.start_byte = 0
|
||||
self.end_byte = 0
|
||||
self.total_size = 0
|
||||
# Store request details for potential retry during streaming
|
||||
self._current_url: typing.Optional[str] = None
|
||||
self._current_headers: typing.Optional[dict] = None
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(DownloadError),
|
||||
retry=retry_if_download_error_not_404,
|
||||
)
|
||||
async def create_streaming_response(self, url: str, headers: dict):
|
||||
async def create_streaming_response(self, url: str, headers: dict, method: str = "GET"):
|
||||
"""
|
||||
Creates and sends a streaming request.
|
||||
|
||||
Args:
|
||||
url (str): The URL to stream from.
|
||||
headers (dict): The headers to include in the request.
|
||||
|
||||
url: The URL to stream from.
|
||||
headers: The headers to include in the request.
|
||||
method: HTTP method to use (GET or HEAD). Defaults to GET.
|
||||
For HEAD requests, will fallback to GET if server doesn't support HEAD.
|
||||
"""
|
||||
# Store request details for potential retry during streaming
|
||||
self._current_url = url
|
||||
self._current_headers = headers.copy()
|
||||
|
||||
try:
|
||||
request = self.client.build_request("GET", url, headers=headers)
|
||||
self.response = await self.client.send(request, stream=True, follow_redirects=True)
|
||||
self.response.raise_for_status()
|
||||
except httpx.TimeoutException:
|
||||
if method.upper() == "HEAD":
|
||||
# Try HEAD first, fallback to GET if server doesn't support it
|
||||
try:
|
||||
self.response = await self.session.head(url, headers=headers, proxy=self.proxy_url)
|
||||
self.response.raise_for_status()
|
||||
except (aiohttp.ClientResponseError, aiohttp.ClientError) as head_error:
|
||||
# HEAD failed, fallback to GET (some servers don't support HEAD)
|
||||
logger.debug(f"HEAD request failed ({head_error}), falling back to GET")
|
||||
self.response = await self.session.get(url, headers=headers, proxy=self.proxy_url)
|
||||
self.response.raise_for_status()
|
||||
else:
|
||||
self.response = await self.session.get(url, headers=headers, proxy=self.proxy_url)
|
||||
self.response.raise_for_status()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while creating streaming response")
|
||||
raise DownloadError(409, "Timeout while creating streaming response")
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code} while creating streaming response")
|
||||
if e.response.status_code == 404:
|
||||
logger.error(f"Segment Resource not found: {url}")
|
||||
raise e
|
||||
raise DownloadError(
|
||||
e.response.status_code, f"HTTP error {e.response.status_code} while creating streaming response"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
logger.debug(f"Segment not found (404): {url}")
|
||||
raise DownloadError(404, f"Not found (404): {url}")
|
||||
# Don't retry rate-limit errors (429, 509) - retrying while other connections
|
||||
# are still active just wastes time. Let the player handle its own retry logic.
|
||||
if e.status in (429, 509):
|
||||
logger.warning(f"Rate limited ({e.status}) by upstream: {url}")
|
||||
raise aiohttp.ClientResponseError(e.request_info, e.history, status=e.status, message=e.message)
|
||||
logger.error(f"HTTP error {e.status} while creating streaming response")
|
||||
raise DownloadError(e.status, f"HTTP error {e.status} while creating streaming response")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Error creating streaming response: {e}")
|
||||
raise DownloadError(502, f"Error creating streaming response: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating streaming response: {e}")
|
||||
raise RuntimeError(f"Error creating streaming response: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _strip_fake_png_wrapper(chunk: bytes) -> bytes:
|
||||
async def _retry_connection(self, from_byte: int) -> bool:
|
||||
"""
|
||||
Strip fake PNG wrapper from chunk data.
|
||||
|
||||
Some streaming services (StreamWish, FileMoon) prepend a fake PNG image
|
||||
to video data to evade detection. This method detects and removes it.
|
||||
Attempt to reconnect to the upstream using Range header.
|
||||
|
||||
Args:
|
||||
chunk: The raw chunk data that may contain a fake PNG header.
|
||||
from_byte: The byte position to resume from.
|
||||
|
||||
Returns:
|
||||
The chunk with fake PNG wrapper removed, or original chunk if not present.
|
||||
bool: True if reconnection was successful, False otherwise.
|
||||
"""
|
||||
if not chunk.startswith(Streamer._PNG_SIGNATURE):
|
||||
return chunk
|
||||
if not self._current_url or not self._current_headers:
|
||||
return False
|
||||
|
||||
# Find the IEND marker that signals end of PNG data
|
||||
iend_pos = chunk.find(Streamer._PNG_IEND_MARKER)
|
||||
if iend_pos == -1:
|
||||
# IEND not found in this chunk - return as-is to avoid data corruption
|
||||
logger.debug("PNG signature detected but IEND marker not found in chunk")
|
||||
return chunk
|
||||
# Close existing response if any
|
||||
if self.response:
|
||||
self.response.close()
|
||||
self.response = None
|
||||
|
||||
# Calculate position after IEND marker
|
||||
content_start = iend_pos + len(Streamer._PNG_IEND_MARKER)
|
||||
# Create new headers with Range
|
||||
retry_headers = self._current_headers.copy()
|
||||
if self.total_size > 0:
|
||||
retry_headers["Range"] = f"bytes={from_byte}-{self.total_size - 1}"
|
||||
else:
|
||||
retry_headers["Range"] = f"bytes={from_byte}-"
|
||||
|
||||
# Skip any padding bytes (null or 0xFF) between PNG and actual content
|
||||
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
|
||||
content_start += 1
|
||||
try:
|
||||
self.response = await self.session.get(self._current_url, headers=retry_headers, proxy=self.proxy_url)
|
||||
# Accept both 200 and 206 (Partial Content) as valid responses
|
||||
if self.response.status in (200, 206):
|
||||
logger.info(f"Successfully reconnected at byte {from_byte}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Retry connection returned unexpected status: {self.response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconnect: {e}")
|
||||
return False
|
||||
|
||||
stripped_bytes = content_start
|
||||
logger.debug(f"Stripped {stripped_bytes} bytes of fake PNG wrapper from stream")
|
||||
async def stream_content(
|
||||
self, transformer: typing.Optional[StreamTransformer] = None
|
||||
) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Stream content from the response, optionally applying a transformer.
|
||||
|
||||
return chunk[content_start:]
|
||||
Includes automatic retry logic when upstream disconnects mid-stream,
|
||||
using Range headers to resume from the last successful byte.
|
||||
|
||||
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
Args:
|
||||
transformer: Optional StreamTransformer to apply host-specific
|
||||
content manipulation (e.g., PNG stripping, TS detection).
|
||||
If None, content is streamed directly without modification.
|
||||
|
||||
Yields:
|
||||
Bytes chunks from the upstream response.
|
||||
"""
|
||||
if not self.response:
|
||||
raise RuntimeError("No response available for streaming")
|
||||
|
||||
is_first_chunk = True
|
||||
retry_count = 0
|
||||
max_retries = settings.upstream_retry_attempts if settings.upstream_retry_on_disconnect else 0
|
||||
|
||||
try:
|
||||
self.parse_content_range()
|
||||
while True:
|
||||
try:
|
||||
self.parse_content_range()
|
||||
|
||||
if settings.enable_streaming_progress:
|
||||
with tqdm_asyncio(
|
||||
total=self.total_size,
|
||||
initial=self.start_byte,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
desc="Streaming",
|
||||
ncols=100,
|
||||
mininterval=1,
|
||||
) as self.progress_bar:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
chunk = self._strip_fake_png_wrapper(chunk)
|
||||
# Create async generator from response content
|
||||
async def raw_chunks():
|
||||
async for chunk in self.response.content.iter_any():
|
||||
yield chunk
|
||||
|
||||
# Choose the chunk source based on whether we have a transformer
|
||||
# Note: Transformer state may not survive reconnection properly for all transformers
|
||||
if transformer and retry_count == 0:
|
||||
chunk_source = transformer.transform(raw_chunks())
|
||||
else:
|
||||
chunk_source = raw_chunks()
|
||||
|
||||
if settings.enable_streaming_progress:
|
||||
with tqdm_asyncio(
|
||||
total=self.total_size,
|
||||
initial=self.start_byte,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
desc="Streaming",
|
||||
ncols=100,
|
||||
mininterval=1,
|
||||
) as self.progress_bar:
|
||||
async for chunk in chunk_source:
|
||||
yield chunk
|
||||
self.bytes_transferred += len(chunk)
|
||||
self.progress_bar.update(len(chunk))
|
||||
else:
|
||||
async for chunk in chunk_source:
|
||||
yield chunk
|
||||
self.bytes_transferred += len(chunk)
|
||||
self.progress_bar.update(len(chunk))
|
||||
else:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
chunk = self._strip_fake_png_wrapper(chunk)
|
||||
|
||||
yield chunk
|
||||
self.bytes_transferred += len(chunk)
|
||||
# Successfully completed streaming
|
||||
return
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Timeout while streaming")
|
||||
raise DownloadError(409, "Timeout while streaming")
|
||||
except httpx.RemoteProtocolError as e:
|
||||
# Special handling for connection closed errors
|
||||
if "peer closed connection without sending complete message body" in str(e):
|
||||
logger.warning(f"Remote server closed connection prematurely: {e}")
|
||||
# If we've received some data, just log the warning and return normally
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while streaming")
|
||||
raise DownloadError(409, "Timeout while streaming")
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientPayloadError, aiohttp.ClientError) as e:
|
||||
# Handle connection errors with potential retry
|
||||
error_type = type(e).__name__
|
||||
logger.warning(f"{error_type} while streaming after {self.bytes_transferred} bytes: {e}")
|
||||
|
||||
# Check if we should retry
|
||||
if retry_count < max_retries and self.bytes_transferred > 0:
|
||||
retry_count += 1
|
||||
resume_from = self.start_byte + self.bytes_transferred
|
||||
logger.info(f"Attempting reconnection (retry {retry_count}/{max_retries}) from byte {resume_from}")
|
||||
|
||||
# Wait before retry
|
||||
await asyncio.sleep(settings.upstream_retry_delay)
|
||||
|
||||
if await self._retry_connection(resume_from):
|
||||
# Successfully reconnected, continue the loop to resume streaming
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"Reconnection failed on retry {retry_count}")
|
||||
|
||||
# No more retries or reconnection failed
|
||||
if self.bytes_transferred > 0:
|
||||
logger.info(
|
||||
f"Partial content received ({self.bytes_transferred} bytes). Continuing with available data."
|
||||
f"Partial content received ({self.bytes_transferred} bytes). "
|
||||
f"Graceful termination after {retry_count} retry attempts."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# If we haven't received any data, raise an error
|
||||
raise DownloadError(502, f"Remote server closed connection without sending any data: {e}")
|
||||
else:
|
||||
logger.error(f"Protocol error while streaming: {e}")
|
||||
raise DownloadError(502, f"Protocol error while streaming: {e}")
|
||||
except GeneratorExit:
|
||||
logger.info("Streaming session stopped by the user")
|
||||
except httpx.ReadError as e:
|
||||
# Handle network read errors gracefully - these occur when upstream connection drops
|
||||
logger.warning(f"ReadError while streaming: {e}")
|
||||
if self.bytes_transferred > 0:
|
||||
logger.info(f"Partial content received ({self.bytes_transferred} bytes) before ReadError. Graceful termination.")
|
||||
raise DownloadError(502, f"{error_type} while streaming: {e}")
|
||||
except GeneratorExit:
|
||||
logger.info("Streaming session stopped by the user")
|
||||
return
|
||||
else:
|
||||
raise DownloadError(502, f"ReadError while streaming: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming content: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def format_bytes(size) -> str:
|
||||
power = 2**10
|
||||
@@ -263,41 +332,41 @@ class Streamer:
|
||||
self.total_size = int(self.response.headers.get("Content-Length", 0))
|
||||
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
|
||||
|
||||
async def get_text(self, url: str, headers: dict):
|
||||
async def get_text(self, url: str, headers: dict) -> str:
|
||||
"""
|
||||
Sends a GET request to a URL and returns the response text.
|
||||
|
||||
Args:
|
||||
url (str): The URL to send the GET request to.
|
||||
headers (dict): The headers to include in the request.
|
||||
url: The URL to send the GET request to.
|
||||
headers: The headers to include in the request.
|
||||
|
||||
Returns:
|
||||
str: The response text.
|
||||
"""
|
||||
try:
|
||||
self.response = await fetch_with_retry(self.client, "GET", url, headers)
|
||||
self.response = await fetch_with_retry(self.session, "GET", url, headers, proxy=self.proxy_url)
|
||||
return await self.response.text()
|
||||
except tenacity.RetryError as e:
|
||||
raise e.last_attempt.result()
|
||||
return self.response.text
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the HTTP client and response.
|
||||
Closes the HTTP response and session.
|
||||
"""
|
||||
if self.response:
|
||||
await self.response.aclose()
|
||||
self.response.close()
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
await self.client.aclose()
|
||||
await self.session.close()
|
||||
|
||||
|
||||
async def download_file_with_retry(url: str, headers: dict):
|
||||
async def download_file_with_retry(url: str, headers: dict) -> bytes:
|
||||
"""
|
||||
Downloads a file with retry logic.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to download.
|
||||
headers (dict): The headers to include in the request.
|
||||
url: The URL of the file to download.
|
||||
headers: The headers to include in the request.
|
||||
|
||||
Returns:
|
||||
bytes: The downloaded file content.
|
||||
@@ -305,10 +374,10 @@ async def download_file_with_retry(url: str, headers: dict):
|
||||
Raises:
|
||||
DownloadError: If the download fails after retries.
|
||||
"""
|
||||
async with create_httpx_client() as client:
|
||||
async with create_aiohttp_session(url) as (session, proxy_url):
|
||||
try:
|
||||
response = await fetch_with_retry(client, "GET", url, headers)
|
||||
return response.content
|
||||
response = await fetch_with_retry(session, "GET", url, headers, proxy=proxy_url)
|
||||
return await response.read()
|
||||
except DownloadError as e:
|
||||
logger.error(f"Failed to download file: {e}")
|
||||
raise e
|
||||
@@ -316,31 +385,85 @@ async def download_file_with_retry(url: str, headers: dict):
|
||||
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
|
||||
|
||||
|
||||
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
|
||||
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> ClientResponse:
|
||||
"""
|
||||
Sends an HTTP request with retry logic.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use (e.g., GET, POST).
|
||||
url (str): The URL to send the request to.
|
||||
headers (dict): The headers to include in the request.
|
||||
method: The HTTP method to use (e.g., GET, POST).
|
||||
url: The URL to send the request to.
|
||||
headers: The headers to include in the request.
|
||||
**kwargs: Additional arguments to pass to the request.
|
||||
|
||||
Returns:
|
||||
httpx.Response: The HTTP response.
|
||||
ClientResponse: The HTTP response.
|
||||
|
||||
Raises:
|
||||
DownloadError: If the request fails after retries.
|
||||
"""
|
||||
async with create_httpx_client() as client:
|
||||
async with create_aiohttp_session(url) as (session, proxy_url):
|
||||
try:
|
||||
response = await fetch_with_retry(client, method, url, headers, **kwargs)
|
||||
response = await fetch_with_retry(session, method, url, headers, proxy=proxy_url, **kwargs)
|
||||
# Read the content so it's available after session closes
|
||||
await response.read()
|
||||
return response
|
||||
except DownloadError as e:
|
||||
logger.error(f"Failed to download file: {e}")
|
||||
logger.error(f"Failed to make request: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def create_streamer(url: str = None) -> Streamer:
|
||||
"""
|
||||
Create a Streamer configured for the given URL.
|
||||
|
||||
The Streamer manages its own session lifecycle. Call streamer.close()
|
||||
when done to release resources.
|
||||
|
||||
Args:
|
||||
url: Optional URL for routing configuration (SSL/proxy settings).
|
||||
|
||||
Returns:
|
||||
Streamer: A configured Streamer instance.
|
||||
"""
|
||||
_ensure_routing_initialized()
|
||||
|
||||
routing_config = get_routing_config()
|
||||
route_match = routing_config.match_url(url)
|
||||
|
||||
# Use sock_read timeout: no total timeout, but timeout if no data received
|
||||
# for sock_read seconds. This correctly handles:
|
||||
# - Live streams (indefinite duration)
|
||||
# - Large file downloads (total time depends on file size)
|
||||
# - Seek operations (upstream may take time to seek)
|
||||
# - Dead connection detection (timeout if no data flows)
|
||||
timeout_config = ClientTimeout(
|
||||
total=None,
|
||||
sock_read=settings.transport_config.timeout,
|
||||
)
|
||||
|
||||
connector, proxy_url = _create_connector(route_match.proxy_url, route_match.verify_ssl)
|
||||
|
||||
session = ClientSession(connector=connector, timeout=timeout_config)
|
||||
return Streamer(session, proxy_url)
|
||||
|
||||
|
||||
# Keep setup_streamer as alias for backward compatibility during transition
|
||||
async def setup_streamer(url: str = None) -> typing.Tuple[ClientSession, str, Streamer]:
|
||||
"""
|
||||
Set up an aiohttp session and streamer.
|
||||
|
||||
DEPRECATED: Use create_streamer() instead which returns only the Streamer.
|
||||
|
||||
Args:
|
||||
url: Optional URL for routing configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (session, proxy_url, streamer)
|
||||
"""
|
||||
streamer = await create_streamer(url)
|
||||
return streamer.session, streamer.proxy_url, streamer
|
||||
|
||||
|
||||
def encode_mediaflow_proxy_url(
|
||||
mediaflow_proxy_url: str,
|
||||
endpoint: typing.Optional[str] = None,
|
||||
@@ -348,25 +471,31 @@ def encode_mediaflow_proxy_url(
|
||||
query_params: typing.Optional[dict] = None,
|
||||
request_headers: typing.Optional[dict] = None,
|
||||
response_headers: typing.Optional[dict] = None,
|
||||
propagate_response_headers: typing.Optional[dict] = None,
|
||||
remove_response_headers: typing.Optional[list[str]] = None,
|
||||
encryption_handler: EncryptionHandler = None,
|
||||
expiration: int = None,
|
||||
ip: str = None,
|
||||
filename: typing.Optional[str] = None,
|
||||
stream_transformer: typing.Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
|
||||
|
||||
Args:
|
||||
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
|
||||
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
|
||||
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
|
||||
query_params (dict, optional): Additional query parameters to include. Defaults to None.
|
||||
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
|
||||
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
|
||||
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
|
||||
filename (str, optional): Filename to be preserved for media players like Infuse. Defaults to None.
|
||||
mediaflow_proxy_url: The base MediaFlow proxy URL.
|
||||
endpoint: The endpoint to append to the base URL. Defaults to None.
|
||||
destination_url: The destination URL to include in the query parameters. Defaults to None.
|
||||
query_params: Additional query parameters to include. Defaults to None.
|
||||
request_headers: Headers to include as query parameters. Defaults to None.
|
||||
response_headers: Headers to include as query parameters (r_ prefix). Defaults to None.
|
||||
propagate_response_headers: Response headers that propagate to segments (rp_ prefix). Defaults to None.
|
||||
remove_response_headers: List of response header names to remove. Defaults to None.
|
||||
encryption_handler: The encryption handler to use. Defaults to None.
|
||||
expiration: The expiration time for the encrypted token. Defaults to None.
|
||||
ip: The public IP address to include in the query parameters. Defaults to None.
|
||||
filename: Filename to be preserved for media players like Infuse. Defaults to None.
|
||||
stream_transformer: ID of the stream transformer to apply. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The encoded MediaFlow proxy URL.
|
||||
@@ -376,15 +505,45 @@ def encode_mediaflow_proxy_url(
|
||||
if destination_url is not None:
|
||||
query_params["d"] = destination_url
|
||||
|
||||
# Add headers if provided
|
||||
# Add headers if provided (always use lowercase prefix for consistency)
|
||||
# Filter out empty values to avoid URLs like &h_if-range=&h_referer=...
|
||||
# Also exclude dynamic per-request headers (range, if-range) that are already handled
|
||||
# via SUPPORTED_REQUEST_HEADERS from the player's actual request. Encoding them as h_
|
||||
# query params would bake in stale values that override the player's real headers on
|
||||
# subsequent requests (e.g., when seeking to a different position).
|
||||
if request_headers:
|
||||
query_params.update(
|
||||
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
|
||||
{
|
||||
key if key.lower().startswith("h_") else f"h_{key}": value
|
||||
for key, value in request_headers.items()
|
||||
if value and (key.lower().removeprefix("h_") not in SUPPORTED_REQUEST_HEADERS)
|
||||
}
|
||||
)
|
||||
if response_headers:
|
||||
query_params.update(
|
||||
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
|
||||
{
|
||||
key if key.lower().startswith("r_") else f"r_{key}": value
|
||||
for key, value in response_headers.items()
|
||||
if value # Skip empty/None values
|
||||
}
|
||||
)
|
||||
# Add propagate response headers (rp_ prefix - these propagate to segments)
|
||||
if propagate_response_headers:
|
||||
query_params.update(
|
||||
{
|
||||
key if key.lower().startswith("rp_") else f"rp_{key}": value
|
||||
for key, value in propagate_response_headers.items()
|
||||
if value # Skip empty/None values
|
||||
}
|
||||
)
|
||||
|
||||
# Add remove headers if provided (x_ prefix for "exclude")
|
||||
if remove_response_headers:
|
||||
query_params["x_headers"] = ",".join(remove_response_headers)
|
||||
|
||||
# Add stream transformer if provided
|
||||
if stream_transformer:
|
||||
query_params["transformer"] = stream_transformer
|
||||
|
||||
# Construct the base URL
|
||||
if endpoint is None:
|
||||
@@ -441,10 +600,10 @@ def encode_stremio_proxy_url(
|
||||
Format: http://127.0.0.1:11470/proxy/d=<encoded_origin>&h=<headers>&r=<response_headers>/<path><query>
|
||||
|
||||
Args:
|
||||
stremio_proxy_url (str): The base Stremio proxy URL.
|
||||
destination_url (str): The destination URL to proxy.
|
||||
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||
response_headers (dict, optional): Response headers to include as query parameters. Defaults to None.
|
||||
stremio_proxy_url: The base Stremio proxy URL.
|
||||
destination_url: The destination URL to proxy.
|
||||
request_headers: Headers to include as query parameters. Defaults to None.
|
||||
response_headers: Response headers to include as query parameters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The encoded Stremio proxy URL.
|
||||
@@ -498,7 +657,7 @@ def get_original_scheme(request: Request) -> str:
|
||||
Determines the original scheme (http or https) of the request.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
request: The incoming HTTP request.
|
||||
|
||||
Returns:
|
||||
str: The original scheme ('http' or 'https')
|
||||
@@ -528,6 +687,35 @@ def get_original_scheme(request: Request) -> str:
|
||||
class ProxyRequestHeaders:
|
||||
request: dict
|
||||
response: dict
|
||||
remove: list # headers to remove from response
|
||||
propagate: dict # response headers to propagate to segments (rp_ prefix)
|
||||
|
||||
|
||||
def apply_header_manipulation(
|
||||
base_headers: dict, proxy_headers: ProxyRequestHeaders, include_propagate: bool = True
|
||||
) -> dict:
|
||||
"""
|
||||
Apply response header additions and removals.
|
||||
|
||||
This function filters out headers specified in proxy_headers.remove,
|
||||
then merges in headers from proxy_headers.response and optionally proxy_headers.propagate.
|
||||
|
||||
Args:
|
||||
base_headers: The base headers to start with.
|
||||
proxy_headers: The proxy headers containing response additions and removals.
|
||||
include_propagate: Whether to include propagate headers (rp_).
|
||||
Set to False for manifests, True for segments. Defaults to True.
|
||||
|
||||
Returns:
|
||||
dict: The manipulated headers.
|
||||
"""
|
||||
remove_set = set(h.lower() for h in proxy_headers.remove)
|
||||
result = {k: v for k, v in base_headers.items() if k.lower() not in remove_set}
|
||||
# Apply propagate headers first (for segments), then response headers (response takes precedence)
|
||||
if include_propagate:
|
||||
result.update(proxy_headers.propagate)
|
||||
result.update(proxy_headers.response)
|
||||
return result
|
||||
|
||||
|
||||
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
|
||||
@@ -535,32 +723,42 @@ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
|
||||
Extracts proxy headers from the request query parameters.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
request: The incoming HTTP request.
|
||||
|
||||
Returns:
|
||||
ProxyRequest: A named tuple containing the request headers and response headers.
|
||||
ProxyRequest: A named tuple containing the request headers, response headers, and headers to remove.
|
||||
"""
|
||||
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
|
||||
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
|
||||
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS and v}
|
||||
|
||||
# Extract h_ prefixed headers from query params, filtering out empty values
|
||||
for k, v in request.query_params.items():
|
||||
if k.lower().startswith("h_") and v: # Skip empty values
|
||||
request_headers[k[2:].lower()] = v
|
||||
|
||||
request_headers.setdefault("user-agent", settings.user_agent)
|
||||
|
||||
# Handle common misspelling of referer
|
||||
if "referrer" in request_headers:
|
||||
if "referer" not in request_headers:
|
||||
request_headers["referer"] = request_headers.pop("referrer")
|
||||
|
||||
dest = request.query_params.get("d", "")
|
||||
host = urlparse(dest).netloc.lower()
|
||||
|
||||
if "vidoza" in host or "videzz" in host:
|
||||
# Remove ALL empty headers
|
||||
for h in list(request_headers.keys()):
|
||||
v = request_headers[h]
|
||||
if v is None or v.strip() == "":
|
||||
request_headers.pop(h, None)
|
||||
|
||||
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
|
||||
return ProxyRequestHeaders(request_headers, response_headers)
|
||||
# r_ prefix: response headers (manifest only, not propagated to segments)
|
||||
# Filter out empty values
|
||||
response_headers = {
|
||||
k[2:].lower(): v
|
||||
for k, v in request.query_params.items()
|
||||
if k.lower().startswith("r_") and not k.lower().startswith("rp_") and v
|
||||
}
|
||||
|
||||
# rp_ prefix: response headers that propagate to segments
|
||||
# Filter out empty values
|
||||
propagate_headers = {k[3:].lower(): v for k, v in request.query_params.items() if k.lower().startswith("rp_") and v}
|
||||
|
||||
# Parse headers to remove from response (x_headers parameter)
|
||||
x_headers_param = request.query_params.get("x_headers", "")
|
||||
remove_headers = [h.strip().lower() for h in x_headers_param.split(",") if h.strip()] if x_headers_param else []
|
||||
|
||||
return ProxyRequestHeaders(request_headers, response_headers, remove_headers, propagate_headers)
|
||||
|
||||
|
||||
class EnhancedStreamingResponse(Response):
|
||||
@@ -632,19 +830,19 @@ class EnhancedStreamingResponse(Response):
|
||||
# Successfully streamed all content
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finalization_sent = True
|
||||
except (httpx.RemoteProtocolError, httpx.ReadError, h11._util.LocalProtocolError) as e:
|
||||
except (aiohttp.ServerDisconnectedError, aiohttp.ClientPayloadError, aiohttp.ClientError) as e:
|
||||
# Handle connection closed / read errors gracefully
|
||||
if data_sent:
|
||||
# We've sent some data to the client, so try to complete the response
|
||||
logger.warning(f"Upstream connection error after partial streaming: {e}")
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finalization_sent = True
|
||||
logger.info(
|
||||
f"Response finalized after partial content ({self.actual_content_length} bytes transferred)"
|
||||
)
|
||||
except Exception as close_err:
|
||||
logger.warning(f"Could not finalize response after upstream error: {close_err}")
|
||||
# We've sent some data to the client. With Content-Length set, we cannot
|
||||
# gracefully finalize a partial response - h11 will raise LocalProtocolError
|
||||
# if we try to send more_body: False without delivering all promised bytes.
|
||||
# The best we can do is log and return silently, letting the client handle
|
||||
# the incomplete response (most players will just stop or retry).
|
||||
logger.warning(
|
||||
f"Upstream connection error after partial streaming ({self.actual_content_length} bytes transferred): {e}"
|
||||
)
|
||||
# Don't try to finalize - just return and let the connection close naturally
|
||||
return
|
||||
else:
|
||||
# No data was sent, re-raise the error
|
||||
logger.error(f"Upstream error before any data was streamed: {e}")
|
||||
@@ -667,13 +865,16 @@ class EnhancedStreamingResponse(Response):
|
||||
except Exception:
|
||||
# If we can't send an error response, just log it
|
||||
pass
|
||||
elif response_started and not finalization_sent:
|
||||
# Response already started but not finalized - gracefully close the stream
|
||||
elif response_started and not finalization_sent and not data_sent:
|
||||
# Response started but no data sent yet - we can safely finalize
|
||||
# (If data was sent with Content-Length, we can't finalize without h11 error)
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finalization_sent = True
|
||||
except Exception:
|
||||
pass
|
||||
# If data was sent but streaming failed, just return silently
|
||||
# The client will see an incomplete response which is unavoidable with Content-Length
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import logging
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from urllib import parse
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
@@ -9,9 +11,121 @@ from mediaflow_proxy.utils.crypto_utils import encryption_handler
|
||||
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, encode_stremio_proxy_url, get_original_scheme
|
||||
from mediaflow_proxy.utils.hls_prebuffer import hls_prebuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_graceful_end_playlist(message: str = "Stream ended") -> str:
|
||||
"""
|
||||
Generate a minimal valid m3u8 playlist that signals stream end.
|
||||
|
||||
This is used when upstream fails but we want to provide a graceful
|
||||
end to the player instead of an abrupt error. Most players will
|
||||
interpret this as the stream ending normally.
|
||||
|
||||
Args:
|
||||
message: Optional message to include as a comment.
|
||||
|
||||
Returns:
|
||||
str: A valid m3u8 playlist string with EXT-X-ENDLIST.
|
||||
"""
|
||||
return f"""#EXTM3U
|
||||
#EXT-X-VERSION:3
|
||||
#EXT-X-TARGETDURATION:1
|
||||
#EXT-X-PLAYLIST-TYPE:VOD
|
||||
# {message}
|
||||
#EXT-X-ENDLIST
|
||||
"""
|
||||
|
||||
|
||||
def generate_error_playlist(error_message: str = "Stream unavailable") -> str:
|
||||
"""
|
||||
Generate a minimal valid m3u8 playlist for error scenarios.
|
||||
|
||||
Unlike generate_graceful_end_playlist, this includes a very short
|
||||
segment duration to signal something went wrong while still being
|
||||
a valid playlist that players can parse.
|
||||
|
||||
Args:
|
||||
error_message: Error message to include as a comment.
|
||||
|
||||
Returns:
|
||||
str: A valid m3u8 playlist string.
|
||||
"""
|
||||
return f"""#EXTM3U
|
||||
#EXT-X-VERSION:3
|
||||
#EXT-X-TARGETDURATION:1
|
||||
#EXT-X-PLAYLIST-TYPE:VOD
|
||||
# Error: {error_message}
|
||||
#EXT-X-ENDLIST
|
||||
"""
|
||||
|
||||
|
||||
class SkipSegmentFilter:
|
||||
"""
|
||||
Helper class to filter HLS segments based on time ranges.
|
||||
|
||||
Tracks cumulative playback time and determines which segments
|
||||
should be skipped based on the provided skip segment list.
|
||||
"""
|
||||
|
||||
def __init__(self, skip_segments: Optional[List[dict]] = None):
|
||||
"""
|
||||
Initialize the skip segment filter.
|
||||
|
||||
Args:
|
||||
skip_segments: List of skip segment dicts with 'start' and 'end' keys.
|
||||
"""
|
||||
self.skip_segments = skip_segments or []
|
||||
self.current_time = 0.0 # Cumulative playback time in seconds
|
||||
|
||||
def should_skip_segment(self, duration: float) -> bool:
|
||||
"""
|
||||
Determine if the current segment should be skipped.
|
||||
|
||||
Args:
|
||||
duration: Duration of the current segment in seconds.
|
||||
|
||||
Returns:
|
||||
True if the segment overlaps with any skip range, False otherwise.
|
||||
"""
|
||||
segment_start = self.current_time
|
||||
segment_end = self.current_time + duration
|
||||
|
||||
# Check if this segment overlaps with any skip range
|
||||
for skip in self.skip_segments:
|
||||
skip_start = skip.get("start", 0)
|
||||
skip_end = skip.get("end", 0)
|
||||
|
||||
# Check for overlap: segment overlaps if it starts before skip ends AND ends after skip starts
|
||||
if segment_start < skip_end and segment_end > skip_start:
|
||||
logger.debug(
|
||||
f"Skipping segment at {segment_start:.2f}s-{segment_end:.2f}s "
|
||||
f"(overlaps with skip range {skip_start:.2f}s-{skip_end:.2f}s)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def advance_time(self, duration: float):
|
||||
"""Advance the cumulative playback time."""
|
||||
self.current_time += duration
|
||||
|
||||
def has_skip_segments(self) -> bool:
|
||||
"""Check if there are any skip segments configured."""
|
||||
return bool(self.skip_segments)
|
||||
|
||||
|
||||
class M3U8Processor:
|
||||
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None, key_only_proxy: bool = False, no_proxy: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
request,
|
||||
key_url: str = None,
|
||||
force_playlist_proxy: bool = None,
|
||||
key_only_proxy: bool = False,
|
||||
no_proxy: bool = False,
|
||||
skip_segments: Optional[List[dict]] = None,
|
||||
start_offset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the M3U8Processor with the request and URL prefix.
|
||||
|
||||
@@ -21,21 +135,65 @@ class M3U8Processor:
|
||||
force_playlist_proxy (bool, optional): Force all playlist URLs to be proxied through MediaFlow. Defaults to None.
|
||||
key_only_proxy (bool, optional): Only proxy the key URL, leaving segment URLs direct. Defaults to False.
|
||||
no_proxy (bool, optional): If True, returns the manifest without proxying any URLs. Defaults to False.
|
||||
skip_segments (List[dict], optional): List of time segments to skip. Each dict should have
|
||||
'start', 'end' (in seconds), and optionally 'type'.
|
||||
start_offset (float, optional): Time offset in seconds for EXT-X-START tag. Use negative values
|
||||
for live streams to start behind the live edge.
|
||||
"""
|
||||
self.request = request
|
||||
self.key_url = parse.urlparse(key_url) if key_url else None
|
||||
self.key_only_proxy = key_only_proxy
|
||||
self.no_proxy = no_proxy
|
||||
self.force_playlist_proxy = force_playlist_proxy
|
||||
self.skip_filter = SkipSegmentFilter(skip_segments)
|
||||
# Track if user explicitly provided start_offset (vs using default)
|
||||
self._user_provided_start_offset = start_offset is not None
|
||||
# Store the explicit value or default (will be applied conditionally for live streams)
|
||||
self._start_offset_value = start_offset if start_offset is not None else settings.livestream_start_offset
|
||||
self.mediaflow_proxy_url = str(
|
||||
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
|
||||
)
|
||||
# Base URL for segment proxy - extension will be appended based on actual segment
|
||||
# url_for with path param returns URL with placeholder, so we build it manually
|
||||
self.segment_proxy_base_url = str(
|
||||
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
|
||||
).replace("/hls/manifest.m3u8", "/hls/segment")
|
||||
self.playlist_url = None # Will be set when processing starts
|
||||
|
||||
def _should_apply_start_offset(self, content: str) -> bool:
|
||||
"""
|
||||
Determine if start_offset should be applied to this playlist.
|
||||
|
||||
Args:
|
||||
content: The playlist content to check.
|
||||
|
||||
Returns:
|
||||
True if start_offset should be applied, False otherwise.
|
||||
"""
|
||||
if self._start_offset_value is None:
|
||||
return False
|
||||
|
||||
# If user explicitly provided start_offset, always use it
|
||||
if self._user_provided_start_offset:
|
||||
return True
|
||||
|
||||
# Using default from settings - only apply for live streams
|
||||
# VOD playlists have #EXT-X-ENDLIST tag or #EXT-X-PLAYLIST-TYPE:VOD
|
||||
# Also skip master playlists (they have #EXT-X-STREAM-INF)
|
||||
is_vod = "#EXT-X-ENDLIST" in content or "#EXT-X-PLAYLIST-TYPE:VOD" in content
|
||||
is_master = "#EXT-X-STREAM-INF" in content
|
||||
|
||||
return not is_vod and not is_master
|
||||
|
||||
async def process_m3u8(self, content: str, base_url: str) -> str:
|
||||
"""
|
||||
Processes the m3u8 content, proxying URLs and handling key lines.
|
||||
|
||||
For content filtering with skip_segments, this follows the IntroHater approach:
|
||||
- Segments within skip ranges are completely removed (EXTINF + URL)
|
||||
- A #EXT-X-DISCONTINUITY marker is added BEFORE the URL of the first segment
|
||||
after a skipped section (not before the EXTINF)
|
||||
|
||||
Args:
|
||||
content (str): The m3u8 content to process.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
@@ -45,35 +203,131 @@ class M3U8Processor:
|
||||
"""
|
||||
# Store the playlist URL for prebuffering
|
||||
self.playlist_url = base_url
|
||||
|
||||
|
||||
lines = content.splitlines()
|
||||
processed_lines = []
|
||||
for line in lines:
|
||||
|
||||
# Track if we need to add discontinuity before next URL (after skipping segments)
|
||||
discontinuity_pending = False
|
||||
# Buffer the current EXTINF line - only output when we output the URL
|
||||
pending_extinf: Optional[str] = None
|
||||
# Track if we've injected EXT-X-START tag
|
||||
start_offset_injected = False
|
||||
# Determine if we should apply start_offset (checks if live stream)
|
||||
apply_start_offset = self._should_apply_start_offset(content)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Inject EXT-X-START tag right after #EXTM3U (only for live streams or if user explicitly requested)
|
||||
if line.strip() == "#EXTM3U" and apply_start_offset and not start_offset_injected:
|
||||
processed_lines.append(line)
|
||||
processed_lines.append(f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES")
|
||||
start_offset_injected = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle EXTINF lines (segment duration markers)
|
||||
if line.startswith("#EXTINF:"):
|
||||
duration = self._parse_extinf_duration(line)
|
||||
|
||||
if self.skip_filter.has_skip_segments() and self.skip_filter.should_skip_segment(duration):
|
||||
# Skip this segment entirely - don't buffer the EXTINF
|
||||
discontinuity_pending = True # Mark that we need discontinuity before next kept segment
|
||||
self.skip_filter.advance_time(duration)
|
||||
pending_extinf = None
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# Keep this segment
|
||||
self.skip_filter.advance_time(duration)
|
||||
pending_extinf = line
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle segment URLs (non-comment, non-empty lines)
|
||||
if not line.startswith("#") and line.strip():
|
||||
if pending_extinf is None:
|
||||
# No pending EXTINF means this segment was skipped
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Add discontinuity BEFORE the EXTINF if we just skipped segments
|
||||
# Per HLS spec, EXT-X-DISCONTINUITY must appear before the first segment of the new content
|
||||
if discontinuity_pending:
|
||||
processed_lines.append("#EXT-X-DISCONTINUITY")
|
||||
discontinuity_pending = False
|
||||
|
||||
# Output the buffered EXTINF and proxied URL
|
||||
processed_lines.append(pending_extinf)
|
||||
processed_lines.append(await self.proxy_content_url(line, base_url))
|
||||
pending_extinf = None
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle existing discontinuity markers - pass through but reset pending flag
|
||||
if line.startswith("#EXT-X-DISCONTINUITY"):
|
||||
processed_lines.append(line)
|
||||
discontinuity_pending = False # Don't add duplicate
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle key lines
|
||||
if "URI=" in line:
|
||||
processed_lines.append(await self.process_key_line(line, base_url))
|
||||
elif not line.startswith("#") and line.strip():
|
||||
processed_lines.append(await self.proxy_content_url(line, base_url))
|
||||
else:
|
||||
processed_lines.append(line)
|
||||
|
||||
# Pre-buffer segments if enabled and this is a playlist
|
||||
if (settings.enable_hls_prebuffer and
|
||||
"#EXTM3U" in content and
|
||||
self.playlist_url):
|
||||
|
||||
# Extract headers from request for pre-buffering
|
||||
headers = {}
|
||||
for key, value in self.request.query_params.items():
|
||||
if key.startswith("h_"):
|
||||
headers[key[2:]] = value
|
||||
|
||||
# Start pre-buffering in background using the actual playlist URL
|
||||
asyncio.create_task(
|
||||
hls_prebuffer.prebuffer_playlist(self.playlist_url, headers)
|
||||
)
|
||||
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# All other lines (headers, comments, etc.)
|
||||
processed_lines.append(line)
|
||||
i += 1
|
||||
|
||||
# Log skip statistics
|
||||
if self.skip_filter.has_skip_segments():
|
||||
logger.info(f"Content filtering: processed playlist with {len(self.skip_filter.skip_segments)} skip ranges")
|
||||
|
||||
# Register playlist with the priority-based prefetcher
|
||||
if settings.enable_hls_prebuffer and "#EXTM3U" in content and self.playlist_url:
|
||||
# Skip master playlists
|
||||
if "#EXT-X-STREAM-INF" not in content:
|
||||
segment_urls = self._extract_segment_urls_from_content(content, self.playlist_url)
|
||||
|
||||
if segment_urls:
|
||||
headers = {}
|
||||
for key, value in self.request.query_params.items():
|
||||
if key.startswith("h_"):
|
||||
headers[key[2:]] = value
|
||||
|
||||
logger.info(
|
||||
f"[M3U8Processor] Registering playlist ({len(segment_urls)} segments): {self.playlist_url}"
|
||||
)
|
||||
asyncio.create_task(
|
||||
hls_prebuffer.register_playlist(
|
||||
self.playlist_url,
|
||||
segment_urls,
|
||||
headers,
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
def _parse_extinf_duration(self, line: str) -> float:
|
||||
"""
|
||||
Parse the duration from an #EXTINF line.
|
||||
|
||||
Args:
|
||||
line: The #EXTINF line (e.g., "#EXTINF:10.0," or "#EXTINF:10,title")
|
||||
|
||||
Returns:
|
||||
The duration in seconds as a float.
|
||||
"""
|
||||
# Format: #EXTINF:<duration>[,<title>]
|
||||
match = re.match(r"#EXTINF:(\d+(?:\.\d+)?)", line)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
return 0.0
|
||||
|
||||
async def process_m3u8_streaming(
|
||||
self, content_iterator: AsyncGenerator[bytes, None], base_url: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
@@ -81,20 +335,37 @@ class M3U8Processor:
|
||||
Processes the m3u8 content on-the-fly, yielding processed lines as they are read.
|
||||
Optimized to avoid accumulating the entire playlist content in memory.
|
||||
|
||||
Note: When skip_segments are configured, this method buffers lines to properly
|
||||
handle EXTINF + segment URL pairs that need to be skipped together.
|
||||
|
||||
Args:
|
||||
content_iterator: An async iterator that yields chunks of the m3u8 content.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
|
||||
Yields:
|
||||
str: Processed lines of the m3u8 content.
|
||||
|
||||
Raises:
|
||||
ValueError: If the content is not a valid m3u8 playlist (e.g., HTML error page).
|
||||
"""
|
||||
# Store the playlist URL for prebuffering
|
||||
self.playlist_url = base_url
|
||||
|
||||
|
||||
buffer = "" # String buffer for decoded content
|
||||
raw_content = "" # Accumulate raw content for prebuffer
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
is_playlist_detected = False
|
||||
is_prebuffer_started = False
|
||||
is_html_detected = False
|
||||
initial_check_done = False
|
||||
|
||||
# State for skip segment filtering
|
||||
discontinuity_pending = False # Track if we need discontinuity before next URL
|
||||
pending_extinf = None # Buffer EXTINF line until we decide to emit it
|
||||
# Track if we've injected EXT-X-START tag
|
||||
start_offset_injected = False
|
||||
# Buffer header lines until we know if it's a master playlist (for default start_offset)
|
||||
header_buffer = []
|
||||
header_flushed = False
|
||||
|
||||
# Process the content chunk by chunk
|
||||
async for chunk in content_iterator:
|
||||
@@ -104,6 +375,24 @@ class M3U8Processor:
|
||||
# Incrementally decode the chunk
|
||||
decoded_chunk = decoder.decode(chunk)
|
||||
buffer += decoded_chunk
|
||||
raw_content += decoded_chunk # Accumulate for prebuffer
|
||||
|
||||
# Early detection: check if this is HTML instead of m3u8
|
||||
# This helps catch upstream error pages quickly
|
||||
if not initial_check_done and len(buffer) > 50:
|
||||
initial_check_done = True
|
||||
buffer_lower = buffer.lower().strip()
|
||||
# Check for HTML markers
|
||||
if buffer_lower.startswith("<!doctype") or buffer_lower.startswith("<html"):
|
||||
is_html_detected = True
|
||||
logger.error(f"Upstream returned HTML instead of m3u8 playlist: {base_url}")
|
||||
# Raise an error so the HTTP handler returns a proper error response
|
||||
# This allows the player to retry or show an error instead of thinking
|
||||
# the stream has ended normally
|
||||
raise ValueError(
|
||||
f"Upstream returned HTML instead of m3u8 playlist. "
|
||||
f"The stream may be offline or unavailable: {base_url}"
|
||||
)
|
||||
|
||||
# Check for playlist marker early to avoid accumulating content
|
||||
if not is_playlist_detected and "#EXTM3U" in buffer:
|
||||
@@ -114,40 +403,246 @@ class M3U8Processor:
|
||||
if len(lines) > 1:
|
||||
# Process all complete lines except the last one
|
||||
for line in lines[:-1]:
|
||||
if line: # Skip empty lines
|
||||
if not line: # Skip empty lines
|
||||
continue
|
||||
|
||||
# Buffer header lines until we can determine playlist type
|
||||
# This allows us to decide whether to inject EXT-X-START
|
||||
if not header_flushed:
|
||||
# Always buffer the current line first
|
||||
header_buffer.append(line)
|
||||
|
||||
# Check if we can now determine playlist type
|
||||
# Only check the current line, not raw_content (which may contain future content)
|
||||
is_master = "#EXT-X-STREAM-INF" in line
|
||||
is_media = "#EXTINF" in line
|
||||
|
||||
if is_master or is_media:
|
||||
# For non-user-provided (default) start_offset, determine if this
|
||||
# is a live stream before injecting. We need to avoid injecting
|
||||
# EXT-X-START with negative offsets into VOD playlists, as players
|
||||
# like VLC interpret negative offsets as "from the end" and start
|
||||
# playing near the end of the video.
|
||||
#
|
||||
# Live stream indicators (checked in header):
|
||||
# - No #EXT-X-PLAYLIST-TYPE:VOD tag
|
||||
# - No #EXT-X-ENDLIST tag (may not be visible yet in streaming)
|
||||
# - #EXT-X-MEDIA-SEQUENCE > 0 (live windows have rolling sequence)
|
||||
#
|
||||
# VOD indicators:
|
||||
# - #EXT-X-PLAYLIST-TYPE:VOD in header
|
||||
# - #EXT-X-ENDLIST in raw_content (if small enough to be buffered)
|
||||
# - #EXT-X-MEDIA-SEQUENCE:0 or absent (VOD starts from beginning)
|
||||
header_content = "\n".join(header_buffer)
|
||||
all_content = header_content + "\n" + raw_content
|
||||
|
||||
is_explicitly_vod = (
|
||||
"#EXT-X-PLAYLIST-TYPE:VOD" in all_content or "#EXT-X-ENDLIST" in all_content
|
||||
)
|
||||
|
||||
# Check for live stream indicator: #EXT-X-MEDIA-SEQUENCE with value > 0
|
||||
# Live streams have a rolling window so their media sequence increments
|
||||
is_likely_live = False
|
||||
seq_match = re.search(r"#EXT-X-MEDIA-SEQUENCE:\s*(\d+)", all_content)
|
||||
if seq_match and int(seq_match.group(1)) > 0:
|
||||
is_likely_live = True
|
||||
|
||||
# Flush header buffer with or without EXT-X-START
|
||||
should_inject = (
|
||||
self._start_offset_value is not None
|
||||
and not is_master
|
||||
and (
|
||||
self._user_provided_start_offset
|
||||
or (is_media and not is_explicitly_vod and is_likely_live)
|
||||
) # User provided OR it's a live media playlist
|
||||
)
|
||||
|
||||
for header_line in header_buffer:
|
||||
# Process header lines to rewrite URLs (e.g., #EXT-X-KEY)
|
||||
processed_header_line = await self.process_line(header_line, base_url)
|
||||
yield processed_header_line + "\n"
|
||||
if header_line.strip() == "#EXTM3U" and should_inject and not start_offset_injected:
|
||||
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
|
||||
start_offset_injected = True
|
||||
|
||||
header_buffer = []
|
||||
header_flushed = True
|
||||
# If not master/media yet, continue buffering (line already added above)
|
||||
continue
|
||||
|
||||
# If user explicitly provided start_offset and we haven't injected yet
|
||||
# (handles edge case where we flush header before seeing EXTINF/STREAM-INF)
|
||||
if (
|
||||
line.strip() == "#EXTM3U"
|
||||
and self._user_provided_start_offset
|
||||
and self._start_offset_value is not None
|
||||
and not start_offset_injected
|
||||
):
|
||||
yield line + "\n"
|
||||
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
|
||||
start_offset_injected = True
|
||||
continue
|
||||
|
||||
# Handle segment filtering if skip_segments are configured
|
||||
if self.skip_filter.has_skip_segments():
|
||||
result = await self._process_line_with_filtering(
|
||||
line, base_url, discontinuity_pending, pending_extinf
|
||||
)
|
||||
processed_line, discontinuity_pending, pending_extinf = result
|
||||
if processed_line is not None:
|
||||
yield processed_line + "\n"
|
||||
else:
|
||||
# No filtering, process normally
|
||||
processed_line = await self.process_line(line, base_url)
|
||||
yield processed_line + "\n"
|
||||
|
||||
# Keep the last line in the buffer (it might be incomplete)
|
||||
buffer = lines[-1]
|
||||
|
||||
# Start pre-buffering early once we detect this is a playlist
|
||||
# This avoids waiting until the entire playlist is processed
|
||||
if (settings.enable_hls_prebuffer and
|
||||
is_playlist_detected and
|
||||
not is_prebuffer_started and
|
||||
self.playlist_url):
|
||||
|
||||
# Extract headers from request for pre-buffering
|
||||
headers = {}
|
||||
for key, value in self.request.query_params.items():
|
||||
if key.startswith("h_"):
|
||||
headers[key[2:]] = value
|
||||
|
||||
# Start pre-buffering in background using the actual playlist URL
|
||||
asyncio.create_task(
|
||||
hls_prebuffer.prebuffer_playlist(self.playlist_url, headers)
|
||||
# If HTML was detected, we already returned an error playlist
|
||||
if is_html_detected:
|
||||
return
|
||||
|
||||
# Flush any remaining header buffer (for short playlists or edge cases)
|
||||
# At this point we have the full raw_content so we can make a definitive determination
|
||||
if header_buffer and not header_flushed:
|
||||
is_master = "#EXT-X-STREAM-INF" in raw_content
|
||||
is_vod = "#EXT-X-ENDLIST" in raw_content or "#EXT-X-PLAYLIST-TYPE:VOD" in raw_content
|
||||
# For default offset, also require positive live indicator
|
||||
is_likely_live = False
|
||||
seq_match = re.search(r"#EXT-X-MEDIA-SEQUENCE:\s*(\d+)", raw_content)
|
||||
if seq_match and int(seq_match.group(1)) > 0:
|
||||
is_likely_live = True
|
||||
should_inject = (
|
||||
self._start_offset_value is not None
|
||||
and not is_master
|
||||
and (
|
||||
self._user_provided_start_offset
|
||||
or (not is_vod and is_likely_live) # Default offset: only inject for live streams
|
||||
)
|
||||
is_prebuffer_started = True
|
||||
)
|
||||
for header_line in header_buffer:
|
||||
yield header_line + "\n"
|
||||
if header_line.strip() == "#EXTM3U" and should_inject and not start_offset_injected:
|
||||
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
|
||||
start_offset_injected = True
|
||||
header_buffer = []
|
||||
|
||||
# Process any remaining data in the buffer plus final bytes
|
||||
final_chunk = decoder.decode(b"", final=True)
|
||||
if final_chunk:
|
||||
buffer += final_chunk
|
||||
|
||||
# Final validation: if we never detected a valid m3u8 playlist marker
|
||||
if not is_playlist_detected:
|
||||
logger.error(f"Invalid m3u8 content from upstream (no #EXTM3U marker found): {base_url}")
|
||||
yield "#EXTM3U\n"
|
||||
yield "#EXT-X-PLAYLIST-TYPE:VOD\n"
|
||||
yield "# ERROR: Invalid m3u8 content from upstream (no #EXTM3U marker found)\n"
|
||||
yield "# The upstream server may have returned an error page\n"
|
||||
yield "#EXT-X-ENDLIST\n"
|
||||
return
|
||||
|
||||
if buffer: # Process the last line if it's not empty
|
||||
processed_line = await self.process_line(buffer, base_url)
|
||||
yield processed_line
|
||||
if self.skip_filter.has_skip_segments():
|
||||
result = await self._process_line_with_filtering(
|
||||
buffer, base_url, discontinuity_pending, pending_extinf
|
||||
)
|
||||
processed_line, _, _ = result
|
||||
if processed_line is not None:
|
||||
yield processed_line
|
||||
else:
|
||||
processed_line = await self.process_line(buffer, base_url)
|
||||
yield processed_line
|
||||
|
||||
# Log skip statistics
|
||||
if self.skip_filter.has_skip_segments():
|
||||
logger.info(f"Content filtering: processed playlist with {len(self.skip_filter.skip_segments)} skip ranges")
|
||||
|
||||
# Register playlist with the priority-based prefetcher
|
||||
# The prefetcher uses a smart approach:
|
||||
# 1. When player requests a segment, it gets priority (downloaded first)
|
||||
# 2. After serving priority segment, prefetcher continues sequentially
|
||||
# 3. Multiple users watching same channel share the prefetcher
|
||||
# 4. Inactive prefetchers are cleaned up automatically
|
||||
if settings.enable_hls_prebuffer and is_playlist_detected and self.playlist_url and raw_content:
|
||||
# Skip master playlists (they contain variant streams, not segments)
|
||||
if "#EXT-X-STREAM-INF" not in raw_content:
|
||||
# Extract segment URLs from the playlist
|
||||
segment_urls = self._extract_segment_urls_from_content(raw_content, self.playlist_url)
|
||||
|
||||
if segment_urls:
|
||||
# Extract headers for prefetcher
|
||||
headers = {}
|
||||
for key, value in self.request.query_params.items():
|
||||
if key.startswith("h_"):
|
||||
headers[key[2:]] = value
|
||||
|
||||
logger.info(
|
||||
f"[M3U8Processor] Registering playlist ({len(segment_urls)} segments): {self.playlist_url}"
|
||||
)
|
||||
asyncio.create_task(
|
||||
hls_prebuffer.register_playlist(
|
||||
self.playlist_url,
|
||||
segment_urls,
|
||||
headers,
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_line_with_filtering(
|
||||
self, line: str, base_url: str, discontinuity_pending: bool, pending_extinf: Optional[str]
|
||||
) -> tuple:
|
||||
"""
|
||||
Process a single line with segment filtering (skip/mute/black).
|
||||
|
||||
Uses the IntroHater approach: discontinuity is added BEFORE the URL of the
|
||||
first segment after a skipped section, not before the EXTINF.
|
||||
|
||||
Returns a tuple of (processed_lines, discontinuity_pending, pending_extinf).
|
||||
processed_lines is None if the line should be skipped, otherwise a string to output.
|
||||
"""
|
||||
# Handle EXTINF lines (segment duration markers)
|
||||
if line.startswith("#EXTINF:"):
|
||||
duration = self._parse_extinf_duration(line)
|
||||
|
||||
if self.skip_filter.should_skip_segment(duration):
|
||||
# Skip this segment - don't buffer the EXTINF
|
||||
self.skip_filter.advance_time(duration)
|
||||
return (None, True, None) # discontinuity_pending = True, clear pending
|
||||
else:
|
||||
# Keep this segment
|
||||
self.skip_filter.advance_time(duration)
|
||||
return (None, discontinuity_pending, line) # Buffer EXTINF
|
||||
|
||||
# Handle segment URLs (non-comment, non-empty lines)
|
||||
if not line.startswith("#") and line.strip():
|
||||
if pending_extinf is None:
|
||||
# No pending EXTINF means this segment was skipped
|
||||
return (None, discontinuity_pending, None)
|
||||
|
||||
# Build output: optional discontinuity + EXTINF + URL
|
||||
# Per HLS spec, EXT-X-DISCONTINUITY must appear before the first segment of the new content
|
||||
processed_url = await self.proxy_content_url(line, base_url)
|
||||
|
||||
output_lines = []
|
||||
if discontinuity_pending:
|
||||
output_lines.append("#EXT-X-DISCONTINUITY")
|
||||
output_lines.append(pending_extinf)
|
||||
output_lines.append(processed_url)
|
||||
|
||||
return ("\n".join(output_lines), False, None)
|
||||
|
||||
# Handle existing discontinuity markers - pass through and reset pending
|
||||
if line.startswith("#EXT-X-DISCONTINUITY"):
|
||||
return (line, False, pending_extinf)
|
||||
|
||||
# Handle key lines
|
||||
if "URI=" in line:
|
||||
processed = await self.process_key_line(line, base_url)
|
||||
return (processed, discontinuity_pending, pending_extinf)
|
||||
|
||||
# All other lines (headers, comments, etc.)
|
||||
return (line, discontinuity_pending, pending_extinf)
|
||||
|
||||
async def process_line(self, line: str, base_url: str) -> str:
|
||||
"""
|
||||
@@ -186,14 +681,23 @@ class M3U8Processor:
|
||||
full_url = parse.urljoin(base_url, original_uri)
|
||||
line = line.replace(f'URI="{original_uri}"', f'URI="{full_url}"')
|
||||
return line
|
||||
|
||||
|
||||
uri_match = re.search(r'URI="([^"]+)"', line)
|
||||
if uri_match:
|
||||
original_uri = uri_match.group(1)
|
||||
uri = parse.urlparse(original_uri)
|
||||
if self.key_url:
|
||||
# Only substitute key_url scheme/netloc for actual EXT-X-KEY lines.
|
||||
# EXT-X-MAP (init segments) and other tags must keep their original host,
|
||||
# otherwise the proxied destination URL gets the wrong upstream hostname.
|
||||
if self.key_url and line.startswith("#EXT-X-KEY"):
|
||||
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
|
||||
new_uri = await self.proxy_url(uri.geturl(), base_url)
|
||||
# Check if this is a DLHD stream with key params (needs stream endpoint for header computation)
|
||||
query_params = dict(self.request.query_params)
|
||||
is_dlhd_key_request = "dlhd_salt" in query_params and "/key/" in uri.geturl()
|
||||
# Use stream endpoint for DLHD key URLs, manifest endpoint for others
|
||||
new_uri = await self.proxy_url(
|
||||
uri.geturl(), base_url, use_full_url=True, is_playlist=not is_dlhd_key_request
|
||||
)
|
||||
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
|
||||
return line
|
||||
|
||||
@@ -223,16 +727,19 @@ class M3U8Processor:
|
||||
|
||||
# Check if we should force MediaFlow proxy for all playlist URLs
|
||||
if self.force_playlist_proxy:
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True)
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
|
||||
|
||||
# For playlist URLs, always use MediaFlow proxy regardless of strategy
|
||||
# Check for actual playlist file extensions, not just substring matches
|
||||
parsed_url = parse.urlparse(full_url)
|
||||
if (parsed_url.path.endswith((".m3u", ".m3u8", ".m3u_plus")) or
|
||||
parse.parse_qs(parsed_url.query).get("type", [""])[0] in ["m3u", "m3u8", "m3u_plus"]):
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True)
|
||||
is_playlist_url = parsed_url.path.endswith((".m3u", ".m3u8", ".m3u_plus")) or parse.parse_qs(
|
||||
parsed_url.query
|
||||
).get("type", [""])[0] in ["m3u", "m3u8", "m3u_plus"]
|
||||
|
||||
# Route non-playlist content URLs based on strategy
|
||||
if is_playlist_url:
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
|
||||
|
||||
# Route non-playlist content URLs (segments) based on strategy
|
||||
if routing_strategy == "direct":
|
||||
# Return the URL directly without any proxying
|
||||
return full_url
|
||||
@@ -250,9 +757,33 @@ class M3U8Processor:
|
||||
)
|
||||
else:
|
||||
# Default to MediaFlow proxy (routing_strategy == "mediaflow" or fallback)
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True)
|
||||
# Use stream endpoint for segment URLs
|
||||
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=False)
|
||||
|
||||
async def proxy_url(self, url: str, base_url: str, use_full_url: bool = False) -> str:
|
||||
def _extract_segment_urls_from_content(self, content: str, base_url: str) -> list:
|
||||
"""
|
||||
Extract segment URLs from HLS playlist content.
|
||||
|
||||
Args:
|
||||
content: Raw playlist content
|
||||
base_url: Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List of absolute segment URLs
|
||||
"""
|
||||
segment_urls = []
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
# Absolute URL
|
||||
if line.startswith("http://") or line.startswith("https://"):
|
||||
segment_urls.append(line)
|
||||
else:
|
||||
# Relative URL - resolve against base
|
||||
segment_urls.append(parse.urljoin(base_url, line))
|
||||
return segment_urls
|
||||
|
||||
async def proxy_url(self, url: str, base_url: str, use_full_url: bool = False, is_playlist: bool = True) -> str:
|
||||
"""
|
||||
Proxies a URL, encoding it with the MediaFlow proxy URL.
|
||||
|
||||
@@ -260,6 +791,7 @@ class M3U8Processor:
|
||||
url (str): The URL to proxy.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
use_full_url (bool): Whether to use the URL as-is (True) or join with base_url (False).
|
||||
is_playlist (bool): Whether this is a playlist URL (uses manifest endpoint) or segment URL (uses stream endpoint).
|
||||
|
||||
Returns:
|
||||
str: The proxied URL.
|
||||
@@ -271,15 +803,52 @@ class M3U8Processor:
|
||||
|
||||
query_params = dict(self.request.query_params)
|
||||
has_encrypted = query_params.pop("has_encrypted", False)
|
||||
# Remove the response headers from the query params to avoid it being added to the consecutive requests
|
||||
[query_params.pop(key, None) for key in list(query_params.keys()) if key.startswith("r_")]
|
||||
# Remove force_playlist_proxy to avoid it being added to subsequent requests
|
||||
# Remove the response headers (r_) from the query params to avoid it being added to the consecutive requests
|
||||
# BUT keep rp_ (response propagate) headers as they should propagate to segments
|
||||
[
|
||||
query_params.pop(key, None)
|
||||
for key in list(query_params.keys())
|
||||
if key.lower().startswith("r_") and not key.lower().startswith("rp_")
|
||||
]
|
||||
# Remove manifest-only parameters to avoid them being added to subsequent requests
|
||||
query_params.pop("force_playlist_proxy", None)
|
||||
if not is_playlist:
|
||||
query_params.pop("start_offset", None)
|
||||
|
||||
# Use appropriate proxy URL based on content type
|
||||
if is_playlist:
|
||||
proxy_url = self.mediaflow_proxy_url
|
||||
else:
|
||||
# Check if this is a DLHD key URL (needs /stream endpoint for header computation)
|
||||
is_dlhd_key = "dlhd_salt" in query_params and "/key/" in full_url
|
||||
if is_dlhd_key:
|
||||
# Use /stream endpoint for DLHD key URLs
|
||||
proxy_url = self.mediaflow_proxy_url.replace("/hls/manifest.m3u8", "/stream")
|
||||
else:
|
||||
# Determine segment extension from the URL
|
||||
# Default to .ts for traditional HLS, but detect fMP4 extensions
|
||||
segment_ext = "ts"
|
||||
url_lower = full_url.lower()
|
||||
# Check for fMP4/CMAF extensions
|
||||
if url_lower.endswith(".m4s"):
|
||||
segment_ext = "m4s"
|
||||
elif url_lower.endswith(".mp4"):
|
||||
segment_ext = "mp4"
|
||||
elif url_lower.endswith(".m4a"):
|
||||
segment_ext = "m4a"
|
||||
elif url_lower.endswith(".m4v"):
|
||||
segment_ext = "m4v"
|
||||
elif url_lower.endswith(".aac"):
|
||||
segment_ext = "aac"
|
||||
# Build segment proxy URL with correct extension
|
||||
proxy_url = f"{self.segment_proxy_base_url}.{segment_ext}"
|
||||
# Remove h_range header - each segment should handle its own range requests
|
||||
query_params.pop("h_range", None)
|
||||
|
||||
return encode_mediaflow_proxy_url(
|
||||
self.mediaflow_proxy_url,
|
||||
"",
|
||||
proxy_url,
|
||||
None, # No endpoint - URL is already complete
|
||||
full_url,
|
||||
query_params=query_params,
|
||||
encryption_handler=encryption_handler if has_encrypted else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -10,6 +10,35 @@ import xmltodict
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_url(base_url: str, relative_url: str) -> str:
|
||||
"""
|
||||
Resolve a relative URL against a base URL.
|
||||
|
||||
Handles three cases:
|
||||
1. Absolute URL (starts with http:// or https://) - return as-is
|
||||
2. Absolute path (starts with /) - resolve against origin (scheme + host)
|
||||
3. Relative path - resolve against base URL directory
|
||||
|
||||
Args:
|
||||
base_url: The base URL (typically the MPD URL)
|
||||
relative_url: The URL to resolve
|
||||
|
||||
Returns:
|
||||
The resolved absolute URL
|
||||
"""
|
||||
if not relative_url:
|
||||
return base_url
|
||||
|
||||
# Already absolute URL
|
||||
if relative_url.startswith(("http://", "https://")):
|
||||
return relative_url
|
||||
|
||||
# Use urljoin which correctly handles:
|
||||
# - Absolute paths (starting with /) -> resolves against origin
|
||||
# - Relative paths -> resolves against base URL
|
||||
return urljoin(base_url, relative_url)
|
||||
|
||||
|
||||
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
|
||||
"""
|
||||
Parses the MPD content into a dictionary.
|
||||
@@ -43,7 +72,6 @@ def parse_mpd_dict(
|
||||
"""
|
||||
profiles = []
|
||||
parsed_dict = {}
|
||||
source = "/".join(mpd_url.split("/")[:-1])
|
||||
|
||||
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
|
||||
parsed_dict["isLive"] = is_live
|
||||
@@ -66,7 +94,10 @@ def parse_mpd_dict(
|
||||
|
||||
for period in periods:
|
||||
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
|
||||
for adaptation in period["AdaptationSet"]:
|
||||
adaptation_sets = period["AdaptationSet"]
|
||||
adaptation_sets = adaptation_sets if isinstance(adaptation_sets, list) else [adaptation_sets]
|
||||
|
||||
for adaptation in adaptation_sets:
|
||||
representations = adaptation["Representation"]
|
||||
representations = representations if isinstance(representations, list) else [representations]
|
||||
|
||||
@@ -75,7 +106,7 @@ def parse_mpd_dict(
|
||||
parsed_dict,
|
||||
representation,
|
||||
adaptation,
|
||||
source,
|
||||
mpd_url,
|
||||
media_presentation_duration,
|
||||
parse_segment_profile_id,
|
||||
)
|
||||
@@ -195,7 +226,7 @@ def parse_representation(
|
||||
parsed_dict: dict,
|
||||
representation: dict,
|
||||
adaptation: dict,
|
||||
source: str,
|
||||
mpd_url: str,
|
||||
media_presentation_duration: str,
|
||||
parse_segment_profile_id: Optional[str],
|
||||
) -> Optional[dict]:
|
||||
@@ -206,7 +237,7 @@ def parse_representation(
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
representation (dict): The representation data.
|
||||
adaptation (dict): The adaptation set data.
|
||||
source (str): The source URL.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
media_presentation_duration (str): The media presentation duration.
|
||||
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
||||
|
||||
@@ -263,14 +294,50 @@ def parse_representation(
|
||||
else:
|
||||
profile["segment_template_start_number"] = 1
|
||||
|
||||
# For SegmentBase profiles, we need to set initUrl even when not parsing segments
|
||||
# This is needed for the HLS playlist builder to reference the init URL
|
||||
segment_base_data = representation.get("SegmentBase")
|
||||
if segment_base_data and "initUrl" not in profile:
|
||||
base_url = representation.get("BaseURL", "")
|
||||
profile["initUrl"] = resolve_url(mpd_url, base_url)
|
||||
|
||||
# Store initialization range if available
|
||||
if "Initialization" in segment_base_data:
|
||||
init_range = segment_base_data["Initialization"].get("@range")
|
||||
if init_range:
|
||||
profile["initRange"] = init_range
|
||||
|
||||
# For SegmentList profiles, we also need to set initUrl even when not parsing segments
|
||||
segment_list_data = representation.get("SegmentList") or adaptation.get("SegmentList")
|
||||
if segment_list_data and "initUrl" not in profile:
|
||||
if "Initialization" in segment_list_data:
|
||||
init_data = segment_list_data["Initialization"]
|
||||
if "@sourceURL" in init_data:
|
||||
init_url = init_data["@sourceURL"]
|
||||
profile["initUrl"] = resolve_url(mpd_url, init_url)
|
||||
elif "@range" in init_data:
|
||||
base_url = representation.get("BaseURL", "")
|
||||
profile["initUrl"] = resolve_url(mpd_url, base_url)
|
||||
profile["initRange"] = init_data["@range"]
|
||||
|
||||
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
|
||||
return profile
|
||||
|
||||
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
|
||||
if item:
|
||||
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
|
||||
# Parse segments based on the addressing scheme used
|
||||
segment_template = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
|
||||
segment_list = adaptation.get("SegmentList") or representation.get("SegmentList")
|
||||
|
||||
# Get BaseURL from representation (can be relative path like "a/b/c/")
|
||||
base_url = representation.get("BaseURL", "")
|
||||
|
||||
if segment_template:
|
||||
profile["segments"] = parse_segment_template(parsed_dict, segment_template, profile, mpd_url, base_url)
|
||||
elif segment_list:
|
||||
# Get timescale from SegmentList or default to 1
|
||||
timescale = int(segment_list.get("@timescale", 1))
|
||||
profile["segments"] = parse_segment_list(adaptation, representation, profile, mpd_url, timescale)
|
||||
else:
|
||||
profile["segments"] = parse_segment_base(representation, profile, source)
|
||||
profile["segments"] = parse_segment_base(representation, profile, mpd_url)
|
||||
|
||||
return profile
|
||||
|
||||
@@ -290,7 +357,9 @@ def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
|
||||
return representation.get(key, adaptation.get(key, None))
|
||||
|
||||
|
||||
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
|
||||
def parse_segment_template(
|
||||
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, base_url: str = ""
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Parses a segment template and extracts segment information.
|
||||
|
||||
@@ -298,7 +367,8 @@ def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment template data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
@@ -311,20 +381,23 @@ def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
media = item["@initialization"]
|
||||
media = media.replace("$RepresentationID$", profile["id"])
|
||||
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
||||
if not media.startswith("http"):
|
||||
media = f"{source}/{media}"
|
||||
profile["initUrl"] = media
|
||||
# Combine base_url and media, then resolve against mpd_url
|
||||
if base_url:
|
||||
media = base_url + media
|
||||
profile["initUrl"] = resolve_url(mpd_url, media)
|
||||
|
||||
# Segments
|
||||
if "SegmentTimeline" in item:
|
||||
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
|
||||
segments.extend(parse_segment_timeline(parsed_dict, item, profile, mpd_url, timescale, base_url))
|
||||
elif "@duration" in item:
|
||||
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
|
||||
segments.extend(parse_segment_duration(parsed_dict, item, profile, mpd_url, timescale, base_url))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||
def parse_segment_timeline(
|
||||
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, timescale: int, base_url: str = ""
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Parses a segment timeline and extracts segment information.
|
||||
|
||||
@@ -332,8 +405,9 @@ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment timeline data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
timescale (int): The timescale for the segments.
|
||||
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
@@ -347,7 +421,7 @@ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
start_number = int(item.get("@startNumber", 1))
|
||||
|
||||
segments = [
|
||||
create_segment_data(timeline, item, profile, source, timescale)
|
||||
create_segment_data(timeline, item, profile, mpd_url, timescale, base_url)
|
||||
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
|
||||
]
|
||||
return segments
|
||||
@@ -379,14 +453,14 @@ def preprocess_timeline(
|
||||
for _ in range(repeat + 1):
|
||||
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
|
||||
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
|
||||
presentation_time = start_time - presentation_time_offset
|
||||
processed_data.append(
|
||||
{
|
||||
"number": start_number,
|
||||
"start_time": segment_start_time,
|
||||
"end_time": segment_end_time,
|
||||
"duration": duration,
|
||||
"time": presentation_time,
|
||||
"time": start_time,
|
||||
"duration_mpd_timescale": duration,
|
||||
}
|
||||
)
|
||||
start_time += duration
|
||||
@@ -397,7 +471,9 @@ def preprocess_timeline(
|
||||
return processed_data
|
||||
|
||||
|
||||
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||
def parse_segment_duration(
|
||||
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, timescale: int, base_url: str = ""
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Parses segment duration and extracts segment information.
|
||||
This is used for static or live MPD manifests.
|
||||
@@ -406,8 +482,9 @@ def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment duration data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
timescale (int): The timescale for the segments.
|
||||
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
@@ -421,7 +498,7 @@ def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source:
|
||||
else:
|
||||
segments = generate_vod_segments(profile, duration, timescale, start_number)
|
||||
|
||||
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
|
||||
return [create_segment_data(seg, item, profile, mpd_url, timescale, base_url) for seg in segments]
|
||||
|
||||
|
||||
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
|
||||
@@ -480,7 +557,9 @@ def generate_vod_segments(profile: dict, duration: int, timescale: int, start_nu
|
||||
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
|
||||
|
||||
|
||||
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
|
||||
def create_segment_data(
|
||||
segment: Dict, item: dict, profile: dict, mpd_url: str, timescale: Optional[int] = None, base_url: str = ""
|
||||
) -> Dict:
|
||||
"""
|
||||
Creates segment data based on the segment information. This includes the segment URL and metadata.
|
||||
|
||||
@@ -488,8 +567,9 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
|
||||
segment (Dict): The segment information.
|
||||
item (dict): The segment template data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
timescale (int, optional): The timescale for the segments. Defaults to None.
|
||||
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
|
||||
|
||||
Returns:
|
||||
Dict: The created segment data.
|
||||
@@ -503,8 +583,10 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
|
||||
if "time" in segment and timescale is not None:
|
||||
media = media.replace("$Time$", str(int(segment["time"])))
|
||||
|
||||
if not media.startswith("http"):
|
||||
media = f"{source}/{media}"
|
||||
# Combine base_url and media, then resolve against mpd_url
|
||||
if base_url:
|
||||
media = base_url + media
|
||||
media = resolve_url(mpd_url, media)
|
||||
|
||||
segment_data = {
|
||||
"type": "segment",
|
||||
@@ -528,7 +610,8 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
|
||||
}
|
||||
)
|
||||
elif "start_time" in segment and "duration" in segment:
|
||||
duration_seconds = segment["duration"] / timescale
|
||||
# duration here is in timescale units (from timeline segments)
|
||||
duration_seconds = segment["duration"] / timescale if timescale else segment["duration"]
|
||||
segment_data.update(
|
||||
{
|
||||
"start_time": segment["start_time"],
|
||||
@@ -537,43 +620,144 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
|
||||
"program_date_time": segment["start_time"].isoformat() + "Z",
|
||||
}
|
||||
)
|
||||
elif "duration" in segment and timescale is not None:
|
||||
# Convert duration from timescale units to seconds
|
||||
segment_data["extinf"] = segment["duration"] / timescale
|
||||
elif "duration" in segment:
|
||||
# If no timescale is provided, assume duration is already in seconds
|
||||
# duration from generate_vod_segments and generate_live_segments is already in seconds
|
||||
segment_data["extinf"] = segment["duration"]
|
||||
|
||||
return segment_data
|
||||
|
||||
|
||||
def parse_segment_base(representation: dict, profile: dict, source: str) -> List[Dict]:
|
||||
def parse_segment_list(
|
||||
adaptation: dict, representation: dict, profile: dict, mpd_url: str, timescale: int
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Parses segment base information and extracts segment data. This is used for single-segment representations.
|
||||
Parses SegmentList element with explicit SegmentURL entries.
|
||||
|
||||
SegmentList MPDs explicitly list each segment URL, unlike SegmentTemplate which uses
|
||||
URL patterns. This is less common but used by some packagers.
|
||||
|
||||
Args:
|
||||
adaptation (dict): The adaptation set data.
|
||||
representation (dict): The representation data.
|
||||
source (str): The source URL.
|
||||
profile (dict): The profile information.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
timescale (int): The timescale for duration calculations.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
segment = representation["SegmentBase"]
|
||||
start, end = map(int, segment["@indexRange"].split("-"))
|
||||
if "Initialization" in segment:
|
||||
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
|
||||
|
||||
# Set initUrl for SegmentBase
|
||||
if not representation['BaseURL'].startswith("http"):
|
||||
profile["initUrl"] = f"{source}/{representation['BaseURL']}"
|
||||
else:
|
||||
profile["initUrl"] = representation['BaseURL']
|
||||
# SegmentList can be at AdaptationSet or Representation level
|
||||
segment_list = representation.get("SegmentList") or adaptation.get("SegmentList", {})
|
||||
segments = []
|
||||
|
||||
# Handle Initialization element
|
||||
if "Initialization" in segment_list:
|
||||
init_data = segment_list["Initialization"]
|
||||
if "@sourceURL" in init_data:
|
||||
init_url = init_data["@sourceURL"]
|
||||
profile["initUrl"] = resolve_url(mpd_url, init_url)
|
||||
elif "@range" in init_data:
|
||||
# Initialization by byte range on the BaseURL
|
||||
base_url = representation.get("BaseURL", "")
|
||||
profile["initUrl"] = resolve_url(mpd_url, base_url)
|
||||
profile["initRange"] = init_data["@range"]
|
||||
|
||||
# Get segment duration from SegmentList attributes
|
||||
duration = int(segment_list.get("@duration", 0))
|
||||
list_timescale = int(segment_list.get("@timescale", timescale or 1))
|
||||
segment_duration_sec = duration / list_timescale if list_timescale else 0
|
||||
|
||||
# Parse SegmentURL elements
|
||||
segment_urls = segment_list.get("SegmentURL", [])
|
||||
if not isinstance(segment_urls, list):
|
||||
segment_urls = [segment_urls]
|
||||
|
||||
for i, seg_url in enumerate(segment_urls):
|
||||
if seg_url is None:
|
||||
continue
|
||||
|
||||
# Get media URL - can be @media attribute or use BaseURL with @mediaRange
|
||||
media_url = seg_url.get("@media", "")
|
||||
media_range = seg_url.get("@mediaRange")
|
||||
|
||||
if media_url:
|
||||
media_url = resolve_url(mpd_url, media_url)
|
||||
else:
|
||||
# Use BaseURL with byte range
|
||||
base_url = representation.get("BaseURL", "")
|
||||
media_url = resolve_url(mpd_url, base_url)
|
||||
|
||||
segment_data = {
|
||||
"type": "segment",
|
||||
"media": media_url,
|
||||
"number": i + 1,
|
||||
"extinf": segment_duration_sec if segment_duration_sec > 0 else 1.0,
|
||||
}
|
||||
|
||||
# Include media range if specified
|
||||
if media_range:
|
||||
segment_data["mediaRange"] = media_range
|
||||
|
||||
segments.append(segment_data)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def parse_segment_base(representation: dict, profile: dict, mpd_url: str) -> List[Dict]:
|
||||
"""
|
||||
Parses segment base information and extracts segment data. This is used for single-segment representations
|
||||
(SegmentBase MPDs, typically GPAC-generated on-demand profiles).
|
||||
|
||||
For SegmentBase, the entire media file is treated as a single segment. The initialization data
|
||||
is specified by the Initialization element's range, and the segment index (SIDX) is at indexRange.
|
||||
|
||||
Args:
|
||||
representation (dict): The representation data.
|
||||
profile (dict): The profile information.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
segment = representation.get("SegmentBase", {})
|
||||
base_url = representation.get("BaseURL", "")
|
||||
|
||||
# Build the full media URL
|
||||
media_url = resolve_url(mpd_url, base_url)
|
||||
|
||||
# Set initUrl for SegmentBase - this is the URL with the initialization range
|
||||
# The initialization segment contains codec/track info needed before playing media
|
||||
profile["initUrl"] = media_url
|
||||
|
||||
# For SegmentBase, we need to specify byte ranges for init and media segments
|
||||
init_range = None
|
||||
if "Initialization" in segment:
|
||||
init_range = segment["Initialization"].get("@range")
|
||||
|
||||
# Store initialization range in profile for segment endpoint to use
|
||||
if init_range:
|
||||
profile["initRange"] = init_range
|
||||
|
||||
# Get the index range which points to SIDX box
|
||||
index_range = segment.get("@indexRange", "")
|
||||
|
||||
# Calculate total duration from profile's mediaPresentationDuration
|
||||
total_duration = profile.get("mediaPresentationDuration")
|
||||
if isinstance(total_duration, str):
|
||||
total_duration = parse_duration(total_duration)
|
||||
elif total_duration is None:
|
||||
total_duration = 0
|
||||
|
||||
# For SegmentBase, we return a single segment representing the entire media
|
||||
# The media URL is the same as initUrl but will be accessed with different byte ranges
|
||||
return [
|
||||
{
|
||||
"type": "segment",
|
||||
"range": f"{start}-{end}",
|
||||
"media": f"{source}/{representation['BaseURL']}",
|
||||
"media": media_url,
|
||||
"number": 1,
|
||||
"extinf": total_duration if total_duration > 0 else 1.0,
|
||||
"indexRange": index_range,
|
||||
"initRange": init_range,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#Adapted for use in MediaFlowProxy from:
|
||||
#https://github.com/einars/js-beautify/blob/master/python/jsbeautifier/unpackers/packer.py
|
||||
# Adapted for use in MediaFlowProxy from:
|
||||
# https://github.com/einars/js-beautify/blob/master/python/jsbeautifier/unpackers/packer.py
|
||||
# Unpacker for Dean Edward's p.a.c.k.e.r, a part of javascript beautifier
|
||||
# by Einar Lielmanis <einar@beautifier.io>
|
||||
#
|
||||
@@ -21,8 +21,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
def detect(source):
|
||||
if "eval(function(p,a,c,k,e,d)" in source:
|
||||
mystr = "smth"
|
||||
@@ -71,9 +69,7 @@ def _filterargs(source):
|
||||
raise UnpackingError("Corrupted p.a.c.k.e.r. data.")
|
||||
|
||||
# could not find a satisfying regex
|
||||
raise UnpackingError(
|
||||
"Could not make sense of p.a.c.k.e.r data (unexpected code structure)"
|
||||
)
|
||||
raise UnpackingError("Could not make sense of p.a.c.k.e.r data (unexpected code structure)")
|
||||
|
||||
|
||||
def _replacestrings(source):
|
||||
@@ -88,7 +84,7 @@ def _replacestrings(source):
|
||||
for index, value in enumerate(lookup):
|
||||
source = source.replace(variable % index, '"%s"' % value)
|
||||
return source[startpoint:]
|
||||
return source
|
||||
return source
|
||||
|
||||
|
||||
class Unbaser(object):
|
||||
@@ -97,10 +93,7 @@ class Unbaser(object):
|
||||
|
||||
ALPHABET = {
|
||||
62: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
95: (
|
||||
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
|
||||
),
|
||||
95: (" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"),
|
||||
}
|
||||
|
||||
def __init__(self, base):
|
||||
@@ -118,9 +111,7 @@ class Unbaser(object):
|
||||
else:
|
||||
# Build conversion dictionary cache
|
||||
try:
|
||||
self.dictionary = dict(
|
||||
(cipher, index) for index, cipher in enumerate(self.ALPHABET[base])
|
||||
)
|
||||
self.dictionary = dict((cipher, index) for index, cipher in enumerate(self.ALPHABET[base]))
|
||||
except KeyError:
|
||||
raise TypeError("Unsupported base encoding.")
|
||||
|
||||
@@ -135,6 +126,8 @@ class Unbaser(object):
|
||||
for index, cipher in enumerate(string[::-1]):
|
||||
ret += (self.base**index) * self.dictionary[cipher]
|
||||
return ret
|
||||
|
||||
|
||||
class UnpackingError(Exception):
|
||||
"""Badly packed source or general error. Argument is a
|
||||
meaningful description."""
|
||||
@@ -142,11 +135,10 @@ class UnpackingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def eval_solver(self, url: str, headers: dict[str, str] | None, patterns: list[str]) -> str:
|
||||
try:
|
||||
response = await self._make_request(url, headers=headers)
|
||||
soup = BeautifulSoup(response.text, "lxml",parse_only=SoupStrainer("script"))
|
||||
soup = BeautifulSoup(response.text, "lxml", parse_only=SoupStrainer("script"))
|
||||
script_all = soup.find_all("script")
|
||||
for i in script_all:
|
||||
if detect(i.text):
|
||||
@@ -162,4 +154,4 @@ async def eval_solver(self, url: str, headers: dict[str, str] | None, patterns:
|
||||
raise UnpackingError("No p.a.c.k.e.d JS found or no pattern matched.")
|
||||
except Exception as e:
|
||||
logger.exception("Eval solver error for %s", url)
|
||||
raise UnpackingError("Error in eval_solver") from e
|
||||
raise UnpackingError("Error in eval_solver") from e
|
||||
|
||||
@@ -8,7 +8,7 @@ from .aes import AES
|
||||
from .rijndael import Rijndael
|
||||
from .cryptomath import bytesToNumber, numberToByteArray
|
||||
|
||||
__all__ = ['new', 'Python_AES']
|
||||
__all__ = ["new", "Python_AES"]
|
||||
|
||||
|
||||
def new(key, mode, IV):
|
||||
@@ -37,22 +37,21 @@ class Python_AES(AES):
|
||||
plaintextBytes = bytearray(plaintext)
|
||||
chainBytes = self.IV[:]
|
||||
|
||||
#CBC Mode: For each block...
|
||||
for x in range(len(plaintextBytes)//16):
|
||||
|
||||
#XOR with the chaining block
|
||||
blockBytes = plaintextBytes[x*16 : (x*16)+16]
|
||||
# CBC Mode: For each block...
|
||||
for x in range(len(plaintextBytes) // 16):
|
||||
# XOR with the chaining block
|
||||
blockBytes = plaintextBytes[x * 16 : (x * 16) + 16]
|
||||
for y in range(16):
|
||||
blockBytes[y] ^= chainBytes[y]
|
||||
|
||||
#Encrypt it
|
||||
# Encrypt it
|
||||
encryptedBytes = self.rijndael.encrypt(blockBytes)
|
||||
|
||||
#Overwrite the input with the output
|
||||
# Overwrite the input with the output
|
||||
for y in range(16):
|
||||
plaintextBytes[(x*16)+y] = encryptedBytes[y]
|
||||
plaintextBytes[(x * 16) + y] = encryptedBytes[y]
|
||||
|
||||
#Set the next chaining block
|
||||
# Set the next chaining block
|
||||
chainBytes = encryptedBytes
|
||||
|
||||
self.IV = chainBytes[:]
|
||||
@@ -64,19 +63,18 @@ class Python_AES(AES):
|
||||
ciphertextBytes = ciphertext[:]
|
||||
chainBytes = self.IV[:]
|
||||
|
||||
#CBC Mode: For each block...
|
||||
for x in range(len(ciphertextBytes)//16):
|
||||
|
||||
#Decrypt it
|
||||
blockBytes = ciphertextBytes[x*16 : (x*16)+16]
|
||||
# CBC Mode: For each block...
|
||||
for x in range(len(ciphertextBytes) // 16):
|
||||
# Decrypt it
|
||||
blockBytes = ciphertextBytes[x * 16 : (x * 16) + 16]
|
||||
decryptedBytes = self.rijndael.decrypt(blockBytes)
|
||||
|
||||
#XOR with the chaining block and overwrite the input with output
|
||||
# XOR with the chaining block and overwrite the input with output
|
||||
for y in range(16):
|
||||
decryptedBytes[y] ^= chainBytes[y]
|
||||
ciphertextBytes[(x*16)+y] = decryptedBytes[y]
|
||||
ciphertextBytes[(x * 16) + y] = decryptedBytes[y]
|
||||
|
||||
#Set the next chaining block
|
||||
# Set the next chaining block
|
||||
chainBytes = blockBytes
|
||||
|
||||
self.IV = chainBytes[:]
|
||||
@@ -89,7 +87,7 @@ class Python_AES_CTR(AES):
|
||||
self.rijndael = Rijndael(key, 16)
|
||||
self.IV = IV
|
||||
self._counter_bytes = 16 - len(self.IV)
|
||||
self._counter = self.IV + bytearray(b'\x00' * self._counter_bytes)
|
||||
self._counter = self.IV + bytearray(b"\x00" * self._counter_bytes)
|
||||
|
||||
@property
|
||||
def counter(self):
|
||||
@@ -102,9 +100,9 @@ class Python_AES_CTR(AES):
|
||||
def _counter_update(self):
|
||||
counter_int = bytesToNumber(self._counter) + 1
|
||||
self._counter = numberToByteArray(counter_int, 16)
|
||||
if self._counter_bytes > 0 and \
|
||||
self._counter[-self._counter_bytes:] == \
|
||||
bytearray(b'\xff' * self._counter_bytes):
|
||||
if self._counter_bytes > 0 and self._counter[-self._counter_bytes :] == bytearray(
|
||||
b"\xff" * self._counter_bytes
|
||||
):
|
||||
raise OverflowError("CTR counter overflowed")
|
||||
|
||||
def encrypt(self, plaintext):
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Rate limit handlers for host-specific rate limiting strategies.
|
||||
|
||||
This module provides handler classes that implement specific rate limiting
|
||||
logic for different streaming hosts (e.g., Vidoza's aggressive 509 rate limiting).
|
||||
|
||||
Similar pattern to stream_transformers.py but for rate limiting behavior.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitHandler:
|
||||
"""
|
||||
Base class for rate limit handlers.
|
||||
|
||||
Subclasses should override properties to customize rate limiting behavior.
|
||||
"""
|
||||
|
||||
@property
|
||||
def cooldown_seconds(self) -> int:
|
||||
"""
|
||||
Duration in seconds to wait between upstream connections.
|
||||
Default: 0 (no cooldown, allow immediate requests)
|
||||
"""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def use_head_cache(self) -> bool:
|
||||
"""
|
||||
Whether to cache HEAD responses to avoid upstream calls.
|
||||
Default: False
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def use_stream_gate(self) -> bool:
|
||||
"""
|
||||
Whether to use distributed locking to serialize requests.
|
||||
Default: False
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def exclusive_stream(self) -> bool:
|
||||
"""
|
||||
If True, the stream gate is held for the ENTIRE duration of the stream.
|
||||
This prevents any concurrent connections to the same URL.
|
||||
Required for hosts that 509 on ANY concurrent streams.
|
||||
Default: False (gate released after headers received)
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self) -> int:
|
||||
"""
|
||||
Value for Retry-After header when returning 503.
|
||||
Default: 2
|
||||
"""
|
||||
return 2
|
||||
|
||||
|
||||
class VidozaRateLimitHandler(RateLimitHandler):
|
||||
"""
|
||||
Rate limit handler for Vidoza CDN.
|
||||
|
||||
Vidoza aggressively rate-limits (509) if ANY concurrent connections exist
|
||||
to the same URL from the same IP. This handler:
|
||||
- Uses EXCLUSIVE stream gate: only ONE stream at a time (gate held during entire stream)
|
||||
- Caches HEAD responses to serve repeated probes without connections
|
||||
- ExoPlayer/clients must wait for the current stream to finish before starting a new one
|
||||
|
||||
WARNING: This means only one client can actively stream at a time. Other clients will
|
||||
wait (up to timeout) and eventually get 503 if the current stream is too long.
|
||||
"""
|
||||
|
||||
@property
|
||||
def cooldown_seconds(self) -> int:
|
||||
return 0 # No cooldown needed - we use exclusive streaming instead
|
||||
|
||||
@property
|
||||
def use_head_cache(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def use_stream_gate(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def exclusive_stream(self) -> bool:
|
||||
"""
|
||||
If True, the stream gate is held for the ENTIRE duration of the stream,
|
||||
not just at the start. This prevents any concurrent connections.
|
||||
Required for hosts like Vidoza that 509 on ANY concurrent connections.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self) -> int:
|
||||
return 5
|
||||
|
||||
|
||||
class AggressiveRateLimitHandler(RateLimitHandler):
|
||||
"""
|
||||
Generic aggressive rate limit handler for hosts with strict rate limiting.
|
||||
|
||||
Use this for hosts that show similar behavior to Vidoza but may have
|
||||
different thresholds.
|
||||
"""
|
||||
|
||||
@property
|
||||
def cooldown_seconds(self) -> int:
|
||||
return 3
|
||||
|
||||
@property
|
||||
def use_head_cache(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def use_stream_gate(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self) -> int:
|
||||
return 2
|
||||
|
||||
|
||||
# Registry of available rate limit handlers by ID
|
||||
RATE_LIMIT_HANDLER_REGISTRY: dict[str, type[RateLimitHandler]] = {
|
||||
"vidoza": VidozaRateLimitHandler,
|
||||
"aggressive": AggressiveRateLimitHandler,
|
||||
}
|
||||
|
||||
# Auto-detection: hostname patterns to handler IDs
|
||||
# These patterns are checked against the video URL hostname
|
||||
#
|
||||
# NOTE: Vidoza CDN DOES rate limit concurrent connections from the same IP.
|
||||
# When multiple clients request through the proxy, all requests come from
|
||||
# the proxy's IP, triggering Vidoza's rate limit (509 errors).
|
||||
# Stream-level rate limiting serializes requests to avoid this.
|
||||
#
|
||||
HOST_PATTERN_TO_HANDLER: dict[str, str] = {
|
||||
"vidoza.net": "vidoza",
|
||||
"vidoza.org": "vidoza",
|
||||
# Add more patterns as needed for hosts that rate-limit CDN streaming:
|
||||
# "example-cdn.com": "aggressive",
|
||||
}
|
||||
|
||||
|
||||
def get_rate_limit_handler(
|
||||
handler_id: Optional[str] = None,
|
||||
video_url: Optional[str] = None,
|
||||
) -> RateLimitHandler:
|
||||
"""
|
||||
Get a rate limit handler instance.
|
||||
|
||||
Priority:
|
||||
1. Explicit handler_id if provided
|
||||
2. Auto-detect from video_url hostname
|
||||
3. Default (no rate limiting)
|
||||
|
||||
Args:
|
||||
handler_id: Explicit handler identifier (e.g., "vidoza", "aggressive")
|
||||
video_url: Video URL for auto-detection based on hostname
|
||||
|
||||
Returns:
|
||||
A rate limit handler instance. Returns base RateLimitHandler (no-op) if
|
||||
no handler specified and no auto-detection match.
|
||||
"""
|
||||
# 1. Explicit handler ID
|
||||
if handler_id:
|
||||
handler_class = RATE_LIMIT_HANDLER_REGISTRY.get(handler_id)
|
||||
if handler_class:
|
||||
logger.debug(f"Using explicit rate limit handler: {handler_id}")
|
||||
return handler_class()
|
||||
else:
|
||||
logger.warning(f"Unknown rate limit handler ID: {handler_id}")
|
||||
|
||||
# 2. Auto-detect from URL hostname
|
||||
if video_url:
|
||||
try:
|
||||
hostname = urlparse(video_url).hostname or ""
|
||||
# Check each pattern
|
||||
for pattern, detected_handler_id in HOST_PATTERN_TO_HANDLER.items():
|
||||
if pattern in hostname:
|
||||
handler_class = RATE_LIMIT_HANDLER_REGISTRY.get(detected_handler_id)
|
||||
if handler_class:
|
||||
logger.info(f"[RateLimit] Auto-detected handler '{detected_handler_id}' for host: {hostname}")
|
||||
return handler_class()
|
||||
logger.debug(f"[RateLimit] No handler matched for hostname: {hostname}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[RateLimit] Error during auto-detection: {e}")
|
||||
|
||||
# 3. Default: no rate limiting
|
||||
return RateLimitHandler()
|
||||
|
||||
|
||||
def get_available_handlers() -> list[str]:
|
||||
"""Get list of available rate limit handler IDs."""
|
||||
return list(RATE_LIMIT_HANDLER_REGISTRY.keys())
|
||||
@@ -0,0 +1,861 @@
|
||||
"""
|
||||
Redis utilities for cross-worker coordination and caching.
|
||||
|
||||
Provides:
|
||||
- Distributed locking (stream gating, generic locks)
|
||||
- Shared caching (HEAD responses, extractors, MPD, segments, init segments)
|
||||
- In-flight request deduplication
|
||||
- Cooldown/throttle tracking
|
||||
|
||||
All caches are shared across all uvicorn workers via Redis.
|
||||
|
||||
IMPORTANT: Redis is OPTIONAL. If settings.redis_url is None, all functions
|
||||
gracefully degrade:
|
||||
- Locks always succeed immediately (no cross-worker coordination)
|
||||
- Cache operations return None/False (no shared caching)
|
||||
- Cooldowns always allow (no rate limiting)
|
||||
|
||||
This allows single-worker deployments to work without Redis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# Redis Clients (Lazy Singletons)
|
||||
# =============================================================================
|
||||
# Two clients: one for text/JSON (decode_responses=True), one for binary data
|
||||
|
||||
_redis_client = None
|
||||
_redis_binary_client = None
|
||||
_redis_available: Optional[bool] = None # None = not checked yet
|
||||
|
||||
|
||||
def is_redis_configured() -> bool:
|
||||
"""Check if Redis URL is configured in settings."""
|
||||
return settings.redis_url is not None and settings.redis_url.strip() != ""
|
||||
|
||||
|
||||
async def is_redis_available() -> bool:
|
||||
"""
|
||||
Check if Redis is configured and reachable.
|
||||
|
||||
Caches the result after first successful/failed connection attempt.
|
||||
"""
|
||||
global _redis_available
|
||||
|
||||
if not is_redis_configured():
|
||||
return False
|
||||
|
||||
if _redis_available is not None:
|
||||
return _redis_available
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
|
||||
client = redis.from_url(
|
||||
settings.redis_url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=2,
|
||||
socket_timeout=2,
|
||||
)
|
||||
await client.ping()
|
||||
await client.aclose()
|
||||
_redis_available = True
|
||||
logger.info(f"Redis is available: {settings.redis_url}")
|
||||
except Exception as e:
|
||||
_redis_available = False
|
||||
logger.warning(f"Redis not available (features will be disabled): {e}")
|
||||
|
||||
return _redis_available
|
||||
|
||||
|
||||
async def get_redis():
|
||||
"""
|
||||
Get or create the Redis connection pool for text/JSON data (lazy singleton).
|
||||
|
||||
The connection pool is shared across all async tasks in a single worker.
|
||||
Each worker process has its own pool, but Redis itself coordinates across workers.
|
||||
|
||||
Returns None if Redis is not configured or not available.
|
||||
"""
|
||||
global _redis_client
|
||||
|
||||
if not is_redis_configured():
|
||||
return None
|
||||
|
||||
if _redis_client is None:
|
||||
import redis.asyncio as redis
|
||||
|
||||
_redis_client = redis.from_url(
|
||||
settings.redis_url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
)
|
||||
# Test connection
|
||||
try:
|
||||
await _redis_client.ping()
|
||||
logger.info(f"Redis connected (text): {settings.redis_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis connection failed: {e}")
|
||||
_redis_client = None
|
||||
return None
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def get_redis_binary():
|
||||
"""
|
||||
Get or create the Redis connection pool for binary data (lazy singleton).
|
||||
|
||||
Used for caching segments and init segments without base64 encoding overhead.
|
||||
|
||||
Returns None if Redis is not configured or not available.
|
||||
"""
|
||||
global _redis_binary_client
|
||||
|
||||
if not is_redis_configured():
|
||||
return None
|
||||
|
||||
if _redis_binary_client is None:
|
||||
import redis.asyncio as redis
|
||||
|
||||
_redis_binary_client = redis.from_url(
|
||||
settings.redis_url,
|
||||
decode_responses=False, # Keep bytes as-is
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
)
|
||||
# Test connection
|
||||
try:
|
||||
await _redis_binary_client.ping()
|
||||
logger.info(f"Redis connected (binary): {settings.redis_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis binary connection failed: {e}")
|
||||
_redis_binary_client = None
|
||||
return None
|
||||
return _redis_binary_client
|
||||
|
||||
|
||||
async def close_redis():
|
||||
"""Close all Redis connection pools (call on shutdown)."""
|
||||
global _redis_client, _redis_binary_client, _redis_available
|
||||
if _redis_client is not None:
|
||||
await _redis_client.aclose()
|
||||
_redis_client = None
|
||||
if _redis_binary_client is not None:
|
||||
await _redis_binary_client.aclose()
|
||||
_redis_binary_client = None
|
||||
_redis_available = None
|
||||
logger.info("Redis connections closed")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Instance Namespace Helper
|
||||
# =============================================================================
|
||||
# Some cached data is bound to the outgoing IP of the pod that produced it
|
||||
# (e.g. extractor results resolved via the pod's egress IP). Sharing these
|
||||
# entries across pods in a multi-instance deployment causes other pods to serve
|
||||
# stale/wrong URLs.
|
||||
#
|
||||
# Set CACHE_NAMESPACE (env: CACHE_NAMESPACE) to a unique value per pod (e.g.
|
||||
# pod name, hostname, or any discriminator). Instance-scoped keys are then
|
||||
# stored under "<namespace>:<original_key>", while fully-shared keys (MPD,
|
||||
# init segments, media segments, locks, stream gates) remain unchanged.
|
||||
|
||||
|
||||
def make_instance_key(key: str) -> str:
|
||||
"""Prefix *key* with the configured instance namespace.
|
||||
|
||||
Use this for cache/coordination keys that must NOT be shared across pods
|
||||
because the underlying data is specific to a pod's outgoing IP (e.g.
|
||||
extractor results). Common content (MPD, init/media segments) should
|
||||
never be namespaced.
|
||||
|
||||
If ``settings.cache_namespace`` is not set the key is returned unchanged,
|
||||
so single-instance deployments are unaffected.
|
||||
"""
|
||||
ns = settings.cache_namespace
|
||||
return f"{ns}:{key}" if ns else key
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stream Gate (Distributed Lock)
|
||||
# =============================================================================
|
||||
# Serializes upstream connection handshakes per-URL across all workers.
|
||||
# Uses SET NX EX for atomic acquire with auto-expiry.
|
||||
|
||||
GATE_PREFIX = "mfp:stream_gate:"
|
||||
GATE_TTL = 15 # seconds - auto-expire if worker crashes mid-request
|
||||
|
||||
|
||||
def _gate_key(url: str) -> str:
|
||||
"""Generate Redis key for a stream gate."""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
return f"{GATE_PREFIX}{url_hash}"
|
||||
|
||||
|
||||
async def acquire_stream_gate(url: str, timeout: float = 15.0) -> bool:
|
||||
"""
|
||||
Try to acquire a per-URL stream gate (distributed lock).
|
||||
|
||||
Only one worker across all processes can hold the gate for a given URL.
|
||||
The gate auto-expires after GATE_TTL seconds to prevent deadlocks.
|
||||
|
||||
If Redis is not available, always returns True (no coordination).
|
||||
|
||||
Args:
|
||||
url: The upstream URL to gate
|
||||
timeout: Maximum time to wait for the gate (seconds)
|
||||
|
||||
Returns:
|
||||
True if gate acquired (or Redis unavailable), False if timeout
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
# No Redis - no cross-worker coordination, always allow
|
||||
return True
|
||||
|
||||
key = _gate_key(url)
|
||||
deadline = time.time() + timeout
|
||||
|
||||
while time.time() < deadline:
|
||||
# SET NX EX is atomic: only succeeds if key doesn't exist
|
||||
if await r.set(key, "1", nx=True, ex=GATE_TTL):
|
||||
logger.debug(f"[Redis] Acquired stream gate: {key[:50]}...")
|
||||
return True
|
||||
# Another worker holds the gate, wait and retry
|
||||
await asyncio.sleep(0.05) # 50ms poll interval
|
||||
|
||||
logger.warning(f"[Redis] Gate acquisition timeout ({timeout}s): {key[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
async def release_stream_gate(url: str):
|
||||
"""
|
||||
Release a per-URL stream gate.
|
||||
|
||||
Safe to call even if gate wasn't acquired or already expired.
|
||||
No-op if Redis is not available.
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
key = _gate_key(url)
|
||||
await r.delete(key)
|
||||
logger.debug(f"[Redis] Released stream gate: {key[:50]}...")
|
||||
|
||||
|
||||
async def extend_stream_gate(url: str, ttl: int = GATE_TTL):
|
||||
"""
|
||||
Extend the TTL of a stream gate to keep it held during long streams.
|
||||
|
||||
Should be called periodically (e.g., every 10s) while streaming.
|
||||
No-op if Redis is not available or gate doesn't exist.
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
key = _gate_key(url)
|
||||
await r.expire(key, ttl)
|
||||
logger.debug(f"[Redis] Extended stream gate TTL ({ttl}s): {key[:50]}...")
|
||||
|
||||
|
||||
async def is_stream_gate_held(url: str) -> bool:
|
||||
"""Check if a stream gate is currently held. Returns False if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return False
|
||||
|
||||
key = _gate_key(url)
|
||||
return await r.exists(key) > 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HEAD Response Cache
|
||||
# =============================================================================
|
||||
# Caches upstream response headers so repeated HEAD probes (e.g., from ExoPlayer)
|
||||
# can be served without any upstream connection. Shared across all workers.
|
||||
|
||||
HEAD_CACHE_PREFIX = "mfp:head_cache:"
|
||||
HEAD_CACHE_TTL = 60 # seconds - Vidoza CDN URLs typically expire in minutes
|
||||
|
||||
|
||||
def _head_cache_key(url: str) -> str:
|
||||
"""Generate Redis key for HEAD cache entry."""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
return f"{HEAD_CACHE_PREFIX}{url_hash}"
|
||||
|
||||
|
||||
async def get_cached_head(url: str) -> Optional[dict]:
|
||||
"""
|
||||
Get cached HEAD response metadata for a URL.
|
||||
|
||||
Args:
|
||||
url: The upstream URL
|
||||
|
||||
Returns:
|
||||
Dict with 'headers' and 'status' keys, or None if not cached (or Redis unavailable)
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
key = _head_cache_key(url)
|
||||
data = await r.get(key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] HEAD cache hit: {key[:50]}...")
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
|
||||
async def set_cached_head(url: str, headers: dict, status: int):
|
||||
"""
|
||||
Cache HEAD response metadata for a URL.
|
||||
|
||||
No-op if Redis is not available.
|
||||
|
||||
Args:
|
||||
url: The upstream URL
|
||||
headers: Response headers dict (will be JSON serialized)
|
||||
status: HTTP status code (e.g., 200, 206)
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
key = _head_cache_key(url)
|
||||
# Only cache headers that are useful for HEAD responses
|
||||
# Filter to avoid caching large or irrelevant headers
|
||||
cached_headers = {}
|
||||
for k, v in headers.items():
|
||||
k_lower = k.lower()
|
||||
if k_lower in (
|
||||
"content-type",
|
||||
"content-length",
|
||||
"accept-ranges",
|
||||
"content-range",
|
||||
"etag",
|
||||
"last-modified",
|
||||
"cache-control",
|
||||
):
|
||||
cached_headers[k_lower] = v
|
||||
|
||||
payload = json.dumps({"headers": cached_headers, "status": status})
|
||||
await r.set(key, payload, ex=HEAD_CACHE_TTL)
|
||||
logger.debug(f"[Redis] HEAD cache set ({HEAD_CACHE_TTL}s TTL): {key[:50]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Generic Distributed Lock
|
||||
# =============================================================================
|
||||
# For cross-worker coordination (e.g., segment downloads, prebuffering)
|
||||
|
||||
LOCK_PREFIX = "mfp:lock:"
|
||||
DEFAULT_LOCK_TTL = 30 # seconds
|
||||
|
||||
|
||||
def _lock_key(key: str) -> str:
|
||||
"""Generate Redis key for a lock."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return f"{LOCK_PREFIX}{key_hash}"
|
||||
|
||||
|
||||
async def acquire_lock(key: str, ttl: int = DEFAULT_LOCK_TTL, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Acquire a distributed lock.
|
||||
|
||||
If Redis is not available, always returns True (no coordination).
|
||||
|
||||
Args:
|
||||
key: The lock identifier
|
||||
ttl: Lock auto-expiry time in seconds (prevents deadlocks)
|
||||
timeout: Maximum time to wait for the lock
|
||||
|
||||
Returns:
|
||||
True if lock acquired (or Redis unavailable), False if timeout
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return True # No Redis - no coordination
|
||||
|
||||
lock_key = _lock_key(key)
|
||||
deadline = time.time() + timeout
|
||||
|
||||
while time.time() < deadline:
|
||||
if await r.set(lock_key, "1", nx=True, ex=ttl):
|
||||
logger.debug(f"[Redis] Acquired lock: {key[:60]}...")
|
||||
return True
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.warning(f"[Redis] Lock timeout ({timeout}s): {key[:60]}...")
|
||||
return False
|
||||
|
||||
|
||||
async def release_lock(key: str):
|
||||
"""Release a distributed lock. No-op if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
lock_key = _lock_key(key)
|
||||
await r.delete(lock_key)
|
||||
logger.debug(f"[Redis] Released lock: {key[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Extractor Cache
|
||||
# =============================================================================
|
||||
# Caches extractor results (JSON) across all workers
|
||||
|
||||
EXTRACTOR_CACHE_PREFIX = "mfp:extractor:"
|
||||
EXTRACTOR_CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
def _extractor_key(key: str) -> str:
|
||||
"""Generate Redis key for extractor cache."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return f"{EXTRACTOR_CACHE_PREFIX}{key_hash}"
|
||||
|
||||
|
||||
async def get_cached_extractor(key: str) -> Optional[dict]:
|
||||
"""Get cached extractor result. Returns None if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
redis_key = _extractor_key(key)
|
||||
data = await r.get(redis_key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] Extractor cache hit: {key[:60]}...")
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
|
||||
async def set_cached_extractor(key: str, data: dict, ttl: int = EXTRACTOR_CACHE_TTL):
|
||||
"""Cache extractor result. No-op if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
redis_key = _extractor_key(key)
|
||||
await r.set(redis_key, json.dumps(data), ex=ttl)
|
||||
logger.debug(f"[Redis] Extractor cache set ({ttl}s TTL): {key[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MPD Cache
|
||||
# =============================================================================
|
||||
# Caches parsed MPD manifests (JSON) across all workers
|
||||
|
||||
MPD_CACHE_PREFIX = "mfp:mpd:"
|
||||
DEFAULT_MPD_CACHE_TTL = 60 # 1 minute
|
||||
|
||||
|
||||
def _mpd_key(key: str) -> str:
|
||||
"""Generate Redis key for MPD cache."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return f"{MPD_CACHE_PREFIX}{key_hash}"
|
||||
|
||||
|
||||
async def get_cached_mpd(key: str) -> Optional[dict]:
|
||||
"""Get cached MPD manifest. Returns None if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
redis_key = _mpd_key(key)
|
||||
data = await r.get(redis_key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] MPD cache hit: {key[:60]}...")
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
|
||||
async def set_cached_mpd(key: str, data: dict, ttl: int | float = DEFAULT_MPD_CACHE_TTL):
|
||||
"""Cache MPD manifest. No-op if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
redis_key = _mpd_key(key)
|
||||
# Ensure TTL is an integer (Redis requires int for ex parameter)
|
||||
ttl_int = max(1, int(ttl))
|
||||
await r.set(redis_key, json.dumps(data), ex=ttl_int)
|
||||
logger.debug(f"[Redis] MPD cache set ({ttl_int}s TTL): {key[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Segment Cache (Binary)
|
||||
# =============================================================================
|
||||
# Caches HLS/DASH segments across all workers
|
||||
|
||||
SEGMENT_CACHE_PREFIX = b"mfp:segment:"
|
||||
DEFAULT_SEGMENT_CACHE_TTL = 60 # 1 minute
|
||||
|
||||
|
||||
def _segment_key(url: str) -> bytes:
|
||||
"""Generate Redis key for segment cache."""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
return SEGMENT_CACHE_PREFIX + url_hash.encode()
|
||||
|
||||
|
||||
async def get_cached_segment(url: str) -> Optional[bytes]:
|
||||
"""Get cached segment data. Returns None if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
key = _segment_key(url)
|
||||
data = await r.get(key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] Segment cache hit: {url[:60]}...")
|
||||
return data
|
||||
|
||||
|
||||
async def set_cached_segment(url: str, data: bytes, ttl: int = DEFAULT_SEGMENT_CACHE_TTL):
|
||||
"""Cache segment data. No-op if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
key = _segment_key(url)
|
||||
await r.set(key, data, ex=ttl)
|
||||
logger.debug(f"[Redis] Segment cache set ({ttl}s TTL, {len(data)} bytes): {url[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Init Segment Cache (Binary)
|
||||
# =============================================================================
|
||||
# Caches initialization segments across all workers
|
||||
|
||||
INIT_CACHE_PREFIX = b"mfp:init:"
|
||||
DEFAULT_INIT_CACHE_TTL = 3600 # 1 hour
|
||||
|
||||
|
||||
def _init_key(url: str) -> bytes:
|
||||
"""Generate Redis key for init segment cache."""
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
return INIT_CACHE_PREFIX + url_hash.encode()
|
||||
|
||||
|
||||
async def get_cached_init_segment(url: str) -> Optional[bytes]:
|
||||
"""Get cached init segment data. Returns None if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
key = _init_key(url)
|
||||
data = await r.get(key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] Init segment cache hit: {url[:60]}...")
|
||||
return data
|
||||
|
||||
|
||||
async def set_cached_init_segment(url: str, data: bytes, ttl: int = DEFAULT_INIT_CACHE_TTL):
|
||||
"""Cache init segment data. No-op if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
key = _init_key(url)
|
||||
await r.set(key, data, ex=ttl)
|
||||
logger.debug(f"[Redis] Init segment cache set ({ttl}s TTL, {len(data)} bytes): {url[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Processed Init Segment Cache (Binary)
|
||||
# =============================================================================
|
||||
# Caches DRM-stripped/processed init segments across all workers
|
||||
|
||||
PROCESSED_INIT_CACHE_PREFIX = b"mfp:processed_init:"
|
||||
DEFAULT_PROCESSED_INIT_TTL = 3600 # 1 hour
|
||||
|
||||
|
||||
def _processed_init_key(key: str) -> bytes:
|
||||
"""Generate Redis key for processed init segment cache."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return PROCESSED_INIT_CACHE_PREFIX + key_hash.encode()
|
||||
|
||||
|
||||
async def get_cached_processed_init(key: str) -> Optional[bytes]:
|
||||
"""Get cached processed init segment data. Returns None if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
redis_key = _processed_init_key(key)
|
||||
data = await r.get(redis_key)
|
||||
if data:
|
||||
logger.debug(f"[Redis] Processed init cache hit: {key[:60]}...")
|
||||
return data
|
||||
|
||||
|
||||
async def set_cached_processed_init(key: str, data: bytes, ttl: int = DEFAULT_PROCESSED_INIT_TTL):
|
||||
"""Cache processed init segment data. No-op if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
redis_key = _processed_init_key(key)
|
||||
await r.set(redis_key, data, ex=ttl)
|
||||
logger.debug(f"[Redis] Processed init cache set ({ttl}s TTL, {len(data)} bytes): {key[:60]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# In-flight Request Tracking
|
||||
# =============================================================================
|
||||
# Prevents duplicate upstream requests across all workers
|
||||
|
||||
INFLIGHT_PREFIX = "mfp:inflight:"
|
||||
DEFAULT_INFLIGHT_TTL = 60 # seconds
|
||||
|
||||
|
||||
def _inflight_key(key: str) -> str:
|
||||
"""Generate Redis key for in-flight tracking."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return f"{INFLIGHT_PREFIX}{key_hash}"
|
||||
|
||||
|
||||
async def mark_inflight(key: str, ttl: int = DEFAULT_INFLIGHT_TTL) -> bool:
|
||||
"""
|
||||
Mark a request as in-flight (being processed by some worker).
|
||||
|
||||
If Redis is not available, always returns True (this worker is "first").
|
||||
|
||||
Args:
|
||||
key: The request identifier
|
||||
ttl: Auto-expiry time in seconds
|
||||
|
||||
Returns:
|
||||
True if this call marked it (first worker), False if already in-flight
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return True # No Redis - always proceed
|
||||
|
||||
inflight_key = _inflight_key(key)
|
||||
result = await r.set(inflight_key, "1", nx=True, ex=ttl)
|
||||
if result:
|
||||
logger.debug(f"[Redis] Marked in-flight: {key[:60]}...")
|
||||
return bool(result)
|
||||
|
||||
|
||||
async def is_inflight(key: str) -> bool:
|
||||
"""Check if a request is currently in-flight. Returns False if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return False
|
||||
|
||||
inflight_key = _inflight_key(key)
|
||||
return await r.exists(inflight_key) > 0
|
||||
|
||||
|
||||
async def clear_inflight(key: str):
|
||||
"""Clear in-flight marker (call when request completes). No-op if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
|
||||
inflight_key = _inflight_key(key)
|
||||
await r.delete(inflight_key)
|
||||
logger.debug(f"[Redis] Cleared in-flight: {key[:60]}...")
|
||||
|
||||
|
||||
async def wait_for_completion(key: str, timeout: float = 30.0, poll_interval: float = 0.1) -> bool:
|
||||
"""
|
||||
Wait for an in-flight request to complete.
|
||||
|
||||
If Redis is not available, returns True immediately.
|
||||
|
||||
Args:
|
||||
key: The request identifier
|
||||
timeout: Maximum time to wait
|
||||
poll_interval: Time between checks
|
||||
|
||||
Returns:
|
||||
True if completed (inflight marker gone), False if timeout
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return True # No Redis - don't wait
|
||||
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
if not await is_inflight(key):
|
||||
return True
|
||||
await asyncio.sleep(poll_interval)
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cooldown/Throttle Tracking
|
||||
# =============================================================================
|
||||
# Prevents rapid repeated operations (e.g., background refresh throttling)
|
||||
|
||||
COOLDOWN_PREFIX = "mfp:cooldown:"
|
||||
|
||||
|
||||
def _cooldown_key(key: str) -> str:
|
||||
"""Generate Redis key for cooldown tracking."""
|
||||
key_hash = hashlib.md5(key.encode()).hexdigest()
|
||||
return f"{COOLDOWN_PREFIX}{key_hash}"
|
||||
|
||||
|
||||
async def check_and_set_cooldown(key: str, cooldown_seconds: int) -> bool:
|
||||
"""
|
||||
Check if cooldown has elapsed and set new cooldown if so.
|
||||
|
||||
If Redis is not available, always returns True (no rate limiting).
|
||||
|
||||
Args:
|
||||
key: The cooldown identifier
|
||||
cooldown_seconds: Duration of the cooldown period
|
||||
|
||||
Returns:
|
||||
True if cooldown elapsed (and new cooldown set), False if still in cooldown
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return True # No Redis - no rate limiting
|
||||
|
||||
cooldown_key = _cooldown_key(key)
|
||||
# SET NX EX: only succeeds if key doesn't exist
|
||||
result = await r.set(cooldown_key, "1", nx=True, ex=cooldown_seconds)
|
||||
if result:
|
||||
logger.debug(f"[Redis] Cooldown set ({cooldown_seconds}s): {key[:60]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def is_in_cooldown(key: str) -> bool:
|
||||
"""Check if currently in cooldown period. Returns False if Redis unavailable."""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return False
|
||||
|
||||
cooldown_key = _cooldown_key(key)
|
||||
return await r.exists(cooldown_key) > 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HLS Transcode Session (Cross-Worker)
|
||||
# =============================================================================
|
||||
# Per-segment HLS transcode caching.
|
||||
# Each segment is independently transcoded and cached. Segment output metadata
|
||||
# (video/audio DTS, sequence number) is stored so consecutive segments can
|
||||
# maintain timeline continuity without a persistent pipeline.
|
||||
|
||||
HLS_SEG_PREFIX = b"mfp:hls_seg:"
|
||||
HLS_INIT_PREFIX = b"mfp:hls_init:"
|
||||
HLS_SEG_META_PREFIX = "mfp:hls_smeta:"
|
||||
|
||||
HLS_SEG_TTL = 60 # 60 s -- short-lived; only for immediate retry/re-request
|
||||
HLS_INIT_TTL = 3600 # 1 hour -- stable for the viewing session
|
||||
HLS_SEG_META_TTL = 3600 # 1 hour -- needed for next-segment continuity
|
||||
|
||||
|
||||
def _hls_seg_key(cache_key: str, seg_index: int) -> bytes:
|
||||
return HLS_SEG_PREFIX + f"{cache_key}:{seg_index}".encode()
|
||||
|
||||
|
||||
def _hls_init_key(cache_key: str) -> bytes:
|
||||
return HLS_INIT_PREFIX + cache_key.encode()
|
||||
|
||||
|
||||
def _hls_seg_meta_key(cache_key: str, seg_index: int) -> str:
|
||||
return f"{HLS_SEG_META_PREFIX}{cache_key}:{seg_index}"
|
||||
|
||||
|
||||
async def hls_get_segment(cache_key: str, seg_index: int) -> Optional[bytes]:
|
||||
"""Get a cached HLS segment from Redis. Returns None if unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return None
|
||||
try:
|
||||
return await r.get(_hls_seg_key(cache_key, seg_index))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def hls_set_segment(cache_key: str, seg_index: int, data: bytes) -> None:
|
||||
"""Store an HLS segment in Redis with short TTL. No-op if Redis unavailable."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return
|
||||
try:
|
||||
await r.set(_hls_seg_key(cache_key, seg_index), data, ex=HLS_SEG_TTL)
|
||||
except Exception:
|
||||
logger.debug("[Redis] Failed to cache HLS segment %d", seg_index)
|
||||
|
||||
|
||||
async def hls_get_init(cache_key: str) -> Optional[bytes]:
|
||||
"""Get the cached HLS init segment from Redis."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return None
|
||||
try:
|
||||
return await r.get(_hls_init_key(cache_key))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def hls_set_init(cache_key: str, data: bytes) -> None:
|
||||
"""Store the HLS init segment in Redis."""
|
||||
r = await get_redis_binary()
|
||||
if r is None:
|
||||
return
|
||||
try:
|
||||
await r.set(_hls_init_key(cache_key), data, ex=HLS_INIT_TTL)
|
||||
except Exception:
|
||||
logger.debug("[Redis] Failed to cache HLS init segment")
|
||||
|
||||
|
||||
async def hls_get_segment_meta(cache_key: str, seg_index: int) -> Optional[dict]:
|
||||
"""
|
||||
Get per-segment output metadata from Redis.
|
||||
|
||||
Returns a dict with keys like ``video_dts_ms``, ``audio_dts_ms``,
|
||||
``sequence_number``, or None if unavailable.
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return None
|
||||
try:
|
||||
raw = await r.get(_hls_seg_meta_key(cache_key, seg_index))
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
async def hls_set_segment_meta(cache_key: str, seg_index: int, meta: dict) -> None:
|
||||
"""
|
||||
Store per-segment output metadata in Redis.
|
||||
|
||||
``meta`` should contain keys like ``video_dts_ms``, ``audio_dts_ms``,
|
||||
``sequence_number`` so the next segment can continue the timeline.
|
||||
"""
|
||||
r = await get_redis()
|
||||
if r is None:
|
||||
return
|
||||
try:
|
||||
await r.set(
|
||||
_hls_seg_meta_key(cache_key, seg_index),
|
||||
json.dumps(meta),
|
||||
ex=HLS_SEG_META_TTL,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("[Redis] Failed to set HLS segment meta %d", seg_index)
|
||||
+3707
-897
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Stream transformers for host-specific content manipulation.
|
||||
|
||||
This module provides transformer classes that can modify streaming content
|
||||
on-the-fly. Each transformer handles specific content manipulation needs
|
||||
for different streaming hosts (e.g., PNG wrapper stripping, TS detection).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import typing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamTransformer:
|
||||
"""
|
||||
Base class for stream content transformers.
|
||||
|
||||
Subclasses should override the transform method to implement
|
||||
specific content manipulation logic.
|
||||
"""
|
||||
|
||||
async def transform(self, chunk_iterator: typing.AsyncIterator[bytes]) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Transform stream chunks.
|
||||
|
||||
Args:
|
||||
chunk_iterator: Async iterator of raw bytes from upstream.
|
||||
|
||||
Yields:
|
||||
Transformed bytes chunks.
|
||||
"""
|
||||
async for chunk in chunk_iterator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class TSStreamTransformer(StreamTransformer):
|
||||
"""
|
||||
Transformer for MPEG-TS streams with obfuscation.
|
||||
|
||||
Handles streams from hosts like TurboVidPlay, StreamWish, and FileMoon
|
||||
that may have:
|
||||
- Fake PNG wrapper prepended to video data
|
||||
- 0xFF padding bytes before actual content
|
||||
- Need for TS sync byte detection
|
||||
"""
|
||||
|
||||
# PNG signature and IEND marker for fake PNG header detection
|
||||
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
_PNG_IEND_MARKER = b"\x49\x45\x4e\x44\xae\x42\x60\x82"
|
||||
|
||||
# TS packet constants
|
||||
_TS_SYNC = 0x47
|
||||
_TS_PACKET_SIZE = 188
|
||||
|
||||
# Maximum bytes to buffer before forcing passthrough
|
||||
_MAX_PREFETCH = 512 * 1024 # 512 KB
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = bytearray()
|
||||
self.ts_started = False
|
||||
self.bytes_stripped = 0
|
||||
|
||||
@staticmethod
|
||||
def _find_ts_start(buffer: bytes) -> typing.Optional[int]:
|
||||
"""
|
||||
Find MPEG-TS sync byte (0x47) aligned on 188 bytes.
|
||||
|
||||
Args:
|
||||
buffer: Bytes to search for TS sync pattern.
|
||||
|
||||
Returns:
|
||||
Offset where TS starts, or None if not found.
|
||||
"""
|
||||
TS_SYNC = 0x47
|
||||
TS_PACKET = 188
|
||||
|
||||
max_i = len(buffer) - TS_PACKET
|
||||
for i in range(max(0, max_i)):
|
||||
if buffer[i] == TS_SYNC and buffer[i + TS_PACKET] == TS_SYNC:
|
||||
return i
|
||||
return None
|
||||
|
||||
def _strip_fake_png_wrapper(self, chunk: bytes) -> bytes:
|
||||
"""
|
||||
Strip fake PNG wrapper from chunk data.
|
||||
|
||||
Some streaming services prepend a fake PNG image to video data
|
||||
to evade detection. This method detects and removes it.
|
||||
|
||||
Args:
|
||||
chunk: The raw chunk data that may contain a fake PNG header.
|
||||
|
||||
Returns:
|
||||
The chunk with fake PNG wrapper removed, or original chunk if not present.
|
||||
"""
|
||||
if not chunk.startswith(self._PNG_SIGNATURE):
|
||||
return chunk
|
||||
|
||||
# Find the IEND marker that signals end of PNG data
|
||||
iend_pos = chunk.find(self._PNG_IEND_MARKER)
|
||||
if iend_pos == -1:
|
||||
# IEND not found in this chunk - return as-is to avoid data corruption
|
||||
logger.debug("PNG signature detected but IEND marker not found in chunk")
|
||||
return chunk
|
||||
|
||||
# Calculate position after IEND marker
|
||||
content_start = iend_pos + len(self._PNG_IEND_MARKER)
|
||||
|
||||
# Skip any padding bytes (null or 0xFF) between PNG and actual content
|
||||
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
|
||||
content_start += 1
|
||||
|
||||
self.bytes_stripped = content_start
|
||||
logger.debug(f"Stripped {content_start} bytes of fake PNG wrapper from stream")
|
||||
|
||||
return chunk[content_start:]
|
||||
|
||||
async def transform(self, chunk_iterator: typing.AsyncIterator[bytes]) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Transform TS stream by stripping PNG wrapper and finding TS start.
|
||||
|
||||
Args:
|
||||
chunk_iterator: Async iterator of raw bytes from upstream.
|
||||
|
||||
Yields:
|
||||
Cleaned TS stream bytes.
|
||||
"""
|
||||
async for chunk in chunk_iterator:
|
||||
if self.ts_started:
|
||||
# Normal streaming once TS has started
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
# Prebuffer phase (until we find TS or pass through)
|
||||
self.buffer += chunk
|
||||
|
||||
# Fast-path: if it's an m3u8 playlist, don't do TS detection
|
||||
if len(self.buffer) >= 7 and self.buffer[:7] in (b"#EXTM3U", b"#EXT-X-"):
|
||||
yield bytes(self.buffer)
|
||||
self.buffer.clear()
|
||||
self.ts_started = True
|
||||
continue
|
||||
|
||||
# Strip fake PNG wrapper if present
|
||||
if self.buffer.startswith(self._PNG_SIGNATURE):
|
||||
if self._PNG_IEND_MARKER in self.buffer:
|
||||
self.buffer = bytearray(self._strip_fake_png_wrapper(bytes(self.buffer)))
|
||||
|
||||
# Skip pure 0xFF padding bytes (TurboVid style)
|
||||
while self.buffer and self.buffer[0] == 0xFF:
|
||||
self.buffer.pop(0)
|
||||
|
||||
# Re-check for m3u8 playlist after stripping PNG wrapper and padding
|
||||
# This handles cases where m3u8 content is wrapped in PNG
|
||||
if len(self.buffer) >= 7 and self.buffer[:7] in (b"#EXTM3U", b"#EXT-X-"):
|
||||
logger.debug("Found m3u8 content after stripping wrapper - passing through")
|
||||
yield bytes(self.buffer)
|
||||
self.buffer.clear()
|
||||
self.ts_started = True
|
||||
continue
|
||||
|
||||
ts_offset = self._find_ts_start(bytes(self.buffer))
|
||||
if ts_offset is None:
|
||||
# Keep buffering until we find TS or hit limit
|
||||
if len(self.buffer) > self._MAX_PREFETCH:
|
||||
logger.warning("TS sync not found after large prebuffer, forcing passthrough")
|
||||
yield bytes(self.buffer)
|
||||
self.buffer.clear()
|
||||
self.ts_started = True
|
||||
continue
|
||||
|
||||
# TS found: emit from ts_offset and switch to pass-through
|
||||
self.ts_started = True
|
||||
out = bytes(self.buffer[ts_offset:])
|
||||
self.buffer.clear()
|
||||
|
||||
if out:
|
||||
yield out
|
||||
|
||||
|
||||
# Registry of available transformers
|
||||
TRANSFORMER_REGISTRY: dict[str, type[StreamTransformer]] = {
|
||||
"ts_stream": TSStreamTransformer,
|
||||
}
|
||||
|
||||
|
||||
def get_transformer(transformer_id: typing.Optional[str]) -> typing.Optional[StreamTransformer]:
|
||||
"""
|
||||
Get a transformer instance by ID.
|
||||
|
||||
Args:
|
||||
transformer_id: The transformer identifier (e.g., "ts_stream").
|
||||
|
||||
Returns:
|
||||
A new transformer instance, or None if transformer_id is None or not found.
|
||||
"""
|
||||
if transformer_id is None:
|
||||
return None
|
||||
|
||||
transformer_class = TRANSFORMER_REGISTRY.get(transformer_id)
|
||||
if transformer_class is None:
|
||||
logger.warning(f"Unknown transformer ID: {transformer_id}")
|
||||
return None
|
||||
|
||||
return transformer_class()
|
||||
|
||||
|
||||
async def apply_transformer_to_bytes(
|
||||
data: bytes,
|
||||
transformer_id: typing.Optional[str],
|
||||
) -> bytes:
|
||||
"""
|
||||
Apply a transformer to already-downloaded bytes data.
|
||||
|
||||
This is useful when serving cached segments that need transformation.
|
||||
Creates a single-chunk async iterator and collects the transformed output.
|
||||
|
||||
Args:
|
||||
data: The raw bytes data to transform.
|
||||
transformer_id: The transformer identifier (e.g., "ts_stream").
|
||||
|
||||
Returns:
|
||||
Transformed bytes, or original data if no transformer specified.
|
||||
"""
|
||||
if not transformer_id:
|
||||
return data
|
||||
|
||||
transformer = get_transformer(transformer_id)
|
||||
if not transformer:
|
||||
return data
|
||||
|
||||
async def single_chunk_iterator():
|
||||
yield data
|
||||
|
||||
# Collect all transformed chunks
|
||||
result = bytearray()
|
||||
async for chunk in transformer.transform(single_chunk_iterator()):
|
||||
result.extend(chunk)
|
||||
|
||||
return bytes(result)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,13 +3,21 @@
|
||||
|
||||
"""hashlib that handles FIPS mode."""
|
||||
|
||||
# Because we are extending the hashlib module, we need to import all its
|
||||
# fields to suppport the same uses
|
||||
# pylint: disable=unused-wildcard-import, wildcard-import
|
||||
from hashlib import *
|
||||
# pylint: enable=unused-wildcard-import, wildcard-import
|
||||
import hashlib
|
||||
|
||||
# Re-export commonly used hash constructors
|
||||
sha1 = hashlib.sha1
|
||||
sha224 = hashlib.sha224
|
||||
sha256 = hashlib.sha256
|
||||
sha384 = hashlib.sha384
|
||||
sha512 = hashlib.sha512
|
||||
sha3_224 = hashlib.sha3_224
|
||||
sha3_256 = hashlib.sha3_256
|
||||
sha3_384 = hashlib.sha3_384
|
||||
sha3_512 = hashlib.sha3_512
|
||||
blake2b = hashlib.blake2b
|
||||
blake2s = hashlib.blake2s
|
||||
|
||||
|
||||
def _fipsFunction(func, *args, **kwargs):
|
||||
"""Make hash function support FIPS mode."""
|
||||
@@ -19,8 +27,6 @@ def _fipsFunction(func, *args, **kwargs):
|
||||
return func(*args, usedforsecurity=False, **kwargs)
|
||||
|
||||
|
||||
# redefining the function is exactly what we intend to do
|
||||
# pylint: disable=function-redefined
|
||||
def md5(*args, **kwargs):
|
||||
"""MD5 constructor that works in FIPS mode."""
|
||||
return _fipsFunction(hashlib.md5, *args, **kwargs)
|
||||
@@ -29,4 +35,3 @@ def md5(*args, **kwargs):
|
||||
def new(*args, **kwargs):
|
||||
"""General constructor that works in FIPS mode."""
|
||||
return _fipsFunction(hashlib.new, *args, **kwargs)
|
||||
# pylint: enable=function-redefined
|
||||
|
||||
@@ -11,17 +11,20 @@ Note that this makes this code FIPS non-compliant!
|
||||
# fields to suppport the same uses
|
||||
from . import tlshashlib
|
||||
from .compat import compatHMAC
|
||||
|
||||
try:
|
||||
from hmac import compare_digest
|
||||
|
||||
__all__ = ["new", "compare_digest", "HMAC"]
|
||||
except ImportError:
|
||||
__all__ = ["new", "HMAC"]
|
||||
|
||||
try:
|
||||
from hmac import HMAC, new
|
||||
|
||||
# if we can calculate HMAC on MD5, then use the built-in HMAC
|
||||
# implementation
|
||||
_val = HMAC(b'some key', b'msg', 'md5')
|
||||
_val = HMAC(b"some key", b"msg", "md5")
|
||||
_val.digest()
|
||||
del _val
|
||||
except Exception:
|
||||
@@ -38,10 +41,10 @@ except Exception:
|
||||
"""
|
||||
self.key = key
|
||||
if digestmod is None:
|
||||
digestmod = 'md5'
|
||||
digestmod = "md5"
|
||||
if callable(digestmod):
|
||||
digestmod = digestmod()
|
||||
if not hasattr(digestmod, 'digest_size'):
|
||||
if not hasattr(digestmod, "digest_size"):
|
||||
digestmod = tlshashlib.new(digestmod)
|
||||
self.block_size = digestmod.block_size
|
||||
self.digest_size = digestmod.digest_size
|
||||
@@ -51,10 +54,10 @@ except Exception:
|
||||
k_hash.update(compatHMAC(key))
|
||||
key = k_hash.digest()
|
||||
if len(key) < self.block_size:
|
||||
key = key + b'\x00' * (self.block_size - len(key))
|
||||
key = key + b"\x00" * (self.block_size - len(key))
|
||||
key = bytearray(key)
|
||||
ipad = bytearray(b'\x36' * self.block_size)
|
||||
opad = bytearray(b'\x5c' * self.block_size)
|
||||
ipad = bytearray(b"\x36" * self.block_size)
|
||||
opad = bytearray(b"\x5c" * self.block_size)
|
||||
i_key = bytearray(i ^ j for i, j in zip(key, ipad))
|
||||
self._o_key = bytearray(i ^ j for i, j in zip(key, opad))
|
||||
self._context = digestmod.copy()
|
||||
@@ -82,7 +85,6 @@ except Exception:
|
||||
new._context = self._context.copy()
|
||||
return new
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
"""General constructor that works in FIPS mode."""
|
||||
return HMAC(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user