This commit is contained in:
UrloMythus
2026-02-19 20:15:03 +01:00
parent 7785e8c604
commit cfc6bbabc9
181 changed files with 32141 additions and 4629 deletions

Binary file not shown.

View File

@@ -0,0 +1,684 @@
"""
Acestream session management with cross-process coordination.
This module provides:
- AcestreamSessionManager: Manages acestream sessions per infohash with cross-process coordination
- AcestreamSession: Represents a single acestream session with playback URLs
- AsyncMultiWriter: Fan-out writer for streaming to multiple clients (MPEG-TS mode)
Architecture:
- Uses Redis for cross-worker coordination and session registry
- Each worker can reuse existing session's playback_url (acestream allows multiple connections)
- Session cleanup via command_url?method=stop when all clients disconnect
"""
import asyncio
import hashlib
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from uuid import uuid4
import aiohttp
from mediaflow_proxy.configs import settings
from mediaflow_proxy.utils import redis_utils
logger = logging.getLogger(__name__)
@dataclass
class AcestreamResponse:
"""Response from acestream's format=json API."""
playback_url: str
stat_url: str
command_url: str
infohash: str
playback_session_id: str
is_live: bool
is_encrypted: bool
@dataclass
class AcestreamSession:
"""
Represents an active acestream session.
A session is created when the first client requests a stream for an infohash.
Multiple clients can share the same session (same playback_url).
"""
infohash: str
pid: str
playback_url: str
command_url: str
stat_url: str
playback_session_id: str
is_live: bool
created_at: float = field(default_factory=time.time)
last_access: float = field(default_factory=time.time)
last_segment_request: float = field(default_factory=time.time)
client_count: int = 0
def touch(self) -> None:
"""Update last access time."""
self.last_access = time.time()
def touch_segment(self) -> None:
"""Update last segment request time (indicates active playback)."""
now = time.time()
self.last_access = now
self.last_segment_request = now
def is_actively_streaming(self, timeout: float = 30.0) -> bool:
"""Check if this session has recent segment activity."""
return (time.time() - self.last_segment_request) < timeout
def to_dict(self) -> Dict[str, Any]:
"""Convert session to dictionary for file-based registry."""
return {
"infohash": self.infohash,
"pid": self.pid,
"playback_url": self.playback_url,
"command_url": self.command_url,
"stat_url": self.stat_url,
"playback_session_id": self.playback_session_id,
"is_live": self.is_live,
"created_at": self.created_at,
"last_access": self.last_access,
"last_segment_request": self.last_segment_request,
"worker_pid": os.getpid(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AcestreamSession":
"""Create session from dictionary."""
return cls(
infohash=data["infohash"],
pid=data["pid"],
playback_url=data["playback_url"],
command_url=data["command_url"],
stat_url=data["stat_url"],
playback_session_id=data["playback_session_id"],
is_live=data.get("is_live", True),
created_at=data.get("created_at", time.time()),
last_access=data.get("last_access", time.time()),
last_segment_request=data.get("last_segment_request", time.time()),
)
class AsyncMultiWriter:
"""
Async multi-writer for fan-out streaming to multiple clients.
Based on acexy's PMultiWriter but adapted for Python asyncio.
Writes are done in parallel to all connected writers.
Writers that fail are automatically removed.
"""
def __init__(self):
self._writers: List[asyncio.StreamWriter] = []
self._lock = asyncio.Lock()
async def add(self, writer: asyncio.StreamWriter) -> None:
"""Add a writer to the list."""
async with self._lock:
if writer not in self._writers:
self._writers.append(writer)
logger.debug(f"[AsyncMultiWriter] Added writer, total: {len(self._writers)}")
async def remove(self, writer: asyncio.StreamWriter) -> None:
"""Remove a writer from the list."""
async with self._lock:
if writer in self._writers:
self._writers.remove(writer)
logger.debug(f"[AsyncMultiWriter] Removed writer, total: {len(self._writers)}")
async def write(self, data: bytes) -> int:
"""
Write data to all connected writers in parallel.
Writers that fail are automatically removed.
Returns:
Number of successful writes.
"""
if not data:
return 0
async with self._lock:
if not self._writers:
return 0
writers_copy = list(self._writers)
failed_writers = []
successful = 0
async def write_to_single(writer: asyncio.StreamWriter) -> bool:
try:
writer.write(data)
await writer.drain()
return True
except (ConnectionResetError, BrokenPipeError, ConnectionError) as e:
logger.debug(f"[AsyncMultiWriter] Writer disconnected: {e}")
return False
except Exception as e:
logger.warning(f"[AsyncMultiWriter] Write error: {e}")
return False
# Write to all writers in parallel
results = await asyncio.gather(
*[write_to_single(w) for w in writers_copy],
return_exceptions=True,
)
for writer, result in zip(writers_copy, results):
if result is True:
successful += 1
else:
failed_writers.append(writer)
# Remove failed writers
if failed_writers:
async with self._lock:
for writer in failed_writers:
if writer in self._writers:
self._writers.remove(writer)
try:
writer.close()
except Exception:
pass
return successful
@property
def count(self) -> int:
"""Number of connected writers."""
return len(self._writers)
async def close_all(self) -> None:
"""Close all writers."""
async with self._lock:
for writer in self._writers:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
self._writers.clear()
class AcestreamSessionManager:
"""
Manages acestream sessions with cross-process coordination.
Features:
- Per-worker session tracking
- Redis-based session registry for cross-worker visibility
- Session creation via acestream's format=json API
- Session cleanup via command_url?method=stop
- Session keepalive via periodic stat_url polling
"""
# Redis key prefixes
REGISTRY_PREFIX = "mfp:acestream:session:"
REGISTRY_TTL = 3600 # 1 hour
def __init__(self):
# Per-worker session tracking (infohash -> session)
self._sessions: Dict[str, AcestreamSession] = {}
# Keepalive task
self._keepalive_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
# HTTP client session
self._http_session: Optional[aiohttp.ClientSession] = None
logger.info("[AcestreamSessionManager] Initialized with Redis backend")
def _get_registry_key(self, infohash: str) -> str:
"""Get the Redis key for an infohash."""
hash_key = hashlib.md5(infohash.encode()).hexdigest()
return f"{self.REGISTRY_PREFIX}{hash_key}"
async def _get_http_session(self) -> aiohttp.ClientSession:
"""Get or create HTTP client session."""
if self._http_session is None or self._http_session.closed:
self._http_session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
return self._http_session
async def _read_registry(self, infohash: str) -> Optional[Dict[str, Any]]:
"""Read session data from Redis registry."""
try:
r = await redis_utils.get_redis()
key = self._get_registry_key(infohash)
data = await r.get(key)
if data:
return json.loads(data)
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Error reading registry: {e}")
return None
async def _write_registry(self, session: AcestreamSession) -> None:
"""Write session data to Redis registry."""
try:
r = await redis_utils.get_redis()
key = self._get_registry_key(session.infohash)
await r.set(key, json.dumps(session.to_dict()), ex=self.REGISTRY_TTL)
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Error writing registry: {e}")
async def _delete_registry(self, infohash: str) -> None:
"""Delete session from Redis registry."""
try:
r = await redis_utils.get_redis()
key = self._get_registry_key(infohash)
await r.delete(key)
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Error deleting registry: {e}")
async def _create_acestream_session(self, infohash: str, content_id: Optional[str] = None) -> AcestreamResponse:
"""
Create a new acestream session via format=json API.
Args:
infohash: The infohash of the content (40-char hex from magnet link)
content_id: Optional content ID (alternative to infohash)
Returns:
AcestreamResponse with playback URLs
Raises:
Exception if session creation fails
"""
base_url = f"http://{settings.acestream_host}:{settings.acestream_port}"
pid = str(uuid4())
# Build URL with parameters
# Acestream uses different parameter names:
# - 'id' or 'content_id' for content IDs
# - 'infohash' for magnet link hashes (40-char hex)
params = {
"format": "json",
"pid": pid,
}
if content_id:
# Content ID provided - use 'id' parameter
params["id"] = content_id
else:
# Only infohash provided - use 'infohash' parameter
params["infohash"] = infohash
# Use manifest.m3u8 for HLS or getstream for MPEG-TS
# We'll use manifest.m3u8 as the primary since we leverage HLS infrastructure
url = f"{base_url}/ace/manifest.m3u8"
session = await self._get_http_session()
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
data = await response.json()
if data.get("error"):
raise Exception(f"Acestream error: {data['error']}")
resp = data.get("response", {})
return AcestreamResponse(
playback_url=resp.get("playback_url", ""),
stat_url=resp.get("stat_url", ""),
command_url=resp.get("command_url", ""),
infohash=resp.get("infohash", infohash),
playback_session_id=resp.get("playback_session_id", ""),
is_live=bool(resp.get("is_live", 1)),
is_encrypted=bool(resp.get("is_encrypted", 0)),
)
except aiohttp.ClientError as e:
logger.error(f"[AcestreamSessionManager] HTTP error creating session: {e}")
raise
async def get_or_create_session(
self,
infohash: str,
content_id: Optional[str] = None,
increment_client: bool = True,
) -> AcestreamSession:
"""
Get an existing session or create a new one.
Uses Redis locking to coordinate session creation across workers.
Args:
infohash: The infohash of the content
content_id: Optional content ID
increment_client: Whether to increment client count (False for manifest requests)
Returns:
AcestreamSession instance
"""
# Check if we already have this session in this worker
if infohash in self._sessions:
session = self._sessions[infohash]
session.touch()
if increment_client:
session.client_count += 1
logger.info(
f"[AcestreamSessionManager] Reusing existing session: {infohash[:16]}... "
f"(clients: {session.client_count})"
)
return session
# Need to create or fetch session - use Redis lock
lock_key = f"acestream_session:{infohash}"
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=30)
if not lock_acquired:
raise Exception(f"Failed to acquire lock for acestream session: {infohash[:16]}...")
try:
# Double-check after acquiring lock
if infohash in self._sessions:
session = self._sessions[infohash]
session.touch()
if increment_client:
session.client_count += 1
return session
# Check registry for existing session from another worker
registry_data = await self._read_registry(infohash)
if registry_data:
# Validate session is still alive by checking stat_url
if await self._validate_session(registry_data.get("stat_url", "")):
logger.info(f"[AcestreamSessionManager] Using existing session from registry: {infohash[:16]}...")
session = AcestreamSession.from_dict(registry_data)
session.client_count = 1 if increment_client else 0
self._sessions[infohash] = session
self._ensure_tasks()
return session
else:
# Session is stale, remove from registry
await self._delete_registry(infohash)
# Create new session
logger.info(f"[AcestreamSessionManager] Creating new session: {infohash[:16]}...")
try:
response = await self._create_acestream_session(infohash, content_id)
session = AcestreamSession(
infohash=infohash,
pid=str(uuid4()),
playback_url=response.playback_url,
command_url=response.command_url,
stat_url=response.stat_url,
playback_session_id=response.playback_session_id,
is_live=response.is_live,
client_count=1 if increment_client else 0,
)
self._sessions[infohash] = session
await self._write_registry(session)
self._ensure_tasks()
logger.info(
f"[AcestreamSessionManager] Created session: {infohash[:16]}... "
f"playback_url: {response.playback_url}"
)
return session
except Exception as e:
logger.error(f"[AcestreamSessionManager] Failed to create session: {e}")
raise
finally:
await redis_utils.release_lock(lock_key)
async def _validate_session(self, stat_url: str) -> bool:
"""Check if a session is still valid by polling stat_url."""
if not stat_url:
return False
try:
session = await self._get_http_session()
async with session.get(stat_url, timeout=aiohttp.ClientTimeout(total=5)) as response:
if response.status == 200:
return True
except Exception as e:
logger.debug(f"[AcestreamSessionManager] Session validation failed: {e}")
return False
async def release_session(self, infohash: str) -> None:
"""
Release a client's hold on a session.
Decrements client count. When count reaches 0, the session is closed.
Args:
infohash: The infohash of the session to release
"""
if infohash not in self._sessions:
return
session = self._sessions[infohash]
session.client_count -= 1
logger.info(
f"[AcestreamSessionManager] Released client from session: {infohash[:16]}... "
f"(remaining clients: {session.client_count})"
)
if session.client_count <= 0:
await self._close_session(infohash)
async def invalidate_session(self, infohash: str) -> None:
"""
Invalidate a stale session (e.g., when we get 403 from acestream).
This forces the session to be closed and removed from registry,
so next request will create a fresh session.
Args:
infohash: The infohash of the session to invalidate
"""
logger.warning(f"[AcestreamSessionManager] Invalidating stale session: {infohash[:16]}...")
if infohash in self._sessions:
session = self._sessions.pop(infohash)
# Try to stop the session gracefully
if session.command_url:
try:
http_session = await self._get_http_session()
url = f"{session.command_url}?method=stop"
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=3)) as response:
logger.debug(f"[AcestreamSessionManager] Stop command sent: {response.status}")
except Exception as e:
logger.debug(f"[AcestreamSessionManager] Error stopping stale session: {e}")
# Always remove from registry
await self._delete_registry(infohash)
logger.info(f"[AcestreamSessionManager] Session invalidated: {infohash[:16]}...")
async def _close_session(self, infohash: str) -> None:
"""
Close an acestream session.
Calls command_url?method=stop to properly close the session.
"""
if infohash not in self._sessions:
return
session = self._sessions.pop(infohash)
lock_key = f"acestream_session:{infohash}"
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=10, timeout=10)
try:
# Check if this is the last worker using this session
registry_data = await self._read_registry(infohash)
# Only close if we're the owner or session is stale
if registry_data and registry_data.get("worker_pid") == os.getpid():
# We're the owner, close the session
if session.command_url:
try:
http_session = await self._get_http_session()
url = f"{session.command_url}?method=stop"
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as response:
logger.info(
f"[AcestreamSessionManager] Closed session: {infohash[:16]}... "
f"(status: {response.status})"
)
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Error closing session: {e}")
await self._delete_registry(infohash)
else:
# Another worker may still be using this session
logger.debug(
f"[AcestreamSessionManager] Session {infohash[:16]}... owned by another worker, not closing"
)
finally:
if lock_acquired:
await redis_utils.release_lock(lock_key)
def _ensure_tasks(self) -> None:
"""Ensure background tasks are running."""
if self._keepalive_task is None or self._keepalive_task.done():
self._keepalive_task = asyncio.create_task(self._keepalive_loop())
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _keepalive_loop(self) -> None:
"""Periodically poll stat_url to keep sessions alive with active clients or recent segment activity."""
while True:
try:
await asyncio.sleep(settings.acestream_keepalive_interval)
for infohash, session in list(self._sessions.items()):
# Keepalive sessions with active clients OR recent segment activity
# This ensures HLS streams (which don't use client_count) stay alive
has_recent_activity = session.is_actively_streaming(timeout=settings.acestream_empty_timeout)
if session.client_count <= 0 and not has_recent_activity:
logger.debug(
f"[AcestreamSessionManager] Skipping keepalive (no clients, no recent segments): "
f"{infohash[:16]}..."
)
continue
if session.stat_url:
try:
http_session = await self._get_http_session()
async with http_session.get(
session.stat_url,
timeout=aiohttp.ClientTimeout(total=5),
) as response:
if response.status == 200:
session.touch()
await self._write_registry(session)
logger.debug(
f"[AcestreamSessionManager] Keepalive OK: {infohash[:16]}... "
f"(clients: {session.client_count}, recent_activity: {has_recent_activity})"
)
else:
logger.warning(
f"[AcestreamSessionManager] Keepalive failed: {infohash[:16]}... "
f"(status: {response.status})"
)
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Keepalive error: {infohash[:16]}... - {e}")
except asyncio.CancelledError:
return
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Keepalive loop error: {e}")
async def _cleanup_loop(self) -> None:
"""Periodically clean up stale sessions."""
while True:
try:
await asyncio.sleep(15) # Check every 15 seconds
now = time.time()
timeout = settings.acestream_session_timeout
empty_timeout = settings.acestream_empty_timeout
for infohash, session in list(self._sessions.items()):
idle_time = now - session.last_access
segment_idle_time = now - session.last_segment_request
# Don't clean up sessions with recent segment activity (active playback)
# Use empty_timeout as the threshold for "recent" activity
if segment_idle_time < empty_timeout:
logger.debug(
f"[AcestreamSessionManager] Session has recent segment activity: {infohash[:16]}... "
f"(segment idle: {segment_idle_time:.0f}s)"
)
continue
# Clean up sessions with no clients after empty_timeout (faster cleanup)
if session.client_count <= 0 and idle_time > empty_timeout:
logger.info(
f"[AcestreamSessionManager] Cleaning up empty session: {infohash[:16]}... "
f"(idle: {idle_time:.0f}s, segment idle: {segment_idle_time:.0f}s)"
)
await self._close_session(infohash)
# Clean up any session after session_timeout regardless of client count
elif idle_time > timeout:
logger.info(
f"[AcestreamSessionManager] Cleaning up stale session: {infohash[:16]}... "
f"(idle: {idle_time:.0f}s, segment idle: {segment_idle_time:.0f}s, clients: {session.client_count})"
)
await self._close_session(infohash)
# Note: Redis entries expire via TTL, no manual cleanup needed
except asyncio.CancelledError:
return
except Exception as e:
logger.warning(f"[AcestreamSessionManager] Cleanup loop error: {e}")
def get_session(self, infohash: str) -> Optional[AcestreamSession]:
"""Get a session by infohash if it exists in this worker."""
return self._sessions.get(infohash)
def get_active_sessions(self) -> Dict[str, AcestreamSession]:
"""Get all active sessions in this worker."""
return dict(self._sessions)
async def close(self) -> None:
"""Close the session manager and clean up resources."""
# Cancel background tasks
if self._keepalive_task:
self._keepalive_task.cancel()
try:
await self._keepalive_task
except asyncio.CancelledError:
pass
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Close all sessions
for infohash in list(self._sessions.keys()):
await self._close_session(infohash)
# Close HTTP session
if self._http_session and not self._http_session.closed:
await self._http_session.close()
logger.info("[AcestreamSessionManager] Closed")
# Global session manager instance
acestream_manager = AcestreamSessionManager()

View File

@@ -3,6 +3,7 @@
"""Abstract class for AES."""
class AES(object):
def __init__(self, key, mode, IV, implementation):
if len(key) not in (16, 24, 32):
@@ -19,21 +20,21 @@ class AES(object):
self.isAEAD = False
self.block_size = 16
self.implementation = implementation
if len(key)==16:
if len(key) == 16:
self.name = "aes128"
elif len(key)==24:
elif len(key) == 24:
self.name = "aes192"
elif len(key)==32:
elif len(key) == 32:
self.name = "aes256"
else:
raise AssertionError()
#CBC-Mode encryption, returns ciphertext
#WARNING: *MAY* modify the input as well
# CBC-Mode encryption, returns ciphertext
# WARNING: *MAY* modify the input as well
def encrypt(self, plaintext):
assert(len(plaintext) % 16 == 0)
assert len(plaintext) % 16 == 0
#CBC-Mode decryption, returns plaintext
#WARNING: *MAY* modify the input as well
# CBC-Mode decryption, returns plaintext
# WARNING: *MAY* modify the input as well
def decrypt(self, ciphertext):
assert(len(ciphertext) % 16 == 0)
assert len(ciphertext) % 16 == 0

View File

@@ -18,6 +18,7 @@ from . import python_aes
from .constanttime import ct_compare_digest
from .cryptomath import bytesToNumber, numberToByteArray
class AESGCM(object):
"""
AES-GCM implementation. Note: this implementation does not attempt
@@ -39,7 +40,7 @@ class AESGCM(object):
self.key = key
self._rawAesEncrypt = rawAesEncrypt
self._ctr = python_aes.new(self.key, 6, bytearray(b'\x00' * 16))
self._ctr = python_aes.new(self.key, 6, bytearray(b"\x00" * 16))
# The GCM key is AES(0).
h = bytesToNumber(self._rawAesEncrypt(bytearray(16)))
@@ -51,11 +52,8 @@ class AESGCM(object):
self._productTable = [0] * 16
self._productTable[self._reverseBits(1)] = h
for i in range(2, 16, 2):
self._productTable[self._reverseBits(i)] = \
self._gcmShift(self._productTable[self._reverseBits(i//2)])
self._productTable[self._reverseBits(i+1)] = \
self._gcmAdd(self._productTable[self._reverseBits(i)], h)
self._productTable[self._reverseBits(i)] = self._gcmShift(self._productTable[self._reverseBits(i // 2)])
self._productTable[self._reverseBits(i + 1)] = self._gcmAdd(self._productTable[self._reverseBits(i)], h)
def _auth(self, ciphertext, ad, tagMask):
y = 0
@@ -68,7 +66,7 @@ class AESGCM(object):
def _update(self, y, data):
for i in range(0, len(data) // 16):
y ^= bytesToNumber(data[16*i:16*i+16])
y ^= bytesToNumber(data[16 * i : 16 * i + 16])
y = self._mul(y)
extra = len(data) % 16
if extra != 0:
@@ -79,26 +77,26 @@ class AESGCM(object):
return y
def _mul(self, y):
""" Returns y*H, where H is the GCM key. """
"""Returns y*H, where H is the GCM key."""
ret = 0
# Multiply H by y 4 bits at a time, starting with the highest power
# terms.
for i in range(0, 128, 4):
# Multiply by x^4. The reduction for the top four terms is
# precomputed.
retHigh = ret & 0xf
retHigh = ret & 0xF
ret >>= 4
ret ^= (AESGCM._gcmReductionTable[retHigh] << (128-16))
ret ^= AESGCM._gcmReductionTable[retHigh] << (128 - 16)
# Add in y' * H where y' are the next four terms of y, shifted down
# to the x^0..x^4. This is one of the pre-computed multiples of
# H. The multiplication by x^4 shifts them back into place.
ret ^= self._productTable[y & 0xf]
ret ^= self._productTable[y & 0xF]
y >>= 4
assert y == 0
return ret
def seal(self, nonce, plaintext, data=''):
def seal(self, nonce, plaintext, data=""):
"""
Encrypts and authenticates plaintext using nonce and data. Returns the
ciphertext, consisting of the encrypted plaintext and tag concatenated.
@@ -123,7 +121,7 @@ class AESGCM(object):
return ciphertext + tag
def open(self, nonce, ciphertext, data=''):
def open(self, nonce, ciphertext, data=""):
"""
Decrypts and authenticates ciphertext using nonce and data. If the
tag is valid, the plaintext is returned. If the tag is invalid,
@@ -156,8 +154,8 @@ class AESGCM(object):
@staticmethod
def _reverseBits(i):
assert i < 16
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
i = ((i << 2) & 0xC) | ((i >> 2) & 0x3)
i = ((i << 1) & 0xA) | ((i >> 1) & 0x5)
return i
@staticmethod
@@ -173,12 +171,12 @@ class AESGCM(object):
# The x^127 term was shifted up to x^128, so subtract a 1+x+x^2+x^7
# term. This is 0b11100001 or 0xe1 when represented as an 8-bit
# polynomial.
x ^= 0xe1 << (128-8)
x ^= 0xE1 << (128 - 8)
return x
@staticmethod
def _inc32(counter):
for i in range(len(counter)-1, len(counter)-5, -1):
for i in range(len(counter) - 1, len(counter) - 5, -1):
counter[i] = (counter[i] + 1) % 256
if counter[i] != 0:
break
@@ -188,6 +186,20 @@ class AESGCM(object):
# result is stored as a 16-bit polynomial. This is used in the reduction step to
# multiply elements of GF(2^128) by x^4.
_gcmReductionTable = [
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
0x0000,
0x1C20,
0x3840,
0x2460,
0x7080,
0x6CA0,
0x48C0,
0x54E0,
0xE100,
0xFD20,
0xD940,
0xC560,
0x9180,
0x8DA0,
0xA9C0,
0xB5E0,
]

View File

@@ -9,56 +9,56 @@ logger = logging.getLogger(__name__)
def is_base64_url(url: str) -> bool:
"""
Check if a URL appears to be base64 encoded.
Args:
url (str): The URL to check.
Returns:
bool: True if the URL appears to be base64 encoded, False otherwise.
"""
# Check if the URL doesn't start with http/https and contains base64-like characters
if url.startswith(('http://', 'https://', 'ftp://', 'ftps://')):
if url.startswith(("http://", "https://", "ftp://", "ftps://")):
return False
# Base64 URLs typically contain only alphanumeric characters, +, /, and =
# and don't contain typical URL characters like ://
base64_chars = set('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=')
# and don't contain typical URL characters like ://
base64_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
url_chars = set(url)
# If the URL contains characters not in base64 charset, it's likely not base64
if not url_chars.issubset(base64_chars):
return False
# Additional heuristic: base64 strings are typically longer and don't contain common URL patterns
if len(url) < 10: # Too short to be a meaningful base64 encoded URL
return False
return True
def decode_base64_url(encoded_url: str) -> Optional[str]:
"""
Decode a base64 encoded URL.
Args:
encoded_url (str): The base64 encoded URL string.
Returns:
Optional[str]: The decoded URL if successful, None if decoding fails.
"""
try:
# Handle URL-safe base64 encoding (replace - with + and _ with /)
url_safe_encoded = encoded_url.replace('-', '+').replace('_', '/')
url_safe_encoded = encoded_url.replace("-", "+").replace("_", "/")
# Add padding if necessary
missing_padding = len(url_safe_encoded) % 4
if missing_padding:
url_safe_encoded += '=' * (4 - missing_padding)
url_safe_encoded += "=" * (4 - missing_padding)
# Decode the base64 string
decoded_bytes = base64.b64decode(url_safe_encoded)
decoded_url = decoded_bytes.decode('utf-8')
decoded_url = decoded_bytes.decode("utf-8")
# Validate that the decoded string is a valid URL
parsed = urlparse(decoded_url)
if parsed.scheme and parsed.netloc:
@@ -67,7 +67,7 @@ def decode_base64_url(encoded_url: str) -> Optional[str]:
else:
logger.warning(f"Decoded string is not a valid URL: {decoded_url}")
return None
except (base64.binascii.Error, UnicodeDecodeError, ValueError) as e:
logger.debug(f"Failed to decode base64 URL '{encoded_url[:50]}...': {e}")
return None
@@ -76,27 +76,27 @@ def decode_base64_url(encoded_url: str) -> Optional[str]:
def encode_url_to_base64(url: str, url_safe: bool = True) -> str:
"""
Encode a URL to base64.
Args:
url (str): The URL to encode.
url_safe (bool): Whether to use URL-safe base64 encoding (default: True).
Returns:
str: The base64 encoded URL.
"""
try:
url_bytes = url.encode('utf-8')
url_bytes = url.encode("utf-8")
if url_safe:
# Use URL-safe base64 encoding (replace + with - and / with _)
encoded = base64.urlsafe_b64encode(url_bytes).decode('utf-8')
encoded = base64.urlsafe_b64encode(url_bytes).decode("utf-8")
# Remove padding for cleaner URLs
encoded = encoded.rstrip('=')
encoded = encoded.rstrip("=")
else:
encoded = base64.b64encode(url_bytes).decode('utf-8')
encoded = base64.b64encode(url_bytes).decode("utf-8")
logger.debug(f"Encoded URL to base64: {url} -> {encoded}")
return encoded
except Exception as e:
logger.error(f"Failed to encode URL to base64: {e}")
raise
@@ -106,10 +106,10 @@ def process_potential_base64_url(url: str) -> str:
"""
Process a URL that might be base64 encoded. If it's base64 encoded, decode it.
Otherwise, return the original URL.
Args:
url (str): The URL to process.
Returns:
str: The processed URL (decoded if it was base64, original otherwise).
"""
@@ -119,5 +119,5 @@ def process_potential_base64_url(url: str) -> str:
return decoded_url
else:
logger.warning(f"URL appears to be base64 but failed to decode: {url[:50]}...")
return url
return url

View File

@@ -0,0 +1,367 @@
"""
Base prebuffer class with shared functionality for HLS and DASH prebuffering.
This module provides cross-process download coordination using Redis-based locking
to prevent duplicate downloads across multiple uvicorn workers. Both player requests
and background prebuffer tasks use the same coordination mechanism.
"""
import asyncio
import logging
import time
import psutil
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional
from mediaflow_proxy.utils.cache_utils import (
get_cached_segment,
set_cached_segment,
)
from mediaflow_proxy.utils.http_utils import download_file_with_retry
from mediaflow_proxy.utils import redis_utils
logger = logging.getLogger(__name__)
@dataclass
class PrebufferStats:
"""Statistics for prebuffer performance tracking."""
cache_hits: int = 0
cache_misses: int = 0
segments_prebuffered: int = 0
bytes_prebuffered: int = 0
prefetch_triggered: int = 0
downloads_coordinated: int = 0 # Times we waited for existing download
last_reset: float = field(default_factory=time.time)
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate percentage."""
total = self.cache_hits + self.cache_misses
return (self.cache_hits / total * 100) if total > 0 else 0.0
def reset(self) -> None:
"""Reset statistics."""
self.cache_hits = 0
self.cache_misses = 0
self.segments_prebuffered = 0
self.bytes_prebuffered = 0
self.prefetch_triggered = 0
self.downloads_coordinated = 0
self.last_reset = time.time()
def to_dict(self) -> dict:
"""Convert stats to dictionary for logging."""
return {
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"hit_rate": f"{self.hit_rate:.1f}%",
"segments_prebuffered": self.segments_prebuffered,
"bytes_prebuffered_mb": f"{self.bytes_prebuffered / 1024 / 1024:.2f}",
"prefetch_triggered": self.prefetch_triggered,
"downloads_coordinated": self.downloads_coordinated,
"uptime_seconds": int(time.time() - self.last_reset),
}
class BasePrebuffer(ABC):
"""
Base class for prebuffer systems with cross-process download coordination.
This class provides:
- Cross-process coordination using Redis locks to prevent duplicate downloads
- Memory usage monitoring
- Cache statistics tracking
- Shared download and caching logic
The Redis-based locking ensures that even with multiple uvicorn workers,
only one worker downloads any given segment at a time.
Subclasses should implement protocol-specific logic (HLS playlist parsing,
DASH MPD handling, etc.) while inheriting the core download coordination.
"""
def __init__(
self,
max_cache_size: int,
prebuffer_segments: int,
max_memory_percent: float,
emergency_threshold: float,
segment_ttl: int = 60,
):
"""
Initialize the base prebuffer.
Args:
max_cache_size: Maximum number of segments to track
prebuffer_segments: Number of segments to pre-buffer ahead
max_memory_percent: Maximum memory usage percentage before skipping prebuffer
emergency_threshold: Memory threshold for emergency cleanup
segment_ttl: TTL for cached segments in seconds
"""
self.max_cache_size = max_cache_size
self.prebuffer_segment_count = prebuffer_segments
self.max_memory_percent = max_memory_percent
self.emergency_threshold = emergency_threshold
self.segment_ttl = segment_ttl
# Statistics (per-worker, not shared - but that's fine for monitoring)
self.stats = PrebufferStats()
# Stats logging task
self._stats_task: Optional[asyncio.Task] = None
self._stats_interval = 60 # Log stats every 60 seconds
def _get_memory_usage_percent(self) -> float:
"""Get current memory usage percentage."""
try:
memory = psutil.virtual_memory()
return memory.percent
except Exception as e:
logger.warning(f"Failed to get memory usage: {e}")
return 0.0
def _check_memory_threshold(self) -> bool:
"""Check if memory usage exceeds the emergency threshold."""
return self._get_memory_usage_percent() > self.emergency_threshold
def _should_skip_for_memory(self) -> bool:
"""Check if we should skip prebuffering due to high memory usage."""
return self._get_memory_usage_percent() > self.max_memory_percent
def record_cache_hit(self) -> None:
"""Record a cache hit for statistics."""
self.stats.cache_hits += 1
self._ensure_stats_logging()
def record_cache_miss(self) -> None:
"""Record a cache miss for statistics."""
self.stats.cache_misses += 1
self._ensure_stats_logging()
def _ensure_stats_logging(self) -> None:
"""Ensure the stats logging task is running."""
if self._stats_task is None or self._stats_task.done():
self._stats_task = asyncio.create_task(self._periodic_stats_logging())
async def _periodic_stats_logging(self) -> None:
"""Periodically log prebuffer statistics."""
while True:
try:
await asyncio.sleep(self._stats_interval)
# Only log if there's been activity
if self.stats.cache_hits > 0 or self.stats.cache_misses > 0:
self.log_stats()
except asyncio.CancelledError:
return
except Exception as e:
logger.warning(f"Error in stats logging: {e}")
async def get_or_download(
self,
url: str,
headers: Dict[str, str],
timeout: float = 10.0,
) -> Optional[bytes]:
"""
Get a segment from cache or download it, with cross-process coordination.
This is the primary method for getting segments. It:
1. Checks cache first (immediate return if hit)
2. Acquires Redis lock to prevent duplicate downloads across workers
3. Double-checks cache after acquiring lock
4. Downloads and caches if needed
The Redis-based locking ensures that even with multiple uvicorn workers,
only one worker downloads any given segment at a time.
Args:
url: URL of the segment to get
headers: Headers to use for the request
timeout: Maximum time to wait for lock acquisition (seconds).
Keep this short (10s) for player requests - if lock is held
too long, fall back to direct streaming.
Returns:
Segment data if successful, None if failed or timed out
"""
self._ensure_stats_logging()
# Check cache first (Redis cache is shared across workers)
cached = await get_cached_segment(url)
if cached:
self.record_cache_hit()
logger.info(f"[get_or_download] CACHE HIT ({len(cached)} bytes): {url}")
return cached
# Cache miss - need to coordinate download across workers
logger.info(f"[get_or_download] CACHE MISS: {url}")
lock_key = f"segment_download:{url}"
lock_acquired = False
try:
# Acquire Redis lock - only one worker downloads at a time
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=timeout)
if not lock_acquired:
logger.warning(f"[get_or_download] Lock TIMEOUT ({timeout}s), falling back to streaming: {url}")
return None
# Double-check cache after acquiring lock
# Another worker may have completed the download while we waited
cached = await get_cached_segment(url)
if cached:
# Count this as a cache hit since we didn't download
self.record_cache_hit()
self.stats.downloads_coordinated += 1
logger.info(f"[get_or_download] Found in cache after lock (coordinated): {url}")
return cached
# We're the one who needs to download - count as miss now
self.record_cache_miss()
# We're the first - download and cache
logger.info(f"[get_or_download] Downloading: {url}")
content = await self._download_and_cache(url, headers)
return content
except Exception as e:
logger.warning(f"[get_or_download] Error during download coordination: {e}")
return None
finally:
if lock_acquired:
await redis_utils.release_lock(lock_key)
async def _download_and_cache(
self,
url: str,
headers: Dict[str, str],
) -> Optional[bytes]:
"""
Download a segment and cache it.
This method should only be called while holding the Redis lock.
Args:
url: URL to download
headers: Headers for the request
Returns:
Downloaded content if successful, None otherwise
"""
try:
content = await download_file_with_retry(url, headers)
if content:
logger.info(f"[_download_and_cache] Downloaded {len(content)} bytes, caching: {url}")
await set_cached_segment(url, content, ttl=self.segment_ttl)
self.stats.segments_prebuffered += 1
self.stats.bytes_prebuffered += len(content)
return content
else:
logger.warning(f"[_download_and_cache] Download returned empty: {url}")
return None
except Exception as e:
logger.warning(f"[_download_and_cache] Failed to download: {url} - {e}")
return None
async def try_get_cached(self, url: str) -> Optional[bytes]:
"""
Check cache only, don't download.
Use this for background prebuffer tasks that shouldn't block
if segment isn't available yet.
Args:
url: URL to check in cache
Returns:
Cached data if available, None otherwise
"""
return await get_cached_segment(url)
async def prebuffer_segment(self, url: str, headers: Dict[str, str]) -> None:
"""
Prebuffer a single segment in the background.
This method uses Redis locking to prevent duplicate downloads
across multiple workers.
Args:
url: URL of segment to prebuffer
headers: Headers for the request
"""
if self._should_skip_for_memory():
logger.debug("Skipping prebuffer due to high memory usage")
return
# Check if already cached
cached = await get_cached_segment(url)
if cached:
logger.debug(f"[prebuffer_segment] Already cached, skipping: {url}")
return
lock_key = f"segment_download:{url}"
lock_acquired = False
try:
# Try to acquire lock with short timeout for prebuffering
# If lock is held by another process, skip this segment
lock_acquired = await redis_utils.acquire_lock(lock_key, ttl=30, timeout=1.0)
if not lock_acquired:
# Another process is downloading, skip this segment
logger.debug(f"[prebuffer_segment] Lock busy, skipping: {url}")
return
# Double-check cache after acquiring lock
cached = await get_cached_segment(url)
if cached:
logger.debug(f"[prebuffer_segment] Found in cache after lock: {url}")
return
# Download and cache
logger.info(f"[prebuffer_segment] Downloading: {url}")
await self._download_and_cache(url, headers)
except Exception as e:
logger.warning(f"[prebuffer_segment] Error: {e}")
finally:
if lock_acquired:
await redis_utils.release_lock(lock_key)
async def prebuffer_segments_batch(
self,
urls: list,
headers: Dict[str, str],
max_concurrent: int = 2,
) -> None:
"""
Prebuffer multiple segments with concurrency control.
Args:
urls: List of segment URLs to prebuffer
headers: Headers for requests
max_concurrent: Maximum concurrent downloads (default 2 to avoid
lock contention with player requests)
"""
if self._should_skip_for_memory():
logger.warning("Skipping prebuffer due to high memory usage")
return
semaphore = asyncio.Semaphore(max_concurrent)
async def limited_prebuffer(url: str):
async with semaphore:
await self.prebuffer_segment(url, headers)
# Start all prebuffer tasks
tasks = [limited_prebuffer(url) for url in urls]
await asyncio.gather(*tasks, return_exceptions=True)
def log_stats(self) -> None:
"""Log current prebuffer statistics."""
logger.info(f"Prebuffer Stats: {self.stats.to_dict()}")

