Updated to newest version

This commit is contained in:
UrloMythus
2025-04-29 18:52:23 +02:00
parent c54be91e39
commit 323ca2d1b6
11 changed files with 358 additions and 237 deletions

View File

@@ -325,7 +325,7 @@ async def get_cached_mpd(
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
# Cache the original MPD dict
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"])
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict.get("minimumUpdatePeriod"))
return parsed_dict
except DownloadError as error:
logger.error(f"Error downloading MPD: {error}")

View File

@@ -29,11 +29,13 @@ class EncryptionHandler:
iv = get_random_bytes(16)
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8").rstrip("=")
def decrypt_data(self, token: str, client_ip: str) -> dict:
try:
encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
padding_needed = (4 - len(token) % 4) % 4
encrypted_token_b64_padded = token + ("=" * padding_needed)
encrypted_data = base64.urlsafe_b64decode(encrypted_token_b64_padded.encode("utf-8"))
iv = encrypted_data[:16]
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
@@ -60,14 +62,55 @@ class EncryptionMiddleware(BaseHTTPMiddleware):
self.encryption_handler = encryption_handler
async def dispatch(self, request: Request, call_next):
encrypted_token = request.query_params.get("token")
path = request.url.path
token_marker = "/_token_"
encrypted_token = None
# Check for token in path
if token_marker in path and self.encryption_handler:
try:
# Extract token from path
token_start = path.find(token_marker) + len(token_marker)
token_end = path.find("/", token_start)
if token_end == -1: # No trailing slash (no filename after token)
token_end = len(path)
filename_part = ""
else:
# There's something after the token (likely a filename)
filename_part = path[token_end:]
# Get the encrypted token
encrypted_token = path[token_start:token_end]
# Modify the path to remove the token part but preserve the filename
original_path = path[: path.find(token_marker)]
original_path += filename_part # Add back the filename part
request.scope["path"] = original_path
# Update the raw path as well
request.scope["raw_path"] = original_path.encode()
except Exception as e:
logging.error(f"Error processing token in path: {str(e)}")
return JSONResponse(content={"error": f"Invalid token in path: {str(e)}"}, status_code=400)
# Check for token in query parameters (original method)
if not encrypted_token: # Only check if we didn't already find a token in the path
encrypted_token = request.query_params.get("token")
# Process the token if found (from either source)
if encrypted_token and self.encryption_handler:
try:
client_ip = self.get_client_ip(request)
decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
# Modify request query parameters with decrypted data
query_params = dict(request.query_params)
query_params.pop("token") # Remove the encrypted token from query params
if "token" in query_params:
query_params.pop("token") # Remove the encrypted token from query params
query_params.update(decrypted_data) # Add decrypted data to query params
query_params["has_encrypted"] = True
@@ -75,8 +118,12 @@ class EncryptionMiddleware(BaseHTTPMiddleware):
new_query_string = urlencode(query_params)
request.scope["query_string"] = new_query_string.encode()
request._query_params = query_params
except HTTPException as e:
return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)
except Exception as e:
logging.error(f"Error decrypting token: {str(e)}")
return JSONResponse(content={"error": f"Invalid token: {str(e)}"}, status_code=400)
try:
response = await call_next(request)

View File

