Add files via upload

This commit is contained in:
Urlo30
2024-12-29 23:18:53 +01:00
committed by GitHub
parent 76987906b5
commit fb60e99822
40 changed files with 5170 additions and 0 deletions

View File

View File

@@ -0,0 +1,376 @@
import asyncio
import hashlib
import json
import logging
import os
import tempfile
import threading
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, Any
import aiofiles
import aiofiles.os
from pydantic import ValidationError
from mediaflow_proxy.speedtest.models import SpeedTestTask
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Represents a cache entry with metadata."""
data: bytes
expires_at: float
access_count: int = 0
last_access: float = 0.0
size: int = 0
class LRUMemoryCache:
"""Thread-safe LRU memory cache with support."""
def __init__(self, maxsize: int):
self.maxsize = maxsize
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.Lock()
self._current_size = 0
def get(self, key: str) -> Optional[CacheEntry]:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key) # Remove and re-insert for LRU
if time.time() < entry.expires_at:
entry.access_count += 1
entry.last_access = time.time()
self._cache[key] = entry
return entry
else:
# Remove expired entry
self._current_size -= entry.size
self._cache.pop(key, None)
return None
def set(self, key: str, entry: CacheEntry) -> None:
with self._lock:
if key in self._cache:
old_entry = self._cache[key]
self._current_size -= old_entry.size
# Check if we need to make space
while self._current_size + entry.size > self.maxsize and self._cache:
_, removed_entry = self._cache.popitem(last=False)
self._current_size -= removed_entry.size
self._cache[key] = entry
self._current_size += entry.size
def remove(self, key: str) -> None:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key)
self._current_size -= entry.size
class HybridCache:
"""High-performance hybrid cache combining memory and file storage."""
def __init__(
self,
cache_dir_name: str,
ttl: int,
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
executor_workers: int = 4,
):
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
self.ttl = ttl
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
self._lock = asyncio.Lock()
# Initialize cache directories
self._init_cache_dirs()
def _init_cache_dirs(self):
"""Initialize sharded cache directories."""
os.makedirs(self.cache_dir, exist_ok=True)
def _get_md5_hash(self, key: str) -> str:
"""Get the MD5 hash of a cache key."""
return hashlib.md5(key.encode()).hexdigest()
def _get_file_path(self, key: str) -> Path:
"""Get the file path for a cache key."""
return self.cache_dir / key
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""
Get value from cache, trying memory first then file.
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default if not found
"""
key = self._get_md5_hash(key)
# Try memory cache first
entry = self.memory_cache.get(key)
if entry is not None:
return entry.data
# Try file cache
try:
file_path = self._get_file_path(key)
async with aiofiles.open(file_path, "rb") as f:
metadata_size = await f.read(8)
metadata_length = int.from_bytes(metadata_size, "big")
metadata_bytes = await f.read(metadata_length)
metadata = json.loads(metadata_bytes.decode())
# Check expiration
if metadata["expires_at"] < time.time():
await self.delete(key)
return default
# Read data
data = await f.read()
# Update memory cache in background
entry = CacheEntry(
data=data,
expires_at=metadata["expires_at"],
access_count=metadata["access_count"] + 1,
last_access=time.time(),
size=len(data),
)
self.memory_cache.set(key, entry)
return data
except FileNotFoundError:
return default
except Exception as e:
logger.error(f"Error reading from cache: {e}")
return default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""
Set value in both memory and file cache.
Args:
key: Cache key
data: Data to cache
ttl: Optional TTL override
Returns:
bool: Success status
"""
if not isinstance(data, (bytes, bytearray, memoryview)):
raise ValueError("Data must be bytes, bytearray, or memoryview")
expires_at = time.time() + (ttl or self.ttl)
# Create cache entry
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
key = self._get_md5_hash(key)
# Update memory cache
self.memory_cache.set(key, entry)
file_path = self._get_file_path(key)
temp_path = file_path.with_suffix(".tmp")
# Update file cache
try:
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
metadata_bytes = json.dumps(metadata).encode()
metadata_size = len(metadata_bytes).to_bytes(8, "big")
async with aiofiles.open(temp_path, "wb") as f:
await f.write(metadata_size)
await f.write(metadata_bytes)
await f.write(data)
await aiofiles.os.rename(temp_path, file_path)
return True
except Exception as e:
logger.error(f"Error writing to cache: {e}")
try:
await aiofiles.os.remove(temp_path)
except:
pass
return False
async def delete(self, key: str) -> bool:
"""Delete item from both caches."""
self.memory_cache.remove(key)
try:
file_path = self._get_file_path(key)
await aiofiles.os.remove(file_path)
return True
except FileNotFoundError:
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
class AsyncMemoryCache:
"""Async wrapper around LRUMemoryCache."""
def __init__(self, max_memory_size: int):
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""Get value from cache."""
entry = self.memory_cache.get(key)
return entry.data if entry is not None else default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""Set value in cache."""
try:
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
entry = CacheEntry(
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
)
self.memory_cache.set(key, entry)
return True
except Exception as e:
logger.error(f"Error setting cache value: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete item from cache."""
try:
self.memory_cache.remove(key)
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
# Create cache instances
INIT_SEGMENT_CACHE = HybridCache(
cache_dir_name="init_segment_cache",
ttl=3600, # 1 hour
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
)
MPD_CACHE = AsyncMemoryCache(
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
)
SPEEDTEST_CACHE = HybridCache(
cache_dir_name="speedtest_cache",
ttl=3600, # 1 hour
max_memory_size=50 * 1024 * 1024,
)
EXTRACTOR_CACHE = HybridCache(
cache_dir_name="extractor_cache",
ttl=5 * 60, # 5 minutes
max_memory_size=50 * 1024 * 1024,
)
# Specific cache implementations
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
"""Get initialization segment from cache or download it."""
# Try cache first
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
if cached_data is not None:
return cached_data
# Download if not cached
try:
init_content = await download_file_with_retry(init_url, headers)
if init_content:
await INIT_SEGMENT_CACHE.set(init_url, init_content)
return init_content
except Exception as e:
logger.error(f"Error downloading init segment: {e}")
return None
async def get_cached_mpd(
mpd_url: str,
headers: dict,
parse_drm: bool,
parse_segment_profile_id: Optional[str] = None,
) -> dict:
"""Get MPD from cache or download and parse it."""
# Try cache first
cached_data = await MPD_CACHE.get(mpd_url)
if cached_data is not None:
try:
mpd_dict = json.loads(cached_data)
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
except json.JSONDecodeError:
await MPD_CACHE.delete(mpd_url)
# Download and parse if not cached
try:
mpd_content = await download_file_with_retry(mpd_url, headers)
mpd_dict = parse_mpd(mpd_content)
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
# Cache the original MPD dict
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"])
return parsed_dict
except DownloadError as error:
logger.error(f"Error downloading MPD: {error}")
raise error
except Exception as error:
logger.exception(f"Error processing MPD: {error}")
raise error
async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
"""Get speed test results from cache."""
cached_data = await SPEEDTEST_CACHE.get(task_id)
if cached_data is not None:
try:
return SpeedTestTask.model_validate_json(cached_data.decode())
except ValidationError as e:
logger.error(f"Error parsing cached speed test data: {e}")
await SPEEDTEST_CACHE.delete(task_id)
return None
async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool:
"""Cache speed test results."""
try:
return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode())
except Exception as e:
logger.error(f"Error caching speed test data: {e}")
return False
async def get_cached_extractor_result(key: str) -> Optional[dict]:
"""Get extractor result from cache."""
cached_data = await EXTRACTOR_CACHE.get(key)
if cached_data is not None:
try:
return json.loads(cached_data)
except json.JSONDecodeError:
await EXTRACTOR_CACHE.delete(key)
return None
async def set_cache_extractor_result(key: str, result: dict) -> bool:
"""Cache extractor result."""
try:
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
except Exception as e:
logger.error(f"Error caching extractor result: {e}")
return False