View File

@@ -1,308 +1,31 @@
import asyncio
import hashlib
import json
import logging
import os
import tempfile
import threading
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, Any
"""
Cache utilities for mediaflow-proxy.
import aiofiles
import aiofiles.os
All caching is now done via Redis for cross-worker sharing.
See redis_utils.py for the underlying Redis operations.
"""
import logging
from typing import Optional
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
from mediaflow_proxy.utils import redis_utils
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Represents a cache entry with metadata."""
data: bytes
expires_at: float
access_count: int = 0
last_access: float = 0.0
size: int = 0
# =============================================================================
# Init Segment Cache
# =============================================================================
class LRUMemoryCache:
"""Thread-safe LRU memory cache with support."""
def __init__(self, maxsize: int):
self.maxsize = maxsize
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.Lock()
self._current_size = 0
def get(self, key: str) -> Optional[CacheEntry]:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key) # Remove and re-insert for LRU
if time.time() < entry.expires_at:
entry.access_count += 1
entry.last_access = time.time()
self._cache[key] = entry
return entry
else:
# Remove expired entry
self._current_size -= entry.size
self._cache.pop(key, None)
return None
def set(self, key: str, entry: CacheEntry) -> None:
with self._lock:
if key in self._cache:
old_entry = self._cache[key]
self._current_size -= old_entry.size
# Check if we need to make space
while self._current_size + entry.size > self.maxsize and self._cache:
_, removed_entry = self._cache.popitem(last=False)
self._current_size -= removed_entry.size
self._cache[key] = entry
self._current_size += entry.size
def remove(self, key: str) -> None:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key)
self._current_size -= entry.size
class HybridCache:
"""High-performance hybrid cache combining memory and file storage."""
def __init__(
self,
cache_dir_name: str,
ttl: int,
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
executor_workers: int = 4,
):
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
self.ttl = ttl
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
self._lock = asyncio.Lock()
# Initialize cache directories
self._init_cache_dirs()
def _init_cache_dirs(self):
"""Initialize sharded cache directories."""
os.makedirs(self.cache_dir, exist_ok=True)
def _get_md5_hash(self, key: str) -> str:
"""Get the MD5 hash of a cache key."""
return hashlib.md5(key.encode()).hexdigest()
def _get_file_path(self, key: str) -> Path:
"""Get the file path for a cache key."""
return self.cache_dir / key
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""
Get value from cache, trying memory first then file.
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default if not found
"""
key = self._get_md5_hash(key)
# Try memory cache first
entry = self.memory_cache.get(key)
if entry is not None:
return entry.data
# Try file cache
try:
file_path = self._get_file_path(key)
async with aiofiles.open(file_path, "rb") as f:
metadata_size = await f.read(8)
metadata_length = int.from_bytes(metadata_size, "big")
metadata_bytes = await f.read(metadata_length)
metadata = json.loads(metadata_bytes.decode())
# Check expiration
if metadata["expires_at"] < time.time():
await self.delete(key)
return default
# Read data
data = await f.read()
# Update memory cache in background
entry = CacheEntry(
data=data,
expires_at=metadata["expires_at"],
access_count=metadata["access_count"] + 1,
last_access=time.time(),
size=len(data),
)
self.memory_cache.set(key, entry)
return data
except FileNotFoundError:
return default
except Exception as e:
logger.error(f"Error reading from cache: {e}")
return default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""
Set value in both memory and file cache.
Args:
key: Cache key
data: Data to cache
ttl: Optional TTL override
Returns:
bool: Success status
"""
if not isinstance(data, (bytes, bytearray, memoryview)):
raise ValueError("Data must be bytes, bytearray, or memoryview")
ttl_seconds = self.ttl if ttl is None else ttl
key = self._get_md5_hash(key)
if ttl_seconds <= 0:
# Explicit request to avoid caching - remove any previous entry and return success
self.memory_cache.remove(key)
try:
file_path = self._get_file_path(key)
await aiofiles.os.remove(file_path)
except FileNotFoundError:
pass
except Exception as e:
logger.error(f"Error removing cache file: {e}")
return True
expires_at = time.time() + ttl_seconds
# Create cache entry
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
# Update memory cache
self.memory_cache.set(key, entry)
file_path = self._get_file_path(key)
temp_path = file_path.with_suffix(".tmp")
# Update file cache
try:
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
metadata_bytes = json.dumps(metadata).encode()
metadata_size = len(metadata_bytes).to_bytes(8, "big")
async with aiofiles.open(temp_path, "wb") as f:
await f.write(metadata_size)
await f.write(metadata_bytes)
await f.write(data)
await aiofiles.os.rename(temp_path, file_path)
return True
except Exception as e:
logger.error(f"Error writing to cache: {e}")
try:
await aiofiles.os.remove(temp_path)
except:
pass
return False
async def delete(self, key: str) -> bool:
"""Delete item from both caches."""
hashed_key = self._get_md5_hash(key)
self.memory_cache.remove(hashed_key)
try:
file_path = self._get_file_path(hashed_key)
await aiofiles.os.remove(file_path)
return True
except FileNotFoundError:
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
class AsyncMemoryCache:
"""Async wrapper around LRUMemoryCache."""
def __init__(self, max_memory_size: int):
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""Get value from cache."""
entry = self.memory_cache.get(key)
return entry.data if entry is not None else default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""Set value in cache."""
try:
ttl_seconds = 3600 if ttl is None else ttl
if ttl_seconds <= 0:
self.memory_cache.remove(key)
return True
expires_at = time.time() + ttl_seconds
entry = CacheEntry(
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
)
self.memory_cache.set(key, entry)
return True
except Exception as e:
logger.error(f"Error setting cache value: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete item from cache."""
try:
self.memory_cache.remove(key)
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
# Create cache instances
INIT_SEGMENT_CACHE = HybridCache(
cache_dir_name="init_segment_cache",
ttl=3600, # 1 hour
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
)
MPD_CACHE = AsyncMemoryCache(
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
)
EXTRACTOR_CACHE = HybridCache(
cache_dir_name="extractor_cache",
ttl=5 * 60, # 5 minutes
max_memory_size=50 * 1024 * 1024,
)
# Specific cache implementations
async def get_cached_init_segment(
init_url: str,
headers: dict,
cache_token: str | None = None,
ttl: Optional[int] = None,
byte_range: str | None = None,
) -> Optional[bytes]:
"""Get initialization segment from cache or download it.
@@ -310,29 +33,39 @@ async def get_cached_init_segment(
rely on different DRM keys or initialization payloads (e.g. key rotation).
ttl overrides the default cache TTL; pass a value <= 0 to skip caching entirely.
"""
byte_range specifies a byte range for SegmentBase MPDs (e.g., '0-11568').
"""
use_cache = ttl is None or ttl > 0
cache_key = f"{init_url}|{cache_token}" if cache_token else init_url
# Include byte_range in cache key for SegmentBase
cache_key = f"{init_url}|{cache_token}|{byte_range}" if cache_token or byte_range else init_url
if use_cache:
cached_data = await INIT_SEGMENT_CACHE.get(cache_key)
cached_data = await redis_utils.get_cached_init_segment(cache_key)
if cached_data is not None:
return cached_data
else:
# Remove any previously cached entry when caching is disabled
await INIT_SEGMENT_CACHE.delete(cache_key)
try:
init_content = await download_file_with_retry(init_url, headers)
# Add Range header if byte_range is specified (for SegmentBase MPDs)
request_headers = dict(headers)
if byte_range:
request_headers["Range"] = f"bytes={byte_range}"
init_content = await download_file_with_retry(init_url, request_headers)
if init_content and use_cache:
await INIT_SEGMENT_CACHE.set(cache_key, init_content, ttl=ttl)
cache_ttl = ttl if ttl is not None else redis_utils.DEFAULT_INIT_CACHE_TTL
await redis_utils.set_cached_init_segment(cache_key, init_content, ttl=cache_ttl)
return init_content
except Exception as e:
logger.error(f"Error downloading init segment: {e}")
return None
# =============================================================================
# MPD Cache
# =============================================================================
async def get_cached_mpd(
mpd_url: str,
headers: dict,
@@ -341,13 +74,13 @@ async def get_cached_mpd(
) -> dict:
"""Get MPD from cache or download and parse it."""
# Try cache first
cached_data = await MPD_CACHE.get(mpd_url)
cached_data = await redis_utils.get_cached_mpd(mpd_url)
if cached_data is not None:
try:
mpd_dict = json.loads(cached_data)
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
except json.JSONDecodeError:
await MPD_CACHE.delete(mpd_url)
return parse_mpd_dict(cached_data, mpd_url, parse_drm, parse_segment_profile_id)
except Exception:
# Invalid cached data, will re-download
pass
# Download and parse if not cached
try:
@@ -355,8 +88,9 @@ async def get_cached_mpd(
mpd_dict = parse_mpd(mpd_content)
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
# Cache the original MPD dict
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict.get("minimumUpdatePeriod"))
# Cache the original MPD dict with TTL from minimumUpdatePeriod
cache_ttl = parsed_dict.get("minimumUpdatePeriod") or redis_utils.DEFAULT_MPD_CACHE_TTL
await redis_utils.set_cached_mpd(mpd_url, mpd_dict, ttl=cache_ttl)
return parsed_dict
except DownloadError as error:
logger.error(f"Error downloading MPD: {error}")
@@ -366,21 +100,158 @@ async def get_cached_mpd(
raise error
# =============================================================================
# Extractor Cache
# =============================================================================
async def get_cached_extractor_result(key: str) -> Optional[dict]:
"""Get extractor result from cache."""
cached_data = await EXTRACTOR_CACHE.get(key)
if cached_data is not None:
try:
return json.loads(cached_data)
except json.JSONDecodeError:
await EXTRACTOR_CACHE.delete(key)
return None
return await redis_utils.get_cached_extractor(key)
async def set_cache_extractor_result(key: str, result: dict) -> bool:
"""Cache extractor result."""
try:
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
await redis_utils.set_cached_extractor(key, result)
return True
except Exception as e:
logger.error(f"Error caching extractor result: {e}")
return False
# =============================================================================
# Processed Init Segment Cache
# =============================================================================
async def get_cached_processed_init(
init_url: str,
key_id: str,
) -> Optional[bytes]:
"""Get processed (DRM-stripped) init segment from cache.
Args:
init_url: URL of the init segment
key_id: DRM key ID used for processing
Returns:
Processed init segment bytes if cached, None otherwise
"""
cache_key = f"processed|{init_url}|{key_id}"
return await redis_utils.get_cached_processed_init(cache_key)
async def set_cached_processed_init(
init_url: str,
key_id: str,
processed_content: bytes,
ttl: Optional[int] = None,
) -> bool:
"""Cache processed (DRM-stripped) init segment.
Args:
init_url: URL of the init segment
key_id: DRM key ID used for processing
processed_content: The processed init segment bytes
ttl: Optional TTL override
Returns:
True if cached successfully
"""
cache_key = f"processed|{init_url}|{key_id}"
try:
cache_ttl = ttl if ttl is not None else redis_utils.DEFAULT_PROCESSED_INIT_TTL
await redis_utils.set_cached_processed_init(cache_key, processed_content, ttl=cache_ttl)
return True
except Exception as e:
logger.error(f"Error caching processed init segment: {e}")
return False
# =============================================================================
# Processed Segment Cache (decrypted/remuxed segments)
# =============================================================================
async def get_cached_processed_segment(
segment_url: str,
key_id: str = None,
remux: bool = False,
) -> Optional[bytes]:
"""Get processed (decrypted/remuxed) segment from cache.
Args:
segment_url: URL of the segment
key_id: DRM key ID if decrypted
remux: Whether the segment was remuxed to TS
Returns:
Processed segment bytes if cached, None otherwise
"""
cache_key = f"proc|{segment_url}|{key_id or ''}|{remux}"
return await redis_utils.get_cached_segment(cache_key)
async def set_cached_processed_segment(
segment_url: str,
content: bytes,
key_id: str = None,
remux: bool = False,
ttl: int = 60,
) -> bool:
"""Cache processed (decrypted/remuxed) segment.
Args:
segment_url: URL of the segment
content: Processed segment bytes
key_id: DRM key ID if decrypted
remux: Whether the segment was remuxed to TS
ttl: Time to live in seconds
Returns:
True if cached successfully
"""
cache_key = f"proc|{segment_url}|{key_id or ''}|{remux}"
try:
await redis_utils.set_cached_segment(cache_key, content, ttl=ttl)
return True
except Exception as e:
logger.error(f"Error caching processed segment: {e}")
return False
# =============================================================================
# Segment Cache
# =============================================================================
async def get_cached_segment(segment_url: str) -> Optional[bytes]:
"""Get media segment from prebuffer cache.
Args:
segment_url: URL of the segment
Returns:
Segment bytes if cached, None otherwise
"""
return await redis_utils.get_cached_segment(segment_url)
async def set_cached_segment(segment_url: str, content: bytes, ttl: int = 60) -> bool:
"""Cache media segment with configurable TTL.
Args:
segment_url: URL of the segment
content: Segment bytes
ttl: Time to live in seconds (default 60s, configurable via dash_segment_cache_ttl)
Returns:
True if cached successfully
"""
try:
await redis_utils.set_cached_segment(segment_url, content, ttl=ttl)
return True
except Exception as e:
logger.error(f"Error caching segment: {e}")
return False

View File

@@ -1,27 +1,27 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
# See the LICENSE file for legal information regarding use of this file
"""Classes for reading/writing binary data (such as TLS records)."""
from __future__ import division
import sys
import struct
from struct import pack
from .compat import bytes_to_int
class DecodeError(SyntaxError):
"""Exception raised in case of decoding errors."""
pass
class BadCertificateError(SyntaxError):
"""Exception raised in case of bad certificate."""
pass
class Writer(object):
class Writer:
"""Serialisation helper for complex byte-based structures."""
def __init__(self):
@@ -32,102 +32,51 @@ class Writer(object):
"""Add a single-byte wide element to buffer, see add()."""
self.bytes.append(val)
if sys.version_info < (2, 7):
# struct.pack on Python2.6 does not raise exception if the value
# is larger than can fit inside the specified size
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>H', val)
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
try:
self.bytes += pack(">H", val)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>BH', val >> 16, val & 0xffff)
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
try:
self.bytes += pack(">BH", val >> 16, val & 0xFFFF)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffffffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>I', val)
else:
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>H', val)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
try:
self.bytes += pack(">I", val)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>BH', val >> 16, val & 0xffff)
except struct.error:
raise ValueError("Can't represent value in specified length")
def add(self, x, length):
"""
Add a single positive integer value x, encode it in length bytes.
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>I', val)
except struct.error:
raise ValueError("Can't represent value in specified length")
Encode positive integer x in big-endian format using length bytes,
add to the internal buffer.
if sys.version_info >= (3, 0):
# the method is called thousands of times, so it's better to extern
# the version info check
def add(self, x, length):
"""
Add a single positive integer value x, encode it in length bytes
:type x: int
:param x: value to encode
Encode positive integer x in big-endian format using length bytes,
add to the internal buffer.
:type x: int
:param x: value to encode
:type length: int
:param length: number of bytes to use for encoding the value
"""
try:
self.bytes += x.to_bytes(length, 'big')
except OverflowError:
raise ValueError("Can't represent value in specified length")
else:
_addMethods = {1: addOne, 2: addTwo, 3: addThree, 4: addFour}
def add(self, x, length):
"""
Add a single positive integer value x, encode it in length bytes
Encode positive iteger x in big-endian format using length bytes,
add to the internal buffer.
:type x: int
:param x: value to encode
:type length: int
:param length: number of bytes to use for encoding the value
"""
try:
self._addMethods[length](self, x)
except KeyError:
self.bytes += bytearray(length)
newIndex = len(self.bytes) - 1
for i in range(newIndex, newIndex - length, -1):
self.bytes[i] = x & 0xFF
x >>= 8
if x != 0:
raise ValueError("Can't represent value in specified "
"length")
:type length: int
:param length: number of bytes to use for encoding the value
"""
try:
self.bytes += x.to_bytes(length, "big")
except OverflowError:
raise ValueError("Can't represent value in specified length")
def addFixSeq(self, seq, length):
"""
Add a list of items, encode every item in length bytes
Add a list of items, encode every item in length bytes.
Uses the unbounded iterable seq to produce items, each of
which is then encoded to length bytes
which is then encoded to length bytes.
:type seq: iterable of int
:param seq: list of positive integers to encode
@@ -138,72 +87,35 @@ class Writer(object):
for e in seq:
self.add(e, length)
if sys.version_info < (2, 7):
# struct.pack on Python2.6 does not raise exception if the value
# is larger than can fit inside the specified size
def _addVarSeqTwo(self, seq):
"""Helper method for addVarSeq"""
if not all(0 <= i <= 0xffff for i in seq):
raise ValueError("Can't represent value in specified "
"length")
self.bytes += pack('>' + 'H' * len(seq), *seq)
def addVarSeq(self, seq, length, lengthLength):
"""
Add a bounded list of same-sized values.
def addVarSeq(self, seq, length, lengthLength):
"""
Add a bounded list of same-sized values
Create a list of specific length with all items being of the same
size.
Create a list of specific length with all items being of the same
size
:type seq: list of int
:param seq: list of positive integers to encode
:type seq: list of int
:param seq: list of positive integers to encode
:type length: int
:param length: amount of bytes in which to encode every item
:type length: int
:param length: amount of bytes in which to encode every item
:type lengthLength: int
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
self.add(len(seq)*length, lengthLength)
if length == 1:
self.bytes.extend(seq)
elif length == 2:
self._addVarSeqTwo(seq)
else:
for i in seq:
self.add(i, length)
else:
def addVarSeq(self, seq, length, lengthLength):
"""
Add a bounded list of same-sized values
Create a list of specific length with all items being of the same
size
:type seq: list of int
:param seq: list of positive integers to encode
:type length: int
:param length: amount of bytes in which to encode every item
:type lengthLength: int
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
seqLen = len(seq)
self.add(seqLen*length, lengthLength)
if length == 1:
self.bytes.extend(seq)
elif length == 2:
try:
self.bytes += pack('>' + 'H' * seqLen, *seq)
except struct.error:
raise ValueError("Can't represent value in specified "
"length")
else:
for i in seq:
self.add(i, length)
:type lengthLength: int
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
seqLen = len(seq)
self.add(seqLen * length, lengthLength)
if length == 1:
self.bytes.extend(seq)
elif length == 2:
try:
self.bytes += pack(">" + "H" * seqLen, *seq)
except struct.error:
raise ValueError("Can't represent value in specified length")
else:
for i in seq:
self.add(i, length)
def addVarTupleSeq(self, seq, length, lengthLength):
"""
@@ -257,7 +169,7 @@ class Writer(object):
self.bytes += data
class Parser(object):
class Parser:
"""
Parser for TLV and LV byte-based encodings.
@@ -269,9 +181,6 @@ class Parser(object):
read a 4-byte integer from a 2-byte buffer), most methods will raise a
DecodeError exception.
TODO: don't use an exception used by language parser to indicate errors
in application code.
:vartype bytes: bytearray
:ivar bytes: data to be interpreted (buffer)
@@ -285,14 +194,14 @@ class Parser(object):
:ivar indexCheck: position at which the structure begins in buffer
"""
def __init__(self, bytes):
def __init__(self, data):
"""
Bind raw bytes with parser.
:type bytes: bytearray
:param bytes: bytes to be parsed/interpreted
:type data: bytearray
:param data: bytes to be parsed/interpreted
"""
self.bytes = bytes
self.bytes = data
self.index = 0
self.indexCheck = 0
self.lengthCheck = 0
@@ -307,7 +216,7 @@ class Parser(object):
:rtype: int
"""
ret = self.getFixBytes(length)
return bytes_to_int(ret, 'big')
return bytes_to_int(ret, "big")
def getFixBytes(self, lengthBytes):
"""
@@ -358,10 +267,10 @@ class Parser(object):
:rtype: list of int
"""
l = [0] * lengthList
result = [0] * lengthList
for x in range(lengthList):
l[x] = self.get(length)
return l
result[x] = self.get(length)
return result
def getVarList(self, length, lengthLength):
"""
@@ -377,13 +286,12 @@ class Parser(object):
"""
lengthList = self.get(lengthLength)
if lengthList % length != 0:
raise DecodeError("Encoded length not a multiple of element "
"length")
raise DecodeError("Encoded length not a multiple of element length")
lengthList = lengthList // length
l = [0] * lengthList
result = [0] * lengthList
for x in range(lengthList):
l[x] = self.get(length)
return l
result[x] = self.get(length)
return result
def getVarTupleList(self, elemLength, elemNum, lengthLength):
"""
@@ -402,8 +310,7 @@ class Parser(object):
"""
lengthList = self.get(lengthLength)
if lengthList % (elemLength * elemNum) != 0:
raise DecodeError("Encoded length not a multiple of element "
"length")
raise DecodeError("Encoded length not a multiple of element length")
tupleCount = lengthList // (elemLength * elemNum)
tupleList = []
for _ in range(tupleCount):

View File

@@ -1,223 +1,114 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
# See the LICENSE file for legal information regarding use of this file
"""Miscellaneous functions to mask Python version differences."""
"""Miscellaneous utility functions for Python 3.13+."""
import sys
import re
import platform
import binascii
import traceback
import re
import time
if sys.version_info >= (3,0):
def compat26Str(x):
"""Identity function for compatibility."""
return x
def compat26Str(x): return x
# Python 3.3 requires bytes instead of bytearrays for HMAC
# So, python 2.6 requires strings, python 3 requires 'bytes',
# and python 2.7 and 3.5 can handle bytearrays...
# pylint: disable=invalid-name
# we need to keep compatHMAC and `x` for API compatibility
if sys.version_info < (3, 4):
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
return bytes(x)
else:
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
return x
# pylint: enable=invalid-name
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
return x
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
if isinstance(val, str):
return bytes(val, 'ascii')
return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
return str(val, 'ascii')
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
if isinstance(val, str):
return bytes(val, "ascii")
return val
def raw_input(s):
return input(s)
# So, the python3 binascii module deals with bytearrays, and python2
# deals with strings... I would rather deal with the "a" part as
# strings, and the "b" part as bytearrays, regardless of python version,
# so...
def a2b_hex(s):
try:
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
except Exception as e:
raise SyntaxError("base16 error: %s" % e)
return b
def a2b_base64(s):
try:
if isinstance(s, str):
s = bytearray(s, "ascii")
b = bytearray(binascii.a2b_base64(s))
except Exception as e:
raise SyntaxError("base64 error: %s" % e)
return b
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
return str(val, "ascii")
def b2a_hex(b):
return binascii.b2a_hex(b).decode("ascii")
def b2a_base64(b):
return binascii.b2a_base64(b).decode("ascii")
def readStdinBinary():
return sys.stdin.buffer.read()
def a2b_hex(s):
"""Convert hex string to bytearray."""
try:
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
except Exception as e:
raise SyntaxError(f"base16 error: {e}") from e
return b
def compatLong(num):
return int(num)
int_types = tuple([int])
def a2b_base64(s):
"""Convert base64 string to bytearray."""
try:
if isinstance(s, str):
s = bytearray(s, "ascii")
b = bytearray(binascii.a2b_base64(s))
except Exception as e:
raise SyntaxError(f"base64 error: {e}") from e
return b
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
return str(e)
def time_stamp():
"""Returns system time as a float"""
if sys.version_info >= (3, 3):
return time.perf_counter()
return time.clock()
def b2a_hex(b):
"""Convert bytes to hex string."""
return binascii.b2a_hex(b).decode("ascii")
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
# pylint: disable=invalid-name
# pylint is stupid here and deson't notice it's a function, not
# constant
bytes_to_int = int.from_bytes
# pylint: enable=invalid-name
def b2a_base64(b):
"""Convert bytes to base64 string."""
return binascii.b2a_base64(b).decode("ascii")
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes"""
if length is None:
if val:
length = byte_length(val)
else:
length = 1
# for gmpy we need to convert back to native int
if type(val) != int:
val = int(val)
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
def readStdinBinary():
"""Read binary data from stdin."""
import sys
else:
# Python 2.6 requires strings instead of bytearrays in a couple places,
# so we define this function so it does the conversion if needed.
# same thing with very old 2.7 versions
# or on Jython
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) \
or platform.system() == 'Java':
def compat26Str(x): return str(x)
return sys.stdin.buffer.read()
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
if val == 0:
return 0
return len(bin(val))-2
else:
def compat26Str(x): return x
def compatLong(num):
"""Convert to int (compatibility function)."""
return int(num)
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
int_types = (int,)
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
return str(val)
def formatExceptionTrace(e):
"""Return exception information formatted as string."""
return str(e)
# So, python 2.6 requires strings, python 3 requires 'bytes',
# and python 2.7 can handle bytearrays...
def compatHMAC(x): return compat26Str(x)
def a2b_hex(s):
try:
b = bytearray(binascii.a2b_hex(s))
except Exception as e:
raise SyntaxError("base16 error: %s" % e)
return b
def time_stamp():
"""Returns system time as a float."""
return time.perf_counter()
def a2b_base64(s):
try:
b = bytearray(binascii.a2b_base64(s))
except Exception as e:
raise SyntaxError("base64 error: %s" % e)
return b
def b2a_hex(b):
return binascii.b2a_hex(compat26Str(b))
def b2a_base64(b):
return binascii.b2a_base64(compat26Str(b))
def compatLong(num):
return long(num)
def remove_whitespace(text):
"""Removes all whitespace from passed in string."""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
int_types = (int, long)
# pylint on Python3 goes nuts for the sys dereferences...
bytes_to_int = int.from_bytes
#pylint: disable=no-member
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
newStr = "".join(traceback.format_exception(sys.exc_type,
sys.exc_value,
sys.exc_traceback))
return newStr
#pylint: enable=no-member
def time_stamp():
"""Returns system time as a float"""
return time.clock()
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def bytes_to_int(val, byteorder):
"""Convert bytes to an int."""
if not val:
return 0
if byteorder == "big":
return int(b2a_hex(val), 16)
if byteorder == "little":
return int(b2a_hex(val[::-1]), 16)
raise ValueError("Only 'big' and 'little' endian supported")
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes"""
if length is None:
if val:
length = byte_length(val)
else:
length = 1
if byteorder == "big":
return bytearray((val >> i) & 0xff
for i in reversed(range(0, length*8, 8)))
if byteorder == "little":
return bytearray((val >> i) & 0xff
for i in range(0, length*8, 8))
raise ValueError("Only 'big' or 'little' endian supported")
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes."""
if length is None:
if val:
length = byte_length(val)
else:
length = 1
# for gmpy we need to convert back to native int
if not isinstance(val, int):
val = int(val)
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
def byte_length(val):

View File

@@ -8,6 +8,7 @@ from __future__ import division
from .compat import compatHMAC
import hmac
def ct_lt_u32(val_a, val_b):
"""
Returns 1 if val_a < val_b, 0 otherwise. Constant time.
@@ -18,10 +19,10 @@ def ct_lt_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val_a &= 0xffffffff
val_b &= 0xffffffff
val_a &= 0xFFFFFFFF
val_b &= 0xFFFFFFFF
return (val_a^((val_a^val_b)|(((val_a-val_b)&0xffffffff)^val_b)))>>31
return (val_a ^ ((val_a ^ val_b) | (((val_a - val_b) & 0xFFFFFFFF) ^ val_b))) >> 31
def ct_gt_u32(val_a, val_b):
@@ -77,8 +78,8 @@ def ct_isnonzero_u32(val):
:param val: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val &= 0xffffffff
return (val|(-val&0xffffffff)) >> 31
val &= 0xFFFFFFFF
return (val | (-val & 0xFFFFFFFF)) >> 31
def ct_neq_u32(val_a, val_b):
@@ -91,10 +92,11 @@ def ct_neq_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val_a &= 0xffffffff
val_b &= 0xffffffff
val_a &= 0xFFFFFFFF
val_b &= 0xFFFFFFFF
return (((val_a - val_b) & 0xFFFFFFFF) | ((val_b - val_a) & 0xFFFFFFFF)) >> 31
return (((val_a-val_b)&0xffffffff) | ((val_b-val_a)&0xffffffff)) >> 31
def ct_eq_u32(val_a, val_b):
"""
@@ -108,8 +110,8 @@ def ct_eq_u32(val_a, val_b):
"""
return 1 ^ ct_neq_u32(val_a, val_b)
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
block_size=16):
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version, block_size=16):
"""
Check CBC cipher HMAC and padding. Close to constant time.
@@ -135,7 +137,7 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
assert version in ((3, 0), (3, 1), (3, 2), (3, 3))
data_len = len(data)
if mac.digest_size + 1 > data_len: # data_len is public
if mac.digest_size + 1 > data_len: # data_len is public
return False
# 0 - OK
@@ -144,11 +146,11 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
#
# check padding
#
pad_length = data[data_len-1]
pad_length = data[data_len - 1]
pad_start = data_len - pad_length - 1
pad_start = max(0, pad_start)
if version == (3, 0): # version is public
if version == (3, 0): # version is public
# in SSLv3 we can only check if pad is not longer than the cipher
# block size
@@ -179,33 +181,35 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
data_mac = mac.copy()
data_mac.update(compatHMAC(seqnumBytes))
data_mac.update(compatHMAC(bytearray([contentType])))
if version != (3, 0): # version is public
if version != (3, 0): # version is public
data_mac.update(compatHMAC(bytearray([version[0]])))
data_mac.update(compatHMAC(bytearray([version[1]])))
data_mac.update(compatHMAC(bytearray([mac_start >> 8])))
data_mac.update(compatHMAC(bytearray([mac_start & 0xff])))
data_mac.update(compatHMAC(bytearray([mac_start & 0xFF])))
data_mac.update(compatHMAC(data[:start_pos]))
# don't check past the array end (already checked to be >= zero)
end_pos = data_len - mac.digest_size
# calculate all possible
for i in range(start_pos, end_pos): # constant for given overall length
for i in range(start_pos, end_pos): # constant for given overall length
cur_mac = data_mac.copy()
cur_mac.update(compatHMAC(data[start_pos:i]))
mac_compare = bytearray(cur_mac.digest())
# compare the hash for real only if it's the place where mac is
# supposed to be
mask = ct_lsb_prop_u8(ct_eq_u32(i, mac_start))
for j in range(0, mac.digest_size): # digest_size is public
result |= (data[i+j] ^ mac_compare[j]) & mask
for j in range(0, mac.digest_size): # digest_size is public
result |= (data[i + j] ^ mac_compare[j]) & mask
# return python boolean
return result == 0
if hasattr(hmac, 'compare_digest'):
if hasattr(hmac, "compare_digest"):
ct_compare_digest = hmac.compare_digest
else:
def ct_compare_digest(val_a, val_b):
"""Compares if string like objects are equal. Constant time."""
if len(val_a) != len(val_b):

View File

@@ -52,7 +52,7 @@ class EncryptionHandler:
del data["ip"] # Remove IP from the data
return data
except Exception as e:
except Exception:
raise HTTPException(status_code=401, detail="Invalid or expired token")

View File

@@ -1,22 +1,27 @@
# Authors:
# Authors:
# Trevor Perrin
# Martin von Loewis - python 3 port
# Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
#
# See the LICENSE file for legal information regarding use of this file.
# See the LICENSE file for legal information regarding use of this file
"""cryptomath module
This module has basic math/crypto code."""
from __future__ import print_function
import os
import math
import base64
import binascii
from .compat import compat26Str, compatHMAC, compatLong, \
bytes_to_int, int_to_bytes, bit_length, byte_length
import math
import os
import zlib
from .codec import Writer
from .compat import (
bit_length,
byte_length,
bytes_to_int,
compat26Str,
compatHMAC,
int_to_bytes,
)
from . import tlshashlib as hashlib
from . import tlshmac as hmac
@@ -33,27 +38,31 @@ pycryptoLoaded = False
# **************************************************************************
# Check that os.urandom works
import zlib
assert len(zlib.compress(os.urandom(1000))) > 900
def getRandomBytes(howMany):
b = bytearray(os.urandom(howMany))
assert(len(b) == howMany)
assert len(b) == howMany
return b
prngName = "os.urandom"
# **************************************************************************
# Simple hash functions
# **************************************************************************
def MD5(b):
"""Return a MD5 digest of data"""
return secureHash(b, 'md5')
return secureHash(b, "md5")
def SHA1(b):
"""Return a SHA1 digest of data"""
return secureHash(b, 'sha1')
return secureHash(b, "sha1")
def secureHash(data, algorithm):
"""Return a digest of `data` using `algorithm`"""
@@ -61,33 +70,40 @@ def secureHash(data, algorithm):
hashInstance.update(compat26Str(data))
return bytearray(hashInstance.digest())
def secureHMAC(k, b, algorithm):
"""Return a HMAC using `b` and `k` using `algorithm`"""
k = compatHMAC(k)
b = compatHMAC(b)
return bytearray(hmac.new(k, b, getattr(hashlib, algorithm)).digest())
def HMAC_MD5(k, b):
return secureHMAC(k, b, 'md5')
return secureHMAC(k, b, "md5")
def HMAC_SHA1(k, b):
return secureHMAC(k, b, 'sha1')
return secureHMAC(k, b, "sha1")
def HMAC_SHA256(k, b):
return secureHMAC(k, b, 'sha256')
return secureHMAC(k, b, "sha256")
def HMAC_SHA384(k, b):
return secureHMAC(k, b, 'sha384')
return secureHMAC(k, b, "sha384")
def HKDF_expand(PRK, info, L, algorithm):
N = divceil(L, getattr(hashlib, algorithm)().digest_size)
T = bytearray()
Titer = bytearray()
for x in range(1, N+2):
for x in range(1, N + 2):
T += Titer
Titer = secureHMAC(PRK, Titer + info + bytearray([x]), algorithm)
return T[:L]
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
"""
TLS1.3 key derivation function (HKDF-Expand-Label).
@@ -108,6 +124,7 @@ def HKDF_expand_label(secret, label, hashValue, length, algorithm):
return HKDF_expand(secret, hkdfLabel.bytes, length, algorithm)
def derive_secret(secret, label, handshake_hashes, algorithm):
"""
TLS1.3 key derivation function (Derive-Secret).
@@ -123,17 +140,17 @@ def derive_secret(secret, label, handshake_hashes, algorithm):
:rtype: bytearray
"""
if handshake_hashes is None:
hs_hash = secureHash(bytearray(b''), algorithm)
hs_hash = secureHash(bytearray(b""), algorithm)
else:
hs_hash = handshake_hashes.digest(algorithm)
return HKDF_expand_label(secret, label, hs_hash,
getattr(hashlib, algorithm)().digest_size,
algorithm)
return HKDF_expand_label(secret, label, hs_hash, getattr(hashlib, algorithm)().digest_size, algorithm)
# **************************************************************************
# Converter Functions
# **************************************************************************
def bytesToNumber(b, endian="big"):
"""
Convert a number stored in bytearray to an integer.
@@ -156,7 +173,7 @@ def numberToByteArray(n, howManyBytes=None, endian="big"):
if howManyBytes < length:
ret = int_to_bytes(n, length, endian)
if endian == "big":
return ret[length-howManyBytes:length]
return ret[length - howManyBytes : length]
return ret[:howManyBytes]
return int_to_bytes(n, howManyBytes, endian)
@@ -172,12 +189,12 @@ def mpiToNumber(mpi):
def numberToMPI(n):
b = numberToByteArray(n)
ext = 0
#If the high-order bit is going to be set,
#add an extra byte of zeros
if (numBits(n) & 0x7)==0:
# If the high-order bit is going to be set,
# add an extra byte of zeros
if (numBits(n) & 0x7) == 0:
ext = 1
length = numBytes(n) + ext
b = bytearray(4+ext) + b
b = bytearray(4 + ext) + b
b[0] = (length >> 24) & 0xFF
b[1] = (length >> 16) & 0xFF
b[2] = (length >> 8) & 0xFF
@@ -190,75 +207,57 @@ def numberToMPI(n):
# **************************************************************************
# pylint: disable=invalid-name
# pylint recognises them as constants, not function names, also
# we can't change their names without API change
numBits = bit_length
numBytes = byte_length
# pylint: enable=invalid-name
# **************************************************************************
# Big Number Math
# **************************************************************************
def getRandomNumber(low, high):
assert low < high
howManyBits = numBits(high)
howManyBytes = numBytes(high)
lastBits = howManyBits % 8
while 1:
bytes = getRandomBytes(howManyBytes)
while True:
random_bytes = getRandomBytes(howManyBytes)
if lastBits:
bytes[0] = bytes[0] % (1 << lastBits)
n = bytesToNumber(bytes)
if n >= low and n < high:
random_bytes[0] = random_bytes[0] % (1 << lastBits)
n = bytesToNumber(random_bytes)
if low <= n < high:
return n
def gcd(a,b):
a, b = max(a,b), min(a,b)
def gcd(a, b):
a, b = max(a, b), min(a, b)
while b:
a, b = b, a % b
return a
def lcm(a, b):
return (a * b) // gcd(a, b)
# pylint: disable=invalid-name
# disable pylint check as the (a, b) are part of the API
if GMPY2_LOADED:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
if a == 0:
return 0
return powmod(a, -1, b)
else:
# Use Extended Euclidean Algorithm
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
c, d = a, b
uc, ud = 1, 0
while c != 0:
q = d // c
c, d = d-(q*c), c
uc, ud = ud - (q * uc), uc
if d == 1:
return ud % b
return 0
# pylint: enable=invalid-name
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
c, d = a, b
uc, ud = 1, 0
while c != 0:
q = d // c
c, d = d - (q * c), c
uc, ud = ud - (q * uc), uc
if d == 1:
return ud % b
return 0
if gmpyLoaded or GMPY2_LOADED:
def powMod(base, power, modulus):
base = mpz(base)
power = mpz(power)
modulus = mpz(modulus)
result = pow(base, power, modulus)
return compatLong(result)
else:
powMod = pow
# Use built-in pow for modular exponentiation (Python 3 handles this efficiently)
powMod = pow
def divceil(divident, divisor):
@@ -267,10 +266,10 @@ def divceil(divident, divisor):
return quot + int(bool(r))
#Pre-calculate a sieve of the ~100 primes < 1000:
# Pre-calculate a sieve of the ~100 primes < 1000:
def makeSieve(n):
sieve = list(range(n))
for count in range(2, int(math.sqrt(n))+1):
for count in range(2, int(math.sqrt(n)) + 1):
if sieve[count] == 0:
continue
x = sieve[count] * 2
@@ -280,30 +279,34 @@ def makeSieve(n):
sieve = [x for x in sieve[2:] if x]
return sieve
def isPrime(n, iterations=5, display=False, sieve=makeSieve(1000)):
#Trial division with sieve
# Trial division with sieve
for x in sieve:
if x >= n: return True
if n % x == 0: return False
#Passed trial division, proceed to Rabin-Miller
#Rabin-Miller implemented per Ferguson & Schneier
#Compute s, t for Rabin-Miller
if display: print("*", end=' ')
s, t = n-1, 0
if x >= n:
return True
if n % x == 0:
return False
# Passed trial division, proceed to Rabin-Miller
# Rabin-Miller implemented per Ferguson & Schneier
# Compute s, t for Rabin-Miller
if display:
print("*", end=" ")
s, t = n - 1, 0
while s % 2 == 0:
s, t = s//2, t+1
#Repeat Rabin-Miller x times
a = 2 #Use 2 as a base for first iteration speedup, per HAC
for count in range(iterations):
s, t = s // 2, t + 1
# Repeat Rabin-Miller x times
a = 2 # Use 2 as a base for first iteration speedup, per HAC
for _ in range(iterations):
v = powMod(a, s, n)
if v==1:
if v == 1:
continue
i = 0
while v != n-1:
if i == t-1:
while v != n - 1:
if i == t - 1:
return False
else:
v, i = powMod(v, 2, n), i+1
v, i = powMod(v, 2, n), i + 1
a = getRandomNumber(2, n)
return True
@@ -316,16 +319,16 @@ def getRandomPrime(bits, display=False):
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
"""
assert bits >= 10
#The 1.5 ensures the 2 MSBs are set
#Thus, when used for p,q in RSA, n will have its MSB set
# The 1.5 ensures the 2 MSBs are set
# Thus, when used for p,q in RSA, n will have its MSB set
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
low = ((2 ** (bits-1)) * 3) // 2
high = 2 ** bits - 30
# Since 30 is lcm(2,3,5), we'll set our test numbers to
# 29 % 30 and keep them there
low = ((2 ** (bits - 1)) * 3) // 2
high = 2**bits - 30
while True:
if display:
print(".", end=' ')
print(".", end=" ")
cand_p = getRandomNumber(low, high)
# make odd
if cand_p % 2 == 0:
@@ -334,7 +337,7 @@ def getRandomPrime(bits, display=False):
return cand_p
#Unused at the moment...
# Unused at the moment...
def getRandomSafePrime(bits, display=False):
"""Generate a random safe prime.
@@ -342,23 +345,24 @@ def getRandomSafePrime(bits, display=False):
the (p-1)/2 will also be prime.
"""
assert bits >= 10
#The 1.5 ensures the 2 MSBs are set
#Thus, when used for p,q in RSA, n will have its MSB set
# The 1.5 ensures the 2 MSBs are set
# Thus, when used for p,q in RSA, n will have its MSB set
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
low = (2 ** (bits-2)) * 3//2
high = (2 ** (bits-1)) - 30
# Since 30 is lcm(2,3,5), we'll set our test numbers to
# 29 % 30 and keep them there
low = (2 ** (bits - 2)) * 3 // 2
high = (2 ** (bits - 1)) - 30
q = getRandomNumber(low, high)
q += 29 - (q % 30)
while 1:
if display: print(".", end=' ')
while True:
if display:
print(".", end=" ")
q += 30
if (q >= high):
if q >= high:
q = getRandomNumber(low, high)
q += 29 - (q % 30)
#Ideas from Tom Wu's SRP code
#Do trial division on p and q before Rabin-Miller
# Ideas from Tom Wu's SRP code
# Do trial division on p and q before Rabin-Miller
if isPrime(q, 0, display=display):
p = (2 * q) + 1
if isPrime(p, display=display):

View File

@@ -1,373 +1,402 @@
import logging
import psutil
from typing import Dict, Optional, List
from urllib.parse import urljoin
import xmltodict
from mediaflow_proxy.utils.http_utils import create_httpx_client
from mediaflow_proxy.configs import settings
logger = logging.getLogger(__name__)
class DASHPreBuffer:
"""
Pre-buffer system for DASH streams to reduce latency and improve streaming performance.
"""
def __init__(self, max_cache_size: Optional[int] = None, prebuffer_segments: Optional[int] = None):
"""
Initialize the DASH pre-buffer system.
Args:
max_cache_size (int): Maximum number of segments to cache (uses config if None)
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
"""
self.max_cache_size = max_cache_size or settings.dash_prebuffer_cache_size
self.prebuffer_segments = prebuffer_segments or settings.dash_prebuffer_segments
self.max_memory_percent = settings.dash_prebuffer_max_memory_percent
self.emergency_threshold = settings.dash_prebuffer_emergency_threshold
# Cache for different types of DASH content
self.segment_cache: Dict[str, bytes] = {}
self.init_segment_cache: Dict[str, bytes] = {}
self.manifest_cache: Dict[str, dict] = {}
# Track segment URLs for each adaptation set
self.adaptation_segments: Dict[str, List[str]] = {}
self.client = create_httpx_client()
def _get_memory_usage_percent(self) -> float:
"""
Get current memory usage percentage.
Returns:
float: Memory usage percentage
"""
try:
memory = psutil.virtual_memory()
return memory.percent
except Exception as e:
logger.warning(f"Failed to get memory usage: {e}")
return 0.0
def _check_memory_threshold(self) -> bool:
"""
Check if memory usage exceeds the emergency threshold.
Returns:
bool: True if emergency cleanup is needed
"""
memory_percent = self._get_memory_usage_percent()
return memory_percent > self.emergency_threshold
def _emergency_cache_cleanup(self) -> None:
"""
Perform emergency cache cleanup when memory usage is high.
"""
if self._check_memory_threshold():
logger.warning("Emergency DASH cache cleanup triggered due to high memory usage")
# Clear 50% of segment cache
segment_cache_size = len(self.segment_cache)
segment_keys_to_remove = list(self.segment_cache.keys())[:segment_cache_size // 2]
for key in segment_keys_to_remove:
del self.segment_cache[key]
# Clear 50% of init segment cache
init_cache_size = len(self.init_segment_cache)
init_keys_to_remove = list(self.init_segment_cache.keys())[:init_cache_size // 2]
for key in init_keys_to_remove:
del self.init_segment_cache[key]
logger.info(f"Emergency cleanup removed {len(segment_keys_to_remove)} segments and {len(init_keys_to_remove)} init segments from cache")
async def prebuffer_dash_manifest(self, mpd_url: str, headers: Dict[str, str]) -> None:
"""
Pre-buffer segments from a DASH manifest.
Args:
mpd_url (str): URL of the DASH manifest
headers (Dict[str, str]): Headers to use for requests
"""
try:
# Download and parse MPD manifest
response = await self.client.get(mpd_url, headers=headers)
response.raise_for_status()
mpd_content = response.text
# Parse MPD XML
mpd_dict = xmltodict.parse(mpd_content)
# Store manifest in cache
self.manifest_cache[mpd_url] = mpd_dict
# Extract initialization segments and first few segments
await self._extract_and_prebuffer_segments(mpd_dict, mpd_url, headers)
logger.info(f"Pre-buffered DASH manifest: {mpd_url}")
except Exception as e:
logger.warning(f"Failed to pre-buffer DASH manifest {mpd_url}: {e}")
async def _extract_and_prebuffer_segments(self, mpd_dict: dict, base_url: str, headers: Dict[str, str]) -> None:
"""
Extract and pre-buffer segments from MPD manifest.
Args:
mpd_dict (dict): Parsed MPD manifest
base_url (str): Base URL for resolving relative URLs
headers (Dict[str, str]): Headers to use for requests
"""
try:
# Extract Period and AdaptationSet information
mpd = mpd_dict.get('MPD', {})
periods = mpd.get('Period', [])
if not isinstance(periods, list):
periods = [periods]
for period in periods:
adaptation_sets = period.get('AdaptationSet', [])
if not isinstance(adaptation_sets, list):
adaptation_sets = [adaptation_sets]
for adaptation_set in adaptation_sets:
# Extract initialization segment
init_segment = adaptation_set.get('SegmentTemplate', {}).get('@initialization')
if init_segment:
init_url = urljoin(base_url, init_segment)
await self._download_init_segment(init_url, headers)
# Extract segment template
segment_template = adaptation_set.get('SegmentTemplate', {})
if segment_template:
await self._prebuffer_template_segments(segment_template, base_url, headers)
# Extract segment list
segment_list = adaptation_set.get('SegmentList', {})
if segment_list:
await self._prebuffer_list_segments(segment_list, base_url, headers)
except Exception as e:
logger.warning(f"Failed to extract segments from MPD: {e}")
async def _download_init_segment(self, init_url: str, headers: Dict[str, str]) -> None:
"""
Download and cache initialization segment.
Args:
init_url (str): URL of the initialization segment
headers (Dict[str, str]): Headers to use for request
"""
try:
# Check memory usage before downloading
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping init segment download")
return
response = await self.client.get(init_url, headers=headers)
response.raise_for_status()
# Cache the init segment
self.init_segment_cache[init_url] = response.content
# Check for emergency cleanup
if self._check_memory_threshold():
self._emergency_cache_cleanup()
logger.debug(f"Cached init segment: {init_url}")
except Exception as e:
logger.warning(f"Failed to download init segment {init_url}: {e}")
async def _prebuffer_template_segments(self, segment_template: dict, base_url: str, headers: Dict[str, str]) -> None:
"""
Pre-buffer segments using segment template.
Args:
segment_template (dict): Segment template from MPD
base_url (str): Base URL for resolving relative URLs
headers (Dict[str, str]): Headers to use for requests
"""
try:
media_template = segment_template.get('@media')
if not media_template:
return
# Extract template parameters
start_number = int(segment_template.get('@startNumber', 1))
duration = float(segment_template.get('@duration', 0))
timescale = float(segment_template.get('@timescale', 1))
# Pre-buffer first few segments
for i in range(self.prebuffer_segments):
segment_number = start_number + i
segment_url = media_template.replace('$Number$', str(segment_number))
full_url = urljoin(base_url, segment_url)
await self._download_segment(full_url, headers)
except Exception as e:
logger.warning(f"Failed to pre-buffer template segments: {e}")
async def _prebuffer_list_segments(self, segment_list: dict, base_url: str, headers: Dict[str, str]) -> None:
"""
Pre-buffer segments from segment list.
Args:
segment_list (dict): Segment list from MPD
base_url (str): Base URL for resolving relative URLs
headers (Dict[str, str]): Headers to use for requests
"""
try:
segments = segment_list.get('SegmentURL', [])
if not isinstance(segments, list):
segments = [segments]
# Pre-buffer first few segments
for segment in segments[:self.prebuffer_segments]:
segment_url = segment.get('@src')
if segment_url:
full_url = urljoin(base_url, segment_url)
await self._download_segment(full_url, headers)
except Exception as e:
logger.warning(f"Failed to pre-buffer list segments: {e}")
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
"""
Download a single segment and cache it.
Args:
segment_url (str): URL of the segment to download
headers (Dict[str, str]): Headers to use for request
"""
try:
# Check memory usage before downloading
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping segment download")
return
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
# Cache the segment
self.segment_cache[segment_url] = response.content
# Check for emergency cleanup
if self._check_memory_threshold():
self._emergency_cache_cleanup()
# Maintain cache size
elif len(self.segment_cache) > self.max_cache_size:
# Remove oldest entries (simple FIFO)
oldest_key = next(iter(self.segment_cache))
del self.segment_cache[oldest_key]
logger.debug(f"Cached DASH segment: {segment_url}")
except Exception as e:
logger.warning(f"Failed to download DASH segment {segment_url}: {e}")
async def get_segment(self, segment_url: str, headers: Dict[str, str]) -> Optional[bytes]:
"""
Get a segment from cache or download it.
Args:
segment_url (str): URL of the segment
headers (Dict[str, str]): Headers to use for request
Returns:
Optional[bytes]: Cached segment data or None if not available
"""
# Check segment cache first
if segment_url in self.segment_cache:
logger.debug(f"DASH cache hit for segment: {segment_url}")
return self.segment_cache[segment_url]
# Check init segment cache
if segment_url in self.init_segment_cache:
logger.debug(f"DASH cache hit for init segment: {segment_url}")
return self.init_segment_cache[segment_url]
# Check memory usage before downloading
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
return None
# Download if not in cache
try:
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
segment_data = response.content
# Determine if it's an init segment or regular segment
if 'init' in segment_url.lower() or segment_url.endswith('.mp4'):
self.init_segment_cache[segment_url] = segment_data
else:
self.segment_cache[segment_url] = segment_data
# Check for emergency cleanup
if self._check_memory_threshold():
self._emergency_cache_cleanup()
# Maintain cache size
elif len(self.segment_cache) > self.max_cache_size:
oldest_key = next(iter(self.segment_cache))
del self.segment_cache[oldest_key]
logger.debug(f"Downloaded and cached DASH segment: {segment_url}")
return segment_data
except Exception as e:
logger.warning(f"Failed to get DASH segment {segment_url}: {e}")
return None
async def get_manifest(self, mpd_url: str, headers: Dict[str, str]) -> Optional[dict]:
"""
Get MPD manifest from cache or download it.
Args:
mpd_url (str): URL of the MPD manifest
headers (Dict[str, str]): Headers to use for request
Returns:
Optional[dict]: Cached manifest data or None if not available
"""
# Check cache first
if mpd_url in self.manifest_cache:
logger.debug(f"DASH cache hit for manifest: {mpd_url}")
return self.manifest_cache[mpd_url]
# Download if not in cache
try:
response = await self.client.get(mpd_url, headers=headers)
response.raise_for_status()
mpd_content = response.text
mpd_dict = xmltodict.parse(mpd_content)
# Cache the manifest
self.manifest_cache[mpd_url] = mpd_dict
logger.debug(f"Downloaded and cached DASH manifest: {mpd_url}")
return mpd_dict
except Exception as e:
logger.warning(f"Failed to get DASH manifest {mpd_url}: {e}")
return None
def clear_cache(self) -> None:
"""Clear the DASH cache."""
self.segment_cache.clear()
self.init_segment_cache.clear()
self.manifest_cache.clear()
self.adaptation_segments.clear()
logger.info("DASH pre-buffer cache cleared")
async def close(self) -> None:
"""Close the pre-buffer system."""
await self.client.aclose()
# Global DASH pre-buffer instance
dash_prebuffer = DASHPreBuffer()
"""
DASH Pre-buffer system for reducing latency and improving streaming performance.
This module extends BasePrebuffer with DASH-specific functionality including
MPD parsing integration, profile handling, and init segment management.
"""
import asyncio
import logging
import time
from typing import Dict, Optional, List
from mediaflow_proxy.utils.base_prebuffer import BasePrebuffer
from mediaflow_proxy.utils.cache_utils import (
get_cached_mpd,
get_cached_init_segment,
)
from mediaflow_proxy.configs import settings
logger = logging.getLogger(__name__)
class DASHPreBuffer(BasePrebuffer):
"""
Pre-buffer system for DASH streams.
Extends BasePrebuffer with DASH-specific features:
- MPD manifest parsing and profile handling
- Init segment prebuffering
- Live stream segment tracking
- Profile-based segment prefetching
Uses event-based download coordination from BasePrebuffer to prevent
duplicate downloads between player requests and background prebuffering.
"""
def __init__(
self,
max_cache_size: Optional[int] = None,
prebuffer_segments: Optional[int] = None,
):
"""
Initialize the DASH pre-buffer system.
Args:
max_cache_size: Maximum number of segments to cache (uses config if None)
prebuffer_segments: Number of segments to pre-buffer ahead (uses config if None)
"""
super().__init__(
max_cache_size=max_cache_size or settings.dash_prebuffer_cache_size,
prebuffer_segments=prebuffer_segments or settings.dash_prebuffer_segments,
max_memory_percent=settings.dash_prebuffer_max_memory_percent,
emergency_threshold=settings.dash_prebuffer_emergency_threshold,
segment_ttl=settings.dash_segment_cache_ttl,
)
self.inactivity_timeout = settings.dash_prebuffer_inactivity_timeout
# DASH-specific state
# Track active streams for prefetching: mpd_url -> stream_info
self.active_streams: Dict[str, dict] = {}
self.prefetch_tasks: Dict[str, asyncio.Task] = {}
# Additional stats for DASH
self.init_segments_prebuffered = 0
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
def log_stats(self) -> None:
"""Log current prebuffer statistics with DASH-specific info."""
stats = self.stats.to_dict()
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
stats["active_streams"] = len(self.active_streams)
logger.info(f"DASH Prebuffer Stats: {stats}")
async def prebuffer_dash_manifest(
self,
mpd_url: str,
headers: Dict[str, str],
) -> None:
"""
Pre-buffer segments from a DASH manifest using existing MPD parsing.
Args:
mpd_url: URL of the DASH manifest
headers: Headers to use for requests
"""
try:
# First get the basic MPD info without segments
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
if not parsed_mpd:
logger.warning(f"Failed to get parsed MPD for prebuffering: {mpd_url}")
return
is_live = parsed_mpd.get("isLive", False)
base_profiles = parsed_mpd.get("profiles", [])
if not base_profiles:
logger.warning(f"No profiles found in MPD for prebuffering: {mpd_url}")
return
# Now get segments for each profile by parsing with profile_id
profiles_with_segments = []
for profile in base_profiles:
profile_id = profile.get("id")
if profile_id:
parsed_with_segments = await get_cached_mpd(
mpd_url, headers, parse_drm=False, parse_segment_profile_id=profile_id
)
# Find the matching profile with segments
for p in parsed_with_segments.get("profiles", []):
if p.get("id") == profile_id:
profiles_with_segments.append(p)
break
# Store stream info for ongoing prefetching
self.active_streams[mpd_url] = {
"headers": headers,
"is_live": is_live,
"profiles": profiles_with_segments,
"last_access": time.time(),
}
# Prebuffer init segments and media segments
await self._prebuffer_profiles(profiles_with_segments, headers, is_live)
# Start cleanup task if not running
self._ensure_cleanup_task_running()
logger.info(
f"Pre-buffered DASH manifest: {mpd_url} (live={is_live}, profiles={len(profiles_with_segments)})"
)
except Exception as e:
logger.warning(f"Failed to pre-buffer DASH manifest {mpd_url}: {e}")
async def _prebuffer_profiles(
self,
profiles: List[dict],
headers: Dict[str, str],
is_live: bool = False,
) -> None:
"""
Pre-buffer init segments and media segments for all profiles.
For live streams, prebuffers from the END of the segment list.
For VOD, prebuffers from the beginning.
Args:
profiles: List of parsed profiles with resolved URLs
headers: Headers to use for requests
is_live: Whether this is a live stream
"""
if self._should_skip_for_memory():
logger.warning("Memory usage too high, skipping prebuffer")
return
# Collect all segment URLs to prebuffer
segment_urls = []
init_urls = []
for profile in profiles:
# Collect init segment URL
init_url = profile.get("initUrl")
if init_url:
init_urls.append(init_url)
# Get segments to prebuffer
segments = profile.get("segments", [])
if not segments:
continue
# For live streams, prebuffer from the END (most recent)
if is_live:
segments_to_buffer = segments[-self.prebuffer_segment_count :]
else:
segments_to_buffer = segments[: self.prebuffer_segment_count]
for segment in segments_to_buffer:
segment_url = segment.get("media")
if segment_url:
segment_urls.append(segment_url)
# Prebuffer init segments (using special init cache)
for init_url in init_urls:
asyncio.create_task(self._prebuffer_init_segment(init_url, headers))
# Prebuffer media segments using base class method
if segment_urls:
await self.prebuffer_segments_batch(segment_urls, headers)
async def _prebuffer_init_segment(
self,
init_url: str,
headers: Dict[str, str],
) -> None:
"""
Prebuffer an init segment using the init segment cache.
Args:
init_url: URL of the init segment
headers: Headers for the request
"""
try:
# get_cached_init_segment handles both caching and downloading
content = await get_cached_init_segment(init_url, headers)
if content:
self.init_segments_prebuffered += 1
self.stats.bytes_prebuffered += len(content)
logger.debug(f"Prebuffered init segment ({len(content)} bytes)")
except Exception as e:
logger.warning(f"Failed to prebuffer init segment: {e}")
async def prefetch_upcoming_segments(
self,
mpd_url: str,
current_segment_url: str,
headers: Dict[str, str],
profile_id: Optional[str] = None,
) -> None:
"""
Prefetch upcoming segments based on current playback position.
Called when a segment is requested to prefetch the next N segments.
Args:
mpd_url: URL of the MPD manifest
current_segment_url: URL of the currently requested segment
headers: Headers to use for requests
profile_id: Optional profile ID to limit prefetching to
"""
self.stats.prefetch_triggered += 1
try:
# First check if we have cached profiles with segments
if mpd_url in self.active_streams:
# Update last access time
self.active_streams[mpd_url]["last_access"] = time.time()
profiles = self.active_streams[mpd_url].get("profiles", [])
else:
# Get parsed MPD
parsed_mpd = await get_cached_mpd(mpd_url, headers, parse_drm=False)
if not parsed_mpd:
return
profiles = parsed_mpd.get("profiles", [])
for profile in profiles:
pid = profile.get("id")
if profile_id and pid != profile_id:
continue
segments = profile.get("segments", [])
# If no segments, try to get them by parsing with profile_id
if not segments and pid:
parsed_with_segments = await get_cached_mpd(
mpd_url, headers, parse_drm=False, parse_segment_profile_id=pid
)
for p in parsed_with_segments.get("profiles", []):
if p.get("id") == pid:
segments = p.get("segments", [])
break
# Find current segment index
current_index = -1
for i, segment in enumerate(segments):
if segment.get("media") == current_segment_url:
current_index = i
break
if current_index < 0:
continue
# Collect next N segment URLs
segment_urls = []
end_index = min(current_index + 1 + self.prebuffer_segment_count, len(segments))
for i in range(current_index + 1, end_index):
segment_url = segments[i].get("media")
if segment_url:
segment_urls.append(segment_url)
if segment_urls:
logger.debug(f"Prefetching {len(segment_urls)} upcoming segments from index {current_index + 1}")
# Run prefetch in background
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
except Exception as e:
logger.warning(f"Failed to prefetch upcoming segments: {e}")
async def prefetch_for_live_playlist(
self,
profiles: List[dict],
headers: Dict[str, str],
) -> None:
"""
Prefetch segments for a live playlist refresh.
Called from process_playlist to ensure upcoming segments are cached.
Args:
profiles: List of profiles with resolved segment URLs
headers: Headers to use for requests
"""
segment_urls = []
for profile in profiles:
segments = profile.get("segments", [])
if not segments:
continue
# For live, prefetch the last N segments (most recent)
segments_to_prefetch = segments[-self.prebuffer_segment_count :]
for segment in segments_to_prefetch:
segment_url = segment.get("media")
if segment_url:
# Check if already cached before adding
cached = await self.try_get_cached(segment_url)
if not cached:
segment_urls.append(segment_url)
if segment_urls:
logger.debug(f"Live playlist prefetch: {len(segment_urls)} segments")
asyncio.create_task(self.prebuffer_segments_batch(segment_urls, headers, max_concurrent=3))
def _ensure_cleanup_task_running(self) -> None:
"""Ensure the cleanup task is running."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_inactive_streams())
async def _cleanup_inactive_streams(self) -> None:
"""
Periodically check for and clean up inactive streams.
Runs in the background and removes streams that haven't been
accessed recently.
"""
while True:
try:
await asyncio.sleep(30) # Check every 30 seconds
if not self.active_streams:
logger.debug("No active DASH streams to monitor, stopping cleanup")
return
current_time = time.time()
streams_to_remove = []
for mpd_url, stream_info in self.active_streams.items():
last_access = stream_info.get("last_access", 0)
time_since_access = current_time - last_access
if time_since_access > self.inactivity_timeout:
streams_to_remove.append(mpd_url)
logger.info(f"Cleaning up inactive DASH stream ({time_since_access:.0f}s idle)")
# Remove inactive streams
for mpd_url in streams_to_remove:
self.active_streams.pop(mpd_url, None)
task = self.prefetch_tasks.pop(mpd_url, None)
if task:
task.cancel()
if streams_to_remove:
logger.info(f"Cleaned up {len(streams_to_remove)} inactive DASH stream(s)")
except asyncio.CancelledError:
logger.debug("DASH cleanup task cancelled")
return
except Exception as e:
logger.warning(f"Error in DASH cleanup task: {e}")
def get_stats(self) -> dict:
"""Get current prebuffer statistics."""
stats = self.stats.to_dict()
stats["init_segments_prebuffered"] = self.init_segments_prebuffered
stats["active_streams"] = len(self.active_streams)
return stats
def clear_cache(self) -> None:
"""Clear active streams tracking and log final stats."""
self.log_stats()
self.active_streams.clear()
for task in self.prefetch_tasks.values():
task.cancel()
self.prefetch_tasks.clear()
# Cancel cleanup task
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
self._cleanup_task = None
self.stats.reset()
self.init_segments_prebuffered = 0
logger.info("DASH pre-buffer state cleared")
async def close(self) -> None:
"""Close the pre-buffer system."""
self.clear_cache()
# Global DASH pre-buffer instance
dash_prebuffer = DASHPreBuffer()

