mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-04-11 11:50:51 +00:00
Add files via upload
This commit is contained in:
0
mediaflow_proxy/utils/__init__.py
Normal file
0
mediaflow_proxy/utils/__init__.py
Normal file
376
mediaflow_proxy/utils/cache_utils.py
Normal file
376
mediaflow_proxy/utils/cache_utils.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mediaflow_proxy.speedtest.models import SpeedTestTask
|
||||
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
|
||||
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Represents a cache entry with metadata."""
|
||||
|
||||
data: bytes
|
||||
expires_at: float
|
||||
access_count: int = 0
|
||||
last_access: float = 0.0
|
||||
size: int = 0
|
||||
|
||||
|
||||
class LRUMemoryCache:
|
||||
"""Thread-safe LRU memory cache with support."""
|
||||
|
||||
def __init__(self, maxsize: int):
|
||||
self.maxsize = maxsize
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
self._current_size = 0
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache.pop(key) # Remove and re-insert for LRU
|
||||
if time.time() < entry.expires_at:
|
||||
entry.access_count += 1
|
||||
entry.last_access = time.time()
|
||||
self._cache[key] = entry
|
||||
return entry
|
||||
else:
|
||||
# Remove expired entry
|
||||
self._current_size -= entry.size
|
||||
self._cache.pop(key, None)
|
||||
return None
|
||||
|
||||
def set(self, key: str, entry: CacheEntry) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
old_entry = self._cache[key]
|
||||
self._current_size -= old_entry.size
|
||||
|
||||
# Check if we need to make space
|
||||
while self._current_size + entry.size > self.maxsize and self._cache:
|
||||
_, removed_entry = self._cache.popitem(last=False)
|
||||
self._current_size -= removed_entry.size
|
||||
|
||||
self._cache[key] = entry
|
||||
self._current_size += entry.size
|
||||
|
||||
def remove(self, key: str) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache.pop(key)
|
||||
self._current_size -= entry.size
|
||||
|
||||
|
||||
class HybridCache:
|
||||
"""High-performance hybrid cache combining memory and file storage."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir_name: str,
|
||||
ttl: int,
|
||||
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
|
||||
executor_workers: int = 4,
|
||||
):
|
||||
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
|
||||
self.ttl = ttl
|
||||
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize cache directories
|
||||
self._init_cache_dirs()
|
||||
|
||||
def _init_cache_dirs(self):
|
||||
"""Initialize sharded cache directories."""
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _get_md5_hash(self, key: str) -> str:
|
||||
"""Get the MD5 hash of a cache key."""
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def _get_file_path(self, key: str) -> Path:
|
||||
"""Get the file path for a cache key."""
|
||||
return self.cache_dir / key
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||
"""
|
||||
Get value from cache, trying memory first then file.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Cached value or default if not found
|
||||
"""
|
||||
key = self._get_md5_hash(key)
|
||||
# Try memory cache first
|
||||
entry = self.memory_cache.get(key)
|
||||
if entry is not None:
|
||||
return entry.data
|
||||
|
||||
# Try file cache
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
metadata_size = await f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_size, "big")
|
||||
metadata_bytes = await f.read(metadata_length)
|
||||
metadata = json.loads(metadata_bytes.decode())
|
||||
|
||||
# Check expiration
|
||||
if metadata["expires_at"] < time.time():
|
||||
await self.delete(key)
|
||||
return default
|
||||
|
||||
# Read data
|
||||
data = await f.read()
|
||||
|
||||
# Update memory cache in background
|
||||
entry = CacheEntry(
|
||||
data=data,
|
||||
expires_at=metadata["expires_at"],
|
||||
access_count=metadata["access_count"] + 1,
|
||||
last_access=time.time(),
|
||||
size=len(data),
|
||||
)
|
||||
self.memory_cache.set(key, entry)
|
||||
|
||||
return data
|
||||
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from cache: {e}")
|
||||
return default
|
||||
|
||||
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in both memory and file cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
data: Data to cache
|
||||
ttl: Optional TTL override
|
||||
|
||||
Returns:
|
||||
bool: Success status
|
||||
"""
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise ValueError("Data must be bytes, bytearray, or memoryview")
|
||||
|
||||
expires_at = time.time() + (ttl or self.ttl)
|
||||
|
||||
# Create cache entry
|
||||
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
|
||||
|
||||
key = self._get_md5_hash(key)
|
||||
# Update memory cache
|
||||
self.memory_cache.set(key, entry)
|
||||
file_path = self._get_file_path(key)
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
|
||||
# Update file cache
|
||||
try:
|
||||
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
|
||||
metadata_bytes = json.dumps(metadata).encode()
|
||||
metadata_size = len(metadata_bytes).to_bytes(8, "big")
|
||||
|
||||
async with aiofiles.open(temp_path, "wb") as f:
|
||||
await f.write(metadata_size)
|
||||
await f.write(metadata_bytes)
|
||||
await f.write(data)
|
||||
|
||||
await aiofiles.os.rename(temp_path, file_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to cache: {e}")
|
||||
try:
|
||||
await aiofiles.os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete item from both caches."""
|
||||
self.memory_cache.remove(key)
|
||||
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
await aiofiles.os.remove(file_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting from cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AsyncMemoryCache:
|
||||
"""Async wrapper around LRUMemoryCache."""
|
||||
|
||||
def __init__(self, max_memory_size: int):
|
||||
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||
"""Get value from cache."""
|
||||
entry = self.memory_cache.get(key)
|
||||
return entry.data if entry is not None else default
|
||||
|
||||
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||
"""Set value in cache."""
|
||||
try:
|
||||
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
|
||||
entry = CacheEntry(
|
||||
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
|
||||
)
|
||||
self.memory_cache.set(key, entry)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting cache value: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete item from cache."""
|
||||
try:
|
||||
self.memory_cache.remove(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting from cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Create cache instances
|
||||
INIT_SEGMENT_CACHE = HybridCache(
|
||||
cache_dir_name="init_segment_cache",
|
||||
ttl=3600, # 1 hour
|
||||
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
|
||||
)
|
||||
|
||||
MPD_CACHE = AsyncMemoryCache(
|
||||
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
|
||||
)
|
||||
|
||||
SPEEDTEST_CACHE = HybridCache(
|
||||
cache_dir_name="speedtest_cache",
|
||||
ttl=3600, # 1 hour
|
||||
max_memory_size=50 * 1024 * 1024,
|
||||
)
|
||||
|
||||
EXTRACTOR_CACHE = HybridCache(
|
||||
cache_dir_name="extractor_cache",
|
||||
ttl=5 * 60, # 5 minutes
|
||||
max_memory_size=50 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
# Specific cache implementations
|
||||
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
|
||||
"""Get initialization segment from cache or download it."""
|
||||
# Try cache first
|
||||
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
|
||||
# Download if not cached
|
||||
try:
|
||||
init_content = await download_file_with_retry(init_url, headers)
|
||||
if init_content:
|
||||
await INIT_SEGMENT_CACHE.set(init_url, init_content)
|
||||
return init_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading init segment: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_cached_mpd(
|
||||
mpd_url: str,
|
||||
headers: dict,
|
||||
parse_drm: bool,
|
||||
parse_segment_profile_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get MPD from cache or download and parse it."""
|
||||
# Try cache first
|
||||
cached_data = await MPD_CACHE.get(mpd_url)
|
||||
if cached_data is not None:
|
||||
try:
|
||||
mpd_dict = json.loads(cached_data)
|
||||
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||
except json.JSONDecodeError:
|
||||
await MPD_CACHE.delete(mpd_url)
|
||||
|
||||
# Download and parse if not cached
|
||||
try:
|
||||
mpd_content = await download_file_with_retry(mpd_url, headers)
|
||||
mpd_dict = parse_mpd(mpd_content)
|
||||
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||
|
||||
# Cache the original MPD dict
|
||||
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"])
|
||||
return parsed_dict
|
||||
except DownloadError as error:
|
||||
logger.error(f"Error downloading MPD: {error}")
|
||||
raise error
|
||||
except Exception as error:
|
||||
logger.exception(f"Error processing MPD: {error}")
|
||||
raise error
|
||||
|
||||
|
||||
async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
|
||||
"""Get speed test results from cache."""
|
||||
cached_data = await SPEEDTEST_CACHE.get(task_id)
|
||||
if cached_data is not None:
|
||||
try:
|
||||
return SpeedTestTask.model_validate_json(cached_data.decode())
|
||||
except ValidationError as e:
|
||||
logger.error(f"Error parsing cached speed test data: {e}")
|
||||
await SPEEDTEST_CACHE.delete(task_id)
|
||||
return None
|
||||
|
||||
|
||||
async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool:
|
||||
"""Cache speed test results."""
|
||||
try:
|
||||
return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode())
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching speed test data: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_cached_extractor_result(key: str) -> Optional[dict]:
|
||||
"""Get extractor result from cache."""
|
||||
cached_data = await EXTRACTOR_CACHE.get(key)
|
||||
if cached_data is not None:
|
||||
try:
|
||||
return json.loads(cached_data)
|
||||
except json.JSONDecodeError:
|
||||
await EXTRACTOR_CACHE.delete(key)
|
||||
return None
|
||||
|
||||
|
||||
async def set_cache_extractor_result(key: str, result: dict) -> bool:
|
||||
"""Cache extractor result."""
|
||||
try:
|
||||
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching extractor result: {e}")
|
||||
return False
|
||||
110
mediaflow_proxy/utils/crypto_utils.py
Normal file
110
mediaflow_proxy/utils/crypto_utils.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
|
||||
|
||||
class EncryptionHandler:
|
||||
def __init__(self, secret_key: str):
|
||||
self.secret_key = secret_key.encode("utf-8").ljust(32)[:32]
|
||||
|
||||
def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str:
|
||||
if expiration:
|
||||
data["exp"] = int(time.time()) + expiration
|
||||
if ip:
|
||||
data["ip"] = ip
|
||||
json_data = json.dumps(data).encode("utf-8")
|
||||
iv = get_random_bytes(16)
|
||||
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
|
||||
encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
|
||||
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")
|
||||
|
||||
def decrypt_data(self, token: str, client_ip: str) -> dict:
|
||||
try:
|
||||
encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
iv = encrypted_data[:16]
|
||||
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
|
||||
decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
if "exp" in data:
|
||||
if data["exp"] < time.time():
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
del data["exp"] # Remove expiration from the data
|
||||
|
||||
if "ip" in data:
|
||||
if data["ip"] != client_ip:
|
||||
raise HTTPException(status_code=403, detail="IP address mismatch")
|
||||
del data["ip"] # Remove IP from the data
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
|
||||
class EncryptionMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.encryption_handler = encryption_handler
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
encrypted_token = request.query_params.get("token")
|
||||
if encrypted_token and self.encryption_handler:
|
||||
try:
|
||||
client_ip = self.get_client_ip(request)
|
||||
decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
|
||||
# Modify request query parameters with decrypted data
|
||||
query_params = dict(request.query_params)
|
||||
query_params.pop("token") # Remove the encrypted token from query params
|
||||
query_params.update(decrypted_data) # Add decrypted data to query params
|
||||
query_params["has_encrypted"] = True
|
||||
|
||||
# Create a new request scope with updated query parameters
|
||||
new_query_string = urlencode(query_params)
|
||||
request.scope["query_string"] = new_query_string.encode()
|
||||
request._query_params = query_params
|
||||
except HTTPException as e:
|
||||
return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
except Exception:
|
||||
exc = traceback.format_exc(chain=False)
|
||||
logging.error("An error occurred while processing the request, error: %s", exc)
|
||||
return JSONResponse(
|
||||
content={"error": "An error occurred while processing the request, check the server for logs"},
|
||||
status_code=500,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract the client's real IP address from the request headers or fallback to the client host.
|
||||
"""
|
||||
x_forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if x_forwarded_for:
|
||||
# In some cases, this header can contain multiple IPs
|
||||
# separated by commas.
|
||||
# The first one is the original client's IP.
|
||||
return x_forwarded_for.split(",")[0].strip()
|
||||
# Fallback to X-Real-IP if X-Forwarded-For is not available
|
||||
x_real_ip = request.headers.get("X-Real-IP")
|
||||
if x_real_ip:
|
||||
return x_real_ip
|
||||
return request.client.host if request.client else "127.0.0.1"
|
||||
|
||||
|
||||
encryption_handler = EncryptionHandler(settings.api_password) if settings.api_password else None
|
||||
430
mediaflow_proxy/utils/http_utils.py
Normal file
430
mediaflow_proxy/utils/http_utils.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import logging
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from urllib import parse
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import tenacity
|
||||
from fastapi import Response
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Receive, Send, Scope
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from tqdm.asyncio import tqdm as tqdm_asyncio
|
||||
|
||||
from mediaflow_proxy.configs import settings
|
||||
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
|
||||
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient:
|
||||
"""Creates an HTTPX client with configured proxy routing"""
|
||||
mounts = settings.transport_config.get_mounts()
|
||||
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs)
|
||||
return client
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(DownloadError),
|
||||
)
|
||||
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
|
||||
"""
|
||||
Fetches a URL with retry logic.
|
||||
|
||||
Args:
|
||||
client (httpx.AsyncClient): The HTTP client to use for the request.
|
||||
method (str): The HTTP method to use (e.g., GET, POST).
|
||||
url (str): The URL to fetch.
|
||||
headers (dict): The headers to include in the request.
|
||||
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
|
||||
**kwargs: Additional arguments to pass to the request.
|
||||
|
||||
Returns:
|
||||
httpx.Response: The HTTP response.
|
||||
|
||||
Raises:
|
||||
DownloadError: If the request fails after retries.
|
||||
"""
|
||||
try:
|
||||
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.TimeoutException:
|
||||
logger.warning(f"Timeout while downloading {url}")
|
||||
raise DownloadError(409, f"Timeout while downloading {url}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
|
||||
if e.response.status_code == 404:
|
||||
logger.error(f"Segment Resource not found: {url}")
|
||||
raise e
|
||||
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading {url}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class Streamer:
|
||||
def __init__(self, client):
|
||||
"""
|
||||
Initializes the Streamer with an HTTP client.
|
||||
|
||||
Args:
|
||||
client (httpx.AsyncClient): The HTTP client to use for streaming.
|
||||
"""
|
||||
self.client = client
|
||||
self.response = None
|
||||
self.progress_bar = None
|
||||
self.bytes_transferred = 0
|
||||
self.start_byte = 0
|
||||
self.end_byte = 0
|
||||
self.total_size = 0
|
||||
|
||||
async def create_streaming_response(self, url: str, headers: dict):
|
||||
"""
|
||||
Creates and sends a streaming request.
|
||||
|
||||
Args:
|
||||
url (str): The URL to stream from.
|
||||
headers (dict): The headers to include in the request.
|
||||
|
||||
"""
|
||||
request = self.client.build_request("GET", url, headers=headers)
|
||||
self.response = await self.client.send(request, stream=True, follow_redirects=True)
|
||||
self.response.raise_for_status()
|
||||
|
||||
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Streams the content from the response.
|
||||
"""
|
||||
if not self.response:
|
||||
raise RuntimeError("No response available for streaming")
|
||||
|
||||
try:
|
||||
self.parse_content_range()
|
||||
|
||||
if settings.enable_streaming_progress:
|
||||
with tqdm_asyncio(
|
||||
total=self.total_size,
|
||||
initial=self.start_byte,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
desc="Streaming",
|
||||
ncols=100,
|
||||
mininterval=1,
|
||||
) as self.progress_bar:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
yield chunk
|
||||
chunk_size = len(chunk)
|
||||
self.bytes_transferred += chunk_size
|
||||
self.progress_bar.set_postfix_str(
|
||||
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
|
||||
)
|
||||
self.progress_bar.update(chunk_size)
|
||||
else:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
yield chunk
|
||||
self.bytes_transferred += len(chunk)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Timeout while streaming")
|
||||
raise DownloadError(409, "Timeout while streaming")
|
||||
except GeneratorExit:
|
||||
logger.info("Streaming session stopped by the user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming content: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def format_bytes(size) -> str:
|
||||
power = 2**10
|
||||
n = 0
|
||||
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
|
||||
while size > power:
|
||||
size /= power
|
||||
n += 1
|
||||
return f"{size:.2f} {units[n]}"
|
||||
|
||||
def parse_content_range(self):
|
||||
content_range = self.response.headers.get("Content-Range", "")
|
||||
if content_range:
|
||||
range_info = content_range.split()[-1]
|
||||
self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-"))
|
||||
else:
|
||||
self.start_byte = 0
|
||||
self.total_size = int(self.response.headers.get("Content-Length", 0))
|
||||
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
|
||||
|
||||
async def get_text(self, url: str, headers: dict):
|
||||
"""
|
||||
Sends a GET request to a URL and returns the response text.
|
||||
|
||||
Args:
|
||||
url (str): The URL to send the GET request to.
|
||||
headers (dict): The headers to include in the request.
|
||||
|
||||
Returns:
|
||||
str: The response text.
|
||||
"""
|
||||
try:
|
||||
self.response = await fetch_with_retry(self.client, "GET", url, headers)
|
||||
except tenacity.RetryError as e:
|
||||
raise e.last_attempt.result()
|
||||
return self.response.text
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the HTTP client and response.
|
||||
"""
|
||||
if self.response:
|
||||
await self.response.aclose()
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
async def download_file_with_retry(url: str, headers: dict):
|
||||
"""
|
||||
Downloads a file with retry logic.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to download.
|
||||
headers (dict): The headers to include in the request.
|
||||
|
||||
Returns:
|
||||
bytes: The downloaded file content.
|
||||
|
||||
Raises:
|
||||
DownloadError: If the download fails after retries.
|
||||
"""
|
||||
async with create_httpx_client() as client:
|
||||
try:
|
||||
response = await fetch_with_retry(client, "GET", url, headers)
|
||||
return response.content
|
||||
except DownloadError as e:
|
||||
logger.error(f"Failed to download file: {e}")
|
||||
raise e
|
||||
except tenacity.RetryError as e:
|
||||
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
|
||||
|
||||
|
||||
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
|
||||
"""
|
||||
Sends an HTTP request with retry logic.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use (e.g., GET, POST).
|
||||
url (str): The URL to send the request to.
|
||||
headers (dict): The headers to include in the request.
|
||||
**kwargs: Additional arguments to pass to the request.
|
||||
|
||||
Returns:
|
||||
httpx.Response: The HTTP response.
|
||||
|
||||
Raises:
|
||||
DownloadError: If the request fails after retries.
|
||||
"""
|
||||
async with create_httpx_client() as client:
|
||||
try:
|
||||
response = await fetch_with_retry(client, method, url, headers, **kwargs)
|
||||
return response
|
||||
except DownloadError as e:
|
||||
logger.error(f"Failed to download file: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def encode_mediaflow_proxy_url(
|
||||
mediaflow_proxy_url: str,
|
||||
endpoint: typing.Optional[str] = None,
|
||||
destination_url: typing.Optional[str] = None,
|
||||
query_params: typing.Optional[dict] = None,
|
||||
request_headers: typing.Optional[dict] = None,
|
||||
response_headers: typing.Optional[dict] = None,
|
||||
encryption_handler: EncryptionHandler = None,
|
||||
expiration: int = None,
|
||||
ip: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
|
||||
|
||||
Args:
|
||||
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
|
||||
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
|
||||
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
|
||||
query_params (dict, optional): Additional query parameters to include. Defaults to None.
|
||||
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
|
||||
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
|
||||
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The encoded MediaFlow proxy URL.
|
||||
"""
|
||||
query_params = query_params or {}
|
||||
if destination_url is not None:
|
||||
query_params["d"] = destination_url
|
||||
|
||||
# Add headers if provided
|
||||
if request_headers:
|
||||
query_params.update(
|
||||
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
|
||||
)
|
||||
if response_headers:
|
||||
query_params.update(
|
||||
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
|
||||
)
|
||||
|
||||
if encryption_handler:
|
||||
encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip)
|
||||
encoded_params = urlencode({"token": encrypted_token})
|
||||
else:
|
||||
encoded_params = urlencode(query_params)
|
||||
|
||||
# Construct the full URL
|
||||
if endpoint is None:
|
||||
return f"{mediaflow_proxy_url}?{encoded_params}"
|
||||
|
||||
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
|
||||
return f"{base_url}?{encoded_params}"
|
||||
|
||||
|
||||
def get_original_scheme(request: Request) -> str:
|
||||
"""
|
||||
Determines the original scheme (http or https) of the request.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
|
||||
Returns:
|
||||
str: The original scheme ('http' or 'https')
|
||||
"""
|
||||
# Check the X-Forwarded-Proto header first
|
||||
forwarded_proto = request.headers.get("X-Forwarded-Proto")
|
||||
if forwarded_proto:
|
||||
return forwarded_proto
|
||||
|
||||
# Check if the request is secure
|
||||
if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
|
||||
return "https"
|
||||
|
||||
# Check for other common headers that might indicate HTTPS
|
||||
if (
|
||||
request.headers.get("X-Forwarded-Ssl") == "on"
|
||||
or request.headers.get("X-Forwarded-Protocol") == "https"
|
||||
or request.headers.get("X-Url-Scheme") == "https"
|
||||
):
|
||||
return "https"
|
||||
|
||||
# Default to http if no indicators of https are found
|
||||
return "http"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyRequestHeaders:
|
||||
request: dict
|
||||
response: dict
|
||||
|
||||
|
||||
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
|
||||
"""
|
||||
Extracts proxy headers from the request query parameters.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
|
||||
Returns:
|
||||
ProxyRequest: A named tuple containing the request headers and response headers.
|
||||
"""
|
||||
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
|
||||
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
|
||||
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
|
||||
return ProxyRequestHeaders(request_headers, response_headers)
|
||||
|
||||
|
||||
class EnhancedStreamingResponse(Response):
|
||||
body_iterator: typing.AsyncIterable[typing.Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
|
||||
status_code: int = 200,
|
||||
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
media_type: typing.Optional[str] = None,
|
||||
background: typing.Optional[BackgroundTask] = None,
|
||||
) -> None:
|
||||
if isinstance(content, typing.AsyncIterable):
|
||||
self.body_iterator = content
|
||||
else:
|
||||
self.body_iterator = iterate_in_threadpool(content)
|
||||
self.status_code = status_code
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
@staticmethod
|
||||
async def listen_for_disconnect(receive: Receive) -> None:
|
||||
try:
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
logger.debug("Client disconnected")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen_for_disconnect: {str(e)}")
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
try:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
if not isinstance(chunk, (bytes, memoryview)):
|
||||
chunk = chunk.encode(self.charset)
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
except (ConnectionResetError, anyio.BrokenResourceError):
|
||||
logger.info("Client disconnected during streaming")
|
||||
return
|
||||
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in stream_response: {str(e)}")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||
try:
|
||||
await func()
|
||||
except Exception as e:
|
||||
if not isinstance(e, anyio.get_cancelled_exc_class()):
|
||||
logger.exception("Error in streaming task")
|
||||
raise
|
||||
finally:
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||
await wrap(partial(self.listen_for_disconnect, receive))
|
||||
|
||||
if self.background is not None:
|
||||
await self.background()
|
||||
87
mediaflow_proxy/utils/m3u8_processor.py
Normal file
87
mediaflow_proxy/utils/m3u8_processor.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import re
|
||||
from urllib import parse
|
||||
|
||||
from mediaflow_proxy.utils.crypto_utils import encryption_handler
|
||||
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
|
||||
|
||||
|
||||
class M3U8Processor:
|
||||
def __init__(self, request, key_url: str = None):
|
||||
"""
|
||||
Initializes the M3U8Processor with the request and URL prefix.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming HTTP request.
|
||||
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
|
||||
"""
|
||||
self.request = request
|
||||
self.key_url = parse.urlparse(key_url) if key_url else None
|
||||
self.mediaflow_proxy_url = str(
|
||||
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
|
||||
)
|
||||
|
||||
async def process_m3u8(self, content: str, base_url: str) -> str:
|
||||
"""
|
||||
Processes the m3u8 content, proxying URLs and handling key lines.
|
||||
|
||||
Args:
|
||||
content (str): The m3u8 content to process.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
|
||||
Returns:
|
||||
str: The processed m3u8 content.
|
||||
"""
|
||||
lines = content.splitlines()
|
||||
processed_lines = []
|
||||
for line in lines:
|
||||
if "URI=" in line:
|
||||
processed_lines.append(await self.process_key_line(line, base_url))
|
||||
elif not line.startswith("#") and line.strip():
|
||||
processed_lines.append(await self.proxy_url(line, base_url))
|
||||
else:
|
||||
processed_lines.append(line)
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
async def process_key_line(self, line: str, base_url: str) -> str:
|
||||
"""
|
||||
Processes a key line in the m3u8 content, proxying the URI.
|
||||
|
||||
Args:
|
||||
line (str): The key line to process.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
|
||||
Returns:
|
||||
str: The processed key line.
|
||||
"""
|
||||
uri_match = re.search(r'URI="([^"]+)"', line)
|
||||
if uri_match:
|
||||
original_uri = uri_match.group(1)
|
||||
uri = parse.urlparse(original_uri)
|
||||
if self.key_url:
|
||||
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
|
||||
new_uri = await self.proxy_url(uri.geturl(), base_url)
|
||||
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
|
||||
return line
|
||||
|
||||
async def proxy_url(self, url: str, base_url: str) -> str:
|
||||
"""
|
||||
Proxies a URL, encoding it with the MediaFlow proxy URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL to proxy.
|
||||
base_url (str): The base URL to resolve relative URLs.
|
||||
|
||||
Returns:
|
||||
str: The proxied URL.
|
||||
"""
|
||||
full_url = parse.urljoin(base_url, url)
|
||||
query_params = dict(self.request.query_params)
|
||||
has_encrypted = query_params.pop("has_encrypted", False)
|
||||
|
||||
return encode_mediaflow_proxy_url(
|
||||
self.mediaflow_proxy_url,
|
||||
"",
|
||||
full_url,
|
||||
query_params=dict(self.request.query_params),
|
||||
encryption_handler=encryption_handler if has_encrypted else None,
|
||||
)
|
||||
555
mediaflow_proxy/utils/mpd_utils.py
Normal file
555
mediaflow_proxy/utils/mpd_utils.py
Normal file
@@ -0,0 +1,555 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Dict, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import xmltodict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
|
||||
"""
|
||||
Parses the MPD content into a dictionary.
|
||||
|
||||
Args:
|
||||
mpd_content (Union[str, bytes]): The MPD content to parse.
|
||||
|
||||
Returns:
|
||||
dict: The parsed MPD content as a dictionary.
|
||||
"""
|
||||
return xmltodict.parse(mpd_content)
|
||||
|
||||
|
||||
def parse_mpd_dict(
|
||||
mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Parses the MPD dictionary and extracts relevant information.
|
||||
|
||||
Args:
|
||||
mpd_dict (dict): The MPD content as a dictionary.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
|
||||
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The parsed MPD information including profiles and DRM info.
|
||||
|
||||
This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
|
||||
It handles both live and static MPD manifests.
|
||||
"""
|
||||
profiles = []
|
||||
parsed_dict = {}
|
||||
source = "/".join(mpd_url.split("/")[:-1])
|
||||
|
||||
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
|
||||
parsed_dict["isLive"] = is_live
|
||||
|
||||
media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
|
||||
|
||||
# Parse additional MPD attributes for live streams
|
||||
if is_live:
|
||||
parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
|
||||
parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
|
||||
parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
|
||||
mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
|
||||
)
|
||||
parsed_dict["publishTime"] = datetime.fromisoformat(
|
||||
mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
periods = mpd_dict["MPD"]["Period"]
|
||||
periods = periods if isinstance(periods, list) else [periods]
|
||||
|
||||
for period in periods:
|
||||
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
|
||||
for adaptation in period["AdaptationSet"]:
|
||||
representations = adaptation["Representation"]
|
||||
representations = representations if isinstance(representations, list) else [representations]
|
||||
|
||||
for representation in representations:
|
||||
profile = parse_representation(
|
||||
parsed_dict,
|
||||
representation,
|
||||
adaptation,
|
||||
source,
|
||||
media_presentation_duration,
|
||||
parse_segment_profile_id,
|
||||
)
|
||||
if profile:
|
||||
profiles.append(profile)
|
||||
parsed_dict["profiles"] = profiles
|
||||
|
||||
if parse_drm:
|
||||
drm_info = extract_drm_info(periods, mpd_url)
|
||||
else:
|
||||
drm_info = {}
|
||||
parsed_dict["drmInfo"] = drm_info
|
||||
|
||||
return parsed_dict
|
||||
|
||||
|
||||
def pad_base64(encoded_key_id):
|
||||
"""
|
||||
Pads a base64 encoded key ID to make its length a multiple of 4.
|
||||
|
||||
Args:
|
||||
encoded_key_id (str): The base64 encoded key ID.
|
||||
|
||||
Returns:
|
||||
str: The padded base64 encoded key ID.
|
||||
"""
|
||||
return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
|
||||
|
||||
|
||||
def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
|
||||
"""
|
||||
Extracts DRM information from the MPD periods.
|
||||
|
||||
Args:
|
||||
periods (List[Dict]): The list of periods in the MPD.
|
||||
mpd_url (str): The URL of the MPD manifest.
|
||||
|
||||
Returns:
|
||||
Dict: The extracted DRM information.
|
||||
|
||||
This function processes the ContentProtection elements in the MPD to extract DRM system information,
|
||||
such as ClearKey, Widevine, and PlayReady.
|
||||
"""
|
||||
drm_info = {"isDrmProtected": False}
|
||||
|
||||
for period in periods:
|
||||
adaptation_sets: Union[list[dict], dict] = period.get("AdaptationSet", [])
|
||||
if not isinstance(adaptation_sets, list):
|
||||
adaptation_sets = [adaptation_sets]
|
||||
|
||||
for adaptation_set in adaptation_sets:
|
||||
# Check ContentProtection in AdaptationSet
|
||||
process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
|
||||
|
||||
# Check ContentProtection inside each Representation
|
||||
representations: Union[list[dict], dict] = adaptation_set.get("Representation", [])
|
||||
if not isinstance(representations, list):
|
||||
representations = [representations]
|
||||
|
||||
for representation in representations:
|
||||
process_content_protection(representation.get("ContentProtection", []), drm_info)
|
||||
|
||||
# If we have a license acquisition URL, make sure it's absolute
|
||||
if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
|
||||
drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
|
||||
|
||||
return drm_info
|
||||
|
||||
|
||||
def process_content_protection(content_protection: Union[list[dict], dict], drm_info: dict):
|
||||
"""
|
||||
Processes the ContentProtection elements to extract DRM information.
|
||||
|
||||
Args:
|
||||
content_protection (Union[list[dict], dict]): The ContentProtection elements.
|
||||
drm_info (dict): The dictionary to store DRM information.
|
||||
|
||||
This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
|
||||
"""
|
||||
if not isinstance(content_protection, list):
|
||||
content_protection = [content_protection]
|
||||
|
||||
for protection in content_protection:
|
||||
drm_info["isDrmProtected"] = True
|
||||
scheme_id_uri = protection.get("@schemeIdUri", "").lower()
|
||||
|
||||
if "clearkey" in scheme_id_uri:
|
||||
drm_info["drmSystem"] = "clearkey"
|
||||
if "clearkey:Laurl" in protection:
|
||||
la_url = protection["clearkey:Laurl"].get("#text")
|
||||
if la_url and "laUrl" not in drm_info:
|
||||
drm_info["laUrl"] = la_url
|
||||
|
||||
elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
|
||||
drm_info["drmSystem"] = "widevine"
|
||||
pssh = protection.get("cenc:pssh", {}).get("#text")
|
||||
if pssh:
|
||||
drm_info["pssh"] = pssh
|
||||
|
||||
elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
|
||||
drm_info["drmSystem"] = "playready"
|
||||
|
||||
if "@cenc:default_KID" in protection:
|
||||
key_id = protection["@cenc:default_KID"].replace("-", "")
|
||||
if "keyId" not in drm_info:
|
||||
drm_info["keyId"] = key_id
|
||||
|
||||
if "ms:laurl" in protection:
|
||||
la_url = protection["ms:laurl"].get("@licenseUrl")
|
||||
if la_url and "laUrl" not in drm_info:
|
||||
drm_info["laUrl"] = la_url
|
||||
|
||||
return drm_info
|
||||
|
||||
|
||||
def parse_representation(
|
||||
parsed_dict: dict,
|
||||
representation: dict,
|
||||
adaptation: dict,
|
||||
source: str,
|
||||
media_presentation_duration: str,
|
||||
parse_segment_profile_id: Optional[str],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Parses a representation and extracts profile information.
|
||||
|
||||
Args:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
representation (dict): The representation data.
|
||||
adaptation (dict): The adaptation set data.
|
||||
source (str): The source URL.
|
||||
media_presentation_duration (str): The media presentation duration.
|
||||
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The parsed profile information or None if not applicable.
|
||||
"""
|
||||
mime_type = _get_key(adaptation, representation, "@mimeType") or (
|
||||
"video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
|
||||
)
|
||||
if "video" not in mime_type and "audio" not in mime_type:
|
||||
return None
|
||||
|
||||
profile = {
|
||||
"id": representation.get("@id") or adaptation.get("@id"),
|
||||
"mimeType": mime_type,
|
||||
"lang": representation.get("@lang") or adaptation.get("@lang"),
|
||||
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
|
||||
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
|
||||
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
|
||||
"mediaPresentationDuration": media_presentation_duration,
|
||||
}
|
||||
|
||||
if "audio" in profile["mimeType"]:
|
||||
profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
|
||||
profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
|
||||
else:
|
||||
profile["width"] = int(representation["@width"])
|
||||
profile["height"] = int(representation["@height"])
|
||||
frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
|
||||
frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
|
||||
profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
|
||||
profile["sar"] = representation.get("@sar", "1:1")
|
||||
|
||||
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
|
||||
return profile
|
||||
|
||||
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
|
||||
if item:
|
||||
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
|
||||
else:
|
||||
profile["segments"] = parse_segment_base(representation, source)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a key from the representation or adaptation set.
|
||||
|
||||
Args:
|
||||
adaptation (dict): The adaptation set data.
|
||||
representation (dict): The representation data.
|
||||
key (str): The key to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The value of the key or None if not found.
|
||||
"""
|
||||
return representation.get(key, adaptation.get(key, None))
|
||||
|
||||
|
||||
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
|
||||
"""
|
||||
Parses a segment template and extracts segment information.
|
||||
|
||||
Args:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment template data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
segments = []
|
||||
timescale = int(item.get("@timescale", 1))
|
||||
|
||||
# Initialization
|
||||
if "@initialization" in item:
|
||||
media = item["@initialization"]
|
||||
media = media.replace("$RepresentationID$", profile["id"])
|
||||
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
||||
if not media.startswith("http"):
|
||||
media = f"{source}/{media}"
|
||||
profile["initUrl"] = media
|
||||
|
||||
# Segments
|
||||
if "SegmentTimeline" in item:
|
||||
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
|
||||
elif "@duration" in item:
|
||||
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||
"""
|
||||
Parses a segment timeline and extracts segment information.
|
||||
|
||||
Args:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment timeline data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
timescale (int): The timescale for the segments.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
timelines = item["SegmentTimeline"]["S"]
|
||||
timelines = timelines if isinstance(timelines, list) else [timelines]
|
||||
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
|
||||
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
|
||||
start_number = int(item.get("@startNumber", 1))
|
||||
|
||||
segments = [
|
||||
create_segment_data(timeline, item, profile, source, timescale)
|
||||
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
|
||||
]
|
||||
return segments
|
||||
|
||||
|
||||
def preprocess_timeline(
|
||||
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Preprocesses the segment timeline data.
|
||||
|
||||
Args:
|
||||
timelines (List[Dict]): The list of timeline segments.
|
||||
start_number (int): The starting segment number.
|
||||
period_start (datetime): The start time of the period.
|
||||
presentation_time_offset (int): The presentation time offset.
|
||||
timescale (int): The timescale for the segments.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of preprocessed timeline segments.
|
||||
"""
|
||||
processed_data = []
|
||||
current_time = 0
|
||||
for timeline in timelines:
|
||||
repeat = int(timeline.get("@r", 0))
|
||||
duration = int(timeline["@d"])
|
||||
start_time = int(timeline.get("@t", current_time))
|
||||
|
||||
for _ in range(repeat + 1):
|
||||
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
|
||||
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
|
||||
processed_data.append(
|
||||
{
|
||||
"number": start_number,
|
||||
"start_time": segment_start_time,
|
||||
"end_time": segment_end_time,
|
||||
"duration": duration,
|
||||
"time": start_time,
|
||||
}
|
||||
)
|
||||
start_time += duration
|
||||
start_number += 1
|
||||
|
||||
current_time = start_time
|
||||
|
||||
return processed_data
|
||||
|
||||
|
||||
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||
"""
|
||||
Parses segment duration and extracts segment information.
|
||||
This is used for static or live MPD manifests.
|
||||
|
||||
Args:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
item (dict): The segment duration data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
timescale (int): The timescale for the segments.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
duration = int(item["@duration"])
|
||||
start_number = int(item.get("@startNumber", 1))
|
||||
segment_duration_sec = duration / timescale
|
||||
|
||||
if parsed_dict["isLive"]:
|
||||
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
|
||||
else:
|
||||
segments = generate_vod_segments(profile, duration, timescale, start_number)
|
||||
|
||||
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
|
||||
|
||||
|
||||
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
|
||||
"""
|
||||
Generates live segments based on the segment duration and start number.
|
||||
This is used for live MPD manifests.
|
||||
|
||||
Args:
|
||||
parsed_dict (dict): The parsed MPD data.
|
||||
segment_duration_sec (float): The segment duration in seconds.
|
||||
start_number (int): The starting segment number.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of generated live segments.
|
||||
"""
|
||||
time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
|
||||
segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
earliest_segment_number = max(
|
||||
start_number
|
||||
+ math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
|
||||
- segment_count,
|
||||
start_number,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"number": number,
|
||||
"start_time": parsed_dict["availabilityStartTime"]
|
||||
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
|
||||
"duration": segment_duration_sec,
|
||||
}
|
||||
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
|
||||
]
|
||||
|
||||
|
||||
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
|
||||
"""
|
||||
Generates VOD segments based on the segment duration and start number.
|
||||
This is used for static MPD manifests.
|
||||
|
||||
Args:
|
||||
profile (dict): The profile information.
|
||||
duration (int): The segment duration.
|
||||
timescale (int): The timescale for the segments.
|
||||
start_number (int): The starting segment number.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of generated VOD segments.
|
||||
"""
|
||||
total_duration = profile.get("mediaPresentationDuration") or 0
|
||||
if isinstance(total_duration, str):
|
||||
total_duration = parse_duration(total_duration)
|
||||
segment_count = math.ceil(total_duration * timescale / duration)
|
||||
|
||||
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
|
||||
|
||||
|
||||
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
Creates segment data based on the segment information. This includes the segment URL and metadata.
|
||||
|
||||
Args:
|
||||
segment (Dict): The segment information.
|
||||
item (dict): The segment template data.
|
||||
profile (dict): The profile information.
|
||||
source (str): The source URL.
|
||||
timescale (int, optional): The timescale for the segments. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: The created segment data.
|
||||
"""
|
||||
media_template = item["@media"]
|
||||
media = media_template.replace("$RepresentationID$", profile["id"])
|
||||
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
|
||||
media = media.replace("$Number$", str(segment["number"]))
|
||||
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
||||
|
||||
if "time" in segment and timescale is not None:
|
||||
media = media.replace("$Time$", str(int(segment["time"] * timescale)))
|
||||
|
||||
if not media.startswith("http"):
|
||||
media = f"{source}/{media}"
|
||||
|
||||
segment_data = {
|
||||
"type": "segment",
|
||||
"media": media,
|
||||
"number": segment["number"],
|
||||
}
|
||||
|
||||
if "start_time" in segment and "end_time" in segment:
|
||||
segment_data.update(
|
||||
{
|
||||
"start_time": segment["start_time"],
|
||||
"end_time": segment["end_time"],
|
||||
"extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
|
||||
"program_date_time": segment["start_time"].isoformat() + "Z",
|
||||
}
|
||||
)
|
||||
elif "start_time" in segment and "duration" in segment:
|
||||
duration = segment["duration"]
|
||||
segment_data.update(
|
||||
{
|
||||
"start_time": segment["start_time"],
|
||||
"end_time": segment["start_time"] + timedelta(seconds=duration),
|
||||
"extinf": duration,
|
||||
"program_date_time": segment["start_time"].isoformat() + "Z",
|
||||
}
|
||||
)
|
||||
elif "duration" in segment:
|
||||
segment_data["extinf"] = segment["duration"]
|
||||
|
||||
return segment_data
|
||||
|
||||
|
||||
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
|
||||
"""
|
||||
Parses segment base information and extracts segment data. This is used for single-segment representations.
|
||||
|
||||
Args:
|
||||
representation (dict): The representation data.
|
||||
source (str): The source URL.
|
||||
|
||||
Returns:
|
||||
List[Dict]: The list of parsed segments.
|
||||
"""
|
||||
segment = representation["SegmentBase"]
|
||||
start, end = map(int, segment["@indexRange"].split("-"))
|
||||
if "Initialization" in segment:
|
||||
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "segment",
|
||||
"range": f"{start}-{end}",
|
||||
"media": f"{source}/{representation['BaseURL']}",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def parse_duration(duration_str: str) -> float:
|
||||
"""
|
||||
Parses a duration ISO 8601 string into seconds.
|
||||
|
||||
Args:
|
||||
duration_str (str): The duration string to parse.
|
||||
|
||||
Returns:
|
||||
float: The parsed duration in seconds.
|
||||
"""
|
||||
pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
|
||||
match = pattern.match(duration_str)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid duration format: {duration_str}")
|
||||
|
||||
years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
|
||||
return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds
|
||||
Reference in New Issue
Block a user