View File

@@ -0,0 +1,110 @@
import base64
import json
import logging
import time
import traceback
from typing import Optional
from urllib.parse import urlencode
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from mediaflow_proxy.configs import settings
class EncryptionHandler:
def __init__(self, secret_key: str):
self.secret_key = secret_key.encode("utf-8").ljust(32)[:32]
def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str:
if expiration:
data["exp"] = int(time.time()) + expiration
if ip:
data["ip"] = ip
json_data = json.dumps(data).encode("utf-8")
iv = get_random_bytes(16)
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")
def decrypt_data(self, token: str, client_ip: str) -> dict:
try:
encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
iv = encrypted_data[:16]
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
data = json.loads(decrypted_data)
if "exp" in data:
if data["exp"] < time.time():
raise HTTPException(status_code=401, detail="Token has expired")
del data["exp"] # Remove expiration from the data
if "ip" in data:
if data["ip"] != client_ip:
raise HTTPException(status_code=403, detail="IP address mismatch")
del data["ip"] # Remove IP from the data
return data
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid or expired token")
class EncryptionMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
self.encryption_handler = encryption_handler
async def dispatch(self, request: Request, call_next):
encrypted_token = request.query_params.get("token")
if encrypted_token and self.encryption_handler:
try:
client_ip = self.get_client_ip(request)
decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
# Modify request query parameters with decrypted data
query_params = dict(request.query_params)
query_params.pop("token") # Remove the encrypted token from query params
query_params.update(decrypted_data) # Add decrypted data to query params
query_params["has_encrypted"] = True
# Create a new request scope with updated query parameters
new_query_string = urlencode(query_params)
request.scope["query_string"] = new_query_string.encode()
request._query_params = query_params
except HTTPException as e:
return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)
try:
response = await call_next(request)
except Exception:
exc = traceback.format_exc(chain=False)
logging.error("An error occurred while processing the request, error: %s", exc)
return JSONResponse(
content={"error": "An error occurred while processing the request, check the server for logs"},
status_code=500,
)
return response
@staticmethod
def get_client_ip(request: Request) -> Optional[str]:
"""
Extract the client's real IP address from the request headers or fallback to the client host.
"""
x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for:
# In some cases, this header can contain multiple IPs
# separated by commas.
# The first one is the original client's IP.
return x_forwarded_for.split(",")[0].strip()
# Fallback to X-Real-IP if X-Forwarded-For is not available
x_real_ip = request.headers.get("X-Real-IP")
if x_real_ip:
return x_real_ip
return request.client.host if request.client else "127.0.0.1"
encryption_handler = EncryptionHandler(settings.api_password) if settings.api_password else None