View File

@@ -2,14 +2,13 @@
#
# See the LICENSE file for legal information regarding use of this file.
"""Methods for deprecating old names for arguments or attributes."""
import warnings
import inspect
from functools import wraps
def deprecated_class_name(old_name,
warn="Class name '{old_name}' is deprecated, "
"please use '{new_name}'"):
def deprecated_class_name(old_name, warn="Class name '{old_name}' is deprecated, please use '{new_name}'"):
"""
Class decorator to deprecate a use of class.
@@ -21,14 +20,12 @@ def deprecated_class_name(old_name,
keyword name and the 'new_name' for the current one.
Example: "Old name: {old_nam}, use '{new_name}' instead".
"""
def _wrap(obj):
assert callable(obj)
def _warn():
warnings.warn(warn.format(old_name=old_name,
new_name=obj.__name__),
DeprecationWarning,
stacklevel=3)
warnings.warn(warn.format(old_name=old_name, new_name=obj.__name__), DeprecationWarning, stacklevel=3)
def _wrap_with_warn(func, is_inspect):
@wraps(func)
@@ -41,12 +38,12 @@ def deprecated_class_name(old_name,
# isinstance(old_name(), new_name) to work
frame = inspect.currentframe().f_back
code = inspect.getframeinfo(frame).code_context
if [line for line in code
if '{0}('.format(old_name) in line]:
if [line for line in code if "{0}(".format(old_name) in line]:
_warn()
else:
_warn()
return func(*args, **kwargs)
return _func
# Make old name available.
@@ -63,11 +60,11 @@ def deprecated_class_name(old_name,
frame.f_globals[old_name] = placeholder
return obj
return _wrap
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
"please use '{new_name}'"):
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to translate obsolete names and warn about their use.
:param dict names: dictionary with pairs of new_name: old_name
@@ -78,27 +75,24 @@ def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for new_name, old_name in names.items():
if old_name in kwargs:
if new_name in kwargs:
raise TypeError("got multiple values for keyword "
"argument '{0}'".format(new_name))
warnings.warn(warn.format(old_name=old_name,
new_name=new_name),
DeprecationWarning,
stacklevel=2)
raise TypeError("got multiple values for keyword argument '{0}'".format(new_name))
warnings.warn(warn.format(old_name=old_name, new_name=new_name), DeprecationWarning, stacklevel=2)
kwargs[new_name] = kwargs.pop(old_name)
return func(*args, **kwargs)
return wrapper
return decorator
def deprecated_instance_attrs(names,
warn="Attribute '{old_name}' is deprecated, "
"please use '{new_name}'"):
def deprecated_instance_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to deprecate class instance attributes.
Translates all names in `names` to use new names and emits warnings
@@ -119,27 +113,20 @@ def deprecated_instance_attrs(names,
def decorator(clazz):
def getx(self, name, __old_getx=getattr(clazz, "__getattr__", None)):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
return getattr(self, names[name])
if __old_getx:
if hasattr(__old_getx, "__func__"):
return __old_getx.__func__(self, name)
return __old_getx(self, name)
raise AttributeError("'{0}' object has no attribute '{1}'"
.format(clazz.__name__, name))
raise AttributeError("'{0}' object has no attribute '{1}'".format(clazz.__name__, name))
getx.__name__ = "__getattr__"
clazz.__getattr__ = getx
def setx(self, name, value, __old_setx=getattr(clazz, "__setattr__")):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
setattr(self, names[name], value)
else:
__old_setx(self, name, value)
@@ -149,10 +136,7 @@ def deprecated_instance_attrs(names,
def delx(self, name, __old_delx=getattr(clazz, "__delattr__")):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
warnings.warn(warn.format(old_name=name, new_name=names[name]), DeprecationWarning, stacklevel=2)
delattr(self, names[name])
else:
__old_delx(self, name)
@@ -161,11 +145,11 @@ def deprecated_instance_attrs(names,
clazz.__delattr__ = delx
return clazz
return decorator
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
"please use '{new_name}'"):
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to deprecate all specified attributes in class.
Translates all names in `names` to use new names and emits warnings
@@ -180,6 +164,7 @@ def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
# prepare metaclass for handling all the class methods, class variables
# and static methods (as they don't go through instance's __getattr__)
class DeprecatedProps(type):
@@ -192,27 +177,33 @@ def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
# apply metaclass
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
slots = orig_vars.get("__slots__")
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
orig_vars.pop("__dict__", None)
orig_vars.pop("__weakref__", None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def deprecated_method(message):
"""Decorator for deprecating methods.
:param ste message: The message you want to display.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
warnings.warn("{0} is a deprecated method. {1}".format(func.__name__, message),
DeprecationWarning, stacklevel=2)
warnings.warn(
"{0} is a deprecated method. {1}".format(func.__name__, message), DeprecationWarning, stacklevel=2
)
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,151 @@
"""
Helper functions for automatic stream extraction in proxy routes.
This module provides caching and extraction helpers for DLHD/DaddyLive
and Sportsonline/Sportzonline streams that are auto-detected in proxy routes.
"""
import logging
import re
import time
from urllib.parse import urlparse
from fastapi import Request, HTTPException
from mediaflow_proxy.extractors.base import ExtractorError
from mediaflow_proxy.extractors.factory import ExtractorFactory
from mediaflow_proxy.utils.http_utils import ProxyRequestHeaders, DownloadError
logger = logging.getLogger(__name__)
# DLHD extraction cache: {original_url: {"data": extraction_result, "timestamp": time.time()}}
_dlhd_extraction_cache: dict = {}
_dlhd_cache_duration = 600 # 10 minutes in seconds
# Sportsonline extraction cache
_sportsonline_extraction_cache: dict = {}
_sportsonline_cache_duration = 600 # 10 minutes in seconds
async def check_and_extract_dlhd_stream(
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
) -> dict | None:
"""
Check if destination contains DLHD/DaddyLive patterns and extract stream directly.
Uses caching to avoid repeated extractions (10 minute cache).
Args:
request (Request): The incoming HTTP request.
destination (str): The destination URL to check.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
force_refresh (bool): Force re-extraction even if cached data exists.
Returns:
dict | None: Extracted stream data if DLHD link detected, None otherwise.
"""
# Check for common DLHD/DaddyLive patterns in the URL
# This includes stream-XXX pattern and domain names like dlhd.dad or daddylive.sx
is_dlhd_link = (
re.search(r"stream-\d+", destination)
or "dlhd.dad" in urlparse(destination).netloc
or "daddylive.sx" in urlparse(destination).netloc
)
if not is_dlhd_link:
return None
logger.info(f"DLHD link detected: {destination}")
# Check cache first (unless force_refresh is True)
current_time = time.time()
if not force_refresh and destination in _dlhd_extraction_cache:
cached_entry = _dlhd_extraction_cache[destination]
cache_age = current_time - cached_entry["timestamp"]
if cache_age < _dlhd_cache_duration:
logger.info(f"Using cached DLHD data (age: {cache_age:.1f}s)")
return cached_entry["data"]
else:
logger.info(f"DLHD cache expired (age: {cache_age:.1f}s), re-extracting...")
del _dlhd_extraction_cache[destination]
# Extract stream data
try:
logger.info(f"Extracting DLHD stream data from: {destination}")
extractor = ExtractorFactory.get_extractor("DLHD", proxy_headers.request)
result = await extractor.extract(destination)
logger.info(f"DLHD extraction successful. Stream URL: {result.get('destination_url')}")
# Handle dlhd_key_params - encode them for URL passing
if "dlhd_key_params" in result:
key_params = result.pop("dlhd_key_params")
# Add key params as special query parameters for key URL handling
result["dlhd_channel_salt"] = key_params.get("channel_salt", "")
result["dlhd_auth_token"] = key_params.get("auth_token", "")
result["dlhd_iframe_url"] = key_params.get("iframe_url", "")
logger.info("DLHD key params extracted for dynamic header computation")
# Cache a copy of result to prevent downstream mutations from corrupting the cache
_dlhd_extraction_cache[destination] = {"data": result.copy(), "timestamp": current_time}
logger.info(f"DLHD data cached for {_dlhd_cache_duration}s")
return result
except (ExtractorError, DownloadError) as e:
logger.error(f"DLHD extraction failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"DLHD extraction failed: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error during DLHD extraction: {str(e)}")
raise HTTPException(status_code=500, detail=f"DLHD extraction failed: {str(e)}")
async def check_and_extract_sportsonline_stream(
request: Request, destination: str, proxy_headers: ProxyRequestHeaders, force_refresh: bool = False
) -> dict | None:
"""
Check if destination contains Sportsonline/Sportzonline patterns and extract stream directly.
Uses caching to avoid repeated extractions (10 minute cache).
Args:
request (Request): The incoming HTTP request.
destination (str): The destination URL to check.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
force_refresh (bool): Force re-extraction even if cached data exists.
Returns:
dict | None: Extracted stream data if Sportsonline link detected, None otherwise.
"""
parsed_netloc = urlparse(destination).netloc
is_sportsonline_link = "sportzonline." in parsed_netloc or "sportsonline." in parsed_netloc
if not is_sportsonline_link:
return None
logger.info(f"Sportsonline link detected: {destination}")
current_time = time.time()
if not force_refresh and destination in _sportsonline_extraction_cache:
cached_entry = _sportsonline_extraction_cache[destination]
if current_time - cached_entry["timestamp"] < _sportsonline_cache_duration:
logger.info(f"Using cached Sportsonline data (age: {current_time - cached_entry['timestamp']:.1f}s)")
return cached_entry["data"]
else:
logger.info("Sportsonline cache expired, re-extracting...")
del _sportsonline_extraction_cache[destination]
try:
logger.info(f"Extracting Sportsonline stream data from: {destination}")
extractor = ExtractorFactory.get_extractor("Sportsonline", proxy_headers.request)
result = await extractor.extract(destination)
logger.info(f"Sportsonline extraction successful. Stream URL: {result.get('destination_url')}")
_sportsonline_extraction_cache[destination] = {"data": result, "timestamp": current_time}
logger.info(f"Sportsonline data cached for {_sportsonline_cache_duration}s")
return result
except (ExtractorError, DownloadError) as e:
logger.error(f"Sportsonline extraction failed: {str(e)}")
raise HTTPException(status_code=400, detail=f"Sportsonline extraction failed: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error during Sportsonline extraction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Sportsonline extraction failed: {str(e)}")

View File

@@ -1,490 +1,478 @@
import asyncio
import logging
import psutil
from typing import Dict, Optional, List
from urllib.parse import urlparse
import httpx
from mediaflow_proxy.utils.http_utils import create_httpx_client
from mediaflow_proxy.configs import settings
from collections import OrderedDict
import time
from urllib.parse import urljoin
logger = logging.getLogger(__name__)
class HLSPreBuffer:
"""
Pre-buffer system for HLS streams to reduce latency and improve streaming performance.
"""
def __init__(self, max_cache_size: Optional[int] = None, prebuffer_segments: Optional[int] = None):
"""
Initialize the HLS pre-buffer system.
Args:
max_cache_size (int): Maximum number of segments to cache (uses config if None)
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
"""
from collections import OrderedDict
import time
from urllib.parse import urljoin
self.max_cache_size = max_cache_size or settings.hls_prebuffer_cache_size
self.prebuffer_segments = prebuffer_segments or settings.hls_prebuffer_segments
self.max_memory_percent = settings.hls_prebuffer_max_memory_percent
self.emergency_threshold = settings.hls_prebuffer_emergency_threshold
# Cache LRU
self.segment_cache: "OrderedDict[str, bytes]" = OrderedDict()
# Mappa playlist -> lista segmenti
self.segment_urls: Dict[str, List[str]] = {}
# Mappa inversa segmento -> (playlist_url, index)
self.segment_to_playlist: Dict[str, tuple[str, int]] = {}
# Stato per playlist: {headers, last_access, refresh_task, target_duration}
self.playlist_state: Dict[str, dict] = {}
self.client = create_httpx_client()
async def prebuffer_playlist(self, playlist_url: str, headers: Dict[str, str]) -> None:
"""
Pre-buffer segments from an HLS playlist.
Args:
playlist_url (str): URL of the HLS playlist
headers (Dict[str, str]): Headers to use for requests
"""
try:
logger.debug(f"Starting pre-buffer for playlist: {playlist_url}")
response = await self.client.get(playlist_url, headers=headers)
response.raise_for_status()
playlist_content = response.text
# Se master playlist: prendi la prima variante (fix relativo)
if "#EXT-X-STREAM-INF" in playlist_content:
logger.debug(f"Master playlist detected, finding first variant")
variant_urls = self._extract_variant_urls(playlist_content, playlist_url)
if variant_urls:
first_variant_url = variant_urls[0]
logger.debug(f"Pre-buffering first variant: {first_variant_url}")
await self.prebuffer_playlist(first_variant_url, headers)
else:
logger.warning("No variants found in master playlist")
return
# Media playlist: estrai segmenti, salva stato e lancia refresh loop
segment_urls = self._extract_segment_urls(playlist_content, playlist_url)
self.segment_urls[playlist_url] = segment_urls
# aggiorna mappa inversa
for idx, u in enumerate(segment_urls):
self.segment_to_playlist[u] = (playlist_url, idx)
# prebuffer iniziale
await self._prebuffer_segments(segment_urls[:self.prebuffer_segments], headers)
logger.info(f"Pre-buffered {min(self.prebuffer_segments, len(segment_urls))} segments for {playlist_url}")
# setup refresh loop se non già attivo
target_duration = self._parse_target_duration(playlist_content) or 6
st = self.playlist_state.get(playlist_url, {})
if not st.get("refresh_task") or st["refresh_task"].done():
task = asyncio.create_task(self._refresh_playlist_loop(playlist_url, headers, target_duration))
self.playlist_state[playlist_url] = {
"headers": headers,
"last_access": asyncio.get_event_loop().time(),
"refresh_task": task,
"target_duration": target_duration,
}
except Exception as e:
logger.warning(f"Failed to pre-buffer playlist {playlist_url}: {e}")
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Extract segment URLs from HLS playlist content.
Args:
playlist_content (str): Content of the HLS playlist
base_url (str): Base URL for resolving relative URLs
Returns:
List[str]: List of segment URLs
"""
segment_urls = []
lines = playlist_content.split('\n')
logger.debug(f"Analyzing playlist with {len(lines)} lines")
for line in lines:
line = line.strip()
if line and not line.startswith('#'):
# Check if line contains a URL (http/https) or is a relative path
if 'http://' in line or 'https://' in line:
segment_urls.append(line)
logger.debug(f"Found absolute URL: {line}")
elif line and not line.startswith('#'):
# This might be a relative path to a segment
parsed_base = urlparse(base_url)
# Ensure proper path joining
if line.startswith('/'):
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
else:
# Get the directory path from base_url
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
segment_urls.append(segment_url)
logger.debug(f"Found relative path: {line} -> {segment_url}")
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
if segment_urls:
logger.debug(f"First segment URL: {segment_urls[0]}")
else:
logger.debug("No segment URLs found in playlist")
# Log first few lines for debugging
for i, line in enumerate(lines[:10]):
logger.debug(f"Line {i}: {line}")
return segment_urls
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
"""
from urllib.parse import urljoin
variant_urls = []
lines = [l.strip() for l in playlist_content.split('\n')]
take_next_uri = False
for line in lines:
if line.startswith("#EXT-X-STREAM-INF"):
take_next_uri = True
continue
if take_next_uri:
take_next_uri = False
if line and not line.startswith('#'):
variant_urls.append(urljoin(base_url, line))
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
if variant_urls:
logger.debug(f"First variant URL: {variant_urls[0]}")
return variant_urls
async def _prebuffer_segments(self, segment_urls: List[str], headers: Dict[str, str]) -> None:
"""
Pre-buffer specific segments.
Args:
segment_urls (List[str]): List of segment URLs to pre-buffer
headers (Dict[str, str]): Headers to use for requests
"""
tasks = []
for url in segment_urls:
if url not in self.segment_cache:
tasks.append(self._download_segment(url, headers))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def _get_memory_usage_percent(self) -> float:
"""
Get current memory usage percentage.
Returns:
float: Memory usage percentage
"""
try:
memory = psutil.virtual_memory()
return memory.percent
except Exception as e:
logger.warning(f"Failed to get memory usage: {e}")
return 0.0
def _check_memory_threshold(self) -> bool:
"""
Check if memory usage exceeds the emergency threshold.
Returns:
bool: True if emergency cleanup is needed
"""
memory_percent = self._get_memory_usage_percent()
return memory_percent > self.emergency_threshold
def _emergency_cache_cleanup(self) -> None:
"""
Esegue cleanup LRU rimuovendo il 50% più vecchio.
"""
if self._check_memory_threshold():
logger.warning("Emergency cache cleanup triggered due to high memory usage")
to_remove = max(1, len(self.segment_cache) // 2)
removed = 0
while removed < to_remove and self.segment_cache:
self.segment_cache.popitem(last=False) # rimuovi LRU
removed += 1
logger.info(f"Emergency cleanup removed {removed} segments from cache")
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
"""
Download a single segment and cache it.
Args:
segment_url (str): URL of the segment to download
headers (Dict[str, str]): Headers to use for request
"""
try:
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
return
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
# Cache LRU
self.segment_cache[segment_url] = response.content
self.segment_cache.move_to_end(segment_url, last=True)
if self._check_memory_threshold():
self._emergency_cache_cleanup()
elif len(self.segment_cache) > self.max_cache_size:
# Evict LRU finché non rientra
while len(self.segment_cache) > self.max_cache_size:
self.segment_cache.popitem(last=False)
logger.debug(f"Cached segment: {segment_url}")
except Exception as e:
logger.warning(f"Failed to download segment {segment_url}: {e}")
async def get_segment(self, segment_url: str, headers: Dict[str, str]) -> Optional[bytes]:
"""
Get a segment from cache or download it.
Args:
segment_url (str): URL of the segment
headers (Dict[str, str]): Headers to use for request
Returns:
Optional[bytes]: Cached segment data or None if not available
"""
# Check cache first
if segment_url in self.segment_cache:
logger.debug(f"Cache hit for segment: {segment_url}")
# LRU touch
data = self.segment_cache[segment_url]
self.segment_cache.move_to_end(segment_url, last=True)
# aggiorna last_access per la playlist se mappata
pl = self.segment_to_playlist.get(segment_url)
if pl:
st = self.playlist_state.get(pl[0])
if st:
st["last_access"] = asyncio.get_event_loop().time()
return data
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
return None
try:
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
segment_data = response.content
# Cache LRU
self.segment_cache[segment_url] = segment_data
self.segment_cache.move_to_end(segment_url, last=True)
if self._check_memory_threshold():
self._emergency_cache_cleanup()
elif len(self.segment_cache) > self.max_cache_size:
while len(self.segment_cache) > self.max_cache_size:
self.segment_cache.popitem(last=False)
# aggiorna last_access per playlist
pl = self.segment_to_playlist.get(segment_url)
if pl:
st = self.playlist_state.get(pl[0])
if st:
st["last_access"] = asyncio.get_event_loop().time()
logger.debug(f"Downloaded and cached segment: {segment_url}")
return segment_data
except Exception as e:
logger.warning(f"Failed to get segment {segment_url}: {e}")
return None
async def prebuffer_from_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
"""
Dato un URL di segmento, prebuffer i successivi in base alla playlist e all'indice mappato.
"""
mapped = self.segment_to_playlist.get(segment_url)
if not mapped:
return
playlist_url, idx = mapped
# aggiorna access time
st = self.playlist_state.get(playlist_url)
if st:
st["last_access"] = asyncio.get_event_loop().time()
await self.prebuffer_next_segments(playlist_url, idx, headers)
async def prebuffer_next_segments(self, playlist_url: str, current_segment_index: int, headers: Dict[str, str]) -> None:
"""
Pre-buffer next segments based on current playback position.
Args:
playlist_url (str): URL of the playlist
current_segment_index (int): Index of current segment
headers (Dict[str, str]): Headers to use for requests
"""
if playlist_url not in self.segment_urls:
return
segment_urls = self.segment_urls[playlist_url]
next_segments = segment_urls[current_segment_index + 1:current_segment_index + 1 + self.prebuffer_segments]
if next_segments:
await self._prebuffer_segments(next_segments, headers)
def clear_cache(self) -> None:
"""Clear the segment cache."""
self.segment_cache.clear()
self.segment_urls.clear()
self.segment_to_playlist.clear()
self.playlist_state.clear()
logger.info("HLS pre-buffer cache cleared")
async def close(self) -> None:
"""Close the pre-buffer system."""
await self.client.aclose()
# Global pre-buffer instance
hls_prebuffer = HLSPreBuffer()
class HLSPreBuffer:
def _parse_target_duration(self, playlist_content: str) -> Optional[int]:
"""
Parse EXT-X-TARGETDURATION from a media playlist and return duration in seconds.
Returns None if not present or unparsable.
"""
for line in playlist_content.splitlines():
line = line.strip()
if line.startswith("#EXT-X-TARGETDURATION:"):
try:
value = line.split(":", 1)[1].strip()
return int(float(value))
except Exception:
return None
return None
async def _refresh_playlist_loop(self, playlist_url: str, headers: Dict[str, str], target_duration: int) -> None:
"""
Aggiorna periodicamente la playlist per seguire la sliding window e mantenere la cache coerente.
Interrompe e pulisce dopo inattività prolungata.
"""
sleep_s = max(2, min(15, int(target_duration)))
inactivity_timeout = 600 # 10 minuti
while True:
try:
st = self.playlist_state.get(playlist_url)
now = asyncio.get_event_loop().time()
if not st:
return
if now - st.get("last_access", now) > inactivity_timeout:
# cleanup specifico della playlist
urls = set(self.segment_urls.get(playlist_url, []))
if urls:
# rimuovi dalla cache solo i segmenti di questa playlist
for u in list(self.segment_cache.keys()):
if u in urls:
self.segment_cache.pop(u, None)
# rimuovi mapping
for u in urls:
self.segment_to_playlist.pop(u, None)
self.segment_urls.pop(playlist_url, None)
self.playlist_state.pop(playlist_url, None)
logger.info(f"Stopped HLS prebuffer for inactive playlist: {playlist_url}")
return
# refresh manifest
resp = await self.client.get(playlist_url, headers=headers)
resp.raise_for_status()
content = resp.text
new_target = self._parse_target_duration(content)
if new_target:
sleep_s = max(2, min(15, int(new_target)))
new_urls = self._extract_segment_urls(content, playlist_url)
if new_urls:
self.segment_urls[playlist_url] = new_urls
# rebuild reverse map per gli ultimi N (limita la memoria)
for idx, u in enumerate(new_urls[-(self.max_cache_size * 2):]):
# rimappiando sovrascrivi eventuali entry
real_idx = len(new_urls) - (self.max_cache_size * 2) + idx if len(new_urls) > (self.max_cache_size * 2) else idx
self.segment_to_playlist[u] = (playlist_url, real_idx)
# tenta un prebuffer proattivo: se conosciamo l'ultimo segmento accessibile, anticipa i successivi
# Non conosciamo l'indice di riproduzione corrente qui, quindi non facciamo nulla di aggressivo.
except Exception as e:
logger.debug(f"Playlist refresh error for {playlist_url}: {e}")
await asyncio.sleep(sleep_s)
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Extract segment URLs from HLS playlist content.
Args:
playlist_content (str): Content of the HLS playlist
base_url (str): Base URL for resolving relative URLs
Returns:
List[str]: List of segment URLs
"""
segment_urls = []
lines = playlist_content.split('\n')
logger.debug(f"Analyzing playlist with {len(lines)} lines")
for line in lines:
line = line.strip()
if line and not line.startswith('#'):
# Check if line contains a URL (http/https) or is a relative path
if 'http://' in line or 'https://' in line:
segment_urls.append(line)
logger.debug(f"Found absolute URL: {line}")
elif line and not line.startswith('#'):
# This might be a relative path to a segment
parsed_base = urlparse(base_url)
# Ensure proper path joining
if line.startswith('/'):
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
else:
# Get the directory path from base_url
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
segment_urls.append(segment_url)
logger.debug(f"Found relative path: {line} -> {segment_url}")
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
if segment_urls:
logger.debug(f"First segment URL: {segment_urls[0]}")
else:
logger.debug("No segment URLs found in playlist")
# Log first few lines for debugging
for i, line in enumerate(lines[:10]):
logger.debug(f"Line {i}: {line}")
return segment_urls
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
"""
from urllib.parse import urljoin
variant_urls = []
lines = [l.strip() for l in playlist_content.split('\n')]
take_next_uri = False
for line in lines:
if line.startswith("#EXT-X-STREAM-INF"):
take_next_uri = True
continue
if take_next_uri:
take_next_uri = False
if line and not line.startswith('#'):
variant_urls.append(urljoin(base_url, line))
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
if variant_urls:
logger.debug(f"First variant URL: {variant_urls[0]}")
return variant_urls
"""
HLS Pre-buffer system with priority-based sequential prefetching.
This module provides a smart prebuffering system that:
- Prioritizes player-requested segments (downloaded immediately)
- Prefetches remaining segments sequentially in background
- Supports multiple users watching the same channel (shared prefetcher)
- Cleans up inactive prefetchers automatically
Architecture:
1. When playlist is fetched, register_playlist() creates a PlaylistPrefetcher
2. PlaylistPrefetcher runs a background loop: priority queue -> sequential prefetch
3. When player requests a segment, request_segment() adds it to priority queue
4. Prefetcher downloads priority segment first, then continues sequential
"""
import asyncio
import logging
import time
from typing import Dict, Optional, List
from urllib.parse import urljoin
from mediaflow_proxy.utils.base_prebuffer import BasePrebuffer
from mediaflow_proxy.utils.cache_utils import get_cached_segment
from mediaflow_proxy.configs import settings
logger = logging.getLogger(__name__)
class PlaylistPrefetcher:
"""
Manages prefetching for a single playlist with priority support.
Key design for live streams with changing tokens:
- Does NOT start prefetching immediately on registration
- Only starts prefetching AFTER player requests a segment
- This ensures we prefetch from the CURRENT playlist, not stale ones
The prefetcher runs a background loop that:
1. Waits for player to request a segment (priority)
2. Downloads the priority segment first
3. Then prefetches subsequent segments sequentially
4. Stops when cancelled or all segments are prefetched
"""
def __init__(
self,
playlist_url: str,
segment_urls: List[str],
headers: Dict[str, str],
prebuffer: "HLSPreBuffer",
prefetch_limit: int = 5,
):
"""
Initialize a playlist prefetcher.
Args:
playlist_url: URL of the HLS playlist
segment_urls: Ordered list of segment URLs from the playlist
headers: Headers to use for requests
prebuffer: Parent HLSPreBuffer instance for download methods
prefetch_limit: Maximum number of segments to prefetch ahead of player position
"""
self.playlist_url = playlist_url
self.segment_urls = segment_urls
self.headers = headers
self.prebuffer = prebuffer
self.prefetch_limit = prefetch_limit
self.last_access = time.time()
self.current_index = 0 # Next segment to prefetch sequentially
self.player_index = 0 # Last segment index requested by player
self.priority_event = asyncio.Event() # Signals priority segment available
self.priority_url: Optional[str] = None # Current priority segment
self.cancelled = False
self._task: Optional[asyncio.Task] = None
self._lock = asyncio.Lock() # Protects priority_url
# Track which segments are already cached or being downloaded
self.downloading: set = set()
# Track if prefetching has been activated by a player request
self.activated = False
def start(self) -> None:
"""Start the prefetch background task."""
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._run())
logger.info(f"[PlaylistPrefetcher] Started (waiting for activation): {self.playlist_url}")
def stop(self) -> None:
"""Stop the prefetch background task."""
self.cancelled = True
self.priority_event.set() # Wake up the loop
if self._task and not self._task.done():
self._task.cancel()
logger.info(f"[PlaylistPrefetcher] Stopped for: {self.playlist_url}")
def update_segments(self, segment_urls: List[str]) -> None:
"""
Update segment URLs (called when playlist is refreshed).
Args:
segment_urls: New list of segment URLs
"""
self.segment_urls = segment_urls
self.last_access = time.time()
logger.debug(f"[PlaylistPrefetcher] Updated segments ({len(segment_urls)}): {self.playlist_url}")
async def request_priority(self, segment_url: str) -> None:
"""
Player requested this segment - update indices and activate prefetching.
The player will download this segment via get_or_download().
The prefetcher's job is to prefetch segments AHEAD of the player,
not to download the segment the player is already requesting.
For VOD/movie streams: handles seek by detecting large jumps in segment
index and resetting the prefetch window accordingly.
Args:
segment_url: URL of the segment the player needs
"""
self.last_access = time.time()
self.activated = True # Activate prefetching
# Update player position for prefetch limit calculation
segment_index = self._find_segment_index(segment_url)
if segment_index >= 0:
old_player_index = self.player_index
self.player_index = segment_index
# Start prefetching from the NEXT segment (player handles current one)
self.current_index = segment_index + 1
# Detect seek: if player jumped more than prefetch_limit segments
# This handles VOD seek scenarios where user jumps to different position
jump_distance = abs(segment_index - old_player_index)
if jump_distance > self.prefetch_limit and old_player_index >= 0:
logger.info(
f"[PlaylistPrefetcher] Seek detected: jumped {jump_distance} segments "
f"(from {old_player_index} to {segment_index})"
)
# Signal the prefetch loop to wake up and start prefetching ahead
async with self._lock:
self.priority_url = segment_url
self.priority_event.set()
def _find_segment_index(self, segment_url: str) -> int:
"""Find the index of a segment URL in the list."""
try:
return self.segment_urls.index(segment_url)
except ValueError:
return -1
async def _run(self) -> None:
"""
Main prefetch loop.
For live streams: waits until activated by player request before prefetching.
Priority: Player-requested segment > Sequential prefetch
After downloading priority segment, continue sequential from that point.
Prefetching is LIMITED to `prefetch_limit` segments ahead of the player's
current position to avoid downloading the entire stream.
"""
logger.info(f"[PlaylistPrefetcher] Loop started for: {self.playlist_url}")
while not self.cancelled:
try:
# Wait for activation (player request) before doing anything
if not self.activated:
try:
await asyncio.wait_for(self.priority_event.wait(), timeout=1.0)
except asyncio.TimeoutError:
continue
# Check for priority segment first
async with self._lock:
priority_url = self.priority_url
self.priority_url = None
self.priority_event.clear()
if priority_url:
# Player is already downloading this segment via get_or_download()
# We just need to update our indices and skip to prefetching NEXT segments
# This avoids duplicate download attempts and inflated cache miss stats
priority_index = self._find_segment_index(priority_url)
if priority_index >= 0:
self.player_index = priority_index
self.current_index = priority_index + 1 # Start prefetching from next segment
logger.info(
f"[PlaylistPrefetcher] Player at index {self.player_index}, "
f"will prefetch up to {self.prefetch_limit} segments ahead"
)
continue
# Calculate prefetch limit based on player position
max_prefetch_index = self.player_index + self.prefetch_limit + 1
# No priority - prefetch next sequential segment (only if within limit)
if (
self.activated
and self.current_index < len(self.segment_urls)
and self.current_index < max_prefetch_index
):
url = self.segment_urls[self.current_index]
# Skip if already cached or being downloaded
if url not in self.downloading:
cached = await get_cached_segment(url)
if not cached:
logger.info(
f"[PlaylistPrefetcher] Prefetching [{self.current_index}] "
f"(player at {self.player_index}, limit {self.prefetch_limit}): {url}"
)
await self._download_segment(url)
else:
logger.debug(f"[PlaylistPrefetcher] Already cached [{self.current_index}]: {url}")
self.current_index += 1
else:
# Reached prefetch limit or end of segments - wait for player to advance
try:
await asyncio.wait_for(self.priority_event.wait(), timeout=1.0)
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
logger.info(f"[PlaylistPrefetcher] Loop cancelled: {self.playlist_url}")
return
except Exception as e:
logger.warning(f"[PlaylistPrefetcher] Error in loop: {e}")
await asyncio.sleep(0.5)
logger.info(f"[PlaylistPrefetcher] Loop ended: {self.playlist_url}")
async def _download_segment(self, url: str) -> None:
"""
Download and cache a segment using the parent prebuffer.
Args:
url: URL of the segment to download
"""
if url in self.downloading:
return
self.downloading.add(url)
try:
# Use the base prebuffer's get_or_download for cross-process coordination
await self.prebuffer.get_or_download(url, self.headers)
finally:
self.downloading.discard(url)
class HLSPreBuffer(BasePrebuffer):
"""
Pre-buffer system for HLS streams with priority-based prefetching.
Features:
- Priority queue: Player-requested segments downloaded first
- Sequential prefetch: Background prefetch of remaining segments
- Multi-user support: Multiple users share same prefetcher
- Automatic cleanup: Inactive prefetchers removed after timeout
"""
def __init__(
self,
max_cache_size: Optional[int] = None,
prebuffer_segments: Optional[int] = None,
):
"""
Initialize the HLS pre-buffer system.
Args:
max_cache_size: Maximum number of segments to cache (uses config if None)
prebuffer_segments: Number of segments to pre-buffer ahead (uses config if None)
"""
super().__init__(
max_cache_size=max_cache_size or settings.hls_prebuffer_cache_size,
prebuffer_segments=prebuffer_segments or settings.hls_prebuffer_segments,
max_memory_percent=settings.hls_prebuffer_max_memory_percent,
emergency_threshold=settings.hls_prebuffer_emergency_threshold,
segment_ttl=settings.hls_segment_cache_ttl,
)
self.inactivity_timeout = settings.hls_prebuffer_inactivity_timeout
# Active prefetchers: playlist_url -> PlaylistPrefetcher
self.active_prefetchers: Dict[str, PlaylistPrefetcher] = {}
# Reverse mapping: segment URL -> playlist_url
self.segment_to_playlist: Dict[str, str] = {}
# Lock for prefetcher management
self._prefetcher_lock = asyncio.Lock()
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_interval = 30 # Check every 30 seconds
def log_stats(self) -> None:
"""Log current prebuffer statistics with HLS-specific info."""
stats = self.stats.to_dict()
stats["active_prefetchers"] = len(self.active_prefetchers)
logger.info(f"HLS Prebuffer Stats: {stats}")
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Extract segment URLs from HLS playlist content.
Args:
playlist_content: Content of the HLS playlist
base_url: Base URL for resolving relative URLs
Returns:
List of segment URLs
"""
segment_urls = []
lines = playlist_content.split("\n")
for line in lines:
line = line.strip()
if line and not line.startswith("#"):
# Absolute URL
if line.startswith("http://") or line.startswith("https://"):
segment_urls.append(line)
else:
# Relative URL - resolve against base
segment_url = urljoin(base_url, line)
segment_urls.append(segment_url)
return segment_urls
def _is_master_playlist(self, playlist_content: str) -> bool:
"""Check if this is a master playlist (contains variant streams)."""
return "#EXT-X-STREAM-INF" in playlist_content
async def register_playlist(
self,
playlist_url: str,
segment_urls: List[str],
headers: Dict[str, str],
) -> None:
"""
Register a playlist for prefetching.
Creates a new PlaylistPrefetcher or updates existing one.
Called by M3U8 processor when a playlist is fetched.
Args:
playlist_url: URL of the HLS playlist
segment_urls: Ordered list of segment URLs from the playlist
headers: Headers to use for requests
"""
if not segment_urls:
logger.debug(f"[register_playlist] No segments, skipping: {playlist_url}")
return
async with self._prefetcher_lock:
# Update reverse mapping
for url in segment_urls:
self.segment_to_playlist[url] = playlist_url
if playlist_url in self.active_prefetchers:
# Update existing prefetcher
prefetcher = self.active_prefetchers[playlist_url]
prefetcher.update_segments(segment_urls)
prefetcher.headers = headers
logger.info(f"[register_playlist] Updated existing prefetcher: {playlist_url}")
else:
# Create new prefetcher with configured prefetch limit
prefetcher = PlaylistPrefetcher(
playlist_url=playlist_url,
segment_urls=segment_urls,
headers=headers,
prebuffer=self,
prefetch_limit=settings.hls_prebuffer_segments,
)
self.active_prefetchers[playlist_url] = prefetcher
prefetcher.start()
logger.info(
f"[register_playlist] Created new prefetcher ({len(segment_urls)} segments, "
f"prefetch_limit={settings.hls_prebuffer_segments}): {playlist_url}"
)
# Ensure cleanup task is running
self._ensure_cleanup_task()
async def request_segment(self, segment_url: str) -> None:
"""
Player requested a segment - set as priority for prefetching.
Finds the prefetcher for this segment and adds it to priority queue.
Called by the segment endpoint when a segment is requested.
Args:
segment_url: URL of the segment the player needs
"""
playlist_url = self.segment_to_playlist.get(segment_url)
if not playlist_url:
logger.debug(f"[request_segment] No prefetcher found for: {segment_url}")
return
prefetcher = self.active_prefetchers.get(playlist_url)
if prefetcher:
await prefetcher.request_priority(segment_url)
else:
logger.debug(f"[request_segment] Prefetcher not active for: {playlist_url}")
def _ensure_cleanup_task(self) -> None:
"""Ensure the cleanup task is running."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _cleanup_loop(self) -> None:
"""Periodically clean up inactive prefetchers."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_inactive_prefetchers()
except asyncio.CancelledError:
return
except Exception as e:
logger.warning(f"[cleanup_loop] Error: {e}")
async def _cleanup_inactive_prefetchers(self) -> None:
"""Remove prefetchers that haven't been accessed recently."""
now = time.time()
to_remove = []
async with self._prefetcher_lock:
for playlist_url, prefetcher in self.active_prefetchers.items():
inactive_time = now - prefetcher.last_access
if inactive_time > self.inactivity_timeout:
to_remove.append(playlist_url)
logger.info(f"[cleanup] Removing inactive prefetcher ({inactive_time:.0f}s): {playlist_url}")
for playlist_url in to_remove:
prefetcher = self.active_prefetchers.pop(playlist_url, None)
if prefetcher:
prefetcher.stop()
# Clean up reverse mapping
for url in prefetcher.segment_urls:
self.segment_to_playlist.pop(url, None)
if to_remove:
logger.info(f"[cleanup] Removed {len(to_remove)} inactive prefetchers")
def get_stats(self) -> dict:
"""Get current prebuffer statistics."""
stats = self.stats.to_dict()
stats["active_prefetchers"] = len(self.active_prefetchers)
return stats
def clear_cache(self) -> None:
"""Clear all prebuffer state and log final stats."""
self.log_stats()
# Stop all prefetchers
for prefetcher in self.active_prefetchers.values():
prefetcher.stop()
self.active_prefetchers.clear()
self.segment_to_playlist.clear()
self.stats.reset()
logger.info("HLS pre-buffer state cleared")
async def close(self) -> None:
"""Close the pre-buffer system."""
self.clear_cache()
if self._cleanup_task:
self._cleanup_task.cancel()
# Global HLS pre-buffer instance
hls_prebuffer = HLSPreBuffer()

View File

@@ -1,11 +1,47 @@
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Any, Optional
from urllib.parse import urljoin
logger = logging.getLogger(__name__)
def find_stream_by_resolution(streams: List[Dict[str, Any]], target_resolution: str) -> Optional[Dict[str, Any]]:
"""
Find stream matching target resolution (e.g., '1080p', '720p').
Falls back to closest lower resolution if exact match not found.
Args:
streams: List of stream dictionaries with 'resolution' key as (width, height) tuple.
target_resolution: Target resolution string (e.g., '1080p', '720p').
Returns:
The matching stream dictionary, or None if no streams available.
"""
# Parse target height from "1080p" -> 1080
target_height = int(target_resolution.rstrip("p"))
# Filter streams with valid resolution (height > 0), sort by height descending
valid_streams = [s for s in streams if s.get("resolution", (0, 0))[1] > 0]
if not valid_streams:
logger.warning("No streams with valid resolution found")
return streams[0] if streams else None
sorted_streams = sorted(valid_streams, key=lambda s: s["resolution"][1], reverse=True)
# Find exact match or closest lower
for stream in sorted_streams:
stream_height = stream["resolution"][1]
if stream_height <= target_height:
logger.info(f"Selected stream with resolution {stream['resolution']} for target {target_resolution}")
return stream
# If all streams are higher than target, return lowest available
lowest_stream = sorted_streams[-1]
logger.info(f"All streams higher than target {target_resolution}, using lowest: {lowest_stream['resolution']}")
return lowest_stream
def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Parses an HLS master playlist to extract stream information.
@@ -18,37 +54,37 @@ def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) ->
List[Dict[str, Any]]: A list of dictionaries, each representing a stream variant.
"""
streams = []
lines = playlist_content.strip().split('\n')
lines = playlist_content.strip().split("\n")
# Regex to capture attributes from #EXT-X-STREAM-INF
stream_inf_pattern = re.compile(r'#EXT-X-STREAM-INF:(.*)')
stream_inf_pattern = re.compile(r"#EXT-X-STREAM-INF:(.*)")
for i, line in enumerate(lines):
if line.startswith('#EXT-X-STREAM-INF'):
stream_info = {'raw_stream_inf': line}
if line.startswith("#EXT-X-STREAM-INF"):
stream_info = {"raw_stream_inf": line}
match = stream_inf_pattern.match(line)
if not match:
logger.warning(f"Could not parse #EXT-X-STREAM-INF line: {line}")
continue
attributes_str = match.group(1)
# Parse attributes like BANDWIDTH, RESOLUTION, etc.
attributes = re.findall(r'([A-Z-]+)=("([^"]+)"|([^,]+))', attributes_str)
for key, _, quoted_val, unquoted_val in attributes:
value = quoted_val if quoted_val else unquoted_val
if key == 'RESOLUTION':
if key == "RESOLUTION":
try:
width, height = map(int, value.split('x'))
stream_info['resolution'] = (width, height)
width, height = map(int, value.split("x"))
stream_info["resolution"] = (width, height)
except ValueError:
stream_info['resolution'] = (0, 0)
stream_info["resolution"] = (0, 0)
else:
stream_info[key.lower().replace('-', '_')] = value
stream_info[key.lower().replace("-", "_")] = value
# The next line should be the stream URL
if i + 1 < len(lines) and not lines[i + 1].startswith('#'):
if i + 1 < len(lines) and not lines[i + 1].startswith("#"):
stream_url = lines[i + 1].strip()
stream_info['url'] = urljoin(base_url, stream_url) if base_url else stream_url
stream_info["url"] = urljoin(base_url, stream_url) if base_url else stream_url
streams.append(stream_info)
return streams
return streams

View File

@@ -0,0 +1,362 @@
"""
aiohttp client factory with URL-based SSL verification and proxy routing.
This module provides a centralized HTTP client factory for aiohttp,
allowing per-URL configuration of SSL verification and proxy routing.
"""
import logging
import ssl
import typing
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse
import aiohttp
from aiohttp import ClientSession, ClientTimeout, TCPConnector
logger = logging.getLogger(__name__)
@dataclass
class RouteMatch:
"""Configuration for a matched route."""
verify_ssl: bool = True
proxy_url: Optional[str] = None
@dataclass
class URLRoutingConfig:
"""
URL-based routing configuration for SSL verification and proxy settings.
Supports pattern matching:
- "all://*.example.com" - matches all protocols for *.example.com
- "https://api.example.com" - matches specific protocol and host
- "all://" - default fallback for all URLs
"""
# Pattern -> (verify_ssl, proxy_url)
routes: Dict[str, Tuple[bool, Optional[str]]] = field(default_factory=dict)
# Global defaults
default_verify_ssl: bool = True
default_proxy_url: Optional[str] = None
def add_route(
self,
pattern: str,
verify_ssl: bool = True,
proxy_url: Optional[str] = None,
) -> None:
"""
Add a route configuration.
Args:
pattern: URL pattern (e.g., "all://*.example.com", "https://api.example.com")
verify_ssl: Whether to verify SSL for this pattern
proxy_url: Proxy URL to use for this pattern (None = no proxy)
"""
self.routes[pattern] = (verify_ssl, proxy_url)
def match_url(self, url: str) -> RouteMatch:
"""
Find the best matching route for a URL.
Args:
url: The URL to match
Returns:
RouteMatch with SSL and proxy settings
"""
if not url:
return RouteMatch(
verify_ssl=self.default_verify_ssl,
proxy_url=self.default_proxy_url,
)
parsed = urlparse(url)
scheme = parsed.scheme.lower()
host = parsed.netloc.lower()
# Remove port from host for matching
if ":" in host:
host = host.split(":")[0]
best_match: Optional[RouteMatch] = None
best_specificity = -1
for pattern, (verify_ssl, proxy_url) in self.routes.items():
specificity = self._match_pattern(pattern, scheme, host)
if specificity > best_specificity:
best_specificity = specificity
best_match = RouteMatch(verify_ssl=verify_ssl, proxy_url=proxy_url)
if best_match:
return best_match
# Return defaults
return RouteMatch(
verify_ssl=self.default_verify_ssl,
proxy_url=self.default_proxy_url,
)
def _match_pattern(self, pattern: str, scheme: str, host: str) -> int:
"""
Check if a pattern matches the given scheme and host.
Returns specificity score (higher = more specific match):
- -1: No match
- 0: Default match (all://)
- 1: Scheme match only
- 2: Wildcard host match
- 3: Exact host match
"""
# Parse pattern
if "://" in pattern:
pattern_scheme, pattern_host = pattern.split("://", 1)
else:
return -1
# Check scheme
scheme_matches = pattern_scheme.lower() == "all" or pattern_scheme.lower() == scheme
if not scheme_matches:
return -1
# Empty host = default route
if not pattern_host:
return 0
# Check host with wildcard support
if pattern_host.startswith("*."):
# Wildcard subdomain match
suffix = pattern_host[1:] # Remove the *
if host.endswith(suffix) or host == pattern_host[2:]:
return 2
return -1
elif pattern_host == host:
# Exact match
return 3
else:
return -1
# Global routing configuration - will be initialized from settings
_global_routing_config: Optional[URLRoutingConfig] = None
_routing_initialized = False
def get_routing_config() -> URLRoutingConfig:
"""Get the global URL routing configuration."""
global _global_routing_config
if _global_routing_config is None:
_global_routing_config = URLRoutingConfig()
return _global_routing_config
def initialize_routing_from_config(transport_config) -> None:
"""
Initialize the global routing configuration from TransportConfig.
Args:
transport_config: The TransportConfig instance from settings
"""
global _global_routing_config, _routing_initialized
config = URLRoutingConfig(
default_verify_ssl=not transport_config.disable_ssl_verification_globally,
default_proxy_url=transport_config.proxy_url if transport_config.all_proxy else None,
)
# Add configured routes
for pattern, route in transport_config.transport_routes.items():
global_verify = not transport_config.disable_ssl_verification_globally
verify_ssl = route.verify_ssl if global_verify else False
proxy_url = route.proxy_url or transport_config.proxy_url if route.proxy else None
config.add_route(pattern, verify_ssl=verify_ssl, proxy_url=proxy_url)
# Hardcoded routes for specific domains (SSL verification disabled)
hardcoded_domains = [
"all://jxoplay.xyz",
"all://dlhd.dad",
"all://*.newkso.ru",
]
for domain in hardcoded_domains:
proxy_url = transport_config.proxy_url if transport_config.all_proxy else None
config.add_route(domain, verify_ssl=False, proxy_url=proxy_url)
# Default route for global settings
if transport_config.all_proxy or transport_config.disable_ssl_verification_globally:
default_proxy = transport_config.proxy_url if transport_config.all_proxy else None
config.add_route(
"all://",
verify_ssl=not transport_config.disable_ssl_verification_globally,
proxy_url=default_proxy,
)
_global_routing_config = config
_routing_initialized = True
logger.info(f"Initialized aiohttp routing with {len(config.routes)} routes")
def _ensure_routing_initialized():
"""Ensure routing configuration is initialized from settings."""
global _routing_initialized
if not _routing_initialized:
from mediaflow_proxy.configs import settings
initialize_routing_from_config(settings.transport_config)
def _get_ssl_context(verify: bool) -> ssl.SSLContext:
"""Get an SSL context with the specified verification setting."""
if verify:
return ssl.create_default_context()
else:
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx
def create_proxy_connector(proxy_url: str, verify_ssl: bool = True) -> aiohttp.BaseConnector:
"""
Create a connector for proxy connections, supporting SOCKS5 and HTTP proxies.
Args:
proxy_url: The proxy URL (socks5://..., http://..., https://...)
verify_ssl: Whether to verify SSL certificates
Returns:
Appropriate connector for the proxy type
"""
parsed = urlparse(proxy_url)
scheme = parsed.scheme.lower()
ssl_context = _get_ssl_context(verify_ssl)
if scheme in ("socks5", "socks5h", "socks4", "socks4a"):
try:
from aiohttp_socks import ProxyConnector, ProxyType
proxy_type_map = {
"socks5": ProxyType.SOCKS5,
"socks5h": ProxyType.SOCKS5,
"socks4": ProxyType.SOCKS4,
"socks4a": ProxyType.SOCKS4,
}
return ProxyConnector(
proxy_type=proxy_type_map[scheme],
host=parsed.hostname,
port=parsed.port or 1080,
username=parsed.username,
password=parsed.password,
rdns=scheme.endswith("h"), # Remote DNS resolution for socks5h
ssl=ssl_context if not verify_ssl else None,
)
except ImportError:
logger.warning("aiohttp-socks not installed, SOCKS proxy support unavailable")
raise
else:
# HTTP/HTTPS proxy - use standard connector
# The proxy URL will be passed to the request method
return TCPConnector(
ssl=ssl_context,
limit=100,
limit_per_host=10,
)
def _create_connector(proxy_url: Optional[str], verify_ssl: bool) -> Tuple[aiohttp.BaseConnector, Optional[str]]:
"""
Create an appropriate connector based on proxy configuration.
Args:
proxy_url: The proxy URL or None
verify_ssl: Whether to verify SSL certificates
Returns:
Tuple of (connector, effective_proxy_url)
For SOCKS proxies, effective_proxy_url is None (handled by connector)
For HTTP proxies, effective_proxy_url is passed to requests
"""
if proxy_url:
parsed_proxy = urlparse(proxy_url)
if parsed_proxy.scheme in ("socks5", "socks5h", "socks4", "socks4a"):
# SOCKS proxy - use special connector, proxy handled internally
connector = create_proxy_connector(proxy_url, verify_ssl)
return connector, None
else:
# HTTP proxy - use standard connector, pass proxy to request
ssl_ctx = _get_ssl_context(verify_ssl)
connector = TCPConnector(ssl=ssl_ctx, limit=100, limit_per_host=10)
return connector, proxy_url
else:
ssl_ctx = _get_ssl_context(verify_ssl)
connector = TCPConnector(ssl=ssl_ctx, limit=100, limit_per_host=10)
return connector, None
@asynccontextmanager
async def create_aiohttp_session(
url: str = None,
timeout: typing.Union[int, float, ClientTimeout] = None,
headers: typing.Optional[typing.Dict[str, str]] = None,
verify: typing.Optional[bool] = None,
) -> typing.AsyncGenerator[typing.Tuple[ClientSession, typing.Optional[str]], None]:
"""
Create an aiohttp ClientSession with configured proxy routing and SSL settings.
This is the primary way to create HTTP sessions in the application.
It automatically applies URL-based routing for SSL verification and proxy settings.
Args:
url: The URL to configure the session for (used for routing)
timeout: Request timeout (int/float for total seconds, or ClientTimeout)
headers: Default headers for the session
verify: Override SSL verification (None = use routing config)
Yields:
Tuple of (session, proxy_url) - proxy_url should be passed to request methods
"""
_ensure_routing_initialized()
# Get routing configuration for the URL
routing_config = get_routing_config()
route_match = routing_config.match_url(url)
# Determine SSL verification
if verify is not None:
use_verify = verify
else:
use_verify = route_match.verify_ssl
# Create timeout
if timeout is None:
from mediaflow_proxy.configs import settings
timeout_config = ClientTimeout(total=settings.transport_config.timeout)
elif isinstance(timeout, (int, float)):
timeout_config = ClientTimeout(total=timeout)
else:
timeout_config = timeout
# Create connector
connector, effective_proxy_url = _create_connector(route_match.proxy_url, use_verify)
session = ClientSession(
connector=connector,
timeout=timeout_config,
headers=headers,
)
try:
yield session, effective_proxy_url
finally:
await session.close()

View File

@@ -1,25 +1,33 @@
import asyncio
import logging
import typing
from dataclasses import dataclass
from functools import partial
from urllib import parse
from urllib.parse import urlencode, urlparse
from urllib.parse import urlencode
import aiohttp
from aiohttp import ClientSession, ClientTimeout, ClientResponse
import anyio
import h11
import httpx
import tenacity
from fastapi import Response
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.requests import Request
from starlette.types import Receive, Send, Scope
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm.asyncio import tqdm as tqdm_asyncio
from mediaflow_proxy.configs import settings
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
from mediaflow_proxy.utils.stream_transformers import StreamTransformer
from mediaflow_proxy.utils.http_client import (
create_aiohttp_session,
get_routing_config,
_ensure_routing_initialized,
_create_connector,
)
logger = logging.getLogger(__name__)
@@ -31,218 +39,279 @@ class DownloadError(Exception):
super().__init__(message)
def create_httpx_client(follow_redirects: bool = True, **kwargs) -> httpx.AsyncClient:
"""Creates an HTTPX client with configured proxy routing"""
mounts = settings.transport_config.get_mounts()
kwargs.setdefault("timeout", settings.transport_config.timeout)
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, **kwargs)
return client
def retry_if_download_error_not_404(retry_state):
"""Retry on DownloadError except for 404 errors."""
if retry_state.outcome.failed:
exception = retry_state.outcome.exception()
if isinstance(exception, DownloadError):
return exception.status_code != 404
return False
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(DownloadError),
retry=retry_if_download_error_not_404,
)
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
async def fetch_with_retry(
session: ClientSession,
method: str,
url: str,
headers: dict,
proxy: typing.Optional[str] = None,
**kwargs,
) -> ClientResponse:
"""
Fetches a URL with retry logic.
Fetches a URL with retry logic using native aiohttp.
Args:
client (httpx.AsyncClient): The HTTP client to use for the request.
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to fetch.
headers (dict): The headers to include in the request.
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
session: The aiohttp ClientSession to use for the request.
method: The HTTP method to use (e.g., GET, POST).
url: The URL to fetch.
headers: The headers to include in the request.
proxy: Optional proxy URL for HTTP proxies.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
ClientResponse: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
try:
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
response = await session.request(method, url, headers=headers, proxy=proxy, **kwargs)
response.raise_for_status()
return response
except httpx.TimeoutException:
except asyncio.TimeoutError:
logger.warning(f"Timeout while downloading {url}")
raise DownloadError(409, f"Timeout while downloading {url}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
if e.response.status_code == 404:
logger.error(f"Segment Resource not found: {url}")
raise e
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
except aiohttp.ClientResponseError as e:
if e.status == 404:
logger.debug(f"Segment not found (404): {url}")
raise DownloadError(404, f"Not found (404): {url}")
logger.error(f"HTTP error {e.status} while downloading {url}")
raise DownloadError(e.status, f"HTTP error {e.status} while downloading {url}")
except aiohttp.ClientError as e:
logger.error(f"Client error downloading {url}: {e}")
raise DownloadError(502, f"Client error downloading {url}: {e}")
except Exception as e:
logger.error(f"Error downloading {url}: {e}")
raise
class Streamer:
# PNG signature and IEND marker for fake PNG header detection (StreamWish/FileMoon)
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
_PNG_IEND_MARKER = b"\x49\x45\x4E\x44\xAE\x42\x60\x82"
"""Handles streaming HTTP responses using aiohttp."""
def __init__(self, client):
def __init__(self, session: ClientSession, proxy_url: typing.Optional[str] = None):
"""
Initializes the Streamer with an HTTP client.
Initializes the Streamer with an aiohttp session.
Args:
client (httpx.AsyncClient): The HTTP client to use for streaming.
session: The aiohttp ClientSession to use for streaming.
proxy_url: Optional proxy URL for HTTP proxies.
"""
self.client = client
self.response = None
self.session = session
self.proxy_url = proxy_url
self.response: typing.Optional[ClientResponse] = None
self.progress_bar = None
self.bytes_transferred = 0
self.start_byte = 0
self.end_byte = 0
self.total_size = 0
# Store request details for potential retry during streaming
self._current_url: typing.Optional[str] = None
self._current_headers: typing.Optional[dict] = None
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(DownloadError),
retry=retry_if_download_error_not_404,
)
async def create_streaming_response(self, url: str, headers: dict):
async def create_streaming_response(self, url: str, headers: dict, method: str = "GET"):
"""
Creates and sends a streaming request.
Args:
url (str): The URL to stream from.
headers (dict): The headers to include in the request.
url: The URL to stream from.
headers: The headers to include in the request.
method: HTTP method to use (GET or HEAD). Defaults to GET.
For HEAD requests, will fallback to GET if server doesn't support HEAD.
"""
# Store request details for potential retry during streaming
self._current_url = url
self._current_headers = headers.copy()
try:
request = self.client.build_request("GET", url, headers=headers)
self.response = await self.client.send(request, stream=True, follow_redirects=True)
self.response.raise_for_status()
except httpx.TimeoutException:
if method.upper() == "HEAD":
# Try HEAD first, fallback to GET if server doesn't support it
try:
self.response = await self.session.head(url, headers=headers, proxy=self.proxy_url)
self.response.raise_for_status()
except (aiohttp.ClientResponseError, aiohttp.ClientError) as head_error:
# HEAD failed, fallback to GET (some servers don't support HEAD)
logger.debug(f"HEAD request failed ({head_error}), falling back to GET")
self.response = await self.session.get(url, headers=headers, proxy=self.proxy_url)
self.response.raise_for_status()
else:
self.response = await self.session.get(url, headers=headers, proxy=self.proxy_url)
self.response.raise_for_status()
except asyncio.TimeoutError:
logger.warning("Timeout while creating streaming response")
raise DownloadError(409, "Timeout while creating streaming response")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} while creating streaming response")
if e.response.status_code == 404:
logger.error(f"Segment Resource not found: {url}")
raise e
raise DownloadError(
e.response.status_code, f"HTTP error {e.response.status_code} while creating streaming response"
)
except httpx.RequestError as e:
except aiohttp.ClientResponseError as e:
if e.status == 404:
logger.debug(f"Segment not found (404): {url}")
raise DownloadError(404, f"Not found (404): {url}")
# Don't retry rate-limit errors (429, 509) - retrying while other connections
# are still active just wastes time. Let the player handle its own retry logic.
if e.status in (429, 509):
logger.warning(f"Rate limited ({e.status}) by upstream: {url}")
raise aiohttp.ClientResponseError(e.request_info, e.history, status=e.status, message=e.message)
logger.error(f"HTTP error {e.status} while creating streaming response")
raise DownloadError(e.status, f"HTTP error {e.status} while creating streaming response")
except aiohttp.ClientError as e:
logger.error(f"Error creating streaming response: {e}")
raise DownloadError(502, f"Error creating streaming response: {e}")
except Exception as e:
logger.error(f"Error creating streaming response: {e}")
raise RuntimeError(f"Error creating streaming response: {e}")
@staticmethod
def _strip_fake_png_wrapper(chunk: bytes) -> bytes:
async def _retry_connection(self, from_byte: int) -> bool:
"""
Strip fake PNG wrapper from chunk data.
Some streaming services (StreamWish, FileMoon) prepend a fake PNG image
to video data to evade detection. This method detects and removes it.
Attempt to reconnect to the upstream using Range header.
Args:
chunk: The raw chunk data that may contain a fake PNG header.
from_byte: The byte position to resume from.
Returns:
The chunk with fake PNG wrapper removed, or original chunk if not present.
bool: True if reconnection was successful, False otherwise.
"""
if not chunk.startswith(Streamer._PNG_SIGNATURE):
return chunk
if not self._current_url or not self._current_headers:
return False
# Find the IEND marker that signals end of PNG data
iend_pos = chunk.find(Streamer._PNG_IEND_MARKER)
if iend_pos == -1:
# IEND not found in this chunk - return as-is to avoid data corruption
logger.debug("PNG signature detected but IEND marker not found in chunk")
return chunk
# Close existing response if any
if self.response:
self.response.close()
self.response = None
# Calculate position after IEND marker
content_start = iend_pos + len(Streamer._PNG_IEND_MARKER)
# Create new headers with Range
retry_headers = self._current_headers.copy()
if self.total_size > 0:
retry_headers["Range"] = f"bytes={from_byte}-{self.total_size - 1}"
else:
retry_headers["Range"] = f"bytes={from_byte}-"
# Skip any padding bytes (null or 0xFF) between PNG and actual content
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
content_start += 1
try:
self.response = await self.session.get(self._current_url, headers=retry_headers, proxy=self.proxy_url)
# Accept both 200 and 206 (Partial Content) as valid responses
if self.response.status in (200, 206):
logger.info(f"Successfully reconnected at byte {from_byte}")
return True
else:
logger.warning(f"Retry connection returned unexpected status: {self.response.status}")
return False
except Exception as e:
logger.warning(f"Failed to reconnect: {e}")
return False
stripped_bytes = content_start
logger.debug(f"Stripped {stripped_bytes} bytes of fake PNG wrapper from stream")
async def stream_content(
self, transformer: typing.Optional[StreamTransformer] = None
) -> typing.AsyncGenerator[bytes, None]:
"""
Stream content from the response, optionally applying a transformer.
return chunk[content_start:]
Includes automatic retry logic when upstream disconnects mid-stream,
using Range headers to resume from the last successful byte.
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
Args:
transformer: Optional StreamTransformer to apply host-specific
content manipulation (e.g., PNG stripping, TS detection).
If None, content is streamed directly without modification.
Yields:
Bytes chunks from the upstream response.
"""
if not self.response:
raise RuntimeError("No response available for streaming")
is_first_chunk = True
retry_count = 0
max_retries = settings.upstream_retry_attempts if settings.upstream_retry_on_disconnect else 0
try:
self.parse_content_range()
while True:
try:
self.parse_content_range()
if settings.enable_streaming_progress:
with tqdm_asyncio(
total=self.total_size,
initial=self.start_byte,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc="Streaming",
ncols=100,
mininterval=1,
) as self.progress_bar:
async for chunk in self.response.aiter_bytes():
if is_first_chunk:
is_first_chunk = False
chunk = self._strip_fake_png_wrapper(chunk)
# Create async generator from response content
async def raw_chunks():
async for chunk in self.response.content.iter_any():
yield chunk
# Choose the chunk source based on whether we have a transformer
# Note: Transformer state may not survive reconnection properly for all transformers
if transformer and retry_count == 0:
chunk_source = transformer.transform(raw_chunks())
else:
chunk_source = raw_chunks()
if settings.enable_streaming_progress:
with tqdm_asyncio(
total=self.total_size,
initial=self.start_byte,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc="Streaming",
ncols=100,
mininterval=1,
) as self.progress_bar:
async for chunk in chunk_source:
yield chunk
self.bytes_transferred += len(chunk)
self.progress_bar.update(len(chunk))
else:
async for chunk in chunk_source:
yield chunk
self.bytes_transferred += len(chunk)
self.progress_bar.update(len(chunk))
else:
async for chunk in self.response.aiter_bytes():
if is_first_chunk:
is_first_chunk = False
chunk = self._strip_fake_png_wrapper(chunk)
yield chunk
self.bytes_transferred += len(chunk)
# Successfully completed streaming
return
except httpx.TimeoutException:
logger.warning("Timeout while streaming")
raise DownloadError(409, "Timeout while streaming")
except httpx.RemoteProtocolError as e:
# Special handling for connection closed errors
if "peer closed connection without sending complete message body" in str(e):
logger.warning(f"Remote server closed connection prematurely: {e}")
# If we've received some data, just log the warning and return normally
except asyncio.TimeoutError:
logger.warning("Timeout while streaming")
raise DownloadError(409, "Timeout while streaming")
except (aiohttp.ServerDisconnectedError, aiohttp.ClientPayloadError, aiohttp.ClientError) as e:
# Handle connection errors with potential retry
error_type = type(e).__name__
logger.warning(f"{error_type} while streaming after {self.bytes_transferred} bytes: {e}")
# Check if we should retry
if retry_count < max_retries and self.bytes_transferred > 0:
retry_count += 1
resume_from = self.start_byte + self.bytes_transferred
logger.info(f"Attempting reconnection (retry {retry_count}/{max_retries}) from byte {resume_from}")
# Wait before retry
await asyncio.sleep(settings.upstream_retry_delay)
if await self._retry_connection(resume_from):
# Successfully reconnected, continue the loop to resume streaming
continue
else:
logger.warning(f"Reconnection failed on retry {retry_count}")
# No more retries or reconnection failed
if self.bytes_transferred > 0:
logger.info(
f"Partial content received ({self.bytes_transferred} bytes). Continuing with available data."
f"Partial content received ({self.bytes_transferred} bytes). "
f"Graceful termination after {retry_count} retry attempts."
)
return
else:
# If we haven't received any data, raise an error
raise DownloadError(502, f"Remote server closed connection without sending any data: {e}")
else:
logger.error(f"Protocol error while streaming: {e}")
raise DownloadError(502, f"Protocol error while streaming: {e}")
except GeneratorExit:
logger.info("Streaming session stopped by the user")
except httpx.ReadError as e:
# Handle network read errors gracefully - these occur when upstream connection drops
logger.warning(f"ReadError while streaming: {e}")
if self.bytes_transferred > 0:
logger.info(f"Partial content received ({self.bytes_transferred} bytes) before ReadError. Graceful termination.")
raise DownloadError(502, f"{error_type} while streaming: {e}")
except GeneratorExit:
logger.info("Streaming session stopped by the user")
return
else:
raise DownloadError(502, f"ReadError while streaming: {e}")
except Exception as e:
logger.error(f"Error streaming content: {e}")
raise
@staticmethod
def format_bytes(size) -> str:
power = 2**10
@@ -263,41 +332,41 @@ class Streamer:
self.total_size = int(self.response.headers.get("Content-Length", 0))
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
async def get_text(self, url: str, headers: dict):
async def get_text(self, url: str, headers: dict) -> str:
"""
Sends a GET request to a URL and returns the response text.
Args:
url (str): The URL to send the GET request to.
headers (dict): The headers to include in the request.
url: The URL to send the GET request to.
headers: The headers to include in the request.
Returns:
str: The response text.
"""
try:
self.response = await fetch_with_retry(self.client, "GET", url, headers)
self.response = await fetch_with_retry(self.session, "GET", url, headers, proxy=self.proxy_url)
return await self.response.text()
except tenacity.RetryError as e:
raise e.last_attempt.result()
return self.response.text
async def close(self):
"""
Closes the HTTP client and response.
Closes the HTTP response and session.
"""
if self.response:
await self.response.aclose()
self.response.close()
if self.progress_bar:
self.progress_bar.close()
await self.client.aclose()
await self.session.close()
async def download_file_with_retry(url: str, headers: dict):
async def download_file_with_retry(url: str, headers: dict) -> bytes:
"""
Downloads a file with retry logic.
Args:
url (str): The URL of the file to download.
headers (dict): The headers to include in the request.
url: The URL of the file to download.
headers: The headers to include in the request.
Returns:
bytes: The downloaded file content.
@@ -305,10 +374,10 @@ async def download_file_with_retry(url: str, headers: dict):
Raises:
DownloadError: If the download fails after retries.
"""
async with create_httpx_client() as client:
async with create_aiohttp_session(url) as (session, proxy_url):
try:
response = await fetch_with_retry(client, "GET", url, headers)
return response.content
response = await fetch_with_retry(session, "GET", url, headers, proxy=proxy_url)
return await response.read()
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
raise e
@@ -316,31 +385,85 @@ async def download_file_with_retry(url: str, headers: dict):
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> ClientResponse:
"""
Sends an HTTP request with retry logic.
Args:
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to send the request to.
headers (dict): The headers to include in the request.
method: The HTTP method to use (e.g., GET, POST).
url: The URL to send the request to.
headers: The headers to include in the request.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
ClientResponse: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
async with create_httpx_client() as client:
async with create_aiohttp_session(url) as (session, proxy_url):
try:
response = await fetch_with_retry(client, method, url, headers, **kwargs)
response = await fetch_with_retry(session, method, url, headers, proxy=proxy_url, **kwargs)
# Read the content so it's available after session closes
await response.read()
return response
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
logger.error(f"Failed to make request: {e}")
raise
async def create_streamer(url: str = None) -> Streamer:
"""
Create a Streamer configured for the given URL.
The Streamer manages its own session lifecycle. Call streamer.close()
when done to release resources.
Args:
url: Optional URL for routing configuration (SSL/proxy settings).
Returns:
Streamer: A configured Streamer instance.
"""
_ensure_routing_initialized()
routing_config = get_routing_config()
route_match = routing_config.match_url(url)
# Use sock_read timeout: no total timeout, but timeout if no data received
# for sock_read seconds. This correctly handles:
# - Live streams (indefinite duration)
# - Large file downloads (total time depends on file size)
# - Seek operations (upstream may take time to seek)
# - Dead connection detection (timeout if no data flows)
timeout_config = ClientTimeout(
total=None,
sock_read=settings.transport_config.timeout,
)
connector, proxy_url = _create_connector(route_match.proxy_url, route_match.verify_ssl)
session = ClientSession(connector=connector, timeout=timeout_config)
return Streamer(session, proxy_url)
# Keep setup_streamer as alias for backward compatibility during transition
async def setup_streamer(url: str = None) -> typing.Tuple[ClientSession, str, Streamer]:
"""
Set up an aiohttp session and streamer.
DEPRECATED: Use create_streamer() instead which returns only the Streamer.
Args:
url: Optional URL for routing configuration.
Returns:
Tuple of (session, proxy_url, streamer)
"""
streamer = await create_streamer(url)
return streamer.session, streamer.proxy_url, streamer
def encode_mediaflow_proxy_url(
mediaflow_proxy_url: str,
endpoint: typing.Optional[str] = None,
@@ -348,25 +471,31 @@ def encode_mediaflow_proxy_url(
query_params: typing.Optional[dict] = None,
request_headers: typing.Optional[dict] = None,
response_headers: typing.Optional[dict] = None,
propagate_response_headers: typing.Optional[dict] = None,
remove_response_headers: typing.Optional[list[str]] = None,
encryption_handler: EncryptionHandler = None,
expiration: int = None,
ip: str = None,
filename: typing.Optional[str] = None,
stream_transformer: typing.Optional[str] = None,
) -> str:
"""
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
Args:
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
query_params (dict, optional): Additional query parameters to include. Defaults to None.
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
filename (str, optional): Filename to be preserved for media players like Infuse. Defaults to None.
mediaflow_proxy_url: The base MediaFlow proxy URL.
endpoint: The endpoint to append to the base URL. Defaults to None.
destination_url: The destination URL to include in the query parameters. Defaults to None.
query_params: Additional query parameters to include. Defaults to None.
request_headers: Headers to include as query parameters. Defaults to None.
response_headers: Headers to include as query parameters (r_ prefix). Defaults to None.
propagate_response_headers: Response headers that propagate to segments (rp_ prefix). Defaults to None.
remove_response_headers: List of response header names to remove. Defaults to None.
encryption_handler: The encryption handler to use. Defaults to None.
expiration: The expiration time for the encrypted token. Defaults to None.
ip: The public IP address to include in the query parameters. Defaults to None.
filename: Filename to be preserved for media players like Infuse. Defaults to None.
stream_transformer: ID of the stream transformer to apply. Defaults to None.
Returns:
str: The encoded MediaFlow proxy URL.
@@ -376,15 +505,45 @@ def encode_mediaflow_proxy_url(
if destination_url is not None:
query_params["d"] = destination_url
# Add headers if provided
# Add headers if provided (always use lowercase prefix for consistency)
# Filter out empty values to avoid URLs like &h_if-range=&h_referer=...
# Also exclude dynamic per-request headers (range, if-range) that are already handled
# via SUPPORTED_REQUEST_HEADERS from the player's actual request. Encoding them as h_
# query params would bake in stale values that override the player's real headers on
# subsequent requests (e.g., when seeking to a different position).
if request_headers:
query_params.update(
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
{
key if key.lower().startswith("h_") else f"h_{key}": value
for key, value in request_headers.items()
if value and (key.lower().removeprefix("h_") not in SUPPORTED_REQUEST_HEADERS)
}
)
if response_headers:
query_params.update(
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
{
key if key.lower().startswith("r_") else f"r_{key}": value
for key, value in response_headers.items()
if value # Skip empty/None values
}
)
# Add propagate response headers (rp_ prefix - these propagate to segments)
if propagate_response_headers:
query_params.update(
{
key if key.lower().startswith("rp_") else f"rp_{key}": value
for key, value in propagate_response_headers.items()
if value # Skip empty/None values
}
)
# Add remove headers if provided (x_ prefix for "exclude")
if remove_response_headers:
query_params["x_headers"] = ",".join(remove_response_headers)
# Add stream transformer if provided
if stream_transformer:
query_params["transformer"] = stream_transformer
# Construct the base URL
if endpoint is None:
@@ -441,10 +600,10 @@ def encode_stremio_proxy_url(
Format: http://127.0.0.1:11470/proxy/d=<encoded_origin>&h=<headers>&r=<response_headers>/<path><query>
Args:
stremio_proxy_url (str): The base Stremio proxy URL.
destination_url (str): The destination URL to proxy.
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
response_headers (dict, optional): Response headers to include as query parameters. Defaults to None.
stremio_proxy_url: The base Stremio proxy URL.
destination_url: The destination URL to proxy.
request_headers: Headers to include as query parameters. Defaults to None.
response_headers: Response headers to include as query parameters. Defaults to None.
Returns:
str: The encoded Stremio proxy URL.
@@ -498,7 +657,7 @@ def get_original_scheme(request: Request) -> str:
Determines the original scheme (http or https) of the request.
Args:
request (Request): The incoming HTTP request.
request: The incoming HTTP request.
Returns:
str: The original scheme ('http' or 'https')
@@ -528,6 +687,35 @@ def get_original_scheme(request: Request) -> str:
class ProxyRequestHeaders:
request: dict
response: dict
remove: list # headers to remove from response
propagate: dict # response headers to propagate to segments (rp_ prefix)
def apply_header_manipulation(
base_headers: dict, proxy_headers: ProxyRequestHeaders, include_propagate: bool = True
) -> dict:
"""
Apply response header additions and removals.
This function filters out headers specified in proxy_headers.remove,
then merges in headers from proxy_headers.response and optionally proxy_headers.propagate.
Args:
base_headers: The base headers to start with.
proxy_headers: The proxy headers containing response additions and removals.
include_propagate: Whether to include propagate headers (rp_).
Set to False for manifests, True for segments. Defaults to True.
Returns:
dict: The manipulated headers.
"""
remove_set = set(h.lower() for h in proxy_headers.remove)
result = {k: v for k, v in base_headers.items() if k.lower() not in remove_set}
# Apply propagate headers first (for segments), then response headers (response takes precedence)
if include_propagate:
result.update(proxy_headers.propagate)
result.update(proxy_headers.response)
return result
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
@@ -535,32 +723,42 @@ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
Extracts proxy headers from the request query parameters.
Args:
request (Request): The incoming HTTP request.
request: The incoming HTTP request.
Returns:
ProxyRequest: A named tuple containing the request headers and response headers.
ProxyRequest: A named tuple containing the request headers, response headers, and headers to remove.
"""
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS and v}
# Extract h_ prefixed headers from query params, filtering out empty values
for k, v in request.query_params.items():
if k.lower().startswith("h_") and v: # Skip empty values
request_headers[k[2:].lower()] = v
request_headers.setdefault("user-agent", settings.user_agent)
# Handle common misspelling of referer
if "referrer" in request_headers:
if "referer" not in request_headers:
request_headers["referer"] = request_headers.pop("referrer")
dest = request.query_params.get("d", "")
host = urlparse(dest).netloc.lower()
if "vidoza" in host or "videzz" in host:
# Remove ALL empty headers
for h in list(request_headers.keys()):
v = request_headers[h]
if v is None or v.strip() == "":
request_headers.pop(h, None)
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
return ProxyRequestHeaders(request_headers, response_headers)
# r_ prefix: response headers (manifest only, not propagated to segments)
# Filter out empty values
response_headers = {
k[2:].lower(): v
for k, v in request.query_params.items()
if k.lower().startswith("r_") and not k.lower().startswith("rp_") and v
}
# rp_ prefix: response headers that propagate to segments
# Filter out empty values
propagate_headers = {k[3:].lower(): v for k, v in request.query_params.items() if k.lower().startswith("rp_") and v}
# Parse headers to remove from response (x_headers parameter)
x_headers_param = request.query_params.get("x_headers", "")
remove_headers = [h.strip().lower() for h in x_headers_param.split(",") if h.strip()] if x_headers_param else []
return ProxyRequestHeaders(request_headers, response_headers, remove_headers, propagate_headers)
class EnhancedStreamingResponse(Response):
@@ -632,19 +830,19 @@ class EnhancedStreamingResponse(Response):
# Successfully streamed all content
await send({"type": "http.response.body", "body": b"", "more_body": False})
finalization_sent = True
except (httpx.RemoteProtocolError, httpx.ReadError, h11._util.LocalProtocolError) as e:
except (aiohttp.ServerDisconnectedError, aiohttp.ClientPayloadError, aiohttp.ClientError) as e:
# Handle connection closed / read errors gracefully
if data_sent:
# We've sent some data to the client, so try to complete the response
logger.warning(f"Upstream connection error after partial streaming: {e}")
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
finalization_sent = True
logger.info(
f"Response finalized after partial content ({self.actual_content_length} bytes transferred)"
)
except Exception as close_err:
logger.warning(f"Could not finalize response after upstream error: {close_err}")
# We've sent some data to the client. With Content-Length set, we cannot
# gracefully finalize a partial response - h11 will raise LocalProtocolError
# if we try to send more_body: False without delivering all promised bytes.
# The best we can do is log and return silently, letting the client handle
# the incomplete response (most players will just stop or retry).
logger.warning(
f"Upstream connection error after partial streaming ({self.actual_content_length} bytes transferred): {e}"
)
# Don't try to finalize - just return and let the connection close naturally
return
else:
# No data was sent, re-raise the error
logger.error(f"Upstream error before any data was streamed: {e}")
@@ -667,13 +865,16 @@ class EnhancedStreamingResponse(Response):
except Exception:
# If we can't send an error response, just log it
pass
elif response_started and not finalization_sent:
# Response already started but not finalized - gracefully close the stream
elif response_started and not finalization_sent and not data_sent:
# Response started but no data sent yet - we can safely finalize
# (If data was sent with Content-Length, we can't finalize without h11 error)
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
finalization_sent = True
except Exception:
pass
# If data was sent but streaming failed, just return silently
# The client will see an incomplete response which is unavoidable with Content-Length
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:

View File

@@ -1,7 +1,9 @@
import asyncio
import codecs
import logging
import re
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from urllib import parse
from mediaflow_proxy.configs import settings
@@ -9,9 +11,121 @@ from mediaflow_proxy.utils.crypto_utils import encryption_handler
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, encode_stremio_proxy_url, get_original_scheme
from mediaflow_proxy.utils.hls_prebuffer import hls_prebuffer
logger = logging.getLogger(__name__)
def generate_graceful_end_playlist(message: str = "Stream ended") -> str:
"""
Generate a minimal valid m3u8 playlist that signals stream end.
This is used when upstream fails but we want to provide a graceful
end to the player instead of an abrupt error. Most players will
interpret this as the stream ending normally.
Args:
message: Optional message to include as a comment.
Returns:
str: A valid m3u8 playlist string with EXT-X-ENDLIST.
"""
return f"""#EXTM3U
#EXT-X-VERSION:3
#EXT-X-TARGETDURATION:1
#EXT-X-PLAYLIST-TYPE:VOD
# {message}
#EXT-X-ENDLIST
"""
def generate_error_playlist(error_message: str = "Stream unavailable") -> str:
"""
Generate a minimal valid m3u8 playlist for error scenarios.
Unlike generate_graceful_end_playlist, this includes a very short
segment duration to signal something went wrong while still being
a valid playlist that players can parse.
Args:
error_message: Error message to include as a comment.
Returns:
str: A valid m3u8 playlist string.
"""
return f"""#EXTM3U
#EXT-X-VERSION:3
#EXT-X-TARGETDURATION:1
#EXT-X-PLAYLIST-TYPE:VOD
# Error: {error_message}
#EXT-X-ENDLIST
"""
class SkipSegmentFilter:
"""
Helper class to filter HLS segments based on time ranges.
Tracks cumulative playback time and determines which segments
should be skipped based on the provided skip segment list.
"""
def __init__(self, skip_segments: Optional[List[dict]] = None):
"""
Initialize the skip segment filter.
Args:
skip_segments: List of skip segment dicts with 'start' and 'end' keys.
"""
self.skip_segments = skip_segments or []
self.current_time = 0.0 # Cumulative playback time in seconds
def should_skip_segment(self, duration: float) -> bool:
"""
Determine if the current segment should be skipped.
Args:
duration: Duration of the current segment in seconds.
Returns:
True if the segment overlaps with any skip range, False otherwise.
"""
segment_start = self.current_time
segment_end = self.current_time + duration
# Check if this segment overlaps with any skip range
for skip in self.skip_segments:
skip_start = skip.get("start", 0)
skip_end = skip.get("end", 0)
# Check for overlap: segment overlaps if it starts before skip ends AND ends after skip starts
if segment_start < skip_end and segment_end > skip_start:
logger.debug(
f"Skipping segment at {segment_start:.2f}s-{segment_end:.2f}s "
f"(overlaps with skip range {skip_start:.2f}s-{skip_end:.2f}s)"
)
return True
return False
def advance_time(self, duration: float):
"""Advance the cumulative playback time."""
self.current_time += duration
def has_skip_segments(self) -> bool:
"""Check if there are any skip segments configured."""
return bool(self.skip_segments)
class M3U8Processor:
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None, key_only_proxy: bool = False, no_proxy: bool = False):
def __init__(
self,
request,
key_url: str = None,
force_playlist_proxy: bool = None,
key_only_proxy: bool = False,
no_proxy: bool = False,
skip_segments: Optional[List[dict]] = None,
start_offset: Optional[float] = None,
):
"""
Initializes the M3U8Processor with the request and URL prefix.
@@ -21,21 +135,65 @@ class M3U8Processor:
force_playlist_proxy (bool, optional): Force all playlist URLs to be proxied through MediaFlow. Defaults to None.
key_only_proxy (bool, optional): Only proxy the key URL, leaving segment URLs direct. Defaults to False.
no_proxy (bool, optional): If True, returns the manifest without proxying any URLs. Defaults to False.
skip_segments (List[dict], optional): List of time segments to skip. Each dict should have
'start', 'end' (in seconds), and optionally 'type'.
start_offset (float, optional): Time offset in seconds for EXT-X-START tag. Use negative values
for live streams to start behind the live edge.
"""
self.request = request
self.key_url = parse.urlparse(key_url) if key_url else None
self.key_only_proxy = key_only_proxy
self.no_proxy = no_proxy
self.force_playlist_proxy = force_playlist_proxy
self.skip_filter = SkipSegmentFilter(skip_segments)
# Track if user explicitly provided start_offset (vs using default)
self._user_provided_start_offset = start_offset is not None
# Store the explicit value or default (will be applied conditionally for live streams)
self._start_offset_value = start_offset if start_offset is not None else settings.livestream_start_offset
self.mediaflow_proxy_url = str(
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
)
# Base URL for segment proxy - extension will be appended based on actual segment
# url_for with path param returns URL with placeholder, so we build it manually
self.segment_proxy_base_url = str(
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
).replace("/hls/manifest.m3u8", "/hls/segment")
self.playlist_url = None # Will be set when processing starts
def _should_apply_start_offset(self, content: str) -> bool:
"""
Determine if start_offset should be applied to this playlist.
Args:
content: The playlist content to check.
Returns:
True if start_offset should be applied, False otherwise.
"""
if self._start_offset_value is None:
return False
# If user explicitly provided start_offset, always use it
if self._user_provided_start_offset:
return True
# Using default from settings - only apply for live streams
# VOD playlists have #EXT-X-ENDLIST tag or #EXT-X-PLAYLIST-TYPE:VOD
# Also skip master playlists (they have #EXT-X-STREAM-INF)
is_vod = "#EXT-X-ENDLIST" in content or "#EXT-X-PLAYLIST-TYPE:VOD" in content
is_master = "#EXT-X-STREAM-INF" in content
return not is_vod and not is_master
async def process_m3u8(self, content: str, base_url: str) -> str:
"""
Processes the m3u8 content, proxying URLs and handling key lines.
For content filtering with skip_segments, this follows the IntroHater approach:
- Segments within skip ranges are completely removed (EXTINF + URL)
- A #EXT-X-DISCONTINUITY marker is added BEFORE the URL of the first segment
after a skipped section (not before the EXTINF)
Args:
content (str): The m3u8 content to process.
base_url (str): The base URL to resolve relative URLs.
@@ -45,35 +203,131 @@ class M3U8Processor:
"""
# Store the playlist URL for prebuffering
self.playlist_url = base_url
lines = content.splitlines()
processed_lines = []
for line in lines:
# Track if we need to add discontinuity before next URL (after skipping segments)
discontinuity_pending = False
# Buffer the current EXTINF line - only output when we output the URL
pending_extinf: Optional[str] = None
# Track if we've injected EXT-X-START tag
start_offset_injected = False
# Determine if we should apply start_offset (checks if live stream)
apply_start_offset = self._should_apply_start_offset(content)
i = 0
while i < len(lines):
line = lines[i]
# Inject EXT-X-START tag right after #EXTM3U (only for live streams or if user explicitly requested)
if line.strip() == "#EXTM3U" and apply_start_offset and not start_offset_injected:
processed_lines.append(line)
processed_lines.append(f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES")
start_offset_injected = True
i += 1
continue
# Handle EXTINF lines (segment duration markers)
if line.startswith("#EXTINF:"):
duration = self._parse_extinf_duration(line)
if self.skip_filter.has_skip_segments() and self.skip_filter.should_skip_segment(duration):
# Skip this segment entirely - don't buffer the EXTINF
discontinuity_pending = True # Mark that we need discontinuity before next kept segment
self.skip_filter.advance_time(duration)
pending_extinf = None
i += 1
continue
else:
# Keep this segment
self.skip_filter.advance_time(duration)
pending_extinf = line
i += 1
continue
# Handle segment URLs (non-comment, non-empty lines)
if not line.startswith("#") and line.strip():
if pending_extinf is None:
# No pending EXTINF means this segment was skipped
i += 1
continue
# Add discontinuity BEFORE the EXTINF if we just skipped segments
# Per HLS spec, EXT-X-DISCONTINUITY must appear before the first segment of the new content
if discontinuity_pending:
processed_lines.append("#EXT-X-DISCONTINUITY")
discontinuity_pending = False
# Output the buffered EXTINF and proxied URL
processed_lines.append(pending_extinf)
processed_lines.append(await self.proxy_content_url(line, base_url))
pending_extinf = None
i += 1
continue
# Handle existing discontinuity markers - pass through but reset pending flag
if line.startswith("#EXT-X-DISCONTINUITY"):
processed_lines.append(line)
discontinuity_pending = False # Don't add duplicate
i += 1
continue
# Handle key lines
if "URI=" in line:
processed_lines.append(await self.process_key_line(line, base_url))
elif not line.startswith("#") and line.strip():
processed_lines.append(await self.proxy_content_url(line, base_url))
else:
processed_lines.append(line)
# Pre-buffer segments if enabled and this is a playlist
if (settings.enable_hls_prebuffer and
"#EXTM3U" in content and
self.playlist_url):
# Extract headers from request for pre-buffering
headers = {}
for key, value in self.request.query_params.items():
if key.startswith("h_"):
headers[key[2:]] = value
# Start pre-buffering in background using the actual playlist URL
asyncio.create_task(
hls_prebuffer.prebuffer_playlist(self.playlist_url, headers)
)
i += 1
continue
# All other lines (headers, comments, etc.)
processed_lines.append(line)
i += 1
# Log skip statistics
if self.skip_filter.has_skip_segments():
logger.info(f"Content filtering: processed playlist with {len(self.skip_filter.skip_segments)} skip ranges")
# Register playlist with the priority-based prefetcher
if settings.enable_hls_prebuffer and "#EXTM3U" in content and self.playlist_url:
# Skip master playlists
if "#EXT-X-STREAM-INF" not in content:
segment_urls = self._extract_segment_urls_from_content(content, self.playlist_url)
if segment_urls:
headers = {}
for key, value in self.request.query_params.items():
if key.startswith("h_"):
headers[key[2:]] = value
logger.info(
f"[M3U8Processor] Registering playlist ({len(segment_urls)} segments): {self.playlist_url}"
)
asyncio.create_task(
hls_prebuffer.register_playlist(
self.playlist_url,
segment_urls,
headers,
)
)
return "\n".join(processed_lines)
def _parse_extinf_duration(self, line: str) -> float:
"""
Parse the duration from an #EXTINF line.
Args:
line: The #EXTINF line (e.g., "#EXTINF:10.0," or "#EXTINF:10,title")
Returns:
The duration in seconds as a float.
"""
# Format: #EXTINF:<duration>[,<title>]
match = re.match(r"#EXTINF:(\d+(?:\.\d+)?)", line)
if match:
return float(match.group(1))
return 0.0
async def process_m3u8_streaming(
self, content_iterator: AsyncGenerator[bytes, None], base_url: str
) -> AsyncGenerator[str, None]:
@@ -81,20 +335,37 @@ class M3U8Processor:
Processes the m3u8 content on-the-fly, yielding processed lines as they are read.
Optimized to avoid accumulating the entire playlist content in memory.
Note: When skip_segments are configured, this method buffers lines to properly
handle EXTINF + segment URL pairs that need to be skipped together.
Args:
content_iterator: An async iterator that yields chunks of the m3u8 content.
base_url (str): The base URL to resolve relative URLs.
Yields:
str: Processed lines of the m3u8 content.
Raises:
ValueError: If the content is not a valid m3u8 playlist (e.g., HTML error page).
"""
# Store the playlist URL for prebuffering
self.playlist_url = base_url
buffer = "" # String buffer for decoded content
raw_content = "" # Accumulate raw content for prebuffer
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
is_playlist_detected = False
is_prebuffer_started = False
is_html_detected = False
initial_check_done = False
# State for skip segment filtering
discontinuity_pending = False # Track if we need discontinuity before next URL
pending_extinf = None # Buffer EXTINF line until we decide to emit it
# Track if we've injected EXT-X-START tag
start_offset_injected = False
# Buffer header lines until we know if it's a master playlist (for default start_offset)
header_buffer = []
header_flushed = False
# Process the content chunk by chunk
async for chunk in content_iterator:
@@ -104,6 +375,24 @@ class M3U8Processor:
# Incrementally decode the chunk
decoded_chunk = decoder.decode(chunk)
buffer += decoded_chunk
raw_content += decoded_chunk # Accumulate for prebuffer
# Early detection: check if this is HTML instead of m3u8
# This helps catch upstream error pages quickly
if not initial_check_done and len(buffer) > 50:
initial_check_done = True
buffer_lower = buffer.lower().strip()
# Check for HTML markers
if buffer_lower.startswith("<!doctype") or buffer_lower.startswith("<html"):
is_html_detected = True
logger.error(f"Upstream returned HTML instead of m3u8 playlist: {base_url}")
# Raise an error so the HTTP handler returns a proper error response
# This allows the player to retry or show an error instead of thinking
# the stream has ended normally
raise ValueError(
f"Upstream returned HTML instead of m3u8 playlist. "
f"The stream may be offline or unavailable: {base_url}"
)
# Check for playlist marker early to avoid accumulating content
if not is_playlist_detected and "#EXTM3U" in buffer:
@@ -114,40 +403,246 @@ class M3U8Processor:
if len(lines) > 1:
# Process all complete lines except the last one
for line in lines[:-1]:
if line: # Skip empty lines
if not line: # Skip empty lines
continue
# Buffer header lines until we can determine playlist type
# This allows us to decide whether to inject EXT-X-START
if not header_flushed:
# Always buffer the current line first
header_buffer.append(line)
# Check if we can now determine playlist type
# Only check the current line, not raw_content (which may contain future content)
is_master = "#EXT-X-STREAM-INF" in line
is_media = "#EXTINF" in line
if is_master or is_media:
# For non-user-provided (default) start_offset, determine if this
# is a live stream before injecting. We need to avoid injecting
# EXT-X-START with negative offsets into VOD playlists, as players
# like VLC interpret negative offsets as "from the end" and start
# playing near the end of the video.
#
# Live stream indicators (checked in header):
# - No #EXT-X-PLAYLIST-TYPE:VOD tag
# - No #EXT-X-ENDLIST tag (may not be visible yet in streaming)
# - #EXT-X-MEDIA-SEQUENCE > 0 (live windows have rolling sequence)
#
# VOD indicators:
# - #EXT-X-PLAYLIST-TYPE:VOD in header
# - #EXT-X-ENDLIST in raw_content (if small enough to be buffered)
# - #EXT-X-MEDIA-SEQUENCE:0 or absent (VOD starts from beginning)
header_content = "\n".join(header_buffer)
all_content = header_content + "\n" + raw_content
is_explicitly_vod = (
"#EXT-X-PLAYLIST-TYPE:VOD" in all_content or "#EXT-X-ENDLIST" in all_content
)
# Check for live stream indicator: #EXT-X-MEDIA-SEQUENCE with value > 0
# Live streams have a rolling window so their media sequence increments
is_likely_live = False
seq_match = re.search(r"#EXT-X-MEDIA-SEQUENCE:\s*(\d+)", all_content)
if seq_match and int(seq_match.group(1)) > 0:
is_likely_live = True
# Flush header buffer with or without EXT-X-START
should_inject = (
self._start_offset_value is not None
and not is_master
and (
self._user_provided_start_offset
or (is_media and not is_explicitly_vod and is_likely_live)
) # User provided OR it's a live media playlist
)
for header_line in header_buffer:
# Process header lines to rewrite URLs (e.g., #EXT-X-KEY)
processed_header_line = await self.process_line(header_line, base_url)
yield processed_header_line + "\n"
if header_line.strip() == "#EXTM3U" and should_inject and not start_offset_injected:
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
start_offset_injected = True
header_buffer = []
header_flushed = True
# If not master/media yet, continue buffering (line already added above)
continue
# If user explicitly provided start_offset and we haven't injected yet
# (handles edge case where we flush header before seeing EXTINF/STREAM-INF)
if (
line.strip() == "#EXTM3U"
and self._user_provided_start_offset
and self._start_offset_value is not None
and not start_offset_injected
):
yield line + "\n"
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
start_offset_injected = True
continue
# Handle segment filtering if skip_segments are configured
if self.skip_filter.has_skip_segments():
result = await self._process_line_with_filtering(
line, base_url, discontinuity_pending, pending_extinf
)
processed_line, discontinuity_pending, pending_extinf = result
if processed_line is not None:
yield processed_line + "\n"
else:
# No filtering, process normally
processed_line = await self.process_line(line, base_url)
yield processed_line + "\n"
# Keep the last line in the buffer (it might be incomplete)
buffer = lines[-1]
# Start pre-buffering early once we detect this is a playlist
# This avoids waiting until the entire playlist is processed
if (settings.enable_hls_prebuffer and
is_playlist_detected and
not is_prebuffer_started and
self.playlist_url):
# Extract headers from request for pre-buffering
headers = {}
for key, value in self.request.query_params.items():
if key.startswith("h_"):
headers[key[2:]] = value
# Start pre-buffering in background using the actual playlist URL
asyncio.create_task(
hls_prebuffer.prebuffer_playlist(self.playlist_url, headers)
# If HTML was detected, we already returned an error playlist
if is_html_detected:
return
# Flush any remaining header buffer (for short playlists or edge cases)
# At this point we have the full raw_content so we can make a definitive determination
if header_buffer and not header_flushed:
is_master = "#EXT-X-STREAM-INF" in raw_content
is_vod = "#EXT-X-ENDLIST" in raw_content or "#EXT-X-PLAYLIST-TYPE:VOD" in raw_content
# For default offset, also require positive live indicator
is_likely_live = False
seq_match = re.search(r"#EXT-X-MEDIA-SEQUENCE:\s*(\d+)", raw_content)
if seq_match and int(seq_match.group(1)) > 0:
is_likely_live = True
should_inject = (
self._start_offset_value is not None
and not is_master
and (
self._user_provided_start_offset
or (not is_vod and is_likely_live) # Default offset: only inject for live streams
)
is_prebuffer_started = True
)
for header_line in header_buffer:
yield header_line + "\n"
if header_line.strip() == "#EXTM3U" and should_inject and not start_offset_injected:
yield f"#EXT-X-START:TIME-OFFSET={self._start_offset_value:.1f},PRECISE=YES\n"
start_offset_injected = True
header_buffer = []
# Process any remaining data in the buffer plus final bytes
final_chunk = decoder.decode(b"", final=True)
if final_chunk:
buffer += final_chunk
# Final validation: if we never detected a valid m3u8 playlist marker
if not is_playlist_detected:
logger.error(f"Invalid m3u8 content from upstream (no #EXTM3U marker found): {base_url}")
yield "#EXTM3U\n"
yield "#EXT-X-PLAYLIST-TYPE:VOD\n"
yield "# ERROR: Invalid m3u8 content from upstream (no #EXTM3U marker found)\n"
yield "# The upstream server may have returned an error page\n"
yield "#EXT-X-ENDLIST\n"
return
if buffer: # Process the last line if it's not empty
processed_line = await self.process_line(buffer, base_url)
yield processed_line
if self.skip_filter.has_skip_segments():
result = await self._process_line_with_filtering(
buffer, base_url, discontinuity_pending, pending_extinf
)
processed_line, _, _ = result
if processed_line is not None:
yield processed_line
else:
processed_line = await self.process_line(buffer, base_url)
yield processed_line
# Log skip statistics
if self.skip_filter.has_skip_segments():
logger.info(f"Content filtering: processed playlist with {len(self.skip_filter.skip_segments)} skip ranges")
# Register playlist with the priority-based prefetcher
# The prefetcher uses a smart approach:
# 1. When player requests a segment, it gets priority (downloaded first)
# 2. After serving priority segment, prefetcher continues sequentially
# 3. Multiple users watching same channel share the prefetcher
# 4. Inactive prefetchers are cleaned up automatically
if settings.enable_hls_prebuffer and is_playlist_detected and self.playlist_url and raw_content:
# Skip master playlists (they contain variant streams, not segments)
if "#EXT-X-STREAM-INF" not in raw_content:
# Extract segment URLs from the playlist
segment_urls = self._extract_segment_urls_from_content(raw_content, self.playlist_url)
if segment_urls:
# Extract headers for prefetcher
headers = {}
for key, value in self.request.query_params.items():
if key.startswith("h_"):
headers[key[2:]] = value
logger.info(
f"[M3U8Processor] Registering playlist ({len(segment_urls)} segments): {self.playlist_url}"
)
asyncio.create_task(
hls_prebuffer.register_playlist(
self.playlist_url,
segment_urls,
headers,
)
)
async def _process_line_with_filtering(
self, line: str, base_url: str, discontinuity_pending: bool, pending_extinf: Optional[str]
) -> tuple:
"""
Process a single line with segment filtering (skip/mute/black).
Uses the IntroHater approach: discontinuity is added BEFORE the URL of the
first segment after a skipped section, not before the EXTINF.
Returns a tuple of (processed_lines, discontinuity_pending, pending_extinf).
processed_lines is None if the line should be skipped, otherwise a string to output.
"""
# Handle EXTINF lines (segment duration markers)
if line.startswith("#EXTINF:"):
duration = self._parse_extinf_duration(line)
if self.skip_filter.should_skip_segment(duration):
# Skip this segment - don't buffer the EXTINF
self.skip_filter.advance_time(duration)
return (None, True, None) # discontinuity_pending = True, clear pending
else:
# Keep this segment
self.skip_filter.advance_time(duration)
return (None, discontinuity_pending, line) # Buffer EXTINF
# Handle segment URLs (non-comment, non-empty lines)
if not line.startswith("#") and line.strip():
if pending_extinf is None:
# No pending EXTINF means this segment was skipped
return (None, discontinuity_pending, None)
# Build output: optional discontinuity + EXTINF + URL
# Per HLS spec, EXT-X-DISCONTINUITY must appear before the first segment of the new content
processed_url = await self.proxy_content_url(line, base_url)
output_lines = []
if discontinuity_pending:
output_lines.append("#EXT-X-DISCONTINUITY")
output_lines.append(pending_extinf)
output_lines.append(processed_url)
return ("\n".join(output_lines), False, None)
# Handle existing discontinuity markers - pass through and reset pending
if line.startswith("#EXT-X-DISCONTINUITY"):
return (line, False, pending_extinf)
# Handle key lines
if "URI=" in line:
processed = await self.process_key_line(line, base_url)
return (processed, discontinuity_pending, pending_extinf)
# All other lines (headers, comments, etc.)
return (line, discontinuity_pending, pending_extinf)
async def process_line(self, line: str, base_url: str) -> str:
"""
@@ -186,14 +681,23 @@ class M3U8Processor:
full_url = parse.urljoin(base_url, original_uri)
line = line.replace(f'URI="{original_uri}"', f'URI="{full_url}"')
return line
uri_match = re.search(r'URI="([^"]+)"', line)
if uri_match:
original_uri = uri_match.group(1)
uri = parse.urlparse(original_uri)
if self.key_url:
# Only substitute key_url scheme/netloc for actual EXT-X-KEY lines.
# EXT-X-MAP (init segments) and other tags must keep their original host,
# otherwise the proxied destination URL gets the wrong upstream hostname.
if self.key_url and line.startswith("#EXT-X-KEY"):
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
new_uri = await self.proxy_url(uri.geturl(), base_url)
# Check if this is a DLHD stream with key params (needs stream endpoint for header computation)
query_params = dict(self.request.query_params)
is_dlhd_key_request = "dlhd_salt" in query_params and "/key/" in uri.geturl()
# Use stream endpoint for DLHD key URLs, manifest endpoint for others
new_uri = await self.proxy_url(
uri.geturl(), base_url, use_full_url=True, is_playlist=not is_dlhd_key_request
)
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
return line
@@ -223,16 +727,19 @@ class M3U8Processor:
# Check if we should force MediaFlow proxy for all playlist URLs
if self.force_playlist_proxy:
return await self.proxy_url(full_url, base_url, use_full_url=True)
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
# For playlist URLs, always use MediaFlow proxy regardless of strategy
# Check for actual playlist file extensions, not just substring matches
parsed_url = parse.urlparse(full_url)
if (parsed_url.path.endswith((".m3u", ".m3u8", ".m3u_plus")) or
parse.parse_qs(parsed_url.query).get("type", [""])[0] in ["m3u", "m3u8", "m3u_plus"]):
return await self.proxy_url(full_url, base_url, use_full_url=True)
is_playlist_url = parsed_url.path.endswith((".m3u", ".m3u8", ".m3u_plus")) or parse.parse_qs(
parsed_url.query
).get("type", [""])[0] in ["m3u", "m3u8", "m3u_plus"]
# Route non-playlist content URLs based on strategy
if is_playlist_url:
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=True)
# Route non-playlist content URLs (segments) based on strategy
if routing_strategy == "direct":
# Return the URL directly without any proxying
return full_url
@@ -250,9 +757,33 @@ class M3U8Processor:
)
else:
# Default to MediaFlow proxy (routing_strategy == "mediaflow" or fallback)
return await self.proxy_url(full_url, base_url, use_full_url=True)
# Use stream endpoint for segment URLs
return await self.proxy_url(full_url, base_url, use_full_url=True, is_playlist=False)
async def proxy_url(self, url: str, base_url: str, use_full_url: bool = False) -> str:
def _extract_segment_urls_from_content(self, content: str, base_url: str) -> list:
"""
Extract segment URLs from HLS playlist content.
Args:
content: Raw playlist content
base_url: Base URL for resolving relative URLs
Returns:
List of absolute segment URLs
"""
segment_urls = []
for line in content.split("\n"):
line = line.strip()
if line and not line.startswith("#"):
# Absolute URL
if line.startswith("http://") or line.startswith("https://"):
segment_urls.append(line)
else:
# Relative URL - resolve against base
segment_urls.append(parse.urljoin(base_url, line))
return segment_urls
async def proxy_url(self, url: str, base_url: str, use_full_url: bool = False, is_playlist: bool = True) -> str:
"""
Proxies a URL, encoding it with the MediaFlow proxy URL.
@@ -260,6 +791,7 @@ class M3U8Processor:
url (str): The URL to proxy.
base_url (str): The base URL to resolve relative URLs.
use_full_url (bool): Whether to use the URL as-is (True) or join with base_url (False).
is_playlist (bool): Whether this is a playlist URL (uses manifest endpoint) or segment URL (uses stream endpoint).
Returns:
str: The proxied URL.
@@ -271,15 +803,52 @@ class M3U8Processor:
query_params = dict(self.request.query_params)
has_encrypted = query_params.pop("has_encrypted", False)
# Remove the response headers from the query params to avoid it being added to the consecutive requests
[query_params.pop(key, None) for key in list(query_params.keys()) if key.startswith("r_")]
# Remove force_playlist_proxy to avoid it being added to subsequent requests
# Remove the response headers (r_) from the query params to avoid it being added to the consecutive requests
# BUT keep rp_ (response propagate) headers as they should propagate to segments
[
query_params.pop(key, None)
for key in list(query_params.keys())
if key.lower().startswith("r_") and not key.lower().startswith("rp_")
]
# Remove manifest-only parameters to avoid them being added to subsequent requests
query_params.pop("force_playlist_proxy", None)
if not is_playlist:
query_params.pop("start_offset", None)
# Use appropriate proxy URL based on content type
if is_playlist:
proxy_url = self.mediaflow_proxy_url
else:
# Check if this is a DLHD key URL (needs /stream endpoint for header computation)
is_dlhd_key = "dlhd_salt" in query_params and "/key/" in full_url
if is_dlhd_key:
# Use /stream endpoint for DLHD key URLs
proxy_url = self.mediaflow_proxy_url.replace("/hls/manifest.m3u8", "/stream")
else:
# Determine segment extension from the URL
# Default to .ts for traditional HLS, but detect fMP4 extensions
segment_ext = "ts"
url_lower = full_url.lower()
# Check for fMP4/CMAF extensions
if url_lower.endswith(".m4s"):
segment_ext = "m4s"
elif url_lower.endswith(".mp4"):
segment_ext = "mp4"
elif url_lower.endswith(".m4a"):
segment_ext = "m4a"
elif url_lower.endswith(".m4v"):
segment_ext = "m4v"
elif url_lower.endswith(".aac"):
segment_ext = "aac"
# Build segment proxy URL with correct extension
proxy_url = f"{self.segment_proxy_base_url}.{segment_ext}"
# Remove h_range header - each segment should handle its own range requests
query_params.pop("h_range", None)
return encode_mediaflow_proxy_url(
self.mediaflow_proxy_url,
"",
proxy_url,
None, # No endpoint - URL is already complete
full_url,
query_params=query_params,
encryption_handler=encryption_handler if has_encrypted else None,
)
)

View File

@@ -10,6 +10,35 @@ import xmltodict
logger = logging.getLogger(__name__)
def resolve_url(base_url: str, relative_url: str) -> str:
"""
Resolve a relative URL against a base URL.
Handles three cases:
1. Absolute URL (starts with http:// or https://) - return as-is
2. Absolute path (starts with /) - resolve against origin (scheme + host)
3. Relative path - resolve against base URL directory
Args:
base_url: The base URL (typically the MPD URL)
relative_url: The URL to resolve
Returns:
The resolved absolute URL
"""
if not relative_url:
return base_url
# Already absolute URL
if relative_url.startswith(("http://", "https://")):
return relative_url
# Use urljoin which correctly handles:
# - Absolute paths (starting with /) -> resolves against origin
# - Relative paths -> resolves against base URL
return urljoin(base_url, relative_url)
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
"""
Parses the MPD content into a dictionary.
@@ -43,7 +72,6 @@ def parse_mpd_dict(
"""
profiles = []
parsed_dict = {}
source = "/".join(mpd_url.split("/")[:-1])
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
parsed_dict["isLive"] = is_live
@@ -66,7 +94,10 @@ def parse_mpd_dict(
for period in periods:
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
for adaptation in period["AdaptationSet"]:
adaptation_sets = period["AdaptationSet"]
adaptation_sets = adaptation_sets if isinstance(adaptation_sets, list) else [adaptation_sets]
for adaptation in adaptation_sets:
representations = adaptation["Representation"]
representations = representations if isinstance(representations, list) else [representations]
@@ -75,7 +106,7 @@ def parse_mpd_dict(
parsed_dict,
representation,
adaptation,
source,
mpd_url,
media_presentation_duration,
parse_segment_profile_id,
)
@@ -195,7 +226,7 @@ def parse_representation(
parsed_dict: dict,
representation: dict,
adaptation: dict,
source: str,
mpd_url: str,
media_presentation_duration: str,
parse_segment_profile_id: Optional[str],
) -> Optional[dict]:
@@ -206,7 +237,7 @@ def parse_representation(
parsed_dict (dict): The parsed MPD data.
representation (dict): The representation data.
adaptation (dict): The adaptation set data.
source (str): The source URL.
mpd_url (str): The URL of the MPD manifest.
media_presentation_duration (str): The media presentation duration.
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
@@ -263,14 +294,50 @@ def parse_representation(
else:
profile["segment_template_start_number"] = 1
# For SegmentBase profiles, we need to set initUrl even when not parsing segments
# This is needed for the HLS playlist builder to reference the init URL
segment_base_data = representation.get("SegmentBase")
if segment_base_data and "initUrl" not in profile:
base_url = representation.get("BaseURL", "")
profile["initUrl"] = resolve_url(mpd_url, base_url)
# Store initialization range if available
if "Initialization" in segment_base_data:
init_range = segment_base_data["Initialization"].get("@range")
if init_range:
profile["initRange"] = init_range
# For SegmentList profiles, we also need to set initUrl even when not parsing segments
segment_list_data = representation.get("SegmentList") or adaptation.get("SegmentList")
if segment_list_data and "initUrl" not in profile:
if "Initialization" in segment_list_data:
init_data = segment_list_data["Initialization"]
if "@sourceURL" in init_data:
init_url = init_data["@sourceURL"]
profile["initUrl"] = resolve_url(mpd_url, init_url)
elif "@range" in init_data:
base_url = representation.get("BaseURL", "")
profile["initUrl"] = resolve_url(mpd_url, base_url)
profile["initRange"] = init_data["@range"]
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
return profile
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
if item:
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
# Parse segments based on the addressing scheme used
segment_template = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
segment_list = adaptation.get("SegmentList") or representation.get("SegmentList")
# Get BaseURL from representation (can be relative path like "a/b/c/")
base_url = representation.get("BaseURL", "")
if segment_template:
profile["segments"] = parse_segment_template(parsed_dict, segment_template, profile, mpd_url, base_url)
elif segment_list:
# Get timescale from SegmentList or default to 1
timescale = int(segment_list.get("@timescale", 1))
profile["segments"] = parse_segment_list(adaptation, representation, profile, mpd_url, timescale)
else:
profile["segments"] = parse_segment_base(representation, profile, source)
profile["segments"] = parse_segment_base(representation, profile, mpd_url)
return profile
@@ -290,7 +357,9 @@ def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
return representation.get(key, adaptation.get(key, None))
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
def parse_segment_template(
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, base_url: str = ""
) -> List[Dict]:
"""
Parses a segment template and extracts segment information.
@@ -298,7 +367,8 @@ def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
mpd_url (str): The URL of the MPD manifest.
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
Returns:
List[Dict]: The list of parsed segments.
@@ -311,20 +381,23 @@ def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source:
media = item["@initialization"]
media = media.replace("$RepresentationID$", profile["id"])
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if not media.startswith("http"):
media = f"{source}/{media}"
profile["initUrl"] = media
# Combine base_url and media, then resolve against mpd_url
if base_url:
media = base_url + media
profile["initUrl"] = resolve_url(mpd_url, media)
# Segments
if "SegmentTimeline" in item:
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
segments.extend(parse_segment_timeline(parsed_dict, item, profile, mpd_url, timescale, base_url))
elif "@duration" in item:
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
segments.extend(parse_segment_duration(parsed_dict, item, profile, mpd_url, timescale, base_url))
return segments
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
def parse_segment_timeline(
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, timescale: int, base_url: str = ""
) -> List[Dict]:
"""
Parses a segment timeline and extracts segment information.
@@ -332,8 +405,9 @@ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment timeline data.
profile (dict): The profile information.
source (str): The source URL.
mpd_url (str): The URL of the MPD manifest.
timescale (int): The timescale for the segments.
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
Returns:
List[Dict]: The list of parsed segments.
@@ -347,7 +421,7 @@ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source:
start_number = int(item.get("@startNumber", 1))
segments = [
create_segment_data(timeline, item, profile, source, timescale)
create_segment_data(timeline, item, profile, mpd_url, timescale, base_url)
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
]
return segments
@@ -379,14 +453,14 @@ def preprocess_timeline(
for _ in range(repeat + 1):
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
presentation_time = start_time - presentation_time_offset
processed_data.append(
{
"number": start_number,
"start_time": segment_start_time,
"end_time": segment_end_time,
"duration": duration,
"time": presentation_time,
"time": start_time,
"duration_mpd_timescale": duration,
}
)
start_time += duration
@@ -397,7 +471,9 @@ def preprocess_timeline(
return processed_data
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
def parse_segment_duration(
parsed_dict: dict, item: dict, profile: dict, mpd_url: str, timescale: int, base_url: str = ""
) -> List[Dict]:
"""
Parses segment duration and extracts segment information.
This is used for static or live MPD manifests.
@@ -406,8 +482,9 @@ def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment duration data.
profile (dict): The profile information.
source (str): The source URL.
mpd_url (str): The URL of the MPD manifest.
timescale (int): The timescale for the segments.
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
Returns:
List[Dict]: The list of parsed segments.
@@ -421,7 +498,7 @@ def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source:
else:
segments = generate_vod_segments(profile, duration, timescale, start_number)
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
return [create_segment_data(seg, item, profile, mpd_url, timescale, base_url) for seg in segments]
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
@@ -480,7 +557,9 @@ def generate_vod_segments(profile: dict, duration: int, timescale: int, start_nu
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
def create_segment_data(
segment: Dict, item: dict, profile: dict, mpd_url: str, timescale: Optional[int] = None, base_url: str = ""
) -> Dict:
"""
Creates segment data based on the segment information. This includes the segment URL and metadata.
@@ -488,8 +567,9 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
segment (Dict): The segment information.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
mpd_url (str): The URL of the MPD manifest.
timescale (int, optional): The timescale for the segments. Defaults to None.
base_url (str): The BaseURL from the representation (optional, for per-representation paths).
Returns:
Dict: The created segment data.
@@ -503,8 +583,10 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
if "time" in segment and timescale is not None:
media = media.replace("$Time$", str(int(segment["time"])))
if not media.startswith("http"):
media = f"{source}/{media}"
# Combine base_url and media, then resolve against mpd_url
if base_url:
media = base_url + media
media = resolve_url(mpd_url, media)
segment_data = {
"type": "segment",
@@ -528,7 +610,8 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
}
)
elif "start_time" in segment and "duration" in segment:
duration_seconds = segment["duration"] / timescale
# duration here is in timescale units (from timeline segments)
duration_seconds = segment["duration"] / timescale if timescale else segment["duration"]
segment_data.update(
{
"start_time": segment["start_time"],
@@ -537,43 +620,144 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
"program_date_time": segment["start_time"].isoformat() + "Z",
}
)
elif "duration" in segment and timescale is not None:
# Convert duration from timescale units to seconds
segment_data["extinf"] = segment["duration"] / timescale
elif "duration" in segment:
# If no timescale is provided, assume duration is already in seconds
# duration from generate_vod_segments and generate_live_segments is already in seconds
segment_data["extinf"] = segment["duration"]
return segment_data
def parse_segment_base(representation: dict, profile: dict, source: str) -> List[Dict]:
def parse_segment_list(
adaptation: dict, representation: dict, profile: dict, mpd_url: str, timescale: int
) -> List[Dict]:
"""
Parses segment base information and extracts segment data. This is used for single-segment representations.
Parses SegmentList element with explicit SegmentURL entries.
SegmentList MPDs explicitly list each segment URL, unlike SegmentTemplate which uses
URL patterns. This is less common but used by some packagers.
Args:
adaptation (dict): The adaptation set data.
representation (dict): The representation data.
source (str): The source URL.
profile (dict): The profile information.
mpd_url (str): The URL of the MPD manifest.
timescale (int): The timescale for duration calculations.
Returns:
List[Dict]: The list of parsed segments.
"""
segment = representation["SegmentBase"]
start, end = map(int, segment["@indexRange"].split("-"))
if "Initialization" in segment:
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
# Set initUrl for SegmentBase
if not representation['BaseURL'].startswith("http"):
profile["initUrl"] = f"{source}/{representation['BaseURL']}"
else:
profile["initUrl"] = representation['BaseURL']
# SegmentList can be at AdaptationSet or Representation level
segment_list = representation.get("SegmentList") or adaptation.get("SegmentList", {})
segments = []
# Handle Initialization element
if "Initialization" in segment_list:
init_data = segment_list["Initialization"]
if "@sourceURL" in init_data:
init_url = init_data["@sourceURL"]
profile["initUrl"] = resolve_url(mpd_url, init_url)
elif "@range" in init_data:
# Initialization by byte range on the BaseURL
base_url = representation.get("BaseURL", "")
profile["initUrl"] = resolve_url(mpd_url, base_url)
profile["initRange"] = init_data["@range"]
# Get segment duration from SegmentList attributes
duration = int(segment_list.get("@duration", 0))
list_timescale = int(segment_list.get("@timescale", timescale or 1))
segment_duration_sec = duration / list_timescale if list_timescale else 0
# Parse SegmentURL elements
segment_urls = segment_list.get("SegmentURL", [])
if not isinstance(segment_urls, list):
segment_urls = [segment_urls]
for i, seg_url in enumerate(segment_urls):
if seg_url is None:
continue
# Get media URL - can be @media attribute or use BaseURL with @mediaRange
media_url = seg_url.get("@media", "")
media_range = seg_url.get("@mediaRange")
if media_url:
media_url = resolve_url(mpd_url, media_url)
else:
# Use BaseURL with byte range
base_url = representation.get("BaseURL", "")
media_url = resolve_url(mpd_url, base_url)
segment_data = {
"type": "segment",
"media": media_url,
"number": i + 1,
"extinf": segment_duration_sec if segment_duration_sec > 0 else 1.0,
}
# Include media range if specified
if media_range:
segment_data["mediaRange"] = media_range
segments.append(segment_data)
return segments
def parse_segment_base(representation: dict, profile: dict, mpd_url: str) -> List[Dict]:
"""
Parses segment base information and extracts segment data. This is used for single-segment representations
(SegmentBase MPDs, typically GPAC-generated on-demand profiles).
For SegmentBase, the entire media file is treated as a single segment. The initialization data
is specified by the Initialization element's range, and the segment index (SIDX) is at indexRange.
Args:
representation (dict): The representation data.
profile (dict): The profile information.
mpd_url (str): The URL of the MPD manifest.
Returns:
List[Dict]: The list of parsed segments.
"""
segment = representation.get("SegmentBase", {})
base_url = representation.get("BaseURL", "")
# Build the full media URL
media_url = resolve_url(mpd_url, base_url)
# Set initUrl for SegmentBase - this is the URL with the initialization range
# The initialization segment contains codec/track info needed before playing media
profile["initUrl"] = media_url
# For SegmentBase, we need to specify byte ranges for init and media segments
init_range = None
if "Initialization" in segment:
init_range = segment["Initialization"].get("@range")
# Store initialization range in profile for segment endpoint to use
if init_range:
profile["initRange"] = init_range
# Get the index range which points to SIDX box
index_range = segment.get("@indexRange", "")
# Calculate total duration from profile's mediaPresentationDuration
total_duration = profile.get("mediaPresentationDuration")
if isinstance(total_duration, str):
total_duration = parse_duration(total_duration)
elif total_duration is None:
total_duration = 0
# For SegmentBase, we return a single segment representing the entire media
# The media URL is the same as initUrl but will be accessed with different byte ranges
return [
{
"type": "segment",
"range": f"{start}-{end}",
"media": f"{source}/{representation['BaseURL']}",
"media": media_url,
"number": 1,
"extinf": total_duration if total_duration > 0 else 1.0,
"indexRange": index_range,
"initRange": init_range,
}
]

View File

@@ -1,5 +1,5 @@
#Adapted for use in MediaFlowProxy from:
#https://github.com/einars/js-beautify/blob/master/python/jsbeautifier/unpackers/packer.py
# Adapted for use in MediaFlowProxy from:
# https://github.com/einars/js-beautify/blob/master/python/jsbeautifier/unpackers/packer.py
# Unpacker for Dean Edward's p.a.c.k.e.r, a part of javascript beautifier
# by Einar Lielmanis <einar@beautifier.io>
#
@@ -21,8 +21,6 @@ import logging
logger = logging.getLogger(__name__)
def detect(source):
if "eval(function(p,a,c,k,e,d)" in source:
mystr = "smth"
@@ -71,9 +69,7 @@ def _filterargs(source):
raise UnpackingError("Corrupted p.a.c.k.e.r. data.")
# could not find a satisfying regex
raise UnpackingError(
"Could not make sense of p.a.c.k.e.r data (unexpected code structure)"
)
raise UnpackingError("Could not make sense of p.a.c.k.e.r data (unexpected code structure)")
def _replacestrings(source):
@@ -88,7 +84,7 @@ def _replacestrings(source):
for index, value in enumerate(lookup):
source = source.replace(variable % index, '"%s"' % value)
return source[startpoint:]
return source
return source
class Unbaser(object):
@@ -97,10 +93,7 @@ class Unbaser(object):
ALPHABET = {
62: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
95: (
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
),
95: (" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"),
}
def __init__(self, base):
@@ -118,9 +111,7 @@ class Unbaser(object):
else:
# Build conversion dictionary cache
try:
self.dictionary = dict(
(cipher, index) for index, cipher in enumerate(self.ALPHABET[base])
)
self.dictionary = dict((cipher, index) for index, cipher in enumerate(self.ALPHABET[base]))
except KeyError:
raise TypeError("Unsupported base encoding.")
@@ -135,6 +126,8 @@ class Unbaser(object):
for index, cipher in enumerate(string[::-1]):
ret += (self.base**index) * self.dictionary[cipher]
return ret
class UnpackingError(Exception):
"""Badly packed source or general error. Argument is a
meaningful description."""
@@ -142,11 +135,10 @@ class UnpackingError(Exception):
pass
async def eval_solver(self, url: str, headers: dict[str, str] | None, patterns: list[str]) -> str:
try:
response = await self._make_request(url, headers=headers)
soup = BeautifulSoup(response.text, "lxml",parse_only=SoupStrainer("script"))
soup = BeautifulSoup(response.text, "lxml", parse_only=SoupStrainer("script"))
script_all = soup.find_all("script")
for i in script_all:
if detect(i.text):
@@ -162,4 +154,4 @@ async def eval_solver(self, url: str, headers: dict[str, str] | None, patterns:
raise UnpackingError("No p.a.c.k.e.d JS found or no pattern matched.")
except Exception as e:
logger.exception("Eval solver error for %s", url)
raise UnpackingError("Error in eval_solver") from e
raise UnpackingError("Error in eval_solver") from e

View File

@@ -8,7 +8,7 @@ from .aes import AES
from .rijndael import Rijndael
from .cryptomath import bytesToNumber, numberToByteArray
__all__ = ['new', 'Python_AES']
__all__ = ["new", "Python_AES"]
def new(key, mode, IV):
@@ -37,22 +37,21 @@ class Python_AES(AES):
plaintextBytes = bytearray(plaintext)
chainBytes = self.IV[:]
#CBC Mode: For each block...
for x in range(len(plaintextBytes)//16):
#XOR with the chaining block
blockBytes = plaintextBytes[x*16 : (x*16)+16]
# CBC Mode: For each block...
for x in range(len(plaintextBytes) // 16):
# XOR with the chaining block
blockBytes = plaintextBytes[x * 16 : (x * 16) + 16]
for y in range(16):
blockBytes[y] ^= chainBytes[y]
#Encrypt it
# Encrypt it
encryptedBytes = self.rijndael.encrypt(blockBytes)
#Overwrite the input with the output
# Overwrite the input with the output
for y in range(16):
plaintextBytes[(x*16)+y] = encryptedBytes[y]
plaintextBytes[(x * 16) + y] = encryptedBytes[y]
#Set the next chaining block
# Set the next chaining block
chainBytes = encryptedBytes
self.IV = chainBytes[:]
@@ -64,19 +63,18 @@ class Python_AES(AES):
ciphertextBytes = ciphertext[:]
chainBytes = self.IV[:]
#CBC Mode: For each block...
for x in range(len(ciphertextBytes)//16):
#Decrypt it
blockBytes = ciphertextBytes[x*16 : (x*16)+16]
# CBC Mode: For each block...
for x in range(len(ciphertextBytes) // 16):
# Decrypt it
blockBytes = ciphertextBytes[x * 16 : (x * 16) + 16]
decryptedBytes = self.rijndael.decrypt(blockBytes)
#XOR with the chaining block and overwrite the input with output
# XOR with the chaining block and overwrite the input with output
for y in range(16):
decryptedBytes[y] ^= chainBytes[y]
ciphertextBytes[(x*16)+y] = decryptedBytes[y]
ciphertextBytes[(x * 16) + y] = decryptedBytes[y]
#Set the next chaining block
# Set the next chaining block
chainBytes = blockBytes
self.IV = chainBytes[:]
@@ -89,7 +87,7 @@ class Python_AES_CTR(AES):
self.rijndael = Rijndael(key, 16)
self.IV = IV
self._counter_bytes = 16 - len(self.IV)
self._counter = self.IV + bytearray(b'\x00' * self._counter_bytes)
self._counter = self.IV + bytearray(b"\x00" * self._counter_bytes)
@property
def counter(self):
@@ -102,9 +100,9 @@ class Python_AES_CTR(AES):
def _counter_update(self):
counter_int = bytesToNumber(self._counter) + 1
self._counter = numberToByteArray(counter_int, 16)
if self._counter_bytes > 0 and \
self._counter[-self._counter_bytes:] == \
bytearray(b'\xff' * self._counter_bytes):
if self._counter_bytes > 0 and self._counter[-self._counter_bytes :] == bytearray(
b"\xff" * self._counter_bytes
):
raise OverflowError("CTR counter overflowed")
def encrypt(self, plaintext):

View File

@@ -0,0 +1,204 @@
"""
Rate limit handlers for host-specific rate limiting strategies.
This module provides handler classes that implement specific rate limiting
logic for different streaming hosts (e.g., Vidoza's aggressive 509 rate limiting).
Similar pattern to stream_transformers.py but for rate limiting behavior.
"""
import logging
from typing import Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
class RateLimitHandler:
"""
Base class for rate limit handlers.
Subclasses should override properties to customize rate limiting behavior.
"""
@property
def cooldown_seconds(self) -> int:
"""
Duration in seconds to wait between upstream connections.
Default: 0 (no cooldown, allow immediate requests)
"""
return 0
@property
def use_head_cache(self) -> bool:
"""
Whether to cache HEAD responses to avoid upstream calls.
Default: False
"""
return False
@property
def use_stream_gate(self) -> bool:
"""
Whether to use distributed locking to serialize requests.
Default: False
"""
return False
@property
def exclusive_stream(self) -> bool:
"""
If True, the stream gate is held for the ENTIRE duration of the stream.
This prevents any concurrent connections to the same URL.
Required for hosts that 509 on ANY concurrent streams.
Default: False (gate released after headers received)
"""
return False
@property
def retry_after_seconds(self) -> int:
"""
Value for Retry-After header when returning 503.
Default: 2
"""
return 2
class VidozaRateLimitHandler(RateLimitHandler):
"""
Rate limit handler for Vidoza CDN.
Vidoza aggressively rate-limits (509) if ANY concurrent connections exist
to the same URL from the same IP. This handler:
- Uses EXCLUSIVE stream gate: only ONE stream at a time (gate held during entire stream)
- Caches HEAD responses to serve repeated probes without connections
- ExoPlayer/clients must wait for the current stream to finish before starting a new one
WARNING: This means only one client can actively stream at a time. Other clients will
wait (up to timeout) and eventually get 503 if the current stream is too long.
"""
@property
def cooldown_seconds(self) -> int:
return 0 # No cooldown needed - we use exclusive streaming instead
@property
def use_head_cache(self) -> bool:
return True
@property
def use_stream_gate(self) -> bool:
return True
@property
def exclusive_stream(self) -> bool:
"""
If True, the stream gate is held for the ENTIRE duration of the stream,
not just at the start. This prevents any concurrent connections.
Required for hosts like Vidoza that 509 on ANY concurrent connections.
"""
return True
@property
def retry_after_seconds(self) -> int:
return 5
class AggressiveRateLimitHandler(RateLimitHandler):
"""
Generic aggressive rate limit handler for hosts with strict rate limiting.
Use this for hosts that show similar behavior to Vidoza but may have
different thresholds.
"""
@property
def cooldown_seconds(self) -> int:
return 3
@property
def use_head_cache(self) -> bool:
return True
@property
def use_stream_gate(self) -> bool:
return True
@property
def retry_after_seconds(self) -> int:
return 2
# Registry of available rate limit handlers by ID
RATE_LIMIT_HANDLER_REGISTRY: dict[str, type[RateLimitHandler]] = {
"vidoza": VidozaRateLimitHandler,
"aggressive": AggressiveRateLimitHandler,
}
# Auto-detection: hostname patterns to handler IDs
# These patterns are checked against the video URL hostname
#
# NOTE: Vidoza CDN DOES rate limit concurrent connections from the same IP.
# When multiple clients request through the proxy, all requests come from
# the proxy's IP, triggering Vidoza's rate limit (509 errors).
# Stream-level rate limiting serializes requests to avoid this.
#
HOST_PATTERN_TO_HANDLER: dict[str, str] = {
"vidoza.net": "vidoza",
"vidoza.org": "vidoza",
# Add more patterns as needed for hosts that rate-limit CDN streaming:
# "example-cdn.com": "aggressive",
}
def get_rate_limit_handler(
handler_id: Optional[str] = None,
video_url: Optional[str] = None,
) -> RateLimitHandler:
"""
Get a rate limit handler instance.
Priority:
1. Explicit handler_id if provided
2. Auto-detect from video_url hostname
3. Default (no rate limiting)
Args:
handler_id: Explicit handler identifier (e.g., "vidoza", "aggressive")
video_url: Video URL for auto-detection based on hostname
Returns:
A rate limit handler instance. Returns base RateLimitHandler (no-op) if
no handler specified and no auto-detection match.
"""
# 1. Explicit handler ID
if handler_id:
handler_class = RATE_LIMIT_HANDLER_REGISTRY.get(handler_id)
if handler_class:
logger.debug(f"Using explicit rate limit handler: {handler_id}")
return handler_class()
else:
logger.warning(f"Unknown rate limit handler ID: {handler_id}")
# 2. Auto-detect from URL hostname
if video_url:
try:
hostname = urlparse(video_url).hostname or ""
# Check each pattern
for pattern, detected_handler_id in HOST_PATTERN_TO_HANDLER.items():
if pattern in hostname:
handler_class = RATE_LIMIT_HANDLER_REGISTRY.get(detected_handler_id)
if handler_class:
logger.info(f"[RateLimit] Auto-detected handler '{detected_handler_id}' for host: {hostname}")
return handler_class()
logger.debug(f"[RateLimit] No handler matched for hostname: {hostname}")
except Exception as e:
logger.warning(f"[RateLimit] Error during auto-detection: {e}")
# 3. Default: no rate limiting
return RateLimitHandler()
def get_available_handlers() -> list[str]:
"""Get list of available rate limit handler IDs."""
return list(RATE_LIMIT_HANDLER_REGISTRY.keys())

View File

@@ -0,0 +1,861 @@
"""
Redis utilities for cross-worker coordination and caching.
Provides:
- Distributed locking (stream gating, generic locks)
- Shared caching (HEAD responses, extractors, MPD, segments, init segments)
- In-flight request deduplication
- Cooldown/throttle tracking
All caches are shared across all uvicorn workers via Redis.
IMPORTANT: Redis is OPTIONAL. If settings.redis_url is None, all functions
gracefully degrade:
- Locks always succeed immediately (no cross-worker coordination)
- Cache operations return None/False (no shared caching)
- Cooldowns always allow (no rate limiting)
This allows single-worker deployments to work without Redis.
"""
import asyncio
import hashlib
import json
import logging
import time
from typing import Optional
from mediaflow_proxy.configs import settings
logger = logging.getLogger(__name__)
# =============================================================================
# Redis Clients (Lazy Singletons)
# =============================================================================
# Two clients: one for text/JSON (decode_responses=True), one for binary data
_redis_client = None
_redis_binary_client = None
_redis_available: Optional[bool] = None # None = not checked yet
def is_redis_configured() -> bool:
"""Check if Redis URL is configured in settings."""
return settings.redis_url is not None and settings.redis_url.strip() != ""
async def is_redis_available() -> bool:
"""
Check if Redis is configured and reachable.
Caches the result after first successful/failed connection attempt.
"""
global _redis_available
if not is_redis_configured():
return False
if _redis_available is not None:
return _redis_available
# Try to connect
try:
import redis.asyncio as redis
client = redis.from_url(
settings.redis_url,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2,
)
await client.ping()
await client.aclose()
_redis_available = True
logger.info(f"Redis is available: {settings.redis_url}")
except Exception as e:
_redis_available = False
logger.warning(f"Redis not available (features will be disabled): {e}")
return _redis_available
async def get_redis():
"""
Get or create the Redis connection pool for text/JSON data (lazy singleton).
The connection pool is shared across all async tasks in a single worker.
Each worker process has its own pool, but Redis itself coordinates across workers.
Returns None if Redis is not configured or not available.
"""
global _redis_client
if not is_redis_configured():
return None
if _redis_client is None:
import redis.asyncio as redis
_redis_client = redis.from_url(
settings.redis_url,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
)
# Test connection
try:
await _redis_client.ping()
logger.info(f"Redis connected (text): {settings.redis_url}")
except Exception as e:
logger.error(f"Redis connection failed: {e}")
_redis_client = None
return None
return _redis_client
async def get_redis_binary():
"""
Get or create the Redis connection pool for binary data (lazy singleton).
Used for caching segments and init segments without base64 encoding overhead.
Returns None if Redis is not configured or not available.
"""
global _redis_binary_client
if not is_redis_configured():
return None
if _redis_binary_client is None:
import redis.asyncio as redis
_redis_binary_client = redis.from_url(
settings.redis_url,
decode_responses=False, # Keep bytes as-is
socket_connect_timeout=5,
socket_timeout=5,
)
# Test connection
try:
await _redis_binary_client.ping()
logger.info(f"Redis connected (binary): {settings.redis_url}")
except Exception as e:
logger.error(f"Redis binary connection failed: {e}")
_redis_binary_client = None
return None
return _redis_binary_client
async def close_redis():
"""Close all Redis connection pools (call on shutdown)."""
global _redis_client, _redis_binary_client, _redis_available
if _redis_client is not None:
await _redis_client.aclose()
_redis_client = None
if _redis_binary_client is not None:
await _redis_binary_client.aclose()
_redis_binary_client = None
_redis_available = None
logger.info("Redis connections closed")
# =============================================================================
# Instance Namespace Helper
# =============================================================================
# Some cached data is bound to the outgoing IP of the pod that produced it
# (e.g. extractor results resolved via the pod's egress IP). Sharing these
# entries across pods in a multi-instance deployment causes other pods to serve
# stale/wrong URLs.
#
# Set CACHE_NAMESPACE (env: CACHE_NAMESPACE) to a unique value per pod (e.g.
# pod name, hostname, or any discriminator). Instance-scoped keys are then
# stored under "<namespace>:<original_key>", while fully-shared keys (MPD,
# init segments, media segments, locks, stream gates) remain unchanged.
def make_instance_key(key: str) -> str:
"""Prefix *key* with the configured instance namespace.
Use this for cache/coordination keys that must NOT be shared across pods
because the underlying data is specific to a pod's outgoing IP (e.g.
extractor results). Common content (MPD, init/media segments) should
never be namespaced.
If ``settings.cache_namespace`` is not set the key is returned unchanged,
so single-instance deployments are unaffected.
"""
ns = settings.cache_namespace
return f"{ns}:{key}" if ns else key
# =============================================================================
# Stream Gate (Distributed Lock)
# =============================================================================
# Serializes upstream connection handshakes per-URL across all workers.
# Uses SET NX EX for atomic acquire with auto-expiry.
GATE_PREFIX = "mfp:stream_gate:"
GATE_TTL = 15 # seconds - auto-expire if worker crashes mid-request
def _gate_key(url: str) -> str:
"""Generate Redis key for a stream gate."""
url_hash = hashlib.md5(url.encode()).hexdigest()
return f"{GATE_PREFIX}{url_hash}"
async def acquire_stream_gate(url: str, timeout: float = 15.0) -> bool:
"""
Try to acquire a per-URL stream gate (distributed lock).
Only one worker across all processes can hold the gate for a given URL.
The gate auto-expires after GATE_TTL seconds to prevent deadlocks.
If Redis is not available, always returns True (no coordination).
Args:
url: The upstream URL to gate
timeout: Maximum time to wait for the gate (seconds)
Returns:
True if gate acquired (or Redis unavailable), False if timeout
"""
r = await get_redis()
if r is None:
# No Redis - no cross-worker coordination, always allow
return True
key = _gate_key(url)
deadline = time.time() + timeout
while time.time() < deadline:
# SET NX EX is atomic: only succeeds if key doesn't exist
if await r.set(key, "1", nx=True, ex=GATE_TTL):
logger.debug(f"[Redis] Acquired stream gate: {key[:50]}...")
return True
# Another worker holds the gate, wait and retry
await asyncio.sleep(0.05) # 50ms poll interval
logger.warning(f"[Redis] Gate acquisition timeout ({timeout}s): {key[:50]}...")
return False
async def release_stream_gate(url: str):
"""
Release a per-URL stream gate.
Safe to call even if gate wasn't acquired or already expired.
No-op if Redis is not available.
"""
r = await get_redis()
if r is None:
return
key = _gate_key(url)
await r.delete(key)
logger.debug(f"[Redis] Released stream gate: {key[:50]}...")
async def extend_stream_gate(url: str, ttl: int = GATE_TTL):
"""
Extend the TTL of a stream gate to keep it held during long streams.
Should be called periodically (e.g., every 10s) while streaming.
No-op if Redis is not available or gate doesn't exist.
"""
r = await get_redis()
if r is None:
return
key = _gate_key(url)
await r.expire(key, ttl)
logger.debug(f"[Redis] Extended stream gate TTL ({ttl}s): {key[:50]}...")
async def is_stream_gate_held(url: str) -> bool:
"""Check if a stream gate is currently held. Returns False if Redis unavailable."""
r = await get_redis()
if r is None:
return False
key = _gate_key(url)
return await r.exists(key) > 0
# =============================================================================
# HEAD Response Cache
# =============================================================================
# Caches upstream response headers so repeated HEAD probes (e.g., from ExoPlayer)
# can be served without any upstream connection. Shared across all workers.
HEAD_CACHE_PREFIX = "mfp:head_cache:"
HEAD_CACHE_TTL = 60 # seconds - Vidoza CDN URLs typically expire in minutes
def _head_cache_key(url: str) -> str:
"""Generate Redis key for HEAD cache entry."""
url_hash = hashlib.md5(url.encode()).hexdigest()
return f"{HEAD_CACHE_PREFIX}{url_hash}"
async def get_cached_head(url: str) -> Optional[dict]:
"""
Get cached HEAD response metadata for a URL.
Args:
url: The upstream URL
Returns:
Dict with 'headers' and 'status' keys, or None if not cached (or Redis unavailable)
"""
r = await get_redis()
if r is None:
return None
key = _head_cache_key(url)
data = await r.get(key)
if data:
logger.debug(f"[Redis] HEAD cache hit: {key[:50]}...")
return json.loads(data)
return None
async def set_cached_head(url: str, headers: dict, status: int):
"""
Cache HEAD response metadata for a URL.
No-op if Redis is not available.
Args:
url: The upstream URL
headers: Response headers dict (will be JSON serialized)
status: HTTP status code (e.g., 200, 206)
"""
r = await get_redis()
if r is None:
return
key = _head_cache_key(url)
# Only cache headers that are useful for HEAD responses
# Filter to avoid caching large or irrelevant headers
cached_headers = {}
for k, v in headers.items():
k_lower = k.lower()
if k_lower in (
"content-type",
"content-length",
"accept-ranges",
"content-range",
"etag",
"last-modified",
"cache-control",
):
cached_headers[k_lower] = v
payload = json.dumps({"headers": cached_headers, "status": status})
await r.set(key, payload, ex=HEAD_CACHE_TTL)
logger.debug(f"[Redis] HEAD cache set ({HEAD_CACHE_TTL}s TTL): {key[:50]}...")
# =============================================================================
# Generic Distributed Lock
# =============================================================================
# For cross-worker coordination (e.g., segment downloads, prebuffering)
LOCK_PREFIX = "mfp:lock:"
DEFAULT_LOCK_TTL = 30 # seconds
def _lock_key(key: str) -> str:
"""Generate Redis key for a lock."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return f"{LOCK_PREFIX}{key_hash}"
async def acquire_lock(key: str, ttl: int = DEFAULT_LOCK_TTL, timeout: float = 30.0) -> bool:
"""
Acquire a distributed lock.
If Redis is not available, always returns True (no coordination).
Args:
key: The lock identifier
ttl: Lock auto-expiry time in seconds (prevents deadlocks)
timeout: Maximum time to wait for the lock
Returns:
True if lock acquired (or Redis unavailable), False if timeout
"""
r = await get_redis()
if r is None:
return True # No Redis - no coordination
lock_key = _lock_key(key)
deadline = time.time() + timeout
while time.time() < deadline:
if await r.set(lock_key, "1", nx=True, ex=ttl):
logger.debug(f"[Redis] Acquired lock: {key[:60]}...")
return True
await asyncio.sleep(0.05)
logger.warning(f"[Redis] Lock timeout ({timeout}s): {key[:60]}...")
return False
async def release_lock(key: str):
"""Release a distributed lock. No-op if Redis unavailable."""
r = await get_redis()
if r is None:
return
lock_key = _lock_key(key)
await r.delete(lock_key)
logger.debug(f"[Redis] Released lock: {key[:60]}...")
# =============================================================================
# Extractor Cache
# =============================================================================
# Caches extractor results (JSON) across all workers
EXTRACTOR_CACHE_PREFIX = "mfp:extractor:"
EXTRACTOR_CACHE_TTL = 300 # 5 minutes
def _extractor_key(key: str) -> str:
"""Generate Redis key for extractor cache."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return f"{EXTRACTOR_CACHE_PREFIX}{key_hash}"
async def get_cached_extractor(key: str) -> Optional[dict]:
"""Get cached extractor result. Returns None if Redis unavailable."""
r = await get_redis()
if r is None:
return None
redis_key = _extractor_key(key)
data = await r.get(redis_key)
if data:
logger.debug(f"[Redis] Extractor cache hit: {key[:60]}...")
return json.loads(data)
return None
async def set_cached_extractor(key: str, data: dict, ttl: int = EXTRACTOR_CACHE_TTL):
"""Cache extractor result. No-op if Redis unavailable."""
r = await get_redis()
if r is None:
return
redis_key = _extractor_key(key)
await r.set(redis_key, json.dumps(data), ex=ttl)
logger.debug(f"[Redis] Extractor cache set ({ttl}s TTL): {key[:60]}...")
# =============================================================================
# MPD Cache
# =============================================================================
# Caches parsed MPD manifests (JSON) across all workers
MPD_CACHE_PREFIX = "mfp:mpd:"
DEFAULT_MPD_CACHE_TTL = 60 # 1 minute
def _mpd_key(key: str) -> str:
"""Generate Redis key for MPD cache."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return f"{MPD_CACHE_PREFIX}{key_hash}"
async def get_cached_mpd(key: str) -> Optional[dict]:
"""Get cached MPD manifest. Returns None if Redis unavailable."""
r = await get_redis()
if r is None:
return None
redis_key = _mpd_key(key)
data = await r.get(redis_key)
if data:
logger.debug(f"[Redis] MPD cache hit: {key[:60]}...")
return json.loads(data)
return None
async def set_cached_mpd(key: str, data: dict, ttl: int | float = DEFAULT_MPD_CACHE_TTL):
"""Cache MPD manifest. No-op if Redis unavailable."""
r = await get_redis()
if r is None:
return
redis_key = _mpd_key(key)
# Ensure TTL is an integer (Redis requires int for ex parameter)
ttl_int = max(1, int(ttl))
await r.set(redis_key, json.dumps(data), ex=ttl_int)
logger.debug(f"[Redis] MPD cache set ({ttl_int}s TTL): {key[:60]}...")
# =============================================================================
# Segment Cache (Binary)
# =============================================================================
# Caches HLS/DASH segments across all workers
SEGMENT_CACHE_PREFIX = b"mfp:segment:"
DEFAULT_SEGMENT_CACHE_TTL = 60 # 1 minute
def _segment_key(url: str) -> bytes:
"""Generate Redis key for segment cache."""
url_hash = hashlib.md5(url.encode()).hexdigest()
return SEGMENT_CACHE_PREFIX + url_hash.encode()
async def get_cached_segment(url: str) -> Optional[bytes]:
"""Get cached segment data. Returns None if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return None
key = _segment_key(url)
data = await r.get(key)
if data:
logger.debug(f"[Redis] Segment cache hit: {url[:60]}...")
return data
async def set_cached_segment(url: str, data: bytes, ttl: int = DEFAULT_SEGMENT_CACHE_TTL):
"""Cache segment data. No-op if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return
key = _segment_key(url)
await r.set(key, data, ex=ttl)
logger.debug(f"[Redis] Segment cache set ({ttl}s TTL, {len(data)} bytes): {url[:60]}...")
# =============================================================================
# Init Segment Cache (Binary)
# =============================================================================
# Caches initialization segments across all workers
INIT_CACHE_PREFIX = b"mfp:init:"
DEFAULT_INIT_CACHE_TTL = 3600 # 1 hour
def _init_key(url: str) -> bytes:
"""Generate Redis key for init segment cache."""
url_hash = hashlib.md5(url.encode()).hexdigest()
return INIT_CACHE_PREFIX + url_hash.encode()
async def get_cached_init_segment(url: str) -> Optional[bytes]:
"""Get cached init segment data. Returns None if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return None
key = _init_key(url)
data = await r.get(key)
if data:
logger.debug(f"[Redis] Init segment cache hit: {url[:60]}...")
return data
async def set_cached_init_segment(url: str, data: bytes, ttl: int = DEFAULT_INIT_CACHE_TTL):
"""Cache init segment data. No-op if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return
key = _init_key(url)
await r.set(key, data, ex=ttl)
logger.debug(f"[Redis] Init segment cache set ({ttl}s TTL, {len(data)} bytes): {url[:60]}...")
# =============================================================================
# Processed Init Segment Cache (Binary)
# =============================================================================
# Caches DRM-stripped/processed init segments across all workers
PROCESSED_INIT_CACHE_PREFIX = b"mfp:processed_init:"
DEFAULT_PROCESSED_INIT_TTL = 3600 # 1 hour
def _processed_init_key(key: str) -> bytes:
"""Generate Redis key for processed init segment cache."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return PROCESSED_INIT_CACHE_PREFIX + key_hash.encode()
async def get_cached_processed_init(key: str) -> Optional[bytes]:
"""Get cached processed init segment data. Returns None if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return None
redis_key = _processed_init_key(key)
data = await r.get(redis_key)
if data:
logger.debug(f"[Redis] Processed init cache hit: {key[:60]}...")
return data
async def set_cached_processed_init(key: str, data: bytes, ttl: int = DEFAULT_PROCESSED_INIT_TTL):
"""Cache processed init segment data. No-op if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return
redis_key = _processed_init_key(key)
await r.set(redis_key, data, ex=ttl)
logger.debug(f"[Redis] Processed init cache set ({ttl}s TTL, {len(data)} bytes): {key[:60]}...")
# =============================================================================
# In-flight Request Tracking
# =============================================================================
# Prevents duplicate upstream requests across all workers
INFLIGHT_PREFIX = "mfp:inflight:"
DEFAULT_INFLIGHT_TTL = 60 # seconds
def _inflight_key(key: str) -> str:
"""Generate Redis key for in-flight tracking."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return f"{INFLIGHT_PREFIX}{key_hash}"
async def mark_inflight(key: str, ttl: int = DEFAULT_INFLIGHT_TTL) -> bool:
"""
Mark a request as in-flight (being processed by some worker).
If Redis is not available, always returns True (this worker is "first").
Args:
key: The request identifier
ttl: Auto-expiry time in seconds
Returns:
True if this call marked it (first worker), False if already in-flight
"""
r = await get_redis()
if r is None:
return True # No Redis - always proceed
inflight_key = _inflight_key(key)
result = await r.set(inflight_key, "1", nx=True, ex=ttl)
if result:
logger.debug(f"[Redis] Marked in-flight: {key[:60]}...")
return bool(result)
async def is_inflight(key: str) -> bool:
"""Check if a request is currently in-flight. Returns False if Redis unavailable."""
r = await get_redis()
if r is None:
return False
inflight_key = _inflight_key(key)
return await r.exists(inflight_key) > 0
async def clear_inflight(key: str):
"""Clear in-flight marker (call when request completes). No-op if Redis unavailable."""
r = await get_redis()
if r is None:
return
inflight_key = _inflight_key(key)
await r.delete(inflight_key)
logger.debug(f"[Redis] Cleared in-flight: {key[:60]}...")
async def wait_for_completion(key: str, timeout: float = 30.0, poll_interval: float = 0.1) -> bool:
"""
Wait for an in-flight request to complete.
If Redis is not available, returns True immediately.
Args:
key: The request identifier
timeout: Maximum time to wait
poll_interval: Time between checks
Returns:
True if completed (inflight marker gone), False if timeout
"""
r = await get_redis()
if r is None:
return True # No Redis - don't wait
deadline = time.time() + timeout
while time.time() < deadline:
if not await is_inflight(key):
return True
await asyncio.sleep(poll_interval)
return False
# =============================================================================
# Cooldown/Throttle Tracking
# =============================================================================
# Prevents rapid repeated operations (e.g., background refresh throttling)
COOLDOWN_PREFIX = "mfp:cooldown:"
def _cooldown_key(key: str) -> str:
"""Generate Redis key for cooldown tracking."""
key_hash = hashlib.md5(key.encode()).hexdigest()
return f"{COOLDOWN_PREFIX}{key_hash}"
async def check_and_set_cooldown(key: str, cooldown_seconds: int) -> bool:
"""
Check if cooldown has elapsed and set new cooldown if so.
If Redis is not available, always returns True (no rate limiting).
Args:
key: The cooldown identifier
cooldown_seconds: Duration of the cooldown period
Returns:
True if cooldown elapsed (and new cooldown set), False if still in cooldown
"""
r = await get_redis()
if r is None:
return True # No Redis - no rate limiting
cooldown_key = _cooldown_key(key)
# SET NX EX: only succeeds if key doesn't exist
result = await r.set(cooldown_key, "1", nx=True, ex=cooldown_seconds)
if result:
logger.debug(f"[Redis] Cooldown set ({cooldown_seconds}s): {key[:60]}...")
return True
return False
async def is_in_cooldown(key: str) -> bool:
"""Check if currently in cooldown period. Returns False if Redis unavailable."""
r = await get_redis()
if r is None:
return False
cooldown_key = _cooldown_key(key)
return await r.exists(cooldown_key) > 0
# =============================================================================
# HLS Transcode Session (Cross-Worker)
# =============================================================================
# Per-segment HLS transcode caching.
# Each segment is independently transcoded and cached. Segment output metadata
# (video/audio DTS, sequence number) is stored so consecutive segments can
# maintain timeline continuity without a persistent pipeline.
HLS_SEG_PREFIX = b"mfp:hls_seg:"
HLS_INIT_PREFIX = b"mfp:hls_init:"
HLS_SEG_META_PREFIX = "mfp:hls_smeta:"
HLS_SEG_TTL = 60 # 60 s -- short-lived; only for immediate retry/re-request
HLS_INIT_TTL = 3600 # 1 hour -- stable for the viewing session
HLS_SEG_META_TTL = 3600 # 1 hour -- needed for next-segment continuity
def _hls_seg_key(cache_key: str, seg_index: int) -> bytes:
return HLS_SEG_PREFIX + f"{cache_key}:{seg_index}".encode()
def _hls_init_key(cache_key: str) -> bytes:
return HLS_INIT_PREFIX + cache_key.encode()
def _hls_seg_meta_key(cache_key: str, seg_index: int) -> str:
return f"{HLS_SEG_META_PREFIX}{cache_key}:{seg_index}"
async def hls_get_segment(cache_key: str, seg_index: int) -> Optional[bytes]:
"""Get a cached HLS segment from Redis. Returns None if unavailable."""
r = await get_redis_binary()
if r is None:
return None
try:
return await r.get(_hls_seg_key(cache_key, seg_index))
except Exception:
return None
async def hls_set_segment(cache_key: str, seg_index: int, data: bytes) -> None:
"""Store an HLS segment in Redis with short TTL. No-op if Redis unavailable."""
r = await get_redis_binary()
if r is None:
return
try:
await r.set(_hls_seg_key(cache_key, seg_index), data, ex=HLS_SEG_TTL)
except Exception:
logger.debug("[Redis] Failed to cache HLS segment %d", seg_index)
async def hls_get_init(cache_key: str) -> Optional[bytes]:
"""Get the cached HLS init segment from Redis."""
r = await get_redis_binary()
if r is None:
return None
try:
return await r.get(_hls_init_key(cache_key))
except Exception:
return None
async def hls_set_init(cache_key: str, data: bytes) -> None:
"""Store the HLS init segment in Redis."""
r = await get_redis_binary()
if r is None:
return
try:
await r.set(_hls_init_key(cache_key), data, ex=HLS_INIT_TTL)
except Exception:
logger.debug("[Redis] Failed to cache HLS init segment")
async def hls_get_segment_meta(cache_key: str, seg_index: int) -> Optional[dict]:
"""
Get per-segment output metadata from Redis.
Returns a dict with keys like ``video_dts_ms``, ``audio_dts_ms``,
``sequence_number``, or None if unavailable.
"""
r = await get_redis()
if r is None:
return None
try:
raw = await r.get(_hls_seg_meta_key(cache_key, seg_index))
if raw:
return json.loads(raw)
except Exception:
pass
return None
async def hls_set_segment_meta(cache_key: str, seg_index: int, meta: dict) -> None:
"""
Store per-segment output metadata in Redis.
``meta`` should contain keys like ``video_dts_ms``, ``audio_dts_ms``,
``sequence_number`` so the next segment can continue the timeline.
"""
r = await get_redis()
if r is None:
return
try:
await r.set(
_hls_seg_meta_key(cache_key, seg_index),
json.dumps(meta),
ex=HLS_SEG_META_TTL,
)
except Exception:
logger.debug("[Redis] Failed to set HLS segment meta %d", seg_index)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
"""
Stream transformers for host-specific content manipulation.
This module provides transformer classes that can modify streaming content
on-the-fly. Each transformer handles specific content manipulation needs
for different streaming hosts (e.g., PNG wrapper stripping, TS detection).
"""
import logging
import typing
logger = logging.getLogger(__name__)
class StreamTransformer:
"""
Base class for stream content transformers.
Subclasses should override the transform method to implement
specific content manipulation logic.
"""
async def transform(self, chunk_iterator: typing.AsyncIterator[bytes]) -> typing.AsyncGenerator[bytes, None]:
"""
Transform stream chunks.
Args:
chunk_iterator: Async iterator of raw bytes from upstream.
Yields:
Transformed bytes chunks.
"""
async for chunk in chunk_iterator:
yield chunk
class TSStreamTransformer(StreamTransformer):
"""
Transformer for MPEG-TS streams with obfuscation.
Handles streams from hosts like TurboVidPlay, StreamWish, and FileMoon
that may have:
- Fake PNG wrapper prepended to video data
- 0xFF padding bytes before actual content
- Need for TS sync byte detection
"""
# PNG signature and IEND marker for fake PNG header detection
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
_PNG_IEND_MARKER = b"\x49\x45\x4e\x44\xae\x42\x60\x82"
# TS packet constants
_TS_SYNC = 0x47
_TS_PACKET_SIZE = 188
# Maximum bytes to buffer before forcing passthrough
_MAX_PREFETCH = 512 * 1024 # 512 KB
def __init__(self):
self.buffer = bytearray()
self.ts_started = False
self.bytes_stripped = 0
@staticmethod
def _find_ts_start(buffer: bytes) -> typing.Optional[int]:
"""
Find MPEG-TS sync byte (0x47) aligned on 188 bytes.
Args:
buffer: Bytes to search for TS sync pattern.
Returns:
Offset where TS starts, or None if not found.
"""
TS_SYNC = 0x47
TS_PACKET = 188
max_i = len(buffer) - TS_PACKET
for i in range(max(0, max_i)):
if buffer[i] == TS_SYNC and buffer[i + TS_PACKET] == TS_SYNC:
return i
return None
def _strip_fake_png_wrapper(self, chunk: bytes) -> bytes:
"""
Strip fake PNG wrapper from chunk data.
Some streaming services prepend a fake PNG image to video data
to evade detection. This method detects and removes it.
Args:
chunk: The raw chunk data that may contain a fake PNG header.
Returns:
The chunk with fake PNG wrapper removed, or original chunk if not present.
"""
if not chunk.startswith(self._PNG_SIGNATURE):
return chunk
# Find the IEND marker that signals end of PNG data
iend_pos = chunk.find(self._PNG_IEND_MARKER)
if iend_pos == -1:
# IEND not found in this chunk - return as-is to avoid data corruption
logger.debug("PNG signature detected but IEND marker not found in chunk")
return chunk
# Calculate position after IEND marker
content_start = iend_pos + len(self._PNG_IEND_MARKER)
# Skip any padding bytes (null or 0xFF) between PNG and actual content
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
content_start += 1
self.bytes_stripped = content_start
logger.debug(f"Stripped {content_start} bytes of fake PNG wrapper from stream")
return chunk[content_start:]
async def transform(self, chunk_iterator: typing.AsyncIterator[bytes]) -> typing.AsyncGenerator[bytes, None]:
"""
Transform TS stream by stripping PNG wrapper and finding TS start.
Args:
chunk_iterator: Async iterator of raw bytes from upstream.
Yields:
Cleaned TS stream bytes.
"""
async for chunk in chunk_iterator:
if self.ts_started:
# Normal streaming once TS has started
yield chunk
continue
# Prebuffer phase (until we find TS or pass through)
self.buffer += chunk
# Fast-path: if it's an m3u8 playlist, don't do TS detection
if len(self.buffer) >= 7 and self.buffer[:7] in (b"#EXTM3U", b"#EXT-X-"):
yield bytes(self.buffer)
self.buffer.clear()
self.ts_started = True
continue
# Strip fake PNG wrapper if present
if self.buffer.startswith(self._PNG_SIGNATURE):
if self._PNG_IEND_MARKER in self.buffer:
self.buffer = bytearray(self._strip_fake_png_wrapper(bytes(self.buffer)))
# Skip pure 0xFF padding bytes (TurboVid style)
while self.buffer and self.buffer[0] == 0xFF:
self.buffer.pop(0)
# Re-check for m3u8 playlist after stripping PNG wrapper and padding
# This handles cases where m3u8 content is wrapped in PNG
if len(self.buffer) >= 7 and self.buffer[:7] in (b"#EXTM3U", b"#EXT-X-"):
logger.debug("Found m3u8 content after stripping wrapper - passing through")
yield bytes(self.buffer)
self.buffer.clear()
self.ts_started = True
continue
ts_offset = self._find_ts_start(bytes(self.buffer))
if ts_offset is None:
# Keep buffering until we find TS or hit limit
if len(self.buffer) > self._MAX_PREFETCH:
logger.warning("TS sync not found after large prebuffer, forcing passthrough")
yield bytes(self.buffer)
self.buffer.clear()
self.ts_started = True
continue
# TS found: emit from ts_offset and switch to pass-through
self.ts_started = True
out = bytes(self.buffer[ts_offset:])
self.buffer.clear()
if out:
yield out
# Registry of available transformers
TRANSFORMER_REGISTRY: dict[str, type[StreamTransformer]] = {
"ts_stream": TSStreamTransformer,
}
def get_transformer(transformer_id: typing.Optional[str]) -> typing.Optional[StreamTransformer]:
"""
Get a transformer instance by ID.
Args:
transformer_id: The transformer identifier (e.g., "ts_stream").
Returns:
A new transformer instance, or None if transformer_id is None or not found.
"""
if transformer_id is None:
return None
transformer_class = TRANSFORMER_REGISTRY.get(transformer_id)
if transformer_class is None:
logger.warning(f"Unknown transformer ID: {transformer_id}")
return None
return transformer_class()
async def apply_transformer_to_bytes(
data: bytes,
transformer_id: typing.Optional[str],
) -> bytes:
"""
Apply a transformer to already-downloaded bytes data.
This is useful when serving cached segments that need transformation.
Creates a single-chunk async iterator and collects the transformed output.
Args:
data: The raw bytes data to transform.
transformer_id: The transformer identifier (e.g., "ts_stream").
Returns:
Transformed bytes, or original data if no transformer specified.
"""
if not transformer_id:
return data
transformer = get_transformer(transformer_id)
if not transformer:
return data
async def single_chunk_iterator():
yield data
# Collect all transformed chunks
result = bytearray()
async for chunk in transformer.transform(single_chunk_iterator()):
result.extend(chunk)
return bytes(result)

File diff suppressed because it is too large Load Diff

View File

@@ -3,13 +3,21 @@
"""hashlib that handles FIPS mode."""
# Because we are extending the hashlib module, we need to import all its
# fields to suppport the same uses
# pylint: disable=unused-wildcard-import, wildcard-import
from hashlib import *
# pylint: enable=unused-wildcard-import, wildcard-import
import hashlib
# Re-export commonly used hash constructors
sha1 = hashlib.sha1
sha224 = hashlib.sha224
sha256 = hashlib.sha256
sha384 = hashlib.sha384
sha512 = hashlib.sha512
sha3_224 = hashlib.sha3_224
sha3_256 = hashlib.sha3_256
sha3_384 = hashlib.sha3_384
sha3_512 = hashlib.sha3_512
blake2b = hashlib.blake2b
blake2s = hashlib.blake2s
def _fipsFunction(func, *args, **kwargs):
"""Make hash function support FIPS mode."""
@@ -19,8 +27,6 @@ def _fipsFunction(func, *args, **kwargs):
return func(*args, usedforsecurity=False, **kwargs)
# redefining the function is exactly what we intend to do
# pylint: disable=function-redefined
def md5(*args, **kwargs):
"""MD5 constructor that works in FIPS mode."""
return _fipsFunction(hashlib.md5, *args, **kwargs)
@@ -29,4 +35,3 @@ def md5(*args, **kwargs):
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
return _fipsFunction(hashlib.new, *args, **kwargs)
# pylint: enable=function-redefined

View File

@@ -11,17 +11,20 @@ Note that this makes this code FIPS non-compliant!
# fields to suppport the same uses
from . import tlshashlib
from .compat import compatHMAC
try:
from hmac import compare_digest
__all__ = ["new", "compare_digest", "HMAC"]
except ImportError:
__all__ = ["new", "HMAC"]
try:
from hmac import HMAC, new
# if we can calculate HMAC on MD5, then use the built-in HMAC
# implementation
_val = HMAC(b'some key', b'msg', 'md5')
_val = HMAC(b"some key", b"msg", "md5")
_val.digest()
del _val
except Exception:
@@ -38,10 +41,10 @@ except Exception:
"""
self.key = key
if digestmod is None:
digestmod = 'md5'
digestmod = "md5"
if callable(digestmod):
digestmod = digestmod()
if not hasattr(digestmod, 'digest_size'):
if not hasattr(digestmod, "digest_size"):
digestmod = tlshashlib.new(digestmod)
self.block_size = digestmod.block_size
self.digest_size = digestmod.digest_size
@@ -51,10 +54,10 @@ except Exception:
k_hash.update(compatHMAC(key))
key = k_hash.digest()
if len(key) < self.block_size:
key = key + b'\x00' * (self.block_size - len(key))
key = key + b"\x00" * (self.block_size - len(key))
key = bytearray(key)
ipad = bytearray(b'\x36' * self.block_size)
opad = bytearray(b'\x5c' * self.block_size)
ipad = bytearray(b"\x36" * self.block_size)
opad = bytearray(b"\x5c" * self.block_size)
i_key = bytearray(i ^ j for i, j in zip(key, ipad))
self._o_key = bytearray(i ^ j for i, j in zip(key, opad))
self._context = digestmod.copy()
@@ -82,7 +85,6 @@ except Exception:
new._context = self._context.copy()
return new
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
return HMAC(*args, **kwargs)