@@ -30,10 +30,11 @@ class DownloadError(Exception):
super().__init__(message)
def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient:
def create_httpx_client(follow_redirects: bool = True, **kwargs) -> httpx.AsyncClient:
"""Creates an HTTPX client with configured proxy routing"""
mounts = settings.transport_config.get_mounts()
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs)
kwargs.setdefault("timeout", settings.transport_config.timeout)
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, **kwargs)
return client
@@ -94,6 +95,11 @@ class Streamer:
self.end_byte = 0
self.total_size = 0
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(DownloadError),
)
async def create_streaming_response(self, url: str, headers: dict):
"""
Creates and sends a streaming request.
@@ -103,9 +109,27 @@ class Streamer:
headers (dict): The headers to include in the request.
"""
request = self.client.build_request("GET", url, headers=headers)
self.response = await self.client.send(request, stream=True, follow_redirects=True)
self.response.raise_for_status()
try:
request = self.client.build_request("GET", url, headers=headers)
self.response = await self.client.send(request, stream=True, follow_redirects=True)
self.response.raise_for_status()
except httpx.TimeoutException:
logger.warning("Timeout while creating streaming response")
raise DownloadError(409, "Timeout while creating streaming response")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} while creating streaming response")
if e.response.status_code == 404:
logger.error(f"Segment Resource not found: {url}")
raise e
raise DownloadError(
e.response.status_code, f"HTTP error {e.response.status_code} while creating streaming response"
)
except httpx.RequestError as e:
logger.error(f"Error creating streaming response: {e}")
raise DownloadError(502, f"Error creating streaming response: {e}")
except Exception as e:
logger.error(f"Error creating streaming response: {e}")
raise RuntimeError(f"Error creating streaming response: {e}")
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
"""
@@ -258,6 +282,7 @@ def encode_mediaflow_proxy_url(
encryption_handler: EncryptionHandler = None,
expiration: int = None,
ip: str = None,
filename: typing.Optional[str] = None,
) -> str:
"""
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
@@ -272,10 +297,12 @@ def encode_mediaflow_proxy_url(
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
filename (str, optional): Filename to be preserved for media players like Infuse. Defaults to None.
Returns:
str: The encoded MediaFlow proxy URL.
"""
# Prepare query parameters
query_params = query_params or {}
if destination_url is not None:
query_params["d"] = destination_url
@@ -290,18 +317,37 @@ def encode_mediaflow_proxy_url(
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
)
# Construct the base URL
if endpoint is None:
base_url = mediaflow_proxy_url
else:
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
# Ensure base_url doesn't end with a slash for consistent handling
if base_url.endswith("/"):
base_url = base_url[:-1]
# Handle encryption if needed
if encryption_handler:
encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip)
encoded_params = urlencode({"token": encrypted_token})
# Build the URL with token in path
path_parts = [base_url, f"_token_{encrypted_token}"]
# Add filename at the end if provided
if filename:
path_parts.append(parse.quote(filename))
return "/".join(path_parts)
else:
encoded_params = urlencode(query_params)
# No encryption, use regular query parameters
url = base_url
if filename:
url = f"{url}/{parse.quote(filename)}"
# Construct the full URL
if endpoint is None:
return f"{mediaflow_proxy_url}?{encoded_params}"
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
return f"{base_url}?{encoded_params}"
if query_params:
return f"{url}?{urlencode(query_params)}"
return url
def get_original_scheme(request: Request) -> str:

View File

@@ -1,4 +1,6 @@
import codecs
import re
from typing import AsyncGenerator
from urllib import parse
from mediaflow_proxy.utils.crypto_utils import encryption_handler
@@ -42,6 +44,70 @@ class M3U8Processor:
processed_lines.append(line)
return "\n".join(processed_lines)
async def process_m3u8_streaming(
self, content_iterator: AsyncGenerator[bytes, None], base_url: str
) -> AsyncGenerator[str, None]:
"""
Processes the m3u8 content on-the-fly, yielding processed lines as they are read.
Args:
content_iterator: An async iterator that yields chunks of the m3u8 content.
base_url (str): The base URL to resolve relative URLs.
Yields:
str: Processed lines of the m3u8 content.
"""
buffer = "" # String buffer for decoded content
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# Process the content chunk by chunk
async for chunk in content_iterator:
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
# Incrementally decode the chunk
decoded_chunk = decoder.decode(chunk)
buffer += decoded_chunk
# Process complete lines
lines = buffer.split("\n")
if len(lines) > 1:
# Process all complete lines except the last one
for line in lines[:-1]:
if line: # Skip empty lines
processed_line = await self.process_line(line, base_url)
yield processed_line + "\n"
# Keep the last line in the buffer (it might be incomplete)
buffer = lines[-1]
# Process any remaining data in the buffer plus final bytes
final_chunk = decoder.decode(b"", final=True)
if final_chunk:
buffer += final_chunk
if buffer: # Process the last line if it's not empty
processed_line = await self.process_line(buffer, base_url)
yield processed_line
async def process_line(self, line: str, base_url: str) -> str:
"""
Process a single line from the m3u8 content.
Args:
line (str): The line to process.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The processed line.
"""
if "URI=" in line:
return await self.process_key_line(line, base_url)
elif not line.startswith("#") and line.strip():
return await self.proxy_url(line, base_url)
else:
return line
async def process_key_line(self, line: str, base_url: str) -> str:
"""
Processes a key line in the m3u8 content, proxying the URI.

View File

@@ -317,7 +317,7 @@ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source:
"""
timelines = item["SegmentTimeline"]["S"]
timelines = timelines if isinstance(timelines, list) else [timelines]
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
period_start = parsed_dict.get("availabilityStartTime", datetime.fromtimestamp(0, tz=timezone.utc)) + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
start_number = int(item.get("@startNumber", 1))