View File

@@ -0,0 +1,430 @@
import logging
import typing
from dataclasses import dataclass
from functools import partial
from urllib import parse
from urllib.parse import urlencode
import anyio
import httpx
import tenacity
from fastapi import Response
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.requests import Request
from starlette.types import Receive, Send, Scope
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm.asyncio import tqdm as tqdm_asyncio
from mediaflow_proxy.configs import settings
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
logger = logging.getLogger(__name__)
class DownloadError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(message)
def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient:
"""Creates an HTTPX client with configured proxy routing"""
mounts = settings.transport_config.get_mounts()
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs)
return client
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(DownloadError),
)
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
"""
Fetches a URL with retry logic.
Args:
client (httpx.AsyncClient): The HTTP client to use for the request.
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to fetch.
headers (dict): The headers to include in the request.
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
try:
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
response.raise_for_status()
return response
except httpx.TimeoutException:
logger.warning(f"Timeout while downloading {url}")
raise DownloadError(409, f"Timeout while downloading {url}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
if e.response.status_code == 404:
logger.error(f"Segment Resource not found: {url}")
raise e
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
except Exception as e:
logger.error(f"Error downloading {url}: {e}")
raise
class Streamer:
def __init__(self, client):
"""
Initializes the Streamer with an HTTP client.
Args:
client (httpx.AsyncClient): The HTTP client to use for streaming.
"""
self.client = client
self.response = None
self.progress_bar = None
self.bytes_transferred = 0
self.start_byte = 0
self.end_byte = 0
self.total_size = 0
async def create_streaming_response(self, url: str, headers: dict):
"""
Creates and sends a streaming request.
Args:
url (str): The URL to stream from.
headers (dict): The headers to include in the request.
"""
request = self.client.build_request("GET", url, headers=headers)
self.response = await self.client.send(request, stream=True, follow_redirects=True)
self.response.raise_for_status()
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
"""
Streams the content from the response.
"""
if not self.response:
raise RuntimeError("No response available for streaming")
try:
self.parse_content_range()
if settings.enable_streaming_progress:
with tqdm_asyncio(
total=self.total_size,
initial=self.start_byte,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc="Streaming",
ncols=100,
mininterval=1,
) as self.progress_bar:
async for chunk in self.response.aiter_bytes():
yield chunk
chunk_size = len(chunk)
self.bytes_transferred += chunk_size
self.progress_bar.set_postfix_str(
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
)
self.progress_bar.update(chunk_size)
else:
async for chunk in self.response.aiter_bytes():
yield chunk
self.bytes_transferred += len(chunk)
except httpx.TimeoutException:
logger.warning("Timeout while streaming")
raise DownloadError(409, "Timeout while streaming")
except GeneratorExit:
logger.info("Streaming session stopped by the user")
except Exception as e:
logger.error(f"Error streaming content: {e}")
raise
@staticmethod
def format_bytes(size) -> str:
power = 2**10
n = 0
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
while size > power:
size /= power
n += 1
return f"{size:.2f} {units[n]}"
def parse_content_range(self):
content_range = self.response.headers.get("Content-Range", "")
if content_range:
range_info = content_range.split()[-1]
self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-"))
else:
self.start_byte = 0
self.total_size = int(self.response.headers.get("Content-Length", 0))
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
async def get_text(self, url: str, headers: dict):
"""
Sends a GET request to a URL and returns the response text.
Args:
url (str): The URL to send the GET request to.
headers (dict): The headers to include in the request.
Returns:
str: The response text.
"""
try:
self.response = await fetch_with_retry(self.client, "GET", url, headers)
except tenacity.RetryError as e:
raise e.last_attempt.result()
return self.response.text
async def close(self):
"""
Closes the HTTP client and response.
"""
if self.response:
await self.response.aclose()
if self.progress_bar:
self.progress_bar.close()
await self.client.aclose()
async def download_file_with_retry(url: str, headers: dict):
"""
Downloads a file with retry logic.
Args:
url (str): The URL of the file to download.
headers (dict): The headers to include in the request.
Returns:
bytes: The downloaded file content.
Raises:
DownloadError: If the download fails after retries.
"""
async with create_httpx_client() as client:
try:
response = await fetch_with_retry(client, "GET", url, headers)
return response.content
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
raise e
except tenacity.RetryError as e:
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
"""
Sends an HTTP request with retry logic.
Args:
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to send the request to.
headers (dict): The headers to include in the request.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
async with create_httpx_client() as client:
try:
response = await fetch_with_retry(client, method, url, headers, **kwargs)
return response
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
raise
def encode_mediaflow_proxy_url(
mediaflow_proxy_url: str,
endpoint: typing.Optional[str] = None,
destination_url: typing.Optional[str] = None,
query_params: typing.Optional[dict] = None,
request_headers: typing.Optional[dict] = None,
response_headers: typing.Optional[dict] = None,
encryption_handler: EncryptionHandler = None,
expiration: int = None,
ip: str = None,
) -> str:
"""
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
Args:
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
query_params (dict, optional): Additional query parameters to include. Defaults to None.
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
Returns:
str: The encoded MediaFlow proxy URL.
"""
query_params = query_params or {}
if destination_url is not None:
query_params["d"] = destination_url
# Add headers if provided
if request_headers:
query_params.update(
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
)
if response_headers:
query_params.update(
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
)
if encryption_handler:
encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip)
encoded_params = urlencode({"token": encrypted_token})
else:
encoded_params = urlencode(query_params)
# Construct the full URL
if endpoint is None:
return f"{mediaflow_proxy_url}?{encoded_params}"
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
return f"{base_url}?{encoded_params}"
def get_original_scheme(request: Request) -> str:
"""
Determines the original scheme (http or https) of the request.
Args:
request (Request): The incoming HTTP request.
Returns:
str: The original scheme ('http' or 'https')
"""
# Check the X-Forwarded-Proto header first
forwarded_proto = request.headers.get("X-Forwarded-Proto")
if forwarded_proto:
return forwarded_proto
# Check if the request is secure
if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
return "https"
# Check for other common headers that might indicate HTTPS
if (
request.headers.get("X-Forwarded-Ssl") == "on"
or request.headers.get("X-Forwarded-Protocol") == "https"
or request.headers.get("X-Url-Scheme") == "https"
):
return "https"
# Default to http if no indicators of https are found
return "http"
@dataclass
class ProxyRequestHeaders:
request: dict
response: dict
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
"""
Extracts proxy headers from the request query parameters.
Args:
request (Request): The incoming HTTP request.
Returns:
ProxyRequest: A named tuple containing the request headers and response headers.
"""
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
return ProxyRequestHeaders(request_headers, response_headers)
class EnhancedStreamingResponse(Response):
body_iterator: typing.AsyncIterable[typing.Any]
def __init__(
self,
content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.init_headers(headers)
@staticmethod
async def listen_for_disconnect(receive: Receive) -> None:
try:
while True:
message = await receive()
if message["type"] == "http.disconnect":
logger.debug("Client disconnected")
break
except Exception as e:
logger.error(f"Error in listen_for_disconnect: {str(e)}")
async def stream_response(self, send: Send) -> None:
try:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
try:
await send({"type": "http.response.body", "body": chunk, "more_body": True})
except (ConnectionResetError, anyio.BrokenResourceError):
logger.info("Client disconnected during streaming")
return
await send({"type": "http.response.body", "body": b"", "more_body": False})
except Exception as e:
logger.exception(f"Error in stream_response: {str(e)}")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
try:
await func()
except Exception as e:
if not isinstance(e, anyio.get_cancelled_exc_class()):
logger.exception("Error in streaming task")
raise
finally:
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()

View File

@@ -0,0 +1,87 @@
import re
from urllib import parse
from mediaflow_proxy.utils.crypto_utils import encryption_handler
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
class M3U8Processor:
def __init__(self, request, key_url: str = None):
"""
Initializes the M3U8Processor with the request and URL prefix.
Args:
request (Request): The incoming HTTP request.
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
"""
self.request = request
self.key_url = parse.urlparse(key_url) if key_url else None
self.mediaflow_proxy_url = str(
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
)
async def process_m3u8(self, content: str, base_url: str) -> str:
"""
Processes the m3u8 content, proxying URLs and handling key lines.
Args:
content (str): The m3u8 content to process.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The processed m3u8 content.
"""
lines = content.splitlines()
processed_lines = []
for line in lines:
if "URI=" in line:
processed_lines.append(await self.process_key_line(line, base_url))
elif not line.startswith("#") and line.strip():
processed_lines.append(await self.proxy_url(line, base_url))
else:
processed_lines.append(line)
return "\n".join(processed_lines)
async def process_key_line(self, line: str, base_url: str) -> str:
"""
Processes a key line in the m3u8 content, proxying the URI.
Args:
line (str): The key line to process.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The processed key line.
"""
uri_match = re.search(r'URI="([^"]+)"', line)
if uri_match:
original_uri = uri_match.group(1)
uri = parse.urlparse(original_uri)
if self.key_url:
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
new_uri = await self.proxy_url(uri.geturl(), base_url)
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
return line
async def proxy_url(self, url: str, base_url: str) -> str:
"""
Proxies a URL, encoding it with the MediaFlow proxy URL.
Args:
url (str): The URL to proxy.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The proxied URL.
"""
full_url = parse.urljoin(base_url, url)
query_params = dict(self.request.query_params)
has_encrypted = query_params.pop("has_encrypted", False)
return encode_mediaflow_proxy_url(
self.mediaflow_proxy_url,
"",
full_url,
query_params=dict(self.request.query_params),
encryption_handler=encryption_handler if has_encrypted else None,
)

View File

@@ -0,0 +1,555 @@
import logging
import math
import re
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Optional, Union
from urllib.parse import urljoin
import xmltodict
logger = logging.getLogger(__name__)
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
"""
Parses the MPD content into a dictionary.
Args:
mpd_content (Union[str, bytes]): The MPD content to parse.
Returns:
dict: The parsed MPD content as a dictionary.
"""
return xmltodict.parse(mpd_content)
def parse_mpd_dict(
mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: Optional[str] = None
) -> dict:
"""
Parses the MPD dictionary and extracts relevant information.
Args:
mpd_dict (dict): The MPD content as a dictionary.
mpd_url (str): The URL of the MPD manifest.
parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
Returns:
dict: The parsed MPD information including profiles and DRM info.
This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
It handles both live and static MPD manifests.
"""
profiles = []
parsed_dict = {}
source = "/".join(mpd_url.split("/")[:-1])
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
parsed_dict["isLive"] = is_live
media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
# Parse additional MPD attributes for live streams
if is_live:
parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
)
parsed_dict["publishTime"] = datetime.fromisoformat(
mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
)
periods = mpd_dict["MPD"]["Period"]
periods = periods if isinstance(periods, list) else [periods]
for period in periods:
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
for adaptation in period["AdaptationSet"]:
representations = adaptation["Representation"]
representations = representations if isinstance(representations, list) else [representations]
for representation in representations:
profile = parse_representation(
parsed_dict,
representation,
adaptation,
source,
media_presentation_duration,
parse_segment_profile_id,
)
if profile:
profiles.append(profile)
parsed_dict["profiles"] = profiles
if parse_drm:
drm_info = extract_drm_info(periods, mpd_url)
else:
drm_info = {}
parsed_dict["drmInfo"] = drm_info
return parsed_dict
def pad_base64(encoded_key_id):
"""
Pads a base64 encoded key ID to make its length a multiple of 4.
Args:
encoded_key_id (str): The base64 encoded key ID.
Returns:
str: The padded base64 encoded key ID.
"""
return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
"""
Extracts DRM information from the MPD periods.
Args:
periods (List[Dict]): The list of periods in the MPD.
mpd_url (str): The URL of the MPD manifest.
Returns:
Dict: The extracted DRM information.
This function processes the ContentProtection elements in the MPD to extract DRM system information,
such as ClearKey, Widevine, and PlayReady.
"""
drm_info = {"isDrmProtected": False}
for period in periods:
adaptation_sets: Union[list[dict], dict] = period.get("AdaptationSet", [])
if not isinstance(adaptation_sets, list):
adaptation_sets = [adaptation_sets]
for adaptation_set in adaptation_sets:
# Check ContentProtection in AdaptationSet
process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
# Check ContentProtection inside each Representation
representations: Union[list[dict], dict] = adaptation_set.get("Representation", [])
if not isinstance(representations, list):
representations = [representations]
for representation in representations:
process_content_protection(representation.get("ContentProtection", []), drm_info)
# If we have a license acquisition URL, make sure it's absolute
if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
return drm_info
def process_content_protection(content_protection: Union[list[dict], dict], drm_info: dict):
"""
Processes the ContentProtection elements to extract DRM information.
Args:
content_protection (Union[list[dict], dict]): The ContentProtection elements.
drm_info (dict): The dictionary to store DRM information.
This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
"""
if not isinstance(content_protection, list):
content_protection = [content_protection]
for protection in content_protection:
drm_info["isDrmProtected"] = True
scheme_id_uri = protection.get("@schemeIdUri", "").lower()
if "clearkey" in scheme_id_uri:
drm_info["drmSystem"] = "clearkey"
if "clearkey:Laurl" in protection:
la_url = protection["clearkey:Laurl"].get("#text")
if la_url and "laUrl" not in drm_info:
drm_info["laUrl"] = la_url
elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
drm_info["drmSystem"] = "widevine"
pssh = protection.get("cenc:pssh", {}).get("#text")
if pssh:
drm_info["pssh"] = pssh
elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
drm_info["drmSystem"] = "playready"
if "@cenc:default_KID" in protection:
key_id = protection["@cenc:default_KID"].replace("-", "")
if "keyId" not in drm_info:
drm_info["keyId"] = key_id
if "ms:laurl" in protection:
la_url = protection["ms:laurl"].get("@licenseUrl")
if la_url and "laUrl" not in drm_info:
drm_info["laUrl"] = la_url
return drm_info
def parse_representation(
parsed_dict: dict,
representation: dict,
adaptation: dict,
source: str,
media_presentation_duration: str,
parse_segment_profile_id: Optional[str],
) -> Optional[dict]:
"""
Parses a representation and extracts profile information.
Args:
parsed_dict (dict): The parsed MPD data.
representation (dict): The representation data.
adaptation (dict): The adaptation set data.
source (str): The source URL.
media_presentation_duration (str): The media presentation duration.
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
Returns:
Optional[dict]: The parsed profile information or None if not applicable.
"""
mime_type = _get_key(adaptation, representation, "@mimeType") or (
"video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
)
if "video" not in mime_type and "audio" not in mime_type:
return None
profile = {
"id": representation.get("@id") or adaptation.get("@id"),
"mimeType": mime_type,
"lang": representation.get("@lang") or adaptation.get("@lang"),
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
"mediaPresentationDuration": media_presentation_duration,
}
if "audio" in profile["mimeType"]:
profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
else:
profile["width"] = int(representation["@width"])
profile["height"] = int(representation["@height"])
frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
profile["sar"] = representation.get("@sar", "1:1")
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
return profile
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
if item:
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
else:
profile["segments"] = parse_segment_base(representation, source)
return profile
def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
"""
Retrieves a key from the representation or adaptation set.
Args:
adaptation (dict): The adaptation set data.
representation (dict): The representation data.
key (str): The key to retrieve.
Returns:
Optional[str]: The value of the key or None if not found.
"""
return representation.get(key, adaptation.get(key, None))
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
"""
Parses a segment template and extracts segment information.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
Returns:
List[Dict]: The list of parsed segments.
"""
segments = []
timescale = int(item.get("@timescale", 1))
# Initialization
if "@initialization" in item:
media = item["@initialization"]
media = media.replace("$RepresentationID$", profile["id"])
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if not media.startswith("http"):
media = f"{source}/{media}"
profile["initUrl"] = media
# Segments
if "SegmentTimeline" in item:
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
elif "@duration" in item:
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
return segments
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
"""
Parses a segment timeline and extracts segment information.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment timeline data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of parsed segments.
"""
timelines = item["SegmentTimeline"]["S"]
timelines = timelines if isinstance(timelines, list) else [timelines]
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
start_number = int(item.get("@startNumber", 1))
segments = [
create_segment_data(timeline, item, profile, source, timescale)
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
]
return segments
def preprocess_timeline(
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
) -> List[Dict]:
"""
Preprocesses the segment timeline data.
Args:
timelines (List[Dict]): The list of timeline segments.
start_number (int): The starting segment number.
period_start (datetime): The start time of the period.
presentation_time_offset (int): The presentation time offset.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of preprocessed timeline segments.
"""
processed_data = []
current_time = 0
for timeline in timelines:
repeat = int(timeline.get("@r", 0))
duration = int(timeline["@d"])
start_time = int(timeline.get("@t", current_time))
for _ in range(repeat + 1):
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
processed_data.append(
{
"number": start_number,
"start_time": segment_start_time,
"end_time": segment_end_time,
"duration": duration,
"time": start_time,
}
)
start_time += duration
start_number += 1
current_time = start_time
return processed_data
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
"""
Parses segment duration and extracts segment information.
This is used for static or live MPD manifests.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment duration data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of parsed segments.
"""
duration = int(item["@duration"])
start_number = int(item.get("@startNumber", 1))
segment_duration_sec = duration / timescale
if parsed_dict["isLive"]:
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
else:
segments = generate_vod_segments(profile, duration, timescale, start_number)
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
"""
Generates live segments based on the segment duration and start number.
This is used for live MPD manifests.
Args:
parsed_dict (dict): The parsed MPD data.
segment_duration_sec (float): The segment duration in seconds.
start_number (int): The starting segment number.
Returns:
List[Dict]: The list of generated live segments.
"""
time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
current_time = datetime.now(tz=timezone.utc)
earliest_segment_number = max(
start_number
+ math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
- segment_count,
start_number,
)
return [
{
"number": number,
"start_time": parsed_dict["availabilityStartTime"]
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
"duration": segment_duration_sec,
}
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
]
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
"""
Generates VOD segments based on the segment duration and start number.
This is used for static MPD manifests.
Args:
profile (dict): The profile information.
duration (int): The segment duration.
timescale (int): The timescale for the segments.
start_number (int): The starting segment number.
Returns:
List[Dict]: The list of generated VOD segments.
"""
total_duration = profile.get("mediaPresentationDuration") or 0
if isinstance(total_duration, str):
total_duration = parse_duration(total_duration)
segment_count = math.ceil(total_duration * timescale / duration)
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
"""
Creates segment data based on the segment information. This includes the segment URL and metadata.
Args:
segment (Dict): The segment information.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int, optional): The timescale for the segments. Defaults to None.
Returns:
Dict: The created segment data.
"""
media_template = item["@media"]
media = media_template.replace("$RepresentationID$", profile["id"])
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
media = media.replace("$Number$", str(segment["number"]))
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if "time" in segment and timescale is not None:
media = media.replace("$Time$", str(int(segment["time"] * timescale)))
if not media.startswith("http"):
media = f"{source}/{media}"
segment_data = {
"type": "segment",
"media": media,
"number": segment["number"],
}
if "start_time" in segment and "end_time" in segment:
segment_data.update(
{
"start_time": segment["start_time"],
"end_time": segment["end_time"],
"extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
"program_date_time": segment["start_time"].isoformat() + "Z",
}
)
elif "start_time" in segment and "duration" in segment:
duration = segment["duration"]
segment_data.update(
{
"start_time": segment["start_time"],
"end_time": segment["start_time"] + timedelta(seconds=duration),
"extinf": duration,
"program_date_time": segment["start_time"].isoformat() + "Z",
}
)
elif "duration" in segment:
segment_data["extinf"] = segment["duration"]
return segment_data
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
"""
Parses segment base information and extracts segment data. This is used for single-segment representations.
Args:
representation (dict): The representation data.
source (str): The source URL.
Returns:
List[Dict]: The list of parsed segments.
"""
segment = representation["SegmentBase"]
start, end = map(int, segment["@indexRange"].split("-"))
if "Initialization" in segment:
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
return [
{
"type": "segment",
"range": f"{start}-{end}",
"media": f"{source}/{representation['BaseURL']}",
}
]
def parse_duration(duration_str: str) -> float:
"""
Parses a duration ISO 8601 string into seconds.
Args:
duration_str (str): The duration string to parse.
Returns:
float: The parsed duration in seconds.
"""
pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
match = pattern.match(duration_str)
if not match:
raise ValueError(f"Invalid duration format: {duration_str}")
years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds